# Copyright (c) Facebook, Inc. and its affiliates.
"""
Models built in MMF need to inherit ``BaseModel`` class and adhere to
a fixed format. To create a model for MMF, follow this quick cheatsheet.
1. Inherit ``BaseModel`` class, make sure to call ``super().__init__()`` in your
class's ``__init__`` function.
2. Implement `build` function for your model. If you build everything in ``__init__``,
you can just return in this function.
3. Write a `forward` function which takes in a ``SampleList`` as an argument and
returns a dict.
4. Register using ``@registry.register_model("key")`` decorator on top of the
class.
If you are doing logits based predictions, the dict you return from your model
should contain a `scores` field. Losses are automatically calculated by the
``BaseModel`` class and added to this dict if not present.
Example::
import torch
from mmf.common.registry import registry
from mmf.models.base_model import BaseModel
@registry.register("pythia")
class Pythia(BaseModel):
# config is model_config from global config
def __init__(self, config):
super().__init__(config)
def build(self):
....
def forward(self, sample_list):
scores = torch.rand(sample_list.get_batch_size(), 3127)
return {"scores": scores}
"""
import collections
import logging
import warnings
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import pytorch_lightning as pl
from mmf.common.registry import registry
from mmf.common.report import Report
from mmf.common.sample import SampleList, to_device
from mmf.modules.losses import LossConfig, Losses
from mmf.utils.checkpoint import load_pretrained_model
from mmf.utils.checkpoint_updater import MMFToPLCheckpointUpdater
from mmf.utils.download import download_pretrained_model
from mmf.utils.file_io import PathManager
from mmf.utils.general import get_current_device
from mmf.utils.logger import log_class_usage
from omegaconf import MISSING, DictConfig, OmegaConf
logger = logging.getLogger(__name__)
[docs]class BaseModel(pl.LightningModule):
"""For integration with MMF's trainer, datasets and other features,
models needs to inherit this class, call `super`, write a build function,
write a forward function taking a ``SampleList`` as input and returning a
dict as output and finally, register it using ``@registry.register_model``
Args:
config (DictConfig): ``model_config`` configuration from global config.
"""
[docs] @dataclass
class Config:
# Name of the model that is used in registry
model: str = MISSING
losses: Optional[List[LossConfig]] = MISSING
def __init__(self, config: Union[DictConfig, Config]):
super().__init__()
if not isinstance(config, DictConfig) and isinstance(config, self.Config):
config = OmegaConf.structured(config)
self.config = config
self._logged_warning = {"losses_present": False}
self._is_pretrained = False
self._is_pl_enabled = False
self.checkpoint_updater = None
log_class_usage("Model", self.__class__)
@classmethod
def from_params(cls, **kwargs):
return cls(OmegaConf.structured(cls.Config(**kwargs)))
@property
def is_pretrained(self):
return self._is_pretrained
@property
def is_pl_enabled(self):
return self._is_pl_enabled
@is_pretrained.setter
def is_pretrained(self, x: bool):
self._is_pretrained = x
@is_pl_enabled.setter
def is_pl_enabled(self, x: bool):
self._is_pl_enabled = x
[docs] def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
"""
This is called by the pl.LightningModule before the model's checkpoint
is loaded.
"""
self.build()
if self.checkpoint_updater is None:
self.checkpoint_updater = MMFToPLCheckpointUpdater()
self.checkpoint_updater.update_checkpoint(checkpoint, self)
def _run_format_state_key(self, state_dict: Dict[str, Any]) -> None:
"""Function to rewrtie the checkpoint in place"""
tmp_state_dict = dict(state_dict)
for attr in tmp_state_dict:
new_attr = self.format_state_key(attr)
if attr != new_attr:
value = state_dict.pop(attr)
state_dict[new_attr] = value
[docs] def build(self):
"""Function to be implemented by the child class, in case they need to
build their model separately than ``__init__``. All model related
downloads should also happen here.
"""
raise NotImplementedError(
"Build method not implemented in the child model class."
)
def build_meters(self, run_type):
from mmf.utils.build import build_meters
"""Function only used in lightning setting"""
self.train_meter, self.val_meter, self.test_meter = build_meters(run_type)
[docs] def init_losses(self):
"""Initializes loss for the model based ``losses`` key. Automatically called by
MMF internally after building the model.
"""
losses = self.config.get("losses", [])
if len(losses) == 0 and not self.is_pretrained:
warnings.warn(
"No losses are defined in model configuration. You are expected "
"to return loss in your return dict from forward."
)
self.losses = Losses(losses)
@classmethod
def config_path(cls):
return None
[docs] def load_state_dict(self, state_dict, *args, **kwargs):
copied_state_dict = deepcopy(state_dict)
for key in list(copied_state_dict.keys()):
formatted_key = self.format_state_key(key)
copied_state_dict[formatted_key] = copied_state_dict.pop(key)
return super().load_state_dict(copied_state_dict, *args, **kwargs)
[docs] def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
super().on_save_checkpoint(checkpoint)
config = registry.get("config")
config_dict = OmegaConf.to_container(config, resolve=True)
checkpoint["config"] = config_dict
# TODO: add git features, for example:
# 'git/branch', 'git/commit_hash', 'git/commit_author',
# 'git/commit_message', 'git/diff'
[docs] def forward(self, sample_list, *args, **kwargs):
"""To be implemented by child class. Takes in a ``SampleList`` and
returns back a dict.
Args:
sample_list (SampleList): SampleList returned by the DataLoader for
current iteration
Returns:
Dict: Dict containing scores object.
"""
raise NotImplementedError(
"Forward of the child model class needs to be implemented."
)
[docs] def training_step(self, batch: SampleList, batch_idx: int, *args, **kwargs):
"""Member function of PL modules. Used only when PL enabled.
To be implemented by child class. Takes in a ``SampleList``,
batch_idx and returns back a dict.
Args:
sample_list (SampleList): SampleList returned by the DataLoader for
current iteration
Returns:
Dict: Dict containing loss.
"""
output = self._forward_lightning_step(batch, batch_idx)
report = Report(batch, output).detach()
self.train_meter.update_from_report(report)
return output
[docs] def validation_step(self, batch: SampleList, batch_idx: int, *args, **kwargs):
"""Member function of PL modules. Used only when PL enabled.
To be implemented by child class. Takes in a ``SampleList``,
batch_idx and returns back a dict.
Args:
sample_list (SampleList): SampleList returned by the DataLoader for
current iteration
Returns:
Dict
"""
output = self._forward_lightning_step(batch, batch_idx)
report = Report(batch, output).detach()
self.val_meter.update_from_report(report, should_update_loss=False)
report.metrics = self.metrics(report, report)
self.log_dict(report.metrics)
return output
[docs] def test_step(self, batch: SampleList, batch_idx: int, *args, **kwargs):
"""Member function of PL modules. Used only when PL enabled.
To be implemented by child class. Takes in a ``SampleList``,
batch_idx and returns back a dict.
Args:
sample_list (SampleList): SampleList returned by the DataLoader for
current iteration
Returns:
Dict
"""
return self._forward_lightning_step(batch, batch_idx)
def _forward_lightning_step(self, batch, batch_idx):
batch = self._ensure_sample_list(batch)
output = self(batch)
loss_dict = output["losses"]
output["loss"] = sum(loss.mean() for loss in loss_dict.values())
self._detach_forward_output(output)
return output
def _detach_forward_output(self, output):
keys_to_detach = [key for key in output.keys() if key != "loss"]
for key in keys_to_detach:
if hasattr(output[key], "detach"):
output[key] = output[key].detach()
def _ensure_sample_list(self, batch):
if not isinstance(batch, SampleList):
# Try converting to SampleList
batch = SampleList(batch)
return batch
def __call__(self, sample_list, *args, **kwargs):
if not self._is_pl_enabled:
# Move to proper device i.e. same as the model before passing
sample_list = to_device(sample_list, get_current_device())
model_output = super().__call__(sample_list, *args, **kwargs)
# Don't do anything fancy to output if it is pretrained
if self.is_pretrained:
return model_output
# Make sure that the output from the model is a Mapping
assert isinstance(
model_output, collections.abc.Mapping
), "A dict must be returned from the forward of the model."
if "losses" in model_output:
if not self._logged_warning["losses_present"]:
warnings.warn(
"'losses' already present in model output. "
"No calculation will be done in base model."
)
self._logged_warning["losses_present"] = True
assert isinstance(
model_output["losses"], collections.abc.Mapping
), "'losses' must be a dict."
elif hasattr(self, "losses"):
model_output["losses"] = self.losses(sample_list, model_output)
else:
model_output["losses"] = {}
return model_output
def load_requirements(self, config, *args, **kwargs):
requirements = config.get("zoo_requirements", [])
if isinstance(requirements, str):
requirements = [requirements]
for item in requirements:
download_pretrained_model(item, *args, **kwargs)
@classmethod
def from_pretrained(cls, model_name_or_path, *args, **kwargs):
# Check if the path exists, if not it is pretrained, otherwise,
# we will try to load the checkpoint from the path
if not PathManager.exists(model_name_or_path):
model_key = model_name_or_path.split(".")[0]
model_cls = registry.get_model_class(model_key)
assert (
model_cls == cls
), f"Incorrect pretrained model key {model_name_or_path} "
"for class {cls.__name__}"
output = load_pretrained_model(model_name_or_path, *args, **kwargs)
config, checkpoint, full_config = (
output["config"],
output["checkpoint"],
output["full_config"],
)
# Save original config for state reset later
config_temp_holder = registry.get("config")
# Register full config from checkpoint when loading the model
registry.register("config", full_config)
# Some models need registry updates to be load pretrained model
# If they have this method, call it so they can update accordingly
if hasattr(cls, "update_registry_for_pretrained"):
cls.update_registry_for_pretrained(config, checkpoint, output)
instance = cls(config)
instance.is_pretrained = True
instance.build()
incompatible_keys = instance.load_state_dict(checkpoint, strict=False)
# The model has loaded, reset the state
registry.register("config", config_temp_holder)
if len(incompatible_keys.missing_keys) != 0:
logger.warning(
f"Missing keys {incompatible_keys.missing_keys} in the"
+ " checkpoint.\n"
+ "If this is not your checkpoint, please open up an "
+ "issue on MMF GitHub. \n"
+ f"Unexpected keys if any: {incompatible_keys.unexpected_keys}"
)
if len(incompatible_keys.unexpected_keys) != 0:
logger.warning(
"Unexpected keys in state dict: "
+ f"{incompatible_keys.unexpected_keys} \n"
+ "This is usually not a problem with pretrained models, but "
+ "if this is your own model, please double check. \n"
+ "If you think this is an issue, please open up a "
+ "bug at MMF GitHub."
)
instance.eval()
return instance