diff --git a/mindspore/lite/nnacl/assembly/arm64/AdderFp32.S b/mindspore/lite/nnacl/assembly/arm64/AdderFp32.S new file mode 100644 index 00000000000..13fb0ace0d9 --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm64/AdderFp32.S @@ -0,0 +1,633 @@ +#ifdef __aarch64__ + .text + .align 5 + .global AdderFloatNeon64 +#ifndef __APPLE__ + .type AdderFloatNeon64, %function +#endif + +// void AdderFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +AdderFloatNeon64: + sub sp, sp, #144 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + + ldr x8, [sp] + + mov x18, #48 // sizeof(float) * 12 + mul x17, x5, x18 // block stride of lhs/rhs: sizeof(float) * 12 * depth + + mov x18, #4 + mul x8, x8, x18 + +LoopRowStart: + cmp x6, #4 + ble LoopRow4 + cmp x6, #8 + blt LoopRow8 + +LoopRow: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol: + mov x11, x2 + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + LoopDepthStart: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s}, [x14], #16 + + dup v8.4s, v0.s[0] + fabd v9.4s, v3.4s, v8.4s + dup v10.4s, v0.s[1] + fabd v11.4s, v3.4s, v10.4s + dup v12.4s, v0.s[2] + fabd v13.4s, v3.4s, v12.4s + dup v14.4s, v0.s[3] + fabd v15.4s, v3.4s, v14.4s + + dup v16.4s, v1.s[0] + fabd v17.4s, v3.4s, v16.4s + dup v18.4s, v1.s[1] + fabd v19.4s, v3.4s, v18.4s + dup v20.4s, v1.s[2] + fabd v21.4s, v3.4s, v20.4s + dup v22.4s, v1.s[3] + fabd v23.4s, v3.4s, v22.4s + + dup v24.4s, v2.s[0] + fabd v25.4s, v3.4s, v24.4s + dup v26.4s, v2.s[1] + fabd v27.4s, v3.4s, v26.4s + dup v28.4s, v2.s[2] + fabd v29.4s, v3.4s, v28.4s + dup v30.4s, v2.s[3] + fabd v31.4s, v3.4s, v30.4s + + subs x19, x19, #1 + beq Bias + + LoopDepth: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s}, [x14], #16 + + dup v8.4s, v0.s[0] + fabd v8.4s, v3.4s, v8.4s + fadd v9.4s, v9.4s, v8.4s + dup v10.4s, v0.s[1] + fabd v10.4s, v3.4s, v10.4s + fadd v11.4s, v11.4s, v10.4s + dup v12.4s, v0.s[2] + fabd v12.4s, v3.4s, v12.4s + fadd v13.4s, v13.4s, v12.4s + dup v14.4s, v0.s[3] + fabd v14.4s, v3.4s, v14.4s + fadd v15.4s, v15.4s, v14.4s + + dup v16.4s, v1.s[0] + fabd v16.4s, v3.4s, v16.4s + fadd v17.4s, v17.4s, v16.4s + dup v18.4s, v1.s[1] + fabd v18.4s, v3.4s, v18.4s + fadd v19.4s, v19.4s, v18.4s + dup v20.4s, v1.s[2] + fabd v20.4s, v3.4s, v20.4s + fadd v21.4s, v21.4s, v20.4s + dup v22.4s, v1.s[3] + fabd v22.4s, v3.4s, v22.4s + fadd v23.4s, v23.4s, v22.4s + + dup v24.4s, v2.s[0] + fabd v24.4s, v3.4s, v24.4s + fadd v25.4s, v25.4s, v24.4s + dup v26.4s, v2.s[1] + fabd v26.4s, v3.4s, v26.4s + fadd v27.4s, v27.4s, v26.4s + dup v28.4s, v2.s[2] + fabd v28.4s, v3.4s, v28.4s + fadd v29.4s, v29.4s, v28.4s + dup v30.4s, v2.s[3] + fabd v30.4s, v3.4s, v30.4s + fadd v31.4s, v31.4s, v30.4s + + subs x19, x19, #1 + bgt LoopDepth + + Bias: + fneg v9.4s, v9.4s + fneg v11.4s, v11.4s + fneg v13.4s, v13.4s + fneg v15.4s, v15.4s + fneg v17.4s, v17.4s + fneg v19.4s, v19.4s + fneg v21.4s, v21.4s + fneg v23.4s, v23.4s + fneg v25.4s, v25.4s + fneg v27.4s, v27.4s + fneg v29.4s, v29.4s + fneg v31.4s, v31.4s + cbz x3, Activation + ld1 {v0.4s}, [x12], #16 + fadd v9.4s, v9.4s, v0.4s + fadd v11.4s, v11.4s, v0.4s + fadd v13.4s, v13.4s, v0.4s + fadd v15.4s, v15.4s, v0.4s + fadd v17.4s, v17.4s, v0.4s + fadd v19.4s, v19.4s, v0.4s + fadd v21.4s, v21.4s, v0.4s + fadd v23.4s, v23.4s, v0.4s + fadd v25.4s, v25.4s, v0.4s + fadd v27.4s, v27.4s, v0.4s + fadd v29.4s, v29.4s, v0.4s + fadd v31.4s, v31.4s, v0.4s + + Activation: + cmp x4, #3 + beq Relu6 + cmp x4, #1 + beq Relu + b Write + + Relu6: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + fmin v25.4s, v25.4s, v2.4s + fmin v27.4s, v27.4s, v2.4s + fmin v29.4s, v29.4s, v2.4s + fmin v31.4s, v31.4s, v2.4s + + Relu: + dup v3.4s, wzr + fmax v9.4s, v9.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + fmax v25.4s, v25.4s, v3.4s + fmax v27.4s, v27.4s, v3.4s + fmax v29.4s, v29.4s, v3.4s + fmax v31.4s, v31.4s, v3.4s + b Write + +LoopRow8: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol8: + mov x11, x2 + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + LoopDepthStart8: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s}, [x14], #16 + + dup v8.4s, v0.s[0] + fabd v9.4s, v3.4s, v8.4s + dup v10.4s, v0.s[1] + fabd v11.4s, v3.4s, v10.4s + dup v12.4s, v0.s[2] + fabd v13.4s, v3.4s, v12.4s + dup v14.4s, v0.s[3] + fabd v15.4s, v3.4s, v14.4s + + dup v16.4s, v1.s[0] + fabd v17.4s, v3.4s, v16.4s + dup v18.4s, v1.s[1] + fabd v19.4s, v3.4s, v18.4s + dup v20.4s, v1.s[2] + fabd v21.4s, v3.4s, v20.4s + dup v22.4s, v1.s[3] + fabd v23.4s, v3.4s, v22.4s + + subs x19, x19, #1 + beq Bias8 + + LoopDepth8: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s}, [x14], #16 + dup v8.4s, v0.s[0] + fabd v8.4s, v3.4s, v8.4s + fadd v9.4s, v9.4s, v8.4s + dup v10.4s, v0.s[1] + fabd v10.4s, v3.4s, v10.4s + fadd v11.4s, v11.4s, v10.4s + dup v12.4s, v0.s[2] + fabd v12.4s, v3.4s, v12.4s + fadd v13.4s, v13.4s, v12.4s + dup v14.4s, v0.s[3] + fabd v14.4s, v3.4s, v14.4s + fadd v15.4s, v15.4s, v14.4s + + dup v16.4s, v1.s[0] + fabd v16.4s, v3.4s, v16.4s + fadd v17.4s, v17.4s, v16.4s + dup v18.4s, v1.s[1] + fabd v18.4s, v3.4s, v18.4s + fadd v19.4s, v19.4s, v18.4s + dup v20.4s, v1.s[2] + fabd v20.4s, v3.4s, v20.4s + fadd v21.4s, v21.4s, v20.4s + dup v22.4s, v1.s[3] + fabd v22.4s, v3.4s, v22.4s + fadd v23.4s, v23.4s, v22.4s + + subs x19, x19, #1 + bgt LoopDepth8 + + Bias8: + fneg v9.4s, v9.4s + fneg v11.4s, v11.4s + fneg v13.4s, v13.4s + fneg v15.4s, v15.4s + fneg v17.4s, v17.4s + fneg v19.4s, v19.4s + fneg v21.4s, v21.4s + fneg v23.4s, v23.4s + cbz x3, Activation8 + ld1 {v0.4s}, [x12], #16 + fadd v9.4s, v9.4s, v0.4s + fadd v11.4s, v11.4s, v0.4s + fadd v13.4s, v13.4s, v0.4s + fadd v15.4s, v15.4s, v0.4s + fadd v17.4s, v17.4s, v0.4s + fadd v19.4s, v19.4s, v0.4s + fadd v21.4s, v21.4s, v0.4s + fadd v23.4s, v23.4s, v0.4s + + Activation8: + cmp x4, #3 + beq Relu68 + cmp x4, #1 + beq Relu8 + b Write + + Relu68: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + Relu8: + dup v3.4s, wzr + fmax v9.4s, v9.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + b Write + +LoopRow4: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol4: + mov x11, x2 + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + LoopDepthStart4: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s}, [x14], #16 + dup v8.4s, v0.s[0] + fabd v9.4s, v3.4s, v8.4s + dup v10.4s, v0.s[1] + fabd v11.4s, v3.4s, v10.4s + dup v12.4s, v0.s[2] + fabd v13.4s, v3.4s, v12.4s + dup v14.4s, v0.s[3] + fabd v15.4s, v3.4s, v14.4s + + subs x19, x19, #1 + beq Bias4 + + LoopDepth4: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s}, [x14], #16 + dup v8.4s, v0.s[0] + fabd v8.4s, v3.4s, v8.4s + fadd v9.4s, v9.4s, v8.4s + dup v10.4s, v0.s[1] + fabd v10.4s, v3.4s, v10.4s + fadd v11.4s, v11.4s, v10.4s + dup v12.4s, v0.s[2] + fabd v12.4s, v3.4s, v12.4s + fadd v13.4s, v13.4s, v12.4s + dup v14.4s, v0.s[3] + fabd v14.4s, v3.4s, v14.4s + fadd v15.4s, v15.4s, v14.4s + + subs x19, x19, #1 + bgt LoopDepth4 + + Bias4: + fneg v9.4s, v9.4s + fneg v11.4s, v11.4s + fneg v13.4s, v13.4s + fneg v15.4s, v15.4s + cbz x3, Activation4 + ld1 {v0.4s}, [x12], #16 + + fadd v9.4s, v9.4s, v0.4s + fadd v11.4s, v11.4s, v0.4s + fadd v13.4s, v13.4s, v0.4s + fadd v15.4s, v15.4s, v0.4s + + Activation4: + cmp x4, #3 + beq Relu64 + cmp x4, #1 + beq Relu4 + b Write + + Relu64: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + + Relu4: + dup v3.4s, wzr + fmax v9.4s, v9.4s, v2.4s + fmax v11.4s, v11.4s, v2.4s + fmax v13.4s, v13.4s, v2.4s + fmax v15.4s, v15.4s, v2.4s + b Write + + Write: + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + b Write4 + + Write1: + add x2, x2, #4 + str s9, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str s11, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str s13, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str s15, [x11] + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str s17, [x11] + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str s19, [x11] + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str s21, [x11] + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str s23, [x11] + cmp x6, #8 + beq WriteEnd + add x11, x11, x8 + str s25, [x11] + cmp x6, #9 + beq WriteEnd + add x11, x11, x8 + str s27, [x11] + cmp x6, #10 + beq WriteEnd + add x11, x11, x8 + str s29, [x11] + cmp x6, #11 + beq WriteEnd + add x11, x11, x8 + str s31, [x11] + add x11, x11, x8 + add x11, x11, #4 + b WriteEnd + Write2: + add x2, x2, #8 + str d9, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str d11, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str d13, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str d15, [x11] + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str d17, [x11] + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str d19, [x11] + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str d21, [x11] + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str d23, [x11] + cmp x6, #8 + beq WriteEnd + add x11, x11, x8 + str d25, [x11] + cmp x6, #9 + beq WriteEnd + add x11, x11, x8 + str d27, [x11] + cmp x6, #10 + beq WriteEnd + add x11, x11, x8 + str d29, [x11] + cmp x6, #11 + beq WriteEnd + add x11, x11, x8 + str d31, [x11] + add x11, x11, x8 + add x11, x11, #8 + b WriteEnd + Write3: + add x2, x2, #12 + add x19, x11, #8 + str d9, [x11] + st1 {v9.s}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str d11, [x11] + st1 {v11.s}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str d13, [x11] + st1 {v13.s}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str d15, [x11] + st1 {v15.s}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str d17, [x11] + st1 {v17.s}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str d19, [x11] + st1 {v19.s}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str d21, [x11] + st1 {v21.s}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str d23, [x11] + st1 {v23.s}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + add x11, x11, x8 + str d25, [x11] + st1 {v25.s}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + add x11, x11, x8 + str d27, [x11] + st1 {v27.s}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + add x11, x11, x8 + str d29, [x11] + st1 {v29.s}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + add x11, x11, x8 + str d31, [x11] + st1 {v31.s}[2], [x19] + add x11, x11, x8 + add x11, x11, #12 + b WriteEnd + Write4: + add x2, x2, #16 + st1 {v9.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v11.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v13.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v15.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v17.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v19.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v21.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v25.4s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v27.4s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v29.4s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v31.4s}, [x11], x8 + add x11, x11, #16 + b WriteEnd + + WriteEnd: + subs x13, x13, #4 // rhs col - 4 + ble LoopColEnd + cmp x6, #4 + ble LoopCol4 + cmp x6, #8 + ble LoopCol8 + b LoopCol + +LoopColEnd: + add x0, x0, x17 + mov x18, #4 + mul x18, x18, x7 + sub x11, x11, x18 + mov x2, x11 + subs x6, x6, #12 + bgt LoopRowStart + + sub sp, sp, #144 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ret +#endif diff --git a/mindspore/lite/nnacl/fp32/adder_fp32.c b/mindspore/lite/nnacl/fp32/adder_fp32.c new file mode 100644 index 00000000000..61efa9b2be0 --- /dev/null +++ b/mindspore/lite/nnacl/fp32/adder_fp32.c @@ -0,0 +1,90 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/adder_fp32.h" +#include +#include +#include "nnacl/fp32/common_func_fp32.h" +#include "nnacl/fp32/matmul_fp32.h" + +void Adder12x4(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, + int col, int stride) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r12div = r / 12, r12mod = r % 12; + int c4div = c / 4, c4mod = c % 4; + size_t ci = r * stride + c; + float value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r12div * deep * 12 + d * 12 + r12mod; + size_t bi = c4div * deep * 4 + d * 4 + c4mod; + value += fabsf(a[ai] - b[bi]); + } + value = -value; + if (bias != NULL) value += bias[c]; + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); + if (act_type != ActType_No) value = MSMAX(0.0f, value); + dst[ci] = value; + } + } +} + +void AdderOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col, + size_t stride) { +#ifdef ENABLE_ARM64 + AdderFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride); +#else + Adder12x4(a, b, c, bias, act_type, deep, row, col, stride); +#endif +} + +void AdderFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, + float *col_major_input, float *output_data, int task_id, ConvParameter *conv_param) { + int out_channel = conv_param->output_channel_; + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + int output_count = conv_param->output_h_ * conv_param->output_w_; +#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) + const int cal_num = C4NUM; +#else + const int cal_num = C12NUM; +#endif + int output_tile_count = UP_DIV(output_count, cal_num); + + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_batch_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + int out_batch_offset = b * out_channel * output_count; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { + int start_index = thread_id * cal_num; + int real_cal_num = (output_count - start_index) < cal_num ? (output_count - start_index) : cal_num; + float *gemm_input = packed_input + task_id * deep * cal_num; + float *col_major_gemm_input = col_major_input + task_id * deep * cal_num; + size_t packed_input_size = deep * cal_num * sizeof(float); + memset(gemm_input, 0, packed_input_size); + memset(col_major_gemm_input, 0, packed_input_size); + Im2ColPackUnitFp32(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); + + int out_offset = thread_id * cal_num * out_channel + out_batch_offset; + float *gemm_output = output_data + out_offset; +#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) + RowMajor2Col4Major(gemm_input, col_major_gemm_input, cal_num, deep); +#else + RowMajor2Col12Major(gemm_input, col_major_gemm_input, cal_num, deep); +#endif + AdderOpt(col_major_gemm_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep, real_cal_num, + out_channel, out_channel); + } + } +} diff --git a/mindspore/lite/nnacl/fp32/adder_fp32.h b/mindspore/lite/nnacl/fp32/adder_fp32.h new file mode 100644 index 00000000000..fd3956ac085 --- /dev/null +++ b/mindspore/lite/nnacl/fp32/adder_fp32.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_NNACL_FP32_ADDER_H_ +#define MINDSPORE_LITE_NNACL_FP32_ADDER_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/pack.h" +#include "nnacl/op_base.h" +#include "nnacl/common_func.h" +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef ENABLE_ARM64 +void AdderFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, size_t stride); +#endif + +void AdderOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col, + size_t stride); + +void AdderFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, + float *col_major_input, float *output_data, int task_id, ConvParameter *conv_param); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_NNACL_FP32_ADDER_H_ diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index fe37c527807..5c32d7642a9 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -207,6 +207,26 @@ table Conv2D { activationType: ActivationType = 0; } +table Adder { + format: Format = 0; + group: int; + channelIn: int; + channelOut: int; + kernelW: int; + kernelH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + dilateW: int; + dilateH: int; + hasBias: bool = false; + activationType: ActivationType = 0; +} + table Conv2DGradFilter { format: Format = 0; group: int; @@ -1176,6 +1196,3 @@ table All { table Assert { summarize : int; } - -table Adder { -} diff --git a/mindspore/lite/src/ops/adder.cc b/mindspore/lite/src/ops/adder.cc index 1a4d1c1a2a2..720b75762d5 100644 --- a/mindspore/lite/src/ops/adder.cc +++ b/mindspore/lite/src/ops/adder.cc @@ -15,6 +15,14 @@ */ #include "src/ops/adder.h" +#include +#include + +#include "include/errorcode.h" +#include "src/common/log_adapter.h" +#ifdef PRIMITIVE_WRITEABLE +#include "tools/converter/quantizer/quantize_util.h" +#endif #ifndef PRIMITIVE_WRITEABLE #include "src/ops/ops_register.h" @@ -23,6 +31,118 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE +int Adder::GetFormat() const { return this->primitive_->value.AsAdder()->format; } +int Adder::GetGroup() const { return this->primitive_->value.AsAdder()->group; } +int Adder::GetChannelIn() const { return this->primitive_->value.AsAdder()->channelIn; } +int Adder::GetChannelOut() const { return this->primitive_->value.AsAdder()->channelOut; } +int Adder::GetKernelW() const { return this->primitive_->value.AsAdder()->kernelW; } +int Adder::GetKernelH() const { return this->primitive_->value.AsAdder()->kernelH; } +int Adder::GetStrideW() const { return this->primitive_->value.AsAdder()->strideW; } +int Adder::GetStrideH() const { return this->primitive_->value.AsAdder()->strideH; } +int Adder::GetPadMode() const { return this->primitive_->value.AsAdder()->padMode; } +int Adder::GetPadUp() const { return this->primitive_->value.AsAdder()->padUp; } +int Adder::GetPadDown() const { return this->primitive_->value.AsAdder()->padDown; } +int Adder::GetPadLeft() const { return this->primitive_->value.AsAdder()->padLeft; } +int Adder::GetPadRight() const { return this->primitive_->value.AsAdder()->padRight; } +int Adder::GetDilateW() const { return this->primitive_->value.AsAdder()->dilateW; } +int Adder::GetDilateH() const { return this->primitive_->value.AsAdder()->dilateH; } +bool Adder::GetHasBias() const { return this->primitive_->value.AsAdder()->hasBias; } +int Adder::GetActivationType() const { return this->primitive_->value.AsAdder()->activationType; } + +void Adder::SetFormat(int format) { this->primitive_->value.AsAdder()->format = (schema::Format)format; } +void Adder::SetGroup(int group) { this->primitive_->value.AsAdder()->group = group; } +void Adder::SetChannelIn(int channel_in) { this->primitive_->value.AsAdder()->channelIn = channel_in; } +void Adder::SetChannelOut(int channel_out) { this->primitive_->value.AsAdder()->channelOut = channel_out; } +void Adder::SetKernelW(int kernel_w) { this->primitive_->value.AsAdder()->kernelW = kernel_w; } +void Adder::SetKernelH(int kernel_h) { this->primitive_->value.AsAdder()->kernelH = kernel_h; } +void Adder::SetStrideW(int stride_w) { this->primitive_->value.AsAdder()->strideW = stride_w; } +void Adder::SetStrideH(int stride_h) { this->primitive_->value.AsAdder()->strideH = stride_h; } +void Adder::SetPadMode(int pad_mode) { this->primitive_->value.AsAdder()->padMode = (schema::PadMode)pad_mode; } +void Adder::SetPadUp(int pad_up) { this->primitive_->value.AsAdder()->padUp = pad_up; } +void Adder::SetPadDown(int pad_down) { this->primitive_->value.AsAdder()->padDown = pad_down; } +void Adder::SetPadLeft(int pad_left) { this->primitive_->value.AsAdder()->padLeft = pad_left; } +void Adder::SetPadRight(int pad_right) { this->primitive_->value.AsAdder()->padRight = pad_right; } +void Adder::SetDilateW(int dilate_w) { this->primitive_->value.AsAdder()->dilateW = dilate_w; } +void Adder::SetDilateH(int dilate_h) { this->primitive_->value.AsAdder()->dilateH = dilate_h; } +void Adder::SetHasBias(bool has_bias) { this->primitive_->value.AsAdder()->hasBias = has_bias; } +void Adder::SetActivationType(int activation_type) { + this->primitive_->value.AsAdder()->activationType = (schema::ActivationType)activation_type; +} + +void Adder::PopulaterAdderSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group) { + auto attr = std::make_unique(); + attr->group = group; + auto format = GetValue(prim.GetAttr("data_format")); + if (format == "NCHW") { + attr->format = schema::Format::Format_NCHW; + } else if (format == "NHWC") { + attr->format = schema::Format::Format_NHWC; + } else { + attr->format = schema::Format::Format_NUM_OF_FORMAT; + } + auto pad_list = CastToInt(prim.GetAttr("pad_list")); + attr->padUp = pad_list[0]; + attr->padDown = pad_list[1]; + attr->padLeft = pad_list[2]; + attr->padRight = pad_list[3]; + + auto dilation = CastToInt(prim.GetAttr("dilation")); + attr->dilateH = dilation[2]; + attr->dilateW = dilation[3]; + + auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); + attr->kernelH = kernel_size[0]; + attr->kernelW = kernel_size[1]; + + auto stride = CastToInt(prim.GetAttr("stride")); + attr->strideH = stride[2]; + attr->strideW = stride[3]; + + attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front(); + + auto pad_mode = GetValue(prim.GetAttr("pad_mode")); + if (pad_mode == "valid") { + attr->padMode = schema::PadMode_VALID; + } else if (pad_mode == "same") { + attr->padMode = schema::PadMode_SAME_UPPER; + } else { + attr->padMode = schema::PadMode_NOTSET; + } + + if (prim.GetAttr("activation_name") != nullptr) { + auto activate_name = GetValue(prim.GetAttr("activation_name")); + attr->activationType = kActivationTypeMap[activate_name]; + } else { + attr->activationType = schema::ActivationType_NO_ACTIVATION; + } + + primitive->value.type = schema::PrimitiveType_Adder; + primitive->value.value = attr.release(); +} + +int Adder::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Adder; + } + if (this->primitive_->value.type != schema::PrimitiveType_Adder) { + MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; + return RET_ERROR; + } + auto groupAttr = prim.GetAttr("group"); + if (groupAttr == nullptr) { + MS_LOG(ERROR) << "conv2d op has no group attr,please check pb model"; + return RET_NULL_PTR; + } + int group = CastToInt(groupAttr).front(); + PopulaterAdderSingleGroup(prim, this->primitive_, group); + PopulaterQuantParam(prim, inputs); + return RET_OK; +} #else int Adder::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { @@ -33,39 +153,36 @@ int Adder::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: MS_LOG(ERROR) << "value_as_Adder return nullptr"; return RET_ERROR; } - auto val_offset = schema::CreateAdder(*fbb); + + auto val_offset = schema::CreateAdder( + *fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), attr->kernelW(), attr->kernelH(), + attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(), + attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType()); auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Adder, val_offset.o); fbb->Finish(prim_offset); return RET_OK; } +int Adder::GetFormat() const { return this->primitive_->value_as_Adder()->format(); } +int Adder::GetGroup() const { return this->primitive_->value_as_Adder()->group(); } +int Adder::GetChannelIn() const { return this->primitive_->value_as_Adder()->channelIn(); } +int Adder::GetChannelOut() const { return this->primitive_->value_as_Adder()->channelOut(); } +int Adder::GetKernelW() const { return this->primitive_->value_as_Adder()->kernelW(); } +int Adder::GetKernelH() const { return this->primitive_->value_as_Adder()->kernelH(); } +int Adder::GetStrideW() const { return this->primitive_->value_as_Adder()->strideW(); } +int Adder::GetStrideH() const { return this->primitive_->value_as_Adder()->strideH(); } +int Adder::GetPadMode() const { return this->primitive_->value_as_Adder()->padMode(); } +int Adder::GetPadUp() const { return this->primitive_->value_as_Adder()->padUp(); } +int Adder::GetPadDown() const { return this->primitive_->value_as_Adder()->padDown(); } +int Adder::GetPadLeft() const { return this->primitive_->value_as_Adder()->padLeft(); } +int Adder::GetPadRight() const { return this->primitive_->value_as_Adder()->padRight(); } +int Adder::GetDilateW() const { return this->primitive_->value_as_Adder()->dilateW(); } +int Adder::GetDilateH() const { return this->primitive_->value_as_Adder()->dilateH(); } +bool Adder::GetHasBias() const { return this->primitive_->value_as_Adder()->hasBias(); } +int Adder::GetActivationType() const { return this->primitive_->value_as_Adder()->activationType(); } + PrimitiveC *AdderCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } Registry AdderRegistry(schema::PrimitiveType_Adder, AdderCreator); #endif - -int Adder::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive_ != nullptr); - MS_ASSERT(inputs_.size() == 2); - auto input0 = inputs_.front(); - MS_ASSERT(input0 != nullptr); - MS_ASSERT(input0->shape().size() == 2); - auto input1 = inputs_.at(1); - MS_ASSERT(input1 != nullptr); - MS_ASSERT(input1->shape().size() == 2); - auto output = outputs_.front(); - MS_ASSERT(output != nullptr); - - output->set_data_type(input0->data_type()); - output->set_format(input0->format()); - if (!infer_flag()) { - return RET_OK; - } - std::vector in_shape; - in_shape.push_back(input0->shape().at(0)); - in_shape.push_back(input1->shape().at(1)); - output->set_shape(in_shape); - - return RET_OK; -} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/adder.h b/mindspore/lite/src/ops/adder.h index 486572e5f2c..cab34abef31 100644 --- a/mindspore/lite/src/ops/adder.h +++ b/mindspore/lite/src/ops/adder.h @@ -20,22 +20,62 @@ #include #include #include -#include "src/ops/primitive_c.h" +#include +#include "src/ops/conv2d.h" namespace mindspore { namespace lite { -class Adder : public PrimitiveC { +class Adder : public Conv2D { public: Adder() = default; ~Adder() = default; #ifdef PRIMITIVE_WRITEABLE - MS_DECLARE_PARENT(Adder, PrimitiveC); - explicit Adder(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + MS_DECLARE_PARENT(Adder, Conv2D); + explicit Adder(schema::PrimitiveT *primitive) : Conv2D(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; + void SetFormat(int format); + void SetGroup(int group); + void SetChannelIn(int channel_in); + void SetChannelOut(int channel_out); + void SetKernelW(int kernel_w); + void SetKernelH(int kernel_h); + void SetStrideW(int stride_w); + void SetStrideH(int stride_h); + void SetPadMode(int pad_mode); + void SetPadUp(int pad_up); + void SetPadDown(int pad_down); + void SetPadLeft(int pad_left); + void SetPadRight(int pad_right); + void SetDilateW(int dilate_w); + void SetDilateH(int dilate_h); + void SetHasBias(bool has_bias); + void SetActivationType(int activation_type); + + private: + void PopulaterAdderSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group); #else - int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb); #endif - int InferShape(std::vector inputs_, std::vector outputs_) override; + + public: + int GetFormat() const; + int GetGroup() const; + int GetChannelIn() const; + int GetChannelOut() const; + int GetKernelW() const; + int GetKernelH() const; + int GetStrideW() const; + int GetStrideH() const; + int GetPadMode() const; + int GetPadUp() const; + int GetPadDown() const; + int GetPadLeft() const; + int GetPadRight() const; + int GetDilateW() const; + int GetDilateH() const; + bool GetHasBias() const; + int GetActivationType() const; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/conv2d.h b/mindspore/lite/src/ops/conv2d.h index 5f424b40b25..cceda85b45a 100644 --- a/mindspore/lite/src/ops/conv2d.h +++ b/mindspore/lite/src/ops/conv2d.h @@ -34,23 +34,23 @@ class Conv2D : public PrimitiveC { explicit Conv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; - void SetFormat(int format); - void SetGroup(int group); - void SetChannelIn(int channel_in); - void SetChannelOut(int channel_out); - void SetKernelW(int kernel_w); - void SetKernelH(int kernel_h); - void SetStrideW(int stride_w); - void SetStrideH(int stride_h); - void SetPadMode(int pad_mode); - void SetPadUp(int pad_up); - void SetPadDown(int pad_down); - void SetPadLeft(int pad_left); - void SetPadRight(int pad_right); - void SetDilateW(int dilate_w); - void SetDilateH(int dilate_h); - void SetHasBias(bool has_bias); - void SetActivationType(int activation_type); + virtual void SetFormat(int format); + virtual void SetGroup(int group); + virtual void SetChannelIn(int channel_in); + virtual void SetChannelOut(int channel_out); + virtual void SetKernelW(int kernel_w); + virtual void SetKernelH(int kernel_h); + virtual void SetStrideW(int stride_w); + virtual void SetStrideH(int stride_h); + virtual void SetPadMode(int pad_mode); + virtual void SetPadUp(int pad_up); + virtual void SetPadDown(int pad_down); + virtual void SetPadLeft(int pad_left); + virtual void SetPadRight(int pad_right); + virtual void SetDilateW(int dilate_w); + virtual void SetDilateH(int dilate_h); + virtual void SetHasBias(bool has_bias); + virtual void SetActivationType(int activation_type); private: void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group, @@ -67,23 +67,23 @@ class Conv2D : public PrimitiveC { int PadLeft() const; int PadRight() const; - int GetFormat() const; - int GetGroup() const; - int GetChannelIn() const; - int GetChannelOut() const; - int GetKernelW() const; - int GetKernelH() const; - int GetStrideW() const; - int GetStrideH() const; - int GetPadMode() const; - int GetPadUp() const; - int GetPadDown() const; - int GetPadLeft() const; - int GetPadRight() const; - int GetDilateW() const; - int GetDilateH() const; - bool GetHasBias() const; - int GetActivationType() const; + virtual int GetFormat() const; + virtual int GetGroup() const; + virtual int GetChannelIn() const; + virtual int GetChannelOut() const; + virtual int GetKernelW() const; + virtual int GetKernelH() const; + virtual int GetStrideW() const; + virtual int GetStrideH() const; + virtual int GetPadMode() const; + virtual int GetPadUp() const; + virtual int GetPadDown() const; + virtual int GetPadLeft() const; + virtual int GetPadRight() const; + virtual int GetDilateW() const; + virtual int GetDilateH() const; + virtual bool GetHasBias() const; + virtual int GetActivationType() const; protected: void ConvInferShape(int input_h, int input_w, int *output_h, int *output_w); diff --git a/mindspore/lite/src/ops/populate/adder_populate.cc b/mindspore/lite/src/ops/populate/adder_populate.cc index aceca8ce11d..59ab043381d 100644 --- a/mindspore/lite/src/ops/populate/adder_populate.cc +++ b/mindspore/lite/src/ops/populate/adder_populate.cc @@ -15,24 +15,53 @@ */ #include "src/ops/adder.h" +#include "src/common/log_adapter.h" +#include "nnacl/conv_parameter.h" #include "src/ops/primitive_c.h" #include "src/ops/populate/populate_register.h" -#include "nnacl/adder.h" namespace mindspore { namespace lite { - OpParameter *PopulateAdderParameter(const mindspore::lite::PrimitiveC *primitive) { - auto *adder_param = reinterpret_cast(malloc(sizeof(AdderParameter))); - if (adder_param == nullptr) { - MS_LOG(ERROR) << "malloc AdderParameter failed."; + ConvParameter *conv_param = reinterpret_cast(malloc(sizeof(ConvParameter))); + if (conv_param == nullptr) { + MS_LOG(ERROR) << "malloc ConvParameter failed."; return nullptr; } - memset(adder_param, 0, sizeof(AdderParameter)); - adder_param->op_parameter_.type_ = primitive->Type(); - return reinterpret_cast(adder_param); + memset(conv_param, 0, sizeof(ConvParameter)); + conv_param->op_parameter_.type_ = primitive->Type(); + auto adder_primitive = + reinterpret_cast(const_cast(primitive)); + conv_param->kernel_h_ = adder_primitive->GetKernelH(); + conv_param->kernel_w_ = adder_primitive->GetKernelW(); + conv_param->group_ = adder_primitive->GetGroup(); + conv_param->stride_h_ = adder_primitive->GetStrideH(); + conv_param->stride_w_ = adder_primitive->GetStrideW(); + + auto adder_lite_primitive = (lite::Adder *)primitive; + conv_param->pad_u_ = adder_lite_primitive->PadUp(); + conv_param->pad_d_ = adder_lite_primitive->PadDown(); + conv_param->pad_l_ = adder_lite_primitive->PadLeft(); + conv_param->pad_r_ = adder_lite_primitive->PadRight(); + conv_param->dilation_h_ = adder_primitive->GetDilateH(); + conv_param->dilation_w_ = adder_primitive->GetDilateW(); + conv_param->input_channel_ = adder_primitive->GetChannelIn(); + conv_param->output_channel_ = adder_primitive->GetChannelOut(); + conv_param->group_ = adder_primitive->GetGroup(); + auto act_type = adder_primitive->GetActivationType(); + switch (act_type) { + case schema::ActivationType_RELU: + conv_param->act_type_ = ActType_Relu; + break; + case schema::ActivationType_RELU6: + conv_param->act_type_ = ActType_Relu6; + break; + default: + conv_param->act_type_ = ActType_No; + break; + } + return reinterpret_cast(conv_param); } Registry AdderParameterRegistry(schema::PrimitiveType_Adder, PopulateAdderParameter); - } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/adder_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/adder_fp32.cc new file mode 100644 index 00000000000..15e90fc16ac --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/adder_fp32.cc @@ -0,0 +1,133 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/adder_fp32.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" +#include "schema/model_generated.h" +#include "nnacl/fp32/adder_fp32.h" +#include "nnacl/fp32/matmul_fp32.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_INFER_INVALID; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Adder; +using mindspore::schema::Format::Format_NHWC; + +namespace mindspore::kernel { +int AdderCPUKernel::InitWeightBias() { + auto filter_tensor = in_tensors_.at(kWeightIndex); + int kernel_h = filter_tensor->Height(); + int kernel_w = filter_tensor->Width(); + int in_channel = filter_tensor->Channel(); + int out_channel = filter_tensor->Batch(); + conv_param_->input_channel_ = in_channel; + conv_param_->output_channel_ = out_channel; + int kernel_plane = kernel_h * kernel_w; + const int oc_block = C4NUM; + int oc_block_num = UP_DIV(out_channel, C4NUM); + int pack_weight_size = oc_block_num * oc_block * in_channel * kernel_plane; + + auto origin_weight = reinterpret_cast(filter_tensor->MutableData()); + packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(float))); + if (packed_weight_ == nullptr) { + MS_LOG(ERROR) << "malloc packed weight failed."; + return RET_ERROR; + } + memset(packed_weight_, 0, pack_weight_size * sizeof(float)); + RowMajor2Col4Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane); + + bias_data_ = reinterpret_cast(malloc(oc_block_num * oc_block * sizeof(float))); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "malloc bias failed."; + return RET_ERROR; + } + memset(bias_data_, 0, oc_block_num * oc_block * sizeof(float)); + + if (in_tensors_.size() == kInputSize2) { + auto ori_bias = reinterpret_cast(in_tensors_.at(kBiasIndex)->MutableData()); + memcpy(bias_data_, ori_bias, out_channel * sizeof(float)); + } else { + MS_ASSERT(in_tensors_.size() == kInputSize1); + } + return RET_OK; +} + +int AdderCPUKernel::RunImpl(int task_id) { + auto input_tensor = in_tensors_.at(kInputIndex); + auto ori_input_data = reinterpret_cast(input_tensor->data_c()); + auto output_addr = reinterpret_cast(out_tensors_.at(kOutputIndex)->data_c()); + AdderFp32(ori_input_data, packed_input_, packed_weight_, reinterpret_cast(bias_data_), col_major_input_, + output_addr, task_id, conv_param_); + return RET_OK; +} + +int AdderImpl(void *cdata, int task_id) { + auto adder = reinterpret_cast(cdata); + auto error_code = adder->RunImpl(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Adder Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int AdderCPUKernel::Run() { + auto ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + + int error_code = ParallelLaunch(this->context_->thread_pool_, AdderImpl, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "adder error error_code[" << error_code << "]"; + FreeTmpBuffer(); + return RET_ERROR; + } + FreeTmpBuffer(); + return RET_OK; +} + +kernel::LiteKernel *CpuAdderFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *op_parameter, + const InnerContext *ctx, const kernel::KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + MS_ASSERT(op_parameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Adder); + MS_ASSERT(desc.data_type == kNumberTypeFloat32); + kernel::LiteKernel *kernel = new (std::nothrow) kernel::AdderCPUKernel(op_parameter, inputs, outputs, ctx, primitive); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + free(op_parameter); + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK && ret != RET_INFER_INVALID) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(op_parameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Adder, CpuAdderFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/adder_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/adder_fp32.h new file mode 100644 index 00000000000..7f2b8c4363b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/adder_fp32.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ADDER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ADDER_H_ + +#include +#include "src/lite_kernel.h" +#include "nnacl/op_base.h" +#include "src/runtime/kernel/arm/fp32/convolution_fp32.h" +#include "nnacl/fp32/conv_fp32.h" + +namespace mindspore::kernel { +class AdderCPUKernel : public ConvolutionCPUKernel { + public: + AdderCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : ConvolutionCPUKernel(parameter, inputs, outputs, ctx, primitive) {} + ~AdderCPUKernel() override = default; + + int InitWeightBias() override; + int Run() override; + int RunImpl(int task_id) override; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ADDER_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.h index f4879c11d00..7dc6ead1594 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.h @@ -38,13 +38,13 @@ class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel { } int Init() override; + virtual int InitWeightBias(); + int InitTmpBuffer(); int ReSize() override; int Run() override; - int RunImpl(int task_id); - int InitWeightBias(); - int InitTmpBuffer(); + virtual int RunImpl(int task_id); - private: + protected: void FreeTmpBuffer() { if (packed_input_ != nullptr) { ctx_->allocator->Free(packed_input_); @@ -55,8 +55,10 @@ class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel { col_major_input_ = nullptr; } } - float *packed_input_ = nullptr; + + protected: float *packed_weight_ = nullptr; + float *packed_input_ = nullptr; float *col_major_input_ = nullptr; }; } // namespace mindspore::kernel