common.sample¶
Sample
and SampleList
are data structures for arbitrary data returned from a
dataset. To work with MMF, minimum requirement for datasets is to return
an object of Sample
class and for models to accept an object of type SampleList
as an argument.
Sample
is used to represent an arbitrary sample from dataset, while SampleList
is list of Sample combined in an efficient way to be used by the model.
In simple term, SampleList
is a batch of Sample but allow easy access of
attributes from Sample
while taking care of properly batching things.
- class mmf.common.sample.Sample(init_dict=None)[source]¶
Sample represent some arbitrary data. All datasets in MMF must return an object of type
Sample
.- Parameters
init_dict (Dict) – Dictionary to init
Sample
class with.
Usage:
>>> sample = Sample({"text": torch.tensor(2)}) >>> sample.text.zero_() # Custom attributes can be added to ``Sample`` after initialization >>> sample.context = torch.tensor(4)
- class mmf.common.sample.SampleList(samples=None)[source]¶
SampleList
is used to collate a list ofSample
into a batch during batch preparation. It can be thought of as a merger of list of Dicts into a single Dict.If
Sample
contains an attribute ‘text’ of size (2) and there are 10 samples in list, the returnedSampleList
will have an attribute ‘text’ which is a tensor of size (10, 2).- Parameters
samples (type) – List of
Sample
from which theSampleList
will be created.
Usage:
>>> sample_list = [ Sample({"text": torch.tensor(2)}), Sample({"text": torch.tensor(2)}) ] >>> sample_list.text torch.tensor([2, 2])
- add_field(field, data)[source]¶
Add an attribute
field
with valuedata
to the SampleList- Parameters
field (str) – Key under which the data will be added.
data (object) – Data to be added, can be a
torch.Tensor
,list
orSample
- copy()[source]¶
Get a copy of the current SampleList
- Returns
Copy of current SampleList.
- Return type
- fields()[source]¶
Get current attributes/fields registered under the SampleList.
- Returns
list of attributes of the SampleList.
- Return type
List[str]
- get_batch_size()[source]¶
Get batch size of the current
SampleList
. There must be a tensor be a tensor present inside sample list to use this function. :returns: Size of the batch inSampleList
. :rtype: int
- get_field(field)[source]¶
Get value of a particular attribute
- Parameters
field (str) – Attribute whose value is to be returned.
- get_fields(fields)[source]¶
Get a new
SampleList
generated from the currentSampleList
but contains only the attributes passed in fields argument- Parameters
fields (List[str]) – Attributes whose
SampleList
will be made.- Returns
SampleList containing only the attribute values of the fields which were passed.
- Return type
- get_item_list(key)[source]¶
Get
SampleList
of only one particular attribute that is present in theSampleList
.- Parameters
key (str) – Attribute whose
SampleList
will be made.- Returns
SampleList containing only the attribute value of the key which was passed.
- Return type
- pin_memory()[source]¶
In custom batch object, we need to define pin_memory function so that PyTorch can actually apply pinning. This function just individually pins all of the tensor fields
- to(device, non_blocking=True)[source]¶
Similar to
.to
function on a torch.Tensor. Moves all of the tensors present inside theSampleList
to a particular device. If an attribute’s value is not a tensor, it is ignored and kept as it is.- Parameters
device (str|torch.device) – Device on which the
SampleList
should moved.non_blocking (bool) – Whether the move should be non_blocking. Default: True
- Returns
a SampleList moved to the
device
.- Return type