Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ api / autograd.py

import copy
import re
from dataclasses import dataclass
from typing import Dict, List, Match, Optional, Sequence, Set, Tuple

from torchgen.api import cpp
from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT
from torchgen.model import (
    FunctionSchema,
    NativeFunction,
    NativeFunctionsViewGroup,
    SchemaKind,
    Type,
)
from torchgen.utils import IDENT_REGEX

# Represents a saved attribute involved in backward calculation.
# Note that it can be a derived property of an input argument, e.g.:
# we could save `other.scalar_type()` instead of the entire `other` tensor.
@dataclass(frozen=True)
class SavedAttribute:
    # The NamedCType holds the updated name and cpp type of the attribute
    # for the name, Suffix is appended if it's derived property, e.g.: `other_scalar_type`
    nctype: NamedCType

    # The expression to read the derived property at save time, e.g.:
    # `other.scalar_type()`.
    expr: str


# Represents a backward formula that calculates derivatives for one
# or more tensors.
@dataclass(frozen=True)
class Derivative:
    # The formula string (legit C++ expression).
    # Note that expressions against input arguments have been replaced with the
    # corresponding saved attributes.
    # E.g.:
    #  raw formula: `mul_tensor_backward(grad, self, other.scalar_type())`
    #         here: `mul_tensor_backward(grad, self, other_scalar_type)`
    formula: str

    # The formula string before input argument replacement
    original_formula: str

    # Names of the arguments for which this formula calculates derivatives.
    var_names: Tuple[str, ...]

    # Saved inputs that are referenced by the formula.
    saved_inputs: Tuple[SavedAttribute, ...]

    # Saved outputs that are referenced by the formula.
    saved_outputs: Tuple[SavedAttribute, ...]

    # Gradients that are referenced by name in the formula.
    named_gradients: Set[str]


# Represents a forward formula that calculates forward derivatives
# for one tensor.
@dataclass(frozen=True)
class ForwardDerivative:
    # The formula string (legit C++ expression).
    # Note that special keywords such as "linear" or "element_wise" have been
    # replaced by the automatically generated formula.
    formula: str

    # Name of the output arguments for which this formula calculates forward
    # derivatives
    var_names: Tuple[str, ...]

    # Type of the output arguments for which this formula calculates forward
    # derivatives
    var_types: Tuple[Type, ...]

    # Inputs for which the forward derivatives are required for this formula
    required_inputs_fw_grad: Optional[Tuple[str, ...]]

    # Inputs for which the primal is required for this formula
    required_inputs_primal: Optional[Tuple[str, ...]]

    # Flag to specify if this formula requires the original value of self
    # This is only used by inplace operations
    required_original_self_value: bool

    # If this formula is specified in derivatives.yaml or if we are re-using the
    # out of place formula for inplace
    is_reusing_outplace_formula: bool


