From a9c058a118a17ef9bca27e083c106b54e546420f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=8D=97?= Date: Tue, 31 Aug 2021 21:33:11 +0800 Subject: [PATCH] codecheck clean --- mindspore/core/ops/adam.cc | 16 ++++----- mindspore/core/ops/add.cc | 25 ++++---------- mindspore/core/ops/apply_momentum.cc | 13 ++++---- mindspore/core/ops/assign.cc | 3 +- mindspore/core/ops/avg_pool.cc | 8 +++-- mindspore/core/ops/avg_pool_3d.cc | 10 ++++-- mindspore/core/ops/batch_matmul.cc | 4 ++- mindspore/core/ops/batch_norm.cc | 25 +++++++------- mindspore/core/ops/batch_to_space.cc | 15 +++++---- mindspore/core/ops/batch_to_space_nd.cc | 10 ++++-- mindspore/core/ops/bias_add.cc | 15 ++++++--- mindspore/core/ops/binary_cross_entropy.cc | 21 ++++++------ mindspore/core/ops/conv2d.cc | 11 ++++--- mindspore/core/ops/conv2d_transpose.cc | 16 ++++++--- mindspore/core/ops/crop.cc | 3 +- mindspore/core/ops/ctcloss.cc | 33 +++++++++++++------ mindspore/core/ops/cumsum.cc | 3 +- mindspore/core/ops/depth_to_space.cc | 3 +- mindspore/core/ops/detection_post_process.cc | 16 ++++----- mindspore/core/ops/dropout_do_mask.cc | 8 +++-- mindspore/core/ops/dropout_gen_mask.cc | 4 ++- .../ops/dynamic_broadcast_gradient_args.cc | 3 +- mindspore/core/ops/embedding_lookup.cc | 6 ++-- mindspore/core/ops/equal.cc | 29 +++++----------- mindspore/core/ops/expand_dims.cc | 3 +- .../core/ops/fake_quant_with_min_max_vars.cc | 12 +++---- ...ake_quant_with_min_max_vars_per_channel.cc | 12 +++---- mindspore/core/ops/fill.cc | 13 ++++---- mindspore/core/ops/fusion/add_fusion.cc | 4 +-- mindspore/core/ops/fusion/avg_pool_fusion.cc | 3 +- .../ops/fusion/conv2d_transpose_fusion.cc | 6 ++-- mindspore/core/ops/fusion/full_connection.cc | 6 ++-- mindspore/core/ops/fusion/slice_fusion.cc | 6 ++-- mindspore/core/ops/gather.cc | 22 ++++++++----- mindspore/core/ops/gather_d.cc | 15 +++++---- mindspore/core/ops/gather_nd.cc | 3 +- mindspore/core/ops/grad/avg_pool_3d_grad.cc | 6 ++-- mindspore/core/ops/grad/batch_norm_grad.cc | 8 ++--- .../ops/grad/binary_cross_entropy_grad.cc | 19 ++++++----- .../core/ops/grad/conv2d_backprop_filter.cc | 7 ++-- .../core/ops/grad/conv2d_backprop_input.cc | 9 +++-- .../core/ops/grad/group_conv2d_grad_input.cc | 5 +-- mindspore/core/ops/grad/layer_norm_grad.cc | 9 ++--- mindspore/core/ops/grad/relu_grad.cc | 6 ++-- mindspore/core/ops/grad/relu_grad_v2.cc | 20 +++++------ .../sigmoid_cross_entropy_with_logits_grad.cc | 19 ++++++----- .../core/ops/grad/smooth_l1_loss_grad.cc | 17 +++++----- mindspore/core/ops/hashtable_lookup.cc | 7 ++-- mindspore/core/ops/lsh_projection.cc | 12 ++++--- mindspore/core/ops/lstm.cc | 33 ++++++++++--------- mindspore/core/ops/mat_mul.cc | 3 +- mindspore/core/ops/max_pool.cc | 3 +- mindspore/core/ops/mfcc.cc | 8 +++-- mindspore/core/ops/minimum.cc | 6 ++-- mindspore/core/ops/mul.cc | 7 ++-- mindspore/core/ops/not_equal.cc | 24 +++++--------- mindspore/core/ops/one_hot.cc | 11 ++++--- mindspore/core/ops/op_utils.h | 21 ++++++++++++ mindspore/core/ops/prelu.cc | 7 ++-- mindspore/core/ops/range.cc | 15 +++++---- mindspore/core/ops/real_div.cc | 6 ++-- mindspore/core/ops/reshape.cc | 3 +- mindspore/core/ops/resize_bilinear.cc | 3 +- mindspore/core/ops/reverse_sequence.cc | 4 ++- mindspore/core/ops/roi_pooling.cc | 3 +- mindspore/core/ops/scatter_nd.cc | 6 ++-- .../ops/sigmoid_cross_entropy_with_logits.cc | 3 +- mindspore/core/ops/space_to_batch.cc | 3 +- mindspore/core/ops/space_to_batch_nd.cc | 10 ++++-- ...parse_softmax_cross_entropy_with_logits.cc | 4 ++- mindspore/core/ops/sparse_to_dense.cc | 5 +-- mindspore/core/ops/strided_slice.cc | 6 ++-- mindspore/core/ops/sub.cc | 6 ++-- mindspore/core/ops/tile.cc | 3 +- mindspore/core/ops/topk.cc | 3 +- mindspore/core/ops/unsorted_segment_sum.cc | 10 +++--- mindspore/core/ops/where.cc | 18 +++++----- 77 files changed, 448 insertions(+), 335 deletions(-) diff --git a/mindspore/core/ops/adam.cc b/mindspore/core/ops/adam.cc index cd7e43a6868..4ca2ba48d36 100644 --- a/mindspore/core/ops/adam.cc +++ b/mindspore/core/ops/adam.cc @@ -28,19 +28,19 @@ abstract::AbstractBasePtr AdamInfer(const PrimitivePtr &primitive, const std::ve CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, prim_name); // infer shape - auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; - auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape]; - auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->GetShapeTrack())[kShape]; - auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[9]->GetShapeTrack())[kShape]; + auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShapeTrack())[kShape]; + auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape]; + auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShapeTrack())[kShape]; + auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex9]->GetShapeTrack())[kShape]; CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "m_shape", m_shape, prim_name); CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "v_shape", v_shape, prim_name); CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "grad_shape", grad_shape, prim_name); // infer type - auto var_type = input_args[0]->BuildType(); - auto m_type = input_args[1]->BuildType(); - auto v_type = input_args[2]->BuildType(); - auto grad_type = input_args[9]->BuildType(); + auto var_type = input_args[kInputIndex0]->BuildType(); + auto m_type = input_args[kInputIndex1]->BuildType(); + auto v_type = input_args[kInputIndex2]->BuildType(); + auto grad_type = input_args[kInputIndex9]->BuildType(); auto infer_var_type = CheckAndConvertUtils::CheckTensorTypeValid("var_type", var_type, common_valid_types, prim_name); auto infer_m_type = CheckAndConvertUtils::CheckTensorTypeValid("m_type", m_type, common_valid_types, prim_name); auto infer_v_type = CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, common_valid_types, prim_name); diff --git a/mindspore/core/ops/add.cc b/mindspore/core/ops/add.cc index 5dca848dae7..fd5a1f216eb 100644 --- a/mindspore/core/ops/add.cc +++ b/mindspore/core/ops/add.cc @@ -24,33 +24,22 @@ namespace mindspore { namespace ops { -namespace { -abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { +AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 2, prim_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - return BroadCastInferShape(prim_name, input_args); -} + auto output_shape = BroadCastInferShape(prim_name, input_args); -TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) { - for (const auto &item : input_args) { - MS_EXCEPTION_IF_NULL(item); - } - auto op_name = prim->name(); - (void)CheckAndConvertUtils::CheckInteger("Add infer", SizeToLong(input_args.size()), kGreaterEqual, 2, op_name); std::map types; (void)types.emplace("x", input_args[0]->BuildType()); (void)types.emplace("y", input_args[1]->BuildType()); - return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); -} -} // namespace - -AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args) { - return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); + auto output_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim_name); + return abstract::MakeAbstract(output_shape, output_type); } REGISTER_PRIMITIVE_C(kNameAdd, Add); } // namespace ops diff --git a/mindspore/core/ops/apply_momentum.cc b/mindspore/core/ops/apply_momentum.cc index cfe7933fd4c..6acd1e022b8 100644 --- a/mindspore/core/ops/apply_momentum.cc +++ b/mindspore/core/ops/apply_momentum.cc @@ -61,7 +61,8 @@ AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const Pr const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("apply_momentum_infer", SizeToLong(input_args.size()), kEqual, 5, prim_name); + const int64_t input_num = 5; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); @@ -70,11 +71,11 @@ AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const Pr auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; // Infer type - auto v_tensor_type = input_args[0]->BuildType(); - auto a_tensor_type = input_args[1]->BuildType(); - auto l_type = input_args[2]->BuildType(); - auto g_type = input_args[3]->BuildType(); - auto m_type = input_args[4]->BuildType(); + auto v_tensor_type = input_args[kInputIndex0]->BuildType(); + auto a_tensor_type = input_args[kInputIndex1]->BuildType(); + auto l_type = input_args[kInputIndex2]->BuildType(); + auto g_type = input_args[kInputIndex3]->BuildType(); + auto m_type = input_args[kInputIndex4]->BuildType(); const std::set valid_types = {kFloat16, kFloat32, kFloat64}; (void)CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_tensor_type, valid_types, prim_name); (void)CheckAndConvertUtils::CheckTensorTypeValid("a_type", a_tensor_type, valid_types, prim_name); diff --git a/mindspore/core/ops/assign.cc b/mindspore/core/ops/assign.cc index afbbddcdf49..16625490e95 100644 --- a/mindspore/core/ops/assign.cc +++ b/mindspore/core/ops/assign.cc @@ -30,8 +30,9 @@ AbstractBasePtr InferImplAssign(const abstract::AnalysisEnginePtr &, const Primi const AbstractBasePtrList &args_spec_list) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); + const int64_t input_num = 2; (void)CheckAndConvertUtils::CheckInteger( - "Assign infer", SizeToLong(CheckAndConvertUtils::GetRemoveMonadAbsNum(args_spec_list)), kEqual, 2, prim_name); + "infer", SizeToLong(CheckAndConvertUtils::GetRemoveMonadAbsNum(args_spec_list)), kEqual, input_num, prim_name); auto check_types = common_valid_types; (void)check_types.emplace(kBool); auto variable_type = args_spec_list[0]->BuildType(); diff --git a/mindspore/core/ops/avg_pool.cc b/mindspore/core/ops/avg_pool.cc index d6af8a467b4..f3426fdbdaf 100644 --- a/mindspore/core/ops/avg_pool.cc +++ b/mindspore/core/ops/avg_pool.cc @@ -83,7 +83,9 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vectorname(); auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; auto format = Format(GetValue(primitive->GetAttr(kFormat))); - (void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, 4, op_name); + const int64_t x_size = 4; + const int64_t attr_size = 4; + (void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, x_size, op_name); if (format == NHWC) { in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; } @@ -94,8 +96,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector>(primitive->GetAttr(kStrides)); - (void)CheckAndConvertUtils::CheckInteger("kernel size", SizeToLong(kernel_size.size()), kEqual, 4, op_name); - (void)CheckAndConvertUtils::CheckInteger("strides size", SizeToLong(strides.size()), kEqual, 4, op_name); + (void)CheckAndConvertUtils::CheckInteger("kernel size", SizeToLong(kernel_size.size()), kEqual, attr_size, op_name); + (void)CheckAndConvertUtils::CheckInteger("strides size", SizeToLong(strides.size()), kEqual, attr_size, op_name); if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) { return stride <= 0; })) { MS_LOG(EXCEPTION) << "Strides is not valid, strides must be positive."; } diff --git a/mindspore/core/ops/avg_pool_3d.cc b/mindspore/core/ops/avg_pool_3d.cc index a2ea0e2435e..b850c0baec9 100644 --- a/mindspore/core/ops/avg_pool_3d.cc +++ b/mindspore/core/ops/avg_pool_3d.cc @@ -38,12 +38,15 @@ void GetAttrs(const PrimitivePtr &primitive, std::vector *kernel_size, // attr kernel size *kernel_size = GetValue>(primitive->GetAttr(kKernelSize)); if (kernel_size->size() != kKernelDims) { - MS_LOG(EXCEPTION) << "kernel_size of AvgPool3D must be 3."; + MS_LOG(EXCEPTION) << "`kernel_size` of AvgPool3D must be 3."; } // attr strides *strides = GetValue>(primitive->GetAttr(kStrides)); if (strides->size() != kStridesDims) { - MS_LOG(EXCEPTION) << "strides of AvgPool3D must be 3."; + MS_LOG(EXCEPTION) << "`strides` of AvgPool3D must be 3."; + } + if (std::any_of(strides->begin(), strides->end(), [](int64_t stride) { return stride <= 0; })) { + MS_EXCEPTION(ValueError) << "Invalid strides, strides must be all positive."; } // sttr pad_list *pad_list = GetValue>(primitive->GetAttr(kPadList)); @@ -135,6 +138,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector out_shape = GetOutputShape(in_shape, kernel_d, kernel_h, kernel_w, stride_d, stride_h, stride_w, new_pad_list, ceil_mode); - if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) { + if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t shp_v) { return shp_v <= 0; })) { MS_LOG(EXCEPTION) << "output size is not valid."; } return std::make_shared(out_shape); diff --git a/mindspore/core/ops/batch_matmul.cc b/mindspore/core/ops/batch_matmul.cc index 597af3f0a93..a7a57200f0b 100644 --- a/mindspore/core/ops/batch_matmul.cc +++ b/mindspore/core/ops/batch_matmul.cc @@ -144,7 +144,9 @@ bool BatchMatmul::get_transpose_b() const { AbstractBasePtr BatchMatmulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - (void)CheckAndConvertUtils::CheckInteger("BatchMatmul infer", input_args.size(), kGreaterEqual, 2, primitive->name()); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("BatchMatmul infer", SizeToLong(input_args.size()), kGreaterEqual, input_num, + primitive->name()); return abstract::MakeAbstract(BatchMatmulInferShape(primitive, input_args), BatchMatmulInferType(primitive, input_args)); } diff --git a/mindspore/core/ops/batch_norm.cc b/mindspore/core/ops/batch_norm.cc index 38508570b0c..c8ef1908e43 100644 --- a/mindspore/core/ops/batch_norm.cc +++ b/mindspore/core/ops/batch_norm.cc @@ -73,17 +73,18 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit // Infer shape MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("batch_norm_infer", SizeToLong(input_args.size()), kEqual, 5, prim_name); + const int64_t input_num = 5; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); auto input_x = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto format = Format(GetValue(primitive->GetAttr(kFormat))); if (format == NHWC) { input_x = {input_x[0], input_x[3], input_x[1], input_x[2]}; } - auto scale = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; - auto bias = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; - auto mean = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape]; - auto variance = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->BuildShape())[kShape]; + auto scale = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + auto bias = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; + auto mean = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape]; + auto variance = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape]; std::vector input_shape_norm; if (format == NCHW) { @@ -106,19 +107,19 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit } // Infer type - auto scale_type = input_args[1]->BuildType()->cast()->element(); - auto bias_type = input_args[2]->BuildType()->cast()->element(); + auto scale_type = input_args[kInputIndex1]->BuildType()->cast()->element(); + auto bias_type = input_args[kInputIndex2]->BuildType()->cast()->element(); const std::set valid_types = {kFloat16, kFloat32}; auto input_x_type = - CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name); + CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[kInputIndex0]->BuildType(), valid_types, prim_name); std::map args; - args.emplace("scale", input_args[1]->BuildType()); - args.emplace("bias", input_args[2]->BuildType()); + args.emplace("scale", input_args[kInputIndex1]->BuildType()); + args.emplace("bias", input_args[kInputIndex2]->BuildType()); (void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); std::map args_moving; - args_moving.emplace("scale", input_args[2]->BuildType()); - args_moving.emplace("bias", input_args[3]->BuildType()); + args_moving.emplace("scale", input_args[kInputIndex2]->BuildType()); + args_moving.emplace("bias", input_args[kInputIndex3]->BuildType()); (void)CheckAndConvertUtils::CheckTensorTypeSame(args_moving, valid_types, prim_name); auto output0 = std::make_shared(input_x_type, input_x); diff --git a/mindspore/core/ops/batch_to_space.cc b/mindspore/core/ops/batch_to_space.cc index 5130606555a..601c3e3aa89 100644 --- a/mindspore/core/ops/batch_to_space.cc +++ b/mindspore/core/ops/batch_to_space.cc @@ -59,15 +59,18 @@ AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri auto block_size = GetValue>(primitive->GetAttr(kBlockSize)); auto crops = GetValue>>(primitive->GetAttr(kCrops)); auto out_shape = x_shape; - (void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name); - (void)CheckAndConvertUtils::CheckInteger("block_size size", SizeToLong(block_size.size()), kEqual, 4, prim_name); - (void)CheckAndConvertUtils::CheckInteger("crops size", SizeToLong(crops.size()), kEqual, 4, prim_name); - (void)CheckAndConvertUtils::CheckInteger("crops[0] size", SizeToLong(crops[0].size()), kEqual, 4, prim_name); - (void)CheckAndConvertUtils::CheckInteger("crops[1] size", SizeToLong(crops[1].size()), kEqual, 4, prim_name); + const int64_t attr_size = 4; + const int64_t x_rank = 4; + (void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kEqual, x_rank, prim_name); + (void)CheckAndConvertUtils::CheckInteger("block_size size", SizeToLong(block_size.size()), kEqual, attr_size, + prim_name); + (void)CheckAndConvertUtils::CheckInteger("crops size", SizeToLong(crops.size()), kEqual, attr_size, prim_name); + (void)CheckAndConvertUtils::CheckInteger("crops[0] size", SizeToLong(crops[0].size()), kEqual, attr_size, prim_name); + (void)CheckAndConvertUtils::CheckInteger("crops[1] size", SizeToLong(crops[1].size()), kEqual, attr_size, prim_name); for (size_t i = 0; i < 2; ++i) { auto x_block_prod = out_shape[i + 2] * block_size[i]; auto crops_sum = crops[i][0] + crops[i][1]; - CheckAndConvertUtils::Check("x block shape prod", x_block_prod, kGreaterThan, "crops sum", 4, prim_name); + CheckAndConvertUtils::Check("x block shape prod", x_block_prod, kGreaterThan, "crops sum", attr_size, prim_name); out_shape[i + 2] = x_block_prod - crops_sum; } (void)CheckAndConvertUtils::CheckInteger("x_shape[0] % (block_size[0]*block_size[1])", diff --git a/mindspore/core/ops/batch_to_space_nd.cc b/mindspore/core/ops/batch_to_space_nd.cc index ffb6e66e6a6..a7320237489 100644 --- a/mindspore/core/ops/batch_to_space_nd.cc +++ b/mindspore/core/ops/batch_to_space_nd.cc @@ -30,7 +30,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vectorname(); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - (void)CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name); + const int64_t x_rank = 4; + (void)CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()), kEqual, x_rank, prim_name); auto out_shape = x_shape; int64_t block_shape_prod = 1; size_t offset = 2; @@ -62,7 +63,8 @@ TypePtr InferType(const std::vector &input_args) { } // namespace void BatchToSpaceND::set_crops(std::vector> crops) { - (void)CheckAndConvertUtils::CheckInteger(kCrops, SizeToLong(crops.size()), kEqual, 2, this->name()); + const int64_t crop_size = 2; + (void)CheckAndConvertUtils::CheckInteger(kCrops, SizeToLong(crops.size()), kEqual, crop_size, this->name()); size_t h = crops.size(); size_t w = crops[0].size(); std::vector temp_w = {2, 2}; @@ -80,7 +82,9 @@ std::vector> BatchToSpaceND::get_crops() const { return GetValue>>(value_ptr); } void BatchToSpaceND::set_block_shape(std::vector block_shape) { - (void)CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape.size()), kEqual, 2, this->name()); + const int64_t block_size = 2; + (void)CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape.size()), kEqual, block_size, + this->name()); for (size_t i = 0; i < block_shape.size(); i++) { (void)CheckAndConvertUtils::CheckInteger(kBlockShape, block_shape[i], kGreaterEqual, 1, this->name()); } diff --git a/mindspore/core/ops/bias_add.cc b/mindspore/core/ops/bias_add.cc index 0883eee9519..f8c991f80fd 100644 --- a/mindspore/core/ops/bias_add.cc +++ b/mindspore/core/ops/bias_add.cc @@ -35,15 +35,20 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector(prim_name, input_args, 1); MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(bias); - (void)CheckAndConvertUtils::CheckInteger("arg size", SizeToLong(input_args.size()), kEqual, 2, prim_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("arg size", SizeToLong(input_args.size()), kEqual, input_num, prim_name); auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); auto input_shape = shape_map[kShape]; auto min_shape = shape_map[kMinShape]; auto max_shape = shape_map[kMaxShape]; - CheckAndConvertUtils::CheckInRange("bias_add_infer", input_shape.size(), kIncludeBoth, {2, 5}, prim_name); + const int64_t x_min_rank = 2; + const int64_t x_max_rank = 5; + CheckAndConvertUtils::CheckInRange("bias_add_infer", input_shape.size(), kIncludeBoth, {x_min_rank, x_max_rank}, + prim_name); auto bias_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; (void)CheckAndConvertUtils::CheckInteger("bias rank", SizeToLong(bias_shape.size()), kEqual, 1, prim_name); - (void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(input_shape.size()), kGreaterEqual, 2, prim_name); + const int64_t x_size = 2; + (void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(input_shape.size()), kGreaterEqual, x_size, prim_name); auto data_format_ptr = primitive->GetAttr("format"); int64_t data_format = Format::NCHW; if (data_format_ptr != nullptr) { @@ -71,7 +76,9 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(prim); auto prim_name = prim->name(); - (void)CheckAndConvertUtils::CheckInteger("biasadd_infer", SizeToLong(input_args.size()), kEqual, 2, prim_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("biasadd_infer", SizeToLong(input_args.size()), kEqual, input_num, + prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } diff --git a/mindspore/core/ops/binary_cross_entropy.cc b/mindspore/core/ops/binary_cross_entropy.cc index cf1debefa48..4bf0d20a812 100644 --- a/mindspore/core/ops/binary_cross_entropy.cc +++ b/mindspore/core/ops/binary_cross_entropy.cc @@ -34,9 +34,9 @@ abstract::ShapePtr BinaryCrossEntroyInferShape(const PrimitivePtr &primitive, MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInRange("binary_cross_entropy_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; - auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name); std::vector infer_shape; if (weight_shape.size() < 1) { @@ -50,19 +50,20 @@ abstract::ShapePtr BinaryCrossEntroyInferShape(const PrimitivePtr &primitive, } TypePtr BinaryCrossEntroyInferType(const PrimitivePtr &prim, const std::vector &input_args) { - (void)CheckAndConvertUtils::CheckInteger("binary_cross_entropy_infer", SizeToLong(input_args.size()), kEqual, 3, - prim->name()); + const int64_t input_num = 3; + (void)CheckAndConvertUtils::CheckInteger("binary_cross_entropy_infer", SizeToLong(input_args.size()), kEqual, + input_num, prim->name()); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } const std::set valid_types = {kFloat16, kFloat32}; std::map types; - (void)types.emplace("x_shape", input_args[0]->BuildType()); - (void)types.emplace("y_shape", input_args[1]->BuildType()); + (void)types.emplace("x_shape", input_args[kInputIndex0]->BuildType()); + (void)types.emplace("y_shape", input_args[kInputIndex1]->BuildType()); auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); - if (input_args[3]->BuildType() != nullptr) { - (void)types.emplace("x_shape", input_args[0]->BuildType()); - (void)types.emplace("weight_shape", input_args[2]->BuildType()); + if (input_args[kInputIndex3]->BuildType() != nullptr) { + (void)types.emplace("x_shape", input_args[kInputIndex0]->BuildType()); + (void)types.emplace("weight_shape", input_args[kInputIndex2]->BuildType()); infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); } return infer_type; diff --git a/mindspore/core/ops/conv2d.cc b/mindspore/core/ops/conv2d.cc index b09688e485c..c40e0cbbdec 100644 --- a/mindspore/core/ops/conv2d.cc +++ b/mindspore/core/ops/conv2d.cc @@ -109,8 +109,9 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve auto w_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape()); auto x_shape = x_shape_map[kShape]; auto w_shape = w_shape_map[kShape]; - (void)CheckAndConvertUtils::CheckInteger("x shape size", SizeToLong(x_shape.size()), kEqual, 4, primitive->name()); - (void)CheckAndConvertUtils::CheckInteger("w shape size", SizeToLong(w_shape.size()), kEqual, 4, primitive->name()); + const int64_t shape_size = 4; + (void)CheckAndConvertUtils::CheckInteger("x shape size", SizeToLong(x_shape.size()), kEqual, shape_size, prim_name); + (void)CheckAndConvertUtils::CheckInteger("w shape size", SizeToLong(w_shape.size()), kEqual, shape_size, prim_name); auto x_min_shape = x_shape_map[kMinShape]; auto x_max_shape = x_shape_map[kMaxShape]; auto w_min_shape = w_shape_map[kMinShape]; @@ -252,7 +253,8 @@ void Conv2D::set_pad_mode(const PadMode &pad_mode) { } void Conv2D::set_pad(const std::vector &pad) { - (void)CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name()); + const int64_t pad_size = 4; + (void)CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, pad_size, name()); (void)AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name()))); } @@ -316,7 +318,8 @@ Format Conv2D::get_format() const { AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - (void)CheckAndConvertUtils::CheckInteger("Conv2d infer", SizeToLong(input_args.size()), kGreaterEqual, 2, + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("Conv2d infer", SizeToLong(input_args.size()), kGreaterEqual, input_num, primitive->name()); const std::set valid_types = {kInt8, kInt32, kInt64, kFloat16, kFloat32}; std::map types; diff --git a/mindspore/core/ops/conv2d_transpose.cc b/mindspore/core/ops/conv2d_transpose.cc index 29e8330b1cd..a357315334e 100644 --- a/mindspore/core/ops/conv2d_transpose.cc +++ b/mindspore/core/ops/conv2d_transpose.cc @@ -53,7 +53,8 @@ void Conv2DTranspose::set_out_channel(int64_t out_channel) { } void Conv2DTranspose::set_kernel_size(const std::vector &kernel_size) { - (void)CheckAndConvertUtils::CheckInteger(kKernelSize, SizeToLong(kernel_size.size()), kEqual, 2, name()); + const int64_t kernel_len = 2; + (void)CheckAndConvertUtils::CheckInteger(kKernelSize, SizeToLong(kernel_size.size()), kEqual, kernel_len, name()); for (int64_t item : kernel_size) { (void)CheckAndConvertUtils::CheckInteger(kKernelSize, item, kGreaterEqual, 1, name()); } @@ -61,7 +62,8 @@ void Conv2DTranspose::set_kernel_size(const std::vector &kernel_size) { } void Conv2DTranspose::set_stride(const std::vector &stride) { - (void)CheckAndConvertUtils::CheckInteger(kStride, SizeToLong(stride.size()), kEqual, 2, name()); + const int64_t stride_size = 2; + (void)CheckAndConvertUtils::CheckInteger(kStride, SizeToLong(stride.size()), kEqual, stride_size, name()); for (int64_t item : stride) { (void)CheckAndConvertUtils::CheckInteger(kStride, item, kGreaterEqual, 1, name()); } @@ -69,7 +71,9 @@ void Conv2DTranspose::set_stride(const std::vector &stride) { } void Conv2DTranspose::set_dilation(const std::vector &dilation) { - (void)CheckAndConvertUtils::CheckInteger(kDilation, SizeToLong(dilation.size()), kGreaterEqual, 2, name()); + const int64_t dilation_size = 2; + (void)CheckAndConvertUtils::CheckInteger(kDilation, SizeToLong(dilation.size()), kGreaterEqual, dilation_size, + name()); (void)AddAttr(kDilation, MakeValue(dilation)); } @@ -87,7 +91,8 @@ void Conv2DTranspose::set_pad_mode(const PadMode &pad_mode) { } void Conv2DTranspose::set_pad(const std::vector &pad) { - (void)CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, 4, name()); + const int64_t pad_size = 4; + (void)CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, pad_size, name()); (void)AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name()))); } @@ -105,7 +110,8 @@ void Conv2DTranspose::set_format(const Format &format) { } void Conv2DTranspose::set_pad_list(const std::vector &pad_list) { - (void)CheckAndConvertUtils::CheckInteger(kPadList, SizeToLong(pad_list.size()), kEqual, 4, name()); + const int64_t pad_size = 4; + (void)CheckAndConvertUtils::CheckInteger(kPadList, SizeToLong(pad_list.size()), kEqual, pad_size, name()); (void)this->AddAttr(kPadList, MakeValue(pad_list)); } diff --git a/mindspore/core/ops/crop.cc b/mindspore/core/ops/crop.cc index 8f6dd23864e..b0a9d426672 100644 --- a/mindspore/core/ops/crop.cc +++ b/mindspore/core/ops/crop.cc @@ -44,7 +44,8 @@ AbstractBasePtr CropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } diff --git a/mindspore/core/ops/ctcloss.cc b/mindspore/core/ops/ctcloss.cc index 5086fb0b4e1..315c016ad79 100644 --- a/mindspore/core/ops/ctcloss.cc +++ b/mindspore/core/ops/ctcloss.cc @@ -29,7 +29,9 @@ namespace mindspore { namespace ops { namespace { void CheckCTCLossInputs(const std::vector &input_args, const std::string &op_name) { - (void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kGreaterEqual, 4, op_name); + const int64_t input_num = 4; + (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, input_num, + op_name); auto inputs = CheckAndConvertUtils::CheckArgs(op_name, input_args, 0); auto labels_indices = CheckAndConvertUtils::CheckArgs(op_name, input_args, 1); auto labels_values = CheckAndConvertUtils::CheckArgs(op_name, input_args, 2); @@ -40,11 +42,18 @@ void CheckCTCLossInputs(const std::vector &input_args, const st auto labels_values_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(labels_values->BuildShape())[kShape]; auto sequence_length_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(sequence_length->BuildShape())[kShape]; - (void)CheckAndConvertUtils::CheckInteger("inputs rank", inputs_shape.size(), kEqual, 3, op_name); - (void)CheckAndConvertUtils::CheckInteger("label_indices rank", labels_indices_shape.size(), kEqual, 2, op_name); - (void)CheckAndConvertUtils::CheckInteger("label_indices second dim", labels_indices_shape[1], kEqual, 2, op_name); - (void)CheckAndConvertUtils::CheckInteger("label_values rank", labels_values_shape.size(), kEqual, 1, op_name); - (void)CheckAndConvertUtils::CheckInteger("sequence_length rank", sequence_length_shape.size(), kEqual, 1, op_name); + const int64_t input_size = 3; + const int64_t label_indice_size = 2; + const int64_t label_indice_last_dim = 2; + (void)CheckAndConvertUtils::CheckInteger("inputs rank", SizeToLong(inputs_shape.size()), kEqual, input_size, op_name); + (void)CheckAndConvertUtils::CheckInteger("label_indices rank", SizeToLong(labels_indices_shape.size()), kEqual, + label_indice_size, op_name); + (void)CheckAndConvertUtils::CheckInteger("label_indices second dim", labels_indices_shape[1], kEqual, + label_indice_last_dim, op_name); + (void)CheckAndConvertUtils::CheckInteger("label_values rank", int64_t(labels_values_shape.size()), kEqual, 1, + op_name); + (void)CheckAndConvertUtils::CheckInteger("sequence_length rank", int64_t(sequence_length_shape.size()), kEqual, 1, + op_name); if (labels_indices_shape[0] != labels_values_shape[0]) { MS_EXCEPTION(ValueError) << "For CTCLoss first dim of label_indices and label_value must be same, but got " @@ -82,11 +91,15 @@ abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vec TuplePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) { auto op_name = primitive->name(); - (void)CheckAndConvertUtils::CheckTensorTypeValid("labels_indices", input_args[1]->BuildType(), {kInt64}, op_name); - (void)CheckAndConvertUtils::CheckTensorTypeValid("labels_values", input_args[2]->BuildType(), {kInt32}, op_name); - (void)CheckAndConvertUtils::CheckTensorTypeValid("sequence_length", input_args[3]->BuildType(), {kInt32}, op_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("labels_indices", input_args[kInputIndex1]->BuildType(), {kInt64}, + op_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("labels_values", input_args[kInputIndex2]->BuildType(), {kInt32}, + op_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("sequence_length", input_args[kInputIndex3]->BuildType(), {kInt32}, + op_name); const std::set valid_types = {kFloat16, kFloat32, kFloat64}; - auto type = CheckAndConvertUtils::CheckTensorTypeValid("inputs", input_args[0]->BuildType(), valid_types, op_name); + auto type = + CheckAndConvertUtils::CheckTensorTypeValid("inputs", input_args[kInputIndex0]->BuildType(), valid_types, op_name); return std::make_shared(std::vector{type, type}); } } // namespace diff --git a/mindspore/core/ops/cumsum.cc b/mindspore/core/ops/cumsum.cc index 4c0b48e92ce..709abf839ac 100644 --- a/mindspore/core/ops/cumsum.cc +++ b/mindspore/core/ops/cumsum.cc @@ -44,7 +44,8 @@ AbstractBasePtr CumSumInfer(const abstract::AnalysisEnginePtr &, const Primitive const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, prim_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } diff --git a/mindspore/core/ops/depth_to_space.cc b/mindspore/core/ops/depth_to_space.cc index ca4adf1fa92..0d7a97835eb 100644 --- a/mindspore/core/ops/depth_to_space.cc +++ b/mindspore/core/ops/depth_to_space.cc @@ -59,7 +59,8 @@ AbstractBasePtr DepthToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri if (format == NHWC) { x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; } - (void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name); + const int64_t x_rank = 4; + (void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kEqual, x_rank, prim_name); int64_t block_size = GetValue(primitive->GetAttr(kBlockSize)); (void)CheckAndConvertUtils::CheckInteger("x_shape[1] % (block_size*block_size)", x_shape[1] % (block_size * block_size), kEqual, 0, prim_name); diff --git a/mindspore/core/ops/detection_post_process.cc b/mindspore/core/ops/detection_post_process.cc index 3c90483f838..0ab46b4fcbc 100644 --- a/mindspore/core/ops/detection_post_process.cc +++ b/mindspore/core/ops/detection_post_process.cc @@ -117,14 +117,14 @@ AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, c const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("detection_post_process_infer", SizeToLong(input_args.size()), kEqual, 3, - prim_name); - MS_EXCEPTION_IF_NULL(input_args[0]); - MS_EXCEPTION_IF_NULL(input_args[1]); - MS_EXCEPTION_IF_NULL(input_args[2]); - auto boxes = input_args[0]; - auto scores = input_args[1]; - auto anchors = input_args[2]; + const int64_t input_num = 3; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); + MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]); + MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]); + MS_EXCEPTION_IF_NULL(input_args[kInputIndex2]); + auto boxes = input_args[kInputIndex0]; + auto scores = input_args[kInputIndex1]; + auto anchors = input_args[kInputIndex2]; auto boxes_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(boxes->BuildShape())[kShape]; auto scores_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(scores->BuildShape())[kShape]; auto anchors_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(anchors->BuildShape())[kShape]; diff --git a/mindspore/core/ops/dropout_do_mask.cc b/mindspore/core/ops/dropout_do_mask.cc index ddcf04b5cbd..093f0ca561d 100644 --- a/mindspore/core/ops/dropout_do_mask.cc +++ b/mindspore/core/ops/dropout_do_mask.cc @@ -60,7 +60,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vectorToString() << ", mask shape: " << mask_shape->ToString(); } - auto keep_prop = input_args[2]; + auto keep_prop = input_args[kInputIndex2]; if (keep_prop->isa()) { auto keep_prop_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(keep_prop->BuildShape())[kShape]; if (!keep_prop_shape.empty()) { @@ -72,7 +72,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { auto op_name = primitive->name(); - auto keep_prop = input_args[2]; + auto keep_prop = input_args[kInputIndex2]; MS_EXCEPTION_IF_NULL(keep_prop); auto keep_prop_value = keep_prop->BuildValue(); MS_EXCEPTION_IF_NULL(keep_prop_value); @@ -116,7 +116,9 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - (void)CheckAndConvertUtils::CheckInteger("infer shape", input_args.size(), kGreaterEqual, 3, primitive->name()); + const int64_t input_num = 3; + (void)CheckAndConvertUtils::CheckInteger("infer shape", SizeToLong(input_args.size()), kGreaterEqual, input_num, + primitive->name()); return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(DropoutDoMask, prim::kPrimDropoutDoMask, DropoutDoMaskInfer, nullptr, true); diff --git a/mindspore/core/ops/dropout_gen_mask.cc b/mindspore/core/ops/dropout_gen_mask.cc index 63affa9f477..a49382bf1f8 100644 --- a/mindspore/core/ops/dropout_gen_mask.cc +++ b/mindspore/core/ops/dropout_gen_mask.cc @@ -94,7 +94,9 @@ ShapeVector CalOutputShape(const AbstractBasePtrList shape_list) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { auto op_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("infer shape", input_args.size(), kGreaterEqual, 2, op_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("infer shape", SizeToLong(input_args.size()), kGreaterEqual, input_num, + op_name); AbstractBasePtr shape_args = input_args[0]; MS_EXCEPTION_IF_NULL(shape_args); diff --git a/mindspore/core/ops/dynamic_broadcast_gradient_args.cc b/mindspore/core/ops/dynamic_broadcast_gradient_args.cc index d325f34d93b..d716ca0dcaa 100644 --- a/mindspore/core/ops/dynamic_broadcast_gradient_args.cc +++ b/mindspore/core/ops/dynamic_broadcast_gradient_args.cc @@ -55,7 +55,8 @@ size_t CheckInputsAndGetShape(const AbstractBasePtr &input_arg, const string &pr abstract::TupleShapePtr Infer(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); auto x_shape = CheckInputsAndGetShape(input_args[0], prim_name); auto y_shape = CheckInputsAndGetShape(input_args[1], prim_name); diff --git a/mindspore/core/ops/embedding_lookup.cc b/mindspore/core/ops/embedding_lookup.cc index 06d4c337d72..bd5172de906 100644 --- a/mindspore/core/ops/embedding_lookup.cc +++ b/mindspore/core/ops/embedding_lookup.cc @@ -37,7 +37,8 @@ AbstractBasePtr EmbeddingLookupInfer(const abstract::AnalysisEnginePtr &, const const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 3, prim_name); + const int64_t input_num = 3; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } @@ -47,7 +48,8 @@ AbstractBasePtr EmbeddingLookupInfer(const abstract::AnalysisEnginePtr &, const MS_EXCEPTION_IF_NULL(indices); const std::set int_valid_types = {kInt8, kInt16, kInt32, kInt64}; (void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indices->BuildType(), int_valid_types, prim_name); - (void)CheckAndConvertUtils::CheckTensorTypeValid("offset", input_args[2]->BuildType(), int_valid_types, prim_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("offset", input_args[kInputIndex2]->BuildType(), int_valid_types, + prim_name); MS_EXCEPTION_IF_NULL(params->shape()); auto params_shp = params->shape()->shape(); MS_EXCEPTION_IF_NULL(indices->shape()); diff --git a/mindspore/core/ops/equal.cc b/mindspore/core/ops/equal.cc index a26c6c0bd19..d9716473eb7 100644 --- a/mindspore/core/ops/equal.cc +++ b/mindspore/core/ops/equal.cc @@ -26,21 +26,13 @@ namespace mindspore { namespace ops { -namespace { -abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { +AbstractBasePtr EqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, op_name); - for (const auto &item : input_args) { - MS_EXCEPTION_IF_NULL(item); - } - return BroadCastInferShape(op_name, input_args); -} - -TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(prim); - auto op_name = prim->name(); - (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, 2, op_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, input_num, + op_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } @@ -50,14 +42,9 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & std::map types; (void)types.emplace("x", input_args[0]->BuildType()); (void)types.emplace("y", input_args[1]->BuildType()); - return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name); -} -} // namespace - -AbstractBasePtr EqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args) { - (void)InferType(primitive, input_args); - return abstract::MakeAbstract(InferShape(primitive, input_args), kBool); + (void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name); + auto out_shape = BroadCastInferShape(op_name, input_args); + return abstract::MakeAbstract(out_shape, kBool); } REGISTER_PRIMITIVE_C(kNameEqual, Equal); } // namespace ops diff --git a/mindspore/core/ops/expand_dims.cc b/mindspore/core/ops/expand_dims.cc index fd558abe9f3..8e7b4c13492 100644 --- a/mindspore/core/ops/expand_dims.cc +++ b/mindspore/core/ops/expand_dims.cc @@ -31,7 +31,8 @@ AbstractBasePtr ExpandDimsInfer(const abstract::AnalysisEnginePtr &, const Primi const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 2, prim_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } diff --git a/mindspore/core/ops/fake_quant_with_min_max_vars.cc b/mindspore/core/ops/fake_quant_with_min_max_vars.cc index 6f614749ef2..6d788fda280 100644 --- a/mindspore/core/ops/fake_quant_with_min_max_vars.cc +++ b/mindspore/core/ops/fake_quant_with_min_max_vars.cc @@ -29,9 +29,9 @@ namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; - auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; + auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; (void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kGreaterEqual, 1, prim_name); CheckAndConvertUtils::Check("min_shape", min_shape, kEqual, "max_shape", max_shape, prim_name); (void)CheckAndConvertUtils::CheckInteger("min_shape", SizeToLong(min_shape.size()), kEqual, 1, prim_name); @@ -51,9 +51,9 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & MS_LOG(EXCEPTION) << "nullptr"; } std::map types; - (void)types.emplace("x", input_args[0]->BuildType()); - (void)types.emplace("min", input_args[1]->BuildType()); - (void)types.emplace("max", input_args[2]->BuildType()); + (void)types.emplace("x", input_args[kInputIndex0]->BuildType()); + (void)types.emplace("min", input_args[kInputIndex1]->BuildType()); + (void)types.emplace("max", input_args[kInputIndex2]->BuildType()); return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/fake_quant_with_min_max_vars_per_channel.cc b/mindspore/core/ops/fake_quant_with_min_max_vars_per_channel.cc index cd79263c852..87164771d0c 100644 --- a/mindspore/core/ops/fake_quant_with_min_max_vars_per_channel.cc +++ b/mindspore/core/ops/fake_quant_with_min_max_vars_per_channel.cc @@ -44,17 +44,17 @@ AbstractBasePtr FakeQuantWithMinMaxVarsPerChannelInfer(const abstract::AnalysisE const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; - auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; (void)CheckAndConvertUtils::CheckInteger("x rank", (int64_t)x_shape.size(), kGreaterThan, 1, op_name); CheckAndConvertUtils::Check("min shape", min_shape, kEqual, "max shape", max_shape, op_name); (void)CheckAndConvertUtils::CheckInteger("min shape", (int64_t)min_shape.size(), kEqual, 1, op_name); CheckAndConvertUtils::Check("min shape", min_shape[0], kEqual, "x shape", x_shape[x_shape.size() - 1], op_name); - auto x_type = input_args[0]->BuildType(); - auto min_type = input_args[1]->BuildType(); - auto max_type = input_args[2]->BuildType(); + auto x_type = input_args[kInputIndex0]->BuildType(); + auto min_type = input_args[kInputIndex1]->BuildType(); + auto max_type = input_args[kInputIndex2]->BuildType(); std::vector type_name = {"x", "min", "max"}; std::vector type = {x_type, min_type, max_type}; for (size_t i = 0; i < 3; i++) { diff --git a/mindspore/core/ops/fill.cc b/mindspore/core/ops/fill.cc index 5272c217754..c5237ea9b9b 100644 --- a/mindspore/core/ops/fill.cc +++ b/mindspore/core/ops/fill.cc @@ -26,11 +26,12 @@ AbstractBasePtr FillInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 3, prim_name); + const int64_t input_num = 3; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - auto input_dtype = input_args[0]->cast(); + auto input_dtype = input_args[kInputIndex0]->cast(); MS_EXCEPTION_IF_NULL(input_dtype); auto dtype_value = input_dtype->BuildValue(); MS_EXCEPTION_IF_NULL(dtype_value); @@ -39,10 +40,10 @@ AbstractBasePtr FillInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt auto valid_types = common_valid_types; valid_types.insert(kBool); (void)CheckAndConvertUtils::CheckTypeValid("output datatype", dtype, valid_types, prim_name); - auto out_shape = GetValue>(input_args[1]->BuildValue()); - auto x_type = input_args[2]->BuildType(); + auto out_shape = GetValue>(input_args[kInputIndex1]->BuildValue()); + auto x_type = input_args[kInputIndex2]->BuildType(); auto x_type_id = x_type->type_id(); - auto x_value = input_args[2]->BuildValue(); + auto x_value = input_args[kInputIndex2]->BuildValue(); auto abs = std::make_shared(dtype, std::make_shared(out_shape)); tensor::TensorPtr tensor = std::make_shared(x_type_id, out_shape); MS_EXCEPTION_IF_NULL(tensor); @@ -54,7 +55,7 @@ AbstractBasePtr FillInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt auto float_value = GetValue(x_value); SetTensorData(tensor->data_c(), float_value, mem_size); } else { - MS_LOG(ERROR) << " Fill not supported to flod the constant type " << input_args[2]->ToString(); + MS_LOG(ERROR) << " Fill not supported to flod the constant type " << input_args[kInputIndex2]->ToString(); } abs->set_value(tensor); return abs; diff --git a/mindspore/core/ops/fusion/add_fusion.cc b/mindspore/core/ops/fusion/add_fusion.cc index ccfa3e38097..1a81553c7a5 100644 --- a/mindspore/core/ops/fusion/add_fusion.cc +++ b/mindspore/core/ops/fusion/add_fusion.cc @@ -47,8 +47,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & MS_LOG(EXCEPTION) << "nullptr"; } std::map types; - (void)types.emplace("x", input_args[0]->BuildType()); - (void)types.emplace("y", input_args[1]->BuildType()); + (void)types.emplace("x", input_args[kInputIndex0]->BuildType()); + (void)types.emplace("y", input_args[kInputIndex1]->BuildType()); return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/fusion/avg_pool_fusion.cc b/mindspore/core/ops/fusion/avg_pool_fusion.cc index 8ac8e5f290f..3e5497f6da5 100644 --- a/mindspore/core/ops/fusion/avg_pool_fusion.cc +++ b/mindspore/core/ops/fusion/avg_pool_fusion.cc @@ -62,7 +62,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector>(primitive->GetAttr(kKernelSize)); auto pad_mode = PadMode(GetValue(primitive->GetAttr(kPadMode))); auto batch = in_shape[0]; diff --git a/mindspore/core/ops/fusion/conv2d_transpose_fusion.cc b/mindspore/core/ops/fusion/conv2d_transpose_fusion.cc index b199c6a3457..5d3d4ef9d13 100644 --- a/mindspore/core/ops/fusion/conv2d_transpose_fusion.cc +++ b/mindspore/core/ops/fusion/conv2d_transpose_fusion.cc @@ -40,7 +40,8 @@ void Conv2dTransposeFusion::Init(int64_t in_channel, int64_t out_channel, const } void Conv2dTransposeFusion::set_kernel_size(const std::vector &kernel_size) { - (void)CheckAndConvertUtils::CheckInteger(kKernelSize, SizeToLong(kernel_size.size()), kEqual, 2, name()); + const size_t kernel_len = 2; + (void)CheckAndConvertUtils::CheckInteger(kKernelSize, SizeToLong(kernel_size.size()), kEqual, kernel_len, name()); for (int64_t item : kernel_size) { (void)CheckAndConvertUtils::CheckInteger(kKernelSize, item, kGreaterEqual, 1, name()); } @@ -48,7 +49,8 @@ void Conv2dTransposeFusion::set_kernel_size(const std::vector &kernel_s } void Conv2dTransposeFusion::set_dilation(const std::vector &dilation) { - (void)CheckAndConvertUtils::CheckInteger(kDilation, SizeToLong(dilation.size()), kEqual, 2, name()); + const size_t dilation_size = 2; + (void)CheckAndConvertUtils::CheckInteger(kDilation, SizeToLong(dilation.size()), kEqual, dilation_size, name()); for (int64_t item : dilation) { (void)CheckAndConvertUtils::CheckInteger(kDilation, item, kGreaterEqual, 1, name()); } diff --git a/mindspore/core/ops/fusion/full_connection.cc b/mindspore/core/ops/fusion/full_connection.cc index f9dbbaabedd..29c12c25db4 100644 --- a/mindspore/core/ops/fusion/full_connection.cc +++ b/mindspore/core/ops/fusion/full_connection.cc @@ -62,8 +62,8 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - MS_EXCEPTION_IF_NULL(input_args[0]); - MS_EXCEPTION_IF_NULL(input_args[1]); + MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]); + MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]); auto input0 = input_args[0]; auto input1 = input_args[1]; auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input0->BuildShape())[kShape]; @@ -93,7 +93,7 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P new_k = input1_shape[1]; } if (has_bias) { - auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; + auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; if (input2_shape[0] != input1_shape[0]) { MS_EXCEPTION(ValueError) << "Bias size is invalid"; } diff --git a/mindspore/core/ops/fusion/slice_fusion.cc b/mindspore/core/ops/fusion/slice_fusion.cc index f1fcc7914dd..b346304ec83 100644 --- a/mindspore/core/ops/fusion/slice_fusion.cc +++ b/mindspore/core/ops/fusion/slice_fusion.cc @@ -36,9 +36,9 @@ AbstractBasePtr SliceFusionInfer(const abstract::AnalysisEnginePtr &, const Prim auto op_name = primitive->name(); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto x_shape_len = x_shape.size(); - auto begin_v = input_args[1]->BuildValue(); - auto size_v = input_args[2]->BuildValue(); - auto x_type = input_args[0]->BuildType(); + auto begin_v = input_args[kInputIndex1]->BuildValue(); + auto size_v = input_args[kInputIndex2]->BuildValue(); + auto x_type = input_args[kInputIndex0]->BuildType(); MS_EXCEPTION_IF_NULL(x_type); MS_EXCEPTION_IF_NULL(begin_v); MS_EXCEPTION_IF_NULL(size_v); diff --git a/mindspore/core/ops/gather.cc b/mindspore/core/ops/gather.cc index a2b7599f2ac..0bf22b67b8b 100644 --- a/mindspore/core/ops/gather.cc +++ b/mindspore/core/ops/gather.cc @@ -14,10 +14,12 @@ * limitations under the License. */ +#include "ops/gather.h" + #include #include #include -#include "ops/gather.h" +#include "ops/op_utils.h" namespace mindspore { namespace ops { @@ -33,28 +35,30 @@ AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const Primitive CheckAndConvertUtils::CheckArgs(op_name, input_args, 1); // check std::set valid_params_types = {kTensorType}; - (void)CheckAndConvertUtils::CheckSubClass("params_type", input_args[0]->BuildType(), valid_params_types, op_name); + (void)CheckAndConvertUtils::CheckSubClass("params_type", input_args[kInputIndex0]->BuildType(), valid_params_types, + op_name); std::set int_types = {kInt8, kInt16, kInt32, kInt64}; - (void)CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[1]->BuildType(), int_types, op_name); - (void)CheckAndConvertUtils::CheckTypeValid("axis_type", input_args[2]->BuildType(), int_types, op_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[kInputIndex1]->BuildType(), int_types, + op_name); + (void)CheckAndConvertUtils::CheckTypeValid("axis_type", input_args[kInputIndex2]->BuildType(), int_types, op_name); bool ind_dyn = (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty()); bool param_dyn = (!params->shape()->min_shape().empty() && !params->shape()->max_shape().empty()); int64_t axis_val = 0; // 3rd input is a Tensor when Gather is a dynamic shape operator - if (input_args[2]->isa()) { - auto axis = input_args[2]->cast(); + if (input_args[kInputIndex2]->isa()) { + auto axis = input_args[kInputIndex2]->cast(); MS_EXCEPTION_IF_NULL(axis); auto axis_value_ptr = axis->BuildValue(); MS_EXCEPTION_IF_NULL(axis_value_ptr); auto axis_tensor = axis_value_ptr->cast(); MS_EXCEPTION_IF_NULL(axis_tensor); axis_val = *static_cast(axis_tensor->data_c()); - } else if (input_args[2]->isa()) { - auto axis = input_args[2]->cast(); + } else if (input_args[kInputIndex2]->isa()) { + auto axis = input_args[kInputIndex2]->cast(); axis_val = GetValue(axis->BuildValue()); } else { - MS_LOG(EXCEPTION) << "Invalid abstract type:" << input_args[2]->type_name(); + MS_LOG(EXCEPTION) << "Invalid abstract type:" << input_args[kInputIndex2]->type_name(); } auto params_shp = params->shape()->shape(); auto indices_shp = indices->shape()->shape(); diff --git a/mindspore/core/ops/gather_d.cc b/mindspore/core/ops/gather_d.cc index f0c5bff8b2c..bc83d5633ce 100644 --- a/mindspore/core/ops/gather_d.cc +++ b/mindspore/core/ops/gather_d.cc @@ -29,11 +29,11 @@ abstract::ShapePtr GatherDInferShape(const PrimitivePtr &primitive, const std::v MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); // check - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; - int64_t x_rank = x_shape.size(); - CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, "index_rank", index_shape.size(), prim_name); - auto value_ptr = input_args[1]->BuildValue(); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; + int64_t x_rank = SizeToLong(x_shape.size()); + CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, "index_rank", SizeToLong(index_shape.size()), prim_name); + auto value_ptr = input_args[kInputIndex1]->BuildValue(); MS_EXCEPTION_IF_NULL(value_ptr); auto dim_v = GetValue(value_ptr); CheckAndConvertUtils::Check("dim value", dim_v, kGreaterEqual, "negative index_rank", -x_rank, prim_name); @@ -66,8 +66,9 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv auto prim_name = primitive->name(); // check std::set valid_types = {kInt32, kInt64}; - (void)CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[2]->BuildType(), valid_types, prim_name); - (void)CheckAndConvertUtils::CheckSubClass("dim_type", input_args[1]->BuildType(), valid_types, prim_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[kInputIndex2]->BuildType(), valid_types, + prim_name); + (void)CheckAndConvertUtils::CheckSubClass("dim_type", input_args[kInputIndex1]->BuildType(), valid_types, prim_name); return abstract::MakeAbstract(GatherDInferShape(primitive, input_args), GatherDInferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, true); diff --git a/mindspore/core/ops/gather_nd.cc b/mindspore/core/ops/gather_nd.cc index 89b9d0a2b49..ddb9f9d0532 100644 --- a/mindspore/core/ops/gather_nd.cc +++ b/mindspore/core/ops/gather_nd.cc @@ -28,7 +28,8 @@ namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } diff --git a/mindspore/core/ops/grad/avg_pool_3d_grad.cc b/mindspore/core/ops/grad/avg_pool_3d_grad.cc index 15f99c51be5..b1f9faf036a 100644 --- a/mindspore/core/ops/grad/avg_pool_3d_grad.cc +++ b/mindspore/core/ops/grad/avg_pool_3d_grad.cc @@ -30,7 +30,8 @@ constexpr size_t k5DInputDims = 5; abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kEqual, 2, op_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kEqual, input_num, op_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } @@ -48,7 +49,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kEqual, 2, op_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kEqual, input_num, op_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } diff --git a/mindspore/core/ops/grad/batch_norm_grad.cc b/mindspore/core/ops/grad/batch_norm_grad.cc index d3f1afcd262..91da624cac2 100644 --- a/mindspore/core/ops/grad/batch_norm_grad.cc +++ b/mindspore/core/ops/grad/batch_norm_grad.cc @@ -58,10 +58,10 @@ AbstractBasePtr BatchNormGradInfer(const abstract::AnalysisEnginePtr &, const Pr auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; CheckAndConvertUtils::Check("BatchNorm y_backprop_shape", y_backprop_shape, kEqual, "BatchNorm x_shape", x_shape); - auto dx = input_args[1]->Broaden(); - auto dscale = input_args[2]->Broaden(); - auto reserve_1 = input_args[3]->Broaden(); - auto reserve_2 = input_args[4]->Broaden(); + auto dx = input_args[kInputIndex1]->Broaden(); + auto dscale = input_args[kInputIndex2]->Broaden(); + auto reserve_1 = input_args[kInputIndex3]->Broaden(); + auto reserve_2 = input_args[kInputIndex4]->Broaden(); AbstractBasePtrList rets = {dx, dscale, dscale, reserve_1, reserve_2}; return std::make_shared(rets); diff --git a/mindspore/core/ops/grad/binary_cross_entropy_grad.cc b/mindspore/core/ops/grad/binary_cross_entropy_grad.cc index 64824f1aa72..8132f6cfa60 100644 --- a/mindspore/core/ops/grad/binary_cross_entropy_grad.cc +++ b/mindspore/core/ops/grad/binary_cross_entropy_grad.cc @@ -27,10 +27,11 @@ abstract::ShapePtr BinaryCrossEntroyGradInferShape(const PrimitivePtr &primitive const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; - auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; - CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name); + + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; + (void)CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name); if (weight_shape.size() < 1) { CheckAndConvertUtils::Check("y shape", y_shape, kEqual, "weight shape", weight_shape, prim_name); } @@ -40,12 +41,12 @@ abstract::ShapePtr BinaryCrossEntroyGradInferShape(const PrimitivePtr &primitive TypePtr BinaryCrossEntroyGradInferType(const PrimitivePtr &prim, const std::vector &input_args) { const std::set valid_types = {kFloat16, kFloat32}; std::map types; - (void)types.emplace("x_shape", input_args[0]->BuildType()); - (void)types.emplace("y_shape", input_args[1]->BuildType()); + (void)types.emplace("x_shape", input_args[kInputIndex0]->BuildType()); + (void)types.emplace("y_shape", input_args[kInputIndex1]->BuildType()); auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); - if (input_args[3]->BuildType() != nullptr) { - (void)types.emplace("x_shape", input_args[0]->BuildType()); - (void)types.emplace("weight_shape", input_args[2]->BuildType()); + if (input_args[kInputIndex3]->BuildType() != nullptr) { + (void)types.emplace("x_shape", input_args[kInputIndex0]->BuildType()); + (void)types.emplace("weight_shape", input_args[kInputIndex2]->BuildType()); infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); } return infer_type; diff --git a/mindspore/core/ops/grad/conv2d_backprop_filter.cc b/mindspore/core/ops/grad/conv2d_backprop_filter.cc index 87a6c38a584..1d19080fef2 100644 --- a/mindspore/core/ops/grad/conv2d_backprop_filter.cc +++ b/mindspore/core/ops/grad/conv2d_backprop_filter.cc @@ -29,7 +29,7 @@ abstract::ShapePtr Conv2DBackpropFilterInferShape(const PrimitivePtr &primitive, MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); // check - auto w_size_v = input_args[2]->BuildValue(); + auto w_size_v = input_args[kInputIndex2]->BuildValue(); auto ret_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("w_size", w_size_v, prim_name); return std::make_shared(ret_shape); } @@ -157,8 +157,9 @@ AbstractBasePtr Conv2DBackpropFilterInfer(const abstract::AnalysisEnginePtr &, c const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - // check - (void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, 3, prim_name); + const int64_t input_num = 3; + (void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, input_num, + prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } diff --git a/mindspore/core/ops/grad/conv2d_backprop_input.cc b/mindspore/core/ops/grad/conv2d_backprop_input.cc index 13c43ba56f0..a857be8b29b 100644 --- a/mindspore/core/ops/grad/conv2d_backprop_input.cc +++ b/mindspore/core/ops/grad/conv2d_backprop_input.cc @@ -67,7 +67,7 @@ abstract::ShapePtr Conv2DBackpropInputInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - auto x_size_v = input_args[2]->BuildValue(); + auto x_size_v = input_args[kInputIndex2]->BuildValue(); auto ret_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("x_size", x_size_v, prim_name); auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto format = CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr(kFormat)); @@ -93,7 +93,9 @@ AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, co MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); // check - (void)CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kGreaterEqual, 3, prim_name); + const int64_t input_num = 3; + (void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, input_num, + prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } @@ -149,7 +151,8 @@ void Conv2DBackpropInput::set_pad_mode(const PadMode &pad_mode) { } void Conv2DBackpropInput::set_pad(const std::vector &pad) { - (void)CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, 4, name()); + const int64_t pad_size = 4; + (void)CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, pad_size, name()); (void)AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name()))); } diff --git a/mindspore/core/ops/grad/group_conv2d_grad_input.cc b/mindspore/core/ops/grad/group_conv2d_grad_input.cc index 92cdcd2aeb4..42231f066dd 100644 --- a/mindspore/core/ops/grad/group_conv2d_grad_input.cc +++ b/mindspore/core/ops/grad/group_conv2d_grad_input.cc @@ -162,8 +162,9 @@ AbstractBasePtr GroupConv2DGradInputInfer(const abstract::AnalysisEnginePtr &, c const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("group_conv_2D_infer", SizeToLong(input_args.size()), kGreaterEqual, 2, - prim_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("group_conv_2D_infer", SizeToLong(input_args.size()), kGreaterEqual, + input_num, prim_name); MS_EXCEPTION_IF_NULL(input_args[0]); // Infer shape diff --git a/mindspore/core/ops/grad/layer_norm_grad.cc b/mindspore/core/ops/grad/layer_norm_grad.cc index bff7fe08256..99b1110dfee 100644 --- a/mindspore/core/ops/grad/layer_norm_grad.cc +++ b/mindspore/core/ops/grad/layer_norm_grad.cc @@ -26,10 +26,11 @@ AbstractBasePtr LayerNormGradInfer(const abstract::AnalysisEnginePtr &, const Pr // Outputs: x_backprob, gamma_backprob, beta_backprob MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 5, op_name); - auto x_backprob = input_args[0]->Broaden(); - auto gamma_backprob = input_args[4]->Broaden(); - auto beta_backprob = input_args[4]->Broaden(); + const int64_t input_num = 5; + (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, input_num, op_name); + auto x_backprob = input_args[kInputIndex0]->Broaden(); + auto gamma_backprob = input_args[kInputIndex4]->Broaden(); + auto beta_backprob = input_args[kInputIndex4]->Broaden(); auto shapes = std::make_shared(std::vector{ x_backprob->BuildShape(), gamma_backprob->BuildShape(), beta_backprob->BuildShape()}); auto types = std::make_shared( diff --git a/mindspore/core/ops/grad/relu_grad.cc b/mindspore/core/ops/grad/relu_grad.cc index e16f3dd250e..4e5e0a44933 100644 --- a/mindspore/core/ops/grad/relu_grad.cc +++ b/mindspore/core/ops/grad/relu_grad.cc @@ -30,7 +30,8 @@ namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } @@ -47,7 +48,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(prim); auto prim_name = prim->name(); - (void)CheckAndConvertUtils::CheckInteger("ReLUGrad infer", SizeToLong(input_args.size()), kEqual, 2, prim_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); MS_EXCEPTION_IF_NULL(input_args[0]); auto dout = CheckAndConvertUtils::CheckArgs(prim_name, input_args, 0); auto out = CheckAndConvertUtils::CheckArgs(prim_name, input_args, 1); diff --git a/mindspore/core/ops/grad/relu_grad_v2.cc b/mindspore/core/ops/grad/relu_grad_v2.cc index 395c56fe635..215e54aab7f 100644 --- a/mindspore/core/ops/grad/relu_grad_v2.cc +++ b/mindspore/core/ops/grad/relu_grad_v2.cc @@ -27,13 +27,7 @@ namespace mindspore { namespace ops { namespace { -abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name); - for (const auto &item : input_args) { - MS_EXCEPTION_IF_NULL(item); - } +abstract::ShapePtr InferShape(const std::vector &input_args) { auto x = input_args[0]->BuildShape(); MS_EXCEPTION_IF_NULL(x); auto shape_element = x->cast(); @@ -41,9 +35,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(prim); auto prim_name = prim->name(); - (void)CheckAndConvertUtils::CheckInteger("ReLUGradV2 infer", input_args.size(), kEqual, 2, prim_name); MS_EXCEPTION_IF_NULL(input_args[0]); auto x_type_map = input_args[0]->BuildType(); MS_EXCEPTION_IF_NULL(x_type_map); @@ -55,7 +47,15 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & } // namespace AbstractBasePtr ReLUGradV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, input_num, + prim_name); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + return abstract::MakeAbstract(InferShape(input_args), InferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(ReLUGradV2, prim::kPrimReluGradV2, ReLUGradV2Infer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/grad/sigmoid_cross_entropy_with_logits_grad.cc b/mindspore/core/ops/grad/sigmoid_cross_entropy_with_logits_grad.cc index 20e30e5e982..76dba169e53 100644 --- a/mindspore/core/ops/grad/sigmoid_cross_entropy_with_logits_grad.cc +++ b/mindspore/core/ops/grad/sigmoid_cross_entropy_with_logits_grad.cc @@ -31,23 +31,24 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsGradInfer(const abstract::AnalysisE const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); + const int64_t input_num = 3; (void)CheckAndConvertUtils::CheckInteger("sigmoid_cross_entropy_with_logits_grad_infer", - SizeToLong(input_args.size()), kEqual, 3, prim_name); + SizeToLong(input_args.size()), kEqual, input_num, prim_name); // Infer Shape - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; - auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; - CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "y_shape", y_shape, prim_name, TypeError); - CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "dout_shape", dout_shape, prim_name, TypeError); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; + (void)CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "y_shape", y_shape, prim_name, TypeError); + (void)CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "dout_shape", dout_shape, prim_name, TypeError); // Infer type const std::set valid_types = {kBool, kInt, kInt8, kInt16, kInt32, kInt64, kUInt, kUInt8, kUInt16, kUInt32, kUInt64, kFloat, kFloat16, kFloat32, kFloat64, kComplex64}; std::map args; - args.emplace("x_type", input_args[0]->BuildType()); - args.emplace("y_type", input_args[1]->BuildType()); - args.emplace("dout_type", input_args[2]->BuildType()); + args.emplace("x_type", input_args[kInputIndex0]->BuildType()); + args.emplace("y_type", input_args[kInputIndex1]->BuildType()); + args.emplace("dout_type", input_args[kInputIndex2]->BuildType()); auto dout_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); return std::make_shared(dout_type, x_shape); } diff --git a/mindspore/core/ops/grad/smooth_l1_loss_grad.cc b/mindspore/core/ops/grad/smooth_l1_loss_grad.cc index 4e23c8fffa6..bad6d3743b0 100644 --- a/mindspore/core/ops/grad/smooth_l1_loss_grad.cc +++ b/mindspore/core/ops/grad/smooth_l1_loss_grad.cc @@ -38,13 +38,14 @@ AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("smooth_l1_loss_grad_infer", SizeToLong(input_args.size()), kEqual, 3, - prim_name); + const int64_t input_num = 3; + (void)CheckAndConvertUtils::CheckInteger("smooth_l1_loss_grad_infer", SizeToLong(input_args.size()), kEqual, + input_num, prim_name); // Infer shape - auto prediction = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - auto target = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; - auto dloss = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; + auto prediction = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto target = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + auto dloss = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "target shape", target, prim_name, TypeError); CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "dloss", dloss, prim_name, TypeError); @@ -52,9 +53,9 @@ AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const const std::set valid_types = {kBool, kInt, kInt8, kInt16, kInt32, kInt64, kUInt, kUInt8, kUInt16, kUInt32, kUInt64, kFloat, kFloat16, kFloat32, kFloat64, kComplex64}; std::map args; - args.emplace("prediction", input_args[0]->BuildType()); - args.emplace("target", input_args[1]->BuildType()); - args.emplace("dloss", input_args[2]->BuildType()); + args.emplace("prediction", input_args[kInputIndex0]->BuildType()); + args.emplace("target", input_args[kInputIndex1]->BuildType()); + args.emplace("dloss", input_args[kInputIndex2]->BuildType()); auto dloss_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); return std::make_shared(dloss_type, prediction); diff --git a/mindspore/core/ops/hashtable_lookup.cc b/mindspore/core/ops/hashtable_lookup.cc index e9231d284b0..e9640e93b59 100644 --- a/mindspore/core/ops/hashtable_lookup.cc +++ b/mindspore/core/ops/hashtable_lookup.cc @@ -14,10 +14,11 @@ * limitations under the License. */ -#include - #include "ops/hashtable_lookup.h" + +#include #include "utils/check_convert_utils.h" +#include "ops/op_utils.h" namespace mindspore { namespace ops { @@ -32,7 +33,7 @@ AbstractBasePtr HashtableLookupInfer(const abstract::AnalysisEnginePtr &, const (void)CheckAndConvertUtils::CheckInteger("logits size", SizeToLong(input.size()), kGreaterEqual, 1, op_name); hits_shape.push_back(input[0]); - auto value_type = input_args[2]->BuildType(); + auto value_type = input_args[kInputIndex2]->BuildType(); MS_EXCEPTION_IF_NULL(value_type); auto tensor_type = value_type->cast(); MS_EXCEPTION_IF_NULL(tensor_type); diff --git a/mindspore/core/ops/lsh_projection.cc b/mindspore/core/ops/lsh_projection.cc index 04a0270f958..a60183d855d 100644 --- a/mindspore/core/ops/lsh_projection.cc +++ b/mindspore/core/ops/lsh_projection.cc @@ -33,15 +33,17 @@ AbstractBasePtr LshProjectionInfer(const abstract::AnalysisEnginePtr &, const Pr MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); const int64_t input_num = 2; + const int64_t input0_size = 2; + const int64_t input0_last_dim = 32; CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, op_name); - auto input0 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - auto input1 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; - (void)CheckAndConvertUtils::CheckInteger("input0 rank", SizeToLong(input0.size()), kEqual, 2, op_name); - (void)CheckAndConvertUtils::CheckInteger("input0_shape_dimen_1", input0[1], kLessEqual, 32, op_name); + auto input0 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto input1 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + (void)CheckAndConvertUtils::CheckInteger("input0 rank", SizeToLong(input0.size()), kEqual, input0_size, op_name); + (void)CheckAndConvertUtils::CheckInteger("input0_shape_dimen_1", input0[1], kLessEqual, input0_last_dim, op_name); (void)CheckAndConvertUtils::CheckInteger("input1 rank", SizeToLong(input1.size()), kGreaterEqual, 1, op_name); if (input_args.size() == 3) { - auto input2 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; + auto input2 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; (void)CheckAndConvertUtils::CheckInteger("input2 rank", SizeToLong(input2.size()), kEqual, 1, op_name); (void)CheckAndConvertUtils::CheckInteger("input2_shape_dimen_0", input2[0], kEqual, input1[0], op_name); } diff --git a/mindspore/core/ops/lstm.cc b/mindspore/core/ops/lstm.cc index 667b095f596..75bfc14ba26 100644 --- a/mindspore/core/ops/lstm.cc +++ b/mindspore/core/ops/lstm.cc @@ -32,16 +32,23 @@ AbstractBasePtr LstmInfer(const PrimitivePtr &primitive, const std::vectorname(); - (void)CheckAndConvertUtils::CheckInteger("lstm_prim_infer", SizeToLong(input_args.size()), kEqual, 4, prim_name); - auto x_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - auto h_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; - auto c_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; + const int64_t input_num = 4; + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); + auto x_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto h_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + auto c_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; int64_t input_x_size = GetValue(primitive->GetAttr(kInput_size)); - (void)CheckAndConvertUtils::CheckInteger("x_shape.size()", SizeToLong(x_input_shape.size()), kEqual, 3, prim_name); + const int64_t shape_size = 3; + (void)CheckAndConvertUtils::CheckInteger("x_shape.size()", SizeToLong(x_input_shape.size()), kEqual, shape_size, + prim_name); (void)CheckAndConvertUtils::CheckInteger("x_shape[2]", x_input_shape[2], kEqual, input_x_size, prim_name); - (void)CheckAndConvertUtils::CheckInteger("h_shape.size()", SizeToLong(h_input_shape.size()), kEqual, 3, prim_name); + (void)CheckAndConvertUtils::CheckInteger("h_shape.size()", SizeToLong(h_input_shape.size()), kEqual, shape_size, + prim_name); CheckAndConvertUtils::Check("h_shape", h_input_shape, kEqual, "c_shape", c_input_shape, prim_name); int64_t num_layers = GetValue(primitive->GetAttr(kNumLayers)); @@ -81,15 +88,11 @@ AbstractBasePtr LstmInfer(const PrimitivePtr &primitive, const std::vector state_shape = {1, 1}; // infer type - (void)CheckAndConvertUtils::CheckInteger("lstm_prim_infer", SizeToLong(input_args.size()), kEqual, 4, prim_name); - for (const auto &item : input_args) { - MS_EXCEPTION_IF_NULL(item); - } - auto infer_type0 = input_args[0]->BuildType()->cast()->element(); - auto infer_type1 = input_args[1]->BuildType()->cast()->element(); - auto infer_type2 = input_args[2]->BuildType()->cast()->element(); - auto infer_type3 = input_args[3]->BuildType()->cast()->element(); - auto infer_type4 = input_args[4]->BuildType()->cast()->element(); + auto infer_type0 = input_args[kInputIndex0]->BuildType()->cast()->element(); + auto infer_type1 = input_args[kInputIndex1]->BuildType()->cast()->element(); + auto infer_type2 = input_args[kInputIndex2]->BuildType()->cast()->element(); + auto infer_type3 = input_args[kInputIndex3]->BuildType()->cast()->element(); + auto infer_type4 = input_args[kInputIndex4]->BuildType()->cast()->element(); auto output0 = std::make_shared(infer_type0, x_shape); auto output1 = std::make_shared(infer_type1, y_shape); auto output2 = std::make_shared(infer_type2, c_shape); diff --git a/mindspore/core/ops/mat_mul.cc b/mindspore/core/ops/mat_mul.cc index 5b64ccdf54b..f46aeefde02 100644 --- a/mindspore/core/ops/mat_mul.cc +++ b/mindspore/core/ops/mat_mul.cc @@ -125,7 +125,8 @@ bool MatMul::get_transpose_b() const { AbstractBasePtr MatMulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - (void)CheckAndConvertUtils::CheckInteger("MatMul infer", SizeToLong(input_args.size()), kGreaterEqual, 2, + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("MatMul infer", SizeToLong(input_args.size()), kGreaterEqual, input_num, primitive->name()); return abstract::MakeAbstract(MatMulInferShape(primitive, input_args), MatMulInferType(primitive, input_args)); } diff --git a/mindspore/core/ops/max_pool.cc b/mindspore/core/ops/max_pool.cc index 4583fe0a196..c96b84c87fd 100644 --- a/mindspore/core/ops/max_pool.cc +++ b/mindspore/core/ops/max_pool.cc @@ -88,7 +88,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector>(primitive->GetAttr(kKernelSize)); auto pad_mode_value = (primitive->GetAttr(kPadMode)); diff --git a/mindspore/core/ops/mfcc.cc b/mindspore/core/ops/mfcc.cc index efc006989a6..49b94581ae3 100644 --- a/mindspore/core/ops/mfcc.cc +++ b/mindspore/core/ops/mfcc.cc @@ -25,10 +25,14 @@ namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); + const int64_t input0_size = 3; + const int64_t input1_size = 1; auto first_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto second_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; - (void)CheckAndConvertUtils::CheckInteger("input 0 rank", SizeToLong(first_input_shape.size()), kEqual, 3, prim_name); - (void)CheckAndConvertUtils::CheckInteger("input 1 rank", SizeToLong(second_input_shape.size()), kEqual, 1, prim_name); + (void)CheckAndConvertUtils::CheckInteger("input 0 rank", SizeToLong(first_input_shape.size()), kEqual, input0_size, + prim_name); + (void)CheckAndConvertUtils::CheckInteger("input 1 rank", SizeToLong(second_input_shape.size()), kEqual, input1_size, + prim_name); std::vector out_shape = {first_input_shape[0], first_input_shape[1], GetValue(primitive->GetAttr(kDctCoeffNum))}; return std::make_shared(out_shape); diff --git a/mindspore/core/ops/minimum.cc b/mindspore/core/ops/minimum.cc index 622d6f3b368..06a9f4d2bae 100644 --- a/mindspore/core/ops/minimum.cc +++ b/mindspore/core/ops/minimum.cc @@ -35,14 +35,16 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(prim); - (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim->name()); + auto op_name = prim->name(); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, op_name); if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) { MS_LOG(EXCEPTION) << "nullptr"; } std::map types; (void)types.emplace("x", input_args[0]->BuildType()); (void)types.emplace("y", input_args[1]->BuildType()); - return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); + return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name); } } // namespace diff --git a/mindspore/core/ops/mul.cc b/mindspore/core/ops/mul.cc index 5da93558a38..43ac4f6e91c 100644 --- a/mindspore/core/ops/mul.cc +++ b/mindspore/core/ops/mul.cc @@ -30,7 +30,8 @@ namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } @@ -42,7 +43,9 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & MS_EXCEPTION_IF_NULL(item); } auto op_name = prim->name(); - (void)CheckAndConvertUtils::CheckInteger("Mul infer", input_args.size(), kGreaterEqual, 2, op_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, input_num, + op_name); std::map types; (void)types.emplace("x", input_args[0]->BuildType()); (void)types.emplace("y", input_args[1]->BuildType()); diff --git a/mindspore/core/ops/not_equal.cc b/mindspore/core/ops/not_equal.cc index b5b4d33ee10..60e7f4ce10b 100644 --- a/mindspore/core/ops/not_equal.cc +++ b/mindspore/core/ops/not_equal.cc @@ -27,23 +27,8 @@ namespace mindspore { namespace ops { namespace { -abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(primitive); - auto op_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, op_name); - for (const auto &item : input_args) { - MS_EXCEPTION_IF_NULL(item); - } - return BroadCastInferShape(op_name, input_args); -} - TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(prim); auto op_name = prim->name(); - (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, 2, op_name); - for (const auto &item : input_args) { - MS_EXCEPTION_IF_NULL(item); - } std::map types; (void)types.emplace("x", input_args[0]->BuildType()); (void)types.emplace("y", input_args[1]->BuildType()); @@ -53,8 +38,15 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & AbstractBasePtr NotEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto op_name = primitive->name(); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, op_name); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } (void)InferType(primitive, input_args); - return abstract::MakeAbstract(InferShape(primitive, input_args), kBool); + return abstract::MakeAbstract(BroadCastInferShape(op_name, input_args), kBool); } REGISTER_PRIMITIVE_C(kNameNotEqual, NotEqual); } // namespace ops diff --git a/mindspore/core/ops/one_hot.cc b/mindspore/core/ops/one_hot.cc index e8f609677bc..7b5a7c501d2 100644 --- a/mindspore/core/ops/one_hot.cc +++ b/mindspore/core/ops/one_hot.cc @@ -59,11 +59,12 @@ abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::ve TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector &input_args) { auto op_name = prim->name(); - (void)CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[0]->BuildType(), {kInt32, kInt64}, op_name); - (void)CheckAndConvertUtils::CheckTypeValid("depth", input_args[1]->BuildType(), {kInt8, kInt16, kInt32, kInt64}, - op_name); - std::map args = {{"on_value", input_args[2]->BuildType()}, - {"off_dtype", input_args[3]->BuildType()}}; + (void)CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[kInputIndex0]->BuildType(), {kInt32, kInt64}, + op_name); + (void)CheckAndConvertUtils::CheckTypeValid("depth", input_args[kInputIndex1]->BuildType(), + {kInt8, kInt16, kInt32, kInt64}, op_name); + std::map args = {{"on_value", input_args[kInputIndex2]->BuildType()}, + {"off_dtype", input_args[kInputIndex3]->BuildType()}}; return CheckAndConvertUtils::CheckTensorTypeSame(args, {kFloat16, kFloat32}, op_name); } } // namespace diff --git a/mindspore/core/ops/op_utils.h b/mindspore/core/ops/op_utils.h index 7150131feaf..3f10725a54f 100644 --- a/mindspore/core/ops/op_utils.h +++ b/mindspore/core/ops/op_utils.h @@ -252,6 +252,27 @@ constexpr auto kSplitDim = "split_dim"; constexpr auto kPadTop = "pad_top"; constexpr auto kTransFormat = "trans_format"; constexpr auto kApproximate = "approximate"; + +enum Index : size_t { + kInputIndex0 = 0, + kInputIndex1, + kInputIndex2, + kInputIndex3, + kInputIndex4, + kInputIndex5, + kInputIndex6, + kInputIndex7, + kInputIndex8, + kInputIndex9, + kInputIndex10, + kInputIndex11, + kInputIndex12, + kInputIndex13, + kInputIndex14, + kInputIndex15, + kInputIndex16, +}; + const std::set common_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32, kUInt64, kFloat16, kFloat32, kFloat64}; diff --git a/mindspore/core/ops/prelu.cc b/mindspore/core/ops/prelu.cc index 9ba562bf558..48264a771e6 100644 --- a/mindspore/core/ops/prelu.cc +++ b/mindspore/core/ops/prelu.cc @@ -26,9 +26,10 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vectorBuildShape(); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x)[kShape]; auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(w)[kShape]; - - (void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kGreaterEqual, 2, prim_name); - (void)CheckAndConvertUtils::CheckInteger("weight rank", SizeToLong(w_shape.size()), kEqual, 1, prim_name); + const int64_t x_rank = 2; + const int64_t w_rank = 1; + (void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kGreaterEqual, x_rank, prim_name); + (void)CheckAndConvertUtils::CheckInteger("weight rank", SizeToLong(w_shape.size()), kEqual, w_rank, prim_name); if (w_shape[0] != x_shape[1] && w_shape[0] != 1) { MS_LOG(EXCEPTION) << "For " << prim_name << ", channel of input_x and weight must be matched, " << "while channel of input_x is " << x_shape[1] << ", weight_shape[0] is " << w_shape[0]; diff --git a/mindspore/core/ops/range.cc b/mindspore/core/ops/range.cc index 642760a0b83..6601375ca73 100644 --- a/mindspore/core/ops/range.cc +++ b/mindspore/core/ops/range.cc @@ -61,13 +61,14 @@ AbstractBasePtr RangeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); int64_t shape_size = 0; - if (input_args.size() == 3) { - MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue()); - MS_EXCEPTION_IF_NULL(input_args[1]->BuildValue()); - MS_EXCEPTION_IF_NULL(input_args[2]->BuildValue()); - auto start_tensor = input_args[0]->BuildValue()->cast(); - auto limit_tensor = input_args[1]->BuildValue()->cast(); - auto delta_tensor = input_args[2]->BuildValue()->cast(); + const size_t max_input_num = 3; + if (input_args.size() == max_input_num) { + MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]->BuildValue()); + MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]->BuildValue()); + MS_EXCEPTION_IF_NULL(input_args[kInputIndex2]->BuildValue()); + auto start_tensor = input_args[kInputIndex0]->BuildValue()->cast(); + auto limit_tensor = input_args[kInputIndex1]->BuildValue()->cast(); + auto delta_tensor = input_args[kInputIndex2]->BuildValue()->cast(); auto dtype = start_tensor->data_type(); switch (dtype) { case kNumberTypeInt: diff --git a/mindspore/core/ops/real_div.cc b/mindspore/core/ops/real_div.cc index 0fd2534fdba..f70df08affb 100644 --- a/mindspore/core/ops/real_div.cc +++ b/mindspore/core/ops/real_div.cc @@ -29,7 +29,8 @@ namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 2, prim_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } @@ -42,7 +43,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & MS_EXCEPTION_IF_NULL(item); } auto op_name = prim->name(); - (void)CheckAndConvertUtils::CheckInteger("RealDiv infer", SizeToLong(input_args.size()), kGreaterEqual, 2, op_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(input_args.size()), kGreaterEqual, input_num, op_name); std::map types; (void)types.emplace("x", input_args[0]->BuildType()); (void)types.emplace("y", input_args[1]->BuildType()); diff --git a/mindspore/core/ops/reshape.cc b/mindspore/core/ops/reshape.cc index bc709f8a36f..7c2799a0d5a 100644 --- a/mindspore/core/ops/reshape.cc +++ b/mindspore/core/ops/reshape.cc @@ -31,7 +31,8 @@ AbstractBasePtr ReshapeInfer(const abstract::AnalysisEnginePtr &, const Primitiv const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } diff --git a/mindspore/core/ops/resize_bilinear.cc b/mindspore/core/ops/resize_bilinear.cc index fa67597dd82..d0a181590ac 100644 --- a/mindspore/core/ops/resize_bilinear.cc +++ b/mindspore/core/ops/resize_bilinear.cc @@ -50,7 +50,8 @@ AbstractBasePtr ResizeBilinearInfer(const abstract::AnalysisEnginePtr &, const P // Infer shape auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - (void)CheckAndConvertUtils::CheckInteger("input_shape_rank", SizeToLong(input_shape.size()), kEqual, 4, prim_name); + const int64_t shape_size = 4; + (void)CheckAndConvertUtils::CheckInteger("input rank", SizeToLong(input_shape.size()), kEqual, shape_size, prim_name); std::vector out_shape = {input_shape[0], input_shape[1]}; auto size = GetValue>(primitive->GetAttr(kSize)); (void)out_shape.insert(out_shape.end(), size.begin(), size.end()); diff --git a/mindspore/core/ops/reverse_sequence.cc b/mindspore/core/ops/reverse_sequence.cc index 1bfd5342c1c..05fe1302dcf 100644 --- a/mindspore/core/ops/reverse_sequence.cc +++ b/mindspore/core/ops/reverse_sequence.cc @@ -39,7 +39,9 @@ AbstractBasePtr ReverseSequenceInfer(const abstract::AnalysisEnginePtr &, const const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 2, prim_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, input_num, + prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } diff --git a/mindspore/core/ops/roi_pooling.cc b/mindspore/core/ops/roi_pooling.cc index 34aefa8a506..8964896fe80 100644 --- a/mindspore/core/ops/roi_pooling.cc +++ b/mindspore/core/ops/roi_pooling.cc @@ -52,7 +52,8 @@ AbstractBasePtr ROIPoolingInfer(const abstract::AnalysisEnginePtr &, const Primi const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("roi_pooling_infer", SizeToLong(input_args.size()), kEqual, 2, prim_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(input_args.size()), kEqual, input_num, prim_name); MS_EXCEPTION_IF_NULL(input_args[0]); MS_EXCEPTION_IF_NULL(input_args[1]); diff --git a/mindspore/core/ops/scatter_nd.cc b/mindspore/core/ops/scatter_nd.cc index b1341649b95..7f75f8da63c 100644 --- a/mindspore/core/ops/scatter_nd.cc +++ b/mindspore/core/ops/scatter_nd.cc @@ -24,13 +24,13 @@ namespace mindspore { namespace ops { namespace { abstract::ShapePtr InferShape(const std::vector &input_args) { - auto shape_value = input_args[2]->BuildValue(); + auto shape_value = input_args[kInputIndex2]->BuildValue(); auto shape_value_element = GetValue>(shape_value); for (const auto &shape : shape_value_element) { (void)CheckAndConvertUtils::CheckInteger("shape value", shape, kGreaterThan, 0, "ScatterNd"); } - auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - auto update_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; + auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto update_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; (void)CheckAndConvertUtils::CheckInteger("indices_shape[0] and update_shape[0]", indices_shape[0], kEqual, update_shape[0], "ScatterNd"); return std::make_shared(shape_value_element); diff --git a/mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc b/mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc index 747a077c2dd..6158f64a53a 100644 --- a/mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc +++ b/mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc @@ -30,8 +30,9 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsInfer(const abstract::AnalysisEngin const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); + const int64_t input_num = 2; (void)CheckAndConvertUtils::CheckInteger("sigmoid_cross_extropy_with_logits_infer", SizeToLong(input_args.size()), - kEqual, 2, prim_name); + kEqual, input_num, prim_name); // Infer shape auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; diff --git a/mindspore/core/ops/space_to_batch.cc b/mindspore/core/ops/space_to_batch.cc index 7f3c5ebb3ab..1f43ed55326 100644 --- a/mindspore/core/ops/space_to_batch.cc +++ b/mindspore/core/ops/space_to_batch.cc @@ -30,7 +30,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vectorname(); auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - (void)CheckAndConvertUtils::CheckInteger("input shape", SizeToLong(input_shape.size()), kEqual, 4, prim_name); + const int64_t x_rank = 4; + (void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(input_shape.size()), kEqual, x_rank, prim_name); std::vector output_shape(input_shape.size()); auto block_shape_vector = GetValue>(primitive->GetAttr(kBlockSize)); auto paddings = GetValue>>(primitive->GetAttr(kPaddings)); diff --git a/mindspore/core/ops/space_to_batch_nd.cc b/mindspore/core/ops/space_to_batch_nd.cc index 91d27235405..7d17c5cc6ac 100644 --- a/mindspore/core/ops/space_to_batch_nd.cc +++ b/mindspore/core/ops/space_to_batch_nd.cc @@ -30,7 +30,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vectorname(); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - (void)CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name); + const int64_t shape_size = 4; + (void)CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()), kEqual, shape_size, prim_name); auto out_shape = x_shape; int64_t block_shape_prod = 1; const size_t offset = 2; @@ -60,7 +61,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & } // namespace void SpaceToBatchND::set_paddings(std::vector> paddings) { - (void)CheckAndConvertUtils::CheckInteger(kPaddings, SizeToLong(paddings.size()), kEqual, 2, this->name()); + const int64_t pad_size = 2; + (void)CheckAndConvertUtils::CheckInteger(kPaddings, SizeToLong(paddings.size()), kEqual, pad_size, this->name()); size_t h = paddings.size(); size_t w = paddings[0].size(); std::vector temp_w = {2, 2}; @@ -78,7 +80,9 @@ std::vector> SpaceToBatchND::get_paddings() const { return GetValue>>(value_ptr); } void SpaceToBatchND::set_block_shape(std::vector block_shape) { - (void)CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape.size()), kEqual, 2, this->name()); + const int64_t block_size = 2; + (void)CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape.size()), kEqual, block_size, + this->name()); for (size_t i = 0; i < block_shape.size(); i++) { (void)CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape[i]), kGreaterEqual, 1, this->name()); } diff --git a/mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.cc b/mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.cc index 413067a73d2..eaa3d3eb35a 100644 --- a/mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.cc +++ b/mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.cc @@ -38,7 +38,9 @@ AbstractBasePtr SparseSoftmaxCrossEntropyWithLogitsInfer(const abstract::Analysi const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, input_num, + prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } diff --git a/mindspore/core/ops/sparse_to_dense.cc b/mindspore/core/ops/sparse_to_dense.cc index c9533f2f7e5..4e9e87b4c30 100644 --- a/mindspore/core/ops/sparse_to_dense.cc +++ b/mindspore/core/ops/sparse_to_dense.cc @@ -28,12 +28,13 @@ AbstractBasePtr SparseToDenseInfer(const abstract::AnalysisEnginePtr &, const Pr const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 3, prim_name); + const int64_t input_num = 3; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } // infer shape - auto dense_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape]; + auto dense_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape]; // infer type auto values_type = input_args[1]->BuildType()->cast()->element(); return std::make_shared(values_type, dense_shape); diff --git a/mindspore/core/ops/strided_slice.cc b/mindspore/core/ops/strided_slice.cc index bc7d459e41e..c41a1f58840 100644 --- a/mindspore/core/ops/strided_slice.cc +++ b/mindspore/core/ops/strided_slice.cc @@ -173,18 +173,18 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive, MS_EXCEPTION_IF_NULL(primitive); auto strided_slice_prim = primitive->cast(); MS_EXCEPTION_IF_NULL(strided_slice_prim); - auto tuple_begin_v = input_args[1]->cast(); + auto tuple_begin_v = input_args[kInputIndex1]->cast(); MS_EXCEPTION_IF_NULL(tuple_begin_v); auto temp_begin_v = tuple_begin_v->BuildValue(); MS_EXCEPTION_IF_NULL(temp_begin_v); auto begin_v = GetValue>(temp_begin_v); - auto tuple_end_v = input_args[2]->cast(); + auto tuple_end_v = input_args[kInputIndex2]->cast(); MS_EXCEPTION_IF_NULL(tuple_end_v); auto temp_end_v = tuple_end_v->BuildValue(); MS_EXCEPTION_IF_NULL(temp_end_v); auto end_v = GetValue>(temp_end_v); - auto strides_v = CheckAndGetValidStrides(input_args[3]); + auto strides_v = CheckAndGetValidStrides(input_args[kInputIndex3]); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMinShape]; diff --git a/mindspore/core/ops/sub.cc b/mindspore/core/ops/sub.cc index 78733654b16..02cecaaac28 100644 --- a/mindspore/core/ops/sub.cc +++ b/mindspore/core/ops/sub.cc @@ -30,7 +30,8 @@ namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 2, prim_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } @@ -42,7 +43,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & MS_EXCEPTION_IF_NULL(item); } auto op_name = prim->name(); - (void)CheckAndConvertUtils::CheckInteger("Sub infer", SizeToLong(input_args.size()), kGreaterEqual, 2, op_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(input_args.size()), kGreaterEqual, input_num, op_name); std::map types; (void)types.emplace("x", input_args[0]->BuildType()); (void)types.emplace("y", input_args[1]->BuildType()); diff --git a/mindspore/core/ops/tile.cc b/mindspore/core/ops/tile.cc index bf024e28e80..9bc2d9c4f77 100644 --- a/mindspore/core/ops/tile.cc +++ b/mindspore/core/ops/tile.cc @@ -77,7 +77,8 @@ abstract::ShapePtr TileInferShape(const PrimitivePtr &primitive, const std::vect TypePtr TileInferType(const PrimitivePtr &prim, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(prim); auto prim_name = prim->name(); - (void)CheckAndConvertUtils::CheckInteger("tile_prim_infer", input_args.size(), kEqual, 2, prim_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(input_args.size()), kEqual, input_num, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } diff --git a/mindspore/core/ops/topk.cc b/mindspore/core/ops/topk.cc index c1fa50e0c62..7093b3a322d 100644 --- a/mindspore/core/ops/topk.cc +++ b/mindspore/core/ops/topk.cc @@ -32,7 +32,8 @@ AbstractBasePtr TopKInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("top_k_infer", SizeToLong(input_args.size()), kEqual, 2, prim_name); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("top_k_infer", SizeToLong(input_args.size()), kEqual, input_num, prim_name); // Infer dtype for (const auto &item : input_args) { diff --git a/mindspore/core/ops/unsorted_segment_sum.cc b/mindspore/core/ops/unsorted_segment_sum.cc index 8b84c30759b..017e552a3c0 100644 --- a/mindspore/core/ops/unsorted_segment_sum.cc +++ b/mindspore/core/ops/unsorted_segment_sum.cc @@ -55,11 +55,11 @@ AbstractBasePtr UnsortedSegmentSumInfer(const abstract::AnalysisEnginePtr &, con } const std::set valid_num_segments_types = {kInt32, kInt64}; - (void)CheckAndConvertUtils::CheckTensorTypeValid("num_segments", input_args[2]->BuildType(), valid_num_segments_types, - prim_name); - int64_t size_segment_ids_shp = segment_ids_shape.size(); - int64_t size_x_shpe = x_shape.size(); - for (int64_t i = size_segment_ids_shp; i < size_x_shpe; ++i) { + (void)CheckAndConvertUtils::CheckTensorTypeValid("num_segments", input_args[kInputIndex2]->BuildType(), + valid_num_segments_types, prim_name); + auto size_segment_ids_shp = segment_ids_shape.size(); + auto size_x_shpe = x_shape.size(); + for (size_t i = size_segment_ids_shp; i < size_x_shpe; ++i) { (void)shp.emplace_back(x_shape[i]); } diff --git a/mindspore/core/ops/where.cc b/mindspore/core/ops/where.cc index 53996def932..c7d36302174 100644 --- a/mindspore/core/ops/where.cc +++ b/mindspore/core/ops/where.cc @@ -29,16 +29,18 @@ AbstractBasePtr WhereInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP MS_EXCEPTION_IF_NULL(input); } auto op_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kGreaterEqual, 3, op_name); - auto input0_type_ = input_args[0]->BuildType()->cast(); + const int64_t input_num = 3; + (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, input_num, + op_name); + auto input0_type_ = input_args[kInputIndex0]->BuildType()->cast(); MS_EXCEPTION_IF_NULL(input0_type_); auto input0_type = input0_type_->element(); - auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - auto num = input_args[0]->BuildValue()->cast()->ElementsNum(); - auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; - auto num1 = input_args[1]->BuildValue()->cast()->ElementsNum(); - auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; - auto num2 = input_args[2]->BuildValue()->cast()->ElementsNum(); + auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto num = input_args[kInputIndex0]->BuildValue()->cast()->ElementsNum(); + auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + auto num1 = input_args[kInputIndex1]->BuildValue()->cast()->ElementsNum(); + auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; + auto num2 = input_args[kInputIndex2]->BuildValue()->cast()->ElementsNum(); auto nummax = num > num1 ? num : (num1 > num2 ? num1 : num2); size_t axisout = 0; size_t temp = 0;