!40256 delete unsorted segment sum max and min shape

Merge pull request !40256 from zhangdong/static_check
This commit is contained in:
i-robot 2022-08-12 06:47:54 +00:00 committed by Gitee
commit 877581a07e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 3 additions and 29 deletions

View File

@ -30,7 +30,7 @@ namespace mindspore {
namespace ops {
namespace {
void GetNumSegmentsValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
ShapeVector *num_vec, ShapeVector *num_min_vec, ShapeVector *num_max_vec) {
ShapeVector *num_vec) {
MS_EXCEPTION_IF_NULL(primitive);
const std::string &op_name = primitive->name();
int64_t num_segments_v;
@ -45,18 +45,10 @@ void GetNumSegmentsValue(const PrimitivePtr &primitive, const std::vector<Abstra
: *static_cast<int64_t *>(n_tensor_ptr->data_c());
(void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name);
num_vec->push_back(num_segments_v);
num_min_vec->push_back(num_segments_v);
num_max_vec->push_back(num_segments_v);
} else {
auto n_abstract_tensor = input_args[kInputIndex2]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(n_abstract_tensor);
num_vec->push_back(-1);
auto num_min_value = n_abstract_tensor->get_min_value();
auto num_max_value = n_abstract_tensor->get_max_value();
if (num_min_value != nullptr && num_max_value != nullptr) {
*num_min_vec = GetValue<ShapeVector>(num_min_value);
*num_max_vec = GetValue<ShapeVector>(num_max_value);
}
}
} else if (input_args[kInputIndex2]->isa<abstract::AbstractScalar>()) {
auto num_segments_input_type = input_args[kInputIndex2]->BuildType();
@ -74,8 +66,6 @@ void GetNumSegmentsValue(const PrimitivePtr &primitive, const std::vector<Abstra
}
(void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name);
num_vec->push_back(num_segments_v);
num_min_vec->push_back(num_segments_v);
num_max_vec->push_back(num_segments_v);
} else {
MS_LOG(EXCEPTION) << "For '" << op_name
<< "', the third input type should be tensor or scalar, but got invalid abstract type:"
@ -88,8 +78,6 @@ abstract::ShapePtr UnsortedSegmentSumInferShape(const PrimitivePtr &primitive,
MS_EXCEPTION_IF_NULL(primitive);
const std::string &op_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x_min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMinShape];
auto x_max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMaxShape];
auto x_shape_rank = SizeToLong(x_shape.size());
(void)CheckAndConvertUtils::CheckInteger("input_x size", x_shape_rank, kGreaterThan, 0, op_name);
auto segment_ids_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
@ -123,9 +111,7 @@ abstract::ShapePtr UnsortedSegmentSumInferShape(const PrimitivePtr &primitive,
abstract::CheckShapeAnyAndPositive(op_name + " segment_ids_shape", segment_ids_shape);
ShapeVector num_vec;
ShapeVector num_min_vec;
ShapeVector num_max_vec;
GetNumSegmentsValue(primitive, input_args, &num_vec, &num_min_vec, &num_max_vec);
GetNumSegmentsValue(primitive, input_args, &num_vec);
int64_t batch_rank = 0;
if (primitive->HasAttr(kBatchRank)) {
auto batch_rank_ptr = primitive->GetAttr(kBatchRank);
@ -140,21 +126,9 @@ abstract::ShapePtr UnsortedSegmentSumInferShape(const PrimitivePtr &primitive,
(void)copy(x_shape.begin() + segment_ids_shape_rank, x_shape.end(), std::back_inserter(out_vec));
return out_vec;
};
ShapeVector out_min_shape;
ShapeVector out_max_shape;
auto out_vec = calc_shape(num_vec, x_shape);
(void)copy(out_vec.begin(), out_vec.end(), std::back_inserter(output_shape));
auto output_shape_rank = SizeToLong(output_shape.size());
auto out_max_shape_rank = SizeToLong(calc_shape(num_max_vec, x_max_shape).size());
bool x_min_any_shape =
std::any_of(x_min_shape.begin(), x_min_shape.end(), [](int64_t dim) { return dim == abstract::Shape::SHP_ANY; });
bool x_max_any_shape =
std::any_of(x_max_shape.begin(), x_max_shape.end(), [](int64_t dim) { return dim == abstract::Shape::SHP_ANY; });
if ((out_max_shape_rank == output_shape_rank) && (!x_min_any_shape && !x_max_any_shape)) {
out_min_shape = calc_shape(num_min_vec, x_min_shape);
out_max_shape = calc_shape(num_max_vec, x_max_shape);
return std::make_shared<abstract::Shape>(output_shape, out_min_shape, out_max_shape);
}
return std::make_shared<abstract::Shape>(output_shape);
}