82 lines
2.0 KiB
Python
82 lines
2.0 KiB
Python
|
from __future__ import division
|
||
|
import pytest
|
||
|
|
||
|
from preshed.counter import PreshCounter
|
||
|
|
||
|
|
||
|
def test_count():
|
||
|
counter = PreshCounter()
|
||
|
assert counter[12] == 0
|
||
|
counter.inc(12, 1)
|
||
|
assert counter[12] == 1
|
||
|
counter.inc(14, 10)
|
||
|
counter.inc(9, 10)
|
||
|
counter.inc(12, 4)
|
||
|
assert counter[12] == 5
|
||
|
assert counter[14] == 10
|
||
|
assert counter[9] == 10
|
||
|
|
||
|
|
||
|
def test_unsmooth_prob():
|
||
|
counter = PreshCounter()
|
||
|
assert counter.prob(12) == 0.0
|
||
|
counter.inc(12, 1)
|
||
|
assert counter.prob(12) == 1.0
|
||
|
counter.inc(14, 10)
|
||
|
assert counter.prob(14) == 10 / 11
|
||
|
assert counter.prob(12) == 1.0 / 11
|
||
|
|
||
|
def test_smooth_prob():
|
||
|
p = PreshCounter()
|
||
|
# 1 10
|
||
|
# 2 6
|
||
|
# 3 4
|
||
|
# 5 2
|
||
|
# 8 1
|
||
|
for i in range(10):
|
||
|
p.inc(100-i, 1) # 10 items of freq 1
|
||
|
for i in range(6):
|
||
|
p.inc(90 - i, 2) # 6 items of freq 2
|
||
|
for i in range(4):
|
||
|
p.inc(80 - i, 3) # 4 items of freq 3
|
||
|
for i in range(2):
|
||
|
p.inc(70 - i, 5) # 2 items of freq 5
|
||
|
for i in range(1):
|
||
|
p.inc(60 - i, 8) # 1 item of freq 8
|
||
|
|
||
|
assert p.total == (10 * 1) + (6 * 2) + (4 * 3) + (2 * 5) + (1 * 8)
|
||
|
|
||
|
assert p.prob(100) == 1.0 / p.total
|
||
|
assert p.prob(200) == 0.0
|
||
|
assert p.prob(60) == 8.0 / p.total
|
||
|
|
||
|
p.smooth()
|
||
|
|
||
|
assert p.smoother(1) < 1.0
|
||
|
assert p.smoother(8) < 8.0
|
||
|
assert p.prob(1000) < p.prob(100)
|
||
|
|
||
|
for event, count in reversed(sorted(p, key=lambda it: it[1])):
|
||
|
assert p.smoother(count) < count
|
||
|
|
||
|
|
||
|
import os
|
||
|
def test_large_freqs():
|
||
|
if 'TEST_FILE_LOC' in os.environ:
|
||
|
loc = os.environ['TEST_FILE_LOC']
|
||
|
else:
|
||
|
return None
|
||
|
counts = PreshCounter()
|
||
|
for i, line in enumerate(open(loc)):
|
||
|
line = line.strip()
|
||
|
if not line:
|
||
|
continue
|
||
|
freq = int(line.split()[0])
|
||
|
counts.inc(i+1, freq)
|
||
|
oov = i+2
|
||
|
assert counts.prob(oov) == 0.0
|
||
|
assert counts.prob(1) < 0.1
|
||
|
counts.smooth()
|
||
|
assert counts.prob(oov) > 0
|
||
|
assert counts.prob(oov) < counts.prob(i)
|