forked from mindspore-Ecosystem/mindspore
!1685 [Auto parallel] Fix the bugs in Embeddinglookup forward operator
Merge pull request !1685 from Xiaoda/fix-the-embeddinglookup-bug
This commit is contained in:
commit
b9ba99bb13
|
@ -576,19 +576,21 @@ class EmbeddingLookup(PrimitiveWithInfer):
|
|||
"""
|
||||
Returns a slice of input tensor based on the specified indices and axis. This Primitive has the similar
|
||||
functionality as GatherV2, but has three more inputs: `offset`, `reduce_scatter_flag` and `split_num`.
|
||||
This primitive runs on the host instead of devices.
|
||||
|
||||
Inputs:
|
||||
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||
The Tensor slice, instead of the entire Tensor.
|
||||
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
||||
Specifies the indices of elements of the original Tensor. Must be in the range
|
||||
`[0, input_param.shape()[axis])`.
|
||||
Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`,
|
||||
and the exceeding part will be filled with 0 in the output.
|
||||
- **axis** (int) - Specifies the dimension index to gather indices.
|
||||
- **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices
|
||||
are equal to `input_indices` minus `offset`.
|
||||
- **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not.
|
||||
Only constant value is allowed.
|
||||
- **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable
|
||||
is used only if `reduce_scatter_flag` is True.
|
||||
is used only if `reduce_scatter_flag` is True. Only constant value is allowed.
|
||||
|
||||
|
||||
Outputs:
|
||||
|
@ -627,12 +629,20 @@ class EmbeddingLookup(PrimitiveWithInfer):
|
|||
if axis_v < 0:
|
||||
axis_v += rank
|
||||
out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
|
||||
if reduce_scatter_flag:
|
||||
# partition the tensor along the dimension 0.
|
||||
if out_shape[0] % split_num['value'] != 0:
|
||||
raise ValueError("The dimension 0 of the shape: %d, is not divisible by split_num: %d." %
|
||||
(out_shape[0], split_num['value']))
|
||||
out_shape[0] = out_shape[0] // split_num['value']
|
||||
if reduce_scatter_flag is None:
|
||||
raise ValueError("The value of 'reduce_scatter_flag' is None.")
|
||||
reduce_scatter_flag_value = reduce_scatter_flag['value']
|
||||
if split_num is None:
|
||||
raise ValueError("The value of 'split_num_value' is None.")
|
||||
split_num_value = split_num['value']
|
||||
if reduce_scatter_flag_value is True:
|
||||
# Partition the tensor along the dimension 0. The shape size of dimension 0 should be divisible by
|
||||
# (split_num * 8)
|
||||
if out_shape[0] % (split_num_value * 8) != 0:
|
||||
raise ValueError("The dimension 0 of the shape: %d, is not divisible by: %d." %
|
||||
(out_shape[0], (split_num_value * 8)))
|
||||
# After 'Concat' on host, the shape size of dimension 0 is: out_shape[0] // 8
|
||||
out_shape[0] = out_shape[0] // 8
|
||||
out = {'shape': out_shape,
|
||||
'dtype': params['dtype'],
|
||||
'value': None}
|
||||
|
|
|
@ -64,7 +64,7 @@ def test_embeddinglookup_reducescatter_false():
|
|||
|
||||
|
||||
def test_embeddinglookup_reducescatter_true():
|
||||
shape = [8, 8]
|
||||
shape = [64, 8]
|
||||
axis = 0
|
||||
offset = 8
|
||||
reduce_scatter_flag = True
|
||||
|
@ -73,5 +73,5 @@ def test_embeddinglookup_reducescatter_true():
|
|||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([1, 32, 8]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
|
Loading…
Reference in New Issue