embeddinglookup wrap

This commit is contained in:
yao_yf 2020-08-04 15:49:25 +08:00
parent 5adba834d0
commit e4de26d5bc
5 changed files with 100 additions and 44 deletions

View File

@ -18,10 +18,14 @@ from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore._checkparam import Validator
from mindspore.communication.management import get_group_size
from mindspore.train.parallel_utils import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode
from ..cell import Cell from ..cell import Cell
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator, Rel
__all__ = ['Embedding', 'EmbeddingLookup'] __all__ = ['Embedding', 'EmbeddingLookup', 'EmbeddingLookUpSplitMode']
class Embedding(Cell): class Embedding(Cell):
r""" r"""
@ -114,29 +118,36 @@ class EmbeddingLookup(Cell):
When 'target' is set to 'CPU', this module will use When 'target' is set to 'CPU', this module will use
P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
specified 'offset = 0' to lookup table. specified 'offset = 0' to lookup table.
when 'target' is set to 'DEVICE', this module will use P.GatherV2() which When 'target' is set to 'DEVICE', this module will use P.GatherV2() which
specified 'axis = 0' to lookup table. specified 'axis = 0' to lookup table.
In field slice mode, the manual_shapes should be given. It is a tuple ,where
the element is (vocab[i], offset[i]), vocab[i] is the row numbers for i-th
part and offset[i] is the feature id offset for i-th part. The feature id in
i-th part will be subtracted by offset[i] to ensure the id start from 0.
Args: Args:
vocab_size (int): Size of the dictionary of embeddings.
embedding_size (int): The size of each embedding vector.
param_init (str): The initialize way of embedding table. Default: 'normal'.
target (str): Specify the target where the op is executed. Default: 'CPU'. target (str): Specify the target where the op is executed. Default: 'CPU'.
slice_mode (str): The slicing way in semi auto parallel/auto parallel. Default: 'batch_slice'.
manual_shapes (tuple): The accompaniment array in field slice mode.
Inputs: Inputs:
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
The Tensor slice, instead of the entire Tensor.
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`, Specifies the indices of elements of the original Tensor. Values can be out of range of embedding_table,
and the exceeding part will be filled with 0 in the output. and the exceeding part will be filled with 0 in the output. Input_indices should only be a 2d tensor in
this interface.
Outputs: Outputs:
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
Examples: Examples:
>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
>>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32) >>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)
>>> out = nn.EmbeddingLookup()(input_params, input_indices) >>> out = nn.EmbeddingLookup(4,2)(input_indices)
[[[10, 11], [8 ,9]], [[14, 15], [12, 13]]]
""" """
def __init__(self, target='CPU'): def __init__(self, vocab_size, embedding_size, param_init='normal',
target='CPU', slice_mode='batch_slice', manual_shapes=None):
super(EmbeddingLookup, self).__init__() super(EmbeddingLookup, self).__init__()
self.target = target self.target = target
if target not in ('CPU', 'DEVICE'): if target not in ('CPU', 'DEVICE'):
@ -144,10 +155,60 @@ class EmbeddingLookup(Cell):
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.') + str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
self.gatherv2 = P.GatherV2() self.gatherv2 = P.GatherV2()
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
self.embedding_table = Parameter(initializer(param_init, [vocab_size, embedding_size]),
def construct(self, params, indices): name='embedding_table')
if self.target == "CPU": parallel_mode = _get_parallel_mode()
out = self.embeddinglookup(params, indices, 0) is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
if slice_mode == EmbeddingLookUpSplitMode.FIELD_SLICE and is_auto_parallel:
if not manual_shapes:
raise ValueError("in slice field mode, the manual_shapes should not be none")
if not isinstance(manual_shapes, tuple):
raise TypeError("manual_shapes type must be tuple(int) cannot be {}!".format(type(manual_shapes)))
for dim in manual_shapes:
Validator.check_integer('manul shape dim', dim, 0, Rel.GT, self.cls_name)
self.gatherv2.add_prim_attr("manual_split", manual_shapes)
self.embeddinglookup.add_prim_attr("manual_split", manual_shapes)
self.gatherv2.set_strategy(((get_group_size(), 1), (1, get_group_size())))
self.embeddinglookup.set_strategy(((get_group_size(), 1), (1, get_group_size())))
elif slice_mode == EmbeddingLookUpSplitMode.TABLE_ROW_SLICE and is_auto_parallel:
self.gatherv2.set_strategy(((get_group_size(), 1), (1, 1)))
self.embeddinglookup.set_strategy(((get_group_size(), 1), (1, 1)))
elif slice_mode == EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE and is_auto_parallel:
self.gatherv2.set_strategy(((1, get_group_size()), (1, 1)))
self.embeddinglookup.set_strategy(((1, get_group_size()), (1, 1)))
elif slice_mode == EmbeddingLookUpSplitMode.BATCH_SLICE and is_auto_parallel:
self.gatherv2.set_strategy(((1, 1), (get_group_size(), 1)))
self.embeddinglookup.set_strategy(((1, 1), (get_group_size(), 1)))
else: else:
out = self.gatherv2(params, indices, 0) if is_auto_parallel:
raise ValueError("slice_mode should support mode in nn.EmbeddingLookUpSplitMode, but get "
+ str(slice_mode))
def construct(self, indices):
if self.target == "CPU":
out = self.embeddinglookup(self.embedding_table, indices, 0)
else:
out = self.gatherv2(self.embedding_table, indices, 0)
return out return out
class EmbeddingLookUpSplitMode:
"""
EmbeddingLookUp slice options in auto parallel and semi auto parallel mode.
There are five kinds of slice options, "BATCH_SLICE", "FIELD_SLICE",
"TABLE_ROW_SLICE" and "TABLE_COLUMN_SLICE". Default: "BATCH_SLICE".
- BATCH_SLICE: Slicing batch dimensions of indices.
- FIELD_SLICE: Slicing field dimensions of indices.
- TABLE_ROW_SLICE: Slicing row of table.
- TABLE_COLUMN_SLICE: Slicing column of table.
MODE_LIST: The list for all supported parallel modes.
"""
BATCH_SLICE = "batch_slice"
FIELD_SLICE = "field_slice"
TABLE_ROW_SLICE = "table_row_slice"
TABLE_COLUMN_SLICE = "table_column_slice"
MODE_LIST = [BATCH_SLICE, FIELD_SLICE, TABLE_ROW_SLICE, TABLE_COLUMN_SLICE]

View File

@ -209,19 +209,22 @@ class WideDeepModel(nn.Cell):
if is_auto_parallel and host_device_mix: if is_auto_parallel and host_device_mix:
self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),)) self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),))
self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1))) self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1)))
self.deep_embeddinglookup = nn.EmbeddingLookup() self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim,
self.deep_embeddinglookup.embeddinglookup.set_strategy(((1, get_group_size()), (1, 1))) slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE)
self.wide_embeddinglookup = nn.EmbeddingLookup() self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1,
self.wide_embeddinglookup.embeddinglookup.set_strategy(((get_group_size(), 1), (1, 1))) slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_ROW_SLICE)
self.deep_mul.set_strategy(((1, 1, get_group_size()), (1, 1, 1))) self.deep_mul.set_strategy(((1, 1, get_group_size()), (1, 1, 1)))
self.deep_reshape.add_prim_attr("skip_redistribution", True) self.deep_reshape.add_prim_attr("skip_redistribution", True)
self.reduce_sum.add_prim_attr("cross_batch", True) self.reduce_sum.add_prim_attr("cross_batch", True)
self.embedding_table = self.deep_embeddinglookup.embedding_table
elif parameter_server: elif parameter_server:
self.deep_embeddinglookup = nn.EmbeddingLookup() self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim)
self.wide_embeddinglookup = nn.EmbeddingLookup() self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1)
self.embedding_table = self.deep_embeddinglookup.embedding_table
else: else:
self.deep_embeddinglookup = nn.EmbeddingLookup(target='DEVICE') self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target='DEVICE')
self.wide_embeddinglookup = nn.EmbeddingLookup(target='DEVICE') self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target='DEVICE')
self.embedding_table = self.deep_embeddinglookup.embedding_table
def construct(self, id_hldr, wt_hldr): def construct(self, id_hldr, wt_hldr):
""" """
@ -231,11 +234,11 @@ class WideDeepModel(nn.Cell):
""" """
mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1)) mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1))
# Wide layer # Wide layer
wide_id_weight = self.wide_embeddinglookup(self.wide_w, id_hldr) wide_id_weight = self.wide_embeddinglookup(id_hldr)
wx = self.wide_mul(wide_id_weight, mask) wx = self.wide_mul(wide_id_weight, mask)
wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1)) wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1))
# Deep layer # Deep layer
deep_id_embs = self.deep_embeddinglookup(self.embedding_table, id_hldr) deep_id_embs = self.deep_embeddinglookup(id_hldr)
vx = self.deep_mul(deep_id_embs, mask) vx = self.deep_mul(deep_id_embs, mask)
deep_in = self.deep_reshape(vx, (-1, self.field_size * self.emb_dim)) deep_in = self.deep_reshape(vx, (-1, self.field_size * self.emb_dim))
deep_in = self.dense_layer_1(deep_in) deep_in = self.dense_layer_1(deep_in)

View File

@ -24,8 +24,7 @@ from mindspore.common import dtype as mstype
from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Adam from mindspore.nn.optim import Adam
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common.initializer import TruncatedNormal, initializer from mindspore.common.initializer import TruncatedNormal
from mindspore import Parameter
parser = argparse.ArgumentParser(description="test_sparse_embedding") parser = argparse.ArgumentParser(description="test_sparse_embedding")
parser.add_argument("--device_target", type=str, default="Ascend") parser.add_argument("--device_target", type=str, default="Ascend")
@ -53,16 +52,13 @@ class LeNet5(nn.Cell):
super(LeNet5, self).__init__() super(LeNet5, self).__init__()
self.cast = P.Cast() self.cast = P.Cast()
self.flatten = nn.Flatten() self.flatten = nn.Flatten()
self.embedding_table = Parameter( self.embedding = nn.EmbeddingLookup(16, 4)
initializer("normal", (16, 4), mstype.float32), name="embedding_table"
)
self.embedding = nn.EmbeddingLookup()
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.fc = fc_with_initialize(12, num_class) self.fc = fc_with_initialize(12, num_class)
def construct(self, x): def construct(self, x):
x = self.cast(x, mstype.int32) x = self.cast(x, mstype.int32)
x = self.embedding(self.embedding_table, x) x = self.embedding(x)
x = self.flatten(x) x = self.flatten(x)
x = self.fc(x) x = self.fc(x)
return x return x
@ -72,7 +68,7 @@ def do_sparse_embedding(ps=False):
epoch = 10 epoch = 10
net = LeNet5(10) net = LeNet5(10)
if ps: if ps:
net.embedding_table.set_param_ps() net.embedding.embedding_table.set_param_ps()
optimizer = Adam(filter(lambda x: x.requires_grad, net.get_parameters())) optimizer = Adam(filter(lambda x: x.requires_grad, net.get_parameters()))
optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU") optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU")

View File

@ -421,17 +421,16 @@ def test_row_tensor_with_control_flow_if():
class EmbeddingLookUpBnNet(nn.Cell): class EmbeddingLookUpBnNet(nn.Cell):
def __init__(self, param_np, target='CPU'): def __init__(self, vocab_size, embedding_size, target='CPU'):
super().__init__() super().__init__()
self.param = Parameter(Tensor(param_np), name="w1") self.embedding_lookup = nn.EmbeddingLookup(vocab_size, embedding_size, param_init='ones', target=target)
self.embedding_lookup = nn.EmbeddingLookup(target=target)
self.bn = nn.BatchNorm2d(num_features=3) self.bn = nn.BatchNorm2d(num_features=3)
self.mul = P.Mul() self.mul = P.Mul()
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.relu = nn.PReLU() self.relu = nn.PReLU()
def construct(self, indices): def construct(self, indices):
x = self.embedding_lookup(self.param, indices) x = self.embedding_lookup(indices)
x = self.reshape(x, (2, 3, 2, 2)) x = self.reshape(x, (2, 3, 2, 2))
x = self.relu(x) x = self.relu(x)
x = self.bn(x) x = self.bn(x)
@ -439,10 +438,9 @@ class EmbeddingLookUpBnNet(nn.Cell):
def test_embedding_lookup_with_mix_precision(): def test_embedding_lookup_with_mix_precision():
param_np = np.ones([8, 8]).astype(np.float32)
data = Tensor(np.array([0, 1, 2]).astype(np.int32)) data = Tensor(np.array([0, 1, 2]).astype(np.int32))
label = Tensor(np.random.randn(*(2, 3, 2, 2)).astype(np.float32)) label = Tensor(np.random.randn(*(2, 3, 2, 2)).astype(np.float32))
net = EmbeddingLookUpBnNet(param_np, target='CPU') net = EmbeddingLookUpBnNet(8, 8, target='CPU')
criterion = nn.SoftmaxCrossEntropyWithLogits(reduction='mean') criterion = nn.SoftmaxCrossEntropyWithLogits(reduction='mean')
optimizer = nn.Adam(params=net.trainable_params(), learning_rate=0.1) optimizer = nn.Adam(params=net.trainable_params(), learning_rate=0.1)

View File

@ -69,14 +69,12 @@ def test_bprop_with_sparse_feature_mirror():
super(Net, self).__init__() super(Net, self).__init__()
if shape is None: if shape is None:
shape = [8, 8] shape = [8, 8]
weight = Tensor(np.ones([64, 64]), dtype=ms.float32)
self.weight = Parameter(weight, "w")
self.index = Tensor(np.ones(shape), dtype=ms.int32) self.index = Tensor(np.ones(shape), dtype=ms.int32)
self.embeddinglookup = nn.EmbeddingLookup() self.embeddinglookup = nn.EmbeddingLookup(64, 64, param_init='ones')
self.embeddinglookup.embeddinglookup.set_strategy(((1, 1), (8, 1))) self.embeddinglookup.embeddinglookup.set_strategy(((1, 1), (8, 1)))
def construct(self, x, b): def construct(self, x, b):
out = self.embeddinglookup(self.weight, self.index) out = self.embeddinglookup(self.index)
return out return out