oneflow.one_embedding

Embedding is an important component of recommender system, and it has also spread to many fields outside recommender systems. Each framework provides basic operators for Embedding, for example, flow.nn.Embedding in OneFlow:

import numpy as np
import oneflow as flow
indices = flow.tensor([[1, 2, 4, 5], [4, 3, 2, 9]], dtype=flow.int)
embedding = flow.nn.Embedding(10, 3)
y = embedding(indices)

OneEmbedding is the large-scale Embedding solution that OneFlow provides to solve the problem of large-scale deep recommender systems. OneEmbedding has the following advantages compared to ordinary opeartors:

  • With Flexible hierarchical storage, OneEmbedding can place the Embedding table on GPU memory, CPU memory or SSD, and allow high-speed devices to be used as caches for low-speed devices to achieve both speed and capacity.

  • OneEmbedding supports dynamic expansion.

Note

Please refer to Large-Scale Embedding Solution: OneEmbedding for a brief introduction to all features related to OneEmbedding.

Configure Embedding Table

OneEmbedding supports simultaneous creation of multiple Embedding table. The following codes configured three Embedding tables.

import oneflow as flow
import oneflow.nn as nn
import numpy as np

tables = [
    flow.one_embedding.make_table_options(
        flow.one_embedding.make_uniform_initializer(low=-0.1, high=0.1)
    ),
    flow.one_embedding.make_table_options(
        flow.one_embedding.make_uniform_initializer(low=-0.05, high=0.05)
    ),
    flow.one_embedding.make_table_options(
        flow.one_embedding.make_uniform_initializer(low=-0.15, high=0.15)
    ),
]

When configuring the Embedding table, you need to specify the initialization method. The above Embedding tables are initialized in the uniform method. The result of configuring the Embedding table is stored in the tables variable

oneflow.one_embedding.make_table_options(param)

make table param of Embedding tables

Parameters

param (dict or list) – param can be initializer or list of column_option. initializer can be made by make_uniform_initializer or make_normal_initializer or make_constant_initializer, column options can be made by make_column_options

Returns

table param of Embedding tables

Return type

dict

For example:

>>> import oneflow as flow
>>> initializer = flow.one_embedding.make_uniform_initializer(low=-scale, high=scale)
>>> table1 = flow.one_embedding.make_table_options(initializer)
>>> table2 = flow.one_embedding.make_table_options(initializer)
>>> tables = [table1, table2]
>>> # pass the tables to the "tables" param of flow.one_embedding.MultiTableEmbedding or flow.one_embedding.MultiTableMultiColumnEmbedding
>>> # ...
oneflow.one_embedding.make_table(param)

alias of oneflow.one_embedding.make_table_options

See also oneflow.one_embedding.make_table_options()

initialization method

make_uniform_initializer

make uniform initializer param of make_table_options

make_normal_initializer

make normal initializer param of make_table_options

Configure the Storage Attribute of the Embedding Table

Then run the following codes to configure the storage attribute of the Embedding table:

store_options = flow.one_embedding.make_cached_ssd_store_options(
cache_budget_mb=8142,
persistent_path="/your_path_to_ssd",
capacity=40000000,
size_factor=1,
physical_block_size=4096
)

Storage Method

make_device_mem_store_options

make GPU only store_options param of MultiTableEmbedding

make_cached_ssd_store_options

make SSD use GPU and host as cache store_options param of MultiTableEmbedding.

make_cached_host_mem_store_options

make host use GPU as cache store_options param of MultiTableEmbedding

Note

Please refer to Large-Scale Embedding Solution: OneEmbedding for a brief introduction to learn about How to Choose the Proper Storage Configuration

Instantiate Embedding

After the above configuration is completed, you can use MultiTableEmbedding to get the instantiated Embedding layer.

embedding_size = 128
embedding = flow.one_embedding.MultiTableEmbedding(
    name="my_embedding",
    embedding_dim=embedding_size,
    dtype=flow.float,
    key_type=flow.int64,
    tables=tables,
    store_options=store_options,
)

embedding.to("cuda")

Note

Please refer to Large-Scale Embedding Solution: OneEmbedding for a brief introduction to learn about Feature ID and Multi-Table Query.

MultiTableEmbedding

oneflow.one_embedding.MultiTableEmbedding(name, embedding_dim, dtype, key_type, tables, store_options, default_initializer=None, padding_idx=None, seed=0)

MultiTableEmbedding represent multi Embedding tables with same embedding_dim, dtype, and key_type.

Parameters
  • name (str) – The name of Embedding

  • embedding_dim (int) – the size of each embedding vector

  • dtype (flow.dtype) – the data type of embeddings

  • key_type (flow.dtype) – the data type of feature ids

  • tables (list) – list of table param which can be made by flow.one_embedding.make_table_options

  • store_options (dict) – store option of Embedding

  • default_initializer (dict, optional) – if tables param is None, use default_initializer to initialize table. Defaults to None.

  • padding_idx (int, optional) – If specified, the entries at padding_idx do not contribute to the gradient; therefore, the embedding vector at padding_idx is not updated during training, the embedding vector at padding_idx will default to all zeros.

