Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
pytype / pytd / visitors.py
Size: Mime:
"""Visitor(s) for walking ASTs."""

import collections
import itertools
import logging
import re
from typing import Set

from pytype import datatypes
from pytype import module_utils
from pytype import utils
from pytype.pytd import base_visitor
from pytype.pytd import escape
from pytype.pytd import mro
from pytype.pytd import pep484
from pytype.pytd import printer
from pytype.pytd import pytd
from pytype.pytd import pytd_utils
from pytype.pytd import pytd_visitors
from pytype.pytd.parse import parser_constants  # pylint: disable=g-importing-member


class ContainerError(Exception):
  pass


class SymbolLookupError(Exception):
  pass


# All public elements of pytd_visitors are aliased here so that we can maintain
# the conceptually simpler illusion of having a single visitors module.
ALL_NODE_NAMES = base_visitor.ALL_NODE_NAMES
Visitor = base_visitor.Visitor
CanonicalOrderingVisitor = pytd_visitors.CanonicalOrderingVisitor
ClassTypeToNamedType = pytd_visitors.ClassTypeToNamedType
CollectTypeParameters = pytd_visitors.CollectTypeParameters
ExtractSuperClasses = pytd_visitors.ExtractSuperClasses
PrintVisitor = printer.PrintVisitor
RenameModuleVisitor = pytd_visitors.RenameModuleVisitor


class FillInLocalPointers(Visitor):
  """Fill in ClassType pointers using symbol tables.

  This is an in-place visitor! It modifies the original tree. This is
  necessary because we introduce loops.
  """

  def __init__(self, lookup_map, fallback=None):
    """Create this visitor.

    You're expected to then pass this instance to node.Visit().

    Args:
      lookup_map: A map from names to symbol tables (i.e., objects that have a
        "Lookup" function).
      fallback: A symbol table to be tried if lookup otherwise fails.
    """
    super().__init__()
    if fallback is not None:
      lookup_map["*"] = fallback
    self._lookup_map = lookup_map

  def _Lookup(self, node):
    """Look up a node by name."""
    if "." in node.name:
      modules_to_try = []
      module = node.name
      while "." in module:
        module, _, _ = module.rpartition(".")
        modules_to_try.append(("", module))
    else:
      modules_to_try = [("", ""),
                        ("", "builtins"),
                        ("builtins.", "builtins")]
    modules_to_try += [("", "*"), ("builtins.", "*")]
    for prefix, module in modules_to_try:
      mod_ast = self._lookup_map.get(module)
      if mod_ast:
        name = prefix + node.name
        mod_prefix = f"{mod_ast.name}."
        try:
          if name.startswith(mod_prefix) and mod_prefix != "builtins.":
            item = pytd.LookupItemRecursive(mod_ast, name[len(mod_prefix):])
          else:
            item = mod_ast.Lookup(name)
        except KeyError:
          pass
        else:
          yield prefix, item

  def EnterClassType(self, node):
    """Fills in a class type.

    Args:
      node: A ClassType. This node will have a name, which we use for lookup.

    Returns:
      The same ClassType. We will have done our best to fill in its "cls"
      attribute. Call VerifyLookup() on your tree if you want to be sure that
      all of the cls pointers have been filled in.
    """
    for prefix, cls in self._Lookup(node):
      if isinstance(cls, pytd.Class):
        node.cls = cls
        return
      else:
        logging.warning("Couldn't resolve %s: Not a class: %s",
                        prefix + node.name, type(cls))


class _RemoveTypeParametersFromGenericAny(Visitor):
  """Adjusts GenericType nodes to handle base type changes."""

  unchecked_node_names = ("GenericType",)

  def VisitGenericType(self, node):
    if isinstance(node.base_type, (pytd.AnythingType, pytd.Constant)):
      # TODO(rechen): Raise an exception if the base type is a constant whose
      # type isn't Any.
      return node.base_type
    else:
      return node


class DefaceUnresolved(_RemoveTypeParametersFromGenericAny):
  """Replace all types not in a symbol table with AnythingType."""

  def __init__(self, lookup_list, do_not_log_prefix=None):
    """Create this visitor.

    Args:
      lookup_list: An iterable of symbol tables (i.e., objects that have a
        "lookup" function)
      do_not_log_prefix: If given, don't log error messages for classes with
        this prefix.
    """
    super().__init__()
    self._lookup_list = lookup_list
    self._do_not_log_prefix = do_not_log_prefix

  def VisitNamedType(self, node):
    """Do replacement on a pytd.NamedType."""
    name = node.name
    for lookup in self._lookup_list:
      try:
        cls = lookup.Lookup(name)
        if isinstance(cls, pytd.Class):
          return node
      except KeyError:
        pass
    if "." in node.name:
      return node
    else:
      if (self._do_not_log_prefix is None or
          not name.startswith(self._do_not_log_prefix)):
        logging.warning("Setting %s to Any", name)
      return pytd.AnythingType()

  def VisitCallableType(self, node):
    return self.VisitGenericType(node)

  def VisitTupleType(self, node):
    return self.VisitGenericType(node)

  def VisitClassType(self, node):
    return self.VisitNamedType(node)


class NamedTypeToClassType(Visitor):
  """Change all NamedType objects to ClassType objects.
  """

  def VisitNamedType(self, node):
    """Converts a named type to a class type, to be filled in later.

    Args:
      node: The NamedType. This type only has a name.

    Returns:
      A ClassType. This ClassType will (temporarily) only have a name.
    """
    return pytd.ClassType(node.name)


def LookupClasses(target, global_module=None, ignore_late_types=False):
  """Converts a PyTD object from one using NamedType to ClassType.

  Args:
    target: The PyTD object to process. If this is a TypeDeclUnit it will also
      be used for lookups.
    global_module: Global symbols. Required if target is not a TypeDeclUnit.
    ignore_late_types: If True, raise an error if we encounter a LateType.

  Returns:
    A new PyTD object that only uses ClassType. All ClassType instances will
    point to concrete classes.

  Raises:
    ValueError: If we can't find a class.
  """
  target = target.Visit(NamedTypeToClassType())
  module_map = {}
  if global_module is None:
    assert isinstance(target, pytd.TypeDeclUnit)
    global_module = target
  elif isinstance(target, pytd.TypeDeclUnit):
    module_map[""] = target
  target.Visit(FillInLocalPointers(module_map, fallback=global_module))
  target.Visit(VerifyLookup(ignore_late_types))
  return target


class VerifyLookup(Visitor):
  """Utility class for testing visitors.LookupClasses."""

  def __init__(self, ignore_late_types=False):
    super().__init__()
    self.ignore_late_types = ignore_late_types

  def EnterLateType(self, node):
    if not self.ignore_late_types:
      raise ValueError("Unresolved LateType: %r" % node.name)

  def EnterNamedType(self, node):
    raise ValueError("Unreplaced NamedType: %r" % node.name)

  def EnterClassType(self, node):
    if node.cls is None:
      raise ValueError("Unresolved class: %r" % node.name)


