forked from mindspore-Ecosystem/mindspore
回退 'Pull Request !39176 : clean code of Argmin'
This commit is contained in:
parent
cef8acb08b
commit
a9028b24e1
|
@ -16,6 +16,7 @@
|
|||
#include "plugin/device/cpu/kernel/argmin_cpu_kernel.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "mindspore/core/ops/arg_min.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
|
@ -28,8 +29,8 @@ constexpr char kKernelName[] = "Argmin";
|
|||
|
||||
int64_t get_element_num(const std::vector<int64_t> &shape) {
|
||||
int64_t size = 1;
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
size *= UlongToLong(shape[i]);
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(shape.size()); i++) {
|
||||
size *= shape[i];
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
@ -40,11 +41,11 @@ bool check_validation(const std::vector<int64_t> &shape, const int64_t num_befor
|
|||
const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kArgMinInputsNum, kKernelName);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kArgMinOutputsNum, kKernelName);
|
||||
int64_t inputs_size = get_element_num(shape) * static_cast<int64_t>(sizeof(T));
|
||||
int64_t input_size = get_element_num(shape) * static_cast<int64_t>(sizeof(T));
|
||||
int64_t output_num = num_before_axis * num_after_axis;
|
||||
int64_t output_size = output_num * static_cast<int64_t>(sizeof(S));
|
||||
if (static_cast<int64_t>(inputs[0]->size) != inputs_size) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kKernelName << "', the memory size of 'input_x' must be equal to " << inputs_size
|
||||
if (static_cast<int64_t>(inputs[0]->size) != input_size) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kKernelName << "', the memory size of 'input_x' must be equal to " << input_size
|
||||
<< ", but got the memory size is " << inputs[0]->size;
|
||||
}
|
||||
if (static_cast<int64_t>(outputs[0]->size) != output_size) {
|
||||
|
@ -58,7 +59,9 @@ template <typename T, typename S>
|
|||
bool ArgminCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
(void)check_validation<T, S>(shape_, num_before_axis_, num_after_axis_, inputs, outputs);
|
||||
if (!check_validation<T, S>(shape_, num_before_axis_, num_after_axis_, inputs, outputs)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto *input = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto ids_addr = reinterpret_cast<size_t *>(workspace[0]->addr);
|
||||
|
@ -81,7 +84,7 @@ bool ArgminCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inp
|
|||
auto min_ops = std::min_element(idx, idx + axis_size, comparator);
|
||||
auto min_index = iter.RevertPos(*min_ops);
|
||||
|
||||
output[index] = static_cast<S>(min_index);
|
||||
output[index] = min_index;
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, axisIterator_.OuterSize() * axisIterator_.InnerSize(), this, ¶llel_search_info_);
|
||||
|
@ -106,6 +109,7 @@ bool ArgminCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::v
|
|||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return true;
|
||||
|
@ -139,11 +143,11 @@ int ArgminCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::
|
|||
axis_ = axis_ % SizeToLong(shape_len);
|
||||
num_before_axis_ = 1;
|
||||
num_after_axis_ = 1;
|
||||
for (size_t index = 0; index < shape_len; index++) {
|
||||
if (SizeToLong(index) < axis_) {
|
||||
num_before_axis_ *= shape_[index];
|
||||
} else if (SizeToLong(index) > axis_) {
|
||||
num_after_axis_ *= shape_[index];
|
||||
for (size_t i = 0; i < shape_len; i++) {
|
||||
if (SizeToLong(i) < axis_) {
|
||||
num_before_axis_ *= shape_[i];
|
||||
} else if (SizeToLong(i) > axis_) {
|
||||
num_after_axis_ *= shape_[i];
|
||||
}
|
||||
}
|
||||
dim_axis_ = shape_[LongToSize(axis_)];
|
||||
|
|
|
@ -175,6 +175,8 @@ py::function PrimitivePy::GetBpropFunction() {
|
|||
py::function fn = python_obj_.attr(get_bprop_func_name)().cast<py::function>();
|
||||
return fn;
|
||||
}
|
||||
auto fn = GetBpropFunctionByObj(python_obj_);
|
||||
return fn;
|
||||
}
|
||||
|
||||
py::function PrimitivePy::GetTaylorRuleFunction() {
|
||||
|
|
|
@ -101,8 +101,7 @@ abstract::AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const
|
|||
const std::vector<abstract::AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kEqual, input_num,
|
||||
primitive->name());
|
||||
CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kEqual, input_num, primitive->name());
|
||||
return abstract::MakeAbstract(ArgMinInferShape(primitive, input_args), ArgMinInferType(primitive, input_args));
|
||||
}
|
||||
|
||||
|
|
|
@ -53,9 +53,9 @@ void InferImplReduceFuncCalShape(const PrimitivePtr &primitive, ShapeVector *sha
|
|||
MS_EXCEPTION_IF_NULL(axis_ptr_list);
|
||||
axis_ptr_value_list = axis_ptr_list->value();
|
||||
}
|
||||
if (axis_ptr_value_list.size() < 1) {
|
||||
if (!axis_ptr_value_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << primitive->name()
|
||||
<< "', element of 'axis' must not be none if it is one of these types: [tuple/list].";
|
||||
<< "', element of 'axis' must not be noe if it is one of these types: [tuple/list].";
|
||||
} else {
|
||||
(void)shape->insert(shape->end(), x_shape.begin(), x_shape.end());
|
||||
ValuePtrList axis_items = axis_ptr_value_list;
|
||||
|
@ -152,8 +152,7 @@ abstract::AbstractBasePtr ArgminV2Infer(const abstract::AnalysisEnginePtr &, con
|
|||
const std::vector<abstract::AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kEqual, input_num,
|
||||
primitive->name());
|
||||
CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kEqual, input_num, primitive->name());
|
||||
return abstract::MakeAbstract(ArgminV2InferShape(primitive, input_args), ArgminV2InferType(primitive, input_args));
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
#include "mindapi/base/type_id.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
|
Loading…
Reference in New Issue