models.base_model¶
Models built in MMF need to inherit BaseModel
class and adhere to
a fixed format. To create a model for MMF, follow this quick cheatsheet.
Inherit
BaseModel
class, make sure to callsuper().__init__()
in your class’s__init__
function.Implement build function for your model. If you build everything in
__init__
, you can just return in this function.Write a forward function which takes in a
SampleList
as an argument and returns a dict.Register using
@registry.register_model("key")
decorator on top of the class.
If you are doing logits based predictions, the dict you return from your model
should contain a scores field. Losses are automatically calculated by the
BaseModel
class and added to this dict if not present.
Example:
import torch
from mmf.common.registry import registry
from mmf.models.base_model import BaseModel
@registry.register("pythia")
class Pythia(BaseModel):
# config is model_config from global config
def __init__(self, config):
super().__init__(config)
def build(self):
....
def forward(self, sample_list):
scores = torch.rand(sample_list.get_batch_size(), 3127)
return {"scores": scores}
- class mmf.models.base_model.BaseModel(config: Union[omegaconf.dictconfig.DictConfig, mmf.models.base_model.BaseModel.Config])[source]¶
For integration with MMF’s trainer, datasets and other features, models needs to inherit this class, call super, write a build function, write a forward function taking a
SampleList
as input and returning a dict as output and finally, register it using@registry.register_model
- Parameters
config (DictConfig) –
model_config
configuration from global config.
- class Config(model: str = '???', losses: Union[List[mmf.modules.losses.LossConfig], NoneType] = '???')[source]¶
- build()[source]¶
Function to be implemented by the child class, in case they need to build their model separately than
__init__
. All model related downloads should also happen here.
- format_for_prediction(results, report)[source]¶
Implement this method in models if it requires to modify prediction results using report fields. Note that the required fields in report should already be gathered in report.
- classmethod format_state_key(key)[source]¶
Can be implemented if something special needs to be done to the key when pretrained model is being loaded. This will adapt and return keys according to that. Useful for backwards compatibility. See updated load_state_dict below. For an example, see VisualBERT model’s code.
- Parameters
key (string) – key to be formatted
- Returns
formatted key
- Return type
string
- forward(sample_list, *args, **kwargs)[source]¶
To be implemented by child class. Takes in a
SampleList
and returns back a dict.- Parameters
sample_list (SampleList) – SampleList returned by the DataLoader for
iteration (current) –
- Returns
Dict containing scores object.
- Return type
Dict
- init_losses()[source]¶
Initializes loss for the model based
losses
key. Automatically called by MMF internally after building the model.
- load_state_dict(state_dict, *args, **kwargs)[source]¶
Copies parameters and buffers from
state_dict
into this module and its descendants. Ifstrict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.- Parameters
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
- Returns
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type
NamedTuple
withmissing_keys
andunexpected_keys
fields
- on_load_checkpoint(checkpoint: Dict[str, Any]) None [source]¶
This is called by the pl.LightningModule before the model’s checkpoint is loaded.
- on_save_checkpoint(checkpoint: Dict[str, Any]) None [source]¶
Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters
checkpoint – The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.
Example:
def on_save_checkpoint(self, checkpoint): # 99% of use cases you don't need to implement this method checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
Note
Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.
- test_step(batch: mmf.common.sample.SampleList, batch_idx: int, *args, **kwargs)[source]¶
Member function of PL modules. Used only when PL enabled. To be implemented by child class. Takes in a
SampleList
, batch_idx and returns back a dict.- Parameters
sample_list (SampleList) – SampleList returned by the DataLoader for
iteration (current) –
- Returns
Dict
- training_step(batch: mmf.common.sample.SampleList, batch_idx: int, *args, **kwargs)[source]¶
Member function of PL modules. Used only when PL enabled. To be implemented by child class. Takes in a
SampleList
, batch_idx and returns back a dict.- Parameters
sample_list (SampleList) – SampleList returned by the DataLoader for
iteration (current) –
- Returns
Dict containing loss.
- Return type
Dict
- validation_step(batch: mmf.common.sample.SampleList, batch_idx: int, *args, **kwargs)[source]¶
Member function of PL modules. Used only when PL enabled. To be implemented by child class. Takes in a
SampleList
, batch_idx and returns back a dict.- Parameters
sample_list (SampleList) – SampleList returned by the DataLoader for
iteration (current) –
- Returns
Dict