!22721 clean codecheck for ops r1.3

Merge pull request !22721 from wangnan39/codecheck_clean_r1.3
This commit is contained in:
i-robot 2021-09-02 09:19:21 +00:00 committed by Gitee
commit 493aed5f16
77 changed files with 448 additions and 335 deletions

View File

@ -28,19 +28,19 @@ abstract::AbstractBasePtr AdamInfer(const PrimitivePtr &primitive, const std::ve
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, prim_name);
// infer shape
auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape];
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->GetShapeTrack())[kShape];
auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[9]->GetShapeTrack())[kShape];
auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShapeTrack())[kShape];
auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape];
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShapeTrack())[kShape];
auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex9]->GetShapeTrack())[kShape];
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "m_shape", m_shape, prim_name);
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "v_shape", v_shape, prim_name);
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "grad_shape", grad_shape, prim_name);
// infer type
auto var_type = input_args[0]->BuildType();
auto m_type = input_args[1]->BuildType();
auto v_type = input_args[2]->BuildType();
auto grad_type = input_args[9]->BuildType();
auto var_type = input_args[kInputIndex0]->BuildType();
auto m_type = input_args[kInputIndex1]->BuildType();
auto v_type = input_args[kInputIndex2]->BuildType();
auto grad_type = input_args[kInputIndex9]->BuildType();
auto infer_var_type = CheckAndConvertUtils::CheckTensorTypeValid("var_type", var_type, common_valid_types, prim_name);
auto infer_m_type = CheckAndConvertUtils::CheckTensorTypeValid("m_type", m_type, common_valid_types, prim_name);
auto infer_v_type = CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, common_valid_types, prim_name);

View File

@ -24,33 +24,22 @@
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 2, prim_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
return BroadCastInferShape(prim_name, input_args);
}
auto output_shape = BroadCastInferShape(prim_name, input_args);
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto op_name = prim->name();
(void)CheckAndConvertUtils::CheckInteger("Add infer", SizeToLong(input_args.size()), kGreaterEqual, 2, op_name);
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("y", input_args[1]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace
AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
auto output_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim_name);
return abstract::MakeAbstract(output_shape, output_type);
}
REGISTER_PRIMITIVE_C(kNameAdd, Add);
} // namespace ops

View File

@ -61,7 +61,8 @@ AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const Pr
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("apply_momentum_infer", SizeToLong(input_args.size()), kEqual, 5, prim_name);
const int64_t input_num = 5;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
@ -70,11 +71,11 @@ AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const Pr
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
// Infer type
auto v_tensor_type = input_args[0]->BuildType();
auto a_tensor_type = input_args[1]->BuildType();
auto l_type = input_args[2]->BuildType();
auto g_type = input_args[3]->BuildType();
auto m_type = input_args[4]->BuildType();
auto v_tensor_type = input_args[kInputIndex0]->BuildType();
auto a_tensor_type = input_args[kInputIndex1]->BuildType();
auto l_type = input_args[kInputIndex2]->BuildType();
auto g_type = input_args[kInputIndex3]->BuildType();
auto m_type = input_args[kInputIndex4]->BuildType();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_tensor_type, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("a_type", a_tensor_type, valid_types, prim_name);

View File

@ -30,8 +30,9 @@ AbstractBasePtr InferImplAssign(const abstract::AnalysisEnginePtr &, const Primi
const AbstractBasePtrList &args_spec_list) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger(
"Assign infer", SizeToLong(CheckAndConvertUtils::GetRemoveMonadAbsNum(args_spec_list)), kEqual, 2, prim_name);
"infer", SizeToLong(CheckAndConvertUtils::GetRemoveMonadAbsNum(args_spec_list)), kEqual, input_num, prim_name);
auto check_types = common_valid_types;
(void)check_types.emplace(kBool);
auto variable_type = args_spec_list[0]->BuildType();

View File

@ -83,7 +83,9 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, 4, op_name);
const int64_t x_size = 4;
const int64_t attr_size = 4;
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, x_size, op_name);
if (format == NHWC) {
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
}
@ -94,8 +96,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto in_h = in_shape[2];
auto in_w = in_shape[3];
auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
(void)CheckAndConvertUtils::CheckInteger("kernel size", SizeToLong(kernel_size.size()), kEqual, 4, op_name);
(void)CheckAndConvertUtils::CheckInteger("strides size", SizeToLong(strides.size()), kEqual, 4, op_name);
(void)CheckAndConvertUtils::CheckInteger("kernel size", SizeToLong(kernel_size.size()), kEqual, attr_size, op_name);
(void)CheckAndConvertUtils::CheckInteger("strides size", SizeToLong(strides.size()), kEqual, attr_size, op_name);
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) { return stride <= 0; })) {
MS_LOG(EXCEPTION) << "Strides is not valid, strides must be positive.";
}

View File

@ -38,12 +38,15 @@ void GetAttrs(const PrimitivePtr &primitive, std::vector<int64_t> *kernel_size,
// attr kernel size
*kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
if (kernel_size->size() != kKernelDims) {
MS_LOG(EXCEPTION) << "kernel_size of AvgPool3D must be 3.";
MS_LOG(EXCEPTION) << "`kernel_size` of AvgPool3D must be 3.";
}
// attr strides
*strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
if (strides->size() != kStridesDims) {
MS_LOG(EXCEPTION) << "strides of AvgPool3D must be 3.";
MS_LOG(EXCEPTION) << "`strides` of AvgPool3D must be 3.";
}
if (std::any_of(strides->begin(), strides->end(), [](int64_t stride) { return stride <= 0; })) {
MS_EXCEPTION(ValueError) << "Invalid strides, strides must be all positive.";
}
// sttr pad_list
*pad_list = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPadList));
@ -135,6 +138,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto kernel_d = kernel_size[0];
auto kernel_h = kernel_size[1];
auto kernel_w = kernel_size[2];
auto stride_d = strides[0];
auto stride_h = strides[1];
auto stride_w = strides[2];
@ -148,7 +152,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
std::vector<int64_t> out_shape =
GetOutputShape(in_shape, kernel_d, kernel_h, kernel_w, stride_d, stride_h, stride_w, new_pad_list, ceil_mode);
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t shp_v) { return shp_v <= 0; })) {
MS_LOG(EXCEPTION) << "output size is not valid.";
}
return std::make_shared<abstract::Shape>(out_shape);

View File

@ -144,7 +144,9 @@ bool BatchMatmul::get_transpose_b() const {
AbstractBasePtr BatchMatmulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
(void)CheckAndConvertUtils::CheckInteger("BatchMatmul infer", input_args.size(), kGreaterEqual, 2, primitive->name());
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("BatchMatmul infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
primitive->name());
return abstract::MakeAbstract(BatchMatmulInferShape(primitive, input_args),
BatchMatmulInferType(primitive, input_args));
}

View File

@ -73,17 +73,18 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit
// Infer shape
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("batch_norm_infer", SizeToLong(input_args.size()), kEqual, 5, prim_name);
const int64_t input_num = 5;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
auto input_x = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) {
input_x = {input_x[0], input_x[3], input_x[1], input_x[2]};
}
auto scale = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto bias = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
auto mean = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape];
auto variance = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->BuildShape())[kShape];
auto scale = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto bias = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto mean = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
auto variance = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape];
std::vector<int64_t> input_shape_norm;
if (format == NCHW) {
@ -106,19 +107,19 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit
}
// Infer type
auto scale_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element();
auto bias_type = input_args[2]->BuildType()->cast<TensorTypePtr>()->element();
auto scale_type = input_args[kInputIndex1]->BuildType()->cast<TensorTypePtr>()->element();
auto bias_type = input_args[kInputIndex2]->BuildType()->cast<TensorTypePtr>()->element();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
auto input_x_type =
CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name);
CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[kInputIndex0]->BuildType(), valid_types, prim_name);
std::map<std::string, TypePtr> args;
args.emplace("scale", input_args[1]->BuildType());
args.emplace("bias", input_args[2]->BuildType());
args.emplace("scale", input_args[kInputIndex1]->BuildType());
args.emplace("bias", input_args[kInputIndex2]->BuildType());
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
std::map<std::string, TypePtr> args_moving;
args_moving.emplace("scale", input_args[2]->BuildType());
args_moving.emplace("bias", input_args[3]->BuildType());
args_moving.emplace("scale", input_args[kInputIndex2]->BuildType());
args_moving.emplace("bias", input_args[kInputIndex3]->BuildType());
(void)CheckAndConvertUtils::CheckTensorTypeSame(args_moving, valid_types, prim_name);
auto output0 = std::make_shared<abstract::AbstractTensor>(input_x_type, input_x);

