From 894622865a9ec6abda6bbed7b972ee866598352c Mon Sep 17 00:00:00 2001 From: simson Date: Mon, 12 Apr 2021 10:22:03 +0800 Subject: [PATCH] remove ConvertShapePtrToShape function, use ConvertShapePtrToShapeMap instead --- mindspore/core/ops/abs.cc | 3 +-- mindspore/core/ops/adam.cc | 9 ++++----- mindspore/core/ops/addn.cc | 6 ++---- mindspore/core/ops/apply_momentum.cc | 2 +- mindspore/core/ops/arg_max.cc | 2 +- mindspore/core/ops/arg_min.cc | 2 +- mindspore/core/ops/asin.cc | 2 +- mindspore/core/ops/assert.cc | 3 +-- mindspore/core/ops/assign_add.cc | 4 +--- mindspore/core/ops/atan.cc | 2 +- mindspore/core/ops/audio_spectrogram.cc | 4 +--- mindspore/core/ops/avg_pool.cc | 2 +- mindspore/core/ops/batch_norm.cc | 13 ++++++------- mindspore/core/ops/batch_norm_fold.cc | 10 ++++------ mindspore/core/ops/batch_to_space.cc | 2 +- mindspore/core/ops/batch_to_space_nd.cc | 2 +- mindspore/core/ops/bias_add.cc | 4 ++-- mindspore/core/ops/binary_cross_entropy.cc | 7 +++---- mindspore/core/ops/broadcast.cc | 2 +- mindspore/core/ops/broadcast_to.cc | 2 +- mindspore/core/ops/ceil.cc | 2 +- mindspore/core/ops/concat.cc | 6 ++---- mindspore/core/ops/constant_of_shape.cc | 3 +-- mindspore/core/ops/conv2d.cc | 4 ++-- mindspore/core/ops/conv2d_transpose.cc | 3 +-- mindspore/core/ops/cos.cc | 3 +-- mindspore/core/ops/crop.cc | 2 +- mindspore/core/ops/custom_extract_features.cc | 6 +----- mindspore/core/ops/depth_to_space.cc | 2 +- mindspore/core/ops/depthwise_conv2d.cc | 4 ++-- mindspore/core/ops/detection_post_process.cc | 6 +++--- mindspore/core/ops/dropout.cc | 2 +- mindspore/core/ops/elu.cc | 3 +-- mindspore/core/ops/expand_dims.cc | 2 +- mindspore/core/ops/fake_quant_with_min_max_vars.cc | 6 +++--- .../ops/fake_quant_with_min_max_vars_per_channel.cc | 6 +++--- mindspore/core/ops/fft_imag.cc | 3 +-- mindspore/core/ops/fft_real.cc | 2 +- mindspore/core/ops/flatten.cc | 2 +- mindspore/core/ops/floor.cc | 3 +-- mindspore/core/ops/fusion/avg_pool_fusion.cc | 2 +- mindspore/core/ops/fusion/full_connection.cc | 7 +++---- mindspore/core/ops/fusion/max_pool_fusion.cc | 2 +- mindspore/core/ops/fusion/slice_fusion.cc | 2 +- mindspore/core/ops/gather_d.cc | 4 ++-- mindspore/core/ops/gather_nd.cc | 5 ++--- mindspore/core/ops/gelu.cc | 3 +-- mindspore/core/ops/grad/batch_norm_grad.cc | 6 ++---- mindspore/core/ops/grad/bias_add_grad.cc | 2 +- .../core/ops/grad/binary_cross_entropy_grad.cc | 7 +++---- mindspore/core/ops/grad/dropout_grad.cc | 3 +-- mindspore/core/ops/grad/max_pool_grad.cc | 4 +--- .../grad/sigmoid_cross_entropy_with_logits_grad.cc | 6 +++--- mindspore/core/ops/grad/smooth_l1_loss_grad.cc | 6 +++--- mindspore/core/ops/hashtable_lookup.cc | 3 +-- mindspore/core/ops/l2_normalize.cc | 2 +- mindspore/core/ops/log.cc | 2 +- mindspore/core/ops/logical_not.cc | 3 +-- mindspore/core/ops/lrn.cc | 2 +- mindspore/core/ops/lsh_projection.cc | 6 +++--- mindspore/core/ops/lstm.cc | 6 +++--- mindspore/core/ops/mat_mul.cc | 4 ++-- mindspore/core/ops/matrix_diag.cc | 5 ++--- mindspore/core/ops/max_pool.cc | 2 +- mindspore/core/ops/mfcc.cc | 6 ++---- mindspore/core/ops/one_hot.cc | 2 +- mindspore/core/ops/ones_like.cc | 4 +--- mindspore/core/ops/op_utils.cc | 4 ++-- mindspore/core/ops/pack.cc | 4 ++-- mindspore/core/ops/pad.cc | 2 +- mindspore/core/ops/prelu.cc | 4 ++-- mindspore/core/ops/prior_box.cc | 3 +-- mindspore/core/ops/quant_dtype_cast.cc | 3 +-- mindspore/core/ops/reciprocal.cc | 3 +-- mindspore/core/ops/reduce.cc | 3 +-- mindspore/core/ops/resize_bilinear.cc | 3 +-- mindspore/core/ops/reverse_sequence.cc | 6 ++---- mindspore/core/ops/reverse_v2.cc | 3 +-- mindspore/core/ops/rfft.cc | 4 +--- mindspore/core/ops/roi_pooling.cc | 5 ++--- mindspore/core/ops/round.cc | 2 +- mindspore/core/ops/rsqrt.cc | 2 +- mindspore/core/ops/scalar_summary.cc | 2 +- mindspore/core/ops/scatter_nd.cc | 6 ++---- .../core/ops/sigmoid_cross_entropy_with_logits.cc | 4 ++-- mindspore/core/ops/sin.cc | 2 +- mindspore/core/ops/skip_gram.cc | 3 +-- mindspore/core/ops/smooth_l1_loss.cc | 4 ++-- .../core/ops/softmax_cross_entropy_with_logits.cc | 6 ++---- mindspore/core/ops/space_to_batch.cc | 3 +-- mindspore/core/ops/space_to_batch_nd.cc | 2 +- .../ops/sparse_softmax_cross_entropy_with_logits.cc | 3 +-- mindspore/core/ops/sparse_to_dense.cc | 3 +-- mindspore/core/ops/squeeze.cc | 2 +- mindspore/core/ops/stack.cc | 7 ++----- mindspore/core/ops/strided_slice.cc | 2 +- mindspore/core/ops/tan.cc | 2 +- mindspore/core/ops/tensor_list_from_tensor.cc | 7 ++----- mindspore/core/ops/tensor_list_stack.cc | 7 ++----- mindspore/core/ops/tensor_summary.cc | 2 +- mindspore/core/ops/tile.cc | 3 +-- mindspore/core/ops/topk.cc | 2 +- mindspore/core/ops/unpack.cc | 2 +- mindspore/core/ops/unsorted_segment_sum.cc | 5 ++--- mindspore/core/ops/unsqueeze.cc | 2 +- mindspore/core/ops/unstack.cc | 2 +- mindspore/core/ops/where.cc | 9 +++------ mindspore/core/ops/zeros.cc | 3 +-- mindspore/core/ops/zeros_like.cc | 2 +- mindspore/core/utils/check_convert_utils.cc | 12 ------------ mindspore/core/utils/check_convert_utils.h | 3 --- mindspore/lite/tools/converter/ops/while.cc | 4 +--- 112 files changed, 167 insertions(+), 259 deletions(-) diff --git a/mindspore/core/ops/abs.cc b/mindspore/core/ops/abs.cc index 984175d324b..f30467168b9 100644 --- a/mindspore/core/ops/abs.cc +++ b/mindspore/core/ops/abs.cc @@ -30,11 +30,10 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); + auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; return std::make_shared(in_shape); } diff --git a/mindspore/core/ops/adam.cc b/mindspore/core/ops/adam.cc index 05dea617bf5..fc5e0ab799e 100644 --- a/mindspore/core/ops/adam.cc +++ b/mindspore/core/ops/adam.cc @@ -26,11 +26,10 @@ abstract::AbstractBasePtr AdamInfer(const PrimitivePtr &primitive, const std::ve auto prim_name = primitive->name(); // infer shape - auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShape("var_shape", input_args[0]->GetShapeTrack(), prim_name); - auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShape("m_shape", input_args[1]->GetShapeTrack(), prim_name); - auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[2]->GetShapeTrack(), prim_name); - auto grad_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("grad_shape", input_args[9]->GetShapeTrack(), prim_name); + auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; + auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape]; + auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->GetShapeTrack())[kShape]; + auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[9]->GetShapeTrack())[kShape]; CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "m_shape", m_shape, prim_name); CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "v_shape", v_shape, prim_name); CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "grad_shape", grad_shape, prim_name); diff --git a/mindspore/core/ops/addn.cc b/mindspore/core/ops/addn.cc index de1b4e95d00..3b04bf4b304 100644 --- a/mindspore/core/ops/addn.cc +++ b/mindspore/core/ops/addn.cc @@ -38,15 +38,13 @@ AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name); auto element0 = elements[0]->cast(); MS_EXCEPTION_IF_NULL(element0); - auto element0_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("element0 shape", element0->BuildShape(), prim_name); + auto element0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kShape]; std::map types; types.emplace("element0", element0->BuildType()); for (size_t i = 1; i < elements.size(); ++i) { std::string elementi = "element" + std::to_string(i); - auto elementi_shape = - CheckAndConvertUtils::ConvertShapePtrToShape(elementi + " shape", elements[i]->BuildShape(), prim_name); + auto elementi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInteger(elementi + " shape rank", elementi_shape.size(), kEqual, element0_shape.size(), prim_name); for (size_t j = 0; j < element0_shape.size(); ++j) { diff --git a/mindspore/core/ops/apply_momentum.cc b/mindspore/core/ops/apply_momentum.cc index ebc3962a791..3d061ecb8ad 100644 --- a/mindspore/core/ops/apply_momentum.cc +++ b/mindspore/core/ops/apply_momentum.cc @@ -60,7 +60,7 @@ AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const Pr CheckAndConvertUtils::CheckInteger("apply_momentum_infer", input_args.size(), kEqual, 5, prim_name); // Infer shape - auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[0]->BuildShape(), prim_name); + auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; // Infer type auto v_tensor_type = input_args[0]->BuildType(); diff --git a/mindspore/core/ops/arg_max.cc b/mindspore/core/ops/arg_max.cc index 06e96df3f8b..98fd21e3992 100644 --- a/mindspore/core/ops/arg_max.cc +++ b/mindspore/core/ops/arg_max.cc @@ -23,7 +23,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vectorname(); auto axis = GetValue(primitive->GetAttr(kAxis)); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto x_rank = SizeToLong(x_shape.size()); CheckAndConvertUtils::CheckInRange("argmax axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name); axis = axis < 0 ? axis + x_rank : axis; diff --git a/mindspore/core/ops/arg_min.cc b/mindspore/core/ops/arg_min.cc index def1e63c109..1bf95849668 100644 --- a/mindspore/core/ops/arg_min.cc +++ b/mindspore/core/ops/arg_min.cc @@ -42,7 +42,7 @@ AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const Primitive // Infer shape auto axis = GetValue(primitive->GetAttr(kAxis)); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto x_rank = SizeToLong(x_shape.size()); CheckAndConvertUtils::CheckInRange("axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name); if (axis < 0) { diff --git a/mindspore/core/ops/asin.cc b/mindspore/core/ops/asin.cc index ab75a30450a..099bb693e98 100644 --- a/mindspore/core/ops/asin.cc +++ b/mindspore/core/ops/asin.cc @@ -29,7 +29,7 @@ AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt CheckAndConvertUtils::CheckInteger("Asin_infer", input_args.size(), kEqual, 1, prim_name); // Infer Shape - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto infer_shape = std::make_shared(x_shape); // Infer Type diff --git a/mindspore/core/ops/assert.cc b/mindspore/core/ops/assert.cc index 09a0a5028ff..7bb4f621a15 100644 --- a/mindspore/core/ops/assert.cc +++ b/mindspore/core/ops/assert.cc @@ -47,8 +47,7 @@ AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const Primitive } condition = TypeIdToType(kNumberTypeBool); } else { - auto condition_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); + auto condition_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInteger("condition's rank", condition_shape[0], kLessEqual, 1, op_name); if (condition_shape[0] == 1) { auto condition_value = reinterpret_cast(input_args[0]->BuildValue()->cast()->data_c()); diff --git a/mindspore/core/ops/assign_add.cc b/mindspore/core/ops/assign_add.cc index 8d87c74eea2..8c6c48904a5 100644 --- a/mindspore/core/ops/assign_add.cc +++ b/mindspore/core/ops/assign_add.cc @@ -25,9 +25,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); - auto value_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("value_shape", input_args[1]->BuildShape(), prim_name); + auto value_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; return std::make_shared(value_shape); } diff --git a/mindspore/core/ops/atan.cc b/mindspore/core/ops/atan.cc index 1335c6476b4..cc3a09040de 100644 --- a/mindspore/core/ops/atan.cc +++ b/mindspore/core/ops/atan.cc @@ -27,7 +27,7 @@ AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt CheckAndConvertUtils::CheckInteger("Atan_infer", input_args.size(), kEqual, 1, prim_name); // Infer Shape - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto infer_shape = std::make_shared(x_shape); // Infer Type diff --git a/mindspore/core/ops/audio_spectrogram.cc b/mindspore/core/ops/audio_spectrogram.cc index b8e023b57fe..c49a9c81cb9 100644 --- a/mindspore/core/ops/audio_spectrogram.cc +++ b/mindspore/core/ops/audio_spectrogram.cc @@ -30,9 +30,7 @@ namespace { abstract::ShapePtr AudioSpectrogramInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); - auto input_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); + auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; if (input_shape.size() != 2) { MS_LOG(ERROR) << "input shape is error, which need to be 2 dimensions"; } diff --git a/mindspore/core/ops/avg_pool.cc b/mindspore/core/ops/avg_pool.cc index 9c6abfad5f7..46c3efc50d9 100644 --- a/mindspore/core/ops/avg_pool.cc +++ b/mindspore/core/ops/avg_pool.cc @@ -82,7 +82,7 @@ namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); - auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); + auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; auto format = Format(GetValue(primitive->GetAttr(kFormat))); if (format == NHWC) { in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; diff --git a/mindspore/core/ops/batch_norm.cc b/mindspore/core/ops/batch_norm.cc index 57c3d8e284a..b60c3b71c66 100644 --- a/mindspore/core/ops/batch_norm.cc +++ b/mindspore/core/ops/batch_norm.cc @@ -75,20 +75,19 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("batch_norm_infer", input_args.size(), kEqual, 5, prim_name); - auto input_x = CheckAndConvertUtils::ConvertShapePtrToShape("input_x", input_args[0]->BuildShape(), prim_name); + auto input_x = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto format = Format(GetValue(primitive->GetAttr(kFormat))); if (format == NHWC) { input_x = {input_x[0], input_x[3], input_x[1], input_x[2]}; } - auto scale = CheckAndConvertUtils::ConvertShapePtrToShape("scale", input_args[1]->BuildShape(), prim_name); - auto bias = CheckAndConvertUtils::ConvertShapePtrToShape("bias", input_args[2]->BuildShape(), prim_name); - auto mean = CheckAndConvertUtils::ConvertShapePtrToShape("mean", input_args[3]->BuildShape(), prim_name); - auto variance = CheckAndConvertUtils::ConvertShapePtrToShape("variance", input_args[4]->BuildShape(), prim_name); + auto scale = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; + auto bias = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; + auto mean = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape]; + auto variance = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->BuildShape())[kShape]; std::vector input_shape_norm; if (format == NCHW) { - input_shape_norm = - CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); + input_shape_norm = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; } else { input_shape_norm.push_back(input_x[0]); input_shape_norm.push_back(input_x[3]); diff --git a/mindspore/core/ops/batch_norm_fold.cc b/mindspore/core/ops/batch_norm_fold.cc index 359e0fd279d..5484d54c695 100644 --- a/mindspore/core/ops/batch_norm_fold.cc +++ b/mindspore/core/ops/batch_norm_fold.cc @@ -68,12 +68,10 @@ AbstractBasePtr BatchNormFoldInfer(const abstract::AnalysisEnginePtr &, const Pr const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); - auto mean_shape = CheckAndConvertUtils::ConvertShapePtrToShape("mean_shape", input_args[1]->BuildShape(), op_name); - auto variance_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("variance_shape", input_args[2]->BuildShape(), op_name); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name); - auto global_step_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("global_step_shape", input_args[3]->BuildShape(), op_name); + auto mean_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; + auto variance_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto global_step_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape]; CheckAndConvertUtils::Check("mean_shape", mean_shape, kEqual, "gamma_shape", variance_shape, op_name); CheckAndConvertUtils::Check("mean_shape[0]", mean_shape[0], kEqual, "input channel", x_shape[1], op_name); CheckAndConvertUtils::CheckInteger("global step shape len", global_step_shape.size(), kEqual, 1, op_name); diff --git a/mindspore/core/ops/batch_to_space.cc b/mindspore/core/ops/batch_to_space.cc index 800928bb826..f9ba50535d4 100644 --- a/mindspore/core/ops/batch_to_space.cc +++ b/mindspore/core/ops/batch_to_space.cc @@ -55,7 +55,7 @@ AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), common_valid_types, prim_name); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); auto block_size = GetValue>(primitive->GetAttr(kBlockSize)); auto crops = GetValue>>(primitive->GetAttr(kCrops)); diff --git a/mindspore/core/ops/batch_to_space_nd.cc b/mindspore/core/ops/batch_to_space_nd.cc index 06a7000ef46..73d4e5ceb1d 100644 --- a/mindspore/core/ops/batch_to_space_nd.cc +++ b/mindspore/core/ops/batch_to_space_nd.cc @@ -29,7 +29,7 @@ namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name); auto out_shape = x_shape; int64_t block_shape_prod = 1; diff --git a/mindspore/core/ops/bias_add.cc b/mindspore/core/ops/bias_add.cc index d1bb713b274..bcb1a802481 100644 --- a/mindspore/core/ops/bias_add.cc +++ b/mindspore/core/ops/bias_add.cc @@ -30,8 +30,8 @@ abstract::ShapePtr BiasAddInferShape(const PrimitivePtr &primitive, const std::v auto prim_name = primitive->name(); // check CheckAndConvertUtils::CheckInteger("arg size", input_args.size(), kEqual, 2, prim_name); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); - auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShape("b_shape", input_args[1]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kGreaterEqual, 2, prim_name); CheckAndConvertUtils::CheckInteger("bias rank", b_shape.size(), kEqual, 1, prim_name); auto format = Format(GetValue(primitive->GetAttr(kFormat))); diff --git a/mindspore/core/ops/binary_cross_entropy.cc b/mindspore/core/ops/binary_cross_entropy.cc index 4eef0762a55..91926b06633 100644 --- a/mindspore/core/ops/binary_cross_entropy.cc +++ b/mindspore/core/ops/binary_cross_entropy.cc @@ -34,10 +34,9 @@ abstract::ShapePtr BinaryCrossEntroyInferShape(const PrimitivePtr &primitive, MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInRange("binary_cross_entropy_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); - auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name); - auto weight_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("weight_shape", input_args[2]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; + auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name); std::vector infer_shape; if (weight_shape.size() < 1) { diff --git a/mindspore/core/ops/broadcast.cc b/mindspore/core/ops/broadcast.cc index 67d6539015b..e40d33df12a 100644 --- a/mindspore/core/ops/broadcast.cc +++ b/mindspore/core/ops/broadcast.cc @@ -50,7 +50,7 @@ AbstractBasePtr BroadcastInfer(const abstract::AnalysisEnginePtr &, const Primit MS_EXCEPTION_IF_NULL(item); } // infer shape - auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; // infer type auto x_type = input_args[0]->BuildType()->cast()->element(); std::vector output_types; diff --git a/mindspore/core/ops/broadcast_to.cc b/mindspore/core/ops/broadcast_to.cc index 44c01bd5235..03d71a7ac3a 100644 --- a/mindspore/core/ops/broadcast_to.cc +++ b/mindspore/core/ops/broadcast_to.cc @@ -24,7 +24,7 @@ abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto value_ptr = primitive->GetAttr(kShape); auto input_x = GetValue>(value_ptr); int64_t outer_dim_offset = input_x.size() - x_shape.size(); diff --git a/mindspore/core/ops/ceil.cc b/mindspore/core/ops/ceil.cc index e9262b7698b..a55077e3f16 100644 --- a/mindspore/core/ops/ceil.cc +++ b/mindspore/core/ops/ceil.cc @@ -31,7 +31,7 @@ AbstractBasePtr CeilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Ceil"); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; const std::set valid_types = {kFloat16, kFloat32}; auto infer_type = input_args[0]->BuildType(); auto data_type = CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, primitive->name()); diff --git a/mindspore/core/ops/concat.cc b/mindspore/core/ops/concat.cc index e74d8fdfcf8..fd97cba7a60 100644 --- a/mindspore/core/ops/concat.cc +++ b/mindspore/core/ops/concat.cc @@ -43,8 +43,7 @@ AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const Primitive CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name); auto element0 = elements[0]->cast(); MS_EXCEPTION_IF_NULL(element0); - auto element0_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("element0 shape", element0->BuildShape(), prim_name); + auto element0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kShape]; auto element0_rank = SizeToLong(element0_shape.size()); auto axis = GetValue(primitive->GetAttr(kAxis)); CheckAndConvertUtils::CheckInRange("Concat axis", axis, kIncludeBoth, {-element0_rank - 1, element0_rank}, @@ -56,8 +55,7 @@ AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const Primitive int64_t all_shp = element0_shape[axis]; for (size_t i = 1; i < elements.size(); ++i) { std::string elementi = "element" + std::to_string(i); - auto elementi_shape = - CheckAndConvertUtils::ConvertShapePtrToShape(elementi + " shape", elements[i]->BuildShape(), prim_name); + auto elementi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInteger(elementi + " shape rank", elementi_shape.size(), kEqual, element0_shape.size(), prim_name); for (int64_t j = 0; j < element0_rank; ++j) { diff --git a/mindspore/core/ops/constant_of_shape.cc b/mindspore/core/ops/constant_of_shape.cc index fb3d711fa77..8458f1ce619 100644 --- a/mindspore/core/ops/constant_of_shape.cc +++ b/mindspore/core/ops/constant_of_shape.cc @@ -24,8 +24,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { CheckAndConvertUtils::CheckInteger("input args size", input_args.size(), kEqual, 1, "ConstantOfShape"); - auto input_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "ConstantOfShape"); + auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; return std::make_shared(input_shape); } diff --git a/mindspore/core/ops/conv2d.cc b/mindspore/core/ops/conv2d.cc index 14a22629dff..ddac0efd882 100644 --- a/mindspore/core/ops/conv2d.cc +++ b/mindspore/core/ops/conv2d.cc @@ -79,8 +79,8 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInRange("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); - auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; auto format = Format(GetValue(primitive->GetAttr(kFormat))); if (format == NHWC) { x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; diff --git a/mindspore/core/ops/conv2d_transpose.cc b/mindspore/core/ops/conv2d_transpose.cc index 8463fcc0d05..42ab771e543 100644 --- a/mindspore/core/ops/conv2d_transpose.cc +++ b/mindspore/core/ops/conv2d_transpose.cc @@ -28,8 +28,7 @@ namespace { abstract::ShapePtr Conv2dTransposeInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); - auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[3]->BuildShape(), prim_name); + auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape]; return std::make_shared(input_shape); } diff --git a/mindspore/core/ops/cos.cc b/mindspore/core/ops/cos.cc index 0d77b214fc2..700580ab40f 100644 --- a/mindspore/core/ops/cos.cc +++ b/mindspore/core/ops/cos.cc @@ -24,11 +24,10 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; return std::make_shared(in_shape); } diff --git a/mindspore/core/ops/crop.cc b/mindspore/core/ops/crop.cc index f96a2705c75..8d811acc88e 100644 --- a/mindspore/core/ops/crop.cc +++ b/mindspore/core/ops/crop.cc @@ -49,7 +49,7 @@ AbstractBasePtr CropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt MS_EXCEPTION_IF_NULL(item); } // infer shape - auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->BuildShape(), prim_name); + auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; // infer type auto x_type = input_args[0]->BuildType()->cast()->element(); return std::make_shared(x_type, out_shape); diff --git a/mindspore/core/ops/custom_extract_features.cc b/mindspore/core/ops/custom_extract_features.cc index 007b4ca6c8f..43e88ea5bdc 100644 --- a/mindspore/core/ops/custom_extract_features.cc +++ b/mindspore/core/ops/custom_extract_features.cc @@ -24,18 +24,14 @@ namespace ops { AbstractBasePtr CustomExtractFeaturesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); MS_EXCEPTION_IF_NULL(input_args[0]); - // auto input = input_args[0]; - // Infer type auto output0_type = kInt32; auto output1_type = kFloat32; // Infer shape std::vector out_shape; - auto input_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); + auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto string_num = input_shape[0]; if (string_num == 0) { out_shape.push_back(1); diff --git a/mindspore/core/ops/depth_to_space.cc b/mindspore/core/ops/depth_to_space.cc index 229df9c394e..d4449a3c710 100644 --- a/mindspore/core/ops/depth_to_space.cc +++ b/mindspore/core/ops/depth_to_space.cc @@ -54,7 +54,7 @@ AbstractBasePtr DepthToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri auto input_x = input_args[0]->cast(); MS_EXCEPTION_IF_NULL(input_x); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto format = Format(GetValue(primitive->GetAttr(kFormat))); if (format == NHWC) { x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; diff --git a/mindspore/core/ops/depthwise_conv2d.cc b/mindspore/core/ops/depthwise_conv2d.cc index cf600fbb99d..05fd02c5fa0 100644 --- a/mindspore/core/ops/depthwise_conv2d.cc +++ b/mindspore/core/ops/depthwise_conv2d.cc @@ -119,8 +119,8 @@ abstract::ShapePtr DepthWiseConv2DInferShape(const PrimitivePtr &primitive, MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInRange("conv2d_Infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); - auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; + auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape]; auto format = Format(GetValue(primitive->GetAttr(kFormat))); if (format == NHWC) { x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; diff --git a/mindspore/core/ops/detection_post_process.cc b/mindspore/core/ops/detection_post_process.cc index 37cb6669873..295700e21ce 100644 --- a/mindspore/core/ops/detection_post_process.cc +++ b/mindspore/core/ops/detection_post_process.cc @@ -120,9 +120,9 @@ AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, c auto boxes = input_args[0]; auto scores = input_args[1]; auto anchors = input_args[2]; - auto boxes_shape = CheckAndConvertUtils::ConvertShapePtrToShape("boxes_shape", boxes->BuildShape(), prim_name); - auto scores_shape = CheckAndConvertUtils::ConvertShapePtrToShape("scores_shape", scores->BuildShape(), prim_name); - auto anchors_shape = CheckAndConvertUtils::ConvertShapePtrToShape("anchors_shape", anchors->BuildShape(), prim_name); + auto boxes_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(boxes->BuildShape())[kShape]; + auto scores_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(scores->BuildShape())[kShape]; + auto anchors_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(anchors->BuildShape())[kShape]; auto format = Format(GetValue(primitive->GetAttr(kFormat))); if (format == NHWC) { boxes_shape = {boxes_shape[0], boxes_shape[3], boxes_shape[1], boxes_shape[2]}; diff --git a/mindspore/core/ops/dropout.cc b/mindspore/core/ops/dropout.cc index f6671e687a5..c4c433031b3 100644 --- a/mindspore/core/ops/dropout.cc +++ b/mindspore/core/ops/dropout.cc @@ -43,7 +43,7 @@ AbstractBasePtr DropoutInfer(const abstract::AnalysisEnginePtr &, const Primitiv CheckAndConvertUtils::CheckInteger("dropout_infer", input_args.size(), kEqual, 1, prim_name); // Infer shape - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInteger("x_shape", x_shape.size(), kGreaterEqual, 1, prim_name); std::vector out_shape; out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end()); diff --git a/mindspore/core/ops/elu.cc b/mindspore/core/ops/elu.cc index 59a3c3ad878..888cbb5854e 100644 --- a/mindspore/core/ops/elu.cc +++ b/mindspore/core/ops/elu.cc @@ -31,11 +31,10 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto op_name = primitive->name(); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name); + auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; return std::make_shared(in_shape); } diff --git a/mindspore/core/ops/expand_dims.cc b/mindspore/core/ops/expand_dims.cc index 8e1999ea200..f347c9e7d8f 100644 --- a/mindspore/core/ops/expand_dims.cc +++ b/mindspore/core/ops/expand_dims.cc @@ -36,7 +36,7 @@ AbstractBasePtr ExpandDimsInfer(const abstract::AnalysisEnginePtr &, const Primi MS_EXCEPTION_IF_NULL(item); } // Infer shape - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto dim_val = GetValue(input_args[1]->BuildValue()); auto rank = x_shape.size(); CheckAndConvertUtils::CheckInRange("axis", dim_val, kIncludeBoth, {-rank - 1, rank}, prim_name); diff --git a/mindspore/core/ops/fake_quant_with_min_max_vars.cc b/mindspore/core/ops/fake_quant_with_min_max_vars.cc index 495a185f7c8..2214c9b8d06 100644 --- a/mindspore/core/ops/fake_quant_with_min_max_vars.cc +++ b/mindspore/core/ops/fake_quant_with_min_max_vars.cc @@ -29,9 +29,9 @@ namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); - auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShape("min_shape", input_args[1]->BuildShape(), prim_name); - auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShape("max_shape", input_args[2]->BuildShape(), prim_name); + auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; + auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kGreaterEqual, 1, prim_name); CheckAndConvertUtils::Check("min_shape", min_shape, kEqual, "max_shape", max_shape, prim_name); CheckAndConvertUtils::CheckInteger("min_shape", min_shape.size(), kEqual, 1, prim_name); diff --git a/mindspore/core/ops/fake_quant_with_min_max_vars_per_channel.cc b/mindspore/core/ops/fake_quant_with_min_max_vars_per_channel.cc index da25c340927..e7d8666c216 100644 --- a/mindspore/core/ops/fake_quant_with_min_max_vars_per_channel.cc +++ b/mindspore/core/ops/fake_quant_with_min_max_vars_per_channel.cc @@ -44,9 +44,9 @@ AbstractBasePtr FakeQuantWithMinMaxVarsPerChannelInfer(const abstract::AnalysisE const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name); - auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShape("min_shape", input_args[1]->BuildShape(), op_name); - auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShape("max_shape", input_args[2]->BuildShape(), op_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; + auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInteger("x rank", (int64_t)x_shape.size(), kGreaterThan, 1, op_name); CheckAndConvertUtils::Check("min shape", min_shape, kEqual, "max shape", max_shape, op_name); CheckAndConvertUtils::CheckInteger("min shape", (int64_t)min_shape.size(), kEqual, 1, op_name); diff --git a/mindspore/core/ops/fft_imag.cc b/mindspore/core/ops/fft_imag.cc index 762dd459145..250be1312be 100644 --- a/mindspore/core/ops/fft_imag.cc +++ b/mindspore/core/ops/fft_imag.cc @@ -24,8 +24,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); - auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("in_shape", input_args[0]->BuildShape(), prim_name); + auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; in_shape.pop_back(); return std::make_shared(in_shape); } diff --git a/mindspore/core/ops/fft_real.cc b/mindspore/core/ops/fft_real.cc index bd6998b2f46..294059c12e6 100644 --- a/mindspore/core/ops/fft_real.cc +++ b/mindspore/core/ops/fft_real.cc @@ -33,7 +33,7 @@ AbstractBasePtr FftRealInfer(const abstract::AnalysisEnginePtr &, const Primitiv MS_EXCEPTION_IF_NULL(item); } auto out_dtype = kFloat32; - auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; out_shape.pop_back(); return std::make_shared(out_dtype, std::make_shared(out_shape)); } diff --git a/mindspore/core/ops/flatten.cc b/mindspore/core/ops/flatten.cc index c18f39f8dda..a6c421e7e76 100644 --- a/mindspore/core/ops/flatten.cc +++ b/mindspore/core/ops/flatten.cc @@ -25,7 +25,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vectorname(); CheckAndConvertUtils::CheckInteger("input args size", input_args.size(), kGreaterEqual, 1, prim_name); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto prod = 1; int64_t size = x_shape.size(); for (int64_t i = 1; i < size; i++) { diff --git a/mindspore/core/ops/floor.cc b/mindspore/core/ops/floor.cc index 52ac1e0b3d3..3608e3e9822 100644 --- a/mindspore/core/ops/floor.cc +++ b/mindspore/core/ops/floor.cc @@ -28,11 +28,10 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); + auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; return std::make_shared(in_shape); } diff --git a/mindspore/core/ops/fusion/avg_pool_fusion.cc b/mindspore/core/ops/fusion/avg_pool_fusion.cc index 2b333a25770..74eb4143eff 100644 --- a/mindspore/core/ops/fusion/avg_pool_fusion.cc +++ b/mindspore/core/ops/fusion/avg_pool_fusion.cc @@ -53,7 +53,7 @@ namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); - auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); + auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; auto format = Format(GetValue(primitive->GetAttr(kFormat))); if (format == NHWC) { in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; diff --git a/mindspore/core/ops/fusion/full_connection.cc b/mindspore/core/ops/fusion/full_connection.cc index 602d26e7af7..651a10a2419 100644 --- a/mindspore/core/ops/fusion/full_connection.cc +++ b/mindspore/core/ops/fusion/full_connection.cc @@ -53,8 +53,8 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P MS_EXCEPTION_IF_NULL(input_args[1]); auto input0 = input_args[0]; auto input1 = input_args[1]; - auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input0->BuildShape(), prim_name); - auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input1_shape", input1->BuildShape(), prim_name); + auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input0->BuildShape())[kShape]; + auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input1->BuildShape())[kShape]; auto prim_axis = GetValue(primitive->GetAttr(kAxis)); auto has_bias = GetValue(primitive->GetAttr(kHasBias)); if (has_bias) { @@ -78,8 +78,7 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P new_k = input1_shape[1]; } if (has_bias) { - auto input2_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("input2_shape", input_args[2]->BuildShape(), prim_name); + auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; if (input2_shape[0] != input1_shape[0]) { MS_EXCEPTION(ValueError) << "Bias size invalid"; } diff --git a/mindspore/core/ops/fusion/max_pool_fusion.cc b/mindspore/core/ops/fusion/max_pool_fusion.cc index c0bdc079a75..d928d75ba91 100644 --- a/mindspore/core/ops/fusion/max_pool_fusion.cc +++ b/mindspore/core/ops/fusion/max_pool_fusion.cc @@ -53,7 +53,7 @@ namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); - auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); + auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; auto format = Format(GetValue(primitive->GetAttr(kFormat))); if (format == NHWC) { in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; diff --git a/mindspore/core/ops/fusion/slice_fusion.cc b/mindspore/core/ops/fusion/slice_fusion.cc index dd1af5c8120..879a4f7e22f 100644 --- a/mindspore/core/ops/fusion/slice_fusion.cc +++ b/mindspore/core/ops/fusion/slice_fusion.cc @@ -33,7 +33,7 @@ AbstractBasePtr SliceFusionInfer(const abstract::AnalysisEnginePtr &, const Prim const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto x_shape_len = (int64_t)x_shape.size(); auto begin_v = input_args[1]->BuildValue(); auto size_v = input_args[2]->BuildValue(); diff --git a/mindspore/core/ops/gather_d.cc b/mindspore/core/ops/gather_d.cc index 4d14e3d66ed..0274cc8eef5 100644 --- a/mindspore/core/ops/gather_d.cc +++ b/mindspore/core/ops/gather_d.cc @@ -29,8 +29,8 @@ abstract::ShapePtr GatherDInferShape(const PrimitivePtr &primitive, const std::v MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); // check - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); - auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShape("dim_shape", input_args[2]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; int64_t x_rank = x_shape.size(); CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, "index_rank", index_shape.size(), prim_name); auto dim_v = GetValue(input_args[1]->BuildValue()); diff --git a/mindspore/core/ops/gather_nd.cc b/mindspore/core/ops/gather_nd.cc index 53bd1fa51bd..374e4fed522 100644 --- a/mindspore/core/ops/gather_nd.cc +++ b/mindspore/core/ops/gather_nd.cc @@ -32,9 +32,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vectorBuildShape(), prim_name); - auto indices_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("indices_shape", input_args[1]->BuildShape(), prim_name); + auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; auto input_rank = input_shape.size(); auto indices_rank = indices_shape.size(); CheckAndConvertUtils::CheckInteger("Input of indices data", input_rank, kGreaterEqual, diff --git a/mindspore/core/ops/gelu.cc b/mindspore/core/ops/gelu.cc index f48a1f32526..8f3282b4c8d 100644 --- a/mindspore/core/ops/gelu.cc +++ b/mindspore/core/ops/gelu.cc @@ -28,8 +28,7 @@ namespace ops { namespace { abstract::ShapePtr GeLUInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); - auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_x", input_args[0]->BuildShape(), prim_name); + auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; return std::make_shared(input_shape); } diff --git a/mindspore/core/ops/grad/batch_norm_grad.cc b/mindspore/core/ops/grad/batch_norm_grad.cc index 407dca7db1f..b2db2ef5bba 100644 --- a/mindspore/core/ops/grad/batch_norm_grad.cc +++ b/mindspore/core/ops/grad/batch_norm_grad.cc @@ -47,13 +47,11 @@ bool BatchNormGrad::get_is_training() const { AbstractBasePtr BatchNormGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto op_name = primitive->name(); MS_EXCEPTION_IF_NULL(input_args[1]); MS_EXCEPTION_IF_NULL(input_args[2]); MS_EXCEPTION_IF_NULL(input_args[3]); - auto y_backprop_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("y_backprop_shape", input_args[0]->BuildShape(), op_name); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->BuildShape(), op_name); + auto y_backprop_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; CheckAndConvertUtils::Check("BatchNorm y_backprop_shape", y_backprop_shape, kEqual, "BatchNorm x_shape", x_shape); auto dx = input_args[1]->Broaden(); diff --git a/mindspore/core/ops/grad/bias_add_grad.cc b/mindspore/core/ops/grad/bias_add_grad.cc index 0612806c9fe..9fec587f180 100644 --- a/mindspore/core/ops/grad/bias_add_grad.cc +++ b/mindspore/core/ops/grad/bias_add_grad.cc @@ -46,7 +46,7 @@ AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const Prim MS_EXCEPTION_IF_NULL(input_args[0]); // Infer shape - auto inshape = CheckAndConvertUtils::ConvertShapePtrToShape("inshape", input_args[0]->BuildShape(), prim_name); + auto inshape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; for (size_t i = 0; i < inshape.size() - 1; i++) { inshape[i] = 1; } diff --git a/mindspore/core/ops/grad/binary_cross_entropy_grad.cc b/mindspore/core/ops/grad/binary_cross_entropy_grad.cc index a3d17aa74f3..5ff9c8b95d9 100644 --- a/mindspore/core/ops/grad/binary_cross_entropy_grad.cc +++ b/mindspore/core/ops/grad/binary_cross_entropy_grad.cc @@ -27,10 +27,9 @@ abstract::ShapePtr BinaryCrossEntroyGradInferShape(const PrimitivePtr &primitive const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); - auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name); - auto weight_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("weight_shape", input_args[2]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; + auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name); if (weight_shape.size() < 1) { CheckAndConvertUtils::Check("y shape", y_shape, kEqual, "weight shape", weight_shape, prim_name); diff --git a/mindspore/core/ops/grad/dropout_grad.cc b/mindspore/core/ops/grad/dropout_grad.cc index 03fadb08373..df2307d7ecc 100644 --- a/mindspore/core/ops/grad/dropout_grad.cc +++ b/mindspore/core/ops/grad/dropout_grad.cc @@ -35,8 +35,7 @@ namespace { abstract::ShapePtr DropoutGradInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto op_name = primitive->name(); - auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); + auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; return std::make_shared(in_shape); } diff --git a/mindspore/core/ops/grad/max_pool_grad.cc b/mindspore/core/ops/grad/max_pool_grad.cc index 64120ab95ff..42df7b25f5d 100644 --- a/mindspore/core/ops/grad/max_pool_grad.cc +++ b/mindspore/core/ops/grad/max_pool_grad.cc @@ -21,10 +21,8 @@ namespace mindspore { namespace ops { AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(primitive); - auto op_name = primitive->name(); MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue()); - auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x1_shape", input_args[0]->BuildShape(), op_name); + auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto tensor_type = input_args[0]->BuildType()->cast(); MS_EXCEPTION_IF_NULL(tensor_type); auto element = tensor_type->element(); diff --git a/mindspore/core/ops/grad/sigmoid_cross_entropy_with_logits_grad.cc b/mindspore/core/ops/grad/sigmoid_cross_entropy_with_logits_grad.cc index c08730a285a..a85e8c68d7a 100644 --- a/mindspore/core/ops/grad/sigmoid_cross_entropy_with_logits_grad.cc +++ b/mindspore/core/ops/grad/sigmoid_cross_entropy_with_logits_grad.cc @@ -35,9 +35,9 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsGradInfer(const abstract::AnalysisE prim_name); // Infer Shape - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); - auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name); - auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShape("dout_shape", input_args[2]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; + auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "y_shape", y_shape, prim_name, TypeError); CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "dout_shape", dout_shape, prim_name, TypeError); diff --git a/mindspore/core/ops/grad/smooth_l1_loss_grad.cc b/mindspore/core/ops/grad/smooth_l1_loss_grad.cc index 59215e60417..bed8506b7ff 100644 --- a/mindspore/core/ops/grad/smooth_l1_loss_grad.cc +++ b/mindspore/core/ops/grad/smooth_l1_loss_grad.cc @@ -40,9 +40,9 @@ AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const CheckAndConvertUtils::CheckInteger("smooth_l1_loss_grad_infer", input_args.size(), kEqual, 3, prim_name); // Infer shape - auto prediction = CheckAndConvertUtils::ConvertShapePtrToShape("prediction", input_args[0]->BuildShape(), prim_name); - auto target = CheckAndConvertUtils::ConvertShapePtrToShape("target", input_args[1]->BuildShape(), prim_name); - auto dloss = CheckAndConvertUtils::ConvertShapePtrToShape("dloss", input_args[2]->BuildShape(), prim_name); + auto prediction = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto target = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; + auto dloss = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "target shape", target, prim_name, TypeError); CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "dloss", dloss, prim_name, TypeError); diff --git a/mindspore/core/ops/hashtable_lookup.cc b/mindspore/core/ops/hashtable_lookup.cc index 1f039b58147..883067296f5 100644 --- a/mindspore/core/ops/hashtable_lookup.cc +++ b/mindspore/core/ops/hashtable_lookup.cc @@ -27,9 +27,8 @@ AbstractBasePtr HashtableLookupInfer(const abstract::AnalysisEnginePtr &, const for (auto input : input_args) { MS_EXCEPTION_IF_NULL(input); } - auto op_name = primitive->name(); std::vector hits_shape; - auto input = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); + auto input = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; hits_shape.push_back(input[0]); auto value_type = input_args[2]->BuildType(); diff --git a/mindspore/core/ops/l2_normalize.cc b/mindspore/core/ops/l2_normalize.cc index b5e7ef0fe7c..86cd5e9405b 100644 --- a/mindspore/core/ops/l2_normalize.cc +++ b/mindspore/core/ops/l2_normalize.cc @@ -46,7 +46,7 @@ AbstractBasePtr L2NormalizeInfer(const abstract::AnalysisEnginePtr &, const Prim } const std::set valid_types = {kFloat16, kFloat32}; (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto x_rank = SizeToLong(x_shape.size()); auto axiss = GetValue>(primitive->GetAttr(kAxis)); for (auto &axis : axiss) { diff --git a/mindspore/core/ops/log.cc b/mindspore/core/ops/log.cc index f64804c2d0a..8916394f2e9 100644 --- a/mindspore/core/ops/log.cc +++ b/mindspore/core/ops/log.cc @@ -24,7 +24,7 @@ namespace mindspore { namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Log"); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; return std::make_shared(x_shape); } diff --git a/mindspore/core/ops/logical_not.cc b/mindspore/core/ops/logical_not.cc index b21c6175d98..b64e262e275 100644 --- a/mindspore/core/ops/logical_not.cc +++ b/mindspore/core/ops/logical_not.cc @@ -24,8 +24,7 @@ namespace ops { namespace { abstract::ShapePtr LogicalNotInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto op_name = primitive->name(); - auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); + auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; return std::make_shared(in_shape); } diff --git a/mindspore/core/ops/lrn.cc b/mindspore/core/ops/lrn.cc index 4201bc865dc..f1d1e1a3db1 100644 --- a/mindspore/core/ops/lrn.cc +++ b/mindspore/core/ops/lrn.cc @@ -78,7 +78,7 @@ namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInteger("input shape", in_shape.size(), kEqual, 4, prim_name); return std::make_shared(in_shape); } diff --git a/mindspore/core/ops/lsh_projection.cc b/mindspore/core/ops/lsh_projection.cc index 9125f36dac4..2bdfb5ad881 100644 --- a/mindspore/core/ops/lsh_projection.cc +++ b/mindspore/core/ops/lsh_projection.cc @@ -32,14 +32,14 @@ AbstractBasePtr LshProjectionInfer(const abstract::AnalysisEnginePtr &, const Pr const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); - auto input0 = CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input_args[0]->BuildShape(), op_name); - auto input1 = CheckAndConvertUtils::ConvertShapePtrToShape("input1_shape", input_args[1]->BuildShape(), op_name); + auto input0 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto input1 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInteger("input0_shape", input0.size(), kEqual, 2, op_name); CheckAndConvertUtils::CheckInteger("input0_shape_dimen_1", input0[1], kLessEqual, 32, op_name); CheckAndConvertUtils::CheckInteger("input1_shape", input1.size(), kGreaterEqual, 1, op_name); if (input_args.size() == 3) { - auto input2 = CheckAndConvertUtils::ConvertShapePtrToShape("input2_shape", input_args[2]->BuildShape(), op_name); + auto input2 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInteger("input2_shape", input2.size(), kEqual, 1, op_name); CheckAndConvertUtils::CheckInteger("input2_shape_dimen_0", input2[0], kEqual, input1[0], op_name); } diff --git a/mindspore/core/ops/lstm.cc b/mindspore/core/ops/lstm.cc index 6b1985da9e9..2f903efe696 100644 --- a/mindspore/core/ops/lstm.cc +++ b/mindspore/core/ops/lstm.cc @@ -32,9 +32,9 @@ AbstractBasePtr LstmInfer(const PrimitivePtr &primitive, const std::vectorname(); CheckAndConvertUtils::CheckInteger("lstm_prim_infer", input_args.size(), kEqual, 4, prim_name); - auto x_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); - auto h_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("h_shape", input_args[1]->BuildShape(), prim_name); - auto c_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("c_shape", input_args[2]->BuildShape(), prim_name); + auto x_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto h_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; + auto c_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; int64_t input_x_size = GetValue(primitive->GetAttr(kInput_size)); CheckAndConvertUtils::CheckInteger("x_shape.size()", x_input_shape.size(), kEqual, 3, prim_name); diff --git a/mindspore/core/ops/mat_mul.cc b/mindspore/core/ops/mat_mul.cc index 53749eb77e5..17f29738e27 100644 --- a/mindspore/core/ops/mat_mul.cc +++ b/mindspore/core/ops/mat_mul.cc @@ -26,8 +26,8 @@ abstract::ShapePtr MatMulInferShape(const PrimitivePtr &primitive, const std::ve MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("matmul_infer_input", input_args.size(), kEqual, 2, prim_name); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); - auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; auto trans_a = GetValue(primitive->GetAttr(kTransposeA)); auto trans_b = GetValue(primitive->GetAttr(kTransposeB)); diff --git a/mindspore/core/ops/matrix_diag.cc b/mindspore/core/ops/matrix_diag.cc index c8ed95fd2ee..e4cf74002af 100644 --- a/mindspore/core/ops/matrix_diag.cc +++ b/mindspore/core/ops/matrix_diag.cc @@ -30,9 +30,8 @@ namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); - auto assist_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("assist_shape", input_args[1]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto assist_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInteger("assist rank", (int64_t)assist_shape.size(), kGreaterEqual, 2, prim_name); CheckAndConvertUtils::Check("x_shape rank", (int64_t)x_shape.size() + 1, kLessEqual, "assist rank", diff --git a/mindspore/core/ops/max_pool.cc b/mindspore/core/ops/max_pool.cc index 22201079b05..e9da8d3c49b 100644 --- a/mindspore/core/ops/max_pool.cc +++ b/mindspore/core/ops/max_pool.cc @@ -82,7 +82,7 @@ namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); - auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); + auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; auto format = Format(GetValue(primitive->GetAttr(kFormat))); if (format == NHWC) { in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; diff --git a/mindspore/core/ops/mfcc.cc b/mindspore/core/ops/mfcc.cc index c52bec1a105..40658c9b905 100644 --- a/mindspore/core/ops/mfcc.cc +++ b/mindspore/core/ops/mfcc.cc @@ -25,10 +25,8 @@ namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - auto first_input_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("first_input_shape", input_args[0]->BuildShape(), prim_name); - auto second_input_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("second_input_shape", input_args[1]->BuildShape(), prim_name); + auto first_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto second_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInteger("first input rank", first_input_shape.size(), kEqual, 3, prim_name); CheckAndConvertUtils::CheckInteger("second input rank", second_input_shape.size(), kEqual, 1, prim_name); std::vector out_shape = {first_input_shape[0], first_input_shape[1], diff --git a/mindspore/core/ops/one_hot.cc b/mindspore/core/ops/one_hot.cc index 824b0810a14..f559dce9698 100644 --- a/mindspore/core/ops/one_hot.cc +++ b/mindspore/core/ops/one_hot.cc @@ -31,7 +31,7 @@ abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::ve MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); int64_t axis = GetValue(primitive->GetAttr(kAxis)); - auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); + auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInRange("axis", axis, kIncludeBoth, {-1, SizeToLong(in_shape.size())}, op_name); auto depth_val = GetValue(input_args[1]->BuildValue()); CheckAndConvertUtils::CheckInteger("depth", depth_val, kGreaterEqual, 0, op_name); diff --git a/mindspore/core/ops/ones_like.cc b/mindspore/core/ops/ones_like.cc index 1a02b5e6a16..1aaf4fb2773 100644 --- a/mindspore/core/ops/ones_like.cc +++ b/mindspore/core/ops/ones_like.cc @@ -28,9 +28,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); - auto input_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); + auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; return std::make_shared(input_shape); } diff --git a/mindspore/core/ops/op_utils.cc b/mindspore/core/ops/op_utils.cc index 3ca50dd84ec..64e61e0a8b5 100644 --- a/mindspore/core/ops/op_utils.cc +++ b/mindspore/core/ops/op_utils.cc @@ -27,8 +27,8 @@ namespace mindspore { namespace ops { abstract::ShapePtr BroadCastInferShape(const std::string &op_name, const std::vector &input_args) { MS_LOG(INFO) << "Do infer shape for op " << op_name; - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); - auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->GetShapeTrack(), op_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; + auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape]; if (x_shape == y_shape) { return std::make_shared(x_shape); } diff --git a/mindspore/core/ops/pack.cc b/mindspore/core/ops/pack.cc index ec31ed0dd7f..ac52ad775fb 100644 --- a/mindspore/core/ops/pack.cc +++ b/mindspore/core/ops/pack.cc @@ -23,7 +23,7 @@ std::vector _get_pack_shape(std::vector x_shapes, std::ve std::string name) { CheckAndConvertUtils::CheckInteger("len of input_x", (int64_t)x_shapes.size(), kGreaterEqual, 1, name); CheckAndConvertUtils::CheckSubClass("input_x[0]", x_types[0], {TypeIdToType(kObjectTypeTensorType)}, name); - auto output_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape[0]", x_shapes[0], name); + auto output_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shapes[0])[kShape]; int64_t rank_base = output_shape.size(); int64_t N = x_shapes.size(); // CheckAndConvertUtils::CheckInRange("axis", axis, kIncludeBoth, {-rank_base-1, rank_base}, name); @@ -37,7 +37,7 @@ std::vector _get_pack_shape(std::vector x_shapes, std::ve MS_EXCEPTION_IF_NULL(type0); CheckAndConvertUtils::Check("x_type[" + std::to_string(i) + "]", type->type_id(), kEqual, "base", type0->type_id(), name); - auto shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape" + std::to_string(i), x_shapes[i], name); + auto shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shapes[i])[kShape]; if (shape != output_shape) { MS_EXCEPTION(ValueError) << "For '" + name + "' element " + std::to_string(i) + "shape in input can't pack with first element."; diff --git a/mindspore/core/ops/pad.cc b/mindspore/core/ops/pad.cc index 95a5be8fd1a..db15de7c61d 100644 --- a/mindspore/core/ops/pad.cc +++ b/mindspore/core/ops/pad.cc @@ -25,7 +25,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vectorname(); auto paddings_attr = GetValue>>(primitive->GetAttr(kPaddings)); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Pad"); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInteger("paddings_size", paddings_attr.size(), kEqual, int64_t(2 * x_shape.size()), prim_name); int64_t size = paddings_attr.size(); diff --git a/mindspore/core/ops/prelu.cc b/mindspore/core/ops/prelu.cc index 9cd8cc98a8b..b6b8b6e5069 100644 --- a/mindspore/core/ops/prelu.cc +++ b/mindspore/core/ops/prelu.cc @@ -25,8 +25,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vectorname(); auto x = input_args[0]->BuildShape(); auto w = input_args[1]->BuildShape(); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", x, prim_name); - auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", w, prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x)[kShape]; + auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(w)[kShape]; CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kNotEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 1, prim_name); diff --git a/mindspore/core/ops/prior_box.cc b/mindspore/core/ops/prior_box.cc index fad2b279a7d..2ec1f078f1d 100644 --- a/mindspore/core/ops/prior_box.cc +++ b/mindspore/core/ops/prior_box.cc @@ -112,7 +112,6 @@ void PriorBox::Init(const std::vector &min_sizes, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto op_name = primitive->name(); MS_EXCEPTION_IF_NULL(input_args[0]); std::vector different_aspect_ratios{1.0f}; auto aspect_ratios = GetValue>(primitive->GetAttr(kAspectRatios)); @@ -129,7 +128,7 @@ AbstractBasePtr PriorBoxInfer(const abstract::AnalysisEnginePtr &, const Primiti } auto min_sizes = GetValue>(primitive->GetAttr(kMinSizes)); int64_t num_priors_box = min_sizes.size() * different_aspect_ratios.size() + min_sizes.size(); - auto input = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); + auto input = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; int64_t h = input[0] * input[1] * num_priors_box * 4; std::vector output_shape{1, h, 1, 2}; return std::make_shared(kFloat32, output_shape); diff --git a/mindspore/core/ops/quant_dtype_cast.cc b/mindspore/core/ops/quant_dtype_cast.cc index aee60bd0a90..9a0e8bd2414 100644 --- a/mindspore/core/ops/quant_dtype_cast.cc +++ b/mindspore/core/ops/quant_dtype_cast.cc @@ -32,13 +32,12 @@ void QuantDTypeCast::Init(const int64_t src_t, const int64_t dst_t) { AbstractBasePtr QuantDTypeCastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto op_name = primitive->name(); MS_EXCEPTION_IF_NULL(input_args[0]); auto input_type = input_args[0]->BuildType()->cast(); MS_EXCEPTION_IF_NULL(input_type); auto dst_type = GetValue(primitive->GetAttr(kDstT)); MS_ASSERT(input_type->element() == TypeIdToType(TypeId(dst_type))); - auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); + auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; return std::make_shared(TypeIdToType(TypeId(dst_type)), input_shape); } REGISTER_PRIMITIVE_C(kNameQuantDTypeCast, QuantDTypeCast); diff --git a/mindspore/core/ops/reciprocal.cc b/mindspore/core/ops/reciprocal.cc index bca8df361c8..7d3bdd996f0 100644 --- a/mindspore/core/ops/reciprocal.cc +++ b/mindspore/core/ops/reciprocal.cc @@ -34,8 +34,7 @@ AbstractBasePtr ReciprocalInfer(const abstract::AnalysisEnginePtr &, const Primi MS_EXCEPTION_IF_NULL(item); } // infer shape - auto in_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), prim_name); + auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; // infer type std::set valid_x_type = {kTensorType}; auto x_type = CheckAndConvertUtils::CheckTypeValid("x_type", input_args[0]->BuildType(), valid_x_type, prim_name); diff --git a/mindspore/core/ops/reduce.cc b/mindspore/core/ops/reduce.cc index e037a21f9bc..7c791009d9c 100644 --- a/mindspore/core/ops/reduce.cc +++ b/mindspore/core/ops/reduce.cc @@ -71,8 +71,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vectorname(); - auto input_x_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("input_x_shape", input_args[0]->BuildShape(), prim_name); + auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto keep_dims = GetValue(primitive->GetAttr(kKeepDims)); auto out_shape = infer_shape_reduce(input_x_shape, axis_value, keep_dims, prim_name); diff --git a/mindspore/core/ops/resize_bilinear.cc b/mindspore/core/ops/resize_bilinear.cc index 5aca24e120c..e7868fbd61c 100644 --- a/mindspore/core/ops/resize_bilinear.cc +++ b/mindspore/core/ops/resize_bilinear.cc @@ -49,8 +49,7 @@ AbstractBasePtr ResizeBilinearInfer(const abstract::AnalysisEnginePtr &, const P CheckAndConvertUtils::CheckInteger("resize_bilinear_infer", input_args.size(), kEqual, 1, prim_name); // Infer shape - auto input_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); + auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInteger("input_shape_rank", input_shape.size(), kEqual, 4, prim_name); std::vector out_shape = {input_shape[0], input_shape[1]}; auto size = GetValue>(primitive->GetAttr(kSize)); diff --git a/mindspore/core/ops/reverse_sequence.cc b/mindspore/core/ops/reverse_sequence.cc index ead5d7c835e..6e70736302b 100644 --- a/mindspore/core/ops/reverse_sequence.cc +++ b/mindspore/core/ops/reverse_sequence.cc @@ -44,10 +44,8 @@ AbstractBasePtr ReverseSequenceInfer(const abstract::AnalysisEnginePtr &, const MS_EXCEPTION_IF_NULL(item); } // infer shape - auto input_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); - auto seq_lengths = - CheckAndConvertUtils::ConvertShapePtrToShape("seq_lengths", input_args[1]->BuildShape(), prim_name); + auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto seq_lengths = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; auto seq_dim = GetValue(primitive->GetAttr(kSeqDim)); auto batch_dim = GetValue(primitive->GetAttr(kBatchDim)); CheckAndConvertUtils::CheckInteger("seq_dim", seq_dim, kLessEqual, input_shape.size(), prim_name); diff --git a/mindspore/core/ops/reverse_v2.cc b/mindspore/core/ops/reverse_v2.cc index a247172d954..24d3394cbe4 100644 --- a/mindspore/core/ops/reverse_v2.cc +++ b/mindspore/core/ops/reverse_v2.cc @@ -24,8 +24,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; return std::make_shared(x_shape); } diff --git a/mindspore/core/ops/rfft.cc b/mindspore/core/ops/rfft.cc index 2e38650c8ae..15059f52e0b 100644 --- a/mindspore/core/ops/rfft.cc +++ b/mindspore/core/ops/rfft.cc @@ -24,9 +24,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); - auto first_input_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("first_input_shape", input_args[0]->BuildShape(), prim_name); + auto first_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto out_shape = first_input_shape; out_shape[out_shape.size() - 1] = GetValue(primitive->GetAttr(kFftLength)) / 2 + 1; out_shape.push_back(2); diff --git a/mindspore/core/ops/roi_pooling.cc b/mindspore/core/ops/roi_pooling.cc index d0fa2cc11b4..956b98fea5c 100644 --- a/mindspore/core/ops/roi_pooling.cc +++ b/mindspore/core/ops/roi_pooling.cc @@ -62,9 +62,8 @@ AbstractBasePtr ROIPoolingInfer(const abstract::AnalysisEnginePtr &, const Primi // Infer shape auto new_h = GetValue(primitive->GetAttr(kPooledH)); auto new_w = GetValue(primitive->GetAttr(kPooledW)); - auto input_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); - auto roi_shape = CheckAndConvertUtils::ConvertShapePtrToShape("roi_shape", input_args[1]->BuildShape(), prim_name); + auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto roi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; std::vector output_shape; output_shape.push_back(roi_shape[0]); output_shape.push_back(new_h); diff --git a/mindspore/core/ops/round.cc b/mindspore/core/ops/round.cc index 41a5920894b..fb1d345a2e3 100644 --- a/mindspore/core/ops/round.cc +++ b/mindspore/core/ops/round.cc @@ -23,7 +23,7 @@ namespace mindspore { namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "round"); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; return std::make_shared(x_shape); } diff --git a/mindspore/core/ops/rsqrt.cc b/mindspore/core/ops/rsqrt.cc index b9bab657071..8bf7b371602 100644 --- a/mindspore/core/ops/rsqrt.cc +++ b/mindspore/core/ops/rsqrt.cc @@ -30,7 +30,7 @@ namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("in_shape", input_args[0]->GetShapeTrack(), prim_name); + auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; CheckAndConvertUtils::CheckInteger("input shape", in_shape.size(), kEqual, 1, prim_name); return std::make_shared(in_shape); } diff --git a/mindspore/core/ops/scalar_summary.cc b/mindspore/core/ops/scalar_summary.cc index 641ae573181..43f14f83a9d 100644 --- a/mindspore/core/ops/scalar_summary.cc +++ b/mindspore/core/ops/scalar_summary.cc @@ -29,7 +29,7 @@ abstract::ShapePtr ScalarSummaryInferShape(const PrimitivePtr &primitive, MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); // check - auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[1]->BuildShape(), prim_name); + auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kLessEqual, 1, prim_name); return std::make_shared(ShapeVector(1)); } diff --git a/mindspore/core/ops/scatter_nd.cc b/mindspore/core/ops/scatter_nd.cc index 3957f4722cf..ca79a6b7054 100644 --- a/mindspore/core/ops/scatter_nd.cc +++ b/mindspore/core/ops/scatter_nd.cc @@ -29,10 +29,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vectorBuildShape(), "ScatterNd"); - auto update_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("update_shape", input_args[1]->BuildShape(), "ScatterNd"); + auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto update_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInteger("indices_shape[0] and update_shape[0]", indices_shape[0], kEqual, update_shape[0], "ScatterNd"); return std::make_shared(shape_value_element); diff --git a/mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc b/mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc index d316d0c5217..b4536b9ed3b 100644 --- a/mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc +++ b/mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc @@ -34,8 +34,8 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsInfer(const abstract::AnalysisEngin prim_name); // Infer shape - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); - auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "y_shape", y_shape, prim_name, TypeError); // Infer type diff --git a/mindspore/core/ops/sin.cc b/mindspore/core/ops/sin.cc index 6bf39ea2e64..ddb05f25369 100644 --- a/mindspore/core/ops/sin.cc +++ b/mindspore/core/ops/sin.cc @@ -31,7 +31,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vectorBuildShape(), "Sin"); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; return std::make_shared(x_shape); } diff --git a/mindspore/core/ops/skip_gram.cc b/mindspore/core/ops/skip_gram.cc index 59c3737cafd..1b5a1c39ef1 100644 --- a/mindspore/core/ops/skip_gram.cc +++ b/mindspore/core/ops/skip_gram.cc @@ -23,7 +23,6 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); if (input_args.size() != 1) { MS_LOG(ERROR) << "Skip Gram should have one input"; } @@ -31,7 +30,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vectorBuildShape(), prim_name); + auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; return std::make_shared(in_shape); } diff --git a/mindspore/core/ops/smooth_l1_loss.cc b/mindspore/core/ops/smooth_l1_loss.cc index ee908b7ea88..d6ba52cd9a6 100644 --- a/mindspore/core/ops/smooth_l1_loss.cc +++ b/mindspore/core/ops/smooth_l1_loss.cc @@ -40,8 +40,8 @@ AbstractBasePtr SmoothL1LossInfer(const abstract::AnalysisEnginePtr &, const Pri CheckAndConvertUtils::CheckInteger("smooth_l1_loss_infer", input_args.size(), kEqual, 2, prim_name); // Infer shape - auto prediction = CheckAndConvertUtils::ConvertShapePtrToShape("prediction", input_args[0]->BuildShape(), prim_name); - auto target = CheckAndConvertUtils::ConvertShapePtrToShape("target", input_args[0]->BuildShape(), prim_name); + auto prediction = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto target = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "target shape", target, prim_name, TypeError); // Infer type diff --git a/mindspore/core/ops/softmax_cross_entropy_with_logits.cc b/mindspore/core/ops/softmax_cross_entropy_with_logits.cc index 2c0dadaeb00..e139605df1d 100644 --- a/mindspore/core/ops/softmax_cross_entropy_with_logits.cc +++ b/mindspore/core/ops/softmax_cross_entropy_with_logits.cc @@ -34,10 +34,8 @@ AbstractBasePtr SoftmaxCrossEntropyWithLogitsInfer(const abstract::AnalysisEngin prim_name); // Infer shape - auto logits_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("logits_shape", input_args[0]->BuildShape(), prim_name); - auto labels_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("labels_shape", input_args[1]->BuildShape(), prim_name); + auto logits_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto labels_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; CheckAndConvertUtils::Check("logits shape", logits_shape, kEqual, "labels shape", labels_shape, prim_name, TypeError); std::vector loss_shape = {logits_shape[0]}; auto dlogits_shape = logits_shape; diff --git a/mindspore/core/ops/space_to_batch.cc b/mindspore/core/ops/space_to_batch.cc index 93aaa99041f..b8300ec8bdd 100644 --- a/mindspore/core/ops/space_to_batch.cc +++ b/mindspore/core/ops/space_to_batch.cc @@ -29,8 +29,7 @@ namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - auto input_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); + auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInteger("input shape", input_shape.size(), kEqual, 4, prim_name); std::vector output_shape(input_shape.size()); auto block_shape_vector = GetValue>(primitive->GetAttr(kBlockSize)); diff --git a/mindspore/core/ops/space_to_batch_nd.cc b/mindspore/core/ops/space_to_batch_nd.cc index 2075ba7085c..804c61ba871 100644 --- a/mindspore/core/ops/space_to_batch_nd.cc +++ b/mindspore/core/ops/space_to_batch_nd.cc @@ -29,7 +29,7 @@ namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name); auto out_shape = x_shape; int64_t block_shape_prod = 1; diff --git a/mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.cc b/mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.cc index 2a61a73fffb..3294517ae40 100644 --- a/mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.cc +++ b/mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.cc @@ -43,8 +43,7 @@ AbstractBasePtr SparseSoftmaxCrossEntropyWithLogitsInfer(const abstract::Analysi MS_EXCEPTION_IF_NULL(item); } // infer shape - auto input_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); + auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; std::vector output_shape; if (GetValue(primitive->GetAttr(kIsGrad)) != 0) { output_shape = input_shape; diff --git a/mindspore/core/ops/sparse_to_dense.cc b/mindspore/core/ops/sparse_to_dense.cc index e7f68e36486..5668d8ac1a8 100644 --- a/mindspore/core/ops/sparse_to_dense.cc +++ b/mindspore/core/ops/sparse_to_dense.cc @@ -33,8 +33,7 @@ AbstractBasePtr SparseToDenseInfer(const abstract::AnalysisEnginePtr &, const Pr MS_EXCEPTION_IF_NULL(item); } // infer shape - auto dense_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("dense_shape", input_args[3]->BuildShape(), prim_name); + auto dense_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape]; // infer type auto values_type = input_args[1]->BuildType()->cast()->element(); return std::make_shared(values_type, dense_shape); diff --git a/mindspore/core/ops/squeeze.cc b/mindspore/core/ops/squeeze.cc index 4ca7b5c3c3f..4fa07111f87 100644 --- a/mindspore/core/ops/squeeze.cc +++ b/mindspore/core/ops/squeeze.cc @@ -29,7 +29,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector>(primitive->GetAttr(kAxis)); std::vector infer_shape; - auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name); + auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; auto len = SizeToLong(in_shape.size()); if (axis.empty()) { std::copy_if(in_shape.begin(), in_shape.end(), std::back_inserter(infer_shape), diff --git a/mindspore/core/ops/stack.cc b/mindspore/core/ops/stack.cc index 81c2870c6e7..2af5406758d 100644 --- a/mindspore/core/ops/stack.cc +++ b/mindspore/core/ops/stack.cc @@ -21,7 +21,6 @@ namespace ops { namespace { abstract::AbstractBasePtr StackInfer(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); if (input_args.size() != 1) { MS_LOG(ERROR) << "Invalid output size:" << input_args.size(); @@ -29,11 +28,9 @@ abstract::AbstractBasePtr StackInfer(const PrimitivePtr &primitive, const std::v if (input_args.size() < 1) { MS_LOG(ERROR) << "Invalid input size " << input_args.size(); } - auto input_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); + auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; for (int64_t i = 1; i < (int64_t)input_args.size(); ++i) { - auto input_shape_tmp = - CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[i]->BuildShape(), prim_name); + auto input_shape_tmp = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[i]->BuildShape())[kShape]; if (input_shape_tmp.size() != input_shape.size()) { MS_LOG(ERROR) << "All input shape size should be the same!"; } diff --git a/mindspore/core/ops/strided_slice.cc b/mindspore/core/ops/strided_slice.cc index b0346a2b447..510b8decc3e 100644 --- a/mindspore/core/ops/strided_slice.cc +++ b/mindspore/core/ops/strided_slice.cc @@ -108,7 +108,7 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive, auto temp_strides_v = input_args[3]->cast()->BuildValue(); auto strides_v = GetValue>(temp_strides_v); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; int64_t x_rank = x_shape.size(); int64_t slice_len = begin_v.size(); std::vector begin_pos = TenToTwo(GetValue(primitive->GetAttr(kBeginMask))); diff --git a/mindspore/core/ops/tan.cc b/mindspore/core/ops/tan.cc index 72ba1b4d806..d42a517e151 100644 --- a/mindspore/core/ops/tan.cc +++ b/mindspore/core/ops/tan.cc @@ -33,7 +33,7 @@ AbstractBasePtr TanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr CheckAndConvertUtils::CheckInteger("tan_infer", input_args.size(), kEqual, 1, prim_name); // Infer Shape - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto infer_shape = std::make_shared(x_shape); // Infer Type diff --git a/mindspore/core/ops/tensor_list_from_tensor.cc b/mindspore/core/ops/tensor_list_from_tensor.cc index 6b82a140bb5..3d961720ea5 100644 --- a/mindspore/core/ops/tensor_list_from_tensor.cc +++ b/mindspore/core/ops/tensor_list_from_tensor.cc @@ -24,11 +24,8 @@ namespace { abstract::ShapePtr TensorListFromTensorInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); - auto input0_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("input0 shape", input_args[0]->BuildShape(), prim_name); - auto input1_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("input1 shape", input_args[1]->BuildShape(), prim_name); + auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; if (input0_shape.size() < 1) { MS_LOG(ERROR) << "input0_shape.size():" << input0_shape.size() << " must be greater than 0!"; } diff --git a/mindspore/core/ops/tensor_list_stack.cc b/mindspore/core/ops/tensor_list_stack.cc index a7013816522..f22182b3315 100644 --- a/mindspore/core/ops/tensor_list_stack.cc +++ b/mindspore/core/ops/tensor_list_stack.cc @@ -52,9 +52,7 @@ AbstractBasePtr TensorListStackInfer(const abstract::AnalysisEnginePtr &, const for (const auto &input : input_args) { MS_EXCEPTION_IF_NULL(input); } - auto op_name = primitive->name(); - auto input0_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input_args[0]->BuildShape(), op_name); + auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; int64_t num = std::accumulate(input0_shape.begin(), input0_shape.end(), 1LL, std::multiplies()); if (num == 0) { MS_LOG(ERROR) << "Try to stack a empty tensorlist!"; @@ -62,8 +60,7 @@ AbstractBasePtr TensorListStackInfer(const abstract::AnalysisEnginePtr &, const if (input_args[1]->BuildShape() == nullptr) { MS_LOG(ERROR) << "ele_shape->data_c() is nullptr"; } - auto input1_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("input1_shape", input_args[1]->BuildShape(), op_name); + auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; input1_shape.insert(input1_shape.begin(), 1); return std::make_shared(input_args[0]->BuildType(), input1_shape); } diff --git a/mindspore/core/ops/tensor_summary.cc b/mindspore/core/ops/tensor_summary.cc index 73266a960c5..e7b0469cb49 100644 --- a/mindspore/core/ops/tensor_summary.cc +++ b/mindspore/core/ops/tensor_summary.cc @@ -29,7 +29,7 @@ abstract::ShapePtr TensorSummaryInferShape(const PrimitivePtr &primitive, MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); // check - auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[1]->BuildShape(), prim_name); + auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kGreaterEqual, 1, prim_name); return std::make_shared(ShapeVector(1)); } diff --git a/mindspore/core/ops/tile.cc b/mindspore/core/ops/tile.cc index fa3580d24e9..05954576b77 100644 --- a/mindspore/core/ops/tile.cc +++ b/mindspore/core/ops/tile.cc @@ -25,8 +25,7 @@ namespace ops { namespace { abstract::ShapePtr TileInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); - auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x shape", input_args[0]->BuildShape(), prim_name); + auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto multiples_v = GetValue>(input_args[1]->cast()->BuildValue()); int len_sub = input_shape.size() - multiples_v.size(); std::vector infer_shape = input_shape; diff --git a/mindspore/core/ops/topk.cc b/mindspore/core/ops/topk.cc index c87a761b2d1..0f9a9e04476 100644 --- a/mindspore/core/ops/topk.cc +++ b/mindspore/core/ops/topk.cc @@ -41,7 +41,7 @@ AbstractBasePtr TopKInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name); // Infer shape - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto k_v = GetValue(input_args[1]->BuildValue()); auto ndims = x_shape.size() - 1; x_shape[ndims] = k_v; diff --git a/mindspore/core/ops/unpack.cc b/mindspore/core/ops/unpack.cc index 2c9ba02b5f4..c4655f99d31 100644 --- a/mindspore/core/ops/unpack.cc +++ b/mindspore/core/ops/unpack.cc @@ -28,7 +28,7 @@ AbstractBasePtr UnpackInfer(const abstract::AnalysisEnginePtr &, const Primitive auto prim_name = primitive->name(); CheckAndConvertUtils::CheckSubClass("x", input_args[0]->BuildType(), {TypeIdToType(kObjectTypeTensorType)}, prim_name); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; int64_t dim = x_shape.size(); int64_t axis = GetValue(primitive->GetAttr(kAxis)); // CheckAndConvertUtils::CheckInRange("axis value", axis, kIncludeLeft, {-dim, dim}, prim_name); diff --git a/mindspore/core/ops/unsorted_segment_sum.cc b/mindspore/core/ops/unsorted_segment_sum.cc index 5a1916c2e93..da9128d47c6 100644 --- a/mindspore/core/ops/unsorted_segment_sum.cc +++ b/mindspore/core/ops/unsorted_segment_sum.cc @@ -33,11 +33,10 @@ AbstractBasePtr UnsortedSegmentSumInfer(const abstract::AnalysisEnginePtr &, con // Infer type auto x_type = input_args[0]->BuildType()->cast()->element(); // Infer shape - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInteger("x_shape", x_shape.size(), kGreaterThan, 0, prim_name); auto shp = x_shape; - auto segment_ids_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->BuildShape(), prim_name); + auto segment_ids_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInteger("segment_ids_shape", segment_ids_shape.size(), kGreaterThan, 0, prim_name); CheckAndConvertUtils::Check("input_x", x_shape.size(), kGreaterEqual, "segment_ids_shape", segment_ids_shape.size(), prim_name); diff --git a/mindspore/core/ops/unsqueeze.cc b/mindspore/core/ops/unsqueeze.cc index 1737bf672a6..8d7eed6d259 100644 --- a/mindspore/core/ops/unsqueeze.cc +++ b/mindspore/core/ops/unsqueeze.cc @@ -39,7 +39,7 @@ AbstractBasePtr UnsqueezeInfer(const abstract::AnalysisEnginePtr &, const Primit // Infer shape auto dims = GetValue>(primitive->GetAttr(kAxis)); - auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input", input->BuildShape(), prim_name); + auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input->BuildShape())[kShape]; auto input_rank = input_shape.size(); auto dim_rank = dims.size(); std::vector out_shape; diff --git a/mindspore/core/ops/unstack.cc b/mindspore/core/ops/unstack.cc index 3186aabde3c..66528a9870f 100644 --- a/mindspore/core/ops/unstack.cc +++ b/mindspore/core/ops/unstack.cc @@ -26,7 +26,7 @@ AbstractBasePtr UnstackInfer(const abstract::AnalysisEnginePtr &, const Primitiv const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; int64_t dim = x_shape.size(); int64_t axis = GetValue(primitive->GetAttr(kAxis)); // CheckAndConvertUtils::CheckInRange("axis value", axis, kIncludeLeft, {-dim, dim}, prim_name); diff --git a/mindspore/core/ops/where.cc b/mindspore/core/ops/where.cc index 27c11c6d3ec..b34d2d529c0 100644 --- a/mindspore/core/ops/where.cc +++ b/mindspore/core/ops/where.cc @@ -33,14 +33,11 @@ AbstractBasePtr WhereInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP auto input0_type_ = input_args[0]->BuildType()->cast(); MS_EXCEPTION_IF_NULL(input0_type_); auto input0_type = input0_type_->element(); - auto input0_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input_args[0]->BuildShape(), op_name); + auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto num = input_args[0]->BuildValue()->cast()->ElementsNum(); - auto input1_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("input1_shape", input_args[1]->BuildShape(), op_name); + auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; auto num1 = input_args[1]->BuildValue()->cast()->ElementsNum(); - auto input2_shape = - CheckAndConvertUtils::ConvertShapePtrToShape("input2_shape", input_args[2]->BuildShape(), op_name); + auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; auto num2 = input_args[2]->BuildValue()->cast()->ElementsNum(); int64_t nummax = num > num1 ? num : (num1 > num2 ? num1 : num2); int64_t axisout = 0; diff --git a/mindspore/core/ops/zeros.cc b/mindspore/core/ops/zeros.cc index 398e40996a0..e72c4ce646d 100644 --- a/mindspore/core/ops/zeros.cc +++ b/mindspore/core/ops/zeros.cc @@ -52,9 +52,8 @@ TypePtr ZerosInferType(const PrimitivePtr &prim, const std::vector &input_args, const abstract::AbstractBasePtr &abs) { MS_EXCEPTION_IF_NULL(prim); - auto prim_name = prim->name(); // check - auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShape("output shape", abs->BuildShape(), prim_name); + auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(abs->BuildShape())[kShape]; auto out_type = abs->BuildType(); MS_EXCEPTION_IF_NULL(out_type); return TensorConstructUtils::CreateZerosTensor(out_type, out_shape); diff --git a/mindspore/core/ops/zeros_like.cc b/mindspore/core/ops/zeros_like.cc index 63adb63f490..58784283040 100644 --- a/mindspore/core/ops/zeros_like.cc +++ b/mindspore/core/ops/zeros_like.cc @@ -35,7 +35,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vectorBuildShape(), prim_name); + auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; return std::make_shared(in_shape); } diff --git a/mindspore/core/utils/check_convert_utils.cc b/mindspore/core/utils/check_convert_utils.cc index 8e2594165e2..3f4175e1aaf 100644 --- a/mindspore/core/utils/check_convert_utils.cc +++ b/mindspore/core/utils/check_convert_utils.cc @@ -384,18 +384,6 @@ int64_t CheckAndConvertUtils::CheckInteger(const std::string &arg_name, int64_t MS_EXCEPTION(ValueError) << buffer.str(); } -std::vector CheckAndConvertUtils::ConvertShapePtrToShape(const std::string &arg_name, - const BaseShapePtr &shape, - const std::string &prim_name) { - MS_EXCEPTION_IF_NULL(shape); - if (!shape->isa()) { - return std::vector(); - } - auto shape_element = shape->cast(); - MS_EXCEPTION_IF_NULL(shape_element); - return shape_element->shape(); -} - ShapeMap CheckAndConvertUtils::ConvertShapePtrToShapeMap(const BaseShapePtr &shape) { MS_EXCEPTION_IF_NULL(shape); if (!shape->isa()) { diff --git a/mindspore/core/utils/check_convert_utils.h b/mindspore/core/utils/check_convert_utils.h index c3e76d13ee9..9472460e029 100644 --- a/mindspore/core/utils/check_convert_utils.h +++ b/mindspore/core/utils/check_convert_utils.h @@ -236,9 +236,6 @@ class CheckAndConvertUtils { MS_EXCEPTION(ValueError) << buffer.str(); } - static std::vector ConvertShapePtrToShape(const std::string &arg_name, const BaseShapePtr &shape, - const std::string &prim_name); - static ShapeMap ConvertShapePtrToShapeMap(const BaseShapePtr &shape); static void Check(const std::string &arg_name, int64_t arg_value, CompareEnum compare_type, diff --git a/mindspore/lite/tools/converter/ops/while.cc b/mindspore/lite/tools/converter/ops/while.cc index 1c666efdadf..b8cc501143b 100644 --- a/mindspore/lite/tools/converter/ops/while.cc +++ b/mindspore/lite/tools/converter/ops/while.cc @@ -52,11 +52,9 @@ AbstractBasePtr WhileInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP MS_EXCEPTION_IF_NULL(primitive); auto While_prim = primitive->cast(); MS_EXCEPTION_IF_NULL(While_prim); - auto op_name = While_prim->name(); AbstractBasePtrList output; for (int64_t i = 0; i < (int64_t)input_args.size(); i++) { - auto shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape" + std::to_string(i), - input_args[i]->BuildShape(), op_name); + auto shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[i]->BuildShape())[kShape]; output.push_back(std::make_shared(input_args[i]->BuildType(), shape)); } return std::make_shared(output);