diff --git a/mindspore/core/ops/unsorted_segment_sum.cc b/mindspore/core/ops/unsorted_segment_sum.cc index 3d5b39ebef3..145dce8fb96 100644 --- a/mindspore/core/ops/unsorted_segment_sum.cc +++ b/mindspore/core/ops/unsorted_segment_sum.cc @@ -30,7 +30,7 @@ namespace mindspore { namespace ops { namespace { void GetNumSegmentsValue(const PrimitivePtr &primitive, const std::vector &input_args, - ShapeVector *num_vec, ShapeVector *num_min_vec, ShapeVector *num_max_vec) { + ShapeVector *num_vec) { MS_EXCEPTION_IF_NULL(primitive); const std::string &op_name = primitive->name(); int64_t num_segments_v; @@ -45,18 +45,10 @@ void GetNumSegmentsValue(const PrimitivePtr &primitive, const std::vector(n_tensor_ptr->data_c()); (void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name); num_vec->push_back(num_segments_v); - num_min_vec->push_back(num_segments_v); - num_max_vec->push_back(num_segments_v); } else { auto n_abstract_tensor = input_args[kInputIndex2]->cast(); MS_EXCEPTION_IF_NULL(n_abstract_tensor); num_vec->push_back(-1); - auto num_min_value = n_abstract_tensor->get_min_value(); - auto num_max_value = n_abstract_tensor->get_max_value(); - if (num_min_value != nullptr && num_max_value != nullptr) { - *num_min_vec = GetValue(num_min_value); - *num_max_vec = GetValue(num_max_value); - } } } else if (input_args[kInputIndex2]->isa()) { auto num_segments_input_type = input_args[kInputIndex2]->BuildType(); @@ -74,8 +66,6 @@ void GetNumSegmentsValue(const PrimitivePtr &primitive, const std::vectorpush_back(num_segments_v); - num_min_vec->push_back(num_segments_v); - num_max_vec->push_back(num_segments_v); } else { MS_LOG(EXCEPTION) << "For '" << op_name << "', the third input type should be tensor or scalar, but got invalid abstract type:" @@ -88,8 +78,6 @@ abstract::ShapePtr UnsortedSegmentSumInferShape(const PrimitivePtr &primitive, MS_EXCEPTION_IF_NULL(primitive); const std::string &op_name = primitive->name(); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - auto x_min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMinShape]; - auto x_max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMaxShape]; auto x_shape_rank = SizeToLong(x_shape.size()); (void)CheckAndConvertUtils::CheckInteger("input_x size", x_shape_rank, kGreaterThan, 0, op_name); auto segment_ids_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; @@ -123,9 +111,7 @@ abstract::ShapePtr UnsortedSegmentSumInferShape(const PrimitivePtr &primitive, abstract::CheckShapeAnyAndPositive(op_name + " segment_ids_shape", segment_ids_shape); ShapeVector num_vec; - ShapeVector num_min_vec; - ShapeVector num_max_vec; - GetNumSegmentsValue(primitive, input_args, &num_vec, &num_min_vec, &num_max_vec); + GetNumSegmentsValue(primitive, input_args, &num_vec); int64_t batch_rank = 0; if (primitive->HasAttr(kBatchRank)) { auto batch_rank_ptr = primitive->GetAttr(kBatchRank); @@ -140,21 +126,9 @@ abstract::ShapePtr UnsortedSegmentSumInferShape(const PrimitivePtr &primitive, (void)copy(x_shape.begin() + segment_ids_shape_rank, x_shape.end(), std::back_inserter(out_vec)); return out_vec; }; - ShapeVector out_min_shape; - ShapeVector out_max_shape; + auto out_vec = calc_shape(num_vec, x_shape); (void)copy(out_vec.begin(), out_vec.end(), std::back_inserter(output_shape)); - auto output_shape_rank = SizeToLong(output_shape.size()); - auto out_max_shape_rank = SizeToLong(calc_shape(num_max_vec, x_max_shape).size()); - bool x_min_any_shape = - std::any_of(x_min_shape.begin(), x_min_shape.end(), [](int64_t dim) { return dim == abstract::Shape::SHP_ANY; }); - bool x_max_any_shape = - std::any_of(x_max_shape.begin(), x_max_shape.end(), [](int64_t dim) { return dim == abstract::Shape::SHP_ANY; }); - if ((out_max_shape_rank == output_shape_rank) && (!x_min_any_shape && !x_max_any_shape)) { - out_min_shape = calc_shape(num_min_vec, x_min_shape); - out_max_shape = calc_shape(num_max_vec, x_max_shape); - return std::make_shared(output_shape, out_min_shape, out_max_shape); - } return std::make_shared(output_shape); }