# cython: binding=True, infer_types=True, profile=False import numpy from preshed.maps cimport PreshMap from typing import Callable, Dict, Mapping, Optional, Tuple, Union, cast from ..config import registry from ..model import Model from ..types import Ints1d, Ints2d from ..util import to_numpy InT = Union[Ints1d, Ints2d] OutT = Ints2d cdef lookup(PreshMap mapping, long[:] keys, long default): """ Faster dict.get(keys, default) for the case when the "dict" is a Dict[int, int] converted to PreshMap and the "keys" is a numpy integer vector. """ cdef int maxi = len(keys) result = numpy.empty(maxi, dtype="int") cdef long[:] result_view = result for i in range(maxi): v = mapping[keys[i]] if v is None: result_view[i] = default else: result_view[i] = v return result @registry.layers("premap_ids.v1") def premap_ids( mapping_table: Mapping[int, int], default: int = 0, *, column: Optional[int] = None ): """Remap integer inputs to integers a mapping table, usually as a preprocess before embeddings.""" mapper = PreshMap(initial_size=len(mapping_table)) for k, v in mapping_table.items(): if not (isinstance(k, int) and isinstance(v, int)): raise ValueError( "mapping_table has to be of type Mapping[int, int], " f"but found {k}, {type(k)} and {v}, {type(v)}" ) mapper[k] = v return Model( "premap_ids", forward, attrs={ "mapping_table": mapper, "default": default, "column": column } ) def forward( model: Model, inputs: InT, is_train: bool ) -> Tuple[OutT, Callable]: table = model.attrs["mapping_table"] default = model.attrs["default"] column = model.attrs["column"] # Have to convert to numpy anyways, because # cupy ints don't work together with Python ints. if column is None: idx = to_numpy(inputs) else: idx = to_numpy(cast(Ints2d, inputs)[:, column]) result = lookup(table, idx, default) arr = model.ops.asarray2i(result) output = model.ops.reshape2i(arr, -1, 1) def backprop(dY: OutT) -> InT: return model.ops.xp.empty(dY.shape) # type: ignore return output, backprop