refactor nnacl: abstrackt kernel base
This commit is contained in:
parent
9efee92977
commit
ccb88ae531
|
@ -50,3 +50,4 @@
|
|||
"mindspore/mindspore/lite/src/runtime/kernel/opencl/kernel/" "unreadVariable"
|
||||
"mindspore/mindspore/lite/src/runtime/kernel/opencl/cl/" "unreadVariable"
|
||||
"mindspore/mindspore/lite/examples/quick_start_micro/" "syntaxError"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experiment" "unreadVariable"
|
||||
|
|
|
@ -198,3 +198,4 @@ mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/instance_norm_fp16
|
|||
mindspore/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc:mindspore::kernel::MatmulFp32BaseCPUKernel::init_global_variable
|
||||
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/conv_winograd_fp32.c:ConvWinogardFp32
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc:mindspore::opt::MatchAdd5Pattern
|
||||
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experiment/conv_fp32_nchwx_avx512.c:conv2d_compute_fp32_nchwx_avx512
|
||||
|
|
|
@ -13,131 +13,36 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/conv_parameter.h"
|
||||
#include "nnacl/tensor_c.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/kernel.h"
|
||||
#include "nnacl/experiment/conv.h"
|
||||
#include "nnacl/experiment/conv_fp32_nchwx_avx512.h"
|
||||
|
||||
typedef struct ConvStru {
|
||||
KernelStru base;
|
||||
int inIm2colW;
|
||||
int inIm2colH;
|
||||
} ConvStru;
|
||||
|
||||
int conv_init_fp32_nc4hw4_armv8(struct KernelStru *self, KernelContext *ctx) {
|
||||
ConvStru *conv = (ConvStru *)self;
|
||||
self->ctx = ctx;
|
||||
self->infershape(self->param, self->in, self->insize, self->out, self->outsize);
|
||||
|
||||
int outw = self->out[kOutputIndex]->shape_[kNCHW_W];
|
||||
int outh = self->in[kWeightIndex]->shape_[kNCHW_H];
|
||||
int inch = self->in[kInputIndex]->shape_[kNCHW_C];
|
||||
int kw = self->in[kWeightIndex]->shape_[kNCHW_W];
|
||||
int kh = self->in[kWeightIndex]->shape_[kNCHW_H];
|
||||
|
||||
// im2col buffer
|
||||
conv->inIm2colW = inch * kw * kh;
|
||||
conv->inIm2colH = outw * outh;
|
||||
self->buf[0] = ctx->alloc(conv->inIm2colW * conv->inIm2colH);
|
||||
self->buf[1] = ctx->alloc(conv->inIm2colW);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int conv_release_fp32_nc4hw4_armv8(KernelStru *self) {
|
||||
size_t sz = sizeof(self->buf) / sizeof(self->buf[0]);
|
||||
for (size_t i = 0; i < sz; i++) {
|
||||
free(self->buf[sz]);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int conv_compute_fp32_nc4hw4_armv8(KernelStru *self) {
|
||||
int outw = self->out[kOutputIndex]->shape_[kNCHW_W];
|
||||
int outh = self->in[kWeightIndex]->shape_[kNCHW_H];
|
||||
int outch = self->out[kOutputIndex]->shape_[kNCHW_C];
|
||||
int inw = self->in[kInputIndex]->shape_[kNCHW_W];
|
||||
int inh = self->in[kInputIndex]->shape_[kNCHW_H];
|
||||
int inch = self->in[kInputIndex]->shape_[kNCHW_C];
|
||||
int kw = self->in[kWeightIndex]->shape_[kNCHW_W];
|
||||
int kh = self->in[kWeightIndex]->shape_[kNCHW_H];
|
||||
|
||||
int outPos = 0;
|
||||
float *outPtr = (float *)self->out[kOutputIndex]->data_;
|
||||
|
||||
ConvParameter *param = (ConvParameter *)self->param;
|
||||
for (size_t n = 0; n < self->out[kOutputIndex]->shape_[kNCHW_N]; n++) {
|
||||
// im2col input
|
||||
float *inIm2colBuf = (float *)self->buf[0];
|
||||
int index = 0;
|
||||
|
||||
// along the input height direction
|
||||
for (int y = 0; y < outh; y++) {
|
||||
// along the input width direction
|
||||
for (int x = 0; x < outw; x++) {
|
||||
// per input channel
|
||||
for (int ch = 0; ch < inch; ch++) {
|
||||
float *fp = (float *)(self->in[kInputIndex] + inch * inw * inh);
|
||||
|
||||
// per sliding window
|
||||
for (int rowStart = 0; rowStart < kh; rowStart++) {
|
||||
for (int colStart = 0; colStart < kw; colStart++) {
|
||||
int posx = x + colStart;
|
||||
int posy = y + rowStart;
|
||||
|
||||
// the padding area
|
||||
if (posx < inw || posx >= inw + param->pad_l_ || posy < inh || posy >= inh + param->pad_u_) {
|
||||
inIm2colBuf[index++] = 0;
|
||||
continue;
|
||||
}
|
||||
|
||||
inIm2colBuf[index++] = *(fp + (posy - param->pad_u_) * inw + (posx - param->pad_l_));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
static KernelBase *CreateConv(OpParameter *param, TensorC *in[], size_t insize, TensorC *out[], size_t outsize) {
|
||||
if (in[0]->format_ == Format_NHWC) {
|
||||
return NULL;
|
||||
} else if (in[0]->format_ == Format_NCHW) {
|
||||
if (in[0]->data_format_ != Format_NC16HW16) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
for (size_t co = 0; co < outch; co++) { // along out channel direction
|
||||
index = 0;
|
||||
float *wtIm2colBuf = self->buf[1];
|
||||
float *fp = (float *)(self->in[kWeightIndex] + co * inch * kh * kw);
|
||||
|
||||
// im2col weight
|
||||
for (int ch = 0; ch < inch; ch++) {
|
||||
for (int y = 0; y < kh; y++) {
|
||||
for (int x = 0; x < kw; x++) {
|
||||
wtIm2colBuf[index++] = *(fp + ch * kh * kw + y * kh + x);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int y = 0; y < outh * outw; y++) { // along output height*width direction
|
||||
float *rowBuf = inIm2colBuf + y * kw * kh;
|
||||
float *colBuf = wtIm2colBuf;
|
||||
float *outfp = outPtr + outPos;
|
||||
*outfp = 0;
|
||||
for (int l = 0; l < kh * kw; l++) {
|
||||
*outfp += rowBuf[l] * colBuf[l];
|
||||
}
|
||||
outPos++;
|
||||
}
|
||||
KConv2d *conv = (KConv2d *)malloc(sizeof(KConv2d));
|
||||
if (conv == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
conv->base.param = param;
|
||||
conv->base.in = in;
|
||||
conv->base.insize = insize;
|
||||
conv->base.out = out;
|
||||
conv->base.outsize = outsize;
|
||||
|
||||
conv->base.prepare = conv2d_prepare_fp32_nchwx_avx512;
|
||||
conv->base.compute = conv2d_compute_fp32_nchwx_avx512;
|
||||
conv->base.release = conv2d_release_fp32_nchwx_avx512;
|
||||
conv->base.resize = conv2d_resize_fp32_nchwx_avx512;
|
||||
conv->base.inferShape = conv2d_infershape_fp32_nchwx_avx512;
|
||||
return (KernelBase *)conv;
|
||||
} else {
|
||||
return NULL;
|
||||
}
|
||||
return 0;
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static KernelStru *CreateConv(OpParameter *param, TensorC *in[], size_t insize, TensorC *out[], size_t outsize) {
|
||||
ConvStru *conv = (ConvStru *)malloc(sizeof(ConvStru));
|
||||
conv->base.init = conv_init_fp32_nc4hw4_armv8;
|
||||
conv->base.release = conv_release_fp32_nc4hw4_armv8;
|
||||
conv->base.compute = conv_compute_fp32_nc4hw4_armv8;
|
||||
conv->base.param = param;
|
||||
conv->base.in = in;
|
||||
conv->base.insize = insize;
|
||||
conv->base.out = out;
|
||||
conv->base.outsize = outsize;
|
||||
return (KernelStru *)conv;
|
||||
}
|
||||
|
||||
REG_KERNEL_CREATOR(Conv2D, PrimType_Conv2DFusion, kDataTypeFloat, NC4HW4, CreateConv);
|
||||
REG_KERNEL_CREATOR(PrimType_Conv2DFusion, PrimType_Conv2DFusion, DT_Float16, CreateConv);
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_NNACL_EXPERIMENT_CONV_H_
|
||||
#define MINDSPORE_NNACL_EXPERIMENT_CONV_H_
|
||||
#include "nnacl/conv_parameter.h"
|
||||
#include "nnacl/kernel.h"
|
||||
|
||||
typedef struct KConv2d {
|
||||
KernelBase base;
|
||||
char *im2colBuf;
|
||||
char *packedWeight;
|
||||
} KConv2d;
|
||||
|
||||
#endif
|
|
@ -0,0 +1,239 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "nnacl/experiment/conv_fp32_nchwx_avx512.h"
|
||||
#include "nnacl/experiment/conv.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
#include "nnacl/infer/conv2d_infer.h"
|
||||
// #include "nnacl/intrinsics/ms_simd_avx512_instructions.h"
|
||||
|
||||
static const int UNIT_LEN = 512; // register length
|
||||
|
||||
static const int UNIT_NR = 512 / sizeof(float);
|
||||
static const int TILE_ROW;
|
||||
int conv2d_prepare_fp32_nchwx_avx512(struct KernelBase *self, ExecEnv *env) {
|
||||
KConv2d *conv = (KConv2d *)self;
|
||||
self->env = env;
|
||||
|
||||
int rowIndex = 0;
|
||||
TensorC *weight = self->in[kWeightIndex];
|
||||
|
||||
int cout = weight->shape_[kNCHW_N];
|
||||
int cin = weight->shape_[kNCHW_C];
|
||||
int kh = weight->shape_[kNCHW_H];
|
||||
int kw = weight->shape_[kNCHW_W];
|
||||
|
||||
size_t lineLen = cin * kw * kh;
|
||||
conv->packedWeight = malloc(lineLen * UP_ROUND_DIV(cout, UNIT_NR) * sizeof(float)); // allocate packed weight buf
|
||||
|
||||
float *rpos[16] = {0};
|
||||
float *data = (float *)weight->data_;
|
||||
float *buf = (float *)conv->packedWeight;
|
||||
int pos = 0;
|
||||
int bufIndex = 0;
|
||||
// transpose the weight matrix(width = cin*kw*kw, height = cout) from z order to z-N-Z order, z height of z-N-Z order
|
||||
// is UNIT and z width of z-N-Z order is UNIT, if not aligned, pad with 0
|
||||
while (rowIndex < cout) {
|
||||
#ifdef VECTORIZE_OPTIMIZE
|
||||
// use AVX2 instruction to optimize matrix transpose
|
||||
#else
|
||||
for (int r = 0; r < 16 && (rowIndex + r) < cout; r++) {
|
||||
rpos[r] = data + pos;
|
||||
pos += lineLen;
|
||||
}
|
||||
for (int c = 0; c < lineLen; c++) {
|
||||
for (int r = 0; r < 16; r++) {
|
||||
if ((rowIndex + r) < cout) {
|
||||
buf[bufIndex] = *(rpos[r] + c);
|
||||
} else {
|
||||
buf[bufIndex] = 0;
|
||||
}
|
||||
bufIndex++;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
rowIndex += 16;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
int conv2d_release_fp32_nchwx_avx512(struct KernelBase *self) {
|
||||
KConv2d *conv = (KConv2d *)self;
|
||||
free(conv->im2colBuf);
|
||||
free(conv->packedWeight);
|
||||
return 0;
|
||||
}
|
||||
// position map from z order to n-Z order, srcw: src width, nh: n height
|
||||
int PosMapz2nZ(int srcOffset, int srcw, int nh) {
|
||||
int groupSize = srcw * nh;
|
||||
int remain = srcOffset % groupSize;
|
||||
|
||||
int dstX = remain % srcw;
|
||||
int dstY = remain / srcw;
|
||||
|
||||
int dstOffset = groupSize - remain + dstX * nh + dstY;
|
||||
return dstOffset;
|
||||
}
|
||||
|
||||
int conv2d_compute_fp32_nchwx_avx512(struct KernelBase *self) {
|
||||
KConv2d *conv = (KConv2d *)self;
|
||||
ConvParameter *param = (ConvParameter *)self->param;
|
||||
TensorC *in = self->in[kInputIndex];
|
||||
TensorC *weight = self->in[kWeightIndex];
|
||||
TensorC *out = self->out[kOutputIndex];
|
||||
|
||||
float *weightData = (float *)weight->data_;
|
||||
|
||||
// im2col & tiling & pack
|
||||
float *buf = (float *)conv->im2colBuf;
|
||||
float *data = (float *)in->data_;
|
||||
int fmw = in->shape_[kNCHW_W] + param->pad_l_ + param->pad_r_;
|
||||
int fmh = in->shape_[kNCHW_H] + param->pad_u_ + param->pad_d_;
|
||||
int kh = weight->shape_[kNCHW_H];
|
||||
int kw = weight->shape_[kNCHW_W];
|
||||
int ci = UP_ROUND(in->shape_[kNCHW_C], 16);
|
||||
int co = UP_ROUND(out->shape_[kNCHW_C], 16);
|
||||
|
||||
// tiling policy
|
||||
// m, n, k are the the left/right tile's shape
|
||||
int m = TILE_ROW;
|
||||
int n = UNIT_NR;
|
||||
int k = UNIT_NR;
|
||||
|
||||
// im2col + pack to z-N-Z order
|
||||
int unitOffset = 0;
|
||||
int unitChNr = in->data_shape_[1];
|
||||
int interval = in->data_shape_[1] * UNIT_LEN;
|
||||
#ifdef VECTORIZE_OPTIMIZE
|
||||
// use AVX2 instruction to optimize matrix transpose
|
||||
#else
|
||||
for (int wpos = 0; wpos < fmw - kw; wpos++) {
|
||||
for (int hpos = 0; hpos < fmh - kh; hpos++) {
|
||||
for (int x = 0; x < kw; x++) {
|
||||
for (int y = 0; y < kh; y++) {
|
||||
if ((wpos + x) < param->pad_l_ || (wpos + x) >= in->shape_[kNCHW_W] + param->pad_l_ ||
|
||||
(hpos + y) < param->pad_u_ || (hpos + y) >= in->shape_[kNCHW_H] + param->pad_d_) {
|
||||
memset(buf + PosMapz2nZ(unitOffset, unitChNr, m) * UNIT_LEN, 0, UNIT_LEN);
|
||||
unitOffset++;
|
||||
} else {
|
||||
int fmx = wpos + x - param->pad_l_;
|
||||
int fmy = hpos + y - param->pad_u_;
|
||||
int fmpos = (fmx * interval * in->shape_[kNCHW_W] + fmy * interval) * sizeof(float);
|
||||
// copy the whole channel for this feature map position
|
||||
for (int ch = 0; ch < unitChNr; ch++) {
|
||||
// transpose the feature map (width:cin*kw*kh, height:outh*outw, ) from z order to z-N-Z order, z width of
|
||||
// z-N-Z order is UNIT, z height of z-N-Z order is TILE_ROW. if not aligned, pad with 0.
|
||||
memcpy(buf + PosMapz2nZ(unitOffset, m, unitChNr) * UNIT_LEN,
|
||||
data + fmpos + ch * interval * in->shape_[kNCHW_W] * in->shape_[kNCHW_H], UNIT_LEN);
|
||||
unitOffset++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
// gemm: left matrix is feature map with z-N-Z order, right matrix is kernel with z-N-Z order, compute left
|
||||
// matrix multiple with the transpse of the right matrix
|
||||
int lmw = ci * kw * kh; // left matrix width
|
||||
int lmh = out->shape_[kNCHW_H] * out->shape_[kNCHW_W]; // left matrix height
|
||||
// int rmw = lmw;
|
||||
int rmh = co;
|
||||
|
||||
#ifdef VECTORIZE_OPTIMIZE
|
||||
// use AVX2 instruction to optimize gemm
|
||||
#else
|
||||
int InputRegNr = 16;
|
||||
int WeightRegNr = 16;
|
||||
int OutputRegNr = 16;
|
||||
float intputReg[InputRegNr][UNIT_NR];
|
||||
float weightReg[WeightRegNr][UNIT_NR];
|
||||
float outputReg[OutputRegNr][UNIT_NR];
|
||||
memset(outputReg, 0, sizeof(outputReg));
|
||||
int lpos = 0;
|
||||
int rpos = 0;
|
||||
int outOffset = 0;
|
||||
for (int x = 0; x < rmh; x++) { // output channel
|
||||
for (int y = 0; y < lmh; y++) { // output h * w
|
||||
// tile computing
|
||||
for (int tilePos = 0; tilePos < lmw; tilePos++) {
|
||||
// load left tile
|
||||
int regNr = 0;
|
||||
for (int c = 0; c < m; c++) {
|
||||
memcpy(&intputReg[regNr++], buf + lpos, UNIT_LEN);
|
||||
lpos += UNIT_LEN;
|
||||
}
|
||||
// load right tile
|
||||
regNr = 0;
|
||||
for (int c = 0; c < k; c++) {
|
||||
memcpy(&weightReg[regNr++], weightData + rpos, UNIT_LEN);
|
||||
rpos += UNIT_LEN;
|
||||
}
|
||||
|
||||
// matrix multiplication: [m,n] * [n,k]
|
||||
for (int i = 0; i < m; i++) {
|
||||
for (int j = 0; j < k; j++) {
|
||||
for (int p = 0; p < n; p++) {
|
||||
outputReg[i][j] += intputReg[i][p] * weightReg[j][p];
|
||||
}
|
||||
}
|
||||
}
|
||||
tilePos += n;
|
||||
}
|
||||
// flush outputReg to output tensor memory
|
||||
memcpy(out->data_ + outOffset, outputReg, m * k * sizeof(float));
|
||||
outOffset += m * k * sizeof(float);
|
||||
y += m;
|
||||
}
|
||||
x += k;
|
||||
}
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
int conv2d_infershape_fp32_nchwx_avx512(struct KernelBase *self) {
|
||||
return Conv2dInferShape((const struct TensorC *const *)self->in, self->insize, self->out, self->outsize, self->param);
|
||||
}
|
||||
int conv2d_resize_fp32_nchwx_avx512(struct KernelBase *self, TensorC *inputs[], size_t insize, TensorC *outputs[],
|
||||
size_t outsize) {
|
||||
KConv2d *conv = (KConv2d *)self;
|
||||
self->in = inputs;
|
||||
self->insize = insize;
|
||||
self->out = outputs;
|
||||
self->outsize = outsize;
|
||||
|
||||
TensorC *in = self->in[kInputIndex];
|
||||
TensorC *weight = self->in[kWeightIndex];
|
||||
TensorC *out = self->out[kOutputIndex];
|
||||
int kh = weight->shape_[kNCHW_H];
|
||||
int kw = weight->shape_[kNCHW_W];
|
||||
|
||||
self->inferShape(self);
|
||||
out->data_format_ = Format_NC16HW16;
|
||||
out->data_shape_[0] = out->shape_[0];
|
||||
out->data_shape_[1] = UP_ROUND_DIV(out->shape_[1], 16);
|
||||
out->data_shape_[2] = out->shape_[2];
|
||||
out->data_shape_[3] = out->shape_[3];
|
||||
out->data_shape_[4] = 16;
|
||||
out->data_shape_size_ = 5;
|
||||
|
||||
if (conv->im2colBuf) {
|
||||
free(conv->im2colBuf);
|
||||
}
|
||||
int ci = in->data_shape_[1] * in->data_shape_[4];
|
||||
int lmw = ci * kw * kh; // left matrix width
|
||||
int lmh = out->shape_[kNCHW_H] * out->shape_[kNCHW_W]; // left matrix height
|
||||
|
||||
conv->im2colBuf = malloc(lmw * lmh * sizeof(float)); // allocate im2col buf
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_NNACL_EXPERIMENT_CONV_FP32_AVX512_H_
|
||||
#define MINDSPORE_NNACL_EXPERIMENT_CONV_FP32_AVX512_H_
|
||||
#include "nnacl/kernel.h"
|
||||
|
||||
int conv2d_prepare_fp32_nchwx_avx512(struct KernelBase *self, ExecEnv *env);
|
||||
int conv2d_release_fp32_nchwx_avx512(struct KernelBase *self);
|
||||
int conv2d_compute_fp32_nchwx_avx512(struct KernelBase *self);
|
||||
int conv2d_infershape_fp32_nchwx_avx512(struct KernelBase *self);
|
||||
int conv2d_resize_fp32_nchwx_avx512(struct KernelBase *self, TensorC *in[], size_t insize, TensorC *out[],
|
||||
size_t outsize);
|
||||
#endif
|
|
@ -14,16 +14,13 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "nnacl/kernel.h"
|
||||
#include "nnacl/tensor_c.h"
|
||||
#include "nnacl/op_base.h"
|
||||
static KernelCreator g_kernelCreatorRegistry[PrimType_MAX][kDataTypeMax][NUM_OF_FORMAT];
|
||||
static KernelCreator g_kernelCreatorRegistry[PrimType_MAX][16];
|
||||
|
||||
void RegKernelCreator(int opType, LiteDataType dataType, TensorCFormat format, KernelCreator creator) {
|
||||
g_kernelCreatorRegistry[opType][dataType][format] = creator;
|
||||
void RegKernelCreator(int opType, int dataType, KernelCreator creator) {
|
||||
g_kernelCreatorRegistry[opType][dataType - kNumberTypeBegin - 1] = creator;
|
||||
}
|
||||
|
||||
KernelStru *CreateKernel(OpParameter *param, TensorC *in[], size_t insize, TensorC *out[], size_t outsize) {
|
||||
int format = in[kInputIndex]->format_;
|
||||
KernelBase *CreateKernel(OpParameter *param, TensorC *in[], size_t insize, TensorC *out[], size_t outsize) {
|
||||
int dtype = in[kInputIndex]->data_type_;
|
||||
return g_kernelCreatorRegistry[param->type_][format][dtype](param, in, insize, out, outsize);
|
||||
return g_kernelCreatorRegistry[param->type_][dtype - kNumberTypeBegin - 1](param, in, insize, out, outsize);
|
||||
}
|
||||
|
|
|
@ -16,38 +16,41 @@
|
|||
#ifndef MINDSPORE_NNACL_KERNEL_H_
|
||||
#define MINDSPORE_NNACL_KERNEL_H_
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/tensor_c.h"
|
||||
#include "nnacl/infer/common_infer.h"
|
||||
|
||||
typedef struct KernelContext {
|
||||
typedef struct ExecEnv {
|
||||
void *(*alloc)(size_t sz);
|
||||
void (*free)(void *ptr);
|
||||
int threadNum;
|
||||
void (*parallelLaunch)(void *task, void *param, int taskNr);
|
||||
} KernelContext;
|
||||
} ExecEnv;
|
||||
|
||||
typedef struct KernelStru {
|
||||
int (*init)(struct KernelStru *self, KernelContext *ctx);
|
||||
int (*release)(struct KernelStru *self);
|
||||
int (*compute)(struct KernelStru *self);
|
||||
int (*infershape)(OpParameter *param, TensorC *in[], size_t insize, TensorC *out[], size_t outsize);
|
||||
typedef struct KernelBase {
|
||||
int (*prepare)(struct KernelBase *self, ExecEnv *env); // prepare, e.g. pack weight
|
||||
int (*release)(struct KernelBase *self);
|
||||
int (*compute)(struct KernelBase *self);
|
||||
int (*inferShape)(struct KernelBase *self);
|
||||
int (*resize)(struct KernelBase *self, TensorC *in[], size_t insize, TensorC *out[], size_t outsize);
|
||||
OpParameter *param;
|
||||
TensorC **in; // in/out tensor's space should be managed by the invoker
|
||||
// by design, kernelBase's methods are not responsible for input/output tensors' management, user should invokes
|
||||
// KernelBase's infer shape and allocate/free input/output tensor at necessary time.
|
||||
TensorC **in;
|
||||
size_t insize;
|
||||
TensorC **out;
|
||||
size_t outsize;
|
||||
KernelContext *ctx;
|
||||
void *buf[4];
|
||||
} KernelStru;
|
||||
ExecEnv *env;
|
||||
bool inferShape_;
|
||||
} KernelBase;
|
||||
|
||||
KernelStru *CreateKernel(OpParameter *param, TensorC *in[], size_t insize, TensorC *out[], size_t outsize);
|
||||
typedef KernelStru *(*KernelCreator)(OpParameter *param, TensorC *in[], size_t insize, TensorC *out[], size_t outsize);
|
||||
void RegKernelCreator(int opType, LiteDataType dataType, TensorCFormat format, KernelCreator func);
|
||||
KernelBase *CreateKernel(OpParameter *param, TensorC *in[], size_t insize, TensorC *out[], size_t outsize);
|
||||
typedef KernelBase *(*KernelCreator)(OpParameter *param, TensorC *in[], size_t insize, TensorC *out[], size_t outsize);
|
||||
void RegKernelCreator(int opType, int dataType, KernelCreator func);
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define REG_KERNEL_CREATOR(op, op_type, data_type, format, func)
|
||||
#define REG_KERNEL_CREATOR(op, op_type, data_type, func)
|
||||
#else
|
||||
#define REG_KERNEL_CREATOR(op, op_type, data_type, format, func) \
|
||||
__attribute__((constructor(102))) void Reg##op##Creator() { RegKernelCreator(op_type, data_type, format, func); }
|
||||
#define REG_KERNEL_CREATOR(op, op_type, data_type, func) \
|
||||
__attribute__((constructor(102))) void Reg##op##Creator() { RegKernelCreator(op_type, data_type, func); }
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
|
|
@ -13,27 +13,16 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
#include "experiment/kernel/convolution_fp32.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
ConvolutionCPUFp32::ConvolutionCPUFp32(OpParameter *parameter, std::vector<lite::Tensor *> in_tensors,
|
||||
std::vector<lite::Tensor *> out_tensors, const lite::Context *ctx)
|
||||
: InnerKernel(parameter, in_tensors, out_tensors, ctx) {
|
||||
TensorC *in[C4NUM];
|
||||
size_t insize = 0;
|
||||
for (; insize < in_tensors.size() && insize < C4NUM; insize++) {
|
||||
in[insize] = &in_tensors[insize]->TensorC();
|
||||
}
|
||||
|
||||
TensorC *out[1];
|
||||
size_t outsize = 0;
|
||||
for (; outsize < out_tensors.size() && outsize < 1; outsize++) {
|
||||
out[outsize] = &out_tensors[outsize]->TensorC();
|
||||
}
|
||||
kernel = CreateKernel(parameter, in, insize, out, outsize);
|
||||
in[0] = &in_tensors[0]->TensorC();
|
||||
in[1] = &in_tensors[1]->TensorC();
|
||||
out[0] = &out_tensors[0]->TensorC();
|
||||
}
|
||||
|
||||
ConvolutionCPUFp32::~ConvolutionCPUFp32() {
|
||||
|
@ -42,20 +31,18 @@ ConvolutionCPUFp32::~ConvolutionCPUFp32() {
|
|||
}
|
||||
|
||||
int ConvolutionCPUFp32::Prepare() {
|
||||
kernel = CreateKernel(parameter, in, 2, out, 1);
|
||||
if (kernel == nullptr) {
|
||||
return -1;
|
||||
}
|
||||
kernel->init(kernel, &ctx_); // init kernel, pack weight
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ConvolutionCPUFp32::PreProcess() {
|
||||
// allocate output tensor
|
||||
if (kernel->resize(kernel, in, 2, out, 1) != NNACL_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
return 0;
|
||||
return kernel->prepare(kernel, NULL);
|
||||
}
|
||||
|
||||
int ConvolutionCPUFp32::Run() { return kernel->compute(kernel); }
|
||||
|
||||
int ConvolutionCPUFp32::PostProcess() { return kernel->compute(kernel); }
|
||||
int ConvolutionCPUFp32::Resize() { return kernel->resize(kernel, in, 2, out, 1); }
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
|
||||
#include <vector>
|
||||
#include "src/inner_kernel.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/kernel.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
@ -32,12 +31,11 @@ class ConvolutionCPUFp32 : public InnerKernel {
|
|||
|
||||
int Run() override;
|
||||
int ReSize() override;
|
||||
int PostProcess() override; // invoke after running, e.g., free input tensor
|
||||
int PreProcess() override; // invoke before running, e.g., allocate output tensor, pack input
|
||||
|
||||
private:
|
||||
KernelStru *kernel;
|
||||
KernelContext ctx_;
|
||||
KernelBase *kernel;
|
||||
TensorC *in[2];
|
||||
TensorC *out[1];
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -209,10 +209,9 @@ int32_t Tensor::Width() const {
|
|||
|
||||
size_t Tensor::Size() const {
|
||||
size_t element_size = DataTypeSize(this->data_type_);
|
||||
auto element_num = (format_ == mindspore::NC4HW4 || format_ == mindspore::NHWC4) ? ElementsC4Num() : ElementsNum();
|
||||
if (element_num < 0) {
|
||||
MS_LOG(INFO) << "Element number of tensor should large than 0 : " << element_num;
|
||||
return 0;
|
||||
size_t element_num = 1;
|
||||
for (auto i = 0; i < tensorc.data_shape_size_; i++) {
|
||||
element_num *= tensorc.data_shape_[i];
|
||||
}
|
||||
return element_size * element_num;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue