91 lines
2.6 KiB
Cython
91 lines
2.6 KiB
Cython
|
from cymem.cymem cimport Pool
|
||
|
from libc.stdint cimport uint32_t, uint64_t
|
||
|
from libcpp.pair cimport pair
|
||
|
from libcpp.queue cimport priority_queue
|
||
|
from libcpp.vector cimport vector
|
||
|
|
||
|
ctypedef uint64_t hash_t
|
||
|
ctypedef uint64_t class_t
|
||
|
ctypedef float weight_t
|
||
|
|
||
|
|
||
|
ctypedef pair[weight_t, size_t] Entry
|
||
|
ctypedef priority_queue[Entry] Queue
|
||
|
|
||
|
|
||
|
ctypedef int (*trans_func_t)(void* dest, void* src, class_t clas, void* x) except -1
|
||
|
|
||
|
ctypedef void* (*init_func_t)(Pool mem, int n, void* extra_args) except NULL
|
||
|
|
||
|
ctypedef int (*del_func_t)(Pool mem, void* state, void* extra_args) except -1
|
||
|
|
||
|
ctypedef int (*finish_func_t)(void* state, void* extra_args) except -1
|
||
|
|
||
|
ctypedef hash_t (*hash_func_t)(void* state, void* x) except 0
|
||
|
|
||
|
|
||
|
cdef struct _State:
|
||
|
void* content
|
||
|
class_t* hist
|
||
|
weight_t score
|
||
|
weight_t loss
|
||
|
int i
|
||
|
int t
|
||
|
bint is_done
|
||
|
|
||
|
|
||
|
cdef class Beam:
|
||
|
cdef Pool mem
|
||
|
cdef class_t nr_class
|
||
|
cdef class_t width
|
||
|
cdef class_t size
|
||
|
cdef public weight_t min_density
|
||
|
cdef int t
|
||
|
cdef readonly bint is_done
|
||
|
cdef list histories
|
||
|
cdef list _parent_histories
|
||
|
cdef weight_t** scores
|
||
|
cdef int** is_valid
|
||
|
cdef weight_t** costs
|
||
|
cdef _State* _parents
|
||
|
cdef _State* _states
|
||
|
cdef del_func_t del_func
|
||
|
|
||
|
cdef int _fill(self, Queue* q, weight_t** scores, int** is_valid) except -1
|
||
|
|
||
|
cdef inline void* at(self, int i) nogil:
|
||
|
return self._states[i].content
|
||
|
|
||
|
cdef int initialize(self, init_func_t init_func, del_func_t del_func, int n, void* extra_args) except -1
|
||
|
cdef int advance(self, trans_func_t transition_func, hash_func_t hash_func,
|
||
|
void* extra_args) except -1
|
||
|
cdef int check_done(self, finish_func_t finish_func, void* extra_args) except -1
|
||
|
|
||
|
|
||
|
cdef inline void set_cell(self, int i, int j, weight_t score, int is_valid, weight_t cost) nogil:
|
||
|
self.scores[i][j] = score
|
||
|
self.is_valid[i][j] = is_valid
|
||
|
self.costs[i][j] = cost
|
||
|
|
||
|
cdef int set_row(self, int i, const weight_t* scores, const int* is_valid,
|
||
|
const weight_t* costs) except -1
|
||
|
cdef int set_table(self, weight_t** scores, int** is_valid, weight_t** costs) except -1
|
||
|
|
||
|
|
||
|
cdef class MaxViolation:
|
||
|
cdef Pool mem
|
||
|
cdef weight_t cost
|
||
|
cdef weight_t delta
|
||
|
cdef readonly weight_t p_score
|
||
|
cdef readonly weight_t g_score
|
||
|
cdef readonly double Z
|
||
|
cdef readonly double gZ
|
||
|
cdef class_t n
|
||
|
cdef readonly list p_hist
|
||
|
cdef readonly list g_hist
|
||
|
cdef readonly list p_probs
|
||
|
cdef readonly list g_probs
|
||
|
|
||
|
cpdef int check(self, Beam pred, Beam gold) except -1
|
||
|
cpdef int check_crf(self, Beam pred, Beam gold) except -1
|