# Represents differentiability info for a NativeFunction.
@dataclass(frozen=True)
class DifferentiabilityInfo:
    # The base name read from derivatives.yaml.
    name: str

    # The matching native function.
    #
    # There can be multiple NativeFunction having the same base name:
    #  - different overloads with different types of input arguments;
    #  - in-place/out/functional variants of the same function;
    #
    # We first use the schema string (under the 'name' key) in derivatives.yaml
    # to find the NativeFunction having the same schema string.
    # Then we find the in-place/out/functional variants of the matching function.
    # Among these variants, we choose the one having the same name as the
    # derivatives.yaml entry. If there is no exact match, then we choose the
    # in-place variant.
    # TODO: maybe the logic to search for all variants is no longer necessary?
    func: NativeFunction

    # The name of the generated autograd function.
    # It's set only if we will calculate a derivative, i.e.
    # 'args_with_derivatives' is not empty.
    op: Optional[str]

    # The derivatives formulae for this function.
    # Note that the length of this sequence is the number of differentiable inputs
    derivatives: Sequence[Derivative]

    # The forward derivatives formulae for this function.
    # Note that the length of this sequence is the number of differentiable outputs
    forward_derivatives: Sequence[ForwardDerivative]

    # The union of 'saved_inputs' of all 'derivatives'.
    all_saved_inputs: Sequence[SavedAttribute]

    # The union of 'saved_outputs' of all 'derivatives'.
    all_saved_outputs: Sequence[SavedAttribute]

    # All named gradients that are available for use, in the same
    # order as in the grads vector.
    available_named_gradients: Sequence[str]

    # The named gradients that are used in any of the derivatives.
    # Invariant: all(name in available_named_gradients for name in used_named_gradients)
    used_named_gradients: Set[str]

    # The function's input arguments for which it calculates derivatives.
    # It's the union of 'var_names' of all 'derivatives', sorted by the
    # argument order in the function schema.
    args_with_derivatives: Sequence[Binding]

    # Names of arguments whose derivative formula is 'non_differentiable'.
    non_differentiable_arg_names: Sequence[str]

    # Raw data read from derivatives.yaml.
    output_differentiability: Optional[List[bool]]

    # output_differentiability in derivatives.yaml can be a list of
    # conditions that express if the output is differentiable. In this case,
    # the number of conditions must match the number of outputs
    # (NB: we only support one condition right now).
    # output_differentiability gets populated with True for each condition,
    # while output_differentiability_conditions gets populated with the conditions
    output_differentiability_conditions: Optional[List[str]]

    @property
    def has_derivatives(self) -> bool:
        return len(self.args_with_derivatives) > 0

    # Generates a new DifferentiabilityInfo using the exact same set of derivative information,
    # but with a new operator name.
    # This is used when generating "copy" variants of view ops,
    # which are able to use the exact same derivative formula as the original view op
    # See Note [Codegen'd {view}_copy Operators]
    def create_view_copy_from_view_derivative(
        self, g: NativeFunctionsViewGroup
    ) -> Optional["DifferentiabilityInfo"]:
        if g.view_copy is None:
            return None
        f = g.view_copy

        name_split_by_period = self.name.split(".", maxsplit=2)
        # Append a "_copy" to the base name of the operator (but keep the overload name the same)
        view_copy_name = f"{name_split_by_period[0]}_copy." + ".".join(
            name_split_by_period[1:]
        )
        view_copy_op_name = None if self.op is None else f"{self.op}_copy"

        return DifferentiabilityInfo(
            # Use the "_copy" version of name/func/op
            name=view_copy_name,
            func=f,
            op=view_copy_op_name,
            # But keep all derivative info the same
            derivatives=self.derivatives,
            forward_derivatives=self.forward_derivatives,
            all_saved_inputs=self.all_saved_inputs,
            all_saved_outputs=self.all_saved_outputs,
            available_named_gradients=self.available_named_gradients,
            used_named_gradients=self.used_named_gradients,
            args_with_derivatives=self.args_with_derivatives,
            non_differentiable_arg_names=self.non_differentiable_arg_names,
            output_differentiability=self.output_differentiability,
            output_differentiability_conditions=self.output_differentiability_conditions,
        )


def uses_ident(info: Optional[DifferentiabilityInfo], ident: str) -> bool:
    if info is None:
        return False
    for derivative in info.derivatives:
        formula = derivative.formula
        if re.search(IDENT_REGEX.format(ident), formula):
            return True
    return False


def uses_retain_variables(info: Optional[DifferentiabilityInfo]) -> bool:
    return uses_ident(info, "retain_variables")


def uses_single_grad(info: Optional[DifferentiabilityInfo]) -> bool:
    return uses_ident(info, "grad")


# Represents a differentiable `Argument`.
# How is it different from the `Argument` type?
# - It's processed Arguments which are differentiable and only used in the
#   context of the autograd codegen;
# - It can represent SelfArgument or regular Argument but not TensorOptionsArgument;
@dataclass(frozen=True)
class DifferentiableInput:
    name: str
    type: Type

    # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
    cpp_type: str


# Represents a differentiable `Return`.
# How it it different from the `Return` type?
# - The name in `Return` is optional. Here it is always populated using the same
#   `cpp.return_names()` method.
#   TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant?
# - It's processed Returns which are differentiable, in compliance with the
#   `output_differentiability` field defined in derivatives.yaml (if specified),
#   and are only used in the context of the autograd codegen;
@dataclass(frozen=True)
class DifferentiableOutput:
    name: str
    type: Type

    # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
    cpp_type: str


