Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
pytype / tools / traces / traces.py
Size: Mime:
"""A library for accessing pytype's inferred local types."""

import itertools
import re

from pytype import analyze
from pytype import config
from pytype import context
from pytype import errors
from pytype import load_pytd
from pytype.ast import visitor
from pytype.pytd import pytd
from pytype.pytd import pytd_utils
from pytype.pytd import visitors

from pytype.tools.traces import source

_ATTR_OPS = frozenset((
    "LOAD_ATTR",
    "LOAD_METHOD",
    "STORE_ATTR",
))

_CALL_OPS = frozenset((
    "CALL_FUNCTION",
    "CALL_FUNCTION_EX",
    "CALL_FUNCTION_KW",
    "CALL_FUNCTION_VAR",
    "CALL_FUNCTION_VAR_KW",
    "CALL_METHOD",
))

_LOAD_OPS = frozenset((
    "LOAD_DEREF",
    "LOAD_FAST",
    "LOAD_GLOBAL",
    "LOAD_NAME",
))

_LOAD_SUBSCR_METHODS = ("__getitem__", "__getslice__")
_LOAD_SUBSCR_OPS = frozenset((
    "BINARY_SUBSCR",
    "SLICE_0",
    "SLICE_1",
    "SLICE_2",
    "SLICE_3",
))

_STORE_OPS = frozenset((
    "STORE_DEREF",
    "STORE_FAST",
    "STORE_GLOBAL",
    "STORE_NAME",
))


class TypeTrace(source.AbstractTrace):
  """Traces of inferred type information."""


def trace(src, options=None):
  """Generates type traces for the given source code.

  Args:
    src: The source text.
    options: A pytype.config.Options object that can be used to specify options
      such as the target Python version.

  Returns:
    A source.Code object.
  """
  errorlog = errors.ErrorLog()
  options = options or config.Options.create()
  with config.verbosity_from(options):
    loader = load_pytd.create_loader(options)
    ctx = context.Context(
        errorlog=errorlog,
        options=options,
        generate_unknowns=options.protocols,
        loader=loader)
    pytd_module, _ = analyze.infer_types(
        src=src,
        filename=options.input,
        errorlog=errorlog,
        options=options,
        loader=loader,
        ctx=ctx)
    raw_traces = []
    for op, symbol, data in ctx.vm.opcode_traces:
      raw_traces.append(
          (op, symbol, tuple(_to_pytd(d, loader, pytd_module) for d in data)))
  return source.Code(src, raw_traces, TypeTrace, options.input)


def _to_pytd(datum, loader, ast):
  if not datum:
    return pytd.AnythingType()
  t = pytd_utils.JoinTypes(v.to_type() for v in datum).Visit(
      visitors.RemoveUnknownClasses())
  return loader.resolve_pytd(t, ast)


class _SymbolMatcher:
  """Symbol matcher for MatchAstVisitor._get_traces.

  Allows matching against:
    - a regular expression (will use re.match)
    - an arbitrary object (will use object equality)
    - a tuple of the above (will match if any member does)
  """

  @classmethod
  def from_one_match(cls, match):
    return cls((match,))

  @classmethod
  def from_tuple(cls, matches):
    return cls(matches)

  @classmethod
  def from_regex(cls, regex):
    return cls((re.compile(regex),))

  def __init__(self, matches):
    self._matches = matches

  def match(self, symbol):
    for match in self._matches:
      if isinstance(match, re.Pattern):
        if match.match(str(symbol)):
          return True
      elif match == symbol:
        return True
    return False


