From a9028b24e1a530497eb2026d60c99cc1fdca3a15 Mon Sep 17 00:00:00 2001 From: yanghaoran Date: Sat, 30 Jul 2022 10:40:21 +0000 Subject: [PATCH] =?UTF-8?q?=E5=9B=9E=E9=80=80=20'Pull=20Request=20!39176?= =?UTF-8?q?=20:=20clean=20code=20of=20Argmin'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../device/cpu/kernel/argmin_cpu_kernel.cc | 28 +++++++++++-------- mindspore/ccsrc/pybind_api/ir/primitive_py.cc | 2 ++ mindspore/core/ops/arg_min.cc | 3 +- mindspore/core/ops/arg_min_v2.cc | 7 ++--- mindspore/core/ops/arg_min_v2.h | 1 + 5 files changed, 23 insertions(+), 18 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/argmin_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/argmin_cpu_kernel.cc index 5e6a1e34463..778672bf8de 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/argmin_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/argmin_cpu_kernel.cc @@ -16,6 +16,7 @@ #include "plugin/device/cpu/kernel/argmin_cpu_kernel.h" #include #include +#include #include "mindspore/core/ops/arg_min.h" #include "plugin/device/cpu/hal/device/cpu_device_address.h" @@ -28,8 +29,8 @@ constexpr char kKernelName[] = "Argmin"; int64_t get_element_num(const std::vector &shape) { int64_t size = 1; - for (size_t i = 0; i < shape.size(); i++) { - size *= UlongToLong(shape[i]); + for (int64_t i = 0; i < static_cast(shape.size()); i++) { + size *= shape[i]; } return size; } @@ -40,11 +41,11 @@ bool check_validation(const std::vector &shape, const int64_t num_befor const std::vector &inputs, const std::vector &outputs) { CHECK_KERNEL_INPUTS_NUM(inputs.size(), kArgMinInputsNum, kKernelName); CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kArgMinOutputsNum, kKernelName); - int64_t inputs_size = get_element_num(shape) * static_cast(sizeof(T)); + int64_t input_size = get_element_num(shape) * static_cast(sizeof(T)); int64_t output_num = num_before_axis * num_after_axis; int64_t output_size = output_num * static_cast(sizeof(S)); - if (static_cast(inputs[0]->size) != inputs_size) { - MS_LOG(EXCEPTION) << "For '" << kKernelName << "', the memory size of 'input_x' must be equal to " << inputs_size + if (static_cast(inputs[0]->size) != input_size) { + MS_LOG(EXCEPTION) << "For '" << kKernelName << "', the memory size of 'input_x' must be equal to " << input_size << ", but got the memory size is " << inputs[0]->size; } if (static_cast(outputs[0]->size) != output_size) { @@ -58,7 +59,9 @@ template bool ArgminCpuKernelMod::LaunchKernel(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) { - (void)check_validation(shape_, num_before_axis_, num_after_axis_, inputs, outputs); + if (!check_validation(shape_, num_before_axis_, num_after_axis_, inputs, outputs)) { + return false; + } const auto *input = reinterpret_cast(inputs[0]->addr); auto ids_addr = reinterpret_cast(workspace[0]->addr); @@ -81,7 +84,7 @@ bool ArgminCpuKernelMod::LaunchKernel(const std::vector &inp auto min_ops = std::min_element(idx, idx + axis_size, comparator); auto min_index = iter.RevertPos(*min_ops); - output[index] = static_cast(min_index); + output[index] = min_index; } }; ParallelLaunchAutoSearch(task, axisIterator_.OuterSize() * axisIterator_.InnerSize(), this, ¶llel_search_info_); @@ -106,6 +109,7 @@ bool ArgminCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::v auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); if (!is_match) { MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr; + return false; } kernel_func_ = func_list_[index].second; return true; @@ -139,11 +143,11 @@ int ArgminCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std:: axis_ = axis_ % SizeToLong(shape_len); num_before_axis_ = 1; num_after_axis_ = 1; - for (size_t index = 0; index < shape_len; index++) { - if (SizeToLong(index) < axis_) { - num_before_axis_ *= shape_[index]; - } else if (SizeToLong(index) > axis_) { - num_after_axis_ *= shape_[index]; + for (size_t i = 0; i < shape_len; i++) { + if (SizeToLong(i) < axis_) { + num_before_axis_ *= shape_[i]; + } else if (SizeToLong(i) > axis_) { + num_after_axis_ *= shape_[i]; } } dim_axis_ = shape_[LongToSize(axis_)]; diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc index 3a1880a783c..c2e88f7eb6a 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc @@ -175,6 +175,8 @@ py::function PrimitivePy::GetBpropFunction() { py::function fn = python_obj_.attr(get_bprop_func_name)().cast(); return fn; } + auto fn = GetBpropFunctionByObj(python_obj_); + return fn; } py::function PrimitivePy::GetTaylorRuleFunction() { diff --git a/mindspore/core/ops/arg_min.cc b/mindspore/core/ops/arg_min.cc index 9ab2ae8051a..72e3cf46e22 100644 --- a/mindspore/core/ops/arg_min.cc +++ b/mindspore/core/ops/arg_min.cc @@ -101,8 +101,7 @@ abstract::AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); const int64_t input_num = 1; - (void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kEqual, input_num, - primitive->name()); + CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kEqual, input_num, primitive->name()); return abstract::MakeAbstract(ArgMinInferShape(primitive, input_args), ArgMinInferType(primitive, input_args)); } diff --git a/mindspore/core/ops/arg_min_v2.cc b/mindspore/core/ops/arg_min_v2.cc index f0a8c6538ee..79c6c2f45cc 100644 --- a/mindspore/core/ops/arg_min_v2.cc +++ b/mindspore/core/ops/arg_min_v2.cc @@ -53,9 +53,9 @@ void InferImplReduceFuncCalShape(const PrimitivePtr &primitive, ShapeVector *sha MS_EXCEPTION_IF_NULL(axis_ptr_list); axis_ptr_value_list = axis_ptr_list->value(); } - if (axis_ptr_value_list.size() < 1) { + if (!axis_ptr_value_list.size()) { MS_LOG(EXCEPTION) << "For '" << primitive->name() - << "', element of 'axis' must not be none if it is one of these types: [tuple/list]."; + << "', element of 'axis' must not be noe if it is one of these types: [tuple/list]."; } else { (void)shape->insert(shape->end(), x_shape.begin(), x_shape.end()); ValuePtrList axis_items = axis_ptr_value_list; @@ -152,8 +152,7 @@ abstract::AbstractBasePtr ArgminV2Infer(const abstract::AnalysisEnginePtr &, con const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); const int64_t input_num = 2; - (void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kEqual, input_num, - primitive->name()); + CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kEqual, input_num, primitive->name()); return abstract::MakeAbstract(ArgminV2InferShape(primitive, input_args), ArgminV2InferType(primitive, input_args)); } diff --git a/mindspore/core/ops/arg_min_v2.h b/mindspore/core/ops/arg_min_v2.h index 0cac7cf0014..33a96a82f9a 100644 --- a/mindspore/core/ops/arg_min_v2.h +++ b/mindspore/core/ops/arg_min_v2.h @@ -22,6 +22,7 @@ #include "ops/base_operator.h" #include "mindapi/base/types.h" +#include "mindapi/base/type_id.h" namespace mindspore { namespace ops {