class _ToTypeVisitor(Visitor):
  """Visitor for calling pytd.ToType().

  pytd.ToType() usually rejects constants and functions, as they cannot be
  converted to types. However, aliases can point to them, and typing.Literal can
  be parameterized by constants, so this visitor tracks whether we are inside an
  alias or literal, and its to_type() method calls pytd.ToType() with the
  appropriate allow_constants and allow_functions values.
  """

  def __init__(self, allow_singletons):
    super().__init__()
    self._in_alias = False
    self._in_literal = 0
    self.allow_singletons = allow_singletons

  def EnterAlias(self, _):
    assert not self._in_alias
    self._in_alias = True

  def LeaveAlias(self, _):
    assert self._in_alias
    self._in_alias = False

  def EnterLiteral(self, _):
    self._in_literal += 1

  def LeaveLiteral(self, _):
    self._in_literal -= 1

  def to_type(self, t):
    allow_constants = self._in_alias or self._in_literal
    allow_functions = self._in_alias
    return pytd.ToType(
        t, allow_constants=allow_constants, allow_functions=allow_functions,
        allow_singletons=self.allow_singletons)


class LookupBuiltins(_ToTypeVisitor):
  """Look up built-in NamedTypes and give them fully-qualified names."""

  def __init__(self, builtins, full_names=True, allow_singletons=False):
    """Create this visitor.

    Args:
      builtins: The builtins module.
      full_names: Whether to use fully qualified names for lookup.
      allow_singletons: Whether to allow singleton types like Ellipsis.
    """
    super().__init__(allow_singletons)
    self._builtins = builtins
    self._full_names = full_names

  def EnterTypeDeclUnit(self, unit):
    self._current_unit = unit
    self._prefix = unit.name + "." if self._full_names else ""

  def LeaveTypeDeclUnit(self, _):
    del self._current_unit
    del self._prefix

  def VisitNamedType(self, t):
    """Do lookup on a pytd.NamedType."""
    if "." in t.name or self._prefix + t.name in self._current_unit:
      return t
    # We can't find this identifier in our current module, and it isn't fully
    # qualified (doesn't contain a dot). Now check whether it's a builtin.
    try:
      item = self._builtins.Lookup(self._builtins.name + "." + t.name)
    except KeyError:
      return t
    else:
      try:
        return self.to_type(item)
      except NotImplementedError:
        # This can happen if a builtin is redefined.
        return t


def MaybeSubstituteParameters(base_type, parameters=None):
  """Substitutes parameters into base_type if the latter has a template."""
  # Check if `base_type` is a generic type whose type parameters should be
  # substituted by `parameters` (a "type macro").
  template = pytd_utils.GetTypeParameters(base_type)
  if not template or parameters is None:
    return None
  if len(template) != len(parameters):
    raise ValueError("%s expected %d parameters, got %s" % (
        pytd_utils.Print(base_type), len(template), len(parameters)))
  mapping = dict(zip(template, parameters))
  return base_type.Visit(ReplaceTypeParameters(mapping))


class LookupExternalTypes(_RemoveTypeParametersFromGenericAny, _ToTypeVisitor):
  """Look up NamedType pointers using a symbol table."""

  def __init__(self, module_map, self_name=None, module_alias_map=None):
    """Create this visitor.

    Args:
      module_map: A dictionary mapping module names to symbol tables.
      self_name: The name of the current module. If provided, then the visitor
        will ignore nodes with this module name.
      module_alias_map: A dictionary mapping module aliases to real module
        names. If the source contains "import X as Y", module_alias_map should
        contain an entry mapping "Y": "X".
    """
    super().__init__(allow_singletons=False)
    self._module_map = module_map
    self._module_alias_map = module_alias_map or {}
    self.name = self_name
    self._alias_name = None
    self._in_generic_type = 0
    self._star_imports = set()
    self._unit = None

  def _ResolveUsingGetattr(self, module_name, module):
    """Try to resolve an identifier using the top level __getattr__ function."""
    try:
      g = module.Lookup(module_name + ".__getattr__")
    except KeyError:
      return None
    assert len(g.signatures) == 1
    return g.signatures[0].return_type

  def _ResolveUsingStarImport(self, module, name):
    """Try to use any star imports in 'module' to resolve 'name'."""
    wanted_name = self._ModulePrefix() + name
    for alias in module.aliases:
      type_name = alias.type.name
      if not type_name or not type_name.endswith(".*"):
        continue
      imported_module = type_name[:-2]
      # 'module' contains 'from imported_module import *'. If we can find an AST
      # for imported_module, check whether any of the imported names match the
      # one we want to resolve.
      if imported_module not in self._module_map:
        continue
      imported_aliases, _ = self._ImportAll(imported_module)
      for imported_alias in imported_aliases:
        if imported_alias.name == wanted_name:
          return imported_alias
    return None

  def EnterAlias(self, t):
    super().EnterAlias(t)
    assert not self._alias_name
    self._alias_name = t.name

  def LeaveAlias(self, t):
    super().LeaveAlias(t)
    assert self._alias_name
    self._alias_name = None

  def EnterGenericType(self, _):
    self._in_generic_type += 1

  def LeaveGenericType(self, _):
    self._in_generic_type -= 1

  def _LookupModuleRecursive(self, name):
    module_name, cls_prefix = name, ""
    while module_name not in self._module_map and "." in module_name:
      module_name, class_name = module_name.rsplit(".", 1)
      cls_prefix = class_name + "." + cls_prefix
    if module_name in self._module_map:
      return self._module_map[module_name], cls_prefix
    else:
      raise KeyError("Unknown module %s" % name)

  def VisitNamedType(self, t):
    """Try to look up a NamedType.

    Args:
      t: An instance of pytd.NamedType
    Returns:
      The same node t.
    Raises:
      KeyError: If we can't find a module, or an identifier in a module, or
        if an identifier in a module isn't a class.
    """
    if t.name in self._module_map:
      if self._alias_name and "." in self._alias_name:
        # Module aliases appear only in asts that use fully-qualified names.
        return pytd.Module(name=self._alias_name, module_name=t.name)
      else:
        # We have a class with the same name as a module.
        return t
    module_name, dot, name = t.name.rpartition(".")
    if not dot or module_name == self.name:
      # Nothing to do here. This visitor will only look up nodes in other
      # modules.
      return t
    if module_name in self._module_alias_map:
      module_name = self._module_alias_map[module_name]
    try:
      module, cls_prefix = self._LookupModuleRecursive(module_name)
    except KeyError:
      if self._unit and f"{self.name}.{module_name}" in self._unit:
        # Nothing to do here.This is a dotted local reference.
        return t
      raise
    module_name = module.name
    if module_name == self.name:  # dotted local reference
      return t
    name = cls_prefix + name
    try:
      if name == "*":
        self._star_imports.add(module_name)
        item = t  # VisitTypeDeclUnit will remove this unneeded item.
      else:
        item = pytd.LookupItemRecursive(module, name)
    except KeyError as e:
      item = self._ResolveUsingGetattr(module_name, module)
      if item is None:
        # If 'module' is involved in a circular dependency, it may contain a
        # star import that has not yet been resolved via the usual mechanism, so
        # we need to manually resolve it here.
        item = self._ResolveUsingStarImport(module, name)
        if item is None:
          raise KeyError("No %s in module %s" % (name, module_name)) from e
    if not self._in_generic_type and isinstance(item, pytd.Alias):
      # If `item` contains type parameters and is not inside a GenericType, then
      # we replace the parameters with Any.
      item = MaybeSubstituteParameters(item.type) or item
    return self.to_type(item)

  def VisitClassType(self, t):
    new_type = self.VisitNamedType(t)
    if isinstance(new_type, pytd.ClassType):
      t.cls = new_type.cls
      return t
    else:
      return new_type

  def VisitGenericType(self, node):
    if isinstance(node.base_type, (pytd.GenericType, pytd.UnionType)):
      try:
        node = MaybeSubstituteParameters(
            node.base_type, node.parameters) or node
      except ValueError as e:
        raise KeyError(str(e)) from e
    return node

  def _ModulePrefix(self):
    return self.name + "." if self.name else ""

  def _ImportAll(self, module):
    """Get the new members that would result from a star import of the module.

    Args:
      module: The module name.

    Returns:
      A tuple of:
      - a list of new aliases,
      - a set of new __getattr__ functions.
    """
    aliases = []
    getattrs = set()
    ast = self._module_map[module]
    type_param_names = set()
    for member in sum((ast.constants, ast.type_params, ast.classes,
                       ast.functions, ast.aliases), ()):
      _, _, member_name = member.name.rpartition(".")
      new_name = self._ModulePrefix() + member_name
      if isinstance(member, pytd.Function) and member_name == "__getattr__":
        # def __getattr__(name) -> Any needs to be imported directly rather
        # than aliased.
        getattrs.add(member.Replace(name=new_name))
      else:
        # Imported type parameters produce both a type parameter definition and
        # an alias. Keep the definition and discard the alias.
        if isinstance(member, pytd.TypeParameter):
          type_param_names.add(new_name)
        elif new_name in type_param_names:
          continue
        t = pytd.ToType(member, allow_constants=True, allow_functions=True)
        aliases.append(pytd.Alias(new_name, t))
    return aliases, getattrs

  def _DiscardExistingNames(self, node, potential_members):
    new_members = []
    for m in potential_members:
      if m.name not in node:
        new_members.append(m)
    return new_members

  def _HandleDuplicates(self, new_aliases):
    """Handle duplicate module-level aliases.

    Aliases pointing to qualified names could be the result of importing the
    same entity through multiple import paths, which should not count as an
    error; instead we just deduplicate them.

    Args:
      new_aliases: The list of new aliases to deduplicate

    Returns:
      A deduplicated list of aliases.

    Raises:
      KeyError: If there is a name clash.
    """
    def SameModuleName(a, b):
      return (
          isinstance(a.type, pytd.Module) and
          isinstance(b.type, pytd.Module) and
          a.type.module_name == b.type.module_name
      )

    name_to_alias = {}
    out = []
    for a in new_aliases:
      if a.name not in name_to_alias:
        name_to_alias[a.name] = a
        out.append(a)
        continue
      existing = name_to_alias[a.name]
      if existing == a or SameModuleName(existing, a):
        continue
      raise KeyError("Duplicate top level items: %r, %r" % (
          existing.type.name, a.type.name))
    return out

  def EnterTypeDeclUnit(self, node):
    self._unit = node

  def VisitTypeDeclUnit(self, node):
    """Add star imports to the ast.

    Args:
      node: A pytd.TypeDeclUnit instance.

    Returns:
      The pytd.TypeDeclUnit instance, with star imports added.

    Raises:
      KeyError: If a duplicate member is found during import.
    """
    if not self._star_imports:
      return node
    # Discard the 'importing_mod.imported_mod.* = imported_mod.*' aliases.
    star_import_names = set()
    p = self._ModulePrefix()
    for x in self._star_imports:
      # Allow for the case of foo/__init__ importing foo.bar
      if x.startswith(p):
        star_import_names.add(x + ".*")
      star_import_names.add(p + x + ".*")
    new_aliases = []
    new_getattrs = set()
    for module in self._star_imports:
      aliases, getattrs = self._ImportAll(module)
      new_aliases.extend(aliases)
      new_getattrs.update(getattrs)
    # Allow local definitions to shadow imported definitions.
    new_aliases = self._DiscardExistingNames(node, new_aliases)
    new_getattrs = self._DiscardExistingNames(node, new_getattrs)
    # Don't allow imported definitions to conflict with one another.
    new_aliases = self._HandleDuplicates(new_aliases)
    if len(new_getattrs) > 1:
      raise KeyError("Multiple __getattr__ definitions")
    return node.Replace(
        functions=node.functions + tuple(new_getattrs),
        aliases=(
            tuple(a for a in node.aliases if a.name not in star_import_names) +
            tuple(new_aliases)))