@dataclass(frozen=True)
class NativeFunctionWithDifferentiabilityInfo:
    func: NativeFunction
    info: Optional[Dict[str, DifferentiabilityInfo]]
    fw_derivatives: Optional[Dict[str, Sequence[ForwardDerivative]]]


# TODO: Update comment below since it is out of date.
def dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str:
    """How are we going to call the underlying implementation of a
    declaration?  There are two strategies:
        - use_derived: we want to call the implementation on CPUDoubleType
          (or a similar, derived Type instance).  Because these derived
          instances deal in Tensors, not Variables (it's a completely different
          object, so it doesn't dispatch back to VariableType), code on
          this dispatch path needs to wrap/unwrap tensors.  If the
          derived implementation takes and returns tensors, the
          implementation is usually differentiable (although we also use
          the derived dispatch path for non-differentiable functions
          that we still want to dispatch on the derived Type instance;
          e.g., size())
        - use_type: we want to call the implementation on Type, because
          it is implemented concretely, and the functions it invokes will
          get dispatched back to VariableType (which will ensure that they
          are differentiable.)
    """
    # fn is derived as long as any of its per-key differentiability infos
    # has_derivatives. dispatch_strategy() is used to guard generation of fns in VariableType
    # and ADInplaceOrViewType. We want to generate these functions as long as a
    # derivative is defined for ANY dispatch key.
    if fn.func.is_abstract or (
        fn.info is not None and any(info.has_derivatives for info in fn.info.values())
    ):
        # If the function is abstract (not implemented on at::Type), we must
        # call the implementation on the derived type with unpacked tensors.

        # If the function has a derivative specified and is concrete, we could
        # call either implementation. We prefer the calling the derived
        # type's implementation with unpacked tensors because it is more
        # performant in some cases: any internal calls to other ATen functions
        # won't have the history tracked.

        # If the function has a type dispatched argument (i.e. is a factory),
        # we prefer calling the derived type's implementation both because it is
        # more performant and to ensure factory functions return tensors with _version
        # of 0 (probably not strictly necessary, but nice to have to keeps versions simple
        # to understand.

        return "use_derived"
    else:
        # If the function is concrete (we don't have to override it) and we
        # didn't declare it in derivatives.yaml, we'll assume that it is
        # actually implemented out of differentiable functions. (This
        # assumption might not hold, but then you'll see gradcheck fail.)
        return "use_type"


def match_differentiability_info(
    native_functions: List[NativeFunction],
    differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
) -> List[NativeFunctionWithDifferentiabilityInfo]:
    """Sets the "derivative" key on declarations to matching autograd function
    In-place functions will use the out-of-place derivative definition if there
    is no in-place specific derivative.
    """

    functional_info_by_signature = {
        schema.signature(strip_default=True): info_dict
        for schema, info_dict in differentiability_infos.items()
        if schema.kind() == SchemaKind.functional
    }
    non_functional_info_by_signature = {
        schema.signature(strip_default=True): info_dict
        for schema, info_dict in differentiability_infos.items()
        if schema.kind() != SchemaKind.functional
    }

    def find_info(
        f: NativeFunction,
    ) -> Tuple[Optional[Dict[str, DifferentiabilityInfo]], bool]:
        # Don't bother matching info to generated out= variants
        if "generated" in f.tags and f.func.kind() == SchemaKind.out:
            return None, False

        # (1) Check for an exact match
        if f.func in differentiability_infos:
            return differentiability_infos[f.func], True

        # (2) If no exact match, check if the out-of-place variant
        # of this operator has a match.
        # i.e mul() for mul_() or mul_out()
        f_sig = f.func.signature(strip_default=True)
        if f_sig in functional_info_by_signature:
            return functional_info_by_signature[f_sig], False

        # (3) Some operators have a derivative explicitly defined for the mutable
        # variant, but get a code-generated out-of-place variant which does *not*
Loading ...