Repository URL to install this package:
|
Version:
0.4.9 ▾
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import sys
from textwrap import dedent
from typing import cast, Mapping, Sequence, Tuple
from unittest import mock
import libcst as cst
from libcst import ensure_type
from libcst.metadata import MetadataWrapper
from libcst.metadata.scope_provider import (
_gen_dotted_names,
Assignment,
BuiltinAssignment,
BuiltinScope,
ClassScope,
ComprehensionScope,
FunctionScope,
GlobalScope,
ImportAssignment,
LocalScope,
QualifiedName,
QualifiedNameSource,
Scope,
ScopeProvider,
)
from libcst.testing.utils import data_provider, UnitTest
class DependentVisitor(cst.CSTVisitor):
METADATA_DEPENDENCIES = (ScopeProvider,)
def get_scope_metadata_provider(
module_str: str,
) -> Tuple[cst.Module, Mapping[cst.CSTNode, Scope]]:
wrapper = MetadataWrapper(cst.parse_module(dedent(module_str)))
return (
wrapper.module,
cast(
Mapping[cst.CSTNode, Scope], wrapper.resolve(ScopeProvider)
), # we're sure every node has an associated scope
)
class ScopeProviderTest(UnitTest):
def test_not_in_scope(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
pass
"""
)
global_scope = scopes[m]
self.assertEqual(global_scope["not_in_scope"], set())
def test_accesses(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
foo = 'toplevel'
fn1(foo)
fn2(foo)
def fn_def():
foo = 'shadow'
fn3(foo)
"""
)
scope_of_module = scopes[m]
self.assertIsInstance(scope_of_module, GlobalScope)
global_foo_assignments = list(scope_of_module["foo"])
self.assertEqual(len(global_foo_assignments), 1)
foo_assignment = global_foo_assignments[0]
self.assertEqual(len(foo_assignment.references), 2)
fn1_call_arg = ensure_type(
ensure_type(
ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Expr
).value,
cst.Call,
).args[0]
fn2_call_arg = ensure_type(
ensure_type(
ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr
).value,
cst.Call,
).args[0]
self.assertEqual(
{access.node for access in foo_assignment.references},
{fn1_call_arg.value, fn2_call_arg.value},
)
func_body = ensure_type(m.body[3], cst.FunctionDef).body
func_foo_statement = func_body.body[0]
scope_of_func_statement = scopes[func_foo_statement]
self.assertIsInstance(scope_of_func_statement, FunctionScope)
func_foo_assignments = scope_of_func_statement["foo"]
self.assertEqual(len(func_foo_assignments), 1)
foo_assignment = list(func_foo_assignments)[0]
self.assertEqual(len(foo_assignment.references), 1)
fn3_call_arg = ensure_type(
ensure_type(
ensure_type(func_body.body[1], cst.SimpleStatementLine).body[0],
cst.Expr,
).value,
cst.Call,
).args[0]
self.assertEqual(
{access.node for access in foo_assignment.references}, {fn3_call_arg.value}
)
wrapper = MetadataWrapper(cst.parse_module("from a import b\n"))
wrapper.visit(DependentVisitor())
wrapper = MetadataWrapper(cst.parse_module("def a():\n from b import c\n\n"))
wrapper.visit(DependentVisitor())
def test_fstring_accesses(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
from a import b
f"{b}" "hello"
"""
)
global_scope = scopes[m]
self.assertIsInstance(global_scope, GlobalScope)
global_accesses = list(global_scope.accesses)
self.assertEqual(len(global_accesses), 1)
import_node = ensure_type(
ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.ImportFrom
)
b_referent = list(global_accesses[0].referents)[0]
self.assertIsInstance(b_referent, Assignment)
if isinstance(b_referent, Assignment): # for the typechecker's eyes
self.assertEqual(b_referent.node, import_node)
@data_provider((("any",), ("True",), ("Exception",), ("__name__",)))
def test_builtins(self, builtin: str) -> None:
m, scopes = get_scope_metadata_provider(
"""
def fn():
pass
"""
)
scope_of_module = scopes[m]
self.assertIsInstance(scope_of_module, GlobalScope)
self.assertEqual(len(scope_of_module[builtin]), 1)
self.assertEqual(len(scope_of_module["something_not_a_builtin"]), 0)
scope_of_builtin = scope_of_module.parent
self.assertIsInstance(scope_of_builtin, BuiltinScope)
self.assertEqual(len(scope_of_builtin[builtin]), 1)
self.assertEqual(len(scope_of_builtin["something_not_a_builtin"]), 0)
func_body = ensure_type(m.body[0], cst.FunctionDef).body
func_pass_statement = func_body.body[0]
scope_of_func_statement = scopes[func_pass_statement]
self.assertIsInstance(scope_of_func_statement, FunctionScope)
self.assertEqual(len(scope_of_func_statement[builtin]), 1)
self.assertEqual(len(scope_of_func_statement["something_not_a_builtin"]), 0)
def test_import(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
import foo.bar
import fizz.buzz as fizzbuzz
import a.b.c
import d.e.f as g
"""
)
scope_of_module = scopes[m]
import_0 = cst.ensure_type(
cst.ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.Import
)
self.assertEqual(scopes[import_0], scope_of_module)
import_aliases = import_0.names
if not isinstance(import_aliases, cst.ImportStar):
for alias in import_aliases:
self.assertEqual(scopes[alias], scope_of_module)
for idx, in_scopes in enumerate(
[
["foo", "foo.bar"],
["fizzbuzz"],
["a", "a.b", "a.b.c"],
["g"],
]
):
for in_scope in in_scopes:
self.assertEqual(
len(scope_of_module[in_scope]), 1, f"{in_scope} should be in scope."
)
assignment = cast(ImportAssignment, list(scope_of_module[in_scope])[0])
self.assertEqual(
assignment.name,
in_scope,
f"ImportAssignment name {assignment.name} should equal to {in_scope}.",
)
import_node = ensure_type(m.body[idx], cst.SimpleStatementLine).body[0]
self.assertEqual(
assignment.node,
import_node,
f"The node of ImportAssignment {assignment.node} should equal to {import_node}",
)
self.assertTrue(isinstance(import_node, (cst.Import, cst.ImportFrom)))
names = import_node.names
self.assertFalse(isinstance(names, cst.ImportStar))
alias = names[0]
as_name = alias.asname.name if alias.asname else alias.name
self.assertEqual(
assignment.as_name,
as_name,
f"The alias name of ImportAssignment {assignment.as_name} should equal to {as_name}",
)
def test_dotted_import_access(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
import a.b.c, x.y
a.b.c(x.z)
"""
)
scope_of_module = scopes[m]
first_statement = ensure_type(m.body[1], cst.SimpleStatementLine)
call = ensure_type(
ensure_type(first_statement.body[0], cst.Expr).value, cst.Call
)
self.assertTrue("a.b.c" in scope_of_module)
self.assertTrue("a" in scope_of_module)
self.assertEqual(scope_of_module.accesses["a"], set())
a_b_c_assignment = cast(ImportAssignment, list(scope_of_module["a.b.c"])[0])
a_b_c_access = list(a_b_c_assignment.references)[0]
self.assertEqual(scope_of_module.accesses["a.b.c"], {a_b_c_access})
self.assertEqual(a_b_c_access.node, call.func)
x_assignment = cast(Assignment, list(scope_of_module["x"])[0])
x_access = list(x_assignment.references)[0]
self.assertEqual(scope_of_module.accesses["x"], {x_access})
self.assertEqual(
x_access.node, ensure_type(call.args[0].value, cst.Attribute).value
)
self.assertTrue("x.y" in scope_of_module)
self.assertEqual(list(scope_of_module["x.y"])[0].references, set())
self.assertEqual(scope_of_module.accesses["x.y"], set())
def test_dotted_import_with_call_access(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
import os.path
os.path.join("A", "B").lower()
"""
)
scope_of_module = scopes[m]
first_statement = ensure_type(m.body[1], cst.SimpleStatementLine)
attr = ensure_type(
ensure_type(
ensure_type(
ensure_type(
ensure_type(first_statement.body[0], cst.Expr).value, cst.Call
).func,
cst.Attribute,
).value,
cst.Call,
).func,
cst.Attribute,
).value
self.assertTrue("os.path" in scope_of_module)
self.assertTrue("os" in scope_of_module)
os_path_join_assignment = cast(
ImportAssignment, list(scope_of_module["os.path"])[0]
)
os_path_join_assignment_references = list(os_path_join_assignment.references)
self.assertNotEqual(len(os_path_join_assignment_references), 0)
os_path_join_access = os_path_join_assignment_references[0]
self.assertEqual(scope_of_module.accesses["os"], set())
self.assertEqual(scope_of_module.accesses["os.path"], {os_path_join_access})
self.assertEqual(scope_of_module.accesses["os.path.join"], set())
self.assertEqual(os_path_join_access.node, attr)
def test_import_from(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
from foo.bar import a, b as b_renamed
from . import c
from .foo import d
"""
)
scope_of_module = scopes[m]
import_from = cst.ensure_type(
cst.ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.ImportFrom
)
self.assertEqual(scopes[import_from], scope_of_module)
import_aliases = import_from.names
if not isinstance(import_aliases, cst.ImportStar):
for alias in import_aliases:
self.assertEqual(scopes[alias], scope_of_module)
for idx, in_scope, imported_object_idx in [
(0, "a", 0),
(0, "b_renamed", 1),
(1, "c", 0),
(2, "d", 0),
]:
self.assertEqual(
len(scope_of_module[in_scope]), 1, f"{in_scope} should be in scope."
)
import_assignment = cast(
ImportAssignment, list(scope_of_module[in_scope])[0]
)
self.assertEqual(
import_assignment.name,
in_scope,
f"The name of ImportAssignment {import_assignment.name} should equal to {in_scope}.",
)
import_node = ensure_type(m.body[idx], cst.SimpleStatementLine).body[0]
self.assertEqual(
import_assignment.node,
import_node,
f"The node of ImportAssignment {import_assignment.node} should equal to {import_node}",
)
self.assertTrue(isinstance(import_node, (cst.Import, cst.ImportFrom)))
names = import_node.names
self.assertFalse(isinstance(names, cst.ImportStar))
alias = names[imported_object_idx]
as_name = alias.asname.name if alias.asname else alias.name
self.assertEqual(
import_assignment.as_name,
as_name,
f"The alias name of ImportAssignment {import_assignment.as_name} should equal to {as_name}",
)
for not_in_scope in ["foo", "bar", "foo.bar", "b"]:
self.assertEqual(
len(scope_of_module[not_in_scope]),
0,
f"{not_in_scope} should not be in scope.",
)
def test_function_scope(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
global_var = None
def foo(arg, **kwargs):
local_var = 5
"""
)
scope_of_module = scopes[m]
func_def = ensure_type(m.body[1], cst.FunctionDef)
self.assertEqual(scopes[func_def], scopes[func_def.name])
func_body_statement = func_def.body.body[0]
scope_of_func = scopes[func_body_statement]
self.assertIsInstance(scope_of_func, FunctionScope)
self.assertTrue("global_var" in scope_of_module)
self.assertTrue("global_var" in scope_of_func)
self.assertTrue("arg" not in scope_of_module)
self.assertTrue("arg" in scope_of_func)
self.assertTrue("kwargs" not in scope_of_module)
self.assertTrue("kwargs" in scope_of_func)
self.assertTrue("local_var" not in scope_of_module)
self.assertTrue("local_var" in scope_of_func)
def test_class_scope(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
global_var = None
@cls_attr
class Cls(cls_attr, kwarg=cls_attr):
cls_attr = 5
def f():
pass
"""
)
scope_of_module = scopes[m]
self.assertIsInstance(scope_of_module, GlobalScope)
cls_assignments = list(scope_of_module["Cls"])
self.assertEqual(len(cls_assignments), 1)
cls_assignment = cast(Assignment, cls_assignments[0])
cls_def = ensure_type(m.body[1], cst.ClassDef)
self.assertEqual(cls_assignment.node, cls_def)
self.assertEqual(scopes[cls_def], scopes[cls_def.name])
cls_body = cls_def.body
cls_body_statement = cls_body.body[0]
scope_of_class = scopes[cls_body_statement]
self.assertIsInstance(scope_of_class, ClassScope)
func_body = ensure_type(cls_body.body[1], cst.FunctionDef).body
func_body_statement = func_body.body[0]
scope_of_func = scopes[func_body_statement]
self.assertIsInstance(scope_of_func, FunctionScope)
self.assertTrue("global_var" in scope_of_module)
self.assertTrue("global_var" in scope_of_class)
self.assertTrue("global_var" in scope_of_func)
self.assertTrue("Cls" in scope_of_module)
self.assertTrue("Cls" in scope_of_class)
self.assertTrue("Cls" in scope_of_func)
self.assertTrue("cls_attr" not in scope_of_module)
self.assertTrue("cls_attr" in scope_of_class)
self.assertTrue("cls_attr" not in scope_of_func)
def test_comprehension_scope(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
iterator = None
condition = None
[elt for target in iterator if condition]
{elt for target in iterator if condition}
{elt: target for target in iterator if condition}
(elt for target in iterator if condition)
"""
)
scope_of_module = scopes[m]
self.assertIsInstance(scope_of_module, GlobalScope)
list_comp = ensure_type(
ensure_type(
ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr
).value,
cst.ListComp,
)
scope_of_list_comp = scopes[list_comp.elt]
self.assertIsInstance(scope_of_list_comp, ComprehensionScope)
set_comp = ensure_type(
ensure_type(
ensure_type(m.body[3], cst.SimpleStatementLine).body[0], cst.Expr
).value,
cst.SetComp,
)
scope_of_set_comp = scopes[set_comp.elt]
self.assertIsInstance(scope_of_set_comp, ComprehensionScope)
dict_comp = ensure_type(
ensure_type(
ensure_type(m.body[4], cst.SimpleStatementLine).body[0], cst.Expr
).value,
cst.DictComp,
)
scope_of_dict_comp = scopes[dict_comp.key]
self.assertIsInstance(scope_of_dict_comp, ComprehensionScope)
generator_expr = ensure_type(
ensure_type(
ensure_type(m.body[5], cst.SimpleStatementLine).body[0], cst.Expr
).value,
cst.GeneratorExp,
)
scope_of_generator_expr = scopes[generator_expr.elt]
self.assertIsInstance(scope_of_generator_expr, ComprehensionScope)
self.assertTrue("iterator" in scope_of_module)
self.assertTrue("iterator" in scope_of_list_comp)
self.assertTrue("iterator" in scope_of_set_comp)
self.assertTrue("iterator" in scope_of_dict_comp)
self.assertTrue("iterator" in scope_of_generator_expr)
self.assertTrue("condition" in scope_of_module)
self.assertTrue("condition" in scope_of_list_comp)
self.assertTrue("condition" in scope_of_set_comp)
self.assertTrue("condition" in scope_of_dict_comp)
self.assertTrue("condition" in scope_of_generator_expr)
self.assertTrue("elt" not in scope_of_module)
self.assertTrue("elt" not in scope_of_list_comp)
self.assertTrue("elt" not in scope_of_set_comp)
self.assertTrue("elt" not in scope_of_dict_comp)
self.assertTrue("elt" not in scope_of_generator_expr)
self.assertTrue("target" not in scope_of_module)
self.assertTrue("target" in scope_of_list_comp)
self.assertTrue("target" in scope_of_set_comp)
self.assertTrue("target" in scope_of_dict_comp)
self.assertTrue("target" in scope_of_generator_expr)
def test_nested_comprehension_scope(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
[y for x in iterator for y in x]
"""
)
scope_of_module = scopes[m]
self.assertIsInstance(scope_of_module, GlobalScope)
list_comp = ensure_type(
ensure_type(
ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.Expr
).value,
cst.ListComp,
)
scope_of_list_comp = scopes[list_comp.elt]
self.assertIsInstance(scope_of_list_comp, ComprehensionScope)
self.assertIs(scopes[list_comp], scope_of_module)
self.assertIs(scopes[list_comp.elt], scope_of_list_comp)
self.assertIs(scopes[list_comp.for_in], scope_of_module)
self.assertIs(scopes[list_comp.for_in.iter], scope_of_module)
self.assertIs(scopes[list_comp.for_in.target], scope_of_list_comp)
inner_for_in = ensure_type(list_comp.for_in.inner_for_in, cst.CompFor)
self.assertIs(scopes[inner_for_in], scope_of_list_comp)
self.assertIs(scopes[inner_for_in.iter], scope_of_list_comp)
self.assertIs(scopes[inner_for_in.target], scope_of_list_comp)
def test_global_scope_overwrites(self) -> None:
codes = (
"""
class Cls:
def f():
global var
var = ...
""",
"""
class Cls:
def f():
global var
import f as var
""",
)
for code in codes:
m, scopes = get_scope_metadata_provider(code)
scope_of_module = scopes[m]
self.assertIsInstance(scope_of_module, GlobalScope)
self.assertTrue("var" in scope_of_module)
cls = ensure_type(m.body[0], cst.ClassDef)
scope_of_cls = scopes[cls.body.body[0]]
self.assertIsInstance(scope_of_cls, ClassScope)
self.assertTrue("var" in scope_of_cls)
f = ensure_type(cls.body.body[0], cst.FunctionDef)
scope_of_f = scopes[f.body.body[0]]
self.assertIsInstance(scope_of_f, FunctionScope)
self.assertTrue("var" in scope_of_f)
self.assertEqual(scope_of_f["var"], scope_of_module["var"])
def test_nonlocal_scope_overwrites(self) -> None:
codes = (
"""
def outer_f():
var = ...
class Cls:
var = ...
def inner_f():
nonlocal var
var = ...
""",
"""
def outer_f():
import f as var
class Cls:
var = ...
def inner_f():
nonlocal var
var = ...
""",
"""
def outer_f():
var = ...
class Cls:
var = ...
def inner_f():
nonlocal var
import f as var
""",
)
for code in codes:
m, scopes = get_scope_metadata_provider(code)
scope_of_module = scopes[m]
self.assertIsInstance(scope_of_module, GlobalScope)
self.assertTrue("var" not in scope_of_module)
outer_f = ensure_type(m.body[0], cst.FunctionDef)
outer_f_body_var = ensure_type(
ensure_type(outer_f.body.body[0], cst.SimpleStatementLine).body[0],
cst.CSTNode,
)
scope_of_outer_f = scopes[outer_f_body_var]
self.assertIsInstance(scope_of_outer_f, FunctionScope)
self.assertTrue("var" in scope_of_outer_f)
self.assertEqual(len(scope_of_outer_f["var"]), 2)
cls = ensure_type(outer_f.body.body[1], cst.ClassDef)
scope_of_cls = scopes[cls.body.body[0]]
self.assertIsInstance(scope_of_cls, ClassScope)
self.assertTrue("var" in scope_of_cls)
inner_f = ensure_type(cls.body.body[1], cst.FunctionDef)
inner_f_body_var = ensure_type(
ensure_type(inner_f.body.body[1], cst.SimpleStatementLine).body[0],
cst.CSTNode,
)
scope_of_inner_f = scopes[inner_f_body_var]
self.assertIsInstance(scope_of_inner_f, FunctionScope)
self.assertTrue("var" in scope_of_inner_f)
self.assertEqual(len(scope_of_inner_f["var"]), 2)
self.assertEqual(
{
cast(Assignment, assignment).node
for assignment in scope_of_outer_f["var"]
},
{
outer_f_body_var.targets[0].target
if isinstance(outer_f_body_var, cst.Assign)
else outer_f_body_var,
inner_f_body_var.targets[0].target
if isinstance(inner_f_body_var, cst.Assign)
else inner_f_body_var,
},
)
def test_local_scope_shadowing_with_functions(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
def f():
def f():
f = ...
"""
)
scope_of_module = scopes[m]
self.assertIsInstance(scope_of_module, GlobalScope)
self.assertTrue("f" in scope_of_module)
outer_f = ensure_type(m.body[0], cst.FunctionDef)
scope_of_outer_f = scopes[outer_f.body.body[0]]
self.assertIsInstance(scope_of_outer_f, FunctionScope)
self.assertTrue("f" in scope_of_outer_f)
out_f_assignment = list(scope_of_module["f"])[0]
self.assertEqual(cast(Assignment, out_f_assignment).node, outer_f)
inner_f = ensure_type(outer_f.body.body[0], cst.FunctionDef)
scope_of_inner_f = scopes[inner_f.body.body[0]]
self.assertIsInstance(scope_of_inner_f, FunctionScope)
self.assertTrue("f" in scope_of_inner_f)
inner_f_assignment = list(scope_of_outer_f["f"])[0]
self.assertEqual(cast(Assignment, inner_f_assignment).node, inner_f)
def test_func_param_scope(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
@decorator
def f(x: T=1, *vararg, y: T=2, z, **kwarg) -> RET:
pass
"""
)
scope_of_module = scopes[m]
self.assertIsInstance(scope_of_module, GlobalScope)
self.assertTrue("f" in scope_of_module)
f = ensure_type(m.body[0], cst.FunctionDef)
scope_of_f = scopes[f.body.body[0]]
self.assertIsInstance(scope_of_f, FunctionScope)
decorator = f.decorators[0]
x = f.params.params[0]
xT = ensure_type(x.annotation, cst.Annotation)
one = ensure_type(x.default, cst.BaseExpression)
vararg = ensure_type(f.params.star_arg, cst.Param)
y = f.params.kwonly_params[0]
yT = ensure_type(y.annotation, cst.Annotation)
two = ensure_type(y.default, cst.BaseExpression)
z = f.params.kwonly_params[1]
kwarg = ensure_type(f.params.star_kwarg, cst.Param)
ret = ensure_type(f.returns, cst.Annotation).annotation
self.assertEqual(scopes[decorator], scope_of_module)
self.assertEqual(scopes[x], scope_of_f)
self.assertEqual(scopes[xT], scope_of_module)
self.assertEqual(scopes[one], scope_of_module)
self.assertEqual(scopes[vararg], scope_of_f)
self.assertEqual(scopes[y], scope_of_f)
self.assertEqual(scopes[yT], scope_of_module)
self.assertEqual(scopes[z], scope_of_f)
self.assertEqual(scopes[two], scope_of_module)
self.assertEqual(scopes[kwarg], scope_of_f)
self.assertEqual(scopes[ret], scope_of_module)
self.assertTrue("x" not in scope_of_module)
self.assertTrue("x" in scope_of_f)
self.assertTrue("vararg" not in scope_of_module)
self.assertTrue("vararg" in scope_of_f)
self.assertTrue("y" not in scope_of_module)
self.assertTrue("y" in scope_of_f)
self.assertTrue("z" not in scope_of_module)
self.assertTrue("z" in scope_of_f)
self.assertTrue("kwarg" not in scope_of_module)
self.assertTrue("kwarg" in scope_of_f)
self.assertEqual(cast(Assignment, list(scope_of_f["x"])[0]).node, x)
self.assertEqual(cast(Assignment, list(scope_of_f["vararg"])[0]).node, vararg)
self.assertEqual(cast(Assignment, list(scope_of_f["y"])[0]).node, y)
self.assertEqual(cast(Assignment, list(scope_of_f["z"])[0]).node, z)
self.assertEqual(cast(Assignment, list(scope_of_f["kwarg"])[0]).node, kwarg)
def test_lambda_param_scope(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
lambda x=1, *vararg, y=2, z, **kwarg:x
"""
)
scope_of_module = scopes[m]
self.assertIsInstance(scope_of_module, GlobalScope)
f = ensure_type(
ensure_type(
ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.Expr
).value,
cst.Lambda,
)
scope_of_f = scopes[f.body]
self.assertIsInstance(scope_of_f, FunctionScope)
x = f.params.params[0]
one = ensure_type(x.default, cst.BaseExpression)
vararg = ensure_type(f.params.star_arg, cst.Param)
y = f.params.kwonly_params[0]
two = ensure_type(y.default, cst.BaseExpression)
z = f.params.kwonly_params[1]
kwarg = ensure_type(f.params.star_kwarg, cst.Param)
self.assertEqual(scopes[x], scope_of_f)
self.assertEqual(scopes[one], scope_of_module)
self.assertEqual(scopes[vararg], scope_of_f)
self.assertEqual(scopes[y], scope_of_f)
self.assertEqual(scopes[z], scope_of_f)
self.assertEqual(scopes[two], scope_of_module)
self.assertEqual(scopes[kwarg], scope_of_f)
self.assertTrue("x" not in scope_of_module)
self.assertTrue("x" in scope_of_f)
self.assertTrue("vararg" not in scope_of_module)
self.assertTrue("vararg" in scope_of_f)
self.assertTrue("y" not in scope_of_module)
self.assertTrue("y" in scope_of_f)
self.assertTrue("z" not in scope_of_module)
self.assertTrue("z" in scope_of_f)
self.assertTrue("kwarg" not in scope_of_module)
self.assertTrue("kwarg" in scope_of_f)
self.assertEqual(cast(Assignment, list(scope_of_f["x"])[0]).node, x)
self.assertEqual(cast(Assignment, list(scope_of_f["vararg"])[0]).node, vararg)
self.assertEqual(cast(Assignment, list(scope_of_f["y"])[0]).node, y)
self.assertEqual(cast(Assignment, list(scope_of_f["z"])[0]).node, z)
self.assertEqual(cast(Assignment, list(scope_of_f["kwarg"])[0]).node, kwarg)
def test_except_handler(self) -> None:
"""
The ``except as`` is a special case. The asname is only available in the excep body
block and it'll be removed when existing the block.
See https://docs.python.org/3.4/reference/compound_stmts.html#except
We don't create a new block for except body because we don't handle del in our Scope Analysis.
"""
m, scopes = get_scope_metadata_provider(
"""
try:
...
except Exception as ex:
...
"""
)
scope_of_module = scopes[m]
self.assertIsInstance(scope_of_module, GlobalScope)
self.assertTrue("ex" in scope_of_module)
self.assertEqual(
cast(Assignment, list(scope_of_module["ex"])[0]).node,
ensure_type(
ensure_type(m.body[0], cst.Try).handlers[0].name, cst.AsName
).name,
)
def test_with_asname(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
with open(file_name) as f:
...
"""
)
scope_of_module = scopes[m]
self.assertIsInstance(scope_of_module, GlobalScope)
self.assertTrue("f" in scope_of_module)
self.assertEqual(
cast(Assignment, list(scope_of_module["f"])[0]).node,
ensure_type(
ensure_type(m.body[0], cst.With).items[0].asname, cst.AsName
).name,
)
def test_get_qualified_names_for(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
from a.b import c
class Cls:
def f(self) -> "c":
c()
d = {}
d['key'] = 0
def g():
pass
g()
"""
)
cls = ensure_type(m.body[1], cst.ClassDef)
f = ensure_type(cls.body.body[0], cst.FunctionDef)
scope_of_module = scopes[m]
self.assertEqual(
scope_of_module.get_qualified_names_for(
ensure_type(f.returns, cst.Annotation).annotation
),
set(),
"Get qualified name given a SimpleString type annotation is not supported",
)
c_call = ensure_type(
ensure_type(f.body.body[0], cst.SimpleStatementLine).body[0], cst.Expr
).value
scope_of_f = scopes[c_call]
self.assertIsInstance(scope_of_f, FunctionScope)
self.assertEqual(
scope_of_f.get_qualified_names_for(c_call),
{QualifiedName("a.b.c", QualifiedNameSource.IMPORT)},
)
self.assertEqual(
scope_of_f.get_qualified_names_for(c_call),
{QualifiedName("a.b.c", QualifiedNameSource.IMPORT)},
)
g_call = ensure_type(
ensure_type(m.body[3], cst.SimpleStatementLine).body[0], cst.Expr
).value
self.assertIsInstance(scope_of_module, GlobalScope)
self.assertEqual(
scope_of_module.get_qualified_names_for(g_call),
{QualifiedName("g", QualifiedNameSource.LOCAL)},
)
d_name = (
ensure_type(
ensure_type(f.body.body[1], cst.SimpleStatementLine).body[0], cst.Assign
)
.targets[0]
.target
)
self.assertEqual(
scope_of_f.get_qualified_names_for(d_name),
{QualifiedName("Cls.f.<locals>.d", QualifiedNameSource.LOCAL)},
)
d_subscript = (
ensure_type(
ensure_type(f.body.body[2], cst.SimpleStatementLine).body[0], cst.Assign
)
.targets[0]
.target
)
self.assertEqual(
scope_of_f.get_qualified_names_for(d_subscript),
{QualifiedName("Cls.f.<locals>.d", QualifiedNameSource.LOCAL)},
)
for builtin in ["map", "int", "dict"]:
self.assertEqual(
scope_of_f.get_qualified_names_for(cst.Name(value=builtin)),
{QualifiedName(f"builtins.{builtin}", QualifiedNameSource.BUILTIN)},
f"Test builtin: {builtin}.",
)
self.assertEqual(
scope_of_module.get_qualified_names_for(cst.Name(value="d")),
set(),
"Test variable d in global scope.",
)
def test_get_qualified_names_for_nested_cases(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
class A:
def f1(self):
def f2():
pass
f2()
def f3(self):
class B():
...
B()
def f4():
def f5():
class C:
pass
C()
f5()
"""
)
cls_a = ensure_type(m.body[0], cst.ClassDef)
func_f1 = ensure_type(cls_a.body.body[0], cst.FunctionDef)
scope_of_cls_a = scopes[func_f1]
self.assertIsInstance(scope_of_cls_a, ClassScope)
self.assertEqual(
scope_of_cls_a.get_qualified_names_for(func_f1),
{QualifiedName("A.f1", QualifiedNameSource.LOCAL)},
)
func_f2_call = ensure_type(
ensure_type(func_f1.body.body[1], cst.SimpleStatementLine).body[0], cst.Expr
).value
scope_of_f1 = scopes[func_f2_call]
self.assertIsInstance(scope_of_f1, FunctionScope)
self.assertEqual(
scope_of_f1.get_qualified_names_for(func_f2_call),
{QualifiedName("A.f1.<locals>.f2", QualifiedNameSource.LOCAL)},
)
func_f3 = ensure_type(cls_a.body.body[1], cst.FunctionDef)
call_b = ensure_type(
ensure_type(func_f3.body.body[1], cst.SimpleStatementLine).body[0], cst.Expr
).value
scope_of_f3 = scopes[call_b]
self.assertIsInstance(scope_of_f3, FunctionScope)
self.assertEqual(
scope_of_f3.get_qualified_names_for(call_b),
{QualifiedName("A.f3.<locals>.B", QualifiedNameSource.LOCAL)},
)
func_f4 = ensure_type(m.body[1], cst.FunctionDef)
func_f5 = ensure_type(func_f4.body.body[0], cst.FunctionDef)
scope_of_f4 = scopes[func_f5]
self.assertIsInstance(scope_of_f4, FunctionScope)
self.assertEqual(
scope_of_f4.get_qualified_names_for(func_f5),
{QualifiedName("f4.<locals>.f5", QualifiedNameSource.LOCAL)},
)
cls_c = func_f5.body.body[0]
scope_of_f5 = scopes[cls_c]
self.assertIsInstance(scope_of_f5, FunctionScope)
self.assertEqual(
scope_of_f5.get_qualified_names_for(cls_c),
{QualifiedName("f4.<locals>.f5.<locals>.C", QualifiedNameSource.LOCAL)},
)
def test_get_qualified_names_for_the_same_prefix(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
from a import b, bc
bc()
"""
)
call = ensure_type(
ensure_type(
ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Expr
).value,
cst.Call,
)
module_scope = scopes[m]
self.assertEqual(
module_scope.get_qualified_names_for(call.func),
{QualifiedName("a.bc", QualifiedNameSource.IMPORT)},
)
def test_get_qualified_names_for_dotted_imports(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
import a.b.c
a(a.b.d)
"""
)
call = ensure_type(
ensure_type(
ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Expr
).value,
cst.Call,
)
module_scope = scopes[m]
self.assertEqual(
module_scope.get_qualified_names_for(call.func),
{QualifiedName("a", QualifiedNameSource.IMPORT)},
)
self.assertEqual(
module_scope.get_qualified_names_for(call.args[0].value),
{QualifiedName("a.b.d", QualifiedNameSource.IMPORT)},
)
import_stmt = ensure_type(
ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.Import
)
a_b_c = ensure_type(import_stmt.names[0].name, cst.Attribute)
a_b = ensure_type(a_b_c.value, cst.Attribute)
a = a_b.value
self.assertEqual(
module_scope.get_qualified_names_for(a_b_c),
{QualifiedName("a.b.c", QualifiedNameSource.IMPORT)},
)
self.assertEqual(
module_scope.get_qualified_names_for(a_b),
{QualifiedName("a.b", QualifiedNameSource.IMPORT)},
)
self.assertEqual(
module_scope.get_qualified_names_for(a),
{QualifiedName("a", QualifiedNameSource.IMPORT)},
)
def test_multiple_assignments(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
if 1:
from a import b as c
elif 2:
from d import e as c
c()
"""
)
call = ensure_type(
ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Expr
).value
scope = scopes[call]
self.assertIsInstance(scope, GlobalScope)
self.assertEqual(
scope.get_qualified_names_for(call),
{
QualifiedName(name="a.b", source=QualifiedNameSource.IMPORT),
QualifiedName(name="d.e", source=QualifiedNameSource.IMPORT),
},
)
self.assertEqual(
scope.get_qualified_names_for("c"),
{
QualifiedName(name="a.b", source=QualifiedNameSource.IMPORT),
QualifiedName(name="d.e", source=QualifiedNameSource.IMPORT),
},
)
def test_assignments_and_accesses(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
a = 1
def f():
a = 2
a, b
def g():
b = a
a
"""
)
a_outer_assign = (
ensure_type(
ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.Assign
)
.targets[0]
.target
)
a_outer_access = ensure_type(
ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr
).value
scope_of_module = scopes[a_outer_assign]
a_outer_assignments = scope_of_module.assignments[a_outer_access]
self.assertEqual(len(a_outer_assignments), 1)
a_outer_assignment = list(a_outer_assignments)[0]
self.assertEqual(cast(Assignment, a_outer_assignment).node, a_outer_assign)
self.assertEqual(
{i.node for i in a_outer_assignment.references}, {a_outer_access}
)
a_outer_assesses = scope_of_module.accesses[a_outer_assign]
self.assertEqual(len(a_outer_assesses), 1)
self.assertEqual(list(a_outer_assesses)[0].node, a_outer_access)
self.assertEqual(
{cast(Assignment, i).node for i in list(a_outer_assesses)[0].referents},
{a_outer_assign},
)
self.assertTrue(a_outer_assign in scope_of_module.accesses)
self.assertTrue(a_outer_assign in scope_of_module.assignments)
self.assertTrue(a_outer_access in scope_of_module.accesses)
self.assertTrue(a_outer_access in scope_of_module.assignments)
f = ensure_type(m.body[1], cst.FunctionDef)
a_inner_assign = (
ensure_type(
ensure_type(
ensure_type(f.body, cst.IndentedBlock).body[0],
cst.SimpleStatementLine,
).body[0],
cst.Assign,
)
.targets[0]
.target
)
scope_of_f = scopes[a_inner_assign]
a_inner_assignments = scope_of_f.assignments["a"]
self.assertEqual(len(a_inner_assignments), 1)
self.assertEqual(
cast(Assignment, list(a_inner_assignments)[0]).node, a_inner_assign
)
tup = ensure_type(
ensure_type(
ensure_type(
ensure_type(f.body, cst.IndentedBlock).body[1],
cst.SimpleStatementLine,
).body[0],
cst.Expr,
).value,
cst.Tuple,
)
a_inner_access = tup.elements[0].value
b_inner_access = tup.elements[1].value
all_inner_accesses = [i for i in scope_of_f.accesses]
self.assertEqual(len(all_inner_accesses), 2)
self.assertEqual(
{i.node for i in all_inner_accesses}, {a_inner_access, b_inner_access}
)
g = ensure_type(ensure_type(f.body, cst.IndentedBlock).body[2], cst.FunctionDef)
inner_most_assign = ensure_type(
ensure_type(g.body.body[0], cst.SimpleStatementLine).body[0], cst.Assign
)
b_inner_most_assign = inner_most_assign.targets[0].target
a_inner_most_access = inner_most_assign.value
scope_of_g = scopes[b_inner_most_assign]
self.assertEqual({i.node for i in scope_of_g.accesses}, {a_inner_most_access})
self.assertEqual(
{cast(Assignment, i).node for i in scope_of_g.assignments},
{b_inner_most_assign},
)
self.assertEqual(len(set(scopes.values())), 3)
def test_annotation_access(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
from typing import Literal, NewType, Optional, TypeVar, Callable, cast
from a import A, B, C, D, D2, E, E2, F, G, G2, H, I, J, K, K2, L, M
def x(a: A):
pass
def y(b: "B"):
pass
def z(c: Literal["C"]):
pass
DType = TypeVar("D2", bound=D)
EType = TypeVar("E2", bound="E")
FType = TypeVar("F")
GType = NewType("G2", "Optional[G]")
HType = Optional["H"]
IType = Callable[..., I]
class Test(Generic[J]):
pass
castedK = cast("K", "K2")
castedL = cast("L", M)
"""
)
imp = ensure_type(
ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.ImportFrom
)
scope = scopes[imp]
assignment = list(scope["A"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertTrue(references[0].is_annotation)
assignment = list(scope["B"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertTrue(references[0].is_annotation)
reference_node = references[0].node
self.assertIsInstance(reference_node, cst.SimpleString)
if isinstance(reference_node, cst.SimpleString):
self.assertEqual(reference_node.evaluated_value, "B")
assignment = list(scope["C"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 0)
assignment = list(scope["D"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)
assignment = list(scope["D2"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 0)
assignment = list(scope["E"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)
reference_node = references[0].node
self.assertIsInstance(reference_node, cst.SimpleString)
if isinstance(reference_node, cst.SimpleString):
self.assertEqual(reference_node.evaluated_value, "E")
assignment = list(scope["E2"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 0)
assignment = list(scope["F"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 0)
assignment = list(scope["G"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)
reference_node = references[0].node
self.assertIsInstance(reference_node, cst.SimpleString)
if isinstance(reference_node, cst.SimpleString):
self.assertEqual(reference_node.evaluated_value, "Optional[G]")
assignment = list(scope["G2"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 0)
assignment = list(scope["H"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)
reference_node = references[0].node
self.assertIsInstance(reference_node, cst.SimpleString)
if isinstance(reference_node, cst.SimpleString):
self.assertEqual(reference_node.evaluated_value, "H")
assignment = list(scope["I"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)
assignment = list(scope["J"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)
assignment = list(scope["K"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)
reference_node = references[0].node
self.assertIsInstance(reference_node, cst.SimpleString)
if isinstance(reference_node, cst.SimpleString):
self.assertEqual(reference_node.evaluated_value, "K")
assignment = list(scope["K2"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 0)
assignment = list(scope["L"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
reference_node = references[0].node
self.assertIsInstance(reference_node, cst.SimpleString)
if isinstance(reference_node, cst.SimpleString):
self.assertEqual(reference_node.evaluated_value, "L")
assignment = list(scope["M"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
def test_insane_annotation_access(self) -> None:
m, scopes = get_scope_metadata_provider(
r"""
from typing import TypeVar, Optional
from a import G
TypeVar("G2", bound="Optional[\"G\"]")
"""
)
imp = ensure_type(
ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.ImportFrom
)
call = ensure_type(
ensure_type(
ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr
).value,
cst.Call,
)
bound = call.args[1].value
scope = scopes[imp]
assignment = next(iter(scope["G"]))
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
self.assertEqual(list(assignment.references)[0].node, bound)
def test_dotted_annotation_access(self) -> None:
m, scopes = get_scope_metadata_provider(
r"""
from typing import TypeVar
import a.G
TypeVar("G2", bound="a.G")
"""
)
imp = ensure_type(
ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Import
)
call = ensure_type(
ensure_type(
ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr
).value,
cst.Call,
)
bound = call.args[1].value
scope = scopes[imp]
assignment = next(iter(scope["a.G"]))
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
self.assertEqual(list(assignment.references)[0].node, bound)
def test_node_of_scopes(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
def f1():
target()
class C:
attr = target()
"""
)
f1 = ensure_type(m.body[0], cst.FunctionDef)
target_call = ensure_type(
ensure_type(f1.body.body[0], cst.SimpleStatementLine).body[0], cst.Expr
).value
f1_scope = scopes[target_call]
self.assertIsInstance(f1_scope, FunctionScope)
self.assertEqual(cast(FunctionScope, f1_scope).node, f1)
c = ensure_type(m.body[1], cst.ClassDef)
target_call_2 = ensure_type(
ensure_type(c.body.body[0], cst.SimpleStatementLine).body[0], cst.Assign
).value
c_scope = scopes[target_call_2]
self.assertIsInstance(c_scope, ClassScope)
self.assertEqual(cast(ClassScope, c_scope).node, c)
def test_with_statement(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
import unittest.mock
with unittest.mock.patch("something") as obj:
obj.f1()
unittest.mock
"""
)
import_ = ensure_type(m.body[0], cst.SimpleStatementLine).body[0]
assignments = scopes[import_]["unittest"]
self.assertEqual(len(assignments), 1)
self.assertEqual(cast(Assignment, list(assignments)[0]).node, import_)
with_ = ensure_type(m.body[1], cst.With)
fn_call = with_.items[0].item
self.assertEqual(
scopes[fn_call].get_qualified_names_for(fn_call),
{
QualifiedName(
name="unittest.mock.patch", source=QualifiedNameSource.IMPORT
)
},
)
mock = ensure_type(
ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr
).value
self.assertEqual(
scopes[fn_call].get_qualified_names_for(mock),
{QualifiedName(name="unittest.mock", source=QualifiedNameSource.IMPORT)},
)
def test_del_context_names(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
import a
dic = {}
del dic
del dic["key"]
del a.b
"""
)
dic = ensure_type(
ensure_type(
ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Assign
).targets[0],
cst.AssignTarget,
).target
del_dic = ensure_type(
ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Del
)
scope = scopes[del_dic]
assignments = list(scope["dic"])
self.assertEqual(len(assignments), 1)
dic_assign = assignments[0]
self.assertIsInstance(dic_assign, Assignment)
self.assertEqual(cast(Assignment, dic_assign).node, dic)
self.assertEqual(len(dic_assign.references), 2)
del_dic_subscript = ensure_type(
ensure_type(
ensure_type(m.body[3], cst.SimpleStatementLine).body[0], cst.Del
).target,
cst.Subscript,
)
self.assertSetEqual(
{i.node for i in dic_assign.references},
{del_dic.target, del_dic_subscript.value},
)
assignments = list(scope["a"])
self.assertEqual(len(assignments), 1)
a_assign = assignments[0]
self.assertIsInstance(a_assign, Assignment)
import_a = ensure_type(m.body[0], cst.SimpleStatementLine).body[0]
self.assertEqual(cast(Assignment, a_assign).node, import_a)
self.assertEqual(len(a_assign.references), 1)
del_a_b = ensure_type(
ensure_type(m.body[4], cst.SimpleStatementLine).body[0], cst.Del
)
self.assertEqual(
{i.node for i in a_assign.references},
{ensure_type(del_a_b.target, cst.Attribute).value},
)
self.assertEqual(scope["b"], set())
def test_keyword_arg_in_call(self) -> None:
m, scopes = get_scope_metadata_provider("call(arg=val)")
call = ensure_type(
ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.Expr
).value
scope = scopes[call]
self.assertIsInstance(scope, GlobalScope)
self.assertEqual(len(scope["arg"]), 0) # no assignment should exist
def test_global_contains_is_read_only(self) -> None:
gscope = GlobalScope()
before_assignments = list(gscope.assignments)
before_accesses = list(gscope.accesses)
self.assertFalse("doesnt_exist" in gscope)
self.assertEqual(list(gscope.accesses), before_accesses)
self.assertEqual(list(gscope.assignments), before_assignments)
def test_contains_is_read_only(self) -> None:
for s in [LocalScope, FunctionScope, ClassScope, ComprehensionScope]:
with self.subTest(scope=s):
gscope = GlobalScope()
scope = s(parent=gscope, node=cst.Name("lol"))
before_assignments = list(scope.assignments)
before_accesses = list(scope.accesses)
before_overwrites = list(scope._scope_overwrites.items())
before_parent_assignments = list(scope.parent.assignments)
before_parent_accesses = list(scope.parent.accesses)
self.assertFalse("doesnt_exist" in scope)
self.assertEqual(list(scope.accesses), before_accesses)
self.assertEqual(list(scope.assignments), before_assignments)
self.assertEqual(
list(scope._scope_overwrites.items()), before_overwrites
)
self.assertEqual(
list(scope.parent.assignments), before_parent_assignments
)
self.assertEqual(list(scope.parent.accesses), before_parent_accesses)
def test_attribute_of_function_call(self) -> None:
get_scope_metadata_provider("foo().bar")
def test_attribute_of_subscript_called(self) -> None:
m, scopes = get_scope_metadata_provider("foo[0].bar.baz()")
scope = scopes[m]
self.assertIn("foo", scope.accesses)
def test_self(self) -> None:
with open(__file__) as f:
get_scope_metadata_provider(f.read())
def test_get_qualified_names_for_is_read_only(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
import a
import b
"""
)
a = m.body[0]
scope = scopes[a]
assignments_before = list(scope.assignments)
accesses_before = list(scope.accesses)
scope.get_qualified_names_for("doesnt_exist")
self.assertEqual(list(scope.assignments), assignments_before)
self.assertEqual(list(scope.accesses), accesses_before)
def test_gen_dotted_names(self) -> None:
names = {name for name, node in _gen_dotted_names(cst.Name(value="a"))}
self.assertEqual(names, {"a"})
names = {
name
for name, node in _gen_dotted_names(
cst.Attribute(value=cst.Name(value="a"), attr=cst.Name(value="b"))
)
}
self.assertEqual(names, {"a.b", "a"})
names = {
name
for name, node in _gen_dotted_names(
cst.Attribute(
value=cst.Call(
func=cst.Attribute(
value=cst.Attribute(
value=cst.Name(value="a"), attr=cst.Name(value="b")
),
attr=cst.Name(value="c"),
),
args=[],
),
attr=cst.Name(value="d"),
)
)
}
self.assertEqual(names, {"a.b.c", "a.b", "a"})
def test_ordering(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
from a import b
class X:
x = b
b = b
y = b
"""
)
global_scope = scopes[m]
import_stmt = ensure_type(
ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.ImportFrom
)
first_assignment = list(global_scope.assignments)[0]
assert isinstance(first_assignment, cst.metadata.Assignment)
self.assertEqual(first_assignment.node, import_stmt)
global_refs = first_assignment.references
self.assertEqual(len(global_refs), 2)
global_refs_nodes = {x.node for x in global_refs}
class_def = ensure_type(m.body[1], cst.ClassDef)
x = ensure_type(
ensure_type(class_def.body.body[0], cst.SimpleStatementLine).body[0],
cst.Assign,
)
self.assertIn(x.value, global_refs_nodes)
class_b = ensure_type(
ensure_type(class_def.body.body[1], cst.SimpleStatementLine).body[0],
cst.Assign,
)
self.assertIn(class_b.value, global_refs_nodes)
class_accesses = list(scopes[x].accesses)
self.assertEqual(len(class_accesses), 3)
self.assertIn(
class_b.targets[0].target,
[
ref.node
for acc in class_accesses
for ref in acc.referents
if isinstance(ref, Assignment)
],
)
y = ensure_type(
ensure_type(class_def.body.body[2], cst.SimpleStatementLine).body[0],
cst.Assign,
)
self.assertIn(y.value, [access.node for access in class_accesses])
def test_ordering_between_scopes(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
def f(a):
print(a)
print(b)
a = 1
b = 1
"""
)
f = cst.ensure_type(m.body[0], cst.FunctionDef)
a_param = f.params.params[0].name
a_param_assignment = list(scopes[a_param]["a"])[0]
a_param_refs = list(a_param_assignment.references)
first_print = cst.ensure_type(
cst.ensure_type(
cst.ensure_type(f.body.body[0], cst.SimpleStatementLine).body[0],
cst.Expr,
).value,
cst.Call,
)
second_print = cst.ensure_type(
cst.ensure_type(
cst.ensure_type(f.body.body[1], cst.SimpleStatementLine).body[0],
cst.Expr,
).value,
cst.Call,
)
self.assertEqual(
first_print.args[0].value,
a_param_refs[0].node,
)
a_global = (
cst.ensure_type(
cst.ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Assign
)
.targets[0]
.target
)
a_global_assignment = list(scopes[a_global]["a"])[0]
a_global_refs = list(a_global_assignment.references)
self.assertEqual(a_global_refs, [])
b_global = (
cst.ensure_type(
cst.ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Assign
)
.targets[0]
.target
)
b_global_assignment = list(scopes[b_global]["b"])[0]
b_global_refs = list(b_global_assignment.references)
self.assertEqual(len(b_global_refs), 1)
self.assertEqual(b_global_refs[0].node, second_print.args[0].value)
def test_ordering_comprehension(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
def f(a):
[a for a in [] for b in a]
[b for a in [] for b in a]
[a for a in [] for a in []]
a = 1
"""
)
f = cst.ensure_type(m.body[0], cst.FunctionDef)
a_param = f.params.params[0].name
a_param_assignment = list(scopes[a_param]["a"])[0]
a_param_refs = list(a_param_assignment.references)
self.assertEqual(a_param_refs, [])
first_comp = cst.ensure_type(
cst.ensure_type(
cst.ensure_type(f.body.body[0], cst.SimpleStatementLine).body[0],
cst.Expr,
).value,
cst.ListComp,
)
a_comp_assignment = list(scopes[first_comp.elt]["a"])[0]
self.assertEqual(len(a_comp_assignment.references), 2)
self.assertIn(
first_comp.elt, [ref.node for ref in a_comp_assignment.references]
)
second_comp = cst.ensure_type(
cst.ensure_type(
cst.ensure_type(f.body.body[1], cst.SimpleStatementLine).body[0],
cst.Expr,
).value,
cst.ListComp,
)
b_comp_assignment = list(scopes[second_comp.elt]["b"])[0]
self.assertEqual(len(b_comp_assignment.references), 1)
a_second_comp_assignment = list(scopes[second_comp.elt]["a"])[0]
self.assertEqual(len(a_second_comp_assignment.references), 1)
third_comp = cst.ensure_type(
cst.ensure_type(
cst.ensure_type(f.body.body[2], cst.SimpleStatementLine).body[0],
cst.Expr,
).value,
cst.ListComp,
)
a_third_comp_assignments = list(scopes[third_comp.elt]["a"])
self.assertEqual(len(a_third_comp_assignments), 2)
a_third_comp_access = list(scopes[third_comp.elt].accesses)[0]
self.assertEqual(a_third_comp_access.node, third_comp.elt)
# We record both assignments because it's impossible to know which one
# the access refers to without running the program
self.assertEqual(len(a_third_comp_access.referents), 2)
inner_for_in = third_comp.for_in.inner_for_in
self.assertIsNotNone(inner_for_in)
if inner_for_in:
self.assertIn(
inner_for_in.target,
{
ref.node
for ref in a_third_comp_access.referents
if isinstance(ref, Assignment)
},
)
a_global = (
cst.ensure_type(
cst.ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Assign
)
.targets[0]
.target
)
a_global_assignment = list(scopes[a_global]["a"])[0]
a_global_refs = list(a_global_assignment.references)
self.assertEqual(a_global_refs, [])
def test_ordering_comprehension_confusing(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
def f(a):
[a for a in a]
a = 1
"""
)
f = cst.ensure_type(m.body[0], cst.FunctionDef)
a_param = f.params.params[0].name
a_param_assignment = list(scopes[a_param]["a"])[0]
a_param_refs = list(a_param_assignment.references)
self.assertEqual(len(a_param_refs), 1)
comp = cst.ensure_type(
cst.ensure_type(
cst.ensure_type(f.body.body[0], cst.SimpleStatementLine).body[0],
cst.Expr,
).value,
cst.ListComp,
)
a_comp_assignment = list(scopes[comp.elt]["a"])[0]
self.assertEqual(list(a_param_refs)[0].node, comp.for_in.iter)
self.assertEqual(len(a_comp_assignment.references), 1)
self.assertEqual(list(a_comp_assignment.references)[0].node, comp.elt)
def test_for_scope_ordering(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
def f():
for x in []:
x
class X:
def f():
for x in []:
x
"""
)
for scope in scopes.values():
for acc in scope.accesses:
self.assertEqual(
len(acc.referents),
1,
msg=(
"Access for node has incorrect number of referents: "
+ f"{acc.node}"
),
)
def test_no_out_of_order_references_in_global_scope(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
x = y
y = 1
"""
)
for scope in scopes.values():
for acc in scope.accesses:
self.assertEqual(
len(acc.referents),
0,
msg=(
"Access for node has incorrect number of referents: "
+ f"{acc.node}"
),
)
def test_walrus_accesses(self) -> None:
if sys.version_info < (3, 8):
self.skipTest("This python version doesn't support :=")
m, scopes = get_scope_metadata_provider(
"""
if x := y:
y = 1
x
"""
)
for scope in scopes.values():
for acc in scope.accesses:
self.assertEqual(
len(acc.referents),
1 if getattr(acc.node, "value", None) == "x" else 0,
msg=(
"Access for node has incorrect number of referents: "
+ f"{acc.node}"
),
)
@data_provider(
{
"TypeVar": {
"code": """
from typing import TypeVar
TypeVar("Name", "int")
""",
"calls": [mock.call("int")],
},
"Dict": {
"code": """
from typing import Dict
Dict["str", "int"]
""",
"calls": [mock.call("str"), mock.call("int")],
},
"cast_no_annotation": {
"code": """
from typing import Dict, cast
cast(Dict[str, str], {})["3rr0r"]
""",
"calls": [],
},
"cast_second_arg": {
"code": """
from typing import cast
cast(str, "foo")
""",
"calls": [],
},
"cast_first_arg": {
"code": """
from typing import cast
cast("int", "foo")
""",
"calls": [
mock.call("int"),
],
},
"typevar_func": {
"code": """
from typing import TypeVar
TypeVar("Name", func("int"))
""",
"calls": [],
},
"literal": {
"code": """
from typing import Literal
Literal[\"G\"]
""",
"calls": [],
},
"nested_str": {
"code": r"""
from typing import TypeVar, Optional
from a import G
TypeVar("G2", bound="Optional[\"G\"]")
""",
"calls": [mock.call('Optional["G"]'), mock.call("G")],
},
"class_self_ref": {
"code": """
from typing import TypeVar
class HelperClass:
value: TypeVar("THelperClass", bound="HelperClass")
""",
"calls": [mock.call("HelperClass")],
},
}
)
def test_parse_string_annotations(
self, *, code: str, calls: Sequence[mock._Call]
) -> None:
parse = cst.parse_module
with mock.patch("libcst.parse_module") as parse_mock:
parse_mock.side_effect = parse
get_scope_metadata_provider(dedent(code))
calls = [mock.call(dedent(code))] + list(calls)
self.assertEqual(parse_mock.call_count, len(calls))
parse_mock.assert_has_calls(calls)
def test_builtin_scope(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
a = pow(1, 2)
def foo():
b = pow(2, 3)
"""
)
scope_of_module = scopes[m]
self.assertIsInstance(scope_of_module, GlobalScope)
self.assertEqual(len(scope_of_module["pow"]), 1)
builtin_pow_assignment = list(scope_of_module["pow"])[0]
self.assertIsInstance(builtin_pow_assignment, BuiltinAssignment)
self.assertIsInstance(builtin_pow_assignment.scope, BuiltinScope)
global_a_assignments = scope_of_module["a"]
self.assertEqual(len(global_a_assignments), 1)
a_assignment = list(global_a_assignments)[0]
self.assertIsInstance(a_assignment, Assignment)
func_body = ensure_type(m.body[1], cst.FunctionDef).body
func_statement = func_body.body[0]
scope_of_func_statement = scopes[func_statement]
self.assertIsInstance(scope_of_func_statement, FunctionScope)
func_b_assignments = scope_of_func_statement["b"]
self.assertEqual(len(func_b_assignments), 1)
b_assignment = list(func_b_assignments)[0]
self.assertIsInstance(b_assignment, Assignment)
builtin_pow_accesses = list(builtin_pow_assignment.references)
self.assertEqual(len(builtin_pow_accesses), 2)
def test_override_builtin_scope(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
def pow(x, y):
return x ** y
a = pow(1, 2)
def foo():
b = pow(2, 3)
"""
)
scope_of_module = scopes[m]
self.assertIsInstance(scope_of_module, GlobalScope)
self.assertEqual(len(scope_of_module["pow"]), 1)
global_pow_assignment = list(scope_of_module["pow"])[0]
self.assertIsInstance(global_pow_assignment, Assignment)
self.assertIsInstance(global_pow_assignment.scope, GlobalScope)
global_a_assignments = scope_of_module["a"]
self.assertEqual(len(global_a_assignments), 1)
a_assignment = list(global_a_assignments)[0]
self.assertIsInstance(a_assignment, Assignment)
func_body = ensure_type(m.body[2], cst.FunctionDef).body
func_statement = func_body.body[0]
scope_of_func_statement = scopes[func_statement]
self.assertIsInstance(scope_of_func_statement, FunctionScope)
func_b_assignments = scope_of_func_statement["b"]
self.assertEqual(len(func_b_assignments), 1)
b_assignment = list(func_b_assignments)[0]
self.assertIsInstance(b_assignment, Assignment)
global_pow_accesses = list(global_pow_assignment.references)
self.assertEqual(len(global_pow_accesses), 2)
def test_annotation_access_in_typevar_bound(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
from typing import TypeVar
class Test:
var: TypeVar("T", bound="Test")
"""
)
imp = ensure_type(
ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.ImportFrom
)
scope = scopes[imp]
assignment = list(scope["Test"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertTrue(references[0].is_annotation)
def test_prefix_match(self) -> None:
"""Verify that a name doesn't overmatch on prefix"""
m, scopes = get_scope_metadata_provider(
"""
def something():
...
"""
)
scope = scopes[m]
self.assertEqual(
scope.get_qualified_names_for(cst.Name("something")),
{QualifiedName(name="something", source=QualifiedNameSource.LOCAL)},
)
self.assertEqual(
scope.get_qualified_names_for(cst.Name("something_else")),
set(),
)