forked from mindspore-Ecosystem/mindspore
!8952 [MS][LITE][CPU]Add adder op
From: @fuzhiye Reviewed-by: @hangangqiang,@zhang_xue_tong Signed-off-by: @zhang_xue_tong
This commit is contained in:
commit
627e4d0cf3
|
@ -0,0 +1,633 @@
|
|||
#ifdef __aarch64__
|
||||
.text
|
||||
.align 5
|
||||
.global AdderFloatNeon64
|
||||
#ifndef __APPLE__
|
||||
.type AdderFloatNeon64, %function
|
||||
#endif
|
||||
|
||||
// void AdderFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth
|
||||
// int row, int col, size_t stride)
|
||||
// x0: a
|
||||
// x1: b
|
||||
// x2: c
|
||||
// x3: bias
|
||||
// x4: act_type
|
||||
// x5: depth
|
||||
// x6: row
|
||||
// x7: col
|
||||
// x8: stride
|
||||
// x9: writeMode
|
||||
|
||||
AdderFloatNeon64:
|
||||
sub sp, sp, #144
|
||||
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
stp x19, x20, [sp], #16
|
||||
|
||||
ldr x8, [sp]
|
||||
|
||||
mov x18, #48 // sizeof(float) * 12
|
||||
mul x17, x5, x18 // block stride of lhs/rhs: sizeof(float) * 12 * depth
|
||||
|
||||
mov x18, #4
|
||||
mul x8, x8, x18
|
||||
|
||||
LoopRowStart:
|
||||
cmp x6, #4
|
||||
ble LoopRow4
|
||||
cmp x6, #8
|
||||
blt LoopRow8
|
||||
|
||||
LoopRow:
|
||||
mov x14, x1 // reload rhs ptr
|
||||
mov x13, x7 // reload rhs col
|
||||
mov x12, x3 // reload bias
|
||||
|
||||
LoopCol:
|
||||
mov x11, x2
|
||||
mov x10, x0 // reload lhs ptr
|
||||
mov x19, x5 // reload depth
|
||||
|
||||
LoopDepthStart:
|
||||
ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48
|
||||
ld1 {v3.4s}, [x14], #16
|
||||
|
||||
dup v8.4s, v0.s[0]
|
||||
fabd v9.4s, v3.4s, v8.4s
|
||||
dup v10.4s, v0.s[1]
|
||||
fabd v11.4s, v3.4s, v10.4s
|
||||
dup v12.4s, v0.s[2]
|
||||
fabd v13.4s, v3.4s, v12.4s
|
||||
dup v14.4s, v0.s[3]
|
||||
fabd v15.4s, v3.4s, v14.4s
|
||||
|
||||
dup v16.4s, v1.s[0]
|
||||
fabd v17.4s, v3.4s, v16.4s
|
||||
dup v18.4s, v1.s[1]
|
||||
fabd v19.4s, v3.4s, v18.4s
|
||||
dup v20.4s, v1.s[2]
|
||||
fabd v21.4s, v3.4s, v20.4s
|
||||
dup v22.4s, v1.s[3]
|
||||
fabd v23.4s, v3.4s, v22.4s
|
||||
|
||||
dup v24.4s, v2.s[0]
|
||||
fabd v25.4s, v3.4s, v24.4s
|
||||
dup v26.4s, v2.s[1]
|
||||
fabd v27.4s, v3.4s, v26.4s
|
||||
dup v28.4s, v2.s[2]
|
||||
fabd v29.4s, v3.4s, v28.4s
|
||||
dup v30.4s, v2.s[3]
|
||||
fabd v31.4s, v3.4s, v30.4s
|
||||
|
||||
subs x19, x19, #1
|
||||
beq Bias
|
||||
|
||||
LoopDepth:
|
||||
ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48
|
||||
ld1 {v3.4s}, [x14], #16
|
||||
|
||||
dup v8.4s, v0.s[0]
|
||||
fabd v8.4s, v3.4s, v8.4s
|
||||
fadd v9.4s, v9.4s, v8.4s
|
||||
dup v10.4s, v0.s[1]
|
||||
fabd v10.4s, v3.4s, v10.4s
|
||||
fadd v11.4s, v11.4s, v10.4s
|
||||
dup v12.4s, v0.s[2]
|
||||
fabd v12.4s, v3.4s, v12.4s
|
||||
fadd v13.4s, v13.4s, v12.4s
|
||||
dup v14.4s, v0.s[3]
|
||||
fabd v14.4s, v3.4s, v14.4s
|
||||
fadd v15.4s, v15.4s, v14.4s
|
||||
|
||||
dup v16.4s, v1.s[0]
|
||||
fabd v16.4s, v3.4s, v16.4s
|
||||
fadd v17.4s, v17.4s, v16.4s
|
||||
dup v18.4s, v1.s[1]
|
||||
fabd v18.4s, v3.4s, v18.4s
|
||||
fadd v19.4s, v19.4s, v18.4s
|
||||
dup v20.4s, v1.s[2]
|
||||
fabd v20.4s, v3.4s, v20.4s
|
||||
fadd v21.4s, v21.4s, v20.4s
|
||||
dup v22.4s, v1.s[3]
|
||||
fabd v22.4s, v3.4s, v22.4s
|
||||
fadd v23.4s, v23.4s, v22.4s
|
||||
|
||||
dup v24.4s, v2.s[0]
|
||||
fabd v24.4s, v3.4s, v24.4s
|
||||
fadd v25.4s, v25.4s, v24.4s
|
||||
dup v26.4s, v2.s[1]
|
||||
fabd v26.4s, v3.4s, v26.4s
|
||||
fadd v27.4s, v27.4s, v26.4s
|
||||
dup v28.4s, v2.s[2]
|
||||
fabd v28.4s, v3.4s, v28.4s
|
||||
fadd v29.4s, v29.4s, v28.4s
|
||||
dup v30.4s, v2.s[3]
|
||||
fabd v30.4s, v3.4s, v30.4s
|
||||
fadd v31.4s, v31.4s, v30.4s
|
||||
|
||||
subs x19, x19, #1
|
||||
bgt LoopDepth
|
||||
|
||||
Bias:
|
||||
fneg v9.4s, v9.4s
|
||||
fneg v11.4s, v11.4s
|
||||
fneg v13.4s, v13.4s
|
||||
fneg v15.4s, v15.4s
|
||||
fneg v17.4s, v17.4s
|
||||
fneg v19.4s, v19.4s
|
||||
fneg v21.4s, v21.4s
|
||||
fneg v23.4s, v23.4s
|
||||
fneg v25.4s, v25.4s
|
||||
fneg v27.4s, v27.4s
|
||||
fneg v29.4s, v29.4s
|
||||
fneg v31.4s, v31.4s
|
||||
cbz x3, Activation
|
||||
ld1 {v0.4s}, [x12], #16
|
||||
fadd v9.4s, v9.4s, v0.4s
|
||||
fadd v11.4s, v11.4s, v0.4s
|
||||
fadd v13.4s, v13.4s, v0.4s
|
||||
fadd v15.4s, v15.4s, v0.4s
|
||||
fadd v17.4s, v17.4s, v0.4s
|
||||
fadd v19.4s, v19.4s, v0.4s
|
||||
fadd v21.4s, v21.4s, v0.4s
|
||||
fadd v23.4s, v23.4s, v0.4s
|
||||
fadd v25.4s, v25.4s, v0.4s
|
||||
fadd v27.4s, v27.4s, v0.4s
|
||||
fadd v29.4s, v29.4s, v0.4s
|
||||
fadd v31.4s, v31.4s, v0.4s
|
||||
|
||||
Activation:
|
||||
cmp x4, #3
|
||||
beq Relu6
|
||||
cmp x4, #1
|
||||
beq Relu
|
||||
b Write
|
||||
|
||||
Relu6:
|
||||
mov w19, #6
|
||||
dup v2.4s, w19
|
||||
scvtf v2.4s, v2.4s
|
||||
fmin v9.4s, v9.4s, v2.4s
|
||||
fmin v11.4s, v11.4s, v2.4s
|
||||
fmin v13.4s, v13.4s, v2.4s
|
||||
fmin v15.4s, v15.4s, v2.4s
|
||||
fmin v17.4s, v17.4s, v2.4s
|
||||
fmin v19.4s, v19.4s, v2.4s
|
||||
fmin v21.4s, v21.4s, v2.4s
|
||||
fmin v23.4s, v23.4s, v2.4s
|
||||
fmin v25.4s, v25.4s, v2.4s
|
||||
fmin v27.4s, v27.4s, v2.4s
|
||||
fmin v29.4s, v29.4s, v2.4s
|
||||
fmin v31.4s, v31.4s, v2.4s
|
||||
|
||||
Relu:
|
||||
dup v3.4s, wzr
|
||||
fmax v9.4s, v9.4s, v3.4s
|
||||
fmax v11.4s, v11.4s, v3.4s
|
||||
fmax v13.4s, v13.4s, v3.4s
|
||||
fmax v15.4s, v15.4s, v3.4s
|
||||
fmax v17.4s, v17.4s, v3.4s
|
||||
fmax v19.4s, v19.4s, v3.4s
|
||||
fmax v21.4s, v21.4s, v3.4s
|
||||
fmax v23.4s, v23.4s, v3.4s
|
||||
fmax v25.4s, v25.4s, v3.4s
|
||||
fmax v27.4s, v27.4s, v3.4s
|
||||
fmax v29.4s, v29.4s, v3.4s
|
||||
fmax v31.4s, v31.4s, v3.4s
|
||||
b Write
|
||||
|
||||
LoopRow8:
|
||||
mov x14, x1 // reload rhs ptr
|
||||
mov x13, x7 // reload rhs col
|
||||
mov x12, x3 // reload bias
|
||||
|
||||
LoopCol8:
|
||||
mov x11, x2
|
||||
mov x10, x0 // reload lhs ptr
|
||||
mov x19, x5 // reload depth
|
||||
|
||||
LoopDepthStart8:
|
||||
ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48
|
||||
ld1 {v3.4s}, [x14], #16
|
||||
|
||||
dup v8.4s, v0.s[0]
|
||||
fabd v9.4s, v3.4s, v8.4s
|
||||
dup v10.4s, v0.s[1]
|
||||
fabd v11.4s, v3.4s, v10.4s
|
||||
dup v12.4s, v0.s[2]
|
||||
fabd v13.4s, v3.4s, v12.4s
|
||||
dup v14.4s, v0.s[3]
|
||||
fabd v15.4s, v3.4s, v14.4s
|
||||
|
||||
dup v16.4s, v1.s[0]
|
||||
fabd v17.4s, v3.4s, v16.4s
|
||||
dup v18.4s, v1.s[1]
|
||||
fabd v19.4s, v3.4s, v18.4s
|
||||
dup v20.4s, v1.s[2]
|
||||
fabd v21.4s, v3.4s, v20.4s
|
||||
dup v22.4s, v1.s[3]
|
||||
fabd v23.4s, v3.4s, v22.4s
|
||||
|
||||
subs x19, x19, #1
|
||||
beq Bias8
|
||||
|
||||
LoopDepth8:
|
||||
ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48
|
||||
ld1 {v3.4s}, [x14], #16
|
||||
dup v8.4s, v0.s[0]
|
||||
fabd v8.4s, v3.4s, v8.4s
|
||||
fadd v9.4s, v9.4s, v8.4s
|
||||
dup v10.4s, v0.s[1]
|
||||
fabd v10.4s, v3.4s, v10.4s
|
||||
fadd v11.4s, v11.4s, v10.4s
|
||||
dup v12.4s, v0.s[2]
|
||||
fabd v12.4s, v3.4s, v12.4s
|
||||
fadd v13.4s, v13.4s, v12.4s
|
||||
dup v14.4s, v0.s[3]
|
||||
fabd v14.4s, v3.4s, v14.4s
|
||||
fadd v15.4s, v15.4s, v14.4s
|
||||
|
||||
dup v16.4s, v1.s[0]
|
||||
fabd v16.4s, v3.4s, v16.4s
|
||||
fadd v17.4s, v17.4s, v16.4s
|
||||
dup v18.4s, v1.s[1]
|
||||
fabd v18.4s, v3.4s, v18.4s
|
||||
fadd v19.4s, v19.4s, v18.4s
|
||||
dup v20.4s, v1.s[2]
|
||||
fabd v20.4s, v3.4s, v20.4s
|
||||
fadd v21.4s, v21.4s, v20.4s
|
||||
dup v22.4s, v1.s[3]
|
||||
fabd v22.4s, v3.4s, v22.4s
|
||||
fadd v23.4s, v23.4s, v22.4s
|
||||
|
||||
subs x19, x19, #1
|
||||
bgt LoopDepth8
|
||||
|
||||
Bias8:
|
||||
fneg v9.4s, v9.4s
|
||||
fneg v11.4s, v11.4s
|
||||
fneg v13.4s, v13.4s
|
||||
fneg v15.4s, v15.4s
|
||||
fneg v17.4s, v17.4s
|
||||
fneg v19.4s, v19.4s
|
||||
fneg v21.4s, v21.4s
|
||||
fneg v23.4s, v23.4s
|
||||
cbz x3, Activation8
|
||||
ld1 {v0.4s}, [x12], #16
|
||||
fadd v9.4s, v9.4s, v0.4s
|
||||
fadd v11.4s, v11.4s, v0.4s
|
||||
fadd v13.4s, v13.4s, v0.4s
|
||||
fadd v15.4s, v15.4s, v0.4s
|
||||
fadd v17.4s, v17.4s, v0.4s
|
||||
fadd v19.4s, v19.4s, v0.4s
|
||||
fadd v21.4s, v21.4s, v0.4s
|
||||
fadd v23.4s, v23.4s, v0.4s
|
||||
|
||||
Activation8:
|
||||
cmp x4, #3
|
||||
beq Relu68
|
||||
cmp x4, #1
|
||||
beq Relu8
|
||||
b Write
|
||||
|
||||
Relu68:
|
||||
mov w19, #6
|
||||
dup v2.4s, w19
|
||||
scvtf v2.4s, v2.4s
|
||||
fmin v9.4s, v9.4s, v2.4s
|
||||
fmin v11.4s, v11.4s, v2.4s
|
||||
fmin v13.4s, v13.4s, v2.4s
|
||||
fmin v15.4s, v15.4s, v2.4s
|
||||
fmin v17.4s, v17.4s, v2.4s
|
||||
fmin v19.4s, v19.4s, v2.4s
|
||||
fmin v21.4s, v21.4s, v2.4s
|
||||
fmin v23.4s, v23.4s, v2.4s
|
||||
Relu8:
|
||||
dup v3.4s, wzr
|
||||
fmax v9.4s, v9.4s, v3.4s
|
||||
fmax v11.4s, v11.4s, v3.4s
|
||||
fmax v13.4s, v13.4s, v3.4s
|
||||
fmax v15.4s, v15.4s, v3.4s
|
||||
fmax v17.4s, v17.4s, v3.4s
|
||||
fmax v19.4s, v19.4s, v3.4s
|
||||
fmax v21.4s, v21.4s, v3.4s
|
||||
fmax v23.4s, v23.4s, v3.4s
|
||||
b Write
|
||||
|
||||
LoopRow4:
|
||||
mov x14, x1 // reload rhs ptr
|
||||
mov x13, x7 // reload rhs col
|
||||
mov x12, x3 // reload bias
|
||||
|
||||
LoopCol4:
|
||||
mov x11, x2
|
||||
mov x10, x0 // reload lhs ptr
|
||||
mov x19, x5 // reload depth
|
||||
|
||||
LoopDepthStart4:
|
||||
ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48
|
||||
ld1 {v3.4s}, [x14], #16
|
||||
dup v8.4s, v0.s[0]
|
||||
fabd v9.4s, v3.4s, v8.4s
|
||||
dup v10.4s, v0.s[1]
|
||||
fabd v11.4s, v3.4s, v10.4s
|
||||
dup v12.4s, v0.s[2]
|
||||
fabd v13.4s, v3.4s, v12.4s
|
||||
dup v14.4s, v0.s[3]
|
||||
fabd v15.4s, v3.4s, v14.4s
|
||||
|
||||
subs x19, x19, #1
|
||||
beq Bias4
|
||||
|
||||
LoopDepth4:
|
||||
ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48
|
||||
ld1 {v3.4s}, [x14], #16
|
||||
dup v8.4s, v0.s[0]
|
||||
fabd v8.4s, v3.4s, v8.4s
|
||||
fadd v9.4s, v9.4s, v8.4s
|
||||
dup v10.4s, v0.s[1]
|
||||
fabd v10.4s, v3.4s, v10.4s
|
||||
fadd v11.4s, v11.4s, v10.4s
|
||||
dup v12.4s, v0.s[2]
|
||||
fabd v12.4s, v3.4s, v12.4s
|
||||
fadd v13.4s, v13.4s, v12.4s
|
||||
dup v14.4s, v0.s[3]
|
||||
fabd v14.4s, v3.4s, v14.4s
|
||||
fadd v15.4s, v15.4s, v14.4s
|
||||
|
||||
subs x19, x19, #1
|
||||
bgt LoopDepth4
|
||||
|
||||
Bias4:
|
||||
fneg v9.4s, v9.4s
|
||||
fneg v11.4s, v11.4s
|
||||
fneg v13.4s, v13.4s
|
||||
fneg v15.4s, v15.4s
|
||||
cbz x3, Activation4
|
||||
ld1 {v0.4s}, [x12], #16
|
||||
|
||||
fadd v9.4s, v9.4s, v0.4s
|
||||
fadd v11.4s, v11.4s, v0.4s
|
||||
fadd v13.4s, v13.4s, v0.4s
|
||||
fadd v15.4s, v15.4s, v0.4s
|
||||
|
||||
Activation4:
|
||||
cmp x4, #3
|
||||
beq Relu64
|
||||
cmp x4, #1
|
||||
beq Relu4
|
||||
b Write
|
||||
|
||||
Relu64:
|
||||
mov w19, #6
|
||||
dup v2.4s, w19
|
||||
scvtf v2.4s, v2.4s
|
||||
fmin v9.4s, v9.4s, v2.4s
|
||||
fmin v11.4s, v11.4s, v2.4s
|
||||
fmin v13.4s, v13.4s, v2.4s
|
||||
fmin v15.4s, v15.4s, v2.4s
|
||||
|
||||
Relu4:
|
||||
dup v3.4s, wzr
|
||||
fmax v9.4s, v9.4s, v2.4s
|
||||
fmax v11.4s, v11.4s, v2.4s
|
||||
fmax v13.4s, v13.4s, v2.4s
|
||||
fmax v15.4s, v15.4s, v2.4s
|
||||
b Write
|
||||
|
||||
Write:
|
||||
cmp x13, #1
|
||||
beq Write1
|
||||
cmp x13, #2
|
||||
beq Write2
|
||||
cmp x13, #3
|
||||
beq Write3
|
||||
b Write4
|
||||
|
||||
Write1:
|
||||
add x2, x2, #4
|
||||
str s9, [x11]
|
||||
cmp x6, #1
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str s11, [x11]
|
||||
cmp x6, #2
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str s13, [x11]
|
||||
cmp x6, #3
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str s15, [x11]
|
||||
cmp x6, #4
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str s17, [x11]
|
||||
cmp x6, #5
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str s19, [x11]
|
||||
cmp x6, #6
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str s21, [x11]
|
||||
cmp x6, #7
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str s23, [x11]
|
||||
cmp x6, #8
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str s25, [x11]
|
||||
cmp x6, #9
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str s27, [x11]
|
||||
cmp x6, #10
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str s29, [x11]
|
||||
cmp x6, #11
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str s31, [x11]
|
||||
add x11, x11, x8
|
||||
add x11, x11, #4
|
||||
b WriteEnd
|
||||
Write2:
|
||||
add x2, x2, #8
|
||||
str d9, [x11]
|
||||
cmp x6, #1
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str d11, [x11]
|
||||
cmp x6, #2
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str d13, [x11]
|
||||
cmp x6, #3
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str d15, [x11]
|
||||
cmp x6, #4
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str d17, [x11]
|
||||
cmp x6, #5
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str d19, [x11]
|
||||
cmp x6, #6
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str d21, [x11]
|
||||
cmp x6, #7
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str d23, [x11]
|
||||
cmp x6, #8
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str d25, [x11]
|
||||
cmp x6, #9
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str d27, [x11]
|
||||
cmp x6, #10
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str d29, [x11]
|
||||
cmp x6, #11
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str d31, [x11]
|
||||
add x11, x11, x8
|
||||
add x11, x11, #8
|
||||
b WriteEnd
|
||||
Write3:
|
||||
add x2, x2, #12
|
||||
add x19, x11, #8
|
||||
str d9, [x11]
|
||||
st1 {v9.s}[2], [x19], x8
|
||||
cmp x6, #1
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str d11, [x11]
|
||||
st1 {v11.s}[2], [x19], x8
|
||||
cmp x6, #2
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str d13, [x11]
|
||||
st1 {v13.s}[2], [x19], x8
|
||||
cmp x6, #3
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str d15, [x11]
|
||||
st1 {v15.s}[2], [x19], x8
|
||||
cmp x6, #4
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str d17, [x11]
|
||||
st1 {v17.s}[2], [x19], x8
|
||||
cmp x6, #5
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str d19, [x11]
|
||||
st1 {v19.s}[2], [x19], x8
|
||||
cmp x6, #6
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str d21, [x11]
|
||||
st1 {v21.s}[2], [x19], x8
|
||||
cmp x6, #7
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str d23, [x11]
|
||||
st1 {v23.s}[2], [x19], x8
|
||||
cmp x6, #8
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str d25, [x11]
|
||||
st1 {v25.s}[2], [x19], x8
|
||||
cmp x6, #9
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str d27, [x11]
|
||||
st1 {v27.s}[2], [x19], x8
|
||||
cmp x6, #10
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str d29, [x11]
|
||||
st1 {v29.s}[2], [x19], x8
|
||||
cmp x6, #11
|
||||
beq WriteEnd
|
||||
add x11, x11, x8
|
||||
str d31, [x11]
|
||||
st1 {v31.s}[2], [x19]
|
||||
add x11, x11, x8
|
||||
add x11, x11, #12
|
||||
b WriteEnd
|
||||
Write4:
|
||||
add x2, x2, #16
|
||||
st1 {v9.4s}, [x11], x8
|
||||
cmp x6, #1
|
||||
beq WriteEnd
|
||||
st1 {v11.4s}, [x11], x8
|
||||
cmp x6, #2
|
||||
beq WriteEnd
|
||||
st1 {v13.4s}, [x11], x8
|
||||
cmp x6, #3
|
||||
beq WriteEnd
|
||||
st1 {v15.4s}, [x11], x8
|
||||
cmp x6, #4
|
||||
beq WriteEnd
|
||||
st1 {v17.4s}, [x11], x8
|
||||
cmp x6, #5
|
||||
beq WriteEnd
|
||||
st1 {v19.4s}, [x11], x8
|
||||
cmp x6, #6
|
||||
beq WriteEnd
|
||||
st1 {v21.4s}, [x11], x8
|
||||
cmp x6, #7
|
||||
beq WriteEnd
|
||||
st1 {v23.4s}, [x11], x8
|
||||
cmp x6, #8
|
||||
beq WriteEnd
|
||||
st1 {v25.4s}, [x11], x8
|
||||
cmp x6, #9
|
||||
beq WriteEnd
|
||||
st1 {v27.4s}, [x11], x8
|
||||
cmp x6, #10
|
||||
beq WriteEnd
|
||||
st1 {v29.4s}, [x11], x8
|
||||
cmp x6, #11
|
||||
beq WriteEnd
|
||||
st1 {v31.4s}, [x11], x8
|
||||
add x11, x11, #16
|
||||
b WriteEnd
|
||||
|
||||
WriteEnd:
|
||||
subs x13, x13, #4 // rhs col - 4
|
||||
ble LoopColEnd
|
||||
cmp x6, #4
|
||||
ble LoopCol4
|
||||
cmp x6, #8
|
||||
ble LoopCol8
|
||||
b LoopCol
|
||||
|
||||
LoopColEnd:
|
||||
add x0, x0, x17
|
||||
mov x18, #4
|
||||
mul x18, x18, x7
|
||||
sub x11, x11, x18
|
||||
mov x2, x11
|
||||
subs x6, x6, #12
|
||||
bgt LoopRowStart
|
||||
|
||||
sub sp, sp, #144
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
ldp x19, x20, [sp], #16
|
||||
ret
|
||||
#endif
|
|
@ -0,0 +1,90 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/fp32/adder_fp32.h"
|
||||
#include <string.h>
|
||||
#include <math.h>
|
||||
#include "nnacl/fp32/common_func_fp32.h"
|
||||
#include "nnacl/fp32/matmul_fp32.h"
|
||||
|
||||
void Adder12x4(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
|
||||
int col, int stride) {
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int c = 0; c < col; c++) {
|
||||
int r12div = r / 12, r12mod = r % 12;
|
||||
int c4div = c / 4, c4mod = c % 4;
|
||||
size_t ci = r * stride + c;
|
||||
float value = 0;
|
||||
for (int d = 0; d < deep; d++) {
|
||||
size_t ai = r12div * deep * 12 + d * 12 + r12mod;
|
||||
size_t bi = c4div * deep * 4 + d * 4 + c4mod;
|
||||
value += fabsf(a[ai] - b[bi]);
|
||||
}
|
||||
value = -value;
|
||||
if (bias != NULL) value += bias[c];
|
||||
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
|
||||
if (act_type != ActType_No) value = MSMAX(0.0f, value);
|
||||
dst[ci] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AdderOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col,
|
||||
size_t stride) {
|
||||
#ifdef ENABLE_ARM64
|
||||
AdderFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride);
|
||||
#else
|
||||
Adder12x4(a, b, c, bias, act_type, deep, row, col, stride);
|
||||
#endif
|
||||
}
|
||||
|
||||
void AdderFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data,
|
||||
float *col_major_input, float *output_data, int task_id, ConvParameter *conv_param) {
|
||||
int out_channel = conv_param->output_channel_;
|
||||
int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
|
||||
int output_count = conv_param->output_h_ * conv_param->output_w_;
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
|
||||
const int cal_num = C4NUM;
|
||||
#else
|
||||
const int cal_num = C12NUM;
|
||||
#endif
|
||||
int output_tile_count = UP_DIV(output_count, cal_num);
|
||||
|
||||
for (int b = 0; b < conv_param->input_batch_; b++) {
|
||||
int in_batch_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_;
|
||||
int out_batch_offset = b * out_channel * output_count;
|
||||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) {
|
||||
int start_index = thread_id * cal_num;
|
||||
int real_cal_num = (output_count - start_index) < cal_num ? (output_count - start_index) : cal_num;
|
||||
float *gemm_input = packed_input + task_id * deep * cal_num;
|
||||
float *col_major_gemm_input = col_major_input + task_id * deep * cal_num;
|
||||
size_t packed_input_size = deep * cal_num * sizeof(float);
|
||||
memset(gemm_input, 0, packed_input_size);
|
||||
memset(col_major_gemm_input, 0, packed_input_size);
|
||||
Im2ColPackUnitFp32(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index);
|
||||
|
||||
int out_offset = thread_id * cal_num * out_channel + out_batch_offset;
|
||||
float *gemm_output = output_data + out_offset;
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
|
||||
RowMajor2Col4Major(gemm_input, col_major_gemm_input, cal_num, deep);
|
||||
#else
|
||||
RowMajor2Col12Major(gemm_input, col_major_gemm_input, cal_num, deep);
|
||||
#endif
|
||||
AdderOpt(col_major_gemm_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep, real_cal_num,
|
||||
out_channel, out_channel);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_NNACL_FP32_ADDER_H_
|
||||
#define MINDSPORE_LITE_NNACL_FP32_ADDER_H_
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/pack.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/common_func.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void AdderFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
|
||||
int col, size_t stride);
|
||||
#endif
|
||||
|
||||
void AdderOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col,
|
||||
size_t stride);
|
||||
|
||||
void AdderFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data,
|
||||
float *col_major_input, float *output_data, int task_id, ConvParameter *conv_param);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_FP32_ADDER_H_
|
|
@ -207,6 +207,26 @@ table Conv2D {
|
|||
activationType: ActivationType = 0;
|
||||
}
|
||||
|
||||
table Adder {
|
||||
format: Format = 0;
|
||||
group: int;
|
||||
channelIn: int;
|
||||
channelOut: int;
|
||||
kernelW: int;
|
||||
kernelH: int;
|
||||
strideW: int;
|
||||
strideH: int;
|
||||
padMode: PadMode;
|
||||
padUp: int;
|
||||
padDown: int;
|
||||
padLeft: int;
|
||||
padRight: int;
|
||||
dilateW: int;
|
||||
dilateH: int;
|
||||
hasBias: bool = false;
|
||||
activationType: ActivationType = 0;
|
||||
}
|
||||
|
||||
table Conv2DGradFilter {
|
||||
format: Format = 0;
|
||||
group: int;
|
||||
|
@ -1177,6 +1197,3 @@ table All {
|
|||
table Assert {
|
||||
summarize : int;
|
||||
}
|
||||
|
||||
table Adder {
|
||||
}
|
||||
|
|
|
@ -15,6 +15,14 @@
|
|||
*/
|
||||
|
||||
#include "src/ops/adder.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#endif
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
|
@ -23,6 +31,118 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Adder::GetFormat() const { return this->primitive_->value.AsAdder()->format; }
|
||||
int Adder::GetGroup() const { return this->primitive_->value.AsAdder()->group; }
|
||||
int Adder::GetChannelIn() const { return this->primitive_->value.AsAdder()->channelIn; }
|
||||
int Adder::GetChannelOut() const { return this->primitive_->value.AsAdder()->channelOut; }
|
||||
int Adder::GetKernelW() const { return this->primitive_->value.AsAdder()->kernelW; }
|
||||
int Adder::GetKernelH() const { return this->primitive_->value.AsAdder()->kernelH; }
|
||||
int Adder::GetStrideW() const { return this->primitive_->value.AsAdder()->strideW; }
|
||||
int Adder::GetStrideH() const { return this->primitive_->value.AsAdder()->strideH; }
|
||||
int Adder::GetPadMode() const { return this->primitive_->value.AsAdder()->padMode; }
|
||||
int Adder::GetPadUp() const { return this->primitive_->value.AsAdder()->padUp; }
|
||||
int Adder::GetPadDown() const { return this->primitive_->value.AsAdder()->padDown; }
|
||||
int Adder::GetPadLeft() const { return this->primitive_->value.AsAdder()->padLeft; }
|
||||
int Adder::GetPadRight() const { return this->primitive_->value.AsAdder()->padRight; }
|
||||
int Adder::GetDilateW() const { return this->primitive_->value.AsAdder()->dilateW; }
|
||||
int Adder::GetDilateH() const { return this->primitive_->value.AsAdder()->dilateH; }
|
||||
bool Adder::GetHasBias() const { return this->primitive_->value.AsAdder()->hasBias; }
|
||||
int Adder::GetActivationType() const { return this->primitive_->value.AsAdder()->activationType; }
|
||||
|
||||
void Adder::SetFormat(int format) { this->primitive_->value.AsAdder()->format = (schema::Format)format; }
|
||||
void Adder::SetGroup(int group) { this->primitive_->value.AsAdder()->group = group; }
|
||||
void Adder::SetChannelIn(int channel_in) { this->primitive_->value.AsAdder()->channelIn = channel_in; }
|
||||
void Adder::SetChannelOut(int channel_out) { this->primitive_->value.AsAdder()->channelOut = channel_out; }
|
||||
void Adder::SetKernelW(int kernel_w) { this->primitive_->value.AsAdder()->kernelW = kernel_w; }
|
||||
void Adder::SetKernelH(int kernel_h) { this->primitive_->value.AsAdder()->kernelH = kernel_h; }
|
||||
void Adder::SetStrideW(int stride_w) { this->primitive_->value.AsAdder()->strideW = stride_w; }
|
||||
void Adder::SetStrideH(int stride_h) { this->primitive_->value.AsAdder()->strideH = stride_h; }
|
||||
void Adder::SetPadMode(int pad_mode) { this->primitive_->value.AsAdder()->padMode = (schema::PadMode)pad_mode; }
|
||||
void Adder::SetPadUp(int pad_up) { this->primitive_->value.AsAdder()->padUp = pad_up; }
|
||||
void Adder::SetPadDown(int pad_down) { this->primitive_->value.AsAdder()->padDown = pad_down; }
|
||||
void Adder::SetPadLeft(int pad_left) { this->primitive_->value.AsAdder()->padLeft = pad_left; }
|
||||
void Adder::SetPadRight(int pad_right) { this->primitive_->value.AsAdder()->padRight = pad_right; }
|
||||
void Adder::SetDilateW(int dilate_w) { this->primitive_->value.AsAdder()->dilateW = dilate_w; }
|
||||
void Adder::SetDilateH(int dilate_h) { this->primitive_->value.AsAdder()->dilateH = dilate_h; }
|
||||
void Adder::SetHasBias(bool has_bias) { this->primitive_->value.AsAdder()->hasBias = has_bias; }
|
||||
void Adder::SetActivationType(int activation_type) {
|
||||
this->primitive_->value.AsAdder()->activationType = (schema::ActivationType)activation_type;
|
||||
}
|
||||
|
||||
void Adder::PopulaterAdderSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group) {
|
||||
auto attr = std::make_unique<schema::AdderT>();
|
||||
attr->group = group;
|
||||
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
|
||||
if (format == "NCHW") {
|
||||
attr->format = schema::Format::Format_NCHW;
|
||||
} else if (format == "NHWC") {
|
||||
attr->format = schema::Format::Format_NHWC;
|
||||
} else {
|
||||
attr->format = schema::Format::Format_NUM_OF_FORMAT;
|
||||
}
|
||||
auto pad_list = CastToInt(prim.GetAttr("pad_list"));
|
||||
attr->padUp = pad_list[0];
|
||||
attr->padDown = pad_list[1];
|
||||
attr->padLeft = pad_list[2];
|
||||
attr->padRight = pad_list[3];
|
||||
|
||||
auto dilation = CastToInt(prim.GetAttr("dilation"));
|
||||
attr->dilateH = dilation[2];
|
||||
attr->dilateW = dilation[3];
|
||||
|
||||
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"));
|
||||
attr->kernelH = kernel_size[0];
|
||||
attr->kernelW = kernel_size[1];
|
||||
|
||||
auto stride = CastToInt(prim.GetAttr("stride"));
|
||||
attr->strideH = stride[2];
|
||||
attr->strideW = stride[3];
|
||||
|
||||
attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front();
|
||||
|
||||
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
|
||||
if (pad_mode == "valid") {
|
||||
attr->padMode = schema::PadMode_VALID;
|
||||
} else if (pad_mode == "same") {
|
||||
attr->padMode = schema::PadMode_SAME_UPPER;
|
||||
} else {
|
||||
attr->padMode = schema::PadMode_NOTSET;
|
||||
}
|
||||
|
||||
if (prim.GetAttr("activation_name") != nullptr) {
|
||||
auto activate_name = GetValue<std::string>(prim.GetAttr("activation_name"));
|
||||
attr->activationType = kActivationTypeMap[activate_name];
|
||||
} else {
|
||||
attr->activationType = schema::ActivationType_NO_ACTIVATION;
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_Adder;
|
||||
primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
int Adder::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_Adder;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_Adder) {
|
||||
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto groupAttr = prim.GetAttr("group");
|
||||
if (groupAttr == nullptr) {
|
||||
MS_LOG(ERROR) << "conv2d op has no group attr,please check pb model";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
int group = CastToInt(groupAttr).front();
|
||||
PopulaterAdderSingleGroup(prim, this->primitive_, group);
|
||||
PopulaterQuantParam(prim, inputs);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#else
|
||||
int Adder::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
|
@ -33,39 +153,36 @@ int Adder::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::
|
|||
MS_LOG(ERROR) << "value_as_Adder return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateAdder(*fbb);
|
||||
|
||||
auto val_offset = schema::CreateAdder(
|
||||
*fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), attr->kernelW(), attr->kernelH(),
|
||||
attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
|
||||
attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Adder, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Adder::GetFormat() const { return this->primitive_->value_as_Adder()->format(); }
|
||||
int Adder::GetGroup() const { return this->primitive_->value_as_Adder()->group(); }
|
||||
int Adder::GetChannelIn() const { return this->primitive_->value_as_Adder()->channelIn(); }
|
||||
int Adder::GetChannelOut() const { return this->primitive_->value_as_Adder()->channelOut(); }
|
||||
int Adder::GetKernelW() const { return this->primitive_->value_as_Adder()->kernelW(); }
|
||||
int Adder::GetKernelH() const { return this->primitive_->value_as_Adder()->kernelH(); }
|
||||
int Adder::GetStrideW() const { return this->primitive_->value_as_Adder()->strideW(); }
|
||||
int Adder::GetStrideH() const { return this->primitive_->value_as_Adder()->strideH(); }
|
||||
int Adder::GetPadMode() const { return this->primitive_->value_as_Adder()->padMode(); }
|
||||
int Adder::GetPadUp() const { return this->primitive_->value_as_Adder()->padUp(); }
|
||||
int Adder::GetPadDown() const { return this->primitive_->value_as_Adder()->padDown(); }
|
||||
int Adder::GetPadLeft() const { return this->primitive_->value_as_Adder()->padLeft(); }
|
||||
int Adder::GetPadRight() const { return this->primitive_->value_as_Adder()->padRight(); }
|
||||
int Adder::GetDilateW() const { return this->primitive_->value_as_Adder()->dilateW(); }
|
||||
int Adder::GetDilateH() const { return this->primitive_->value_as_Adder()->dilateH(); }
|
||||
bool Adder::GetHasBias() const { return this->primitive_->value_as_Adder()->hasBias(); }
|
||||
int Adder::GetActivationType() const { return this->primitive_->value_as_Adder()->activationType(); }
|
||||
|
||||
PrimitiveC *AdderCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Adder>(primitive); }
|
||||
Registry AdderRegistry(schema::PrimitiveType_Adder, AdderCreator);
|
||||
#endif
|
||||
|
||||
int Adder::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
MS_ASSERT(inputs_.size() == 2);
|
||||
auto input0 = inputs_.front();
|
||||
MS_ASSERT(input0 != nullptr);
|
||||
MS_ASSERT(input0->shape().size() == 2);
|
||||
auto input1 = inputs_.at(1);
|
||||
MS_ASSERT(input1 != nullptr);
|
||||
MS_ASSERT(input1->shape().size() == 2);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
output->set_data_type(input0->data_type());
|
||||
output->set_format(input0->format());
|
||||
if (!infer_flag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
std::vector<int> in_shape;
|
||||
in_shape.push_back(input0->shape().at(0));
|
||||
in_shape.push_back(input1->shape().at(1));
|
||||
output->set_shape(in_shape);
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,22 +20,62 @@
|
|||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include <memory>
|
||||
#include "src/ops/conv2d.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Adder : public PrimitiveC {
|
||||
class Adder : public Conv2D {
|
||||
public:
|
||||
Adder() = default;
|
||||
~Adder() = default;
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Adder, PrimitiveC);
|
||||
explicit Adder(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
MS_DECLARE_PARENT(Adder, Conv2D);
|
||||
explicit Adder(schema::PrimitiveT *primitive) : Conv2D(primitive) {}
|
||||
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
void SetFormat(int format);
|
||||
void SetGroup(int group);
|
||||
void SetChannelIn(int channel_in);
|
||||
void SetChannelOut(int channel_out);
|
||||
void SetKernelW(int kernel_w);
|
||||
void SetKernelH(int kernel_h);
|
||||
void SetStrideW(int stride_w);
|
||||
void SetStrideH(int stride_h);
|
||||
void SetPadMode(int pad_mode);
|
||||
void SetPadUp(int pad_up);
|
||||
void SetPadDown(int pad_down);
|
||||
void SetPadLeft(int pad_left);
|
||||
void SetPadRight(int pad_right);
|
||||
void SetDilateW(int dilate_w);
|
||||
void SetDilateH(int dilate_h);
|
||||
void SetHasBias(bool has_bias);
|
||||
void SetActivationType(int activation_type);
|
||||
|
||||
private:
|
||||
void PopulaterAdderSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group);
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb);
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
|
||||
public:
|
||||
int GetFormat() const;
|
||||
int GetGroup() const;
|
||||
int GetChannelIn() const;
|
||||
int GetChannelOut() const;
|
||||
int GetKernelW() const;
|
||||
int GetKernelH() const;
|
||||
int GetStrideW() const;
|
||||
int GetStrideH() const;
|
||||
int GetPadMode() const;
|
||||
int GetPadUp() const;
|
||||
int GetPadDown() const;
|
||||
int GetPadLeft() const;
|
||||
int GetPadRight() const;
|
||||
int GetDilateW() const;
|
||||
int GetDilateH() const;
|
||||
bool GetHasBias() const;
|
||||
int GetActivationType() const;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,23 +34,23 @@ class Conv2D : public PrimitiveC {
|
|||
explicit Conv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
void SetFormat(int format);
|
||||
void SetGroup(int group);
|
||||
void SetChannelIn(int channel_in);
|
||||
void SetChannelOut(int channel_out);
|
||||
void SetKernelW(int kernel_w);
|
||||
void SetKernelH(int kernel_h);
|
||||
void SetStrideW(int stride_w);
|
||||
void SetStrideH(int stride_h);
|
||||
void SetPadMode(int pad_mode);
|
||||
void SetPadUp(int pad_up);
|
||||
void SetPadDown(int pad_down);
|
||||
void SetPadLeft(int pad_left);
|
||||
void SetPadRight(int pad_right);
|
||||
void SetDilateW(int dilate_w);
|
||||
void SetDilateH(int dilate_h);
|
||||
void SetHasBias(bool has_bias);
|
||||
void SetActivationType(int activation_type);
|
||||
virtual void SetFormat(int format);
|
||||
virtual void SetGroup(int group);
|
||||
virtual void SetChannelIn(int channel_in);
|
||||
virtual void SetChannelOut(int channel_out);
|
||||
virtual void SetKernelW(int kernel_w);
|
||||
virtual void SetKernelH(int kernel_h);
|
||||
virtual void SetStrideW(int stride_w);
|
||||
virtual void SetStrideH(int stride_h);
|
||||
virtual void SetPadMode(int pad_mode);
|
||||
virtual void SetPadUp(int pad_up);
|
||||
virtual void SetPadDown(int pad_down);
|
||||
virtual void SetPadLeft(int pad_left);
|
||||
virtual void SetPadRight(int pad_right);
|
||||
virtual void SetDilateW(int dilate_w);
|
||||
virtual void SetDilateH(int dilate_h);
|
||||
virtual void SetHasBias(bool has_bias);
|
||||
virtual void SetActivationType(int activation_type);
|
||||
|
||||
private:
|
||||
void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group,
|
||||
|
@ -67,23 +67,23 @@ class Conv2D : public PrimitiveC {
|
|||
int PadLeft() const;
|
||||
int PadRight() const;
|
||||
|
||||
int GetFormat() const;
|
||||
int GetGroup() const;
|
||||
int GetChannelIn() const;
|
||||
int GetChannelOut() const;
|
||||
int GetKernelW() const;
|
||||
int GetKernelH() const;
|
||||
int GetStrideW() const;
|
||||
int GetStrideH() const;
|
||||
int GetPadMode() const;
|
||||
int GetPadUp() const;
|
||||
int GetPadDown() const;
|
||||
int GetPadLeft() const;
|
||||
int GetPadRight() const;
|
||||
int GetDilateW() const;
|
||||
int GetDilateH() const;
|
||||
bool GetHasBias() const;
|
||||
int GetActivationType() const;
|
||||
virtual int GetFormat() const;
|
||||
virtual int GetGroup() const;
|
||||
virtual int GetChannelIn() const;
|
||||
virtual int GetChannelOut() const;
|
||||
virtual int GetKernelW() const;
|
||||
virtual int GetKernelH() const;
|
||||
virtual int GetStrideW() const;
|
||||
virtual int GetStrideH() const;
|
||||
virtual int GetPadMode() const;
|
||||
virtual int GetPadUp() const;
|
||||
virtual int GetPadDown() const;
|
||||
virtual int GetPadLeft() const;
|
||||
virtual int GetPadRight() const;
|
||||
virtual int GetDilateW() const;
|
||||
virtual int GetDilateH() const;
|
||||
virtual bool GetHasBias() const;
|
||||
virtual int GetActivationType() const;
|
||||
|
||||
protected:
|
||||
void ConvInferShape(int input_h, int input_w, int *output_h, int *output_w);
|
||||
|
|
|
@ -15,24 +15,53 @@
|
|||
*/
|
||||
|
||||
#include "src/ops/adder.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
#include "nnacl/adder.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
OpParameter *PopulateAdderParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
auto *adder_param = reinterpret_cast<AdderParameter *>(malloc(sizeof(AdderParameter)));
|
||||
if (adder_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc AdderParameter failed.";
|
||||
ConvParameter *conv_param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
|
||||
if (conv_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc ConvParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(adder_param, 0, sizeof(AdderParameter));
|
||||
adder_param->op_parameter_.type_ = primitive->Type();
|
||||
return reinterpret_cast<OpParameter *>(adder_param);
|
||||
memset(conv_param, 0, sizeof(ConvParameter));
|
||||
conv_param->op_parameter_.type_ = primitive->Type();
|
||||
auto adder_primitive =
|
||||
reinterpret_cast<mindspore::lite::Adder *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
conv_param->kernel_h_ = adder_primitive->GetKernelH();
|
||||
conv_param->kernel_w_ = adder_primitive->GetKernelW();
|
||||
conv_param->group_ = adder_primitive->GetGroup();
|
||||
conv_param->stride_h_ = adder_primitive->GetStrideH();
|
||||
conv_param->stride_w_ = adder_primitive->GetStrideW();
|
||||
|
||||
auto adder_lite_primitive = (lite::Adder *)primitive;
|
||||
conv_param->pad_u_ = adder_lite_primitive->PadUp();
|
||||
conv_param->pad_d_ = adder_lite_primitive->PadDown();
|
||||
conv_param->pad_l_ = adder_lite_primitive->PadLeft();
|
||||
conv_param->pad_r_ = adder_lite_primitive->PadRight();
|
||||
conv_param->dilation_h_ = adder_primitive->GetDilateH();
|
||||
conv_param->dilation_w_ = adder_primitive->GetDilateW();
|
||||
conv_param->input_channel_ = adder_primitive->GetChannelIn();
|
||||
conv_param->output_channel_ = adder_primitive->GetChannelOut();
|
||||
conv_param->group_ = adder_primitive->GetGroup();
|
||||
auto act_type = adder_primitive->GetActivationType();
|
||||
switch (act_type) {
|
||||
case schema::ActivationType_RELU:
|
||||
conv_param->act_type_ = ActType_Relu;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
conv_param->act_type_ = ActType_Relu6;
|
||||
break;
|
||||
default:
|
||||
conv_param->act_type_ = ActType_No;
|
||||
break;
|
||||
}
|
||||
return reinterpret_cast<OpParameter *>(conv_param);
|
||||
}
|
||||
Registry AdderParameterRegistry(schema::PrimitiveType_Adder, PopulateAdderParameter);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp32/adder_fp32.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "nnacl/fp32/adder_fp32.h"
|
||||
#include "nnacl/fp32/matmul_fp32.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_INFER_INVALID;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_Adder;
|
||||
using mindspore::schema::Format::Format_NHWC;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int AdderCPUKernel::InitWeightBias() {
|
||||
auto filter_tensor = in_tensors_.at(kWeightIndex);
|
||||
int kernel_h = filter_tensor->Height();
|
||||
int kernel_w = filter_tensor->Width();
|
||||
int in_channel = filter_tensor->Channel();
|
||||
int out_channel = filter_tensor->Batch();
|
||||
conv_param_->input_channel_ = in_channel;
|
||||
conv_param_->output_channel_ = out_channel;
|
||||
int kernel_plane = kernel_h * kernel_w;
|
||||
const int oc_block = C4NUM;
|
||||
int oc_block_num = UP_DIV(out_channel, C4NUM);
|
||||
int pack_weight_size = oc_block_num * oc_block * in_channel * kernel_plane;
|
||||
|
||||
auto origin_weight = reinterpret_cast<float *>(filter_tensor->MutableData());
|
||||
packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float)));
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc packed weight failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(packed_weight_, 0, pack_weight_size * sizeof(float));
|
||||
RowMajor2Col4Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane);
|
||||
|
||||
bias_data_ = reinterpret_cast<float *>(malloc(oc_block_num * oc_block * sizeof(float)));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc bias failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(bias_data_, 0, oc_block_num * oc_block * sizeof(float));
|
||||
|
||||
if (in_tensors_.size() == kInputSize2) {
|
||||
auto ori_bias = reinterpret_cast<float *>(in_tensors_.at(kBiasIndex)->MutableData());
|
||||
memcpy(bias_data_, ori_bias, out_channel * sizeof(float));
|
||||
} else {
|
||||
MS_ASSERT(in_tensors_.size() == kInputSize1);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int AdderCPUKernel::RunImpl(int task_id) {
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto ori_input_data = reinterpret_cast<float *>(input_tensor->data_c());
|
||||
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->data_c());
|
||||
AdderFp32(ori_input_data, packed_input_, packed_weight_, reinterpret_cast<float *>(bias_data_), col_major_input_,
|
||||
output_addr, task_id, conv_param_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int AdderImpl(void *cdata, int task_id) {
|
||||
auto adder = reinterpret_cast<AdderCPUKernel *>(cdata);
|
||||
auto error_code = adder->RunImpl(task_id);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "Adder Run error task_id[" << task_id << "] error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int AdderCPUKernel::Run() {
|
||||
auto ret = InitTmpBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init tmp buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
int error_code = ParallelLaunch(this->context_->thread_pool_, AdderImpl, this, thread_count_);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "adder error error_code[" << error_code << "]";
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
FreeTmpBuffer();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuAdderFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter,
|
||||
const InnerContext *ctx, const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
MS_ASSERT(op_parameter != nullptr);
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_Adder);
|
||||
MS_ASSERT(desc.data_type == kNumberTypeFloat32);
|
||||
kernel::LiteKernel *kernel = new (std::nothrow) kernel::AdderCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
free(op_parameter);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK && ret != RET_INFER_INVALID) {
|
||||
delete kernel;
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Adder, CpuAdderFp32KernelCreator)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ADDER_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ADDER_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "src/runtime/kernel/arm/fp32/convolution_fp32.h"
|
||||
#include "nnacl/fp32/conv_fp32.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class AdderCPUKernel : public ConvolutionCPUKernel {
|
||||
public:
|
||||
AdderCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: ConvolutionCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
~AdderCPUKernel() override = default;
|
||||
|
||||
int InitWeightBias() override;
|
||||
int Run() override;
|
||||
int RunImpl(int task_id) override;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ADDER_H_
|
|
@ -38,13 +38,13 @@ class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel {
|
|||
}
|
||||
|
||||
int Init() override;
|
||||
virtual int InitWeightBias();
|
||||
int InitTmpBuffer();
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int RunImpl(int task_id);
|
||||
int InitWeightBias();
|
||||
int InitTmpBuffer();
|
||||
virtual int RunImpl(int task_id);
|
||||
|
||||
private:
|
||||
protected:
|
||||
void FreeTmpBuffer() {
|
||||
if (packed_input_ != nullptr) {
|
||||
ctx_->allocator->Free(packed_input_);
|
||||
|
@ -55,8 +55,10 @@ class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel {
|
|||
col_major_input_ = nullptr;
|
||||
}
|
||||
}
|
||||
float *packed_input_ = nullptr;
|
||||
|
||||
protected:
|
||||
float *packed_weight_ = nullptr;
|
||||
float *packed_input_ = nullptr;
|
||||
float *col_major_input_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
|
Loading…
Reference in New Issue