fullconnection support 3d

This commit is contained in:
chenzupeng 2020-12-01 20:27:35 +08:00
parent b273a46c53
commit 3cfcddbd26
5 changed files with 203 additions and 70 deletions

View File

@ -4,23 +4,17 @@
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__kernel void FullConnection(__read_only image2d_t input, __write_only image2d_t output, __global FLT16 *weight,
__read_only image2d_t bias, int4 in_shape, int2 out_shape, int act_type) {
__read_only image2d_t bias, int N, int CI4, int CO4, int2 in_img_shape, int act_type) {
int gidx = get_global_id(0); // CO4
int gidz = get_global_id(2); // N
int lidx = get_local_id(0);
int lidy = get_local_id(1);
int ci4 = UP_DIV(in_shape.w, C4NUM);
int hwci4 = ci4 * in_shape.y * in_shape.z;
int wci4 = ci4 * in_shape.z;
int co4 = UP_DIV(out_shape.y, C4NUM);
int n = out_shape.x;
bool inside = gidx < co4 && gidz < n;
bool inside = gidx < CO4 && gidz < N;
FLT4 result = (FLT4)(0.0f);
for (uint i = lidy; i < hwci4 && inside; i += 4) {
int index_h = i / wci4;
int index_wci4 = i % wci4;
FLT4 v = READ_IMAGE(input, smp_zero, (int2)(index_wci4, gidz * in_shape.y + index_h));
FLT16 w = weight[i * co4 + gidx];
for (uint i = lidy; i < CI4 && inside; i += 4) {
int index = gidz * CI4 + i;
FLT4 v = READ_IMAGE(input, smp_zero, (int2)(index % in_img_shape.y, index / in_img_shape.y));
FLT16 w = weight[i * CO4 + gidx];
result.x += dot(v, w.s0123);
result.y += dot(v, w.s4567);
result.z += dot(v, w.s89ab);
@ -46,3 +40,45 @@ __kernel void FullConnection(__read_only image2d_t input, __write_only image2d_t
WRITE_IMAGE(output, (int2)(gidx, gidz), result);
}
}
__kernel void FullConnectionWeightVar(__read_only image2d_t input, __write_only image2d_t output,
__read_only image2d_t weight, __read_only image2d_t bias, int N, int CI4, int CO4,
int2 in_img_shape, int act_type) {
int gidx = get_global_id(0); // CO4
int gidz = get_global_id(2); // N
int lidx = get_local_id(0);
int lidy = get_local_id(1);
bool inside = gidx < CO4 && gidz < N;
FLT4 result = (FLT4)(0.0f);
for (uint i = lidy; i < CI4 && inside; i += 4) {
int index = gidz * CI4 + i;
FLT4 v = READ_IMAGE(input, smp_zero, (int2)(index % in_img_shape.y, index / in_img_shape.y));
FLT4 weight0 = READ_IMAGE(weight, smp_zero, (int2)(i, gidx * 4));
result.x += dot(v, weight0);
FLT4 weight1 = READ_IMAGE(weight, smp_zero, (int2)(i, gidx * 4 + 1));
result.y += dot(v, weight1);
FLT4 weight2 = READ_IMAGE(weight, smp_zero, (int2)(i, gidx * 4 + 2));
result.z += dot(v, weight2);
FLT4 weight3 = READ_IMAGE(weight, smp_zero, (int2)(i, gidx * 4 + 3));
result.w += dot(v, weight3);
}
__local FLT4 temp[32][4];
temp[lidx][lidy] = result;
barrier(CLK_LOCAL_MEM_FENCE);
if (lidy == 0 && inside) {
result += temp[lidx][1];
result += temp[lidx][2];
result += temp[lidx][3];
result += READ_IMAGE(bias, smp_zero, (int2)(gidx, 0));
if (act_type == ActivationType_RELU) {
result = max(result, (FLT4)(0.0f));
} else if (act_type == ActivationType_RELU6) {
result = clamp(result, (FLT4)(0.0f), (FLT4)(6.0f));
} else if (act_type == ActivationType_TANH) {
FLT4 exp0 = exp(result);
FLT4 exp1 = exp(-result);
result = (exp0 - exp1) / (exp0 + exp1);
}
WRITE_IMAGE(output, (int2)(gidx, gidz), result);
}
}

View File

@ -42,15 +42,39 @@ int FullConnectionOpenCLKernel::CheckSpecs() {
MS_LOG(ERROR) << "fullconnection only support a_transpose_=false yet.";
return RET_ERROR;
}
if ((in_tensors_[0]->shape().size() != 4 && in_tensors_[0]->shape().size() != 2) ||
(out_tensors_[0]->shape().size() != 4 && out_tensors_[0]->shape().size() != 2)) {
MS_LOG(ERROR) << "fullconnection only support input output shape size = 2 or 4";
auto out_gpu_info = GpuTensorInfo(out_tensors_[0]);
if (out_gpu_info.H != 1 || out_gpu_info.W != 1) {
MS_LOG(ERROR) << "fullconnection only support 2d output shape or 4d output but H=W=1";
return RET_ERROR;
}
if (param->act_type_ != ActType_No && param->act_type_ != ActType_Relu && param->act_type_ != ActType_Relu6) {
MS_LOG(ERROR) << "Unsupported activation type " << param->act_type_;
return RET_ERROR;
}
N_ = out_gpu_info.N;
CO_ = out_gpu_info.C;
auto intensor_shape = GpuTensorInfo(in_tensors_[0]);
int input_nhw = intensor_shape.N * intensor_shape.H * intensor_shape.W;
if (input_nhw < N_) {
MS_LOG(ERROR) << "Unsupported fullconnection shape";
}
if (!in_tensors_.at(kWeightIndex)->IsConst()) {
weight_var_ = true;
if (!param->b_transpose_) {
MS_LOG(ERROR) << "If fullconnection input weight is not constant, b_transpose_ should be true.";
return RET_ERROR;
}
if (in_tensors_.at(kWeightIndex)->shape().size() != 2) {
MS_LOG(ERROR) << "If fullconnection input weight is not constant, it should be 2d.";
return RET_ERROR;
}
if (intensor_shape.C != in_tensors_.at(kWeightIndex)->shape()[1]) {
MS_LOG(ERROR)
<< "If fullconnection input weight is not constant, input channel should equal to weight in_channel.";
return RET_ERROR;
}
}
CI_remainder_ = input_nhw / N_;
return RET_OK;
}
@ -61,8 +85,9 @@ int FullConnectionOpenCLKernel::Prepare() {
enable_fp16_ = ocl_runtime_->GetFp16Enable();
std::string kernel_name = "FullConnection";
inShape = GpuTensorInfo(in_tensors_[0]);
outShape = GpuTensorInfo(out_tensors_[0]);
if (weight_var_) {
kernel_name = "FullConnectionWeightVar";
}
#ifdef PROGRAM_WITH_IL
kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name);
#else
@ -82,23 +107,26 @@ int FullConnectionOpenCLKernel::Prepare() {
}
int FullConnectionOpenCLKernel::InitWeights() {
if (!in_tensors_.at(kWeightIndex)->IsConst()) {
MS_LOG(ERROR) << "FullConnection don't support non-constant filter yet.";
return RET_ERROR;
if (!weight_var_) {
auto ret = InitFilter();
if (ret != RET_OK) {
return ret;
}
}
return InitBias();
} // namespace mindspore::kernel
int FullConnectionOpenCLKernel::InitFilter() {
auto allocator = ocl_runtime_->GetAllocator();
int ci = inShape.C;
int ci4 = UP_DIV(ci, C4NUM);
int co = outShape.C;
int co4 = UP_DIV(co, C4NUM);
int h = inShape.H;
int w = inShape.W;
auto intensor_shape = GpuTensorInfo(in_tensors_[0]);
int co4 = UP_DIV(CO_, C4NUM);
int nhw_remainder = intensor_shape.N * intensor_shape.H * intensor_shape.W / N_;
size_t dtype_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float);
padWeight_ = allocator->Malloc(h * w * ci4 * co4 * C4NUM * C4NUM * dtype_size);
padWeight_ = allocator->Malloc(nhw_remainder * intensor_shape.Slice * co4 * C4NUM * C4NUM * dtype_size);
padWeight_ = allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true);
auto padWeightFp32 = reinterpret_cast<float *>(padWeight_);
auto padWeightFp16 = reinterpret_cast<float16_t *>(padWeight_);
memset(padWeight_, 0x00, h * w * ci4 * co4 * C4NUM * C4NUM * dtype_size);
memset(padWeight_, 0x00, nhw_remainder * intensor_shape.Slice * co4 * C4NUM * C4NUM * dtype_size);
auto originWeightFp32 = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->data_c());
auto originWeightFp16 = reinterpret_cast<float16_t *>(in_tensors_.at(kWeightIndex)->data_c());
bool isModelFp16 = in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat16;
@ -107,36 +135,33 @@ int FullConnectionOpenCLKernel::InitWeights() {
// HWCICO -> (HWCI4)(CO4)(4 from CO)(4 from CI)
// if tranposeB, COHWCI -> (HWCI4)(CO4)(4 from CO)(4 from CI)
int index = 0;
for (int hh = 0; hh < h; hh++) {
for (int ww = 0; ww < w; ww++) {
int baseHW = hh * w + ww;
for (int i = 0; i < ci4; ++i) {
for (int j = 0; j < co4; ++j) {
for (int k = 0; k < C4NUM; ++k) {
for (int l = 0; l < C4NUM; ++l) {
int src_ci = i * C4NUM + l;
int src_co = j * C4NUM + k;
if (src_ci < ci && src_co < co) {
int originId = baseHW * ci * co + src_ci * co + src_co;
if (transposeB) {
originId = src_co * ci * h * w + baseHW * ci + src_ci;
}
if (enable_fp16_) {
if (!isModelFp16) {
padWeightFp16[index++] = originWeightFp32[originId];
} else {
padWeightFp16[index++] = originWeightFp16[originId];
}
for (int nhw = 0; nhw < nhw_remainder; nhw++) {
for (int i = 0; i < intensor_shape.Slice; ++i) {
for (int j = 0; j < co4; ++j) {
for (int k = 0; k < C4NUM; ++k) {
for (int l = 0; l < C4NUM; ++l) {
int src_ci = i * C4NUM + l;
int src_co = j * C4NUM + k;
if (src_ci < intensor_shape.C && src_co < CO_) {
int originId = (nhw * intensor_shape.C + src_ci) * CO_ + src_co;
if (transposeB) {
originId = src_co * intensor_shape.C * nhw_remainder + nhw * intensor_shape.C + src_ci;
}
if (enable_fp16_) {
if (!isModelFp16) {
padWeightFp16[index++] = originWeightFp32[originId];
} else {
if (!isModelFp16) {
padWeightFp32[index++] = originWeightFp32[originId];
} else {
padWeightFp32[index++] = originWeightFp16[originId];
}
padWeightFp16[index++] = originWeightFp16[originId];
}
} else {
index++;
if (!isModelFp16) {
padWeightFp32[index++] = originWeightFp32[originId];
} else {
padWeightFp32[index++] = originWeightFp16[originId];
}
}
} else {
index++;
}
}
}
@ -144,8 +169,14 @@ int FullConnectionOpenCLKernel::InitWeights() {
}
}
allocator->UnmapBuffer(padWeight_);
return RET_OK;
}
int FullConnectionOpenCLKernel::InitBias() {
// pad FC Bias
auto allocator = ocl_runtime_->GetAllocator();
int co4 = UP_DIV(CO_, C4NUM);
size_t dtype_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float);
size_t im_dst_x, im_dst_y;
im_dst_x = co4;
im_dst_y = 1;
@ -163,15 +194,15 @@ int FullConnectionOpenCLKernel::InitWeights() {
return RET_ERROR;
}
if (in_tensors_[2]->data_type() == kNumberTypeFloat32 && enable_fp16_) {
for (int i = 0; i < co; i++) {
for (int i = 0; i < CO_; i++) {
reinterpret_cast<float16_t *>(bias_)[i] = reinterpret_cast<float *>(in_tensors_[2]->data_c())[i];
}
} else if (in_tensors_[2]->data_type() == kNumberTypeFloat16 && !enable_fp16_) {
for (int i = 0; i < co; i++) {
for (int i = 0; i < CO_; i++) {
reinterpret_cast<float *>(bias_)[i] = reinterpret_cast<float16_t *>(in_tensors_[2]->data_c())[i];
}
} else {
memcpy(bias_, in_tensors_[2]->data_c(), co * dtype_size);
memcpy(bias_, in_tensors_[2]->data_c(), CO_ * dtype_size);
}
}
allocator->UnmapBuffer(bias_);
@ -180,20 +211,27 @@ int FullConnectionOpenCLKernel::InitWeights() {
void FullConnectionOpenCLKernel::SetGlobalLocal() {
local_size_ = {32, 4, 1};
global_size_ = {UP_DIV(outShape.C, C4NUM), 4, outShape.N};
size_t CO = CO_;
size_t N = N_;
global_size_ = {UP_DIV(CO, C4NUM), 4, N};
AlignGlobalLocal(global_size_, local_size_);
}
void FullConnectionOpenCLKernel::SetConstArgs() {
int arg_count = 2;
cl_int4 in_shape = {static_cast<int>(inShape.N), static_cast<int>(inShape.H), static_cast<int>(inShape.W),
static_cast<int>(inShape.C)};
cl_int2 out_shape = {static_cast<int>(outShape.N), static_cast<int>(outShape.C)};
ocl_runtime_->SetKernelArg(kernel_, arg_count++, padWeight_, lite::opencl::MemType::BUF);
auto *param = reinterpret_cast<MatMulParameter *>(op_parameter_);
if (!weight_var_) {
ocl_runtime_->SetKernelArg(kernel_, 2, padWeight_, lite::opencl::MemType::BUF);
}
int arg_count = 3;
ocl_runtime_->SetKernelArg(kernel_, arg_count++, bias_);
ocl_runtime_->SetKernelArg(kernel_, arg_count++, in_shape);
ocl_runtime_->SetKernelArg(kernel_, arg_count++, out_shape);
ocl_runtime_->SetKernelArg(kernel_, arg_count++, N_);
auto intensor_shape = GpuTensorInfo(in_tensors_[0]);
int CI4 = CI_remainder_ * intensor_shape.Slice;
ocl_runtime_->SetKernelArg(kernel_, arg_count++, CI4);
ocl_runtime_->SetKernelArg(kernel_, arg_count++, UP_DIV(CO_, C4NUM));
auto in_shape_info = GpuTensorInfo(in_tensors_[0]);
cl_int2 in_img_shape = {static_cast<int>(in_shape_info.height), static_cast<int>(in_shape_info.width)};
ocl_runtime_->SetKernelArg(kernel_, arg_count++, in_img_shape);
auto *param = reinterpret_cast<MatMulParameter *>(op_parameter_);
ocl_runtime_->SetKernelArg(kernel_, arg_count, static_cast<cl_int>(param->act_type_));
}
@ -202,6 +240,9 @@ int FullConnectionOpenCLKernel::Run() {
int arg_count = 0;
ocl_runtime_->SetKernelArg(kernel_, arg_count++, in_tensors_[0]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_count++, out_tensors_[0]->data_c());
if (weight_var_) {
ocl_runtime_->SetKernelArg(kernel_, arg_count++, in_tensors_[1]->data_c());
}
ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr, &event_);
return RET_OK;
}

View File

@ -40,13 +40,17 @@ class FullConnectionOpenCLKernel : public OpenCLKernel {
int Tune() override { return lite::RET_OK; }
private:
int InitFilter();
int InitBias();
void *padWeight_{nullptr};
void *bias_{nullptr};
bool enable_fp16_{false};
bool transposeA{false};
bool transposeB{true};
GpuTensorInfo inShape = GpuTensorInfo(nullptr);
GpuTensorInfo outShape = GpuTensorInfo(nullptr);
bool weight_var_{false};
int N_{1};
int CI_remainder_{1};
int CO_{1};
};
} // namespace mindspore::kernel

View File

@ -68,7 +68,7 @@ cl_float4 ReduceOpenCLKernel::GenC4Mask() {
int ReduceOpenCLKernel::CheckSpecs() {
if (in_tensors_[0]->shape()[0] > 1) {
MS_LOG(ERROR) << "reduce op only support n=2";
MS_LOG(ERROR) << "reduce op only support n = 1";
return RET_PARAM_INVALID;
}
auto reduce_param = reinterpret_cast<ReduceParameter *>(op_parameter_);
@ -76,6 +76,10 @@ int ReduceOpenCLKernel::CheckSpecs() {
MS_LOG(ERROR) << "not supported reduce type:" << reduce_param->mode_;
return RET_PARAM_INVALID;
}
if (reduce_param->num_axes_ == 1 && reduce_param->axes_[0] == 3 && in_tensors_[0]->shape()[2] == 1) {
reduce_param->num_axes_ = 2;
reduce_param->axes_[1] = 2;
}
if (reduce_param->num_axes_ != 2) {
MS_LOG(ERROR) << "reduce op only support axes=2";
return RET_PARAM_INVALID;

View File

@ -24,7 +24,7 @@ namespace {
// PrimitiveType_FullConnection: src/ops/populate/full_connection_populate.cc
OpParameter *CreateParameter(std::vector<int> *input_shape, std::vector<int> *weight_shape,
std::vector<int> *bias_shape, std::vector<int> *output_shape, int ndim, int ci, int co,
int n = 1, int h = 1, int w = 1) {
int n = 1, int h = 1, int w = 1, int in_n = 1) {
auto *param = test::CreateParameter<MatMulParameter>(schema::PrimitiveType_FullConnection);
param->a_transpose_ = false;
param->b_transpose_ = true;
@ -41,6 +41,11 @@ OpParameter *CreateParameter(std::vector<int> *input_shape, std::vector<int> *we
*output_shape = {n, co};
*weight_shape = {co, h * w * ci};
*bias_shape = {co};
} else if (ndim == 3) {
*input_shape = {in_n, w, ci};
*output_shape = {n, co};
*weight_shape = {co, in_n * w * ci / n};
*bias_shape = {co};
}
return reinterpret_cast<OpParameter *>(param);
}
@ -87,4 +92,47 @@ TEST_F(TestOpenCL_FullConnection, 4D) {
}
}
TEST_F(TestOpenCL_FullConnection, 3D) {
int ndim = 3;
int ci = 3;
int co = 4;
int n = 2;
int h = 1;
int w = 4;
int in_n = 1;
float input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
float weight_data[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
float bias_data[] = {1, 1, 1, 1};
float output_data[] = {16, 16, 16, 16, 52, 52, 52, 52};
for (auto fp16_enable : {false, true}) {
std::vector<int> input_shape, weight_shape, bias_shape, output_shape;
auto *param = CreateParameter(&input_shape, &weight_shape, &bias_shape, &output_shape, ndim, ci, co, n, h, w, in_n);
TestMain({{input_shape, input_data, VAR},
{weight_shape, weight_data, CONST_TENSOR},
{bias_shape, bias_data, CONST_TENSOR}},
{output_shape, output_data}, param, fp16_enable);
}
}
TEST_F(TestOpenCL_FullConnection, 3DWeightVar) {
int ndim = 3;
int ci = 6;
int co = 4;
int n = 2;
int h = 1;
int w = 2;
int in_n = 1;
float input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
float weight_data[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
float bias_data[] = {1, 1, 1, 1};
float output_data[] = {16, 16, 16, 16, 52, 52, 52, 52};
for (auto fp16_enable : {false, true}) {
std::vector<int> input_shape, weight_shape, bias_shape, output_shape;
auto *param = CreateParameter(&input_shape, &weight_shape, &bias_shape, &output_shape, ndim, ci, co, n, h, w, in_n);
TestMain({{input_shape, input_data, VAR}, {weight_shape, weight_data, VAR}, {bias_shape, bias_data, CONST_TENSOR}},
{output_shape, output_data}, param, fp16_enable);
}
}
} // namespace mindspore::lite::opencl::test