Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
pytype / pytd / printer.py
Size: Mime:
"""Printer to output pytd trees in pyi format."""

import collections
import copy
import logging
import re

from pytype import utils
from pytype.pytd import base_visitor
from pytype.pytd import pep484
from pytype.pytd import pytd
from pytype.pytd.parse import parser_constants


class PrintVisitor(base_visitor.Visitor):
  """Visitor for converting ASTs back to pytd source code."""
  visits_all_node_types = True
  unchecked_node_names = base_visitor.ALL_NODE_NAMES

  INDENT = " " * 4
  _RESERVED = frozenset(parser_constants.RESERVED +
                        parser_constants.RESERVED_PYTHON)

  def __init__(self, multiline_args=False):
    super().__init__()
    self.class_names = []  # allow nested classes
    self.imports = collections.defaultdict(set)
    self.in_alias = False
    self.in_parameter = False
    self.in_literal = False
    self.in_constant = False
    self.in_signature = False
    self.multiline_args = multiline_args

    self._unit = None
    self._local_names = {}
    self._class_members = set()
    self._typing_import_counts = collections.defaultdict(int)
    self._module_aliases = {}

  def Print(self, node):
    return node.Visit(copy.deepcopy(self))

  def _IsEmptyTuple(self, t):
    """Check if it is an empty tuple."""
    assert isinstance(t, pytd.GenericType)
    return isinstance(t, pytd.TupleType) and not t.parameters

  def _NeedsTupleEllipsis(self, t):
    """Do we need to use Tuple[x, ...] instead of Tuple[x]?"""
    assert isinstance(t, pytd.GenericType)
    if isinstance(t, pytd.TupleType):
      return False  # TupleType is always heterogeneous.
    return t.base_type == "tuple"

  def _NeedsCallableEllipsis(self, t):
    """Check if it is typing.Callable type."""
    assert isinstance(t, pytd.GenericType)
    return t.name == "typing.Callable"

  def _RequireImport(self, module, name=None):
    """Register that we're using name from module.

    Args:
      module: string identifier.
      name: if None, means we want 'import module'. Otherwise string identifier
       that we want to import.
    """
    self.imports[module].add(name)

  def _GenerateImportStrings(self):
    """Generate import statements needed by the nodes we've visited so far.

    Returns:
      List of strings.
    """
    ret = []
    for module in sorted(self.imports):
      names = set(self.imports[module])
      if module == "typing":
        need_typing = False
        for (name, count) in self._typing_import_counts.items():
          if count:
            need_typing = True
          else:
            names.discard(name)
        if not need_typing:
          names.discard(None)
      if None in names:
        ret.append(f"import {module}")
        names.remove(None)

      if names:
        name_str = ", ".join(sorted(names))
        ret.append(f"from {module} import {name_str}")

    return ret

  def _IsBuiltin(self, module):
    return module == "builtins"

  def _FormatTypeParams(self, type_params):
    formatted_type_params = []
    for t in type_params:
      args = [f"'{t.name}'"]
      args += [self.Print(c) for c in t.constraints]
      if t.bound:
        args.append(f"bound={self.Print(t.bound)}")
      formatted_type_params.append(f"{t.name} = TypeVar({', '.join(args)})")
    return sorted(formatted_type_params)

  def _NameCollision(self, name):

    def name_in(members):
      return name in members or (
          self._unit and f"{self._unit.name}.{name}" in members)

    return name_in(self._class_members) or name_in(self._local_names)

  def _FromTyping(self, name):
    self._typing_import_counts[name] += 1
    if self._NameCollision(name):
      self._RequireImport("typing")
      return f"typing.{name}"
    else:
      self._RequireImport("typing", name)
      return name

  def _ImportTypingExtension(self, name):
    # `name` is a typing construct that is not supported in all Python versions.
    if self._local_names.get(name) == "alias":
      # A typing_extensions import is parsed as Alias(X, typing_extensions.X).
      # If we see an alias to `name`, assume it's been explicitly imported from
      # typing_extensions due to the current Python version not supporting it.
      return name
    else:
      return self._FromTyping(name)

  def _StripUnitPrefix(self, name):
    if self._unit:
      return utils.strip_prefix(name, f"{self._unit.name}.")
    else:
      return name

  def EnterTypeDeclUnit(self, unit):
    self._unit = unit
    for definitions, label in [(unit.classes, "class"),
                               (unit.functions, "function"),
                               (unit.constants, "constant"),
                               (unit.type_params, "type_param"),
                               (unit.aliases, "alias")]:
      for defn in definitions:
        self._local_names[defn.name] = label
    for alias in unit.aliases:
      # Modules are represented as NamedTypes in partially resolved asts and
      # sometimes as LateTypes in asts modified for pickling.
      if isinstance(alias.type, pytd.Module):
        module_name = alias.type.module_name
      elif isinstance(alias.type, (pytd.NamedType, pytd.LateType)):
        module_name = alias.type.name
      else:
        continue
      name = self._StripUnitPrefix(alias.name)
      self._module_aliases[module_name] = name

  def LeaveTypeDeclUnit(self, _):
    self._unit = None
    self._local_names = {}

  def VisitTypeDeclUnit(self, node):
    """Convert the AST for an entire module back to a string."""
    if node.type_params:
      self._FromTyping("TypeVar")

    aliases = []
    imports = set(self._GenerateImportStrings())
    for alias in node.aliases:
      if alias.startswith(("from ", "import ")):
        imports.add(alias)
      else:
        aliases.append(alias)

    # Sort import lines lexicographically and ensure import statements come
    # before from-import statements.
    imports = sorted(imports, key=lambda s: (s.startswith("from "), s))

    sections = [
        imports,
        aliases,
        node.constants,
        self._FormatTypeParams(self.old_node.type_params),
        node.classes,
        node.functions,
    ]

    # We put one blank line after every class,so we need to strip the blank line
    # after the last class.
    sections_as_string = ("\n".join(section_suite).rstrip()
                          for section_suite in sections
                          if section_suite)
    return "\n\n".join(sections_as_string)

  def EnterConstant(self, node):
    self.in_constant = True

  def LeaveConstant(self, node):
    self.in_constant = False

  def VisitConstant(self, node):
    """Convert a class-level or module-level constant to a string."""
    if self.in_literal:
      module, _, name = node.name.partition(".")
      assert module == "builtins", module
      assert name in ("True", "False"), name
      return name
    return f"{node.name}: {node.type}"

  def EnterAlias(self, _):
    self.old_imports = self.imports.copy()

  def VisitAlias(self, node):
    """Convert an import or alias to a string."""
    if (isinstance(self.old_node.type,
                   (pytd.NamedType, pytd.ClassType, pytd.LateType)) and
        not self.in_constant and not self.in_signature):
      full_name = self.old_node.type.name
      suffix = ""
      module, _, name = full_name.rpartition(".")
      if module:
        alias_name = self._StripUnitPrefix(self.old_node.name)
        if name not in ("*", alias_name):
          suffix += f" as {alias_name}"
        self.imports = self.old_imports  # undo unnecessary imports change
        return f"from {module} import {name}{suffix}"
    elif isinstance(self.old_node.type, (pytd.Constant, pytd.Function)):
      return self.Print(self.old_node.type.Replace(name=node.name))
    elif isinstance(self.old_node.type, pytd.Module):
      return node.type
    return f"{node.name} = {node.type}"

  def EnterClass(self, node):
    """Entering a class - record class name for children's use."""
    n = node.name
    if node.template:
      n += "[{}]".format(", ".join(self.Print(t) for t in node.template))
    for member in node.methods + node.constants:
      self._class_members.add(member.name)
    self.class_names.append(n)
    # Class decorators are resolved to their underlying functions, but all we
    # output is '@{decorator.name}', so we do not want to visit the Function()
    # node and collect types etc. (In particular, we would add a spurious import
    # of 'Any' when generating a decorator for an InterpreterClass.)
    return {"decorators"}

  def LeaveClass(self, unused_node):
    self._class_members.clear()
    self.class_names.pop()

  def VisitClass(self, node):
    """Visit a class, producing a multi-line, properly indented string."""
    bases = node.bases
    # If object is the only base, we don't need to list any bases.
    if bases == ("object",):
      bases = ()
    if node.metaclass is not None:
      bases += ("metaclass=" + node.metaclass,)
    bases_str = f"({', '.join(bases)})" if bases else ""
    header = [f"class {node.name}{bases_str}:"]
    if node.slots is not None:
      slots_str = ", ".join(f"\"{s}\"" for s in node.slots)
      slots = [self.INDENT + f"__slots__ = [{slots_str}]"]
    else:
      slots = []
    decorators = ["@" + self.VisitNamedType(d)
                  for d in self.old_node.decorators]
    # Our handling of class decorators is a bit hacky (see output.py); this
    # makes sure that typing classes read in directly from a pyi file and then
    # reemitted (e.g. in assertTypesMatchPytd) have their required module
    # imports handled correctly.
    for d in self.old_node.decorators:
      if d.type.name.startswith("typing."):
        self.VisitNamedType(d.type)
    if node.classes or node.methods or node.constants or slots:
      # We have multiple methods, and every method has multiple signatures
      # (i.e., the method string will have multiple lines). Combine this into
      # an array that contains all the lines, then indent the result.
      class_lines = sum((m.splitlines() for m in node.classes), [])
      classes = [self.INDENT + m for m in class_lines]
      constants = [self.INDENT + m for m in node.constants]
      method_lines = sum((m.splitlines() for m in node.methods), [])
      methods = [self.INDENT + m for m in method_lines]
    else:
      header[-1] += " ..."
      constants = []
      classes = []
      methods = []
    lines = decorators + header + slots + classes + constants + methods
    return "\n".join(lines) + "\n"

  def VisitFunction(self, node):
    """Visit function, producing multi-line string (one for each signature)."""
    function_name = node.name
    decorators = ""
    if node.is_final:
      decorators += "@" + self._FromTyping("final") + "\n"
    if (node.kind == pytd.MethodTypes.STATICMETHOD and
        function_name != "__new__"):
      decorators += "@staticmethod\n"
    elif (node.kind == pytd.MethodTypes.CLASSMETHOD and
          function_name != "__init_subclass__"):
      decorators += "@classmethod\n"
    elif node.kind == pytd.MethodTypes.PROPERTY:
      decorators += "@property\n"
    if node.is_abstract:
      decorators += "@abstractmethod\n"
    if node.is_coroutine:
      decorators += "@coroutine\n"
    if len(node.signatures) > 1:
      decorators += "@" + self._FromTyping("overload") + "\n"
    signatures = "\n".join(decorators + "def " + function_name + sig
                           for sig in node.signatures)
    return signatures

  def _FormatContainerContents(self, node):
    """Print out the last type parameter of a container. Used for *args/**kw."""
    assert isinstance(node, pytd.Parameter)
    if isinstance(node.type, pytd.GenericType):
      container_name = node.type.name.rpartition(".")[2]
      assert container_name in ("tuple", "dict")
      self._typing_import_counts[container_name.capitalize()] -= 1
      # If the type is "Any", e.g. `**kwargs: Any`, decrement Any to avoid an
      # extraneous import of typing.Any. Any was visited before this function
      # transformed **kwargs, so it was incremented at least once already.
      if isinstance(node.type.parameters[-1], pytd.AnythingType):
        self._typing_import_counts["Any"] -= 1
      return self.Print(
          node.Replace(type=node.type.parameters[-1], optional=False))
    else:
      return self.Print(node.Replace(type=pytd.AnythingType(), optional=False))

  def EnterSignature(self, node):
    self.in_signature = True

  def LeaveSignature(self, node):
    self.in_signature = False

  def VisitSignature(self, node):
    """Visit a signature, producing a string."""
    if node.return_type == "nothing":
      return_type = "NoReturn"  # a prettier alias for nothing
      self._FromTyping(return_type)
    else:
      return_type = node.return_type
    ret = f" -> {return_type}"

    # Put parameters in the right order:
    # (arg1, arg2, *args, kwonly1, kwonly2, **kwargs)
    if self.old_node.starargs is not None:
      starargs = self._FormatContainerContents(self.old_node.starargs)
    else:
      # We don't have explicit *args, but we might need to print "*", for
      # kwonly params.
      starargs = ""
    params = node.params
    for i, p in enumerate(params):
      if self.old_node.params[i].kwonly:
        assert all(p.kwonly for p in self.old_node.params[i:])
        params = params[0:i] + ("*"+starargs,) + params[i:]
        break
    else:
      if starargs:
        params += (f"*{starargs}",)
    if self.old_node.starstarargs is not None:
      starstarargs = self._FormatContainerContents(self.old_node.starstarargs)
      params += (f"**{starstarargs}",)

    body = []
    # Handle Mutable parameters
    # pylint: disable=no-member
    # (old_node is set in parse/node.py)
    mutable_params = [(p.name, p.mutated_type) for p in self.old_node.params
                      if p.mutated_type is not None]
    # pylint: enable=no-member
    for name, new_type in mutable_params:
      body.append("\n{indent}{name} = {new_type}".format(
          indent=self.INDENT, name=name, new_type=self.Print(new_type)))
    for exc in node.exceptions:
      body.append("\n{indent}raise {exc}()".format(indent=self.INDENT, exc=exc))
    if not body:
      body.append(" ...")

    if self.multiline_args:
      indent = "\n" + self.INDENT
      params = ",".join([indent + p for p in params])
      return "({params}\n){ret}:{body}".format(
          params=params, ret=ret, body="".join(body))
    else:
      params = ", ".join(params)
      return "({params}){ret}:{body}".format(
          params=params, ret=ret, body="".join(body))

  def EnterParameter(self, unused_node):
    assert not self.in_parameter
    self.in_parameter = True

  def LeaveParameter(self, unused_node):
    assert self.in_parameter
    self.in_parameter = False

  def VisitParameter(self, node):
    """Convert a function parameter to a string."""
    suffix = " = ..." if node.optional else ""
    if node.type == "Any":
      # Abbreviated form. "Any" is the default.
      self._typing_import_counts["Any"] -= 1
      return node.name + suffix
    # For parameterized class, for example: ClsName[T, V].
    # Its name is `ClsName` before `[`.
    elif node.name == "self" and self.class_names and (
        self.class_names[-1].split("[")[0] == node.type.split("[")[0]):
      if "[" in node.type:
        elided = node.type.split("[", 1)[-1]
        for k in self._typing_import_counts:
          if re.search(r"(^|\W)%s($|\W)" % k, elided):
            self._typing_import_counts[k] -= 1
      return node.name + suffix
    elif node.name == "cls" and self.class_names and (
        node.type == "Type[%s]" % self.class_names[-1]):
      self._typing_import_counts["Type"] -= 1
      return node.name + suffix
    elif node.type is None:
      logging.warning("node.type is None")
      return node.name
    else:
      return node.name + ": " + node.type + suffix

  def VisitTemplateItem(self, node):
    """Convert a template to a string."""
    return node.type_param

  def _UseExistingModuleAlias(self, name):
    prefix, suffix = name.rsplit(".", 1)
    while prefix:
      if prefix in self._module_aliases:
        return f"{self._module_aliases[prefix]}.{suffix}"
      prefix, _, remainder = prefix.rpartition(".")
      suffix = f"{remainder}.{suffix}"
    return None

  def _GuessModule(self, maybe_module):
    """Guess which part of the given name is the module prefix."""
    if "." not in maybe_module:
      return maybe_module, ""
    prefix, suffix = maybe_module.rsplit(".", 1)
    # Heuristic: modules are typically lowercase, classes uppercase.
    if suffix[0].islower():
      return maybe_module, ""
    else:
      module, rest = self._GuessModule(prefix)
      return module, f"{rest}.{suffix}" if rest else suffix

  def VisitNamedType(self, node):
    """Convert a type to a string."""
    prefix, _, suffix = node.name.rpartition(".")
    if self._IsBuiltin(prefix) and not self._NameCollision(suffix):
      node_name = suffix
    elif prefix == "typing":
      node_name = self._FromTyping(suffix)
    elif "." not in node.name:
      node_name = node.name
    else:
      if self._unit:
        try:
          pytd.LookupItemRecursive(self._unit, self._StripUnitPrefix(node.name))
        except KeyError:
          aliased_name = self._UseExistingModuleAlias(node.name)
          if aliased_name:
            node_name = aliased_name
          else:
            module, rest = self._GuessModule(prefix)
            module_alias = module
            while self._NameCollision(module_alias):
              module_alias = f"_{module_alias}"
            if module_alias == module:
              self._RequireImport(module)
              node_name = node.name
            else:
              self._RequireImport(f"{module} as {module_alias}")
              node_name = ".".join(filter(bool, (module_alias, rest, suffix)))
        else:
          node_name = node.name
      else:
        node_name = node.name
    if node_name == "NoneType":
      # PEP 484 allows this special abbreviation.
      return "None"
    else:
      return node_name

  def VisitLateType(self, node):
    return self.VisitNamedType(node)

  def VisitClassType(self, node):
    return self.VisitNamedType(node)

  def VisitStrictType(self, node):
    # 'StrictType' is defined, and internally used, by booleq.py. We allow it
    # here so that booleq.py can use pytd_utils.Print().
    return self.VisitNamedType(node)

  def VisitAnythingType(self, unused_node):
    """Convert an anything type to a string."""
    return self._FromTyping("Any")

  def VisitNothingType(self, unused_node):
    """Convert the nothing type to a string."""
    return "nothing"

  def VisitTypeParameter(self, node):
    return node.name

  def VisitModule(self, node):
    if self.in_constant or self.in_signature:
      return "module"
    elif not node.is_aliased:
      return f"import {node.module_name}"
    elif "." in node.module_name:
      # `import x.y as z` and `from x import y as z` are equivalent, but the
      # latter is a bit prettier.
      prefix, suffix = node.module_name.rsplit(".", 1)
      imp = f"from {prefix} import {suffix}"
      if node.name != suffix:
        imp += f" as {node.name}"
      return imp
    else:
      return f"import {node.module_name} as {node.name}"

  def MaybeCapitalize(self, name):
    """Capitalize a generic type, if necessary."""
    capitalized = pep484.PEP484_MaybeCapitalize(name)
    if capitalized:
      return self._FromTyping(capitalized)
    else:
      return name

  def VisitGenericType(self, node):
    """Convert a generic type to a string."""
    parameters = node.parameters
    if self._IsEmptyTuple(node):
      parameters = ("()",)
    elif self._NeedsTupleEllipsis(node):
      parameters += ("...",)
    elif self._NeedsCallableEllipsis(self.old_node):
      # Callable[Any, X] is rewritten to Callable[..., X].
      self._typing_import_counts["Any"] -= 1
      parameters = ("...",) + parameters[1:]
    return (self.MaybeCapitalize(node.base_type) +
            "[" + ", ".join(parameters) + "]")

  def VisitCallableType(self, node):
    typ = self.MaybeCapitalize(node.base_type)
    args = ", ".join(node.args)
    return f"{typ}[[{args}], {node.ret}]"

  def VisitTupleType(self, node):
    return self.VisitGenericType(node)

  def VisitUnionType(self, node):
    """Convert a union type ("x or y") to a string."""
    type_list = self._FormSetTypeList(node)
    return self._BuildUnion(type_list)

  def VisitIntersectionType(self, node):
    """Convert a intersection type ("x and y") to a string."""
    type_list = self._FormSetTypeList(node)
    return self._BuildIntersection(type_list)

  def _FormSetTypeList(self, node):
    """Form list of types within a set type."""
    type_list = collections.OrderedDict.fromkeys(node.type_list)
    if self.in_parameter:
      # Parameter's set types are merged after as a follow up to the
      # ExpandCompatibleBuiltins visitor.
      for compat, name in pep484.COMPAT_ITEMS:
        # name can replace compat.
        if compat in type_list and name in type_list:
          del type_list[compat]
    return type_list

  def _BuildUnion(self, type_list):
    """Builds a union of the types in type_list.

    Args:
      type_list: A list of strings representing types.

    Returns:
      A string representing the union of the types in type_list. Simplifies
      Union[X] to X and Union[X, None] to Optional[X].
    """
    # Collect all literals, so we can print them using the Literal[x1, ..., xn]
    # syntactic sugar.
    literals = []
    new_type_list = []
    for t in type_list:
      match = re.fullmatch(r"Literal\[(?P<content>.*)\]", t)
      if match:
        literals.append(match.group("content"))
      else:
        new_type_list.append(t)
    if literals:
      new_type_list.append("Literal[%s]" % ", ".join(literals))
    if len(new_type_list) == 1:
      return new_type_list[0]
    elif "None" in new_type_list:
      return (self._FromTyping("Optional") + "[" +
              self._BuildUnion(t for t in new_type_list if t != "None") + "]")
    else:
      return self._FromTyping("Union") + "[" + ", ".join(new_type_list) + "]"

  def _BuildIntersection(self, type_list):
    """Builds a intersection of the types in type_list.

    Args:
      type_list: A list of strings representing types.

    Returns:
      A string representing the intersection of the types in type_list.
      Simplifies Intersection[X] to X and Intersection[X, None] to Optional[X].
    """
    type_list = tuple(type_list)
    if len(type_list) == 1:
      return type_list[0]
    else:
      return " and ".join(type_list)

  def EnterLiteral(self, _):
    assert not self.in_literal
    self.in_literal = True

  def LeaveLiteral(self, _):
    assert self.in_literal
    self.in_literal = False

  def VisitLiteral(self, node):
    base = self._ImportTypingExtension("Literal")
    return f"{base}[{node.value}]"

  def VisitAnnotated(self, node):
    base = self._ImportTypingExtension("Annotated")
    annotations = ", ".join(node.annotations)
    return f"{base}[{node.base_type}, {annotations}]"