forked from mindspore-Ecosystem/mindspore
!8589 add opencl argmin and argmax
From: @ddwsky Reviewed-by: @zhanghaibo5,@HilbertDavid Signed-off-by: @HilbertDavid
This commit is contained in:
commit
f1e0a673ac
|
@ -0,0 +1,72 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#define swap(a, b, c) \
|
||||
c = a; \
|
||||
a = b; \
|
||||
b = c;
|
||||
#define swap_atomic(a, b, c) \
|
||||
c = atomic_xchg(a, *(b)); \
|
||||
c = atomic_xchg(b, c);
|
||||
#define UP_ROUND(a, b) (((a + b - 1) / b) * b)
|
||||
#define C4NUM 4
|
||||
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
__kernel void argminmax(__global FLT *src_data, __global FLT *dst_data, __global FLT *buf, __global int *ids,
|
||||
int4 shape, int4 src_size, int4 cus_size, int4 strides, int4 flags) {
|
||||
int X = get_global_id(0); // reduce len
|
||||
int Y = get_global_id(1); // upper axis accumulation
|
||||
if (X >= src_size.x || Y >= src_size.y) {
|
||||
return;
|
||||
}
|
||||
int offset = X + Y * src_size.z;
|
||||
int align_c4 = (flags.z != 3) ? (X / shape.w) * (shape.x) : 0;
|
||||
int align_in = 0;
|
||||
int align_out = 0;
|
||||
if (flags.z == 3) {
|
||||
align_in = (Y / shape.z) * cus_size.z;
|
||||
align_out = (Y / shape.z) * cus_size.w;
|
||||
}
|
||||
if (flags.z == 0) {
|
||||
align_in = X / (shape.y) * cus_size.z;
|
||||
align_out = align_in;
|
||||
}
|
||||
for (int k = 0; k < src_size.w; ++k) {
|
||||
int idx0 = (X + k * strides.x) + Y * strides.y + (align_c4 + align_in);
|
||||
int idx1 = offset + k * src_size.x;
|
||||
ids[idx1] = k;
|
||||
buf[idx1] = src_data[idx0];
|
||||
}
|
||||
for (unsigned int i = 2; i <= cus_size.x; i <<= 1) {
|
||||
for (unsigned int j = i >> 1; j > 0; j >>= 1) {
|
||||
for (int tid = 0; tid < src_size.w; ++tid) {
|
||||
unsigned int tid_comp = tid + j;
|
||||
if (tid_comp < src_size.w) {
|
||||
int lk = offset + tid * src_size.x;
|
||||
int rk = offset + tid_comp * src_size.x;
|
||||
if ((tid & i) == 0) { // ascending
|
||||
if (buf[lk] > buf[rk]) {
|
||||
FLT tmpf;
|
||||
swap(buf[lk], buf[rk], tmpf);
|
||||
int tmpi;
|
||||
swap(ids[lk], ids[rk], tmpi);
|
||||
}
|
||||
} else { // desending
|
||||
if (buf[lk] < buf[rk]) {
|
||||
FLT tmpf;
|
||||
swap(buf[lk], buf[rk], tmpf);
|
||||
int tmpi;
|
||||
swap(ids[lk], ids[rk], tmpi);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int k = 0; k < flags.w; ++k) {
|
||||
int idx0 = (X + k * strides.z) + Y * strides.w + (align_c4 + align_out);
|
||||
int idx1 = flags.y ? (offset + (src_size.w - k - 1) * src_size.x) : (offset + k * src_size.x);
|
||||
if (flags.x) {
|
||||
dst_data[idx0] = buf[idx1];
|
||||
} else {
|
||||
dst_data[idx0] = ids[idx1];
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,164 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
#include <functional>
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/kernel/opencl/utils.h"
|
||||
#include "src/runtime/kernel/opencl/kernel/argminmax.h"
|
||||
#include "src/runtime/kernel/opencl/cl/argminmax.cl.inc"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kGPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_ArgMax;
|
||||
using mindspore::schema::PrimitiveType_ArgMin;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
int ArgMinMaxOpenCLKernel::CheckSpecs() {
|
||||
if (in_tensors_[0]->data_type() != kNumberTypeFloat32 && in_tensors_[0]->data_type() != kNumberTypeFloat16) {
|
||||
MS_LOG(ERROR) << "Unsupported data type " << in_tensors_[0]->data_type();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (in_tensors_[0]->shape().size() > 4 && in_tensors_[0]->shape().size() == 0) {
|
||||
MS_LOG(ERROR) << "input shape size must be (1-4), actual: " << in_tensors_[0]->shape().size() << ", "
|
||||
<< out_tensors_[0]->shape().size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto *param = reinterpret_cast<ArgMinMaxParameter *>(this->op_parameter_);
|
||||
param->dims_size_ = in_tensors_[0]->shape().size();
|
||||
param->axis_ = (param->axis_ + param->dims_size_) % param->dims_size_;
|
||||
if (param->axis_ < 0 || param->axis_ >= param->dims_size_) {
|
||||
MS_LOG(ERROR) << "Invalid axis " << param->axis_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
param->get_max_ = (op_parameter_->type_ == PrimitiveType_ArgMax);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void ArgMinMaxOpenCLKernel::SetConstArgs() {
|
||||
auto param = reinterpret_cast<ArgMinMaxParameter *>(op_parameter_);
|
||||
cl_int4 in_shape{static_cast<int>(im_in_.N), static_cast<int>(im_in_.H), static_cast<int>(im_in_.W),
|
||||
static_cast<int>(im_in_.C)};
|
||||
in_shape.s[0] = UP_ROUND(im_in_.C, C4NUM) - im_in_.C;
|
||||
in_shape.s[1] = im_in_.W * im_in_.C;
|
||||
cl_int4 flags = {param->out_value_, param->get_max_, param->axis_, param->topk_};
|
||||
int arg_cnt = 2;
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, buff_, lite::opencl::MemType::BUF);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, ids_, lite::opencl::MemType::BUF);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, in_shape);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, src_size_);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, cus_size_);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, strides_);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, flags);
|
||||
}
|
||||
|
||||
void ArgMinMaxOpenCLKernel::SetGlobalLocal() {
|
||||
auto param = reinterpret_cast<ArgMinMaxParameter *>(op_parameter_);
|
||||
auto in_shape = in_tensors_[0]->shape();
|
||||
auto in_shape_align = in_shape;
|
||||
in_shape_align[3] = UP_ROUND(in_shape[3], C4NUM);
|
||||
im_in_ = Image2DInfo(in_tensors_[0]);
|
||||
auto out_shape_align = in_shape_align;
|
||||
out_shape_align.at(param->axis_) = param->axis_ == 3 ? UP_ROUND(param->topk_, C4NUM) : param->topk_;
|
||||
int reduce_len = GetUpPow2(in_shape.at(param->axis_));
|
||||
cus_size_ = {reduce_len, static_cast<int>(im_in_.RowPitch() / C4NUM), 1, 1};
|
||||
cus_size_.s[2] = UP_ROUND(im_in_.width * C4NUM, cus_size_.s[1]) - im_in_.width * C4NUM;
|
||||
cus_size_.s[3] = im_in_.W * UP_ROUND(param->topk_, C4NUM);
|
||||
cus_size_.s[3] = UP_ROUND(cus_size_.s[3], cus_size_.s[1]) - cus_size_.s[3];
|
||||
src_size_ = {std::accumulate(in_shape.begin() + param->axis_ + 1, in_shape.end(), 1, std::multiplies<int>()),
|
||||
std::accumulate(in_shape.begin(), in_shape.begin() + param->axis_, 1, std::multiplies<int>()),
|
||||
std::accumulate(in_shape.begin() + param->axis_, in_shape.end(), 1, std::multiplies<int>()),
|
||||
in_shape.at(param->axis_)};
|
||||
strides_ = {
|
||||
std::accumulate(in_shape_align.begin() + param->axis_ + 1, in_shape_align.end(), 1, std::multiplies<int>()),
|
||||
std::accumulate(in_shape_align.begin() + param->axis_, in_shape_align.end(), 1, std::multiplies<int>()),
|
||||
std::accumulate(out_shape_align.begin() + param->axis_ + 1, out_shape_align.end(), 1, std::multiplies<int>()),
|
||||
std::accumulate(out_shape_align.begin() + param->axis_, out_shape_align.end(), 1, std::multiplies<int>()),
|
||||
};
|
||||
switch (param->axis_) {
|
||||
case 0:
|
||||
strides_.s[0] = UP_ROUND(strides_.s[0] / im_in_.H, cus_size_.s[1]) * im_in_.H;
|
||||
strides_.s[1] = strides_.s[0] * im_in_.N;
|
||||
strides_.s[2] = UP_ROUND(strides_.s[2] / im_in_.H, cus_size_.s[1]) * im_in_.H;
|
||||
strides_.s[3] = strides_.s[2] * param->topk_;
|
||||
break;
|
||||
case 1:
|
||||
strides_.s[0] = UP_ROUND(strides_.s[0], cus_size_.s[1]);
|
||||
strides_.s[1] = UP_ROUND(strides_.s[1] / im_in_.H, cus_size_.s[1]) * im_in_.H;
|
||||
strides_.s[2] = UP_ROUND(strides_.s[2], cus_size_.s[1]);
|
||||
strides_.s[3] = UP_ROUND(strides_.s[3] / param->topk_, cus_size_.s[1]) * param->topk_;
|
||||
break;
|
||||
case 2:
|
||||
strides_.s[1] = UP_ROUND(strides_.s[1], cus_size_.s[1]);
|
||||
strides_.s[3] = UP_ROUND(strides_.s[3], cus_size_.s[1]);
|
||||
break;
|
||||
default: // 3
|
||||
break;
|
||||
}
|
||||
std::vector<size_t> local = {1, 1, 1};
|
||||
std::vector<size_t> global = {static_cast<size_t>(strides_.s[0]), static_cast<size_t>(src_size_.s[1]), 1};
|
||||
OpenCLKernel::AlignGlobalLocal(global, local);
|
||||
}
|
||||
|
||||
int ArgMinMaxOpenCLKernel::InitWeights() {
|
||||
auto allocator = ocl_runtime_->GetAllocator();
|
||||
int dtype_size = ocl_runtime_->GetFp16Enable() ? sizeof(int16_t) : sizeof(float);
|
||||
buff_ = allocator->Malloc(in_tensors_[0]->ElementsNum() * dtype_size);
|
||||
ids_ = allocator->Malloc(in_tensors_[0]->ElementsNum() * sizeof(int32_t));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ArgMinMaxOpenCLKernel::Prepare() {
|
||||
std::string kernel_name = "argminmax";
|
||||
|
||||
#ifdef PROGRAM_WITH_IL
|
||||
kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name);
|
||||
#else
|
||||
|
||||
std::set<std::string> build_options;
|
||||
std::string source = argminmax_source;
|
||||
std::string program_name = "argminmax";
|
||||
ocl_runtime_->LoadSource(program_name, source);
|
||||
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options);
|
||||
#endif
|
||||
|
||||
InitWeights();
|
||||
SetGlobalLocal();
|
||||
SetConstArgs();
|
||||
MS_LOG(DEBUG) << kernel_name << " Init Done!";
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ArgMinMaxOpenCLKernel::Run() {
|
||||
MS_LOG(DEBUG) << this->name() << " Running! ";
|
||||
ocl_runtime_->SetKernelArg(kernel_, 0, in_tensors_[0]->data_c(), lite::opencl::MemType::BUF);
|
||||
ocl_runtime_->SetKernelArg(kernel_, 1, out_tensors_[0]->data_c(), lite::opencl::MemType::BUF);
|
||||
ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr);
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_ArgMin, OpenCLKernelCreator<ArgMinMaxOpenCLKernel>);
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_ArgMin, OpenCLKernelCreator<ArgMinMaxOpenCLKernel>);
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_ArgMax, OpenCLKernelCreator<ArgMinMaxOpenCLKernel>);
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_ArgMax, OpenCLKernelCreator<ArgMinMaxOpenCLKernel>);
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_BATCH_TO_SPACE_ND_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_BATCH_TO_SPACE_ND_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/runtime/kernel/opencl/opencl_kernel.h"
|
||||
#include "nnacl/arg_min_max_parameter.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
class ArgMinMaxOpenCLKernel : public OpenCLKernel {
|
||||
public:
|
||||
ArgMinMaxOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs)
|
||||
: OpenCLKernel(parameter, inputs, outputs) {}
|
||||
|
||||
~ArgMinMaxOpenCLKernel() override = default;
|
||||
|
||||
int Run() override;
|
||||
int Prepare() override;
|
||||
|
||||
int CheckSpecs() override;
|
||||
void SetConstArgs() override;
|
||||
void SetGlobalLocal() override;
|
||||
int InitWeights() override;
|
||||
|
||||
private:
|
||||
cl::Kernel kernel_;
|
||||
void *buff_{nullptr};
|
||||
void *ids_{nullptr};
|
||||
Image2DInfo im_in_{Image2DInfo(nullptr)};
|
||||
cl_int4 src_size_;
|
||||
cl_int4 cus_size_;
|
||||
cl_int4 strides_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif
|
|
@ -40,6 +40,17 @@ kernel::LiteKernel *GetOpenCLKernel(const std::vector<Tensor *> &in_tensors, con
|
|||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
int GetUpPow2(int n) {
|
||||
int i = 0;
|
||||
int j = 0;
|
||||
while (n > 0) {
|
||||
j += n & 1;
|
||||
n = n >> 1;
|
||||
i++;
|
||||
}
|
||||
return 1 << (i - (j == 1));
|
||||
}
|
||||
|
||||
int GetMaxDivisor(int x, int divisor) {
|
||||
int i = divisor;
|
||||
while (i > 0) {
|
||||
|
|
|
@ -34,6 +34,8 @@ kernel::LiteKernel *GetOpenCLKernel(const std::vector<Tensor *> &in_tensors, con
|
|||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
int GetUpPow2(int n);
|
||||
|
||||
int GetMaxDivisor(int x, int divisor);
|
||||
|
||||
int GetMaxDivisorStrategy0(int x, int divisor);
|
||||
|
|
|
@ -157,7 +157,19 @@ int OpenCLRuntime::Init() {
|
|||
}
|
||||
#else
|
||||
MS_LOG(INFO) << "Create common opencl context";
|
||||
#ifdef Debug
|
||||
std::vector<cl_context_properties> ctx_properties = {CL_CONTEXT_PLATFORM,
|
||||
(cl_context_properties)platforms[0](),
|
||||
CL_PRINTF_CALLBACK_ARM,
|
||||
(cl_context_properties)printf_callback,
|
||||
CL_PRINTF_BUFFERSIZE_ARM,
|
||||
0x1000000,
|
||||
0};
|
||||
context_ =
|
||||
new (std::nothrow) cl::Context(std::vector<cl::Device>{*device_}, ctx_properties.data(), nullptr, nullptr, &ret);
|
||||
#else
|
||||
context_ = new (std::nothrow) cl::Context(std::vector<cl::Device>{*device_}, nullptr, nullptr, nullptr, &ret);
|
||||
#endif
|
||||
#endif
|
||||
if (ret != CL_SUCCESS) {
|
||||
delete device_;
|
||||
|
@ -201,9 +213,10 @@ int OpenCLRuntime::Init() {
|
|||
MS_LOG(INFO) << "Compute Unit: " << compute_units_;
|
||||
MS_LOG(INFO) << "Clock Frequency: " << max_freq_ << " MHz";
|
||||
|
||||
#ifdef Debug
|
||||
const cl_command_queue_properties properties = CL_QUEUE_PROFILING_ENABLE;
|
||||
#else
|
||||
const cl_command_queue_properties properties = 0;
|
||||
#if MS_OPENCL_PROFILE
|
||||
properties |= CL_QUEUE_PROFILING_ENABLE;
|
||||
#endif
|
||||
|
||||
default_command_queue_ = new (std::nothrow) cl::CommandQueue(*context_, *device_, properties, &ret);
|
||||
|
@ -412,7 +425,7 @@ int OpenCLRuntime::RunKernel(const cl::Kernel &kernel, const std::vector<size_t>
|
|||
}
|
||||
cnt++;
|
||||
MS_LOG(DEBUG) << "RunKernel success!";
|
||||
#if MS_OPENCL_PROFILE
|
||||
#ifdef Debug
|
||||
event.wait();
|
||||
cl_ulong time_start;
|
||||
cl_ulong time_end;
|
||||
|
@ -445,7 +458,7 @@ int OpenCLRuntime::RunKernel(const cl::Kernel &kernel, const cl::NDRange &global
|
|||
}
|
||||
cnt++;
|
||||
MS_LOG(DEBUG) << "RunKernel success!";
|
||||
#if MS_OPENCL_PROFILE
|
||||
#ifdef Debug
|
||||
event.wait();
|
||||
cl_ulong time_start;
|
||||
cl_ulong time_end;
|
||||
|
|
|
@ -0,0 +1,283 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "common/common_test.h"
|
||||
#include "src/runtime/kernel/opencl/utils.h"
|
||||
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TestArgMinMaxOpenCL : public mindspore::CommonTest {
|
||||
public:
|
||||
TestArgMinMaxOpenCL() {}
|
||||
};
|
||||
template <typename T>
|
||||
void test_main_argminmax(void *input_data, void *correct_data, const std::vector<int> &input_shape,
|
||||
const std::vector<int> &output_shape, ArgMinMaxParameter *param, TypeId data_type,
|
||||
schema::Format format) {
|
||||
MS_LOG(INFO) << " begin test ";
|
||||
auto ocl_runtime_wrap = lite::opencl::OpenCLRuntimeWrapper();
|
||||
auto ocl_runtime = ocl_runtime_wrap.GetInstance();
|
||||
ocl_runtime->Init();
|
||||
auto allocator = ocl_runtime->GetAllocator();
|
||||
|
||||
auto tensor_a = lite::Tensor(TypeId(data_type), input_shape, format);
|
||||
auto tensor_c = lite::Tensor(TypeId(data_type), output_shape, format);
|
||||
std::vector<lite::Tensor *> inputs{&tensor_a};
|
||||
std::vector<lite::Tensor *> outputs{&tensor_c};
|
||||
size_t input_size = tensor_a.Size();
|
||||
|
||||
auto *pkernel =
|
||||
new (std::nothrow) kernel::ArgMinMaxOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
|
||||
if (pkernel == nullptr) {
|
||||
MS_LOG(INFO) << "new SpaceToBatchNDOpenCLKernel failed ";
|
||||
return;
|
||||
}
|
||||
pkernel->Init();
|
||||
|
||||
// to do allocate memory for inputs and outputs
|
||||
for (auto &input_tensor : inputs) {
|
||||
input_tensor->MallocData(allocator);
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << " initialize sub_graph ";
|
||||
std::vector<kernel::LiteKernel *> kernels{pkernel};
|
||||
auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
|
||||
if (sub_graph == nullptr) {
|
||||
delete pkernel;
|
||||
MS_LOG(INFO) << " new SubGraphOpenCLKernel failed ";
|
||||
return;
|
||||
}
|
||||
sub_graph->Init();
|
||||
|
||||
MS_LOG(INFO) << " init tensors ";
|
||||
T *input_ptr = reinterpret_cast<T *>(inputs[0]->MutableData());
|
||||
memcpy(input_ptr, input_data, input_size);
|
||||
std::cout << "==================input data================" << std::endl;
|
||||
for (auto i = 0; i < inputs[0]->ElementsNum(); ++i) {
|
||||
std::cout << input_ptr[i] << ", ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
sub_graph->Run();
|
||||
|
||||
auto *output_data = reinterpret_cast<T *>(outputs[0]->MutableData());
|
||||
std::cout << "==================output data================" << std::endl;
|
||||
for (auto i = 0; i < outputs[0]->ElementsNum(); ++i) {
|
||||
std::cout << output_data[i] << ", ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
std::cout << "==================correct data================" << std::endl;
|
||||
for (auto i = 0; i < outputs[0]->ElementsNum(); ++i) {
|
||||
std::cout << static_cast<T *>(correct_data)[i] << ", ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
CommonTest::CompareOutputData<T>(output_data, static_cast<T *>(correct_data), outputs[0]->ElementsNum(), 0.0001);
|
||||
delete sub_graph;
|
||||
}
|
||||
TEST_F(TestArgMinMaxOpenCL, axis0topk2index) {
|
||||
ArgMinMaxParameter *param = std::make_unique<ArgMinMaxParameter>().release();
|
||||
if (param == nullptr) {
|
||||
return;
|
||||
}
|
||||
std::vector<float> in_data = {100, 2, 4, 50, 11, 12, 34, 35, 10, 20, 40, 5,
|
||||
7, 80, 10, 11, 55, 25, 5, 15, 18, 8, 15, 16};
|
||||
std::vector<float> except_out = {0, 2, 1, 0, 2, 1, 0, 0, 2, 1, 2, 2, 0, 0, 2, 2};
|
||||
param->dims_size_ = 4;
|
||||
param->axis_ = 0;
|
||||
param->topk_ = 2;
|
||||
param->get_max_ = true;
|
||||
param->out_value_ = false;
|
||||
std::vector<int> in_shape = {3, 2, 2, 2};
|
||||
std::vector<int> out_shape = {2, 2, 2, 2};
|
||||
|
||||
TypeId data_type = kNumberTypeFloat32;
|
||||
schema::Format format = schema::Format_NHWC;
|
||||
test_main_argminmax<float>(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format);
|
||||
}
|
||||
TEST_F(TestArgMinMaxOpenCL, axis0topk2value) {
|
||||
ArgMinMaxParameter *param = std::make_unique<ArgMinMaxParameter>().release();
|
||||
if (param == nullptr) {
|
||||
return;
|
||||
}
|
||||
std::vector<float> in_data = {100, 2, 4, 50, 11, 12, 34, 35, 10, 20, 40, 5,
|
||||
7, 80, 10, 11, 55, 25, 5, 15, 18, 8, 15, 16};
|
||||
std::vector<float> except_out = {100, 25, 40, 50, 18, 80, 34, 35, 55, 20, 5, 15, 11, 12, 15, 16};
|
||||
param->dims_size_ = 4;
|
||||
param->axis_ = 0;
|
||||
param->topk_ = 2;
|
||||
param->get_max_ = true;
|
||||
param->out_value_ = true;
|
||||
std::vector<int> in_shape = {3, 2, 2, 2};
|
||||
std::vector<int> out_shape = {2, 2, 2, 2};
|
||||
|
||||
TypeId data_type = kNumberTypeFloat32;
|
||||
schema::Format format = schema::Format_NHWC;
|
||||
test_main_argminmax<float>(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format);
|
||||
}
|
||||
TEST_F(TestArgMinMaxOpenCL, axis1topk2index) {
|
||||
ArgMinMaxParameter *param = std::make_unique<ArgMinMaxParameter>().release();
|
||||
if (param == nullptr) {
|
||||
return;
|
||||
}
|
||||
std::vector<float> in_data = {100, 2, 200, 4, 50, 6, 11, 12, 13, 34, 35, 36, 9, 6, 17, 10, 20, 30,
|
||||
10, 20, 30, 40, 5, 60, 7, 80, 90, 10, 11, 120, 18, 5, 16, 9, 22, 23};
|
||||
std::vector<float> except_out = {0, 1, 0, 1, 0, 1, 1, 2, 2, 2, 1, 2, 2, 1, 1, 0, 2, 1, 0, 0, 0, 1, 1, 0};
|
||||
param->dims_size_ = 4;
|
||||
param->axis_ = 1;
|
||||
param->topk_ = 2;
|
||||
param->get_max_ = true;
|
||||
param->out_value_ = false;
|
||||
std::vector<int> in_shape = {2, 3, 2, 3};
|
||||
std::vector<int> out_shape = {2, 2, 2, 3};
|
||||
|
||||
TypeId data_type = kNumberTypeFloat32;
|
||||
schema::Format format = schema::Format_NHWC;
|
||||
test_main_argminmax<float>(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format);
|
||||
}
|
||||
TEST_F(TestArgMinMaxOpenCL, axis1topk2value) {
|
||||
ArgMinMaxParameter *param = std::make_unique<ArgMinMaxParameter>().release();
|
||||
if (param == nullptr) {
|
||||
return;
|
||||
}
|
||||
std::vector<float> in_data = {100, 2, 200, 4, 50, 6, 11, 12, 13, 34, 35, 36, 9, 6, 17, 10, 20, 30,
|
||||
10, 20, 30, 40, 5, 60, 7, 80, 90, 10, 11, 120, 18, 5, 16, 9, 22, 23};
|
||||
std::vector<float> except_out = {100, 12, 200, 34, 50, 36, 11, 6, 17, 10, 35, 30,
|
||||
18, 80, 90, 40, 22, 120, 10, 20, 30, 10, 11, 60};
|
||||
param->dims_size_ = 4;
|
||||
param->axis_ = 1;
|
||||
param->topk_ = 2;
|
||||
param->get_max_ = true;
|
||||
param->out_value_ = true;
|
||||
std::vector<int> in_shape = {2, 3, 2, 3};
|
||||
std::vector<int> out_shape = {2, 2, 2, 3};
|
||||
|
||||
TypeId data_type = kNumberTypeFloat32;
|
||||
schema::Format format = schema::Format_NHWC;
|
||||
test_main_argminmax<float>(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format);
|
||||
}
|
||||
TEST_F(TestArgMinMaxOpenCL, axis2topk1index) {
|
||||
ArgMinMaxParameter *param = std::make_unique<ArgMinMaxParameter>().release();
|
||||
if (param == nullptr) {
|
||||
return;
|
||||
}
|
||||
param->dims_size_ = 4;
|
||||
param->axis_ = 2;
|
||||
param->topk_ = 1;
|
||||
param->get_max_ = true;
|
||||
param->out_value_ = false;
|
||||
std::vector<float> in_data = {10, 20, 30, 11, 15, 10, 5, 10, 12, 10, 20, 30, 11, 15, 10, 5, 10, 12,
|
||||
10, 20, 30, 11, 15, 10, 5, 10, 12, 10, 20, 30, 11, 15, 10, 5, 10, 12,
|
||||
10, 20, 30, 11, 15, 10, 5, 10, 12, 10, 20, 30, 11, 15, 10, 5, 10, 12};
|
||||
std::vector<float> except_out = {1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0};
|
||||
std::vector<int> in_shape = {2, 3, 3, 3};
|
||||
std::vector<int> out_shape = {2, 3, 1, 3};
|
||||
|
||||
TypeId data_type = kNumberTypeFloat32;
|
||||
schema::Format format = schema::Format_NHWC;
|
||||
test_main_argminmax<float>(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format);
|
||||
}
|
||||
TEST_F(TestArgMinMaxOpenCL, axis2topk2value) {
|
||||
ArgMinMaxParameter *param = std::make_unique<ArgMinMaxParameter>().release();
|
||||
if (param == nullptr) {
|
||||
return;
|
||||
}
|
||||
std::vector<float> in_data = {10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90,
|
||||
20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50,
|
||||
30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30};
|
||||
std::vector<float> except_out = {30, 45, 30, 50, 90, 20, 20, 25, 40, 50, 30, 45, 30, 50, 90, 20, 20, 25, 40, 50,
|
||||
30, 45, 30, 50, 90, 20, 20, 25, 40, 50, 30, 45, 30, 50, 90, 20, 20, 25, 40, 50};
|
||||
param->dims_size_ = 4;
|
||||
param->axis_ = 2;
|
||||
param->topk_ = 2;
|
||||
param->get_max_ = true;
|
||||
param->out_value_ = true;
|
||||
std::vector<int> in_shape = {2, 2, 3, 5};
|
||||
std::vector<int> out_shape = {1, 2, 2, 5};
|
||||
|
||||
TypeId data_type = kNumberTypeFloat32;
|
||||
schema::Format format = schema::Format_NHWC;
|
||||
test_main_argminmax<float>(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format);
|
||||
}
|
||||
TEST_F(TestArgMinMaxOpenCL, axis2topk2index) {
|
||||
ArgMinMaxParameter *param = std::make_unique<ArgMinMaxParameter>().release();
|
||||
if (param == nullptr) {
|
||||
return;
|
||||
}
|
||||
std::vector<float> in_data = {10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90,
|
||||
20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50,
|
||||
30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30};
|
||||
std::vector<float> except_out = {2, 2, 0, 2, 0, 1, 0, 2, 0, 1, 2, 2, 0, 2, 0, 1, 0, 2, 0, 1,
|
||||
2, 2, 0, 2, 0, 1, 0, 2, 0, 1, 2, 2, 0, 2, 0, 1, 0, 2, 0, 1};
|
||||
param->dims_size_ = 4;
|
||||
param->axis_ = 2;
|
||||
param->topk_ = 2;
|
||||
param->get_max_ = true;
|
||||
param->out_value_ = false;
|
||||
std::vector<int> in_shape = {2, 2, 3, 5};
|
||||
std::vector<int> out_shape = {2, 2, 2, 5};
|
||||
|
||||
TypeId data_type = kNumberTypeFloat32;
|
||||
schema::Format format = schema::Format_NHWC;
|
||||
test_main_argminmax<float>(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format);
|
||||
}
|
||||
TEST_F(TestArgMinMaxOpenCL, axis3topk2index) {
|
||||
ArgMinMaxParameter *param = std::make_unique<ArgMinMaxParameter>().release();
|
||||
if (param == nullptr) {
|
||||
return;
|
||||
}
|
||||
std::vector<float> in_data = {10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90,
|
||||
20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50,
|
||||
30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30};
|
||||
std::vector<float> except_out = {4, 3, 4, 0, 3, 1, 4, 3, 4, 0, 3, 1, 4, 3, 4, 0, 3, 1, 4, 3, 4, 0, 3, 1};
|
||||
param->dims_size_ = 4;
|
||||
param->axis_ = 3;
|
||||
param->topk_ = 2;
|
||||
param->get_max_ = true;
|
||||
param->out_value_ = false;
|
||||
std::vector<int> in_shape = {2, 2, 3, 5};
|
||||
std::vector<int> out_shape = {2, 2, 3, 2};
|
||||
|
||||
TypeId data_type = kNumberTypeFloat32;
|
||||
schema::Format format = schema::Format_NHWC;
|
||||
test_main_argminmax<float>(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format);
|
||||
}
|
||||
TEST_F(TestArgMinMaxOpenCL, axis3topk2value) {
|
||||
ArgMinMaxParameter *param = std::make_unique<ArgMinMaxParameter>().release();
|
||||
if (param == nullptr) {
|
||||
return;
|
||||
}
|
||||
std::vector<float> in_data = {10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90,
|
||||
20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50,
|
||||
30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30};
|
||||
std::vector<float> except_out = {90, 40, 50, 20, 50, 45, 90, 40, 50, 20, 50, 45,
|
||||
90, 40, 50, 20, 50, 45, 90, 40, 50, 20, 50, 45};
|
||||
param->dims_size_ = 4;
|
||||
param->axis_ = 3;
|
||||
param->topk_ = 2;
|
||||
param->get_max_ = true;
|
||||
param->out_value_ = true;
|
||||
std::vector<int> in_shape = {2, 2, 3, 5};
|
||||
std::vector<int> out_shape = {2, 2, 3, 2};
|
||||
|
||||
TypeId data_type = kNumberTypeFloat32;
|
||||
schema::Format format = schema::Format_NHWC;
|
||||
test_main_argminmax<float>(in_data.data(), except_out.data(), in_shape, out_shape, param, data_type, format);
|
||||
}
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue