diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 19828d38717..69ca492aac7 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -576,19 +576,21 @@ class EmbeddingLookup(PrimitiveWithInfer): """ Returns a slice of input tensor based on the specified indices and axis. This Primitive has the similar functionality as GatherV2, but has three more inputs: `offset`, `reduce_scatter_flag` and `split_num`. + This primitive runs on the host instead of devices. Inputs: - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The Tensor slice, instead of the entire Tensor. - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. - Specifies the indices of elements of the original Tensor. Must be in the range - `[0, input_param.shape()[axis])`. + Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`, + and the exceeding part will be filled with 0 in the output. - **axis** (int) - Specifies the dimension index to gather indices. - **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices are equal to `input_indices` minus `offset`. - **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not. + Only constant value is allowed. - **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable - is used only if `reduce_scatter_flag` is True. + is used only if `reduce_scatter_flag` is True. Only constant value is allowed. Outputs: @@ -627,12 +629,20 @@ class EmbeddingLookup(PrimitiveWithInfer): if axis_v < 0: axis_v += rank out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] - if reduce_scatter_flag: - # partition the tensor along the dimension 0. - if out_shape[0] % split_num['value'] != 0: - raise ValueError("The dimension 0 of the shape: %d, is not divisible by split_num: %d." % - (out_shape[0], split_num['value'])) - out_shape[0] = out_shape[0] // split_num['value'] + if reduce_scatter_flag is None: + raise ValueError("The value of 'reduce_scatter_flag' is None.") + reduce_scatter_flag_value = reduce_scatter_flag['value'] + if split_num is None: + raise ValueError("The value of 'split_num_value' is None.") + split_num_value = split_num['value'] + if reduce_scatter_flag_value is True: + # Partition the tensor along the dimension 0. The shape size of dimension 0 should be divisible by + # (split_num * 8) + if out_shape[0] % (split_num_value * 8) != 0: + raise ValueError("The dimension 0 of the shape: %d, is not divisible by: %d." % + (out_shape[0], (split_num_value * 8))) + # After 'Concat' on host, the shape size of dimension 0 is: out_shape[0] // 8 + out_shape[0] = out_shape[0] // 8 out = {'shape': out_shape, 'dtype': params['dtype'], 'value': None} diff --git a/tests/ut/python/parallel/test_embeddinglookup.py b/tests/ut/python/parallel/test_embeddinglookup.py index b934028a480..953b59ecbc4 100644 --- a/tests/ut/python/parallel/test_embeddinglookup.py +++ b/tests/ut/python/parallel/test_embeddinglookup.py @@ -64,7 +64,7 @@ def test_embeddinglookup_reducescatter_false(): def test_embeddinglookup_reducescatter_true(): - shape = [8, 8] + shape = [64, 8] axis = 0 offset = 8 reduce_scatter_flag = True @@ -73,5 +73,5 @@ def test_embeddinglookup_reducescatter_true(): net.set_auto_parallel() x = Tensor(np.ones([64, 32]), dtype=ms.float32) - y = Tensor(np.ones([1, 32, 8]), dtype=ms.float32) + y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32) _executor.compile(net, x, y)