Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ package / package_exporter.py

import collections
import importlib.machinery
import io
import linecache
import pickletools
import platform
import types
from collections import defaultdict, OrderedDict
from dataclasses import dataclass
from enum import Enum
from importlib.machinery import SourceFileLoader
from pathlib import Path
from typing import (
    Any,
    BinaryIO,
    Callable,
    cast,
    DefaultDict,
    Dict,
    List,
    Optional,
    Sequence,
    Set,
    Union,
)

import torch
from torch.serialization import location_tag, normalize_storage_type
from torch.types import Storage
from torch.utils.hooks import RemovableHandle

from ._digraph import DiGraph
from ._importlib import _normalize_path
from ._mangling import demangle, is_mangled
from ._package_pickler import create_pickler
from ._stdlib import is_stdlib_module
from .find_file_dependencies import find_files_source_depends_on
from .glob_group import GlobGroup, GlobPattern
from .importer import Importer, OrderedImporter, sys_importer

__all__ = [
    "PackagingErrorReason",
    "EmptyMatchError",
    "PackagingError",
    "PackageExporter",
]

_gate_torchscript_serialization = True

ActionHook = Callable[["PackageExporter", str], None]


class _ModuleProviderAction(Enum):
    """Represents one of the actions that :class:`PackageExporter` can take on a module.

    See :meth:`PackageExporter.extern` and friends for a description of what the actions do.
    """

    INTERN = 1
    EXTERN = 2
    MOCK = 3
    DENY = 4
    # Special case: when a module is mocked, PackageExporter writes out a
    # `_mock` module that implements our mocking stubs. If we re-package code,
    # we may encounter a `_mock` module from the original package. If we do,
    # just ignore it and write a `_mock` module once.
    REPACKAGED_MOCK_MODULE = 5
    # Special case: PackageImporter adds a fake module
    # (`torch_package_importer`) that allows packaged code to access it. Don't
    # re-export this.
    SKIP = 6


class PackagingErrorReason(Enum):
    """Listing of different reasons a dependency may fail to package.

    This enum is used to provide good error messages when
    :class:`PackagingError` is raised.
    """

    def __repr__(self):
        return "<%s.%s>" % (self.__class__.__name__, self.name)

    IS_EXTENSION_MODULE = (
        "Module is a C extension module. torch.package supports Python modules only."
    )
    NO_DUNDER_FILE = "Module had no __file__ defined."
    SOURCE_FILE_NOT_FOUND = (
        "Module had a __file__, but we could not find it in your filesystem."
    )
    DEPENDENCY_RESOLUTION_FAILED = "Dependency resolution failed."
    NO_ACTION = (
        "Module did not match against any action pattern. Extern, mock, or intern it."
    )
    DENIED = "Module was denied by a pattern."
    MOCKED_BUT_STILL_USED = (
        "Module was mocked out, but is still being used in the package. "
        "Please intern or extern the mocked modules if objects are supposed to be in "
        "the package."
    )


@dataclass
class _PatternInfo:
    """Holds :class:`PackageExporter`-specific info about how to execute matches against"""

    # What action to take on a module that matches this pattern.
    action: _ModuleProviderAction
    # The value of `allow_empty` the user gave when specifying the pattern.
    allow_empty: bool
    # Whether this pattern has been matched during packaging.
    was_matched: bool

    def __init__(self, action, allow_empty):
        self.action = action
        self.allow_empty = allow_empty
        self.was_matched = False


class EmptyMatchError(Exception):
    """This is an exception that is thrown when a mock or extern is marked as
    ``allow_empty=False``, and is not matched with any module during packaging.
    """

    pass


class PackagingError(Exception):
    """This exception is raised when there is an issue with exporting a package.
    ``PackageExporter`` will attempt to gather up all the errors and present
    them to you at once.
    """

    def __init__(self, dependency_graph: DiGraph, debug=False):
        # Group errors by reason.
        broken: Dict[PackagingErrorReason, List[str]] = defaultdict(list)
        for module_name, attrs in dependency_graph.nodes.items():
            error = attrs.get("error")
            if error is None:
                continue
            if error == PackagingErrorReason.NO_ACTION:
                assert "action" not in attrs
            broken[error].append(module_name)

        message = io.StringIO()
        message.write("\n")

        for reason, module_names in broken.items():
            message.write(f"* {reason.value}\n")
            for module_name in module_names:
                message.write(f"    {module_name}\n")

                # Print additional context if it's provided.
                error_context = dependency_graph.nodes[module_name].get("error_context")
                if error_context is not None:
                    message.write(f"      Context: {error_context}\n")
                if module_name in _DISALLOWED_MODULES:
                    message.write(
                        (
                            "      Note: While we usually use modules in the python standard library "
                            f"from the local environment, `{module_name}` has a lot of system "
                            "level access and therefore can pose a security risk. We heavily "
                            f"recommend removing `{module_name}` from your packaged code. However, if that "
                            "is not possible, add it to the extern list by calling "
                            f'PackageExporter.extern("`{module_name}`")\n'
                        )
                    )
                if debug:
                    module_path = dependency_graph.first_path(module_name)
                    message.write(
                        f"      A path to {module_name}: {' -> '.join(module_path)}"
                    )
        if not debug:
            message.write("\n")
            message.write(
                (
                    "Set debug=True when invoking PackageExporter for a visualization of where "
                    "broken modules are coming from!\n"
                )
            )
        # Save the dependency graph so that tooling can get at it.
        self.dependency_graph = dependency_graph
        super().__init__(message.getvalue())


