119 lines
4.5 KiB
Python
119 lines
4.5 KiB
Python
|
"""isort:skip_file"""
|
||
|
from pickle import ( # type: ignore[attr-defined]
|
||
|
_compat_pickle,
|
||
|
_extension_registry,
|
||
|
_getattribute,
|
||
|
_Pickler,
|
||
|
EXT1,
|
||
|
EXT2,
|
||
|
EXT4,
|
||
|
GLOBAL,
|
||
|
Pickler,
|
||
|
PicklingError,
|
||
|
STACK_GLOBAL,
|
||
|
)
|
||
|
from struct import pack
|
||
|
from types import FunctionType
|
||
|
|
||
|
from .importer import Importer, ObjMismatchError, ObjNotFoundError, sys_importer
|
||
|
|
||
|
|
||
|
class PackagePickler(_Pickler):
|
||
|
"""Package-aware pickler.
|
||
|
|
||
|
This behaves the same as a normal pickler, except it uses an `Importer`
|
||
|
to find objects and modules to save.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, importer: Importer, *args, **kwargs):
|
||
|
self.importer = importer
|
||
|
super().__init__(*args, **kwargs)
|
||
|
|
||
|
# Make sure the dispatch table copied from _Pickler is up-to-date.
|
||
|
# Previous issues have been encountered where a library (e.g. dill)
|
||
|
# mutate _Pickler.dispatch, PackagePickler makes a copy when this lib
|
||
|
# is imported, then the offending library removes its dispatch entries,
|
||
|
# leaving PackagePickler with a stale dispatch table that may cause
|
||
|
# unwanted behavior.
|
||
|
self.dispatch = _Pickler.dispatch.copy() # type: ignore[misc]
|
||
|
self.dispatch[FunctionType] = PackagePickler.save_global # type: ignore[assignment]
|
||
|
|
||
|
def save_global(self, obj, name=None):
|
||
|
# unfortunately the pickler code is factored in a way that
|
||
|
# forces us to copy/paste this function. The only change is marked
|
||
|
# CHANGED below.
|
||
|
write = self.write # type: ignore[attr-defined]
|
||
|
memo = self.memo # type: ignore[attr-defined]
|
||
|
|
||
|
# CHANGED: import module from module environment instead of __import__
|
||
|
try:
|
||
|
module_name, name = self.importer.get_name(obj, name)
|
||
|
except (ObjNotFoundError, ObjMismatchError) as err:
|
||
|
raise PicklingError(f"Can't pickle {obj}: {str(err)}") from None
|
||
|
|
||
|
module = self.importer.import_module(module_name)
|
||
|
_, parent = _getattribute(module, name)
|
||
|
# END CHANGED
|
||
|
|
||
|
if self.proto >= 2: # type: ignore[attr-defined]
|
||
|
code = _extension_registry.get((module_name, name))
|
||
|
if code:
|
||
|
assert code > 0
|
||
|
if code <= 0xFF:
|
||
|
write(EXT1 + pack("<B", code))
|
||
|
elif code <= 0xFFFF:
|
||
|
write(EXT2 + pack("<H", code))
|
||
|
else:
|
||
|
write(EXT4 + pack("<i", code))
|
||
|
return
|
||
|
lastname = name.rpartition(".")[2]
|
||
|
if parent is module:
|
||
|
name = lastname
|
||
|
# Non-ASCII identifiers are supported only with protocols >= 3.
|
||
|
if self.proto >= 4: # type: ignore[attr-defined]
|
||
|
self.save(module_name) # type: ignore[attr-defined]
|
||
|
self.save(name) # type: ignore[attr-defined]
|
||
|
write(STACK_GLOBAL)
|
||
|
elif parent is not module:
|
||
|
self.save_reduce(getattr, (parent, lastname)) # type: ignore[attr-defined]
|
||
|
elif self.proto >= 3: # type: ignore[attr-defined]
|
||
|
write(
|
||
|
GLOBAL
|
||
|
+ bytes(module_name, "utf-8")
|
||
|
+ b"\n"
|
||
|
+ bytes(name, "utf-8")
|
||
|
+ b"\n"
|
||
|
)
|
||
|
else:
|
||
|
if self.fix_imports: # type: ignore[attr-defined]
|
||
|
r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING
|
||
|
r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING
|
||
|
if (module_name, name) in r_name_mapping:
|
||
|
module_name, name = r_name_mapping[(module_name, name)]
|
||
|
elif module_name in r_import_mapping:
|
||
|
module_name = r_import_mapping[module_name]
|
||
|
try:
|
||
|
write(
|
||
|
GLOBAL
|
||
|
+ bytes(module_name, "ascii")
|
||
|
+ b"\n"
|
||
|
+ bytes(name, "ascii")
|
||
|
+ b"\n"
|
||
|
)
|
||
|
except UnicodeEncodeError:
|
||
|
raise PicklingError(
|
||
|
"can't pickle global identifier '%s.%s' using "
|
||
|
"pickle protocol %i" % (module, name, self.proto) # type: ignore[attr-defined]
|
||
|
) from None
|
||
|
|
||
|
self.memoize(obj) # type: ignore[attr-defined]
|
||
|
|
||
|
|
||
|
def create_pickler(data_buf, importer, protocol=4):
|
||
|
if importer is sys_importer:
|
||
|
# if we are using the normal import library system, then
|
||
|
# we can use the C implementation of pickle which is faster
|
||
|
return Pickler(data_buf, protocol=protocol)
|
||
|
else:
|
||
|
return PackagePickler(importer, data_buf, protocol=protocol)
|