!8154 Add nn.MultiFieldEmbedding for the embedding lookup opearations

From: @huangxinjing
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-24 09:49:39 +08:00 committed by Gitee
commit 3d6d820612
3 changed files with 327 additions and 1 deletions

View File

@ -239,6 +239,7 @@ def create_group(group, rank_ids):
ValueError: If `rank_ids` size is not larger than 1, or `rank_ids` has duplicate data, or backend is invalid.
RuntimeError: If hccl/nccl is not available or nccl not supports.
Examples:
>>> init()
>>> group = "0-1"
>>> rank_ids = [0,1]
>>> create_group(group, rank_ids)

View File

@ -26,9 +26,10 @@ from mindspore._checkparam import Rel
from mindspore._checkparam import Validator as validator
from mindspore.ops.primitive import constexpr
from .basic import ClipByNorm
from .math import Range
from ..cell import Cell
__all__ = ['Embedding', 'EmbeddingLookup']
__all__ = ['Embedding', 'EmbeddingLookup', 'MultiFieldEmbeddingLookup']
class Embedding(Cell):
@ -268,3 +269,190 @@ class EmbeddingLookup(Cell):
clip_by_norm = ClipByNorm(axis)
out = clip_by_norm(out, self.max_norm)
return out
class MultiFieldEmbeddingLookup(EmbeddingLookup):
r"""
Returns a slice of input tensor based on the specified indices based on the filed ids. This operation
supports looking up embeddings within multi hot and one hot fields simultaneously.
Note:
When 'target' is set to 'CPU', this module will use
P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
specified 'offset = 0' to lookup table.
When 'target' is set to 'DEVICE', this module will use P.GatherV2() which
specified 'axis = 0' to lookup table.
The vectors with the same field_ids will be combined by the `operator`, such as `SUM`, `MAX` and
`MEAN`. Ensure the input_values of the padded id is zero, so that they can be ignored. The final
output will be zeros if the summed of absolute weight of the field is zero. This class only
supports ['table_row_slice', 'batch_slice' and 'table_column_slice']
Args:
vocab_size (int): Size of the dictionary of embeddings.
embedding_size (int): The size of each embedding vector.
field_size (int): The field size of the final outputs.
param_init (str): The initialize way of embedding table. Default: 'normal'.
target (str): Specifies the target where the op is executed. The value must in
['DEVICE', 'CPU']. Default: 'CPU'.
slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through
nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE.
feature_num_list (tuple): The accompaniment array in field slice mode.
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
or None. Default: None
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
operator (string): The pooling method for the features in one field. Support `SUM`, `MEAN` and 'MAX'
Inputs:
- **input_indices** (Tensor) - The shape of tensor is :math:`(batch_size, seq_length)`.
Specifies the indices of elements of the original Tensor. Values can be out of range of embedding_table,
and the exceeding part will be filled with 0 in the output. Input_indices must be a 2d tensor in
this interface. Type is Int16, Int32, Int64.
- **input_values** (Tensor) - The shape of tensor is :math:`(batch_size, seq_length)`.
Specifies the weights of elements of the input_indices. The lookout vector will multiply with
the input_values. Type is Float32.
- **field_ids** (Tensor) - The shape of tensor is :math:`(batch_size, seq_length)`.
Specifics the field id of elements of the input_indices. Type is Type is Int16, Int32, Int64.
Outputs:
Tensor, the shape of tensor is :math:`(batch_size, field_size, embedding_size)`. Type is Float32.
Examples:
>>> input_indices = Tensor([[2, 4, 6, 0, 0], [1, 3, 5, 0, 0]], mindspore.int32)
>>> input_values = Tensor([[1, 1, 1, 0, 0], [1, 1, 1, 0, 0]], mindspore.float32)
>>> field_ids = Tensor([[0, 1, 1, 0, 0], [0, 0, 1, 0, 0]], mindspore.int32)
>>> net = nn.MultiFieldEmbeddingLookup(10, 2, field_size=2, operator='SUM')
>>> out = net(input_indices, input_values, field_ids)
>>> print(result)
[[[-0.00478983 -0.00772568]
[-0.00968955 -0.00064902]]
[[-0.01251151 -0.01251151]
[-0.00196387 -0.00196387]
"""
OPERATOR_SUM = 'SUM'
OPERATOR_MEAN = 'MEAN'
OPERATOR_MAX = 'MAX'
def __init__(self, vocab_size, embedding_size, field_size, param_init='normal', target='CPU',
slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'):
super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target,
slice_mode, feature_num_list, max_norm, sparse)
self.field_size = field_size
self.operator = operator
self.mul = P.Mul()
self.inf_mask_mul = P.Mul()
self.bias_add = P.TensorAdd()
self.inf_add = P.TensorAdd()
self.merge_op = None
self.count_op = P.UnsortedSegmentSum()
self.abs = P.Abs()
self.equal = P.Equal()
self.add = P.TensorAdd()
self.cast = P.Cast()
self.div_no_nan = P.DivNoNan()
self.expand = P.ExpandDims()
self.max_mask_mul = P.Mul()
self.max_no_equal = P.NotEqual()
if operator == MultiFieldEmbeddingLookup.OPERATOR_SUM:
self.merge_op = P.UnsortedSegmentSum()
elif operator == MultiFieldEmbeddingLookup.OPERATOR_MAX:
self.merge_op = P.UnsortedSegmentMax()
elif operator == MultiFieldEmbeddingLookup.OPERATOR_MEAN:
self.merge_op = P.UnsortedSegmentSum()
else:
raise ValueError("The operator supports ['SUM', 'MAX', 'MEAN'], but found: "+str(operator))
parallel_mode = _get_parallel_mode()
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
if slice_mode in ["table_row_slice", "batch_slice"] and is_auto_parallel:
self.merge_op.shard(((get_group_size(), 1, 1), (get_group_size(), 1)))
self.expand.shard(((get_group_size(),),))
self.bias_add.shard(((1, 1), (1, 1)))
self.mul.shard(((get_group_size(), 1, 1), (get_group_size(), 1, 1)))
self.count_op.shard(((get_group_size(), 1), (get_group_size(), 1)))
self.add.shard(((get_group_size(),), (get_group_size(),)))
self.div_no_nan.shard(((get_group_size(), 1), (get_group_size(), 1)))
self.max_mask_mul.shard(((get_group_size(), 1), (get_group_size(), 1)))
self.max_no_equal.shard(((1,), ()))
if operator == MultiFieldEmbeddingLookup.OPERATOR_MAX:
self.equal.shard(((get_group_size(), 1, 1), ()))
self.inf_mask_mul.shard(((get_group_size(), 1, 1), ()))
self.merge_op.shard(((get_group_size(), 1), (get_group_size(),)))
self.count_op.shard(((get_group_size(),), (get_group_size(),)))
self.inf_add.shard(((get_group_size(), 1, 1), (get_group_size(), 1, 1)))
elif slice_mode == "table_column_slice" and is_auto_parallel:
self.merge_op.shard(((1, 1, get_group_size()), (1, 1)))
self.div_no_nan.shard(((1, get_group_size()), (1, 1)))
self.bias_add.shard(((1, 1), (1, 1)))
self.mul.shard(((1, 1, 1), (1, 1, get_group_size())))
self.count_op.shard(((1, 1), (1, 1)))
self.add.shard(((1,), (1,)))
self.max_mask_mul.shard(((1, get_group_size()), (1, 1)))
self.expand.shard(((1,),))
self.max_no_equal.shard(((1,), ()))
if operator == MultiFieldEmbeddingLookup.OPERATOR_MAX:
self.equal.shard(((1, 1, 1), ()))
self.inf_mask_mul.shard(((1, 1, 1), ()))
self.merge_op.shard(((1, get_group_size()), (1,)))
self.count_op.shard(((1,), (1,)))
self.inf_add.shard(((1, 1, get_group_size()), (1, 1, 1)))
else:
if is_auto_parallel:
raise ValueError("slice_mode should be ['table_row_slice', 'batch_slice' and \
'table_column_slice'], but get " + str(slice_mode))
# Min value for fp32
self.negative_inf_value = -3.402823466E+38
def construct(self, input_indices, input_values, field_ids):
batch_size = self.shape(input_indices)[0]
num_segments = batch_size * self.field_size
bias = Range(0, num_segments, self.field_size)()
bias = self.reshape(bias, (self.field_size, -1))
field_ids = self.bias_add(field_ids, bias)
if self.target == "CPU":
out = self.embeddinglookup(self.embedding_table, input_indices, 0)
else:
if self.forward_unique:
shp = self.shape(input_indices) + (self.embedding_size,)
indices_flatten = self.reshape(input_indices, (-1,))
unique_id, unique_idx = self.unique(indices_flatten)
weight_unique = self.gatherv2(self.embedding_table, unique_id, 0)
weight_flatten = self.gather_revert(weight_unique, unique_idx, 0)
out = self.reshape(weight_flatten, shp)
else:
out = self.gatherv2(self.embedding_table, input_indices, 0)
if self.max_norm is not None:
axis = _make_axis_range(F.rank(input_indices), F.rank(out))
clip_by_norm = ClipByNorm(axis)
out = clip_by_norm(out, self.max_norm)
weights = self.reshape(input_values, (batch_size, self.shape(input_indices)[1], 1))
embedding = self.mul(weights, out)
if self.operator == 'MAX':
# Fill the padding value to -inf, so the padded value will not influence the results
negatvie_inf_mask = self.cast(self.equal(weights, 0), mstype.float32)
inf_mask = self.inf_mask_mul(negatvie_inf_mask, self.negative_inf_value)
embedding = self.inf_add(embedding, inf_mask)
embedding = self.reshape(embedding, (-1, self.embedding_size))
field_ids = self.reshape(field_ids, (-1,))
merged_vectors = self.merge_op(embedding, field_ids, num_segments)
if self.operator == 'MAX':
value_count = self.count_op(self.abs(self.reshape(input_values, (-1,))), field_ids, num_segments)
value_zeros = self.cast(self.max_no_equal(value_count, 0.0), mstype.float32)
count = self.expand(value_zeros, -1)
merged_vectors = self.max_mask_mul(merged_vectors, count)
if self.operator == 'MEAN':
value_count = self.count_op(self.abs(input_values), field_ids, num_segments)
value_count = self.expand(value_count, -1)
merged_vectors = self.div_no_nan(merged_vectors, value_count)
merged_vectors = self.reshape(merged_vectors, (batch_size, self.field_size, -1))
return merged_vectors

