Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ package / package_importer.py

import builtins
import importlib
import importlib.machinery
import inspect
import io
import linecache
import os.path
import types
from contextlib import contextmanager
from pathlib import Path
from typing import Any, BinaryIO, Callable, cast, Dict, Iterable, List, Optional, Union
from weakref import WeakValueDictionary

import torch
from torch.serialization import _get_restore_location, _maybe_decode_ascii

from ._directory_reader import DirectoryReader
from ._importlib import (
    _calc___package__,
    _normalize_line_endings,
    _normalize_path,
    _resolve_name,
    _sanity_check,
)
from ._mangling import demangle, PackageMangler
from ._package_unpickler import PackageUnpickler
from .file_structure_representation import _create_directory_from_file_list, Directory
from .glob_group import GlobPattern
from .importer import Importer

__all__ = ["PackageImporter"]


# This is a list of imports that are implicitly allowed even if they haven't
# been marked as extern. This is to work around the fact that Torch implicitly
# depends on numpy and package can't track it.
# https://github.com/pytorch/MultiPy/issues/46
IMPLICIT_IMPORT_ALLOWLIST: Iterable[str] = [
    "numpy",
    "numpy.core",
    "numpy.core._multiarray_umath",
    # FX GraphModule might depend on builtins module and users usually
    # don't extern builtins. Here we import it here by default.
    "builtins",
]


