slice infer

This commit is contained in:
Henry Shi 2022-07-29 00:16:09 +08:00
parent 6d87fa6288
commit 68c4a9c00d
1 changed files with 1 additions and 10 deletions

View File

@ -100,12 +100,7 @@ abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vec
MS_EXCEPTION(ValueError) << "For Slice, the shape of input|begin|size must be equal.";
}
(void)CheckAndConvertUtils::CheckPositiveVector("input_begin", input_begin_value, prim_name);
bool is_dynamic = false;
for (size_t i = 0; i < rank; ++i) {
if (input_x_shape[i] < 0) {
is_dynamic = true;
continue;
}
if (input_begin_value[i] + input_size_value[i] > input_x_shape[i]) {
MS_EXCEPTION(ValueError) << "For Slice, the sum of begin_shape[" << i << "] and size_shape[" << i
<< "] must be no greater than input_x_shape[" << i << "].";
@ -116,11 +111,7 @@ abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vec
out_shape_min[i] = input_size_value[i];
out_shape_max[i] = input_size_value[i];
}
if (!is_dynamic) {
return std::make_shared<abstract::Shape>(input_size_value);
} else {
return std::make_shared<abstract::Shape>(input_size_value, out_shape_min, out_shape_max);
}
}
TypePtr SliceInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {