Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ jit / annotations.py

import ast
import dis
import enum
import inspect
import re
import builtins
import torch
import warnings
from .._jit_internal import List, Tuple, is_tuple, is_list, Dict, is_dict, Optional, \
    is_optional, _qualified_name, Any, Future, is_future, _Await, is_await, is_ignored_fn, Union, is_union
from .._jit_internal import BroadcastingList1, BroadcastingList2, BroadcastingList3  # type: ignore[attr-defined]
from ._state import _get_script_class

from torch._C import TensorType, TupleType, FloatType, IntType, ComplexType, \
    ListType, StringType, DictType, BoolType, OptionalType, InterfaceType, AnyType, \
    NoneType, DeviceObjType, StreamObjType, FutureType, AwaitType, EnumType, UnionType, NumberType


from textwrap import dedent
from torch._sources import get_source_lines_and_file
from typing import Type

if torch.distributed.rpc.is_available():
    from .._jit_internal import RRef, is_rref
    from torch._C import RRefType

from torch._ops import OpOverloadPacket

class Module:
    def __init__(self, name, members):
        self.name = name
        self.members = members

    def __getattr__(self, name):
        try:
            return self.members[name]
        except KeyError:
            raise RuntimeError(f"Module {self.name} has no member called {name}") from None


class EvalEnv:
    env = {
        'torch': Module('torch', {'Tensor': torch.Tensor}),
        'Tensor': torch.Tensor,
        'typing': Module('typing', {'Tuple': Tuple}),
        'Tuple': Tuple,
        'List': List,
        'Dict': Dict,
        'Optional': Optional,
        'Union': Union,
        'Future': Future,
        'Await': _Await
    }

    def __init__(self, rcb):
        self.rcb = rcb
        if torch.distributed.rpc.is_available():
            self.env['RRef'] = RRef

    def __getitem__(self, name):
        if name in self.env:
            return self.env[name]
        if self.rcb is not None:
            return self.rcb(name)
        return getattr(builtins, name, None)

def get_signature(fn, rcb, loc, is_method):
    if isinstance(fn, OpOverloadPacket):
        signature = try_real_annotations(fn.op, loc)
    else:
        signature = try_real_annotations(fn, loc)
    if signature is not None and is_method:
        # If this is a method, then the signature will include a type for
        # `self`, but type comments do not contain a `self`. So strip it
        # away here so everything is consistent (`inspect.ismethod` does
        # not work here since `fn` is unbound at this point)
        param_types, return_type = signature
        param_types = param_types[1:]
        signature = (param_types, return_type)

    if signature is None:
        type_line, source = None, None
        try:
            source = dedent(''.join(get_source_lines_and_file(fn)[0]))
            type_line = get_type_line(source)
        except TypeError:
            pass
        # This might happen both because we failed to get the source of fn, or
        # because it didn't have any annotations.
        if type_line is not None:
            signature = parse_type_line(type_line, rcb, loc)

    return signature


def is_function_or_method(the_callable):
    # A stricter version of `inspect.isroutine` that does not pass for built-in
    # functions
    return inspect.isfunction(the_callable) or inspect.ismethod(the_callable)


def is_vararg(the_callable):
    if not is_function_or_method(the_callable) and hasattr(the_callable, '__call__'):  # noqa: B004
        # If `the_callable` is a class, de-sugar the call so we can still get
        # the signature
        the_callable = the_callable.__call__

    if is_function_or_method(the_callable):
        return inspect.getfullargspec(the_callable).varargs is not None
    else:
        return False


def get_param_names(fn, n_args):
    if isinstance(fn, OpOverloadPacket):
        fn = fn.op

    if not is_function_or_method(fn) and hasattr(fn, '__call__') and is_function_or_method(fn.__call__):  # noqa: B004
        # De-sugar calls to classes
        fn = fn.__call__

    if is_function_or_method(fn):
        if is_ignored_fn(fn):
            fn = inspect.unwrap(fn)
        return inspect.getfullargspec(fn).args
    else:
        # The `fn` was not a method or function (maybe a class with a __call__
        # method, so use a default param name list)
        return [str(i) for i in range(n_args)]


