!25206 Q888_CV_model_age_gender.pb bugfix, add gate config, ulfgf gaze_corrector

Merge pull request !25206 from Greatpan/master
This commit is contained in:
i-robot 2021-10-25 11:23:23 +00:00 committed by Gitee
commit 3f2d75f31b
7 changed files with 162 additions and 31 deletions

View File

@ -64,10 +64,8 @@ __kernel void Tanh(__read_only image2d_t input, __write_only image2d_t output, c
int Y = get_global_id(1);
if (X >= img_shape.x || Y >= img_shape.y) return;
FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X, Y));
FLT4 exp0 = exp(in_c4);
FLT4 exp1 = exp(-in_c4);
in_c4 = (exp0 - exp1) / (exp0 + exp1);
WRITE_IMAGE(output, (int2)(X, Y), in_c4);
in_c4 = clamp(in_c4, -10.0f, 10.0f);
WRITE_IMAGE(output, (int2)(X, Y), tanh(in_c4));
}
__kernel void Swish(__read_only image2d_t input, __write_only image2d_t output, const int2 img_shape) {

View File

@ -311,17 +311,99 @@ __kernel void GlobalCMean(__read_only image2d_t src_data, __write_only image2d_t
WRITE_IMAGE(dst_data, (int2)(0, X), result2); \
}
// HWC
__kernel void GlobalHWCMean(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) {
float4 value = (float4)0.f;
for (int h = 0; h < size.x; h++) {
for (int w = 0; w < size.y; w++) {
for (int c4 = 0; c4 < size.z; c4++) {
value += convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + c4, h)));
}
}
}
float4 result = (float4)0.f;
result.x = dot((float4)(1.0f), value) / (size.x * size.y * size.w);
WRITE_IMAGE(dst_data, (int2)(0, 0), TO_FLT4(result));
}
#define DoHWCSum(a, B) ((a) = dot((float4)(1.0f), (B)))
#define DoHWCMax(a, B) ((a) = max((B).x, max((B).y, max((B).z, (B).w))))
#define DoHWCMin(a, B) ((a) = min((B).x, min((B).y, min((B).z, (B).w))))
#define DoHWCProd(a, B) ((a) = (B).x * (B).y * (B).z * (B).w)
#define GlobalHWC(Method) \
__kernel void GlobalHWC##Method(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) { \
float4 value = (float4)0.f; \
for (int h = 0; h < size.x; h++) { \
for (int w = 0; w < size.y; w++) { \
for (int c4 = 0; c4 < size.z; c4++) { \
Do##Method(value, convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + c4, h)))); \
} \
} \
} \
float4 result = (float4)0.f; \
DoHWC##Method(result.x, value); \
WRITE_IMAGE(dst_data, (int2)(0, 0), TO_FLT4(result)); \
}
// H
__kernel void GlobalHMean(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) {
int w = get_global_id(0);
int c4 = get_global_id(1);
float4 result = (float4)0.f;
for (int h = 0; h < size.x; h++) {
result += convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + c4, h)));
}
result /= size.x;
WRITE_IMAGE(dst_data, (int2)(w * size.z + c4, 0), TO_FLT4(result));
}
#define GlobalH(Method) \
__kernel void GlobalH##Method(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) { \
int w = get_global_id(0); \
int c4 = get_global_id(1); \
float4 result = (float4)0.f; \
for (int h = 0; h < size.x; h++) { \
Do##Method(result, convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + c4, h)))); \
} \
WRITE_IMAGE(dst_data, (int2)(w * size.z + c4, 0), TO_FLT4(result)); \
}
// W
__kernel void GlobalWMean(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) {
int h = get_global_id(0);
int c4 = get_global_id(1);
float4 result = (float4)0.f;
for (int w = 0; w < size.y; w++) {
result += convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + c4, h)));
}
result /= size.y;
WRITE_IMAGE(dst_data, (int2)(c4, h), TO_FLT4(result));
}
#define GlobalW(Method) \
__kernel void GlobalW##Method(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) { \
int h = get_global_id(0); \
int c4 = get_global_id(1); \
float4 result = (float4)0.f; \
for (int w = 0; w < size.y; w++) { \
Do##Method(result, convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + c4, h)))); \
} \
WRITE_IMAGE(dst_data, (int2)(c4, h), TO_FLT4(result)); \
}
#define DoSum(A, B) A += B
#define InitSum 0.f
GlobalHW(Sum) GlobalWC(Sum) LocalHW(Sum) LocalWC(Sum)
GlobalHWC(Sum) GlobalHW(Sum) GlobalWC(Sum) GlobalH(Sum) GlobalW(Sum) LocalHW(Sum) LocalWC(Sum)
#define DoMin(A, B) A = min(A, B)
#define InitMin 10000.f
GlobalHW(Min) GlobalWC(Min) LocalHW(Min) LocalWC(Min)
GlobalHWC(Min) GlobalHW(Min) GlobalWC(Min) GlobalH(Min) GlobalW(Min) LocalHW(Min) LocalWC(Min)
#define DoMax(A, B) A = max(A, B)
#define InitMax -10000.f
GlobalHW(Max) GlobalWC(Max) LocalHW(Max) LocalWC(Max)
GlobalHWC(Max) GlobalHW(Max) GlobalWC(Max) GlobalH(Max) GlobalW(Max) LocalHW(Max) LocalWC(Max)
#define DoProd(A, B) A *= B
#define InitProd 1.f
GlobalHW(Prod) GlobalWC(Prod) LocalHW(Prod) LocalWC(Prod)
GlobalHWC(Prod) GlobalHW(Prod) GlobalWC(Prod) GlobalH(Prod) GlobalW(Prod) LocalHW(Prod) LocalWC(Prod)

