forked from mindspore-Ecosystem/mindspore
add cast to opencl to_format op
This commit is contained in:
parent
fd555f049b
commit
d922befbe0
|
@ -1,3 +1,4 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#define divide_no_check(a, b) (a / b)
|
||||
__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;
|
||||
|
||||
|
@ -62,7 +63,7 @@ __kernel void BoardcastArith_IMG(__read_only image2d_t input_a, float weight, fl
|
|||
}
|
||||
|
||||
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));
|
||||
WRITE_IMAGE(output, (int2)(X, Y), weight * a + bias);
|
||||
WRITE_IMAGE(output, (int2)(X, Y), ((FLT)weight) * a + (FLT)bias);
|
||||
}
|
||||
|
||||
__kernel void ElementAdd_BUF(__global float *input_a, __global float *input_b, __global float *output,
|
||||
|
|
|
@ -1,34 +1,61 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
__kernel void to_format_NHWC_to_NHWC4_IMG(__global FLT4 *src_data, __write_only image2d_t dst_data, int4 size,
|
||||
int4 shape) {
|
||||
__kernel void to_format_NHWC_to_NHWC4_IMG_float(__global float4 *src_data, __write_only image2d_t dst_data, int4 size,
|
||||
int4 shape) {
|
||||
int X = get_global_id(0);
|
||||
int Y = get_global_id(1);
|
||||
int Z = get_global_id(2);
|
||||
if (X >= size.x || Y >= size.y || Z >= size.z) {
|
||||
return;
|
||||
}
|
||||
int offset = (X * shape.z + Y) * shape.w + Z * 4;
|
||||
__global FLT *src_addr = (__global FLT *)src_data;
|
||||
src_addr += offset;
|
||||
FLT4 data = (FLT4)(0.f);
|
||||
int offset = (X * shape.z + Y) * shape.w + Z * 4;
|
||||
__global float *src_addr = (__global float *)src_data;
|
||||
src_addr += offset;
|
||||
if ((Z + 1) * 4 <= shape.w) {
|
||||
data = ((__global FLT4 *)src_addr)[0];
|
||||
data = TO_FLT4(((__global float4 *)src_addr)[0]);
|
||||
} else {
|
||||
if ((shape.w - Z * 4) >= 1) {
|
||||
data.x = src_addr[0];
|
||||
data.x = (FLT)src_addr[0];
|
||||
}
|
||||
if ((shape.w - Z * 4) >= 2) {
|
||||
data.y = src_addr[1];
|
||||
data.y = (FLT)src_addr[1];
|
||||
}
|
||||
if ((shape.w - Z * 4) >= 3) {
|
||||
data.z = src_addr[2];
|
||||
data.z = (FLT)src_addr[2];
|
||||
}
|
||||
}
|
||||
WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), data);
|
||||
}
|
||||
__kernel void to_format_NHWC_to_NC4HW4_IMG(__global FLT4 *src_data, __write_only image2d_t dst_data, int4 size,
|
||||
int4 shape) {
|
||||
__kernel void to_format_NHWC_to_NHWC4_IMG_half(__global half4 *src_data, __write_only image2d_t dst_data, int4 size,
|
||||
int4 shape) {
|
||||
int X = get_global_id(0);
|
||||
int Y = get_global_id(1);
|
||||
int Z = get_global_id(2);
|
||||
if (X >= size.x || Y >= size.y || Z >= size.z) {
|
||||
return;
|
||||
}
|
||||
FLT4 data = (FLT4)(0.f);
|
||||
int offset = (X * shape.z + Y) * shape.w + Z * 4;
|
||||
__global half *src_addr = (__global half *)src_data;
|
||||
src_addr += offset;
|
||||
if ((Z + 1) * 4 <= shape.w) {
|
||||
data = TO_FLT4(((__global half4 *)src_addr)[0]);
|
||||
} else {
|
||||
if ((shape.w - Z * 4) >= 1) {
|
||||
data.x = (FLT)src_addr[0];
|
||||
}
|
||||
if ((shape.w - Z * 4) >= 2) {
|
||||
data.y = (FLT)src_addr[1];
|
||||
}
|
||||
if ((shape.w - Z * 4) >= 3) {
|
||||
data.z = (FLT)src_addr[2];
|
||||
}
|
||||
}
|
||||
WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), data);
|
||||
}
|
||||
__kernel void to_format_NHWC_to_NC4HW4_IMG_float(__global float4 *src_data, __write_only image2d_t dst_data, int4 size,
|
||||
int4 shape) {
|
||||
int X = get_global_id(0);
|
||||
int Y = get_global_id(1);
|
||||
int Z = get_global_id(2);
|
||||
|
@ -36,36 +63,73 @@ __kernel void to_format_NHWC_to_NC4HW4_IMG(__global FLT4 *src_data, __write_only
|
|||
return;
|
||||
}
|
||||
int offset = (X * shape.z + Y) * shape.w + Z * 4;
|
||||
__global FLT *src_addr = (__global FLT *)src_data;
|
||||
__global float *src_addr = (__global float *)src_data;
|
||||
src_addr += offset;
|
||||
FLT4 data = (FLT4)(0.f);
|
||||
if ((Z + 1) * 4 <= shape.w) {
|
||||
data = ((__global FLT4 *)src_addr)[0];
|
||||
data = TO_FLT4(((__global float4 *)src_addr)[0]);
|
||||
} else {
|
||||
if ((shape.w - Z * 4) >= 1) {
|
||||
data.x = src_addr[0];
|
||||
data.x = (FLT)src_addr[0];
|
||||
}
|
||||
if ((shape.w - Z * 4) >= 2) {
|
||||
data.y = src_addr[1];
|
||||
data.y = (FLT)src_addr[1];
|
||||
}
|
||||
if ((shape.w - Z * 4) >= 3) {
|
||||
data.z = src_addr[2];
|
||||
data.z = (FLT)src_addr[2];
|
||||
}
|
||||
}
|
||||
WRITE_IMAGE(dst_data, (int2)(Y, Z * size.x + X), data);
|
||||
}
|
||||
__kernel void to_format_NHWC4_to_NHWC4_IMG(__global FLT4 *src_data, __write_only image2d_t dst_data, int4 size,
|
||||
int4 shape) {
|
||||
__kernel void to_format_NHWC_to_NC4HW4_IMG_half(__global half4 *src_data, __write_only image2d_t dst_data, int4 size,
|
||||
int4 shape) {
|
||||
int X = get_global_id(0);
|
||||
int Y = get_global_id(1);
|
||||
int Z = get_global_id(2);
|
||||
if (X >= size.x || Y >= size.y || Z >= size.z) {
|
||||
return;
|
||||
}
|
||||
WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), src_data[(X * size.y + Y) * size.z + Z]);
|
||||
int offset = (X * shape.z + Y) * shape.w + Z * 4;
|
||||
__global half *src_addr = (__global half *)src_data;
|
||||
src_addr += offset;
|
||||
FLT4 data = (FLT4)(0.f);
|
||||
if ((Z + 1) * 4 <= shape.w) {
|
||||
data = TO_FLT4(((__global half4 *)src_addr)[0]);
|
||||
} else {
|
||||
if ((shape.w - Z * 4) >= 1) {
|
||||
data.x = (FLT)src_addr[0];
|
||||
}
|
||||
if ((shape.w - Z * 4) >= 2) {
|
||||
data.y = (FLT)src_addr[1];
|
||||
}
|
||||
if ((shape.w - Z * 4) >= 3) {
|
||||
data.z = (FLT)src_addr[2];
|
||||
}
|
||||
}
|
||||
WRITE_IMAGE(dst_data, (int2)(Y, Z * size.x + X), data);
|
||||
}
|
||||
__kernel void to_format_NC4HW4_to_NC4HW4_IMG(__global FLT4 *src_data, __write_only image2d_t dst_data, int4 size,
|
||||
int4 shape) {
|
||||
__kernel void to_format_NHWC4_to_NHWC4_IMG_float(__global float4 *src_data, __write_only image2d_t dst_data, int4 size,
|
||||
int4 shape) {
|
||||
int X = get_global_id(0);
|
||||
int Y = get_global_id(1);
|
||||
int Z = get_global_id(2);
|
||||
if (X >= size.x || Y >= size.y || Z >= size.z) {
|
||||
return;
|
||||
}
|
||||
WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), TO_FLT4(src_data[(X * size.y + Y) * size.z + Z]));
|
||||
}
|
||||
__kernel void to_format_NHWC4_to_NHWC4_IMG_half(__global half4 *src_data, __write_only image2d_t dst_data, int4 size,
|
||||
int4 shape) {
|
||||
int X = get_global_id(0);
|
||||
int Y = get_global_id(1);
|
||||
int Z = get_global_id(2);
|
||||
if (X >= size.x || Y >= size.y || Z >= size.z) {
|
||||
return;
|
||||
}
|
||||
WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), TO_FLT4(src_data[(X * size.y + Y) * size.z + Z]));
|
||||
}
|
||||
__kernel void to_format_NC4HW4_to_NC4HW4_IMG_float(__global float4 *src_data, __write_only image2d_t dst_data,
|
||||
int4 size, int4 shape) {
|
||||
// size(h, w, c4, 1), shape(n, c, h, w)
|
||||
int X = get_global_id(0); // h
|
||||
int Y = get_global_id(1); // w
|
||||
|
@ -73,32 +137,53 @@ __kernel void to_format_NC4HW4_to_NC4HW4_IMG(__global FLT4 *src_data, __write_on
|
|||
if (X >= size.x || Y >= size.y || Z >= size.z) {
|
||||
return;
|
||||
}
|
||||
WRITE_IMAGE(dst_data, (int2)(Y, Z * size.x + X), src_data[(Z * size.x + X) * size.y + Y]);
|
||||
WRITE_IMAGE(dst_data, (int2)(Y, Z * size.x + X), TO_FLT4(src_data[(Z * size.x + X) * size.y + Y]));
|
||||
}
|
||||
__kernel void to_format_NCHW_to_NCHW_BUF(__read_only image2d_t src_data, __global FLT4 *dst_data, int4 size,
|
||||
int4 shape) {
|
||||
__kernel void to_format_NC4HW4_to_NC4HW4_IMG_half(__global half4 *src_data, __write_only image2d_t dst_data, int4 size,
|
||||
int4 shape) {
|
||||
// size(h, w, c4, 1), shape(n, c, h, w)
|
||||
int X = get_global_id(0); // h
|
||||
int Y = get_global_id(1); // w
|
||||
int Z = get_global_id(2); // c4
|
||||
if (X >= size.x || Y >= size.y || Z >= size.z) {
|
||||
return;
|
||||
}
|
||||
WRITE_IMAGE(dst_data, (int2)(Y, Z * size.x + X), TO_FLT4(src_data[(Z * size.x + X) * size.y + Y]));
|
||||
}
|
||||
__kernel void to_format_NCHW_to_NCHW_BUF_float(__read_only image2d_t src_data, __global float4 *dst_data, int4 size,
|
||||
int4 shape) {
|
||||
int X = get_global_id(0);
|
||||
int Y = get_global_id(1);
|
||||
int Z = get_global_id(2);
|
||||
if (X >= size.x || Y >= size.y || Z >= size.z) {
|
||||
return;
|
||||
}
|
||||
dst_data[(Z * size.y + Y) * size.x + X] = READ_IMAGE(src_data, smp_zero, (int2)(Y * size.x + X, Z));
|
||||
dst_data[(Z * size.y + Y) * size.x + X] = convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(Y * size.x + X, Z)));
|
||||
}
|
||||
__kernel void to_format_NHWC4_to_NHWC_BUF(__read_only image2d_t src_data, __global FLT4 *dst_data, int4 size,
|
||||
int4 shape) {
|
||||
__kernel void to_format_NCHW_to_NCHW_BUF_half(__read_only image2d_t src_data, __global half4 *dst_data, int4 size,
|
||||
int4 shape) {
|
||||
int X = get_global_id(0);
|
||||
int Y = get_global_id(1);
|
||||
int Z = get_global_id(2);
|
||||
if (X >= size.x || Y >= size.y || Z >= size.z) {
|
||||
return;
|
||||
}
|
||||
FLT4 data = READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + Z, X));
|
||||
dst_data[(Z * size.y + Y) * size.x + X] = convert_half4(READ_IMAGE(src_data, smp_zero, (int2)(Y * size.x + X, Z)));
|
||||
}
|
||||
__kernel void to_format_NHWC4_to_NHWC_BUF_float(__read_only image2d_t src_data, __global float4 *dst_data, int4 size,
|
||||
int4 shape) {
|
||||
int X = get_global_id(0);
|
||||
int Y = get_global_id(1);
|
||||
int Z = get_global_id(2);
|
||||
if (X >= size.x || Y >= size.y || Z >= size.z) {
|
||||
return;
|
||||
}
|
||||
float4 data = convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + Z, X)));
|
||||
int offset = (X * shape.z + Y) * shape.w + Z * 4;
|
||||
__global FLT *dst_addr = (__global FLT *)dst_data;
|
||||
__global float *dst_addr = (__global float *)dst_data;
|
||||
dst_addr += offset;
|
||||
if ((Z + 1) * 4 <= shape.w) {
|
||||
((__global FLT4 *)dst_addr)[0] = data;
|
||||
((__global float4 *)dst_addr)[0] = data;
|
||||
} else {
|
||||
if (shape.w - Z * 4 >= 1) {
|
||||
dst_addr[0] = data.x;
|
||||
|
@ -111,20 +196,20 @@ __kernel void to_format_NHWC4_to_NHWC_BUF(__read_only image2d_t src_data, __glob
|
|||
}
|
||||
}
|
||||
}
|
||||
__kernel void to_format_NC4HW4_to_NHWC_BUF(__read_only image2d_t src_data, __global FLT4 *dst_data, int4 size,
|
||||
int4 shape) {
|
||||
__kernel void to_format_NHWC4_to_NHWC_BUF_half(__read_only image2d_t src_data, __global half4 *dst_data, int4 size,
|
||||
int4 shape) {
|
||||
int X = get_global_id(0);
|
||||
int Y = get_global_id(1);
|
||||
int Z = get_global_id(2);
|
||||
if (X >= size.x || Y >= size.y || Z >= size.z) {
|
||||
return;
|
||||
}
|
||||
FLT4 data = READ_IMAGE(src_data, smp_zero, (int2)(Y, Z * size.x + X));
|
||||
half4 data = convert_half4(READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + Z, X)));
|
||||
int offset = (X * shape.z + Y) * shape.w + Z * 4;
|
||||
__global FLT *dst_addr = (__global FLT *)dst_data;
|
||||
__global half *dst_addr = (__global half *)dst_data;
|
||||
dst_addr += offset;
|
||||
if ((Z + 1) * 4 <= shape.w) {
|
||||
((__global FLT4 *)dst_addr)[0] = data;
|
||||
((__global half4 *)dst_addr)[0] = data;
|
||||
} else {
|
||||
if (shape.w - Z * 4 >= 1) {
|
||||
dst_addr[0] = data.x;
|
||||
|
@ -137,8 +222,60 @@ __kernel void to_format_NC4HW4_to_NHWC_BUF(__read_only image2d_t src_data, __glo
|
|||
}
|
||||
}
|
||||
}
|
||||
__kernel void to_format_NC4HW4_to_NC4HW4_BUF(__read_only image2d_t src_data, __global FLT4 *dst_data, int4 size,
|
||||
int4 shape) {
|
||||
__kernel void to_format_NC4HW4_to_NHWC_BUF_float(__read_only image2d_t src_data, __global float4 *dst_data, int4 size,
|
||||
int4 shape) {
|
||||
int X = get_global_id(0);
|
||||
int Y = get_global_id(1);
|
||||
int Z = get_global_id(2);
|
||||
if (X >= size.x || Y >= size.y || Z >= size.z) {
|
||||
return;
|
||||
}
|
||||
float4 data = convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(Y, Z * size.x + X)));
|
||||
int offset = (X * shape.z + Y) * shape.w + Z * 4;
|
||||
__global float *dst_addr = (__global float *)dst_data;
|
||||
dst_addr += offset;
|
||||
if ((Z + 1) * 4 <= shape.w) {
|
||||
((__global float4 *)dst_addr)[0] = data;
|
||||
} else {
|
||||
if (shape.w - Z * 4 >= 1) {
|
||||
dst_addr[0] = data.x;
|
||||
}
|
||||
if (shape.w - Z * 4 >= 2) {
|
||||
dst_addr[1] = data.y;
|
||||
}
|
||||
if (shape.w - Z * 4 >= 3) {
|
||||
dst_addr[2] = data.z;
|
||||
}
|
||||
}
|
||||
}
|
||||
__kernel void to_format_NC4HW4_to_NHWC_BUF_half(__read_only image2d_t src_data, __global half4 *dst_data, int4 size,
|
||||
int4 shape) {
|
||||
int X = get_global_id(0);
|
||||
int Y = get_global_id(1);
|
||||
int Z = get_global_id(2);
|
||||
if (X >= size.x || Y >= size.y || Z >= size.z) {
|
||||
return;
|
||||
}
|
||||
half4 data = convert_half4(READ_IMAGE(src_data, smp_zero, (int2)(Y, Z * size.x + X)));
|
||||
int offset = (X * shape.z + Y) * shape.w + Z * 4;
|
||||
__global half *dst_addr = (__global half *)dst_data;
|
||||
dst_addr += offset;
|
||||
if ((Z + 1) * 4 <= shape.w) {
|
||||
((__global half4 *)dst_addr)[0] = data;
|
||||
} else {
|
||||
if (shape.w - Z * 4 >= 1) {
|
||||
dst_addr[0] = data.x;
|
||||
}
|
||||
if (shape.w - Z * 4 >= 2) {
|
||||
dst_addr[1] = data.y;
|
||||
}
|
||||
if (shape.w - Z * 4 >= 3) {
|
||||
dst_addr[2] = data.z;
|
||||
}
|
||||
}
|
||||
}
|
||||
__kernel void to_format_NC4HW4_to_NC4HW4_BUF_float(__read_only image2d_t src_data, __global float4 *dst_data, int4 size,
|
||||
int4 shape) {
|
||||
// size(h, w, c, 1), shape(n, c, h, w)
|
||||
int X = get_global_id(0); // h
|
||||
int Y = get_global_id(1); // w
|
||||
|
@ -146,15 +283,26 @@ __kernel void to_format_NC4HW4_to_NC4HW4_BUF(__read_only image2d_t src_data, __g
|
|||
if (X >= size.x || Y >= size.y || Z >= size.z) {
|
||||
return;
|
||||
}
|
||||
dst_data[(Z * size.x + X) * size.y + Y] = READ_IMAGE(src_data, smp_zero, (int2)(Y, Z * size.x + X));
|
||||
dst_data[(Z * size.x + X) * size.y + Y] = convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(Y, Z * size.x + X)));
|
||||
}
|
||||
__kernel void to_format_NHWC4_to_NHWC4_BUF(__read_only image2d_t src_data, __global FLT4 *dst_data, int4 size,
|
||||
int4 shape) {
|
||||
__kernel void to_format_NC4HW4_to_NC4HW4_BUF_half(__read_only image2d_t src_data, __global half4 *dst_data, int4 size,
|
||||
int4 shape) {
|
||||
// size(h, w, c, 1), shape(n, c, h, w)
|
||||
int X = get_global_id(0); // h
|
||||
int Y = get_global_id(1); // w
|
||||
int Z = get_global_id(2); // c
|
||||
if (X >= size.x || Y >= size.y || Z >= size.z) {
|
||||
return;
|
||||
}
|
||||
dst_data[(Z * size.x + X) * size.y + Y] = convert_half4(READ_IMAGE(src_data, smp_zero, (int2)(Y, Z * size.x + X)));
|
||||
}
|
||||
__kernel void to_format_NHWC4_to_NHWC4_BUF_float(__read_only image2d_t src_data, __global float4 *dst_data, int4 size,
|
||||
int4 shape) {
|
||||
int X = get_global_id(0);
|
||||
int Y = get_global_id(1);
|
||||
int Z = get_global_id(2);
|
||||
if (X >= size.x || Y >= size.y || Z >= size.z) {
|
||||
return;
|
||||
}
|
||||
dst_data[(X * size.y + Y) * size.z + Z] = READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + Z, X));
|
||||
dst_data[(X * size.y + Y) * size.z + Z] = convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + Z, X)));
|
||||
}
|
||||
|
|
|
@ -42,10 +42,12 @@ int ToFormatOpenCLKernel::Init() {
|
|||
{schema::Format_NC, "NHWC"}, {schema::Format_NHWC4, "NHWC4"}};
|
||||
std::string kernel_name =
|
||||
"to_format_" + format_str[in_tensors_[0]->GetFormat()] + "_to_" + format_str[out_tensors_[0]->GetFormat()];
|
||||
std::map<TypeId, std::string> dtype_str{
|
||||
{kNumberTypeFloat32, "float"}, {kNumberTypeFloat16, "half"}, {kNumberTypeInt8, "Int8"}};
|
||||
if (out_mem_type_ == OpenCLMemType::IMG) {
|
||||
kernel_name += "_IMG";
|
||||
kernel_name += "_IMG_" + dtype_str[in_tensors_[0]->data_type()];
|
||||
} else {
|
||||
kernel_name += "_BUF";
|
||||
kernel_name += "_BUF_" + dtype_str[out_tensors_[0]->data_type()];
|
||||
}
|
||||
|
||||
this->set_name(kernel_name);
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
|
||||
#include <set>
|
||||
#include "src/runtime/opencl/opencl_executor.h"
|
||||
#include "src/runtime/opencl/opencl_runtime.h"
|
||||
#include "src/runtime/kernel/opencl/utils.h"
|
||||
|
@ -181,11 +182,31 @@ int SubGraphOpenCLKernel::Init() {
|
|||
}
|
||||
nodes_.insert(nodes_.end(), out_convert_ops_.begin(), out_convert_ops_.end());
|
||||
|
||||
UpdateTensorDataType();
|
||||
|
||||
MallocTensorWithReuse();
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int SubGraphOpenCLKernel::UpdateTensorDataType() {
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
bool is_fp16 = ocl_runtime->GetFp16Enable();
|
||||
if (is_fp16 && (in_tensors_[0]->data_type() == kNumberTypeFloat32)) {
|
||||
std::set<lite::tensor::Tensor *> out_set;
|
||||
out_set.insert(in_tensors_.begin(), in_tensors_.end());
|
||||
out_set.insert(out_tensors_.begin(), out_tensors_.end());
|
||||
for (auto iv : nodes_) {
|
||||
auto cur_outs = iv->out_tensors();
|
||||
for (auto jv : cur_outs) {
|
||||
if (out_set.count(jv) == 0) {
|
||||
jv->set_data_type(kNumberTypeFloat16);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
int SubGraphOpenCLKernel::MallocTensorWithReuse() {
|
||||
kernel::LiteKernelUtil::InitTensorRefCount(nodes_);
|
||||
for (auto *kernel : nodes_) {
|
||||
|
|
|
@ -46,6 +46,7 @@ class SubGraphOpenCLKernel : public SubGraphKernel {
|
|||
int UnInit();
|
||||
|
||||
protected:
|
||||
int UpdateTensorDataType();
|
||||
int MallocTensorWithReuse();
|
||||
int GenToFormatOp(const std::vector<lite::tensor::Tensor *> &in_tensors,
|
||||
const std::vector<std::vector<kernel::LiteKernel *>> in_kernels,
|
||||
|
|
|
@ -301,12 +301,12 @@ int OpenCLRuntime::BuildKernel(cl::Kernel &kernel, const std::string &program_na
|
|||
// fp16 enable, kernel will use half and read_imageh and write_imageh.
|
||||
build_options_str =
|
||||
"-DFLT=half -DFLT4=half4 -DFLT16=half16 "
|
||||
"-DWRITE_IMAGE=write_imageh -DREAD_IMAGE=read_imageh -DTO_FLT4=convert_half4 ";
|
||||
"-DWRITE_IMAGE=write_imageh -DREAD_IMAGE=read_imageh -DTO_FLT=convert_half -DTO_FLT4=convert_half4 ";
|
||||
} else {
|
||||
// fp16 not enable, kernel will use float and read_imagef and write_imagef.
|
||||
build_options_str =
|
||||
"-DFLT=float -DFLT4=float4 -DFLT16=float16 "
|
||||
"-DWRITE_IMAGE=write_imagef -DREAD_IMAGE=read_imagef -DTO_FLT4=convert_float4 ";
|
||||
"-DWRITE_IMAGE=write_imagef -DREAD_IMAGE=read_imagef -DTO_FLT=convert_float -DTO_FLT4=convert_float4 ";
|
||||
}
|
||||
|
||||
auto build_options_ext = std::accumulate(
|
||||
|
|
Loading…
Reference in New Issue