diff --git a/mindspore/core/ops/batch_to_space_nd.cc b/mindspore/core/ops/batch_to_space_nd.cc index 30f2088177a..f2a06fba7d5 100644 --- a/mindspore/core/ops/batch_to_space_nd.cc +++ b/mindspore/core/ops/batch_to_space_nd.cc @@ -23,6 +23,7 @@ #include "utils/check_convert_utils.h" #include "abstract/ops/primitive_infer_map.h" #include "mindapi/src/helper.h" +#include "utils/ms_context.h" namespace mindspore { namespace ops { @@ -35,12 +36,17 @@ abstract::ShapePtr BatchToSpaceNDInferShape(const PrimitivePtr &primitive, if (IsDynamicRank(x_shape)) { return std::make_shared(std::vector{abstract::Shape::kShapeRankAny}); } + constexpr int64_t len = 4; + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool is_ascend = (context->get_param(MS_CTX_DEVICE_TARGET) == kAscendDevice); + (void)CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()), + (is_ascend ? kEqual : kGreaterEqual), len, prim_name); if (IsDynamicShape(x_shape)) { std::vector res(x_shape.size(), abstract::Shape::kShapeDimAny); return std::make_shared(res); } - constexpr int64_t len = 4; - (void)CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()), kGreaterEqual, len, prim_name); + auto out_shape = x_shape; int64_t block_shape_prod = 1;