diff --git a/mindspore/core/ops/batch_matmul.cc b/mindspore/core/ops/batch_matmul.cc index 9db244c0e84..eb96fc8da9c 100644 --- a/mindspore/core/ops/batch_matmul.cc +++ b/mindspore/core/ops/batch_matmul.cc @@ -36,8 +36,8 @@ abstract::ShapePtr BatchMatmulInferShape(const PrimitivePtr &primitive, auto y_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape()); auto x_shp = x_shape_map[kShape]; auto y_shp = y_shape_map[kShape]; - constexpr size_t x_dim_limit = 2; - constexpr size_t y_dim_limit = 3; + constexpr size_t x_dim_limit = 3; + constexpr size_t y_dim_limit = 2; if (x_shp.size() < x_dim_limit || y_shp.size() < y_dim_limit) { MS_EXCEPTION(ValueError) << "For BatchMatMul, input x should be greater or equal to 3, input y should be greater " "or equal to 2 while x size = " diff --git a/mindspore/core/ops/zeros.cc b/mindspore/core/ops/zeros.cc index 4f763ff9c67..b7bd0cbc854 100644 --- a/mindspore/core/ops/zeros.cc +++ b/mindspore/core/ops/zeros.cc @@ -57,6 +57,8 @@ AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(prim); + const int64_t input_num = 2; + CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, prim->name()); auto abs = ZerosInfer(nullptr, prim, input_args); // check auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(abs->BuildShape())[kShape];