Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ api / translate.py

from typing import Dict, List, NoReturn, Sequence, Union

from torchgen.api.types import (
    ArrayRefCType,
    BaseCType,
    Binding,
    boolT,
    ConstRefCType,
    deviceT,
    Expr,
    intArrayRefT,
    iOptTensorListRefT,
    layoutT,
    ListCType,
    longT,
    memoryFormatT,
    MutRefCType,
    NamedCType,
    opmath_t,
    OptionalCType,
    optionalIntArrayRefT,
    optionalScalarRefT,
    optionalSymIntArrayRefT,
    optionalTensorRefT,
    scalar_t,
    scalarT,
    scalarTypeT,
    SpecialArgName,
    symIntArrayRefT,
    SymIntT,
    tensorOptionsT,
    tensorT,
    VectorCType,
)

# This file implements a small program synthesis engine that implements
# conversions between one API to another.
#
# The key data type in this file in NamedCType, short for Named C++ semantic type.  A NamedCType
# represents a C++ type, plus semantic information about what it represents.
# For example, consider the argument "bool pin_memory"; its normal C++ type is
# "bool", but its C++ semantic type also keeps track that this represents a
# "pin_memory"; you can't just use a random other boolean in a context where you
# need a "pin_memory"!
#
# The translator takes a list of needed NamedCTypes, and then figures out how
# to construct expressions with these NamedCTypes from the given bindings.  Many
# of these expressions are trivial (I need a Tensor other; there's a Tensor
# other scope); others are more nontrivial and may require packing/unpacking.
# Some examples of non-trivial action:
#
#   - Need the "dtype" binding?  Well, maybe "dtype" isn't available
#     in the context, instead, "options" is, and you need to extract
#     it from there.  (Gather)
#
#   - Need the "context" binding?  Well, maybe "context" isn't available
#     in the context, and you need to construct it from "dtype", "device",
#     etc.  (Scatter)
#
#   - Need the "memory_format" binding?  Well, actually, it's available
#     from both "memory_format" and "options", so you had better make sure
#     they are consistent.  (Join)

options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT)))

out_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT)))

longVec_ctype = VectorCType(BaseCType(longT))
longSymVec_ctype = VectorCType(BaseCType(SymIntT))
optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT)))
optionalScalar_ctype = OptionalCType(BaseCType(scalarT))
optionalTensor_ctype = OptionalCType(BaseCType(tensorT))


class UnsatError(RuntimeError):
    pass