class LookupLocalTypes(_RemoveTypeParametersFromGenericAny, _ToTypeVisitor):
  """Look up local identifiers. Must be called on a TypeDeclUnit."""

  def __init__(self, allow_singletons=False, toplevel=True):
    super().__init__(allow_singletons)
    self._toplevel = toplevel
    self.local_names = set()
    self.class_names = []

  def EnterTypeDeclUnit(self, unit):
    self.unit = unit

  def LeaveTypeDeclUnit(self, _):
    del self.unit

  def _LookupItemRecursive(self, name):
    return pytd.LookupItemRecursive(self.unit, name)

  def EnterClass(self, node):
    self.class_names.append(node.name)

  def LeaveClass(self, unused_node):
    self.class_names.pop()

  def _LookupScopedName(self, name):
    """Look up a name in the chain of nested class scopes."""
    scopes = [self.unit.name] + self.class_names
    prefix = f"{self.unit.name}."
    item = None
    while item is None and scopes:
      inner = scopes.pop()
      lookup_name = f"{inner}.{name}"[len(prefix):]
      try:
        item = self._LookupItemRecursive(lookup_name)
      except KeyError:
        pass
    return item

  def _LookupLocalName(self, node):
    assert "." not in node.name
    self.local_names.add(node.name)
    item = self._LookupScopedName(node.name)
    if item is None:
      # Node names are not prefixed by the unit name when infer calls
      # load_pytd.resolve_ast() for the final pyi.
      try:
        item = self.unit.Lookup(node.name)
      except KeyError:
        pass
    if item is None:
      if (self.allow_singletons and node.name in pytd.SINGLETON_TYPES):
        # Let the builtins resolver handle it
        return node
      msg = f"Couldn't find {node.name} in {self.unit.name}"
      raise SymbolLookupError(msg)
    return item

  def _LookupLocalTypes(self, node):
    visitor = LookupLocalTypes(self.allow_singletons, toplevel=False)
    visitor.unit = self.unit
    return node.Visit(visitor), visitor.local_names

  def VisitNamedType(self, node):
    """Do lookup on a pytd.NamedType."""
    # TODO(rechen): This method and FillInLocalPointers._Lookup do very similar
    # things; is there any common code we can extract out?
    if "." in node.name:
      resolved_node = None
      module_name, name = node.name, ""
      while "." in module_name:
        module_name, _, prefix = module_name.rpartition(".")
        name = f"{prefix}.{name}" if name else prefix
        if module_name == self.unit.name:
          # Fully qualified reference to a member of the current module. May
          # contain nested items that need to be recursively resolved.
          try:
            resolved_node = self.to_type(self._LookupItemRecursive(name))
          except (KeyError, NotImplementedError):
            if "." in name:
              # This might be a dotted local reference without a module_name
              # prefix, so we'll do another lookup attempt below.
              pass
            else:
              raise
          break
      if resolved_node is None:
        # Possibly a reference to a member of the current module that does not
        # have a module_name prefix.
        try:
          resolved_node = self.to_type(self._LookupItemRecursive(node.name))
        except KeyError:
          resolved_node = node  # lookup failures are handled later
    else:  # simple reference to a member of the current module
      item = self._LookupLocalName(node)
      if self._toplevel:
        # Check if the definition of this name refers back to itself.
        while isinstance(item, pytd.Alias):
          new_item, new_item_names = self._LookupLocalTypes(item)
          if node.name in new_item_names:
            # We've found a self-reference. This is a recursive type, so delay
            # resolution by representing it as a LateType.
            if item.name.startswith(f"{self.unit.name}."):
              late_name = f"{self.unit.name}.{node.name}"
            else:
              late_name = node.name
            item = pytd.LateType(late_name, recursive=True)
          elif new_item == item:
            break
          else:
            item = new_item
      try:
        resolved_node = self.to_type(item)
      except NotImplementedError as e:
        raise SymbolLookupError("%s is not a type" % item) from e
    if isinstance(resolved_node, (pytd.Constant, pytd.Function)):
      visitor = LookupLocalTypes()
      visitor.unit = self.unit
      return self._LookupLocalTypes(resolved_node)[0]
    return resolved_node


