!1685 [Auto parallel] Fix the bugs in Embeddinglookup forward operator

Merge pull request !1685 from Xiaoda/fix-the-embeddinglookup-bug
This commit is contained in:
mindspore-ci-bot 2020-06-01 19:31:50 +08:00 committed by Gitee
commit b9ba99bb13
2 changed files with 21 additions and 11 deletions

View File

@ -576,19 +576,21 @@ class EmbeddingLookup(PrimitiveWithInfer):
"""
Returns a slice of input tensor based on the specified indices and axis. This Primitive has the similar
functionality as GatherV2, but has three more inputs: `offset`, `reduce_scatter_flag` and `split_num`.
This primitive runs on the host instead of devices.
Inputs:
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
The Tensor slice, instead of the entire Tensor.
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
Specifies the indices of elements of the original Tensor. Must be in the range
`[0, input_param.shape()[axis])`.
Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`,
and the exceeding part will be filled with 0 in the output.
- **axis** (int) - Specifies the dimension index to gather indices.
- **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices
are equal to `input_indices` minus `offset`.
- **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not.
Only constant value is allowed.
- **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable
is used only if `reduce_scatter_flag` is True.
is used only if `reduce_scatter_flag` is True. Only constant value is allowed.
Outputs:
@ -627,12 +629,20 @@ class EmbeddingLookup(PrimitiveWithInfer):
if axis_v < 0:
axis_v += rank
out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
if reduce_scatter_flag:
# partition the tensor along the dimension 0.
if out_shape[0] % split_num['value'] != 0:
raise ValueError("The dimension 0 of the shape: %d, is not divisible by split_num: %d." %
(out_shape[0], split_num['value']))
out_shape[0] = out_shape[0] // split_num['value']
if reduce_scatter_flag is None:
raise ValueError("The value of 'reduce_scatter_flag' is None.")
reduce_scatter_flag_value = reduce_scatter_flag['value']
if split_num is None:
raise ValueError("The value of 'split_num_value' is None.")
split_num_value = split_num['value']
if reduce_scatter_flag_value is True:
# Partition the tensor along the dimension 0. The shape size of dimension 0 should be divisible by
# (split_num * 8)
if out_shape[0] % (split_num_value * 8) != 0:
raise ValueError("The dimension 0 of the shape: %d, is not divisible by: %d." %
(out_shape[0], (split_num_value * 8)))
# After 'Concat' on host, the shape size of dimension 0 is: out_shape[0] // 8
out_shape[0] = out_shape[0] // 8
out = {'shape': out_shape,
'dtype': params['dtype'],
'value': None}

View File

@ -64,7 +64,7 @@ def test_embeddinglookup_reducescatter_false():
def test_embeddinglookup_reducescatter_true():
shape = [8, 8]
shape = [64, 8]
axis = 0
offset = 8
reduce_scatter_flag = True
@ -73,5 +73,5 @@ def test_embeddinglookup_reducescatter_true():
net.set_auto_parallel()
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([1, 32, 8]), dtype=ms.float32)
y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32)
_executor.compile(net, x, y)