Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
pytype / pytd / optimize.py
Size: Mime:
"""Functions for optimizing pytd syntax trees.

   pytd files come from various sources, and are typically redundant (duplicate
   functions, different signatures saying the same thing, overlong type
   disjunctions). The Visitors in this file remove various forms of these
   redundancies.
"""

import collections
import logging

from pytype import utils
from pytype.pytd import abc_hierarchy
from pytype.pytd import escape
from pytype.pytd import pytd
from pytype.pytd import pytd_utils
from pytype.pytd import visitors

log = logging.getLogger(__name__)


class RenameUnknowns(visitors.Visitor):
  """Give unknowns that map to the same set of concrete types the same name."""

  def __init__(self, mapping):
    super().__init__()
    self.name_to_cls = {name: hash(cls) for name, cls in mapping.items()}
    self.cls_to_canonical_name = {
        cls: name for name, cls in self.name_to_cls.items()}

  def VisitClassType(self, node):
    if escape.is_unknown(node.name):
      return pytd.ClassType(
          self.cls_to_canonical_name[self.name_to_cls[node.name]], None)
    else:
      return node


class NormalizeGenericSelfTypes(visitors.Visitor):
  """Removes unwanted parameter types from the 'self' parameter.

  For example, this transforms
    class Foo(Generic[T]):
      def f(self: Foo[int]): ...
  to
    class Foo(Generic[T]):
      def f(self: Foo): ...
  """

  def __init__(self):
    super().__init__()
    self.class_stack = []

  def EnterClass(self, node):
    self.class_stack.append(node.name)

  def LeaveClass(self, node):
    self.class_stack.pop()

  def VisitFunction(self, node):
    if not self.class_stack:
      return node
    signatures = []
    for sig in node.signatures:
      if (sig.params and sig.params[0].name == "self" and
          isinstance(sig.params[0].type, pytd.GenericType) and
          sig.params[0].type.base_type.name == self.class_stack[-1]):
        param = sig.params[0].Replace(type=sig.params[0].type.base_type)
        signatures.append(sig.Replace(params=(param,) + sig.params[1:]))
      else:
        signatures.append(sig)
    return node.Replace(signatures=tuple(signatures))


class RemoveDuplicates(visitors.Visitor):
  """Remove duplicate function signatures.

  For example, this transforms
    def f(x: int) -> float
    def f(x: int) -> float
  to
    def f(x: int) -> float
  In order to be removed, a signature has to be exactly identical to an
  existing one.
  """

  def VisitFunction(self, node):
    # We remove duplicates, but keep existing entries in the same order.
    return node.Replace(
        signatures=tuple(pytd_utils.OrderedSet(node.signatures)))


class SimplifyUnions(visitors.Visitor):
  """Remove duplicate or redundant entries in union types.

  For example, this transforms
    a: Union[int, int]
    b: Union[int, Any]
    c: Union[int, int, float]
  to
    a: int
    b: Any
    c: Union[int, float]
  """

  def VisitUnionType(self, union):
    return pytd_utils.JoinTypes(union.type_list)


class _ReturnsAndExceptions:
  """Mutable class for collecting return types and exceptions of functions.

  The collecting is stable: Items are kept in the order in which they were
  encountered.

  Attributes:
    return_types: Return types seen so far.
    exceptions: Exceptions seen so far.
  """

  def __init__(self):
    self.return_types = []
    self.exceptions = []

  def Update(self, signature):
    """Add the return types / exceptions of a signature to this instance."""

    if signature.return_type not in self.return_types:
      self.return_types.append(signature.return_type)

    self.exceptions.extend(exception
                           for exception in signature.exceptions
                           if exception not in self.exceptions)


