!8589 add opencl argmin and argmax

From: @ddwsky
Reviewed-by: @zhanghaibo5,@HilbertDavid
Signed-off-by: @HilbertDavid
This commit is contained in:
mindspore-ci-bot 2020-11-16 18:31:17 +08:00 committed by Gitee
commit f1e0a673ac
7 changed files with 601 additions and 4 deletions

View File

@ -0,0 +1,72 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define swap(a, b, c) \
c = a; \
a = b; \
b = c;
#define swap_atomic(a, b, c) \
c = atomic_xchg(a, *(b)); \
c = atomic_xchg(b, c);
#define UP_ROUND(a, b) (((a + b - 1) / b) * b)
#define C4NUM 4
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__kernel void argminmax(__global FLT *src_data, __global FLT *dst_data, __global FLT *buf, __global int *ids,
int4 shape, int4 src_size, int4 cus_size, int4 strides, int4 flags) {
int X = get_global_id(0); // reduce len
int Y = get_global_id(1); // upper axis accumulation
if (X >= src_size.x || Y >= src_size.y) {
return;
}
int offset = X + Y * src_size.z;
int align_c4 = (flags.z != 3) ? (X / shape.w) * (shape.x) : 0;
int align_in = 0;
int align_out = 0;
if (flags.z == 3) {
align_in = (Y / shape.z) * cus_size.z;
align_out = (Y / shape.z) * cus_size.w;
}
if (flags.z == 0) {
align_in = X / (shape.y) * cus_size.z;
align_out = align_in;
}
for (int k = 0; k < src_size.w; ++k) {
int idx0 = (X + k * strides.x) + Y * strides.y + (align_c4 + align_in);
int idx1 = offset + k * src_size.x;
ids[idx1] = k;
buf[idx1] = src_data[idx0];
}
for (unsigned int i = 2; i <= cus_size.x; i <<= 1) {
for (unsigned int j = i >> 1; j > 0; j >>= 1) {
for (int tid = 0; tid < src_size.w; ++tid) {
unsigned int tid_comp = tid + j;
if (tid_comp < src_size.w) {
int lk = offset + tid * src_size.x;
int rk = offset + tid_comp * src_size.x;
if ((tid & i) == 0) { // ascending
if (buf[lk] > buf[rk]) {
FLT tmpf;
swap(buf[lk], buf[rk], tmpf);
int tmpi;
swap(ids[lk], ids[rk], tmpi);
}
} else { // desending
if (buf[lk] < buf[rk]) {
FLT tmpf;
swap(buf[lk], buf[rk], tmpf);
int tmpi;
swap(ids[lk], ids[rk], tmpi);
}
}
}
}
}
}
for (int k = 0; k < flags.w; ++k) {
int idx0 = (X + k * strides.z) + Y * strides.w + (align_c4 + align_out);
int idx1 = flags.y ? (offset + (src_size.w - k - 1) * src_size.x) : (offset + k * src_size.x);
if (flags.x) {
dst_data[idx0] = buf[idx1];
} else {
dst_data[idx0] = ids[idx1];
}
}
}

View File

