Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
pytype / tools / traces / traces_test.py
Size: Mime:
"""Tests for traces.traces."""

import ast
import collections
import textwrap
from pytype import config
from pytype import file_utils
from pytype.pytd import pytd
from pytype.pytd import pytd_utils
from pytype.tests import test_utils
from pytype.tools.traces import traces
import unittest


class _NotImplementedVisitor(traces.MatchAstVisitor):

  def visit_Module(self, node):
    self.match(node)


class _TestVisitor(traces.MatchAstVisitor):

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.traces_by_node_type = collections.defaultdict(list)

  def generic_visit(self, node):
    try:
      matches = self.match(node)
    except NotImplementedError:
      return
    self.traces_by_node_type[node.__class__].extend(matches)


class TraceTest(unittest.TestCase):
  """Tests for traces.trace."""

  def test_traces(self):
    src = traces.trace("")
    trace, = src.traces[1]
    self.assertEqual(trace.op, "LOAD_CONST")
    self.assertIsNone(trace.symbol)
    pyval, = trace.types
    self.assertEqual(pyval.name, "builtins.NoneType")
    self.assertEqual(pyval.cls.name, "builtins.NoneType")

  def test_options(self):
    src = traces.trace("", config.Options.create("rumpelstiltskin"))
    self.assertEqual(src.filename, "rumpelstiltskin")

  def test_external_type(self):
    with file_utils.Tempdir() as d:
      pyi_path = d.create_file("foo.pyi", "class Foo: ...")
      imports_info = d.create_file("imports_info", "foo %s" % pyi_path)
      src = traces.trace(
          "import foo\nx = foo.Foo()",
          config.Options.create(imports_map=imports_info))
    trace, = (x for x in src.traces[2] if x.op == "STORE_NAME")
    pyval, = trace.types
    self.assertEqual(pyval.name, "foo.Foo")
    self.assertEqual(pyval.cls.name, "foo.Foo")

  def test_py3_class(self):
    src = traces.trace(textwrap.dedent("""
      class Foo:
        pass
    """).lstrip())
    trace, = (x for x in src.traces[1] if x.op == "LOAD_BUILD_CLASS")
    pyval, = trace.types
    self.assertEqual(pyval.name, "typing.Callable")

  def test_unknown(self):
    # pytype represents unannotated function parameters as unknowns. Make sure
    # unknowns don't appear in the traced types.
    src = traces.trace("def f(x): return x")
    trace = next(x for x in src.traces[1] if x.op == "LOAD_FAST")
    pyval, = trace.types
    self.assertIsInstance(pyval, pytd.AnythingType)


class MatchAstTestCase(unittest.TestCase):
  """Base class for testing traces.MatchAstVisitor."""

  def _parse(self, text, options=None):
    text = textwrap.dedent(text).lstrip()
    return ast.parse(text), traces.trace(text, options)

  def _get_traces(self, text, node_type, options=None):
    module, src = self._parse(text, options)
    v = _TestVisitor(src, ast)
    v.visit(module)
    return v.traces_by_node_type[node_type]

  def assertTracesEqual(self, actual_traces, expected_traces):
    self.assertEqual(len(actual_traces), len(expected_traces))
    for trace, expected_trace in zip(actual_traces, expected_traces):
      loc, trace = trace
      expected_loc, expected_op, expected_symbol, expected_annots = (
          expected_trace)
      self.assertEqual(loc, expected_loc)
      self.assertEqual(trace.op, expected_op)
      self.assertEqual(trace.symbol, expected_symbol)
      self.assertEqual(len(trace.types), len(expected_annots))
      for t, annot in zip(trace.types, expected_annots):
        self.assertEqual(pytd_utils.Print(t), annot)


class MatchAstVisitorTest(MatchAstTestCase):
  """Tests for traces.MatchAstVisitor."""

  def test_not_implemented(self):
    module, src = self._parse("")
    v = _NotImplementedVisitor(src, ast)
    with self.assertRaises(NotImplementedError):
      v.visit(module)

  def test_import(self):
    matches = self._get_traces("import os, sys as tzt", ast.Import)
    self.assertTracesEqual(matches, [
        ((1, 7), "IMPORT_NAME", "os", ("import os",)),
        ((1, 18), "STORE_NAME", "tzt", ("import sys",))])

  def test_import_from(self):
    matches = self._get_traces(
        "from os import path as p, environ", ast.ImportFrom)
    self.assertTracesEqual(matches, [
        ((1, 23), "STORE_NAME", "p", ("import os.path",)),
        ((1, 26), "STORE_NAME", "environ", ("os._Environ[str]",))])


