forked from mindspore-Ecosystem/mindspore
embeddinglookup wrap
This commit is contained in:
parent
5adba834d0
commit
e4de26d5bc
|
@ -18,10 +18,14 @@ from mindspore.common.tensor import Tensor
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
from mindspore.parallel._utils import _get_parallel_mode
|
||||
from ..cell import Cell
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Validator as validator, Rel
|
||||
|
||||
__all__ = ['Embedding', 'EmbeddingLookup']
|
||||
__all__ = ['Embedding', 'EmbeddingLookup', 'EmbeddingLookUpSplitMode']
|
||||
|
||||
class Embedding(Cell):
|
||||
r"""
|
||||
|
@ -114,29 +118,36 @@ class EmbeddingLookup(Cell):
|
|||
When 'target' is set to 'CPU', this module will use
|
||||
P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
|
||||
specified 'offset = 0' to lookup table.
|
||||
when 'target' is set to 'DEVICE', this module will use P.GatherV2() which
|
||||
When 'target' is set to 'DEVICE', this module will use P.GatherV2() which
|
||||
specified 'axis = 0' to lookup table.
|
||||
In field slice mode, the manual_shapes should be given. It is a tuple ,where
|
||||
the element is (vocab[i], offset[i]), vocab[i] is the row numbers for i-th
|
||||
part and offset[i] is the feature id offset for i-th part. The feature id in
|
||||
i-th part will be subtracted by offset[i] to ensure the id start from 0.
|
||||
|
||||
Args:
|
||||
vocab_size (int): Size of the dictionary of embeddings.
|
||||
embedding_size (int): The size of each embedding vector.
|
||||
param_init (str): The initialize way of embedding table. Default: 'normal'.
|
||||
target (str): Specify the target where the op is executed. Default: 'CPU'.
|
||||
slice_mode (str): The slicing way in semi auto parallel/auto parallel. Default: 'batch_slice'.
|
||||
manual_shapes (tuple): The accompaniment array in field slice mode.
|
||||
|
||||
Inputs:
|
||||
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||
The Tensor slice, instead of the entire Tensor.
|
||||
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
||||
Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`,
|
||||
and the exceeding part will be filled with 0 in the output.
|
||||
Specifies the indices of elements of the original Tensor. Values can be out of range of embedding_table,
|
||||
and the exceeding part will be filled with 0 in the output. Input_indices should only be a 2d tensor in
|
||||
this interface.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
|
||||
|
||||
Examples:
|
||||
>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
|
||||
>>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)
|
||||
>>> out = nn.EmbeddingLookup()(input_params, input_indices)
|
||||
[[[10, 11], [8 ,9]], [[14, 15], [12, 13]]]
|
||||
>>> out = nn.EmbeddingLookup(4,2)(input_indices)
|
||||
"""
|
||||
def __init__(self, target='CPU'):
|
||||
def __init__(self, vocab_size, embedding_size, param_init='normal',
|
||||
target='CPU', slice_mode='batch_slice', manual_shapes=None):
|
||||
super(EmbeddingLookup, self).__init__()
|
||||
self.target = target
|
||||
if target not in ('CPU', 'DEVICE'):
|
||||
|
@ -144,10 +155,60 @@ class EmbeddingLookup(Cell):
|
|||
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
|
||||
self.gatherv2 = P.GatherV2()
|
||||
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
|
||||
|
||||
def construct(self, params, indices):
|
||||
if self.target == "CPU":
|
||||
out = self.embeddinglookup(params, indices, 0)
|
||||
self.embedding_table = Parameter(initializer(param_init, [vocab_size, embedding_size]),
|
||||
name='embedding_table')
|
||||
parallel_mode = _get_parallel_mode()
|
||||
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||
if slice_mode == EmbeddingLookUpSplitMode.FIELD_SLICE and is_auto_parallel:
|
||||
if not manual_shapes:
|
||||
raise ValueError("in slice field mode, the manual_shapes should not be none")
|
||||
if not isinstance(manual_shapes, tuple):
|
||||
raise TypeError("manual_shapes type must be tuple(int) cannot be {}!".format(type(manual_shapes)))
|
||||
for dim in manual_shapes:
|
||||
Validator.check_integer('manul shape dim', dim, 0, Rel.GT, self.cls_name)
|
||||
self.gatherv2.add_prim_attr("manual_split", manual_shapes)
|
||||
self.embeddinglookup.add_prim_attr("manual_split", manual_shapes)
|
||||
self.gatherv2.set_strategy(((get_group_size(), 1), (1, get_group_size())))
|
||||
self.embeddinglookup.set_strategy(((get_group_size(), 1), (1, get_group_size())))
|
||||
elif slice_mode == EmbeddingLookUpSplitMode.TABLE_ROW_SLICE and is_auto_parallel:
|
||||
self.gatherv2.set_strategy(((get_group_size(), 1), (1, 1)))
|
||||
self.embeddinglookup.set_strategy(((get_group_size(), 1), (1, 1)))
|
||||
elif slice_mode == EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE and is_auto_parallel:
|
||||
self.gatherv2.set_strategy(((1, get_group_size()), (1, 1)))
|
||||
self.embeddinglookup.set_strategy(((1, get_group_size()), (1, 1)))
|
||||
elif slice_mode == EmbeddingLookUpSplitMode.BATCH_SLICE and is_auto_parallel:
|
||||
self.gatherv2.set_strategy(((1, 1), (get_group_size(), 1)))
|
||||
self.embeddinglookup.set_strategy(((1, 1), (get_group_size(), 1)))
|
||||
else:
|
||||
out = self.gatherv2(params, indices, 0)
|
||||
if is_auto_parallel:
|
||||
raise ValueError("slice_mode should support mode in nn.EmbeddingLookUpSplitMode, but get "
|
||||
+ str(slice_mode))
|
||||
|
||||
def construct(self, indices):
|
||||
if self.target == "CPU":
|
||||
out = self.embeddinglookup(self.embedding_table, indices, 0)
|
||||
else:
|
||||
out = self.gatherv2(self.embedding_table, indices, 0)
|
||||
return out
|
||||
|
||||
|
||||
class EmbeddingLookUpSplitMode:
|
||||
"""
|
||||
EmbeddingLookUp slice options in auto parallel and semi auto parallel mode.
|
||||
|
||||
There are five kinds of slice options, "BATCH_SLICE", "FIELD_SLICE",
|
||||
"TABLE_ROW_SLICE" and "TABLE_COLUMN_SLICE". Default: "BATCH_SLICE".
|
||||
|
||||
- BATCH_SLICE: Slicing batch dimensions of indices.
|
||||
- FIELD_SLICE: Slicing field dimensions of indices.
|
||||
- TABLE_ROW_SLICE: Slicing row of table.
|
||||
- TABLE_COLUMN_SLICE: Slicing column of table.
|
||||
|
||||
MODE_LIST: The list for all supported parallel modes.
|
||||
"""
|
||||
|
||||
BATCH_SLICE = "batch_slice"
|
||||
FIELD_SLICE = "field_slice"
|
||||
TABLE_ROW_SLICE = "table_row_slice"
|
||||
TABLE_COLUMN_SLICE = "table_column_slice"
|
||||
MODE_LIST = [BATCH_SLICE, FIELD_SLICE, TABLE_ROW_SLICE, TABLE_COLUMN_SLICE]
|
||||
|
|
|
@ -209,19 +209,22 @@ class WideDeepModel(nn.Cell):
|
|||
if is_auto_parallel and host_device_mix:
|
||||
self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),))
|
||||
self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1)))
|
||||
self.deep_embeddinglookup = nn.EmbeddingLookup()
|
||||
self.deep_embeddinglookup.embeddinglookup.set_strategy(((1, get_group_size()), (1, 1)))
|
||||
self.wide_embeddinglookup = nn.EmbeddingLookup()
|
||||
self.wide_embeddinglookup.embeddinglookup.set_strategy(((get_group_size(), 1), (1, 1)))
|
||||
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim,
|
||||
slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE)
|
||||
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1,
|
||||
slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_ROW_SLICE)
|
||||
self.deep_mul.set_strategy(((1, 1, get_group_size()), (1, 1, 1)))
|
||||
self.deep_reshape.add_prim_attr("skip_redistribution", True)
|
||||
self.reduce_sum.add_prim_attr("cross_batch", True)
|
||||
self.embedding_table = self.deep_embeddinglookup.embedding_table
|
||||
elif parameter_server:
|
||||
self.deep_embeddinglookup = nn.EmbeddingLookup()
|
||||
self.wide_embeddinglookup = nn.EmbeddingLookup()
|
||||
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim)
|
||||
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1)
|
||||
self.embedding_table = self.deep_embeddinglookup.embedding_table
|
||||
else:
|
||||
self.deep_embeddinglookup = nn.EmbeddingLookup(target='DEVICE')
|
||||
self.wide_embeddinglookup = nn.EmbeddingLookup(target='DEVICE')
|
||||
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target='DEVICE')
|
||||
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target='DEVICE')
|
||||
self.embedding_table = self.deep_embeddinglookup.embedding_table
|
||||
|
||||
def construct(self, id_hldr, wt_hldr):
|
||||
"""
|
||||
|
@ -231,11 +234,11 @@ class WideDeepModel(nn.Cell):
|
|||
"""
|
||||
mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1))
|
||||
# Wide layer
|
||||
wide_id_weight = self.wide_embeddinglookup(self.wide_w, id_hldr)
|
||||
wide_id_weight = self.wide_embeddinglookup(id_hldr)
|
||||
wx = self.wide_mul(wide_id_weight, mask)
|
||||
wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1))
|
||||
# Deep layer
|
||||
deep_id_embs = self.deep_embeddinglookup(self.embedding_table, id_hldr)
|
||||
deep_id_embs = self.deep_embeddinglookup(id_hldr)
|
||||
vx = self.deep_mul(deep_id_embs, mask)
|
||||
deep_in = self.deep_reshape(vx, (-1, self.field_size * self.emb_dim))
|
||||
deep_in = self.dense_layer_1(deep_in)
|
||||
|
|
|
@ -24,8 +24,7 @@ from mindspore.common import dtype as mstype
|
|||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import TruncatedNormal, initializer
|
||||
from mindspore import Parameter
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
|
||||
parser = argparse.ArgumentParser(description="test_sparse_embedding")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend")
|
||||
|
@ -53,16 +52,13 @@ class LeNet5(nn.Cell):
|
|||
super(LeNet5, self).__init__()
|
||||
self.cast = P.Cast()
|
||||
self.flatten = nn.Flatten()
|
||||
self.embedding_table = Parameter(
|
||||
initializer("normal", (16, 4), mstype.float32), name="embedding_table"
|
||||
)
|
||||
self.embedding = nn.EmbeddingLookup()
|
||||
self.embedding = nn.EmbeddingLookup(16, 4)
|
||||
self.relu = nn.ReLU()
|
||||
self.fc = fc_with_initialize(12, num_class)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.cast(x, mstype.int32)
|
||||
x = self.embedding(self.embedding_table, x)
|
||||
x = self.embedding(x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
@ -72,7 +68,7 @@ def do_sparse_embedding(ps=False):
|
|||
epoch = 10
|
||||
net = LeNet5(10)
|
||||
if ps:
|
||||
net.embedding_table.set_param_ps()
|
||||
net.embedding.embedding_table.set_param_ps()
|
||||
|
||||
optimizer = Adam(filter(lambda x: x.requires_grad, net.get_parameters()))
|
||||
optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU")
|
||||
|
|
|
@ -421,17 +421,16 @@ def test_row_tensor_with_control_flow_if():
|
|||
|
||||
|
||||
class EmbeddingLookUpBnNet(nn.Cell):
|
||||
def __init__(self, param_np, target='CPU'):
|
||||
def __init__(self, vocab_size, embedding_size, target='CPU'):
|
||||
super().__init__()
|
||||
self.param = Parameter(Tensor(param_np), name="w1")
|
||||
self.embedding_lookup = nn.EmbeddingLookup(target=target)
|
||||
self.embedding_lookup = nn.EmbeddingLookup(vocab_size, embedding_size, param_init='ones', target=target)
|
||||
self.bn = nn.BatchNorm2d(num_features=3)
|
||||
self.mul = P.Mul()
|
||||
self.reshape = P.Reshape()
|
||||
self.relu = nn.PReLU()
|
||||
|
||||
def construct(self, indices):
|
||||
x = self.embedding_lookup(self.param, indices)
|
||||
x = self.embedding_lookup(indices)
|
||||
x = self.reshape(x, (2, 3, 2, 2))
|
||||
x = self.relu(x)
|
||||
x = self.bn(x)
|
||||
|
@ -439,10 +438,9 @@ class EmbeddingLookUpBnNet(nn.Cell):
|
|||
|
||||
|
||||
def test_embedding_lookup_with_mix_precision():
|
||||
param_np = np.ones([8, 8]).astype(np.float32)
|
||||
data = Tensor(np.array([0, 1, 2]).astype(np.int32))
|
||||
label = Tensor(np.random.randn(*(2, 3, 2, 2)).astype(np.float32))
|
||||
net = EmbeddingLookUpBnNet(param_np, target='CPU')
|
||||
net = EmbeddingLookUpBnNet(8, 8, target='CPU')
|
||||
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(reduction='mean')
|
||||
optimizer = nn.Adam(params=net.trainable_params(), learning_rate=0.1)
|
||||
|
|
|
@ -69,14 +69,12 @@ def test_bprop_with_sparse_feature_mirror():
|
|||
super(Net, self).__init__()
|
||||
if shape is None:
|
||||
shape = [8, 8]
|
||||
weight = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
self.weight = Parameter(weight, "w")
|
||||
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
||||
self.embeddinglookup = nn.EmbeddingLookup()
|
||||
self.embeddinglookup = nn.EmbeddingLookup(64, 64, param_init='ones')
|
||||
self.embeddinglookup.embeddinglookup.set_strategy(((1, 1), (8, 1)))
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.embeddinglookup(self.weight, self.index)
|
||||
out = self.embeddinglookup(self.index)
|
||||
|
||||
return out
|
||||
|
||||
|
|
Loading…
Reference in New Issue