ai-content-maker/.venv/Lib/site-packages/thinc/tests/layers/test_with_flatten.py

32 lines
871 B
Python
Raw Permalink Normal View History

2024-05-03 04:18:51 +03:00
from typing import List
from thinc.api import Model, with_flatten_v2
INPUT = [[1, 2, 3], [4, 5], [], [6, 7, 8]]
INPUT_FLAT = [1, 2, 3, 4, 5, 6, 7, 8]
OUTPUT = [[2, 3, 4], [5, 6], [], [7, 8, 9]]
BACKPROP_OUTPUT = [[3, 4, 5], [6, 7], [], [8, 9, 10]]
def _memoize_input() -> Model[List[int], List[int]]:
return Model(name="memoize_input", forward=_memoize_input_forward)
def _memoize_input_forward(
model: Model[List[int], List[int]], X: List[int], is_train: bool
):
model.attrs["last_input"] = X
def backprop(dY: List[int]):
return [v + 2 for v in dY]
return [v + 1 for v in X], backprop
def test_with_flatten():
model = with_flatten_v2(_memoize_input())
Y, backprop = model(INPUT, is_train=True)
assert Y == OUTPUT
assert model.layers[0].attrs["last_input"] == INPUT_FLAT
assert backprop(INPUT) == BACKPROP_OUTPUT