class CombineReturnsAndExceptions(visitors.Visitor):
  """Group function signatures that only differ in exceptions or return values.

  For example, this transforms
    def f(x: int) -> float:
      raise OverflowError()
    def f(x: int) -> int:
      raise IndexError()
  to
    def f(x: int) -> Union[float, int]:
      raise IndexError()
      raise OverflowError()
  """

  def _GroupByArguments(self, signatures):
    """Groups signatures by arguments.

    Arguments:
      signatures: A list of function signatures (Signature instances).

    Returns:
      A dictionary mapping signatures (without return and exceptions) to
      a tuple of return values and exceptions.
    """
    groups = collections.OrderedDict()  # Signature -> ReturnsAndExceptions
    for sig in signatures:
      stripped_signature = sig.Replace(return_type=None, exceptions=None)

      ret = groups.get(stripped_signature)
      if not ret:
        ret = _ReturnsAndExceptions()
        groups[stripped_signature] = ret

      ret.Update(sig)

    return groups

  def VisitFunction(self, f):
    """Merge signatures of a function.

    This groups signatures by arguments and then for each group creates a
    single signature that joins the return values / exceptions using "or".

    Arguments:
      f: A pytd.Function instance

    Returns:
      Function with simplified / combined signatures.
    """
    groups = self._GroupByArguments(f.signatures)

    new_signatures = []
    for stripped_signature, ret_exc in groups.items():
      ret = pytd_utils.JoinTypes(ret_exc.return_types)
      exc = tuple(ret_exc.exceptions)

      new_signatures.append(
          stripped_signature.Replace(return_type=ret, exceptions=exc)
      )
    return f.Replace(signatures=tuple(new_signatures))


class CombineContainers(visitors.Visitor):
  """Change unions of containers to containers of unions.

  For example, this transforms
    Union[list[int], list[float]]
  to
    list[Union[int, float]]
  .
  """

  _CONTAINER_NAMES = {
      pytd.TupleType: ("builtins.tuple", "typing.Tuple"),
      pytd.CallableType: ("typing.Callable",),
  }

  def _key(self, t):
    if isinstance(t, (pytd.CallableType, pytd.TupleType)):
      return (t.base_type, len(t.parameters))
    else:
      return t.base_type

  def _should_merge(self, pytd_type, union):
    """Determine whether pytd_type values in the union should be merged.

    If the union contains the homogeneous flavor of pytd_type (e.g.,
    GenericType(base_type=tuple) when pytd_type is TupleType), or pytd_type
    values of different lengths, we want to turn all of the pytd_type values
    into homogeneous ones so that they can be merged into a single container.

    Args:
      pytd_type: The pytd type, either TupleType or CallableType.
      union: a pytd.UnionType

    Returns:
      True if the pytd_type values should be merged, False otherwise.
    """
    names = self._CONTAINER_NAMES[pytd_type]
    length = None
    for t in union.type_list:
      if isinstance(t, pytd_type):
        if length is None:
          length = len(t.parameters)
        elif length != len(t.parameters):
          return True
      elif isinstance(t, pytd.GenericType) and t.name in names:
        return True
    return False

  def VisitUnionType(self, union):
    """Push unions down into containers.

    This collects similar container types in unions and merges them into
    single instances with the union type pushed down to the element_type level.

    Arguments:
      union: A pytd.Union instance. Might appear in a parameter, a return type,
        a constant type, etc.

    Returns:
      A simplified pytd.Union.
    """
    if not any(isinstance(t, pytd.GenericType) for t in union.type_list):
      # Optimization: If we're not going to change anything, return original.
      return union
    union = pytd_utils.JoinTypes(union.type_list)  # flatten
    if not isinstance(union, pytd.UnionType):
      union = pytd.UnionType((union,))
    merge_tuples = self._should_merge(pytd.TupleType, union)
    merge_callables = self._should_merge(pytd.CallableType, union)
    if merge_tuples or merge_callables:
      type_list = []
      for t in union.type_list:
        if merge_tuples and isinstance(t, pytd.TupleType):
          t = pytd.GenericType(base_type=t.base_type,
                               parameters=(pytd_utils.JoinTypes(t.parameters),))
        elif merge_callables and isinstance(t, pytd.CallableType):
          t = pytd.GenericType(base_type=t.base_type,
                               parameters=(pytd.AnythingType(), t.ret))
        type_list.append(t)
      union = union.Replace(type_list=tuple(type_list))
    collect = {}
    has_redundant_base_types = False
    for t in union.type_list:
      if isinstance(t, pytd.GenericType):
        key = self._key(t)
        if key in collect:
          has_redundant_base_types = True
          collect[key] = tuple(
              pytd_utils.JoinTypes([p1, p2])
              for p1, p2 in zip(collect[key], t.parameters))
        else:
          collect[key] = t.parameters
    if not has_redundant_base_types:
      return union
    result = pytd.NothingType()
    done = set()
    for t in union.type_list:
      if isinstance(t, pytd.GenericType):
        key = self._key(t)
        if key in done:
          continue  # already added
        parameters = collect[key]
        add = t.Replace(parameters=tuple(p.Visit(CombineContainers())
                                         for p in parameters))
        done.add(key)
      else:
        add = t
      result = pytd_utils.JoinTypes([result, add])
    return result


