From 6768d4f6c84778ba849fc68be1c099a21951c3f1 Mon Sep 17 00:00:00 2001 From: wandongdong Date: Thu, 12 Nov 2020 01:38:58 -0800 Subject: [PATCH] add opencl argminmax --- .../src/runtime/kernel/opencl/cl/argminmax.cl | 72 +++++ .../runtime/kernel/opencl/kernel/argminmax.cc | 164 ++++++++++ .../runtime/kernel/opencl/kernel/argminmax.h | 52 ++++ .../lite/src/runtime/kernel/opencl/utils.cc | 11 + .../lite/src/runtime/kernel/opencl/utils.h | 2 + .../lite/src/runtime/opencl/opencl_runtime.cc | 21 +- .../runtime/kernel/opencl/argminmax_tests.cc | 283 ++++++++++++++++++ 7 files changed, 601 insertions(+), 4 deletions(-) create mode 100644 mindspore/lite/src/runtime/kernel/opencl/cl/argminmax.cl create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.cc create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.h create mode 100644 mindspore/lite/test/ut/src/runtime/kernel/opencl/argminmax_tests.cc diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/argminmax.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/argminmax.cl new file mode 100644 index 00000000000..743eabe5bdb --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/argminmax.cl @@ -0,0 +1,72 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#define swap(a, b, c) \ + c = a; \ + a = b; \ + b = c; +#define swap_atomic(a, b, c) \ + c = atomic_xchg(a, *(b)); \ + c = atomic_xchg(b, c); +#define UP_ROUND(a, b) (((a + b - 1) / b) * b) +#define C4NUM 4 +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +__kernel void argminmax(__global FLT *src_data, __global FLT *dst_data, __global FLT *buf, __global int *ids, + int4 shape, int4 src_size, int4 cus_size, int4 strides, int4 flags) { + int X = get_global_id(0); // reduce len + int Y = get_global_id(1); // upper axis accumulation + if (X >= src_size.x || Y >= src_size.y) { + return; + } + int offset = X + Y * src_size.z; + int align_c4 = (flags.z != 3) ? (X / shape.w) * (shape.x) : 0; + int align_in = 0; + int align_out = 0; + if (flags.z == 3) { + align_in = (Y / shape.z) * cus_size.z; + align_out = (Y / shape.z) * cus_size.w; + } + if (flags.z == 0) { + align_in = X / (shape.y) * cus_size.z; + align_out = align_in; + } + for (int k = 0; k < src_size.w; ++k) { + int idx0 = (X + k * strides.x) + Y * strides.y + (align_c4 + align_in); + int idx1 = offset + k * src_size.x; + ids[idx1] = k; + buf[idx1] = src_data[idx0]; + } + for (unsigned int i = 2; i <= cus_size.x; i <<= 1) { + for (unsigned int j = i >> 1; j > 0; j >>= 1) { + for (int tid = 0; tid < src_size.w; ++tid) { + unsigned int tid_comp = tid + j; + if (tid_comp < src_size.w) { + int lk = offset + tid * src_size.x; + int rk = offset + tid_comp * src_size.x; + if ((tid & i) == 0) { // ascending + if (buf[lk] > buf[rk]) { + FLT tmpf; + swap(buf[lk], buf[rk], tmpf); + int tmpi; + swap(ids[lk], ids[rk], tmpi); + } + } else { // desending + if (buf[lk] < buf[rk]) { + FLT tmpf; + swap(buf[lk], buf[rk], tmpf); + int tmpi; + swap(ids[lk], ids[rk], tmpi); + } + } + } + } + } + } + for (int k = 0; k < flags.w; ++k) { + int idx0 = (X + k * strides.z) + Y * strides.w + (align_c4 + align_out); + int idx1 = flags.y ? (offset + (src_size.w - k - 1) * src_size.x) : (offset + k * src_size.x); + if (flags.x) { + dst_data[idx0] = buf[idx1]; + } else { + dst_data[idx0] = ids[idx1]; + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.cc new file mode 100644 index 00000000000..ab53aeb1966 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.cc @@ -0,0 +1,164 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include "src/kernel_registry.h" +#include "src/runtime/kernel/opencl/utils.h" +#include "src/runtime/kernel/opencl/kernel/argminmax.h" +#include "src/runtime/kernel/opencl/cl/argminmax.cl.inc" + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_ArgMax; +using mindspore::schema::PrimitiveType_ArgMin; + +namespace mindspore::kernel { + +int ArgMinMaxOpenCLKernel::CheckSpecs() { + if (in_tensors_[0]->data_type() != kNumberTypeFloat32 && in_tensors_[0]->data_type() != kNumberTypeFloat16) { + MS_LOG(ERROR) << "Unsupported data type " << in_tensors_[0]->data_type(); + return RET_ERROR; + } + if (in_tensors_[0]->shape().size() > 4 && in_tensors_[0]->shape().size() == 0) { + MS_LOG(ERROR) << "input shape size must be (1-4), actual: " << in_tensors_[0]->shape().size() << ", " + << out_tensors_[0]->shape().size(); + return RET_ERROR; + } + auto *param = reinterpret_cast(this->op_parameter_); + param->dims_size_ = in_tensors_[0]->shape().size(); + param->axis_ = (param->axis_ + param->dims_size_) % param->dims_size_; + if (param->axis_ < 0 || param->axis_ >= param->dims_size_) { + MS_LOG(ERROR) << "Invalid axis " << param->axis_; + return RET_ERROR; + } + param->get_max_ = (op_parameter_->type_ == PrimitiveType_ArgMax); + return RET_OK; +} + +void ArgMinMaxOpenCLKernel::SetConstArgs() { + auto param = reinterpret_cast(op_parameter_); + cl_int4 in_shape{static_cast(im_in_.N), static_cast(im_in_.H), static_cast(im_in_.W), + static_cast(im_in_.C)}; + in_shape.s[0] = UP_ROUND(im_in_.C, C4NUM) - im_in_.C; + in_shape.s[1] = im_in_.W * im_in_.C; + cl_int4 flags = {param->out_value_, param->get_max_, param->axis_, param->topk_}; + int arg_cnt = 2; + ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, buff_, lite::opencl::MemType::BUF); + ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, ids_, lite::opencl::MemType::BUF); + ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, in_shape); + ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, src_size_); + ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, cus_size_); + ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, strides_); + ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, flags); +} + +void ArgMinMaxOpenCLKernel::SetGlobalLocal() { + auto param = reinterpret_cast(op_parameter_); + auto in_shape = in_tensors_[0]->shape(); + auto in_shape_align = in_shape; + in_shape_align[3] = UP_ROUND(in_shape[3], C4NUM); + im_in_ = Image2DInfo(in_tensors_[0]); + auto out_shape_align = in_shape_align; + out_shape_align.at(param->axis_) = param->axis_ == 3 ? UP_ROUND(param->topk_, C4NUM) : param->topk_; + int reduce_len = GetUpPow2(in_shape.at(param->axis_)); + cus_size_ = {reduce_len, static_cast(im_in_.RowPitch() / C4NUM), 1, 1}; + cus_size_.s[2] = UP_ROUND(im_in_.width * C4NUM, cus_size_.s[1]) - im_in_.width * C4NUM; + cus_size_.s[3] = im_in_.W * UP_ROUND(param->topk_, C4NUM); + cus_size_.s[3] = UP_ROUND(cus_size_.s[3], cus_size_.s[1]) - cus_size_.s[3]; + src_size_ = {std::accumulate(in_shape.begin() + param->axis_ + 1, in_shape.end(), 1, std::multiplies()), + std::accumulate(in_shape.begin(), in_shape.begin() + param->axis_, 1, std::multiplies()), + std::accumulate(in_shape.begin() + param->axis_, in_shape.end(), 1, std::multiplies()), + in_shape.at(param->axis_)}; + strides_ = { + std::accumulate(in_shape_align.begin() + param->axis_ + 1, in_shape_align.end(), 1, std::multiplies()), + std::accumulate(in_shape_align.begin() + param->axis_, in_shape_align.end(), 1, std::multiplies()), + std::accumulate(out_shape_align.begin() + param->axis_ + 1, out_shape_align.end(), 1, std::multiplies()), + std::accumulate(out_shape_align.begin() + param->axis_, out_shape_align.end(), 1, std::multiplies()), + }; + switch (param->axis_) { + case 0: + strides_.s[0] = UP_ROUND(strides_.s[0] / im_in_.H, cus_size_.s[1]) * im_in_.H; + strides_.s[1] = strides_.s[0] * im_in_.N; + strides_.s[2] = UP_ROUND(strides_.s[2] / im_in_.H, cus_size_.s[1]) * im_in_.H; + strides_.s[3] = strides_.s[2] * param->topk_; + break; + case 1: + strides_.s[0] = UP_ROUND(strides_.s[0], cus_size_.s[1]); + strides_.s[1] = UP_ROUND(strides_.s[1] / im_in_.H, cus_size_.s[1]) * im_in_.H; + strides_.s[2] = UP_ROUND(strides_.s[2], cus_size_.s[1]); + strides_.s[3] = UP_ROUND(strides_.s[3] / param->topk_, cus_size_.s[1]) * param->topk_; + break; + case 2: + strides_.s[1] = UP_ROUND(strides_.s[1], cus_size_.s[1]); + strides_.s[3] = UP_ROUND(strides_.s[3], cus_size_.s[1]); + break; + default: // 3 + break; + } + std::vector local = {1, 1, 1}; + std::vector global = {static_cast(strides_.s[0]), static_cast(src_size_.s[1]), 1}; + OpenCLKernel::AlignGlobalLocal(global, local); +} + +int ArgMinMaxOpenCLKernel::InitWeights() { + auto allocator = ocl_runtime_->GetAllocator(); + int dtype_size = ocl_runtime_->GetFp16Enable() ? sizeof(int16_t) : sizeof(float); + buff_ = allocator->Malloc(in_tensors_[0]->ElementsNum() * dtype_size); + ids_ = allocator->Malloc(in_tensors_[0]->ElementsNum() * sizeof(int32_t)); + return RET_OK; +} + +int ArgMinMaxOpenCLKernel::Prepare() { + std::string kernel_name = "argminmax"; + +#ifdef PROGRAM_WITH_IL + kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name); +#else + + std::set build_options; + std::string source = argminmax_source; + std::string program_name = "argminmax"; + ocl_runtime_->LoadSource(program_name, source); + ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options); +#endif + + InitWeights(); + SetGlobalLocal(); + SetConstArgs(); + MS_LOG(DEBUG) << kernel_name << " Init Done!"; + return RET_OK; +} + +int ArgMinMaxOpenCLKernel::Run() { + MS_LOG(DEBUG) << this->name() << " Running! "; + ocl_runtime_->SetKernelArg(kernel_, 0, in_tensors_[0]->data_c(), lite::opencl::MemType::BUF); + ocl_runtime_->SetKernelArg(kernel_, 1, out_tensors_[0]->data_c(), lite::opencl::MemType::BUF); + ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr); + + return RET_OK; +} + +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_ArgMin, OpenCLKernelCreator); +REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_ArgMin, OpenCLKernelCreator); +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_ArgMax, OpenCLKernelCreator); +REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_ArgMax, OpenCLKernelCreator); +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.h new file mode 100644 index 00000000000..80910d3c12a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.h @@ -0,0 +1,52 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_BATCH_TO_SPACE_ND_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_BATCH_TO_SPACE_ND_H_ + +#include +#include "src/runtime/kernel/opencl/opencl_kernel.h" +#include "nnacl/arg_min_max_parameter.h" + +namespace mindspore::kernel { + +class ArgMinMaxOpenCLKernel : public OpenCLKernel { + public: + ArgMinMaxOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : OpenCLKernel(parameter, inputs, outputs) {} + + ~ArgMinMaxOpenCLKernel() override = default; + + int Run() override; + int Prepare() override; + + int CheckSpecs() override; + void SetConstArgs() override; + void SetGlobalLocal() override; + int InitWeights() override; + + private: + cl::Kernel kernel_; + void *buff_{nullptr}; + void *ids_{nullptr}; + Image2DInfo im_in_{Image2DInfo(nullptr)}; + cl_int4 src_size_; + cl_int4 cus_size_; + cl_int4 strides_; +}; +} // namespace mindspore::kernel +#endif diff --git a/mindspore/lite/src/runtime/kernel/opencl/utils.cc b/mindspore/lite/src/runtime/kernel/opencl/utils.cc index fbde8104426..bef9e336641 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/utils.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/utils.cc @@ -40,6 +40,17 @@ kernel::LiteKernel *GetOpenCLKernel(const std::vector &in_tensors, con namespace mindspore::kernel { +int GetUpPow2(int n) { + int i = 0; + int j = 0; + while (n > 0) { + j += n & 1; + n = n >> 1; + i++; + } + return 1 << (i - (j == 1)); +} + int GetMaxDivisor(int x, int divisor) { int i = divisor; while (i > 0) { diff --git a/mindspore/lite/src/runtime/kernel/opencl/utils.h b/mindspore/lite/src/runtime/kernel/opencl/utils.h index cec6316b314..228682358c3 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/utils.h +++ b/mindspore/lite/src/runtime/kernel/opencl/utils.h @@ -34,6 +34,8 @@ kernel::LiteKernel *GetOpenCLKernel(const std::vector &in_tensors, con namespace mindspore::kernel { +int GetUpPow2(int n); + int GetMaxDivisor(int x, int divisor); int GetMaxDivisorStrategy0(int x, int divisor); diff --git a/mindspore/lite/src/runtime/opencl/opencl_runtime.cc b/mindspore/lite/src/runtime/opencl/opencl_runtime.cc index cb16581c667..70137d416d9 100644 --- a/mindspore/lite/src/runtime/opencl/opencl_runtime.cc +++ b/mindspore/lite/src/runtime/opencl/opencl_runtime.cc @@ -157,7 +157,19 @@ int OpenCLRuntime::Init() { } #else MS_LOG(INFO) << "Create common opencl context"; +#ifdef Debug + std::vector ctx_properties = {CL_CONTEXT_PLATFORM, + (cl_context_properties)platforms[0](), + CL_PRINTF_CALLBACK_ARM, + (cl_context_properties)printf_callback, + CL_PRINTF_BUFFERSIZE_ARM, + 0x1000000, + 0}; + context_ = + new (std::nothrow) cl::Context(std::vector{*device_}, ctx_properties.data(), nullptr, nullptr, &ret); +#else context_ = new (std::nothrow) cl::Context(std::vector{*device_}, nullptr, nullptr, nullptr, &ret); +#endif #endif if (ret != CL_SUCCESS) { delete device_; @@ -201,9 +213,10 @@ int OpenCLRuntime::Init() { MS_LOG(INFO) << "Compute Unit: " << compute_units_; MS_LOG(INFO) << "Clock Frequency: " << max_freq_ << " MHz"; +#ifdef Debug + const cl_command_queue_properties properties = CL_QUEUE_PROFILING_ENABLE; +#else const cl_command_queue_properties properties = 0; -#if MS_OPENCL_PROFILE - properties |= CL_QUEUE_PROFILING_ENABLE; #endif default_command_queue_ = new (std::nothrow) cl::CommandQueue(*context_, *device_, properties, &ret); @@ -412,7 +425,7 @@ int OpenCLRuntime::RunKernel(const cl::Kernel &kernel, const std::vector } cnt++; MS_LOG(DEBUG) << "RunKernel success!"; -#if MS_OPENCL_PROFILE +#ifdef Debug event.wait(); cl_ulong time_start; cl_ulong time_end; @@ -445,7 +458,7 @@ int OpenCLRuntime::RunKernel(const cl::Kernel &kernel, const cl::NDRange &global } cnt++; MS_LOG(DEBUG) << "RunKernel success!"; -#if MS_OPENCL_PROFILE +#ifdef Debug event.wait(); cl_ulong time_start; cl_ulong time_end; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/argminmax_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/argminmax_tests.cc new file mode 100644 index 00000000000..b2cf515ed5f --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/argminmax_tests.cc @@ -0,0 +1,283 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "src/common/log_adapter.h" +#include "common/common_test.h" +#include "src/runtime/kernel/opencl/utils.h" +#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.h" + +namespace mindspore { +class TestArgMinMaxOpenCL : public mindspore::CommonTest { + public: + TestArgMinMaxOpenCL() {} +}; +template +void test_main_argminmax(void *input_data, void *correct_data, const std::vector &input_shape, + const std::vector &output_shape, ArgMinMaxParameter *param, TypeId data_type, + schema::Format format) { + MS_LOG(INFO) << " begin test "; + auto ocl_runtime_wrap = lite::opencl::OpenCLRuntimeWrapper(); + auto ocl_runtime = ocl_runtime_wrap.GetInstance(); + ocl_runtime->Init(); + auto allocator = ocl_runtime->GetAllocator(); + + auto tensor_a = lite::Tensor(TypeId(data_type), input_shape, format); + auto tensor_c = lite::Tensor(TypeId(data_type), output_shape, format); + std::vector inputs{&tensor_a}; + std::vector outputs{&tensor_c}; + size_t input_size = tensor_a.Size(); + + auto *pkernel = + new (std::nothrow) kernel::ArgMinMaxOpenCLKernel(reinterpret_cast(param), inputs, outputs); + if (pkernel == nullptr) { + MS_LOG(INFO) << "new SpaceToBatchNDOpenCLKernel failed "; + return; + } + pkernel->Init(); + + // to do allocate memory for inputs and outputs + for (auto &input_tensor : inputs) { + input_tensor->MallocData(allocator); + } + + MS_LOG(INFO) << " initialize sub_graph "; + std::vector kernels{pkernel}; + auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + if (sub_graph == nullptr) { + delete pkernel; + MS_LOG(INFO) << " new SubGraphOpenCLKernel failed "; + return; + } + sub_graph->Init(); + + MS_LOG(INFO) << " init tensors "; + T *input_ptr = reinterpret_cast(inputs[0]->MutableData()); + memcpy(input_ptr, input_data, input_size); + std::cout << "==================input data================" << std::endl; + for (auto i = 0; i < inputs[0]->ElementsNum(); ++i) { + std::cout << input_ptr[i] << ", "; + } + std::cout << std::endl; + + sub_graph->Run(); + + auto *output_data = reinterpret_cast(outputs[0]->MutableData()); + std::cout << "==================output data================" << std::endl; + for (auto i = 0; i < outputs[0]->ElementsNum(); ++i) { + std::cout << output_data[i] << ", "; + } + std::cout << std::endl; + std::cout << "==================correct data================" << std::endl; + for (auto i = 0; i < outputs[0]->ElementsNum(); ++i) { + std::cout << static_cast(correct_data)[i] << ", "; + } + std::cout << std::endl; + CommonTest::CompareOutputData(output_data, static_cast(correct_data), outputs[0]->ElementsNum(), 0.0001); + delete sub_graph; +} +TEST_F(TestArgMinMaxOpenCL, axis0topk2index) { + ArgMinMaxParameter *param = std::make_unique().release(); + if (param == nullptr) { + return; + } + std::vector in_data = {100, 2, 4, 50, 11, 12, 34, 35, 10, 20, 40, 5, + 7, 80, 10, 11, 55, 25, 5, 15, 18, 8, 15, 16}; + std::vector except_out = {0, 2, 1, 0, 2, 1, 0, 0, 2, 1, 2, 2, 0, 0, 2, 2}; + param->dims_size_ = 4; + param->axis_ = 0; + param->topk_ = 2; + param->get_max_ = true; + param->out_value_ = false; + std::vector in_shape = {3, 2, 2, 2}; + std::vector out_shape = {2, 2, 2, 2}; + + TypeId data_type = kNumberTypeFloat32; + schema::Format format = schema::Format_NHWC; + test_main_argminmax(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format); +} +TEST_F(TestArgMinMaxOpenCL, axis0topk2value) { + ArgMinMaxParameter *param = std::make_unique().release(); + if (param == nullptr) { + return; + } + std::vector in_data = {100, 2, 4, 50, 11, 12, 34, 35, 10, 20, 40, 5, + 7, 80, 10, 11, 55, 25, 5, 15, 18, 8, 15, 16}; + std::vector except_out = {100, 25, 40, 50, 18, 80, 34, 35, 55, 20, 5, 15, 11, 12, 15, 16}; + param->dims_size_ = 4; + param->axis_ = 0; + param->topk_ = 2; + param->get_max_ = true; + param->out_value_ = true; + std::vector in_shape = {3, 2, 2, 2}; + std::vector out_shape = {2, 2, 2, 2}; + + TypeId data_type = kNumberTypeFloat32; + schema::Format format = schema::Format_NHWC; + test_main_argminmax(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format); +} +TEST_F(TestArgMinMaxOpenCL, axis1topk2index) { + ArgMinMaxParameter *param = std::make_unique().release(); + if (param == nullptr) { + return; + } + std::vector in_data = {100, 2, 200, 4, 50, 6, 11, 12, 13, 34, 35, 36, 9, 6, 17, 10, 20, 30, + 10, 20, 30, 40, 5, 60, 7, 80, 90, 10, 11, 120, 18, 5, 16, 9, 22, 23}; + std::vector except_out = {0, 1, 0, 1, 0, 1, 1, 2, 2, 2, 1, 2, 2, 1, 1, 0, 2, 1, 0, 0, 0, 1, 1, 0}; + param->dims_size_ = 4; + param->axis_ = 1; + param->topk_ = 2; + param->get_max_ = true; + param->out_value_ = false; + std::vector in_shape = {2, 3, 2, 3}; + std::vector out_shape = {2, 2, 2, 3}; + + TypeId data_type = kNumberTypeFloat32; + schema::Format format = schema::Format_NHWC; + test_main_argminmax(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format); +} +TEST_F(TestArgMinMaxOpenCL, axis1topk2value) { + ArgMinMaxParameter *param = std::make_unique().release(); + if (param == nullptr) { + return; + } + std::vector in_data = {100, 2, 200, 4, 50, 6, 11, 12, 13, 34, 35, 36, 9, 6, 17, 10, 20, 30, + 10, 20, 30, 40, 5, 60, 7, 80, 90, 10, 11, 120, 18, 5, 16, 9, 22, 23}; + std::vector except_out = {100, 12, 200, 34, 50, 36, 11, 6, 17, 10, 35, 30, + 18, 80, 90, 40, 22, 120, 10, 20, 30, 10, 11, 60}; + param->dims_size_ = 4; + param->axis_ = 1; + param->topk_ = 2; + param->get_max_ = true; + param->out_value_ = true; + std::vector in_shape = {2, 3, 2, 3}; + std::vector out_shape = {2, 2, 2, 3}; + + TypeId data_type = kNumberTypeFloat32; + schema::Format format = schema::Format_NHWC; + test_main_argminmax(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format); +} +TEST_F(TestArgMinMaxOpenCL, axis2topk1index) { + ArgMinMaxParameter *param = std::make_unique().release(); + if (param == nullptr) { + return; + } + param->dims_size_ = 4; + param->axis_ = 2; + param->topk_ = 1; + param->get_max_ = true; + param->out_value_ = false; + std::vector in_data = {10, 20, 30, 11, 15, 10, 5, 10, 12, 10, 20, 30, 11, 15, 10, 5, 10, 12, + 10, 20, 30, 11, 15, 10, 5, 10, 12, 10, 20, 30, 11, 15, 10, 5, 10, 12, + 10, 20, 30, 11, 15, 10, 5, 10, 12, 10, 20, 30, 11, 15, 10, 5, 10, 12}; + std::vector except_out = {1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0}; + std::vector in_shape = {2, 3, 3, 3}; + std::vector out_shape = {2, 3, 1, 3}; + + TypeId data_type = kNumberTypeFloat32; + schema::Format format = schema::Format_NHWC; + test_main_argminmax(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format); +} +TEST_F(TestArgMinMaxOpenCL, axis2topk2value) { + ArgMinMaxParameter *param = std::make_unique().release(); + if (param == nullptr) { + return; + } + std::vector in_data = {10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, + 20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, + 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30}; + std::vector except_out = {30, 45, 30, 50, 90, 20, 20, 25, 40, 50, 30, 45, 30, 50, 90, 20, 20, 25, 40, 50, + 30, 45, 30, 50, 90, 20, 20, 25, 40, 50, 30, 45, 30, 50, 90, 20, 20, 25, 40, 50}; + param->dims_size_ = 4; + param->axis_ = 2; + param->topk_ = 2; + param->get_max_ = true; + param->out_value_ = true; + std::vector in_shape = {2, 2, 3, 5}; + std::vector out_shape = {1, 2, 2, 5}; + + TypeId data_type = kNumberTypeFloat32; + schema::Format format = schema::Format_NHWC; + test_main_argminmax(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format); +} +TEST_F(TestArgMinMaxOpenCL, axis2topk2index) { + ArgMinMaxParameter *param = std::make_unique().release(); + if (param == nullptr) { + return; + } + std::vector in_data = {10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, + 20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, + 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30}; + std::vector except_out = {2, 2, 0, 2, 0, 1, 0, 2, 0, 1, 2, 2, 0, 2, 0, 1, 0, 2, 0, 1, + 2, 2, 0, 2, 0, 1, 0, 2, 0, 1, 2, 2, 0, 2, 0, 1, 0, 2, 0, 1}; + param->dims_size_ = 4; + param->axis_ = 2; + param->topk_ = 2; + param->get_max_ = true; + param->out_value_ = false; + std::vector in_shape = {2, 2, 3, 5}; + std::vector out_shape = {2, 2, 2, 5}; + + TypeId data_type = kNumberTypeFloat32; + schema::Format format = schema::Format_NHWC; + test_main_argminmax(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format); +} +TEST_F(TestArgMinMaxOpenCL, axis3topk2index) { + ArgMinMaxParameter *param = std::make_unique().release(); + if (param == nullptr) { + return; + } + std::vector in_data = {10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, + 20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, + 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30}; + std::vector except_out = {4, 3, 4, 0, 3, 1, 4, 3, 4, 0, 3, 1, 4, 3, 4, 0, 3, 1, 4, 3, 4, 0, 3, 1}; + param->dims_size_ = 4; + param->axis_ = 3; + param->topk_ = 2; + param->get_max_ = true; + param->out_value_ = false; + std::vector in_shape = {2, 2, 3, 5}; + std::vector out_shape = {2, 2, 3, 2}; + + TypeId data_type = kNumberTypeFloat32; + schema::Format format = schema::Format_NHWC; + test_main_argminmax(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format); +} +TEST_F(TestArgMinMaxOpenCL, axis3topk2value) { + ArgMinMaxParameter *param = std::make_unique().release(); + if (param == nullptr) { + return; + } + std::vector in_data = {10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, + 20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, + 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30}; + std::vector except_out = {90, 40, 50, 20, 50, 45, 90, 40, 50, 20, 50, 45, + 90, 40, 50, 20, 50, 45, 90, 40, 50, 20, 50, 45}; + param->dims_size_ = 4; + param->axis_ = 3; + param->topk_ = 2; + param->get_max_ = true; + param->out_value_ = true; + std::vector in_shape = {2, 2, 3, 5}; + std::vector out_shape = {2, 2, 3, 2}; + + TypeId data_type = kNumberTypeFloat32; + schema::Format format = schema::Format_NHWC; + test_main_argminmax(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format); +} +} // namespace mindspore