From ccb88ae531b50f64d080eaba94eed2c322c22510 Mon Sep 17 00:00:00 2001 From: zhaizhiqiang Date: Wed, 23 Mar 2022 09:35:34 +0800 Subject: [PATCH] refactor nnacl: abstrackt kernel base --- .jenkins/check/config/filter_cppcheck.txt | 1 + .jenkins/check/config/whitelizard.txt | 1 + .../device/cpu/kernel/nnacl/experiment/conv.c | 149 ++--------- .../device/cpu/kernel/nnacl/experiment/conv.h | 27 ++ .../nnacl/experiment/conv_fp32_nchwx_avx512.c | 239 ++++++++++++++++++ .../nnacl/experiment/conv_fp32_nchwx_avx512.h | 26 ++ .../plugin/device/cpu/kernel/nnacl/kernel.c | 13 +- .../plugin/device/cpu/kernel/nnacl/kernel.h | 39 +-- .../experiment/kernel/convolution_fp32.cc | 31 +-- .../lite/experiment/kernel/convolution_fp32.h | 8 +- mindspore/lite/experiment/src/tensor.cc | 7 +- 11 files changed, 362 insertions(+), 179 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experiment/conv.h create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experiment/conv_fp32_nchwx_avx512.c create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experiment/conv_fp32_nchwx_avx512.h diff --git a/.jenkins/check/config/filter_cppcheck.txt b/.jenkins/check/config/filter_cppcheck.txt index 42a5f0fc303..5b8fc68ffca 100644 --- a/.jenkins/check/config/filter_cppcheck.txt +++ b/.jenkins/check/config/filter_cppcheck.txt @@ -50,3 +50,4 @@ "mindspore/mindspore/lite/src/runtime/kernel/opencl/kernel/" "unreadVariable" "mindspore/mindspore/lite/src/runtime/kernel/opencl/cl/" "unreadVariable" "mindspore/mindspore/lite/examples/quick_start_micro/" "syntaxError" +"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experiment" "unreadVariable" diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt index acefbeac65a..2c90db79d0a 100644 --- a/.jenkins/check/config/whitelizard.txt +++ b/.jenkins/check/config/whitelizard.txt @@ -198,3 +198,4 @@ mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/instance_norm_fp16 mindspore/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc:mindspore::kernel::MatmulFp32BaseCPUKernel::init_global_variable mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/conv_winograd_fp32.c:ConvWinogardFp32 mindspore/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc:mindspore::opt::MatchAdd5Pattern +mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experiment/conv_fp32_nchwx_avx512.c:conv2d_compute_fp32_nchwx_avx512 diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experiment/conv.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experiment/conv.c index e0a1aa7cee7..6555cdd473c 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experiment/conv.c +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experiment/conv.c @@ -13,131 +13,36 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/conv_parameter.h" -#include "nnacl/tensor_c.h" -#include "nnacl/op_base.h" -#include "nnacl/kernel.h" +#include "nnacl/experiment/conv.h" +#include "nnacl/experiment/conv_fp32_nchwx_avx512.h" -typedef struct ConvStru { - KernelStru base; - int inIm2colW; - int inIm2colH; -} ConvStru; - -int conv_init_fp32_nc4hw4_armv8(struct KernelStru *self, KernelContext *ctx) { - ConvStru *conv = (ConvStru *)self; - self->ctx = ctx; - self->infershape(self->param, self->in, self->insize, self->out, self->outsize); - - int outw = self->out[kOutputIndex]->shape_[kNCHW_W]; - int outh = self->in[kWeightIndex]->shape_[kNCHW_H]; - int inch = self->in[kInputIndex]->shape_[kNCHW_C]; - int kw = self->in[kWeightIndex]->shape_[kNCHW_W]; - int kh = self->in[kWeightIndex]->shape_[kNCHW_H]; - - // im2col buffer - conv->inIm2colW = inch * kw * kh; - conv->inIm2colH = outw * outh; - self->buf[0] = ctx->alloc(conv->inIm2colW * conv->inIm2colH); - self->buf[1] = ctx->alloc(conv->inIm2colW); - - return 0; -} - -int conv_release_fp32_nc4hw4_armv8(KernelStru *self) { - size_t sz = sizeof(self->buf) / sizeof(self->buf[0]); - for (size_t i = 0; i < sz; i++) { - free(self->buf[sz]); - } - return 0; -} - -int conv_compute_fp32_nc4hw4_armv8(KernelStru *self) { - int outw = self->out[kOutputIndex]->shape_[kNCHW_W]; - int outh = self->in[kWeightIndex]->shape_[kNCHW_H]; - int outch = self->out[kOutputIndex]->shape_[kNCHW_C]; - int inw = self->in[kInputIndex]->shape_[kNCHW_W]; - int inh = self->in[kInputIndex]->shape_[kNCHW_H]; - int inch = self->in[kInputIndex]->shape_[kNCHW_C]; - int kw = self->in[kWeightIndex]->shape_[kNCHW_W]; - int kh = self->in[kWeightIndex]->shape_[kNCHW_H]; - - int outPos = 0; - float *outPtr = (float *)self->out[kOutputIndex]->data_; - - ConvParameter *param = (ConvParameter *)self->param; - for (size_t n = 0; n < self->out[kOutputIndex]->shape_[kNCHW_N]; n++) { - // im2col input - float *inIm2colBuf = (float *)self->buf[0]; - int index = 0; - - // along the input height direction - for (int y = 0; y < outh; y++) { - // along the input width direction - for (int x = 0; x < outw; x++) { - // per input channel - for (int ch = 0; ch < inch; ch++) { - float *fp = (float *)(self->in[kInputIndex] + inch * inw * inh); - - // per sliding window - for (int rowStart = 0; rowStart < kh; rowStart++) { - for (int colStart = 0; colStart < kw; colStart++) { - int posx = x + colStart; - int posy = y + rowStart; - - // the padding area - if (posx < inw || posx >= inw + param->pad_l_ || posy < inh || posy >= inh + param->pad_u_) { - inIm2colBuf[index++] = 0; - continue; - } - - inIm2colBuf[index++] = *(fp + (posy - param->pad_u_) * inw + (posx - param->pad_l_)); - } - } - } - } +static KernelBase *CreateConv(OpParameter *param, TensorC *in[], size_t insize, TensorC *out[], size_t outsize) { + if (in[0]->format_ == Format_NHWC) { + return NULL; + } else if (in[0]->format_ == Format_NCHW) { + if (in[0]->data_format_ != Format_NC16HW16) { + return NULL; } - - for (size_t co = 0; co < outch; co++) { // along out channel direction - index = 0; - float *wtIm2colBuf = self->buf[1]; - float *fp = (float *)(self->in[kWeightIndex] + co * inch * kh * kw); - - // im2col weight - for (int ch = 0; ch < inch; ch++) { - for (int y = 0; y < kh; y++) { - for (int x = 0; x < kw; x++) { - wtIm2colBuf[index++] = *(fp + ch * kh * kw + y * kh + x); - } - } - } - - for (int y = 0; y < outh * outw; y++) { // along output height*width direction - float *rowBuf = inIm2colBuf + y * kw * kh; - float *colBuf = wtIm2colBuf; - float *outfp = outPtr + outPos; - *outfp = 0; - for (int l = 0; l < kh * kw; l++) { - *outfp += rowBuf[l] * colBuf[l]; - } - outPos++; - } + KConv2d *conv = (KConv2d *)malloc(sizeof(KConv2d)); + if (conv == NULL) { + return NULL; } + conv->base.param = param; + conv->base.in = in; + conv->base.insize = insize; + conv->base.out = out; + conv->base.outsize = outsize; + + conv->base.prepare = conv2d_prepare_fp32_nchwx_avx512; + conv->base.compute = conv2d_compute_fp32_nchwx_avx512; + conv->base.release = conv2d_release_fp32_nchwx_avx512; + conv->base.resize = conv2d_resize_fp32_nchwx_avx512; + conv->base.inferShape = conv2d_infershape_fp32_nchwx_avx512; + return (KernelBase *)conv; + } else { + return NULL; } - return 0; + return NULL; } -static KernelStru *CreateConv(OpParameter *param, TensorC *in[], size_t insize, TensorC *out[], size_t outsize) { - ConvStru *conv = (ConvStru *)malloc(sizeof(ConvStru)); - conv->base.init = conv_init_fp32_nc4hw4_armv8; - conv->base.release = conv_release_fp32_nc4hw4_armv8; - conv->base.compute = conv_compute_fp32_nc4hw4_armv8; - conv->base.param = param; - conv->base.in = in; - conv->base.insize = insize; - conv->base.out = out; - conv->base.outsize = outsize; - return (KernelStru *)conv; -} - -REG_KERNEL_CREATOR(Conv2D, PrimType_Conv2DFusion, kDataTypeFloat, NC4HW4, CreateConv); +REG_KERNEL_CREATOR(PrimType_Conv2DFusion, PrimType_Conv2DFusion, DT_Float16, CreateConv); diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experiment/conv.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experiment/conv.h new file mode 100644 index 00000000000..b500f095739 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experiment/conv.h @@ -0,0 +1,27 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_EXPERIMENT_CONV_H_ +#define MINDSPORE_NNACL_EXPERIMENT_CONV_H_ +#include "nnacl/conv_parameter.h" +#include "nnacl/kernel.h" + +typedef struct KConv2d { + KernelBase base; + char *im2colBuf; + char *packedWeight; +} KConv2d; + +#endif diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experiment/conv_fp32_nchwx_avx512.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experiment/conv_fp32_nchwx_avx512.c new file mode 100644 index 00000000000..ab20bd5f3f9 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experiment/conv_fp32_nchwx_avx512.c @@ -0,0 +1,239 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/experiment/conv_fp32_nchwx_avx512.h" +#include "nnacl/experiment/conv.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/infer/conv2d_infer.h" +// #include "nnacl/intrinsics/ms_simd_avx512_instructions.h" + +static const int UNIT_LEN = 512; // register length + +static const int UNIT_NR = 512 / sizeof(float); +static const int TILE_ROW; +int conv2d_prepare_fp32_nchwx_avx512(struct KernelBase *self, ExecEnv *env) { + KConv2d *conv = (KConv2d *)self; + self->env = env; + + int rowIndex = 0; + TensorC *weight = self->in[kWeightIndex]; + + int cout = weight->shape_[kNCHW_N]; + int cin = weight->shape_[kNCHW_C]; + int kh = weight->shape_[kNCHW_H]; + int kw = weight->shape_[kNCHW_W]; + + size_t lineLen = cin * kw * kh; + conv->packedWeight = malloc(lineLen * UP_ROUND_DIV(cout, UNIT_NR) * sizeof(float)); // allocate packed weight buf + + float *rpos[16] = {0}; + float *data = (float *)weight->data_; + float *buf = (float *)conv->packedWeight; + int pos = 0; + int bufIndex = 0; + // transpose the weight matrix(width = cin*kw*kw, height = cout) from z order to z-N-Z order, z height of z-N-Z order + // is UNIT and z width of z-N-Z order is UNIT, if not aligned, pad with 0 + while (rowIndex < cout) { +#ifdef VECTORIZE_OPTIMIZE +// use AVX2 instruction to optimize matrix transpose +#else + for (int r = 0; r < 16 && (rowIndex + r) < cout; r++) { + rpos[r] = data + pos; + pos += lineLen; + } + for (int c = 0; c < lineLen; c++) { + for (int r = 0; r < 16; r++) { + if ((rowIndex + r) < cout) { + buf[bufIndex] = *(rpos[r] + c); + } else { + buf[bufIndex] = 0; + } + bufIndex++; + } + } +#endif + rowIndex += 16; + } + return 0; +} +int conv2d_release_fp32_nchwx_avx512(struct KernelBase *self) { + KConv2d *conv = (KConv2d *)self; + free(conv->im2colBuf); + free(conv->packedWeight); + return 0; +} +// position map from z order to n-Z order, srcw: src width, nh: n height +int PosMapz2nZ(int srcOffset, int srcw, int nh) { + int groupSize = srcw * nh; + int remain = srcOffset % groupSize; + + int dstX = remain % srcw; + int dstY = remain / srcw; + + int dstOffset = groupSize - remain + dstX * nh + dstY; + return dstOffset; +} + +int conv2d_compute_fp32_nchwx_avx512(struct KernelBase *self) { + KConv2d *conv = (KConv2d *)self; + ConvParameter *param = (ConvParameter *)self->param; + TensorC *in = self->in[kInputIndex]; + TensorC *weight = self->in[kWeightIndex]; + TensorC *out = self->out[kOutputIndex]; + + float *weightData = (float *)weight->data_; + + // im2col & tiling & pack + float *buf = (float *)conv->im2colBuf; + float *data = (float *)in->data_; + int fmw = in->shape_[kNCHW_W] + param->pad_l_ + param->pad_r_; + int fmh = in->shape_[kNCHW_H] + param->pad_u_ + param->pad_d_; + int kh = weight->shape_[kNCHW_H]; + int kw = weight->shape_[kNCHW_W]; + int ci = UP_ROUND(in->shape_[kNCHW_C], 16); + int co = UP_ROUND(out->shape_[kNCHW_C], 16); + + // tiling policy + // m, n, k are the the left/right tile's shape + int m = TILE_ROW; + int n = UNIT_NR; + int k = UNIT_NR; + + // im2col + pack to z-N-Z order + int unitOffset = 0; + int unitChNr = in->data_shape_[1]; + int interval = in->data_shape_[1] * UNIT_LEN; +#ifdef VECTORIZE_OPTIMIZE +// use AVX2 instruction to optimize matrix transpose +#else + for (int wpos = 0; wpos < fmw - kw; wpos++) { + for (int hpos = 0; hpos < fmh - kh; hpos++) { + for (int x = 0; x < kw; x++) { + for (int y = 0; y < kh; y++) { + if ((wpos + x) < param->pad_l_ || (wpos + x) >= in->shape_[kNCHW_W] + param->pad_l_ || + (hpos + y) < param->pad_u_ || (hpos + y) >= in->shape_[kNCHW_H] + param->pad_d_) { + memset(buf + PosMapz2nZ(unitOffset, unitChNr, m) * UNIT_LEN, 0, UNIT_LEN); + unitOffset++; + } else { + int fmx = wpos + x - param->pad_l_; + int fmy = hpos + y - param->pad_u_; + int fmpos = (fmx * interval * in->shape_[kNCHW_W] + fmy * interval) * sizeof(float); + // copy the whole channel for this feature map position + for (int ch = 0; ch < unitChNr; ch++) { + // transpose the feature map (width:cin*kw*kh, height:outh*outw, ) from z order to z-N-Z order, z width of + // z-N-Z order is UNIT, z height of z-N-Z order is TILE_ROW. if not aligned, pad with 0. + memcpy(buf + PosMapz2nZ(unitOffset, m, unitChNr) * UNIT_LEN, + data + fmpos + ch * interval * in->shape_[kNCHW_W] * in->shape_[kNCHW_H], UNIT_LEN); + unitOffset++; + } + } + } + } + } + } +#endif + // gemm: left matrix is feature map with z-N-Z order, right matrix is kernel with z-N-Z order, compute left + // matrix multiple with the transpse of the right matrix + int lmw = ci * kw * kh; // left matrix width + int lmh = out->shape_[kNCHW_H] * out->shape_[kNCHW_W]; // left matrix height + // int rmw = lmw; + int rmh = co; + +#ifdef VECTORIZE_OPTIMIZE +// use AVX2 instruction to optimize gemm +#else + int InputRegNr = 16; + int WeightRegNr = 16; + int OutputRegNr = 16; + float intputReg[InputRegNr][UNIT_NR]; + float weightReg[WeightRegNr][UNIT_NR]; + float outputReg[OutputRegNr][UNIT_NR]; + memset(outputReg, 0, sizeof(outputReg)); + int lpos = 0; + int rpos = 0; + int outOffset = 0; + for (int x = 0; x < rmh; x++) { // output channel + for (int y = 0; y < lmh; y++) { // output h * w + // tile computing + for (int tilePos = 0; tilePos < lmw; tilePos++) { + // load left tile + int regNr = 0; + for (int c = 0; c < m; c++) { + memcpy(&intputReg[regNr++], buf + lpos, UNIT_LEN); + lpos += UNIT_LEN; + } + // load right tile + regNr = 0; + for (int c = 0; c < k; c++) { + memcpy(&weightReg[regNr++], weightData + rpos, UNIT_LEN); + rpos += UNIT_LEN; + } + + // matrix multiplication: [m,n] * [n,k] + for (int i = 0; i < m; i++) { + for (int j = 0; j < k; j++) { + for (int p = 0; p < n; p++) { + outputReg[i][j] += intputReg[i][p] * weightReg[j][p]; + } + } + } + tilePos += n; + } + // flush outputReg to output tensor memory + memcpy(out->data_ + outOffset, outputReg, m * k * sizeof(float)); + outOffset += m * k * sizeof(float); + y += m; + } + x += k; + } +#endif + return 0; +} +int conv2d_infershape_fp32_nchwx_avx512(struct KernelBase *self) { + return Conv2dInferShape((const struct TensorC *const *)self->in, self->insize, self->out, self->outsize, self->param); +} +int conv2d_resize_fp32_nchwx_avx512(struct KernelBase *self, TensorC *inputs[], size_t insize, TensorC *outputs[], + size_t outsize) { + KConv2d *conv = (KConv2d *)self; + self->in = inputs; + self->insize = insize; + self->out = outputs; + self->outsize = outsize; + + TensorC *in = self->in[kInputIndex]; + TensorC *weight = self->in[kWeightIndex]; + TensorC *out = self->out[kOutputIndex]; + int kh = weight->shape_[kNCHW_H]; + int kw = weight->shape_[kNCHW_W]; + + self->inferShape(self); + out->data_format_ = Format_NC16HW16; + out->data_shape_[0] = out->shape_[0]; + out->data_shape_[1] = UP_ROUND_DIV(out->shape_[1], 16); + out->data_shape_[2] = out->shape_[2]; + out->data_shape_[3] = out->shape_[3]; + out->data_shape_[4] = 16; + out->data_shape_size_ = 5; + + if (conv->im2colBuf) { + free(conv->im2colBuf); + } + int ci = in->data_shape_[1] * in->data_shape_[4]; + int lmw = ci * kw * kh; // left matrix width + int lmh = out->shape_[kNCHW_H] * out->shape_[kNCHW_W]; // left matrix height + + conv->im2colBuf = malloc(lmw * lmh * sizeof(float)); // allocate im2col buf + return 0; +} diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experiment/conv_fp32_nchwx_avx512.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experiment/conv_fp32_nchwx_avx512.h new file mode 100644 index 00000000000..d01ce2f7c80 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experiment/conv_fp32_nchwx_avx512.h @@ -0,0 +1,26 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_EXPERIMENT_CONV_FP32_AVX512_H_ +#define MINDSPORE_NNACL_EXPERIMENT_CONV_FP32_AVX512_H_ +#include "nnacl/kernel.h" + +int conv2d_prepare_fp32_nchwx_avx512(struct KernelBase *self, ExecEnv *env); +int conv2d_release_fp32_nchwx_avx512(struct KernelBase *self); +int conv2d_compute_fp32_nchwx_avx512(struct KernelBase *self); +int conv2d_infershape_fp32_nchwx_avx512(struct KernelBase *self); +int conv2d_resize_fp32_nchwx_avx512(struct KernelBase *self, TensorC *in[], size_t insize, TensorC *out[], + size_t outsize); +#endif diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel.c index aef9900f6df..58b7416f2b5 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel.c +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel.c @@ -14,16 +14,13 @@ * limitations under the License. */ #include "nnacl/kernel.h" -#include "nnacl/tensor_c.h" -#include "nnacl/op_base.h" -static KernelCreator g_kernelCreatorRegistry[PrimType_MAX][kDataTypeMax][NUM_OF_FORMAT]; +static KernelCreator g_kernelCreatorRegistry[PrimType_MAX][16]; -void RegKernelCreator(int opType, LiteDataType dataType, TensorCFormat format, KernelCreator creator) { - g_kernelCreatorRegistry[opType][dataType][format] = creator; +void RegKernelCreator(int opType, int dataType, KernelCreator creator) { + g_kernelCreatorRegistry[opType][dataType - kNumberTypeBegin - 1] = creator; } -KernelStru *CreateKernel(OpParameter *param, TensorC *in[], size_t insize, TensorC *out[], size_t outsize) { - int format = in[kInputIndex]->format_; +KernelBase *CreateKernel(OpParameter *param, TensorC *in[], size_t insize, TensorC *out[], size_t outsize) { int dtype = in[kInputIndex]->data_type_; - return g_kernelCreatorRegistry[param->type_][format][dtype](param, in, insize, out, outsize); + return g_kernelCreatorRegistry[param->type_][dtype - kNumberTypeBegin - 1](param, in, insize, out, outsize); } diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel.h index 2d5ee89a951..84bbde18222 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel.h @@ -16,38 +16,41 @@ #ifndef MINDSPORE_NNACL_KERNEL_H_ #define MINDSPORE_NNACL_KERNEL_H_ #include "nnacl/op_base.h" -#include "nnacl/tensor_c.h" +#include "nnacl/infer/common_infer.h" -typedef struct KernelContext { +typedef struct ExecEnv { void *(*alloc)(size_t sz); void (*free)(void *ptr); int threadNum; void (*parallelLaunch)(void *task, void *param, int taskNr); -} KernelContext; +} ExecEnv; -typedef struct KernelStru { - int (*init)(struct KernelStru *self, KernelContext *ctx); - int (*release)(struct KernelStru *self); - int (*compute)(struct KernelStru *self); - int (*infershape)(OpParameter *param, TensorC *in[], size_t insize, TensorC *out[], size_t outsize); +typedef struct KernelBase { + int (*prepare)(struct KernelBase *self, ExecEnv *env); // prepare, e.g. pack weight + int (*release)(struct KernelBase *self); + int (*compute)(struct KernelBase *self); + int (*inferShape)(struct KernelBase *self); + int (*resize)(struct KernelBase *self, TensorC *in[], size_t insize, TensorC *out[], size_t outsize); OpParameter *param; - TensorC **in; // in/out tensor's space should be managed by the invoker + // by design, kernelBase's methods are not responsible for input/output tensors' management, user should invokes + // KernelBase's infer shape and allocate/free input/output tensor at necessary time. + TensorC **in; size_t insize; TensorC **out; size_t outsize; - KernelContext *ctx; - void *buf[4]; -} KernelStru; + ExecEnv *env; + bool inferShape_; +} KernelBase; -KernelStru *CreateKernel(OpParameter *param, TensorC *in[], size_t insize, TensorC *out[], size_t outsize); -typedef KernelStru *(*KernelCreator)(OpParameter *param, TensorC *in[], size_t insize, TensorC *out[], size_t outsize); -void RegKernelCreator(int opType, LiteDataType dataType, TensorCFormat format, KernelCreator func); +KernelBase *CreateKernel(OpParameter *param, TensorC *in[], size_t insize, TensorC *out[], size_t outsize); +typedef KernelBase *(*KernelCreator)(OpParameter *param, TensorC *in[], size_t insize, TensorC *out[], size_t outsize); +void RegKernelCreator(int opType, int dataType, KernelCreator func); #ifdef _MSC_VER -#define REG_KERNEL_CREATOR(op, op_type, data_type, format, func) +#define REG_KERNEL_CREATOR(op, op_type, data_type, func) #else -#define REG_KERNEL_CREATOR(op, op_type, data_type, format, func) \ - __attribute__((constructor(102))) void Reg##op##Creator() { RegKernelCreator(op_type, data_type, format, func); } +#define REG_KERNEL_CREATOR(op, op_type, data_type, func) \ + __attribute__((constructor(102))) void Reg##op##Creator() { RegKernelCreator(op_type, data_type, func); } #endif #endif diff --git a/mindspore/lite/experiment/kernel/convolution_fp32.cc b/mindspore/lite/experiment/kernel/convolution_fp32.cc index 0f556a1502f..d591f122b42 100644 --- a/mindspore/lite/experiment/kernel/convolution_fp32.cc +++ b/mindspore/lite/experiment/kernel/convolution_fp32.cc @@ -13,27 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include "nnacl/op_base.h" - #include "experiment/kernel/convolution_fp32.h" namespace mindspore::kernel { ConvolutionCPUFp32::ConvolutionCPUFp32(OpParameter *parameter, std::vector in_tensors, std::vector out_tensors, const lite::Context *ctx) : InnerKernel(parameter, in_tensors, out_tensors, ctx) { - TensorC *in[C4NUM]; - size_t insize = 0; - for (; insize < in_tensors.size() && insize < C4NUM; insize++) { - in[insize] = &in_tensors[insize]->TensorC(); - } - - TensorC *out[1]; - size_t outsize = 0; - for (; outsize < out_tensors.size() && outsize < 1; outsize++) { - out[outsize] = &out_tensors[outsize]->TensorC(); - } - kernel = CreateKernel(parameter, in, insize, out, outsize); + in[0] = &in_tensors[0]->TensorC(); + in[1] = &in_tensors[1]->TensorC(); + out[0] = &out_tensors[0]->TensorC(); } ConvolutionCPUFp32::~ConvolutionCPUFp32() { @@ -42,20 +31,18 @@ ConvolutionCPUFp32::~ConvolutionCPUFp32() { } int ConvolutionCPUFp32::Prepare() { + kernel = CreateKernel(parameter, in, 2, out, 1); if (kernel == nullptr) { return -1; } - kernel->init(kernel, &ctx_); // init kernel, pack weight - return 0; -} -int ConvolutionCPUFp32::PreProcess() { - // allocate output tensor + if (kernel->resize(kernel, in, 2, out, 1) != NNACL_OK) { + return ret; + } - return 0; + return kernel->prepare(kernel, NULL); } int ConvolutionCPUFp32::Run() { return kernel->compute(kernel); } - -int ConvolutionCPUFp32::PostProcess() { return kernel->compute(kernel); } +int ConvolutionCPUFp32::Resize() { return kernel->resize(kernel, in, 2, out, 1); } } // namespace mindspore::kernel diff --git a/mindspore/lite/experiment/kernel/convolution_fp32.h b/mindspore/lite/experiment/kernel/convolution_fp32.h index bc7d2d67c47..046aff26f53 100644 --- a/mindspore/lite/experiment/kernel/convolution_fp32.h +++ b/mindspore/lite/experiment/kernel/convolution_fp32.h @@ -19,7 +19,6 @@ #include #include "src/inner_kernel.h" -#include "nnacl/op_base.h" #include "nnacl/kernel.h" namespace mindspore::kernel { @@ -32,12 +31,11 @@ class ConvolutionCPUFp32 : public InnerKernel { int Run() override; int ReSize() override; - int PostProcess() override; // invoke after running, e.g., free input tensor - int PreProcess() override; // invoke before running, e.g., allocate output tensor, pack input private: - KernelStru *kernel; - KernelContext ctx_; + KernelBase *kernel; + TensorC *in[2]; + TensorC *out[1]; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/experiment/src/tensor.cc b/mindspore/lite/experiment/src/tensor.cc index 326e6d128db..d5c2d8c117c 100644 --- a/mindspore/lite/experiment/src/tensor.cc +++ b/mindspore/lite/experiment/src/tensor.cc @@ -209,10 +209,9 @@ int32_t Tensor::Width() const { size_t Tensor::Size() const { size_t element_size = DataTypeSize(this->data_type_); - auto element_num = (format_ == mindspore::NC4HW4 || format_ == mindspore::NHWC4) ? ElementsC4Num() : ElementsNum(); - if (element_num < 0) { - MS_LOG(INFO) << "Element number of tensor should large than 0 : " << element_num; - return 0; + size_t element_num = 1; + for (auto i = 0; i < tensorc.data_shape_size_; i++) { + element_num *= tensorc.data_shape_[i]; } return element_size * element_num; }