class PackageImporter(Importer):
    """Importers allow you to load code written to packages by :class:`PackageExporter`.
    Code is loaded in a hermetic way, using files from the package
    rather than the normal python import system. This allows
    for the packaging of PyTorch model code and data so that it can be run
    on a server or used in the future for transfer learning.

    The importer for packages ensures that code in the module can only be loaded from
    within the package, except for modules explicitly listed as external during export.
    The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on.
    This prevents "implicit" dependencies where the package runs locally because it is importing
    a locally-installed package, but then fails when the package is copied to another machine.
    """

    """The dictionary of already loaded modules from this package, equivalent to ``sys.modules`` but
    local to this importer.
    """

    modules: Dict[str, types.ModuleType]

    def __init__(
        self,
        file_or_buffer: Union[str, torch._C.PyTorchFileReader, Path, BinaryIO],
        module_allowed: Callable[[str], bool] = lambda module_name: True,
    ):
        """Open ``file_or_buffer`` for importing. This checks that the imported package only requires modules
        allowed by ``module_allowed``

        Args:
            file_or_buffer: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
                a string, or an ``os.PathLike`` object containing a filename.
            module_allowed (Callable[[str], bool], optional): A method to determine if a externally provided module
                should be allowed. Can be used to ensure packages loaded do not depend on modules that the server
                does not support. Defaults to allowing anything.

        Raises:
            ImportError: If the package will use a disallowed module.
        """
        torch._C._log_api_usage_once("torch.package.PackageImporter")

        self.zip_reader: Any
        if isinstance(file_or_buffer, torch._C.PyTorchFileReader):
            self.filename = "<pytorch_file_reader>"
            self.zip_reader = file_or_buffer
        elif isinstance(file_or_buffer, (Path, str)):
            self.filename = str(file_or_buffer)
            if not os.path.isdir(self.filename):
                self.zip_reader = torch._C.PyTorchFileReader(self.filename)
            else:
                self.zip_reader = DirectoryReader(self.filename)
        else:
            self.filename = "<binary>"
            self.zip_reader = torch._C.PyTorchFileReader(file_or_buffer)

        self.root = _PackageNode(None)
        self.modules = {}
        self.extern_modules = self._read_extern()

        for extern_module in self.extern_modules:
            if not module_allowed(extern_module):
                raise ImportError(
                    f"package '{file_or_buffer}' needs the external module '{extern_module}' "
                    f"but that module has been disallowed"
                )
            self._add_extern(extern_module)

        for fname in self.zip_reader.get_all_records():
            self._add_file(fname)

        self.patched_builtins = builtins.__dict__.copy()
        self.patched_builtins["__import__"] = self.__import__
        # Allow packaged modules to reference their PackageImporter
        self.modules["torch_package_importer"] = self  # type: ignore[assignment]

        self._mangler = PackageMangler()

        # used for reduce deserializaiton
        self.storage_context: Any = None
        self.last_map_location = None

        # used for torch.serialization._load
        self.Unpickler = lambda *args, **kwargs: PackageUnpickler(self, *args, **kwargs)

    def import_module(self, name: str, package=None):
        """Load a module from the package if it hasn't already been loaded, and then return
        the module. Modules are loaded locally
        to the importer and will appear in ``self.modules`` rather than ``sys.modules``.

        Args:
            name (str): Fully qualified name of the module to load.
            package ([type], optional): Unused, but present to match the signature of importlib.import_module. Defaults to ``None``.

        Returns:
            types.ModuleType: The (possibly already) loaded module.
        """
        # We should always be able to support importing modules from this package.
        # This is to support something like:
        #   obj = importer.load_pickle(...)
        #   importer.import_module(obj.__module__)  <- this string will be mangled
        #
        # Note that _mangler.demangle will not demangle any module names
        # produced by a different PackageImporter instance.
        name = self._mangler.demangle(name)

        return self._gcd_import(name)

    def load_binary(self, package: str, resource: str) -> bytes:
        """Load raw bytes.

        Args:
            package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
            resource (str): The unique name for the resource.

        Returns:
            bytes: The loaded data.
        """

        path = self._zipfile_path(package, resource)
        return self.zip_reader.get_record(path)

    def load_text(
        self,
        package: str,
        resource: str,
        encoding: str = "utf-8",
        errors: str = "strict",
    ) -> str:
        """Load a string.

        Args:
            package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
            resource (str): The unique name for the resource.
            encoding (str, optional): Passed to ``decode``. Defaults to ``'utf-8'``.
            errors (str, optional): Passed to ``decode``. Defaults to ``'strict'``.

        Returns:
            str: The loaded text.
        """
        data = self.load_binary(package, resource)
        return data.decode(encoding, errors)

    def load_pickle(self, package: str, resource: str, map_location=None) -> Any:
        """Unpickles the resource from the package, loading any modules that are needed to construct the objects
        using :meth:`import_module`.

        Args:
            package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
            resource (str): The unique name for the resource.
            map_location: Passed to `torch.load` to determine how tensors are mapped to devices. Defaults to ``None``.

        Returns:
            Any: The unpickled object.
        """
        pickle_file = self._zipfile_path(package, resource)
        restore_location = _get_restore_location(map_location)
        loaded_storages = {}
        loaded_reduces = {}
        storage_context = torch._C.DeserializationStorageContext()

        def load_tensor(dtype, size, key, location, restore_location):
            name = f"{key}.storage"

            if storage_context.has_storage(name):
                storage = storage_context.get_storage(name, dtype)._typed_storage()
            else:
                tensor = self.zip_reader.get_storage_from_record(
                    ".data/" + name, size, dtype
                )
                if isinstance(self.zip_reader, torch._C.PyTorchFileReader):
                    storage_context.add_storage(name, tensor)
                storage = tensor._typed_storage()
            loaded_storages[key] = restore_location(storage, location)

        def persistent_load(saved_id):
            assert isinstance(saved_id, tuple)
            typename = _maybe_decode_ascii(saved_id[0])
            data = saved_id[1:]

            if typename == "storage":
                storage_type, key, location, size = data
                dtype = storage_type.dtype

                if key not in loaded_storages:
                    load_tensor(
                        dtype,
                        size,
                        key,
                        _maybe_decode_ascii(location),
                        restore_location,
                    )
                storage = loaded_storages[key]
                # TODO: Once we decide to break serialization FC, we can
                # stop wrapping with TypedStorage
                return torch.storage.TypedStorage(
                    wrap_storage=storage._untyped_storage, dtype=dtype, _internal=True
                )
            elif typename == "reduce_package":
                # to fix BC breaking change, objects on this load path
                # will be loaded multiple times erroneously
                if len(data) == 2:
                    func, args = data
                    return func(self, *args)
                reduce_id, func, args = data
                if reduce_id not in loaded_reduces:
                    loaded_reduces[reduce_id] = func(self, *args)
                return loaded_reduces[reduce_id]
            else:
                f"Unknown typename for persistent_load, expected 'storage' or 'reduce_package' but got '{typename}'"

        # Load the data (which may in turn use `persistent_load` to load tensors)
        data_file = io.BytesIO(self.zip_reader.get_record(pickle_file))
        unpickler = self.Unpickler(data_file)
        unpickler.persistent_load = persistent_load  # type: ignore[assignment]

        @contextmanager
        def set_deserialization_context():
            # to let reduce_package access deserializaiton context
            self.storage_context = storage_context
            self.last_map_location = map_location
            try:
                yield
            finally:
                self.storage_context = None
                self.last_map_location = None

        with set_deserialization_context():
            result = unpickler.load()

        # TODO from zdevito:
        #   This stateful weird function will need to be removed in our efforts
        #   to unify the format. It has a race condition if multiple python
        #   threads try to read independent files
        torch._utils._validate_loaded_sparse_tensors()

        return result

    def id(self):
        """
        Returns internal identifier that torch.package uses to distinguish :class:`PackageImporter` instances.
        Looks like::

            <torch_package_0>
        """
        return self._mangler.parent_name()

    def file_structure(
        self, *, include: "GlobPattern" = "**", exclude: "GlobPattern" = ()
    ) -> Directory:
        """Returns a file structure representation of package's zipfile.

        Args:
            include (Union[List[str], str]): An optional string e.g. ``"my_package.my_subpackage"``, or optional list of strings
                for the names of the files to be included in the zipfile representation. This can also be
                a glob-style pattern, as described in :meth:`PackageExporter.mock`

            exclude (Union[List[str], str]): An optional pattern that excludes files whose name match the pattern.

        Returns:
            :class:`Directory`
        """
        return _create_directory_from_file_list(
            self.filename, self.zip_reader.get_all_records(), include, exclude
        )

    def python_version(self):
        """Returns the version of python that was used to create this package.

        Note: this function is experimental and not Forward Compatible. The plan is to move this into a lock
        file later on.

        Returns:
            :class:`Optional[str]` a python version e.g. 3.8.9 or None if no version was stored with this package
        """
        python_version_path = ".data/python_version"
        return (
            self.zip_reader.get_record(python_version_path).decode("utf-8").strip()
            if self.zip_reader.has_record(python_version_path)
            else None
        )

    def _read_extern(self):
        return (
            self.zip_reader.get_record(".data/extern_modules")
            .decode("utf-8")
            .splitlines(keepends=False)
        )

    def _make_module(
        self, name: str, filename: Optional[str], is_package: bool, parent: str
    ):
        mangled_filename = self._mangler.mangle(filename) if filename else None
        spec = importlib.machinery.ModuleSpec(
            name,
            self,  # type: ignore[arg-type]
            origin="<package_importer>",
            is_package=is_package,
        )
        module = importlib.util.module_from_spec(spec)
Loading ...