ai-content-maker/.venv/Lib/site-packages/thinc/shims/torchscript.py

68 lines
2.4 KiB
Python

from io import BytesIO
from typing import Any, Optional
import srsly
from ..compat import torch
from ..util import get_torch_default_device
from .pytorch import PyTorchShim
from .pytorch_grad_scaler import PyTorchGradScaler
class TorchScriptShim(PyTorchShim):
"""A Thinc shim that wraps a TorchScript module.
model:
The TorchScript module. A value of `None` is also possible to
construct a shim to deserialize into.
mixed_precision:
Enable mixed-precision. This changes whitelisted ops to run
in half precision for better performance and lower memory use.
grad_scaler:
The gradient scaler to use for mixed-precision training. If this
argument is set to "None" and mixed precision is enabled, a gradient
scaler with the default configuration is used.
device:
The PyTorch device to run the model on. When this argument is
set to "None", the default device for the currently active Thinc
ops is used.
"""
def __init__(
self,
model: Optional["torch.jit.ScriptModule"],
config=None,
optimizer: Any = None,
mixed_precision: bool = False,
grad_scaler: Optional[PyTorchGradScaler] = None,
device: Optional["torch.device"] = None,
):
if model is not None and not isinstance(model, torch.jit.ScriptModule):
raise ValueError(
"PyTorchScriptShim must be initialized with ScriptModule or None (for deserialization)"
)
super().__init__(model, config, optimizer, mixed_precision, grad_scaler, device)
def to_bytes(self):
filelike = BytesIO()
torch.jit.save(self._model, filelike)
filelike.seek(0)
model_bytes = filelike.getvalue()
msg = {"config": self.cfg, "model": model_bytes}
return srsly.msgpack_dumps(msg)
def from_bytes(self, bytes_data):
device = get_torch_default_device()
msg = srsly.msgpack_loads(bytes_data)
self.cfg = msg["config"]
filelike = BytesIO(msg["model"])
filelike.seek(0)
# As of Torch 2.0.0, loading TorchScript models directly to
# an MPS device is not supported.
map_location = torch.device("cpu") if device.type == "mps" else device
self._model = torch.jit.load(filelike, map_location=map_location)
self._model.to(device)
self._grad_scaler.to_(device)
return self