!28142 fix check axis error log

Merge pull request !28142 from lianliguang/fix-check-args-report-error
This commit is contained in:
i-robot 2021-12-25 08:19:24 +00:00 committed by Gitee
commit 2c893b54d4
5 changed files with 14 additions and 10 deletions

View File

@ -240,7 +240,7 @@ AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValueP
}
for (auto &elem : axis_data) {
int64_t e_value = CheckAxis(primitive->name(), "axis", elem, -SizeToLong(x_rank), SizeToLong(x_rank));
int64_t e_value = CheckAxis(primitive->name(), "axis", elem, -SizeToLong(x_rank), SizeToLong(x_rank), "input_x");
(void)axis_set.insert(e_value);
}
MS_EXCEPTION_IF_NULL(x_shp_value->cast<ValueTuplePtr>());

View File

@ -162,7 +162,7 @@ TypePtr CheckDtypeSame(const std::string &op, const AbstractTensorPtr &tensor_ba
}
int64_t CheckAxis(const std::string &op, const std::string &args_name, const ValuePtr &axis, int64_t minimum,
int64_t max) {
int64_t max, const std::string &rank_name) {
if (axis == nullptr) {
MS_LOG(EXCEPTION) << op << " evaluator axis is null";
}
@ -171,8 +171,9 @@ int64_t CheckAxis(const std::string &op, const std::string &args_name, const Val
}
int64_t axis_value = GetValue<int64_t>(axis);
if (axis_value >= max || axis_value < minimum) {
MS_LOG(EXCEPTION) << "The primitive[" << op << "]'s \'" << args_name << "\' value should be in the range ["
<< minimum << ", " << max << "), but got " << axis_value;
MS_LOG(EXCEPTION) << "For primitive[" << op << "], " << rank_name << "'s rank is " << max << ", while the "
<< "\'" << args_name << "\' value should be in the range [" << minimum << ", " << max
<< "), but got " << axis_value;
}
if (axis_value < 0) {
axis_value = axis_value + SizeToLong(max);

View File

@ -45,7 +45,8 @@ void CheckShapeSame(const std::string &op, const AbstractTensorPtr &tensor_base,
TypePtr CheckDtypeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor);
int64_t CheckAxis(const std::string &op, const std::string &arg_name, const ValuePtr &axis, int64_t min, int64_t max);
int64_t CheckAxis(const std::string &op, const std::string &arg_name, const ValuePtr &axis, int64_t min, int64_t max,
const std::string &rank_name);
void CheckArgsSize(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t size_expect);

View File

@ -148,7 +148,7 @@ AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &pr
ValuePtr axis = primitive->GetAttr("axis");
// Axis value should be in [-(rank_base + 1), rank_base).
int64_t axis_value = CheckAxis(op_name, "axis", axis, -(rank_base + 1), rank_base);
int64_t axis_value = CheckAxis(op_name, "axis", axis, -(rank_base + 1), rank_base, "input_x");
for (size_t i = 1; i < tuple_len; ++i) {
AbstractTensorPtr tensor = nullptr;
@ -948,7 +948,7 @@ AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &pr
int64_t rank = SizeToLong(x_shape.size());
ValuePtr axis = primitive->GetAttr("axis");
int64_t axis_value_pos = CheckAxis(op_name, "axis", axis, -(rank + 1), rank);
int64_t axis_value_pos = CheckAxis(op_name, "axis", axis, -(rank + 1), rank, "input_x");
int64_t output_num_value = GetValue<int64_t>(primitive->GetAttr("output_num"));
if ((x_shape[axis_value_pos] != Shape::SHP_ANY) && (x_shape[axis_value_pos] % output_num_value != 0)) {
MS_LOG(EXCEPTION) << "x_shape[" << axis_value_pos << "] = " << x_shape[axis_value_pos]
@ -1093,7 +1093,7 @@ AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &p
ValuePtr axis = primitive->GetAttr("axis");
// Axis value should be in [-(rank_base + 1), rank_base).
int64_t axis_value = CheckAxis(op_name, "axis", axis, -(rank_base + 1), rank_base);
int64_t axis_value = CheckAxis(op_name, "axis", axis, -(rank_base + 1), rank_base, "input_x");
int64_t all_shp = shape_base[axis_value];
int64_t min_all_shp = min_shape_base[axis_value];

View File

@ -61,10 +61,12 @@ AbstractBasePtr LayerNormInfer(const abstract::AnalysisEnginePtr &, const Primit
// begin_norm_axis and begin_params_axis should be smaller than the size of input_x and >= -1
ValuePtr bna_ptr = primitive->GetAttr("begin_norm_axis");
int64_t begin_norm_axis = abstract::CheckAxis(op_name, "begin_norm_axis", bna_ptr, -1, SizeToLong(input_rank));
int64_t begin_norm_axis =
abstract::CheckAxis(op_name, "begin_norm_axis", bna_ptr, -1, SizeToLong(input_rank), "input_x");
ValuePtr bpa_ptr = primitive->GetAttr("begin_params_axis");
int64_t begin_params_axis = abstract::CheckAxis(op_name, "begin_params_axis", bpa_ptr, -1, SizeToLong(input_rank));
int64_t begin_params_axis =
abstract::CheckAxis(op_name, "begin_params_axis", bpa_ptr, -1, SizeToLong(input_rank), "input_x");
// the beta and gama shape should be x_shape[begin_params_axis:]
auto valid_types = {kFloat16, kFloat32};