class MatchAttributeTest(MatchAstTestCase):
  """Tests for traces.MatchAstVisit.match_Attribute."""

  def test_basic(self):
    matches = self._get_traces("""
      x = 0
      print(x.real)
    """, ast.Attribute)
    self.assertTracesEqual(matches, [
        ((2, 8), "LOAD_ATTR", "real", ("int", "int"))])

  def test_multi(self):
    matches = self._get_traces("""
      class Foo:
        real = True
      x = 0
      (Foo.real, x.real)
    """, ast.Attribute)
    # The second attribute is at the wrong location due to limitations of
    # source.Code.get_attr_location(), but we can at least test that we get the
    # right number of traces with the right types.
    self.assertTracesEqual(matches, [
        ((4, 5), "LOAD_ATTR", "real", ("Type[Foo]", "bool")),
        ((4, 5), "LOAD_ATTR", "real", ("int", "int"))])

  def test_property(self):
    matches = self._get_traces("""
      class Foo:
        @property
        def x(self):
          return 42
      v = Foo().x
    """, ast.Attribute)
    self.assertTracesEqual(matches, [
        ((5, 10), "LOAD_ATTR", "x", ("Foo", "int"))])


class MatchNameTest(MatchAstTestCase):
  """Tests for traces.MatchAstVisitor.match_Name."""

  def test_basic(self):
    matches = self._get_traces("x = 42", ast.Name)
    self.assertTracesEqual(matches, [((1, 0), "STORE_NAME", "x", ("int",))])

  def test_multiline(self):
    matches = self._get_traces("""
      x = (1 +
           2)
    """, ast.Name)
    self.assertTracesEqual(matches, [((1, 0), "STORE_NAME", "x", ("int",))])

  def test_multiline_subscr(self):
    matches = self._get_traces("""
      x = [0]
      x[0] = (1,
              2)
    """, ast.Name)
    x_annot = "List[Union[int, Tuple[int, int]]]"
    self.assertTracesEqual(matches, [((1, 0), "STORE_NAME", "x", (x_annot,)),
                                     ((2, 0), "LOAD_NAME", "x", (x_annot,))])


class MatchCallTest(MatchAstTestCase):
  """Tests for traces.MatchAstVisitor.match_Call."""

  def test_basic(self):
    matches = self._get_traces("""
      def f(x):
        return x + 1.0
      f(42)
    """, ast.Call)
    self.assertTracesEqual(matches, [
        ((3, 0), "CALL_FUNCTION", "f", ("Callable[[Any], Any]", "float"))])

  def test_chain(self):
    matches = self._get_traces("""
      class Foo:
        def f(self, x):
          return x
      Foo().f(42)
    """, ast.Call)
    self.assertTracesEqual(matches, [
        ((4, 0), "CALL_FUNCTION", "Foo", ("Type[Foo]", "Foo")),
        ((4, 0), "CALL_METHOD", "f", ("Callable[[Any], Any]", "int"))])

  def test_multiple_bindings(self):
    matches = self._get_traces("""
      class Foo:
        @staticmethod
        def f(x):
          return x
      class Bar:
        @staticmethod
        def f(x):
          return x + 1.0
      f = Foo.f if __random__ else Bar.f
      f(42)
    """, ast.Call)
    self.assertTracesEqual(matches, [
        ((10, 0), "CALL_FUNCTION", "f",
         ("Callable[[Any], Any]", "Union[int, float]"))])

  def test_bad_call(self):
    matches = self._get_traces("""
      def f(): pass
      f(42)
    """, ast.Call)
    self.assertTracesEqual(
        matches, [((2, 0), "CALL_FUNCTION", "f", ("Callable[[], Any]", "Any"))])

  def test_literal(self):
    matches = self._get_traces("''.upper()", ast.Call)
    self.assertTracesEqual(matches, [
        ((1, 0), "CALL_METHOD", "upper", ("Callable[[], str]", "str"))])

  def test_lookahead(self):
    matches = self._get_traces("""
      def f(x, y, z):
        return x + y + z
      f(
        0,
        1,
        2,
      )
    """, ast.Call)
    self.assertTracesEqual(matches, [
        ((3, 0), "CALL_FUNCTION", "f",
         ("Callable[[Any, Any, Any], Any]", "int"))])