View File

@ -59,15 +59,18 @@ AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri
auto block_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize));
auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops));
auto out_shape = x_shape;
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name);
(void)CheckAndConvertUtils::CheckInteger("block_size size", SizeToLong(block_size.size()), kEqual, 4, prim_name);
(void)CheckAndConvertUtils::CheckInteger("crops size", SizeToLong(crops.size()), kEqual, 4, prim_name);
(void)CheckAndConvertUtils::CheckInteger("crops[0] size", SizeToLong(crops[0].size()), kEqual, 4, prim_name);
(void)CheckAndConvertUtils::CheckInteger("crops[1] size", SizeToLong(crops[1].size()), kEqual, 4, prim_name);
const int64_t attr_size = 4;
const int64_t x_rank = 4;
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kEqual, x_rank, prim_name);
(void)CheckAndConvertUtils::CheckInteger("block_size size", SizeToLong(block_size.size()), kEqual, attr_size,
prim_name);
(void)CheckAndConvertUtils::CheckInteger("crops size", SizeToLong(crops.size()), kEqual, attr_size, prim_name);
(void)CheckAndConvertUtils::CheckInteger("crops[0] size", SizeToLong(crops[0].size()), kEqual, attr_size, prim_name);
(void)CheckAndConvertUtils::CheckInteger("crops[1] size", SizeToLong(crops[1].size()), kEqual, attr_size, prim_name);
for (size_t i = 0; i < 2; ++i) {
auto x_block_prod = out_shape[i + 2] * block_size[i];
auto crops_sum = crops[i][0] + crops[i][1];
CheckAndConvertUtils::Check("x block shape prod", x_block_prod, kGreaterThan, "crops sum", 4, prim_name);
CheckAndConvertUtils::Check("x block shape prod", x_block_prod, kGreaterThan, "crops sum", attr_size, prim_name);
out_shape[i + 2] = x_block_prod - crops_sum;
}
(void)CheckAndConvertUtils::CheckInteger("x_shape[0] % (block_size[0]*block_size[1])",

View File

@ -30,7 +30,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name);
const int64_t x_rank = 4;
(void)CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()), kEqual, x_rank, prim_name);
auto out_shape = x_shape;
int64_t block_shape_prod = 1;
size_t offset = 2;
@ -62,7 +63,8 @@ TypePtr InferType(const std::vector<AbstractBasePtr> &input_args) {
} // namespace
void BatchToSpaceND::set_crops(std::vector<std::vector<int64_t>> crops) {
(void)CheckAndConvertUtils::CheckInteger(kCrops, SizeToLong(crops.size()), kEqual, 2, this->name());
const int64_t crop_size = 2;
(void)CheckAndConvertUtils::CheckInteger(kCrops, SizeToLong(crops.size()), kEqual, crop_size, this->name());
size_t h = crops.size();
size_t w = crops[0].size();
std::vector<size_t> temp_w = {2, 2};
@ -80,7 +82,9 @@ std::vector<std::vector<int64_t>> BatchToSpaceND::get_crops() const {
return GetValue<std::vector<std::vector<int64_t>>>(value_ptr);
}
void BatchToSpaceND::set_block_shape(std::vector<int64_t> block_shape) {
(void)CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape.size()), kEqual, 2, this->name());
const int64_t block_size = 2;
(void)CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape.size()), kEqual, block_size,
this->name());
for (size_t i = 0; i < block_shape.size(); i++) {
(void)CheckAndConvertUtils::CheckInteger(kBlockShape, block_shape[i], kGreaterEqual, 1, this->name());
}

View File

@ -35,15 +35,20 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto bias = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1);
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(bias);
(void)CheckAndConvertUtils::CheckInteger("arg size", SizeToLong(input_args.size()), kEqual, 2, prim_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("arg size", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto input_shape = shape_map[kShape];
auto min_shape = shape_map[kMinShape];
auto max_shape = shape_map[kMaxShape];
CheckAndConvertUtils::CheckInRange("bias_add_infer", input_shape.size(), kIncludeBoth, {2, 5}, prim_name);
const int64_t x_min_rank = 2;
const int64_t x_max_rank = 5;
CheckAndConvertUtils::CheckInRange("bias_add_infer", input_shape.size(), kIncludeBoth, {x_min_rank, x_max_rank},
prim_name);
auto bias_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("bias rank", SizeToLong(bias_shape.size()), kEqual, 1, prim_name);
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(input_shape.size()), kGreaterEqual, 2, prim_name);
const int64_t x_size = 2;
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(input_shape.size()), kGreaterEqual, x_size, prim_name);
auto data_format_ptr = primitive->GetAttr("format");
int64_t data_format = Format::NCHW;
if (data_format_ptr != nullptr) {
@ -71,7 +76,9 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
(void)CheckAndConvertUtils::CheckInteger("biasadd_infer", SizeToLong(input_args.size()), kEqual, 2, prim_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("biasadd_infer", SizeToLong(input_args.size()), kEqual, input_num,
prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}

View File

@ -34,9 +34,9 @@ abstract::ShapePtr BinaryCrossEntroyInferShape(const PrimitivePtr &primitive,
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInRange("binary_cross_entropy_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name);
std::vector<int64_t> infer_shape;
if (weight_shape.size() < 1) {
@ -50,19 +50,20 @@ abstract::ShapePtr BinaryCrossEntroyInferShape(const PrimitivePtr &primitive,
}
TypePtr BinaryCrossEntroyInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
(void)CheckAndConvertUtils::CheckInteger("binary_cross_entropy_infer", SizeToLong(input_args.size()), kEqual, 3,
prim->name());
const int64_t input_num = 3;
(void)CheckAndConvertUtils::CheckInteger("binary_cross_entropy_infer", SizeToLong(input_args.size()), kEqual,
input_num, prim->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
std::map<std::string, TypePtr> types;
(void)types.emplace("x_shape", input_args[0]->BuildType());
(void)types.emplace("y_shape", input_args[1]->BuildType());
(void)types.emplace("x_shape", input_args[kInputIndex0]->BuildType());
(void)types.emplace("y_shape", input_args[kInputIndex1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
if (input_args[3]->BuildType() != nullptr) {
(void)types.emplace("x_shape", input_args[0]->BuildType());
(void)types.emplace("weight_shape", input_args[2]->BuildType());
if (input_args[kInputIndex3]->BuildType() != nullptr) {
(void)types.emplace("x_shape", input_args[kInputIndex0]->BuildType());
(void)types.emplace("weight_shape", input_args[kInputIndex2]->BuildType());
infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
return infer_type;

View File

@ -109,8 +109,9 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
auto w_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
auto x_shape = x_shape_map[kShape];
auto w_shape = w_shape_map[kShape];
(void)CheckAndConvertUtils::CheckInteger("x shape size", SizeToLong(x_shape.size()), kEqual, 4, primitive->name());
(void)CheckAndConvertUtils::CheckInteger("w shape size", SizeToLong(w_shape.size()), kEqual, 4, primitive->name());
const int64_t shape_size = 4;
(void)CheckAndConvertUtils::CheckInteger("x shape size", SizeToLong(x_shape.size()), kEqual, shape_size, prim_name);
(void)CheckAndConvertUtils::CheckInteger("w shape size", SizeToLong(w_shape.size()), kEqual, shape_size, prim_name);
auto x_min_shape = x_shape_map[kMinShape];
auto x_max_shape = x_shape_map[kMaxShape];
auto w_min_shape = w_shape_map[kMinShape];
@ -252,7 +253,8 @@ void Conv2D::set_pad_mode(const PadMode &pad_mode) {
}
void Conv2D::set_pad(const std::vector<int64_t> &pad) {
(void)CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name());
const int64_t pad_size = 4;
(void)CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, pad_size, name());
(void)AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
}
@ -316,7 +318,8 @@ Format Conv2D::get_format() const {
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
(void)CheckAndConvertUtils::CheckInteger("Conv2d infer", SizeToLong(input_args.size()), kGreaterEqual, 2,
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("Conv2d infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
primitive->name());
const std::set<TypePtr> valid_types = {kInt8, kInt32, kInt64, kFloat16, kFloat32};
std::map<std::string, TypePtr> types;

View File

@ -53,7 +53,8 @@ void Conv2DTranspose::set_out_channel(int64_t out_channel) {
}
void Conv2DTranspose::set_kernel_size(const std::vector<int64_t> &kernel_size) {
(void)CheckAndConvertUtils::CheckInteger(kKernelSize, SizeToLong(kernel_size.size()), kEqual, 2, name());
const int64_t kernel_len = 2;
(void)CheckAndConvertUtils::CheckInteger(kKernelSize, SizeToLong(kernel_size.size()), kEqual, kernel_len, name());
for (int64_t item : kernel_size) {
(void)CheckAndConvertUtils::CheckInteger(kKernelSize, item, kGreaterEqual, 1, name());
}
@ -61,7 +62,8 @@ void Conv2DTranspose::set_kernel_size(const std::vector<int64_t> &kernel_size) {
}
void Conv2DTranspose::set_stride(const std::vector<int64_t> &stride) {
(void)CheckAndConvertUtils::CheckInteger(kStride, SizeToLong(stride.size()), kEqual, 2, name());
const int64_t stride_size = 2;
(void)CheckAndConvertUtils::CheckInteger(kStride, SizeToLong(stride.size()), kEqual, stride_size, name());
for (int64_t item : stride) {
(void)CheckAndConvertUtils::CheckInteger(kStride, item, kGreaterEqual, 1, name());
}
@ -69,7 +71,9 @@ void Conv2DTranspose::set_stride(const std::vector<int64_t> &stride) {
}
void Conv2DTranspose::set_dilation(const std::vector<int64_t> &dilation) {
(void)CheckAndConvertUtils::CheckInteger(kDilation, SizeToLong(dilation.size()), kGreaterEqual, 2, name());
const int64_t dilation_size = 2;
(void)CheckAndConvertUtils::CheckInteger(kDilation, SizeToLong(dilation.size()), kGreaterEqual, dilation_size,
name());
(void)AddAttr(kDilation, MakeValue(dilation));
}
@ -87,7 +91,8 @@ void Conv2DTranspose::set_pad_mode(const PadMode &pad_mode) {
}
void Conv2DTranspose::set_pad(const std::vector<int64_t> &pad) {
(void)CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, 4, name());
const int64_t pad_size = 4;
(void)CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, pad_size, name());
(void)AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
}
@ -105,7 +110,8 @@ void Conv2DTranspose::set_format(const Format &format) {
}
void Conv2DTranspose::set_pad_list(const std::vector<int64_t> &pad_list) {
(void)CheckAndConvertUtils::CheckInteger(kPadList, SizeToLong(pad_list.size()), kEqual, 4, name());
const int64_t pad_size = 4;
(void)CheckAndConvertUtils::CheckInteger(kPadList, SizeToLong(pad_list.size()), kEqual, pad_size, name());
(void)this->AddAttr(kPadList, MakeValue(pad_list));
}

View File

@ -44,7 +44,8 @@ AbstractBasePtr CropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}

View File

@ -29,7 +29,9 @@ namespace mindspore {
namespace ops {
namespace {
void CheckCTCLossInputs(const std::vector<AbstractBasePtr> &input_args, const std::string &op_name) {
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kGreaterEqual, 4, op_name);
const int64_t input_num = 4;
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, input_num,
op_name);
auto inputs = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 0);
auto labels_indices = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 1);
auto labels_values = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 2);
@ -40,11 +42,18 @@ void CheckCTCLossInputs(const std::vector<AbstractBasePtr> &input_args, const st
auto labels_values_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(labels_values->BuildShape())[kShape];
auto sequence_length_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(sequence_length->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("inputs rank", inputs_shape.size(), kEqual, 3, op_name);
(void)CheckAndConvertUtils::CheckInteger("label_indices rank", labels_indices_shape.size(), kEqual, 2, op_name);
(void)CheckAndConvertUtils::CheckInteger("label_indices second dim", labels_indices_shape[1], kEqual, 2, op_name);
(void)CheckAndConvertUtils::CheckInteger("label_values rank", labels_values_shape.size(), kEqual, 1, op_name);
(void)CheckAndConvertUtils::CheckInteger("sequence_length rank", sequence_length_shape.size(), kEqual, 1, op_name);
const int64_t input_size = 3;
const int64_t label_indice_size = 2;
const int64_t label_indice_last_dim = 2;
(void)CheckAndConvertUtils::CheckInteger("inputs rank", SizeToLong(inputs_shape.size()), kEqual, input_size, op_name);
(void)CheckAndConvertUtils::CheckInteger("label_indices rank", SizeToLong(labels_indices_shape.size()), kEqual,
label_indice_size, op_name);
(void)CheckAndConvertUtils::CheckInteger("label_indices second dim", labels_indices_shape[1], kEqual,
label_indice_last_dim, op_name);
(void)CheckAndConvertUtils::CheckInteger("label_values rank", int64_t(labels_values_shape.size()), kEqual, 1,
op_name);
(void)CheckAndConvertUtils::CheckInteger("sequence_length rank", int64_t(sequence_length_shape.size()), kEqual, 1,
op_name);
if (labels_indices_shape[0] != labels_values_shape[0]) {
MS_EXCEPTION(ValueError) << "For CTCLoss first dim of label_indices and label_value must be same, but got "
@ -82,11 +91,15 @@ abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vec
TuplePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = primitive->name();
(void)CheckAndConvertUtils::CheckTensorTypeValid("labels_indices", input_args[1]->BuildType(), {kInt64}, op_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("labels_values", input_args[2]->BuildType(), {kInt32}, op_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("sequence_length", input_args[3]->BuildType(), {kInt32}, op_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("labels_indices", input_args[kInputIndex1]->BuildType(), {kInt64},
op_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("labels_values", input_args[kInputIndex2]->BuildType(), {kInt32},
op_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("sequence_length", input_args[kInputIndex3]->BuildType(), {kInt32},
op_name);
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
auto type = CheckAndConvertUtils::CheckTensorTypeValid("inputs", input_args[0]->BuildType(), valid_types, op_name);
auto type =
CheckAndConvertUtils::CheckTensorTypeValid("inputs", input_args[kInputIndex0]->BuildType(), valid_types, op_name);
return std::make_shared<Tuple>(std::vector<TypePtr>{type, type});
}
} // namespace

View File

@ -44,7 +44,8 @@ AbstractBasePtr CumSumInfer(const abstract::AnalysisEnginePtr &, const Primitive
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, prim_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}

View File

@ -59,7 +59,8 @@ AbstractBasePtr DepthToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri
if (format == NHWC) {
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};
}
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name);
const int64_t x_rank = 4;
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kEqual, x_rank, prim_name);
int64_t block_size = GetValue<int64_t>(primitive->GetAttr(kBlockSize));
(void)CheckAndConvertUtils::CheckInteger("x_shape[1] % (block_size*block_size)",
x_shape[1] % (block_size * block_size), kEqual, 0, prim_name);

View File

@ -117,14 +117,14 @@ AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, c
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("detection_post_process_infer", SizeToLong(input_args.size()), kEqual, 3,
prim_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
MS_EXCEPTION_IF_NULL(input_args[1]);
MS_EXCEPTION_IF_NULL(input_args[2]);
auto boxes = input_args[0];
auto scores = input_args[1];
auto anchors = input_args[2];
const int64_t input_num = 3;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]);
MS_EXCEPTION_IF_NULL(input_args[kInputIndex2]);
auto boxes = input_args[kInputIndex0];
auto scores = input_args[kInputIndex1];
auto anchors = input_args[kInputIndex2];
auto boxes_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(boxes->BuildShape())[kShape];
auto scores_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(scores->BuildShape())[kShape];
auto anchors_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(anchors->BuildShape())[kShape];

View File

@ -60,7 +60,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
MS_EXCEPTION(ValueError) << "DropoutDoMask input mask do not match input, input_x shape: " << x_shape->ToString()
<< ", mask shape: " << mask_shape->ToString();
}
auto keep_prop = input_args[2];
auto keep_prop = input_args[kInputIndex2];
if (keep_prop->isa<abstract::AbstractTensor>()) {
auto keep_prop_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(keep_prop->BuildShape())[kShape];
if (!keep_prop_shape.empty()) {
@ -72,7 +72,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = primitive->name();
auto keep_prop = input_args[2];
auto keep_prop = input_args[kInputIndex2];
MS_EXCEPTION_IF_NULL(keep_prop);
auto keep_prop_value = keep_prop->BuildValue();
MS_EXCEPTION_IF_NULL(keep_prop_value);
@ -116,7 +116,9 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
AbstractBasePtr DropoutDoMaskInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
(void)CheckAndConvertUtils::CheckInteger("infer shape", input_args.size(), kGreaterEqual, 3, primitive->name());
const int64_t input_num = 3;
(void)CheckAndConvertUtils::CheckInteger("infer shape", SizeToLong(input_args.size()), kGreaterEqual, input_num,
primitive->name());
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(DropoutDoMask, prim::kPrimDropoutDoMask, DropoutDoMaskInfer, nullptr, true);

View File

@ -94,7 +94,9 @@ ShapeVector CalOutputShape(const AbstractBasePtrList shape_list) {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("infer shape", input_args.size(), kGreaterEqual, 2, op_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("infer shape", SizeToLong(input_args.size()), kGreaterEqual, input_num,
op_name);
AbstractBasePtr shape_args = input_args[0];
MS_EXCEPTION_IF_NULL(shape_args);

View File

@ -55,7 +55,8 @@ size_t CheckInputsAndGetShape(const AbstractBasePtr &input_arg, const string &pr
abstract::TupleShapePtr Infer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
auto x_shape = CheckInputsAndGetShape(input_args[0], prim_name);
auto y_shape = CheckInputsAndGetShape(input_args[1], prim_name);

View File

@ -37,7 +37,8 @@ AbstractBasePtr EmbeddingLookupInfer(const abstract::AnalysisEnginePtr &, const
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 3, prim_name);
const int64_t input_num = 3;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
@ -47,7 +48,8 @@ AbstractBasePtr EmbeddingLookupInfer(const abstract::AnalysisEnginePtr &, const
MS_EXCEPTION_IF_NULL(indices);
const std::set<TypePtr> int_valid_types = {kInt8, kInt16, kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indices->BuildType(), int_valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("offset", input_args[2]->BuildType(), int_valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("offset", input_args[kInputIndex2]->BuildType(), int_valid_types,
prim_name);
MS_EXCEPTION_IF_NULL(params->shape());
auto params_shp = params->shape()->shape();
MS_EXCEPTION_IF_NULL(indices->shape());

View File

@ -26,21 +26,13 @@
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
AbstractBasePtr EqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, op_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
return BroadCastInferShape(op_name, input_args);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto op_name = prim->name();
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, 2, op_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, input_num,
op_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
@ -50,14 +42,9 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("y", input_args[1]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name);
}
} // namespace
AbstractBasePtr EqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
(void)InferType(primitive, input_args);
return abstract::MakeAbstract(InferShape(primitive, input_args), kBool);
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name);
auto out_shape = BroadCastInferShape(op_name, input_args);
return abstract::MakeAbstract(out_shape, kBool);
}
REGISTER_PRIMITIVE_C(kNameEqual, Equal);
} // namespace ops

View File

@ -31,7 +31,8 @@ AbstractBasePtr ExpandDimsInfer(const abstract::AnalysisEnginePtr &, const Primi
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 2, prim_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}

View File

@ -29,9 +29,9 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kGreaterEqual, 1, prim_name);
CheckAndConvertUtils::Check("min_shape", min_shape, kEqual, "max_shape", max_shape, prim_name);
(void)CheckAndConvertUtils::CheckInteger("min_shape", SizeToLong(min_shape.size()), kEqual, 1, prim_name);
@ -51,9 +51,9 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
MS_LOG(EXCEPTION) << "nullptr";
}
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("min", input_args[1]->BuildType());
(void)types.emplace("max", input_args[2]->BuildType());
(void)types.emplace("x", input_args[kInputIndex0]->BuildType());
(void)types.emplace("min", input_args[kInputIndex1]->BuildType());
(void)types.emplace("max", input_args[kInputIndex2]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace

View File

@ -44,17 +44,17 @@ AbstractBasePtr FakeQuantWithMinMaxVarsPerChannelInfer(const abstract::AnalysisE
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("x rank", (int64_t)x_shape.size(), kGreaterThan, 1, op_name);
CheckAndConvertUtils::Check("min shape", min_shape, kEqual, "max shape", max_shape, op_name);
(void)CheckAndConvertUtils::CheckInteger("min shape", (int64_t)min_shape.size(), kEqual, 1, op_name);
CheckAndConvertUtils::Check("min shape", min_shape[0], kEqual, "x shape", x_shape[x_shape.size() - 1], op_name);
auto x_type = input_args[0]->BuildType();
auto min_type = input_args[1]->BuildType();
auto max_type = input_args[2]->BuildType();
auto x_type = input_args[kInputIndex0]->BuildType();
auto min_type = input_args[kInputIndex1]->BuildType();
auto max_type = input_args[kInputIndex2]->BuildType();
std::vector<std::string> type_name = {"x", "min", "max"};
std::vector<TypePtr> type = {x_type, min_type, max_type};
for (size_t i = 0; i < 3; i++) {

View File

@ -26,11 +26,12 @@ AbstractBasePtr FillInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 3, prim_name);
const int64_t input_num = 3;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto input_dtype = input_args[0]->cast<abstract::AbstractTypePtr>();
auto input_dtype = input_args[kInputIndex0]->cast<abstract::AbstractTypePtr>();
MS_EXCEPTION_IF_NULL(input_dtype);
auto dtype_value = input_dtype->BuildValue();
MS_EXCEPTION_IF_NULL(dtype_value);
@ -39,10 +40,10 @@ AbstractBasePtr FillInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
auto valid_types = common_valid_types;
valid_types.insert(kBool);
(void)CheckAndConvertUtils::CheckTypeValid("output datatype", dtype, valid_types, prim_name);
auto out_shape = GetValue<std::vector<int64_t>>(input_args[1]->BuildValue());
auto x_type = input_args[2]->BuildType();
auto out_shape = GetValue<std::vector<int64_t>>(input_args[kInputIndex1]->BuildValue());
auto x_type = input_args[kInputIndex2]->BuildType();
auto x_type_id = x_type->type_id();
auto x_value = input_args[2]->BuildValue();
auto x_value = input_args[kInputIndex2]->BuildValue();
auto abs = std::make_shared<abstract::AbstractTensor>(dtype, std::make_shared<abstract::Shape>(out_shape));
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(x_type_id, out_shape);
MS_EXCEPTION_IF_NULL(tensor);
@ -54,7 +55,7 @@ AbstractBasePtr FillInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
auto float_value = GetValue<float>(x_value);
SetTensorData(tensor->data_c(), float_value, mem_size);
} else {
MS_LOG(ERROR) << " Fill not supported to flod the constant type " << input_args[2]->ToString();
MS_LOG(ERROR) << " Fill not supported to flod the constant type " << input_args[kInputIndex2]->ToString();
}
abs->set_value(tensor);
return abs;

View File

@ -47,8 +47,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
MS_LOG(EXCEPTION) << "nullptr";
}
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("y", input_args[1]->BuildType());
(void)types.emplace("x", input_args[kInputIndex0]->BuildType());
(void)types.emplace("y", input_args[kInputIndex1]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace

View File

@ -62,7 +62,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
if (format == NHWC) {
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
}
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, 4, op_name);
const int64_t x_rank = 4;
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, x_rank, op_name);
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
auto batch = in_shape[0];

View File

@ -40,7 +40,8 @@ void Conv2dTransposeFusion::Init(int64_t in_channel, int64_t out_channel, const
}
void Conv2dTransposeFusion::set_kernel_size(const std::vector<int64_t> &kernel_size) {
(void)CheckAndConvertUtils::CheckInteger(kKernelSize, SizeToLong(kernel_size.size()), kEqual, 2, name());
const size_t kernel_len = 2;
(void)CheckAndConvertUtils::CheckInteger(kKernelSize, SizeToLong(kernel_size.size()), kEqual, kernel_len, name());
for (int64_t item : kernel_size) {
(void)CheckAndConvertUtils::CheckInteger(kKernelSize, item, kGreaterEqual, 1, name());
}
@ -48,7 +49,8 @@ void Conv2dTransposeFusion::set_kernel_size(const std::vector<int64_t> &kernel_s
}
void Conv2dTransposeFusion::set_dilation(const std::vector<int64_t> &dilation) {
(void)CheckAndConvertUtils::CheckInteger(kDilation, SizeToLong(dilation.size()), kEqual, 2, name());
const size_t dilation_size = 2;
(void)CheckAndConvertUtils::CheckInteger(kDilation, SizeToLong(dilation.size()), kEqual, dilation_size, name());
for (int64_t item : dilation) {
(void)CheckAndConvertUtils::CheckInteger(kDilation, item, kGreaterEqual, 1, name());
}

View File

@ -62,8 +62,8 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[0]);
MS_EXCEPTION_IF_NULL(input_args[1]);
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]);
auto input0 = input_args[0];
auto input1 = input_args[1];
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input0->BuildShape())[kShape];
@ -93,7 +93,7 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P
new_k = input1_shape[1];
}
if (has_bias) {
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
if (input2_shape[0] != input1_shape[0]) {
MS_EXCEPTION(ValueError) << "Bias size is invalid";
}

View File

@ -36,9 +36,9 @@ AbstractBasePtr SliceFusionInfer(const abstract::AnalysisEnginePtr &, const Prim
auto op_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x_shape_len = x_shape.size();
auto begin_v = input_args[1]->BuildValue();
auto size_v = input_args[2]->BuildValue();
auto x_type = input_args[0]->BuildType();
auto begin_v = input_args[kInputIndex1]->BuildValue();
auto size_v = input_args[kInputIndex2]->BuildValue();
auto x_type = input_args[kInputIndex0]->BuildType();
MS_EXCEPTION_IF_NULL(x_type);
MS_EXCEPTION_IF_NULL(begin_v);
MS_EXCEPTION_IF_NULL(size_v);

View File

@ -14,10 +14,12 @@
* limitations under the License.
*/
#include "ops/gather.h"
#include <set>
#include <memory>
#include <algorithm>
#include "ops/gather.h"
#include "ops/op_utils.h"
namespace mindspore {
namespace ops {
@ -33,28 +35,30 @@ AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const Primitive
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 1);
// check
std::set<TypePtr> valid_params_types = {kTensorType};
(void)CheckAndConvertUtils::CheckSubClass("params_type", input_args[0]->BuildType(), valid_params_types, op_name);
(void)CheckAndConvertUtils::CheckSubClass("params_type", input_args[kInputIndex0]->BuildType(), valid_params_types,
op_name);
std::set<TypePtr> int_types = {kInt8, kInt16, kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[1]->BuildType(), int_types, op_name);
(void)CheckAndConvertUtils::CheckTypeValid("axis_type", input_args[2]->BuildType(), int_types, op_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[kInputIndex1]->BuildType(), int_types,
op_name);
(void)CheckAndConvertUtils::CheckTypeValid("axis_type", input_args[kInputIndex2]->BuildType(), int_types, op_name);
bool ind_dyn = (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty());
bool param_dyn = (!params->shape()->min_shape().empty() && !params->shape()->max_shape().empty());
int64_t axis_val = 0;
// 3rd input is a Tensor when Gather is a dynamic shape operator
if (input_args[2]->isa<abstract::AbstractTensor>()) {
auto axis = input_args[2]->cast<abstract::AbstractTensorPtr>();
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>()) {
auto axis = input_args[kInputIndex2]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(axis);
auto axis_value_ptr = axis->BuildValue();
MS_EXCEPTION_IF_NULL(axis_value_ptr);
auto axis_tensor = axis_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(axis_tensor);
axis_val = *static_cast<int64_t *>(axis_tensor->data_c());
} else if (input_args[2]->isa<abstract::AbstractScalar>()) {
auto axis = input_args[2]->cast<abstract::AbstractScalarPtr>();
} else if (input_args[kInputIndex2]->isa<abstract::AbstractScalar>()) {
auto axis = input_args[kInputIndex2]->cast<abstract::AbstractScalarPtr>();
axis_val = GetValue<int64_t>(axis->BuildValue());
} else {
MS_LOG(EXCEPTION) << "Invalid abstract type:" << input_args[2]->type_name();
MS_LOG(EXCEPTION) << "Invalid abstract type:" << input_args[kInputIndex2]->type_name();
}
auto params_shp = params->shape()->shape();
auto indices_shp = indices->shape()->shape();

View File

@ -29,11 +29,11 @@ abstract::ShapePtr GatherDInferShape(const PrimitivePtr &primitive, const std::v
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
// check
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
int64_t x_rank = x_shape.size();
CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, "index_rank", index_shape.size(), prim_name);
auto value_ptr = input_args[1]->BuildValue();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
int64_t x_rank = SizeToLong(x_shape.size());
CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, "index_rank", SizeToLong(index_shape.size()), prim_name);
auto value_ptr = input_args[kInputIndex1]->BuildValue();
MS_EXCEPTION_IF_NULL(value_ptr);
auto dim_v = GetValue<int64_t>(value_ptr);
CheckAndConvertUtils::Check("dim value", dim_v, kGreaterEqual, "negative index_rank", -x_rank, prim_name);
@ -66,8 +66,9 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv
auto prim_name = primitive->name();
// check
std::set<TypePtr> valid_types = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[2]->BuildType(), valid_types, prim_name);
(void)CheckAndConvertUtils::CheckSubClass("dim_type", input_args[1]->BuildType(), valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[kInputIndex2]->BuildType(), valid_types,
prim_name);
(void)CheckAndConvertUtils::CheckSubClass("dim_type", input_args[kInputIndex1]->BuildType(), valid_types, prim_name);
return abstract::MakeAbstract(GatherDInferShape(primitive, input_args), GatherDInferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, true);

View File

@ -28,7 +28,8 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}

View File

@ -30,7 +30,8 @@ constexpr size_t k5DInputDims = 5;
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kEqual, 2, op_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kEqual, input_num, op_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
@ -48,7 +49,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kEqual, 2, op_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kEqual, input_num, op_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}

View File

@ -58,10 +58,10 @@ AbstractBasePtr BatchNormGradInfer(const abstract::AnalysisEnginePtr &, const Pr
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::Check("BatchNorm y_backprop_shape", y_backprop_shape, kEqual, "BatchNorm x_shape", x_shape);
auto dx = input_args[1]->Broaden();
auto dscale = input_args[2]->Broaden();
auto reserve_1 = input_args[3]->Broaden();
auto reserve_2 = input_args[4]->Broaden();
auto dx = input_args[kInputIndex1]->Broaden();
auto dscale = input_args[kInputIndex2]->Broaden();
auto reserve_1 = input_args[kInputIndex3]->Broaden();
auto reserve_2 = input_args[kInputIndex4]->Broaden();
AbstractBasePtrList rets = {dx, dscale, dscale, reserve_1, reserve_2};
return std::make_shared<abstract::AbstractTuple>(rets);

View File

@ -27,10 +27,11 @@ abstract::ShapePtr BinaryCrossEntroyGradInferShape(const PrimitivePtr &primitive
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
(void)CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name);
if (weight_shape.size() < 1) {
CheckAndConvertUtils::Check("y shape", y_shape, kEqual, "weight shape", weight_shape, prim_name);
}
@ -40,12 +41,12 @@ abstract::ShapePtr BinaryCrossEntroyGradInferShape(const PrimitivePtr &primitive
TypePtr BinaryCrossEntroyGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
std::map<std::string, TypePtr> types;
(void)types.emplace("x_shape", input_args[0]->BuildType());
(void)types.emplace("y_shape", input_args[1]->BuildType());
(void)types.emplace("x_shape", input_args[kInputIndex0]->BuildType());
(void)types.emplace("y_shape", input_args[kInputIndex1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
if (input_args[3]->BuildType() != nullptr) {
(void)types.emplace("x_shape", input_args[0]->BuildType());
(void)types.emplace("weight_shape", input_args[2]->BuildType());
if (input_args[kInputIndex3]->BuildType() != nullptr) {
(void)types.emplace("x_shape", input_args[kInputIndex0]->BuildType());
(void)types.emplace("weight_shape", input_args[kInputIndex2]->BuildType());
infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
return infer_type;

View File

@ -29,7 +29,7 @@ abstract::ShapePtr Conv2DBackpropFilterInferShape(const PrimitivePtr &primitive,
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
// check
auto w_size_v = input_args[2]->BuildValue();
auto w_size_v = input_args[kInputIndex2]->BuildValue();
auto ret_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("w_size", w_size_v, prim_name);
return std::make_shared<abstract::Shape>(ret_shape);
}
@ -157,8 +157,9 @@ AbstractBasePtr Conv2DBackpropFilterInfer(const abstract::AnalysisEnginePtr &, c
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
// check
(void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, 3, prim_name);
const int64_t input_num = 3;
(void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, input_num,
prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}

View File

@ -67,7 +67,7 @@ abstract::ShapePtr Conv2DBackpropInputInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_size_v = input_args[2]->BuildValue();
auto x_size_v = input_args[kInputIndex2]->BuildValue();
auto ret_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("x_size", x_size_v, prim_name);
auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto format = CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr(kFormat));
@ -93,7 +93,9 @@ AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, co
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
// check
(void)CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kGreaterEqual, 3, prim_name);
const int64_t input_num = 3;
(void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, input_num,
prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
@ -149,7 +151,8 @@ void Conv2DBackpropInput::set_pad_mode(const PadMode &pad_mode) {
}
void Conv2DBackpropInput::set_pad(const std::vector<int64_t> &pad) {
(void)CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, 4, name());
const int64_t pad_size = 4;
(void)CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, pad_size, name());
(void)AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
}

View File

@ -162,8 +162,9 @@ AbstractBasePtr GroupConv2DGradInputInfer(const abstract::AnalysisEnginePtr &, c
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("group_conv_2D_infer", SizeToLong(input_args.size()), kGreaterEqual, 2,
prim_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("group_conv_2D_infer", SizeToLong(input_args.size()), kGreaterEqual,
input_num, prim_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
// Infer shape

View File

@ -26,10 +26,11 @@ AbstractBasePtr LayerNormGradInfer(const abstract::AnalysisEnginePtr &, const Pr
// Outputs: x_backprob, gamma_backprob, beta_backprob
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 5, op_name);
auto x_backprob = input_args[0]->Broaden();
auto gamma_backprob = input_args[4]->Broaden();
auto beta_backprob = input_args[4]->Broaden();
const int64_t input_num = 5;
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, input_num, op_name);
auto x_backprob = input_args[kInputIndex0]->Broaden();
auto gamma_backprob = input_args[kInputIndex4]->Broaden();
auto beta_backprob = input_args[kInputIndex4]->Broaden();
auto shapes = std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{
x_backprob->BuildShape(), gamma_backprob->BuildShape(), beta_backprob->BuildShape()});
auto types = std::make_shared<Tuple>(

View File

@ -30,7 +30,8 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
@ -47,7 +48,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
(void)CheckAndConvertUtils::CheckInteger("ReLUGrad infer", SizeToLong(input_args.size()), kEqual, 2, prim_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
auto dout = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
auto out = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1);

View File

@ -27,13 +27,7 @@
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
abstract::ShapePtr InferShape(const std::vector<AbstractBasePtr> &input_args) {
auto x = input_args[0]->BuildShape();
MS_EXCEPTION_IF_NULL(x);
auto shape_element = x->cast<abstract::ShapePtr>();
@ -41,9 +35,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
return shape_element;
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
(void)CheckAndConvertUtils::CheckInteger("ReLUGradV2 infer", input_args.size(), kEqual, 2, prim_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
auto x_type_map = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(x_type_map);
@ -55,7 +47,15 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
} // namespace
AbstractBasePtr ReLUGradV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, input_num,
prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
return abstract::MakeAbstract(InferShape(input_args), InferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(ReLUGradV2, prim::kPrimReluGradV2, ReLUGradV2Infer, nullptr, true);
} // namespace ops

View File

@ -31,23 +31,24 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsGradInfer(const abstract::AnalysisE
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t input_num = 3;
(void)CheckAndConvertUtils::CheckInteger("sigmoid_cross_entropy_with_logits_grad_infer",
SizeToLong(input_args.size()), kEqual, 3, prim_name);
SizeToLong(input_args.size()), kEqual, input_num, prim_name);
// Infer Shape
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "y_shape", y_shape, prim_name, TypeError);
CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "dout_shape", dout_shape, prim_name, TypeError);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
(void)CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "y_shape", y_shape, prim_name, TypeError);
(void)CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "dout_shape", dout_shape, prim_name, TypeError);
// Infer type
const std::set<TypePtr> valid_types = {kBool, kInt, kInt8, kInt16, kInt32, kInt64, kUInt, kUInt8,
kUInt16, kUInt32, kUInt64, kFloat, kFloat16, kFloat32, kFloat64, kComplex64};
std::map<std::string, TypePtr> args;
args.emplace("x_type", input_args[0]->BuildType());
args.emplace("y_type", input_args[1]->BuildType());
args.emplace("dout_type", input_args[2]->BuildType());
args.emplace("x_type", input_args[kInputIndex0]->BuildType());
args.emplace("y_type", input_args[kInputIndex1]->BuildType());
args.emplace("dout_type", input_args[kInputIndex2]->BuildType());
auto dout_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
return std::make_shared<abstract::AbstractTensor>(dout_type, x_shape);
}

View File

@ -38,13 +38,14 @@ AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("smooth_l1_loss_grad_infer", SizeToLong(input_args.size()), kEqual, 3,
prim_name);
const int64_t input_num = 3;
(void)CheckAndConvertUtils::CheckInteger("smooth_l1_loss_grad_infer", SizeToLong(input_args.size()), kEqual,
input_num, prim_name);
// Infer shape
auto prediction = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto target = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto dloss = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
auto prediction = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto target = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto dloss = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "target shape", target, prim_name, TypeError);
CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "dloss", dloss, prim_name, TypeError);
@ -52,9 +53,9 @@ AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const
const std::set<TypePtr> valid_types = {kBool, kInt, kInt8, kInt16, kInt32, kInt64, kUInt, kUInt8,
kUInt16, kUInt32, kUInt64, kFloat, kFloat16, kFloat32, kFloat64, kComplex64};
std::map<std::string, TypePtr> args;
args.emplace("prediction", input_args[0]->BuildType());
args.emplace("target", input_args[1]->BuildType());
args.emplace("dloss", input_args[2]->BuildType());
args.emplace("prediction", input_args[kInputIndex0]->BuildType());
args.emplace("target", input_args[kInputIndex1]->BuildType());
args.emplace("dloss", input_args[kInputIndex2]->BuildType());
auto dloss_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
return std::make_shared<abstract::AbstractTensor>(dloss_type, prediction);

View File

@ -14,10 +14,11 @@
* limitations under the License.
*/
#include <vector>
#include "ops/hashtable_lookup.h"
#include <vector>
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"
namespace mindspore {
namespace ops {
@ -32,7 +33,7 @@ AbstractBasePtr HashtableLookupInfer(const abstract::AnalysisEnginePtr &, const
(void)CheckAndConvertUtils::CheckInteger("logits size", SizeToLong(input.size()), kGreaterEqual, 1, op_name);
hits_shape.push_back(input[0]);
auto value_type = input_args[2]->BuildType();
auto value_type = input_args[kInputIndex2]->BuildType();
MS_EXCEPTION_IF_NULL(value_type);
auto tensor_type = value_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);

View File

@ -33,15 +33,17 @@ AbstractBasePtr LshProjectionInfer(const abstract::AnalysisEnginePtr &, const Pr
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
const int64_t input_num = 2;
const int64_t input0_size = 2;
const int64_t input0_last_dim = 32;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, op_name);
auto input0 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto input1 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("input0 rank", SizeToLong(input0.size()), kEqual, 2, op_name);
(void)CheckAndConvertUtils::CheckInteger("input0_shape_dimen_1", input0[1], kLessEqual, 32, op_name);
auto input0 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto input1 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("input0 rank", SizeToLong(input0.size()), kEqual, input0_size, op_name);
(void)CheckAndConvertUtils::CheckInteger("input0_shape_dimen_1", input0[1], kLessEqual, input0_last_dim, op_name);
(void)CheckAndConvertUtils::CheckInteger("input1 rank", SizeToLong(input1.size()), kGreaterEqual, 1, op_name);
if (input_args.size() == 3) {
auto input2 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
auto input2 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("input2 rank", SizeToLong(input2.size()), kEqual, 1, op_name);
(void)CheckAndConvertUtils::CheckInteger("input2_shape_dimen_0", input2[0], kEqual, input1[0], op_name);
}

View File

@ -32,16 +32,23 @@ AbstractBasePtr LstmInfer(const PrimitivePtr &primitive, const std::vector<Abstr
// infer shape
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("lstm_prim_infer", SizeToLong(input_args.size()), kEqual, 4, prim_name);
auto x_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto h_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto c_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
const int64_t input_num = 4;
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
auto x_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto h_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto c_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
int64_t input_x_size = GetValue<int64_t>(primitive->GetAttr(kInput_size));
(void)CheckAndConvertUtils::CheckInteger("x_shape.size()", SizeToLong(x_input_shape.size()), kEqual, 3, prim_name);
const int64_t shape_size = 3;
(void)CheckAndConvertUtils::CheckInteger("x_shape.size()", SizeToLong(x_input_shape.size()), kEqual, shape_size,
prim_name);
(void)CheckAndConvertUtils::CheckInteger("x_shape[2]", x_input_shape[2], kEqual, input_x_size, prim_name);
(void)CheckAndConvertUtils::CheckInteger("h_shape.size()", SizeToLong(h_input_shape.size()), kEqual, 3, prim_name);
(void)CheckAndConvertUtils::CheckInteger("h_shape.size()", SizeToLong(h_input_shape.size()), kEqual, shape_size,
prim_name);
CheckAndConvertUtils::Check("h_shape", h_input_shape, kEqual, "c_shape", c_input_shape, prim_name);
int64_t num_layers = GetValue<int64_t>(primitive->GetAttr(kNumLayers));
@ -81,15 +88,11 @@ AbstractBasePtr LstmInfer(const PrimitivePtr &primitive, const std::vector<Abstr
std::vector<int64_t> state_shape = {1, 1};
// infer type
(void)CheckAndConvertUtils::CheckInteger("lstm_prim_infer", SizeToLong(input_args.size()), kEqual, 4, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto infer_type0 = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
auto infer_type1 = input_args[1]->BuildType()->cast<TensorTypePtr>()->element();
auto infer_type2 = input_args[2]->BuildType()->cast<TensorTypePtr>()->element();
auto infer_type3 = input_args[3]->BuildType()->cast<TensorTypePtr>()->element();
auto infer_type4 = input_args[4]->BuildType()->cast<TensorTypePtr>()->element();
auto infer_type0 = input_args[kInputIndex0]->BuildType()->cast<TensorTypePtr>()->element();
auto infer_type1 = input_args[kInputIndex1]->BuildType()->cast<TensorTypePtr>()->element();
auto infer_type2 = input_args[kInputIndex2]->BuildType()->cast<TensorTypePtr>()->element();
auto infer_type3 = input_args[kInputIndex3]->BuildType()->cast<TensorTypePtr>()->element();
auto infer_type4 = input_args[kInputIndex4]->BuildType()->cast<TensorTypePtr>()->element();
auto output0 = std::make_shared<abstract::AbstractTensor>(infer_type0, x_shape);
auto output1 = std::make_shared<abstract::AbstractTensor>(infer_type1, y_shape);
auto output2 = std::make_shared<abstract::AbstractTensor>(infer_type2, c_shape);

View File

@ -125,7 +125,8 @@ bool MatMul::get_transpose_b() const {
AbstractBasePtr MatMulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
(void)CheckAndConvertUtils::CheckInteger("MatMul infer", SizeToLong(input_args.size()), kGreaterEqual, 2,
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("MatMul infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
primitive->name());
return abstract::MakeAbstract(MatMulInferShape(primitive, input_args), MatMulInferType(primitive, input_args));
}

View File

@ -88,7 +88,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
if (format == NHWC) {
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
}
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, 4, op_name);
const int64_t x_rank = 4;
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, x_rank, op_name);
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
auto pad_mode_value = (primitive->GetAttr(kPadMode));

View File

@ -25,10 +25,14 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t input0_size = 3;
const int64_t input1_size = 1;
auto first_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto second_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("input 0 rank", SizeToLong(first_input_shape.size()), kEqual, 3, prim_name);
(void)CheckAndConvertUtils::CheckInteger("input 1 rank", SizeToLong(second_input_shape.size()), kEqual, 1, prim_name);
(void)CheckAndConvertUtils::CheckInteger("input 0 rank", SizeToLong(first_input_shape.size()), kEqual, input0_size,
prim_name);
(void)CheckAndConvertUtils::CheckInteger("input 1 rank", SizeToLong(second_input_shape.size()), kEqual, input1_size,
prim_name);
std::vector<int64_t> out_shape = {first_input_shape[0], first_input_shape[1],
GetValue<int64_t>(primitive->GetAttr(kDctCoeffNum))};
return std::make_shared<abstract::Shape>(out_shape);

View File

@ -35,14 +35,16 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim->name());
auto op_name = prim->name();
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, op_name);
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
MS_LOG(EXCEPTION) << "nullptr";
}
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("y", input_args[1]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name);
}
} // namespace

View File

@ -30,7 +30,8 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
@ -42,7 +43,9 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
MS_EXCEPTION_IF_NULL(item);
}
auto op_name = prim->name();
(void)CheckAndConvertUtils::CheckInteger("Mul infer", input_args.size(), kGreaterEqual, 2, op_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, input_num,
op_name);
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("y", input_args[1]->BuildType());

View File

@ -27,23 +27,8 @@
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, op_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
return BroadCastInferShape(op_name, input_args);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto op_name = prim->name();
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, 2, op_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("y", input_args[1]->BuildType());
@ -53,8 +38,15 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
AbstractBasePtr NotEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, op_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
(void)InferType(primitive, input_args);
return abstract::MakeAbstract(InferShape(primitive, input_args), kBool);
return abstract::MakeAbstract(BroadCastInferShape(op_name, input_args), kBool);
}
REGISTER_PRIMITIVE_C(kNameNotEqual, NotEqual);
} // namespace ops

View File

@ -59,11 +59,12 @@ abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::ve
TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = prim->name();
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[0]->BuildType(), {kInt32, kInt64}, op_name);
(void)CheckAndConvertUtils::CheckTypeValid("depth", input_args[1]->BuildType(), {kInt8, kInt16, kInt32, kInt64},
op_name);
std::map<std::string, TypePtr> args = {{"on_value", input_args[2]->BuildType()},
{"off_dtype", input_args[3]->BuildType()}};
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[kInputIndex0]->BuildType(), {kInt32, kInt64},
op_name);
(void)CheckAndConvertUtils::CheckTypeValid("depth", input_args[kInputIndex1]->BuildType(),
{kInt8, kInt16, kInt32, kInt64}, op_name);
std::map<std::string, TypePtr> args = {{"on_value", input_args[kInputIndex2]->BuildType()},
{"off_dtype", input_args[kInputIndex3]->BuildType()}};
return CheckAndConvertUtils::CheckTensorTypeSame(args, {kFloat16, kFloat32}, op_name);
}
} // namespace