@ -0,0 +1,164 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstring>
#include <string>
#include <algorithm>
#include <set>
#include <utility>
#include <functional>
#include "src/kernel_registry.h"
#include "src/runtime/kernel/opencl/utils.h"
#include "src/runtime/kernel/opencl/kernel/argminmax.h"
#include "src/runtime/kernel/opencl/cl/argminmax.cl.inc"
using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_ArgMax;
using mindspore::schema::PrimitiveType_ArgMin;
namespace mindspore::kernel {
int ArgMinMaxOpenCLKernel::CheckSpecs() {
if (in_tensors_[0]->data_type() != kNumberTypeFloat32 && in_tensors_[0]->data_type() != kNumberTypeFloat16) {
MS_LOG(ERROR) << "Unsupported data type " << in_tensors_[0]->data_type();
return RET_ERROR;
}
if (in_tensors_[0]->shape().size() > 4 && in_tensors_[0]->shape().size() == 0) {
MS_LOG(ERROR) << "input shape size must be (1-4), actual: " << in_tensors_[0]->shape().size() << ", "
<< out_tensors_[0]->shape().size();
return RET_ERROR;
}
auto *param = reinterpret_cast<ArgMinMaxParameter *>(this->op_parameter_);
param->dims_size_ = in_tensors_[0]->shape().size();
param->axis_ = (param->axis_ + param->dims_size_) % param->dims_size_;
if (param->axis_ < 0 || param->axis_ >= param->dims_size_) {
MS_LOG(ERROR) << "Invalid axis " << param->axis_;
return RET_ERROR;
}
param->get_max_ = (op_parameter_->type_ == PrimitiveType_ArgMax);
return RET_OK;
}
void ArgMinMaxOpenCLKernel::SetConstArgs() {
auto param = reinterpret_cast<ArgMinMaxParameter *>(op_parameter_);
cl_int4 in_shape{static_cast<int>(im_in_.N), static_cast<int>(im_in_.H), static_cast<int>(im_in_.W),
static_cast<int>(im_in_.C)};
in_shape.s[0] = UP_ROUND(im_in_.C, C4NUM) - im_in_.C;
in_shape.s[1] = im_in_.W * im_in_.C;
cl_int4 flags = {param->out_value_, param->get_max_, param->axis_, param->topk_};
int arg_cnt = 2;
ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, buff_, lite::opencl::MemType::BUF);
ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, ids_, lite::opencl::MemType::BUF);
ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, in_shape);
ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, src_size_);
ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, cus_size_);
ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, strides_);
ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, flags);
}
void ArgMinMaxOpenCLKernel::SetGlobalLocal() {
auto param = reinterpret_cast<ArgMinMaxParameter *>(op_parameter_);
auto in_shape = in_tensors_[0]->shape();
auto in_shape_align = in_shape;
in_shape_align[3] = UP_ROUND(in_shape[3], C4NUM);
im_in_ = Image2DInfo(in_tensors_[0]);
auto out_shape_align = in_shape_align;
out_shape_align.at(param->axis_) = param->axis_ == 3 ? UP_ROUND(param->topk_, C4NUM) : param->topk_;
int reduce_len = GetUpPow2(in_shape.at(param->axis_));
cus_size_ = {reduce_len, static_cast<int>(im_in_.RowPitch() / C4NUM), 1, 1};
cus_size_.s[2] = UP_ROUND(im_in_.width * C4NUM, cus_size_.s[1]) - im_in_.width * C4NUM;
cus_size_.s[3] = im_in_.W * UP_ROUND(param->topk_, C4NUM);
cus_size_.s[3] = UP_ROUND(cus_size_.s[3], cus_size_.s[1]) - cus_size_.s[3];
src_size_ = {std::accumulate(in_shape.begin() + param->axis_ + 1, in_shape.end(), 1, std::multiplies<int>()),
std::accumulate(in_shape.begin(), in_shape.begin() + param->axis_, 1, std::multiplies<int>()),
std::accumulate(in_shape.begin() + param->axis_, in_shape.end(), 1, std::multiplies<int>()),
in_shape.at(param->axis_)};
strides_ = {
std::accumulate(in_shape_align.begin() + param->axis_ + 1, in_shape_align.end(), 1, std::multiplies<int>()),
std::accumulate(in_shape_align.begin() + param->axis_, in_shape_align.end(), 1, std::multiplies<int>()),
std::accumulate(out_shape_align.begin() + param->axis_ + 1, out_shape_align.end(), 1, std::multiplies<int>()),
std::accumulate(out_shape_align.begin() + param->axis_, out_shape_align.end(), 1, std::multiplies<int>()),
};
switch (param->axis_) {
case 0:
strides_.s[0] = UP_ROUND(strides_.s[0] / im_in_.H, cus_size_.s[1]) * im_in_.H;
strides_.s[1] = strides_.s[0] * im_in_.N;
strides_.s[2] = UP_ROUND(strides_.s[2] / im_in_.H, cus_size_.s[1]) * im_in_.H;
strides_.s[3] = strides_.s[2] * param->topk_;
break;
case 1:
strides_.s[0] = UP_ROUND(strides_.s[0], cus_size_.s[1]);
strides_.s[1] = UP_ROUND(strides_.s[1] / im_in_.H, cus_size_.s[1]) * im_in_.H;
strides_.s[2] = UP_ROUND(strides_.s[2], cus_size_.s[1]);
strides_.s[3] = UP_ROUND(strides_.s[3] / param->topk_, cus_size_.s[1]) * param->topk_;
break;
case 2:
strides_.s[1] = UP_ROUND(strides_.s[1], cus_size_.s[1]);
strides_.s[3] = UP_ROUND(strides_.s[3], cus_size_.s[1]);
break;
default: // 3
break;
}
std::vector<size_t> local = {1, 1, 1};
std::vector<size_t> global = {static_cast<size_t>(strides_.s[0]), static_cast<size_t>(src_size_.s[1]), 1};
OpenCLKernel::AlignGlobalLocal(global, local);
}
int ArgMinMaxOpenCLKernel::InitWeights() {
auto allocator = ocl_runtime_->GetAllocator();
int dtype_size = ocl_runtime_->GetFp16Enable() ? sizeof(int16_t) : sizeof(float);
buff_ = allocator->Malloc(in_tensors_[0]->ElementsNum() * dtype_size);
ids_ = allocator->Malloc(in_tensors_[0]->ElementsNum() * sizeof(int32_t));
return RET_OK;
}
int ArgMinMaxOpenCLKernel::Prepare() {
std::string kernel_name = "argminmax";
#ifdef PROGRAM_WITH_IL
kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name);
#else
std::set<std::string> build_options;
std::string source = argminmax_source;
std::string program_name = "argminmax";
ocl_runtime_->LoadSource(program_name, source);
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options);
#endif
InitWeights();
SetGlobalLocal();
SetConstArgs();
MS_LOG(DEBUG) << kernel_name << " Init Done!";
return RET_OK;
}
int ArgMinMaxOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running! ";
ocl_runtime_->SetKernelArg(kernel_, 0, in_tensors_[0]->data_c(), lite::opencl::MemType::BUF);
ocl_runtime_->SetKernelArg(kernel_, 1, out_tensors_[0]->data_c(), lite::opencl::MemType::BUF);
ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr);
return RET_OK;
}
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_ArgMin, OpenCLKernelCreator<ArgMinMaxOpenCLKernel>);
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_ArgMin, OpenCLKernelCreator<ArgMinMaxOpenCLKernel>);
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_ArgMax, OpenCLKernelCreator<ArgMinMaxOpenCLKernel>);
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_ArgMax, OpenCLKernelCreator<ArgMinMaxOpenCLKernel>);
} // namespace mindspore::kernel

