Source code for mmf.datasets.base_dataset
# Copyright (c) Facebook, Inc. and its affiliates.
from mmf.common.registry import registry
from mmf.common.sample import SampleList
from mmf.utils.general import get_current_device
from torch.utils.data.dataset import Dataset
[docs]class BaseDataset(Dataset):
"""Base class for implementing a dataset. Inherits from PyTorch's Dataset class
but adds some custom functionality on top. Processors mentioned in the
configuration are automatically initialized for the end user.
Args:
dataset_name (str): Name of your dataset to be used a representative
in text strings
dataset_type (str): Type of your dataset. Normally, train|val|test
config (DictConfig): Configuration for the current dataset
"""
def __init__(self, dataset_name, config, dataset_type="train", *args, **kwargs):
super().__init__()
if config is None:
config = {}
self.config = config
self._dataset_name = dataset_name
self._dataset_type = dataset_type
self._global_config = registry.get("config")
self._device = get_current_device()
self.use_cuda = "cuda" in str(self._device)
[docs] def load_item(self, idx):
"""
Implement if you need to separately load the item and cache it.
Args:
idx (int): Index of the sample to be loaded.
"""
return
def __getitem__(self, idx):
"""
Basically, __getitem__ of a torch dataset.
Args:
idx (int): Index of the sample to be loaded.
"""
raise NotImplementedError
def init_processors(self):
if "processors" not in self.config:
return
from mmf.utils.build import build_processors
extra_params = {"data_dir": self.config.data_dir}
reg_key = f"{self._dataset_name}_{{}}"
processor_dict = build_processors(
self.config.processors, reg_key, **extra_params
)
for processor_key, processor_instance in processor_dict.items():
setattr(self, processor_key, processor_instance)
full_key = reg_key.format(processor_key)
registry.register(full_key, processor_instance)
[docs] def prepare_batch(self, batch):
"""
Can be possibly overridden in your child class. Not supported w Lightning
trainer
Prepare batch for passing to model. Whatever returned from here will
be directly passed to model's forward function. Currently moves the batch to
proper device.
Args:
batch (SampleList): sample list containing the currently loaded batch
Returns:
sample_list (SampleList): Returns a sample representing current
batch loaded
"""
# Should be a SampleList
if not isinstance(batch, SampleList):
# Try converting to SampleList
batch = SampleList(batch)
batch = batch.to(self._device)
return batch
@property
def dataset_type(self):
return self._dataset_type
@property
def name(self):
return self._dataset_name
@property
def dataset_name(self):
return self._dataset_name
@dataset_name.setter
def dataset_name(self, name):
self._dataset_name = name
@dataset_type.setter
def dataset_type(self, dataset_type):
self._dataset_type = dataset_type
def format_for_prediction(self, report):
return []
def verbose_dump(self, *args, **kwargs):
return
def visualize(self, num_samples=1, *args, **kwargs):
raise NotImplementedError(
f"{self.dataset_name} doesn't implement visualize function"
)