forked from mindspore-Ecosystem/mindspore
!44049 unattr for unsorted segment sum & prod
Merge pull request !44049 from yangsijia/unsortedseg-unattr
This commit is contained in:
commit
7879137afd
|
@ -67,8 +67,6 @@ RER_GPU_STATIC_CONST_TO_ATTR(kSubscalarOpName, 1);
|
|||
RER_GPU_STATIC_CONST_TO_ATTR(kTensorCopySlicesOpName, 2, 3, 4);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kTileOpName, 1);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kTransposeOpName, 1);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kUnsortedSegmentProdOpName, 2);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kUnsortedSegmentSumOpName, 2);
|
||||
} // namespace mindspore::opt
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_GPU_OPTIMIZER_REG_GPU_CONST_INPUT_TO_ATTR_H_
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "ops/unsorted_segment_arithmetic.h"
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
@ -27,6 +28,48 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
int64_t GetNumSegmentsValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const std::string &op_name = primitive->name();
|
||||
int64_t num_segments_v = 0;
|
||||
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>()) {
|
||||
auto n_value_ptr = input_args[kInputIndex2]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(n_value_ptr);
|
||||
if (n_value_ptr->isa<tensor::Tensor>()) {
|
||||
auto n_tensor_ptr = n_value_ptr->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(n_tensor_ptr);
|
||||
num_segments_v = n_tensor_ptr->data_type() == kNumberTypeInt32 ? *static_cast<int32_t *>(n_tensor_ptr->data_c())
|
||||
: *static_cast<int64_t *>(n_tensor_ptr->data_c());
|
||||
(void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name);
|
||||
return num_segments_v;
|
||||
} else {
|
||||
auto n_abstract_tensor = input_args[kInputIndex2]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(n_abstract_tensor);
|
||||
return -1;
|
||||
}
|
||||
} else if (input_args[kInputIndex2]->isa<abstract::AbstractScalar>()) {
|
||||
auto num_segments_input_type = input_args[kInputIndex2]->BuildType();
|
||||
if (num_segments_input_type->type_id() == kNumberTypeInt64) {
|
||||
auto num_sample_ptr = input_args[kInputIndex2]->cast<abstract::AbstractScalarPtr>();
|
||||
MS_EXCEPTION_IF_NULL(num_sample_ptr);
|
||||
num_segments_v = GetValue<int64_t>(input_args[kInputIndex2]->BuildValue());
|
||||
} else if (num_segments_input_type->type_id() == kNumberTypeInt32) {
|
||||
auto num_sample_ptr = input_args[kInputIndex2]->cast<abstract::AbstractScalarPtr>();
|
||||
MS_EXCEPTION_IF_NULL(num_sample_ptr);
|
||||
num_segments_v = GetValue<int32_t>(input_args[kInputIndex2]->BuildValue());
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For '" << op_name << "' the third input build type is invalid:"
|
||||
<< TypeIdToString(num_segments_input_type->type_id()) << ".";
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name);
|
||||
return num_segments_v;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << op_name
|
||||
<< "', the third input type should be tensor or scalar, but got invalid abstract type:"
|
||||
<< input_args[kInputIndex2]->type_name() << ".";
|
||||
}
|
||||
return num_segments_v;
|
||||
}
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
|
||||
const int64_t &num_segments_value) {
|
||||
|
@ -92,26 +135,11 @@ abstract::ShapePtr UnsortedSegmentArithmeticInferShape(const PrimitivePtr &primi
|
|||
auto num_segments_shape_ptr = input_args[kInputIndex2]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(num_segments_shape_ptr);
|
||||
|
||||
auto num_segments = input_args[kInputIndex2]->cast<abstract::AbstractScalarPtr>();
|
||||
int64_t num_segments_value = 0;
|
||||
|
||||
if (num_segments != nullptr && num_segments->BuildValue() != kAnyValue) {
|
||||
num_segments_value = GetValue<int64_t>(num_segments->BuildValue());
|
||||
(void)primitive->AddAttr(kNumSegments, MakeValue(num_segments_value));
|
||||
}
|
||||
|
||||
if (x_shape_ptr->IsDynamic() || segment_ids_shape_ptr->IsDynamic() || num_segments_shape_ptr->IsDynamic()) {
|
||||
return x_shape_ptr->cast<abstract::ShapePtr>();
|
||||
}
|
||||
|
||||
if (num_segments == nullptr || num_segments->BuildValue() == kAnyValue) {
|
||||
auto value_ptr = primitive->GetAttr(kNumSegments);
|
||||
if (value_ptr != nullptr) {
|
||||
num_segments_value = GetValue<int64_t>(value_ptr);
|
||||
} else {
|
||||
return x_shape_ptr->cast<abstract::ShapePtr>();
|
||||
}
|
||||
}
|
||||
int64_t num_segments_value = GetNumSegmentsValue(primitive, input_args);
|
||||
if (num_segments_value <= 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||
<< "', num_segments value must be greater than 0, but got: " << num_segments_value << ".";
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -25,6 +25,8 @@ namespace ops {
|
|||
abstract::AbstractBasePtr UnsortedSegmentArithmeticInfer(const abstract::AnalysisEnginePtr &,
|
||||
const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
|
||||
int64_t GetNumSegmentsValue(const PrimitivePtr &primitive, const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -25,54 +25,11 @@
|
|||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/unsorted_segment_arithmetic.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
void GetNumSegmentsValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
|
||||
ShapeVector *num_vec) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const std::string &op_name = primitive->name();
|
||||
int64_t num_segments_v;
|
||||
// num_segments is a Tensor when UnsortedSegmentSum is a dynamic shape operator
|
||||
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>()) {
|
||||
auto n_value_ptr = input_args[kInputIndex2]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(n_value_ptr);
|
||||
if (n_value_ptr->isa<tensor::Tensor>()) {
|
||||
auto n_tensor_ptr = n_value_ptr->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(n_tensor_ptr);
|
||||
num_segments_v = n_tensor_ptr->data_type() == kNumberTypeInt32 ? *static_cast<int32_t *>(n_tensor_ptr->data_c())
|
||||
: *static_cast<int64_t *>(n_tensor_ptr->data_c());
|
||||
(void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name);
|
||||
num_vec->push_back(num_segments_v);
|
||||
} else {
|
||||
auto n_abstract_tensor = input_args[kInputIndex2]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(n_abstract_tensor);
|
||||
num_vec->push_back(-1);
|
||||
}
|
||||
} else if (input_args[kInputIndex2]->isa<abstract::AbstractScalar>()) {
|
||||
auto num_segments_input_type = input_args[kInputIndex2]->BuildType();
|
||||
if (num_segments_input_type->type_id() == kNumberTypeInt64) {
|
||||
auto num_sample_ptr = input_args[kInputIndex2]->cast<abstract::AbstractScalarPtr>();
|
||||
MS_EXCEPTION_IF_NULL(num_sample_ptr);
|
||||
num_segments_v = GetValue<int64_t>(input_args[kInputIndex2]->BuildValue());
|
||||
} else if (num_segments_input_type->type_id() == kNumberTypeInt32) {
|
||||
auto num_sample_ptr = input_args[kInputIndex2]->cast<abstract::AbstractScalarPtr>();
|
||||
MS_EXCEPTION_IF_NULL(num_sample_ptr);
|
||||
num_segments_v = GetValue<int32_t>(input_args[kInputIndex2]->BuildValue());
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For '" << op_name << "' the third input build type is invalid:"
|
||||
<< TypeIdToString(num_segments_input_type->type_id()) << ".";
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name);
|
||||
num_vec->push_back(num_segments_v);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << op_name
|
||||
<< "', the third input type should be tensor or scalar, but got invalid abstract type:"
|
||||
<< input_args[kInputIndex2]->type_name() << ".";
|
||||
}
|
||||
}
|
||||
|
||||
abstract::ShapePtr UnsortedSegmentSumInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
@ -108,7 +65,7 @@ abstract::ShapePtr UnsortedSegmentSumInferShape(const PrimitivePtr &primitive,
|
|||
abstract::CheckShapeAnyAndPositive(op_name + " segment_ids_shape", segment_ids_shape);
|
||||
|
||||
ShapeVector num_vec;
|
||||
GetNumSegmentsValue(primitive, input_args, &num_vec);
|
||||
num_vec.push_back(GetNumSegmentsValue(primitive, input_args));
|
||||
int64_t batch_rank = 0;
|
||||
if (primitive->HasAttr(kBatchRank)) {
|
||||
auto batch_rank_ptr = primitive->GetAttr(kBatchRank);
|
||||
|
|
Loading…
Reference in New Issue