Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
pytype / tracer_vm.py
Size: Mime:
"""Code for checking and inferring types."""

import collections
import logging
import re
from typing import Any, Dict, Union

from pytype import special_builtins
from pytype import state as frame_state
from pytype import vm
from pytype.abstract import abstract
from pytype.abstract import abstract_utils
from pytype.abstract import function
from pytype.overlays import typing_overlay
from pytype.pytd import escape
from pytype.pytd import optimize
from pytype.pytd import pytd
from pytype.pytd import pytd_utils
from pytype.pytd import visitors
from pytype.typegraph import cfg

log = logging.getLogger(__name__)

# Most interpreter functions (including lambdas) need to be analyzed as
# stand-alone functions. The exceptions are comprehensions and generators, which
# have names like "<listcomp>" and "<genexpr>".
_SKIP_FUNCTION_RE = re.compile("<(?!lambda).+>$")


_CallRecord = collections.namedtuple("_CallRecord", [
    "node", "function", "signatures", "positional_arguments",
    "keyword_arguments", "return_value"
])


class _Initializing:
  pass


class CallTracer(vm.VirtualMachine):
  """Virtual machine that records all function calls."""

  _CONSTRUCTORS = ("__new__", "__init__")

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self._unknowns = {}
    self._calls = set()
    self._method_calls = set()
    # Used by init_class.
    self._instance_cache: Dict[Any, Union[_Initializing, cfg.Variable]] = {}
    # Used by call_init. Can differ from _instance_cache because we also call
    # __init__ on classes not initialized via init_class.
    self._initialized_instances = set()
    self._interpreter_functions = []
    self._interpreter_classes = []
    self._analyzed_functions = set()
    self._analyzed_classes = set()
    self._generated_classes = {}

  def create_varargs(self, node):
    value = abstract.Instance(self.ctx.convert.tuple_type, self.ctx)
    value.merge_instance_type_parameter(
        node, abstract_utils.T, self.ctx.convert.create_new_unknown(node))
    return value.to_variable(node)

  def create_kwargs(self, node):
    key_type = self.ctx.convert.primitive_class_instances[str].to_variable(node)
    value_type = self.ctx.convert.create_new_unknown(node)
    kwargs = abstract.Instance(self.ctx.convert.dict_type, self.ctx)
    kwargs.merge_instance_type_parameter(node, abstract_utils.K, key_type)
    kwargs.merge_instance_type_parameter(node, abstract_utils.V, value_type)
    return kwargs.to_variable(node)

  def create_method_arguments(self, node, method, use_defaults=False):
    """Create arguments for the given method.

    Creates Unknown objects as arguments for the given method. Note that we
    don't need to take parameter annotations into account as
    InterpreterFunction.call() will take care of that.

    Args:
      node: The current node.
      method: An abstract.InterpreterFunction.
      use_defaults: Whether to use parameter defaults for arguments. When True,
        unknown arguments are created with force=False, as it is fine to use
        Unsolvable rather than Unknown objects for type-checking defaults.

    Returns:
      A tuple of a node and a function.Args object.
    """
    args = []
    num_posargs = method.argcount(node)
    num_posargs_no_default = num_posargs - len(method.defaults)
    for i in range(num_posargs):
      default_idx = i - num_posargs_no_default
      if use_defaults and default_idx >= 0:
        arg = method.defaults[default_idx]
      else:
        arg = self.ctx.convert.create_new_unknown(node, force=not use_defaults)
      args.append(arg)
    kws = {}
    for key in method.signature.kwonly_params:
      if use_defaults and key in method.kw_defaults:
        kws[key] = method.kw_defaults[key]
      else:
        kws[key] = self.ctx.convert.create_new_unknown(
            node, force=not use_defaults)
    starargs = self.create_varargs(node) if method.has_varargs() else None
    starstarargs = self.create_kwargs(node) if method.has_kwargs() else None
    return node, function.Args(posargs=tuple(args),
                               namedargs=kws,
                               starargs=starargs,
                               starstarargs=starstarargs)

  def call_function_with_args(self, node, val, args):
    """Call a function.

    Args:
      node: The given node.
      val: A cfg.Binding containing the function.
      args: A function.Args object.

    Returns:
      A tuple of (1) a node and (2) a cfg.Variable of the return value.
    """
    assert isinstance(val.data, abstract.INTERPRETER_FUNCTION_TYPES)
    with val.data.record_calls():
      new_node, ret = self._call_function_in_frame(node, val, *args)
    return new_node, ret

  def _call_function_in_frame(self, node, val, args, kwargs,
                              starargs, starstarargs):
    # Try to get the function opcode with position information to construct the
    # frame, so that we have the data for any error messages raised before we
    # get to func.call().
    fn = val.data
    if isinstance(fn, abstract.BoundInterpreterFunction):
      # Note that using fn.underlying.def_opcode here would lead to over-caching
      # in Class._new_instance().
      opcode = None
      f_globals = fn.underlying.f_globals
    else:
      assert isinstance(fn, abstract.InterpreterFunction)
      opcode = fn.def_opcode
      f_globals = fn.f_globals
    frame = frame_state.SimpleFrame(
        node=node, opcode=opcode, f_globals=f_globals)
    # We only want this frame to show up in errors if it's at the top of the
    # stack (i.e. we haven't started analysing the actual function yet)
    frame.skip_in_tracebacks = True
    self.push_frame(frame)
    log.info("Analyzing %r", fn.name)
    state = frame_state.FrameState.init(node, self.ctx)
    state, ret = self.call_function_with_state(
        state, val.AssignToNewVariable(node), args, kwargs, starargs,
        starstarargs)
    self.pop_frame(frame)
    return state.node, ret

  def _maybe_fix_classmethod_cls_arg(self, node, cls, func, args):
    sig = func.signature
    if (args.posargs and sig.param_names and
        (sig.param_names[0] not in sig.annotations)):
      # fix "cls" parameter
      return args._replace(
          posargs=(cls.AssignToNewVariable(node),) + args.posargs[1:])
    else:
      return args

  def maybe_analyze_method(self, node, val, cls=None):
    method = val.data
    fname = val.data.name
    if isinstance(method, abstract.INTERPRETER_FUNCTION_TYPES):
      self._analyzed_functions.add(method.get_first_opcode())
      if (not self.ctx.options.analyze_annotated and
          (method.signature.has_return_annotation or method.has_overloads) and
          fname.rsplit(".", 1)[-1] not in self._CONSTRUCTORS):
        log.info("%r has annotations, not analyzing further.", fname)
      else:
        for f in method.iter_signature_functions():
          node, args = self.create_method_arguments(node, f)
          if f.is_classmethod and cls:
            args = self._maybe_fix_classmethod_cls_arg(node, cls, f, args)
          node, _ = self.call_function_with_args(node, val, args)
    return node

  def call_with_fake_args(self, node0, funcv):
    """Attempt to call the given function with made-up arguments."""
    # Note that this should only be used for functions that raised a
    # FailedFunctionCall error. This is not guaranteed to successfully call a
    # function that raised DictKeyMissing instead.
    nodes = []
    rets = []
    for funcb in funcv.bindings:
      func = funcb.data
      log.info("Trying %s with fake arguments", func)

      if isinstance(func, abstract.INTERPRETER_FUNCTION_TYPES):
        node1, args = self.create_method_arguments(node0, func)
        # Once the args are generated, try calling the function.
        # call_function will check fallback_to_unsolvable if a DictKeyMissing or
        # FailedFunctionCall error is raised when the target function is called.
        # DictKeyMissing doesn't trigger call_with_fake_args, so that shouldn't
        # be raised again, and generating fake arguments should avoid any
        # FailedFunctionCall errors. To prevent an infinite recursion loop, set
        # fallback_to_unsolvable to False just in case.
        # This means any additional errors that may be raised will be passed to
        # the call_function that called this method in the first place.
        node2, ret = function.call_function(
            self.ctx,
            node1,
            funcb.AssignToNewVariable(),
            args,
            fallback_to_unsolvable=False)
        nodes.append(node2)
        rets.append(ret)

    if nodes:
      ret = self.ctx.join_variables(node0, rets)
      node = self.ctx.join_cfg_nodes(nodes)
      if ret.bindings:
        return node, ret
    else:
      node = node0
    log.info("Unable to generate fake arguments for %s", funcv)
    return node, self.ctx.new_unsolvable(node)

  def analyze_method_var(self, node0, name, var, cls=None):
    log.info("Analyzing %s", name)
    node1 = node0.ConnectNew(name)
    for val in var.bindings:
      node2 = self.maybe_analyze_method(node1, val, cls)
      node2.ConnectTo(node0)
    return node0

  def bind_method(self, node, methodvar, instance_var):
    bound = self.ctx.program.NewVariable()
    for m in methodvar.Data(node):
      if isinstance(m, special_builtins.ClassMethodInstance):
        m = m.func.data[0]
        is_cls = True
      else:
        is_cls = (isinstance(m, abstract.InterpreterFunction) and
                  m.is_classmethod)
      bound.AddBinding(m.property_get(instance_var, is_cls), [], node)
    return bound

  def _instantiate_binding(self, node0, cls, container):
    """Instantiate a class binding."""
    node1, new = cls.data.get_own_new(node0, cls)
    if not new or (
        any(not isinstance(f, abstract.InterpreterFunction) for f in new.data)):
      # This assumes that any inherited __new__ method defined in a pyi file
      # returns an instance of the current class.
      return node0, cls.data.instantiate(node0, container=container)
    instance = self.ctx.program.NewVariable()
    nodes = []
    for b in new.bindings:
      self._analyzed_functions.add(b.data.get_first_opcode())
      node2, args = self.create_method_arguments(node1, b.data)
      args = self._maybe_fix_classmethod_cls_arg(node0, cls, b.data, args)
      node3 = node2.ConnectNew()
      node4, ret = self.call_function_with_args(node3, b, args)
      instance.PasteVariable(ret)
      nodes.append(node4)
    return self.ctx.join_cfg_nodes(nodes), instance

  def _instantiate_var(self, node, clsv, container):
    """Build an (dummy) instance from a class, for analyzing it."""
    n = self.ctx.program.NewVariable()
    for cls in clsv.Bindings(node, strict=False):
      node, var = self._instantiate_binding(node, cls, container)
      n.PasteVariable(var)
    return node, n

  def _mark_maybe_missing_members(self, values):
    """Set maybe_missing_members to True on these values and their type params.

    Args:
      values: A list of BaseValue objects. On every instance among the values,
        recursively set maybe_missing_members to True on the instance and its
        type parameters.
    """
    values = list(values)
    seen = set()
    while values:
      v = values.pop(0)
      if v not in seen:
        seen.add(v)
        if isinstance(v, abstract.SimpleValue):
          v.maybe_missing_members = True
          for child in v.instance_type_parameters.values():
            values.extend(child.data)

  def init_class(self, node, cls, container=None, extra_key=None):
    """Instantiate a class, and also call __init__.

    Calling __init__ can be expensive, so this method caches its created
    instances. If you don't need __init__ called, use cls.instantiate instead.

    Args:
      node: The current node.
      cls: The class to instantiate.
      container: Optionally, a container to pass to the class's instantiate()
        method, so that type parameters in the container's template are
        instantiated to TypeParameterInstance.
      extra_key: Optionally, extra information about the location at which the
        instantion occurs. By default, this method keys on the current opcode
        and the class, which sometimes isn't enough to disambiguate callers that
        shouldn't get back the same cached instance.

    Returns:
      A tuple of node and instance variable.
    """
    key = (self.frame and self.frame.current_opcode, extra_key, cls)
    instance = self._instance_cache.get(key)
    if not instance or isinstance(instance, _Initializing):
      clsvar = cls.to_variable(node)
      node, instance = self._instantiate_var(node, clsvar, container)
      if key in self._instance_cache:
        # We've encountered a recursive pattern such as
        # class A:
        #   def __init__(self, x: "A"): ...
        # Calling __init__ again would lead to an infinite loop, so
        # we instead create an incomplete instance that will be
        # overwritten later. Note that we have to create a new
        # instance rather than using the one that we're already in
        # the process of initializing - otherwise, setting
        # maybe_missing_members to True would cause pytype to ignore
        # all attribute errors on self in __init__.
        self._mark_maybe_missing_members(instance.data)
      else:
        self._instance_cache[key] = _Initializing()
        node = self.call_init(node, instance)
      self._instance_cache[key] = instance
    return node, instance

  def _call_method(self, node, binding, method_name):
    node, method = self.ctx.attribute_handler.get_attribute(
        node, binding.data.cls, method_name, binding)
    if method:
      bound_method = self.bind_method(
          node, method, binding.AssignToNewVariable())
      node = self.analyze_method_var(node, method_name, bound_method)
    return node

  def _call_init_on_binding(self, node, b):
    if isinstance(b.data, abstract.SimpleValue):
      for param in b.data.instance_type_parameters.values():
        node = self.call_init(node, param)
    node = self._call_method(node, b, "__init__")
    cls = b.data.cls
    if isinstance(cls, abstract.InterpreterClass):
      # Call any additional initalizers the class has registered.
      for method in cls.additional_init_methods:
        node = self._call_method(node, b, method)
    return node

  def call_init(self, node, instance):
    # Call __init__ on each binding.
    for b in instance.bindings:
      if b.data in self._initialized_instances:
        continue
      self._initialized_instances.add(b.data)
      node = self._call_init_on_binding(node, b)
    return node

  def reinitialize_if_initialized(self, node, instance):
    if instance in self._initialized_instances:
      self._call_init_on_binding(node, instance.to_binding(node))

  def analyze_class(self, node, val):
    self._analyzed_classes.add(val.data)
    node, instance = self.init_class(node, val.data)
    good_instances = [b for b in instance.bindings if val.data == b.data.cls]
    if not good_instances:
      # __new__ returned something that's not an instance of our class.
      instance = val.data.instantiate(node)
      node = self.call_init(node, instance)
    elif len(good_instances) != len(instance.bindings):
      # __new__ returned some extra possibilities we don't need.
      instance = self.ctx.join_bindings(node, good_instances)
    for instance_value in instance.data:
      val.data.register_canonical_instance(instance_value)
    methods = sorted(val.data.members.items())
    while methods:
      name, methodvar = methods.pop(0)
      if name in self._CONSTRUCTORS:
        continue  # We already called this method during initialization.
      for v in methodvar.data:
        if isinstance(v, special_builtins.PropertyInstance):
          for m in (v.fget, v.fset, v.fdel):
            if m:
              methods.insert(0, (name, m))
      b = self.bind_method(node, methodvar, instance)
      node = self.analyze_method_var(node, name, b, val)
    return node

  def analyze_function(self, node0, val):
    if val.data.is_attribute_of_class:
      # We'll analyze this function as part of a class.
      log.info("Analyze functions: Skipping class method %s", val.data.name)
    else:
      node1 = node0.ConnectNew(val.data.name)
      node2 = self.maybe_analyze_method(node1, val)
      node2.ConnectTo(node0)
    return node0

  def _should_analyze_as_interpreter_function(self, data):
    # We record analyzed functions by opcode rather than function object. The
    # two ways of recording are equivalent except for closures, which are
    # re-generated when the variables they close over change, but we don't want
    # to re-analyze them.
    return (isinstance(data, abstract.InterpreterFunction) and
            not data.is_overload and
            not data.is_class_builder and
            data.get_first_opcode() not in self._analyzed_functions and
            not _SKIP_FUNCTION_RE.search(data.name))

  def analyze_toplevel(self, node, defs):
    for name, var in sorted(defs.items()):  # sort, for determinicity
      if not self._is_typing_member(name, var):
        for value in var.bindings:
          if isinstance(value.data, abstract.InterpreterClass):
            new_node = self.analyze_class(node, value)
          elif (isinstance(value.data, abstract.INTERPRETER_FUNCTION_TYPES) and
                not value.data.is_overload):
            new_node = self.analyze_function(node, value)
          else:
            continue
          if new_node is not node:
            new_node.ConnectTo(node)
    # Now go through all functions and classes we haven't analyzed yet.
    # These are typically hidden under a decorator.
    # Go through classes first so that the `is_attribute_of_class` will
    # be set for all functions in class.
    for c in self._interpreter_classes:
      for value in c.bindings:
        if (isinstance(value.data, abstract.InterpreterClass) and
            value.data not in self._analyzed_classes):
          node = self.analyze_class(node, value)
    for f in self._interpreter_functions:
      for value in f.bindings:
        if self._should_analyze_as_interpreter_function(value.data):
          node = self.analyze_function(node, value)
    for func, opcode in self.functions_type_params_check:
      func.signature.check_type_parameter_count(self.simple_stack(opcode))
    return node

  def analyze(self, node, defs, maximum_depth):
    assert not self.frame
    self._maximum_depth = maximum_depth
    self._analyzing = True
    node = node.ConnectNew(name="Analyze")
    return self.analyze_toplevel(node, defs)

  def trace_unknown(self, name, unknown_binding):
    self._unknowns[name] = unknown_binding

  def trace_call(self, node, func, sigs, posargs, namedargs, result):
    """Add an entry into the call trace.

    Args:
      node: The CFG node right after this function call.
      func: A cfg.Binding of a function that was called.
      sigs: The signatures that the function might have been called with.
      posargs: The positional arguments, an iterable over cfg.Value.
      namedargs: The keyword arguments, a dict mapping str to cfg.Value.
      result: A Variable of the possible result values.
    """
    log.debug("Logging call to %r with %d args, return %r",
              func, len(posargs), result)
    args = tuple(posargs)
    kwargs = tuple((namedargs or {}).items())
    record = _CallRecord(node, func, sigs, args, kwargs, result)
    if isinstance(func.data, abstract.BoundPyTDFunction):
      self._method_calls.add(record)
    elif isinstance(func.data, abstract.PyTDFunction):
      self._calls.add(record)

  def trace_functiondef(self, f):
    self._interpreter_functions.append(f)

  def trace_classdef(self, c):
    self._interpreter_classes.append(c)

  def trace_namedtuple(self, nt):
    # All namedtuple instances with the same name are equal, so it's fine to
    # overwrite previous instances.
    self._generated_classes[nt.name] = nt

  def pytd_classes_for_unknowns(self):
    classes = []
    for name, val in self._unknowns.items():
      if val in val.variable.Filter(self.ctx.exitpoint, strict=False):
        classes.append(val.data.to_structural_def(self.ctx.exitpoint, name))
    return classes

  def pytd_for_types(self, defs):
    # If a variable is annotated, we'll always output that type.
    annotated_names = set()
    data = []
    annots = abstract_utils.get_annotations_dict(defs)
    for name, t in self.ctx.pytd_convert.annotations_to_instance_types(
        self.ctx.exitpoint, annots):
      annotated_names.add(name)
      data.append(pytd.Constant(name, t))
    for name, var in defs.items():
      if (name in abstract_utils.TOP_LEVEL_IGNORE or name in annotated_names or
          self._is_typing_member(name, var)):
        continue
      options = var.FilteredData(self.ctx.exitpoint, strict=False)
      if (len(options) > 1 and
          not all(isinstance(o, abstract.FUNCTION_TYPES) for o in options)):
        if all(isinstance(o, abstract.TypeParameter) for o in options):
          pytd_def = pytd_utils.JoinTypes(
              t.to_pytd_def(self.ctx.exitpoint, name) for t in options)
          if isinstance(pytd_def, pytd.TypeParameter):
            data.append(pytd_def)
          else:
            # We have multiple definitions for the same TypeVar name. There's no
            # good way to handle this.
            data.append(pytd.Constant(name, pytd.AnythingType()))
        elif all(isinstance(o, (abstract.ParameterizedClass, abstract.Union))
                 for o in options):  # type alias
          pytd_def = pytd_utils.JoinTypes(
              t.to_pytd_def(self.ctx.exitpoint, name).type for t in options)
          data.append(pytd.Alias(name, pytd_def))
        else:
          # It's ambiguous whether this is a type, a function or something
          # else, so encode it as a constant.
          combined_types = pytd_utils.JoinTypes(
              t.to_type(self.ctx.exitpoint) for t in options)
          data.append(pytd.Constant(name, combined_types))
      elif options:
        for option in options:
          try:
            d = option.to_pytd_def(self.ctx.exitpoint, name)  # Deep definition
          except NotImplementedError:
            d = option.to_type(self.ctx.exitpoint)  # Type only
            if isinstance(d, pytd.NothingType):
              if isinstance(option, abstract.Empty):
                d = pytd.AnythingType()
              else:
                assert isinstance(option, typing_overlay.NoReturn)
          if isinstance(d, pytd.Type) and not isinstance(d, pytd.TypeParameter):
            data.append(pytd.Constant(name, d))
          else:
            data.append(d)
      else:
        log.error("No visible options for %s", name)
        data.append(pytd.Constant(name, pytd.AnythingType()))
    return pytd_utils.WrapTypeDeclUnit("inferred", data)

  @staticmethod
  def _call_traces_to_function(call_traces, name_transform=lambda x: x):
    funcs = collections.defaultdict(pytd_utils.OrderedSet)
    for node, func, sigs, args, kws, retvar in call_traces:
      # The lengths may be different in the presence of optional and kw args.
      arg_names = max((sig.get_positional_names() for sig in sigs), key=len)
      for i in range(len(arg_names)):
        if not isinstance(func.data, abstract.BoundFunction) or i > 0:
          arg_names[i] = function.argname(i)
      arg_types = (a.data.to_type(node) for a in args)
      ret = pytd_utils.JoinTypes(t.to_type(node) for t in retvar.data)
      starargs = None
      starstarargs = None
      funcs[func.data.name].add(pytd.Signature(
          tuple(pytd.Parameter(n, t, False, False, None)
                for n, t in zip(arg_names, arg_types)) +
          tuple(pytd.Parameter(name, a.data.to_type(node), False, False, None)
                for name, a in kws),
          starargs, starstarargs,
          ret, exceptions=(), template=()))
    functions = []
    for name, signatures in funcs.items():
      functions.append(pytd.Function(name_transform(name), tuple(signatures),
                                     pytd.MethodTypes.METHOD))
    return functions

  def _is_typing_member(self, name, var):
    for module_name in ("typing", "typing_extensions"):
      if module_name not in self.loaded_overlays:
        continue
      overlay = self.loaded_overlays[module_name]
      if overlay:
        module = overlay.get_module(name)
        if name in module.members and module.members[name].data == var.data:
          return True
    return False

  def pytd_functions_for_call_traces(self):
    return self._call_traces_to_function(self._calls, escape.pack_partial)

  def pytd_classes_for_call_traces(self):
    class_to_records = collections.defaultdict(list)
    for call_record in self._method_calls:
      args = call_record.positional_arguments
      if not any(isinstance(a.data, abstract.Unknown) for a in args):
        # We don't need to record call signatures that don't involve
        # unknowns - there's nothing to solve for.
        continue
      cls = args[0].data.cls
      if isinstance(cls, abstract.PyTDClass):
        class_to_records[cls].append(call_record)
    classes = []
    for cls, call_records in class_to_records.items():
      full_name = cls.module + "." + cls.name if cls.module else cls.name
      classes.append(pytd.Class(
          name=escape.pack_partial(full_name),
          metaclass=None,
          bases=(pytd.NamedType("builtins.object"),),  # not used in solver
          methods=tuple(self._call_traces_to_function(call_records)),
          constants=(),
          classes=(),
          decorators=(),
          slots=None,
          template=(),
      ))
    return classes

  def pytd_classes_for_namedtuple_instances(self):
    return tuple(v.generate_ast() for v in self._generated_classes.values())

  def compute_types(self, defs):
    classes = (tuple(self.pytd_classes_for_unknowns()) +
               tuple(self.pytd_classes_for_call_traces()) +
               self.pytd_classes_for_namedtuple_instances())
    functions = tuple(self.pytd_functions_for_call_traces())
    aliases = ()  # aliases are instead recorded as constants
    ty = pytd_utils.Concat(
        self.pytd_for_types(defs),
        pytd_utils.CreateModule("unknowns", classes=classes,
                                functions=functions, aliases=aliases))
    ty = ty.Visit(optimize.CombineReturnsAndExceptions())
    ty = ty.Visit(optimize.PullInMethodClasses())
    ty = ty.Visit(
        visitors.DefaceUnresolved([ty, self.ctx.loader.concat_all()],
                                  escape.UNKNOWN))
    return ty.Visit(visitors.AdjustTypeParameters())

  def _check_return(self, node, actual, formal):
    if not self.ctx.options.report_errors:
      return True
    bad = self.ctx.matcher(node).bad_matches(actual, formal)
    if bad:
      self.ctx.errorlog.bad_return_type(self.frames, node, formal, actual, bad)
    return not bad