# cython: profile=False from cymem.cymem cimport Pool from thinc.extra.search cimport Beam from thinc.typedefs cimport class_t, weight_t cdef struct TestState: int length int x Py_UNICODE* string cdef int transition(void* dest, void* src, class_t clas, void* extra_args) except -1: dest_state = dest src_state = src dest_state.length = src_state.length dest_state.x = src_state.x dest_state.x += clas if extra_args != NULL: dest_state.string = extra_args else: dest_state.string = src_state.string cdef void* initialize(Pool mem, int n, void* extra_args) except NULL: state = mem.alloc(1, sizeof(TestState)) state.length = n state.x = 1 if extra_args == NULL: state.string = 'default' else: state.string = extra_args return state cdef int destroy(Pool mem, void* state, void* extra_args) except -1: state = state mem.free(state) def test_init(nr_class, beam_width): b = Beam(nr_class, beam_width) assert b.size == 1 assert b.width == beam_width assert b.nr_class == nr_class def test_initialize(nr_class, beam_width, length): b = Beam(nr_class, beam_width) b.initialize(initialize, destroy, length, NULL) for i in range(b.width): s = b.at(i) assert s.length == length, s.length assert s.string == 'default' def test_initialize_extra(nr_class, beam_width, length, unicode extra): b = Beam(nr_class, beam_width) b.initialize(initialize, destroy, length, extra) for i in range(b.width): s = b.at(i) assert s.length == length def test_transition(nr_class=3, beam_width=6, length=3): b = Beam(nr_class, beam_width) b.initialize(initialize, destroy, length, NULL) b.set_cell(0, 2, 30, True, 0) b.set_cell(0, 1, 42, False, 0) b.advance(transition, NULL, NULL) assert b.size == 1, b.size assert b.score == 30, b.score s = b.at(0) assert s.x == 3 assert b._states[0].score == 30, b._states[0].score b.set_cell(0, 1, 10, True, 0) b.set_cell(0, 2, 20, True, 0) b.advance(transition, NULL, NULL) assert b._states[0].score == 50, b._states[0].score assert b._states[1].score == 40 s = b.at(0) assert s.x == 5