ai-content-maker/.venv/Lib/site-packages/thinc/layers/torchscriptwrapper.py

94 lines
3.2 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
from typing import Any, Callable, Optional
from ..compat import torch
from ..model import Model
from ..shims import PyTorchGradScaler, PyTorchShim, TorchScriptShim
from .pytorchwrapper import (
convert_pytorch_default_inputs,
convert_pytorch_default_outputs,
forward,
)
def TorchScriptWrapper_v1(
torchscript_model: Optional["torch.jit.ScriptModule"] = None,
convert_inputs: Optional[Callable] = None,
convert_outputs: Optional[Callable] = None,
mixed_precision: bool = False,
grad_scaler: Optional[PyTorchGradScaler] = None,
device: Optional["torch.device"] = None,
) -> Model[Any, Any]:
"""Wrap a TorchScript model, so that it has the same API as Thinc models.
torchscript_model:
The TorchScript module. A value of `None` is also possible to
construct a shim to deserialize into.
convert_inputs:
Function that converts inputs and gradients that should be passed
to the model to Torch tensors.
convert_outputs:
Function that converts model outputs and gradients from Torch tensors
Thinc arrays.
mixed_precision:
Enable mixed-precision. This changes whitelisted ops to run
in half precision for better performance and lower memory use.
grad_scaler:
The gradient scaler to use for mixed-precision training. If this
argument is set to "None" and mixed precision is enabled, a gradient
scaler with the default configuration is used.
device:
The PyTorch device to run the model on. When this argument is
set to "None", the default device for the currently active Thinc
ops is used.
"""
if convert_inputs is None:
convert_inputs = convert_pytorch_default_inputs
if convert_outputs is None:
convert_outputs = convert_pytorch_default_outputs
return Model(
"pytorch_script",
forward,
attrs={"convert_inputs": convert_inputs, "convert_outputs": convert_outputs},
shims=[
TorchScriptShim(
model=torchscript_model,
mixed_precision=mixed_precision,
grad_scaler=grad_scaler,
device=device,
)
],
dims={"nI": None, "nO": None},
)
def pytorch_to_torchscript_wrapper(model: Model):
"""Convert a PyTorch wrapper to a TorchScript wrapper. The embedded PyTorch
`Module` is converted to `ScriptModule`.
"""
shim = model.shims[0]
if not isinstance(shim, PyTorchShim):
raise ValueError("Expected PyTorchShim when converting a PyTorch wrapper")
convert_inputs = model.attrs["convert_inputs"]
convert_outputs = model.attrs["convert_outputs"]
pytorch_model = shim._model
if not isinstance(pytorch_model, torch.nn.Module):
raise ValueError("PyTorchShim does not wrap a PyTorch module")
torchscript_model = torch.jit.script(pytorch_model)
grad_scaler = shim._grad_scaler
mixed_precision = shim._mixed_precision
device = shim.device
return TorchScriptWrapper_v1(
torchscript_model,
convert_inputs=convert_inputs,
convert_outputs=convert_outputs,
mixed_precision=mixed_precision,
grad_scaler=grad_scaler,
device=device,
)