oneflow.nn.ParameterDict

class oneflow.nn.ParameterDict(parameters=None)

Holds parameters in a dictionary.

ParameterDict can be indexed like a regular Python dictionary, but parameters it contains are properly registered, and will be visible by all Module methods.

ParameterDict is an ordered dictionary that respects

  • the order of insertion, and

  • in update(), the order of the merged OrderedDict or another ParameterDict (the argument to update()).

Note that update() with other unordered mapping types (e.g., Python’s plain dict) does not preserve the order of the merged mapping.

The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.ParameterDict.html?#torch.nn.ParameterDict.

Parameters

parameters (iterable, optional) – a mapping (dictionary) of (string : Parameter) or an iterable of key-value pairs of type (string, Parameter)

>>> import oneflow as flow
>>> import oneflow.nn as nn

>>> class MyModule(nn.Module):
...    def __init__(self):
...        super(MyModule, self).__init__()
...        self.params = nn.ParameterDict({
...                'left': nn.Parameter(flow.randn(5, 10)),
...                'right': nn.Parameter(flow.randn(5, 10))
...        })
...
...    def forward(self, x, choice):
...        x = self.params[choice].mm(x)
...        return x

>>> model = MyModule()
>>> model.params
ParameterDict(
    (left): Parameter containing: [<class 'oneflow.nn.Parameter'> of size 5x10]
    (right): Parameter containing: [<class 'oneflow.nn.Parameter'> of size 5x10]
)