fix infer bug

This commit is contained in:
ckey_Dou 2023-02-28 10:23:36 +08:00
parent e61513cc86
commit 47667c7a9a
1 changed files with 1 additions and 51 deletions

View File

@ -66,57 +66,16 @@ std::vector<int64_t> InferImplSliceFuncCalInputValue(const PrimitivePtr &primiti
return tmp_input;
}
ShapeVector GetOutputShape(const ShapeVector &input_size_shape, const ShapeVector &input_begin_shape,
const ShapeVector &input_x_shape, bool is_inputx_dyn) {
ShapeVector out_shape = {};
bool is_size_dyn_rank = IsDynamicRank(input_size_shape);
if (is_size_dyn_rank) {
out_shape.push_back(kDynamicOutValue);
return out_shape;
}
bool is_size_dyn_shape = IsDynamic(input_size_shape);
if (is_size_dyn_shape) {
if (!is_inputx_dyn) {
for (uint32_t i = 0; i < input_x_shape.size(); i++) {
out_shape.push_back(-1);
}
return out_shape;
}
bool is_begin_dynamic_rank = IsDynamicRank(input_begin_shape);
if (is_begin_dynamic_rank) {
out_shape.push_back(kDynamicOutValue);
return out_shape;
}
bool is_begin_dyn_shape = IsDynamic(input_begin_shape);
if (is_begin_dyn_shape) {
for (uint32_t i = 0; i < input_begin_shape.size(); i++) {
out_shape.push_back(-1);
}
return out_shape;
}
}
for (int64_t i = 0; i < input_size_shape[0]; i++) {
out_shape.push_back(-1);
}
return out_shape;
}
abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
MS_EXCEPTION_IF_CHECK_FAIL(input_args.size() == kSliceInputNum, "Slice inputs num error");
auto input_x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto input_begin_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape());
auto input_size_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape());
auto input_x_shape = input_x_shape_map[kShape];
auto input_begin_value_ptr = input_args[kInputIndex1]->BuildValue();
auto input_size_value_ptr = input_args[kInputIndex2]->BuildValue();
auto input_begin_shape = input_begin_shape_map[kShape];
auto input_size_shape = input_size_shape_map[kShape];
(void)CheckAndConvertUtils::CheckInteger("rank of input_x", SizeToLong(input_x_shape.size()), kGreaterThan, 0,
prim_name);
@ -135,16 +94,7 @@ abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vec
}
if (!IsValueKnown(input_size_value_ptr)) {
if (input_size_shape.size() == 0) {
out_shape.push_back(kDynamicOutValue);
return std::make_shared<abstract::Shape>(out_shape);
}
if (input_size_shape[0] < kDynamicOutValue) {
MS_EXCEPTION(ValueError) << "For Slice, check input_size_shape failed.";
}
out_shape = GetOutputShape(input_size_shape, input_begin_shape, input_x_shape, is_inputx_dyn);
out_shape = GetShapeValue(primitive, input_args[kInputIndex2]);
return std::make_shared<abstract::Shape>(out_shape);
}