Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ distributed / _shard / op_registry_utils.py

import functools
from inspect import signature
from .common_op_utils import _basic_validation

"""
Common utilities to register ops on ShardedTensor, ReplicatedTensor
and PartialTensor.
"""

def _register_op(op, func, op_table):
    """
    Performs basic validation and registers the provided op in the given
    op_table.
    """
    if len(signature(func).parameters) != 4:
        raise TypeError(
            f'Custom sharded op function expects signature: '
            f'(types, args, kwargs, process_group), but received '
            f'signature: {signature(func)}')

    op_table[op] = func

def _decorator_func(wrapped_func, op, op_table):
    """
    Decorator function to register the given ``op`` in the provided
    ``op_table``
    """

    @functools.wraps(wrapped_func)
    def wrapper(types, args, kwargs, process_group):
        _basic_validation(op, args, kwargs)
        return wrapped_func(types, args, kwargs, process_group)

    _register_op(op, wrapper, op_table)
    return wrapper