diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/to_format.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/to_format.cl index bc08a62c2b2..46ac2b24420 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/to_format.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/to_format.cl @@ -27,6 +27,33 @@ __kernel void to_format_NHWC_to_NHWC4_IMG(__global FLT4 *src_data, __write_only } WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), data); } +__kernel void to_format_NHWC_to_NC4HW4_IMG(__global FLT4 *src_data, __write_only image2d_t dst_data, int4 size, + int4 shape) { + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + if (X >= size.x || Y >= size.y || Z >= size.z) { + return; + } + int offset = (X * shape.z + Y) * shape.w + Z * 4; + __global FLT *src_addr = (__global FLT *)src_data; + src_addr += offset; + FLT4 data = (FLT4)(0.f); + if ((Z + 1) * 4 <= shape.w) { + data = ((__global FLT4 *)src_addr)[0]; + } else { + if ((shape.w - Z * 4) >= 1) { + data.x = src_addr[0]; + } + if ((shape.w - Z * 4) >= 2) { + data.y = src_addr[1]; + } + if ((shape.w - Z * 4) >= 3) { + data.z = src_addr[2]; + } + } + WRITE_IMAGE(dst_data, (int2)(Y, Z * size.x + X), data); +} __kernel void to_format_NHWC4_to_NHWC4_IMG(__global FLT4 *src_data, __write_only image2d_t dst_data, int4 size, int4 shape) { int X = get_global_id(0); @@ -84,6 +111,32 @@ __kernel void to_format_NHWC4_to_NHWC_BUF(__read_only image2d_t src_data, __glob } } } +__kernel void to_format_NC4HW4_to_NHWC_BUF(__read_only image2d_t src_data, __global FLT4 *dst_data, int4 size, + int4 shape) { + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + if (X >= size.x || Y >= size.y || Z >= size.z) { + return; + } + FLT4 data = READ_IMAGE(src_data, smp_zero, (int2)(Y, Z * size.x + X)); + int offset = (X * shape.z + Y) * shape.w + Z * 4; + __global FLT *dst_addr = (__global FLT *)dst_data; + dst_addr += offset; + if ((Z + 1) * 4 <= shape.w) { + ((__global FLT4 *)dst_addr)[0] = data; + } else { + if (shape.w - Z * 4 >= 1) { + dst_addr[0] = data.x; + } + if (shape.w - Z * 4 >= 2) { + dst_addr[1] = data.y; + } + if (shape.w - Z * 4 >= 3) { + dst_addr[2] = data.z; + } + } +} __kernel void to_format_NC4HW4_to_NC4HW4_BUF(__read_only image2d_t src_data, __global FLT4 *dst_data, int4 size, int4 shape) { // size(h, w, c, 1), shape(n, c, h, w) diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc index 1077394db38..894b0af14af 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc @@ -43,12 +43,9 @@ namespace mindspore::kernel { int DepthwiseConv2dOpenCLKernel::Init() { auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); std::string kernel_name = "DepthwiseConv2d"; - auto in_format = in_tensors_[0]->GetFormat(); + auto in_format = op_format_; in_ori_format_ = in_tensors_[0]->GetFormat(); out_ori_format_ = out_tensors_[0]->GetFormat(); - in_format = (in_format == schema::Format_NHWC) - ? schema::Format_NHWC4 - : ((in_format == schema::Format_NCHW) ? schema::Format_NC4HW4 : in_format); if (in_format != schema::Format_NHWC4 && in_format != schema::Format_NC4HW4) { MS_LOG(ERROR) << "input format(" << in_format << ") " << "format not support!"; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc index a086374d0b2..94dd6dc66e3 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc @@ -65,13 +65,13 @@ int ToFormatOpenCLKernel::Init() { int ToFormatOpenCLKernel::InitNHWCShape() { std::vector shapex = out_tensors_[0]->shape(); size_t n, h, w, c; - if (out_tensors_[0]->GetFormat() == schema::Format_NHWC4 || out_tensors_[0]->GetFormat() == schema::Format_NHWC) { + if (out_tensors_[0]->GetFormat() == schema::Format_NC4HW4 || out_tensors_[0]->GetFormat() == schema::Format_NHWC4 || + out_tensors_[0]->GetFormat() == schema::Format_NHWC) { n = shapex[0]; h = shapex[1]; w = shapex[2]; c = shapex[3]; - } else if (out_tensors_[0]->GetFormat() == schema::Format_NC4HW4 || - out_tensors_[0]->GetFormat() == schema::Format_NCHW) { + } else if (out_tensors_[0]->GetFormat() == schema::Format_NCHW) { n = shapex[0]; h = shapex[2]; w = shapex[3]; @@ -105,21 +105,20 @@ int ToFormatOpenCLKernel::GetLocalSize(size_t idx, const std::vector &gl int ToFormatOpenCLKernel::GetImageSize(size_t idx, std::vector *img_size) { size_t im_dst_x, im_dst_y; - std::vector shapex = out_tensors_[0]->shape(); if (out_tensors_[0]->GetFormat() == schema::Format_NC4HW4) { - int c = shapex[1] * shapex[2]; - int h = shapex[0]; - int w = shapex[3]; - im_dst_y = h * UP_DIV(c, C4NUM); + int c = nhwc_shape_[3]; + int h = nhwc_shape_[1]; + int w = nhwc_shape_[2]; + im_dst_y = nhwc_shape_[0] * h * UP_DIV(c, C4NUM); im_dst_x = w; } else if (out_tensors_[0]->GetFormat() == schema::Format_NHWC4) { - int h = shapex[0] * shapex[1]; - int w = shapex[2]; - int c = shapex[3]; + int h = nhwc_shape_[0] * nhwc_shape_[1]; + int w = nhwc_shape_[2]; + int c = nhwc_shape_[3]; im_dst_x = w * UP_DIV(c, C4NUM); im_dst_y = h; } else if (out_tensors_[0]->GetFormat() == schema::Format_NC4) { - int c = shapex[1]; + int c = nhwc_shape_[1]; im_dst_x = UP_DIV(c, C4NUM); im_dst_y = 1; } else { diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h index b8b725cfb01..d33d760554f 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h @@ -58,7 +58,7 @@ class OpenCLKernel : public LiteKernel { OpenCLMemType out_mem_type_{OpenCLMemType::IMG}; schema::Format in_ori_format_{schema::Format_NHWC}; schema::Format out_ori_format_{schema::Format_NHWC4}; - schema::Format op_format_{schema::Format_NC4HW4}; + schema::Format op_format_{schema::Format_NHWC4}; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc index e1e6704a18e..41470d90ba8 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc @@ -66,12 +66,12 @@ void DepthWiseTestMain(ConvParameter *conv_param, T2 *input_data, T1 *weight_dat std::vector shape_bias = {conv_param->output_channel_}; std::vector shape_out; std::vector shape_in; - if (format == schema::Format_NHWC || format == schema::Format_NHWC4) { + if (format == schema::Format_NHWC || format == schema::Format_NHWC4 || format == schema::Format_NC4HW4) { shape_in = std::vector( {conv_param->input_batch_, conv_param->input_h_, conv_param->input_w_, conv_param->input_channel_}); shape_out = std::vector( {conv_param->output_batch_, conv_param->output_h_, conv_param->output_w_, conv_param->output_channel_}); - } else if (format == schema::Format_NCHW || format == schema::Format_NC4HW4) { + } else if (format == schema::Format_NCHW) { shape_in = std::vector( {conv_param->input_batch_, conv_param->input_channel_, conv_param->input_h_, conv_param->input_w_}); shape_out = std::vector( @@ -98,6 +98,7 @@ void DepthWiseTestMain(ConvParameter *conv_param, T2 *input_data, T1 *weight_dat delete[] packed_input; return; } + pKernel->SetFormatType(format); pKernel->Init(); std::vector kernels{pKernel.get()};