!44994 add check dim for ascend

Merge pull request !44994 from r1chardf1d0/bts6
This commit is contained in:
i-robot 2022-11-03 14:05:58 +00:00 committed by Gitee
commit 73898d60e1
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 8 additions and 2 deletions

View File

@ -23,6 +23,7 @@
#include "utils/check_convert_utils.h" #include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h" #include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h" #include "mindapi/src/helper.h"
#include "utils/ms_context.h"
namespace mindspore { namespace mindspore {
namespace ops { namespace ops {
@ -35,12 +36,17 @@ abstract::ShapePtr BatchToSpaceNDInferShape(const PrimitivePtr &primitive,
if (IsDynamicRank(x_shape)) { if (IsDynamicRank(x_shape)) {
return std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeRankAny}); return std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeRankAny});
} }
constexpr int64_t len = 4;
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool is_ascend = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice);
(void)CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()),
(is_ascend ? kEqual : kGreaterEqual), len, prim_name);
if (IsDynamicShape(x_shape)) { if (IsDynamicShape(x_shape)) {
std::vector<int64_t> res(x_shape.size(), abstract::Shape::kShapeDimAny); std::vector<int64_t> res(x_shape.size(), abstract::Shape::kShapeDimAny);
return std::make_shared<abstract::Shape>(res); return std::make_shared<abstract::Shape>(res);
} }
constexpr int64_t len = 4;
(void)CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()), kGreaterEqual, len, prim_name);
auto out_shape = x_shape; auto out_shape = x_shape;
int64_t block_shape_prod = 1; int64_t block_shape_prod = 1;