View File

@ -0,0 +1,137 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore.common.api import _executor
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore import Tensor, context
from mindspore.nn import TrainOneStepCell, Adam
from tests.ut.python.ops.test_math_ops import VirtualLoss
grad_all = C.GradOperation(get_all=True)
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x, y, z):
return grad_all(self.network)(x, y, z)
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x, y, z):
predict = self.network(x, y, z)
return self.loss(predict)
class Net(nn.Cell):
def __init__(self, shape, slice_mode=nn.EmbeddingLookup.BATCH_SLICE, target="Device", operator='SUM'):
super().__init__()
self.embedding = nn.MultiFieldEmbeddingLookup(vocab_size=32, embedding_size=64, target=target,
field_size=shape[1], slice_mode=slice_mode, operator=operator)
self.reshape = P.Reshape().shard(((8, 1, 1),))
self.batch_size = shape[0]
def construct(self, x, y, z):
out = self.embedding(x, y, z)
out = self.reshape(out, (self.batch_size, -1))
return out
def compile_net(net, shape):
context.set_context(enable_sparse=True)
x = Tensor(np.ones(shape), dtype=ms.int32)
y = Tensor(np.ones(shape), dtype=ms.float32)
z = Tensor(np.ones(shape), dtype=ms.int32)
optimizer = Adam(net.trainable_params(), learning_rate=0.1)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()
_executor.compile(train_net, x, y, z)
context.reset_auto_parallel_context()
def test_embeddinglookup_batch_parallel_sum():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
shape = [64, 64]
net = NetWithLoss(Net(shape, target='DEVICE'))
compile_net(net, shape)
def test_embeddinglookup_row_parallel_sum():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
shape = [64, 64]
net = NetWithLoss(Net(shape, slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, target='DEVICE'))
compile_net(net, shape)
def test_embeddinglookup_column_parallel_sum():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
shape = [64, 64]
net = NetWithLoss(Net(shape, slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE, target='DEVICE'))
compile_net(net, shape)
def test_embeddinglookup_batch_parallel_mean():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
shape = [64, 64]
net = NetWithLoss(Net(shape, target='DEVICE', operator='MEAN'))
compile_net(net, shape)
def test_embeddinglookup_column_parallel_mean():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
shape = [64, 64]
net = NetWithLoss(Net(shape, target='DEVICE', slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE, operator='MEAN'))
compile_net(net, shape)
def test_embeddinglookup_row_parallel_mean():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
shape = [64, 64]
net = NetWithLoss(Net(shape, target='DEVICE', slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, operator='MEAN'))
compile_net(net, shape)
def test_embeddinglookup_batch_parallel_max():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
shape = [64, 64]
net = NetWithLoss(Net(shape, target='DEVICE', operator='MAX'))
compile_net(net, shape)
def test_embeddinglookup_column_parallel_max():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
shape = [64, 64]
net = NetWithLoss(Net(shape, target='DEVICE', slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE, operator='MAX'))
compile_net(net, shape)
def test_embeddinglookup_row_parallel_max():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
shape = [64, 64]
net = NetWithLoss(Net(shape, target='DEVICE', slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, operator='MAX'))
compile_net(net, shape)