• Docs >
  • datasets.base_dataset_builder
Shortcuts

datasets.base_dataset_builder

In MMF, for adding new datasets, dataset builder for datasets need to be added. A new dataset builder must inherit BaseDatasetBuilder class and implement load and build functions.

build is used to build a dataset when it is not available. For e.g. downloading the ImDBs for a dataset. In future, we plan to add a build to add dataset builder to ease setup of MMF.

load is used to load a dataset from specific path. load needs to return an instance of subclass of mmf.datasets.base_dataset.BaseDataset.

See complete example for VQA2DatasetBuilder here.

Example:

from torch.utils.data import Dataset

from mmf.datasets.base_dataset_builder import BaseDatasetBuilder
from mmf.common.registry import registry

@registry.register_builder("my")
class MyBuilder(BaseDatasetBuilder):
    def __init__(self):
        super().__init__("my")

    def load(self, config, dataset_type, *args, **kwargs):
        ...
        return Dataset()

    def build(self, config, dataset_type, *args, **kwargs):
        ...
class mmf.datasets.base_dataset_builder.BaseDatasetBuilder(*args: Any, **kwargs: Any)[source]

Base class for implementing dataset builders. See more information on top. Child class needs to implement build and load.

Parameters

dataset_name (str) – Name of the dataset passed from child.

build(config, dataset_type='train', *args, **kwargs)[source]

This is used to build a dataset first time. Implement this method in your child dataset builder class.

Parameters
  • config (DictConfig) – Configuration of this dataset loaded from config.

  • dataset_type (str) – Type of dataset, train|val|test

build_dataset(config, dataset_type='train', *args, **kwargs)[source]

Similar to load function, used by MMF to build a dataset for first time when it is not available. This internally calls ‘build’ function. Override that function in your child class.

NOTE: The caller to this function should only call this on main process in a distributed settings so that downloads and build only happen on main process and others can just load it. Make sure to call synchronize afterwards to bring all processes in sync.

Parameters
  • config (DictConfig) – Configuration of this dataset loaded from config.

  • dataset_type (str) – Type of dataset, train|val|test

Warning

DO NOT OVERRIDE in child class. Instead override build.

load(config, dataset_type='train', *args, **kwargs)[source]

This is used to prepare the dataset and load it from a path. Override this method in your child dataset builder class.

Parameters
  • config (DictConfig) – Configuration of this dataset loaded from config.

  • dataset_type (str) – Type of dataset, train|val|test

Returns

Dataset containing data to be trained on

Return type

dataset (BaseDataset)

load_dataset(config, dataset_type='train', *args, **kwargs)[source]

Main load function use by MMF. This will internally call load function. Calls init_processors and try_fast_read on the dataset returned from load

Parameters
  • config (DictConfig) – Configuration of this dataset loaded from config.

  • dataset_type (str) – Type of dataset, train|val|test

Returns

Dataset containing data to be trained on

Return type

dataset (BaseDataset)

Warning

DO NOT OVERRIDE in child class. Instead override load.

prepare_data(config, *args, **kwargs)[source]

NOTE: The caller to this function should only call this on main process in a distributed settings so that downloads and build only happen on main process and others can just load it. Make sure to call synchronize afterwards to bring all processes in sync.

Lightning automatically wraps datamodule in a way that it is only called on a main node, but for extra precaution as lightning can introduce bugs, we should always call this under main process with extra checks on our sides as well.

setup(stage: Optional[str] = None, config: Optional[omegaconf.dictconfig.DictConfig] = None)[source]

Called at the beginning of fit (train + validate), validate, test, and predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters

stage – either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(stage):
        data = Load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
teardown(*args, **kwargs) None[source]

Called at the end of fit (train + validate), validate, test, predict, or tune.

Parameters

stage – either 'fit', 'validate', 'test', or 'predict'

test_dataloader(*args, **kwargs)[source]

Implement one or multiple PyTorch DataLoaders for testing.

The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a postive integer.

For data processing use the following pattern:

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Returns

A torch.utils.data.DataLoader or a sequence of them specifying testing samples.

Example:

def test_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform,
                    download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def test_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]

Note

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

Note

In the case where you return multiple test dataloaders, the test_step() will have an argument dataloader_idx which matches the order here.

train_dataloader(*args, **kwargs)[source]

Implement one or more PyTorch DataLoaders for training.

Returns

A collection of torch.utils.data.DataLoader specifying training samples. In the case of multiple dataloaders, please see this page.

The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Example:

# single dataloader
def train_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform,
                    download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=True
    )
    return loader

# multiple dataloaders, return as list
def train_dataloader(self):
    mnist = MNIST(...)
    cifar = CIFAR(...)
    mnist_loader = torch.utils.data.DataLoader(
        dataset=mnist, batch_size=self.batch_size, shuffle=True
    )
    cifar_loader = torch.utils.data.DataLoader(
        dataset=cifar, batch_size=self.batch_size, shuffle=True
    )
    # each batch will be a list of tensors: [batch_mnist, batch_cifar]
    return [mnist_loader, cifar_loader]

# multiple dataloader, return as dict
def train_dataloader(self):
    mnist = MNIST(...)
    cifar = CIFAR(...)
    mnist_loader = torch.utils.data.DataLoader(
        dataset=mnist, batch_size=self.batch_size, shuffle=True
    )
    cifar_loader = torch.utils.data.DataLoader(
        dataset=cifar, batch_size=self.batch_size, shuffle=True
    )
    # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar}
    return {'mnist': mnist_loader, 'cifar': cifar_loader}
val_dataloader(*args, **kwargs)[source]

Implement one or multiple PyTorch DataLoaders for validation.

The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

Note

Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Returns

A torch.utils.data.DataLoader or a sequence of them specifying validation samples.

Examples:

def val_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False,
                    transform=transform, download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def val_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

Note

In the case where you return multiple validation dataloaders, the validation_step() will have an argument dataloader_idx which matches the order here.