!44049 unattr for unsorted segment sum & prod

Merge pull request !44049 from yangsijia/unsortedseg-unattr
This commit is contained in:
i-robot 2022-11-01 01:47:51 +00:00 committed by Gitee
commit 7879137afd
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 49 additions and 64 deletions

View File

@ -67,8 +67,6 @@ RER_GPU_STATIC_CONST_TO_ATTR(kSubscalarOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kTensorCopySlicesOpName, 2, 3, 4);
RER_GPU_STATIC_CONST_TO_ATTR(kTileOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kTransposeOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kUnsortedSegmentProdOpName, 2);
RER_GPU_STATIC_CONST_TO_ATTR(kUnsortedSegmentSumOpName, 2);
} // namespace mindspore::opt
#endif // MINDSPORE_CCSRC_PLUGIN_GPU_OPTIMIZER_REG_GPU_CONST_INPUT_TO_ATTR_H_

View File

@ -17,6 +17,7 @@
#include "ops/unsorted_segment_arithmetic.h"
#include <memory>
#include <set>
#include <string>
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "ops/op_utils.h"
@ -27,6 +28,48 @@
namespace mindspore {
namespace ops {
int64_t GetNumSegmentsValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const std::string &op_name = primitive->name();
int64_t num_segments_v = 0;
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>()) {
auto n_value_ptr = input_args[kInputIndex2]->BuildValue();
MS_EXCEPTION_IF_NULL(n_value_ptr);
if (n_value_ptr->isa<tensor::Tensor>()) {
auto n_tensor_ptr = n_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(n_tensor_ptr);
num_segments_v = n_tensor_ptr->data_type() == kNumberTypeInt32 ? *static_cast<int32_t *>(n_tensor_ptr->data_c())
: *static_cast<int64_t *>(n_tensor_ptr->data_c());
(void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name);
return num_segments_v;
} else {
auto n_abstract_tensor = input_args[kInputIndex2]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(n_abstract_tensor);
return -1;
}
} else if (input_args[kInputIndex2]->isa<abstract::AbstractScalar>()) {
auto num_segments_input_type = input_args[kInputIndex2]->BuildType();
if (num_segments_input_type->type_id() == kNumberTypeInt64) {
auto num_sample_ptr = input_args[kInputIndex2]->cast<abstract::AbstractScalarPtr>();
MS_EXCEPTION_IF_NULL(num_sample_ptr);
num_segments_v = GetValue<int64_t>(input_args[kInputIndex2]->BuildValue());
} else if (num_segments_input_type->type_id() == kNumberTypeInt32) {
auto num_sample_ptr = input_args[kInputIndex2]->cast<abstract::AbstractScalarPtr>();
MS_EXCEPTION_IF_NULL(num_sample_ptr);
num_segments_v = GetValue<int32_t>(input_args[kInputIndex2]->BuildValue());
} else {
MS_EXCEPTION(TypeError) << "For '" << op_name << "' the third input build type is invalid:"
<< TypeIdToString(num_segments_input_type->type_id()) << ".";
}
(void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name);
return num_segments_v;
} else {
MS_LOG(EXCEPTION) << "For '" << op_name
<< "', the third input type should be tensor or scalar, but got invalid abstract type:"
<< input_args[kInputIndex2]->type_name() << ".";
}
return num_segments_v;
}
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
const int64_t &num_segments_value) {
@ -92,26 +135,11 @@ abstract::ShapePtr UnsortedSegmentArithmeticInferShape(const PrimitivePtr &primi
auto num_segments_shape_ptr = input_args[kInputIndex2]->BuildShape();
MS_EXCEPTION_IF_NULL(num_segments_shape_ptr);
auto num_segments = input_args[kInputIndex2]->cast<abstract::AbstractScalarPtr>();
int64_t num_segments_value = 0;
if (num_segments != nullptr && num_segments->BuildValue() != kAnyValue) {
num_segments_value = GetValue<int64_t>(num_segments->BuildValue());
(void)primitive->AddAttr(kNumSegments, MakeValue(num_segments_value));
}
if (x_shape_ptr->IsDynamic() || segment_ids_shape_ptr->IsDynamic() || num_segments_shape_ptr->IsDynamic()) {
return x_shape_ptr->cast<abstract::ShapePtr>();
}
if (num_segments == nullptr || num_segments->BuildValue() == kAnyValue) {
auto value_ptr = primitive->GetAttr(kNumSegments);
if (value_ptr != nullptr) {
num_segments_value = GetValue<int64_t>(value_ptr);
} else {
return x_shape_ptr->cast<abstract::ShapePtr>();
}
}
int64_t num_segments_value = GetNumSegmentsValue(primitive, input_args);
if (num_segments_value <= 0) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', num_segments value must be greater than 0, but got: " << num_segments_value << ".";

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -25,6 +25,8 @@ namespace ops {
abstract::AbstractBasePtr UnsortedSegmentArithmeticInfer(const abstract::AnalysisEnginePtr &,
const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
int64_t GetNumSegmentsValue(const PrimitivePtr &primitive, const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore

View File

@ -25,54 +25,11 @@
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "ops/unsorted_segment_arithmetic.h"
namespace mindspore {
namespace ops {
namespace {
void GetNumSegmentsValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
ShapeVector *num_vec) {
MS_EXCEPTION_IF_NULL(primitive);
const std::string &op_name = primitive->name();
int64_t num_segments_v;
// num_segments is a Tensor when UnsortedSegmentSum is a dynamic shape operator
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>()) {
auto n_value_ptr = input_args[kInputIndex2]->BuildValue();
MS_EXCEPTION_IF_NULL(n_value_ptr);
if (n_value_ptr->isa<tensor::Tensor>()) {
auto n_tensor_ptr = n_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(n_tensor_ptr);
num_segments_v = n_tensor_ptr->data_type() == kNumberTypeInt32 ? *static_cast<int32_t *>(n_tensor_ptr->data_c())
: *static_cast<int64_t *>(n_tensor_ptr->data_c());
(void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name);
num_vec->push_back(num_segments_v);
} else {
auto n_abstract_tensor = input_args[kInputIndex2]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(n_abstract_tensor);
num_vec->push_back(-1);
}
} else if (input_args[kInputIndex2]->isa<abstract::AbstractScalar>()) {
auto num_segments_input_type = input_args[kInputIndex2]->BuildType();
if (num_segments_input_type->type_id() == kNumberTypeInt64) {
auto num_sample_ptr = input_args[kInputIndex2]->cast<abstract::AbstractScalarPtr>();
MS_EXCEPTION_IF_NULL(num_sample_ptr);
num_segments_v = GetValue<int64_t>(input_args[kInputIndex2]->BuildValue());
} else if (num_segments_input_type->type_id() == kNumberTypeInt32) {
auto num_sample_ptr = input_args[kInputIndex2]->cast<abstract::AbstractScalarPtr>();
MS_EXCEPTION_IF_NULL(num_sample_ptr);
num_segments_v = GetValue<int32_t>(input_args[kInputIndex2]->BuildValue());
} else {
MS_EXCEPTION(TypeError) << "For '" << op_name << "' the third input build type is invalid:"
<< TypeIdToString(num_segments_input_type->type_id()) << ".";
}
(void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name);
num_vec->push_back(num_segments_v);
} else {
MS_LOG(EXCEPTION) << "For '" << op_name
<< "', the third input type should be tensor or scalar, but got invalid abstract type:"
<< input_args[kInputIndex2]->type_name() << ".";
}
}
abstract::ShapePtr UnsortedSegmentSumInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
@ -108,7 +65,7 @@ abstract::ShapePtr UnsortedSegmentSumInferShape(const PrimitivePtr &primitive,
abstract::CheckShapeAnyAndPositive(op_name + " segment_ids_shape", segment_ids_shape);
ShapeVector num_vec;
GetNumSegmentsValue(primitive, input_args, &num_vec);
num_vec.push_back(GetNumSegmentsValue(primitive, input_args));
int64_t batch_rank = 0;
if (primitive->HasAttr(kBatchRank)) {
auto batch_rank_ptr = primitive->GetAttr(kBatchRank);