View File

@ -252,6 +252,27 @@ constexpr auto kSplitDim = "split_dim";
constexpr auto kPadTop = "pad_top";
constexpr auto kTransFormat = "trans_format";
constexpr auto kApproximate = "approximate";
enum Index : size_t {
kInputIndex0 = 0,
kInputIndex1,
kInputIndex2,
kInputIndex3,
kInputIndex4,
kInputIndex5,
kInputIndex6,
kInputIndex7,
kInputIndex8,
kInputIndex9,
kInputIndex10,
kInputIndex11,
kInputIndex12,
kInputIndex13,
kInputIndex14,
kInputIndex15,
kInputIndex16,
};
const std::set<TypePtr> common_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64};

View File

@ -26,9 +26,10 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto w = input_args[1]->BuildShape();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x)[kShape];
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(w)[kShape];
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kGreaterEqual, 2, prim_name);
(void)CheckAndConvertUtils::CheckInteger("weight rank", SizeToLong(w_shape.size()), kEqual, 1, prim_name);
const int64_t x_rank = 2;
const int64_t w_rank = 1;
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kGreaterEqual, x_rank, prim_name);
(void)CheckAndConvertUtils::CheckInteger("weight rank", SizeToLong(w_shape.size()), kEqual, w_rank, prim_name);
if (w_shape[0] != x_shape[1] && w_shape[0] != 1) {
MS_LOG(EXCEPTION) << "For " << prim_name << ", channel of input_x and weight must be matched, "
<< "while channel of input_x is " << x_shape[1] << ", weight_shape[0] is " << w_shape[0];

View File

@ -61,13 +61,14 @@ AbstractBasePtr RangeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
int64_t shape_size = 0;
if (input_args.size() == 3) {
MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue());
MS_EXCEPTION_IF_NULL(input_args[1]->BuildValue());
MS_EXCEPTION_IF_NULL(input_args[2]->BuildValue());
auto start_tensor = input_args[0]->BuildValue()->cast<tensor::TensorPtr>();
auto limit_tensor = input_args[1]->BuildValue()->cast<tensor::TensorPtr>();
auto delta_tensor = input_args[2]->BuildValue()->cast<tensor::TensorPtr>();
const size_t max_input_num = 3;
if (input_args.size() == max_input_num) {
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]->BuildValue());
MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]->BuildValue());
MS_EXCEPTION_IF_NULL(input_args[kInputIndex2]->BuildValue());
auto start_tensor = input_args[kInputIndex0]->BuildValue()->cast<tensor::TensorPtr>();
auto limit_tensor = input_args[kInputIndex1]->BuildValue()->cast<tensor::TensorPtr>();
auto delta_tensor = input_args[kInputIndex2]->BuildValue()->cast<tensor::TensorPtr>();
auto dtype = start_tensor->data_type();
switch (dtype) {
case kNumberTypeInt:

View File

@ -29,7 +29,8 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 2, prim_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
@ -42,7 +43,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
MS_EXCEPTION_IF_NULL(item);
}
auto op_name = prim->name();
(void)CheckAndConvertUtils::CheckInteger("RealDiv infer", SizeToLong(input_args.size()), kGreaterEqual, 2, op_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(input_args.size()), kGreaterEqual, input_num, op_name);
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("y", input_args[1]->BuildType());

