ai-content-maker/.venv/Lib/site-packages/thinc/layers/siamese.py

54 lines
1.6 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
from typing import Callable, Optional, Tuple, TypeVar
from ..config import registry
from ..model import Model
from ..types import ArrayXd
from ..util import get_width
LayerT = TypeVar("LayerT")
SimT = TypeVar("SimT")
InT = Tuple[LayerT, LayerT]
OutT = TypeVar("OutT", bound=ArrayXd)
@registry.layers("siamese.v1")
def siamese(
layer: Model[LayerT, SimT], similarity: Model[Tuple[SimT, SimT], OutT]
) -> Model[InT, OutT]:
return Model(
f"siamese({layer.name}, {similarity.name})",
forward,
init=init,
layers=[layer, similarity],
dims={"nI": layer.get_dim("nI"), "nO": similarity.get_dim("nO")},
)
def forward(
model: Model[InT, OutT], X1_X2: InT, is_train: bool
) -> Tuple[OutT, Callable]:
X1, X2 = X1_X2
vec1, bp_vec1 = model.layers[0](X1, is_train)
vec2, bp_vec2 = model.layers[0](X2, is_train)
output, bp_output = model.layers[1]((vec1, vec2), is_train)
def finish_update(d_output: OutT) -> InT:
d_vec1, d_vec2 = bp_output(d_output)
d_input1 = bp_vec1(d_vec1)
d_input2 = bp_vec2(d_vec2)
return (d_input1, d_input2)
return output, finish_update
def init(
model: Model[InT, OutT], X: Optional[InT] = None, Y: Optional[OutT] = None
) -> None:
if X is not None:
model.layers[0].set_dim("nI", get_width(X[1]))
model.layers[0].initialize(X=X[0])
X = (model.layers[0].predict(X[0]), model.layers[0].predict(X[1]))
model.layers[1].initialize(X=X, Y=Y)
model.set_dim("nI", model.layers[0].get_dim("nI"))
model.set_dim("nO", model.layers[1].get_dim("nO"))