class Factorize(visitors.Visitor):
  """Opposite of ExpandSignatures. Factorizes cartesian products of functions.

  For example, this transforms
    def f(x: int, y: int)
    def f(x: int, y: float)
    def f(x: float, y: int)
    def f(x: float, y: float)
  to
    def f(x: Union[int, float], y: Union[int, float])
  """

  def _GroupByOmittedArg(self, signatures, i):
    """Group functions that are identical if you ignore one of the arguments.

    Arguments:
      signatures: A list of function signatures
      i: The index of the argument to ignore during comparison.

    Returns:
      A list of tuples (signature, types). "signature" is a signature with
      argument i omitted, "types" is the list of types that argument was
      found to have. signatures that don't have argument i are represented
      as (original, None).
    """
    groups = collections.OrderedDict()
    for sig in signatures:
      if i >= len(sig.params):
        # We can't omit argument i, because this signature has too few
        # arguments. Represent this signature as (original, None).
        groups[sig] = None
        continue
      if sig.params[i].mutated_type is not None:
        # We can't group mutable parameters. Leave this signature alone.
        groups[sig] = None
        continue

      # Set type of parameter i to None
      params = list(sig.params)
      param_i = params[i]
      params[i] = param_i.Replace(type=None)

      stripped_signature = sig.Replace(params=tuple(params))
      existing = groups.get(stripped_signature)
      if existing:
        existing.append(param_i.type)
      else:
        groups[stripped_signature] = [param_i.type]
    return groups.items()

  def VisitFunction(self, f):
    """Shrink a function, by factorizing cartesian products of arguments.

    Greedily groups signatures, looking at the arguments from left to right.
    This algorithm is *not* optimal. But it does the right thing for the
    typical cases.

    Arguments:
      f: An instance of pytd.Function. If this function has more
          than one signature, we will try to combine some of these signatures by
          introducing union types.

    Returns:
      A new, potentially optimized, instance of pytd.Function.

    """
    max_argument_count = max(len(s.params) for s in f.signatures)
    signatures = f.signatures

    for i in range(max_argument_count):
      new_sigs = []
      for sig, types in self._GroupByOmittedArg(signatures, i):
        if types:
          # One or more options for argument <i>:
          new_params = list(sig.params)
          new_params[i] = sig.params[i].Replace(
              type=pytd_utils.JoinTypes(types))
          sig = sig.Replace(params=tuple(new_params))
          new_sigs.append(sig)
        else:
          # Signature doesn't have argument <i>, so we store the original:
          new_sigs.append(sig)
      signatures = new_sigs

    return f.Replace(signatures=tuple(signatures))


class SuperClassHierarchy:
  """Utility class for optimizations working with superclasses."""

  def __init__(self, superclasses):
    self._superclasses = superclasses
    self._subclasses = utils.invert_dict(self._superclasses)

  def GetSuperClasses(self):
    return self._superclasses

  def _CollectSuperclasses(self, type_name, collect):
    """Recursively collect super classes for a type.

    Arguments:
      type_name: A string, the type's name.
      collect: A set() of strings, modified to contain all superclasses.
    """
    collect.add(type_name)
    # The superclasses might have superclasses of their own, so recurse.
    for superclass in self._superclasses.get(type_name, []):
      self._CollectSuperclasses(superclass, collect)

  def ExpandSuperClasses(self, t):
    """Generate a list of all (known) superclasses for a type.

    Arguments:
      t: A type name. E.g. "int".

    Returns:
      A set of types. This set includes t as well as all its superclasses. For
      example, this will return "bool", "int" and "object" for "bool".
    """
    superclasses = set()
    self._CollectSuperclasses(t, superclasses)
    return superclasses

  def ExpandSubClasses(self, t):
    """Generate a set of all (known) subclasses for a type.

    Arguments:
      t: A type. E.g. NamedType("int").

    Returns:
      A set of types. This set includes t as well as all its subclasses. For
      example, this will return "int" and "bool" for "int".
    """
    queue = [t]
    seen = set()
    while queue:
      item = queue.pop()
      if item not in seen:
        seen.add(item)
        queue.extend(self._subclasses[item])
    return seen

  def HasSubClassInSet(self, cls, known):
    """Queries whether a subclass of a type is present in a given set."""
    return any(sub in known
               for sub in self._subclasses[cls])

  def HasSuperClassInSet(self, cls, known):
    """Queries whether a superclass of a type is present in a given set."""
    return any(sub in known
               for sub in self._superclasses[cls])