View File

@ -0,0 +1,52 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_BATCH_TO_SPACE_ND_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_BATCH_TO_SPACE_ND_H_
#include <vector>
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "nnacl/arg_min_max_parameter.h"
namespace mindspore::kernel {
class ArgMinMaxOpenCLKernel : public OpenCLKernel {
public:
ArgMinMaxOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs)
: OpenCLKernel(parameter, inputs, outputs) {}
~ArgMinMaxOpenCLKernel() override = default;
int Run() override;
int Prepare() override;
int CheckSpecs() override;
void SetConstArgs() override;
void SetGlobalLocal() override;
int InitWeights() override;
private:
cl::Kernel kernel_;
void *buff_{nullptr};
void *ids_{nullptr};
Image2DInfo im_in_{Image2DInfo(nullptr)};
cl_int4 src_size_;
cl_int4 cus_size_;
cl_int4 strides_;
};
} // namespace mindspore::kernel
#endif

View File

@ -40,6 +40,17 @@ kernel::LiteKernel *GetOpenCLKernel(const std::vector<Tensor *> &in_tensors, con
namespace mindspore::kernel {
int GetUpPow2(int n) {
int i = 0;
int j = 0;
while (n > 0) {
j += n & 1;
n = n >> 1;
i++;
}
return 1 << (i - (j == 1));
}
int GetMaxDivisor(int x, int divisor) {
int i = divisor;
while (i > 0) {

View File

@ -34,6 +34,8 @@ kernel::LiteKernel *GetOpenCLKernel(const std::vector<Tensor *> &in_tensors, con
namespace mindspore::kernel {
int GetUpPow2(int n);
int GetMaxDivisor(int x, int divisor);
int GetMaxDivisorStrategy0(int x, int divisor);

View File

@ -157,7 +157,19 @@ int OpenCLRuntime::Init() {
}
#else
MS_LOG(INFO) << "Create common opencl context";
#ifdef Debug
std::vector<cl_context_properties> ctx_properties = {CL_CONTEXT_PLATFORM,
(cl_context_properties)platforms[0](),
CL_PRINTF_CALLBACK_ARM,
(cl_context_properties)printf_callback,
CL_PRINTF_BUFFERSIZE_ARM,
0x1000000,
0};
context_ =
new (std::nothrow) cl::Context(std::vector<cl::Device>{*device_}, ctx_properties.data(), nullptr, nullptr, &ret);
#else
context_ = new (std::nothrow) cl::Context(std::vector<cl::Device>{*device_}, nullptr, nullptr, nullptr, &ret);
#endif
#endif
if (ret != CL_SUCCESS) {
delete device_;
@ -201,9 +213,10 @@ int OpenCLRuntime::Init() {
MS_LOG(INFO) << "Compute Unit: " << compute_units_;
MS_LOG(INFO) << "Clock Frequency: " << max_freq_ << " MHz";
#ifdef Debug
const cl_command_queue_properties properties = CL_QUEUE_PROFILING_ENABLE;
#else
const cl_command_queue_properties properties = 0;
#if MS_OPENCL_PROFILE
properties |= CL_QUEUE_PROFILING_ENABLE;
#endif
default_command_queue_ = new (std::nothrow) cl::CommandQueue(*context_, *device_, properties, &ret);
@ -412,7 +425,7 @@ int OpenCLRuntime::RunKernel(const cl::Kernel &kernel, const std::vector<size_t>
}
cnt++;
MS_LOG(DEBUG) << "RunKernel success!";
#if MS_OPENCL_PROFILE
#ifdef Debug
event.wait();
cl_ulong time_start;
cl_ulong time_end;
@ -445,7 +458,7 @@ int OpenCLRuntime::RunKernel(const cl::Kernel &kernel, const cl::NDRange &global
}
cnt++;
MS_LOG(DEBUG) << "RunKernel success!";
#if MS_OPENCL_PROFILE
#ifdef Debug
event.wait();
cl_ulong time_start;
cl_ulong time_end;

View File

@ -0,0 +1,283 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "src/common/log_adapter.h"
#include "common/common_test.h"
#include "src/runtime/kernel/opencl/utils.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.h"
namespace mindspore {
class TestArgMinMaxOpenCL : public mindspore::CommonTest {
public:
TestArgMinMaxOpenCL() {}
};
template <typename T>
void test_main_argminmax(void *input_data, void *correct_data, const std::vector<int> &input_shape,
const std::vector<int> &output_shape, ArgMinMaxParameter *param, TypeId data_type,
schema::Format format) {
MS_LOG(INFO) << " begin test ";
auto ocl_runtime_wrap = lite::opencl::OpenCLRuntimeWrapper();
auto ocl_runtime = ocl_runtime_wrap.GetInstance();
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();
auto tensor_a = lite::Tensor(TypeId(data_type), input_shape, format);
auto tensor_c = lite::Tensor(TypeId(data_type), output_shape, format);
std::vector<lite::Tensor *> inputs{&tensor_a};
std::vector<lite::Tensor *> outputs{&tensor_c};
size_t input_size = tensor_a.Size();
auto *pkernel =
new (std::nothrow) kernel::ArgMinMaxOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
if (pkernel == nullptr) {
MS_LOG(INFO) << "new SpaceToBatchNDOpenCLKernel failed ";
return;
}
pkernel->Init();
// to do allocate memory for inputs and outputs
for (auto &input_tensor : inputs) {
input_tensor->MallocData(allocator);
}
MS_LOG(INFO) << " initialize sub_graph ";
std::vector<kernel::LiteKernel *> kernels{pkernel};
auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
if (sub_graph == nullptr) {
delete pkernel;
MS_LOG(INFO) << " new SubGraphOpenCLKernel failed ";
return;
}
sub_graph->Init();
MS_LOG(INFO) << " init tensors ";
T *input_ptr = reinterpret_cast<T *>(inputs[0]->MutableData());
memcpy(input_ptr, input_data, input_size);
std::cout << "==================input data================" << std::endl;
for (auto i = 0; i < inputs[0]->ElementsNum(); ++i) {
std::cout << input_ptr[i] << ", ";
}
std::cout << std::endl;
sub_graph->Run();
auto *output_data = reinterpret_cast<T *>(outputs[0]->MutableData());
std::cout << "==================output data================" << std::endl;
for (auto i = 0; i < outputs[0]->ElementsNum(); ++i) {
std::cout << output_data[i] << ", ";
}
std::cout << std::endl;
std::cout << "==================correct data================" << std::endl;
for (auto i = 0; i < outputs[0]->ElementsNum(); ++i) {
std::cout << static_cast<T *>(correct_data)[i] << ", ";
}
std::cout << std::endl;
CommonTest::CompareOutputData<T>(output_data, static_cast<T *>(correct_data), outputs[0]->ElementsNum(), 0.0001);
delete sub_graph;
}
TEST_F(TestArgMinMaxOpenCL, axis0topk2index) {
ArgMinMaxParameter *param = std::make_unique<ArgMinMaxParameter>().release();
if (param == nullptr) {
return;
}
std::vector<float> in_data = {100, 2, 4, 50, 11, 12, 34, 35, 10, 20, 40, 5,
7, 80, 10, 11, 55, 25, 5, 15, 18, 8, 15, 16};
std::vector<float> except_out = {0, 2, 1, 0, 2, 1, 0, 0, 2, 1, 2, 2, 0, 0, 2, 2};
param->dims_size_ = 4;
param->axis_ = 0;
param->topk_ = 2;
param->get_max_ = true;
param->out_value_ = false;
std::vector<int> in_shape = {3, 2, 2, 2};
std::vector<int> out_shape = {2, 2, 2, 2};
TypeId data_type = kNumberTypeFloat32;
schema::Format format = schema::Format_NHWC;
test_main_argminmax<float>(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format);
}
TEST_F(TestArgMinMaxOpenCL, axis0topk2value) {
ArgMinMaxParameter *param = std::make_unique<ArgMinMaxParameter>().release();
if (param == nullptr) {
return;
}
std::vector<float> in_data = {100, 2, 4, 50, 11, 12, 34, 35, 10, 20, 40, 5,
7, 80, 10, 11, 55, 25, 5, 15, 18, 8, 15, 16};
std::vector<float> except_out = {100, 25, 40, 50, 18, 80, 34, 35, 55, 20, 5, 15, 11, 12, 15, 16};
param->dims_size_ = 4;
param->axis_ = 0;
param->topk_ = 2;
param->get_max_ = true;
param->out_value_ = true;
std::vector<int> in_shape = {3, 2, 2, 2};
std::vector<int> out_shape = {2, 2, 2, 2};
TypeId data_type = kNumberTypeFloat32;
schema::Format format = schema::Format_NHWC;
test_main_argminmax<float>(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format);
}
TEST_F(TestArgMinMaxOpenCL, axis1topk2index) {
ArgMinMaxParameter *param = std::make_unique<ArgMinMaxParameter>().release();
if (param == nullptr) {
return;
}
std::vector<float> in_data = {100, 2, 200, 4, 50, 6, 11, 12, 13, 34, 35, 36, 9, 6, 17, 10, 20, 30,
10, 20, 30, 40, 5, 60, 7, 80, 90, 10, 11, 120, 18, 5, 16, 9, 22, 23};
std::vector<float> except_out = {0, 1, 0, 1, 0, 1, 1, 2, 2, 2, 1, 2, 2, 1, 1, 0, 2, 1, 0, 0, 0, 1, 1, 0};
param->dims_size_ = 4;
param->axis_ = 1;
param->topk_ = 2;
param->get_max_ = true;
param->out_value_ = false;
std::vector<int> in_shape = {2, 3, 2, 3};
std::vector<int> out_shape = {2, 2, 2, 3};
TypeId data_type = kNumberTypeFloat32;
schema::Format format = schema::Format_NHWC;
test_main_argminmax<float>(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format);
}
TEST_F(TestArgMinMaxOpenCL, axis1topk2value) {
ArgMinMaxParameter *param = std::make_unique<ArgMinMaxParameter>().release();
if (param == nullptr) {
return;
}
std::vector<float> in_data = {100, 2, 200, 4, 50, 6, 11, 12, 13, 34, 35, 36, 9, 6, 17, 10, 20, 30,
10, 20, 30, 40, 5, 60, 7, 80, 90, 10, 11, 120, 18, 5, 16, 9, 22, 23};
std::vector<float> except_out = {100, 12, 200, 34, 50, 36, 11, 6, 17, 10, 35, 30,
18, 80, 90, 40, 22, 120, 10, 20, 30, 10, 11, 60};
param->dims_size_ = 4;
param->axis_ = 1;
param->topk_ = 2;
param->get_max_ = true;
param->out_value_ = true;
std::vector<int> in_shape = {2, 3, 2, 3};
std::vector<int> out_shape = {2, 2, 2, 3};
TypeId data_type = kNumberTypeFloat32;
schema::Format format = schema::Format_NHWC;
test_main_argminmax<float>(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format);
}
TEST_F(TestArgMinMaxOpenCL, axis2topk1index) {
ArgMinMaxParameter *param = std::make_unique<ArgMinMaxParameter>().release();
if (param == nullptr) {
return;
}
param->dims_size_ = 4;
param->axis_ = 2;
param->topk_ = 1;
param->get_max_ = true;
param->out_value_ = false;
std::vector<float> in_data = {10, 20, 30, 11, 15, 10, 5, 10, 12, 10, 20, 30, 11, 15, 10, 5, 10, 12,
10, 20, 30, 11, 15, 10, 5, 10, 12, 10, 20, 30, 11, 15, 10, 5, 10, 12,
10, 20, 30, 11, 15, 10, 5, 10, 12, 10, 20, 30, 11, 15, 10, 5, 10, 12};
std::vector<float> except_out = {1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0};
std::vector<int> in_shape = {2, 3, 3, 3};
std::vector<int> out_shape = {2, 3, 1, 3};
TypeId data_type = kNumberTypeFloat32;
schema::Format format = schema::Format_NHWC;
test_main_argminmax<float>(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format);
}
TEST_F(TestArgMinMaxOpenCL, axis2topk2value) {
ArgMinMaxParameter *param = std::make_unique<ArgMinMaxParameter>().release();
if (param == nullptr) {
return;
}
std::vector<float> in_data = {10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90,
20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50,
30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30};
std::vector<float> except_out = {30, 45, 30, 50, 90, 20, 20, 25, 40, 50, 30, 45, 30, 50, 90, 20, 20, 25, 40, 50,
30, 45, 30, 50, 90, 20, 20, 25, 40, 50, 30, 45, 30, 50, 90, 20, 20, 25, 40, 50};
param->dims_size_ = 4;
param->axis_ = 2;
param->topk_ = 2;
param->get_max_ = true;
param->out_value_ = true;
std::vector<int> in_shape = {2, 2, 3, 5};
std::vector<int> out_shape = {1, 2, 2, 5};
TypeId data_type = kNumberTypeFloat32;
schema::Format format = schema::Format_NHWC;
test_main_argminmax<float>(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format);
}
TEST_F(TestArgMinMaxOpenCL, axis2topk2index) {
ArgMinMaxParameter *param = std::make_unique<ArgMinMaxParameter>().release();
if (param == nullptr) {
return;
}
std::vector<float> in_data = {10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90,
20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50,
30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30};
std::vector<float> except_out = {2, 2, 0, 2, 0, 1, 0, 2, 0, 1, 2, 2, 0, 2, 0, 1, 0, 2, 0, 1,
2, 2, 0, 2, 0, 1, 0, 2, 0, 1, 2, 2, 0, 2, 0, 1, 0, 2, 0, 1};
param->dims_size_ = 4;
param->axis_ = 2;
param->topk_ = 2;
param->get_max_ = true;
param->out_value_ = false;
std::vector<int> in_shape = {2, 2, 3, 5};
std::vector<int> out_shape = {2, 2, 2, 5};
TypeId data_type = kNumberTypeFloat32;
schema::Format format = schema::Format_NHWC;
test_main_argminmax<float>(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format);
}
TEST_F(TestArgMinMaxOpenCL, axis3topk2index) {
ArgMinMaxParameter *param = std::make_unique<ArgMinMaxParameter>().release();
if (param == nullptr) {
return;
}
std::vector<float> in_data = {10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90,
20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50,
30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30};
std::vector<float> except_out = {4, 3, 4, 0, 3, 1, 4, 3, 4, 0, 3, 1, 4, 3, 4, 0, 3, 1, 4, 3, 4, 0, 3, 1};
param->dims_size_ = 4;
param->axis_ = 3;
param->topk_ = 2;
param->get_max_ = true;
param->out_value_ = false;
std::vector<int> in_shape = {2, 2, 3, 5};
std::vector<int> out_shape = {2, 2, 3, 2};
TypeId data_type = kNumberTypeFloat32;
schema::Format format = schema::Format_NHWC;
test_main_argminmax<float>(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format);
}
TEST_F(TestArgMinMaxOpenCL, axis3topk2value) {
ArgMinMaxParameter *param = std::make_unique<ArgMinMaxParameter>().release();
if (param == nullptr) {
return;
}
std::vector<float> in_data = {10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90,
20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50,
30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30};
std::vector<float> except_out = {90, 40, 50, 20, 50, 45, 90, 40, 50, 20, 50, 45,
90, 40, 50, 20, 50, 45, 90, 40, 50, 20, 50, 45};
param->dims_size_ = 4;
param->axis_ = 3;
param->topk_ = 2;
param->get_max_ = true;
param->out_value_ = true;
std::vector<int> in_shape = {2, 2, 3, 5};
std::vector<int> out_shape = {2, 2, 3, 2};
TypeId data_type = kNumberTypeFloat32;
schema::Format format = schema::Format_NHWC;
test_main_argminmax<float>(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format);
}
} // namespace mindspore