Repository URL to install this package:
|
Version:
2022.2.8 ▾
|
import os
import textwrap
from pytype import file_utils
from pytype.pyi import parser
from pytype.pytd import pytd
from pytype.pytd import pytd_utils
from pytype.pytd import serialize_ast
from pytype.pytd import visitors
from pytype.pytd.parse import parser_test_base
import unittest
class TestUtils(parser_test_base.ParserTest):
"""Test pytype.pytd.pytd_utils."""
def test_unpack_union(self):
"""Test for UnpackUnion."""
ast = self.Parse("""
from typing import Union
c1 = ... # type: Union[int, float]
c2 = ... # type: int
c3 = ... # type: list[Union[int, float]]""")
c1 = ast.Lookup("c1").type
c2 = ast.Lookup("c2").type
c3 = ast.Lookup("c3").type
self.assertCountEqual(pytd_utils.UnpackUnion(c1), c1.type_list)
self.assertCountEqual(pytd_utils.UnpackUnion(c2), [c2])
self.assertCountEqual(pytd_utils.UnpackUnion(c3), [c3])
def test_concat(self):
"""Test for concatenating two pytd ASTs."""
ast1 = self.Parse("""
c1 = ... # type: int
def f1() -> int: ...
class Class1:
pass
""")
ast2 = self.Parse("""
c2 = ... # type: int
def f2() -> int: ...
class Class2:
pass
""")
expected = textwrap.dedent("""
c1 = ... # type: int
c2 = ... # type: int
def f1() -> int: ...
def f2() -> int: ...
class Class1:
pass
class Class2:
pass
""")
combined = pytd_utils.Concat(ast1, ast2)
self.AssertSourceEquals(combined, expected)
def test_concat3(self):
"""Test for concatenating three pytd ASTs."""
ast1 = self.Parse("""c1 = ... # type: int""")
ast2 = self.Parse("""c2 = ... # type: float""")
ast3 = self.Parse("""c3 = ... # type: bool""")
combined = pytd_utils.Concat(ast1, ast2, ast3)
expected = textwrap.dedent("""
c1 = ... # type: int
c2 = ... # type: float
c3 = ... # type: bool
""")
self.AssertSourceEquals(combined, expected)
def test_concat_type_parameters(self):
"""Test for concatenating ASTs with type parameters."""
ast1 = self.Parse("""T = TypeVar("T")""", name="builtins")
ast2 = self.Parse("""T = TypeVar("T")""")
combined = pytd_utils.Concat(ast1, ast2)
self.assertEqual(combined.Lookup("builtins.T"),
pytd.TypeParameter("T", scope="builtins"))
self.assertEqual(combined.Lookup("T"), pytd.TypeParameter("T", scope=None))
def test_join_types(self):
"""Test that JoinTypes() does recursive flattening."""
n1, n2, n3, n4, n5, n6 = [pytd.NamedType("n%d" % i) for i in range(6)]
# n1 or (n2 or (n3))
nested1 = pytd.UnionType((n1, pytd.UnionType((n2, pytd.UnionType((n3,))))))
# ((n4) or n5) or n6
nested2 = pytd.UnionType((pytd.UnionType((pytd.UnionType((n4,)), n5)), n6))
joined = pytd_utils.JoinTypes([nested1, nested2])
self.assertEqual(joined.type_list,
(n1, n2, n3, n4, n5, n6))
def test_join_single_type(self):
"""Test that JoinTypes() returns single types as-is."""
a = pytd.NamedType("a")
self.assertEqual(pytd_utils.JoinTypes([a]), a)
self.assertEqual(pytd_utils.JoinTypes([a, a]), a)
def test_join_nothing_type(self):
"""Test that JoinTypes() removes or collapses 'nothing'."""
a = pytd.NamedType("a")
nothing = pytd.NothingType()
self.assertEqual(pytd_utils.JoinTypes([a, nothing]), a)
self.assertEqual(pytd_utils.JoinTypes([nothing]), nothing)
self.assertEqual(pytd_utils.JoinTypes([nothing, nothing]), nothing)
def test_join_empty_types_to_nothing(self):
"""Test that JoinTypes() simplifies empty unions to 'nothing'."""
self.assertIsInstance(pytd_utils.JoinTypes([]), pytd.NothingType)
def test_join_anything_types(self):
"""Test that JoinTypes() simplifies unions containing 'Any'."""
types = [pytd.AnythingType(), pytd.NamedType("a")]
self.assertIsInstance(pytd_utils.JoinTypes(types), pytd.AnythingType)
def test_join_optional_anything_types(self):
"""Test that JoinTypes() simplifies unions containing 'Any' and 'None'."""
any_type = pytd.AnythingType()
none_type = pytd.NamedType("builtins.NoneType")
types = [pytd.NamedType("a"), any_type, none_type]
self.assertEqual(pytd_utils.JoinTypes(types),
pytd.UnionType((any_type, none_type)))
def test_type_matcher(self):
"""Test for the TypeMatcher class."""
class MyTypeMatcher(pytd_utils.TypeMatcher):
def default_match(self, t1, t2, mykeyword):
assert mykeyword == "foobar"
return t1 == t2
def match_Function_against_Function(self, f1, f2, mykeyword):
assert mykeyword == "foobar"
return all(self.match(sig1, sig2, mykeyword)
for sig1, sig2 in zip(f1.signatures, f2.signatures))
s1 = pytd.Signature((), None, None, pytd.NothingType(), (), ())
s2 = pytd.Signature((), None, None, pytd.AnythingType(), (), ())
self.assertTrue(MyTypeMatcher().match(
pytd.Function("f1", (s1, s2), pytd.MethodTypes.METHOD),
pytd.Function("f2", (s1, s2), pytd.MethodTypes.METHOD),
mykeyword="foobar"))
self.assertFalse(MyTypeMatcher().match(
pytd.Function("f1", (s1, s2), pytd.MethodTypes.METHOD),
pytd.Function("f2", (s2, s2), pytd.MethodTypes.METHOD),
mykeyword="foobar"))
def test_named_type_with_module(self):
"""Test NamedTypeWithModule()."""
self.assertEqual(pytd_utils.NamedTypeWithModule("name"),
pytd.NamedType("name"))
self.assertEqual(pytd_utils.NamedTypeWithModule("name", None),
pytd.NamedType("name"))
self.assertEqual(pytd_utils.NamedTypeWithModule("name", "package"),
pytd.NamedType("package.name"))
def test_ordered_set(self):
ordered_set = pytd_utils.OrderedSet(n//2 for n in range(10))
ordered_set.add(-42)
ordered_set.add(3)
self.assertEqual(tuple(ordered_set), (0, 1, 2, 3, 4, -42))
def test_wrap_type_decl_unit(self):
"""Test WrapTypeDeclUnit."""
ast1 = self.Parse("""
c = ... # type: int
def f(x: int) -> int: ...
def f(x: float) -> float: ...
class A:
pass
""")
ast2 = self.Parse("""
c = ... # type: float
d = ... # type: int
def f(x: complex) -> complex: ...
class B:
pass
""")
w = pytd_utils.WrapTypeDeclUnit(
"combined",
ast1.classes + ast1.functions + ast1.constants +
ast2.classes + ast2.functions + ast2.constants)
expected = textwrap.dedent("""
from typing import Union
c = ... # type: Union[int, float]
d = ... # type: int
def f(x: int) -> int: ...
def f(x: float) -> float: ...
def f(x: complex) -> complex: ...
class A:
pass
class B:
pass
""")
self.AssertSourceEquals(w, expected)
def test_builtin_alias(self):
src = "Number = int"
ast = parser.parse_string(src, options=self.options)
self.assertMultiLineEqual(pytd_utils.Print(ast), src)
def test_typing_name_conflict1(self):
src = textwrap.dedent("""
import typing
x: typing.List[str]
def List() -> None: ...
""")
ast = parser.parse_string(src, options=self.options)
self.assertMultiLineEqual(pytd_utils.Print(ast).strip("\n"),
src.strip("\n"))
def test_typing_name_conflict2(self):
ast = parser.parse_string(textwrap.dedent("""
import typing
from typing import Any
x = ... # type: typing.List[str]
class MyClass:
List = ... # type: Any
x = ... # type: typing.List[str]
"""), options=self.options)
expected = textwrap.dedent("""
import typing
from typing import Any, List
x: List[str]
class MyClass:
List: Any
x: typing.List[str]
""")
self.assertMultiLineEqual(pytd_utils.Print(ast).strip("\n"),
expected.strip("\n"))
def test_dummy_method(self):
self.assertEqual("def foo() -> Any: ...",
pytd_utils.Print(pytd_utils.DummyMethod("foo")))
self.assertEqual("def foo(x) -> Any: ...",
pytd_utils.Print(pytd_utils.DummyMethod("foo", "x")))
self.assertEqual("def foo(x, y) -> Any: ...",
pytd_utils.Print(pytd_utils.DummyMethod("foo", "x", "y")))
def test_load_pickle_from_file(self):
d1 = {1, 2j, "3"}
with file_utils.Tempdir() as d:
filename = d.create_file("foo.pickle")
pytd_utils.SavePickle(d1, filename)
d2 = pytd_utils.LoadPickle(filename)
self.assertEqual(d1, d2)
def test_load_pickle_from_compressed_file(self):
d1 = {1, 2j, "3"}
with file_utils.Tempdir() as d:
filename = d.create_file("foo.pickle.gz")
pytd_utils.SavePickle(d1, filename, compress=True)
d2 = pytd_utils.LoadPickle(filename, compress=True)
self.assertEqual(d1, d2)
def test_diff_same_pickle(self):
ast = pytd.TypeDeclUnit("foo", (), (), (), (), ())
with file_utils.Tempdir() as d:
filename = os.path.join(d.path, "foo.pickled")
serialize_ast.StoreAst(ast, filename)
with open(filename, "rb") as fi:
data = fi.read()
named_pickles = [("foo", data)]
self.assertFalse(pytd_utils.DiffNamedPickles(named_pickles, named_pickles))
def test_diff_pickle_name(self):
ast = pytd.TypeDeclUnit("foo", (), (), (), (), ())
with file_utils.Tempdir() as d:
filename = os.path.join(d.path, "foo.pickled")
serialize_ast.StoreAst(ast, filename)
with open(filename, "rb") as fi:
data = fi.read()
named_pickles1 = [("foo", data)]
named_pickles2 = [("bar", data)]
self.assertTrue(pytd_utils.DiffNamedPickles(named_pickles1, named_pickles2))
def test_diff_pickle_ast(self):
ast1 = pytd.TypeDeclUnit("foo", (), (), (), (), ())
ast2 = ast1.Replace(type_params=(pytd.TypeParameter("T", (), None, None),))
with file_utils.Tempdir() as d:
data = []
for ast in (ast1, ast2):
filename = os.path.join(d.path, "foo.pickled")
serialize_ast.StoreAst(ast, filename)
with open(filename, "rb") as fi:
data.append(fi.read())
named_pickles1 = [("foo", data[0])]
named_pickles2 = [("foo", data[1])]
self.assertTrue(pytd_utils.DiffNamedPickles(named_pickles1, named_pickles2))
def test_diff_pickle_length(self):
ast = pytd.TypeDeclUnit("foo", (), (), (), (), ())
with file_utils.Tempdir() as d:
filename = os.path.join(d.path, "foo.pickled")
serialize_ast.StoreAst(ast, filename)
with open(filename, "rb") as fi:
data = fi.read()
named_pickles1 = []
named_pickles2 = [("foo", data)]
self.assertTrue(pytd_utils.DiffNamedPickles(named_pickles1, named_pickles2))
def test_asteq(self):
# This creates two ASts that are equivalent but whose sources are slightly
# different. The union types are different (int,str) vs (str,int) but the
# ordering is ignored when testing for equality (which ASTeq uses).
src1 = textwrap.dedent("""
from typing import Union
def foo(a: Union[int, str]) -> C: ...
T = TypeVar('T')
class C(typing.Generic[T], object):
def bar(x: T) -> NoneType: ...
CONSTANT = ... # type: C[float]
""")
src2 = textwrap.dedent("""
from typing import Union
CONSTANT = ... # type: C[float]
T = TypeVar('T')
class C(typing.Generic[T], object):
def bar(x: T) -> NoneType: ...
def foo(a: Union[str, int]) -> C: ...
""")
tree1 = parser.parse_string(src1, options=self.options)
tree2 = parser.parse_string(src2, options=self.options)
tree1.Visit(visitors.VerifyVisitor())
tree2.Visit(visitors.VerifyVisitor())
self.assertTrue(tree1.constants)
self.assertTrue(tree1.classes)
self.assertTrue(tree1.functions)
self.assertTrue(tree2.constants)
self.assertTrue(tree2.classes)
self.assertTrue(tree2.functions)
self.assertIsInstance(tree1, pytd.TypeDeclUnit)
self.assertIsInstance(tree2, pytd.TypeDeclUnit)
# For the ==, != tests, TypeDeclUnit uses identity
# pylint: disable=g-generic-assert
# pylint: disable=comparison-with-itself
self.assertTrue(tree1 == tree1)
self.assertTrue(tree2 == tree2)
self.assertFalse(tree1 == tree2)
self.assertFalse(tree2 == tree1)
self.assertFalse(tree1 != tree1)
self.assertFalse(tree2 != tree2)
self.assertTrue(tree1 != tree2)
self.assertTrue(tree2 != tree1)
# pylint: enable=g-generic-assert
# pylint: enable=comparison-with-itself
self.assertEqual(tree1, tree1)
self.assertEqual(tree2, tree2)
self.assertNotEqual(tree1, tree2)
self.assertTrue(pytd_utils.ASTeq(tree1, tree2))
self.assertTrue(pytd_utils.ASTeq(tree1, tree1))
self.assertTrue(pytd_utils.ASTeq(tree2, tree1))
self.assertTrue(pytd_utils.ASTeq(tree2, tree2))
def test_astdiff(self):
src1 = textwrap.dedent("""
a: int
b: str""").lstrip()
src2 = textwrap.dedent("""
a: int
b: float""").lstrip()
tree1 = parser.parse_string(src1, options=self.options)
tree2 = parser.parse_string(src2, options=self.options)
normalize = lambda diff: textwrap.dedent("\n".join(diff))
self.assertEqual(normalize(pytd_utils.ASTdiff(tree1, tree1)), src1)
self.assertEqual(normalize(pytd_utils.ASTdiff(tree2, tree2)), src2)
diff_pattern = r"(?s)- b.*\+ b"
self.assertRegex(normalize(pytd_utils.ASTdiff(tree1, tree2)), diff_pattern)
self.assertRegex(normalize(pytd_utils.ASTdiff(tree2, tree1)), diff_pattern)
class TestDataFiles(parser_test_base.ParserTest):
"""Test pytd_utils.GetPredefinedFile()."""
BUILTINS = "builtins"
def test_get_predefined_file_basic(self):
# smoke test, only checks that it doesn't throw, the filepath is correct,
# and the result is a string
path, src = pytd_utils.GetPredefinedFile(self.BUILTINS, "builtins")
self.assertEqual(path, "stubs/builtins/builtins.pytd")
self.assertIsInstance(src, str)
def test_get_predefined_file_throws(self):
# smoke test, only checks that it does throw
with self.assertRaisesRegex(
IOError,
r"File not found|Resource not found|No such file or directory"):
pytd_utils.GetPredefinedFile(self.BUILTINS, "-this-file-does-not-exist")
def test_pytd_builtin3(self):
"""Verify 'import sys' for python3."""
subdir = "builtins"
_, import_contents = pytd_utils.GetPredefinedFile(subdir, "builtins")
with open(os.path.join(os.path.dirname(file_utils.__file__), "stubs",
subdir, "builtins.pytd"), "r") as fi:
file_contents = fi.read()
self.assertMultiLineEqual(import_contents, file_contents)
def test_pytd_builtin_is_package(self):
subdir = "builtins"
path, _ = pytd_utils.GetPredefinedFile(subdir, "attr", as_package=True)
self.assertEqual(path, "stubs/builtins/attr/__init__.pytd")
def test_type_builder(self):
t = pytd_utils.TypeBuilder()
self.assertFalse(t)
t.add_type(pytd.AnythingType())
self.assertTrue(t)
class PrintTest(parser_test_base.ParserTest):
"""Test pytd_utils.Print."""
def test_smoke(self):
"""Smoketest for printing pytd."""
ast = self.Parse("""
from typing import Any, Union
c1 = ... # type: int
T = TypeVar('T')
class A(typing.Generic[T], object):
bar = ... # type: T
def foo(self, x: list[int], y: T) -> Union[list[T], float]:
raise ValueError()
X = TypeVar('X')
Y = TypeVar('Y')
def bar(x: Union[X, Y]) -> Any: ...
""")
pytd_utils.Print(ast)
def test_literal(self):
ast = self.Parse("""
from typing import Literal
x1: Literal[""]
x2: Literal[b""]
x3: Literal[0]
x4: Literal[True]
x5: Literal[None]
""")
ast = ast.Visit(visitors.LookupBuiltins(self.loader.builtins))
self.assertMultiLineEqual(pytd_utils.Print(ast), textwrap.dedent("""
from typing import Literal
x1: Literal['']
x2: Literal[b'']
x3: Literal[0]
x4: Literal[True]
x5: None
""").strip())
def test_literal_union(self):
ast = self.Parse("""
from typing import Literal, Union
x: Union[Literal["x"], Literal["y"]]
""")
self.assertMultiLineEqual(pytd_utils.Print(ast), textwrap.dedent("""
from typing import Literal
x: Literal['x', 'y']
""").strip())
def test_reuse_union_name(self):
src = """
import typing
from typing import Callable, Iterable, Tuple
class Node: ...
class Union:
_predicates: Tuple[Callable[[typing.Union[Iterable[Node], Node]], bool], ...]
def __init__(self, *predicates: Callable[[typing.Union[Iterable[Node], Node]], bool]) -> None: ...
"""
ast = self.Parse(src)
self.assertMultiLineEqual(pytd_utils.Print(ast),
textwrap.dedent(src).strip())
if __name__ == "__main__":
unittest.main()