class SimplifyUnionsWithSuperclasses(visitors.Visitor):
  """Simplify Unions with superclasses.

  E.g., this changes
    Union[int, bool]
  to
    int
  since bool is a subclass of int.

  (Interpreting types as "sets of values", this simplification is sound since
   A union B = A, if B is a subset of A.)
  """

  def __init__(self, hierarchy):
    super().__init__()
    self.hierarchy = hierarchy

  def VisitUnionType(self, union):
    c = collections.Counter()
    for t in set(union.type_list):
      if isinstance(t, pytd.GENERIC_BASE_TYPE):
        c += collections.Counter(self.hierarchy.ExpandSubClasses(str(t)))
    # Below, c[str[t]] can be zero - that's the default for non-existent items
    # in collections.Counter. It'll happen for types that are not
    # instances of GENERIC_BASE_TYPE, like container types.
    new_type_list = [t for t in union.type_list if c[str(t)] <= 1]
    return pytd_utils.JoinTypes(new_type_list)


class FindCommonSuperClasses(visitors.Visitor):
  """Find common super classes. Optionally also uses abstract base classes.

  E.g., this changes
    def f(x: Union[list, tuple], y: Union[frozenset, set]) -> Union[int, float]
  to
    def f(x: Sequence, y: Set) -> Real
  """

  def __init__(self, hierarchy):
    super().__init__()
    self.hierarchy = hierarchy

  def VisitUnionType(self, union):
    """Given a union type, try to find a simplification by using superclasses.

    This is a lossy optimization that tries to map a list of types to a common
    base type. For example, int and bool are both base classes of int, so it
    would convert "Union[int, bool]" to "int".

    Arguments:
      union: A union type.

    Returns:
      A simplified type, if available.
    """
    intersection = self.hierarchy.ExpandSuperClasses(str(union.type_list[0]))

    for t in union.type_list[1:]:
      intersection.intersection_update(
          self.hierarchy.ExpandSuperClasses(str(t)))

    # Remove "redundant" superclasses, by removing everything from the tree
    # that's not a leaf. I.e., we don't need "object" if we have more
    # specialized types.
    new_type_list = tuple(
        pytd.NamedType(cls) for cls in intersection
        if not self.hierarchy.HasSubClassInSet(cls, intersection))

    if not new_type_list:
      return union  # if types don't intersect, leave them alone

    return pytd_utils.JoinTypes(new_type_list)


class CollapseLongUnions(visitors.Visitor):
  """Shortens long unions to object (or "?").

  Poor man's version of FindCommonSuperClasses. Shorten types like
  "str or unicode or int or float or list" to just "object" or "?".

  Additionally, if the union already contains at least one "object", we also
  potentially replace the entire union with just "object".

  Attributes:
    max_length: The maximum number of types to allow in a union. If there are
      more types than this, it is shortened.
  """

  def __init__(self, max_length=7):
    assert isinstance(max_length, int)
    super().__init__()
    self.generic_type = pytd.AnythingType()
    self.max_length = max_length

  def VisitUnionType(self, union):
    if len(union.type_list) > self.max_length:
      return self.generic_type
    elif self.generic_type in union.type_list:
      return pytd_utils.JoinTypes(union.type_list)
    else:
      return union


class AdjustGenericType(visitors.Visitor):
  """Changes the generic type from "object" to "Any"."""

  def __init__(self):
    super().__init__()
    self.old_generic_type = pytd.ClassType("builtins.object")
    self.new_generic_type = pytd.AnythingType()

  def VisitClassType(self, t):
    if t == self.old_generic_type:
      return self.new_generic_type
    else:
      return t


class AdjustReturnAndConstantGenericType(visitors.Visitor):
  """Changes "object" to "Any" in return and constant types."""

  def VisitSignature(self, sig):
    return sig.Replace(return_type=sig.return_type.Visit(AdjustGenericType()))

  def VisitConstant(self, c):
    return c.Replace(type=c.type.Visit(AdjustGenericType()))


