fix infer bug
This commit is contained in:
parent
e61513cc86
commit
47667c7a9a
|
@ -66,57 +66,16 @@ std::vector<int64_t> InferImplSliceFuncCalInputValue(const PrimitivePtr &primiti
|
|||
return tmp_input;
|
||||
}
|
||||
|
||||
ShapeVector GetOutputShape(const ShapeVector &input_size_shape, const ShapeVector &input_begin_shape,
|
||||
const ShapeVector &input_x_shape, bool is_inputx_dyn) {
|
||||
ShapeVector out_shape = {};
|
||||
bool is_size_dyn_rank = IsDynamicRank(input_size_shape);
|
||||
if (is_size_dyn_rank) {
|
||||
out_shape.push_back(kDynamicOutValue);
|
||||
return out_shape;
|
||||
}
|
||||
|
||||
bool is_size_dyn_shape = IsDynamic(input_size_shape);
|
||||
if (is_size_dyn_shape) {
|
||||
if (!is_inputx_dyn) {
|
||||
for (uint32_t i = 0; i < input_x_shape.size(); i++) {
|
||||
out_shape.push_back(-1);
|
||||
}
|
||||
return out_shape;
|
||||
}
|
||||
|
||||
bool is_begin_dynamic_rank = IsDynamicRank(input_begin_shape);
|
||||
if (is_begin_dynamic_rank) {
|
||||
out_shape.push_back(kDynamicOutValue);
|
||||
return out_shape;
|
||||
}
|
||||
|
||||
bool is_begin_dyn_shape = IsDynamic(input_begin_shape);
|
||||
if (is_begin_dyn_shape) {
|
||||
for (uint32_t i = 0; i < input_begin_shape.size(); i++) {
|
||||
out_shape.push_back(-1);
|
||||
}
|
||||
return out_shape;
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < input_size_shape[0]; i++) {
|
||||
out_shape.push_back(-1);
|
||||
}
|
||||
return out_shape;
|
||||
}
|
||||
|
||||
abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(input_args.size() == kSliceInputNum, "Slice inputs num error");
|
||||
auto input_x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto input_begin_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape());
|
||||
auto input_size_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape());
|
||||
auto input_x_shape = input_x_shape_map[kShape];
|
||||
auto input_begin_value_ptr = input_args[kInputIndex1]->BuildValue();
|
||||
auto input_size_value_ptr = input_args[kInputIndex2]->BuildValue();
|
||||
auto input_begin_shape = input_begin_shape_map[kShape];
|
||||
auto input_size_shape = input_size_shape_map[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("rank of input_x", SizeToLong(input_x_shape.size()), kGreaterThan, 0,
|
||||
prim_name);
|
||||
|
||||
|
@ -135,16 +94,7 @@ abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vec
|
|||
}
|
||||
|
||||
if (!IsValueKnown(input_size_value_ptr)) {
|
||||
if (input_size_shape.size() == 0) {
|
||||
out_shape.push_back(kDynamicOutValue);
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
if (input_size_shape[0] < kDynamicOutValue) {
|
||||
MS_EXCEPTION(ValueError) << "For Slice, check input_size_shape failed.";
|
||||
}
|
||||
|
||||
out_shape = GetOutputShape(input_size_shape, input_begin_shape, input_x_shape, is_inputx_dyn);
|
||||
out_shape = GetShapeValue(primitive, input_args[kInputIndex2]);
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue