# Natural Language Toolkit: Language Model Unit Tests # # Copyright (C) 2001-2023 NLTK Project # Author: Ilia Kurenkov # URL: # For license information, see LICENSE.TXT import unittest import pytest from nltk import FreqDist from nltk.lm import NgramCounter from nltk.util import everygrams class TestNgramCounter: """Tests for NgramCounter that only involve lookup, no modification.""" @classmethod def setup_class(self): text = [list("abcd"), list("egdbe")] self.trigram_counter = NgramCounter( everygrams(sent, max_len=3) for sent in text ) self.bigram_counter = NgramCounter(everygrams(sent, max_len=2) for sent in text) self.case = unittest.TestCase() def test_N(self): assert self.bigram_counter.N() == 16 assert self.trigram_counter.N() == 21 def test_counter_len_changes_with_lookup(self): assert len(self.bigram_counter) == 2 self.bigram_counter[50] assert len(self.bigram_counter) == 3 def test_ngram_order_access_unigrams(self): assert self.bigram_counter[1] == self.bigram_counter.unigrams def test_ngram_conditional_freqdist(self): case = unittest.TestCase() expected_trigram_contexts = [ ("a", "b"), ("b", "c"), ("e", "g"), ("g", "d"), ("d", "b"), ] expected_bigram_contexts = [("a",), ("b",), ("d",), ("e",), ("c",), ("g",)] bigrams = self.trigram_counter[2] trigrams = self.trigram_counter[3] self.case.assertCountEqual(expected_bigram_contexts, bigrams.conditions()) self.case.assertCountEqual(expected_trigram_contexts, trigrams.conditions()) def test_bigram_counts_seen_ngrams(self): assert self.bigram_counter[["a"]]["b"] == 1 assert self.bigram_counter[["b"]]["c"] == 1 def test_bigram_counts_unseen_ngrams(self): assert self.bigram_counter[["b"]]["z"] == 0 def test_unigram_counts_seen_words(self): assert self.bigram_counter["b"] == 2 def test_unigram_counts_completely_unseen_words(self): assert self.bigram_counter["z"] == 0 class TestNgramCounterTraining: @classmethod def setup_class(self): self.counter = NgramCounter() self.case = unittest.TestCase() @pytest.mark.parametrize("case", ["", [], None]) def test_empty_inputs(self, case): test = NgramCounter(case) assert 2 not in test assert test[1] == FreqDist() def test_train_on_unigrams(self): words = list("abcd") counter = NgramCounter([[(w,) for w in words]]) assert not counter[3] assert not counter[2] self.case.assertCountEqual(words, counter[1].keys()) def test_train_on_illegal_sentences(self): str_sent = ["Check", "this", "out", "!"] list_sent = [["Check", "this"], ["this", "out"], ["out", "!"]] with pytest.raises(TypeError): NgramCounter([str_sent]) with pytest.raises(TypeError): NgramCounter([list_sent]) def test_train_on_bigrams(self): bigram_sent = [("a", "b"), ("c", "d")] counter = NgramCounter([bigram_sent]) assert not bool(counter[3]) def test_train_on_mix(self): mixed_sent = [("a", "b"), ("c", "d"), ("e", "f", "g"), ("h",)] counter = NgramCounter([mixed_sent]) unigrams = ["h"] bigram_contexts = [("a",), ("c",)] trigram_contexts = [("e", "f")] self.case.assertCountEqual(unigrams, counter[1].keys()) self.case.assertCountEqual(bigram_contexts, counter[2].keys()) self.case.assertCountEqual(trigram_contexts, counter[3].keys())