View File

@ -31,7 +31,8 @@ AbstractBasePtr ReshapeInfer(const abstract::AnalysisEnginePtr &, const Primitiv
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}

View File

@ -50,7 +50,8 @@ AbstractBasePtr ResizeBilinearInfer(const abstract::AnalysisEnginePtr &, const P
// Infer shape
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("input_shape_rank", SizeToLong(input_shape.size()), kEqual, 4, prim_name);
const int64_t shape_size = 4;
(void)CheckAndConvertUtils::CheckInteger("input rank", SizeToLong(input_shape.size()), kEqual, shape_size, prim_name);
std::vector<int64_t> out_shape = {input_shape[0], input_shape[1]};
auto size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kSize));
(void)out_shape.insert(out_shape.end(), size.begin(), size.end());

View File

@ -39,7 +39,9 @@ AbstractBasePtr ReverseSequenceInfer(const abstract::AnalysisEnginePtr &, const
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 2, prim_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, input_num,
prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}

View File

@ -52,7 +52,8 @@ AbstractBasePtr ROIPoolingInfer(const abstract::AnalysisEnginePtr &, const Primi
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("roi_pooling_infer", SizeToLong(input_args.size()), kEqual, 2, prim_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
MS_EXCEPTION_IF_NULL(input_args[1]);

View File

@ -24,13 +24,13 @@ namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const std::vector<AbstractBasePtr> &input_args) {
auto shape_value = input_args[2]->BuildValue();
auto shape_value = input_args[kInputIndex2]->BuildValue();
auto shape_value_element = GetValue<std::vector<int64_t>>(shape_value);
for (const auto &shape : shape_value_element) {
(void)CheckAndConvertUtils::CheckInteger("shape value", shape, kGreaterThan, 0, "ScatterNd");
}
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto update_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto update_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("indices_shape[0] and update_shape[0]", indices_shape[0], kEqual,
update_shape[0], "ScatterNd");
return std::make_shared<abstract::Shape>(shape_value_element);

View File

@ -30,8 +30,9 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsInfer(const abstract::AnalysisEngin
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("sigmoid_cross_extropy_with_logits_infer", SizeToLong(input_args.size()),
kEqual, 2, prim_name);
kEqual, input_num, prim_name);
// Infer shape
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];

View File

@ -30,7 +30,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("input shape", SizeToLong(input_shape.size()), kEqual, 4, prim_name);
const int64_t x_rank = 4;
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(input_shape.size()), kEqual, x_rank, prim_name);
std::vector<int64_t> output_shape(input_shape.size());
auto block_shape_vector = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize));
auto paddings = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kPaddings));

