Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

aroundthecode / SQLAlchemy   python

Repository URL to install this package:

Version: 1.2.10 

/ util / langhelpers.py

# util/langhelpers.py
# Copyright (C) 2005-2018 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

"""Routines to help with the creation, loading and introspection of
modules, classes, hierarchies, attributes, functions, and methods.

"""
import inspect
import itertools
import operator
import re
import sys
import types
import warnings
from functools import update_wrapper
from .. import exc
import hashlib
from . import compat
from . import _collections


def md5_hex(x):
    if compat.py3k:
        x = x.encode('utf-8')
    m = hashlib.md5()
    m.update(x)
    return m.hexdigest()


class safe_reraise(object):
    """Reraise an exception after invoking some
    handler code.

    Stores the existing exception info before
    invoking so that it is maintained across a potential
    coroutine context switch.

    e.g.::

        try:
            sess.commit()
        except:
            with safe_reraise():
                sess.rollback()

    """

    __slots__ = ('warn_only', '_exc_info')

    def __init__(self, warn_only=False):
        self.warn_only = warn_only

    def __enter__(self):
        self._exc_info = sys.exc_info()

    def __exit__(self, type_, value, traceback):
        # see #2703 for notes
        if type_ is None:
            exc_type, exc_value, exc_tb = self._exc_info
            self._exc_info = None   # remove potential circular references
            if not self.warn_only:
                compat.reraise(exc_type, exc_value, exc_tb)
        else:
            if not compat.py3k and self._exc_info and self._exc_info[1]:
                # emulate Py3K's behavior of telling us when an exception
                # occurs in an exception handler.
                warn(
                    "An exception has occurred during handling of a "
                    "previous exception.  The previous exception "
                    "is:\n %s %s\n" % (self._exc_info[0], self._exc_info[1]))
            self._exc_info = None   # remove potential circular references
            compat.reraise(type_, value, traceback)


def decode_slice(slc):
    """decode a slice object as sent to __getitem__.

    takes into account the 2.5 __index__() method, basically.

    """
    ret = []
    for x in slc.start, slc.stop, slc.step:
        if hasattr(x, '__index__'):
            x = x.__index__()
        ret.append(x)
    return tuple(ret)


def _unique_symbols(used, *bases):
    used = set(used)
    for base in bases:
        pool = itertools.chain((base,),
                               compat.itertools_imap(lambda i: base + str(i),
                                                     range(1000)))
        for sym in pool:
            if sym not in used:
                used.add(sym)
                yield sym
                break
        else:
            raise NameError("exhausted namespace for symbol base %s" % base)


def map_bits(fn, n):
    """Call the given function given each nonzero bit from n."""

    while n:
        b = n & (~n + 1)
        yield fn(b)
        n ^= b


def decorator(target):
    """A signature-matching decorator factory."""

    def decorate(fn):
        if not inspect.isfunction(fn):
            raise Exception("not a decoratable function")
        spec = compat.inspect_getfullargspec(fn)
        names = tuple(spec[0]) + spec[1:3] + (fn.__name__,)
        targ_name, fn_name = _unique_symbols(names, 'target', 'fn')

        metadata = dict(target=targ_name, fn=fn_name)
        metadata.update(format_argspec_plus(spec, grouped=False))
        metadata['name'] = fn.__name__
        code = """\
def %(name)s(%(args)s):
    return %(target)s(%(fn)s, %(apply_kw)s)
""" % metadata
        decorated = _exec_code_in_env(code,
                                      {targ_name: target, fn_name: fn},
                                      fn.__name__)
        decorated.__defaults__ = getattr(fn, 'im_func', fn).__defaults__
        decorated.__wrapped__ = fn
        return update_wrapper(decorated, fn)
    return update_wrapper(decorate, target)


def _exec_code_in_env(code, env, fn_name):
    exec(code, env)
    return env[fn_name]


def public_factory(target, location):
    """Produce a wrapping function for the given cls or classmethod.

    Rationale here is so that the __init__ method of the
    class can serve as documentation for the function.

    """
    if isinstance(target, type):
        fn = target.__init__
        callable_ = target
        doc = "Construct a new :class:`.%s` object. \n\n"\
            "This constructor is mirrored as a public API function; "\
            "see :func:`~%s` "\
            "for a full usage and argument description." % (
                target.__name__, location, )
    else:
        fn = callable_ = target
        doc = "This function is mirrored; see :func:`~%s` "\
            "for a description of arguments." % location

    location_name = location.split(".")[-1]
    spec = compat.inspect_getfullargspec(fn)
    del spec[0][0]
    metadata = format_argspec_plus(spec, grouped=False)
    metadata['name'] = location_name
    code = """\
def %(name)s(%(args)s):
    return cls(%(apply_kw)s)
""" % metadata
    env = {'cls': callable_, 'symbol': symbol}
    exec(code, env)
    decorated = env[location_name]
    decorated.__doc__ = fn.__doc__
    decorated.__module__ = "sqlalchemy" + location.rsplit(".", 1)[0]
    if compat.py2k or hasattr(fn, '__func__'):
        fn.__func__.__doc__ = doc
    else:
        fn.__doc__ = doc
    return decorated