For example:

>>> import oneflow as flow
>>> import numpy as np
>>> import oneflow.nn as nn
>>> # a simple example with 3 table
>>> table_size_array = [39884407, 39043, 17289]
>>> vocab_size = sum(table_size_array)
>>> num_tables = len(table_size_array)
>>> embedding_size = 128
>>> scales = np.sqrt(1 / np.array(table_size_array))
>>> tables = [
>>>     flow.one_embedding.make_table_options(
>>>         flow.one_embedding.make_uniform_initializer(low=-scale, high=scale)
>>>     )
>>>     for scale in scales
>>> ]
>>> store_options = flow.one_embedding.make_cached_ssd_store_options(
>>>     cache_budget_mb=8192, persistent_path="/your_path_to_ssd", capacity=vocab_size,
>>> )
>>> embedding = flow.one_embedding.MultiTableEmbedding(
>>>     name="my_embedding",
>>>     embedding_dim=embedding_size,
>>>     dtype=flow.float,
>>>     key_type=flow.int64,
>>>     tables=tables,
>>>     store_options=store_options,
>>> )
>>> embedding.to("cuda")
>>> mlp = flow.nn.FusedMLP(
>>>     in_features=embedding_size * num_tables,
>>>     hidden_features=[512, 256, 128],
>>>     out_features=1,
>>>     skip_final_activation=True,
>>> )
>>> mlp.to("cuda")
>>>
>>> class TrainGraph(flow.nn.Graph):
>>>     def __init__(self,):
>>>         super().__init__()
>>>         self.embedding_lookup = embedding
>>>         self.mlp = mlp
>>>         self.add_optimizer(
>>>             flow.optim.SGD(self.embedding_lookup.parameters(), lr=0.1, momentum=0.0)
>>>         )
>>>         self.add_optimizer(
>>>             flow.optim.SGD(self.mlp.parameters(), lr=0.1, momentum=0.0)
>>>         )
>>>     def build(self, ids):
>>>         embedding = self.embedding_lookup(ids)
>>>         loss = self.mlp(flow.reshape(embedding, (-1, num_tables * embedding_size)))
>>>         loss = loss.sum()
>>>         loss.backward()
>>>         return loss
>>> ids = np.random.randint(0, 1000, (100, num_tables), dtype=np.int64)
>>> ids_tensor = flow.tensor(ids, requires_grad=False).to("cuda")
>>> graph = TrainGraph()
>>> loss = graph(ids_tensor)
>>> print(loss)

forward

Embedding lookup operation

save_snapshot

save snapshot

load_snapshot

load snapshot

MultiTableMultiColumnEmbedding

oneflow.one_embedding.MultiTableMultiColumnEmbedding(name, embedding_dim, dtype, key_type, tables, store_options, default_initializer=None, padding_idx=None, seed=0)

MultiTableMultiColumnEmbedding represent multi Embedding tables with multi embedding_dim, same dtype, and key_type.

Parameters
  • name (str) – The name of Embedding

  • embedding_dim (list) – list of the size of each embedding vector

  • dtype (flow.dtype) – the data type of embeddings

  • key_type (flow.dtype) – the data type of feature ids

  • tables (list) – list of table param which can be made by flow.one_embedding.make_table_options

  • store_options (dict) – store option of Embedding

  • default_initializer (dict, optional) – if tables param is None, use default_initializer to initialize table. Defaults to None.

  • padding_idx (int, optional) – If specified, the entries at padding_idx do not contribute to the gradient; therefore, the embedding vector at padding_idx is not updated during training, the embedding vector at padding_idx will default to all zeros.

For example:

