models.base_model¶
Models built in MMF need to inherit BaseModel class and adhere to
a fixed format. To create a model for MMF, follow this quick cheatsheet.
Inherit
BaseModelclass, make sure to callsuper().__init__()in your class’s__init__function.Implement build function for your model. If you build everything in
__init__, you can just return in this function.Write a forward function which takes in a
SampleListas an argument and returns a dict.Register using
@registry.register_model("key")decorator on top of the class.
If you are doing logits based predictions, the dict you return from your model
should contain a scores field. Losses are automatically calculated by the
BaseModel class and added to this dict if not present.
Example:
import torch
from mmf.common.registry import registry
from mmf.models.base_model import BaseModel
@registry.register("pythia")
class Pythia(BaseModel):
# config is model_config from global config
def __init__(self, config):
super().__init__(config)
def build(self):
....
def forward(self, sample_list):
scores = torch.rand(sample_list.get_batch_size(), 3127)
return {"scores": scores}
- class mmf.models.base_model.BaseModel(config: Union[omegaconf.dictconfig.DictConfig, mmf.models.base_model.BaseModel.Config])[source]¶
For integration with MMF’s trainer, datasets and other features, models needs to inherit this class, call super, write a build function, write a forward function taking a
SampleListas input and returning a dict as output and finally, register it using@registry.register_model- Parameters
config (DictConfig) –
model_configconfiguration from global config.
- class Config(model: str = '???', losses: Union[List[mmf.modules.losses.LossConfig], NoneType] = '???')[source]¶
- build()[source]¶
Function to be implemented by the child class, in case they need to build their model separately than
__init__. All model related downloads should also happen here.
- format_for_prediction(results, report)[source]¶
Implement this method in models if it requires to modify prediction results using report fields. Note that the required fields in report should already be gathered in report.
- classmethod format_state_key(key)[source]¶
Can be implemented if something special needs to be done to the key when pretrained model is being loaded. This will adapt and return keys according to that. Useful for backwards compatibility. See updated load_state_dict below. For an example, see VisualBERT model’s code.
- Parameters
key (string) – key to be formatted
- Returns
formatted key
- Return type
string
- forward(sample_list, *args, **kwargs)[source]¶
To be implemented by child class. Takes in a
SampleListand returns back a dict.- Parameters
sample_list (SampleList) – SampleList returned by the DataLoader for
iteration (current) –
- Returns
Dict containing scores object.
- Return type
Dict
- init_losses()[source]¶
Initializes loss for the model based
losseskey. Automatically called by MMF internally after building the model.
- load_state_dict(state_dict, *args, **kwargs)[source]¶
Copies parameters and buffers from
state_dictinto this module and its descendants. IfstrictisTrue, then the keys ofstate_dictmust exactly match the keys returned by this module’sstate_dict()function.- Parameters
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dictmatch the keys returned by this module’sstate_dict()function. Default:True
- Returns
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type
NamedTuplewithmissing_keysandunexpected_keysfields
- on_load_checkpoint(checkpoint: Dict[str, Any]) None[source]¶
This is called by the pl.LightningModule before the model’s checkpoint is loaded.
- on_save_checkpoint(checkpoint: Dict[str, Any]) None[source]¶
Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters
checkpoint – The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.
Example:
def on_save_checkpoint(self, checkpoint): # 99% of use cases you don't need to implement this method checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
Note
Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.
- test_step(batch: mmf.common.sample.SampleList, batch_idx: int, *args, **kwargs)[source]¶
Member function of PL modules. Used only when PL enabled. To be implemented by child class. Takes in a
SampleList, batch_idx and returns back a dict.- Parameters
sample_list (SampleList) – SampleList returned by the DataLoader for
iteration (current) –
- Returns
Dict
- training_step(batch: mmf.common.sample.SampleList, batch_idx: int, *args, **kwargs)[source]¶
Member function of PL modules. Used only when PL enabled. To be implemented by child class. Takes in a
SampleList, batch_idx and returns back a dict.- Parameters
sample_list (SampleList) – SampleList returned by the DataLoader for
iteration (current) –
- Returns
Dict containing loss.
- Return type
Dict
- validation_step(batch: mmf.common.sample.SampleList, batch_idx: int, *args, **kwargs)[source]¶
Member function of PL modules. Used only when PL enabled. To be implemented by child class. Takes in a
SampleList, batch_idx and returns back a dict.- Parameters
sample_list (SampleList) – SampleList returned by the DataLoader for
iteration (current) –
- Returns
Dict