!8952 [MS][LITE][CPU]Add adder op

From: @fuzhiye
Reviewed-by: @hangangqiang,@zhang_xue_tong
Signed-off-by: @zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2020-11-27 17:01:22 +08:00 committed by Gitee
commit 627e4d0cf3
11 changed files with 1232 additions and 83 deletions

View File

@ -0,0 +1,633 @@
#ifdef __aarch64__
.text
.align 5
.global AdderFloatNeon64
#ifndef __APPLE__
.type AdderFloatNeon64, %function
#endif
// void AdderFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth
// int row, int col, size_t stride)
// x0: a
// x1: b
// x2: c
// x3: bias
// x4: act_type
// x5: depth
// x6: row
// x7: col
// x8: stride
// x9: writeMode
AdderFloatNeon64:
sub sp, sp, #144
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
stp x19, x20, [sp], #16
ldr x8, [sp]
mov x18, #48 // sizeof(float) * 12
mul x17, x5, x18 // block stride of lhs/rhs: sizeof(float) * 12 * depth
mov x18, #4
mul x8, x8, x18
LoopRowStart:
cmp x6, #4
ble LoopRow4
cmp x6, #8
blt LoopRow8
LoopRow:
mov x14, x1 // reload rhs ptr
mov x13, x7 // reload rhs col
mov x12, x3 // reload bias
LoopCol:
mov x11, x2
mov x10, x0 // reload lhs ptr
mov x19, x5 // reload depth
LoopDepthStart:
ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48
ld1 {v3.4s}, [x14], #16
dup v8.4s, v0.s[0]
fabd v9.4s, v3.4s, v8.4s
dup v10.4s, v0.s[1]
fabd v11.4s, v3.4s, v10.4s
dup v12.4s, v0.s[2]
fabd v13.4s, v3.4s, v12.4s
dup v14.4s, v0.s[3]
fabd v15.4s, v3.4s, v14.4s
dup v16.4s, v1.s[0]
fabd v17.4s, v3.4s, v16.4s
dup v18.4s, v1.s[1]
fabd v19.4s, v3.4s, v18.4s
dup v20.4s, v1.s[2]
fabd v21.4s, v3.4s, v20.4s
dup v22.4s, v1.s[3]
fabd v23.4s, v3.4s, v22.4s
dup v24.4s, v2.s[0]
fabd v25.4s, v3.4s, v24.4s
dup v26.4s, v2.s[1]
fabd v27.4s, v3.4s, v26.4s
dup v28.4s, v2.s[2]
fabd v29.4s, v3.4s, v28.4s
dup v30.4s, v2.s[3]
fabd v31.4s, v3.4s, v30.4s
subs x19, x19, #1
beq Bias
LoopDepth:
ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48
ld1 {v3.4s}, [x14], #16
dup v8.4s, v0.s[0]
fabd v8.4s, v3.4s, v8.4s
fadd v9.4s, v9.4s, v8.4s
dup v10.4s, v0.s[1]
fabd v10.4s, v3.4s, v10.4s
fadd v11.4s, v11.4s, v10.4s
dup v12.4s, v0.s[2]
fabd v12.4s, v3.4s, v12.4s
fadd v13.4s, v13.4s, v12.4s
dup v14.4s, v0.s[3]
fabd v14.4s, v3.4s, v14.4s
fadd v15.4s, v15.4s, v14.4s
dup v16.4s, v1.s[0]
fabd v16.4s, v3.4s, v16.4s
fadd v17.4s, v17.4s, v16.4s
dup v18.4s, v1.s[1]
fabd v18.4s, v3.4s, v18.4s
fadd v19.4s, v19.4s, v18.4s
dup v20.4s, v1.s[2]
fabd v20.4s, v3.4s, v20.4s
fadd v21.4s, v21.4s, v20.4s
dup v22.4s, v1.s[3]
fabd v22.4s, v3.4s, v22.4s
fadd v23.4s, v23.4s, v22.4s
dup v24.4s, v2.s[0]
fabd v24.4s, v3.4s, v24.4s
fadd v25.4s, v25.4s, v24.4s
dup v26.4s, v2.s[1]
fabd v26.4s, v3.4s, v26.4s
fadd v27.4s, v27.4s, v26.4s
dup v28.4s, v2.s[2]
fabd v28.4s, v3.4s, v28.4s
fadd v29.4s, v29.4s, v28.4s
dup v30.4s, v2.s[3]
fabd v30.4s, v3.4s, v30.4s
fadd v31.4s, v31.4s, v30.4s
subs x19, x19, #1
bgt LoopDepth
Bias:
fneg v9.4s, v9.4s
fneg v11.4s, v11.4s
fneg v13.4s, v13.4s
fneg v15.4s, v15.4s
fneg v17.4s, v17.4s
fneg v19.4s, v19.4s
fneg v21.4s, v21.4s
fneg v23.4s, v23.4s
fneg v25.4s, v25.4s
fneg v27.4s, v27.4s
fneg v29.4s, v29.4s
fneg v31.4s, v31.4s
cbz x3, Activation
ld1 {v0.4s}, [x12], #16
fadd v9.4s, v9.4s, v0.4s
fadd v11.4s, v11.4s, v0.4s
fadd v13.4s, v13.4s, v0.4s
fadd v15.4s, v15.4s, v0.4s
fadd v17.4s, v17.4s, v0.4s
fadd v19.4s, v19.4s, v0.4s
fadd v21.4s, v21.4s, v0.4s
fadd v23.4s, v23.4s, v0.4s
fadd v25.4s, v25.4s, v0.4s
fadd v27.4s, v27.4s, v0.4s
fadd v29.4s, v29.4s, v0.4s
fadd v31.4s, v31.4s, v0.4s
Activation:
cmp x4, #3
beq Relu6
cmp x4, #1
beq Relu
b Write
Relu6:
mov w19, #6
dup v2.4s, w19
scvtf v2.4s, v2.4s
fmin v9.4s, v9.4s, v2.4s
fmin v11.4s, v11.4s, v2.4s
fmin v13.4s, v13.4s, v2.4s
fmin v15.4s, v15.4s, v2.4s
fmin v17.4s, v17.4s, v2.4s
fmin v19.4s, v19.4s, v2.4s
fmin v21.4s, v21.4s, v2.4s
fmin v23.4s, v23.4s, v2.4s
fmin v25.4s, v25.4s, v2.4s
fmin v27.4s, v27.4s, v2.4s
fmin v29.4s, v29.4s, v2.4s
fmin v31.4s, v31.4s, v2.4s
Relu:
dup v3.4s, wzr
fmax v9.4s, v9.4s, v3.4s
fmax v11.4s, v11.4s, v3.4s
fmax v13.4s, v13.4s, v3.4s
fmax v15.4s, v15.4s, v3.4s
fmax v17.4s, v17.4s, v3.4s
fmax v19.4s, v19.4s, v3.4s
fmax v21.4s, v21.4s, v3.4s
fmax v23.4s, v23.4s, v3.4s
fmax v25.4s, v25.4s, v3.4s
fmax v27.4s, v27.4s, v3.4s
fmax v29.4s, v29.4s, v3.4s
fmax v31.4s, v31.4s, v3.4s
b Write
LoopRow8:
mov x14, x1 // reload rhs ptr
mov x13, x7 // reload rhs col
mov x12, x3 // reload bias
LoopCol8:
mov x11, x2
mov x10, x0 // reload lhs ptr
mov x19, x5 // reload depth
LoopDepthStart8:
ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48
ld1 {v3.4s}, [x14], #16
dup v8.4s, v0.s[0]
fabd v9.4s, v3.4s, v8.4s
dup v10.4s, v0.s[1]
fabd v11.4s, v3.4s, v10.4s
dup v12.4s, v0.s[2]
fabd v13.4s, v3.4s, v12.4s
dup v14.4s, v0.s[3]
fabd v15.4s, v3.4s, v14.4s
dup v16.4s, v1.s[0]
fabd v17.4s, v3.4s, v16.4s
dup v18.4s, v1.s[1]
fabd v19.4s, v3.4s, v18.4s
dup v20.4s, v1.s[2]
fabd v21.4s, v3.4s, v20.4s
dup v22.4s, v1.s[3]
fabd v23.4s, v3.4s, v22.4s
subs x19, x19, #1
beq Bias8
LoopDepth8:
ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48
ld1 {v3.4s}, [x14], #16
dup v8.4s, v0.s[0]
fabd v8.4s, v3.4s, v8.4s
fadd v9.4s, v9.4s, v8.4s
dup v10.4s, v0.s[1]
fabd v10.4s, v3.4s, v10.4s
fadd v11.4s, v11.4s, v10.4s
dup v12.4s, v0.s[2]
fabd v12.4s, v3.4s, v12.4s
fadd v13.4s, v13.4s, v12.4s
dup v14.4s, v0.s[3]
fabd v14.4s, v3.4s, v14.4s
fadd v15.4s, v15.4s, v14.4s
dup v16.4s, v1.s[0]
fabd v16.4s, v3.4s, v16.4s
fadd v17.4s, v17.4s, v16.4s
dup v18.4s, v1.s[1]
fabd v18.4s, v3.4s, v18.4s
fadd v19.4s, v19.4s, v18.4s
dup v20.4s, v1.s[2]
fabd v20.4s, v3.4s, v20.4s
fadd v21.4s, v21.4s, v20.4s
dup v22.4s, v1.s[3]
fabd v22.4s, v3.4s, v22.4s
fadd v23.4s, v23.4s, v22.4s
subs x19, x19, #1
bgt LoopDepth8
Bias8:
fneg v9.4s, v9.4s
fneg v11.4s, v11.4s
fneg v13.4s, v13.4s
fneg v15.4s, v15.4s
fneg v17.4s, v17.4s
fneg v19.4s, v19.4s
fneg v21.4s, v21.4s
fneg v23.4s, v23.4s
cbz x3, Activation8
ld1 {v0.4s}, [x12], #16
fadd v9.4s, v9.4s, v0.4s
fadd v11.4s, v11.4s, v0.4s
fadd v13.4s, v13.4s, v0.4s
fadd v15.4s, v15.4s, v0.4s
fadd v17.4s, v17.4s, v0.4s
fadd v19.4s, v19.4s, v0.4s
fadd v21.4s, v21.4s, v0.4s
fadd v23.4s, v23.4s, v0.4s
Activation8:
cmp x4, #3
beq Relu68
cmp x4, #1
beq Relu8
b Write
Relu68:
mov w19, #6
dup v2.4s, w19
scvtf v2.4s, v2.4s
fmin v9.4s, v9.4s, v2.4s
fmin v11.4s, v11.4s, v2.4s
fmin v13.4s, v13.4s, v2.4s
fmin v15.4s, v15.4s, v2.4s
fmin v17.4s, v17.4s, v2.4s
fmin v19.4s, v19.4s, v2.4s
fmin v21.4s, v21.4s, v2.4s
fmin v23.4s, v23.4s, v2.4s
Relu8:
dup v3.4s, wzr
fmax v9.4s, v9.4s, v3.4s
fmax v11.4s, v11.4s, v3.4s
fmax v13.4s, v13.4s, v3.4s
fmax v15.4s, v15.4s, v3.4s
fmax v17.4s, v17.4s, v3.4s
fmax v19.4s, v19.4s, v3.4s
fmax v21.4s, v21.4s, v3.4s
fmax v23.4s, v23.4s, v3.4s
b Write
LoopRow4:
mov x14, x1 // reload rhs ptr
mov x13, x7 // reload rhs col
mov x12, x3 // reload bias
LoopCol4:
mov x11, x2
mov x10, x0 // reload lhs ptr
mov x19, x5 // reload depth
LoopDepthStart4:
ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48
ld1 {v3.4s}, [x14], #16
dup v8.4s, v0.s[0]
fabd v9.4s, v3.4s, v8.4s
dup v10.4s, v0.s[1]
fabd v11.4s, v3.4s, v10.4s
dup v12.4s, v0.s[2]
fabd v13.4s, v3.4s, v12.4s
dup v14.4s, v0.s[3]
fabd v15.4s, v3.4s, v14.4s
subs x19, x19, #1
beq Bias4
LoopDepth4:
ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48
ld1 {v3.4s}, [x14], #16
dup v8.4s, v0.s[0]
fabd v8.4s, v3.4s, v8.4s
fadd v9.4s, v9.4s, v8.4s
dup v10.4s, v0.s[1]
fabd v10.4s, v3.4s, v10.4s
fadd v11.4s, v11.4s, v10.4s
dup v12.4s, v0.s[2]
fabd v12.4s, v3.4s, v12.4s
fadd v13.4s, v13.4s, v12.4s
dup v14.4s, v0.s[3]
fabd v14.4s, v3.4s, v14.4s
fadd v15.4s, v15.4s, v14.4s
subs x19, x19, #1
bgt LoopDepth4
Bias4:
fneg v9.4s, v9.4s
fneg v11.4s, v11.4s
fneg v13.4s, v13.4s
fneg v15.4s, v15.4s
cbz x3, Activation4
ld1 {v0.4s}, [x12], #16
fadd v9.4s, v9.4s, v0.4s
fadd v11.4s, v11.4s, v0.4s
fadd v13.4s, v13.4s, v0.4s
fadd v15.4s, v15.4s, v0.4s
Activation4:
cmp x4, #3
beq Relu64
cmp x4, #1
beq Relu4
b Write
Relu64:
mov w19, #6
dup v2.4s, w19
scvtf v2.4s, v2.4s
fmin v9.4s, v9.4s, v2.4s
fmin v11.4s, v11.4s, v2.4s
fmin v13.4s, v13.4s, v2.4s
fmin v15.4s, v15.4s, v2.4s
Relu4:
dup v3.4s, wzr
fmax v9.4s, v9.4s, v2.4s
fmax v11.4s, v11.4s, v2.4s
fmax v13.4s, v13.4s, v2.4s
fmax v15.4s, v15.4s, v2.4s
b Write
Write:
cmp x13, #1
beq Write1
cmp x13, #2
beq Write2
cmp x13, #3
beq Write3
b Write4
Write1:
add x2, x2, #4
str s9, [x11]
cmp x6, #1
beq WriteEnd
add x11, x11, x8
str s11, [x11]
cmp x6, #2
beq WriteEnd
add x11, x11, x8
str s13, [x11]
cmp x6, #3
beq WriteEnd
add x11, x11, x8
str s15, [x11]
cmp x6, #4
beq WriteEnd
add x11, x11, x8
str s17, [x11]
cmp x6, #5
beq WriteEnd
add x11, x11, x8
str s19, [x11]
cmp x6, #6
beq WriteEnd
add x11, x11, x8
str s21, [x11]
cmp x6, #7
beq WriteEnd
add x11, x11, x8
str s23, [x11]
cmp x6, #8
beq WriteEnd
add x11, x11, x8
str s25, [x11]
cmp x6, #9
beq WriteEnd
add x11, x11, x8
str s27, [x11]
cmp x6, #10
beq WriteEnd
add x11, x11, x8
str s29, [x11]
cmp x6, #11
beq WriteEnd
add x11, x11, x8
str s31, [x11]
add x11, x11, x8
add x11, x11, #4
b WriteEnd
Write2:
add x2, x2, #8
str d9, [x11]
cmp x6, #1
beq WriteEnd
add x11, x11, x8
str d11, [x11]
cmp x6, #2
beq WriteEnd
add x11, x11, x8
str d13, [x11]
cmp x6, #3
beq WriteEnd
add x11, x11, x8
str d15, [x11]
cmp x6, #4
beq WriteEnd
add x11, x11, x8
str d17, [x11]
cmp x6, #5
beq WriteEnd
add x11, x11, x8
str d19, [x11]
cmp x6, #6
beq WriteEnd
add x11, x11, x8
str d21, [x11]
cmp x6, #7
beq WriteEnd
add x11, x11, x8
str d23, [x11]
cmp x6, #8
beq WriteEnd
add x11, x11, x8
str d25, [x11]
cmp x6, #9
beq WriteEnd
add x11, x11, x8
str d27, [x11]
cmp x6, #10
beq WriteEnd
add x11, x11, x8
str d29, [x11]
cmp x6, #11
beq WriteEnd
add x11, x11, x8
str d31, [x11]
add x11, x11, x8
add x11, x11, #8
b WriteEnd
Write3:
add x2, x2, #12
add x19, x11, #8
str d9, [x11]
st1 {v9.s}[2], [x19], x8
cmp x6, #1
beq WriteEnd
add x11, x11, x8
str d11, [x11]
st1 {v11.s}[2], [x19], x8
cmp x6, #2
beq WriteEnd
add x11, x11, x8
str d13, [x11]
st1 {v13.s}[2], [x19], x8
cmp x6, #3
beq WriteEnd
add x11, x11, x8
str d15, [x11]
st1 {v15.s}[2], [x19], x8
cmp x6, #4
beq WriteEnd
add x11, x11, x8
str d17, [x11]
st1 {v17.s}[2], [x19], x8
cmp x6, #5
beq WriteEnd
add x11, x11, x8
str d19, [x11]
st1 {v19.s}[2], [x19], x8
cmp x6, #6
beq WriteEnd
add x11, x11, x8
str d21, [x11]
st1 {v21.s}[2], [x19], x8
cmp x6, #7
beq WriteEnd
add x11, x11, x8
str d23, [x11]
st1 {v23.s}[2], [x19], x8
cmp x6, #8
beq WriteEnd
add x11, x11, x8
str d25, [x11]
st1 {v25.s}[2], [x19], x8
cmp x6, #9
beq WriteEnd
add x11, x11, x8
str d27, [x11]
st1 {v27.s}[2], [x19], x8
cmp x6, #10
beq WriteEnd
add x11, x11, x8
str d29, [x11]
st1 {v29.s}[2], [x19], x8
cmp x6, #11
beq WriteEnd
add x11, x11, x8
str d31, [x11]
st1 {v31.s}[2], [x19]
add x11, x11, x8
add x11, x11, #12
b WriteEnd
Write4:
add x2, x2, #16
st1 {v9.4s}, [x11], x8
cmp x6, #1
beq WriteEnd
st1 {v11.4s}, [x11], x8
cmp x6, #2
beq WriteEnd
st1 {v13.4s}, [x11], x8
cmp x6, #3
beq WriteEnd
st1 {v15.4s}, [x11], x8
cmp x6, #4
beq WriteEnd
st1 {v17.4s}, [x11], x8
cmp x6, #5
beq WriteEnd
st1 {v19.4s}, [x11], x8
cmp x6, #6
beq WriteEnd
st1 {v21.4s}, [x11], x8
cmp x6, #7
beq WriteEnd
st1 {v23.4s}, [x11], x8
cmp x6, #8
beq WriteEnd
st1 {v25.4s}, [x11], x8
cmp x6, #9
beq WriteEnd
st1 {v27.4s}, [x11], x8
cmp x6, #10
beq WriteEnd
st1 {v29.4s}, [x11], x8
cmp x6, #11
beq WriteEnd
st1 {v31.4s}, [x11], x8
add x11, x11, #16
b WriteEnd
WriteEnd:
subs x13, x13, #4 // rhs col - 4
ble LoopColEnd
cmp x6, #4
ble LoopCol4
cmp x6, #8
ble LoopCol8
b LoopCol
LoopColEnd:
add x0, x0, x17
mov x18, #4
mul x18, x18, x7
sub x11, x11, x18
mov x2, x11
subs x6, x6, #12
bgt LoopRowStart
sub sp, sp, #144
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ldp x19, x20, [sp], #16
ret
#endif