class MatchAstVisitor(visitor.BaseVisitor):
  """An AST visitor to match traces to nodes.

  Attributes:
    source: The source and trace information.
  """

  def __init__(self, src_code, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.source = src_code
    # Needed for x[i] = <multiline statement>
    self._assign_subscr = None
    # For tracking already matched traces
    self._matched = None

  def enter_Assign(self, node):
    if isinstance(node.targets[0], self._ast.Subscript):
      self._assign_subscr = node.targets[0].value

  def leave_Assign(self, _):
    self._assign_subscr = None

  def enter_Module(self, _):
    self._matched = set()

  def leave_Module(self, _):
    self._matched = None

  def match(self, node):
    """Gets the traces for the given node, along with their locations."""
    method = "match_" + node.__class__.__name__
    try:
      match = getattr(self, method)
    except AttributeError as e:
      raise NotImplementedError(method) from e
    return match(node)

  def match_Attribute(self, node):
    return [(self._get_match_location(node, tr.symbol), tr)
            for tr in self._get_traces(node.lineno, _ATTR_OPS, node.attr, 1)]

  def match_BinOp(self, node):
    if not isinstance(node.op, self._ast.Mod):
      raise NotImplementedError("match_Binop:%s" % node.op.__class__.__name__)
    op = "BINARY_MODULO"
    symbol = "__mod__"
    # The node's lineno is the first line of the operation, but the opcode's
    # lineno is the last line, so we look ahead to try to find the last line.
    # We do a long lookahead in order to support formatting of long strings.
    return [(self._get_match_location(node), tr) for tr in self._get_traces(
        node.lineno, [op], symbol, maxmatch=1, num_lines=10)]

  def match_Bytes(self, node):
    return self._match_constant(node, node.s)

  def match_Call(self, node):
    # When calling a method of a class, the node name is <value>.<method>, but
    # only the method name is traced.
    name = self._get_node_name(node).rpartition(".")[-1]
    # The node's lineno is the first line of the call, but the opcode's lineno
    # is the last line, so we look ahead to try to find the last line.
    return [(self._get_match_location(node), tr)
            for tr in self._get_traces(
                node.lineno, _CALL_OPS, name, maxmatch=1, num_lines=5)]

  def match_Constant(self, node):
    # As of Python 3.8, bools, numbers, bytes, strings, ellipsis etc are
    # all constants instead of individual ast nodes.
    return self._match_constant(node, node.s)

  def match_Ellipsis(self, node):
    return self._match_constant(node, Ellipsis)

  def match_FunctionDef(self, node):
    symbol = _SymbolMatcher.from_regex(r"(%s|None)" % node.name)
    return [
        (self._get_match_location(node, tr.symbol), tr)
        for tr in self._get_traces(node.lineno, ["MAKE_FUNCTION"], symbol, 1)
    ]

  def match_Import(self, node):
    return list(self._match_import(node, is_from=False))

  def match_ImportFrom(self, node):
    return list(self._match_import(node, is_from=True))

  def match_Lambda(self, node):
    sym = _SymbolMatcher.from_regex(r".*<lambda>$")
    return [(self._get_match_location(node), tr)
            for tr in self._get_traces(node.lineno, ["MAKE_FUNCTION"], sym, 1)]

  def match_Name(self, node):
    if isinstance(node.ctx, self._ast.Load):
      lineno = node.lineno
      ops = _LOAD_OPS
    elif isinstance(node.ctx, self._ast.Store):
      lineno = node.lineno
      ops = _STORE_OPS
    else:
      return []
    return [(self._get_match_location(node), tr)
            for tr in self._get_traces(lineno, ops, node.id, 1)]

  def match_NameConstant(self, node):
    return self._match_constant(node, node.value)

  def match_Num(self, node):
    return self._match_constant(node, node.n)

  def match_Str(self, node):
    return self._match_constant(node, node.s)

  def match_Subscript(self, node):
    return [(self._get_match_location(node), tr) for tr in self._get_traces(
        node.lineno, _LOAD_SUBSCR_OPS,
        _SymbolMatcher.from_tuple(_LOAD_SUBSCR_METHODS), 1)]

  def _get_traces(self, lineno, ops, symbol, maxmatch=-1, num_lines=1):
    """Yields matching traces.

    Args:
      lineno: A starting line number.
      ops: A list of opcode names to match on.
      symbol: A symbol or _SymbolMatcher instance to match on.
      maxmatch: The maximum number of traces to yield. -1 for no maximum.
      num_lines: The number of consecutive lines to search.
    """
    if not isinstance(symbol, _SymbolMatcher):
      symbol = _SymbolMatcher.from_one_match(symbol)
    for tr in itertools.chain.from_iterable(
        self.source.traces[line] for line in range(lineno, lineno + num_lines)):
      if maxmatch == 0:
        break
      if (id(tr) not in self._matched and tr.op in ops and
          symbol.match(tr.symbol)):
        maxmatch -= 1
        self._matched.add(id(tr))
        yield tr

  def _get_match_location(self, node, name=None):
    loc = source.Location(node.lineno, node.col_offset)
    if not name:
      return loc
    if isinstance(node, (self._ast.Import, self._ast.ImportFrom)):
      # Search for imported module names
      m = re.search("[ ,]" + name + r"\b", self.source.line(node.lineno))
      if m is not None:
        c, _ = m.span()
        return source.Location(node.lineno, c + 1)
    elif isinstance(node, self._ast.Attribute):
      attr_loc, _ = self.source.get_attr_location(name, loc)
      return attr_loc
    return loc

  def _get_node_name(self, node):
    if isinstance(node, self._ast.Attribute):
      return "{}.{}".format(self._get_node_name(node.value), node.attr)
    elif isinstance(node, self._ast.Call):
      return self._get_node_name(node.func)
    elif isinstance(node, self._ast.Lambda):
      return "<lambda>"
    elif isinstance(node, self._ast.Name):
      return node.id
    else:
      return node.__class__.__name__

  def _match_constant(self, node, value):
    return [(self._get_match_location(node), tr)
            for tr in self._get_traces(node.lineno, ["LOAD_CONST"], value, 1)]

  def _match_import(self, node, is_from):
    for alias in node.names:
      name = alias.asname if alias.asname else alias.name
      op = "STORE_NAME" if alias.asname or is_from else "IMPORT_NAME"
      for tr in self._get_traces(node.lineno, [op], name, 1):
        yield self._get_match_location(node, name), tr


class _LineNumberVisitor(visitor.BaseVisitor):

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.line = 0

  def generic_visit(self, node):
    lineno = getattr(node, "lineno", 0)
    if lineno > self.line:
      self.line = lineno