forked from mindspore-Ecosystem/mindspore
!40256 delete unsorted segment sum max and min shape
Merge pull request !40256 from zhangdong/static_check
This commit is contained in:
commit
877581a07e
|
@ -30,7 +30,7 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
namespace {
|
||||
void GetNumSegmentsValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
|
||||
ShapeVector *num_vec, ShapeVector *num_min_vec, ShapeVector *num_max_vec) {
|
||||
ShapeVector *num_vec) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const std::string &op_name = primitive->name();
|
||||
int64_t num_segments_v;
|
||||
|
@ -45,18 +45,10 @@ void GetNumSegmentsValue(const PrimitivePtr &primitive, const std::vector<Abstra
|
|||
: *static_cast<int64_t *>(n_tensor_ptr->data_c());
|
||||
(void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name);
|
||||
num_vec->push_back(num_segments_v);
|
||||
num_min_vec->push_back(num_segments_v);
|
||||
num_max_vec->push_back(num_segments_v);
|
||||
} else {
|
||||
auto n_abstract_tensor = input_args[kInputIndex2]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(n_abstract_tensor);
|
||||
num_vec->push_back(-1);
|
||||
auto num_min_value = n_abstract_tensor->get_min_value();
|
||||
auto num_max_value = n_abstract_tensor->get_max_value();
|
||||
if (num_min_value != nullptr && num_max_value != nullptr) {
|
||||
*num_min_vec = GetValue<ShapeVector>(num_min_value);
|
||||
*num_max_vec = GetValue<ShapeVector>(num_max_value);
|
||||
}
|
||||
}
|
||||
} else if (input_args[kInputIndex2]->isa<abstract::AbstractScalar>()) {
|
||||
auto num_segments_input_type = input_args[kInputIndex2]->BuildType();
|
||||
|
@ -74,8 +66,6 @@ void GetNumSegmentsValue(const PrimitivePtr &primitive, const std::vector<Abstra
|
|||
}
|
||||
(void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name);
|
||||
num_vec->push_back(num_segments_v);
|
||||
num_min_vec->push_back(num_segments_v);
|
||||
num_max_vec->push_back(num_segments_v);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << op_name
|
||||
<< "', the third input type should be tensor or scalar, but got invalid abstract type:"
|
||||
|
@ -88,8 +78,6 @@ abstract::ShapePtr UnsortedSegmentSumInferShape(const PrimitivePtr &primitive,
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const std::string &op_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto x_min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMinShape];
|
||||
auto x_max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMaxShape];
|
||||
auto x_shape_rank = SizeToLong(x_shape.size());
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_x size", x_shape_rank, kGreaterThan, 0, op_name);
|
||||
auto segment_ids_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
|
@ -123,9 +111,7 @@ abstract::ShapePtr UnsortedSegmentSumInferShape(const PrimitivePtr &primitive,
|
|||
abstract::CheckShapeAnyAndPositive(op_name + " segment_ids_shape", segment_ids_shape);
|
||||
|
||||
ShapeVector num_vec;
|
||||
ShapeVector num_min_vec;
|
||||
ShapeVector num_max_vec;
|
||||
GetNumSegmentsValue(primitive, input_args, &num_vec, &num_min_vec, &num_max_vec);
|
||||
GetNumSegmentsValue(primitive, input_args, &num_vec);
|
||||
int64_t batch_rank = 0;
|
||||
if (primitive->HasAttr(kBatchRank)) {
|
||||
auto batch_rank_ptr = primitive->GetAttr(kBatchRank);
|
||||
|
@ -140,21 +126,9 @@ abstract::ShapePtr UnsortedSegmentSumInferShape(const PrimitivePtr &primitive,
|
|||
(void)copy(x_shape.begin() + segment_ids_shape_rank, x_shape.end(), std::back_inserter(out_vec));
|
||||
return out_vec;
|
||||
};
|
||||
ShapeVector out_min_shape;
|
||||
ShapeVector out_max_shape;
|
||||
|
||||
auto out_vec = calc_shape(num_vec, x_shape);
|
||||
(void)copy(out_vec.begin(), out_vec.end(), std::back_inserter(output_shape));
|
||||
auto output_shape_rank = SizeToLong(output_shape.size());
|
||||
auto out_max_shape_rank = SizeToLong(calc_shape(num_max_vec, x_max_shape).size());
|
||||
bool x_min_any_shape =
|
||||
std::any_of(x_min_shape.begin(), x_min_shape.end(), [](int64_t dim) { return dim == abstract::Shape::SHP_ANY; });
|
||||
bool x_max_any_shape =
|
||||
std::any_of(x_max_shape.begin(), x_max_shape.end(), [](int64_t dim) { return dim == abstract::Shape::SHP_ANY; });
|
||||
if ((out_max_shape_rank == output_shape_rank) && (!x_min_any_shape && !x_max_any_shape)) {
|
||||
out_min_shape = calc_shape(num_min_vec, x_min_shape);
|
||||
out_max_shape = calc_shape(num_max_vec, x_max_shape);
|
||||
return std::make_shared<abstract::Shape>(output_shape, out_min_shape, out_max_shape);
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue