184 lines
6.4 KiB
Python
184 lines
6.4 KiB
Python
|
# mypy: ignore-errors
|
||
|
from typing import Any, Callable, Dict, Optional, Tuple, Type, TypeVar
|
||
|
|
||
|
import srsly
|
||
|
|
||
|
from ..compat import tensorflow as tf
|
||
|
from ..model import Model
|
||
|
from ..shims import TensorFlowShim, keras_model_fns, maybe_handshake_model
|
||
|
from ..types import ArgsKwargs, ArrayXd
|
||
|
from ..util import (
|
||
|
assert_tensorflow_installed,
|
||
|
convert_recursive,
|
||
|
is_tensorflow_array,
|
||
|
is_xp_array,
|
||
|
tensorflow2xp,
|
||
|
xp2tensorflow,
|
||
|
)
|
||
|
|
||
|
InT = TypeVar("InT")
|
||
|
OutT = TypeVar("OutT")
|
||
|
InFunc = TypeVar("InFunc")
|
||
|
XType = TypeVar("XType", bound=ArrayXd)
|
||
|
YType = TypeVar("YType", bound=ArrayXd)
|
||
|
|
||
|
|
||
|
def keras_subclass(
|
||
|
name: str,
|
||
|
X: XType,
|
||
|
Y: YType,
|
||
|
input_shape: Tuple[int, ...],
|
||
|
compile_args: Optional[Dict[str, Any]] = None,
|
||
|
) -> Callable[[InFunc], InFunc]:
|
||
|
"""Decorate a custom keras subclassed model with enough information to
|
||
|
serialize and deserialize it reliably in the face of the many restrictions
|
||
|
on keras subclassed models.
|
||
|
|
||
|
name (str): The unique namespace string to use to represent this model class.
|
||
|
X (Any): A sample X input for performing a forward pass on the network.
|
||
|
Y (Any): A sample Y input for performing a backward pass on the network.
|
||
|
input_shape (Tuple[int, ...]): A set of input shapes for building the network.
|
||
|
compile: Arguments to pass directly to the keras `model.compile` call.
|
||
|
|
||
|
RETURNS (Callable): The decorated class.
|
||
|
"""
|
||
|
|
||
|
compile_defaults = {"optimizer": "adam", "loss": "mse"}
|
||
|
if compile_args is None:
|
||
|
compile_args = compile_defaults
|
||
|
else:
|
||
|
compile_args = {**compile_defaults, **compile_args}
|
||
|
|
||
|
def call_fn(clazz):
|
||
|
|
||
|
clazz.catalogue_name = property(lambda inst: name)
|
||
|
clazz.eg_shape = property(lambda inst: input_shape)
|
||
|
clazz.eg_compile = property(lambda inst: compile_args)
|
||
|
clazz.eg_x = property(lambda inst: X)
|
||
|
clazz.eg_y = property(lambda inst: Y)
|
||
|
|
||
|
@keras_model_fns(name)
|
||
|
def create_component(*call_args, **call_kwargs):
|
||
|
return clazz(*call_args, **call_kwargs)
|
||
|
|
||
|
# Capture construction args and store them on the instance
|
||
|
wrapped_init = clazz.__init__
|
||
|
|
||
|
def __init__(self, *args, **kwargs):
|
||
|
wrapped_init(self, *args, **kwargs)
|
||
|
try:
|
||
|
srsly.json_dumps(args)
|
||
|
srsly.json_dumps(kwargs)
|
||
|
except BaseException as _err:
|
||
|
raise ValueError(
|
||
|
"In order to serialize Keras Subclass models, the constructor "
|
||
|
"arguments must be serializable. This allows thinc to recreate "
|
||
|
"the code-based model with the same configuration.\n"
|
||
|
f"The encountered error is: {_err}"
|
||
|
)
|
||
|
self.eg_args = ArgsKwargs(args, kwargs)
|
||
|
|
||
|
clazz.__init__ = __init__
|
||
|
|
||
|
return clazz
|
||
|
|
||
|
return call_fn
|
||
|
|
||
|
|
||
|
def TensorFlowWrapper(
|
||
|
tensorflow_model: Any,
|
||
|
convert_inputs: Optional[Callable] = None,
|
||
|
convert_outputs: Optional[Callable] = None,
|
||
|
optimizer: Optional[Any] = None,
|
||
|
model_class: Type[Model] = Model,
|
||
|
model_name: str = "tensorflow",
|
||
|
) -> Model[InT, OutT]:
|
||
|
"""Wrap a TensorFlow model, so that it has the same API as Thinc models.
|
||
|
To optimize the model, you'll need to create a TensorFlow optimizer and call
|
||
|
optimizer.apply_gradients after each batch.
|
||
|
"""
|
||
|
assert_tensorflow_installed()
|
||
|
if not isinstance(tensorflow_model, tf.keras.models.Model):
|
||
|
err = f"Expected tf.keras.models.Model, got: {type(tensorflow_model)}"
|
||
|
raise ValueError(err)
|
||
|
tensorflow_model = maybe_handshake_model(tensorflow_model)
|
||
|
if convert_inputs is None:
|
||
|
convert_inputs = _convert_inputs
|
||
|
if convert_outputs is None:
|
||
|
convert_outputs = _convert_outputs
|
||
|
return model_class(
|
||
|
model_name,
|
||
|
forward,
|
||
|
shims=[TensorFlowShim(tensorflow_model, optimizer=optimizer)],
|
||
|
attrs={"convert_inputs": convert_inputs, "convert_outputs": convert_outputs},
|
||
|
)
|
||
|
|
||
|
|
||
|
def forward(model: Model[InT, OutT], X: InT, is_train: bool) -> Tuple[OutT, Callable]:
|
||
|
"""Return the output of the wrapped TensorFlow model for the given input,
|
||
|
along with a callback to handle the backward pass.
|
||
|
"""
|
||
|
convert_inputs = model.attrs["convert_inputs"]
|
||
|
convert_outputs = model.attrs["convert_outputs"]
|
||
|
tensorflow_model = model.shims[0]
|
||
|
X_tensorflow, get_dX = convert_inputs(model, X, is_train)
|
||
|
if is_train:
|
||
|
Y_tensorflow, tensorflow_backprop = tensorflow_model(X_tensorflow, is_train)
|
||
|
else:
|
||
|
Y_tensorflow = tensorflow_model(X_tensorflow, is_train)
|
||
|
Y, get_dY_tensorflow = convert_outputs(model, Y_tensorflow, is_train)
|
||
|
|
||
|
def backprop(dY: OutT) -> InT:
|
||
|
dY_tensorflow = get_dY_tensorflow(dY)
|
||
|
dX_tensorflow = tensorflow_backprop(dY_tensorflow)
|
||
|
return get_dX(dX_tensorflow)
|
||
|
|
||
|
return Y, backprop
|
||
|
|
||
|
|
||
|
# Default conversion functions
|
||
|
# These are pretty much the same as the PyTorch one, but I think we should
|
||
|
# leave the duplication -- I think the abstraction could get pretty messy,
|
||
|
# and then may need to be undone, as there can always be different specifics.
|
||
|
|
||
|
|
||
|
def _convert_inputs(model, X, is_train):
|
||
|
xp2tensorflow_ = lambda x: xp2tensorflow(x, requires_grad=is_train)
|
||
|
converted = convert_recursive(is_xp_array, xp2tensorflow_, X)
|
||
|
if isinstance(converted, ArgsKwargs):
|
||
|
|
||
|
def reverse_conversion(dXtf):
|
||
|
return convert_recursive(is_tensorflow_array, tensorflow2xp, dXtf)
|
||
|
|
||
|
return converted, reverse_conversion
|
||
|
elif isinstance(converted, dict):
|
||
|
|
||
|
def reverse_conversion(dXtf):
|
||
|
dX = convert_recursive(is_tensorflow_array, tensorflow2xp, dXtf)
|
||
|
return dX.kwargs
|
||
|
|
||
|
return ArgsKwargs(args=tuple(), kwargs=converted), reverse_conversion
|
||
|
elif isinstance(converted, (tuple, list)):
|
||
|
|
||
|
def reverse_conversion(dXtf):
|
||
|
dX = convert_recursive(is_tensorflow_array, tensorflow2xp, dXtf)
|
||
|
return dX.args
|
||
|
|
||
|
return ArgsKwargs(args=converted, kwargs={}), reverse_conversion
|
||
|
else:
|
||
|
|
||
|
def reverse_conversion(dXtf):
|
||
|
dX = convert_recursive(is_tensorflow_array, tensorflow2xp, dXtf)
|
||
|
return dX.args[0]
|
||
|
|
||
|
return ArgsKwargs(args=(converted,), kwargs={}), reverse_conversion
|
||
|
|
||
|
|
||
|
def _convert_outputs(model, Ytf, is_train):
|
||
|
Y = convert_recursive(is_tensorflow_array, tensorflow2xp, Ytf)
|
||
|
|
||
|
def reverse_conversion(dY):
|
||
|
return convert_recursive(is_xp_array, xp2tensorflow, dY)
|
||
|
|
||
|
return Y, reverse_conversion
|