class MatchConstantTest(MatchAstTestCase):

  @test_utils.skipFromPy((3, 8),
                         reason="numbers are ast.Constant in 3.8+")
  def test_num(self):
    matches = self._get_traces("v = 42", ast.Num)
    self.assertTracesEqual(matches, [((1, 4), "LOAD_CONST", 42, ("int",))])

  @test_utils.skipFromPy((3, 8),
                         reason="strings are ast.Constant in 3.8+")
  def test_str(self):
    matches = self._get_traces("v = 'hello'", ast.Str)
    self.assertTracesEqual(matches, [((1, 4), "LOAD_CONST", "hello", ("str",))])

  @test_utils.skipFromPy((3, 8),
                         reason="unicode strings are ast.Constant in 3.8+")
  def test_unicode(self):
    matches = self._get_traces("v = u'hello'", ast.Str)
    self.assertTracesEqual(matches, [((1, 4), "LOAD_CONST", "hello", ("str",))])

  @test_utils.skipFromPy((3, 8),
                         reason="byte strings are ast.Constant in 3.8+")
  def test_bytes(self):
    matches = self._get_traces("v = b'hello'", ast.Bytes)
    self.assertTracesEqual(
        matches, [((1, 4), "LOAD_CONST", b"hello", ("bytes",))])

  @test_utils.skipFromPy((3, 8),
                         reason="bools are ast.Constant in 3.8+")
  def test_bool(self):
    matches = self._get_traces("v = True", ast.NameConstant)
    self.assertTracesEqual(matches, [((1, 4), "LOAD_CONST", True, ("bool",))])

  @test_utils.skipFromPy((3, 8),
                         reason="ellipsis are ast.Constant in 3.8+")
  def test_ellipsis(self):
    matches = self._get_traces("v = ...", ast.Ellipsis)
    self.assertTracesEqual(
        matches, [((1, 4), "LOAD_CONST", Ellipsis, ("ellipsis",))])

  @test_utils.skipBeforePy((3, 8),
                           reason="numbers are ast.Constant in 3.8+")
  def test_num_38(self):
    matches = self._get_traces("v = 42", ast.Constant)
    self.assertTracesEqual(matches, [((1, 4), "LOAD_CONST", 42, ("int",))])

  @test_utils.skipBeforePy((3, 8),
                           reason="strings are ast.Constant in 3.8+")
  def test_str_38(self):
    matches = self._get_traces("v = 'hello'", ast.Constant)
    self.assertTracesEqual(matches, [((1, 4), "LOAD_CONST", "hello", ("str",))])

  @test_utils.skipBeforePy((3, 8),
                           reason="unicode strings are ast.Constant in 3.8+")
  def test_unicode_38(self):
    matches = self._get_traces("v = u'hello'", ast.Constant)
    self.assertTracesEqual(matches, [((1, 4), "LOAD_CONST", "hello", ("str",))])

  @test_utils.skipBeforePy((3, 8),
                           reason="byte strings are ast.Constant in 3.8+")
  def test_bytes_38(self):
    matches = self._get_traces("v = b'hello'", ast.Constant)
    self.assertTracesEqual(
        matches, [((1, 4), "LOAD_CONST", b"hello", ("bytes",))])

  @test_utils.skipBeforePy((3, 8),
                           reason="bools are ast.Constant in 3.8+")
  def test_bool_38(self):
    matches = self._get_traces("v = True", ast.Constant)
    self.assertTracesEqual(matches, [((1, 4), "LOAD_CONST", True, ("bool",))])

  @test_utils.skipBeforePy((3, 8),
                           reason="ellipsis are ast.Constant in 3.8+")
  def test_ellipsis_38(self):
    matches = self._get_traces("v = ...", ast.Constant)
    self.assertTracesEqual(
        matches, [((1, 4), "LOAD_CONST", Ellipsis, ("ellipsis",))])


class MatchSubscriptTest(MatchAstTestCase):

  def test_index(self):
    matches = self._get_traces("""
      v = "hello"
      print(v[0])
    """, ast.Subscript)
    self.assertTracesEqual(
        matches, [((2, 6), "BINARY_SUBSCR", "__getitem__", ("str",))])

  def test_simple_slice(self):
    matches = self._get_traces("""
      v = "hello"
      print(v[:-1])
    """, ast.Subscript)
    self.assertTracesEqual(
        matches, [((2, 6), "BINARY_SUBSCR", "__getitem__", ("str",))])

  def test_complex_slice(self):
    matches = self._get_traces("""
      v = "hello"
      print(v[0:4:2])
    """, ast.Subscript)
    self.assertTracesEqual(
        matches, [((2, 6), "BINARY_SUBSCR", "__getitem__", ("str",))])


class MatchBinOpTest(MatchAstTestCase):

  def test_modulo(self):
    matches = self._get_traces("""
      v = "hello %s"
      print(v % "world")
    """, ast.BinOp)
    self.assertTracesEqual(
        matches, [((2, 6), "BINARY_MODULO", "__mod__", ("str",))])

  def test_modulo_multiline_string(self):
    matches = self._get_traces("""
      ('%s'
       '%s' %
       (__any_object__,
        __any_object__))
    """, ast.BinOp)
    self.assertTracesEqual(
        matches, [((1, 1), "BINARY_MODULO", "__mod__", ("str",))])


class MatchLambdaTest(MatchAstTestCase):

  def test_basic(self):
    matches = self._get_traces("lambda x: x.upper()", ast.Lambda)
    sym = "<lambda>"
    self.assertTracesEqual(
        matches, [((1, 0), "MAKE_FUNCTION", sym, ("Callable[[Any], Any]",))])

  def test_function_locals(self):
    matches = self._get_traces("""
      def f():
        return lambda x: x.upper()
    """, ast.Lambda)
    sym = "f.<locals>.<lambda>"
    self.assertTracesEqual(
        matches, [((2, 9), "MAKE_FUNCTION", sym, ("Callable[[Any], Any]",))])

  def test_multiple_functions(self):
    matches = self._get_traces("""
      def f():
        return (w for w in range(3)), lambda x: x.upper(), lambda y, z: (y, z)
    """, ast.Lambda)
    sym = "f.<locals>.<lambda>"
    self.assertTracesEqual(
        matches, [
            ((2, 32), "MAKE_FUNCTION", sym, ("Callable[[Any], Any]",)),
            ((2, 53), "MAKE_FUNCTION", sym, ("Callable[[Any, Any], Any]",))])


if __name__ == "__main__":
  unittest.main()