forked from mindspore-Ecosystem/mindspore
!8154 Add nn.MultiFieldEmbedding for the embedding lookup opearations
From: @huangxinjing Reviewed-by: Signed-off-by:
This commit is contained in:
commit
3d6d820612
|
@ -239,6 +239,7 @@ def create_group(group, rank_ids):
|
|||
ValueError: If `rank_ids` size is not larger than 1, or `rank_ids` has duplicate data, or backend is invalid.
|
||||
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
||||
Examples:
|
||||
>>> init()
|
||||
>>> group = "0-1"
|
||||
>>> rank_ids = [0,1]
|
||||
>>> create_group(group, rank_ids)
|
||||
|
|
|
@ -26,9 +26,10 @@ from mindspore._checkparam import Rel
|
|||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from .basic import ClipByNorm
|
||||
from .math import Range
|
||||
from ..cell import Cell
|
||||
|
||||
__all__ = ['Embedding', 'EmbeddingLookup']
|
||||
__all__ = ['Embedding', 'EmbeddingLookup', 'MultiFieldEmbeddingLookup']
|
||||
|
||||
|
||||
class Embedding(Cell):
|
||||
|
@ -268,3 +269,190 @@ class EmbeddingLookup(Cell):
|
|||
clip_by_norm = ClipByNorm(axis)
|
||||
out = clip_by_norm(out, self.max_norm)
|
||||
return out
|
||||
|
||||
|
||||
class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
||||
r"""
|
||||
Returns a slice of input tensor based on the specified indices based on the filed ids. This operation
|
||||
supports looking up embeddings within multi hot and one hot fields simultaneously.
|
||||
|
||||
Note:
|
||||
When 'target' is set to 'CPU', this module will use
|
||||
P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
|
||||
specified 'offset = 0' to lookup table.
|
||||
When 'target' is set to 'DEVICE', this module will use P.GatherV2() which
|
||||
specified 'axis = 0' to lookup table.
|
||||
The vectors with the same field_ids will be combined by the `operator`, such as `SUM`, `MAX` and
|
||||
`MEAN`. Ensure the input_values of the padded id is zero, so that they can be ignored. The final
|
||||
output will be zeros if the summed of absolute weight of the field is zero. This class only
|
||||
supports ['table_row_slice', 'batch_slice' and 'table_column_slice']
|
||||
|
||||
Args:
|
||||
vocab_size (int): Size of the dictionary of embeddings.
|
||||
embedding_size (int): The size of each embedding vector.
|
||||
field_size (int): The field size of the final outputs.
|
||||
param_init (str): The initialize way of embedding table. Default: 'normal'.
|
||||
target (str): Specifies the target where the op is executed. The value must in
|
||||
['DEVICE', 'CPU']. Default: 'CPU'.
|
||||
slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through
|
||||
nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE.
|
||||
feature_num_list (tuple): The accompaniment array in field slice mode.
|
||||
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
|
||||
or None. Default: None
|
||||
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
|
||||
operator (string): The pooling method for the features in one field. Support `SUM`, `MEAN` and 'MAX'
|
||||
|
||||
Inputs:
|
||||
- **input_indices** (Tensor) - The shape of tensor is :math:`(batch_size, seq_length)`.
|
||||
Specifies the indices of elements of the original Tensor. Values can be out of range of embedding_table,
|
||||
and the exceeding part will be filled with 0 in the output. Input_indices must be a 2d tensor in
|
||||
this interface. Type is Int16, Int32, Int64.
|
||||
- **input_values** (Tensor) - The shape of tensor is :math:`(batch_size, seq_length)`.
|
||||
Specifies the weights of elements of the input_indices. The lookout vector will multiply with
|
||||
the input_values. Type is Float32.
|
||||
- **field_ids** (Tensor) - The shape of tensor is :math:`(batch_size, seq_length)`.
|
||||
Specifics the field id of elements of the input_indices. Type is Type is Int16, Int32, Int64.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape of tensor is :math:`(batch_size, field_size, embedding_size)`. Type is Float32.
|
||||
|
||||
Examples:
|
||||
>>> input_indices = Tensor([[2, 4, 6, 0, 0], [1, 3, 5, 0, 0]], mindspore.int32)
|
||||
>>> input_values = Tensor([[1, 1, 1, 0, 0], [1, 1, 1, 0, 0]], mindspore.float32)
|
||||
>>> field_ids = Tensor([[0, 1, 1, 0, 0], [0, 0, 1, 0, 0]], mindspore.int32)
|
||||
>>> net = nn.MultiFieldEmbeddingLookup(10, 2, field_size=2, operator='SUM')
|
||||
>>> out = net(input_indices, input_values, field_ids)
|
||||
>>> print(result)
|
||||
[[[-0.00478983 -0.00772568]
|
||||
[-0.00968955 -0.00064902]]
|
||||
[[-0.01251151 -0.01251151]
|
||||
[-0.00196387 -0.00196387]
|
||||
"""
|
||||
OPERATOR_SUM = 'SUM'
|
||||
OPERATOR_MEAN = 'MEAN'
|
||||
OPERATOR_MAX = 'MAX'
|
||||
def __init__(self, vocab_size, embedding_size, field_size, param_init='normal', target='CPU',
|
||||
slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'):
|
||||
super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target,
|
||||
slice_mode, feature_num_list, max_norm, sparse)
|
||||
self.field_size = field_size
|
||||
self.operator = operator
|
||||
|
||||
self.mul = P.Mul()
|
||||
self.inf_mask_mul = P.Mul()
|
||||
self.bias_add = P.TensorAdd()
|
||||
self.inf_add = P.TensorAdd()
|
||||
self.merge_op = None
|
||||
self.count_op = P.UnsortedSegmentSum()
|
||||
self.abs = P.Abs()
|
||||
self.equal = P.Equal()
|
||||
self.add = P.TensorAdd()
|
||||
self.cast = P.Cast()
|
||||
self.div_no_nan = P.DivNoNan()
|
||||
self.expand = P.ExpandDims()
|
||||
self.max_mask_mul = P.Mul()
|
||||
self.max_no_equal = P.NotEqual()
|
||||
|
||||
if operator == MultiFieldEmbeddingLookup.OPERATOR_SUM:
|
||||
self.merge_op = P.UnsortedSegmentSum()
|
||||
elif operator == MultiFieldEmbeddingLookup.OPERATOR_MAX:
|
||||
self.merge_op = P.UnsortedSegmentMax()
|
||||
elif operator == MultiFieldEmbeddingLookup.OPERATOR_MEAN:
|
||||
self.merge_op = P.UnsortedSegmentSum()
|
||||
else:
|
||||
raise ValueError("The operator supports ['SUM', 'MAX', 'MEAN'], but found: "+str(operator))
|
||||
|
||||
parallel_mode = _get_parallel_mode()
|
||||
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||
if slice_mode in ["table_row_slice", "batch_slice"] and is_auto_parallel:
|
||||
self.merge_op.shard(((get_group_size(), 1, 1), (get_group_size(), 1)))
|
||||
self.expand.shard(((get_group_size(),),))
|
||||
self.bias_add.shard(((1, 1), (1, 1)))
|
||||
self.mul.shard(((get_group_size(), 1, 1), (get_group_size(), 1, 1)))
|
||||
self.count_op.shard(((get_group_size(), 1), (get_group_size(), 1)))
|
||||
self.add.shard(((get_group_size(),), (get_group_size(),)))
|
||||
self.div_no_nan.shard(((get_group_size(), 1), (get_group_size(), 1)))
|
||||
self.max_mask_mul.shard(((get_group_size(), 1), (get_group_size(), 1)))
|
||||
self.max_no_equal.shard(((1,), ()))
|
||||
if operator == MultiFieldEmbeddingLookup.OPERATOR_MAX:
|
||||
self.equal.shard(((get_group_size(), 1, 1), ()))
|
||||
self.inf_mask_mul.shard(((get_group_size(), 1, 1), ()))
|
||||
self.merge_op.shard(((get_group_size(), 1), (get_group_size(),)))
|
||||
self.count_op.shard(((get_group_size(),), (get_group_size(),)))
|
||||
self.inf_add.shard(((get_group_size(), 1, 1), (get_group_size(), 1, 1)))
|
||||
elif slice_mode == "table_column_slice" and is_auto_parallel:
|
||||
self.merge_op.shard(((1, 1, get_group_size()), (1, 1)))
|
||||
self.div_no_nan.shard(((1, get_group_size()), (1, 1)))
|
||||
self.bias_add.shard(((1, 1), (1, 1)))
|
||||
self.mul.shard(((1, 1, 1), (1, 1, get_group_size())))
|
||||
self.count_op.shard(((1, 1), (1, 1)))
|
||||
self.add.shard(((1,), (1,)))
|
||||
self.max_mask_mul.shard(((1, get_group_size()), (1, 1)))
|
||||
self.expand.shard(((1,),))
|
||||
self.max_no_equal.shard(((1,), ()))
|
||||
if operator == MultiFieldEmbeddingLookup.OPERATOR_MAX:
|
||||
self.equal.shard(((1, 1, 1), ()))
|
||||
self.inf_mask_mul.shard(((1, 1, 1), ()))
|
||||
self.merge_op.shard(((1, get_group_size()), (1,)))
|
||||
self.count_op.shard(((1,), (1,)))
|
||||
self.inf_add.shard(((1, 1, get_group_size()), (1, 1, 1)))
|
||||
else:
|
||||
if is_auto_parallel:
|
||||
raise ValueError("slice_mode should be ['table_row_slice', 'batch_slice' and \
|
||||
'table_column_slice'], but get " + str(slice_mode))
|
||||
|
||||
# Min value for fp32
|
||||
self.negative_inf_value = -3.402823466E+38
|
||||
|
||||
def construct(self, input_indices, input_values, field_ids):
|
||||
|
||||
batch_size = self.shape(input_indices)[0]
|
||||
num_segments = batch_size * self.field_size
|
||||
|
||||
bias = Range(0, num_segments, self.field_size)()
|
||||
bias = self.reshape(bias, (self.field_size, -1))
|
||||
field_ids = self.bias_add(field_ids, bias)
|
||||
|
||||
if self.target == "CPU":
|
||||
out = self.embeddinglookup(self.embedding_table, input_indices, 0)
|
||||
else:
|
||||
if self.forward_unique:
|
||||
shp = self.shape(input_indices) + (self.embedding_size,)
|
||||
indices_flatten = self.reshape(input_indices, (-1,))
|
||||
unique_id, unique_idx = self.unique(indices_flatten)
|
||||
weight_unique = self.gatherv2(self.embedding_table, unique_id, 0)
|
||||
weight_flatten = self.gather_revert(weight_unique, unique_idx, 0)
|
||||
out = self.reshape(weight_flatten, shp)
|
||||
else:
|
||||
out = self.gatherv2(self.embedding_table, input_indices, 0)
|
||||
if self.max_norm is not None:
|
||||
axis = _make_axis_range(F.rank(input_indices), F.rank(out))
|
||||
clip_by_norm = ClipByNorm(axis)
|
||||
out = clip_by_norm(out, self.max_norm)
|
||||
|
||||
weights = self.reshape(input_values, (batch_size, self.shape(input_indices)[1], 1))
|
||||
embedding = self.mul(weights, out)
|
||||
|
||||
if self.operator == 'MAX':
|
||||
# Fill the padding value to -inf, so the padded value will not influence the results
|
||||
negatvie_inf_mask = self.cast(self.equal(weights, 0), mstype.float32)
|
||||
inf_mask = self.inf_mask_mul(negatvie_inf_mask, self.negative_inf_value)
|
||||
embedding = self.inf_add(embedding, inf_mask)
|
||||
embedding = self.reshape(embedding, (-1, self.embedding_size))
|
||||
field_ids = self.reshape(field_ids, (-1,))
|
||||
|
||||
merged_vectors = self.merge_op(embedding, field_ids, num_segments)
|
||||
|
||||
if self.operator == 'MAX':
|
||||
value_count = self.count_op(self.abs(self.reshape(input_values, (-1,))), field_ids, num_segments)
|
||||
value_zeros = self.cast(self.max_no_equal(value_count, 0.0), mstype.float32)
|
||||
count = self.expand(value_zeros, -1)
|
||||
merged_vectors = self.max_mask_mul(merged_vectors, count)
|
||||
|
||||
if self.operator == 'MEAN':
|
||||
value_count = self.count_op(self.abs(input_values), field_ids, num_segments)
|
||||
value_count = self.expand(value_count, -1)
|
||||
merged_vectors = self.div_no_nan(merged_vectors, value_count)
|
||||
|
||||
merged_vectors = self.reshape(merged_vectors, (batch_size, self.field_size, -1))
|
||||
return merged_vectors
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.nn import TrainOneStepCell, Adam
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y, z):
|
||||
return grad_all(self.network)(x, y, z)
|
||||
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y, z):
|
||||
predict = self.network(x, y, z)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, shape, slice_mode=nn.EmbeddingLookup.BATCH_SLICE, target="Device", operator='SUM'):
|
||||
super().__init__()
|
||||
self.embedding = nn.MultiFieldEmbeddingLookup(vocab_size=32, embedding_size=64, target=target,
|
||||
field_size=shape[1], slice_mode=slice_mode, operator=operator)
|
||||
self.reshape = P.Reshape().shard(((8, 1, 1),))
|
||||
self.batch_size = shape[0]
|
||||
|
||||
def construct(self, x, y, z):
|
||||
out = self.embedding(x, y, z)
|
||||
out = self.reshape(out, (self.batch_size, -1))
|
||||
return out
|
||||
|
||||
|
||||
def compile_net(net, shape):
|
||||
context.set_context(enable_sparse=True)
|
||||
x = Tensor(np.ones(shape), dtype=ms.int32)
|
||||
y = Tensor(np.ones(shape), dtype=ms.float32)
|
||||
z = Tensor(np.ones(shape), dtype=ms.int32)
|
||||
optimizer = Adam(net.trainable_params(), learning_rate=0.1)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
train_net.set_train()
|
||||
_executor.compile(train_net, x, y, z)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_embeddinglookup_batch_parallel_sum():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
shape = [64, 64]
|
||||
net = NetWithLoss(Net(shape, target='DEVICE'))
|
||||
compile_net(net, shape)
|
||||
|
||||
|
||||
def test_embeddinglookup_row_parallel_sum():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
shape = [64, 64]
|
||||
net = NetWithLoss(Net(shape, slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, target='DEVICE'))
|
||||
compile_net(net, shape)
|
||||
|
||||
|
||||
def test_embeddinglookup_column_parallel_sum():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
shape = [64, 64]
|
||||
net = NetWithLoss(Net(shape, slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE, target='DEVICE'))
|
||||
compile_net(net, shape)
|
||||
|
||||
|
||||
def test_embeddinglookup_batch_parallel_mean():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
shape = [64, 64]
|
||||
net = NetWithLoss(Net(shape, target='DEVICE', operator='MEAN'))
|
||||
compile_net(net, shape)
|
||||
|
||||
|
||||
def test_embeddinglookup_column_parallel_mean():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
shape = [64, 64]
|
||||
net = NetWithLoss(Net(shape, target='DEVICE', slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE, operator='MEAN'))
|
||||
compile_net(net, shape)
|
||||
|
||||
|
||||
def test_embeddinglookup_row_parallel_mean():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
shape = [64, 64]
|
||||
net = NetWithLoss(Net(shape, target='DEVICE', slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, operator='MEAN'))
|
||||
compile_net(net, shape)
|
||||
|
||||
|
||||
def test_embeddinglookup_batch_parallel_max():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
shape = [64, 64]
|
||||
net = NetWithLoss(Net(shape, target='DEVICE', operator='MAX'))
|
||||
compile_net(net, shape)
|
||||
|
||||
|
||||
def test_embeddinglookup_column_parallel_max():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
shape = [64, 64]
|
||||
net = NetWithLoss(Net(shape, target='DEVICE', slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE, operator='MAX'))
|
||||
compile_net(net, shape)
|
||||
|
||||
|
||||
def test_embeddinglookup_row_parallel_max():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
shape = [64, 64]
|
||||
net = NetWithLoss(Net(shape, target='DEVICE', slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, operator='MAX'))
|
||||
compile_net(net, shape)
|
Loading…
Reference in New Issue