View File

@ -0,0 +1,90 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp32/adder_fp32.h"
#include <string.h>
#include <math.h>
#include "nnacl/fp32/common_func_fp32.h"
#include "nnacl/fp32/matmul_fp32.h"
void Adder12x4(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
int col, int stride) {
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r12div = r / 12, r12mod = r % 12;
int c4div = c / 4, c4mod = c % 4;
size_t ci = r * stride + c;
float value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r12div * deep * 12 + d * 12 + r12mod;
size_t bi = c4div * deep * 4 + d * 4 + c4mod;
value += fabsf(a[ai] - b[bi]);
}
value = -value;
if (bias != NULL) value += bias[c];
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
if (act_type != ActType_No) value = MSMAX(0.0f, value);
dst[ci] = value;
}
}
}
void AdderOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col,
size_t stride) {
#ifdef ENABLE_ARM64
AdderFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride);
#else
Adder12x4(a, b, c, bias, act_type, deep, row, col, stride);
#endif
}
void AdderFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data,
float *col_major_input, float *output_data, int task_id, ConvParameter *conv_param) {
int out_channel = conv_param->output_channel_;
int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
int output_count = conv_param->output_h_ * conv_param->output_w_;
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
const int cal_num = C4NUM;
#else
const int cal_num = C12NUM;
#endif
int output_tile_count = UP_DIV(output_count, cal_num);
for (int b = 0; b < conv_param->input_batch_; b++) {
int in_batch_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_;
int out_batch_offset = b * out_channel * output_count;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) {
int start_index = thread_id * cal_num;
int real_cal_num = (output_count - start_index) < cal_num ? (output_count - start_index) : cal_num;
float *gemm_input = packed_input + task_id * deep * cal_num;
float *col_major_gemm_input = col_major_input + task_id * deep * cal_num;
size_t packed_input_size = deep * cal_num * sizeof(float);
memset(gemm_input, 0, packed_input_size);
memset(col_major_gemm_input, 0, packed_input_size);
Im2ColPackUnitFp32(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index);
int out_offset = thread_id * cal_num * out_channel + out_batch_offset;
float *gemm_output = output_data + out_offset;
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
RowMajor2Col4Major(gemm_input, col_major_gemm_input, cal_num, deep);
#else
RowMajor2Col12Major(gemm_input, col_major_gemm_input, cal_num, deep);
#endif
AdderOpt(col_major_gemm_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep, real_cal_num,
out_channel, out_channel);
}
}
}