class AddInheritedMethods(visitors.Visitor):
  """Copy methods and constants from base classes into their derived classes.

  E.g. this changes
      class Bar:
        [methods and constants of Bar]
      class Foo(Bar):
        [methods and constants of Foo]
  to
      class Bar:
        [methods and constants of Bar]
      class Foo(Bar):
        [methods and constants of Bar]
        [methods and constants of Foo]
  .
  This is not an optimization by itself, but it can help with other
  optimizations (like signature merging), and is also useful as preprocessor
  for type matching.
  """

  def VisitLateType(self, _):
    raise NotImplementedError("Can't use AddInheritedMethods with LateType.")

  def VisitClass(self, cls):
    """Add superclass methods and constants to this Class."""
    if any(base for base in cls.bases if isinstance(base, pytd.NamedType)):
      raise AssertionError("AddInheritedMethods needs a resolved AST")
    # Filter out only the types we can reason about.
    bases = [base.cls
             for base in cls.bases
             if isinstance(base, pytd.ClassType)]
    # Don't pull in methods that are named the same as existing methods in
    # this class, local methods override base class methods.
    names = {m.name for m in cls.methods} | {c.name for c in cls.constants}
    adjust_self = visitors.AdjustSelf(force=True)
    adjust_self.class_types.append(visitors.ClassAsType(cls))
    new_methods = list(cls.methods)
    for base in bases:
      for m in base.methods:
        if m.name not in names:
          new_methods.append(m.Visit(adjust_self))
    new_constants = list(cls.constants)
    for base in bases:
      for c in base.constants:
        if c.name not in names:
          new_constants.append(c)
    return cls.Replace(methods=tuple(new_methods),
                       constants=tuple(new_constants))


class PullInMethodClasses(visitors.Visitor):
  """Simplifies classes with only a __call__ function to just a method.

  This transforms
    class Foo:
      m: Bar
    class Bar:
      def __call__(self: Foo, ...)
  to
    class Foo:
      def m(self, ...)
  .
  """

  def __init__(self):
    super().__init__()
    self._module = None
    self._total_count = collections.defaultdict(int)
    self._processed_count = collections.defaultdict(int)

  def _MaybeLookup(self, t):
    if isinstance(t, pytd.NamedType):
      try:
        return self._module.Lookup(t.name)
      except KeyError:
        return None
    elif isinstance(t, pytd.ClassType):
      return t.cls
    else:
      return None

  def _HasSelf(self, sig):
    """True if a signature has a self parameter.

    This only checks for the name, since the type can be too many different
    things (type of the method, type of the base class, object, unknown etc.)
    and doesn't carry over to the simplified version, anyway.

    Arguments:
      sig: Function signature (instance of pytd.Signature)
    Returns:
      True if the signature has "self".
    """
    return sig.params and sig.params[0].name == "self"

  def _LookupIfSimpleCall(self, t):
    """Looks up the type if it has only one method, "__call__"."""
    if not isinstance(t, (pytd.NamedType, pytd.ClassType)):
      # We only do this for simple types.
      return None
    cls = self._MaybeLookup(t)
    if not cls or not isinstance(cls, pytd.Class):
      # This is not a class or it doesn't exist, so assume it's not a method.
      return None
    if [f.name for f in cls.methods] != ["__call__"]:
      return None
    method, = cls.methods
    return cls if all(self._HasSelf(sig) for sig in method.signatures) else None

  def _CanDelete(self, cls):
    """Checks whether this class can be deleted.

    Returns whether all occurrences of this class as a type were due to
    constants we removed.

    Arguments:
      cls: A pytd.Class.
    Returns:
      True if we can delete this class.
    """
    if not self._processed_count[cls.name]:
      # Leave standalone classes alone. E.g. the pytd files in
      # stubs/builtins/ defines classes not used by anything else.
      return False
    return self._processed_count[cls.name] == self._total_count[cls.name]

  def EnterTypeDeclUnit(self, module):
    # Since modules are hierarchical, we enter TypeDeclUnits multiple times-
    # but we only want to record the top-level one.
    if not self._module:
      self._module = module

  def VisitTypeDeclUnit(self, unit):
    return unit.Replace(classes=tuple(c for c in unit.classes
                                      if not self._CanDelete(c)))

  def VisitClassType(self, t):
    self._total_count[t.name] += 1
    return t

  def VisitNamedType(self, t):
    self._total_count[t.name] += 1
    return t

  def VisitClass(self, cls):
    """Visit a class, and change constants to methods where possible."""
    new_constants = []
    new_methods = list(cls.methods)
    adjust_self = visitors.AdjustSelf(force=True)
    adjust_self.class_types.append(visitors.ClassAsType(cls))
    for const in cls.constants:
      c = self._LookupIfSimpleCall(const.type)
      if c:
        signatures = c.methods[0].signatures
        self._processed_count[c.name] += 1
        new_method = pytd.Function(const.name, signatures, c.methods[0].kind)
        new_methods.append(new_method.Visit(adjust_self))
      else:
        new_constants.append(const)  # keep
    return cls.Replace(constants=tuple(new_constants),
                       methods=tuple(new_methods))