View File

@ -30,7 +30,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name);
const int64_t shape_size = 4;
(void)CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()), kEqual, shape_size, prim_name);
auto out_shape = x_shape;
int64_t block_shape_prod = 1;
const size_t offset = 2;
@ -60,7 +61,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
} // namespace
void SpaceToBatchND::set_paddings(std::vector<std::vector<int64_t>> paddings) {
(void)CheckAndConvertUtils::CheckInteger(kPaddings, SizeToLong(paddings.size()), kEqual, 2, this->name());
const int64_t pad_size = 2;
(void)CheckAndConvertUtils::CheckInteger(kPaddings, SizeToLong(paddings.size()), kEqual, pad_size, this->name());
size_t h = paddings.size();
size_t w = paddings[0].size();
std::vector<size_t> temp_w = {2, 2};
@ -78,7 +80,9 @@ std::vector<std::vector<int64_t>> SpaceToBatchND::get_paddings() const {
return GetValue<std::vector<std::vector<int64_t>>>(value_ptr);
}
void SpaceToBatchND::set_block_shape(std::vector<int64_t> block_shape) {
(void)CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape.size()), kEqual, 2, this->name());
const int64_t block_size = 2;
(void)CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape.size()), kEqual, block_size,
this->name());
for (size_t i = 0; i < block_shape.size(); i++) {
(void)CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape[i]), kGreaterEqual, 1, this->name());
}

