datasets.base_dataset_builder¶
In MMF, for adding new datasets, dataset builder for datasets need to be
added. A new dataset builder must inherit BaseDatasetBuilder
class and
implement load
and build
functions.
build
is used to build a dataset when it is not available. For e.g.
downloading the ImDBs for a dataset. In future, we plan to add a build
to add dataset builder to ease setup of MMF.
load
is used to load a dataset from specific path. load
needs to return
an instance of subclass of mmf.datasets.base_dataset.BaseDataset
.
See complete example for VQA2DatasetBuilder
here.
Example:
from torch.utils.data import Dataset
from mmf.datasets.base_dataset_builder import BaseDatasetBuilder
from mmf.common.registry import registry
@registry.register_builder("my")
class MyBuilder(BaseDatasetBuilder):
def __init__(self):
super().__init__("my")
def load(self, config, dataset_type, *args, **kwargs):
...
return Dataset()
def build(self, config, dataset_type, *args, **kwargs):
...
- class mmf.datasets.base_dataset_builder.BaseDatasetBuilder(*args: Any, **kwargs: Any)[source]¶
Base class for implementing dataset builders. See more information on top. Child class needs to implement
build
andload
.- Parameters
dataset_name (str) – Name of the dataset passed from child.
- build(config, dataset_type='train', *args, **kwargs)[source]¶
This is used to build a dataset first time. Implement this method in your child dataset builder class.
- Parameters
config (DictConfig) – Configuration of this dataset loaded from config.
dataset_type (str) – Type of dataset, train|val|test
- build_dataset(config, dataset_type='train', *args, **kwargs)[source]¶
Similar to load function, used by MMF to build a dataset for first time when it is not available. This internally calls ‘build’ function. Override that function in your child class.
NOTE: The caller to this function should only call this on main process in a distributed settings so that downloads and build only happen on main process and others can just load it. Make sure to call synchronize afterwards to bring all processes in sync.
- Parameters
config (DictConfig) – Configuration of this dataset loaded from config.
dataset_type (str) – Type of dataset, train|val|test
Warning
DO NOT OVERRIDE in child class. Instead override
build
.
- load(config, dataset_type='train', *args, **kwargs)[source]¶
This is used to prepare the dataset and load it from a path. Override this method in your child dataset builder class.
- Parameters
config (DictConfig) – Configuration of this dataset loaded from config.
dataset_type (str) – Type of dataset, train|val|test
- Returns
Dataset containing data to be trained on
- Return type
dataset (BaseDataset)
- load_dataset(config, dataset_type='train', *args, **kwargs)[source]¶
Main load function use by MMF. This will internally call
load
function. Callsinit_processors
andtry_fast_read
on the dataset returned fromload
- Parameters
config (DictConfig) – Configuration of this dataset loaded from config.
dataset_type (str) – Type of dataset, train|val|test
- Returns
Dataset containing data to be trained on
- Return type
dataset (BaseDataset)
Warning
DO NOT OVERRIDE in child class. Instead override
load
.
- prepare_data(config, *args, **kwargs)[source]¶
NOTE: The caller to this function should only call this on main process in a distributed settings so that downloads and build only happen on main process and others can just load it. Make sure to call synchronize afterwards to bring all processes in sync.
Lightning automatically wraps datamodule in a way that it is only called on a main node, but for extra precaution as lightning can introduce bugs, we should always call this under main process with extra checks on our sides as well.
- setup(stage: Optional[str] = None, config: Optional[omegaconf.dictconfig.DictConfig] = None)[source]¶
Called at the beginning of fit (train + validate), validate, test, and predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters
stage – either
'fit'
,'validate'
,'test'
, or'predict'
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(stage): data = Load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- teardown(*args, **kwargs) None [source]¶
Called at the end of fit (train + validate), validate, test, predict, or tune.
- Parameters
stage – either
'fit'
,'validate'
,'test'
, or'predict'
- test_dataloader(*args, **kwargs)[source]¶
Implement one or multiple PyTorch DataLoaders for testing.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a postive integer.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
Note
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- Returns
A
torch.utils.data.DataLoader
or a sequence of them specifying testing samples.
Example:
def test_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def test_dataloader(self): return [loader_a, loader_b, ..., loader_n]
Note
If you don’t need a test dataset and a
test_step()
, you don’t need to implement this method.Note
In the case where you return multiple test dataloaders, the
test_step()
will have an argumentdataloader_idx
which matches the order here.
- train_dataloader(*args, **kwargs)[source]¶
Implement one or more PyTorch DataLoaders for training.
- Returns
A collection of
torch.utils.data.DataLoader
specifying training samples. In the case of multiple dataloaders, please see this page.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()
…
Note
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Example:
# single dataloader def train_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=True ) return loader # multiple dataloaders, return as list def train_dataloader(self): mnist = MNIST(...) cifar = CIFAR(...) mnist_loader = torch.utils.data.DataLoader( dataset=mnist, batch_size=self.batch_size, shuffle=True ) cifar_loader = torch.utils.data.DataLoader( dataset=cifar, batch_size=self.batch_size, shuffle=True ) # each batch will be a list of tensors: [batch_mnist, batch_cifar] return [mnist_loader, cifar_loader] # multiple dataloader, return as dict def train_dataloader(self): mnist = MNIST(...) cifar = CIFAR(...) mnist_loader = torch.utils.data.DataLoader( dataset=mnist, batch_size=self.batch_size, shuffle=True ) cifar_loader = torch.utils.data.DataLoader( dataset=cifar, batch_size=self.batch_size, shuffle=True ) # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar} return {'mnist': mnist_loader, 'cifar': cifar_loader}
- val_dataloader(*args, **kwargs)[source]¶
Implement one or multiple PyTorch DataLoaders for validation.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data()
.Note
Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Returns
A
torch.utils.data.DataLoader
or a sequence of them specifying validation samples.
Examples:
def val_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def val_dataloader(self): return [loader_a, loader_b, ..., loader_n]
Note
If you don’t need a validation dataset and a
validation_step()
, you don’t need to implement this method.Note
In the case where you return multiple validation dataloaders, the
validation_step()
will have an argumentdataloader_idx
which matches the order here.