diff --git a/mindspore/core/ops/slice.cc b/mindspore/core/ops/slice.cc index 9035aefe837..a27441bcb7d 100644 --- a/mindspore/core/ops/slice.cc +++ b/mindspore/core/ops/slice.cc @@ -49,7 +49,6 @@ namespace mindspore { namespace ops { namespace { constexpr size_t kSliceInputNum = 3; -constexpr int64_t kDynamicOutValue = -2; std::vector InferImplSliceFuncCalInputValue(const PrimitivePtr &primitive, const ValuePtr &input_value) { std::vector tmp_input; MS_EXCEPTION_IF_NULL(input_value); @@ -94,6 +93,14 @@ abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vec } if (!IsValueKnown(input_size_value_ptr)) { + auto arg = input_args[kInputIndex2]; + if (arg->isa()) { + auto abs_tensor = arg->cast(); + auto tensor_shape = abs_tensor->shape()->shape(); + if (tensor_shape.size() != 1) { + MS_EXCEPTION(ValueError) << "For Slice, the shape of input|begin|size must be equal."; + } + } out_shape = GetShapeValue(primitive, input_args[kInputIndex2]); return std::make_shared(out_shape); } @@ -109,6 +116,11 @@ abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vec if (input_x_shape[i] < 0) { continue; } + if (input_size_value[i] < -1) { + MS_EXCEPTION(RuntimeError) << "For Slice, the value in size should not be less than -1, but got " + << input_size_value[i]; + } + if (input_begin_value[i] + input_size_value[i] > input_x_shape[i]) { MS_EXCEPTION(ValueError) << "For Slice, the sum of begin_shape[" << i << "] and size_shape[" << i << "] must be no greater than input_x_shape[" << i << "].";