class ReplaceTypes(Visitor):
  """Visitor for replacing types in a tree.

  This replaces both NamedType and ClassType nodes that have a name in the
  mapping. The two cases are not distinguished.
  """

  def __init__(self, mapping, record=None):
    """Initialize this visitor.

    Args:
      mapping: A dictionary, mapping strings to node instances. Any NamedType
        or ClassType with a name in this dictionary will be replaced with
        the corresponding value.
      record: Optional. A set. If given, this records which entries in
        the map were used.
    """
    super().__init__()
    self.mapping = mapping
    self.record = record

  def VisitNamedType(self, node):
    if node.name in self.mapping:
      if self.record is not None:
        self.record.add(node.name)
      return self.mapping[node.name]
    return node

  def VisitClassType(self, node):
    return self.VisitNamedType(node)

  # We do *not* want to have 'def VisitClass' because that will replace a class
  # definition with itself, which is almost certainly not what is wanted,
  # because running pytd_utils.Print on it will result in output that's just a
  # list of class names with no contents.


class ExtractSuperClassesByName(ExtractSuperClasses):
  """Visitor for extracting all superclasses (i.e., the class hierarchy).

  This returns a mapping by name, e.g. {
    "bool": ["int"],
    "int": ["object"],
    ...
  }.
  """

  def _Key(self, node):
    if isinstance(node, (pytd.GenericType, pytd.GENERIC_BASE_TYPE, pytd.Class)):
      return node.name


class ReplaceTypeParameters(Visitor):
  """Visitor for replacing type parameters with actual types."""

  def __init__(self, mapping):
    super().__init__()
    self.mapping = mapping

  def VisitTypeParameter(self, p):
    return self.mapping[p]


def ClassAsType(cls):
  """Converts a pytd.Class to an instance of pytd.Type."""
  params = tuple(item.type_param for item in cls.template)
  if not params:
    return pytd.NamedType(cls.name)
  else:
    return pytd.GenericType(pytd.NamedType(cls.name), params)


class AdjustSelf(Visitor):
  """Visitor for setting the correct type on self.

  So
    class A:
      def f(self: object)
  becomes
    class A:
      def f(self: A)
  .
  (Notice the latter won't be printed like this, as printing simplifies the
   first argument to just "self")
  """

  def __init__(self, force=False):
    super().__init__()
    self.class_types = []  # allow nested classes
    self.force = force

  def EnterClass(self, cls):
    self.class_types.append(ClassAsType(cls))

  def LeaveClass(self, unused_node):
    self.class_types.pop()

  def VisitClass(self, node):
    return node

  def VisitParameter(self, p):
    """Adjust all parameters called "self" to have their base class type.

    But do this only if their original type is unoccupied ("object" or,
    if configured, "Any").

    Args:
      p: pytd.Parameter instance.

    Returns:
      Adjusted pytd.Parameter instance.
    """
    if not self.class_types:
      # We're not within a class, so this is not a parameter of a method.
      return p
    if p.name == "self" and (
        self.force or isinstance(p.type, pytd.AnythingType)):
      return p.Replace(type=self.class_types[-1])
    else:
      return p


class RemoveUnknownClasses(Visitor):
  """Visitor for converting ClassTypes called ~unknown* to just AnythingType.

  For example, this will change
    def f(x: ~unknown1) -> ~unknown2
    class ~unknown1:
      ...
    class ~unknown2:
      ...
  to
    def f(x) -> Any
  """

  def __init__(self):
    super().__init__()
    self.parameter = None

  def EnterParameter(self, p):
    self.parameter = p

  def LeaveParameter(self, p):
    assert self.parameter is p
    self.parameter = None

  def VisitClassType(self, t):
    if escape.is_unknown(t.name):
      return pytd.AnythingType()
    else:
      return t

  def VisitNamedType(self, t):
    if escape.is_unknown(t.name):
      return pytd.AnythingType()
    else:
      return t

  def VisitTypeDeclUnit(self, u):
    return u.Replace(classes=tuple(
        cls for cls in u.classes if not escape.is_unknown(cls.name)))


class _CountUnknowns(Visitor):
  """Visitor for counting how often given unknowns occur in a type."""

  def __init__(self):
    super().__init__()
    self.counter = collections.Counter()
    self.position = {}

  def EnterNamedType(self, t):
    _, is_unknown, suffix = t.name.partition(escape.UNKNOWN)
    if is_unknown:
      if suffix not in self.counter:
        # Also record the order in which we see the ~unknowns
        self.position[suffix] = len(self.position)
      self.counter[suffix] += 1

  def EnterClassType(self, t):
    return self.EnterNamedType(t)


