forked from mindspore-Ecosystem/mindspore
!4263 [MS][LITE][Develop]argmax,argmin support keepdim
Merge pull request !4263 from chenjianping/lite_dev
This commit is contained in:
commit
756b834616
|
@ -38,9 +38,9 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
MS_LOG(ERROR) << "Invalid axis " << argmax_prim->axis() << ", input shape size: " << input_shape_size;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
if (argmax_prim->topK() == 1) {
|
||||
if (argmax_prim->topK() == 1 && !argmax_prim->keepDims()) {
|
||||
output_shape.erase(output_shape.begin() + axis);
|
||||
} else if (argmax_prim->axisType() == 1) {
|
||||
} else {
|
||||
output_shape[axis] = argmax_prim->topK();
|
||||
}
|
||||
|
||||
|
|
|
@ -37,9 +37,9 @@ int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
return RET_PARAM_INVALID;
|
||||
}
|
||||
std::vector<int> output_shape(input->shape());
|
||||
if (argmin_prim->topK() == 1) {
|
||||
if (argmin_prim->topK() == 1 && !argmin_prim->keepDims()) {
|
||||
output_shape.erase(output_shape.begin() + axis);
|
||||
} else if (argmin_prim->axisType() == 1) {
|
||||
} else {
|
||||
output_shape[axis] = argmin_prim->topK();
|
||||
}
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "src/runtime/kernel/arm/nnacl/arg_min_max.h"
|
||||
#include "src/runtime/kernel/arm/fp32/argminmax.h"
|
||||
#include "src/runtime/kernel/arm/int8/argminmax_int8.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_factory.h"
|
||||
#include "include/errorcode.h"
|
||||
|
@ -60,7 +61,7 @@ int ArgMinMaxBaseCPUKernel::ReSize() {
|
|||
return RET_PARAM_INVALID;
|
||||
}
|
||||
param->topk_ = MSMIN(param->topk_, in_shape[axis]);
|
||||
if (param->topk_ > 1) {
|
||||
if (param->topk_ > 1 || param->keep_dims_) {
|
||||
if (context_ != nullptr && context_->allocator != nullptr) {
|
||||
param->arg_elements_ =
|
||||
reinterpret_cast<ArgElement *>(context_->allocator->Malloc(sizeof(ArgElement) * in_shape[axis]));
|
||||
|
@ -73,6 +74,9 @@ int ArgMinMaxBaseCPUKernel::ReSize() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
ComputeStrides(in_shape.data(), param->in_strides_, in_shape.size());
|
||||
auto out_shape = outputs_.at(0)->shape();
|
||||
ComputeStrides(out_shape.data(), param->out_strides_, out_shape.size());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -89,7 +89,7 @@ void ArgMinMaxTopknFp32(const float *input, float *output, const int *in_shape,
|
|||
}
|
||||
|
||||
void ArgMinMax(const void *input, void *output, const int *in_shape, ArgMinMaxParameter *param) {
|
||||
if (param->topk_ == 1) {
|
||||
if (param->topk_ == 1 && !param->keep_dims_) {
|
||||
ArgMinMaxTopk1(input, output, in_shape, param);
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -40,6 +40,34 @@ TEST_F(TestArgMinMaxTestFp32, ArgMaxTest1) {
|
|||
param.data_type_ = 43;
|
||||
param.dims_size_ = 2;
|
||||
param.get_max_ = true;
|
||||
param.keep_dims_ = false;
|
||||
ArgMinMax(in.data(), out, shape.data(), ¶m);
|
||||
for (size_t i = 0; i < except_out.size(); ++i) {
|
||||
std::cout << out[i] << " ";
|
||||
}
|
||||
std::cout << "\n";
|
||||
CompareOutputData(out, except_out.data(), except_out.size(), 0.000001);
|
||||
}
|
||||
|
||||
TEST_F(TestArgMinMaxTestFp32, ArgMaxTest1_keep_dim) {
|
||||
std::vector<float> in = {10, 20, 30, 40, 90,
|
||||
20, 11, 15, 1, 50,
|
||||
30, 45, 25, 50, 30};
|
||||
std::vector<float> except_out = {2, 2, 0, 2, 0};
|
||||
std::vector<int> shape = {3, 5};
|
||||
float out[5];
|
||||
ArgMinMaxParameter param;
|
||||
param.topk_ = 1;
|
||||
param.out_value_ = false;
|
||||
param.axis_ = 0;
|
||||
param.data_type_ = 43;
|
||||
param.dims_size_ = 2;
|
||||
param.get_max_ = true;
|
||||
param.keep_dims_ = true;
|
||||
param.arg_elements_ = reinterpret_cast<ArgElement *>(malloc(shape[param.axis_] * sizeof(ArgElement)));
|
||||
std::vector<int> out_shape = {1, 5};
|
||||
ComputeStrides(shape.data(), param.in_strides_, shape.size());
|
||||
ComputeStrides(out_shape.data(), param.out_strides_, out_shape.size());
|
||||
ArgMinMax(in.data(), out, shape.data(), ¶m);
|
||||
for (size_t i = 0; i < except_out.size(); ++i) {
|
||||
std::cout << out[i] << " ";
|
||||
|
@ -62,6 +90,7 @@ TEST_F(TestArgMinMaxTestFp32, ArgMaxTest2) {
|
|||
param.data_type_ = 43;
|
||||
param.dims_size_ = 2;
|
||||
param.get_max_ = true;
|
||||
param.keep_dims_ = false;
|
||||
ArgMinMax(in.data(), out, shape.data(), ¶m);
|
||||
CompareOutputData(out, except_out.data(), except_out.size(), 0.000001);
|
||||
}
|
||||
|
@ -80,6 +109,7 @@ TEST_F(TestArgMinMaxTestFp32, ArgMinTest2) {
|
|||
param.data_type_ = 43;
|
||||
param.dims_size_ = 2;
|
||||
param.get_max_ = false;
|
||||
param.keep_dims_ = false;
|
||||
ArgMinMax(in.data(), out, shape.data(), ¶m);
|
||||
CompareOutputData(out, except_out.data(), except_out.size(), 0.000001);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue