ai-content-maker/.venv/Lib/site-packages/torch/package/_package_pickler.py

119 lines
4.5 KiB
Python

"""isort:skip_file"""
from pickle import ( # type: ignore[attr-defined]
_compat_pickle,
_extension_registry,
_getattribute,
_Pickler,
EXT1,
EXT2,
EXT4,
GLOBAL,
Pickler,
PicklingError,
STACK_GLOBAL,
)
from struct import pack
from types import FunctionType
from .importer import Importer, ObjMismatchError, ObjNotFoundError, sys_importer
class PackagePickler(_Pickler):
"""Package-aware pickler.
This behaves the same as a normal pickler, except it uses an `Importer`
to find objects and modules to save.
"""
def __init__(self, importer: Importer, *args, **kwargs):
self.importer = importer
super().__init__(*args, **kwargs)
# Make sure the dispatch table copied from _Pickler is up-to-date.
# Previous issues have been encountered where a library (e.g. dill)
# mutate _Pickler.dispatch, PackagePickler makes a copy when this lib
# is imported, then the offending library removes its dispatch entries,
# leaving PackagePickler with a stale dispatch table that may cause
# unwanted behavior.
self.dispatch = _Pickler.dispatch.copy() # type: ignore[misc]
self.dispatch[FunctionType] = PackagePickler.save_global # type: ignore[assignment]
def save_global(self, obj, name=None):
# unfortunately the pickler code is factored in a way that
# forces us to copy/paste this function. The only change is marked
# CHANGED below.
write = self.write # type: ignore[attr-defined]
memo = self.memo # type: ignore[attr-defined]
# CHANGED: import module from module environment instead of __import__
try:
module_name, name = self.importer.get_name(obj, name)
except (ObjNotFoundError, ObjMismatchError) as err:
raise PicklingError(f"Can't pickle {obj}: {str(err)}") from None
module = self.importer.import_module(module_name)
_, parent = _getattribute(module, name)
# END CHANGED
if self.proto >= 2: # type: ignore[attr-defined]
code = _extension_registry.get((module_name, name))
if code:
assert code > 0
if code <= 0xFF:
write(EXT1 + pack("<B", code))
elif code <= 0xFFFF:
write(EXT2 + pack("<H", code))
else:
write(EXT4 + pack("<i", code))
return
lastname = name.rpartition(".")[2]
if parent is module:
name = lastname
# Non-ASCII identifiers are supported only with protocols >= 3.
if self.proto >= 4: # type: ignore[attr-defined]
self.save(module_name) # type: ignore[attr-defined]
self.save(name) # type: ignore[attr-defined]
write(STACK_GLOBAL)
elif parent is not module:
self.save_reduce(getattr, (parent, lastname)) # type: ignore[attr-defined]
elif self.proto >= 3: # type: ignore[attr-defined]
write(
GLOBAL
+ bytes(module_name, "utf-8")
+ b"\n"
+ bytes(name, "utf-8")
+ b"\n"
)
else:
if self.fix_imports: # type: ignore[attr-defined]
r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING
r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING
if (module_name, name) in r_name_mapping:
module_name, name = r_name_mapping[(module_name, name)]
elif module_name in r_import_mapping:
module_name = r_import_mapping[module_name]
try:
write(
GLOBAL
+ bytes(module_name, "ascii")
+ b"\n"
+ bytes(name, "ascii")
+ b"\n"
)
except UnicodeEncodeError:
raise PicklingError(
"can't pickle global identifier '%s.%s' using "
"pickle protocol %i" % (module, name, self.proto) # type: ignore[attr-defined]
) from None
self.memoize(obj) # type: ignore[attr-defined]
def create_pickler(data_buf, importer, protocol=4):
if importer is sys_importer:
# if we are using the normal import library system, then
# we can use the C implementation of pickle which is faster
return Pickler(data_buf, protocol=protocol)
else:
return PackagePickler(importer, data_buf, protocol=protocol)