class PluginLoader(object):

    def __init__(self, group, auto_fn=None):
        self.group = group
        self.impls = {}
        self.auto_fn = auto_fn

    def clear(self):
        self.impls.clear()

    def load(self, name):
        if name in self.impls:
            return self.impls[name]()

        if self.auto_fn:
            loader = self.auto_fn(name)
            if loader:
                self.impls[name] = loader
                return loader()

        try:
            import pkg_resources
        except ImportError:
            pass
        else:
            for impl in pkg_resources.iter_entry_points(
                    self.group, name):
                self.impls[name] = impl.load
                return impl.load()

        raise exc.NoSuchModuleError(
            "Can't load plugin: %s:%s" %
            (self.group, name))

    def register(self, name, modulepath, objname):
        def load():
            mod = compat.import_(modulepath)
            for token in modulepath.split(".")[1:]:
                mod = getattr(mod, token)
            return getattr(mod, objname)
        self.impls[name] = load


def get_cls_kwargs(cls, _set=None):
    r"""Return the full set of inherited kwargs for the given `cls`.

    Probes a class's __init__ method, collecting all named arguments.  If the
    __init__ defines a \**kwargs catch-all, then the constructor is presumed
    to pass along unrecognized keywords to its base classes, and the
    collection process is repeated recursively on each of the bases.

    Uses a subset of inspect.getargspec() to cut down on method overhead.
    No anonymous tuple arguments please !

    """
    toplevel = _set is None
    if toplevel:
        _set = set()

    ctr = cls.__dict__.get('__init__', False)

    has_init = ctr and isinstance(ctr, types.FunctionType) and \
        isinstance(ctr.__code__, types.CodeType)

    if has_init:
        names, has_kw = inspect_func_args(ctr)
        _set.update(names)

        if not has_kw and not toplevel:
            return None

    if not has_init or has_kw:
        for c in cls.__bases__:
            if get_cls_kwargs(c, _set) is None:
                break

    _set.discard('self')
    return _set


try:
    # TODO: who doesn't have this constant?
    from inspect import CO_VARKEYWORDS

    def inspect_func_args(fn):
        co = fn.__code__
        nargs = co.co_argcount
        names = co.co_varnames
        args = list(names[:nargs])
        has_kw = bool(co.co_flags & CO_VARKEYWORDS)
        return args, has_kw

except ImportError:
    def inspect_func_args(fn):
        names, _, has_kw, _ = compat.inspect_getargspec(fn)
        return names, bool(has_kw)


def get_func_kwargs(func):
    """Return the set of legal kwargs for the given `func`.

    Uses getargspec so is safe to call for methods, functions,
    etc.

    """

    return compat.inspect_getargspec(func)[0]


def get_callable_argspec(fn, no_self=False, _is_init=False):
    """Return the argument signature for any callable.

    All pure-Python callables are accepted, including
    functions, methods, classes, objects with __call__;
    builtins and other edge cases like functools.partial() objects
    raise a TypeError.

    """
    if inspect.isbuiltin(fn):
        raise TypeError("Can't inspect builtin: %s" % fn)
    elif inspect.isfunction(fn):
        if _is_init and no_self:
            spec = compat.inspect_getargspec(fn)
            return compat.ArgSpec(spec.args[1:], spec.varargs,
                                  spec.keywords, spec.defaults)
        else:
            return compat.inspect_getargspec(fn)
    elif inspect.ismethod(fn):
        if no_self and (_is_init or fn.__self__):
            spec = compat.inspect_getargspec(fn.__func__)
            return compat.ArgSpec(spec.args[1:], spec.varargs,
                                  spec.keywords, spec.defaults)
        else:
            return compat.inspect_getargspec(fn.__func__)
    elif inspect.isclass(fn):
        return get_callable_argspec(
            fn.__init__, no_self=no_self, _is_init=True)
    elif hasattr(fn, '__func__'):
        return compat.inspect_getargspec(fn.__func__)
    elif hasattr(fn, '__call__'):
        if inspect.ismethod(fn.__call__):
            return get_callable_argspec(fn.__call__, no_self=no_self)
        else:
            raise TypeError("Can't inspect callable: %s" % fn)
    else:
        raise TypeError("Can't inspect callable: %s" % fn)


def format_argspec_plus(fn, grouped=True):
    """Returns a dictionary of formatted, introspected function arguments.

    A enhanced variant of inspect.formatargspec to support code generation.

    fn
       An inspectable callable or tuple of inspect getargspec() results.
    grouped
      Defaults to True; include (parens, around, argument) lists
Loading ...