!44994 add check dim for ascend
Merge pull request !44994 from r1chardf1d0/bts6
This commit is contained in:
commit
73898d60e1
|
@ -23,6 +23,7 @@
|
|||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -35,12 +36,17 @@ abstract::ShapePtr BatchToSpaceNDInferShape(const PrimitivePtr &primitive,
|
|||
if (IsDynamicRank(x_shape)) {
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeRankAny});
|
||||
}
|
||||
constexpr int64_t len = 4;
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool is_ascend = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()),
|
||||
(is_ascend ? kEqual : kGreaterEqual), len, prim_name);
|
||||
if (IsDynamicShape(x_shape)) {
|
||||
std::vector<int64_t> res(x_shape.size(), abstract::Shape::kShapeDimAny);
|
||||
return std::make_shared<abstract::Shape>(res);
|
||||
}
|
||||
constexpr int64_t len = 4;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()), kGreaterEqual, len, prim_name);
|
||||
|
||||
auto out_shape = x_shape;
|
||||
|
||||
int64_t block_shape_prod = 1;
|
||||
|
|
Loading…
Reference in New Issue