124 lines
4.2 KiB
Python
124 lines
4.2 KiB
Python
|
# mypy: ignore-errors
|
||
|
import copy
|
||
|
from typing import Any, cast
|
||
|
|
||
|
import srsly
|
||
|
|
||
|
from ..compat import mxnet as mx
|
||
|
from ..optimizers import Optimizer
|
||
|
from ..types import ArgsKwargs, FloatsXd
|
||
|
from ..util import (
|
||
|
convert_recursive,
|
||
|
get_array_module,
|
||
|
make_tempfile,
|
||
|
mxnet2xp,
|
||
|
xp2mxnet,
|
||
|
)
|
||
|
from .shim import Shim
|
||
|
|
||
|
|
||
|
class MXNetShim(Shim):
|
||
|
"""Interface between a MXNet model and a Thinc Model. This container is
|
||
|
*not* a Thinc Model subclass itself.
|
||
|
"""
|
||
|
|
||
|
def __call__(self, inputs, is_train):
|
||
|
if is_train:
|
||
|
return self.begin_update(inputs)
|
||
|
else:
|
||
|
return self.predict(inputs), lambda a: ...
|
||
|
|
||
|
def predict(self, inputs: ArgsKwargs) -> Any:
|
||
|
"""Pass inputs through to the underlying MXNet model, and return the
|
||
|
output. No conversions are performed. The MXNet model is set into
|
||
|
evaluation mode.
|
||
|
"""
|
||
|
mx.autograd.set_training(train_mode=False)
|
||
|
with mx.autograd.pause():
|
||
|
outputs = self._model(*inputs.args, **inputs.kwargs)
|
||
|
mx.autograd.set_training(train_mode=True)
|
||
|
return outputs
|
||
|
|
||
|
def begin_update(self, inputs: ArgsKwargs):
|
||
|
"""Pass the inputs through to the underlying MXNet model, keeping
|
||
|
track of which items in the input are tensors requiring gradients.
|
||
|
If the model returns a single value, it is converted into a one-element
|
||
|
tuple. Return the outputs and a callback to backpropagate.
|
||
|
"""
|
||
|
mx.autograd.set_training(train_mode=True)
|
||
|
mx.autograd.set_recording(True)
|
||
|
output = self._model(*inputs.args, **inputs.kwargs)
|
||
|
|
||
|
def backprop(grads):
|
||
|
mx.autograd.set_recording(False)
|
||
|
mx.autograd.backward(*grads.args, **grads.kwargs)
|
||
|
return convert_recursive(
|
||
|
lambda x: hasattr(x, "grad"), lambda x: x.grad, inputs
|
||
|
)
|
||
|
|
||
|
return output, backprop
|
||
|
|
||
|
def finish_update(self, optimizer: Optimizer):
|
||
|
params = []
|
||
|
grads = []
|
||
|
shapes = []
|
||
|
ctx = mx.current_context()
|
||
|
for key, value in self._model.collect_params().items():
|
||
|
grad = cast(FloatsXd, mxnet2xp(value.grad(ctx)))
|
||
|
param = cast(FloatsXd, mxnet2xp(value.data(ctx)))
|
||
|
params.append(param.ravel())
|
||
|
grads.append(grad.ravel())
|
||
|
shapes.append((param.size, param.shape))
|
||
|
if not params:
|
||
|
return
|
||
|
xp = get_array_module(params[0])
|
||
|
flat_params, flat_grads = optimizer(
|
||
|
(self.id, "mxnet-shim"), xp.concatenate(params), xp.concatenate(grads)
|
||
|
)
|
||
|
start = 0
|
||
|
for key, value in self._model.collect_params().items():
|
||
|
size, shape = shapes.pop(0)
|
||
|
param = flat_params[start : start + size].reshape(shape)
|
||
|
value.set_data(xp2mxnet(param))
|
||
|
value.zero_grad()
|
||
|
start += size
|
||
|
|
||
|
def copy(self, ctx: "mx.context.Context" = None):
|
||
|
if ctx is None:
|
||
|
ctx = mx.current_context()
|
||
|
model_bytes = self.to_bytes()
|
||
|
copied = copy.deepcopy(self)
|
||
|
copied._model.initialize(ctx=ctx)
|
||
|
copied.from_bytes(model_bytes)
|
||
|
return copied
|
||
|
|
||
|
def to_device(self, device_type: str, device_id: int):
|
||
|
if device_type == "cpu":
|
||
|
self._model = self.copy(mx.cpu())
|
||
|
elif device_type == "gpu":
|
||
|
self._model = self.copy(mx.gpu())
|
||
|
else:
|
||
|
msg = f"Unexpected device_type: {device_type}. Try 'cpu' or 'gpu'."
|
||
|
raise ValueError(msg)
|
||
|
|
||
|
def to_bytes(self):
|
||
|
# MXNet doesn't implement save/load without a filename
|
||
|
with make_tempfile("w+b") as temp:
|
||
|
self._model.save_parameters(temp.name)
|
||
|
temp.seek(0)
|
||
|
weights_bytes = temp.read()
|
||
|
msg = {"config": self.cfg, "state": weights_bytes}
|
||
|
return srsly.msgpack_dumps(msg)
|
||
|
|
||
|
def from_bytes(self, bytes_data):
|
||
|
msg = srsly.msgpack_loads(bytes_data)
|
||
|
self.cfg = msg["config"]
|
||
|
self._load_params(msg["state"])
|
||
|
return self
|
||
|
|
||
|
def _load_params(self, params):
|
||
|
# MXNet doesn't implement save/load without a filename :(
|
||
|
with make_tempfile("w+b") as temp:
|
||
|
temp.write(params)
|
||
|
self._model.load_parameters(temp.name, ctx=mx.current_context())
|