Shortcuts

Source code for mmf.datasets.base_dataset

# Copyright (c) Facebook, Inc. and its affiliates.
from mmf.common.registry import registry
from mmf.common.sample import SampleList
from mmf.utils.general import get_current_device
from torch.utils.data.dataset import Dataset


[docs]class BaseDataset(Dataset): """Base class for implementing a dataset. Inherits from PyTorch's Dataset class but adds some custom functionality on top. Processors mentioned in the configuration are automatically initialized for the end user. Args: dataset_name (str): Name of your dataset to be used a representative in text strings dataset_type (str): Type of your dataset. Normally, train|val|test config (DictConfig): Configuration for the current dataset """ def __init__(self, dataset_name, config, dataset_type="train", *args, **kwargs): super().__init__() if config is None: config = {} self.config = config self._dataset_name = dataset_name self._dataset_type = dataset_type self._global_config = registry.get("config") self._device = get_current_device() self.use_cuda = "cuda" in str(self._device)
[docs] def load_item(self, idx): """ Implement if you need to separately load the item and cache it. Args: idx (int): Index of the sample to be loaded. """ return
def __getitem__(self, idx): """ Basically, __getitem__ of a torch dataset. Args: idx (int): Index of the sample to be loaded. """ raise NotImplementedError def init_processors(self): if "processors" not in self.config: return from mmf.utils.build import build_processors extra_params = {"data_dir": self.config.data_dir} reg_key = f"{self._dataset_name}_{{}}" processor_dict = build_processors( self.config.processors, reg_key, **extra_params ) for processor_key, processor_instance in processor_dict.items(): setattr(self, processor_key, processor_instance) full_key = reg_key.format(processor_key) registry.register(full_key, processor_instance)
[docs] def prepare_batch(self, batch): """ Can be possibly overridden in your child class. Not supported w Lightning trainer Prepare batch for passing to model. Whatever returned from here will be directly passed to model's forward function. Currently moves the batch to proper device. Args: batch (SampleList): sample list containing the currently loaded batch Returns: sample_list (SampleList): Returns a sample representing current batch loaded """ # Should be a SampleList if not isinstance(batch, SampleList): # Try converting to SampleList batch = SampleList(batch) batch = batch.to(self._device) return batch
@property def dataset_type(self): return self._dataset_type @property def name(self): return self._dataset_name @property def dataset_name(self): return self._dataset_name @dataset_name.setter def dataset_name(self, name): self._dataset_name = name @dataset_type.setter def dataset_type(self, dataset_type): self._dataset_type = dataset_type def format_for_prediction(self, report): return [] def verbose_dump(self, *args, **kwargs): return def visualize(self, num_samples=1, *args, **kwargs): raise NotImplementedError( f"{self.dataset_name} doesn't implement visualize function" )