!4446 fix bug in conv2d transpose

Merge pull request !4446 from chenzupeng/master-lite
This commit is contained in:
mindspore-ci-bot 2020-08-14 21:17:13 +08:00 committed by Gitee
commit bf90c73155
4 changed files with 90 additions and 82 deletions

View File

@ -21,7 +21,7 @@ __kernel void conv2d_transpose2x2(__read_only image2d_t src_data, __global FLT16
FLT4 r1 = (FLT4)(0.f);
FLT4 r2 = (FLT4)(0.f);
FLT4 r3 = (FLT4)(0.f);
int base_w = (co * 4 + kh + kw * 2) * src_size.z;
int base_w = (co * 4 + kh * 2 + kw) * src_size.z;
for (int ci = 0; ci < src_size.z; ++ci) {
FLT4 x0 = READ_IMAGE(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h));
FLT4 x1 = READ_IMAGE(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h + 1));

View File

@ -14,11 +14,11 @@
* limitations under the License.
*/
#include "src/runtime/kernel/opencl/kernel/conv2d_transpose.h"
#include <string>
#include <set>
#include "src/kernel_registry.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/kernel/opencl/kernel/conv2d_transpose.h"
#ifndef PROGRAM_WITH_IL
#include "src/runtime/kernel/opencl/cl/fp16/conv2d_transpose2x2.cl.inc"
#include "src/runtime/kernel/opencl/cl/fp32/conv2d_transpose2x2.cl.inc"
@ -34,11 +34,11 @@ int Conv2dTransposeOpenCLKernel::Init() {
ConvParameter *param = reinterpret_cast<ConvParameter *>(op_parameter_);
if (param->kernel_h_ != 2 || param->kernel_w_ != 2 || param->stride_h_ != 2 || param->stride_w_ != 2) {
MS_LOG(ERROR) << "only support kh=kw=2 and stride_h=stride_w=2.";
return 1;
return RET_ERROR;
}
if (param->pad_h_ >= 2 || param->pad_w_ >= 2) {
MS_LOG(ERROR) << "only support pad in {0,1}.";
return 1;
if (param->pad_h_ != 0 || param->pad_w_ != 0) {
MS_LOG(ERROR) << "only support pad =0.";
return RET_ERROR;
}
std::string kernel_name = "conv2d_transpose2x2";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
@ -55,45 +55,41 @@ int Conv2dTransposeOpenCLKernel::Init() {
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
#endif
int ci = param->input_channel_;
int co = param->output_channel_;
int kh = param->kernel_h_;
int kw = param->kernel_w_;
int div_ci = UP_DIV(ci, 4);
int div_co = UP_DIV(co, 4);
auto allocator = ocl_runtime->GetAllocator();
padWeight_ = reinterpret_cast<FLOAT_T *>(allocator->Malloc(div_ci * div_co * 16 * kh * kw * sizeof(FLOAT_T)));
padWeight_ = reinterpret_cast<FLOAT_T *>(allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true));
PadWeight();
allocator->UnmapBuffer(padWeight_);
out_tensors_[0]->SetFormat(schema::Format_NHWC4);
MS_LOG(DEBUG) << kernel_name << " Init Done!";
return 0;
return RET_OK;
}
int Conv2dTransposeOpenCLKernel::ReSize() { return 0; }
void Conv2dTransposeOpenCLKernel::PadWeight() {
// OHWI to OHWI4(I)4(O)
ConvParameter *param = reinterpret_cast<ConvParameter *>(op_parameter_);
int ci = param->input_channel_;
int co = param->output_channel_;
int ci = in_tensors_[0]->Channel();
int co = out_tensors_[0]->Channel();
int kh = param->kernel_h_;
int kw = param->kernel_w_;
int div_ci = UP_DIV(ci, 4);
int div_co = UP_DIV(co, 4);
auto origin_weight = reinterpret_cast<FLOAT_T *>(in_tensors_.at(kWeightIndex)->Data());
int div_ci = UP_DIV(ci, C4NUM);
int div_co = UP_DIV(co, C4NUM);
auto allocator = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator();
// IHWO to OHWI4(I)4(O)(converter format is IHWO)
// init padWeight_(buffer mem)
padWeight_ =
reinterpret_cast<FLOAT_t *>(allocator->Malloc(div_ci * div_co * C4NUM * C4NUM * kh * kw * sizeof(FLOAT_t)));
padWeight_ = reinterpret_cast<FLOAT_t *>(allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true));
auto origin_weight = reinterpret_cast<FLOAT_t *>(in_tensors_.at(kWeightIndex)->Data());
int index = 0;
for (int co_i = 0; co_i < div_co; co_i++) {
for (int kw_i = 0; kw_i < kw; kw_i++) {
for (int kh_i = 0; kh_i < kh; kh_i++) {
for (int kh_i = 0; kh_i < kh; kh_i++) {
for (int kw_i = 0; kw_i < kw; kw_i++) {
for (int ci_i = 0; ci_i < div_ci; ci_i++) {
for (int ci4_i = 0; ci4_i < 4; ci4_i++) {
for (int co4_i = 0; co4_i < 4; co4_i++) {
int co_offset = co_i * 4 + co4_i;
int ci_offset = ci_i * 4 + ci4_i;
for (int ci4_i = 0; ci4_i < C4NUM; ci4_i++) {
for (int co4_i = 0; co4_i < C4NUM; co4_i++) {
int co_offset = co_i * C4NUM + co4_i;
int ci_offset = ci_i * C4NUM + ci4_i;
if (co_offset < co && ci_offset < ci) {
int ori_index = ((co_offset * kh + kh_i) * kw + kw_i) * ci + ci_offset;
int ori_index = ((ci_offset * kh + kh_i) * kw + kw_i) * ci + co_offset;
padWeight_[index++] = origin_weight[ori_index];
} else {
padWeight_[index++] = 0.;
@ -104,6 +100,40 @@ void Conv2dTransposeOpenCLKernel::PadWeight() {
}
}
}
allocator->UnmapBuffer(padWeight_);
// init bias_(image2d mem)
size_t im_dst_x, im_dst_y;
im_dst_x = div_co;
im_dst_y = 1;
#ifdef ENABLE_FP16
size_t img_dtype = CL_HALF_FLOAT;
#else
size_t img_dtype = CL_FLOAT;
#endif
std::vector<size_t> img_size{im_dst_x, im_dst_y, img_dtype};
bias_ = reinterpret_cast<FLOAT_t *>(allocator->Malloc(im_dst_x * im_dst_y * C4NUM * sizeof(FLOAT_t), img_size));
bias_ = reinterpret_cast<FLOAT_t *>(allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true));
memset(bias_, 0x00, div_co * C4NUM * sizeof(FLOAT_t));
if (in_tensors_.size() >= 3) {
memcpy(bias_, in_tensors_[2]->Data(), co * sizeof(FLOAT_t));
}
allocator->UnmapBuffer(bias_);
}
int Conv2dTransposeOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
size_t im_dst_x, im_dst_y;
im_dst_x = UP_DIV(out_tensors_[0]->Channel() * out_tensors_[0]->Width(), C4NUM);
im_dst_y = out_tensors_[0]->Height();
#ifdef ENABLE_FP16
size_t img_dtype = CL_HALF_FLOAT;
#else
size_t img_dtype = CL_FLOAT;
#endif
img_size->clear();
std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype};
*img_size = vec;
return RET_OK;
}
int Conv2dTransposeOpenCLKernel::Run() {
@ -111,37 +141,20 @@ int Conv2dTransposeOpenCLKernel::Run() {
std::vector<int> shapex = in_tensors_[0]->shape();
int n = shapex[0];
if (n > 1) {
MS_LOG(ERROR) << "Conv2dTranspose n > 1 not supported!";
return 1;
MS_LOG(ERROR) << " n > 1 not supported!";
return RET_ERROR;
}
ConvParameter *param = reinterpret_cast<ConvParameter *>(op_parameter_);
int ci = param->input_channel_;
int co = param->output_channel_;
int ci = in_tensors_[0]->Channel();
int co = out_tensors_[0]->Channel();
int kh = param->kernel_h_;
int kw = param->kernel_w_;
int pad = param->pad_h_;
int oh = out_tensors_[0]->shape()[1];
int ow = out_tensors_[0]->shape()[2];
int h = in_tensors_[0]->shape()[1];
int w = in_tensors_[0]->shape()[2];
int oh = out_tensors_[0]->Height();
int ow = out_tensors_[0]->Width();
int h = in_tensors_[0]->Height();
int w = in_tensors_[0]->Width();
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
cl::ImageFormat image_format;
{
image_format.image_channel_order = CL_RGBA;
#ifdef ENABLE_FP16
image_format.image_channel_data_type = CL_HALF_FLOAT;
#else
image_format.image_channel_data_type = CL_FLOAT;
#endif
}
cl_int in_error_code, in_error_code_weight, in_error_code_bias, out_error_code;
cl::Image2D img_x(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, image_format, w * ci / 4, h, 0,
in_tensors_[0]->Data(), &in_error_code);
cl::Image2D img_bias(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, image_format, co / 4, 1, 0,
in_tensors_[2]->Data(), &in_error_code_bias);
cl::Image2D out_mem(*ocl_runtime->Context(), CL_MEM_WRITE_ONLY, image_format, ow * co / 4, oh, 0, nullptr,
&out_error_code);
// local size should less than MAX_GROUP_SIZE
std::vector<size_t> local = {16, 1, 16};
std::vector<size_t> global = {UP_ROUND((size_t)UP_ROUND(oh / 2, 2), local[0]),
@ -150,23 +163,20 @@ int Conv2dTransposeOpenCLKernel::Run() {
cl_int2 kernel_size = {kh, kw};
cl_int2 stride = {2, 2};
cl_int2 padding = {pad, pad};
cl_int4 src_size = {h, w, UP_DIV(ci, 4), 1};
cl_int4 dst_size = {oh, ow, UP_DIV(co, 4), 1};
ocl_runtime->SetKernelArg(kernel_, 0, img_x);
ocl_runtime->SetKernelArg(kernel_, 1, padWeight_);
ocl_runtime->SetKernelArg(kernel_, 2, img_bias);
ocl_runtime->SetKernelArg(kernel_, 3, out_mem);
ocl_runtime->SetKernelArg(kernel_, 4, kernel_size);
ocl_runtime->SetKernelArg(kernel_, 5, stride);
ocl_runtime->SetKernelArg(kernel_, 6, padding);
ocl_runtime->SetKernelArg(kernel_, 7, src_size);
ocl_runtime->SetKernelArg(kernel_, 8, dst_size);
cl_int4 src_size = {h, w, UP_DIV(ci, C4NUM), 1};
cl_int4 dst_size = {oh, ow, UP_DIV(co, C4NUM), 1};
int arg_cnt = 0;
ocl_runtime->SetKernelArg(kernel_, arg_cnt++, in_tensors_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_cnt++, padWeight_);
ocl_runtime->SetKernelArg(kernel_, arg_cnt++, bias_);
ocl_runtime->SetKernelArg(kernel_, arg_cnt++, out_tensors_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_cnt++, kernel_size);
ocl_runtime->SetKernelArg(kernel_, arg_cnt++, stride);
ocl_runtime->SetKernelArg(kernel_, arg_cnt++, padding);
ocl_runtime->SetKernelArg(kernel_, arg_cnt++, src_size);
ocl_runtime->SetKernelArg(kernel_, arg_cnt++, dst_size);
ocl_runtime->RunKernel(kernel_, global, local, nullptr);
auto origin = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{(size_t)(ow * co / 4), (size_t)(oh), 1};
ocl_runtime->GetDefaultCommandQueue()->enqueueReadImage(out_mem, CL_TRUE, origin, region, 0, 0,
out_tensors_[0]->Data());
return 0;
return RET_OK;
}
kernel::LiteKernel *OpenCLConv2dTransposeKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,

View File

@ -22,32 +22,28 @@
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/nnacl/conv_parameter.h"
#include "src/runtime/opencl/opencl_runtime.h"
#ifdef ENABLE_FP16
using FLOAT_T = float16_t;
#else
using FLOAT_T = float;
#endif
#include "src/runtime/kernel/opencl/opencl_kernel.h"
namespace mindspore::kernel {
class Conv2dTransposeOpenCLKernel : public LiteKernel {
class Conv2dTransposeOpenCLKernel : public OpenCLKernel {
public:
explicit Conv2dTransposeOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs)
: LiteKernel(parameter, inputs, outputs, nullptr, nullptr) {}
: OpenCLKernel(parameter, inputs, outputs) {}
~Conv2dTransposeOpenCLKernel() override{};
int Init() override;
int ReSize() override;
int Run() override;
void PadWeight();
int GetImageSize(size_t idx, std::vector<size_t> *img_size) override;
private:
ConvParameter *parameter_;
cl::Kernel kernel_;
FLOAT_T *padWeight_;
FLOAT_T *bias_;
FLOAT_t *padWeight_;
FLOAT_t *bias_;
};
} // namespace mindspore::kernel

View File

@ -33,6 +33,7 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) {
// setbuf(stdout, NULL);
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();
int pad = 0;
int n = 1;
int h = 240;
@ -57,7 +58,6 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) {
auto bias_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(bias_path.c_str(), &bias_size));
lite::tensor::Tensor *tensor_x = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, h, w, ci});
tensor_x->SetData(input_data);
lite::tensor::Tensor *tensor_w = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {co, kh, kw, ci});
tensor_w->SetData(weight_data);
@ -81,9 +81,11 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) {
new kernel::Conv2dTransposeOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs);
arith_kernel->Init();
inputs[0]->MallocData(allocator);
std::vector<kernel::LiteKernel *> kernels{arith_kernel};
auto *pGraph = new kernel::SubGraphOpenCLKernel({tensor_x}, outputs, kernels, kernels, kernels);
pGraph->Init();
memcpy(inputs[0]->Data(), input_data, input_size);
pGraph->Run();
printf("==================output data=================\n");