def check_fn(fn, loc):
    # Make sure the function definition is not a class instantiation
    try:
        source = dedent(''.join(get_source_lines_and_file(fn)[0]))
    except (TypeError, IOError):
        return
    if source is None:
        return

    py_ast = ast.parse(source)
    if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef):
        raise torch.jit.frontend.FrontendError(
            loc, f"Cannot instantiate class '{py_ast.body[0].name}' in a script function")
    if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
        raise torch.jit.frontend.FrontendError(loc, "Expected a single top-level function")


def _eval_no_call(stmt, glob, loc):
    """Evaluate statement as long as it does not contain any method/function calls"""
    bytecode = compile(stmt, "", mode="eval")
    for insn in dis.get_instructions(bytecode):
        if "CALL" in insn.opname:
            raise RuntimeError(f"Type annotation should not contain calls, but '{stmt}' does")
    return eval(bytecode, glob, loc)  # type: ignore[arg-type] # noqa: P204


def parse_type_line(type_line, rcb, loc):
    """Parses a type annotation specified as a comment.

    Example inputs:
        # type: (Tensor, torch.Tensor) -> Tuple[Tensor]
        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tensor
    """
    arg_ann_str, ret_ann_str = split_type_line(type_line)

    try:
        arg_ann = _eval_no_call(arg_ann_str, {}, EvalEnv(rcb))
    except (NameError, SyntaxError) as e:
        raise RuntimeError("Failed to parse the argument list of a type annotation") from e

    if not isinstance(arg_ann, tuple):
        arg_ann = (arg_ann,)

    try:
        ret_ann = _eval_no_call(ret_ann_str, {}, EvalEnv(rcb))
    except (NameError, SyntaxError) as e:
        raise RuntimeError("Failed to parse the return type of a type annotation") from e

    arg_types = [ann_to_type(ann, loc) for ann in arg_ann]
    return arg_types, ann_to_type(ret_ann, loc)


def get_type_line(source):
    """Tries to find the line containing a comment with the type annotation."""
    type_comment = '# type:'

    lines = source.split('\n')
    lines = [(line_num, line) for line_num, line in enumerate(lines)]
    type_lines = list(filter(lambda line: type_comment in line[1], lines))
    # `type: ignore` comments may be needed in JIT'ed functions for mypy, due
    # to the hack in torch/_VF.py.

    # An ignore type comment can be of following format:
    #   1) type: ignore
    #   2) type: ignore[rule-code]
    # This ignore statement must be at the end of the line

    # adding an extra backslash before the space, to avoid triggering
    # one of the checks in .github/workflows/lint.yml
    type_pattern = re.compile("# type:\\ ignore(\\[[a-zA-Z-]+\\])?$")
    type_lines = list(filter(lambda line: not type_pattern.search(line[1]),
                             type_lines))

    if len(type_lines) == 0:
        # Catch common typo patterns like extra spaces, typo in 'ignore', etc.
        wrong_type_pattern = re.compile("#[\t ]*type[\t ]*(?!: ignore(\\[.*\\])?$):")
        wrong_type_lines = list(filter(lambda line: wrong_type_pattern.search(line[1]), lines))
        if len(wrong_type_lines) > 0:
            raise RuntimeError("The annotation prefix in line " + str(wrong_type_lines[0][0])
                               + " is probably invalid.\nIt must be '# type:'"
                               + "\nSee PEP 484 (https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)"  # noqa: B950
                               + "\nfor examples")
        return None
    elif len(type_lines) == 1:
        # Only 1 type line, quit now
        return type_lines[0][1].strip()

    # Parse split up argument types according to PEP 484
    # https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code
    return_line = None
    parameter_type_lines = []
    for line_num, line in type_lines:
        if '# type: (...) -> ' in line:
            return_line = (line_num, line)
            break
        elif type_comment in line:
            parameter_type_lines.append(line)
    if return_line is None:
        raise RuntimeError(
            "Return type line '# type: (...) -> ...' not found on multiline "
            "type annotation\nfor type lines:\n" +
            '\n'.join([line[1] for line in type_lines]) +
            "\n(See PEP 484 https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)")

    def get_parameter_type(line):
        item_type = line[line.find(type_comment) + len(type_comment):]
        return item_type.strip()

    types = map(get_parameter_type, parameter_type_lines)
    parameter_types = ", ".join(types)

    return return_line[1].replace("...", parameter_types)


