utils.text¶
Text utils module contains implementations for various decoding strategies like Greedy, Beam Search and Nucleus Sampling.
In your model’s config you can specify inference
attribute to use these strategies
in the following way:
model_config:
some_model:
inference:
- type: greedy
- params: {}
- class mmf.utils.text.NucleusSampling(vocab, config)[source]¶
Nucleus Sampling is a new text decoding strategy that avoids likelihood maximization. Rather, it works by sampling from the smallest set of top tokens which have a cumulative probability greater than a specified threshold.
Present text decoding strategies like beam search do not work well on open-ended generation tasks (even on strong language models like GPT-2). They tend to repeat text a lot and the main reason behind it is that they try to maximize likelihood, which is a contrast from human-generated text which has a mix of high and low probability tokens.
Nucleus Sampling is a stochastic approach and resolves this issue. Moreover, it improves upon other stochastic methods like top-k sampling by choosing the right amount of tokens to sample from. The overall result is better text generation on the same language model.
Link to the paper introducing Nucleus Sampling (Section 6) - https://arxiv.org/pdf/1904.09751.pdf
- Parameters
vocab (list) – Collection of all words in vocabulary.
sum_threshold (float) – Ceiling of sum of probabilities of tokens to sample from.
- class mmf.utils.text.TextDecoder(vocab)[source]¶
Base class to be inherited by all decoding strategies. Contains implementations that are common for all strategies.
- Parameters
vocab (list) – Collection of all words in vocabulary.
- mmf.utils.text.generate_ngrams(tokens, n=1)[source]¶
Generate ngrams for particular ‘n’ from a list of tokens
- Parameters
tokens (List[str]) – List of tokens for which the ngram are to be generated
n (int, optional) – n for which ngrams are to be generated. Defaults to 1.
- Returns
List of ngrams generated.
- Return type
List[str]
- mmf.utils.text.generate_ngrams_range(tokens, ngram_range=(1, 3))[source]¶
Generates and returns a list of ngrams for all n present in ngram_range
- Parameters
tokens (List[str]) – List of string tokens for which ngram are to be generated
ngram_range (List[int], optional) – List of ‘n’ for which ngrams are to be generated. For e.g. if ngram_range = (1, 4) then it will returns 1grams, 2grams and 3grams. Defaults to (1, 3).
- Returns
List of ngrams for each n in ngram_range
- Return type
List[str]