class CreateTypeParametersForSignatures(Visitor):
  """Visitor for inserting type parameters into signatures.

  This visitor replaces re-occurring ~unknowns and class types (when necessary)
  with type parameters.

  For example, this will change
  1.
    class ~unknown1:
      ...
    def f(x: ~unknown1) -> ~unknown1
  to
    _T1 = TypeVar("_T1")
    def f(x: _T1) -> _T1
  2.
    class Foo:
      def __new__(cls: Type[Foo]) -> Foo
  to
    _TFoo = TypeVar("_TFoo", bound=Foo)
    class Foo:
      def __new__(cls: Type[_TFoo]) -> _TFoo
  3.
    class Foo:
      def __enter__(self) -> Foo
  to
    _TFoo = TypeVar("_TFoo", bound=Foo)
    class Foo:
      def __enter__(self: _TFoo) -> _TFoo
  """

  PREFIX = "_T"  # Prefix for new type params

  def __init__(self):
    super().__init__()
    self.parameter = None
    self.class_name = None
    self.function_name = None

  def EnterClass(self, node):
    self.class_name = node.name

  def LeaveClass(self, _):
    self.class_name = None

  def EnterFunction(self, node):
    self.function_name = node.name

  def LeaveFunction(self, _):
    self.function_name = None

  def _NeedsClassParam(self, sig):
    """Whether the signature needs a bounded type param for the class.

    We detect the signatures
      (cls: Type[X][, ...]) -> X
    and
      (self: X[, ...]) -> X
    so that we can replace X with a bounded TypeVar. This heuristic
    isn't perfect; for example, in this naive copy method:
      class X:
        def copy(self):
          return X()
    we should have left X alone. But it prevents a number of false
    positives by enabling us to infer correct types for common
    implementations of __new__ and __enter__.

    Args:
      sig: A pytd.Signature.

    Returns:
      True if the signature needs a class param, False otherwise.
    """
    if self.class_name and self.function_name and sig.params:
      # Printing the class name escapes illegal characters.
      safe_class_name = pytd_utils.Print(pytd.NamedType(self.class_name))
      return (pytd_utils.Print(sig.return_type) == safe_class_name and
              pytd_utils.Print(sig.params[0].type) in (
                  "Type[%s]" % safe_class_name, safe_class_name))
    return False

  def VisitSignature(self, sig):
    """Potentially replace ~unknowns with type parameters, in a signature."""
    if (escape.is_partial(self.class_name) or
        escape.is_partial(self.function_name)):
      # Leave unknown classes and call traces as-is, they'll never be part of
      # the output.
      return sig
    counter = _CountUnknowns()
    sig.Visit(counter)
    replacements = {}
    for suffix, count in counter.counter.items():
      if count > 1:
        # We don't care whether it actually occurs in different parameters. That
        # way, e.g. "def f(Dict[T, T])" works, too.
        type_param = pytd.TypeParameter(
            self.PREFIX + str(counter.position[suffix]))
        key = escape.UNKNOWN + suffix
        replacements[key] = type_param
    if self._NeedsClassParam(sig):
      type_param = pytd.TypeParameter(
          self.PREFIX + self.class_name, bound=pytd.NamedType(self.class_name))
      replacements[self.class_name] = type_param
    if replacements:
      self.added_new_type_params = True
      sig = sig.Visit(ReplaceTypes(replacements))
    return sig

  def EnterTypeDeclUnit(self, _):
    self.added_new_type_params = False

  def VisitTypeDeclUnit(self, unit):
    if self.added_new_type_params:
      return unit.Visit(AdjustTypeParameters())
    else:
      return unit


class VerifyVisitor(Visitor):
  """Visitor for verifying pytd ASTs. For tests."""

  _all_templates: Set[pytd.Node]

  def __init__(self):
    super().__init__()
    self._valid_param_name = re.compile(r"[a-zA-Z_]\w*$")

  def _AssertNoDuplicates(self, node, attrs):
    """Checks that we don't have duplicate top-level names."""
    get_set = lambda attr: {entry.name for entry in getattr(node, attr)}
    attr_to_set = {attr: get_set(attr) for attr in attrs}
    # Do a quick sanity check first, and a deeper check if that fails.
    total1 = len(set.union(*attr_to_set.values()))  # all distinct names
    total2 = sum(map(len, attr_to_set.values()), 0)  # all names
    if total1 != total2:
      for a1, a2 in itertools.combinations(attrs, 2):
        both = attr_to_set[a1] & attr_to_set[a2]
        if both:
          raise AssertionError("Duplicate name(s) %s in both %s and %s" % (
              list(both), a1, a2))

  def EnterTypeDeclUnit(self, node):
    self._AssertNoDuplicates(node, ["constants", "type_params", "classes",
                                    "functions", "aliases"])
    self._all_templates = set()

  def LeaveTypeDeclUnit(self, node):
    declared_type_params = {n.name for n in node.type_params}
    for t in self._all_templates:
      if t.name not in declared_type_params:
        raise AssertionError("Type parameter %r used, but not declared. "
                             "Did you call AdjustTypeParameters?" % t.name)

  def EnterClass(self, node):
    self._AssertNoDuplicates(node, ["methods", "constants"])

  def EnterFunction(self, node):
    assert node.signatures, node

  def EnterSignature(self, node):
    assert isinstance(node.has_optional, bool), node

  def EnterTemplateItem(self, node):
    self._all_templates.add(node)

  def EnterParameter(self, node):
    assert self._valid_param_name.match(node.name), node.name

  def EnterCallableType(self, node):
    self.EnterGenericType(node)

  def EnterGenericType(self, node):
    assert node.parameters, node


class RemoveMethods(Visitor):
  """Visitor for removing unwanted methods from classes.

  Intended to be used to remove unwanted __getattribute__/__getattr__ methods.
  """

  def __init__(self, names=("__getattribute__", "__getattr__")):
    super().__init__()
    self.names = names

  def VisitClass(self, node):
    return node.Replace(methods=tuple(f for f in node.methods
                                      if f.name not in self.names))


class StripExternalNamePrefix(Visitor):
  """Strips off the prefix the parser uses to mark external types.

  The prefix needs to be present for AddNamePrefix, and stripped off afterwards.
  """

  def VisitNamedType(self, node):
    new_name = utils.strip_prefix(node.name,
                                  parser_constants.EXTERNAL_NAME_PREFIX)
    return node.Replace(name=new_name)


class AddNamePrefix(Visitor):
  """Visitor for making names fully qualified.

  This will change
    class Foo:
      pass
    def bar(x: Foo) -> Foo
  to (e.g. using prefix "baz"):
    class baz.Foo:
      pass
    def bar(x: baz.Foo) -> baz.Foo
  """

  def __init__(self):
    super().__init__()
    self.cls_stack = []
    self.classes = None
    self.prefix = None
    self.name = None

  def _ClassStackString(self):
    return ".".join(cls.name for cls in self.cls_stack)

  def EnterTypeDeclUnit(self, node):
    self.classes = {cls.name for cls in node.classes}
    self.name = node.name
    self.prefix = node.name + "."

  def EnterClass(self, cls):
    self.cls_stack.append(cls)

  def LeaveClass(self, cls):
    assert self.cls_stack[-1] is cls
    self.cls_stack.pop()

  def VisitClassType(self, node):
    if node.cls is not None:
      raise ValueError("AddNamePrefix visitor called after resolving")
    return self.VisitNamedType(node)

  def VisitNamedType(self, node):
    """Prefix a pytd.NamedType."""
    if node.name.startswith(parser_constants.EXTERNAL_NAME_PREFIX):
      # This is an external type; do not prefix it. StripExternalNamePrefix will
      # remove it later.
      return node
    if self.cls_stack:
      if node.name == self.cls_stack[-1].name:
        # We're referencing a class from within itself.
        return node.Replace(name=self.prefix + self._ClassStackString())
      elif "." in node.name:
        prefix, base = node.name.rsplit(".", 1)
        if prefix == self.cls_stack[-1].name:
          # The parser leaves aliases to nested classes as
          # ImmediateOuter.Nested, so we need to insert the full class stack.
          name = self.prefix + self._ClassStackString() + "." + base
          return node.Replace(name=name)
    if node.name.split(".")[0] in self.classes:
      # We need to check just the first part, in case we have a class constant
      # like Foo.BAR, or some similarly nested name.
      return node.Replace(name=self.prefix + node.name)
    return node

  def VisitClass(self, node):
    name = self.prefix + self._ClassStackString()
    return node.Replace(name=name)

  def VisitTypeParameter(self, node):
    if node.scope is not None:
      return node.Replace(scope=self.prefix + node.scope)
    # Give the type parameter the name of the module it is in as its scope.
    # Module-level type parameters will keep this scope, but others will get a
    # more specific one in AdjustTypeParameters. The last character in the
    # prefix is the dot appended by EnterTypeDeclUnit, so omit that.
    return node.Replace(scope=self.name)

  def _VisitNamedNode(self, node):
    if self.cls_stack:
      # class attribute
      return node
    else:
      # global constant. Handle leading . for relative module names.
      return node.Replace(
          name=module_utils.get_absolute_name(self.name, node.name))

  def VisitFunction(self, node):
    return self._VisitNamedNode(node)

  def VisitConstant(self, node):
    return self._VisitNamedNode(node)

  def VisitAlias(self, node):
    return self._VisitNamedNode(node)

  def VisitModule(self, node):
    return self._VisitNamedNode(node)


