Shortcuts

models.base_model

Models built in MMF need to inherit BaseModel class and adhere to a fixed format. To create a model for MMF, follow this quick cheatsheet.

  1. Inherit BaseModel class, make sure to call super().__init__() in your class’s __init__ function.

  2. Implement build function for your model. If you build everything in __init__, you can just return in this function.

  3. Write a forward function which takes in a SampleList as an argument and returns a dict.

  4. Register using @registry.register_model("key") decorator on top of the class.

If you are doing logits based predictions, the dict you return from your model should contain a scores field. Losses are automatically calculated by the BaseModel class and added to this dict if not present.

Example:

import torch

from mmf.common.registry import registry
from mmf.models.base_model import BaseModel


@registry.register("pythia")
class Pythia(BaseModel):
    # config is model_config from global config
    def __init__(self, config):
        super().__init__(config)

    def build(self):
        ....

    def forward(self, sample_list):
        scores = torch.rand(sample_list.get_batch_size(), 3127)
        return {"scores": scores}
class mmf.models.base_model.BaseModel(config: Union[omegaconf.dictconfig.DictConfig, mmf.models.base_model.BaseModel.Config])[source]

For integration with MMF’s trainer, datasets and other features, models needs to inherit this class, call super, write a build function, write a forward function taking a SampleList as input and returning a dict as output and finally, register it using @registry.register_model

Parameters

config (DictConfig) – model_config configuration from global config.

class Config(model: str = '???', losses: Union[List[mmf.modules.losses.LossConfig], NoneType] = '???')[source]
build()[source]

Function to be implemented by the child class, in case they need to build their model separately than __init__. All model related downloads should also happen here.

configure_optimizers()[source]

Member function of PL modules. Used only when PL enabled.

format_for_prediction(results, report)[source]

Implement this method in models if it requires to modify prediction results using report fields. Note that the required fields in report should already be gathered in report.

classmethod format_state_key(key)[source]

Can be implemented if something special needs to be done to the key when pretrained model is being loaded. This will adapt and return keys according to that. Useful for backwards compatibility. See updated load_state_dict below. For an example, see VisualBERT model’s code.

Parameters

key (string) – key to be formatted

Returns

formatted key

Return type

string

forward(sample_list, *args, **kwargs)[source]

To be implemented by child class. Takes in a SampleList and returns back a dict.

Parameters
  • sample_list (SampleList) – SampleList returned by the DataLoader for

  • iteration (current) –

Returns

Dict containing scores object.

Return type

Dict

init_losses()[source]

Initializes loss for the model based losses key. Automatically called by MMF internally after building the model.

load_state_dict(state_dict, *args, **kwargs)[source]

Copies parameters and buffers from state_dict into this module and its descendants. If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Parameters
  • state_dict (dict) – a dict containing parameters and persistent buffers.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

Returns

  • missing_keys is a list of str containing the missing keys

  • unexpected_keys is a list of str containing the unexpected keys

Return type

NamedTuple with missing_keys and unexpected_keys fields

on_load_checkpoint(checkpoint: Dict[str, Any]) None[source]

This is called by the pl.LightningModule before the model’s checkpoint is loaded.

on_save_checkpoint(checkpoint: Dict[str, Any]) None[source]

Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters

checkpoint – The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.

Example:

def on_save_checkpoint(self, checkpoint):
    # 99% of use cases you don't need to implement this method
    checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object

Note

Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.

test_step(batch: mmf.common.sample.SampleList, batch_idx: int, *args, **kwargs)[source]

Member function of PL modules. Used only when PL enabled. To be implemented by child class. Takes in a SampleList, batch_idx and returns back a dict.

Parameters
  • sample_list (SampleList) – SampleList returned by the DataLoader for

  • iteration (current) –

Returns

Dict

training_step(batch: mmf.common.sample.SampleList, batch_idx: int, *args, **kwargs)[source]

Member function of PL modules. Used only when PL enabled. To be implemented by child class. Takes in a SampleList, batch_idx and returns back a dict.

Parameters
  • sample_list (SampleList) – SampleList returned by the DataLoader for

  • iteration (current) –

Returns

Dict containing loss.

Return type

Dict

validation_step(batch: mmf.common.sample.SampleList, batch_idx: int, *args, **kwargs)[source]

Member function of PL modules. Used only when PL enabled. To be implemented by child class. Takes in a SampleList, batch_idx and returns back a dict.

Parameters
  • sample_list (SampleList) – SampleList returned by the DataLoader for

  • iteration (current) –

Returns

Dict