class AbsorbMutableParameters(visitors.Visitor):
  """Converts mutable parameters to unions. This is lossy.

  For example, this will change
    def f(x: list[int]):
      x = list[Union[int, float]]
  to
    def f(x: Union[list[int], list[Union[int, float]])
  .
  (Use optimize.CombineContainers to then change x to list[Union[int, float]].)

  This also works for methods - it will then potentially change the type of
  "self". The resulting AST is temporary and needs careful handling.
  """

  def VisitParameter(self, p):
    if p.mutated_type is None:
      return p
    else:
      return p.Replace(type=pytd_utils.JoinTypes([p.type, p.mutated_type]),
                       mutated_type=None)


class SimplifyContainers(visitors.Visitor):
  """Simplifies containers whose type parameters are all Any.

  For example, this will change
    def f() -> List[any]
  to
    def f() -> list
  Note that we don't simplify TupleType or CallableType, since they have
  variable-length parameters, and the parameter length is meaningful even when
  the parameters are all Any.
  """

  def _Simplify(self, t):
    if all(isinstance(p, pytd.AnythingType) for p in t.parameters):
      return t.base_type
    else:
      return t

  def VisitGenericType(self, t):
    return self._Simplify(t)


class TypeParameterScope(visitors.Visitor):
  """Common superclass for optimizations that track type parameters."""

  def __init__(self):
    super().__init__()
    self.type_params_stack = [{}]

  def EnterClass(self, cls):
    new = self.type_params_stack[-1].copy()
    new.update({t.type_param: cls for t in cls.template})
    self.type_params_stack.append(new)

  def EnterSignature(self, sig):
    new = self.type_params_stack[-1].copy()
    new.update({t.type_param: sig for t in sig.template})
    self.type_params_stack.append(new)

  def IsClassTypeParameter(self, type_param):
    class_or_sig = self.type_params_stack[-1].get(type_param)
    return isinstance(class_or_sig, pytd.Class)

  def IsFunctionTypeParameter(self, type_param):
    class_or_sig = self.type_params_stack[-1].get(type_param)
    return isinstance(class_or_sig, pytd.Signature)

  def LeaveClass(self, _):
    self.type_params_stack.pop()

  def LeaveSignature(self, _):
    self.type_params_stack.pop()


