ai-content-maker/.venv/Lib/site-packages/thinc/layers/with_debug.py

45 lines
1.3 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
from typing import Any, Callable, Optional, Tuple, TypeVar
from ..model import Model
_ModelT = TypeVar("_ModelT", bound=Model)
do_nothing = lambda *args, **kwargs: None
def with_debug(
layer: _ModelT,
name: Optional[str] = None,
*,
on_init: Callable[[Model, Any, Any], None] = do_nothing,
on_forward: Callable[[Model, Any, bool], None] = do_nothing,
on_backprop: Callable[[Any], None] = do_nothing,
) -> _ModelT:
"""Debugging layer that wraps any layer and allows executing callbacks
during the forward pass, backward pass and initialization. The callbacks
will receive the same arguments as the functions they're called in.
"""
name = layer.name if name is None else name
orig_forward = layer._func
orig_init = layer.init
def forward(model: Model, X: Any, is_train: bool) -> Tuple[Any, Callable]:
on_forward(model, X, is_train)
layer_Y, layer_callback = orig_forward(layer, X, is_train=is_train)
def backprop(dY: Any) -> Any:
on_backprop(dY)
return layer_callback(dY)
return layer_Y, backprop
def init(model: Model, X: Any, Y: Any) -> None:
on_init(model, X, Y)
if orig_init is not None:
orig_init(layer, X, Y)
layer.replace_callbacks(forward, init=init)
return layer