oneflow.nn.ModuleDict¶
-
class
oneflow.nn.ModuleDict(modules: Optional[Mapping[str, oneflow.nn.modules.module.Module]] = None)¶ Holds submodules in a dictionary.
The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.ModuleDict.html?#torch.nn.ModuleDict.
ModuleDictcan be indexed like a regular Python dictionary, but modules it contains are properly registered, and will be visible by allModulemethods.ModuleDictis an ordered dictionary that respectsthe order of insertion, and
in
update(), the order of the mergedOrderedDict,dict(started from Python 3.6) or anotherModuleDict(the argument toupdate()).
Note that
update()with other unordered mapping types (e.g., Python’s plaindictbefore Python version 3.6) does not preserve the order of the merged mapping.- Parameters
modules (iterable, optional) – a mapping (dictionary) of (string: module) or an iterable of key-value pairs of type (string, module)
>>> import oneflow.nn as nn >>> class MyModule(nn.Module): ... def __init__(self): ... super(MyModule, self).__init__() ... self.choices = nn.ModuleDict({ ... 'conv': nn.Conv2d(10, 10, 3), ... 'pool': nn.MaxPool2d(3) ... }) ... self.activations = nn.ModuleDict([ ... ['lrelu', nn.LeakyReLU()], ... ['prelu', nn.PReLU()] ... ]) ... def forward(self, x, choice, act): ... x = self.choices[choice](x) ... x = self.activations[act](x) ... return x >>> model = MyModule() >>> model.choices ModuleDict( (conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)) (pool): MaxPool2d(kernel_size=(3, 3), stride=(3, 3), padding=(0, 0), dilation=(1, 1)) )