oneflow.nn.ModuleDict¶
-
class
oneflow.nn.
ModuleDict
(modules: Optional[Mapping[str, oneflow.nn.modules.module.Module]] = None)¶ Holds submodules in a dictionary.
The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.ModuleDict.html?#torch.nn.ModuleDict.
ModuleDict
can be indexed like a regular Python dictionary, but modules it contains are properly registered, and will be visible by allModule
methods.ModuleDict
is an ordered dictionary that respectsthe order of insertion, and
in
update()
, the order of the mergedOrderedDict
,dict
(started from Python 3.6) or anotherModuleDict
(the argument toupdate()
).
Note that
update()
with other unordered mapping types (e.g., Python’s plaindict
before Python version 3.6) does not preserve the order of the merged mapping.- Parameters
modules (iterable, optional) – a mapping (dictionary) of (string: module) or an iterable of key-value pairs of type (string, module)
>>> import oneflow.nn as nn >>> class MyModule(nn.Module): ... def __init__(self): ... super(MyModule, self).__init__() ... self.choices = nn.ModuleDict({ ... 'conv': nn.Conv2d(10, 10, 3), ... 'pool': nn.MaxPool2d(3) ... }) ... self.activations = nn.ModuleDict([ ... ['lrelu', nn.LeakyReLU()], ... ['prelu', nn.PReLU()] ... ]) ... def forward(self, x, choice, act): ... x = self.choices[choice](x) ... x = self.activations[act](x) ... return x >>> model = MyModule() >>> model.choices ModuleDict( (conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)) (pool): MaxPool2d(kernel_size=(3, 3), stride=(3, 3), padding=(0, 0), dilation=(1, 1)) )