forked from mindspore-Ecosystem/mindspore
fullconnection support 3d
This commit is contained in:
parent
b273a46c53
commit
3cfcddbd26
|
@ -4,23 +4,17 @@
|
|||
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
|
||||
__kernel void FullConnection(__read_only image2d_t input, __write_only image2d_t output, __global FLT16 *weight,
|
||||
__read_only image2d_t bias, int4 in_shape, int2 out_shape, int act_type) {
|
||||
__read_only image2d_t bias, int N, int CI4, int CO4, int2 in_img_shape, int act_type) {
|
||||
int gidx = get_global_id(0); // CO4
|
||||
int gidz = get_global_id(2); // N
|
||||
int lidx = get_local_id(0);
|
||||
int lidy = get_local_id(1);
|
||||
int ci4 = UP_DIV(in_shape.w, C4NUM);
|
||||
int hwci4 = ci4 * in_shape.y * in_shape.z;
|
||||
int wci4 = ci4 * in_shape.z;
|
||||
int co4 = UP_DIV(out_shape.y, C4NUM);
|
||||
int n = out_shape.x;
|
||||
bool inside = gidx < co4 && gidz < n;
|
||||
bool inside = gidx < CO4 && gidz < N;
|
||||
FLT4 result = (FLT4)(0.0f);
|
||||
for (uint i = lidy; i < hwci4 && inside; i += 4) {
|
||||
int index_h = i / wci4;
|
||||
int index_wci4 = i % wci4;
|
||||
FLT4 v = READ_IMAGE(input, smp_zero, (int2)(index_wci4, gidz * in_shape.y + index_h));
|
||||
FLT16 w = weight[i * co4 + gidx];
|
||||
for (uint i = lidy; i < CI4 && inside; i += 4) {
|
||||
int index = gidz * CI4 + i;
|
||||
FLT4 v = READ_IMAGE(input, smp_zero, (int2)(index % in_img_shape.y, index / in_img_shape.y));
|
||||
FLT16 w = weight[i * CO4 + gidx];
|
||||
result.x += dot(v, w.s0123);
|
||||
result.y += dot(v, w.s4567);
|
||||
result.z += dot(v, w.s89ab);
|
||||
|
@ -46,3 +40,45 @@ __kernel void FullConnection(__read_only image2d_t input, __write_only image2d_t
|
|||
WRITE_IMAGE(output, (int2)(gidx, gidz), result);
|
||||
}
|
||||
}
|
||||
|
||||
__kernel void FullConnectionWeightVar(__read_only image2d_t input, __write_only image2d_t output,
|
||||
__read_only image2d_t weight, __read_only image2d_t bias, int N, int CI4, int CO4,
|
||||
int2 in_img_shape, int act_type) {
|
||||
int gidx = get_global_id(0); // CO4
|
||||
int gidz = get_global_id(2); // N
|
||||
int lidx = get_local_id(0);
|
||||
int lidy = get_local_id(1);
|
||||
bool inside = gidx < CO4 && gidz < N;
|
||||
FLT4 result = (FLT4)(0.0f);
|
||||
for (uint i = lidy; i < CI4 && inside; i += 4) {
|
||||
int index = gidz * CI4 + i;
|
||||
FLT4 v = READ_IMAGE(input, smp_zero, (int2)(index % in_img_shape.y, index / in_img_shape.y));
|
||||
FLT4 weight0 = READ_IMAGE(weight, smp_zero, (int2)(i, gidx * 4));
|
||||
result.x += dot(v, weight0);
|
||||
FLT4 weight1 = READ_IMAGE(weight, smp_zero, (int2)(i, gidx * 4 + 1));
|
||||
result.y += dot(v, weight1);
|
||||
FLT4 weight2 = READ_IMAGE(weight, smp_zero, (int2)(i, gidx * 4 + 2));
|
||||
result.z += dot(v, weight2);
|
||||
FLT4 weight3 = READ_IMAGE(weight, smp_zero, (int2)(i, gidx * 4 + 3));
|
||||
result.w += dot(v, weight3);
|
||||
}
|
||||
__local FLT4 temp[32][4];
|
||||
temp[lidx][lidy] = result;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (lidy == 0 && inside) {
|
||||
result += temp[lidx][1];
|
||||
result += temp[lidx][2];
|
||||
result += temp[lidx][3];
|
||||
result += READ_IMAGE(bias, smp_zero, (int2)(gidx, 0));
|
||||
if (act_type == ActivationType_RELU) {
|
||||
result = max(result, (FLT4)(0.0f));
|
||||
} else if (act_type == ActivationType_RELU6) {
|
||||
result = clamp(result, (FLT4)(0.0f), (FLT4)(6.0f));
|
||||
} else if (act_type == ActivationType_TANH) {
|
||||
FLT4 exp0 = exp(result);
|
||||
FLT4 exp1 = exp(-result);
|
||||
result = (exp0 - exp1) / (exp0 + exp1);
|
||||
}
|
||||
WRITE_IMAGE(output, (int2)(gidx, gidz), result);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -42,15 +42,39 @@ int FullConnectionOpenCLKernel::CheckSpecs() {
|
|||
MS_LOG(ERROR) << "fullconnection only support a_transpose_=false yet.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if ((in_tensors_[0]->shape().size() != 4 && in_tensors_[0]->shape().size() != 2) ||
|
||||
(out_tensors_[0]->shape().size() != 4 && out_tensors_[0]->shape().size() != 2)) {
|
||||
MS_LOG(ERROR) << "fullconnection only support input output shape size = 2 or 4";
|
||||
auto out_gpu_info = GpuTensorInfo(out_tensors_[0]);
|
||||
if (out_gpu_info.H != 1 || out_gpu_info.W != 1) {
|
||||
MS_LOG(ERROR) << "fullconnection only support 2d output shape or 4d output but H=W=1";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (param->act_type_ != ActType_No && param->act_type_ != ActType_Relu && param->act_type_ != ActType_Relu6) {
|
||||
MS_LOG(ERROR) << "Unsupported activation type " << param->act_type_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
N_ = out_gpu_info.N;
|
||||
CO_ = out_gpu_info.C;
|
||||
auto intensor_shape = GpuTensorInfo(in_tensors_[0]);
|
||||
int input_nhw = intensor_shape.N * intensor_shape.H * intensor_shape.W;
|
||||
if (input_nhw < N_) {
|
||||
MS_LOG(ERROR) << "Unsupported fullconnection shape";
|
||||
}
|
||||
if (!in_tensors_.at(kWeightIndex)->IsConst()) {
|
||||
weight_var_ = true;
|
||||
if (!param->b_transpose_) {
|
||||
MS_LOG(ERROR) << "If fullconnection input weight is not constant, b_transpose_ should be true.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (in_tensors_.at(kWeightIndex)->shape().size() != 2) {
|
||||
MS_LOG(ERROR) << "If fullconnection input weight is not constant, it should be 2d.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (intensor_shape.C != in_tensors_.at(kWeightIndex)->shape()[1]) {
|
||||
MS_LOG(ERROR)
|
||||
<< "If fullconnection input weight is not constant, input channel should equal to weight in_channel.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
CI_remainder_ = input_nhw / N_;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -61,8 +85,9 @@ int FullConnectionOpenCLKernel::Prepare() {
|
|||
enable_fp16_ = ocl_runtime_->GetFp16Enable();
|
||||
|
||||
std::string kernel_name = "FullConnection";
|
||||
inShape = GpuTensorInfo(in_tensors_[0]);
|
||||
outShape = GpuTensorInfo(out_tensors_[0]);
|
||||
if (weight_var_) {
|
||||
kernel_name = "FullConnectionWeightVar";
|
||||
}
|
||||
#ifdef PROGRAM_WITH_IL
|
||||
kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name);
|
||||
#else
|
||||
|
@ -82,23 +107,26 @@ int FullConnectionOpenCLKernel::Prepare() {
|
|||
}
|
||||
|
||||
int FullConnectionOpenCLKernel::InitWeights() {
|
||||
if (!in_tensors_.at(kWeightIndex)->IsConst()) {
|
||||
MS_LOG(ERROR) << "FullConnection don't support non-constant filter yet.";
|
||||
return RET_ERROR;
|
||||
if (!weight_var_) {
|
||||
auto ret = InitFilter();
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
return InitBias();
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
int FullConnectionOpenCLKernel::InitFilter() {
|
||||
auto allocator = ocl_runtime_->GetAllocator();
|
||||
int ci = inShape.C;
|
||||
int ci4 = UP_DIV(ci, C4NUM);
|
||||
int co = outShape.C;
|
||||
int co4 = UP_DIV(co, C4NUM);
|
||||
int h = inShape.H;
|
||||
int w = inShape.W;
|
||||
auto intensor_shape = GpuTensorInfo(in_tensors_[0]);
|
||||
int co4 = UP_DIV(CO_, C4NUM);
|
||||
int nhw_remainder = intensor_shape.N * intensor_shape.H * intensor_shape.W / N_;
|
||||
size_t dtype_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float);
|
||||
padWeight_ = allocator->Malloc(h * w * ci4 * co4 * C4NUM * C4NUM * dtype_size);
|
||||
padWeight_ = allocator->Malloc(nhw_remainder * intensor_shape.Slice * co4 * C4NUM * C4NUM * dtype_size);
|
||||
padWeight_ = allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true);
|
||||
auto padWeightFp32 = reinterpret_cast<float *>(padWeight_);
|
||||
auto padWeightFp16 = reinterpret_cast<float16_t *>(padWeight_);
|
||||
memset(padWeight_, 0x00, h * w * ci4 * co4 * C4NUM * C4NUM * dtype_size);
|
||||
memset(padWeight_, 0x00, nhw_remainder * intensor_shape.Slice * co4 * C4NUM * C4NUM * dtype_size);
|
||||
auto originWeightFp32 = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->data_c());
|
||||
auto originWeightFp16 = reinterpret_cast<float16_t *>(in_tensors_.at(kWeightIndex)->data_c());
|
||||
bool isModelFp16 = in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat16;
|
||||
|
@ -107,36 +135,33 @@ int FullConnectionOpenCLKernel::InitWeights() {
|
|||
// HWCICO -> (HWCI4)(CO4)(4 from CO)(4 from CI)
|
||||
// if tranposeB, COHWCI -> (HWCI4)(CO4)(4 from CO)(4 from CI)
|
||||
int index = 0;
|
||||
for (int hh = 0; hh < h; hh++) {
|
||||
for (int ww = 0; ww < w; ww++) {
|
||||
int baseHW = hh * w + ww;
|
||||
for (int i = 0; i < ci4; ++i) {
|
||||
for (int j = 0; j < co4; ++j) {
|
||||
for (int k = 0; k < C4NUM; ++k) {
|
||||
for (int l = 0; l < C4NUM; ++l) {
|
||||
int src_ci = i * C4NUM + l;
|
||||
int src_co = j * C4NUM + k;
|
||||
if (src_ci < ci && src_co < co) {
|
||||
int originId = baseHW * ci * co + src_ci * co + src_co;
|
||||
if (transposeB) {
|
||||
originId = src_co * ci * h * w + baseHW * ci + src_ci;
|
||||
}
|
||||
if (enable_fp16_) {
|
||||
if (!isModelFp16) {
|
||||
padWeightFp16[index++] = originWeightFp32[originId];
|
||||
} else {
|
||||
padWeightFp16[index++] = originWeightFp16[originId];
|
||||
}
|
||||
for (int nhw = 0; nhw < nhw_remainder; nhw++) {
|
||||
for (int i = 0; i < intensor_shape.Slice; ++i) {
|
||||
for (int j = 0; j < co4; ++j) {
|
||||
for (int k = 0; k < C4NUM; ++k) {
|
||||
for (int l = 0; l < C4NUM; ++l) {
|
||||
int src_ci = i * C4NUM + l;
|
||||
int src_co = j * C4NUM + k;
|
||||
if (src_ci < intensor_shape.C && src_co < CO_) {
|
||||
int originId = (nhw * intensor_shape.C + src_ci) * CO_ + src_co;
|
||||
if (transposeB) {
|
||||
originId = src_co * intensor_shape.C * nhw_remainder + nhw * intensor_shape.C + src_ci;
|
||||
}
|
||||
if (enable_fp16_) {
|
||||
if (!isModelFp16) {
|
||||
padWeightFp16[index++] = originWeightFp32[originId];
|
||||
} else {
|
||||
if (!isModelFp16) {
|
||||
padWeightFp32[index++] = originWeightFp32[originId];
|
||||
} else {
|
||||
padWeightFp32[index++] = originWeightFp16[originId];
|
||||
}
|
||||
padWeightFp16[index++] = originWeightFp16[originId];
|
||||
}
|
||||
} else {
|
||||
index++;
|
||||
if (!isModelFp16) {
|
||||
padWeightFp32[index++] = originWeightFp32[originId];
|
||||
} else {
|
||||
padWeightFp32[index++] = originWeightFp16[originId];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -144,8 +169,14 @@ int FullConnectionOpenCLKernel::InitWeights() {
|
|||
}
|
||||
}
|
||||
allocator->UnmapBuffer(padWeight_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int FullConnectionOpenCLKernel::InitBias() {
|
||||
// pad FC Bias
|
||||
auto allocator = ocl_runtime_->GetAllocator();
|
||||
int co4 = UP_DIV(CO_, C4NUM);
|
||||
size_t dtype_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float);
|
||||
size_t im_dst_x, im_dst_y;
|
||||
im_dst_x = co4;
|
||||
im_dst_y = 1;
|
||||
|
@ -163,15 +194,15 @@ int FullConnectionOpenCLKernel::InitWeights() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
if (in_tensors_[2]->data_type() == kNumberTypeFloat32 && enable_fp16_) {
|
||||
for (int i = 0; i < co; i++) {
|
||||
for (int i = 0; i < CO_; i++) {
|
||||
reinterpret_cast<float16_t *>(bias_)[i] = reinterpret_cast<float *>(in_tensors_[2]->data_c())[i];
|
||||
}
|
||||
} else if (in_tensors_[2]->data_type() == kNumberTypeFloat16 && !enable_fp16_) {
|
||||
for (int i = 0; i < co; i++) {
|
||||
for (int i = 0; i < CO_; i++) {
|
||||
reinterpret_cast<float *>(bias_)[i] = reinterpret_cast<float16_t *>(in_tensors_[2]->data_c())[i];
|
||||
}
|
||||
} else {
|
||||
memcpy(bias_, in_tensors_[2]->data_c(), co * dtype_size);
|
||||
memcpy(bias_, in_tensors_[2]->data_c(), CO_ * dtype_size);
|
||||
}
|
||||
}
|
||||
allocator->UnmapBuffer(bias_);
|
||||
|
@ -180,20 +211,27 @@ int FullConnectionOpenCLKernel::InitWeights() {
|
|||
|
||||
void FullConnectionOpenCLKernel::SetGlobalLocal() {
|
||||
local_size_ = {32, 4, 1};
|
||||
global_size_ = {UP_DIV(outShape.C, C4NUM), 4, outShape.N};
|
||||
size_t CO = CO_;
|
||||
size_t N = N_;
|
||||
global_size_ = {UP_DIV(CO, C4NUM), 4, N};
|
||||
AlignGlobalLocal(global_size_, local_size_);
|
||||
}
|
||||
|
||||
void FullConnectionOpenCLKernel::SetConstArgs() {
|
||||
int arg_count = 2;
|
||||
cl_int4 in_shape = {static_cast<int>(inShape.N), static_cast<int>(inShape.H), static_cast<int>(inShape.W),
|
||||
static_cast<int>(inShape.C)};
|
||||
cl_int2 out_shape = {static_cast<int>(outShape.N), static_cast<int>(outShape.C)};
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_count++, padWeight_, lite::opencl::MemType::BUF);
|
||||
auto *param = reinterpret_cast<MatMulParameter *>(op_parameter_);
|
||||
if (!weight_var_) {
|
||||
ocl_runtime_->SetKernelArg(kernel_, 2, padWeight_, lite::opencl::MemType::BUF);
|
||||
}
|
||||
int arg_count = 3;
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_count++, bias_);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_count++, in_shape);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_count++, out_shape);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_count++, N_);
|
||||
auto intensor_shape = GpuTensorInfo(in_tensors_[0]);
|
||||
int CI4 = CI_remainder_ * intensor_shape.Slice;
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_count++, CI4);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_count++, UP_DIV(CO_, C4NUM));
|
||||
auto in_shape_info = GpuTensorInfo(in_tensors_[0]);
|
||||
cl_int2 in_img_shape = {static_cast<int>(in_shape_info.height), static_cast<int>(in_shape_info.width)};
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_count++, in_img_shape);
|
||||
auto *param = reinterpret_cast<MatMulParameter *>(op_parameter_);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_count, static_cast<cl_int>(param->act_type_));
|
||||
}
|
||||
|
||||
|
@ -202,6 +240,9 @@ int FullConnectionOpenCLKernel::Run() {
|
|||
int arg_count = 0;
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_count++, in_tensors_[0]->data_c());
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_count++, out_tensors_[0]->data_c());
|
||||
if (weight_var_) {
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_count++, in_tensors_[1]->data_c());
|
||||
}
|
||||
ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr, &event_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -40,13 +40,17 @@ class FullConnectionOpenCLKernel : public OpenCLKernel {
|
|||
int Tune() override { return lite::RET_OK; }
|
||||
|
||||
private:
|
||||
int InitFilter();
|
||||
int InitBias();
|
||||
void *padWeight_{nullptr};
|
||||
void *bias_{nullptr};
|
||||
bool enable_fp16_{false};
|
||||
bool transposeA{false};
|
||||
bool transposeB{true};
|
||||
GpuTensorInfo inShape = GpuTensorInfo(nullptr);
|
||||
GpuTensorInfo outShape = GpuTensorInfo(nullptr);
|
||||
bool weight_var_{false};
|
||||
int N_{1};
|
||||
int CI_remainder_{1};
|
||||
int CO_{1};
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -68,7 +68,7 @@ cl_float4 ReduceOpenCLKernel::GenC4Mask() {
|
|||
|
||||
int ReduceOpenCLKernel::CheckSpecs() {
|
||||
if (in_tensors_[0]->shape()[0] > 1) {
|
||||
MS_LOG(ERROR) << "reduce op only support n=2";
|
||||
MS_LOG(ERROR) << "reduce op only support n = 1";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto reduce_param = reinterpret_cast<ReduceParameter *>(op_parameter_);
|
||||
|
@ -76,6 +76,10 @@ int ReduceOpenCLKernel::CheckSpecs() {
|
|||
MS_LOG(ERROR) << "not supported reduce type:" << reduce_param->mode_;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
if (reduce_param->num_axes_ == 1 && reduce_param->axes_[0] == 3 && in_tensors_[0]->shape()[2] == 1) {
|
||||
reduce_param->num_axes_ = 2;
|
||||
reduce_param->axes_[1] = 2;
|
||||
}
|
||||
if (reduce_param->num_axes_ != 2) {
|
||||
MS_LOG(ERROR) << "reduce op only support axes=2";
|
||||
return RET_PARAM_INVALID;
|
||||
|
|
|
@ -24,7 +24,7 @@ namespace {
|
|||
// PrimitiveType_FullConnection: src/ops/populate/full_connection_populate.cc
|
||||
OpParameter *CreateParameter(std::vector<int> *input_shape, std::vector<int> *weight_shape,
|
||||
std::vector<int> *bias_shape, std::vector<int> *output_shape, int ndim, int ci, int co,
|
||||
int n = 1, int h = 1, int w = 1) {
|
||||
int n = 1, int h = 1, int w = 1, int in_n = 1) {
|
||||
auto *param = test::CreateParameter<MatMulParameter>(schema::PrimitiveType_FullConnection);
|
||||
param->a_transpose_ = false;
|
||||
param->b_transpose_ = true;
|
||||
|
@ -41,6 +41,11 @@ OpParameter *CreateParameter(std::vector<int> *input_shape, std::vector<int> *we
|
|||
*output_shape = {n, co};
|
||||
*weight_shape = {co, h * w * ci};
|
||||
*bias_shape = {co};
|
||||
} else if (ndim == 3) {
|
||||
*input_shape = {in_n, w, ci};
|
||||
*output_shape = {n, co};
|
||||
*weight_shape = {co, in_n * w * ci / n};
|
||||
*bias_shape = {co};
|
||||
}
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
|
@ -87,4 +92,47 @@ TEST_F(TestOpenCL_FullConnection, 4D) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(TestOpenCL_FullConnection, 3D) {
|
||||
int ndim = 3;
|
||||
int ci = 3;
|
||||
int co = 4;
|
||||
int n = 2;
|
||||
int h = 1;
|
||||
int w = 4;
|
||||
int in_n = 1;
|
||||
float input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
|
||||
float weight_data[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
||||
float bias_data[] = {1, 1, 1, 1};
|
||||
float output_data[] = {16, 16, 16, 16, 52, 52, 52, 52};
|
||||
|
||||
for (auto fp16_enable : {false, true}) {
|
||||
std::vector<int> input_shape, weight_shape, bias_shape, output_shape;
|
||||
auto *param = CreateParameter(&input_shape, &weight_shape, &bias_shape, &output_shape, ndim, ci, co, n, h, w, in_n);
|
||||
TestMain({{input_shape, input_data, VAR},
|
||||
{weight_shape, weight_data, CONST_TENSOR},
|
||||
{bias_shape, bias_data, CONST_TENSOR}},
|
||||
{output_shape, output_data}, param, fp16_enable);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TestOpenCL_FullConnection, 3DWeightVar) {
|
||||
int ndim = 3;
|
||||
int ci = 6;
|
||||
int co = 4;
|
||||
int n = 2;
|
||||
int h = 1;
|
||||
int w = 2;
|
||||
int in_n = 1;
|
||||
float input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
|
||||
float weight_data[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
||||
float bias_data[] = {1, 1, 1, 1};
|
||||
float output_data[] = {16, 16, 16, 16, 52, 52, 52, 52};
|
||||
|
||||
for (auto fp16_enable : {false, true}) {
|
||||
std::vector<int> input_shape, weight_shape, bias_shape, output_shape;
|
||||
auto *param = CreateParameter(&input_shape, &weight_shape, &bias_shape, &output_shape, ndim, ci, co, n, h, w, in_n);
|
||||
TestMain({{input_shape, input_data, VAR}, {weight_shape, weight_data, VAR}, {bias_shape, bias_data, CONST_TENSOR}},
|
||||
{output_shape, output_data}, param, fp16_enable);
|
||||
}
|
||||
}
|
||||
} // namespace mindspore::lite::opencl::test
|
||||
|
|
Loading…
Reference in New Issue