class CollectDependencies(Visitor):
  """Visitor for retrieving module names from external types.

  Needs to be called on a TypeDeclUnit.
  """

  def __init__(self):
    super().__init__()
    self.dependencies = {}
    self.late_dependencies = {}

  def _ProcessName(self, name, dependencies):
    """Retrieve a module name from a node name."""
    module_name, dot, base_name = name.rpartition(".")
    if dot:
      if module_name:
        if module_name in dependencies:
          dependencies[module_name].add(base_name)
        else:
          dependencies[module_name] = {base_name}
      else:
        # If we have a relative import that did not get qualified (usually due
        # to an empty package_name), don't insert module_name='' into the
        # dependencies; we get a better error message if we filter it out here
        # and fail later on.
        logging.warning("Empty package name: %s", name)

  def EnterClassType(self, node):
    self._ProcessName(node.name, self.dependencies)

  def EnterNamedType(self, node):
    self._ProcessName(node.name, self.dependencies)

  def EnterLateType(self, node):
    self._ProcessName(node.name, self.late_dependencies)


def ExpandSignature(sig):
  """Expand a single signature.

  For argument lists that contain disjunctions, generates all combinations
  of arguments. The expansion will be done right to left.
  E.g., from (a or b, c or d), this will generate the signatures
  (a, c), (a, d), (b, c), (b, d). (In that order)

  Arguments:
    sig: A pytd.Signature instance.

  Returns:
    A list. The visit function of the parent of this node (VisitFunction) will
    process this list further.
  """
  params = []
  for param in sig.params:
    if isinstance(param.type, pytd.UnionType):
      # multiple types
      params.append([param.Replace(type=t) for t in param.type.type_list])
    else:
      # single type
      params.append([param])

  new_signatures = [sig.Replace(params=tuple(combination))
                    for combination in itertools.product(*params)]

  return new_signatures  # Hand list over to VisitFunction


class ExpandSignatures(Visitor):
  """Expand to Cartesian product of parameter types.

  For example, this transforms
    def f(x: Union[int, float], y: Union[int, float]) -> Union[str, unicode]
  to
    def f(x: int, y: int) -> Union[str, unicode]
    def f(x: int, y: float) -> Union[str, unicode]
    def f(x: float, y: int) -> Union[str, unicode]
    def f(x: float, y: float) -> Union[str, unicode]

  The expansion by this class is typically *not* an optimization. But it can be
  the precursor for optimizations that need the expanded signatures, and it can
  simplify code generation, e.g. when generating type declarations for a type
  inferencer.
  """

  def VisitFunction(self, f):
    """Rebuild the function with the new signatures.

    This is called after its children (i.e. when VisitSignature has already
    converted each signature into a list) and rebuilds the function using the
    new signatures.

    Arguments:
      f: A pytd.Function instance.

    Returns:
      Function with the new signatures.
    """

    # flatten return value(s) from VisitSignature
    signatures = tuple(ex for s in f.signatures for ex in ExpandSignature(s))  # pylint: disable=g-complex-comprehension
    return f.Replace(signatures=signatures)


