Shortcuts

Source code for mmf.datasets.base_dataset_builder

# Copyright (c) Facebook, Inc. and its affiliates.
"""
In MMF, for adding new datasets, dataset builder for datasets need to be
added. A new dataset builder must inherit ``BaseDatasetBuilder`` class and
implement ``load`` and ``build`` functions.

``build`` is used to build a dataset when it is not available. For e.g.
downloading the ImDBs for a dataset. In future, we plan to add a ``build``
to add dataset builder to ease setup of MMF.

``load`` is used to load a dataset from specific path. ``load`` needs to return
an instance of subclass of ``mmf.datasets.base_dataset.BaseDataset``.

See complete example for ``VQA2DatasetBuilder`` here_.

Example::

    from torch.utils.data import Dataset

    from mmf.datasets.base_dataset_builder import BaseDatasetBuilder
    from mmf.common.registry import registry

    @registry.register_builder("my")
    class MyBuilder(BaseDatasetBuilder):
        def __init__(self):
            super().__init__("my")

        def load(self, config, dataset_type, *args, **kwargs):
            ...
            return Dataset()

        def build(self, config, dataset_type, *args, **kwargs):
            ...

.. _here: https://github.com/facebookresearch/mmf/blob/main/mmf/datasets/vqa/vqa2/builder.py
"""
import uuid
from typing import Optional

import pytorch_lightning as pl
from mmf.utils.build import build_dataloader_and_sampler
from mmf.utils.logger import log_class_usage
from omegaconf import DictConfig
from torch.utils.data import Dataset


# TODO(asg): Deprecate BaseDatasetBuilder after version release
[docs]class BaseDatasetBuilder(pl.LightningDataModule): """Base class for implementing dataset builders. See more information on top. Child class needs to implement ``build`` and ``load``. Args: dataset_name (str): Name of the dataset passed from child. """ def __init__(self, dataset_name: Optional[str] = None, *args, **kwargs): super().__init__(*args, **kwargs) if dataset_name is None: # In case user doesn't pass it dataset_name = f"dataset_{uuid.uuid4().hex[:6]}" self.dataset_name = dataset_name self._train_dataset = None self._val_dataset = None self._test_dataset = None log_class_usage("DatasetBuilder", self.__class__) @property def dataset_name(self): return self._dataset_name @dataset_name.setter def dataset_name(self, dataset_name): self._dataset_name = dataset_name
[docs] def prepare_data(self, config, *args, **kwargs): """ NOTE: The caller to this function should only call this on main process in a distributed settings so that downloads and build only happen on main process and others can just load it. Make sure to call synchronize afterwards to bring all processes in sync. Lightning automatically wraps datamodule in a way that it is only called on a main node, but for extra precaution as lightning can introduce bugs, we should always call this under main process with extra checks on our sides as well. """ self.config = config self.build_dataset(config)
[docs] def setup(self, stage: Optional[str] = None, config: Optional[DictConfig] = None): if config is None: config = self.config self.config = config self.train_dataset = self.load_dataset(config, "train") self.val_dataset = self.load_dataset(config, "val") self.test_dataset = self.load_dataset(config, "test")
@property def train_dataset(self) -> Optional[Dataset]: return self._train_dataset @train_dataset.setter def train_dataset(self, dataset: Optional[Dataset]): self._train_dataset = dataset @property def val_dataset(self) -> Optional[Dataset]: return self._val_dataset @val_dataset.setter def val_dataset(self, dataset: Optional[Dataset]): self._val_dataset = dataset @property def test_dataset(self) -> Optional[Dataset]: return self._test_dataset @test_dataset.setter def test_dataset(self, dataset: Optional[Dataset]): self._test_dataset = dataset
[docs] def build_dataset(self, config, dataset_type="train", *args, **kwargs): """ Similar to load function, used by MMF to build a dataset for first time when it is not available. This internally calls 'build' function. Override that function in your child class. NOTE: The caller to this function should only call this on main process in a distributed settings so that downloads and build only happen on main process and others can just load it. Make sure to call synchronize afterwards to bring all processes in sync. Args: config (DictConfig): Configuration of this dataset loaded from config. dataset_type (str): Type of dataset, train|val|test .. warning:: DO NOT OVERRIDE in child class. Instead override ``build``. """ self.build(config, dataset_type, *args, **kwargs)
[docs] def load_dataset(self, config, dataset_type="train", *args, **kwargs): """Main load function use by MMF. This will internally call ``load`` function. Calls ``init_processors`` and ``try_fast_read`` on the dataset returned from ``load`` Args: config (DictConfig): Configuration of this dataset loaded from config. dataset_type (str): Type of dataset, train|val|test Returns: dataset (BaseDataset): Dataset containing data to be trained on .. warning:: DO NOT OVERRIDE in child class. Instead override ``load``. """ dataset = self.load(config, dataset_type, *args, **kwargs) if dataset is not None and hasattr(dataset, "init_processors"): # Checking for init_processors allows us to load some datasets # which don't have processors and don't inherit from BaseDataset dataset.init_processors() return dataset
[docs] def load(self, config, dataset_type="train", *args, **kwargs): """ This is used to prepare the dataset and load it from a path. Override this method in your child dataset builder class. Args: config (DictConfig): Configuration of this dataset loaded from config. dataset_type (str): Type of dataset, train|val|test Returns: dataset (BaseDataset): Dataset containing data to be trained on """ raise NotImplementedError( "This dataset builder doesn't implement a load method" )
@classmethod def config_path(cls): return None
[docs] def build(self, config, dataset_type="train", *args, **kwargs): """ This is used to build a dataset first time. Implement this method in your child dataset builder class. Args: config (DictConfig): Configuration of this dataset loaded from config. dataset_type (str): Type of dataset, train|val|test """ raise NotImplementedError( "This dataset builder doesn't implement a build method" )
def build_dataloader( self, dataset_instance: Optional[Dataset], dataset_type: str, *args, **kwargs ): if dataset_instance is None: raise TypeError( f"dataset instance for {dataset_type} hasn't been set and is None" ) dataset_instance.dataset_type = dataset_type dataloader, _ = build_dataloader_and_sampler(dataset_instance, self.config) return dataloader
[docs] def train_dataloader(self, *args, **kwargs): return self.build_dataloader(self.train_dataset, "train")
[docs] def val_dataloader(self, *args, **kwargs): return self.build_dataloader(self.val_dataset, "val")
[docs] def test_dataloader(self, *args, **kwargs): return self.build_dataloader(self.test_dataset, "test")
[docs] def teardown(self, *args, **kwargs) -> None: pass