27 lines
802 B
Python
27 lines
802 B
Python
|
from mock import MagicMock
|
||
|
|
||
|
from thinc.api import Linear, with_debug
|
||
|
|
||
|
|
||
|
def test_with_debug():
|
||
|
on_init = MagicMock()
|
||
|
on_forward = MagicMock()
|
||
|
on_backprop = MagicMock()
|
||
|
model = with_debug(
|
||
|
Linear(), on_init=on_init, on_forward=on_forward, on_backprop=on_backprop
|
||
|
)
|
||
|
on_init.assert_not_called()
|
||
|
on_forward.assert_not_called()
|
||
|
on_backprop.assert_not_called()
|
||
|
X = model.ops.alloc2f(1, 1)
|
||
|
Y = model.ops.alloc2f(1, 1)
|
||
|
model.initialize(X=X, Y=Y)
|
||
|
on_init.assert_called_once_with(model, X, Y)
|
||
|
on_forward.assert_not_called()
|
||
|
on_backprop.assert_not_called()
|
||
|
Yh, backprop = model(X, is_train=True)
|
||
|
on_forward.assert_called_once_with(model, X, True)
|
||
|
on_backprop.assert_not_called()
|
||
|
backprop(Y)
|
||
|
on_backprop.assert_called_once_with(Y)
|