# Given a set of in-scope bindings and a set of target bindings, synthesize
# a list of expressions that uses only the in-scope bindings (bindings) that
# have all of the types of goals.  You may want to use this function if
# you're generating code for a function like:
#
#   void f({args}) {
#     g({exprs}); // g is a different API
#   }
#
# and you need to generate "exprs".
#
# Typically, a list of Bindings is convenient to get (you usually call something
# like arguments() to get them); but technically you only need less information:
# for 'bindings' an (un-ordered) list of Exprs is sufficient; similarly, for
# 'goals', an (ordered) list of NamedCType goals is sufficient.  If you are doing
# something more complicated, e.g., tracking the set of bindings in a context,
# you may find using these smaller types more convenient.
def translate(
    bindings: Sequence[Union[Expr, Binding]],
    goals: Sequence[Union[NamedCType, Binding]],
    *,
    method: bool = False,
    allow_expensive_conversions: bool = False,
) -> List[Expr]:

    binding_exprs: List[Expr] = []
    for b in bindings:
        if isinstance(b, Binding):
            binding_exprs.append(
                Expr(
                    expr=b.name,
                    type=b.nctype,
                )
            )
        else:
            binding_exprs.append(b)

    goal_ctypes: List[NamedCType] = []
    for g in goals:
        if isinstance(g, Binding):
            goal_ctypes.append(g.nctype)
        else:
            goal_ctypes.append(g)

    # Add all the bindings to the context
    ctx: Dict[NamedCType, str] = {}
    for b in binding_exprs:
        ctx[b.type] = b.expr

        # While we're at it, do some simple forward inference, looking through
        # constructors.
        #
        # NB: When should you do forward inference versus backward inference?
        # The general idea:
        #
        #   - Backward inference WHEN the goal gets smaller
        #   - Forward inference WHEN the hypothesis gets smaller
        #
        # This helps ensure termination: backward inference starts with a goal
        # and tries to make it simpler and simpler until it's trivial; if the
        # goal can grow in size, we blow up to a really huge goal size.
        # Similarly, with forward inference we take hypotheses and decompose
        # them into simpler hypotheses; if hypotheses could expand in size,
        # we also have potential nontermination.  (In the code below, forward
        # inference is only ever carried out at a single step, but you could
        # imagine repeated application of forward inference being profitable.)
        #
        # A good starting point in the literature for exploring more about proof
        # search are these lecture notes
        # https://www.cs.cmu.edu/~fp/courses/oregon-m10/04-focusing.pdf
        #
        # TODO: My kingdom for a pattern matcher
        # https://www.python.org/dev/peps/pep-0634/
        #
        # TODO: This could get us in recomputation trouble if b.expr is nontrivial.
        # Fix this by implementing some sort of sharing so that if multiple
        # goals share the same expression, we only compute it once.  This seems
        # to matter in practice as compiler is often unwilling to CSE nontrivial
        # expressions like scalar.to<scalar_t>()
        t = b.type
        if (
            isinstance(t, ConstRefCType)
            and isinstance(t.elem, OptionalCType)
            and isinstance(t.elem.elem, BaseCType)
            and str(t.elem.elem.type) == "at::Tensor"
        ):
            ctx[
                NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))
            ] = f"({b.expr}.has_value() ? *{b.expr} : at::Tensor())"

        if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))):
            ctx[
                NamedCType(t.name, BaseCType(optionalTensorRefT))
            ] = f"(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())"

        if t.type == ConstRefCType(BaseCType(scalarT)):
            ctx[NamedCType(t.name, BaseCType(opmath_t))] = f"({b.expr}).to<opmath_t>()"

        if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))):
            ctx[
                NamedCType(t.name, BaseCType(optionalScalarRefT))
            ] = f"({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())"

        if t.type == BaseCType(scalar_t):
            ctx[
                NamedCType(t.name, BaseCType(opmath_t))
            ] = f"static_cast<opmath_t>({b.expr})"

        # [Note: IOptTensorListRef]
        if t.type == ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))):
            ctx[
                NamedCType(t.name, BaseCType(iOptTensorListRefT))
            ] = f"at::IOptTensorListRef({b.expr})"

    # Add implicit bindings if the generated code is inside a Tensor method
    if method:
        ctx[
            NamedCType("self", MutRefCType(BaseCType(tensorT)))
        ] = "const_cast<Tensor&>(*this)"
        ctx[
            NamedCType("self", ConstRefCType(BaseCType(tensorT)))
        ] = "const_cast<Tensor&>(*this)"
        # This is better!  Byte-for-byte compat
        # ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this"

    def unsat(goal: NamedCType) -> NoReturn:
        ctx_desc = "\n".join(
            f"  {t.cpp_type()} {t.name}; // {e}" for t, e in ctx.items()
        )
        raise UnsatError(
            f"""
Failed to synthesize the expression "{goal.cpp_type()} {goal.name}".
When I failed, the following bindings were available in the context:

{ctx_desc}

This probably means there is a missing rule in the rules of torchgen.api.translate.
Check this module for more information.
"""
        )

    # A shitty backtracking search implementation.  It's shitty because it
    # does backtracking via stack (bad idea!) and for the most part tries to
    # avoid backtracking.  In particular, if
    # direct=True, we won't try to do any fancy synthesis, just trivial
    # conversions (e.g., "T a" is OK for "const T& a").  So all of the
    # existing rules in this function simply try to solve immediately,
    # and bail if things don't work out.
    def solve(goal: NamedCType, *, direct: bool) -> str:
        def direct_solve(goal: NamedCType) -> str:
            return solve(goal, direct=True)

        if goal in ctx:
            # Trivial
            return ctx[goal]

        # const & is satisfied with mutable &
        if isinstance(goal.type, ConstRefCType):
            try:
                # WARNING: not strictly decreasing; be careful not
                # to add a direct conversion that goes satisfies
                # mutable& with const&
                return solve(
                    NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct
                )
            except UnsatError:
                pass

        # mutable & is satisfied with value
        if isinstance(goal.type, MutRefCType):
            try:
                return solve(NamedCType(goal.name, goal.type.elem), direct=direct)
            except UnsatError:
                pass

        # TODO: These are referentially equal, shouldn't have to do this;
        # ensuring we don't use type synonym IntArrayRef in codegen would
        # help
        if goal.type == ArrayRefCType(BaseCType(longT)):
            return solve(NamedCType(goal.name, BaseCType(intArrayRefT)), direct=direct)

        if direct:
            unsat(goal)

        # For now, all of these rules are mutually exclusive.
        if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))):
            memory_format = direct_solve(
                NamedCType(
                    SpecialArgName.possibly_redundant_memory_format,
                    OptionalCType(BaseCType(memoryFormatT)),
                )
            )
            # No need to join "memory_format" and "options" if the target API takes "options" directly.
            # Otherwise it will cause the redundant memory_format error.
            if options_ctype in goal_ctypes:
                return memory_format
            try:
                options = direct_solve(options_ctype)
                return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})"
            except UnsatError:
                return memory_format
        elif goal == NamedCType("options", BaseCType(tensorOptionsT)):
            dtype = direct_solve(
                NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT)))
            )
            pin_memory = direct_solve(
                NamedCType("pin_memory", OptionalCType(BaseCType(boolT)))
            )
            device = direct_solve(
                NamedCType("device", OptionalCType(BaseCType(deviceT)))
            )
            layout = direct_solve(
                NamedCType("layout", OptionalCType(BaseCType(layoutT)))
            )
            return f"TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})"

        elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))):
            try:
                options = direct_solve(options_ctype)
                return f"optTypeMetaToScalarType({options}.dtype_opt())"
            except UnsatError:
                out_tensor = direct_solve(out_tensor_ctype)
                return f"{out_tensor}.scalar_type()"

        elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))):
            try:
                options = direct_solve(options_ctype)
                return f"{options}.layout_opt()"
            except UnsatError:
                out_tensor = direct_solve(out_tensor_ctype)
                return f"{out_tensor}.layout()"

        elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))):
            try:
                options = direct_solve(options_ctype)
                return f"{options}.device_opt()"
            except UnsatError:
                out_tensor = direct_solve(out_tensor_ctype)
                return f"{out_tensor}.device()"

        elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))):
            try:
                options = direct_solve(options_ctype)
                return f"{options}.pinned_memory_opt()"
            except UnsatError:
                # If we're calling a factory op from its out= variant,
                # We don't actually care about the value of pin_memory.
                out_tensor = direct_solve(out_tensor_ctype)
                return "c10::nullopt"

        # We can always do translations from value types to reference types, like vector<int> -> IntArrayRef
        elif goal.type == BaseCType(intArrayRefT):
            try:
                return direct_solve(NamedCType(goal.name, longVec_ctype))
            except UnsatError:
                # We can also go SymIntArrayRef -> IntArrayRef
                symIntArrayRef_type = direct_solve(
                    NamedCType(goal.name, BaseCType(symIntArrayRefT))
                )
                return f"C10_AS_INTARRAYREF_SLOW({symIntArrayRef_type})"
        elif goal.type == BaseCType(symIntArrayRefT):
            try:
                r = direct_solve(NamedCType(goal.name, BaseCType(intArrayRefT)))
                return f"c10::fromIntArrayRefSlow({r})"
            except UnsatError:
                return direct_solve(NamedCType(goal.name, longSymVec_ctype))
        elif goal.type == BaseCType(SymIntT):
Loading ...