class AdjustTypeParameters(Visitor):
  """Visitor for adjusting type parameters.

  * Inserts class templates.
  * Inserts signature templates.
  * Adds scopes to type parameters.
  """

  def __init__(self):
    super().__init__()
    self.class_typeparams = set()
    self.function_typeparams = None
    self.class_template = []
    self.class_name = None
    self.function_name = None
    self.constant_name = None
    self.all_typeparams = set()
    self.generic_level = 0

  def _GetTemplateItems(self, param):
    """Get a list of template items from a parameter."""
    items = []
    if isinstance(param, pytd.GenericType):
      for p in param.parameters:
        items.extend(self._GetTemplateItems(p))
    elif isinstance(param, pytd.UnionType):
      for p in param.type_list:
        items.extend(self._GetTemplateItems(p))
    elif isinstance(param, pytd.TypeParameter):
      items.append(pytd.TemplateItem(param))
    return items

  def VisitTypeDeclUnit(self, node):
    type_params_to_add = []
    declared_type_params = {n.name for n in node.type_params}
    # Sorting type params helps keep pickling deterministic.
    for t in sorted(self.all_typeparams):
      if t.name not in declared_type_params:
        logging.debug("Adding definition for type parameter %r", t.name)
        declared_type_params.add(t.name)
        type_params_to_add.append(t.Replace(scope=None))
    new_type_params = tuple(
        sorted(node.type_params + tuple(type_params_to_add)))
    return node.Replace(type_params=new_type_params)

  def _CheckDuplicateNames(self, params, class_name):
    seen = set()
    for x in params:
      if x.name in seen:
        raise ContainerError(
            "Duplicate type parameter %s in typing.Generic base of class %s" %
            (x.name, class_name))
      seen.add(x.name)

  def EnterClass(self, node):
    """Establish the template for the class."""
    templates = []
    generic_template = None

    for base in node.bases:
      if isinstance(base, pytd.GenericType):
        params = sum((self._GetTemplateItems(param)
                      for param in base.parameters), [])
        if base.name in ["typing.Generic", "Generic"]:
          # TODO(mdemello): Do we need "Generic" in here or is it guaranteed
          # to be replaced by typing.Generic by the time this visitor is called?
          self._CheckDuplicateNames(params, node.name)
          if generic_template:
            raise ContainerError(
                "Cannot inherit from Generic[...] multiple times in class %s"
                % node.name)
          else:
            generic_template = params
        else:
          templates.append(params)
    if generic_template:
      for params in templates:
        for param in params:
          if param not in generic_template:
            raise ContainerError(
                ("Some type variables (%s) are not listed in Generic of"
                 " class %s") % (param.type_param.name, node.name))
      templates = [generic_template]

    try:
      template = mro.MergeSequences(templates)
    except ValueError as e:
      raise ContainerError(
          "Illegal type parameter order in class %s" % node.name) from e

    self.class_template.append(template)

    for t in template:
      assert isinstance(t.type_param, pytd.TypeParameter)
      self.class_typeparams.add(t.name)

    self.class_name = node.name

  def LeaveClass(self, node):
    del node
    for t in self.class_template[-1]:
      if t.name in self.class_typeparams:
        self.class_typeparams.remove(t.name)
    self.class_name = None
    self.class_template.pop()

  def VisitClass(self, node):
    """Builds a template for the class from its GenericType bases."""
    # The template items will not have been properly scoped because they were
    # stored outside of the ast and not visited while processing the class
    # subtree.  They now need to be scoped similar to VisitTypeParameter,
    # except we happen to know they are all bound by the class.
    template = [pytd.TemplateItem(t.type_param.Replace(scope=node.name))
                for t in self.class_template[-1]]
    node = node.Replace(template=tuple(template))
    return node.Visit(AdjustSelf()).Visit(NamedTypeToClassType())

  def EnterSignature(self, unused_node):
    assert self.function_typeparams is None, self.function_typeparams
    self.function_typeparams = set()

  def LeaveSignature(self, unused_node):
    self.function_typeparams = None

  def _MaybeMutateSelf(self, sig):
    # If the given signature is an __init__ method for a generic class and the
    # class's type parameters all appear among the method's parameter
    # annotations, then we should add a mutation to the parameter values, e.g.:
    #   class Foo(Generic[T]):
    #      def __init__(self, x: T) -> None: ...
    # becomes:
    #   class Foo(Generic[T]):
    #     def __init__(self, x: T) -> None:
    #       self = Foo[T]
    if self.function_name != "__init__" or not self.class_name:
      return sig
    class_template = self.class_template[-1]
    if not class_template:
      return sig
    seen_params = {t.name: t for t in pytd_utils.GetTypeParameters(sig)}
    if any(t.name not in seen_params for t in class_template):
      return sig
    if not sig.params or sig.params[0].mutated_type:
      return sig
    mutated_type = pytd.GenericType(
        base_type=pytd.ClassType(self.class_name),
        parameters=tuple(seen_params[t.name] for t in class_template))
    self_param = sig.params[0].Replace(mutated_type=mutated_type)
    return sig.Replace(params=(self_param,) + sig.params[1:])

  def VisitSignature(self, node):
    # Sorting the template in CanonicalOrderingVisitor is enough to guarantee
    # pyi determinism, but we need to sort here as well for pickle determinism.
    return self._MaybeMutateSelf(
        node.Replace(template=tuple(sorted(self.function_typeparams))))

  def EnterFunction(self, node):
    self.function_name = node.name

  def LeaveFunction(self, unused_node):
    self.function_name = None

  def EnterConstant(self, node):
    self.constant_name = node.name

  def LeaveConstant(self, unused_node):
    self.constant_name = None

  def EnterGenericType(self, unused_node):
    self.generic_level += 1

  def LeaveGenericType(self, unused_node):
    self.generic_level -= 1

  def EnterCallableType(self, node):
    self.EnterGenericType(node)

  def LeaveCallableType(self, node):
    self.LeaveGenericType(node)

  def EnterTupleType(self, node):
    self.EnterGenericType(node)

  def LeaveTupleType(self, node):
    self.LeaveGenericType(node)

  def EnterUnionType(self, node):
    self.EnterGenericType(node)

  def LeaveUnionType(self, node):
    self.LeaveGenericType(node)

  def _GetFullName(self, name):
    return ".".join(n for n in [self.class_name, name] if n)

  def _GetScope(self, name):
    if name in self.class_typeparams:
      return self.class_name
    return self._GetFullName(self.function_name)

  def _IsBoundTypeParam(self, node):
    in_class = self.class_name and node.name in self.class_typeparams
    return in_class or self.generic_level

  def VisitTypeParameter(self, node):
    """Add scopes to type parameters, track unbound params."""
    if self.constant_name and not self._IsBoundTypeParam(node):
      raise ContainerError("Unbound type parameter %s in %s" % (
          node.name, self._GetFullName(self.constant_name)))
    scope = self._GetScope(node.name)
    if scope:
      node = node.Replace(scope=scope)
    else:
      # This is a top-level type parameter (TypeDeclUnit.type_params).
      # AddNamePrefix gave it the right scope, so leave it alone.
      pass

    if (self.function_typeparams is not None and
        node.name not in self.class_typeparams):
      self.function_typeparams.add(pytd.TemplateItem(node))
    self.all_typeparams.add(node)

    return node


class VerifyContainers(Visitor):
  """Visitor for verifying containers.

  Every container (except typing.Generic) must inherit from typing.Generic and
  have an explicitly parameterized base that is also a container. The
  parameters on typing.Generic must all be TypeVar instances. A container must
  have at most as many parameters as specified in its template.

  Raises:
    ContainerError: If a problematic container definition is encountered.
  """

  def EnterGenericType(self, node):
    """Verify a pytd.GenericType."""
    base_type = node.base_type
    if isinstance(base_type, pytd.LateType):
      return  # We can't verify this yet
    if not pytd.IsContainer(base_type.cls):
      raise ContainerError("Class %s is not a container" % base_type.name)
    elif base_type.name in ("typing.Generic", "typing.Protocol"):
      for t in node.parameters:
        if not isinstance(t, pytd.TypeParameter):
          raise ContainerError("Name %s must be defined as a TypeVar" % t.name)
    elif not isinstance(node, (pytd.CallableType, pytd.TupleType)):
      actual_param_count = len(node.parameters)
      if actual_param_count and not base_type.cls.template:
        # This AdjustTypeParameters() call is needed because we validate nodes
        # before their type parameters have been adjusted in some circular
        # import cases. The result of this adjustment is not saved because it
        # may not be accurate if the container is only partially resolved, but
        # it's good enough to avoid some spurious container validation errors.
        cls = base_type.cls.Visit(AdjustTypeParameters())
      else:
        cls = base_type.cls
      max_param_count = len(cls.template)
      if actual_param_count > max_param_count:
        raise ContainerError(
            "Too many parameters on %s: expected %s, got %s" % (
                base_type.name, max_param_count, actual_param_count))

  def EnterCallableType(self, node):
    self.EnterGenericType(node)

  def EnterTupleType(self, node):
    self.EnterGenericType(node)

  def _GetGenericBasesLookupMap(self, node):
    """Get a lookup map for the generic bases of a class.

    Gets a map from a pytd.ClassType to the list of pytd.GenericType bases of
    the node that have that class as their base. This method does depth-first
    traversal of the bases, which ensures that the order of elements in each
    list is consistent with the node's MRO.

    Args:
      node: A pytd.Class node.

    Returns:
      A pytd.ClassType -> List[pytd.GenericType] map.
    """
    mapping = collections.defaultdict(list)
    seen_bases = set()
    bases = list(reversed(node.bases))
    while bases:
      base = bases.pop()
      if base in seen_bases:
        continue
      seen_bases.add(base)
      if (isinstance(base, pytd.GenericType) and
          isinstance(base.base_type, pytd.ClassType)):
        mapping[base.base_type].append(base)
        bases.extend(reversed(base.base_type.cls.bases))
      elif isinstance(base, pytd.ClassType):
        bases.extend(reversed(base.cls.bases))
    return mapping

  def _UpdateParamToValuesMapping(self, mapping, param, value):
    """Update the given mapping of parameter names to values."""
    param_name = param.type_param.full_name
    if isinstance(value, pytd.TypeParameter):
      value_name = value.full_name
      assert param_name != value_name
      # A TypeVar has been aliased, e.g.,
      #   class MyList(List[U]): ...
      #   class List(Sequence[T]): ...
      # Register the alias. May raise AliasingDictConflictError.
      mapping.add_alias(param_name, value_name, set.union)
    else:
      # A TypeVar has been given a concrete value, e.g.,
      #   class MyList(List[str]): ...
      # Register the value.
      if param_name not in mapping:
        mapping[param_name] = set()
      mapping[param_name].add(value)

  def _TypeCompatibilityCheck(self, type_params):
    """Check if the types are compatible.

    It is used to handle the case:
      class A(Sequence[A]): pass
      class B(A, Sequence[B]): pass
      class C(B, Sequence[C]): pass
    In class `C`, the type parameter `_T` of Sequence could be `A`, `B` or `C`.
    Next we will check they have a linear inheritance relationship:
    `A` -> `B` -> `C`.

    Args:
      type_params: The class type params.

    Returns:
      True if all the types are compatible.
    """
    type_params = {t for t in type_params
                   if not isinstance(t, pytd.AnythingType)}
    if not all(isinstance(t, pytd.ClassType) for t in type_params):
      return False
    mro_list = [set(mro.GetBasesInMRO(t.cls)) for t in type_params]
    mro_list.sort(key=len)
    prev = set()
    for cur in mro_list:
      if not cur.issuperset(prev):
        return False
      prev = cur
    return True

  def EnterClass(self, node):
    """Check for conflicting type parameter values in the class's bases."""
    # Get the bases in MRO, since we need to know the order in which type
    # parameters are aliased or assigned values.
    try:
      classes = mro.GetBasesInMRO(node)
    except mro.MROError:
      # TODO(rechen): We should report this, but VerifyContainers() isn't the
      # right place to check for mro errors.
      return
    # GetBasesInMRO gave us the pytd.ClassType for each base. Map class types
    # to generic types so that we can iterate through the latter in MRO.
    cls_to_bases = self._GetGenericBasesLookupMap(node)
    param_to_values = datatypes.AliasingDict()
    ambiguous_aliases = set()
    for base in sum((cls_to_bases[cls] for cls in classes), []):
      for param, value in zip(base.base_type.cls.template, base.parameters):
        try:
          self._UpdateParamToValuesMapping(param_to_values, param, value)
        except datatypes.AliasingDictConflictError:
          ambiguous_aliases.add(param.type_param.full_name)
    for param_name, values in param_to_values.items():
      if any(param_to_values[alias] is values for alias in ambiguous_aliases):
        # Any conflict detected for this type parameter might be a false
        # positive, since a conflicting value assigned through an ambiguous
        # alias could have been meant for a different type parameter.
        continue
      elif len(values) > 1 and not self._TypeCompatibilityCheck(values):
        raise ContainerError(
            "Conflicting values for TypeVar %s: %s" % (
                param_name, ", ".join(str(v) for v in values)))
    for t in node.template:
      if t.type_param.full_name in param_to_values:
        value, = param_to_values[t.type_param.full_name]
        raise ContainerError(
            "Conflicting value %s for TypeVar %s" % (value,
                                                     t.type_param.full_name))