View File

@ -0,0 +1,47 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP32_ADDER_H_
#define MINDSPORE_LITE_NNACL_FP32_ADDER_H_
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include "nnacl/pack.h"
#include "nnacl/op_base.h"
#include "nnacl/common_func.h"
#include "nnacl/conv_parameter.h"
#ifdef __cplusplus
extern "C" {
#endif
#ifdef ENABLE_ARM64
void AdderFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, size_t stride);
#endif
void AdderOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col,
size_t stride);
void AdderFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data,
float *col_major_input, float *output_data, int task_id, ConvParameter *conv_param);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP32_ADDER_H_

View File

@ -207,6 +207,26 @@ table Conv2D {
activationType: ActivationType = 0;
}
table Adder {
format: Format = 0;
group: int;
channelIn: int;
channelOut: int;
kernelW: int;
kernelH: int;
strideW: int;
strideH: int;
padMode: PadMode;
padUp: int;
padDown: int;
padLeft: int;
padRight: int;
dilateW: int;
dilateH: int;
hasBias: bool = false;
activationType: ActivationType = 0;
}
table Conv2DGradFilter {
format: Format = 0;
group: int;
@ -1177,6 +1197,3 @@ table All {
table Assert {
summarize : int;
}
table Adder {
}

View File

@ -15,6 +15,14 @@
*/
#include "src/ops/adder.h"
#include <memory>
#include <string>
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#ifdef PRIMITIVE_WRITEABLE
#include "tools/converter/quantizer/quantize_util.h"
#endif
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
@ -23,6 +31,118 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Adder::GetFormat() const { return this->primitive_->value.AsAdder()->format; }
int Adder::GetGroup() const { return this->primitive_->value.AsAdder()->group; }
int Adder::GetChannelIn() const { return this->primitive_->value.AsAdder()->channelIn; }
int Adder::GetChannelOut() const { return this->primitive_->value.AsAdder()->channelOut; }
int Adder::GetKernelW() const { return this->primitive_->value.AsAdder()->kernelW; }
int Adder::GetKernelH() const { return this->primitive_->value.AsAdder()->kernelH; }
int Adder::GetStrideW() const { return this->primitive_->value.AsAdder()->strideW; }
int Adder::GetStrideH() const { return this->primitive_->value.AsAdder()->strideH; }
int Adder::GetPadMode() const { return this->primitive_->value.AsAdder()->padMode; }
int Adder::GetPadUp() const { return this->primitive_->value.AsAdder()->padUp; }
int Adder::GetPadDown() const { return this->primitive_->value.AsAdder()->padDown; }
int Adder::GetPadLeft() const { return this->primitive_->value.AsAdder()->padLeft; }
int Adder::GetPadRight() const { return this->primitive_->value.AsAdder()->padRight; }
int Adder::GetDilateW() const { return this->primitive_->value.AsAdder()->dilateW; }
int Adder::GetDilateH() const { return this->primitive_->value.AsAdder()->dilateH; }
bool Adder::GetHasBias() const { return this->primitive_->value.AsAdder()->hasBias; }
int Adder::GetActivationType() const { return this->primitive_->value.AsAdder()->activationType; }
void Adder::SetFormat(int format) { this->primitive_->value.AsAdder()->format = (schema::Format)format; }
void Adder::SetGroup(int group) { this->primitive_->value.AsAdder()->group = group; }
void Adder::SetChannelIn(int channel_in) { this->primitive_->value.AsAdder()->channelIn = channel_in; }
void Adder::SetChannelOut(int channel_out) { this->primitive_->value.AsAdder()->channelOut = channel_out; }
void Adder::SetKernelW(int kernel_w) { this->primitive_->value.AsAdder()->kernelW = kernel_w; }
void Adder::SetKernelH(int kernel_h) { this->primitive_->value.AsAdder()->kernelH = kernel_h; }
void Adder::SetStrideW(int stride_w) { this->primitive_->value.AsAdder()->strideW = stride_w; }
void Adder::SetStrideH(int stride_h) { this->primitive_->value.AsAdder()->strideH = stride_h; }
void Adder::SetPadMode(int pad_mode) { this->primitive_->value.AsAdder()->padMode = (schema::PadMode)pad_mode; }
void Adder::SetPadUp(int pad_up) { this->primitive_->value.AsAdder()->padUp = pad_up; }
void Adder::SetPadDown(int pad_down) { this->primitive_->value.AsAdder()->padDown = pad_down; }
void Adder::SetPadLeft(int pad_left) { this->primitive_->value.AsAdder()->padLeft = pad_left; }
void Adder::SetPadRight(int pad_right) { this->primitive_->value.AsAdder()->padRight = pad_right; }
void Adder::SetDilateW(int dilate_w) { this->primitive_->value.AsAdder()->dilateW = dilate_w; }
void Adder::SetDilateH(int dilate_h) { this->primitive_->value.AsAdder()->dilateH = dilate_h; }
void Adder::SetHasBias(bool has_bias) { this->primitive_->value.AsAdder()->hasBias = has_bias; }
void Adder::SetActivationType(int activation_type) {
this->primitive_->value.AsAdder()->activationType = (schema::ActivationType)activation_type;
}
void Adder::PopulaterAdderSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group) {
auto attr = std::make_unique<schema::AdderT>();
attr->group = group;
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format::Format_NCHW;
} else if (format == "NHWC") {
attr->format = schema::Format::Format_NHWC;
} else {
attr->format = schema::Format::Format_NUM_OF_FORMAT;
}
auto pad_list = CastToInt(prim.GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = CastToInt(prim.GetAttr("dilation"));
attr->dilateH = dilation[2];
attr->dilateW = dilation[3];
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = CastToInt(prim.GetAttr("stride"));
attr->strideH = stride[2];
attr->strideW = stride[3];
attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front();
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") {
attr->padMode = schema::PadMode_SAME_UPPER;
} else {
attr->padMode = schema::PadMode_NOTSET;
}
if (prim.GetAttr("activation_name") != nullptr) {
auto activate_name = GetValue<std::string>(prim.GetAttr("activation_name"));
attr->activationType = kActivationTypeMap[activate_name];
} else {
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}
primitive->value.type = schema::PrimitiveType_Adder;
primitive->value.value = attr.release();
}
int Adder::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Adder;
}
if (this->primitive_->value.type != schema::PrimitiveType_Adder) {
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
return RET_ERROR;
}
auto groupAttr = prim.GetAttr("group");
if (groupAttr == nullptr) {
MS_LOG(ERROR) << "conv2d op has no group attr,please check pb model";
return RET_NULL_PTR;
}
int group = CastToInt(groupAttr).front();
PopulaterAdderSingleGroup(prim, this->primitive_, group);
PopulaterQuantParam(prim, inputs);
return RET_OK;
}
#else
int Adder::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
@ -33,39 +153,36 @@ int Adder::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::
MS_LOG(ERROR) << "value_as_Adder return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateAdder(*fbb);
auto val_offset = schema::CreateAdder(
*fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), attr->kernelW(), attr->kernelH(),
attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Adder, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int Adder::GetFormat() const { return this->primitive_->value_as_Adder()->format(); }
int Adder::GetGroup() const { return this->primitive_->value_as_Adder()->group(); }
int Adder::GetChannelIn() const { return this->primitive_->value_as_Adder()->channelIn(); }
int Adder::GetChannelOut() const { return this->primitive_->value_as_Adder()->channelOut(); }
int Adder::GetKernelW() const { return this->primitive_->value_as_Adder()->kernelW(); }
int Adder::GetKernelH() const { return this->primitive_->value_as_Adder()->kernelH(); }
int Adder::GetStrideW() const { return this->primitive_->value_as_Adder()->strideW(); }
int Adder::GetStrideH() const { return this->primitive_->value_as_Adder()->strideH(); }
int Adder::GetPadMode() const { return this->primitive_->value_as_Adder()->padMode(); }
int Adder::GetPadUp() const { return this->primitive_->value_as_Adder()->padUp(); }
int Adder::GetPadDown() const { return this->primitive_->value_as_Adder()->padDown(); }
int Adder::GetPadLeft() const { return this->primitive_->value_as_Adder()->padLeft(); }
int Adder::GetPadRight() const { return this->primitive_->value_as_Adder()->padRight(); }
int Adder::GetDilateW() const { return this->primitive_->value_as_Adder()->dilateW(); }
int Adder::GetDilateH() const { return this->primitive_->value_as_Adder()->dilateH(); }
bool Adder::GetHasBias() const { return this->primitive_->value_as_Adder()->hasBias(); }
int Adder::GetActivationType() const { return this->primitive_->value_as_Adder()->activationType(); }
PrimitiveC *AdderCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Adder>(primitive); }
Registry AdderRegistry(schema::PrimitiveType_Adder, AdderCreator);
#endif
int Adder::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);
MS_ASSERT(inputs_.size() == 2);
auto input0 = inputs_.front();
MS_ASSERT(input0 != nullptr);
MS_ASSERT(input0->shape().size() == 2);
auto input1 = inputs_.at(1);
MS_ASSERT(input1 != nullptr);
MS_ASSERT(input1->shape().size() == 2);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input0->data_type());
output->set_format(input0->format());
if (!infer_flag()) {
return RET_OK;
}
std::vector<int> in_shape;
in_shape.push_back(input0->shape().at(0));
in_shape.push_back(input1->shape().at(1));
output->set_shape(in_shape);
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -20,22 +20,62 @@
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
#include <memory>
#include "src/ops/conv2d.h"
namespace mindspore {
namespace lite {
class Adder : public PrimitiveC {
class Adder : public Conv2D {
public:
Adder() = default;
~Adder() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Adder, PrimitiveC);
explicit Adder(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
MS_DECLARE_PARENT(Adder, Conv2D);
explicit Adder(schema::PrimitiveT *primitive) : Conv2D(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void SetFormat(int format);
void SetGroup(int group);
void SetChannelIn(int channel_in);
void SetChannelOut(int channel_out);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
private:
void PopulaterAdderSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group);
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb);
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
public:
int GetFormat() const;
int GetGroup() const;
int GetChannelIn() const;
int GetChannelOut() const;
int GetKernelW() const;
int GetKernelH() const;
int GetStrideW() const;
int GetStrideH() const;
int GetPadMode() const;
int GetPadUp() const;
int GetPadDown() const;
int GetPadLeft() const;
int GetPadRight() const;
int GetDilateW() const;
int GetDilateH() const;
bool GetHasBias() const;
int GetActivationType() const;
};
} // namespace lite
} // namespace mindspore