View File

@ -38,7 +38,9 @@ AbstractBasePtr SparseSoftmaxCrossEntropyWithLogitsInfer(const abstract::Analysi
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, input_num,
prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}

View File

@ -28,12 +28,13 @@ AbstractBasePtr SparseToDenseInfer(const abstract::AnalysisEnginePtr &, const Pr
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 3, prim_name);
const int64_t input_num = 3;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
// infer shape
auto dense_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape];
auto dense_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
// infer type
auto values_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element();
return std::make_shared<abstract::AbstractTensor>(values_type, dense_shape);

View File

@ -173,18 +173,18 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive,
MS_EXCEPTION_IF_NULL(primitive);
auto strided_slice_prim = primitive->cast<PrimStridedSlicePtr>();
MS_EXCEPTION_IF_NULL(strided_slice_prim);
auto tuple_begin_v = input_args[1]->cast<abstract::AbstractTuplePtr>();
auto tuple_begin_v = input_args[kInputIndex1]->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_begin_v);
auto temp_begin_v = tuple_begin_v->BuildValue();
MS_EXCEPTION_IF_NULL(temp_begin_v);
auto begin_v = GetValue<std::vector<int64_t>>(temp_begin_v);
auto tuple_end_v = input_args[2]->cast<abstract::AbstractTuplePtr>();
auto tuple_end_v = input_args[kInputIndex2]->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_end_v);
auto temp_end_v = tuple_end_v->BuildValue();
MS_EXCEPTION_IF_NULL(temp_end_v);
auto end_v = GetValue<std::vector<int64_t>>(temp_end_v);
auto strides_v = CheckAndGetValidStrides(input_args[3]);
auto strides_v = CheckAndGetValidStrides(input_args[kInputIndex3]);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMinShape];