>>> import oneflow as flow
>>> import numpy as np
>>> import oneflow.nn as nn
>>> # a simple example with 3 table, every table has two column, the first column embedding_size is 10 and the second is 1.
>>> # every table's first column initialize with uniform(-1/sqrt(table_size), 1/sqrt(table_size)), second column initialize with normal(0, 1/sqrt(table_size))
>>> table_size_array = [39884407, 39043, 17289]
>>> vocab_size = sum(table_size_array)
>>> num_tables = len(table_size_array)
>>> embedding_size_list = [10, 1]
>>> scales = np.sqrt(1 / np.array(table_size_array))
>>> tables = [
>>>     flow.one_embedding.make_table_options(
>>>       [flow.one_embedding.make_column_options(
>>>         flow.one_embedding.make_uniform_initializer(low=-scale, high=scale)),
>>>        flow.one_embedding.make_column_options(
>>>         flow.one_embedding.make_normal_initializer(mean=0, std=scale))]
>>>     )
>>>     for scale in scales
>>> ]
>>> store_options = flow.one_embedding.make_cached_ssd_store_options(
>>>     cache_budget_mb=8192, persistent_path="/your_path_to_ssd", capacity=vocab_size,
>>> )
>>> embedding = flow.one_embedding.MultiTableMultiColumnEmbedding(
>>>     name="my_embedding",
>>>     embedding_dim=embedding_size_list,
>>>     dtype=flow.float,
>>>     key_type=flow.int64,
>>>     tables=tables,
>>>     store_options=store_options,
>>> )
>>> embedding.to("cuda")
>>> mlp = flow.nn.FusedMLP(
>>>     in_features=sum(embedding_size_list) * num_tables,
>>>     hidden_features=[512, 256, 128],
>>>     out_features=1,
>>>     skip_final_activation=True,
>>> )
>>> mlp.to("cuda")
>>>
>>> class TrainGraph(flow.nn.Graph):
>>>     def __init__(self,):
>>>         super().__init__()
>>>         self.embedding_lookup = embedding
>>>         self.mlp = mlp
>>>         self.add_optimizer(
>>>             flow.optim.SGD(self.embedding_lookup.parameters(), lr=0.1, momentum=0.0)
>>>         )
>>>         self.add_optimizer(
>>>             flow.optim.SGD(self.mlp.parameters(), lr=0.1, momentum=0.0)
>>>         )
>>>     def build(self, ids):
>>>         embedding = self.embedding_lookup(ids)
>>>         loss = self.mlp(flow.reshape(embedding, (-1, num_tables * sum(embedding_size_list))))
>>>         loss = loss.sum()
>>>         loss.backward()
>>>         return loss
>>> ids = np.random.randint(0, 1000, (100, num_tables), dtype=np.int64)
>>> ids_tensor = flow.tensor(ids, requires_grad=False).to("cuda")
>>> graph = TrainGraph()
>>> loss = graph(ids_tensor)
>>> print(loss)

forward

Embedding lookup operation

save_snapshot

save snapshot

load_snapshot

load snapshot

Construct Graph for Training

OneEmbedding is only supported in Graph mode.

num_tables = 3
mlp = flow.nn.FusedMLP(
    in_features=embedding_size * num_tables,
    hidden_features=[512, 256, 128],
    out_features=1,
    skip_final_activation=True,
)
mlp.to("cuda")

class TrainGraph(flow.nn.Graph):
    def __init__(self,):
        super().__init__()
        self.embedding_lookup = embedding
        self.mlp = mlp
        self.add_optimizer(
            flow.optim.SGD(self.embedding_lookup.parameters(), lr=0.1, momentum=0.0)
        )
        self.add_optimizer(
            flow.optim.SGD(self.mlp.parameters(), lr=0.1, momentum=0.0)
        )
    def build(self, ids):
        embedding = self.embedding_lookup(ids)
        loss = self.mlp(flow.reshape(embedding, (-1, num_tables * embedding_size)))
        loss = loss.sum()
        loss.backward()
        return loss

Note

Please refer to Distributed Training: OneEmbedding for a brief introduction to learn about Graph For Training

Persistent Read & Write

make_persistent_table_reader

Creates a reader for reading persistent table.

make_persistent_table_writer

Creates a writer for writing persistent table.

Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

class oneflow.one_embedding.Ftrl(params: Union[Iterator[oneflow.nn.Parameter], List[Dict]], lr: float = 0.001, weight_decay: float = 0.0, lr_power: float = - 0.5, initial_accumulator_value: float = 0.1, lambda1: float = 0.0, lambda2: float = 0.0, beta: float = 0.0)

FTRL Optimizer.

The formula is:

\[\begin{split}\begin{align} accumlator_{i+1} = accumlator_{i} + grad * grad \\ sigma = (accumulator_{i+1}^{lr\_power} - accumulator_{i}^{lr\_power}) / learning\_rate \\ z_{i+1} = z_{i} + grad - sigma * param_{i} \\ \text{} param_{i+1} = \begin{cases} 0 & \text{ if } |z_{i+1}| < \lambda_1 \\ -(\frac{\beta+accumlator_{i+1}^{lr\_power}}{learning\_rate} + \lambda_2)*(z_{i+1} - sign(z_{i+1})*\lambda_1) & \text{ otherwise } \\ \end{cases} \end{align}\end{split}\]

Example 1:

# Assume net is a custom model.
ftrl = flow.one_embedding.FTRL(net.parameters(), lr=1e-3)

for epoch in range(epochs):
    # Read data, Compute the loss and so on.
    # ...
    loss.backward()
    ftrl.step()
    ftrl.zero_grad()
Parameters
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float, optional) – learning rate. Defaults to 1e-3.

  • weight_decay (float, optional) – weight decay (L2 penalty). Defaults to 0.0.

  • lr_power (float, optional) – learning rate decrease factor. Defaults to -0.5.

  • initial_accumulator_value (float, optional) – The initial value of accumlator. Defaults to 0.1.

  • lambda1 (float, optional) – L1 regularization strength. Defaults to 0.0.

  • lambda2 (float, optional) – L2 regularization strength. Defaults to 0.0.

  • beta (float, optional) – The value of beta. Defaults to 0.0.

step(closure: Optional[Callable] = None)

Performs a single optimization step.

Parameters

closure (callable, optional) – A closure that reevaluates the model and returns the loss.

property support_sparse

Whether the Optimizer support sparse update.