From 5864d6a0063e4db3ea19a7b687ddf25b2058ba1e Mon Sep 17 00:00:00 2001 From: lianliguang Date: Wed, 24 Nov 2021 21:56:03 +0800 Subject: [PATCH] modify some eroor log --- .../operator/ops_front_infer_function.cc | 2 +- mindspore/ccsrc/pipeline/jit/action.cc | 8 +- .../jit/static_analysis/async_eval_result.cc | 5 +- .../pipeline/jit/static_analysis/prim.cc | 30 +++++ mindspore/core/abstract/infer_functions.h | 2 - mindspore/core/abstract/param_validator.cc | 83 +++---------- mindspore/core/abstract/param_validator.h | 2 - mindspore/core/abstract/prim_arrays.cc | 7 +- mindspore/core/abstract/prim_nn.cc | 111 +----------------- .../core/abstract/primitive_infer_map.cc | 1 - mindspore/core/abstract/utils.cc | 77 ------------ mindspore/core/abstract/utils.h | 9 -- mindspore/core/ops/layer_norm.cc | 5 +- tests/ut/cpp/operator/composite_test.cc | 5 +- 14 files changed, 60 insertions(+), 287 deletions(-) diff --git a/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc b/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc index 821d5d43c2b..5704cc1db87 100644 --- a/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc +++ b/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc @@ -227,7 +227,7 @@ AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValueP } for (auto &elem : axis_data) { - int64_t e_value = CheckAxis(primitive->name(), elem, -SizeToLong(x_rank), SizeToLong(x_rank) - 1); + int64_t e_value = CheckAxis(primitive->name(), elem, -SizeToLong(x_rank), SizeToLong(x_rank)); (void)axis_set.insert(e_value); } MS_EXCEPTION_IF_NULL(x_shp_value->cast()); diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index ccbf5d15b4e..430e8a1e8f0 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -1019,10 +1019,10 @@ bool SetMindIRGraphAction(const ResourcePtr &res) { if (!AbstractBasePtrListDeepEqual(func_args, broaded_args)) { MS_LOG(EXCEPTION) << "The input arguments is not compatible with the function graph which has been exported before." << "Please check the args is same with export.\n" - << "The export input argument size : " << func_args.size() << "\n" - << "The load input argument size : " << broaded_args.size() << "\n" - << "Export input args info:" << abstract::ArgsToString(func_args) << "\n" - << "The input args info:" << abstract::ArgsToString(broaded_args); + << "The export input argument size: " << func_args.size() << "\n" + << "The load input argument size: " << broaded_args.size() << "\n" + << "Export input args info: " << abstract::ArgsToString(func_args) << "\n" + << "The input args info: " << abstract::ArgsToString(broaded_args); } // suppose that there is not KeywordArgument for the top graph diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.cc index abe54f668f9..2c450abdb71 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.cc @@ -326,11 +326,10 @@ void AnalysisResultCacheMgr::Todo() { std::string ArgsToString(const AbstractBasePtrList &args_spec_list) { std::ostringstream buffer; - buffer << "("; for (const auto &item : args_spec_list) { - buffer << item->ToString() << " # "; + buffer << item->BuildType()->ToString() << "," << item->BuildShape()->ToString() << " #" + << "\n"; } - buffer << " )"; return buffer.str(); } } // namespace abstract diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 1c252c02d2d..c0e2d0bac54 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -445,6 +445,36 @@ void ConvertAbstractFunctionToPython(const AbstractBasePtr &abs_base, py::dict * } } } + +bool CheckType(const TypePtr &expected_type, const TypePtr &x) { + // As x and predicate both are mindspore type statically, here we only to judge whether + // x is predicate or is a subclass of predicate. + return IsIdentidityOrSubclass(x, expected_type); +} + +// Join all types in args_type_list; +TypePtr TypeJoin(const TypePtrList &args_type_list) { + if (args_type_list.empty()) { + MS_LOG(EXCEPTION) << "args_type_list is empty"; + } + + TypePtr type_tmp = args_type_list[0]; + for (std::size_t i = 1; i < args_type_list.size(); i++) { + type_tmp = abstract::TypeJoin(type_tmp, args_type_list[i]); + } + return type_tmp; +} + +TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list) { + MS_EXCEPTION_IF_NULL(predicate); + for (const auto &arg_type : args_type_list) { + MS_EXCEPTION_IF_NULL(arg_type); + if (!CheckType(predicate, arg_type)) { + MS_LOG(EXCEPTION) << "The expected is " << predicate->ToString() << ", not " << arg_type->ToString(); + } + } + return TypeJoin(args_type_list); +} } // end anonymous namespace py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 2e7d6e7e1ea..f55b78be573 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -49,8 +49,6 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplHSigmoid(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/param_validator.cc b/mindspore/core/abstract/param_validator.cc index 0f402dcb04b..1805c14d35a 100644 --- a/mindspore/core/abstract/param_validator.cc +++ b/mindspore/core/abstract/param_validator.cc @@ -45,7 +45,7 @@ TypePtr CheckType(TypePtr type, const TypePtrList &accepts, const std::string &e if (ok) { return type; } else { - MS_LOG(EXCEPTION) << error_message_prefix << accepts << " but is " << type->ToString(); + MS_EXCEPTION(TypeError) << error_message_prefix << " should be " << accepts << ",but got " << type->ToString(); } } @@ -79,7 +79,8 @@ TypePtr CheckTensorsDTypeSame(const AbstractTensorPtrList &tensor_list, const Ty TypePtr sample_type = sample_elem->BuildType(); MS_EXCEPTION_IF_NULL(sample_type); std::ostringstream loginfoBuffer; - loginfoBuffer << "same type, got"; + loginfoBuffer << "[" << sample_tensor->BuildType()->ToString(); + bool error_flag = false; // Check if other elements have the same type with the first element. for (size_t index = 1; index < tensor_list.size(); ++index) { MS_EXCEPTION_IF_NULL(tensor_list[index]); @@ -87,12 +88,14 @@ TypePtr CheckTensorsDTypeSame(const AbstractTensorPtrList &tensor_list, const Ty MS_EXCEPTION_IF_NULL(elem); auto a_type = elem->BuildType(); MS_EXCEPTION_IF_NULL(a_type); - loginfoBuffer << " " << a_type->ToString(); + loginfoBuffer << "," << tensor_list[index]->BuildType()->ToString(); if (sample_type->type_id() != a_type->type_id()) { - MS_LOG(EXCEPTION) << "Expected type " << sample_type->ToString() << ", but got " << a_type->ToString() - << ", index " << index; + error_flag = true; } } + if (error_flag) { + MS_EXCEPTION(ValueError) << error_message_prefix << " must be same, but got " << loginfoBuffer.str() << "]"; + } MS_LOG(DEBUG) << error_message_prefix << loginfoBuffer.str(); return CheckTensorDType(sample_tensor, accepts, error_message_prefix); } @@ -167,15 +170,19 @@ int64_t CheckAxis(const std::string &op, const ValuePtr &axis, int64_t minimum, } int64_t axis_value = GetValue(axis); if (axis_value > max || axis_value < minimum) { - MS_LOG(EXCEPTION) << op << " evaluator axis value should be in the range [" << minimum << ", " << max - << "], but get " << axis_value; + MS_LOG(EXCEPTION) << "The primitive[" << op << "]'s axis value should be in the range [" << minimum << ", " << max + << "], but got " << axis_value; + } + if (axis_value < 0) { + axis_value = axis_value + SizeToLong(max); } return axis_value; } void CheckArgsSize(const std::string &op, const mindspore::abstract::AbstractBasePtrList &args_spec_list, size_t size_expect) { if (args_spec_list.size() != size_expect) { - MS_LOG(EXCEPTION) << op << " input args size should be " << size_expect << ", but got " << args_spec_list.size(); + MS_LOG(EXCEPTION) << op << " input arguments size should be " << size_expect << ", but got " + << args_spec_list.size(); } for (size_t i = 0; i < size_expect; i++) { @@ -200,65 +207,6 @@ void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape) { } } -int64_t CheckAttrPositiveInt64(const std::string &op, const ValuePtr &attr, const std::string &attr_name) { - MS_EXCEPTION_IF_NULL(attr); - auto int64_value = attr->cast(); - MS_EXCEPTION_IF_NULL(int64_value); - int64_t attr_val = int64_value->value(); - if (attr_val <= 0) { - MS_LOG(EXCEPTION) << op << " invalid " << attr_name << " value: " << attr_val << ", should be greater then 0"; - } - return attr_val; -} - -std::vector CheckAttrIntOrTuple(const std::string &op, const ValuePtr &attr, const size_t start_idx, - const size_t num_element) { - std::vector result; - MS_EXCEPTION_IF_NULL(attr); - if (attr->isa()) { - auto tuple_attr = attr->cast(); - MS_EXCEPTION_IF_NULL(tuple_attr); - std::vector attr_vec = tuple_attr->value(); - if (start_idx > attr_vec.size() || start_idx + num_element > attr_vec.size()) { - MS_EXCEPTION(IndexError) << op << " attr index is out of range, attr size is " << attr_vec.size() - << "but start idx got" << start_idx << " num element " << num_element; - } - auto it_start = attr_vec.begin() + start_idx; - (void)std::transform(it_start, it_start + num_element, std::back_inserter(result), - [](const ValuePtr &e) -> int64_t { return GetValue(e); }); - } else { - auto int64_imm = attr->cast(); - MS_EXCEPTION_IF_NULL(int64_imm); - int64_t attr_val = int64_imm->value(); - (void)result.insert(result.begin(), num_element, attr_val); - } - return result; -} - -std::string CheckAttrStringSet(const std::string &op, const ValuePtr &attr, const std::string &attr_name, - const std::set &val_set) { - MS_EXCEPTION_IF_NULL(attr); - auto string_attr = attr->cast(); - MS_EXCEPTION_IF_NULL(string_attr); - std::string attr_val = string_attr->value(); - if (val_set.find(attr_val) == val_set.end()) { - std::ostringstream buffer; - bool f_begin = true; - buffer << "{"; - for (auto &x : val_set) { - if (!f_begin) { - buffer << ", "; - } else { - f_begin = false; - } - buffer << x; - } - buffer << "}"; - MS_LOG(EXCEPTION) << op << "Unsupported " << attr_name << ": " << attr_val << ". use " << buffer.str(); - } - return attr_val; -} - void CheckRequiredArgsSize(const std::string &op, const mindspore::abstract::AbstractBasePtrList &args_spec_list, size_t size_expect) { if (args_spec_list.size() < size_expect) { @@ -268,6 +216,5 @@ void CheckRequiredArgsSize(const std::string &op, const mindspore::abstract::Abs MS_EXCEPTION_IF_NULL(args_spec_list[i]); } } - } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/param_validator.h b/mindspore/core/abstract/param_validator.h index 754509f546e..84e6930f189 100644 --- a/mindspore/core/abstract/param_validator.h +++ b/mindspore/core/abstract/param_validator.h @@ -53,8 +53,6 @@ void CheckShapeAllPositive(const std::string &op, const ShapeVector &shape); void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape); -int64_t CheckAttrPositiveInt64(const std::string &op, const ValuePtr &attr, const std::string &attr_name); - std::vector CheckAttrIntOrTuple(const std::string &op, const ValuePtr &attr, const size_t start_idx, const size_t num_element); diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index 5d8789416c8..0e0df1e27b2 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -149,8 +149,6 @@ AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &pr ValuePtr axis = primitive->GetAttr("axis"); // Axis value should be in [-(rank_base + 1), rank_base). int64_t axis_value = CheckAxis(op_name, axis, -(rank_base + 1), rank_base); - // If axis is negative, add offset(rank_base + 1) to turn it to positive. - axis_value = GetPositiveAxis(axis_value, LongToSize(rank_base + 1)); for (size_t i = 1; i < tuple_len; ++i) { AbstractTensorPtr tensor = nullptr; @@ -950,8 +948,7 @@ AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &pr int64_t rank = SizeToLong(x_shape.size()); ValuePtr axis = primitive->GetAttr("axis"); - int64_t axis_value = CheckAxis(op_name, axis, -(rank + 1), rank); - uint64_t axis_value_pos = LongToUlong(GetPositiveAxis(axis_value, LongToSize(rank))); + int64_t axis_value_pos = CheckAxis(op_name, axis, -(rank + 1), rank); int64_t output_num_value = GetValue(primitive->GetAttr("output_num")); if ((x_shape[axis_value_pos] != Shape::SHP_ANY) && (x_shape[axis_value_pos] % output_num_value != 0)) { MS_LOG(EXCEPTION) << "x_shape[" << axis_value_pos << "] = " << x_shape[axis_value_pos] @@ -1097,8 +1094,6 @@ AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &p ValuePtr axis = primitive->GetAttr("axis"); // Axis value should be in [-(rank_base + 1), rank_base). int64_t axis_value = CheckAxis(op_name, axis, -(rank_base + 1), rank_base); - // If axis is negative, add offset(rank_base) to turn it to positive. - axis_value = GetPositiveAxis(axis_value, LongToSize(rank_base)); int64_t all_shp = shape_base[axis_value]; int64_t min_all_shp = min_shape_base[axis_value]; diff --git a/mindspore/core/abstract/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc index a53bd4528ca..225ff64f586 100644 --- a/mindspore/core/abstract/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -139,14 +139,14 @@ AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape); auto input_tensor = CheckArg(op_name, args_spec_list, 0); - (void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "param x of BatchNorm should be"); + (void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "Input argument[x] of BatchNorm"); AbstractTensorPtrList tensorPtrList = std::vector(); for (size_t i = 1; i < args_spec_list.size(); ++i) { auto param = CheckArg(op_name, args_spec_list, i); tensorPtrList.push_back(param); } (void)CheckTensorsDTypeSame(tensorPtrList, {kFloat16, kFloat32}, - "param gamma, beta, mean, variance of Batchnorm should be"); + "Input arguments[gamma, beta, mean, variance] of BatchNorm"); auto data_format_ptr = primitive->GetAttr("format"); MS_EXCEPTION_IF_NULL(data_format_ptr); @@ -240,113 +240,6 @@ void CheckShape(const std::string &op_name, const ShapeVector &w_shape, const Ab CheckShapeAllPositive(op_name + " w_max_shape", w_max_shape); } -AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - constexpr auto kConv2DInputNum = 2; - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, kConv2DInputNum); - AbstractTensorPtr input_x = CheckArg(op_name, args_spec_list, 0); - MS_EXCEPTION_IF_NULL(input_x); - MS_EXCEPTION_IF_NULL(input_x->shape()); - ShapeVector x_shape = input_x->shape()->shape(); - ShapeVector x_min_shape = input_x->shape()->min_shape(); - ShapeVector x_max_shape = input_x->shape()->max_shape(); - CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape); - CheckShapeAnyAndPositive(op_name + " x_shape", x_shape); - CheckShapeAllPositive(op_name + " x_min_shape", x_min_shape); - CheckShapeAllPositive(op_name + " x_max_shape", x_max_shape); - AbstractTensorPtr input_w = CheckArg(op_name, args_spec_list, 1); - MS_EXCEPTION_IF_NULL(input_w); - MS_EXCEPTION_IF_NULL(input_w->shape()); - ShapeVector w_shape = input_w->shape()->shape(); - CheckShape(op_name, w_shape, input_w); - const uint64_t n_axis = 0; - uint64_t c_axis = 1; - uint64_t h_axis = 2; - uint64_t w_axis = 3; - - int64_t data_format = GetAndCheckFormat(primitive->GetAttr("format")); - if (data_format == Format::NHWC) { - c_axis = 3; - h_axis = 1; - w_axis = 2; - } - int64_t group = CheckAttrPositiveInt64(op_name, primitive->GetAttr("group"), "group"); - if ((x_shape[c_axis] != Shape::SHP_ANY) && (w_shape[c_axis] != Shape::SHP_ANY) && - ((x_shape[c_axis] / group) != w_shape[c_axis])) { - MS_LOG(EXCEPTION) << "x_shape[C_in] / group must be equal to w_shape[C_in]: " << w_shape[c_axis] << ", but got " - << (x_shape[c_axis] / group); - } - - int64_t out_channel = CheckAttrPositiveInt64(op_name, primitive->GetAttr("out_channel"), "out_channel"); - if ((w_shape[n_axis] != Shape::SHP_ANY) && (w_shape[n_axis] != out_channel)) { - MS_LOG(EXCEPTION) << "w_shape[" << n_axis << "] = " << w_shape[n_axis] << " must be equal to " << out_channel; - } - - const size_t kernel_size_num_element = 2; - std::vector kernel_size = - CheckAttrIntOrTuple(op_name, primitive->GetAttr("kernel_size"), 0, kernel_size_num_element); - if ((w_shape[h_axis] != Shape::SHP_ANY) && (w_shape[h_axis] != kernel_size[0])) { - MS_LOG(EXCEPTION) << "weight height: " << w_shape[h_axis] << " must be equal to " << kernel_size[0]; - } - if ((w_shape[w_axis] != Shape::SHP_ANY) && (w_shape[w_axis] != kernel_size[1])) { - MS_LOG(EXCEPTION) << "weight width: " << w_shape[w_axis] << " must be equal to " << kernel_size[1]; - } - - std::vector stride = - CheckAttrIntOrTuple(op_name, primitive->GetAttr("stride"), stride_start_idx, stride_num_element); - std::vector dilation = - CheckAttrIntOrTuple(op_name, primitive->GetAttr("dilation"), dilation_start_idx, dilation_num_element); - std::vector padding = - CheckAttrIntOrTuple(op_name, primitive->GetAttr("pad"), padding_start_idx, padding_num_element); - int64_t pad_mode; - CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr("pad_mode"), &pad_mode); - std::vector output_hw; - std::vector pad_list; - std::vector output_hw_min; - std::vector pad_list_min; - std::vector output_hw_max; - std::vector pad_list_max; - Conv2DPadFunction(&output_hw, &pad_list, x_shape[h_axis], x_shape[w_axis], kernel_size, stride, dilation, pad_mode, - padding); - if (x_shape[h_axis] == Shape::SHP_ANY) { - output_hw[0] = Shape::SHP_ANY; - } - if (x_shape[w_axis] == Shape::SHP_ANY) { - output_hw[1] = Shape::SHP_ANY; - } - Conv2DPadFunction(&output_hw_min, &pad_list_min, x_min_shape[h_axis], x_min_shape[w_axis], kernel_size, stride, - dilation, pad_mode, padding); - Conv2DPadFunction(&output_hw_max, &pad_list_max, x_max_shape[h_axis], x_max_shape[w_axis], kernel_size, stride, - dilation, pad_mode, padding); - std::vector pad_list_val = {MakeValue(pad_list[0]), MakeValue(pad_list[1]), MakeValue(pad_list[2]), - MakeValue(pad_list[3])}; - primitive->set_attr("pad_list", MakeValue(pad_list_val)); - - ShapeVector output_shape; - ShapeVector output_shape_min; - ShapeVector output_shape_max; - if (data_format == Format::NHWC) { - output_shape = {x_shape[n_axis], output_hw[0], output_hw[1], out_channel}; - output_shape_min = {x_min_shape[n_axis], output_hw_min[0], output_hw_min[1], out_channel}; - output_shape_max = {x_max_shape[n_axis], output_hw_max[0], output_hw_max[1], out_channel}; - } else { - output_shape = {x_shape[n_axis], out_channel, output_hw[0], output_hw[1]}; - output_shape_min = {x_min_shape[n_axis], out_channel, output_hw_min[0], output_hw_min[1]}; - output_shape_max = {x_max_shape[n_axis], out_channel, output_hw_max[0], output_hw_max[1]}; - } - CheckShapeAnyAndPositive(op_name + " output_shape", output_shape); - CheckShapeAllPositive(op_name + " output_shape_min", output_shape_min); - CheckShapeAllPositive(op_name + " output_shape_max", output_shape_max); - - TypePtr x_type = input_x->element()->GetTypeTrack(); - if (x_type->type_id() == TypeId::kNumberTypeInt8) { - x_type = kInt32; - } - ShapePtr output_shape_ptr = std::make_shared(output_shape, output_shape_min, output_shape_max); - return std::make_shared(x_type, output_shape_ptr); -} - AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: at least one tensor(y_backprop) diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 804271cf1a5..f4e6ee8f4c0 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -186,7 +186,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimPooling, R{InferImplPooling, nullptr, true}}, {prim::kPrimPoolingGrad, R{InferImplPoolingGrad, nullptr, true}}, {prim::kPrimBatchNorm, R{InferImplBatchNorm, nullptr, true}}, - {prim::kPrimConv2D, R{InferImplConv2D, nullptr, true}}, {prim::kPrimBpropCut, R{InferImplBpropCut, nullptr, true}}, {prim::kPrimDropout, R{InferImplDropout, nullptr, true}}, {prim::kPrimSparseApplyFtrl, R{InferImplSparseApplyFtrl, nullptr, true}}, diff --git a/mindspore/core/abstract/utils.cc b/mindspore/core/abstract/utils.cc index 442b2c255eb..e9568f4307a 100644 --- a/mindspore/core/abstract/utils.cc +++ b/mindspore/core/abstract/utils.cc @@ -183,83 +183,6 @@ AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec) { return spec->Clone(); } -namespace { -// Join all types in args_type_list; -TypePtr TypeJoin(const TypePtrList &args_type_list) { - if (args_type_list.empty()) { - MS_LOG(EXCEPTION) << "args_type_list is empty"; - } - - TypePtr type_tmp = args_type_list[0]; - for (std::size_t i = 1; i < args_type_list.size(); i++) { - type_tmp = abstract::TypeJoin(type_tmp, args_type_list[i]); - } - return type_tmp; -} -} // namespace - -bool CheckType(const TypePtr &expected_type, const TypePtr &x) { - // As x and predicate both are mindspore type statically, here we only to judge whether - // x is predicate or is a subclass of predicate. - return IsIdentidityOrSubclass(x, expected_type); -} - -TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list) { - MS_EXCEPTION_IF_NULL(predicate); - for (const auto &arg_type : args_type_list) { - MS_EXCEPTION_IF_NULL(arg_type); - if (!CheckType(predicate, arg_type)) { - MS_LOG(EXCEPTION) << "The expected is " << predicate->ToString() << ", not " << arg_type->ToString(); - } - } - return TypeJoin(args_type_list); -} - -int64_t GetPositiveAxis(int64_t axis_value, size_t increment) { - if (axis_value < 0) { - axis_value = axis_value + SizeToLong(increment); - } - - if (axis_value < 0) { - MS_LOG(EXCEPTION) << "axis_value should not still <0"; - } - - return axis_value; -} - -// Return if two shapes can be broadcast. -// Broadcast shape is placed in broadcast_output_shape. -ShapeVector RealBroadcast(const std::string &op, ShapeVector x_shape, ShapeVector y_shape) { - std::reverse(x_shape.begin(), x_shape.end()); - std::reverse(y_shape.begin(), y_shape.end()); - // Fill a placeholder value 1 which will be replaced later. - size_t std_len = x_shape.size() > y_shape.size() ? x_shape.size() : y_shape.size(); - y_shape.resize(std_len, 1); - x_shape.resize(std_len, 1); - - ShapeVector broadcast_shape; - for (size_t i = 0; i < std_len; i++) { - int64_t x_i = x_shape[i]; // i-th dimension of x - int64_t y_i = y_shape[i]; // i-th dimension of y - int64_t output_i = 0; // i-th dimension of the output - if (x_i == y_i) { - output_i = x_i; - } else if (x_i == 1) { - output_i = y_i; - } else if (y_i == 1) { - output_i = x_i; - } else { - MS_LOG(EXCEPTION) - << op - << " evaluator the shape of first tensor and the shape of second tensor do not meet the broadcasting " - "requirements"; - } - broadcast_shape.push_back(output_i); - } - std::reverse(broadcast_shape.begin(), broadcast_shape.end()); - return broadcast_shape; -} - ShapeVector BroadcastShape(ShapeVector shpx, ShapeVector shpy) { int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size()); if (dlen < 0) { diff --git a/mindspore/core/abstract/utils.h b/mindspore/core/abstract/utils.h index 02e9d43f74b..6d8d6546543 100644 --- a/mindspore/core/abstract/utils.h +++ b/mindspore/core/abstract/utils.h @@ -43,20 +43,11 @@ AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const Abstrac // else self.Clone; AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec); -TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list); - -bool CheckType(const TypePtr &expected_type, const TypePtr &x); - -int64_t GetPositiveAxis(int64_t axis_value, size_t increment); - ShapeVector BroadcastShape(ShapeVector shpx, ShapeVector shpy); MS_CORE_API size_t TypeIdSize(const TypeId data_type); size_t ShapeSize(const std::vector &shape); -// Get broadcasted shape for binary element-wise operation -ShapePtr GetBroadcastShape(const std::string &op, const AbstractTensorPtr &tensor_x, const AbstractTensorPtr &tensor_y); - // Check dynamic shape routine void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape); diff --git a/mindspore/core/ops/layer_norm.cc b/mindspore/core/ops/layer_norm.cc index 133874cef9b..ce2a1c4a75c 100644 --- a/mindspore/core/ops/layer_norm.cc +++ b/mindspore/core/ops/layer_norm.cc @@ -61,11 +61,10 @@ AbstractBasePtr LayerNormInfer(const abstract::AnalysisEnginePtr &, const Primit // begin_norm_axis and begin_params_axis should be smaller than the size of input_x and >= -1 ValuePtr bna_ptr = primitive->GetAttr("begin_norm_axis"); - int64_t begin_norm_axis = abstract::CheckAxis(op_name, bna_ptr, -1, SizeToLong(input_rank) - 1); + int64_t begin_norm_axis = abstract::CheckAxis(op_name, bna_ptr, -1, SizeToLong(input_rank)); ValuePtr bpa_ptr = primitive->GetAttr("begin_params_axis"); - int64_t begin_params_axis = abstract::CheckAxis(op_name, bpa_ptr, -1, SizeToLong(input_rank) - 1); - begin_params_axis = abstract::GetPositiveAxis(begin_params_axis, input_rank); + int64_t begin_params_axis = abstract::CheckAxis(op_name, bpa_ptr, -1, SizeToLong(input_rank)); // the beta and gama shape should be x_shape[begin_params_axis:] auto valid_types = {kFloat16, kFloat32}; diff --git a/tests/ut/cpp/operator/composite_test.cc b/tests/ut/cpp/operator/composite_test.cc index 499f19452f2..83699da16e0 100644 --- a/tests/ut/cpp/operator/composite_test.cc +++ b/tests/ut/cpp/operator/composite_test.cc @@ -104,7 +104,8 @@ TEST_F(TestComposite, test_TupleSlice_arg_two_numbers) { engine_->Run(tupleSliceGraphPtr, args_spec_list); FAIL() << "Excepted exception :Args type is wrong"; } catch (std::runtime_error const &err) { - ASSERT_TRUE(std::string(err.what()).find("TupleSlice input args size should be 2, but got 3") != std::string::npos); + ASSERT_TRUE(std::string(err.what()).find("TupleSlice input arguments size should be 2, but got 3") != + std::string::npos); } catch (...) { FAIL() << "Excepted exception :Args type is wrong"; } @@ -250,7 +251,7 @@ TEST_F(TestComposite, test_UnpackCall_3args) { MetaFuncGraphPtr unPackCallPtr = std::make_shared("UnPackCall"); FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 3); - auto fn_arg= std::make_shared(prim::kPrimMakeTuple); + auto fn_arg = std::make_shared(prim::kPrimMakeTuple); AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4}); AbstractBasePtrList eles; for (size_t i = 0; i < 6; i++) {