219 lines
6.7 KiB
Python
219 lines
6.7 KiB
Python
# Natural Language Toolkit
|
|
#
|
|
# Copyright (C) 2001-2023 NLTK Project
|
|
# Author: Ilia Kurenkov <ilia.kurenkov@gmail.com>
|
|
# URL: <https://www.nltk.org/>
|
|
# For license information, see LICENSE.TXT
|
|
"""Language Model Vocabulary"""
|
|
|
|
import sys
|
|
from collections import Counter
|
|
from collections.abc import Iterable
|
|
from functools import singledispatch
|
|
from itertools import chain
|
|
|
|
|
|
@singledispatch
|
|
def _dispatched_lookup(words, vocab):
|
|
raise TypeError(f"Unsupported type for looking up in vocabulary: {type(words)}")
|
|
|
|
|
|
@_dispatched_lookup.register(Iterable)
|
|
def _(words, vocab):
|
|
"""Look up a sequence of words in the vocabulary.
|
|
|
|
Returns an iterator over looked up words.
|
|
|
|
"""
|
|
return tuple(_dispatched_lookup(w, vocab) for w in words)
|
|
|
|
|
|
@_dispatched_lookup.register(str)
|
|
def _string_lookup(word, vocab):
|
|
"""Looks up one word in the vocabulary."""
|
|
return word if word in vocab else vocab.unk_label
|
|
|
|
|
|
class Vocabulary:
|
|
"""Stores language model vocabulary.
|
|
|
|
Satisfies two common language modeling requirements for a vocabulary:
|
|
|
|
- When checking membership and calculating its size, filters items
|
|
by comparing their counts to a cutoff value.
|
|
- Adds a special "unknown" token which unseen words are mapped to.
|
|
|
|
>>> words = ['a', 'c', '-', 'd', 'c', 'a', 'b', 'r', 'a', 'c', 'd']
|
|
>>> from nltk.lm import Vocabulary
|
|
>>> vocab = Vocabulary(words, unk_cutoff=2)
|
|
|
|
Tokens with counts greater than or equal to the cutoff value will
|
|
be considered part of the vocabulary.
|
|
|
|
>>> vocab['c']
|
|
3
|
|
>>> 'c' in vocab
|
|
True
|
|
>>> vocab['d']
|
|
2
|
|
>>> 'd' in vocab
|
|
True
|
|
|
|
Tokens with frequency counts less than the cutoff value will be considered not
|
|
part of the vocabulary even though their entries in the count dictionary are
|
|
preserved.
|
|
|
|
>>> vocab['b']
|
|
1
|
|
>>> 'b' in vocab
|
|
False
|
|
>>> vocab['aliens']
|
|
0
|
|
>>> 'aliens' in vocab
|
|
False
|
|
|
|
Keeping the count entries for seen words allows us to change the cutoff value
|
|
without having to recalculate the counts.
|
|
|
|
>>> vocab2 = Vocabulary(vocab.counts, unk_cutoff=1)
|
|
>>> "b" in vocab2
|
|
True
|
|
|
|
The cutoff value influences not only membership checking but also the result of
|
|
getting the size of the vocabulary using the built-in `len`.
|
|
Note that while the number of keys in the vocabulary's counter stays the same,
|
|
the items in the vocabulary differ depending on the cutoff.
|
|
We use `sorted` to demonstrate because it keeps the order consistent.
|
|
|
|
>>> sorted(vocab2.counts)
|
|
['-', 'a', 'b', 'c', 'd', 'r']
|
|
>>> sorted(vocab2)
|
|
['-', '<UNK>', 'a', 'b', 'c', 'd', 'r']
|
|
>>> sorted(vocab.counts)
|
|
['-', 'a', 'b', 'c', 'd', 'r']
|
|
>>> sorted(vocab)
|
|
['<UNK>', 'a', 'c', 'd']
|
|
|
|
In addition to items it gets populated with, the vocabulary stores a special
|
|
token that stands in for so-called "unknown" items. By default it's "<UNK>".
|
|
|
|
>>> "<UNK>" in vocab
|
|
True
|
|
|
|
We can look up words in a vocabulary using its `lookup` method.
|
|
"Unseen" words (with counts less than cutoff) are looked up as the unknown label.
|
|
If given one word (a string) as an input, this method will return a string.
|
|
|
|
>>> vocab.lookup("a")
|
|
'a'
|
|
>>> vocab.lookup("aliens")
|
|
'<UNK>'
|
|
|
|
If given a sequence, it will return an tuple of the looked up words.
|
|
|
|
>>> vocab.lookup(["p", 'a', 'r', 'd', 'b', 'c'])
|
|
('<UNK>', 'a', '<UNK>', 'd', '<UNK>', 'c')
|
|
|
|
It's possible to update the counts after the vocabulary has been created.
|
|
In general, the interface is the same as that of `collections.Counter`.
|
|
|
|
>>> vocab['b']
|
|
1
|
|
>>> vocab.update(["b", "b", "c"])
|
|
>>> vocab['b']
|
|
3
|
|
"""
|
|
|
|
def __init__(self, counts=None, unk_cutoff=1, unk_label="<UNK>"):
|
|
"""Create a new Vocabulary.
|
|
|
|
:param counts: Optional iterable or `collections.Counter` instance to
|
|
pre-seed the Vocabulary. In case it is iterable, counts
|
|
are calculated.
|
|
:param int unk_cutoff: Words that occur less frequently than this value
|
|
are not considered part of the vocabulary.
|
|
:param unk_label: Label for marking words not part of vocabulary.
|
|
|
|
"""
|
|
self.unk_label = unk_label
|
|
if unk_cutoff < 1:
|
|
raise ValueError(f"Cutoff value cannot be less than 1. Got: {unk_cutoff}")
|
|
self._cutoff = unk_cutoff
|
|
|
|
self.counts = Counter()
|
|
self.update(counts if counts is not None else "")
|
|
|
|
@property
|
|
def cutoff(self):
|
|
"""Cutoff value.
|
|
|
|
Items with count below this value are not considered part of vocabulary.
|
|
|
|
"""
|
|
return self._cutoff
|
|
|
|
def update(self, *counter_args, **counter_kwargs):
|
|
"""Update vocabulary counts.
|
|
|
|
Wraps `collections.Counter.update` method.
|
|
|
|
"""
|
|
self.counts.update(*counter_args, **counter_kwargs)
|
|
self._len = sum(1 for _ in self)
|
|
|
|
def lookup(self, words):
|
|
"""Look up one or more words in the vocabulary.
|
|
|
|
If passed one word as a string will return that word or `self.unk_label`.
|
|
Otherwise will assume it was passed a sequence of words, will try to look
|
|
each of them up and return an iterator over the looked up words.
|
|
|
|
:param words: Word(s) to look up.
|
|
:type words: Iterable(str) or str
|
|
:rtype: generator(str) or str
|
|
:raises: TypeError for types other than strings or iterables
|
|
|
|
>>> from nltk.lm import Vocabulary
|
|
>>> vocab = Vocabulary(["a", "b", "c", "a", "b"], unk_cutoff=2)
|
|
>>> vocab.lookup("a")
|
|
'a'
|
|
>>> vocab.lookup("aliens")
|
|
'<UNK>'
|
|
>>> vocab.lookup(["a", "b", "c", ["x", "b"]])
|
|
('a', 'b', '<UNK>', ('<UNK>', 'b'))
|
|
|
|
"""
|
|
return _dispatched_lookup(words, self)
|
|
|
|
def __getitem__(self, item):
|
|
return self._cutoff if item == self.unk_label else self.counts[item]
|
|
|
|
def __contains__(self, item):
|
|
"""Only consider items with counts GE to cutoff as being in the
|
|
vocabulary."""
|
|
return self[item] >= self.cutoff
|
|
|
|
def __iter__(self):
|
|
"""Building on membership check define how to iterate over
|
|
vocabulary."""
|
|
return chain(
|
|
(item for item in self.counts if item in self),
|
|
[self.unk_label] if self.counts else [],
|
|
)
|
|
|
|
def __len__(self):
|
|
"""Computing size of vocabulary reflects the cutoff."""
|
|
return self._len
|
|
|
|
def __eq__(self, other):
|
|
return (
|
|
self.unk_label == other.unk_label
|
|
and self.cutoff == other.cutoff
|
|
and self.counts == other.counts
|
|
)
|
|
|
|
def __str__(self):
|
|
return "<{} with cutoff={} unk_label='{}' and {} items>".format(
|
|
self.__class__.__name__, self.cutoff, self.unk_label, len(self)
|
|
)
|