Adding a custom loss
#
Custom LossesThis is a tutorial on how to add a new loss function to MMF.
MMF is agnostic to the kind of losses that can be added to it.
Adding a loss requires adding a loss class and adding your new loss to your config yaml.
For example, the ConcatBERT model uses the cross_entropy
loss when training on the hateful memes dataset.
The loss class is CrossEntropyLoss
defined in mmf/modules/losses.py
The loss key cross_entropy
is added to the list of losses in the config yaml at mmf/projects/hateful_memes/configs/concat_bert/defaults.yaml.
#
Loss ClassAdd your loss class to losses.py. It should be a subclass of nn.Module
.
Losses should implement a function forward with signature forward(self, sample_list, model_output)
,
where sample_list (SampleList
) is the current batch and model_output is a dict return by your model for current sample_list.
#
Losses ConfigAdd the name of your new loss class to your model config. Multiple losses can be specified with a yaml array.
For losses with params you can do,
#
Multi-Loss ClassesIf a loss class is responsible for calculating multiple losses, for example, maybe due to shared calculations you can return a dictionary of tensors.
The resulting loss that is optimized is the sum of all losses configured for the model.
For an example, take a look at the BCEAndKLLoss
class in mmf/modules/losses.py