# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from types import FunctionType, BuiltinMethodType, MethodDescriptorType, WrapperDescriptorType, GetSetDescriptorType
from functorch._C import dim as _C
_wrap_method = _C._wrap_method
FUNC_TYPES = (FunctionType, MethodDescriptorType, BuiltinMethodType, WrapperDescriptorType)
PROPERTY_TYPES = (GetSetDescriptorType, property)
def _py_wrap_method(orig, __torch_function__):
def impl(*args, **kwargs):
return __torch_function__(orig, None, args, kwargs)
return impl
def wrap_type(use_c, to_patch, pattern, __torch_function__):
if use_c:
wrap_method = _wrap_method
else:
wrap_method = _py_wrap_method
all = {}
for t in reversed(pattern.mro()[:-1]): # skip object
all.update(t.__dict__)
def wrap_attr(orig):
return property(wrap_method(orig.__get__, __torch_function__))
for name, obj in all.items():
if name in ('__dict__', '__new__', '__init__', '__repr__', '__weakref__', '__doc__', '__module__', '__dir__'):
continue
# skip things that have been overloaded
# things that come from object like `__eq__` still need to be patched, however.
if hasattr(to_patch, name) and getattr(to_patch, name) is not getattr(object, name, None):
continue
if isinstance(obj, FUNC_TYPES):
setattr(to_patch, name, wrap_method(obj, __torch_function__))
elif isinstance(obj, PROPERTY_TYPES):
setattr(to_patch, name, wrap_attr(obj))