import builtins
import importlib
import importlib.machinery
import inspect
import io
import linecache
import os.path
import types
from contextlib import contextmanager
from pathlib import Path
from typing import Any, BinaryIO, Callable, cast, Dict, Iterable, List, Optional, Union
from weakref import WeakValueDictionary
import torch
from torch.serialization import _get_restore_location, _maybe_decode_ascii
from ._directory_reader import DirectoryReader
from ._importlib import (
_calc___package__,
_normalize_line_endings,
_normalize_path,
_resolve_name,
_sanity_check,
)
from ._mangling import demangle, PackageMangler
from ._package_unpickler import PackageUnpickler
from .file_structure_representation import _create_directory_from_file_list, Directory
from .glob_group import GlobPattern
from .importer import Importer
__all__ = ["PackageImporter"]
# This is a list of imports that are implicitly allowed even if they haven't
# been marked as extern. This is to work around the fact that Torch implicitly
# depends on numpy and package can't track it.
# https://github.com/pytorch/MultiPy/issues/46
IMPLICIT_IMPORT_ALLOWLIST: Iterable[str] = [
"numpy",
"numpy.core",
"numpy.core._multiarray_umath",
# FX GraphModule might depend on builtins module and users usually
# don't extern builtins. Here we import it here by default.
"builtins",
]
class PackageImporter(Importer):
"""Importers allow you to load code written to packages by :class:`PackageExporter`.
Code is loaded in a hermetic way, using files from the package
rather than the normal python import system. This allows
for the packaging of PyTorch model code and data so that it can be run
on a server or used in the future for transfer learning.
The importer for packages ensures that code in the module can only be loaded from
within the package, except for modules explicitly listed as external during export.
The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on.
This prevents "implicit" dependencies where the package runs locally because it is importing
a locally-installed package, but then fails when the package is copied to another machine.
"""
"""The dictionary of already loaded modules from this package, equivalent to ``sys.modules`` but
local to this importer.
"""
modules: Dict[str, types.ModuleType]
def __init__(
self,
file_or_buffer: Union[str, torch._C.PyTorchFileReader, Path, BinaryIO],
module_allowed: Callable[[str], bool] = lambda module_name: True,
):
"""Open ``file_or_buffer`` for importing. This checks that the imported package only requires modules
allowed by ``module_allowed``
Args:
file_or_buffer: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
a string, or an ``os.PathLike`` object containing a filename.
module_allowed (Callable[[str], bool], optional): A method to determine if a externally provided module
should be allowed. Can be used to ensure packages loaded do not depend on modules that the server
does not support. Defaults to allowing anything.
Raises:
ImportError: If the package will use a disallowed module.
"""
torch._C._log_api_usage_once("torch.package.PackageImporter")
self.zip_reader: Any
if isinstance(file_or_buffer, torch._C.PyTorchFileReader):
self.filename = "<pytorch_file_reader>"
self.zip_reader = file_or_buffer
elif isinstance(file_or_buffer, (Path, str)):
self.filename = str(file_or_buffer)
if not os.path.isdir(self.filename):
self.zip_reader = torch._C.PyTorchFileReader(self.filename)
else:
self.zip_reader = DirectoryReader(self.filename)
else:
self.filename = "<binary>"
self.zip_reader = torch._C.PyTorchFileReader(file_or_buffer)
self.root = _PackageNode(None)
self.modules = {}
self.extern_modules = self._read_extern()
for extern_module in self.extern_modules:
if not module_allowed(extern_module):
raise ImportError(
f"package '{file_or_buffer}' needs the external module '{extern_module}' "
f"but that module has been disallowed"
)
self._add_extern(extern_module)
for fname in self.zip_reader.get_all_records():
self._add_file(fname)
self.patched_builtins = builtins.__dict__.copy()
self.patched_builtins["__import__"] = self.__import__
# Allow packaged modules to reference their PackageImporter
self.modules["torch_package_importer"] = self # type: ignore[assignment]
self._mangler = PackageMangler()
# used for reduce deserializaiton
self.storage_context: Any = None
self.last_map_location = None
# used for torch.serialization._load
self.Unpickler = lambda *args, **kwargs: PackageUnpickler(self, *args, **kwargs)
def import_module(self, name: str, package=None):
"""Load a module from the package if it hasn't already been loaded, and then return
the module. Modules are loaded locally
to the importer and will appear in ``self.modules`` rather than ``sys.modules``.
Args:
name (str): Fully qualified name of the module to load.
package ([type], optional): Unused, but present to match the signature of importlib.import_module. Defaults to ``None``.
Returns:
types.ModuleType: The (possibly already) loaded module.
"""
# We should always be able to support importing modules from this package.
# This is to support something like:
# obj = importer.load_pickle(...)
# importer.import_module(obj.__module__) <- this string will be mangled
#
# Note that _mangler.demangle will not demangle any module names
# produced by a different PackageImporter instance.
name = self._mangler.demangle(name)
return self._gcd_import(name)
def load_binary(self, package: str, resource: str) -> bytes:
"""Load raw bytes.
Args:
package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
resource (str): The unique name for the resource.
Returns:
bytes: The loaded data.
"""
path = self._zipfile_path(package, resource)
return self.zip_reader.get_record(path)
def load_text(
self,
package: str,
resource: str,
encoding: str = "utf-8",
errors: str = "strict",
) -> str:
"""Load a string.
Args:
package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
resource (str): The unique name for the resource.
encoding (str, optional): Passed to ``decode``. Defaults to ``'utf-8'``.
errors (str, optional): Passed to ``decode``. Defaults to ``'strict'``.
Returns:
str: The loaded text.
"""
data = self.load_binary(package, resource)
return data.decode(encoding, errors)
def load_pickle(self, package: str, resource: str, map_location=None) -> Any:
"""Unpickles the resource from the package, loading any modules that are needed to construct the objects
using :meth:`import_module`.
Args:
package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
resource (str): The unique name for the resource.
map_location: Passed to `torch.load` to determine how tensors are mapped to devices. Defaults to ``None``.
Returns:
Any: The unpickled object.
"""
pickle_file = self._zipfile_path(package, resource)
restore_location = _get_restore_location(map_location)
loaded_storages = {}
loaded_reduces = {}
storage_context = torch._C.DeserializationStorageContext()
def load_tensor(dtype, size, key, location, restore_location):
name = f"{key}.storage"
if storage_context.has_storage(name):
storage = storage_context.get_storage(name, dtype)._typed_storage()
else:
tensor = self.zip_reader.get_storage_from_record(
".data/" + name, size, dtype
)
if isinstance(self.zip_reader, torch._C.PyTorchFileReader):
storage_context.add_storage(name, tensor)
storage = tensor._typed_storage()
loaded_storages[key] = restore_location(storage, location)
def persistent_load(saved_id):
assert isinstance(saved_id, tuple)
typename = _maybe_decode_ascii(saved_id[0])
data = saved_id[1:]
if typename == "storage":
storage_type, key, location, size = data
dtype = storage_type.dtype
if key not in loaded_storages:
load_tensor(
dtype,
size,
key,
_maybe_decode_ascii(location),
restore_location,
)
storage = loaded_storages[key]
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
return torch.storage.TypedStorage(
wrap_storage=storage._untyped_storage, dtype=dtype, _internal=True
)
elif typename == "reduce_package":
# to fix BC breaking change, objects on this load path
# will be loaded multiple times erroneously
if len(data) == 2:
func, args = data
return func(self, *args)
reduce_id, func, args = data
if reduce_id not in loaded_reduces:
loaded_reduces[reduce_id] = func(self, *args)
return loaded_reduces[reduce_id]
else:
f"Unknown typename for persistent_load, expected 'storage' or 'reduce_package' but got '{typename}'"
# Load the data (which may in turn use `persistent_load` to load tensors)
data_file = io.BytesIO(self.zip_reader.get_record(pickle_file))
unpickler = self.Unpickler(data_file)
unpickler.persistent_load = persistent_load # type: ignore[assignment]
@contextmanager
def set_deserialization_context():
# to let reduce_package access deserializaiton context
self.storage_context = storage_context
self.last_map_location = map_location
try:
yield
finally:
self.storage_context = None
self.last_map_location = None
with set_deserialization_context():
result = unpickler.load()
# TODO from zdevito:
# This stateful weird function will need to be removed in our efforts
# to unify the format. It has a race condition if multiple python
# threads try to read independent files
torch._utils._validate_loaded_sparse_tensors()
return result
def id(self):
"""
Returns internal identifier that torch.package uses to distinguish :class:`PackageImporter` instances.
Looks like::
<torch_package_0>
"""
return self._mangler.parent_name()
def file_structure(
self, *, include: "GlobPattern" = "**", exclude: "GlobPattern" = ()
) -> Directory:
"""Returns a file structure representation of package's zipfile.
Args:
include (Union[List[str], str]): An optional string e.g. ``"my_package.my_subpackage"``, or optional list of strings
for the names of the files to be included in the zipfile representation. This can also be
a glob-style pattern, as described in :meth:`PackageExporter.mock`
exclude (Union[List[str], str]): An optional pattern that excludes files whose name match the pattern.
Returns:
:class:`Directory`
"""
return _create_directory_from_file_list(
self.filename, self.zip_reader.get_all_records(), include, exclude
)
def python_version(self):
"""Returns the version of python that was used to create this package.
Note: this function is experimental and not Forward Compatible. The plan is to move this into a lock
file later on.
Returns:
:class:`Optional[str]` a python version e.g. 3.8.9 or None if no version was stored with this package
"""
python_version_path = ".data/python_version"
return (
self.zip_reader.get_record(python_version_path).decode("utf-8").strip()
if self.zip_reader.has_record(python_version_path)
else None
)
def _read_extern(self):
return (
self.zip_reader.get_record(".data/extern_modules")
.decode("utf-8")
.splitlines(keepends=False)
)
def _make_module(
self, name: str, filename: Optional[str], is_package: bool, parent: str
):
mangled_filename = self._mangler.mangle(filename) if filename else None
spec = importlib.machinery.ModuleSpec(
name,
self, # type: ignore[arg-type]
origin="<package_importer>",
is_package=is_package,
)
module = importlib.util.module_from_spec(spec)
Loading ...