diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.cc index 2d535ca31af..15bfc303bfb 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.cc @@ -87,7 +87,7 @@ Status UnsortedSegmentOpInfo::CheckStrategy(const StrategyPtr &strategy) { Shape input_b_shape = inputs_shape_.at(1); // The size of the input b must be equal or smaller than input a for (size_t i = 0; i < input_b_shape.size(); ++i) { - if ((sub_a_strategy[i] != sub_b_strategy[i]) && (input_a_shape[i] != input_b_shape[i])) { + if ((sub_a_strategy[i] != sub_b_strategy[i]) || (input_a_shape[i] != input_b_shape[i])) { MS_LOG(ERROR) << name_ << " : Invalid strategy. The shape and the strategy of the input0 and input1 " "should be same before the front size of the input[1]"; diff --git a/tests/ut/python/parallel/test_unsortedsegmentsum.py b/tests/ut/python/parallel/test_unsortedsegmentsum.py index 2f7b4648705..3085be87d62 100644 --- a/tests/ut/python/parallel/test_unsortedsegmentsum.py +++ b/tests/ut/python/parallel/test_unsortedsegmentsum.py @@ -14,6 +14,7 @@ # ============================================================================ import numpy as np +import pytest import mindspore as ms import mindspore.nn as nn @@ -112,6 +113,17 @@ def test_unsortedsegmentsum_model_parallel_index_slice_3d(): compile_graph(x, y, num_segments, strategy1, strategy2) +def test_unsortedsegmentsum_model_parallel_index_slice_diff_inputs(): + context.set_auto_parallel_context(device_num=4, global_rank=0) + x = Tensor(np.ones((4, 4, 8)), ms.float32) + y = Tensor(np.ones((4, 4)), ms.int32) + num_segments = 16 + strategy1 = (2, 2, 1) + strategy2 = (2, 4) + with pytest.raises(RuntimeError): + compile_graph(x, y, num_segments, strategy1, strategy2) + + def test_unsortedsegmentsum_model_parallel_vector_slice_2d(): context.set_auto_parallel_context(device_num=4, global_rank=0) x = Tensor(np.ones((4, 8)), ms.float32)