Shortcuts

common.sample

Sample and SampleList are data structures for arbitrary data returned from a dataset. To work with MMF, minimum requirement for datasets is to return an object of Sample class and for models to accept an object of type SampleList as an argument.

Sample is used to represent an arbitrary sample from dataset, while SampleList is list of Sample combined in an efficient way to be used by the model. In simple term, SampleList is a batch of Sample but allow easy access of attributes from Sample while taking care of properly batching things.

class mmf.common.sample.Sample(init_dict=None)[source]

Sample represent some arbitrary data. All datasets in MMF must return an object of type Sample.

Parameters

init_dict (Dict) – Dictionary to init Sample class with.

Usage:

>>> sample = Sample({"text": torch.tensor(2)})
>>> sample.text.zero_()
# Custom attributes can be added to ``Sample`` after initialization
>>> sample.context = torch.tensor(4)
fields()[source]

Get current attributes/fields registered under the sample.

Returns

Attributes registered under the Sample.

Return type

List[str]

class mmf.common.sample.SampleList(samples=None)[source]

SampleList is used to collate a list of Sample into a batch during batch preparation. It can be thought of as a merger of list of Dicts into a single Dict.

If Sample contains an attribute ‘text’ of size (2) and there are 10 samples in list, the returned SampleList will have an attribute ‘text’ which is a tensor of size (10, 2).

Parameters

samples (type) – List of Sample from which the SampleList will be created.

Usage:

>>> sample_list = [
        Sample({"text": torch.tensor(2)}),
        Sample({"text": torch.tensor(2)})
    ]
>>> sample_list.text
torch.tensor([2, 2])
add_field(field, data)[source]

Add an attribute field with value data to the SampleList

Parameters
  • field (str) – Key under which the data will be added.

  • data (object) – Data to be added, can be a torch.Tensor, list or Sample

copy()[source]

Get a copy of the current SampleList

Returns

Copy of current SampleList.

Return type

SampleList

fields()[source]

Get current attributes/fields registered under the SampleList.

Returns

list of attributes of the SampleList.

Return type

List[str]

get_batch_size()[source]

Get batch size of the current SampleList. There must be a tensor be a tensor present inside sample list to use this function. :returns: Size of the batch in SampleList. :rtype: int

get_field(field)[source]

Get value of a particular attribute

Parameters

field (str) – Attribute whose value is to be returned.

get_fields(fields)[source]

Get a new SampleList generated from the current SampleList but contains only the attributes passed in fields argument

Parameters

fields (List[str]) – Attributes whose SampleList will be made.

Returns

SampleList containing only the attribute values of the fields which were passed.

Return type

SampleList

get_item_list(key)[source]

Get SampleList of only one particular attribute that is present in the SampleList.

Parameters

key (str) – Attribute whose SampleList will be made.

Returns

SampleList containing only the attribute value of the key which was passed.

Return type

SampleList

pin_memory()[source]

In custom batch object, we need to define pin_memory function so that PyTorch can actually apply pinning. This function just individually pins all of the tensor fields

to(device, non_blocking=True)[source]

Similar to .to function on a torch.Tensor. Moves all of the tensors present inside the SampleList to a particular device. If an attribute’s value is not a tensor, it is ignored and kept as it is.

Parameters
  • device (str|torch.device) – Device on which the SampleList should moved.

  • non_blocking (bool) – Whether the move should be non_blocking. Default: True

Returns

a SampleList moved to the device.

Return type

SampleList

to_dict() Dict[str, Any][source]

Converts a sample list to dict, this is useful for TorchScript and for other internal API unification efforts.

Returns

A dict representation of current sample list

Return type

Dict[str, Any]

mmf.common.sample.detach_tensor(tensor: Any) Any[source]

Detaches any element passed which has a .detach function defined. Currently, in MMF can be SampleList, Report or a tensor.

Parameters

tensor (Any) – Item to be detached

Returns

Detached element

Return type

Any