256 lines
9.1 KiB
Python
256 lines
9.1 KiB
Python
|
import contextlib
|
||
|
import itertools
|
||
|
from io import BytesIO
|
||
|
from typing import Any, Callable, Dict, Optional, cast
|
||
|
|
||
|
import srsly
|
||
|
|
||
|
from ..backends import CupyOps, context_pools, get_current_ops, set_gpu_allocator
|
||
|
from ..compat import torch
|
||
|
from ..optimizers import Optimizer
|
||
|
from ..types import ArgsKwargs, FloatsXd
|
||
|
from ..util import (
|
||
|
convert_recursive,
|
||
|
get_torch_default_device,
|
||
|
iterate_recursive,
|
||
|
torch2xp,
|
||
|
xp2torch,
|
||
|
)
|
||
|
from .pytorch_grad_scaler import PyTorchGradScaler
|
||
|
from .shim import Shim
|
||
|
|
||
|
|
||
|
class PyTorchShim(Shim):
|
||
|
"""Interface between a PyTorch model and a Thinc Model. This container is
|
||
|
*not* a Thinc Model subclass itself.
|
||
|
|
||
|
mixed_precision:
|
||
|
Enable mixed-precision. This changes whitelisted ops to run
|
||
|
in half precision for better performance and lower memory use.
|
||
|
grad_scaler:
|
||
|
The gradient scaler to use for mixed-precision training. If this
|
||
|
argument is set to "None" and mixed precision is enabled, a gradient
|
||
|
scaler with the default configuration is used.
|
||
|
device:
|
||
|
The PyTorch device to run the model on. When this argument is
|
||
|
set to "None", the default device for the currently active Thinc
|
||
|
ops is used.
|
||
|
serialize_model:
|
||
|
Callback that receives the wrapped PyTorch model as its argument and
|
||
|
returns a "bytes" representation of the same. The representation should
|
||
|
contain all the necessary information to fully deserialize the model.
|
||
|
deserialize_model:
|
||
|
Callback that receives the default PyTorch model (passed to the constructor), the
|
||
|
serialized "bytes" representation and a PyTorch device. It should return a
|
||
|
fully deserialized model on the target device as its result.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
model: Any,
|
||
|
config=None,
|
||
|
optimizer: Any = None,
|
||
|
mixed_precision: bool = False,
|
||
|
grad_scaler: Optional[PyTorchGradScaler] = None,
|
||
|
device: Optional["torch.device"] = None,
|
||
|
serialize_model: Optional[Callable[[Any], bytes]] = None,
|
||
|
deserialize_model: Optional[Callable[[Any, bytes, "torch.device"], Any]] = None,
|
||
|
):
|
||
|
super().__init__(model, config, optimizer)
|
||
|
|
||
|
if device is None:
|
||
|
device = get_torch_default_device()
|
||
|
if model is not None:
|
||
|
model.to(device)
|
||
|
|
||
|
if grad_scaler is None:
|
||
|
grad_scaler = PyTorchGradScaler(mixed_precision)
|
||
|
|
||
|
grad_scaler.to_(device)
|
||
|
|
||
|
self._grad_scaler = grad_scaler
|
||
|
self._mixed_precision = mixed_precision
|
||
|
|
||
|
self._serialize_model = (
|
||
|
serialize_model
|
||
|
if serialize_model is not None
|
||
|
else default_serialize_torch_model
|
||
|
)
|
||
|
self._deserialize_model = (
|
||
|
deserialize_model
|
||
|
if deserialize_model is not None
|
||
|
else default_deserialize_torch_model
|
||
|
)
|
||
|
|
||
|
if CupyOps.xp is not None and isinstance(get_current_ops(), CupyOps):
|
||
|
pools = context_pools.get()
|
||
|
if "pytorch" not in pools:
|
||
|
from cupy import get_default_memory_pool
|
||
|
|
||
|
set_gpu_allocator("pytorch")
|
||
|
get_default_memory_pool().free_all_blocks()
|
||
|
|
||
|
def __call__(self, inputs, is_train):
|
||
|
if is_train:
|
||
|
return self.begin_update(inputs)
|
||
|
else:
|
||
|
return self.predict(inputs), lambda a: ...
|
||
|
|
||
|
@property
|
||
|
def device(self):
|
||
|
p = next(self._model.parameters(), None)
|
||
|
if p is None:
|
||
|
return get_torch_default_device()
|
||
|
else:
|
||
|
return p.device
|
||
|
|
||
|
def predict(self, inputs: ArgsKwargs) -> Any:
|
||
|
"""Pass inputs through to the underlying PyTorch model, and return the
|
||
|
output. No conversions are performed. The PyTorch model is set into
|
||
|
evaluation mode.
|
||
|
"""
|
||
|
self._model.eval()
|
||
|
with torch.no_grad():
|
||
|
with torch.cuda.amp.autocast(self._mixed_precision):
|
||
|
outputs = self._model(*inputs.args, **inputs.kwargs)
|
||
|
self._model.train()
|
||
|
return outputs
|
||
|
|
||
|
def begin_update(self, inputs: ArgsKwargs):
|
||
|
"""Pass the inputs through to the underlying PyTorch model, keeping
|
||
|
track of which items in the input are tensors requiring gradients.
|
||
|
If the model returns a single value, it is converted into a one-element tuple.
|
||
|
Return the outputs and a callback to backpropagate.
|
||
|
"""
|
||
|
self._model.train()
|
||
|
|
||
|
# Note: mixed-precision autocast must not be applied to backprop.
|
||
|
with torch.cuda.amp.autocast(self._mixed_precision):
|
||
|
output = self._model(*inputs.args, **inputs.kwargs)
|
||
|
|
||
|
def backprop(grads):
|
||
|
# Normally, gradient scaling is applied to the loss of a model. However,
|
||
|
# since regular thinc layers do not use mixed-precision, we perform scaling
|
||
|
# locally in this shim. Scaling the loss by a factor, scales the gradients
|
||
|
# by the same factor (see the chain rule). Therefore, we scale the gradients
|
||
|
# backprop'ed through the succeeding layer to get the same effect as loss
|
||
|
# scaling.
|
||
|
grads.kwargs["grad_tensors"] = self._grad_scaler.scale(
|
||
|
grads.kwargs["grad_tensors"], inplace=True
|
||
|
)
|
||
|
|
||
|
torch.autograd.backward(*grads.args, **grads.kwargs)
|
||
|
|
||
|
# Unscale weights and check for overflows during backprop.
|
||
|
grad_tensors = []
|
||
|
for torch_data in itertools.chain(
|
||
|
self._model.parameters(),
|
||
|
iterate_recursive(lambda x: hasattr(x, "grad"), inputs),
|
||
|
):
|
||
|
if torch_data.grad is not None:
|
||
|
grad_tensors.append(torch_data.grad)
|
||
|
found_inf = self._grad_scaler.unscale(grad_tensors)
|
||
|
|
||
|
# If there was an over/underflow, return zeroed-out gradients.
|
||
|
if found_inf:
|
||
|
grad_get = lambda x: x.grad.zero_() if x.grad is not None else x.grad
|
||
|
else:
|
||
|
grad_get = lambda x: x.grad
|
||
|
|
||
|
return convert_recursive(lambda x: hasattr(x, "grad"), grad_get, inputs)
|
||
|
|
||
|
return output, backprop
|
||
|
|
||
|
def finish_update(self, optimizer: Optimizer):
|
||
|
for name, torch_data in self._model.named_parameters():
|
||
|
if torch_data.grad is not None:
|
||
|
if (
|
||
|
not self._grad_scaler.found_inf
|
||
|
): # Skip weight update if any gradient overflowed.
|
||
|
param, grad = optimizer(
|
||
|
(self.id, name),
|
||
|
cast(FloatsXd, torch2xp(torch_data.data)),
|
||
|
cast(FloatsXd, torch2xp(torch_data.grad)),
|
||
|
)
|
||
|
torch_data.data = xp2torch(
|
||
|
param, requires_grad=True, device=torch_data.device
|
||
|
)
|
||
|
torch_data.grad.zero_()
|
||
|
|
||
|
self._grad_scaler.update()
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def use_params(self, params):
|
||
|
key_prefix = f"pytorch_{self.id}_"
|
||
|
state_dict = {}
|
||
|
for k, v in params.items():
|
||
|
if hasattr(k, "startswith") and k.startswith(key_prefix):
|
||
|
state_dict[k.replace(key_prefix, "")] = xp2torch(v, device=self.device)
|
||
|
if state_dict:
|
||
|
backup = {k: v.clone() for k, v in self._model.state_dict().items()}
|
||
|
self._model.load_state_dict(state_dict)
|
||
|
yield
|
||
|
self._model.load_state_dict(backup)
|
||
|
else:
|
||
|
yield
|
||
|
|
||
|
def to_device(self, device_type: str, device_id: int): # pragma: no cover
|
||
|
if device_type == "cpu":
|
||
|
self._model.cpu()
|
||
|
elif device_type == "gpu":
|
||
|
self._model.cuda(device_id)
|
||
|
else:
|
||
|
msg = f"Invalid device_type: {device_type}. Try 'cpu' or 'gpu'"
|
||
|
raise ValueError(msg)
|
||
|
|
||
|
def to_bytes(self):
|
||
|
model_bytes = self._serialize_model(self._model)
|
||
|
msg = {"config": self.cfg, "state": model_bytes}
|
||
|
return srsly.msgpack_dumps(msg)
|
||
|
|
||
|
def from_bytes(self, bytes_data):
|
||
|
device = get_torch_default_device()
|
||
|
msg = srsly.msgpack_loads(bytes_data)
|
||
|
self.cfg = msg["config"]
|
||
|
self._model = self._deserialize_model(self._model, msg["state"], device)
|
||
|
self._grad_scaler.to_(device)
|
||
|
return self
|
||
|
|
||
|
|
||
|
def default_serialize_torch_model(model: Any) -> bytes:
|
||
|
"""Serializes the parameters of the wrapped PyTorch model to bytes.
|
||
|
|
||
|
model:
|
||
|
Wrapped PyTorch model.
|
||
|
|
||
|
Returns:
|
||
|
A `bytes` object that encapsulates the serialized model parameters.
|
||
|
"""
|
||
|
filelike = BytesIO()
|
||
|
torch.save(model.state_dict(), filelike)
|
||
|
filelike.seek(0)
|
||
|
return filelike.getvalue()
|
||
|
|
||
|
|
||
|
def default_deserialize_torch_model(
|
||
|
model: Any, state_bytes: bytes, device: "torch.device"
|
||
|
) -> Any:
|
||
|
"""Deserializes the parameters of the wrapped PyTorch model and
|
||
|
moves it to the specified device.
|
||
|
|
||
|
model:
|
||
|
Wrapped PyTorch model.
|
||
|
state_bytes:
|
||
|
Serialized parameters as a byte stream.
|
||
|
device:
|
||
|
PyTorch device to which the model is bound.
|
||
|
|
||
|
Returns:
|
||
|
The deserialized model.
|
||
|
"""
|
||
|
filelike = BytesIO(state_bytes)
|
||
|
filelike.seek(0)
|
||
|
model.load_state_dict(torch.load(filelike, map_location=device))
|
||
|
model.to(device)
|
||
|
return model
|