def split_type_line(type_line):
    """Splits the comment with the type annotation into parts for argument and return types.

    For example, for an input of:
        # type: (Tensor, torch.Tensor) -> Tuple[Tensor, Tensor]

    This function will return:
        ("(Tensor, torch.Tensor)", "Tuple[Tensor, Tensor]")

    """
    start_offset = len('# type:')
    try:
        arrow_pos = type_line.index('->')
    except ValueError:
        raise RuntimeError("Syntax error in type annotation (cound't find `->`)") from None
    return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2:].strip()


def try_real_annotations(fn, loc):
    """Tries to use the Py3.5+ annotation syntax to get the type."""
    try:
        # Note: anything annotated as `Optional[T]` will automatically
        # be returned as `Union[T, None]` per
        # https://github.com/python/typing/blob/master/src/typing.py#L850
        sig = inspect.signature(fn)
    except ValueError:
        return None

    all_annots = [sig.return_annotation] + [p.annotation for p in sig.parameters.values()]
    if all(ann is sig.empty for ann in all_annots):
        return None

    arg_types = [ann_to_type(p.annotation, loc)
                 for p in sig.parameters.values()]
    return_type = ann_to_type(sig.return_annotation, loc)
    return arg_types, return_type


# Finds common type for enum values belonging to an Enum class. If not all
# values have the same type, AnyType is returned.
def get_enum_value_type(e: Type[enum.Enum], loc):
    enum_values: List[enum.Enum] = list(e)
    if not enum_values:
        raise ValueError(f"No enum values defined for: '{e.__class__}'")

    types = {type(v.value) for v in enum_values}
    ir_types = [try_ann_to_type(t, loc) for t in types]

    # If Enum values are of different types, an exception will be raised here.
    # Even though Python supports this case, we chose to not implement it to
    # avoid overcomplicate logic here for a rare use case. Please report a
    # feature request if you find it necessary.
    res = torch._C.unify_type_list(ir_types)
    if not res:
        return AnyType.get()
    return res

def is_tensor(ann):
    if issubclass(ann, torch.Tensor):
        return True

    if issubclass(ann, (torch.LongTensor, torch.DoubleTensor, torch.FloatTensor,
                        torch.IntTensor, torch.ShortTensor, torch.HalfTensor,
                        torch.CharTensor, torch.ByteTensor, torch.BoolTensor)):
        warnings.warn("TorchScript will treat type annotations of Tensor "
                      "dtype-specific subtypes as if they are normal Tensors. "
                      "dtype constraints are not enforced in compilation either.")
        return True

    return False



