common.sample¶
Sample and SampleList are data structures for arbitrary data returned from a
dataset. To work with MMF, minimum requirement for datasets is to return
an object of Sample class and for models to accept an object of type SampleList
as an argument.
Sample is used to represent an arbitrary sample from dataset, while SampleList
is list of Sample combined in an efficient way to be used by the model.
In simple term, SampleList is a batch of Sample but allow easy access of
attributes from Sample while taking care of properly batching things.
- class mmf.common.sample.Sample(init_dict=None)[source]¶
Sample represent some arbitrary data. All datasets in MMF must return an object of type
Sample.- Parameters
init_dict (Dict) – Dictionary to init
Sampleclass with.
Usage:
>>> sample = Sample({"text": torch.tensor(2)}) >>> sample.text.zero_() # Custom attributes can be added to ``Sample`` after initialization >>> sample.context = torch.tensor(4)
- class mmf.common.sample.SampleList(samples=None)[source]¶
SampleListis used to collate a list ofSampleinto a batch during batch preparation. It can be thought of as a merger of list of Dicts into a single Dict.If
Samplecontains an attribute ‘text’ of size (2) and there are 10 samples in list, the returnedSampleListwill have an attribute ‘text’ which is a tensor of size (10, 2).- Parameters
samples (type) – List of
Samplefrom which theSampleListwill be created.
Usage:
>>> sample_list = [ Sample({"text": torch.tensor(2)}), Sample({"text": torch.tensor(2)}) ] >>> sample_list.text torch.tensor([2, 2])
- add_field(field, data)[source]¶
Add an attribute
fieldwith valuedatato the SampleList- Parameters
field (str) – Key under which the data will be added.
data (object) – Data to be added, can be a
torch.Tensor,listorSample
- copy()[source]¶
Get a copy of the current SampleList
- Returns
Copy of current SampleList.
- Return type
- fields()[source]¶
Get current attributes/fields registered under the SampleList.
- Returns
list of attributes of the SampleList.
- Return type
List[str]
- get_batch_size()[source]¶
Get batch size of the current
SampleList. There must be a tensor be a tensor present inside sample list to use this function. :returns: Size of the batch inSampleList. :rtype: int
- get_field(field)[source]¶
Get value of a particular attribute
- Parameters
field (str) – Attribute whose value is to be returned.
- get_fields(fields)[source]¶
Get a new
SampleListgenerated from the currentSampleListbut contains only the attributes passed in fields argument- Parameters
fields (List[str]) – Attributes whose
SampleListwill be made.- Returns
SampleList containing only the attribute values of the fields which were passed.
- Return type
- get_item_list(key)[source]¶
Get
SampleListof only one particular attribute that is present in theSampleList.- Parameters
key (str) – Attribute whose
SampleListwill be made.- Returns
SampleList containing only the attribute value of the key which was passed.
- Return type
- pin_memory()[source]¶
In custom batch object, we need to define pin_memory function so that PyTorch can actually apply pinning. This function just individually pins all of the tensor fields
- to(device, non_blocking=True)[source]¶
Similar to
.tofunction on a torch.Tensor. Moves all of the tensors present inside theSampleListto a particular device. If an attribute’s value is not a tensor, it is ignored and kept as it is.- Parameters
device (str|torch.device) – Device on which the
SampleListshould moved.non_blocking (bool) – Whether the move should be non_blocking. Default: True
- Returns
a SampleList moved to the
device.- Return type