import contextlib import itertools from io import BytesIO from typing import Any, Callable, Dict, Optional, cast import srsly from ..backends import CupyOps, context_pools, get_current_ops, set_gpu_allocator from ..compat import torch from ..optimizers import Optimizer from ..types import ArgsKwargs, FloatsXd from ..util import ( convert_recursive, get_torch_default_device, iterate_recursive, torch2xp, xp2torch, ) from .pytorch_grad_scaler import PyTorchGradScaler from .shim import Shim class PyTorchShim(Shim): """Interface between a PyTorch model and a Thinc Model. This container is *not* a Thinc Model subclass itself. mixed_precision: Enable mixed-precision. This changes whitelisted ops to run in half precision for better performance and lower memory use. grad_scaler: The gradient scaler to use for mixed-precision training. If this argument is set to "None" and mixed precision is enabled, a gradient scaler with the default configuration is used. device: The PyTorch device to run the model on. When this argument is set to "None", the default device for the currently active Thinc ops is used. serialize_model: Callback that receives the wrapped PyTorch model as its argument and returns a "bytes" representation of the same. The representation should contain all the necessary information to fully deserialize the model. deserialize_model: Callback that receives the default PyTorch model (passed to the constructor), the serialized "bytes" representation and a PyTorch device. It should return a fully deserialized model on the target device as its result. """ def __init__( self, model: Any, config=None, optimizer: Any = None, mixed_precision: bool = False, grad_scaler: Optional[PyTorchGradScaler] = None, device: Optional["torch.device"] = None, serialize_model: Optional[Callable[[Any], bytes]] = None, deserialize_model: Optional[Callable[[Any, bytes, "torch.device"], Any]] = None, ): super().__init__(model, config, optimizer) if device is None: device = get_torch_default_device() if model is not None: model.to(device) if grad_scaler is None: grad_scaler = PyTorchGradScaler(mixed_precision) grad_scaler.to_(device) self._grad_scaler = grad_scaler self._mixed_precision = mixed_precision self._serialize_model = ( serialize_model if serialize_model is not None else default_serialize_torch_model ) self._deserialize_model = ( deserialize_model if deserialize_model is not None else default_deserialize_torch_model ) if CupyOps.xp is not None and isinstance(get_current_ops(), CupyOps): pools = context_pools.get() if "pytorch" not in pools: from cupy import get_default_memory_pool set_gpu_allocator("pytorch") get_default_memory_pool().free_all_blocks() def __call__(self, inputs, is_train): if is_train: return self.begin_update(inputs) else: return self.predict(inputs), lambda a: ... @property def device(self): p = next(self._model.parameters(), None) if p is None: return get_torch_default_device() else: return p.device def predict(self, inputs: ArgsKwargs) -> Any: """Pass inputs through to the underlying PyTorch model, and return the output. No conversions are performed. The PyTorch model is set into evaluation mode. """ self._model.eval() with torch.no_grad(): with torch.cuda.amp.autocast(self._mixed_precision): outputs = self._model(*inputs.args, **inputs.kwargs) self._model.train() return outputs def begin_update(self, inputs: ArgsKwargs): """Pass the inputs through to the underlying PyTorch model, keeping track of which items in the input are tensors requiring gradients. If the model returns a single value, it is converted into a one-element tuple. Return the outputs and a callback to backpropagate. """ self._model.train() # Note: mixed-precision autocast must not be applied to backprop. with torch.cuda.amp.autocast(self._mixed_precision): output = self._model(*inputs.args, **inputs.kwargs) def backprop(grads): # Normally, gradient scaling is applied to the loss of a model. However, # since regular thinc layers do not use mixed-precision, we perform scaling # locally in this shim. Scaling the loss by a factor, scales the gradients # by the same factor (see the chain rule). Therefore, we scale the gradients # backprop'ed through the succeeding layer to get the same effect as loss # scaling. grads.kwargs["grad_tensors"] = self._grad_scaler.scale( grads.kwargs["grad_tensors"], inplace=True ) torch.autograd.backward(*grads.args, **grads.kwargs) # Unscale weights and check for overflows during backprop. grad_tensors = [] for torch_data in itertools.chain( self._model.parameters(), iterate_recursive(lambda x: hasattr(x, "grad"), inputs), ): if torch_data.grad is not None: grad_tensors.append(torch_data.grad) found_inf = self._grad_scaler.unscale(grad_tensors) # If there was an over/underflow, return zeroed-out gradients. if found_inf: grad_get = lambda x: x.grad.zero_() if x.grad is not None else x.grad else: grad_get = lambda x: x.grad return convert_recursive(lambda x: hasattr(x, "grad"), grad_get, inputs) return output, backprop def finish_update(self, optimizer: Optimizer): for name, torch_data in self._model.named_parameters(): if torch_data.grad is not None: if ( not self._grad_scaler.found_inf ): # Skip weight update if any gradient overflowed. param, grad = optimizer( (self.id, name), cast(FloatsXd, torch2xp(torch_data.data)), cast(FloatsXd, torch2xp(torch_data.grad)), ) torch_data.data = xp2torch( param, requires_grad=True, device=torch_data.device ) torch_data.grad.zero_() self._grad_scaler.update() @contextlib.contextmanager def use_params(self, params): key_prefix = f"pytorch_{self.id}_" state_dict = {} for k, v in params.items(): if hasattr(k, "startswith") and k.startswith(key_prefix): state_dict[k.replace(key_prefix, "")] = xp2torch(v, device=self.device) if state_dict: backup = {k: v.clone() for k, v in self._model.state_dict().items()} self._model.load_state_dict(state_dict) yield self._model.load_state_dict(backup) else: yield def to_device(self, device_type: str, device_id: int): # pragma: no cover if device_type == "cpu": self._model.cpu() elif device_type == "gpu": self._model.cuda(device_id) else: msg = f"Invalid device_type: {device_type}. Try 'cpu' or 'gpu'" raise ValueError(msg) def to_bytes(self): model_bytes = self._serialize_model(self._model) msg = {"config": self.cfg, "state": model_bytes} return srsly.msgpack_dumps(msg) def from_bytes(self, bytes_data): device = get_torch_default_device() msg = srsly.msgpack_loads(bytes_data) self.cfg = msg["config"] self._model = self._deserialize_model(self._model, msg["state"], device) self._grad_scaler.to_(device) return self def default_serialize_torch_model(model: Any) -> bytes: """Serializes the parameters of the wrapped PyTorch model to bytes. model: Wrapped PyTorch model. Returns: A `bytes` object that encapsulates the serialized model parameters. """ filelike = BytesIO() torch.save(model.state_dict(), filelike) filelike.seek(0) return filelike.getvalue() def default_deserialize_torch_model( model: Any, state_bytes: bytes, device: "torch.device" ) -> Any: """Deserializes the parameters of the wrapped PyTorch model and moves it to the specified device. model: Wrapped PyTorch model. state_bytes: Serialized parameters as a byte stream. device: PyTorch device to which the model is bound. Returns: The deserialized model. """ filelike = BytesIO(state_bytes) filelike.seek(0) model.load_state_dict(torch.load(filelike, map_location=device)) model.to(device) return model