class MergeTypeParameters(TypeParameterScope):
  """Remove all function type parameters in a union with a class type param.

  For example, this will change
    class A(typing.Generic(T)):
      def append(self, Union[T, T2]) -> T2
  to
    class A(typing.Generic(T)):
      def append(self, T) -> T
  .
  Use this visitor after using AbsorbMutableParameters.

  As another example, the combination of AbsorbMutableParameters and
  MergeTypeParameters transforms
    class list(typing.Generic(T)):
      def append(self, v: T2) -> NoneType:
        self = Union[T, T2]
  to
    class list(typing.Generic(T')):
      def append(self, V:T') -> NoneType
  by creating a *new* template variable T' that propagates the
  mutations to the outermost level (in this example, T' = Union[T, T2])
  """

  def __init__(self):
    super().__init__()
    self.type_param_union = None

  def _AppendNew(self, l1, l2):
    """Appends all items to l1 that are not in l2."""
    # l1 and l2 are small (2-3 elements), so just use two loops.
    for e2 in l2:
      if not any(e1 is e2 for e1 in l1):
        l1.append(e2)

  def EnterSignature(self, sig):
    # Necessary because TypeParameterScope also defines this function
    super().EnterSignature(sig)
    assert self.type_param_union is None
    self.type_param_union = collections.defaultdict(list)

  def LeaveSignature(self, node):
    # Necessary because TypeParameterScope also defines this function
    super().LeaveSignature(node)
    self.type_param_union = None

  def VisitUnionType(self, u):
    type_params = [t for t in u.type_list if isinstance(t, pytd.TypeParameter)]
    for t in type_params:
      if self.IsFunctionTypeParameter(t):
        self._AppendNew(self.type_param_union[t.name], type_params)
    return u

  def _AllContaining(self, type_param, seen=None):
    """Gets all type parameters that are in a union with the passed one."""
    seen = seen or set()
    result = [type_param]
    for other in self.type_param_union[type_param.name]:
      if other in seen:
        continue  # break cycles
      seen.add(other)
      self._AppendNew(result, self._AllContaining(other, seen) or [other])
    return result

  def _ReplaceByOuterIfNecessary(self, item, substitutions):
    """Potentially replace a function type param with a class type param.

    Args:
      item: A pytd.TemplateItem
      substitutions: A dictionary to update with what we replaced.
    Returns:
      Either [item] or [].
    """
    containing_union = self._AllContaining(item.type_param)
    if not containing_union:
      return [item]
    class_type_parameters = [type_param
                             for type_param in containing_union
                             if self.IsClassTypeParameter(type_param)]
    if class_type_parameters:
      substitutions[item.type_param] = pytd_utils.JoinTypes(
          class_type_parameters)
      return []
    else:
      # It's a function type parameter that appears in a union with other
      # function type parameters.
      return [item]

  def VisitSignature(self, sig):
    new_template = []
    substitutions = {k: k for k in self.type_params_stack[-1]}
    for item in sig.template:
      new_template += self._ReplaceByOuterIfNecessary(item, substitutions)
    if sig.template == new_template:
      return sig  # Nothing changed.
    else:
      return sig.Replace(template=tuple(new_template)).Visit(
          visitors.ReplaceTypeParameters(substitutions)).Visit(SimplifyUnions())


def Optimize(node,
             builtins=None,
             lossy=False,
             use_abcs=False,
             max_union=7,
             remove_mutable=False,
             can_do_lookup=True):
  """Optimize a PYTD tree.

  Tries to shrink a PYTD tree by applying various optimizations.

  Arguments:
    node: A pytd node to be optimized. It won't be modified - this function
        will return a new node.
    builtins: Definitions of all of the external types in node.
    lossy: Allow optimizations that change the meaning of the pytd.
    use_abcs: Use abstract base classes to represent unions like
        e.g. "Union[float, int]" as "Real".
    max_union: How many types we allow in a union before we simplify
        it to just "object".
    remove_mutable: Whether to simplify mutable parameters to normal
        parameters.
    can_do_lookup: True: We're either allowed to try to resolve NamedType
        instances in the AST, or the AST is already resolved. False: Skip any
        optimizations that would require NamedTypes to be resolved.

  Returns:
    An optimized node.
  """
  node = node.Visit(NormalizeGenericSelfTypes())
  node = node.Visit(RemoveDuplicates())
  node = node.Visit(SimplifyUnions())
  node = node.Visit(CombineReturnsAndExceptions())
  node = node.Visit(Factorize())
  node = node.Visit(CombineContainers())
  node = node.Visit(SimplifyContainers())
  if builtins:
    superclasses = builtins.Visit(visitors.ExtractSuperClassesByName())
    superclasses.update(node.Visit(visitors.ExtractSuperClassesByName()))
    if use_abcs:
      superclasses.update(abc_hierarchy.GetSuperClasses())
    hierarchy = SuperClassHierarchy(superclasses)
    node = node.Visit(SimplifyUnionsWithSuperclasses(hierarchy))
    if lossy:
      node = node.Visit(FindCommonSuperClasses(hierarchy))
  if max_union:
    node = node.Visit(CollapseLongUnions(max_union))
  node = node.Visit(AdjustReturnAndConstantGenericType())
  if remove_mutable:
    node = node.Visit(AbsorbMutableParameters())
    node = node.Visit(CombineContainers())
    node = node.Visit(MergeTypeParameters())
    node = node.Visit(visitors.AdjustSelf())
  node = node.Visit(SimplifyContainers())
  if builtins and can_do_lookup:
    node = visitors.LookupClasses(node, builtins, ignore_late_types=True)
  return node