Shortcuts

utils.text

Text utils module contains implementations for various decoding strategies like Greedy, Beam Search and Nucleus Sampling.

In your model’s config you can specify inference attribute to use these strategies in the following way:

model_config:
    some_model:
        inference:
            - type: greedy
            - params: {}
class mmf.utils.text.BeamSearch(vocab, config)[source]
class mmf.utils.text.NucleusSampling(vocab, config)[source]

Nucleus Sampling is a new text decoding strategy that avoids likelihood maximization. Rather, it works by sampling from the smallest set of top tokens which have a cumulative probability greater than a specified threshold.

Present text decoding strategies like beam search do not work well on open-ended generation tasks (even on strong language models like GPT-2). They tend to repeat text a lot and the main reason behind it is that they try to maximize likelihood, which is a contrast from human-generated text which has a mix of high and low probability tokens.

Nucleus Sampling is a stochastic approach and resolves this issue. Moreover, it improves upon other stochastic methods like top-k sampling by choosing the right amount of tokens to sample from. The overall result is better text generation on the same language model.

Link to the paper introducing Nucleus Sampling (Section 6) - https://arxiv.org/pdf/1904.09751.pdf

Parameters
  • vocab (list) – Collection of all words in vocabulary.

  • sum_threshold (float) – Ceiling of sum of probabilities of tokens to sample from.

class mmf.utils.text.TextDecoder(vocab)[source]

Base class to be inherited by all decoding strategies. Contains implementations that are common for all strategies.

Parameters

vocab (list) – Collection of all words in vocabulary.

mmf.utils.text.generate_ngrams(tokens, n=1)[source]

Generate ngrams for particular ‘n’ from a list of tokens

Parameters
  • tokens (List[str]) – List of tokens for which the ngram are to be generated

  • n (int, optional) – n for which ngrams are to be generated. Defaults to 1.

Returns

List of ngrams generated.

Return type

List[str]

mmf.utils.text.generate_ngrams_range(tokens, ngram_range=(1, 3))[source]

Generates and returns a list of ngrams for all n present in ngram_range

Parameters
  • tokens (List[str]) – List of string tokens for which ngram are to be generated

  • ngram_range (List[int], optional) – List of ‘n’ for which ngrams are to be generated. For e.g. if ngram_range = (1, 4) then it will returns 1grams, 2grams and 3grams. Defaults to (1, 3).

Returns

List of ngrams for each n in ngram_range

Return type

List[str]