!5042 rm caffe_prelu
Merge pull request !5042 from liuzhongkai/delete_caffe_prelu
This commit is contained in:
commit
33fdc43f18
|
@ -1,23 +0,0 @@
|
|||
#pragma OPENCL EXTENSION cl_arm_printf : enable
|
||||
|
||||
#define SLICES 4
|
||||
#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))
|
||||
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
|
||||
__kernel void CaffePRelu(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape,
|
||||
__global float *alpha) {
|
||||
int C = input_shape.w; // channel size
|
||||
|
||||
int Y = get_global_id(0); // height id
|
||||
int X = get_global_id(1); // weight id
|
||||
for (int num = 0; num < UP_DIV(C, SLICES); ++num) {
|
||||
FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC
|
||||
FLT4 tmp;
|
||||
int index = num * 4;
|
||||
tmp.x = in_c4.x * alpha[index];
|
||||
tmp.y = in_c4.y * alpha[index + 1];
|
||||
tmp.z = in_c4.z * alpha[index + 2];
|
||||
tmp.w = in_c4.w * alpha[index + 3];
|
||||
WRITE_IMAGE(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC
|
||||
}
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
#pragma OPENCL EXTENSION cl_arm_printf : enable
|
||||
|
||||
#define SLICES 4
|
||||
#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))
|
||||
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
|
||||
__kernel void PRelu(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape,
|
||||
__global float *alpha, const int dim) {
|
||||
int C = input_shape.w; // channel size
|
||||
|
||||
int Y = get_global_id(0); // height id
|
||||
int X = get_global_id(1); // weight id
|
||||
for (int num = 0; num < UP_DIV(C, SLICES); ++num) {
|
||||
FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC
|
||||
FLT4 tmp;
|
||||
if (dim == 1) {
|
||||
tmp.x = in_c4.x >= 0 ? in_c4.x : in_c4.x * (*alpha);
|
||||
tmp.y = in_c4.y >= 0 ? in_c4.y : in_c4.y * (*alpha);
|
||||
tmp.z = in_c4.z >= 0 ? in_c4.z : in_c4.z * (*alpha);
|
||||
tmp.w = in_c4.w >= 0 ? in_c4.w : in_c4.w * (*alpha);
|
||||
} else {
|
||||
int index = num * 4;
|
||||
tmp.x = in_c4.x >= 0 ? in_c4.x : in_c4.x * alpha[index];
|
||||
tmp.y = in_c4.y >= 0 ? in_c4.y : in_c4.y * alpha[index + 1];
|
||||
tmp.z = in_c4.z >= 0 ? in_c4.z : in_c4.z * alpha[index + 2];
|
||||
tmp.w = in_c4.w >= 0 ? in_c4.w : in_c4.w * alpha[index + 3];
|
||||
}
|
||||
WRITE_IMAGE(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC
|
||||
}
|
||||
}
|
|
@ -1,152 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
*
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/kernel/opencl/kernel/caffe_prelu.h"
|
||||
#include "src/runtime/opencl/opencl_runtime.h"
|
||||
#include "src/runtime/kernel/opencl/cl/caffe_prelu.cl.inc"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kGPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_CaffePReLU;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
void CaffePReluOpenCLKernel::CaffeWeight() {
|
||||
int C = in_tensors_[1]->shape()[0];
|
||||
int div_ci = UP_DIV(C, C4NUM);
|
||||
std::cout << div_ci << std::endl;
|
||||
auto allocator = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator();
|
||||
CaffeWeight_ = reinterpret_cast<FLOAT_t *>(allocator->Malloc(div_ci * C4NUM * sizeof(FLOAT_t)));
|
||||
CaffeWeight_ = reinterpret_cast<FLOAT_t *>(allocator->MapBuffer(CaffeWeight_, CL_MAP_WRITE, nullptr, true));
|
||||
memset(CaffeWeight_, 0x00, div_ci * C4NUM * sizeof(FLOAT_t));
|
||||
auto origin_weight = reinterpret_cast<FLOAT_t *>(in_tensors_[1]->Data());
|
||||
for (int i = 0; i < in_tensors_[1]->ElementsNum(); ++i) {
|
||||
CaffeWeight_[i] = origin_weight[i];
|
||||
}
|
||||
allocator->UnmapBuffer(CaffeWeight_);
|
||||
}
|
||||
|
||||
int CaffePReluOpenCLKernel::Init() {
|
||||
if (in_tensors_[0]->shape().size() != 4) {
|
||||
MS_LOG(ERROR) << "Caffe PRelu only support dim=4, but your dim=" << in_tensors_[0]->shape().size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
CaffeWeight();
|
||||
std::set<std::string> build_options;
|
||||
std::string source = caffe_prelu_source;
|
||||
std::string program_name = "CaffePRelu";
|
||||
std::string kernel_name = "CaffePRelu";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->LoadSource(program_name, source);
|
||||
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
|
||||
in_ori_format_ = in_tensors_[0]->GetFormat();
|
||||
in_tensors_[0]->SetFormat(schema::Format_NHWC4);
|
||||
out_ori_format_ = out_tensors_[0]->GetFormat();
|
||||
out_tensors_[0]->SetFormat(schema::Format_NHWC4);
|
||||
MS_LOG(DEBUG) << program_name << " Init Done!";
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int CaffePReluOpenCLKernel::Run() {
|
||||
int N = in_tensors_[0]->shape()[0];
|
||||
int H = in_tensors_[0]->shape()[1];
|
||||
int W = in_tensors_[0]->shape()[2];
|
||||
int C = in_tensors_[0]->shape()[3];
|
||||
|
||||
cl_int4 input_shape = {N, H, W, C};
|
||||
int C_Weight = in_tensors_[1]->shape()[0];
|
||||
if (UP_DIV(C_Weight, C4NUM) != UP_DIV(C, C4NUM)) {
|
||||
MS_LOG(ERROR) << "CaffePRelu weight channel size:" << C_Weight
|
||||
<< " must be equal with in_teneors channel size:" << C;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << op_parameter_->name_ << " Running!";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
int arg_idx = 0;
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data());
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data());
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape);
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, CaffeWeight_);
|
||||
std::vector<size_t> local = {1, 1};
|
||||
std::vector<size_t> global = {static_cast<size_t>(H), static_cast<size_t>(W)};
|
||||
auto ret = ocl_runtime->RunKernel(kernel_, global, local, nullptr);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Run kernel " << op_parameter_->name_ << " error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int CaffePReluOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
|
||||
int H = in_tensors_[0]->shape()[1];
|
||||
int W = in_tensors_[0]->shape()[2];
|
||||
int C = in_tensors_[0]->shape()[3];
|
||||
|
||||
#ifdef ENABLE_FP16
|
||||
size_t img_dtype = CL_HALF_FLOAT;
|
||||
#else
|
||||
size_t img_dtype = CL_FLOAT;
|
||||
#endif
|
||||
|
||||
img_size->clear();
|
||||
img_size->push_back(W * UP_DIV(C, C4NUM));
|
||||
img_size->push_back(H);
|
||||
img_size->push_back(img_dtype);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *OpenCLCaffePReluKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||
OpParameter *opParameter, const lite::Context *ctx,
|
||||
const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
if (inputs.size() == 0) {
|
||||
MS_LOG(ERROR) << "Input data size must be greater than 0, but your size is " << inputs.size();
|
||||
return nullptr;
|
||||
}
|
||||
if (inputs[0]->shape()[0] > 1) {
|
||||
MS_LOG(ERROR) << "Init CaffePRelu kernel failed: Unsupported multi-batch.";
|
||||
return nullptr;
|
||||
}
|
||||
auto *kernel =
|
||||
new (std::nothrow) CaffePReluOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "Kernel " << opParameter->name_ << "is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init CaffePRelu kernel failed!";
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_CaffePReLU, OpenCLCaffePReluKernelCreator)
|
||||
} // namespace mindspore::kernel
|
|
@ -1,49 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_CAFFEPRELU_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_CAFFEPRELU_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "src/ir/tensor.h"
|
||||
#include "src/runtime/kernel/opencl/opencl_kernel.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/runtime/opencl/opencl_runtime.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
class CaffePReluOpenCLKernel : public OpenCLKernel {
|
||||
public:
|
||||
explicit CaffePReluOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs)
|
||||
: OpenCLKernel(parameter, inputs, outputs) {}
|
||||
~CaffePReluOpenCLKernel() override{};
|
||||
|
||||
int Init() override;
|
||||
int Run() override;
|
||||
int GetImageSize(size_t idx, std::vector<size_t> *img_size) override;
|
||||
void CaffeWeight();
|
||||
|
||||
private:
|
||||
cl::Kernel kernel_;
|
||||
FLOAT_t *CaffeWeight_;
|
||||
};
|
||||
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_CAFFEPRELU_H_
|
|
@ -23,8 +23,7 @@
|
|||
#include "include/errorcode.h"
|
||||
#include "src/runtime/kernel/opencl/kernel/prelu.h"
|
||||
#include "src/runtime/opencl/opencl_runtime.h"
|
||||
#include "src/runtime/kernel/opencl/cl/activation.cl.inc"
|
||||
#include "nnacl/prelu_parameter.h"
|
||||
#include "src/runtime/kernel/opencl/cl/prelu.cl.inc"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kGPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
|
@ -34,15 +33,41 @@ using mindspore::schema::PrimitiveType_Prelu;
|
|||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
void PReluOpenCLKernel::InitBuffer() {
|
||||
int C = in_tensors_[1]->shape()[0];
|
||||
int div_ci = UP_DIV(C, C4NUM);
|
||||
std::cout << div_ci << std::endl;
|
||||
auto allocator = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator();
|
||||
PReluWeight_ = reinterpret_cast<FLOAT_t *>(allocator->Malloc(div_ci * C4NUM * sizeof(FLOAT_t)));
|
||||
PReluWeight_ = reinterpret_cast<FLOAT_t *>(allocator->MapBuffer(PReluWeight_, CL_MAP_WRITE, nullptr, true));
|
||||
memset(PReluWeight_, 0x00, div_ci * C4NUM * sizeof(FLOAT_t));
|
||||
auto origin_weight = reinterpret_cast<FLOAT_t *>(in_tensors_[1]->Data());
|
||||
for (int i = 0; i < in_tensors_[1]->ElementsNum(); ++i) {
|
||||
PReluWeight_[i] = origin_weight[i];
|
||||
}
|
||||
allocator->UnmapBuffer(PReluWeight_);
|
||||
}
|
||||
|
||||
int PReluOpenCLKernel::Init() {
|
||||
if (in_tensors_[0]->shape().size() != 4) {
|
||||
MS_LOG(ERROR) << "PRelu only support dim=4, but your dim=" << in_tensors_[0]->shape().size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
int C_Weight = in_tensors_[1]->shape()[0];
|
||||
int C = in_tensors_[0]->shape()[3];
|
||||
if (C_Weight != 1 && UP_DIV(C_Weight, C4NUM) != UP_DIV(C, C4NUM)) {
|
||||
MS_LOG(ERROR)
|
||||
<< "PRelu weight channel size must be 1 or must be equal with in_teneors channel size, but your weight size is "
|
||||
<< C_Weight << " and your input channel size is " << C;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (C_Weight != 1) {
|
||||
InitBuffer();
|
||||
}
|
||||
std::set<std::string> build_options;
|
||||
std::string source = activation_source;
|
||||
std::string source = prelu_source;
|
||||
std::string program_name = "PRelu";
|
||||
std::string kernel_name = "ReluScalar";
|
||||
std::string kernel_name = "PRelu";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->LoadSource(program_name, source);
|
||||
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
|
||||
|
@ -61,17 +86,18 @@ int PReluOpenCLKernel::Run() {
|
|||
int W = in_tensors_[0]->shape()[2];
|
||||
int C = in_tensors_[0]->shape()[3];
|
||||
cl_int4 input_shape = {N, H, W, C};
|
||||
if (in_tensors_[1]->ElementsNum() < 1) {
|
||||
MS_LOG(ERROR) << "PRelu weight size must be greater than 1! But your weight size is "
|
||||
<< in_tensors_[1]->ElementsNum();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
int arg_idx = 0;
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data());
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data());
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape);
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, reinterpret_cast<float *>(in_tensors_[1]->Data())[0]);
|
||||
if (in_tensors_[1]->shape()[0] == 1) {
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, reinterpret_cast<float *>(in_tensors_[1]->Data()));
|
||||
} else {
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, PReluWeight_);
|
||||
}
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, reinterpret_cast<int>(in_tensors_[1]->shape()[0]));
|
||||
|
||||
std::vector<size_t> local = {1, 1};
|
||||
std::vector<size_t> global = {static_cast<size_t>(H), static_cast<size_t>(W)};
|
||||
|
|
|
@ -36,9 +36,11 @@ class PReluOpenCLKernel : public OpenCLKernel {
|
|||
int Init() override;
|
||||
int Run() override;
|
||||
int GetImageSize(size_t idx, std::vector<size_t> *img_size) override;
|
||||
void InitBuffer();
|
||||
|
||||
private:
|
||||
cl::Kernel kernel_;
|
||||
FLOAT_t *PReluWeight_;
|
||||
};
|
||||
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -155,7 +155,6 @@ if (SUPPORT_GPU)
|
|||
${LITE_DIR}/src/runtime/kernel/opencl/kernel/transpose.cc
|
||||
${LITE_DIR}/src/runtime/kernel/opencl/kernel/reshape.cc
|
||||
${LITE_DIR}/src/runtime/kernel/opencl/kernel/to_format.cc
|
||||
${LITE_DIR}/src/runtime/kernel/opencl/kernel/caffe_prelu.cc
|
||||
${LITE_DIR}/src/runtime/kernel/opencl/kernel/prelu.cc
|
||||
${LITE_DIR}/src/runtime/kernel/opencl/kernel/to_format.cc
|
||||
${LITE_DIR}/src/runtime/kernel/opencl/kernel/biasadd.cc
|
||||
|
@ -336,7 +335,6 @@ if (SUPPORT_GPU)
|
|||
${TEST_DIR}/ut/src/runtime/kernel/opencl/convolution_tests.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/opencl/activation_tests.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/opencl/to_format_tests.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/opencl/caffe_prelu_tests.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/opencl/prelu_tests.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/opencl/reshape_tests.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/opencl/biasadd_tests.cc
|
||||
|
|
|
@ -1,203 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <iostream>
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
#include "common/common_test.h"
|
||||
#include "mindspore/lite/src/common/file_utils.h"
|
||||
#include "nnacl/pack.h"
|
||||
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/caffe_prelu.h"
|
||||
#include "mindspore/lite/nnacl/prelu_parameter.h"
|
||||
|
||||
using mindspore::kernel::CaffePReluOpenCLKernel;
|
||||
using mindspore::kernel::LiteKernel;
|
||||
using mindspore::kernel::SubGraphOpenCLKernel;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
|
||||
namespace mindspore {
|
||||
class TestCaffePReluOpenCL : public mindspore::CommonTest {};
|
||||
|
||||
void LoadDataCaffePRelu(void *dst, size_t dst_size, const std::string &file_path) {
|
||||
if (file_path.empty()) {
|
||||
memset(dst, 0x00, dst_size);
|
||||
} else {
|
||||
auto src_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &dst_size));
|
||||
memcpy(dst, src_data, dst_size);
|
||||
}
|
||||
}
|
||||
|
||||
void CompareOutCaffePRelu(lite::tensor::Tensor *output_tensor, const std::string &standard_answer_file) {
|
||||
auto *output_data = reinterpret_cast<float *>(output_tensor->Data());
|
||||
size_t output_size = output_tensor->ElementsC4Num();
|
||||
auto expect_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size));
|
||||
constexpr float atol = 0.0002;
|
||||
for (int i = 0; i < output_tensor->ElementsC4Num(); ++i) {
|
||||
if (std::fabs(output_data[i] - expect_data[i]) > atol) {
|
||||
printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]);
|
||||
printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]);
|
||||
printf("error at idx[%d] expect=%.3f output=%.3f\n\n\n", i, expect_data[i], output_data[i]);
|
||||
return;
|
||||
}
|
||||
}
|
||||
printf("compare success!\n");
|
||||
printf("compare success!\n");
|
||||
printf("compare success!\n\n\n");
|
||||
}
|
||||
|
||||
void printf_tensor_caffeprelu(mindspore::lite::tensor::Tensor *in_data, int size) {
|
||||
auto input_data = reinterpret_cast<float *>(in_data->Data());
|
||||
for (int i = 0; i < size; ++i) {
|
||||
printf("%f ", input_data[i]);
|
||||
}
|
||||
printf("\n");
|
||||
MS_LOG(INFO) << "Print tensor done";
|
||||
}
|
||||
|
||||
void printf_float(float *data, int num = 0) {
|
||||
float *temp = data;
|
||||
for (int i = 0; i < num; ++i) {
|
||||
std::cout << *temp << " ";
|
||||
temp++;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
TEST_F(TestCaffePReluOpenCL, CaffePReluFp32_dim4) {
|
||||
std::string in_file = "/data/local/tmp/in_data.bin";
|
||||
std::string weight_file = "/data/local/tmp/weight_data.bin";
|
||||
std::string standard_answer_file = "/data/local/tmp/caffeprelu.bin";
|
||||
MS_LOG(INFO) << "CaffePRelu Begin test:";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->Init();
|
||||
auto allocator = ocl_runtime->GetAllocator();
|
||||
|
||||
MS_LOG(INFO) << "CaffePRelu init tensors.";
|
||||
|
||||
std::vector<int> input_shape = {1, 4, 3, 9};
|
||||
std::vector<int> output_shape = {1, 4, 3, 9};
|
||||
auto data_type = kNumberTypeFloat32;
|
||||
auto tensor_type = schema::NodeType_ValueNode;
|
||||
auto *input_tensor =
|
||||
new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC, tensor_type);
|
||||
if (input_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "new input tensor error";
|
||||
return;
|
||||
}
|
||||
auto *output_tensor = new lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC4, tensor_type);
|
||||
if (output_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "new output_tensor error";
|
||||
delete input_tensor;
|
||||
return;
|
||||
}
|
||||
auto *weight_tensor = new (std::nothrow)
|
||||
lite::tensor::Tensor(data_type, std::vector<int>{input_shape[3]}, schema::Format_NHWC, tensor_type);
|
||||
if (weight_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "new weight_tensor error";
|
||||
delete input_tensor;
|
||||
delete output_tensor;
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<lite::tensor::Tensor *> inputs{input_tensor, weight_tensor};
|
||||
std::vector<lite::tensor::Tensor *> outputs{output_tensor};
|
||||
inputs[0]->MallocData(allocator);
|
||||
inputs[1]->MallocData(allocator);
|
||||
std::cout << input_tensor->Size() << std::endl;
|
||||
LoadDataCaffePRelu(input_tensor->Data(), input_tensor->Size(), in_file);
|
||||
MS_LOG(INFO) << "CaffePRelu==================input data================";
|
||||
printf_tensor_caffeprelu(inputs[0], input_tensor->ElementsNum());
|
||||
|
||||
LoadDataCaffePRelu(weight_tensor->Data(), weight_tensor->Size(), weight_file);
|
||||
MS_LOG(INFO) << "CaffePRelu==================weight data================";
|
||||
printf_tensor_caffeprelu(inputs[1], weight_tensor->ElementsNum());
|
||||
|
||||
auto param = new (std::nothrow) PReluParameter();
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "new param error!";
|
||||
delete input_tensor;
|
||||
delete output_tensor;
|
||||
delete weight_tensor;
|
||||
return;
|
||||
}
|
||||
param->channel_num_ = input_shape[3];
|
||||
auto *caffeprelu_kernel =
|
||||
new (std::nothrow) kernel::CaffePReluOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
|
||||
if (caffeprelu_kernel == nullptr) {
|
||||
delete param;
|
||||
delete input_tensor;
|
||||
delete output_tensor;
|
||||
delete weight_tensor;
|
||||
MS_LOG(ERROR) << "Create caffe prelu kernel error.";
|
||||
return;
|
||||
}
|
||||
|
||||
auto ret = caffeprelu_kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
delete param;
|
||||
delete input_tensor;
|
||||
delete output_tensor;
|
||||
delete weight_tensor;
|
||||
delete caffeprelu_kernel;
|
||||
MS_LOG(ERROR) << "caffeprelu_kernel init error.";
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "initialize sub_graph";
|
||||
std::vector<kernel::LiteKernel *> kernels{caffeprelu_kernel};
|
||||
auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel({input_tensor}, outputs, kernels, kernels, kernels);
|
||||
if (sub_graph == nullptr) {
|
||||
delete param;
|
||||
delete input_tensor;
|
||||
delete output_tensor;
|
||||
delete weight_tensor;
|
||||
delete caffeprelu_kernel;
|
||||
MS_LOG(ERROR) << "Create sub_graph kernel error.";
|
||||
return;
|
||||
}
|
||||
ret = sub_graph->Init();
|
||||
if (ret != RET_OK) {
|
||||
delete param;
|
||||
delete input_tensor;
|
||||
delete output_tensor;
|
||||
delete weight_tensor;
|
||||
delete caffeprelu_kernel;
|
||||
delete sub_graph;
|
||||
MS_LOG(ERROR) << "sub_graph init error.";
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Sub graph begin running!";
|
||||
ret = sub_graph->Run();
|
||||
if (ret != RET_OK) {
|
||||
delete input_tensor;
|
||||
delete output_tensor;
|
||||
delete weight_tensor;
|
||||
delete sub_graph;
|
||||
MS_LOG(ERROR) << "sub_graph run error.";
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "CaffePRelu==================output data================";
|
||||
printf_tensor_caffeprelu(outputs[0], output_tensor->ElementsC4Num());
|
||||
CompareOutCaffePRelu(output_tensor, standard_answer_file);
|
||||
delete input_tensor;
|
||||
delete output_tensor;
|
||||
delete weight_tensor;
|
||||
delete sub_graph;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -62,26 +62,27 @@ void CompareOutPRelu(lite::tensor::Tensor *output_tensor, const std::string &sta
|
|||
|
||||
TEST_F(TestPReluOpenCL, PReluFp32_dim4) {
|
||||
std::string in_file = "/data/local/tmp/in_data.bin";
|
||||
std::string standard_answer_file = "/data/local/tmp/leaky_relu.bin";
|
||||
std::string weight_file = "/data/local/tmp/weight_data.bin";
|
||||
std::string standard_answer_file = "/data/local/tmp/caffe_prelu.bin";
|
||||
MS_LOG(INFO) << "-------------------->> Begin test PRelu!";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->Init();
|
||||
auto allocator = ocl_runtime->GetAllocator();
|
||||
|
||||
MS_LOG(INFO) << "Init tensors.";
|
||||
std::vector<int> input_shape = {1, 4, 3, 8};
|
||||
std::vector<int> input_shape = {1, 4, 3, 9};
|
||||
|
||||
auto data_type = kNumberTypeFloat32;
|
||||
auto tensor_type = schema::NodeType_ValueNode;
|
||||
auto input_tensor =
|
||||
new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensor_type);
|
||||
new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC, tensor_type);
|
||||
if (input_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "new input_tensor error!";
|
||||
return;
|
||||
}
|
||||
|
||||
auto output_tensor =
|
||||
new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensor_type);
|
||||
new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC, tensor_type);
|
||||
if (output_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "new output_tensor error";
|
||||
delete input_tensor;
|
||||
|
@ -89,7 +90,7 @@ TEST_F(TestPReluOpenCL, PReluFp32_dim4) {
|
|||
}
|
||||
|
||||
auto weight_tensor =
|
||||
new (std::nothrow) lite::tensor::Tensor(data_type, std::vector<int>{1}, schema::Format_NHWC, tensor_type);
|
||||
new (std::nothrow) lite::tensor::Tensor(data_type, std::vector<int>{9}, schema::Format_NHWC, tensor_type);
|
||||
if (weight_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "new weight_tensor error";
|
||||
delete input_tensor;
|
||||
|
@ -105,11 +106,13 @@ TEST_F(TestPReluOpenCL, PReluFp32_dim4) {
|
|||
|
||||
MS_LOG(INFO) << "initialize input data";
|
||||
LoadDataPRelu(input_tensor->Data(), input_tensor->Size(), in_file);
|
||||
LoadDataPRelu(weight_tensor->Data(), weight_tensor->Size(), weight_file);
|
||||
auto weight_data = reinterpret_cast<float *>(weight_tensor->Data());
|
||||
weight_data[0] = 0.3;
|
||||
PrintData("Weight data", weight_data, inputs[1]->ElementsNum());
|
||||
auto *input_data = reinterpret_cast<float *>(inputs[0]->Data());
|
||||
PrintData("PRelu input data", input_data, inputs[0]->ElementsC4Num());
|
||||
|
||||
PrintData("PRelu input data", input_data, inputs[0]->ElementsNum());
|
||||
std::cout << inputs[0]->ElementsNum() << std::endl;
|
||||
std::cout << "--------------------------------------------" << std::endl;
|
||||
auto param = new (std::nothrow) PReluParameter();
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "new PreluParameter error";
|
||||
|
|
Loading…
Reference in New Issue