Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ dim / wrap_type.py

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from types import FunctionType, BuiltinMethodType, MethodDescriptorType, WrapperDescriptorType, GetSetDescriptorType
from functorch._C import dim as _C
_wrap_method = _C._wrap_method

FUNC_TYPES = (FunctionType, MethodDescriptorType, BuiltinMethodType, WrapperDescriptorType)
PROPERTY_TYPES = (GetSetDescriptorType, property)

def _py_wrap_method(orig, __torch_function__):
    def impl(*args, **kwargs):
        return __torch_function__(orig, None, args, kwargs)
    return impl



def wrap_type(use_c, to_patch, pattern, __torch_function__):

    if use_c:
        wrap_method = _wrap_method
    else:
        wrap_method = _py_wrap_method

    all = {}
    for t in reversed(pattern.mro()[:-1]):  # skip object
        all.update(t.__dict__)


    def wrap_attr(orig):
        return property(wrap_method(orig.__get__, __torch_function__))


    for name, obj in all.items():
        if name in ('__dict__', '__new__', '__init__', '__repr__', '__weakref__', '__doc__', '__module__', '__dir__'):
            continue

        # skip things that have been overloaded
        # things that come from object like `__eq__` still need to be patched, however.
        if hasattr(to_patch, name) and getattr(to_patch, name) is not getattr(object, name, None):
            continue

        if isinstance(obj, FUNC_TYPES):
            setattr(to_patch, name, wrap_method(obj, __torch_function__))
        elif isinstance(obj, PROPERTY_TYPES):
            setattr(to_patch, name, wrap_attr(obj))