class PackageExporter:
    """Exporters allow you to write packages of code, pickled Python data, and
    arbitrary binary and text resources into a self-contained package.

    Imports can load this code in a hermetic way, such that code is loaded
    from the package rather than the normal Python import system. This allows
    for the packaging of PyTorch model code and data so that it can be run
    on a server or used in the future for transfer learning.

    The code contained in packages is copied file-by-file from the original
    source when it is created, and the file format is a specially organized
    zip file. Future users of the package can unzip the package, and edit the code
    in order to perform custom modifications to it.

    The importer for packages ensures that code in the module can only be loaded from
    within the package, except for modules explicitly listed as external using :meth:`extern`.
    The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on.
    This prevents "implicit" dependencies where the package runs locally because it is importing
    a locally-installed package, but then fails when the package is copied to another machine.

    When source code is added to the package, the exporter can optionally scan it
    for further code dependencies (``dependencies=True``). It looks for import statements,
    resolves relative references to qualified module names, and performs an action specified by the user
    (See: :meth:`extern`, :meth:`mock`, and :meth:`intern`).
    """

    """A importer that will be searched in order to find the modules referenced by other modules or by
    pickled objects. The default module environment just uses sys_importer, which searches the Python environment.
    """
    importer: Importer

    def __init__(
        self,
        f: Union[str, Path, BinaryIO],
        importer: Union[Importer, Sequence[Importer]] = sys_importer,
        debug: bool = False,
    ):
        """
        Create an exporter.

        Args:
            f: The location to export to. Can be a  ``string``/``Path`` object containing a filename
                or a binary I/O object.
            importer: If a single Importer is passed, use that to search for modules.
                If a sequence of importers are passed, an ``OrderedImporter`` will be constructed out of them.
            debug: If set to True, add path of broken modules to PackagingErrors.
        """
        torch._C._log_api_usage_once("torch.package.PackageExporter")
        self.debug = debug
        if isinstance(f, (Path, str)):
            f = str(f)
            self.buffer: Optional[BinaryIO] = None
        else:  # is a byte buffer
            self.buffer = f

        self.zip_file = torch._C.PyTorchFileWriter(f)
        self.zip_file.set_min_version(6)
        self._written_files: Set[str] = set()

        self.serialized_reduces: Dict[int, Any] = {}

        # A graph tracking all the modules and pickle objects added to this
        # package and the dependencies between them.
        # - Each node is a module name (or a pickle name that looks like '<foo.obj.pkl>')
        # - Each directed edge (u, v) means u depends on v.
        # - Nodes may contain metadata that describe how to write the thing to the zipfile.
        self.dependency_graph = DiGraph()
        self.script_module_serializer = torch._C.ScriptModuleSerializer(self.zip_file)
        self.storage_context = self.script_module_serializer.storage_context()

        # These are OrderedDicts for compatibility with RemovableHandle.
        # Generic OrderedDict type annotations are not present until 3.7.
        # The real type signature is OrderedDict[int, Callable[[PackageExporter, str], None]]
        self._extern_hooks: OrderedDict = OrderedDict()
        self._mock_hooks: OrderedDict = OrderedDict()
        self._intern_hooks: OrderedDict = OrderedDict()

        if isinstance(importer, Importer):
            self.importer = importer
        else:
            if not isinstance(importer, collections.abc.Sequence):
                raise TypeError(
                    "importer arg should be an Importer or a sequence of Importers, "
                    f"got {type(importer)} instead."
                )
            self.importer = OrderedImporter(*importer)

        self.patterns: Dict[GlobGroup, _PatternInfo] = {}
        self._unique_id = 0

    def save_source_file(
        self, module_name: str, file_or_directory: str, dependencies=True
    ):
        """Adds the local file system ``file_or_directory`` to the source package to provide the code
        for ``module_name``.

        Args:
            module_name (str): e.g. ``"my_package.my_subpackage"``, code will be saved to provide code for this package.
            file_or_directory (str): the path to a file or directory of code. When a directory, all python files in the directory
                are recursively copied using :meth:`save_source_file`. If a file is named ``"/__init__.py"`` the code is treated
                as a package.
            dependencies (bool, optional): If ``True``, we scan the source for dependencies.
        """
        path = Path(file_or_directory)
        if path.is_dir():
            to_save = []  # list of tuples with arguments to save_source_string
            module_path = module_name.replace(".", "/")
            for filename in path.glob("**/*.py"):
                relative_path = filename.relative_to(path).as_posix()
                archivename = module_path + "/" + relative_path
                submodule_name = None
                if filename.name == "__init__.py":
                    submodule_name = archivename[: -len("/__init__.py")].replace(
                        "/", "."
                    )
                    is_package = True
                else:
                    submodule_name = archivename[: -len(".py")].replace("/", ".")
                    is_package = False

                # we delay the call to save_source_string so that we record all the source files
                # being provided by this directory structure _before_ attempting to resolve the dependencies
                # on the source. This makes sure we don't try to copy over modules that will just get
                # overwritten by this directory blob
                to_save.append(
                    (
                        submodule_name,
                        _read_file(str(filename)),
                        is_package,
                        dependencies,
                    )
                )

            for item in to_save:
                self.save_source_string(*item)
        else:
            is_package = path.name == "__init__.py"
            self.save_source_string(
                module_name,
                _read_file(file_or_directory),
                is_package,
                dependencies,
            )

    def get_unique_id(self) -> str:
        """Get an id. This id is guaranteed to only be handed out once for this package."""
        ret = str(self._unique_id)
        self._unique_id += 1
        return ret

    def _get_dependencies(
        self, src: str, module_name: str, is_package: bool
    ) -> List[str]:
        """Return all modules that this source code depends on.

        Dependencies are found by scanning the source code for import-like statements.

        Arguments:
            src: The Python source code to analyze for dependencies.
            module_name: The name of the module that ``src`` corresponds to.
Loading ...