def try_ann_to_type(ann, loc):
    if ann is inspect.Signature.empty:
        return TensorType.getInferred()
    if ann is None:
        return NoneType.get()
    if inspect.isclass(ann) and is_tensor(ann):
        return TensorType.get()
    if is_tuple(ann):
        # Special case for the empty Tuple type annotation `Tuple[()]`
        if len(ann.__args__) == 1 and ann.__args__[0] == ():
            return TupleType([])
        return TupleType([try_ann_to_type(a, loc) for a in ann.__args__])
    if is_list(ann):
        elem_type = try_ann_to_type(ann.__args__[0], loc)
        if elem_type:
            return ListType(elem_type)
    if is_dict(ann):
        key = try_ann_to_type(ann.__args__[0], loc)
        value = try_ann_to_type(ann.__args__[1], loc)
        # Raise error if key or value is None
        if key is None:
            raise ValueError(f"Unknown type annotation: '{ann.__args__[0]}' at {loc.highlight()}")
        if value is None:
            raise ValueError(f"Unknown type annotation: '{ann.__args__[1]}' at {loc.highlight()}")
        return DictType(key, value)
    if is_optional(ann):
        if issubclass(ann.__args__[1], type(None)):
            contained = ann.__args__[0]
        else:
            contained = ann.__args__[1]
        valid_type = try_ann_to_type(contained, loc)
        msg = "Unsupported annotation {} could not be resolved because {} could not be resolved."
        assert valid_type, msg.format(repr(ann), repr(contained))
        return OptionalType(valid_type)
    if is_union(ann):
        # TODO: this is hack to recognize NumberType
        if set(ann.__args__) == {int, float, complex}:
            return NumberType.get()
        inner: List = []
        # We need these extra checks because both `None` and invalid
        # values will return `None`
        # TODO: Determine if the other cases need to be fixed as well
        for a in ann.__args__:
            if a is None:
                inner.append(NoneType.get())
            maybe_type = try_ann_to_type(a, loc)
            msg = "Unsupported annotation {} could not be resolved because {} could not be resolved."
            assert maybe_type, msg.format(repr(ann), repr(maybe_type))
            inner.append(maybe_type)
        return UnionType(inner)    # type: ignore[arg-type]
    if torch.distributed.rpc.is_available() and is_rref(ann):
        return RRefType(try_ann_to_type(ann.__args__[0], loc))
    if is_future(ann):
        return FutureType(try_ann_to_type(ann.__args__[0], loc))
    if is_await(ann):
        elementType = try_ann_to_type(ann.__args__[0], loc) if hasattr(ann, "__args__") else AnyType.get()
        return AwaitType(elementType)
    if ann is float:
        return FloatType.get()
    if ann is complex:
        return ComplexType.get()
    if ann is int:
        return IntType.get()
    if ann is str:
        return StringType.get()
    if ann is bool:
        return BoolType.get()
    if ann is Any:
        return AnyType.get()
    if ann is type(None):
        return NoneType.get()
    if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"):
        return InterfaceType(ann.__torch_script_interface__)
    if ann is torch.device:
        return DeviceObjType.get()
    if ann is torch.Stream:
        return StreamObjType.get()
    if ann is torch.dtype:
        return IntType.get()  # dtype not yet bound in as its own type
    if inspect.isclass(ann) and issubclass(ann, enum.Enum):
        if _get_script_class(ann) is None:
            scripted_class = torch.jit._script._recursive_compile_class(ann, loc)
            name = scripted_class.qualified_name()
        else:
            name = _qualified_name(ann)
        return EnumType(name, get_enum_value_type(ann, loc), list(ann))
    if inspect.isclass(ann):
        maybe_script_class = _get_script_class(ann)
        if maybe_script_class is not None:
            return maybe_script_class
        if torch._jit_internal.can_compile_class(ann):
            return torch.jit._script._recursive_compile_class(ann, loc)

    # Maybe resolve a NamedTuple to a Tuple Type
    def fake_rcb(key):
        return None
    return torch._C._resolve_type_from_object(ann, loc, fake_rcb)


def ann_to_type(ann, loc):
    the_type = try_ann_to_type(ann, loc)
    if the_type is not None:
        return the_type
    raise ValueError(f"Unknown type annotation: '{ann}' at {loc.highlight()}")


__all__ = [
    'Any',
    'List',
    'BroadcastingList1',
    'BroadcastingList2',
    'BroadcastingList3',
    'Tuple',
    'is_tuple',
    'is_list',
    'Dict',
    'is_dict',
    'is_optional',
    'is_union',
    'TensorType',
    'TupleType',
    'FloatType',
    'ComplexType',
    'IntType',
    'ListType',
    'StringType',
    'DictType',
    'AnyType',
    'Module',
    # TODO: Consider not exporting these during wildcard import (reserve
    # that for the types; for idiomatic typing code.)
    'get_signature',
    'check_fn',
    'get_param_names',
    'parse_type_line',
    'get_type_line',
    'split_type_line',
    'try_real_annotations',
    'try_ann_to_type',
    'ann_to_type',
]