View File

@ -191,7 +191,6 @@ __kernel void Softmax1x1_NHWC4(__read_only image2d_t input, __write_only image2d
__kernel void Softmax1x1_32_NHWC4(__read_only image2d_t input, __write_only image2d_t output, const float4 mask,
const int4 input_shape) {
const int MAX_C4_NUM = 8;
int n = get_global_id(1);
if (n >= input_shape.x) return;
@ -210,7 +209,7 @@ __kernel void Softmax1x1_32_NHWC4(__read_only image2d_t input, __write_only imag
float4 input_max_f4 = (float4)(input_max, input_max, input_max, input_max);
// Calc input sum value
float4 element_vec4[MAX_C4_NUM];
float4 element_vec4[8]; // 8 : MAX_C4_NUM
float4 sum_vec4 = convert_float4(mask);
sum_vec4 *= exp(convert_float4(READ_IMAGE(input, smp_zero, (int2)(C4 - 1, n))) - input_max_f4);
element_vec4[C4 - 1] = sum_vec4;

View File

@ -66,15 +66,27 @@ cl_float4 ReduceOpenCLKernel::GenC4Mask() {
return mask;
}
bool IsHWReduce(const bool *reduce_axes_) {
bool ReduceOpenCLKernel::IsHWCReduce() {
return !reduce_axes_[0] && reduce_axes_[1] && reduce_axes_[2] && reduce_axes_[3];
}
bool ReduceOpenCLKernel::IsHWReduce() {
return !reduce_axes_[0] && reduce_axes_[1] && reduce_axes_[2] && !reduce_axes_[3];
}
bool IsWCReduce(const bool *reduce_axes_) {
bool ReduceOpenCLKernel::IsWCReduce() {
return !reduce_axes_[0] && !reduce_axes_[1] && reduce_axes_[2] && reduce_axes_[3];
}
bool IsCReduce(const bool *reduce_axes_) {
bool ReduceOpenCLKernel::IsHReduce() {
return !reduce_axes_[0] && reduce_axes_[1] && !reduce_axes_[2] && !reduce_axes_[3];
}
bool ReduceOpenCLKernel::IsWReduce() {
return !reduce_axes_[0] && !reduce_axes_[1] && reduce_axes_[2] && !reduce_axes_[3];
}
bool ReduceOpenCLKernel::IsCReduce() {
return !reduce_axes_[0] && !reduce_axes_[1] && !reduce_axes_[2] && reduce_axes_[3];
}
@ -83,8 +95,26 @@ int ReduceOpenCLKernel::SetAxes() {
// get num_axes
int num_axes = 0;
auto *axes_tensor = in_tensors_.at(1);
if (axes_tensor->shape().size() == 0) {
CHECK_NULL_RETURN(axes_tensor->data());
auto reduction_indices = reinterpret_cast<int *>(axes_tensor->data())[0];
if (reduction_indices == -1) {
reduce_axes_[1] = true;
reduce_axes_[2] = true;
reduce_axes_[3] = true;
} else if (reduction_indices == 1 || reduction_indices == 2 || reduction_indices == 3) {
reduce_axes_[reduction_indices] = true;
} else {
MS_LOG(ERROR) << "in Reduce: axes tensor's reduction_indices should be -1, 1, 2, 3";
return RET_ERROR;
}
return RET_OK;
}
if (axes_tensor->shape().size() != 1) {
MS_LOG(ERROR) << "in Reduce: axes tensor's ndim should be 1.";
MS_LOG(ERROR) << "in Reduce: axes tensor's ndim should be 0 or 1.";
return RET_ERROR;
} else {
num_axes = axes_tensor->shape().front();
@ -146,14 +176,12 @@ int ReduceOpenCLKernel::CheckSpecs() {
if (ret != RET_OK) {
return ret;
}
hw_reduce_ = IsHWReduce(reduce_axes_);
wc_reduce_ = IsWCReduce(reduce_axes_);
c_reduce_ = IsCReduce(reduce_axes_);
if (!hw_reduce_ && !wc_reduce_ && !c_reduce_) {
if (!IsHWReduce() && !IsWCReduce() && !IsHReduce() && !IsWReduce() && !IsCReduce() && !IsHWCReduce()) {
MS_LOG(WARNING) << "Unsupported reduce axes";
return RET_PARAM_INVALID;
}
if ((c_reduce_ || wc_reduce_) && !reduce_param->keep_dims_) {
if ((IsCReduce() || IsWCReduce() || IsWReduce()) && !reduce_param->keep_dims_) {
MS_LOG(WARNING) << "reduce axis (2,3) should keep dims";
return RET_PARAM_INVALID;
}
@ -169,19 +197,26 @@ int ReduceOpenCLKernel::Prepare() {
std::string kernel_name;
use_local_ = false;
kernel_name = "Global";
if (wc_reduce_ && (inShape.W >= LOCAL_CACHE_THREAD || inShape.C >= LOCAL_CACHE_THREAD)) {
if (IsWCReduce() && (inShape.W >= LOCAL_CACHE_THREAD || inShape.C >= LOCAL_CACHE_THREAD)) {
use_local_ = true;
kernel_name = "Local";
}
if (hw_reduce_ && (inShape.W >= LOCAL_CACHE_THREAD || inShape.H >= LOCAL_CACHE_THREAD)) {
if (IsHWReduce() && (inShape.W >= LOCAL_CACHE_THREAD || inShape.H >= LOCAL_CACHE_THREAD)) {
use_local_ = true;
kernel_name = "Local";
}
if (wc_reduce_) {
if (IsHWCReduce()) {
kernel_name += "HWC";
} else if (IsWCReduce()) {
kernel_name += "WC";
} else if (hw_reduce_) {
} else if (IsHWReduce()) {
kernel_name += "HW";
} else if (c_reduce_) {
} else if (IsHReduce()) {
kernel_name += "H";
} else if (IsWReduce()) {
kernel_name += "W";
} else if (IsCReduce()) {
kernel_name += "C";
}
kernel_name += GetReduceTypeStr(reduce_param->mode_);
@ -217,7 +252,7 @@ int ReduceOpenCLKernel::SetConstArgs() {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
if (wc_reduce_ || c_reduce_) {
if (IsWCReduce() || IsCReduce()) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_idx++, GenC4Mask()) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
@ -234,13 +269,22 @@ void ReduceOpenCLKernel::SetGlobalLocal() {
if (use_local_) {
local_size_ = {1, LOCAL_CACHE_THREAD, LOCAL_CACHE_THREAD};
}
if (hw_reduce_) {
if (IsHWCReduce()) {
global_size_ = {1, 1, 1};
} else if (IsHWReduce()) {
global_size_ = {static_cast<size_t>(c4), 1, 1};
} else if (wc_reduce_) {
} else if (IsWCReduce()) {
global_size_ = {static_cast<size_t>(h), 1, 1};
} else if (c_reduce_ && !use_local_) {
} else if (IsHReduce()) {
global_size_ = {static_cast<size_t>(w), static_cast<size_t>(c4)};
} else if (IsWReduce()) {
global_size_ = {static_cast<size_t>(h), static_cast<size_t>(c4)};
} else if (IsCReduce() && !use_local_) {
global_size_ = {static_cast<size_t>(h), static_cast<size_t>(w)};
} else {
global_size_ = {1, 1, 1};
}
AlignGlobalLocal(global_size_, local_size_);
}

View File

@ -41,11 +41,15 @@ class ReduceOpenCLKernel : public OpenCLKernel {
cl_float4 GenC4Mask();
static std::string GetReduceTypeStr(int type);
bool IsHWCReduce();
bool IsHWReduce();
bool IsWCReduce();
bool IsHReduce();
bool IsWReduce();
bool IsCReduce();
GpuTensorInfo inShape;
bool use_local_{false};
bool wc_reduce_{false};
bool hw_reduce_{false};
bool c_reduce_{false};
bool reduce_axes_[4]{false};
static const size_t LOCAL_CACHE_THREAD{16};
int axes_[MAX_SHAPE_SIZE];

View File

@ -114,3 +114,5 @@ ml_motion_capture_nanodet_m_0.5x_people_0928_sim.onnx;1:input.1
ml_motion_capture_smpl_0916.onnx;3:beta,body_pose,global_orient
ml_motion_capture_spin_mobile_mv3_v3_57mm_sim.onnx;5:input,bbox,init_pose,init_shape,init_cam
ml_video_edit_dimming_tech_model_345000_color.onnx;2:input.18,1
Ireland_ulfgf.onnx;1:input;1,240,320,3
Ireland_gaze_corrector.onnx;3:image,target_angle,strength 1

View File

@ -112,3 +112,5 @@ ml_motion_capture_nanodet_m_0.5x_people_0928_sim.onnx 8
ml_motion_capture_smpl_0916.onnx;3
ml_motion_capture_spin_mobile_mv3_v3_57mm_sim.onnx;5 18
ml_video_edit_dimming_tech_model_345000_color.onnx;2 2
Ireland_ulfgf.onnx;1;1,240,320,3
Ireland_gaze_corrector.onnx;3 15