forked from mindspore-Ecosystem/mindspore
slice infer
This commit is contained in:
parent
6d87fa6288
commit
68c4a9c00d
|
@ -100,12 +100,7 @@ abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vec
|
||||||
MS_EXCEPTION(ValueError) << "For Slice, the shape of input|begin|size must be equal.";
|
MS_EXCEPTION(ValueError) << "For Slice, the shape of input|begin|size must be equal.";
|
||||||
}
|
}
|
||||||
(void)CheckAndConvertUtils::CheckPositiveVector("input_begin", input_begin_value, prim_name);
|
(void)CheckAndConvertUtils::CheckPositiveVector("input_begin", input_begin_value, prim_name);
|
||||||
bool is_dynamic = false;
|
|
||||||
for (size_t i = 0; i < rank; ++i) {
|
for (size_t i = 0; i < rank; ++i) {
|
||||||
if (input_x_shape[i] < 0) {
|
|
||||||
is_dynamic = true;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (input_begin_value[i] + input_size_value[i] > input_x_shape[i]) {
|
if (input_begin_value[i] + input_size_value[i] > input_x_shape[i]) {
|
||||||
MS_EXCEPTION(ValueError) << "For Slice, the sum of begin_shape[" << i << "] and size_shape[" << i
|
MS_EXCEPTION(ValueError) << "For Slice, the sum of begin_shape[" << i << "] and size_shape[" << i
|
||||||
<< "] must be no greater than input_x_shape[" << i << "].";
|
<< "] must be no greater than input_x_shape[" << i << "].";
|
||||||
|
@ -116,11 +111,7 @@ abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vec
|
||||||
out_shape_min[i] = input_size_value[i];
|
out_shape_min[i] = input_size_value[i];
|
||||||
out_shape_max[i] = input_size_value[i];
|
out_shape_max[i] = input_size_value[i];
|
||||||
}
|
}
|
||||||
if (!is_dynamic) {
|
return std::make_shared<abstract::Shape>(input_size_value);
|
||||||
return std::make_shared<abstract::Shape>(input_size_value);
|
|
||||||
} else {
|
|
||||||
return std::make_shared<abstract::Shape>(input_size_value, out_shape_min, out_shape_max);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr SliceInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
TypePtr SliceInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
|
Loading…
Reference in New Issue