From d9075240237ba1cf58efb22feeaa31b8984292b5 Mon Sep 17 00:00:00 2001 From: lianliguang Date: Fri, 24 Dec 2021 15:43:18 +0800 Subject: [PATCH] fix error log of check axis --- .../ccsrc/frontend/operator/ops_front_infer_function.cc | 2 +- mindspore/core/abstract/param_validator.cc | 7 ++++--- mindspore/core/abstract/param_validator.h | 3 ++- mindspore/core/abstract/prim_arrays.cc | 6 +++--- mindspore/core/ops/layer_norm.cc | 6 ++++-- 5 files changed, 14 insertions(+), 10 deletions(-) diff --git a/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc b/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc index 5d204a41b7e..056c7813e3f 100644 --- a/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc +++ b/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc @@ -240,7 +240,7 @@ AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValueP } for (auto &elem : axis_data) { - int64_t e_value = CheckAxis(primitive->name(), "axis", elem, -SizeToLong(x_rank), SizeToLong(x_rank)); + int64_t e_value = CheckAxis(primitive->name(), "axis", elem, -SizeToLong(x_rank), SizeToLong(x_rank), "input_x"); (void)axis_set.insert(e_value); } MS_EXCEPTION_IF_NULL(x_shp_value->cast()); diff --git a/mindspore/core/abstract/param_validator.cc b/mindspore/core/abstract/param_validator.cc index 2d4cc251641..af0f3fd2e95 100644 --- a/mindspore/core/abstract/param_validator.cc +++ b/mindspore/core/abstract/param_validator.cc @@ -162,7 +162,7 @@ TypePtr CheckDtypeSame(const std::string &op, const AbstractTensorPtr &tensor_ba } int64_t CheckAxis(const std::string &op, const std::string &args_name, const ValuePtr &axis, int64_t minimum, - int64_t max) { + int64_t max, const std::string &rank_name) { if (axis == nullptr) { MS_LOG(EXCEPTION) << op << " evaluator axis is null"; } @@ -171,8 +171,9 @@ int64_t CheckAxis(const std::string &op, const std::string &args_name, const Val } int64_t axis_value = GetValue(axis); if (axis_value >= max || axis_value < minimum) { - MS_LOG(EXCEPTION) << "The primitive[" << op << "]'s \'" << args_name << "\' value should be in the range [" - << minimum << ", " << max << "), but got " << axis_value; + MS_LOG(EXCEPTION) << "For primitive[" << op << "], " << rank_name << "'s rank is " << max << ", while the " + << "\'" << args_name << "\' value should be in the range [" << minimum << ", " << max + << "), but got " << axis_value; } if (axis_value < 0) { axis_value = axis_value + SizeToLong(max); diff --git a/mindspore/core/abstract/param_validator.h b/mindspore/core/abstract/param_validator.h index 2fadcd13791..169259fda21 100644 --- a/mindspore/core/abstract/param_validator.h +++ b/mindspore/core/abstract/param_validator.h @@ -45,7 +45,8 @@ void CheckShapeSame(const std::string &op, const AbstractTensorPtr &tensor_base, TypePtr CheckDtypeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor); -int64_t CheckAxis(const std::string &op, const std::string &arg_name, const ValuePtr &axis, int64_t min, int64_t max); +int64_t CheckAxis(const std::string &op, const std::string &arg_name, const ValuePtr &axis, int64_t min, int64_t max, + const std::string &rank_name); void CheckArgsSize(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t size_expect); diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index 76d8dc280f0..9294e2ec885 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -148,7 +148,7 @@ AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &pr ValuePtr axis = primitive->GetAttr("axis"); // Axis value should be in [-(rank_base + 1), rank_base). - int64_t axis_value = CheckAxis(op_name, "axis", axis, -(rank_base + 1), rank_base); + int64_t axis_value = CheckAxis(op_name, "axis", axis, -(rank_base + 1), rank_base, "input_x"); for (size_t i = 1; i < tuple_len; ++i) { AbstractTensorPtr tensor = nullptr; @@ -948,7 +948,7 @@ AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &pr int64_t rank = SizeToLong(x_shape.size()); ValuePtr axis = primitive->GetAttr("axis"); - int64_t axis_value_pos = CheckAxis(op_name, "axis", axis, -(rank + 1), rank); + int64_t axis_value_pos = CheckAxis(op_name, "axis", axis, -(rank + 1), rank, "input_x"); int64_t output_num_value = GetValue(primitive->GetAttr("output_num")); if ((x_shape[axis_value_pos] != Shape::SHP_ANY) && (x_shape[axis_value_pos] % output_num_value != 0)) { MS_LOG(EXCEPTION) << "x_shape[" << axis_value_pos << "] = " << x_shape[axis_value_pos] @@ -1093,7 +1093,7 @@ AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &p ValuePtr axis = primitive->GetAttr("axis"); // Axis value should be in [-(rank_base + 1), rank_base). - int64_t axis_value = CheckAxis(op_name, "axis", axis, -(rank_base + 1), rank_base); + int64_t axis_value = CheckAxis(op_name, "axis", axis, -(rank_base + 1), rank_base, "input_x"); int64_t all_shp = shape_base[axis_value]; int64_t min_all_shp = min_shape_base[axis_value]; diff --git a/mindspore/core/ops/layer_norm.cc b/mindspore/core/ops/layer_norm.cc index 1ad1a4d886d..45abec3aa45 100644 --- a/mindspore/core/ops/layer_norm.cc +++ b/mindspore/core/ops/layer_norm.cc @@ -61,10 +61,12 @@ AbstractBasePtr LayerNormInfer(const abstract::AnalysisEnginePtr &, const Primit // begin_norm_axis and begin_params_axis should be smaller than the size of input_x and >= -1 ValuePtr bna_ptr = primitive->GetAttr("begin_norm_axis"); - int64_t begin_norm_axis = abstract::CheckAxis(op_name, "begin_norm_axis", bna_ptr, -1, SizeToLong(input_rank)); + int64_t begin_norm_axis = + abstract::CheckAxis(op_name, "begin_norm_axis", bna_ptr, -1, SizeToLong(input_rank), "input_x"); ValuePtr bpa_ptr = primitive->GetAttr("begin_params_axis"); - int64_t begin_params_axis = abstract::CheckAxis(op_name, "begin_params_axis", bpa_ptr, -1, SizeToLong(input_rank)); + int64_t begin_params_axis = + abstract::CheckAxis(op_name, "begin_params_axis", bpa_ptr, -1, SizeToLong(input_rank), "input_x"); // the beta and gama shape should be x_shape[begin_params_axis:] auto valid_types = {kFloat16, kFloat32};