From 47667c7a9ac8cb23859ec0c9781ee9bb3bec4736 Mon Sep 17 00:00:00 2001 From: ckey_Dou Date: Tue, 28 Feb 2023 10:23:36 +0800 Subject: [PATCH] fix infer bug --- mindspore/core/ops/slice.cc | 52 +------------------------------------ 1 file changed, 1 insertion(+), 51 deletions(-) diff --git a/mindspore/core/ops/slice.cc b/mindspore/core/ops/slice.cc index fee90718ec4..9035aefe837 100644 --- a/mindspore/core/ops/slice.cc +++ b/mindspore/core/ops/slice.cc @@ -66,57 +66,16 @@ std::vector InferImplSliceFuncCalInputValue(const PrimitivePtr &primiti return tmp_input; } -ShapeVector GetOutputShape(const ShapeVector &input_size_shape, const ShapeVector &input_begin_shape, - const ShapeVector &input_x_shape, bool is_inputx_dyn) { - ShapeVector out_shape = {}; - bool is_size_dyn_rank = IsDynamicRank(input_size_shape); - if (is_size_dyn_rank) { - out_shape.push_back(kDynamicOutValue); - return out_shape; - } - - bool is_size_dyn_shape = IsDynamic(input_size_shape); - if (is_size_dyn_shape) { - if (!is_inputx_dyn) { - for (uint32_t i = 0; i < input_x_shape.size(); i++) { - out_shape.push_back(-1); - } - return out_shape; - } - - bool is_begin_dynamic_rank = IsDynamicRank(input_begin_shape); - if (is_begin_dynamic_rank) { - out_shape.push_back(kDynamicOutValue); - return out_shape; - } - - bool is_begin_dyn_shape = IsDynamic(input_begin_shape); - if (is_begin_dyn_shape) { - for (uint32_t i = 0; i < input_begin_shape.size(); i++) { - out_shape.push_back(-1); - } - return out_shape; - } - } - - for (int64_t i = 0; i < input_size_shape[0]; i++) { - out_shape.push_back(-1); - } - return out_shape; -} - abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); MS_EXCEPTION_IF_CHECK_FAIL(input_args.size() == kSliceInputNum, "Slice inputs num error"); auto input_x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); auto input_begin_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape()); - auto input_size_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape()); auto input_x_shape = input_x_shape_map[kShape]; auto input_begin_value_ptr = input_args[kInputIndex1]->BuildValue(); auto input_size_value_ptr = input_args[kInputIndex2]->BuildValue(); auto input_begin_shape = input_begin_shape_map[kShape]; - auto input_size_shape = input_size_shape_map[kShape]; (void)CheckAndConvertUtils::CheckInteger("rank of input_x", SizeToLong(input_x_shape.size()), kGreaterThan, 0, prim_name); @@ -135,16 +94,7 @@ abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vec } if (!IsValueKnown(input_size_value_ptr)) { - if (input_size_shape.size() == 0) { - out_shape.push_back(kDynamicOutValue); - return std::make_shared(out_shape); - } - - if (input_size_shape[0] < kDynamicOutValue) { - MS_EXCEPTION(ValueError) << "For Slice, check input_size_shape failed."; - } - - out_shape = GetOutputShape(input_size_shape, input_begin_shape, input_x_shape, is_inputx_dyn); + out_shape = GetShapeValue(primitive, input_args[kInputIndex2]); return std::make_shared(out_shape); }