回退 'Pull Request !39176 : clean code of Argmin'

This commit is contained in:
yanghaoran 2022-07-30 10:40:21 +00:00 committed by Gitee
parent cef8acb08b
commit a9028b24e1
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 23 additions and 18 deletions

View File

@ -16,6 +16,7 @@
#include "plugin/device/cpu/kernel/argmin_cpu_kernel.h"
#include <string>
#include <algorithm>
#include <utility>
#include "mindspore/core/ops/arg_min.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
@ -28,8 +29,8 @@ constexpr char kKernelName[] = "Argmin";
int64_t get_element_num(const std::vector<int64_t> &shape) {
int64_t size = 1;
for (size_t i = 0; i < shape.size(); i++) {
size *= UlongToLong(shape[i]);
for (int64_t i = 0; i < static_cast<int64_t>(shape.size()); i++) {
size *= shape[i];
}
return size;
}
@ -40,11 +41,11 @@ bool check_validation(const std::vector<int64_t> &shape, const int64_t num_befor
const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kArgMinInputsNum, kKernelName);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kArgMinOutputsNum, kKernelName);
int64_t inputs_size = get_element_num(shape) * static_cast<int64_t>(sizeof(T));
int64_t input_size = get_element_num(shape) * static_cast<int64_t>(sizeof(T));
int64_t output_num = num_before_axis * num_after_axis;
int64_t output_size = output_num * static_cast<int64_t>(sizeof(S));
if (static_cast<int64_t>(inputs[0]->size) != inputs_size) {
MS_LOG(EXCEPTION) << "For '" << kKernelName << "', the memory size of 'input_x' must be equal to " << inputs_size
if (static_cast<int64_t>(inputs[0]->size) != input_size) {
MS_LOG(EXCEPTION) << "For '" << kKernelName << "', the memory size of 'input_x' must be equal to " << input_size
<< ", but got the memory size is " << inputs[0]->size;
}
if (static_cast<int64_t>(outputs[0]->size) != output_size) {
@ -58,7 +59,9 @@ template <typename T, typename S>
bool ArgminCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) {
(void)check_validation<T, S>(shape_, num_before_axis_, num_after_axis_, inputs, outputs);
if (!check_validation<T, S>(shape_, num_before_axis_, num_after_axis_, inputs, outputs)) {
return false;
}
const auto *input = reinterpret_cast<T *>(inputs[0]->addr);
auto ids_addr = reinterpret_cast<size_t *>(workspace[0]->addr);
@ -81,7 +84,7 @@ bool ArgminCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inp
auto min_ops = std::min_element(idx, idx + axis_size, comparator);
auto min_index = iter.RevertPos(*min_ops);
output[index] = static_cast<S>(min_index);
output[index] = min_index;
}
};
ParallelLaunchAutoSearch(task, axisIterator_.OuterSize() * axisIterator_.InnerSize(), this, &parallel_search_info_);
@ -106,6 +109,7 @@ bool ArgminCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::v
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
return true;
@ -139,11 +143,11 @@ int ArgminCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::
axis_ = axis_ % SizeToLong(shape_len);
num_before_axis_ = 1;
num_after_axis_ = 1;
for (size_t index = 0; index < shape_len; index++) {
if (SizeToLong(index) < axis_) {
num_before_axis_ *= shape_[index];
} else if (SizeToLong(index) > axis_) {
num_after_axis_ *= shape_[index];
for (size_t i = 0; i < shape_len; i++) {
if (SizeToLong(i) < axis_) {
num_before_axis_ *= shape_[i];
} else if (SizeToLong(i) > axis_) {
num_after_axis_ *= shape_[i];
}
}
dim_axis_ = shape_[LongToSize(axis_)];

View File

@ -175,6 +175,8 @@ py::function PrimitivePy::GetBpropFunction() {
py::function fn = python_obj_.attr(get_bprop_func_name)().cast<py::function>();
return fn;
}
auto fn = GetBpropFunctionByObj(python_obj_);
return fn;
}
py::function PrimitivePy::GetTaylorRuleFunction() {

View File

@ -101,8 +101,7 @@ abstract::AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const
const std::vector<abstract::AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 1;
(void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kEqual, input_num,
primitive->name());
CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kEqual, input_num, primitive->name());
return abstract::MakeAbstract(ArgMinInferShape(primitive, input_args), ArgMinInferType(primitive, input_args));
}

View File

@ -53,9 +53,9 @@ void InferImplReduceFuncCalShape(const PrimitivePtr &primitive, ShapeVector *sha
MS_EXCEPTION_IF_NULL(axis_ptr_list);
axis_ptr_value_list = axis_ptr_list->value();
}
if (axis_ptr_value_list.size() < 1) {
if (!axis_ptr_value_list.size()) {
MS_LOG(EXCEPTION) << "For '" << primitive->name()
<< "', element of 'axis' must not be none if it is one of these types: [tuple/list].";
<< "', element of 'axis' must not be noe if it is one of these types: [tuple/list].";
} else {
(void)shape->insert(shape->end(), x_shape.begin(), x_shape.end());
ValuePtrList axis_items = axis_ptr_value_list;
@ -152,8 +152,7 @@ abstract::AbstractBasePtr ArgminV2Infer(const abstract::AnalysisEnginePtr &, con
const std::vector<abstract::AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kEqual, input_num,
primitive->name());
CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kEqual, input_num, primitive->name());
return abstract::MakeAbstract(ArgminV2InferShape(primitive, input_args), ArgminV2InferType(primitive, input_args));
}

View File

@ -22,6 +22,7 @@
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
#include "mindapi/base/type_id.h"
namespace mindspore {
namespace ops {