diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.cc index 1707e1d03d4..238c2df02a8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.cc @@ -22,15 +22,35 @@ MS_REG_GPU_KERNEL_ONE( UnsortedSegmentMax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), UnsortedSegmentMaxGpuKernel, float) - MS_REG_GPU_KERNEL_ONE( UnsortedSegmentMax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), UnsortedSegmentMaxGpuKernel, half) - MS_REG_GPU_KERNEL_ONE( UnsortedSegmentMax, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), UnsortedSegmentMaxGpuKernel, int) +// Dynamic Mode +MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + UnsortedSegmentMaxGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat16), + UnsortedSegmentMaxGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + UnsortedSegmentMaxGpuKernel, int) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h index 0b2fd088272..9350ebca81d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h @@ -28,14 +28,7 @@ namespace kernel { template class UnsortedSegmentMaxGpuKernel : public GpuKernel { public: - UnsortedSegmentMaxGpuKernel() - : num_segments_(1), - inner_size_(1), - outer_size_(1), - input_size_(1), - segment_ids_size_(1), - output_size_(1), - is_null_input_(false) {} + UnsortedSegmentMaxGpuKernel() { ResetResource(); } ~UnsortedSegmentMaxGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -60,18 +53,24 @@ class UnsortedSegmentMaxGpuKernel : public GpuKernel { } bool Init(const CNodePtr &kernel_node) override { - auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto input_shapes = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0); is_null_input_ = CHECK_NULL_INPUT(input_shapes); if (is_null_input_) { MS_LOG(WARNING) << "UnsortedSegmentMax input is null"; InitSizeLists(); return true; } - auto segment_ids_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); // we get that from computation + auto segment_ids_shapes = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1); + auto output_shapes = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0); + + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num == 3) { + MS_LOG(INFO) << "UnsortedSegmentMax Kernel Input count is 3 - dynamic mode"; + } else { + MS_LOG(INFO) << "UnsortedSegmentMax Kernel Input count is 2"; + } num_segments_ = output_shapes[0]; - input_size_ = 1; for (size_t i = 0; i < input_shapes.size(); i++) { input_size_ *= input_shapes[i]; @@ -97,6 +96,19 @@ class UnsortedSegmentMaxGpuKernel : public GpuKernel { return true; } + void ResetResource() noexcept override { + num_segments_ = 1; + inner_size_ = 1; + outer_size_ = 1; + input_size_ = 1; + segment_ids_size_ = 1; + output_size_ = 1; + is_null_input_ = false; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + protected: void InitSizeLists() override { input_size_list_.push_back(input_size_ * sizeof(T)); diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 39db3ef9179..d2fda499e9c 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -109,6 +109,8 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index 7a6143f4010..19c01753a74 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -253,6 +253,74 @@ AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const Pri return std::make_shared(x->element(), std::make_shared(shape, min_shape, max_shape)); } +AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + auto x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(x->shape()); + auto segment_ids = CheckArg(op_name, args_spec_list, 1); + MS_EXCEPTION_IF_NULL(segment_ids); + MS_EXCEPTION_IF_NULL(segment_ids->shape()); + auto segment_ids_shape = segment_ids->shape()->shape(); + (void)CheckTensorDType(x, {kFloat16, kFloat32, kInt32}, "Input 0 (x) for UnsortedSegmentMax should be %s"); + (void)CheckTensorDType(segment_ids, {kInt32}, "Input 1 (segment_ids) for UnsortedSegmentMax should be %s"); + // check if dynamic shape + bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty()); + bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty()); + bool op_is_dynamic = x_is_dyn && ids_is_dyn; + auto x_shape = x->shape()->shape(); + ShapeVector shape; + int64_t num_segments_value = 0; + if (args_spec_list[2]->isa()) { // num_segments is Tensor + auto num_segments = args_spec_list[2]->cast(); + MS_EXCEPTION_IF_NULL(num_segments); + auto num_segments_value_ptr = num_segments->BuildValue(); + MS_EXCEPTION_IF_NULL(num_segments_value_ptr); + auto num_segments_tensor = num_segments_value_ptr->cast(); + MS_EXCEPTION_IF_NULL(num_segments_tensor); + num_segments_value = *static_cast(num_segments_tensor->data_c()); + } else if (args_spec_list[2]->isa()) { // num_segments is Scalar + auto num_segments = CheckArg(op_name, args_spec_list, 2); + num_segments_value = GetValue(num_segments->BuildValue()); + } else { + MS_LOG(EXCEPTION) << "num_segments incorrect type in UnsortedSegmentMax"; + } + if (num_segments_value <= 0) { + MS_LOG(EXCEPTION) << "num_segments must be > 0 in UnsortedSegmentMax"; + } + shape.emplace_back(num_segments_value); + shape.insert(shape.end(), x_shape.begin() + segment_ids_shape.size(), x_shape.end()); + if (!op_is_dynamic) { + if (x_shape[0] != segment_ids_shape[0]) { + MS_LOG(EXCEPTION) << "Length of segment_ids must match first value of x shape UnsortedSegmentMax"; + } + return std::make_shared(x->element(), std::make_shared(shape)); + } + // is dynamic + ShapeVector min_shape; + ShapeVector max_shape; + min_shape.emplace_back(num_segments_value); + max_shape.emplace_back(num_segments_value); + // only run validation if shape values are known + bool x_any_shape = std::any_of(x_shape.begin(), x_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; }); + bool ids_any_shape = + std::any_of(segment_ids_shape.begin(), segment_ids_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; }); + if (!x_any_shape && !ids_any_shape) { + if (x_shape[0] != segment_ids_shape[0]) { + MS_LOG(EXCEPTION) << "Length of segment_ids must match first value of x shape UnsortedSegmentMax"; + } + } + ShapeVector x_shape_min; + ShapeVector x_shape_max; + x_shape_min = (x_is_dyn) ? x->shape()->min_shape() : x->shape()->shape(); + x_shape_max = (x_is_dyn) ? x->shape()->max_shape() : x->shape()->shape(); + min_shape.insert(min_shape.end(), x_shape_min.begin() + segment_ids_shape.size(), x_shape_min.end()); + max_shape.insert(max_shape.end(), x_shape_max.begin() + segment_ids_shape.size(), x_shape_max.end()); + return std::make_shared(x->element(), std::make_shared(shape, min_shape, max_shape)); +} + AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { const std::string op_name = primitive->name(); diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index 7fe4be12d9d..6a2da572c10 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -535,7 +535,7 @@ AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, con ShapeVector input_shape = input->shape()->shape(); int32_t input_rank = input_shape.size(); ShapeVector inferred_shape(input_rank, Shape::SHP_ANY); - ShapeVector min_shape = {1}; + ShapeVector min_shape(input_rank, 1); ShapeVector max_shape = input_shape; ShapePtr shape = std::make_shared(inferred_shape, min_shape, max_shape); diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index db7767ef259..7e599226b34 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -56,6 +56,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimSparseGatherV2, {InferImplGatherV2, true}}, {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}}, {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}}, + {prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}}, {prim::kPrimScatterAdd, {InferImplScatterAdd, true}}, {prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}}, {prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, true}}, diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 4a871f6fc7f..4d964c00f17 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1805,7 +1805,7 @@ class UnsortedSegmentMin(PrimitiveWithInfer): return out -class UnsortedSegmentMax(PrimitiveWithInfer): +class UnsortedSegmentMax(PrimitiveWithCheck): """ Computes the maximum along segments of a tensor. @@ -1838,27 +1838,21 @@ class UnsortedSegmentMax(PrimitiveWithInfer): def __init__(self): """Initialize UnsortedSegmentMax""" self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y']) + self.add_prim_attr("dynamic_shape_depends", [2]) - def __infer__(self, x, segment_ids, num_segments): - x_type = x['dtype'] - x_shape = x['shape'] + def __check__(self, x, segment_ids, num_segments): segment_ids_shape = segment_ids['shape'] valid_type = [mstype.float16, mstype.float32, mstype.int32] validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name) validator.check_tensors_dtypes_same_and_valid({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) - validator.check(f'first shape of input_x', x_shape[0], - 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) - num_segments_v = num_segments['value'] - validator.check_value_type('num_segments', num_segments_v, [int], self.name) - validator.check_positive_int(num_segments_v, "num_segments", self.name) - segment_ids_shape_len = len(segment_ids_shape) - out_shape = [num_segments_v] - out_shape += x_shape[segment_ids_shape_len:] - out = {'shape': out_shape, - 'dtype': x_type, - 'value': None} - return out + num_segments_type = num_segments['dtype'] + validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name) + if isinstance(num_segments_type, type(mstype.tensor)): + validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int64], + self.name) + else: + validator.check_value_type('num_segments', num_segments['value'], [int], self.name) class UnsortedSegmentProd(PrimitiveWithInfer): diff --git a/tests/st/ops/gpu/test_unsorted_segment_max.py b/tests/st/ops/gpu/test_unsorted_segment_max.py index 90fb7bcc713..ffef5b2657c 100644 --- a/tests/st/ops/gpu/test_unsorted_segment_max.py +++ b/tests/st/ops/gpu/test_unsorted_segment_max.py @@ -20,6 +20,7 @@ import mindspore import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor +from mindspore.ops.operations import _inner_ops as inner from mindspore.common import dtype as mstype from mindspore.ops import operations as P @@ -139,7 +140,7 @@ def test_3d_float32(): @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_3d_single_init(): - context.set_context(device_target='GPU') + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') input_x = Tensor(np.arange( 4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32) segment_ids = Tensor([3, 0, 1, -1], mstype.int32) @@ -202,3 +203,112 @@ def test_3d_single_init(): [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]]]).astype(np.float32) np.testing.assert_array_almost_equal(output, expect) + +# For testing Dynamic Shape operation +class UnsortedSegmentMaxDynNet(nn.Cell): + def __init__(self, num_segments): + super(UnsortedSegmentMaxDynNet, self).__init__() + self.unsorted_segment_max = P.UnsortedSegmentMax() + self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() + self.num_segments = num_segments + + def construct(self, data, ids): + dyn_data = self.gpu_convert_to_dynamic_shape(data) + dyn_ids = self.gpu_convert_to_dynamic_shape(ids) + return self.unsorted_segment_max(dyn_data, dyn_ids, self.num_segments) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_3d_float32_dyn(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + input_x = Tensor(np.arange( + 4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32) + segment_ids = Tensor([2, 1, 1, -1], mstype.int32) + num_segments = 3 + net = UnsortedSegmentMaxDynNet(num_segments) + output = net(input_x, segment_ids).asnumpy() + expect = np.array([[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]], + [[3.0000000e+01, 3.1000000e+01, 3.2000000e+01], + [3.3000000e+01, 3.4000000e+01, 3.5000000e+01], + [3.6000000e+01, 3.7000000e+01, 3.8000000e+01], + [3.9000000e+01, 4.0000000e+01, 4.1000000e+01], + [4.2000000e+01, 4.3000000e+01, 4.4000000e+01]], + [[0.0000000e+00, 1.0000000e+00, 2.0000000e+00], + [3.0000000e+00, 4.0000000e+00, 5.0000000e+00], + [6.0000000e+00, 7.0000000e+00, 8.0000000e+00], + [9.0000000e+00, 1.0000000e+01, 1.1000000e+01], + [1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32) + np.testing.assert_array_almost_equal(output, expect) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_3d_single_init_dyn(): + context.set_context(device_target='GPU') + input_x = Tensor(np.arange( + 4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32) + segment_ids = Tensor([3, 0, 1, -1], mstype.int32) + num_segments = 4 + net = UnsortedSegmentMaxDynNet(num_segments) + output = net(input_x, segment_ids).asnumpy() + expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01], + [1.8000000e+01, 1.9000000e+01, 2.0000000e+01], + [2.1000000e+01, 2.2000000e+01, 2.3000000e+01], + [2.4000000e+01, 2.5000000e+01, 2.6000000e+01], + [2.7000000e+01, 2.8000000e+01, 2.9000000e+01]], + [[3.0000000e+01, 3.1000000e+01, 3.2000000e+01], + [3.3000000e+01, 3.4000000e+01, 3.5000000e+01], + [3.6000000e+01, 3.7000000e+01, 3.8000000e+01], + [3.9000000e+01, 4.0000000e+01, 4.1000000e+01], + [4.2000000e+01, 4.3000000e+01, 4.4000000e+01]], + [[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]], + [[0.0000000e+00, 1.0000000e+00, 2.0000000e+00], + [3.0000000e+00, 4.0000000e+00, 5.0000000e+00], + [6.0000000e+00, 7.0000000e+00, 8.0000000e+00], + [9.0000000e+00, 1.0000000e+01, 1.1000000e+01], + [1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32) + np.testing.assert_array_almost_equal(output, expect) + + num_segments = 6 + net = UnsortedSegmentMaxDynNet(num_segments) + output = net(input_x, segment_ids).asnumpy() + expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01], + [1.8000000e+01, 1.9000000e+01, 2.0000000e+01], + [2.1000000e+01, 2.2000000e+01, 2.3000000e+01], + [2.4000000e+01, 2.5000000e+01, 2.6000000e+01], + [2.7000000e+01, 2.8000000e+01, 2.9000000e+01]], + [[3.0000000e+01, 3.1000000e+01, 3.2000000e+01], + [3.3000000e+01, 3.4000000e+01, 3.5000000e+01], + [3.6000000e+01, 3.7000000e+01, 3.8000000e+01], + [3.9000000e+01, 4.0000000e+01, 4.1000000e+01], + [4.2000000e+01, 4.3000000e+01, 4.4000000e+01]], + [[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]], + [[0.0000000e+00, 1.0000000e+00, 2.0000000e+00], + [3.0000000e+00, 4.0000000e+00, 5.0000000e+00], + [6.0000000e+00, 7.0000000e+00, 8.0000000e+00], + [9.0000000e+00, 1.0000000e+01, 1.1000000e+01], + [1.2000000e+01, 1.3000000e+01, 1.4000000e+01]], + [[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]], + [[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]]]).astype(np.float32) + np.testing.assert_array_almost_equal(output, expect)