View File

@ -34,23 +34,23 @@ class Conv2D : public PrimitiveC {
explicit Conv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void SetFormat(int format);
void SetGroup(int group);
void SetChannelIn(int channel_in);
void SetChannelOut(int channel_out);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
virtual void SetFormat(int format);
virtual void SetGroup(int group);
virtual void SetChannelIn(int channel_in);
virtual void SetChannelOut(int channel_out);
virtual void SetKernelW(int kernel_w);
virtual void SetKernelH(int kernel_h);
virtual void SetStrideW(int stride_w);
virtual void SetStrideH(int stride_h);
virtual void SetPadMode(int pad_mode);
virtual void SetPadUp(int pad_up);
virtual void SetPadDown(int pad_down);
virtual void SetPadLeft(int pad_left);
virtual void SetPadRight(int pad_right);
virtual void SetDilateW(int dilate_w);
virtual void SetDilateH(int dilate_h);
virtual void SetHasBias(bool has_bias);
virtual void SetActivationType(int activation_type);
private:
void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group,
@ -67,23 +67,23 @@ class Conv2D : public PrimitiveC {
int PadLeft() const;
int PadRight() const;
int GetFormat() const;
int GetGroup() const;
int GetChannelIn() const;
int GetChannelOut() const;
int GetKernelW() const;
int GetKernelH() const;
int GetStrideW() const;
int GetStrideH() const;
int GetPadMode() const;
int GetPadUp() const;
int GetPadDown() const;
int GetPadLeft() const;
int GetPadRight() const;
int GetDilateW() const;
int GetDilateH() const;
bool GetHasBias() const;
int GetActivationType() const;
virtual int GetFormat() const;
virtual int GetGroup() const;
virtual int GetChannelIn() const;
virtual int GetChannelOut() const;
virtual int GetKernelW() const;
virtual int GetKernelH() const;
virtual int GetStrideW() const;
virtual int GetStrideH() const;
virtual int GetPadMode() const;
virtual int GetPadUp() const;
virtual int GetPadDown() const;
virtual int GetPadLeft() const;
virtual int GetPadRight() const;
virtual int GetDilateW() const;
virtual int GetDilateH() const;
virtual bool GetHasBias() const;
virtual int GetActivationType() const;
protected:
void ConvInferShape(int input_h, int input_w, int *output_h, int *output_w);

View File

@ -15,24 +15,53 @@
*/
#include "src/ops/adder.h"
#include "src/common/log_adapter.h"
#include "nnacl/conv_parameter.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
#include "nnacl/adder.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateAdderParameter(const mindspore::lite::PrimitiveC *primitive) {
auto *adder_param = reinterpret_cast<AdderParameter *>(malloc(sizeof(AdderParameter)));
if (adder_param == nullptr) {
MS_LOG(ERROR) << "malloc AdderParameter failed.";
ConvParameter *conv_param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
if (conv_param == nullptr) {
MS_LOG(ERROR) << "malloc ConvParameter failed.";
return nullptr;
}
memset(adder_param, 0, sizeof(AdderParameter));
adder_param->op_parameter_.type_ = primitive->Type();
return reinterpret_cast<OpParameter *>(adder_param);
memset(conv_param, 0, sizeof(ConvParameter));
conv_param->op_parameter_.type_ = primitive->Type();
auto adder_primitive =
reinterpret_cast<mindspore::lite::Adder *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
conv_param->kernel_h_ = adder_primitive->GetKernelH();
conv_param->kernel_w_ = adder_primitive->GetKernelW();
conv_param->group_ = adder_primitive->GetGroup();
conv_param->stride_h_ = adder_primitive->GetStrideH();
conv_param->stride_w_ = adder_primitive->GetStrideW();
auto adder_lite_primitive = (lite::Adder *)primitive;
conv_param->pad_u_ = adder_lite_primitive->PadUp();
conv_param->pad_d_ = adder_lite_primitive->PadDown();
conv_param->pad_l_ = adder_lite_primitive->PadLeft();
conv_param->pad_r_ = adder_lite_primitive->PadRight();
conv_param->dilation_h_ = adder_primitive->GetDilateH();
conv_param->dilation_w_ = adder_primitive->GetDilateW();
conv_param->input_channel_ = adder_primitive->GetChannelIn();
conv_param->output_channel_ = adder_primitive->GetChannelOut();
conv_param->group_ = adder_primitive->GetGroup();
auto act_type = adder_primitive->GetActivationType();
switch (act_type) {
case schema::ActivationType_RELU:
conv_param->act_type_ = ActType_Relu;
break;
case schema::ActivationType_RELU6:
conv_param->act_type_ = ActType_Relu6;
break;
default:
conv_param->act_type_ = ActType_No;
break;
}
return reinterpret_cast<OpParameter *>(conv_param);
}
Registry AdderParameterRegistry(schema::PrimitiveType_Adder, PopulateAdderParameter);
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,133 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/adder_fp32.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"
#include "include/errorcode.h"
#include "schema/model_generated.h"
#include "nnacl/fp32/adder_fp32.h"
#include "nnacl/fp32/matmul_fp32.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_INFER_INVALID;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Adder;
using mindspore::schema::Format::Format_NHWC;
namespace mindspore::kernel {
int AdderCPUKernel::InitWeightBias() {
auto filter_tensor = in_tensors_.at(kWeightIndex);
int kernel_h = filter_tensor->Height();
int kernel_w = filter_tensor->Width();
int in_channel = filter_tensor->Channel();
int out_channel = filter_tensor->Batch();
conv_param_->input_channel_ = in_channel;
conv_param_->output_channel_ = out_channel;
int kernel_plane = kernel_h * kernel_w;
const int oc_block = C4NUM;
int oc_block_num = UP_DIV(out_channel, C4NUM);
int pack_weight_size = oc_block_num * oc_block * in_channel * kernel_plane;
auto origin_weight = reinterpret_cast<float *>(filter_tensor->MutableData());
packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float)));
if (packed_weight_ == nullptr) {
MS_LOG(ERROR) << "malloc packed weight failed.";
return RET_ERROR;
}
memset(packed_weight_, 0, pack_weight_size * sizeof(float));
RowMajor2Col4Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane);
bias_data_ = reinterpret_cast<float *>(malloc(oc_block_num * oc_block * sizeof(float)));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "malloc bias failed.";
return RET_ERROR;
}
memset(bias_data_, 0, oc_block_num * oc_block * sizeof(float));
if (in_tensors_.size() == kInputSize2) {
auto ori_bias = reinterpret_cast<float *>(in_tensors_.at(kBiasIndex)->MutableData());
memcpy(bias_data_, ori_bias, out_channel * sizeof(float));
} else {
MS_ASSERT(in_tensors_.size() == kInputSize1);
}
return RET_OK;
}
int AdderCPUKernel::RunImpl(int task_id) {
auto input_tensor = in_tensors_.at(kInputIndex);
auto ori_input_data = reinterpret_cast<float *>(input_tensor->data_c());
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->data_c());
AdderFp32(ori_input_data, packed_input_, packed_weight_, reinterpret_cast<float *>(bias_data_), col_major_input_,
output_addr, task_id, conv_param_);
return RET_OK;
}
int AdderImpl(void *cdata, int task_id) {
auto adder = reinterpret_cast<AdderCPUKernel *>(cdata);
auto error_code = adder->RunImpl(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Adder Run error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
int AdderCPUKernel::Run() {
auto ret = InitTmpBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init tmp buffer failed.";
return RET_ERROR;
}
int error_code = ParallelLaunch(this->context_->thread_pool_, AdderImpl, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "adder error error_code[" << error_code << "]";
FreeTmpBuffer();
return RET_ERROR;
}
FreeTmpBuffer();
return RET_OK;
}
kernel::LiteKernel *CpuAdderFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter,
const InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(op_parameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Adder);
MS_ASSERT(desc.data_type == kNumberTypeFloat32);
kernel::LiteKernel *kernel = new (std::nothrow) kernel::AdderCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
free(op_parameter);
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK && ret != RET_INFER_INVALID) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Adder, CpuAdderFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ADDER_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ADDER_H_
#include <vector>
#include "src/lite_kernel.h"
#include "nnacl/op_base.h"
#include "src/runtime/kernel/arm/fp32/convolution_fp32.h"
#include "nnacl/fp32/conv_fp32.h"
namespace mindspore::kernel {
class AdderCPUKernel : public ConvolutionCPUKernel {
public:
AdderCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: ConvolutionCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~AdderCPUKernel() override = default;
int InitWeightBias() override;
int Run() override;
int RunImpl(int task_id) override;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ADDER_H_

View File

@ -38,13 +38,13 @@ class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel {
}
int Init() override;
virtual int InitWeightBias();
int InitTmpBuffer();
int ReSize() override;
int Run() override;
int RunImpl(int task_id);
int InitWeightBias();
int InitTmpBuffer();
virtual int RunImpl(int task_id);
private:
protected:
void FreeTmpBuffer() {
if (packed_input_ != nullptr) {
ctx_->allocator->Free(packed_input_);
@ -55,8 +55,10 @@ class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel {
col_major_input_ = nullptr;
}
}
float *packed_input_ = nullptr;
protected:
float *packed_weight_ = nullptr;
float *packed_input_ = nullptr;
float *col_major_input_ = nullptr;
};
} // namespace mindspore::kernel