forked from mindspore-Ecosystem/mindspore
add opencl leaky relu kernel
This commit is contained in:
parent
290c93a9a7
commit
2eb55946f6
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_LEAKYRELU_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_LEAKYRELU_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
typedef struct LeakyReluParameter {
|
||||
OpParameter op_parameter_;
|
||||
float alpha;
|
||||
} LeakyReluParameter;
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_LEAKYRELU_H_
|
|
@ -16,7 +16,6 @@ __kernel void LeakyRelu(__read_only image2d_t input, __write_only image2d_t outp
|
|||
|
||||
int Y = get_global_id(0); // height id
|
||||
int X = get_global_id(1); // weight id
|
||||
|
||||
for (int num = 0; num < UP_DIV(C, SLICES); ++num) {
|
||||
FLT4 in_c4 = READ_FLT4(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC
|
||||
FLT4 tmp;
|
||||
|
|
|
@ -175,6 +175,10 @@ kernel::LiteKernel *OpenCLConv2dTransposeKernelCreator(const std::vector<lite::t
|
|||
const kernel::KernelKey &desc,
|
||||
const lite::Primitive *primitive) {
|
||||
auto *kernel = new Conv2dTransposeOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (0 != ret) {
|
||||
// MS_LOG(ERROR) << "Init kernel failed, name: " << opDef.name()->str()
|
||||
|
|
|
@ -194,6 +194,10 @@ kernel::LiteKernel *OpenCLDepthwiseConv2dKernelCreator(const std::vector<lite::t
|
|||
const kernel::KernelKey &desc,
|
||||
const lite::Primitive *primitive) {
|
||||
auto *kernel = new DepthwiseConv2dOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (0 != ret) {
|
||||
MS_LOG(ERROR) << "Init DepthwiseConv2dOpenCLKernel failed!";
|
||||
|
|
|
@ -18,97 +18,105 @@
|
|||
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/kernel/opencl/kernel/leaky_relu.h"
|
||||
#include "src/runtime/opencl/opencl_runtime.h"
|
||||
#include "src/runtime/kernel/opencl/cl/fp32/leaky_relu.cl.inc"
|
||||
#include "src/runtime/kernel/arm/nnacl/leaky_relu_parameter.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kGPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_LeakyReLU;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
int LeakyReluOpenCLKernel::Init() {
|
||||
if (inputs_[0]->shape().size() != 4) {
|
||||
MS_LOG(ERROR) << "leaky_relu only support dim=4, but your dim=" << inputs_[0]->shape().size();
|
||||
}
|
||||
std::set<std::string> build_options;
|
||||
std::string source = leaky_relu_source_fp32;
|
||||
std::string program_name = "LeakyRelu";
|
||||
std::string kernel_name = "LeakyRelu";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->LoadSource(program_name, source);
|
||||
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
|
||||
|
||||
MS_LOG(DEBUG) << kernel_name << " Init Done!";
|
||||
return RET_OK;
|
||||
int LeakyReluOpenCLKernel::Init() {
|
||||
if (in_tensors_[0]->shape().size() != 4) {
|
||||
MS_LOG(ERROR) << "leaky_relu only support dim=4, but your dim=" << in_tensors_[0]->shape().size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::set<std::string> build_options;
|
||||
std::string source = leaky_relu_source_fp32;
|
||||
std::string program_name = "LeakyRelu";
|
||||
std::string kernel_name = "LeakyRelu";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->LoadSource(program_name, source);
|
||||
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
|
||||
|
||||
MS_LOG(DEBUG) << kernel_name << " Init Done!";
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LeakyReluOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
|
||||
int H = inputs_[0]->shape()[1];
|
||||
int W = inputs_[0]->shape()[2];
|
||||
int C = inputs_[0]->shape()[3];
|
||||
int LeakyReluOpenCLKernel::Run() {
|
||||
auto param = reinterpret_cast<LeakyReluParameter *>(op_parameter_);
|
||||
MS_LOG(DEBUG) << " Running!";
|
||||
int N = in_tensors_[0]->shape()[0];
|
||||
int H = in_tensors_[0]->shape()[1];
|
||||
int W = in_tensors_[0]->shape()[2];
|
||||
int C = in_tensors_[0]->shape()[3];
|
||||
cl_int4 input_shape = {N, H, W, C};
|
||||
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
int arg_idx = 0;
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data());
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data());
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape);
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, param->alpha);
|
||||
|
||||
std::vector<size_t> local = {1, 1};
|
||||
std::vector<size_t> global = {static_cast<size_t>(H), static_cast<size_t>(W)};
|
||||
ocl_runtime->RunKernel(kernel_, global, local, nullptr);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LeakyReluOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
|
||||
int H = in_tensors_[0]->shape()[1];
|
||||
int W = in_tensors_[0]->shape()[2];
|
||||
int C = in_tensors_[0]->shape()[3];
|
||||
|
||||
#ifdef ENABLE_FP16
|
||||
size_t img_dtype = CL_HALF_FLOAT;
|
||||
size_t img_dtype = CL_HALF_FLOAT;
|
||||
#else
|
||||
size_t img_dtype = CL_FLOAT;
|
||||
size_t img_dtype = CL_FLOAT;
|
||||
#endif
|
||||
|
||||
img_size->clear();
|
||||
img_size->push_back(W * UP_DIV(C, C4NUM));
|
||||
img_size->push_back(H);
|
||||
img_size->push_back(img_dtype);
|
||||
return RET_OK;
|
||||
}
|
||||
img_size->clear();
|
||||
img_size->push_back(W * UP_DIV(C, C4NUM));
|
||||
img_size->push_back(H);
|
||||
img_size->push_back(img_dtype);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LeakyReluOpenCLKernel::Run() {
|
||||
auto param = reinterpret_cast<LeakyReluParameter *>(this->opParameter);
|
||||
MS_LOG(DEBUG) << this->Name() << " Running!";
|
||||
int N = inputs_[0]->shape()[0];
|
||||
int H = inputs_[0]->shape()[1];
|
||||
int W = inputs_[0]->shape()[2];
|
||||
int C = inputs_[0]->shape()[3];
|
||||
cl_int4 input_shape = {N, H, W, C};
|
||||
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
int arg_idx = 0;
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, inputs_[0]->Data());
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, outputs_[0]->Data());
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape);
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, param->alpha);
|
||||
|
||||
std::vector<size_t> local = {1, 1};
|
||||
std::vector<size_t> global = {static_cast<size_t>(H), static_cast<size_t>(W)};
|
||||
ocl_runtime->RunKernel(kernel_, global, local, nullptr);
|
||||
return 0;
|
||||
kernel::LiteKernel *OpenCLLeakyReluKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||
OpParameter *opParameter, const lite::Context *ctx,
|
||||
const kernel::KernelKey &desc, const lite::Primitive *primitive) {
|
||||
if (inputs.size() == 0) {
|
||||
MS_LOG(ERROR) << "Input data size must be greater than 0, but your size is " << inputs.size();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *OpenCLLeakyReluKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||
OpParameter *opParameter, const lite::Context *ctx,
|
||||
const kernel::KernelKey &desc, const lite::Primitive *primitive) {
|
||||
auto *kernel = new LeakyReluOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs);
|
||||
if (inputs.size() == 0) {
|
||||
MS_LOG(ERROR) << "Input data size must must be greater than 0, but your size is " << inputs.size();
|
||||
}
|
||||
if (inputs[0]->shape()[0] > 1) {
|
||||
MS_LOG(ERROR) << "Init `leaky relu` kernel failed: Unsupported multi-batch.";
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (0 != ret) {
|
||||
MS_LOG(ERROR) << "Init `Leaky Relu` kernel failed!";
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
if (inputs[0]->shape()[0] > 1) {
|
||||
MS_LOG(ERROR) << "Init `leaky relu` kernel failed: Unsupported multi-batch.";
|
||||
return nullptr;
|
||||
}
|
||||
auto *kernel = new LeakyReluOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init `Leaky Relu` kernel failed!";
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_LeakyReLU, OpenCLLeakyReluKernelCreator)
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_LeakyReLU, OpenCLLeakyReluKernelCreator)
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -14,18 +14,15 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_LEAKYRELU_H_
|
||||
#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_LEAKYRELU_H_
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_LEAKYRELU_H
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_LEAKYRELU_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "src/runtime/opencl/opencl_runtime.h"
|
||||
#include <string>
|
||||
#include "src/ir/tensor.h"
|
||||
#include "src/runtime/kernel/opencl/opencl_kernel.h"
|
||||
|
||||
struct LeakyReluParameter {
|
||||
OpParameter op_parameter_;
|
||||
cl_float alpha;
|
||||
};
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/runtime/opencl/opencl_runtime.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
|
@ -46,4 +43,4 @@ class LeakyReluOpenCLKernel : public OpenCLKernel {
|
|||
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_LEAKYRELU_H_
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_LEAKYRELU_H_
|
||||
|
|
|
@ -161,6 +161,10 @@ kernel::LiteKernel *OpenCLMatMulKernelCreator(const std::vector<lite::tensor::Te
|
|||
hasBias = (reinterpret_cast<MatMulParameter *>(opParameter))->has_bias_;
|
||||
}
|
||||
auto *kernel = new MatMulOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs, hasBias);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (0 != ret) {
|
||||
// MS_LOG(ERROR) << "Init kernel failed, name: " << opDef.name()->str()
|
||||
|
|
|
@ -87,6 +87,10 @@ kernel::LiteKernel *OpenCLSoftMaxKernelCreator(const std::vector<lite::tensor::T
|
|||
OpParameter *opParameter, const lite::Context *ctx,
|
||||
const kernel::KernelKey &desc, const lite::Primitive *primitive) {
|
||||
auto *kernel = new SoftmaxOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
if (inputs[0]->shape()[0] > 1) {
|
||||
MS_LOG(ERROR) << "Init `Softmax` kernel failed: Unsupported multi-batch.";
|
||||
}
|
||||
|
|
|
@ -110,6 +110,10 @@ kernel::LiteKernel *OpenCLTransposeKernelCreator(const std::vector<lite::tensor:
|
|||
OpParameter *opParameter, const lite::Context *ctx,
|
||||
const kernel::KernelKey &desc, const lite::Primitive *primitive) {
|
||||
auto *kernel = new TransposeOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (0 != ret) {
|
||||
delete kernel;
|
||||
|
|
|
@ -142,7 +142,7 @@ if (SUPPORT_GPU)
|
|||
${LITE_DIR}/src/runtime/kernel/opencl/kernel/matmul.cc
|
||||
${LITE_DIR}/src/runtime/kernel/opencl/kernel/softmax.cc
|
||||
${LITE_DIR}/src/runtime/kernel/opencl/kernel/concat.cc
|
||||
# ${LITE_DIR}/src/runtime/kernel/opencl/kernel/leaky_relu.cc
|
||||
${LITE_DIR}/src/runtime/kernel/opencl/kernel/leaky_relu.cc
|
||||
${LITE_DIR}/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc
|
||||
${LITE_DIR}/src/runtime/kernel/opencl/kernel/transpose.cc
|
||||
)
|
||||
|
@ -320,7 +320,7 @@ if (SUPPORT_GPU)
|
|||
${TEST_DIR}/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/opencl/transpose_tests.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/opencl/convolution_tests.cc
|
||||
# ${TEST_DIR}/ut/src/runtime/kernel/opencl/leakyrelu_tests.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/opencl/leakyrelu_tests.cc
|
||||
)
|
||||
endif()
|
||||
|
||||
|
|
|
@ -21,12 +21,14 @@
|
|||
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/arm/nnacl/leaky_relu_parameter.h"
|
||||
|
||||
using mindspore::kernel::LeakyReluOpenCLKernel;
|
||||
using mindspore::kernel::LiteKernel;
|
||||
using mindspore::kernel::SubGraphOpenCLKernel;
|
||||
|
||||
namespace mindspore {
|
||||
class TestLeakyReluOpenCL : public mindspore::Common {
|
||||
public:
|
||||
TestLeakyReluOpenCL() {}
|
||||
};
|
||||
class TestLeakyReluOpenCL : public mindspore::CommonTest {};
|
||||
|
||||
void LoadDataLeakyRelu(void *dst, size_t dst_size, const std::string &file_path) {
|
||||
if (file_path.empty()) {
|
||||
|
@ -99,7 +101,6 @@ TEST_F(TestLeakyReluOpenCL, LeakyReluFp32_dim4) {
|
|||
LoadDataLeakyRelu(input_tensor->Data(), input_tensor->Size(), in_file);
|
||||
MS_LOG(INFO) << "==================input data================";
|
||||
printf_tensor(inputs[0]);
|
||||
|
||||
sub_graph->Run();
|
||||
|
||||
MS_LOG(INFO) << "==================output data================";
|
||||
|
|
Loading…
Reference in New Issue