class ExpandCompatibleBuiltins(Visitor):
  """Ad-hoc inheritance.

  In parameters, replaces
    ClassType('builtins.float')
  with
    Union[ClassType('builtins.float'), ClassType('builtins.int')]

  And similarly for unicode->(unicode, str, bytes) and bool->(bool, None).

  Used to allow a function requiring a float to accept an int without making
  int inherit from float.

  NOTE: We do not do this for type parameter constraints.

  See https://www.python.org/dev/peps/pep-0484/#the-numeric-tower
  """

  def __init__(self, builtins):
    super().__init__()
    self.in_parameter = False
    self.in_type_parameter = False
    self.replacements = self._BuildReplacementMap(builtins)

  @staticmethod
  def _BuildReplacementMap(builtins):
    """Dict[str, UnionType[ClassType, ...]] map."""
    prefix = builtins.name + "."
    rmap = collections.defaultdict(list)

    # compat_list :: [(compat, name)], where name is the more generalized
    # type and compat is the less generalized one. (eg: name = float, compat =
    # int)
    compat_list = itertools.chain(
        set((v, v) for _, v in pep484.COMPAT_ITEMS), pep484.COMPAT_ITEMS)

    for compat, name in compat_list:
      prefix = builtins.name + "."
      full_name = prefix + compat
      t = builtins.Lookup(full_name)
      if isinstance(t, pytd.Class):
        # Depending on python version, bytes can be an Alias, if so don't
        # want it in our union
        rmap[prefix + name].append(pytd.ClassType(full_name, t))

    return {k: pytd.UnionType(tuple(v)) for k, v in rmap.items()}

  def EnterParameter(self, _):
    assert not self.in_parameter
    self.in_parameter = True

  def LeaveParameter(self, _):
    assert self.in_parameter
    self.in_parameter = False

  def EnterTypeParameter(self, _):
    assert not self.in_type_parameter
    self.in_type_parameter = True

  def LeaveTypeParameter(self, _):
    assert self.in_type_parameter
    self.in_type_parameter = False

  def EnterLiteral(self, _):
    assert not self.in_type_parameter
    self.in_type_parameter = True

  def LeaveLiteral(self, _):
    assert self.in_type_parameter
    self.in_type_parameter = False

  def VisitClassType(self, node):
    if self.in_parameter and not self.in_type_parameter:
      return self.replacements.get(node.name, node)
    else:
      return node


class ClearClassPointers(Visitor):
  """Set .cls pointers to 'None'."""

  def EnterClassType(self, node):
    node.cls = None


class ReplaceModulesWithAny(_RemoveTypeParametersFromGenericAny):
  """Replace all references to modules in a list with AnythingType."""

  def __init__(self, module_list):
    super().__init__()
    assert isinstance(module_list, list)
    self._any_modules = module_list

  def VisitNamedType(self, n):
    if any(n.name.startswith(module) for module in self._any_modules):
      return pytd.AnythingType()
    return n

  def VisitLateType(self, n):
    return self.VisitNamedType(n)

  def VisitClassType(self, n):
    return self.VisitNamedType(n)


class ReplaceUnionsWithAny(Visitor):

  def VisitUnionType(self, _):
    return pytd.AnythingType()


class ClassTypeToLateType(Visitor):
  """Convert ClassType to LateType."""

  def __init__(self, ignore):
    """Initialize the visitor.

    Args:
      ignore: A list of prefixes to ignore. Typically, this list includes
        things something like like "builtins.", since we don't want to
        convert builtin types to late types. (And, more generally, types of
        modules that are always loaded by pytype don't need to be late types)
    """
    super().__init__()
    self._ignore = ignore

  def VisitClassType(self, n):
    for prefix in self._ignore:
      if n.name.startswith(prefix) and "." not in n.name[len(prefix):]:
        return n
    return pytd.LateType(n.name)


class LateTypeToClassType(Visitor):
  """Convert LateType to (unresolved) ClassType."""

  def VisitLateType(self, t):
    return pytd.ClassType(t.name, None)


class DropMutableParameters(Visitor):
  """Drops all mutable parameters.

  Drops all mutable parameters. This visitor differs from
  transforms.RemoveMutableParameters in that the latter absorbs mutable
  parameters into the signature, while this one blindly drops them.
  """

  def VisitParameter(self, p):
    return p.Replace(mutated_type=None)