Adding a custom loss
Custom Losses#
This is a tutorial on how to add a new loss function to MMF.
MMF is agnostic to the kind of losses that can be added to it.
Adding a loss requires adding a loss class and adding your new loss to your config yaml.
For example, the ConcatBERT model uses the cross_entropy loss when training on the hateful memes dataset.
The loss class is CrossEntropyLoss defined in mmf/modules/losses.py
The loss key cross_entropy is added to the list of losses in the config yaml at mmf/projects/hateful_memes/configs/concat_bert/defaults.yaml.
Loss Class#
Add your loss class to losses.py. It should be a subclass of nn.Module.
Losses should implement a function forward with signature forward(self, sample_list, model_output),
where sample_list (SampleList) is the current batch and model_output is a dict return by your model for current sample_list.
Losses Config#
Add the name of your new loss class to your model config. Multiple losses can be specified with a yaml array.
For losses with params you can do,
Multi-Loss Classes#
If a loss class is responsible for calculating multiple losses, for example, maybe due to shared calculations you can return a dictionary of tensors.
The resulting loss that is optimized is the sum of all losses configured for the model.
For an example, take a look at the BCEAndKLLoss class in mmf/modules/losses.py