From 94b91295f9443df1a2cb804a5aa758152d3d7b91 Mon Sep 17 00:00:00 2001 From: dabaiji Date: Wed, 26 Oct 2022 09:35:29 +0800 Subject: [PATCH] unattr for unsorted segment sum & prod --- .../optimizer/reg_gpu_const_input_to_attr.h | 2 - .../core/ops/unsorted_segment_arithmetic.cc | 60 ++++++++++++++----- .../core/ops/unsorted_segment_arithmetic.h | 4 +- mindspore/core/ops/unsorted_segment_sum.cc | 47 +-------------- 4 files changed, 49 insertions(+), 64 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/gpu/optimizer/reg_gpu_const_input_to_attr.h b/mindspore/ccsrc/plugin/device/gpu/optimizer/reg_gpu_const_input_to_attr.h index b773e64d070..e507e91447c 100644 --- a/mindspore/ccsrc/plugin/device/gpu/optimizer/reg_gpu_const_input_to_attr.h +++ b/mindspore/ccsrc/plugin/device/gpu/optimizer/reg_gpu_const_input_to_attr.h @@ -67,8 +67,6 @@ RER_GPU_STATIC_CONST_TO_ATTR(kSubscalarOpName, 1); RER_GPU_STATIC_CONST_TO_ATTR(kTensorCopySlicesOpName, 2, 3, 4); RER_GPU_STATIC_CONST_TO_ATTR(kTileOpName, 1); RER_GPU_STATIC_CONST_TO_ATTR(kTransposeOpName, 1); -RER_GPU_STATIC_CONST_TO_ATTR(kUnsortedSegmentProdOpName, 2); -RER_GPU_STATIC_CONST_TO_ATTR(kUnsortedSegmentSumOpName, 2); } // namespace mindspore::opt #endif // MINDSPORE_CCSRC_PLUGIN_GPU_OPTIMIZER_REG_GPU_CONST_INPUT_TO_ATTR_H_ diff --git a/mindspore/core/ops/unsorted_segment_arithmetic.cc b/mindspore/core/ops/unsorted_segment_arithmetic.cc index b7f5cbf9d22..94426a7a8d5 100644 --- a/mindspore/core/ops/unsorted_segment_arithmetic.cc +++ b/mindspore/core/ops/unsorted_segment_arithmetic.cc @@ -17,6 +17,7 @@ #include "ops/unsorted_segment_arithmetic.h" #include #include +#include #include "utils/check_convert_utils.h" #include "abstract/ops/primitive_infer_map.h" #include "ops/op_utils.h" @@ -27,6 +28,48 @@ namespace mindspore { namespace ops { +int64_t GetNumSegmentsValue(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + const std::string &op_name = primitive->name(); + int64_t num_segments_v = 0; + if (input_args[kInputIndex2]->isa()) { + auto n_value_ptr = input_args[kInputIndex2]->BuildValue(); + MS_EXCEPTION_IF_NULL(n_value_ptr); + if (n_value_ptr->isa()) { + auto n_tensor_ptr = n_value_ptr->cast(); + MS_EXCEPTION_IF_NULL(n_tensor_ptr); + num_segments_v = n_tensor_ptr->data_type() == kNumberTypeInt32 ? *static_cast(n_tensor_ptr->data_c()) + : *static_cast(n_tensor_ptr->data_c()); + (void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name); + return num_segments_v; + } else { + auto n_abstract_tensor = input_args[kInputIndex2]->cast(); + MS_EXCEPTION_IF_NULL(n_abstract_tensor); + return -1; + } + } else if (input_args[kInputIndex2]->isa()) { + auto num_segments_input_type = input_args[kInputIndex2]->BuildType(); + if (num_segments_input_type->type_id() == kNumberTypeInt64) { + auto num_sample_ptr = input_args[kInputIndex2]->cast(); + MS_EXCEPTION_IF_NULL(num_sample_ptr); + num_segments_v = GetValue(input_args[kInputIndex2]->BuildValue()); + } else if (num_segments_input_type->type_id() == kNumberTypeInt32) { + auto num_sample_ptr = input_args[kInputIndex2]->cast(); + MS_EXCEPTION_IF_NULL(num_sample_ptr); + num_segments_v = GetValue(input_args[kInputIndex2]->BuildValue()); + } else { + MS_EXCEPTION(TypeError) << "For '" << op_name << "' the third input build type is invalid:" + << TypeIdToString(num_segments_input_type->type_id()) << "."; + } + (void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name); + return num_segments_v; + } else { + MS_LOG(EXCEPTION) << "For '" << op_name + << "', the third input type should be tensor or scalar, but got invalid abstract type:" + << input_args[kInputIndex2]->type_name() << "."; + } + return num_segments_v; +} namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args, const int64_t &num_segments_value) { @@ -92,26 +135,11 @@ abstract::ShapePtr UnsortedSegmentArithmeticInferShape(const PrimitivePtr &primi auto num_segments_shape_ptr = input_args[kInputIndex2]->BuildShape(); MS_EXCEPTION_IF_NULL(num_segments_shape_ptr); - auto num_segments = input_args[kInputIndex2]->cast(); - int64_t num_segments_value = 0; - - if (num_segments != nullptr && num_segments->BuildValue() != kAnyValue) { - num_segments_value = GetValue(num_segments->BuildValue()); - (void)primitive->AddAttr(kNumSegments, MakeValue(num_segments_value)); - } - if (x_shape_ptr->IsDynamic() || segment_ids_shape_ptr->IsDynamic() || num_segments_shape_ptr->IsDynamic()) { return x_shape_ptr->cast(); } - if (num_segments == nullptr || num_segments->BuildValue() == kAnyValue) { - auto value_ptr = primitive->GetAttr(kNumSegments); - if (value_ptr != nullptr) { - num_segments_value = GetValue(value_ptr); - } else { - return x_shape_ptr->cast(); - } - } + int64_t num_segments_value = GetNumSegmentsValue(primitive, input_args); if (num_segments_value <= 0) { MS_EXCEPTION(ValueError) << "For '" << prim_name << "', num_segments value must be greater than 0, but got: " << num_segments_value << "."; diff --git a/mindspore/core/ops/unsorted_segment_arithmetic.h b/mindspore/core/ops/unsorted_segment_arithmetic.h index b3b90b2aece..9db513481d1 100644 --- a/mindspore/core/ops/unsorted_segment_arithmetic.h +++ b/mindspore/core/ops/unsorted_segment_arithmetic.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,8 @@ namespace ops { abstract::AbstractBasePtr UnsortedSegmentArithmeticInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args); + +int64_t GetNumSegmentsValue(const PrimitivePtr &primitive, const std::vector &input_args); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/unsorted_segment_sum.cc b/mindspore/core/ops/unsorted_segment_sum.cc index 6ec541c99dd..4e87044c0ed 100644 --- a/mindspore/core/ops/unsorted_segment_sum.cc +++ b/mindspore/core/ops/unsorted_segment_sum.cc @@ -25,54 +25,11 @@ #include "utils/check_convert_utils.h" #include "abstract/ops/primitive_infer_map.h" #include "mindapi/src/helper.h" +#include "ops/unsorted_segment_arithmetic.h" namespace mindspore { namespace ops { namespace { -void GetNumSegmentsValue(const PrimitivePtr &primitive, const std::vector &input_args, - ShapeVector *num_vec) { - MS_EXCEPTION_IF_NULL(primitive); - const std::string &op_name = primitive->name(); - int64_t num_segments_v; - // num_segments is a Tensor when UnsortedSegmentSum is a dynamic shape operator - if (input_args[kInputIndex2]->isa()) { - auto n_value_ptr = input_args[kInputIndex2]->BuildValue(); - MS_EXCEPTION_IF_NULL(n_value_ptr); - if (n_value_ptr->isa()) { - auto n_tensor_ptr = n_value_ptr->cast(); - MS_EXCEPTION_IF_NULL(n_tensor_ptr); - num_segments_v = n_tensor_ptr->data_type() == kNumberTypeInt32 ? *static_cast(n_tensor_ptr->data_c()) - : *static_cast(n_tensor_ptr->data_c()); - (void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name); - num_vec->push_back(num_segments_v); - } else { - auto n_abstract_tensor = input_args[kInputIndex2]->cast(); - MS_EXCEPTION_IF_NULL(n_abstract_tensor); - num_vec->push_back(-1); - } - } else if (input_args[kInputIndex2]->isa()) { - auto num_segments_input_type = input_args[kInputIndex2]->BuildType(); - if (num_segments_input_type->type_id() == kNumberTypeInt64) { - auto num_sample_ptr = input_args[kInputIndex2]->cast(); - MS_EXCEPTION_IF_NULL(num_sample_ptr); - num_segments_v = GetValue(input_args[kInputIndex2]->BuildValue()); - } else if (num_segments_input_type->type_id() == kNumberTypeInt32) { - auto num_sample_ptr = input_args[kInputIndex2]->cast(); - MS_EXCEPTION_IF_NULL(num_sample_ptr); - num_segments_v = GetValue(input_args[kInputIndex2]->BuildValue()); - } else { - MS_EXCEPTION(TypeError) << "For '" << op_name << "' the third input build type is invalid:" - << TypeIdToString(num_segments_input_type->type_id()) << "."; - } - (void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name); - num_vec->push_back(num_segments_v); - } else { - MS_LOG(EXCEPTION) << "For '" << op_name - << "', the third input type should be tensor or scalar, but got invalid abstract type:" - << input_args[kInputIndex2]->type_name() << "."; - } -} - abstract::ShapePtr UnsortedSegmentSumInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); @@ -108,7 +65,7 @@ abstract::ShapePtr UnsortedSegmentSumInferShape(const PrimitivePtr &primitive, abstract::CheckShapeAnyAndPositive(op_name + " segment_ids_shape", segment_ids_shape); ShapeVector num_vec; - GetNumSegmentsValue(primitive, input_args, &num_vec); + num_vec.push_back(GetNumSegmentsValue(primitive, input_args)); int64_t batch_rank = 0; if (primitive->HasAttr(kBatchRank)) { auto batch_rank_ptr = primitive->GetAttr(kBatchRank);