1200 lines
50 KiB
Python
1200 lines
50 KiB
Python
import collections
|
|
import importlib.machinery
|
|
import io
|
|
import linecache
|
|
import pickletools
|
|
import platform
|
|
import types
|
|
from collections import defaultdict, OrderedDict
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from importlib.machinery import SourceFileLoader
|
|
from pathlib import Path
|
|
from typing import (
|
|
Any,
|
|
BinaryIO,
|
|
Callable,
|
|
cast,
|
|
DefaultDict,
|
|
Dict,
|
|
List,
|
|
Optional,
|
|
Sequence,
|
|
Set,
|
|
Union,
|
|
)
|
|
|
|
import torch
|
|
from torch.serialization import location_tag, normalize_storage_type
|
|
from torch.types import Storage
|
|
from torch.utils.hooks import RemovableHandle
|
|
|
|
from ._digraph import DiGraph
|
|
from ._importlib import _normalize_path
|
|
from ._mangling import demangle, is_mangled
|
|
from ._package_pickler import create_pickler
|
|
from ._stdlib import is_stdlib_module
|
|
from .find_file_dependencies import find_files_source_depends_on
|
|
from .glob_group import GlobGroup, GlobPattern
|
|
from .importer import Importer, OrderedImporter, sys_importer
|
|
|
|
__all__ = [
|
|
"PackagingErrorReason",
|
|
"EmptyMatchError",
|
|
"PackagingError",
|
|
"PackageExporter",
|
|
]
|
|
|
|
_gate_torchscript_serialization = True
|
|
|
|
ActionHook = Callable[["PackageExporter", str], None]
|
|
|
|
|
|
class _ModuleProviderAction(Enum):
|
|
"""Represents one of the actions that :class:`PackageExporter` can take on a module.
|
|
|
|
See :meth:`PackageExporter.extern` and friends for a description of what the actions do.
|
|
"""
|
|
|
|
INTERN = 1
|
|
EXTERN = 2
|
|
MOCK = 3
|
|
DENY = 4
|
|
# Special case: when a module is mocked, PackageExporter writes out a
|
|
# `_mock` module that implements our mocking stubs. If we re-package code,
|
|
# we may encounter a `_mock` module from the original package. If we do,
|
|
# just ignore it and write a `_mock` module once.
|
|
REPACKAGED_MOCK_MODULE = 5
|
|
# Special case: PackageImporter adds a fake module
|
|
# (`torch_package_importer`) that allows packaged code to access it. Don't
|
|
# re-export this.
|
|
SKIP = 6
|
|
|
|
|
|
class PackagingErrorReason(Enum):
|
|
"""Listing of different reasons a dependency may fail to package.
|
|
|
|
This enum is used to provide good error messages when
|
|
:class:`PackagingError` is raised.
|
|
"""
|
|
|
|
def __repr__(self):
|
|
return f"<{self.__class__.__name__}.{self.name}>"
|
|
|
|
IS_EXTENSION_MODULE = (
|
|
"Module is a C extension module. torch.package supports Python modules only."
|
|
)
|
|
NO_DUNDER_FILE = "Module had no __file__ defined."
|
|
SOURCE_FILE_NOT_FOUND = (
|
|
"Module had a __file__, but we could not find it in your filesystem."
|
|
)
|
|
DEPENDENCY_RESOLUTION_FAILED = "Dependency resolution failed."
|
|
NO_ACTION = (
|
|
"Module did not match against any action pattern. Extern, mock, or intern it."
|
|
)
|
|
DENIED = "Module was denied by a pattern."
|
|
MOCKED_BUT_STILL_USED = (
|
|
"Module was mocked out, but is still being used in the package. "
|
|
"Please intern or extern the mocked modules if objects are supposed to be in "
|
|
"the package."
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class _PatternInfo:
|
|
"""Holds :class:`PackageExporter`-specific info about how to execute matches against"""
|
|
|
|
# What action to take on a module that matches this pattern.
|
|
action: _ModuleProviderAction
|
|
# The value of `allow_empty` the user gave when specifying the pattern.
|
|
allow_empty: bool
|
|
# Whether this pattern has been matched during packaging.
|
|
was_matched: bool
|
|
|
|
def __init__(self, action, allow_empty):
|
|
self.action = action
|
|
self.allow_empty = allow_empty
|
|
self.was_matched = False
|
|
|
|
|
|
class EmptyMatchError(Exception):
|
|
"""This is an exception that is thrown when a mock or extern is marked as
|
|
``allow_empty=False``, and is not matched with any module during packaging.
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
class PackagingError(Exception):
|
|
"""This exception is raised when there is an issue with exporting a package.
|
|
``PackageExporter`` will attempt to gather up all the errors and present
|
|
them to you at once.
|
|
"""
|
|
|
|
def __init__(self, dependency_graph: DiGraph, debug=False):
|
|
# Group errors by reason.
|
|
broken: Dict[PackagingErrorReason, List[str]] = defaultdict(list)
|
|
for module_name, attrs in dependency_graph.nodes.items():
|
|
error = attrs.get("error")
|
|
if error is None:
|
|
continue
|
|
if error == PackagingErrorReason.NO_ACTION:
|
|
assert "action" not in attrs
|
|
broken[error].append(module_name)
|
|
|
|
message = io.StringIO()
|
|
message.write("\n")
|
|
|
|
for reason, module_names in broken.items():
|
|
message.write(f"* {reason.value}\n")
|
|
for module_name in module_names:
|
|
message.write(f" {module_name}\n")
|
|
|
|
# Print additional context if it's provided.
|
|
error_context = dependency_graph.nodes[module_name].get("error_context")
|
|
if error_context is not None:
|
|
message.write(f" Context: {error_context}\n")
|
|
if module_name in _DISALLOWED_MODULES:
|
|
message.write(
|
|
" Note: While we usually use modules in the python standard library "
|
|
f"from the local environment, `{module_name}` has a lot of system "
|
|
"level access and therefore can pose a security risk. We heavily "
|
|
f"recommend removing `{module_name}` from your packaged code. However, if that "
|
|
"is not possible, add it to the extern list by calling "
|
|
f'PackageExporter.extern("`{module_name}`")\n'
|
|
)
|
|
if debug:
|
|
module_path = dependency_graph.first_path(module_name)
|
|
message.write(
|
|
f" A path to {module_name}: {' -> '.join(module_path)}"
|
|
)
|
|
if not debug:
|
|
message.write("\n")
|
|
message.write(
|
|
"Set debug=True when invoking PackageExporter for a visualization of where "
|
|
"broken modules are coming from!\n"
|
|
)
|
|
# Save the dependency graph so that tooling can get at it.
|
|
self.dependency_graph = dependency_graph
|
|
super().__init__(message.getvalue())
|
|
|
|
|
|
class PackageExporter:
|
|
"""Exporters allow you to write packages of code, pickled Python data, and
|
|
arbitrary binary and text resources into a self-contained package.
|
|
|
|
Imports can load this code in a hermetic way, such that code is loaded
|
|
from the package rather than the normal Python import system. This allows
|
|
for the packaging of PyTorch model code and data so that it can be run
|
|
on a server or used in the future for transfer learning.
|
|
|
|
The code contained in packages is copied file-by-file from the original
|
|
source when it is created, and the file format is a specially organized
|
|
zip file. Future users of the package can unzip the package, and edit the code
|
|
in order to perform custom modifications to it.
|
|
|
|
The importer for packages ensures that code in the module can only be loaded from
|
|
within the package, except for modules explicitly listed as external using :meth:`extern`.
|
|
The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on.
|
|
This prevents "implicit" dependencies where the package runs locally because it is importing
|
|
a locally-installed package, but then fails when the package is copied to another machine.
|
|
|
|
When source code is added to the package, the exporter can optionally scan it
|
|
for further code dependencies (``dependencies=True``). It looks for import statements,
|
|
resolves relative references to qualified module names, and performs an action specified by the user
|
|
(See: :meth:`extern`, :meth:`mock`, and :meth:`intern`).
|
|
"""
|
|
|
|
"""A importer that will be searched in order to find the modules referenced by other modules or by
|
|
pickled objects. The default module environment just uses sys_importer, which searches the Python environment.
|
|
"""
|
|
importer: Importer
|
|
|
|
def __init__(
|
|
self,
|
|
f: Union[str, Path, BinaryIO],
|
|
importer: Union[Importer, Sequence[Importer]] = sys_importer,
|
|
debug: bool = False,
|
|
):
|
|
"""
|
|
Create an exporter.
|
|
|
|
Args:
|
|
f: The location to export to. Can be a ``string``/``Path`` object containing a filename
|
|
or a binary I/O object.
|
|
importer: If a single Importer is passed, use that to search for modules.
|
|
If a sequence of importers are passed, an ``OrderedImporter`` will be constructed out of them.
|
|
debug: If set to True, add path of broken modules to PackagingErrors.
|
|
"""
|
|
torch._C._log_api_usage_once("torch.package.PackageExporter")
|
|
self.debug = debug
|
|
if isinstance(f, (Path, str)):
|
|
f = str(f)
|
|
self.buffer: Optional[BinaryIO] = None
|
|
else: # is a byte buffer
|
|
self.buffer = f
|
|
|
|
self.zip_file = torch._C.PyTorchFileWriter(f)
|
|
self.zip_file.set_min_version(6)
|
|
self._written_files: Set[str] = set()
|
|
|
|
self.serialized_reduces: Dict[int, Any] = {}
|
|
|
|
# A graph tracking all the modules and pickle objects added to this
|
|
# package and the dependencies between them.
|
|
# - Each node is a module name (or a pickle name that looks like '<foo.obj.pkl>')
|
|
# - Each directed edge (u, v) means u depends on v.
|
|
# - Nodes may contain metadata that describe how to write the thing to the zipfile.
|
|
self.dependency_graph = DiGraph()
|
|
self.script_module_serializer = torch._C.ScriptModuleSerializer(self.zip_file)
|
|
self.storage_context = self.script_module_serializer.storage_context()
|
|
|
|
# These are OrderedDicts for compatibility with RemovableHandle.
|
|
# Generic OrderedDict type annotations are not present until 3.7.
|
|
# The real type signature is OrderedDict[int, Callable[[PackageExporter, str], None]]
|
|
self._extern_hooks: OrderedDict = OrderedDict()
|
|
self._mock_hooks: OrderedDict = OrderedDict()
|
|
self._intern_hooks: OrderedDict = OrderedDict()
|
|
|
|
if isinstance(importer, Importer):
|
|
self.importer = importer
|
|
else:
|
|
if not isinstance(importer, collections.abc.Sequence):
|
|
raise TypeError(
|
|
"importer arg should be an Importer or a sequence of Importers, "
|
|
f"got {type(importer)} instead."
|
|
)
|
|
self.importer = OrderedImporter(*importer)
|
|
|
|
self.patterns: Dict[GlobGroup, _PatternInfo] = {}
|
|
self._unique_id = 0
|
|
|
|
def save_source_file(
|
|
self, module_name: str, file_or_directory: str, dependencies=True
|
|
):
|
|
"""Adds the local file system ``file_or_directory`` to the source package to provide the code
|
|
for ``module_name``.
|
|
|
|
Args:
|
|
module_name (str): e.g. ``"my_package.my_subpackage"``, code will be saved to provide code for this package.
|
|
file_or_directory (str): the path to a file or directory of code. When a directory, all python files in the directory
|
|
are recursively copied using :meth:`save_source_file`. If a file is named ``"/__init__.py"`` the code is treated
|
|
as a package.
|
|
dependencies (bool, optional): If ``True``, we scan the source for dependencies.
|
|
"""
|
|
path = Path(file_or_directory)
|
|
if path.is_dir():
|
|
to_save = [] # list of tuples with arguments to save_source_string
|
|
module_path = module_name.replace(".", "/")
|
|
for filename in path.glob("**/*.py"):
|
|
relative_path = filename.relative_to(path).as_posix()
|
|
archivename = module_path + "/" + relative_path
|
|
submodule_name = None
|
|
if filename.name == "__init__.py":
|
|
submodule_name = archivename[: -len("/__init__.py")].replace(
|
|
"/", "."
|
|
)
|
|
is_package = True
|
|
else:
|
|
submodule_name = archivename[: -len(".py")].replace("/", ".")
|
|
is_package = False
|
|
|
|
# we delay the call to save_source_string so that we record all the source files
|
|
# being provided by this directory structure _before_ attempting to resolve the dependencies
|
|
# on the source. This makes sure we don't try to copy over modules that will just get
|
|
# overwritten by this directory blob
|
|
to_save.append(
|
|
(
|
|
submodule_name,
|
|
_read_file(str(filename)),
|
|
is_package,
|
|
dependencies,
|
|
)
|
|
)
|
|
|
|
for item in to_save:
|
|
self.save_source_string(*item)
|
|
else:
|
|
is_package = path.name == "__init__.py"
|
|
self.save_source_string(
|
|
module_name,
|
|
_read_file(file_or_directory),
|
|
is_package,
|
|
dependencies,
|
|
)
|
|
|
|
def get_unique_id(self) -> str:
|
|
"""Get an id. This id is guaranteed to only be handed out once for this package."""
|
|
ret = str(self._unique_id)
|
|
self._unique_id += 1
|
|
return ret
|
|
|
|
def _get_dependencies(
|
|
self, src: str, module_name: str, is_package: bool
|
|
) -> List[str]:
|
|
"""Return all modules that this source code depends on.
|
|
|
|
Dependencies are found by scanning the source code for import-like statements.
|
|
|
|
Arguments:
|
|
src: The Python source code to analyze for dependencies.
|
|
module_name: The name of the module that ``src`` corresponds to.
|
|
is_package: Whether this module should be treated as a package.
|
|
See :py:meth:`save_source_string` for more info.
|
|
|
|
Returns:
|
|
A list containing modules detected as direct dependencies in
|
|
``src``. The items in the list are guaranteed to be unique.
|
|
"""
|
|
package_name = (
|
|
module_name if is_package else module_name.rsplit(".", maxsplit=1)[0]
|
|
)
|
|
try:
|
|
dep_pairs = find_files_source_depends_on(src, package_name)
|
|
except Exception as e:
|
|
self.dependency_graph.add_node(
|
|
module_name,
|
|
error=PackagingErrorReason.DEPENDENCY_RESOLUTION_FAILED,
|
|
error_context=str(e),
|
|
)
|
|
return []
|
|
|
|
# Use a dict to get uniquing but also deterministic order
|
|
dependencies = {}
|
|
for dep_module_name, dep_module_obj in dep_pairs:
|
|
# handle the case where someone did something like `from pack import sub`
|
|
# where `sub` is a submodule. In this case we don't have to save pack, just sub.
|
|
# this ensures we don't pick up additional dependencies on pack.
|
|
# However, in the case where `sub` is not a submodule but an object, then we do have
|
|
# to save pack.
|
|
if dep_module_obj is not None:
|
|
possible_submodule = f"{dep_module_name}.{dep_module_obj}"
|
|
if self._module_exists(possible_submodule):
|
|
dependencies[possible_submodule] = True
|
|
# we don't need to save `pack`
|
|
continue
|
|
if self._module_exists(dep_module_name):
|
|
dependencies[dep_module_name] = True
|
|
|
|
return list(dependencies.keys())
|
|
|
|
def save_source_string(
|
|
self,
|
|
module_name: str,
|
|
src: str,
|
|
is_package: bool = False,
|
|
dependencies: bool = True,
|
|
):
|
|
"""Adds ``src`` as the source code for ``module_name`` in the exported package.
|
|
|
|
Args:
|
|
module_name (str): e.g. ``my_package.my_subpackage``, code will be saved to provide code for this package.
|
|
src (str): The Python source code to save for this package.
|
|
is_package (bool, optional): If ``True``, this module is treated as a package. Packages are allowed to have submodules
|
|
(e.g. ``my_package.my_subpackage.my_subsubpackage``), and resources can be saved inside them. Defaults to ``False``.
|
|
dependencies (bool, optional): If ``True``, we scan the source for dependencies.
|
|
"""
|
|
self.dependency_graph.add_node(
|
|
module_name,
|
|
source=src,
|
|
is_package=is_package,
|
|
provided=True,
|
|
action=_ModuleProviderAction.INTERN,
|
|
)
|
|
|
|
if dependencies:
|
|
deps = self._get_dependencies(src, module_name, is_package)
|
|
|
|
for dep in deps:
|
|
self.dependency_graph.add_edge(module_name, dep)
|
|
self.add_dependency(dep)
|
|
|
|
def _write_source_string(
|
|
self,
|
|
module_name: str,
|
|
src: str,
|
|
is_package: bool = False,
|
|
):
|
|
"""Write ``src`` as the source code for ``module_name`` in the zip archive.
|
|
|
|
Arguments are otherwise the same as for :meth:`save_source_string`.
|
|
"""
|
|
extension = "/__init__.py" if is_package else ".py"
|
|
filename = module_name.replace(".", "/") + extension
|
|
|
|
self._write(filename, src)
|
|
|
|
def _import_module(self, module_name: str):
|
|
try:
|
|
return self.importer.import_module(module_name)
|
|
except ModuleNotFoundError as e:
|
|
if not is_mangled(module_name):
|
|
raise
|
|
msg = (
|
|
f"Module not found: '{module_name}'. Make sure the PackageImporter that "
|
|
"created this module is present in `self.importer`"
|
|
)
|
|
raise ModuleNotFoundError(msg) from None
|
|
|
|
def _module_exists(self, module_name: str) -> bool:
|
|
try:
|
|
self._import_module(module_name)
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
def _get_source_of_module(self, module: types.ModuleType) -> Optional[str]:
|
|
filename = None
|
|
spec = getattr(module, "__spec__", None)
|
|
if spec is not None:
|
|
loader = getattr(spec, "loader", None)
|
|
if loader is not None and isinstance(loader, SourceFileLoader):
|
|
try:
|
|
filename = loader.get_filename(module.__name__)
|
|
except ImportError:
|
|
pass
|
|
if filename is None:
|
|
filename = getattr(module, "__file__", None)
|
|
if isinstance(filename, str) and filename.endswith(".py"):
|
|
return "".join(linecache.getlines(filename, module.__dict__))
|
|
return None
|
|
|
|
def add_dependency(self, module_name: str, dependencies=True):
|
|
"""Given a module, add it to the dependency graph according to patterns
|
|
specified by the user.
|
|
"""
|
|
if (
|
|
module_name in self.dependency_graph
|
|
and self.dependency_graph.nodes[module_name].get("provided") is True
|
|
):
|
|
return
|
|
|
|
# Special case: PackageImporter provides a special module called
|
|
# `torch_package_importer` that allows packaged modules to reference
|
|
# their PackageImporter. We don't want to re-export this.
|
|
if module_name == "torch_package_importer":
|
|
self.dependency_graph.add_node(
|
|
module_name,
|
|
action=_ModuleProviderAction.SKIP,
|
|
provided=True,
|
|
)
|
|
return
|
|
|
|
if module_name == "_mock":
|
|
self.dependency_graph.add_node(
|
|
module_name,
|
|
action=_ModuleProviderAction.REPACKAGED_MOCK_MODULE,
|
|
provided=True,
|
|
)
|
|
return
|
|
|
|
if self._can_implicitly_extern(module_name):
|
|
self.dependency_graph.add_node(
|
|
module_name, action=_ModuleProviderAction.EXTERN, provided=True
|
|
)
|
|
return
|
|
|
|
for pattern, pattern_info in self.patterns.items():
|
|
if pattern.matches(module_name):
|
|
pattern_info.was_matched = True
|
|
self.dependency_graph.add_node(
|
|
module_name, action=pattern_info.action, provided=True
|
|
)
|
|
|
|
if pattern_info.action == _ModuleProviderAction.DENY:
|
|
# Requiring a denied module just adds an error to the graph.
|
|
self.dependency_graph.add_node(
|
|
module_name, error=PackagingErrorReason.DENIED
|
|
)
|
|
|
|
# If we are interning this module, we need to retrieve its
|
|
# dependencies and package those as well.
|
|
if pattern_info.action == _ModuleProviderAction.INTERN:
|
|
self._intern_module(module_name, dependencies)
|
|
return
|
|
|
|
# No patterns have matched. Explicitly add this as an error.
|
|
self.dependency_graph.add_node(
|
|
module_name, error=PackagingErrorReason.NO_ACTION
|
|
)
|
|
|
|
def save_module(self, module_name: str, dependencies=True):
|
|
"""Save the code for ``module`` into the package. Code for the module is resolved using the ``importers`` path to find the
|
|
module object, and then using its ``__file__`` attribute to find the source code.
|
|
|
|
Args:
|
|
module_name (str): e.g. ``my_package.my_subpackage``, code will be saved to provide code
|
|
for this package.
|
|
dependencies (bool, optional): If ``True``, we scan the source for dependencies.
|
|
"""
|
|
if not isinstance(module_name, str):
|
|
raise TypeError(
|
|
"save_module() expects a string input, did you perhaps mean to pass `__name__`?"
|
|
)
|
|
|
|
self._intern_module(module_name, dependencies)
|
|
|
|
def _intern_module(
|
|
self,
|
|
module_name: str,
|
|
dependencies: bool,
|
|
):
|
|
"""Adds the module to the dependency graph as an interned module,
|
|
along with any metadata needed to write it out to the zipfile at serialization time.
|
|
"""
|
|
module_obj = self._import_module(module_name)
|
|
# Subtle: if the import above succeeded, either:
|
|
# 1. The module name is not mangled, and this was just a regular import, or
|
|
# 2. The module name is mangled, but one of the importers was able to
|
|
# recognize the mangling and import it.
|
|
# Either way, it is now safe to demangle this name so that we don't
|
|
# serialize the mangled version to the package.
|
|
module_name = demangle(module_name)
|
|
|
|
# Find dependencies of this module and require them as well.
|
|
is_package = hasattr(module_obj, "__path__")
|
|
source = self._get_source_of_module(module_obj)
|
|
if source is None:
|
|
# Couldn't find a source! Add it to our dependency graph as broken
|
|
# and continue.
|
|
filename = getattr(module_obj, "__file__", None)
|
|
error_context = None
|
|
if filename is None:
|
|
packaging_error = PackagingErrorReason.NO_DUNDER_FILE
|
|
elif filename.endswith(tuple(importlib.machinery.EXTENSION_SUFFIXES)):
|
|
packaging_error = PackagingErrorReason.IS_EXTENSION_MODULE
|
|
else:
|
|
packaging_error = PackagingErrorReason.SOURCE_FILE_NOT_FOUND
|
|
error_context = f"filename: {filename}"
|
|
self.dependency_graph.add_node(
|
|
module_name,
|
|
action=_ModuleProviderAction.INTERN,
|
|
is_package=is_package,
|
|
error=packaging_error,
|
|
error_context=error_context,
|
|
provided=True,
|
|
)
|
|
return
|
|
|
|
self.dependency_graph.add_node(
|
|
module_name,
|
|
action=_ModuleProviderAction.INTERN,
|
|
is_package=is_package,
|
|
source=source,
|
|
provided=True,
|
|
)
|
|
|
|
if dependencies:
|
|
deps = self._get_dependencies(source, module_name, is_package)
|
|
for dep in deps:
|
|
self.dependency_graph.add_edge(module_name, dep)
|
|
self.add_dependency(dep)
|
|
|
|
def save_pickle(
|
|
self,
|
|
package: str,
|
|
resource: str,
|
|
obj: Any,
|
|
dependencies: bool = True,
|
|
pickle_protocol: int = 3,
|
|
):
|
|
"""Save a python object to the archive using pickle. Equivalent to :func:`torch.save` but saving into
|
|
the archive rather than a stand-alone file. Standard pickle does not save the code, only the objects.
|
|
If ``dependencies`` is true, this method will also scan the pickled objects for which modules are required
|
|
to reconstruct them and save the relevant code.
|
|
|
|
To be able to save an object where ``type(obj).__name__`` is ``my_module.MyObject``,
|
|
``my_module.MyObject`` must resolve to the class of the object according to the ``importer`` order. When saving objects that
|
|
have previously been packaged, the importer's ``import_module`` method will need to be present in the ``importer`` list
|
|
for this to work.
|
|
|
|
Args:
|
|
package (str): The name of module package this resource should go in (e.g. ``"my_package.my_subpackage"``).
|
|
resource (str): A unique name for the resource, used to identify it to load.
|
|
obj (Any): The object to save, must be picklable.
|
|
dependencies (bool, optional): If ``True``, we scan the source for dependencies.
|
|
"""
|
|
|
|
assert (pickle_protocol == 4) or (
|
|
pickle_protocol == 3
|
|
), "torch.package only supports pickle protocols 3 and 4"
|
|
|
|
filename = self._filename(package, resource)
|
|
# Write the pickle data for `obj`
|
|
data_buf = io.BytesIO()
|
|
pickler = create_pickler(data_buf, self.importer, protocol=pickle_protocol)
|
|
pickler.persistent_id = self._persistent_id
|
|
pickler.dump(obj)
|
|
data_value = data_buf.getvalue()
|
|
mocked_modules = defaultdict(list)
|
|
name_in_dependency_graph = f"<{package}.{resource}>"
|
|
self.dependency_graph.add_node(
|
|
name_in_dependency_graph,
|
|
action=_ModuleProviderAction.INTERN,
|
|
provided=True,
|
|
is_pickle=True,
|
|
)
|
|
|
|
def _check_mocked_error(module: Optional[str], field: Optional[str]):
|
|
"""
|
|
checks if an object (field) comes from a mocked module and then adds
|
|
the pair to mocked_modules which contains mocked modules paired with their
|
|
list of mocked objects present in the pickle.
|
|
|
|
We also hold the invariant that the first user defined rule that applies
|
|
to the module is the one we use.
|
|
"""
|
|
|
|
assert isinstance(module, str)
|
|
assert isinstance(field, str)
|
|
if self._can_implicitly_extern(module):
|
|
return
|
|
for pattern, pattern_info in self.patterns.items():
|
|
if pattern.matches(module):
|
|
if pattern_info.action == _ModuleProviderAction.MOCK:
|
|
mocked_modules[module].append(field)
|
|
return
|
|
|
|
if dependencies:
|
|
all_dependencies = []
|
|
module = None
|
|
field = None
|
|
memo: DefaultDict[int, str] = defaultdict(None)
|
|
memo_count = 0
|
|
# pickletools.dis(data_value)
|
|
for opcode, arg, pos in pickletools.genops(data_value):
|
|
if pickle_protocol == 4:
|
|
if (
|
|
opcode.name == "SHORT_BINUNICODE"
|
|
or opcode.name == "BINUNICODE"
|
|
or opcode.name == "BINUNICODE8"
|
|
):
|
|
assert isinstance(arg, str)
|
|
module = field
|
|
field = arg
|
|
memo[memo_count] = arg
|
|
elif (
|
|
opcode.name == "LONG_BINGET"
|
|
or opcode.name == "BINGET"
|
|
or opcode.name == "GET"
|
|
):
|
|
assert isinstance(arg, int)
|
|
module = field
|
|
field = memo.get(arg, None)
|
|
elif opcode.name == "MEMOIZE":
|
|
memo_count += 1
|
|
elif opcode.name == "STACK_GLOBAL":
|
|
if module is None:
|
|
# If not module was passed on in the entries preceeding this one, continue.
|
|
continue
|
|
assert isinstance(module, str)
|
|
if module not in all_dependencies:
|
|
all_dependencies.append(module)
|
|
_check_mocked_error(module, field)
|
|
elif (
|
|
pickle_protocol == 3 and opcode.name == "GLOBAL"
|
|
): # a global reference
|
|
assert isinstance(arg, str)
|
|
module, field = arg.split(" ")
|
|
if module not in all_dependencies:
|
|
all_dependencies.append(module)
|
|
_check_mocked_error(module, field)
|
|
for module_name in all_dependencies:
|
|
self.dependency_graph.add_edge(name_in_dependency_graph, module_name)
|
|
|
|
""" If an object happens to come from a mocked module, then we collect these errors and spit them
|
|
out with the other errors found by package exporter.
|
|
"""
|
|
if module in mocked_modules:
|
|
assert isinstance(module, str)
|
|
fields = mocked_modules[module]
|
|
self.dependency_graph.add_node(
|
|
module_name,
|
|
action=_ModuleProviderAction.MOCK,
|
|
error=PackagingErrorReason.MOCKED_BUT_STILL_USED,
|
|
error_context=f"Object(s) '{fields}' from module `{module_name}` was mocked out during packaging "
|
|
f"but is being used in resource - `{resource}` in package `{package}`. ",
|
|
provided=True,
|
|
)
|
|
else:
|
|
self.add_dependency(module_name)
|
|
|
|
self._write(filename, data_value)
|
|
|
|
def save_text(self, package: str, resource: str, text: str):
|
|
"""Save text data to the package.
|
|
|
|
Args:
|
|
package (str): The name of module package this resource should go it (e.g. ``"my_package.my_subpackage"``).
|
|
resource (str): A unique name for the resource, used to identify it to load.
|
|
text (str): The contents to save.
|
|
"""
|
|
return self.save_binary(package, resource, text.encode("utf-8"))
|
|
|
|
def save_binary(self, package, resource, binary: bytes):
|
|
"""Save raw bytes to the package.
|
|
|
|
Args:
|
|
package (str): The name of module package this resource should go it (e.g. ``"my_package.my_subpackage"``).
|
|
resource (str): A unique name for the resource, used to identify it to load.
|
|
binary (str): The data to save.
|
|
"""
|
|
filename = self._filename(package, resource)
|
|
self._write(filename, binary)
|
|
|
|
def register_extern_hook(self, hook: ActionHook) -> RemovableHandle:
|
|
"""Registers an extern hook on the exporter.
|
|
|
|
The hook will be called each time a module matches against an :meth:`extern` pattern.
|
|
It should have the following signature::
|
|
|
|
hook(exporter: PackageExporter, module_name: str) -> None
|
|
|
|
Hooks will be called in order of registration.
|
|
|
|
Returns:
|
|
:class:`torch.utils.hooks.RemovableHandle`:
|
|
A handle that can be used to remove the added hook by calling
|
|
``handle.remove()``.
|
|
"""
|
|
handle = RemovableHandle(self._extern_hooks)
|
|
self._extern_hooks[handle.id] = hook
|
|
return handle
|
|
|
|
def register_mock_hook(self, hook: ActionHook) -> RemovableHandle:
|
|
"""Registers a mock hook on the exporter.
|
|
|
|
The hook will be called each time a module matches against a :meth:`mock` pattern.
|
|
It should have the following signature::
|
|
|
|
hook(exporter: PackageExporter, module_name: str) -> None
|
|
|
|
Hooks will be called in order of registration.
|
|
|
|
Returns:
|
|
:class:`torch.utils.hooks.RemovableHandle`:
|
|
A handle that can be used to remove the added hook by calling
|
|
``handle.remove()``.
|
|
"""
|
|
handle = RemovableHandle(self._mock_hooks)
|
|
self._mock_hooks[handle.id] = hook
|
|
return handle
|
|
|
|
def register_intern_hook(self, hook: ActionHook) -> RemovableHandle:
|
|
"""Registers an intern hook on the exporter.
|
|
|
|
The hook will be called each time a module matches against an :meth:`intern` pattern.
|
|
It should have the following signature::
|
|
|
|
hook(exporter: PackageExporter, module_name: str) -> None
|
|
|
|
Hooks will be called in order of registration.
|
|
|
|
Returns:
|
|
:class:`torch.utils.hooks.RemovableHandle`:
|
|
A handle that can be used to remove the added hook by calling
|
|
``handle.remove()``.
|
|
"""
|
|
handle = RemovableHandle(self._intern_hooks)
|
|
self._intern_hooks[handle.id] = hook
|
|
return handle
|
|
|
|
def intern(
|
|
self,
|
|
include: "GlobPattern",
|
|
*,
|
|
exclude: "GlobPattern" = (),
|
|
allow_empty: bool = True,
|
|
):
|
|
"""Specify modules that should be packaged. A module must match some ``intern`` pattern in order to be
|
|
included in the package and have its dependencies processed recursively.
|
|
|
|
Args:
|
|
include (Union[List[str], str]): A string e.g. "my_package.my_subpackage", or list of strings
|
|
for the names of the modules to be externed. This can also be a glob-style pattern, as described in :meth:`mock`.
|
|
|
|
exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string.
|
|
|
|
allow_empty (bool): An optional flag that specifies whether the intern modules specified by this call
|
|
to the ``intern`` method must be matched to some module during packaging. If an ``intern`` module glob
|
|
pattern is added with ``allow_empty=False``, and :meth:`close` is called (either explicitly or via ``__exit__``)
|
|
before any modules match that pattern, an exception is thrown. If ``allow_empty=True``, no such exception is thrown.
|
|
|
|
"""
|
|
self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo(
|
|
_ModuleProviderAction.INTERN, allow_empty
|
|
)
|
|
|
|
def mock(
|
|
self,
|
|
include: "GlobPattern",
|
|
*,
|
|
exclude: "GlobPattern" = (),
|
|
allow_empty: bool = True,
|
|
):
|
|
"""Replace some required modules with a mock implementation. Mocked modules will return a fake
|
|
object for any attribute accessed from it. Because we copy file-by-file, the dependency resolution will sometimes
|
|
find files that are imported by model files but whose functionality is never used
|
|
(e.g. custom serialization code or training helpers).
|
|
Use this function to mock this functionality out without having to modify the original code.
|
|
|
|
Args:
|
|
include (Union[List[str], str]): A string e.g. ``"my_package.my_subpackage"``, or list of strings
|
|
for the names of the modules to be mocked out. Strings can also be a glob-style pattern
|
|
string that may match multiple modules. Any required dependencies that match this pattern
|
|
string will be mocked out automatically.
|
|
|
|
Examples :
|
|
``'torch.**'`` -- matches ``torch`` and all submodules of torch, e.g. ``'torch.nn'``
|
|
and ``'torch.nn.functional'``
|
|
|
|
``'torch.*'`` -- matches ``'torch.nn'`` or ``'torch.functional'``, but not
|
|
``'torch.nn.functional'``
|
|
|
|
exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string.
|
|
e.g. ``include='torch.**', exclude='torch.foo'`` will mock all torch packages except ``'torch.foo'``,
|
|
Default: is ``[]``.
|
|
|
|
allow_empty (bool): An optional flag that specifies whether the mock implementation(s) specified by this call
|
|
to the :meth:`mock` method must be matched to some module during packaging. If a mock is added with
|
|
``allow_empty=False``, and :meth:`close` is called (either explicitly or via ``__exit__``) and the mock has
|
|
not been matched to a module used by the package being exported, an exception is thrown.
|
|
If ``allow_empty=True``, no such exception is thrown.
|
|
|
|
"""
|
|
self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo(
|
|
_ModuleProviderAction.MOCK, allow_empty
|
|
)
|
|
|
|
def extern(
|
|
self,
|
|
include: "GlobPattern",
|
|
*,
|
|
exclude: "GlobPattern" = (),
|
|
allow_empty: bool = True,
|
|
):
|
|
"""Include ``module`` in the list of external modules the package can import.
|
|
This will prevent dependency discovery from saving
|
|
it in the package. The importer will load an external module directly from the standard import system.
|
|
Code for extern modules must also exist in the process loading the package.
|
|
|
|
Args:
|
|
include (Union[List[str], str]): A string e.g. ``"my_package.my_subpackage"``, or list of strings
|
|
for the names of the modules to be externed. This can also be a glob-style pattern, as
|
|
described in :meth:`mock`.
|
|
|
|
exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the
|
|
include string.
|
|
|
|
allow_empty (bool): An optional flag that specifies whether the extern modules specified by this call
|
|
to the ``extern`` method must be matched to some module during packaging. If an extern module glob
|
|
pattern is added with ``allow_empty=False``, and :meth:`close` is called (either explicitly or via
|
|
``__exit__``) before any modules match that pattern, an exception is thrown. If ``allow_empty=True``,
|
|
no such exception is thrown.
|
|
|
|
"""
|
|
self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo(
|
|
_ModuleProviderAction.EXTERN, allow_empty
|
|
)
|
|
|
|
def deny(self, include: "GlobPattern", *, exclude: "GlobPattern" = ()):
|
|
"""Blocklist modules who names match the given glob patterns from the list of modules the package can import.
|
|
If a dependency on any matching packages is found, a :class:`PackagingError` is raised.
|
|
|
|
Args:
|
|
include (Union[List[str], str]): A string e.g. ``"my_package.my_subpackage"``, or list of strings
|
|
for the names of the modules to be externed. This can also be a glob-style pattern, as described in :meth:`mock`.
|
|
|
|
exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string.
|
|
"""
|
|
self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo(
|
|
_ModuleProviderAction.DENY, allow_empty=True
|
|
)
|
|
|
|
def _persistent_id(self, obj):
|
|
if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
|
|
storage: Storage
|
|
if isinstance(obj, torch.storage.TypedStorage):
|
|
# TODO: Once we decide to break serialization FC, we can
|
|
# remove this case
|
|
untyped_storage = obj._untyped_storage
|
|
storage_type_str = obj.pickle_storage_type()
|
|
storage_type = getattr(torch, storage_type_str)
|
|
storage = cast(Storage, untyped_storage)
|
|
storage_numel = obj.size()
|
|
|
|
elif isinstance(obj, torch.UntypedStorage):
|
|
untyped_storage = obj
|
|
storage = cast(Storage, untyped_storage)
|
|
storage_type = normalize_storage_type(type(storage))
|
|
storage_numel = storage.nbytes()
|
|
else:
|
|
raise RuntimeError(f"storage type not recognized: {type(obj)}")
|
|
|
|
location = location_tag(storage)
|
|
|
|
# serialize storage if not already written
|
|
storage_present = self.storage_context.has_storage(storage)
|
|
storage_id = self.storage_context.get_or_add_storage(storage)
|
|
if not storage_present:
|
|
if storage.device.type != "cpu":
|
|
storage = storage.cpu()
|
|
num_bytes = storage.nbytes()
|
|
self.zip_file.write_record(
|
|
f".data/{storage_id}.storage", storage, num_bytes
|
|
)
|
|
return ("storage", storage_type, storage_id, location, storage_numel)
|
|
|
|
if hasattr(obj, "__reduce_package__"):
|
|
if _gate_torchscript_serialization and isinstance(
|
|
obj, torch.jit.RecursiveScriptModule
|
|
):
|
|
raise Exception(
|
|
"Serializing ScriptModules directly into a package is a beta feature. "
|
|
"To use, set global "
|
|
"`torch.package.package_exporter._gate_torchscript_serialization` to `False`."
|
|
)
|
|
if self.serialized_reduces.get(id(obj)) is None:
|
|
self.serialized_reduces[id(obj)] = (
|
|
"reduce_package",
|
|
id(obj),
|
|
*obj.__reduce_package__(self),
|
|
)
|
|
|
|
return self.serialized_reduces[id(obj)]
|
|
|
|
return None
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
# If __exit__ was called because an exception was raised, we do not
|
|
# attempt to finalize the package. Instead, control is returned to the
|
|
# caller to continue raising the exception.
|
|
if exc_type is not None:
|
|
# Do the bare minimum to leave the open buffer in a valid state.
|
|
self._finalize_zip()
|
|
return
|
|
|
|
self.close()
|
|
|
|
def _write(self, filename, str_or_bytes):
|
|
if filename in self._written_files:
|
|
raise AssertionError(
|
|
f"Tried to write file '{filename}', but it already exists in this archive. "
|
|
"Please file a bug."
|
|
)
|
|
self._written_files.add(filename)
|
|
|
|
if is_mangled(filename):
|
|
raise AssertionError(
|
|
f"Tried to save a torch.package'd module as '{filename}'. "
|
|
"Directly saving torch.package'd modules is not allowed."
|
|
)
|
|
if isinstance(str_or_bytes, str):
|
|
str_or_bytes = str_or_bytes.encode("utf-8")
|
|
self.zip_file.write_record(filename, str_or_bytes, len(str_or_bytes))
|
|
|
|
def _validate_dependency_graph(self):
|
|
# 1. Check the graph for any errors inserted during dependency analysis.
|
|
for attrs in self.dependency_graph.nodes.values():
|
|
if "error" in attrs:
|
|
raise PackagingError(self.dependency_graph, debug=self.debug)
|
|
|
|
# 2. Check that all patterns for which allow_empty=False have been matched at least once.
|
|
for pattern, pattern_info in self.patterns.items():
|
|
if not pattern_info.allow_empty and not pattern_info.was_matched:
|
|
raise EmptyMatchError(
|
|
f"Exporter did not match any modules to {pattern}, which was marked as allow_empty=False"
|
|
)
|
|
|
|
def _write_mock_file(self):
|
|
if "_mock.py" not in self._written_files:
|
|
mock_file = str(Path(__file__).parent / "_mock.py")
|
|
self._write_source_string("_mock", _read_file(mock_file), is_package=False)
|
|
|
|
def _execute_dependency_graph(self):
|
|
"""Takes a finalized dependency graph describing how to package all
|
|
modules and executes it, writing to the ZIP archive.
|
|
"""
|
|
self._validate_dependency_graph()
|
|
|
|
extern_modules = []
|
|
for module_name, attrs in self.dependency_graph.nodes.items():
|
|
action = attrs["action"]
|
|
|
|
if action == _ModuleProviderAction.EXTERN:
|
|
for hook in self._extern_hooks.values():
|
|
hook(self, module_name)
|
|
|
|
extern_modules.append(module_name)
|
|
|
|
elif action == _ModuleProviderAction.MOCK:
|
|
for hook in self._mock_hooks.values():
|
|
hook(self, module_name)
|
|
|
|
self._write_mock_file()
|
|
|
|
is_package = hasattr(self._import_module(module_name), "__path__")
|
|
self._write_source_string(module_name, _MOCK_IMPL, is_package)
|
|
|
|
elif action == _ModuleProviderAction.INTERN:
|
|
for hook in self._intern_hooks.values():
|
|
hook(self, module_name)
|
|
|
|
# The node in the dependency graph contains metadata that tells us
|
|
# how to intern the module.
|
|
if "provided" not in attrs:
|
|
raise AssertionError(
|
|
f"Module was marked `intern` but not provided: {module_name}"
|
|
)
|
|
|
|
if attrs.get("is_pickle") is True:
|
|
# This node came from save_pickle, we don't need to write any source for it.
|
|
continue
|
|
|
|
is_package = attrs["is_package"]
|
|
source = attrs["source"]
|
|
self._write_source_string(module_name, source, is_package)
|
|
|
|
elif action == _ModuleProviderAction.REPACKAGED_MOCK_MODULE:
|
|
self._write_mock_file()
|
|
elif action == _ModuleProviderAction.SKIP:
|
|
continue
|
|
else:
|
|
raise AssertionError(
|
|
f"Invalid action: {module_name}, {action}. Please report a bug to PyTorch."
|
|
)
|
|
|
|
extern_file_contents = "\n".join(extern_modules) + "\n"
|
|
self._write(".data/extern_modules", extern_file_contents)
|
|
|
|
def _write_python_version(self):
|
|
"""Writes the python version that the package was created with to .data/python_version"""
|
|
self._write(".data/python_version", platform.python_version())
|
|
|
|
def close(self):
|
|
"""Write the package to the filesystem. Any calls after :meth:`close` are now invalid.
|
|
It is preferable to use resource guard syntax instead::
|
|
|
|
with PackageExporter("file.zip") as e:
|
|
...
|
|
"""
|
|
self._execute_dependency_graph()
|
|
self._write_python_version()
|
|
|
|
self.script_module_serializer.write_files()
|
|
self._finalize_zip()
|
|
|
|
def _finalize_zip(self):
|
|
"""Called at the very end of packaging to leave the zipfile in a closed but valid state."""
|
|
del self.zip_file
|
|
if self.buffer:
|
|
self.buffer.flush()
|
|
|
|
def _filename(self, package, resource):
|
|
package_path = package.replace(".", "/")
|
|
resource = _normalize_path(resource)
|
|
return f"{package_path}/{resource}"
|
|
|
|
def _can_implicitly_extern(self, module_name: str):
|
|
top_level_package_name = module_name.partition(".")[0]
|
|
return top_level_package_name == "torch" or (
|
|
top_level_package_name not in _DISALLOWED_MODULES
|
|
and is_stdlib_module(top_level_package_name)
|
|
)
|
|
|
|
def dependency_graph_string(self) -> str:
|
|
"""Returns digraph string representation of dependencies in package.
|
|
|
|
Returns:
|
|
A string representation of dependencies in package.
|
|
"""
|
|
return self.dependency_graph.to_dot()
|
|
|
|
def _nodes_with_action_type(
|
|
self, action: Optional[_ModuleProviderAction]
|
|
) -> List[str]:
|
|
result = []
|
|
for name, node_dict in self.dependency_graph.nodes.items():
|
|
node_action = node_dict.get("action", None)
|
|
if node_action == action and "is_pickle" not in node_dict:
|
|
result.append(name)
|
|
result.sort()
|
|
return result
|
|
|
|
def externed_modules(self) -> List[str]:
|
|
"""Return all modules that are currently externed.
|
|
|
|
Returns:
|
|
A list containing the names of modules which will be
|
|
externed in this package.
|
|
"""
|
|
return self._nodes_with_action_type(_ModuleProviderAction.EXTERN)
|
|
|
|
def interned_modules(self) -> List[str]:
|
|
"""Return all modules that are currently interned.
|
|
|
|
Returns:
|
|
A list containing the names of modules which will be
|
|
interned in this package.
|
|
"""
|
|
return self._nodes_with_action_type(_ModuleProviderAction.INTERN)
|
|
|
|
def mocked_modules(self) -> List[str]:
|
|
"""Return all modules that are currently mocked.
|
|
|
|
Returns:
|
|
A list containing the names of modules which will be
|
|
mocked in this package.
|
|
"""
|
|
return self._nodes_with_action_type(_ModuleProviderAction.MOCK)
|
|
|
|
def denied_modules(self) -> List[str]:
|
|
"""Return all modules that are currently denied.
|
|
|
|
Returns:
|
|
A list containing the names of modules which will be
|
|
denied in this package.
|
|
"""
|
|
return self._nodes_with_action_type(_ModuleProviderAction.DENY)
|
|
|
|
def get_rdeps(self, module_name: str) -> List[str]:
|
|
"""Return a list of all modules which depend on the module ``module_name``.
|
|
|
|
Returns:
|
|
A list containing the names of modules which depend on ``module_name``.
|
|
"""
|
|
if module_name in self.dependency_graph._pred.keys():
|
|
return list(self.dependency_graph._pred[module_name].keys())
|
|
else:
|
|
return []
|
|
|
|
def all_paths(self, src: str, dst: str) -> str:
|
|
"""Return a dot representation of the subgraph
|
|
that has all paths from src to dst.
|
|
|
|
Returns:
|
|
A dot representation containing all paths from src to dst.
|
|
(https://graphviz.org/doc/info/lang.html)
|
|
"""
|
|
return self.dependency_graph.all_paths(src, dst)
|
|
|
|
|
|
# even though these are in the standard library, we do not allow them to be
|
|
# automatically externed since they offer a lot of system level access
|
|
_DISALLOWED_MODULES = ["sys", "io"]
|
|
|
|
_MOCK_IMPL = """\
|
|
from _mock import MockedObject
|
|
def __getattr__(attr: str):
|
|
return MockedObject(__name__ + '.' + attr, _suppress_err=True)
|
|
"""
|
|
|
|
|
|
def _read_file(filename: str) -> str:
|
|
with open(filename, "rb") as f:
|
|
b = f.read()
|
|
return b.decode("utf-8")
|