slice infer

This commit is contained in:
Henry Shi 2022-07-29 00:16:09 +08:00
parent 6d87fa6288
commit 68c4a9c00d
1 changed files with 1 additions and 10 deletions

View File

@ -100,12 +100,7 @@ abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vec
MS_EXCEPTION(ValueError) << "For Slice, the shape of input|begin|size must be equal."; MS_EXCEPTION(ValueError) << "For Slice, the shape of input|begin|size must be equal.";
} }
(void)CheckAndConvertUtils::CheckPositiveVector("input_begin", input_begin_value, prim_name); (void)CheckAndConvertUtils::CheckPositiveVector("input_begin", input_begin_value, prim_name);
bool is_dynamic = false;
for (size_t i = 0; i < rank; ++i) { for (size_t i = 0; i < rank; ++i) {
if (input_x_shape[i] < 0) {
is_dynamic = true;
continue;
}
if (input_begin_value[i] + input_size_value[i] > input_x_shape[i]) { if (input_begin_value[i] + input_size_value[i] > input_x_shape[i]) {
MS_EXCEPTION(ValueError) << "For Slice, the sum of begin_shape[" << i << "] and size_shape[" << i MS_EXCEPTION(ValueError) << "For Slice, the sum of begin_shape[" << i << "] and size_shape[" << i
<< "] must be no greater than input_x_shape[" << i << "]."; << "] must be no greater than input_x_shape[" << i << "].";
@ -116,11 +111,7 @@ abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vec
out_shape_min[i] = input_size_value[i]; out_shape_min[i] = input_size_value[i];
out_shape_max[i] = input_size_value[i]; out_shape_max[i] = input_size_value[i];
} }
if (!is_dynamic) { return std::make_shared<abstract::Shape>(input_size_value);
return std::make_shared<abstract::Shape>(input_size_value);
} else {
return std::make_shared<abstract::Shape>(input_size_value, out_shape_min, out_shape_max);
}
} }
TypePtr SliceInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { TypePtr SliceInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {