Shortcuts

Source code for mmf.models.base_model

# Copyright (c) Facebook, Inc. and its affiliates.

"""
Models built in MMF need to inherit ``BaseModel`` class and adhere to
a fixed format. To create a model for MMF, follow this quick cheatsheet.

1. Inherit ``BaseModel`` class, make sure to call ``super().__init__()`` in your
   class's ``__init__`` function.
2. Implement `build` function for your model. If you build everything in ``__init__``,
   you can just return in this function.
3. Write a `forward` function which takes in a ``SampleList`` as an argument and
   returns a dict.
4. Register using ``@registry.register_model("key")`` decorator on top of the
   class.

If you are doing logits based predictions, the dict you return from your model
should contain a `scores` field. Losses are automatically calculated by the
``BaseModel`` class and added to this dict if not present.

Example::

    import torch

    from mmf.common.registry import registry
    from mmf.models.base_model import BaseModel


    @registry.register("pythia")
    class Pythia(BaseModel):
        # config is model_config from global config
        def __init__(self, config):
            super().__init__(config)

        def build(self):
            ....

        def forward(self, sample_list):
            scores = torch.rand(sample_list.get_batch_size(), 3127)
            return {"scores": scores}
"""


import collections
import logging
import warnings
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

import pytorch_lightning as pl
from mmf.common.registry import registry
from mmf.common.report import Report
from mmf.common.sample import SampleList, to_device
from mmf.modules.losses import LossConfig, Losses
from mmf.utils.checkpoint import load_pretrained_model
from mmf.utils.checkpoint_updater import MMFToPLCheckpointUpdater
from mmf.utils.download import download_pretrained_model
from mmf.utils.file_io import PathManager
from mmf.utils.general import get_current_device
from mmf.utils.logger import log_class_usage
from omegaconf import MISSING, DictConfig, OmegaConf


logger = logging.getLogger(__name__)


[docs]class BaseModel(pl.LightningModule): """For integration with MMF's trainer, datasets and other features, models needs to inherit this class, call `super`, write a build function, write a forward function taking a ``SampleList`` as input and returning a dict as output and finally, register it using ``@registry.register_model`` Args: config (DictConfig): ``model_config`` configuration from global config. """
[docs] @dataclass class Config: # Name of the model that is used in registry model: str = MISSING losses: Optional[List[LossConfig]] = MISSING
def __init__(self, config: Union[DictConfig, Config]): super().__init__() if not isinstance(config, DictConfig) and isinstance(config, self.Config): config = OmegaConf.structured(config) self.config = config self._logged_warning = {"losses_present": False} self._is_pretrained = False self._is_pl_enabled = False self.checkpoint_updater = None log_class_usage("Model", self.__class__) @classmethod def from_params(cls, **kwargs): return cls(OmegaConf.structured(cls.Config(**kwargs))) @property def is_pretrained(self): return self._is_pretrained @property def is_pl_enabled(self): return self._is_pl_enabled @is_pretrained.setter def is_pretrained(self, x: bool): self._is_pretrained = x @is_pl_enabled.setter def is_pl_enabled(self, x: bool): self._is_pl_enabled = x
[docs] def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """ This is called by the pl.LightningModule before the model's checkpoint is loaded. """ self.build() if self.checkpoint_updater is None: self.checkpoint_updater = MMFToPLCheckpointUpdater() self.checkpoint_updater.update_checkpoint(checkpoint, self)
def _run_format_state_key(self, state_dict: Dict[str, Any]) -> None: """Function to rewrtie the checkpoint in place""" tmp_state_dict = dict(state_dict) for attr in tmp_state_dict: new_attr = self.format_state_key(attr) if attr != new_attr: value = state_dict.pop(attr) state_dict[new_attr] = value
[docs] def build(self): """Function to be implemented by the child class, in case they need to build their model separately than ``__init__``. All model related downloads should also happen here. """ raise NotImplementedError( "Build method not implemented in the child model class." )
def build_meters(self, run_type): from mmf.utils.build import build_meters """Function only used in lightning setting""" self.train_meter, self.val_meter, self.test_meter = build_meters(run_type)
[docs] def init_losses(self): """Initializes loss for the model based ``losses`` key. Automatically called by MMF internally after building the model. """ losses = self.config.get("losses", []) if len(losses) == 0 and not self.is_pretrained: warnings.warn( "No losses are defined in model configuration. You are expected " "to return loss in your return dict from forward." ) self.losses = Losses(losses)
@classmethod def config_path(cls): return None
[docs] @classmethod def format_state_key(cls, key): """Can be implemented if something special needs to be done to the key when pretrained model is being loaded. This will adapt and return keys according to that. Useful for backwards compatibility. See updated load_state_dict below. For an example, see VisualBERT model's code. Args: key (string): key to be formatted Returns: string: formatted key """ return key
[docs] def load_state_dict(self, state_dict, *args, **kwargs): copied_state_dict = deepcopy(state_dict) for key in list(copied_state_dict.keys()): formatted_key = self.format_state_key(key) copied_state_dict[formatted_key] = copied_state_dict.pop(key) return super().load_state_dict(copied_state_dict, *args, **kwargs)
[docs] def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: super().on_save_checkpoint(checkpoint) config = registry.get("config") config_dict = OmegaConf.to_container(config, resolve=True) checkpoint["config"] = config_dict
# TODO: add git features, for example: # 'git/branch', 'git/commit_hash', 'git/commit_author', # 'git/commit_message', 'git/diff'
[docs] def forward(self, sample_list, *args, **kwargs): """To be implemented by child class. Takes in a ``SampleList`` and returns back a dict. Args: sample_list (SampleList): SampleList returned by the DataLoader for current iteration Returns: Dict: Dict containing scores object. """ raise NotImplementedError( "Forward of the child model class needs to be implemented." )
[docs] def training_step(self, batch: SampleList, batch_idx: int, *args, **kwargs): """Member function of PL modules. Used only when PL enabled. To be implemented by child class. Takes in a ``SampleList``, batch_idx and returns back a dict. Args: sample_list (SampleList): SampleList returned by the DataLoader for current iteration Returns: Dict: Dict containing loss. """ output = self._forward_lightning_step(batch, batch_idx) report = Report(batch, output).detach() self.train_meter.update_from_report(report) return output
[docs] def validation_step(self, batch: SampleList, batch_idx: int, *args, **kwargs): """Member function of PL modules. Used only when PL enabled. To be implemented by child class. Takes in a ``SampleList``, batch_idx and returns back a dict. Args: sample_list (SampleList): SampleList returned by the DataLoader for current iteration Returns: Dict """ output = self._forward_lightning_step(batch, batch_idx) report = Report(batch, output).detach() self.val_meter.update_from_report(report, should_update_loss=False) report.metrics = self.metrics(report, report) self.log_dict(report.metrics) return output
[docs] def test_step(self, batch: SampleList, batch_idx: int, *args, **kwargs): """Member function of PL modules. Used only when PL enabled. To be implemented by child class. Takes in a ``SampleList``, batch_idx and returns back a dict. Args: sample_list (SampleList): SampleList returned by the DataLoader for current iteration Returns: Dict """ return self._forward_lightning_step(batch, batch_idx)
def _forward_lightning_step(self, batch, batch_idx): batch = self._ensure_sample_list(batch) output = self(batch) loss_dict = output["losses"] output["loss"] = sum(loss.mean() for loss in loss_dict.values()) self._detach_forward_output(output) return output def _detach_forward_output(self, output): keys_to_detach = [key for key in output.keys() if key != "loss"] for key in keys_to_detach: if hasattr(output[key], "detach"): output[key] = output[key].detach()
[docs] def configure_optimizers(self): """Member function of PL modules. Used only when PL enabled.""" assert self._is_pl_enabled, ( "configure_optimizers should be only used as a member " "function of LightningModule when pytorch lightning is enabled." ) from mmf.utils.build import build_lightning_optimizers config = registry.get("config") return build_lightning_optimizers(self, config)
def _ensure_sample_list(self, batch): if not isinstance(batch, SampleList): # Try converting to SampleList batch = SampleList(batch) return batch def __call__(self, sample_list, *args, **kwargs): if not self._is_pl_enabled: # Move to proper device i.e. same as the model before passing sample_list = to_device(sample_list, get_current_device()) model_output = super().__call__(sample_list, *args, **kwargs) # Don't do anything fancy to output if it is pretrained if self.is_pretrained: return model_output # Make sure that the output from the model is a Mapping assert isinstance( model_output, collections.abc.Mapping ), "A dict must be returned from the forward of the model." if "losses" in model_output: if not self._logged_warning["losses_present"]: warnings.warn( "'losses' already present in model output. " "No calculation will be done in base model." ) self._logged_warning["losses_present"] = True assert isinstance( model_output["losses"], collections.abc.Mapping ), "'losses' must be a dict." elif hasattr(self, "losses"): model_output["losses"] = self.losses(sample_list, model_output) else: model_output["losses"] = {} return model_output def load_requirements(self, config, *args, **kwargs): requirements = config.get("zoo_requirements", []) if isinstance(requirements, str): requirements = [requirements] for item in requirements: download_pretrained_model(item, *args, **kwargs)
[docs] def format_for_prediction(self, results, report): """Implement this method in models if it requires to modify prediction results using report fields. Note that the required fields in report should already be gathered in report. """ return results
@classmethod def from_pretrained(cls, model_name_or_path, *args, **kwargs): # Check if the path exists, if not it is pretrained, otherwise, # we will try to load the checkpoint from the path if not PathManager.exists(model_name_or_path): model_key = model_name_or_path.split(".")[0] model_cls = registry.get_model_class(model_key) assert ( model_cls == cls ), f"Incorrect pretrained model key {model_name_or_path} " "for class {cls.__name__}" output = load_pretrained_model(model_name_or_path, *args, **kwargs) config, checkpoint, full_config = ( output["config"], output["checkpoint"], output["full_config"], ) # Save original config for state reset later config_temp_holder = registry.get("config") # Register full config from checkpoint when loading the model registry.register("config", full_config) # Some models need registry updates to be load pretrained model # If they have this method, call it so they can update accordingly if hasattr(cls, "update_registry_for_pretrained"): cls.update_registry_for_pretrained(config, checkpoint, output) instance = cls(config) instance.is_pretrained = True instance.build() incompatible_keys = instance.load_state_dict(checkpoint, strict=False) # The model has loaded, reset the state registry.register("config", config_temp_holder) if len(incompatible_keys.missing_keys) != 0: logger.warning( f"Missing keys {incompatible_keys.missing_keys} in the" + " checkpoint.\n" + "If this is not your checkpoint, please open up an " + "issue on MMF GitHub. \n" + f"Unexpected keys if any: {incompatible_keys.unexpected_keys}" ) if len(incompatible_keys.unexpected_keys) != 0: logger.warning( "Unexpected keys in state dict: " + f"{incompatible_keys.unexpected_keys} \n" + "This is usually not a problem with pretrained models, but " + "if this is your own model, please double check. \n" + "If you think this is an issue, please open up a " + "bug at MMF GitHub." ) instance.eval() return instance