View File

@ -30,7 +30,8 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 2, prim_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
@ -42,7 +43,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
MS_EXCEPTION_IF_NULL(item);
}
auto op_name = prim->name();
(void)CheckAndConvertUtils::CheckInteger("Sub infer", SizeToLong(input_args.size()), kGreaterEqual, 2, op_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(input_args.size()), kGreaterEqual, input_num, op_name);
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("y", input_args[1]->BuildType());

View File

@ -77,7 +77,8 @@ abstract::ShapePtr TileInferShape(const PrimitivePtr &primitive, const std::vect
TypePtr TileInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
(void)CheckAndConvertUtils::CheckInteger("tile_prim_infer", input_args.size(), kEqual, 2, prim_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}

View File

@ -32,7 +32,8 @@ AbstractBasePtr TopKInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("top_k_infer", SizeToLong(input_args.size()), kEqual, 2, prim_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("top_k_infer", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
// Infer dtype
for (const auto &item : input_args) {

View File

@ -55,11 +55,11 @@ AbstractBasePtr UnsortedSegmentSumInfer(const abstract::AnalysisEnginePtr &, con
}
const std::set<TypePtr> valid_num_segments_types = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("num_segments", input_args[2]->BuildType(), valid_num_segments_types,
prim_name);
int64_t size_segment_ids_shp = segment_ids_shape.size();
int64_t size_x_shpe = x_shape.size();
for (int64_t i = size_segment_ids_shp; i < size_x_shpe; ++i) {
(void)CheckAndConvertUtils::CheckTensorTypeValid("num_segments", input_args[kInputIndex2]->BuildType(),
valid_num_segments_types, prim_name);
auto size_segment_ids_shp = segment_ids_shape.size();
auto size_x_shpe = x_shape.size();
for (size_t i = size_segment_ids_shp; i < size_x_shpe; ++i) {
(void)shp.emplace_back(x_shape[i]);
}

View File

@ -29,16 +29,18 @@ AbstractBasePtr WhereInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
MS_EXCEPTION_IF_NULL(input);
}
auto op_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kGreaterEqual, 3, op_name);
auto input0_type_ = input_args[0]->BuildType()->cast<TensorTypePtr>();
const int64_t input_num = 3;
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, input_num,
op_name);
auto input0_type_ = input_args[kInputIndex0]->BuildType()->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(input0_type_);
auto input0_type = input0_type_->element();
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto num = input_args[0]->BuildValue()->cast<tensor::TensorPtr>()->ElementsNum();
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto num1 = input_args[1]->BuildValue()->cast<tensor::TensorPtr>()->ElementsNum();
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
auto num2 = input_args[2]->BuildValue()->cast<tensor::TensorPtr>()->ElementsNum();
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto num = input_args[kInputIndex0]->BuildValue()->cast<tensor::TensorPtr>()->ElementsNum();
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto num1 = input_args[kInputIndex1]->BuildValue()->cast<tensor::TensorPtr>()->ElementsNum();
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto num2 = input_args[kInputIndex2]->BuildValue()->cast<tensor::TensorPtr>()->ElementsNum();
auto nummax = num > num1 ? num : (num1 > num2 ? num1 : num2);
size_t axisout = 0;
size_t temp = 0;