forked from mindspore-Ecosystem/mindspore
!9210 [MSLITE][Develop] optimize arm cpu fp32 conv depthwise: add indirect buffer
From: @yangruoqi713 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
deb8e5f9fe
|
@ -0,0 +1,146 @@
|
|||
#ifdef __aarch64__
|
||||
|
||||
.text
|
||||
.align 5
|
||||
.global ConvDwFp32Indirect3x3
|
||||
#ifndef __APPLE__
|
||||
.type ConvDwFp32Indirect3x3, %function
|
||||
#endif
|
||||
|
||||
// void ConvDwFp32Indirect3x3(float *output, float **input, const float *weights, const float *bias, int channels, int output_width,
|
||||
// size_t input_stride, size_t relu, size_t relu6)
|
||||
// x0: output, x1: input, x2: weights, x3: bias, x4: channels, x5: output_width, x6: input_stride, x7: relu, x8: relu6
|
||||
|
||||
ConvDwFp32Indirect3x3:
|
||||
sub sp, sp, #16
|
||||
stp x19, x20, [sp], #16
|
||||
|
||||
movi v31.4s, #6
|
||||
scvtf v31.4s, v31.4s
|
||||
dup v30.4s, wzr
|
||||
|
||||
ldr x8, [sp]
|
||||
cmp x5, #0
|
||||
beq End
|
||||
|
||||
LoopPixel:
|
||||
ldp x12, x13, [x1]
|
||||
ldp x14, x15, [x1, #16]
|
||||
ldp x16, x17, [x1, #32]
|
||||
ldp x18, x19, [x1, #48]
|
||||
ldr x20, [x1, #64]
|
||||
mov x9, x2
|
||||
mov x10, x3
|
||||
mov x11, x4
|
||||
|
||||
ld1 {v0.4s}, [x12], #16
|
||||
ld1 {v1.4s}, [x13], #16
|
||||
ld1 {v2.4s}, [x14], #16
|
||||
|
||||
ld1 {v17.4s}, [x9], #16
|
||||
ld1 {v18.4s}, [x9], #16
|
||||
ld1 {v19.4s}, [x9], #16
|
||||
|
||||
ld1 {v29.4s}, [x10], #16
|
||||
cmp x11, #4
|
||||
ble LeftLoop
|
||||
LoopC4:
|
||||
fmla v29.4s, v0.4s, v17.4s
|
||||
ld1 {v3.4s}, [x15], #16
|
||||
ld1 {v20.4s}, [x9], #16
|
||||
fmla v29.4s, v1.4s, v18.4s
|
||||
ld1 {v4.4s}, [x16], #16
|
||||
ld1 {v21.4s}, [x9], #16
|
||||
fmla v29.4s, v2.4s, v19.4s
|
||||
ld1 {v5.4s}, [x17], #16
|
||||
ld1 {v22.4s}, [x9], #16
|
||||
fmla v29.4s, v3.4s, v20.4s
|
||||
ld1 {v6.4s}, [x18], #16
|
||||
ld1 {v23.4s}, [x9], #16
|
||||
fmla v29.4s, v4.4s, v21.4s
|
||||
ld1 {v7.4s}, [x19], #16
|
||||
ld1 {v24.4s}, [x9], #16
|
||||
fmla v29.4s, v5.4s, v22.4s
|
||||
ld1 {v16.4s}, [x20], #16
|
||||
ld1 {v25.4s}, [x9], #16
|
||||
fmla v29.4s, v6.4s, v23.4s
|
||||
ld1 {v0.4s}, [x12], #16
|
||||
ld1 {v17.4s}, [x9], #16
|
||||
fmla v29.4s, v7.4s, v24.4s
|
||||
ld1 {v1.4s}, [x13], #16
|
||||
ld1 {v18.4s}, [x9], #16
|
||||
fmla v29.4s, v16.4s, v25.4s
|
||||
ld1 {v2.4s}, [x14], #16
|
||||
ld1 {v19.4s}, [x9], #16
|
||||
|
||||
cbnz x8, Relu6
|
||||
cbnz x7, Relu
|
||||
b Write
|
||||
Relu6:
|
||||
fmin v29.4s, v29.4s, v31.4s
|
||||
Relu:
|
||||
fmax v29.4s, v29.4s, v30.4s
|
||||
Write:
|
||||
st1 {v29.4s}, [x0], #16
|
||||
|
||||
ld1 {v29.4s}, [x10], #16
|
||||
sub x11, x11, #4
|
||||
cmp x11, #4
|
||||
bgt LoopC4
|
||||
|
||||
LeftLoop:
|
||||
fmla v29.4s, v0.4s, v17.4s
|
||||
ld1 {v3.4s}, [x15], #16
|
||||
ld1 {v20.4s}, [x9], #16
|
||||
fmla v29.4s, v1.4s, v18.4s
|
||||
ld1 {v4.4s}, [x16], #16
|
||||
ld1 {v21.4s}, [x9], #16
|
||||
fmla v29.4s, v2.4s, v19.4s
|
||||
ld1 {v5.4s}, [x17], #16
|
||||
ld1 {v22.4s}, [x9], #16
|
||||
fmla v29.4s, v3.4s, v20.4s
|
||||
ld1 {v6.4s}, [x18], #16
|
||||
ld1 {v23.4s}, [x9], #16
|
||||
fmla v29.4s, v4.4s, v21.4s
|
||||
ld1 {v7.4s}, [x19], #16
|
||||
ld1 {v24.4s}, [x9], #16
|
||||
fmla v29.4s, v5.4s, v22.4s
|
||||
ld1 {v16.4s}, [x20], #16
|
||||
ld1 {v25.4s}, [x9], #16
|
||||
fmla v29.4s, v6.4s, v23.4s
|
||||
fmla v29.4s, v7.4s, v24.4s
|
||||
fmla v29.4s, v16.4s, v25.4s
|
||||
|
||||
cbnz x8, LeftRelu6
|
||||
cbnz x7, LeftRelu
|
||||
b LeftWrite
|
||||
LeftRelu6:
|
||||
fmin v29.4s, v29.4s, v31.4s
|
||||
LeftRelu:
|
||||
fmax v29.4s, v29.4s, v30.4s
|
||||
LeftWrite:
|
||||
cmp x11, #4
|
||||
bne Write3
|
||||
st1 {v29.4s}, [x0], #16
|
||||
b NextPixel
|
||||
Write3:
|
||||
sxtw x11, w11
|
||||
tbnz w11, #1, Write2
|
||||
tbnz w11, #0, Write1
|
||||
Write2:
|
||||
str d29, [x0], #8
|
||||
ext v29.16b, v29.16b, v29.16b, #8
|
||||
tbz w11, #0, NextPixel
|
||||
Write1:
|
||||
str s29, [x0], #4
|
||||
|
||||
NextPixel:
|
||||
add x1, x1, x6
|
||||
sub x5, x5, #1
|
||||
cmp x5, #0
|
||||
bgt LoopPixel
|
||||
End:
|
||||
sub sp, sp, #16
|
||||
ldp x19, x20, [sp], #16
|
||||
ret
|
||||
#endif
|
|
@ -0,0 +1,283 @@
|
|||
#ifdef __aarch64__
|
||||
|
||||
.text
|
||||
.align 5
|
||||
.global ConvDwFp32Indirect5x5
|
||||
#ifndef __APPLE__
|
||||
.type ConvDwFp32Indirect5x5, %function
|
||||
#endif
|
||||
|
||||
// void ConvDwFp32Indirect5x5(float *output, float **input, const float *weights, const float *bias, int channels, int output_width,
|
||||
// size_t input_stride, size_t relu, size_t relu6)
|
||||
// x0: output, x1: input, x2: weights, x3: bias, x4: channels, x5: output_width, x6: input_stride, x7: relu, x8: relu6
|
||||
|
||||
ConvDwFp32Indirect5x5:
|
||||
sub sp, sp, #160
|
||||
stp x19, x20, [sp, #64]
|
||||
stp x21, x22, [sp, #80]
|
||||
stp x23, x24, [sp, #96]
|
||||
stp x25, x26, [sp, #112]
|
||||
stp x27, x28, [sp, #128]
|
||||
stp x29, x30, [sp, #144]
|
||||
ldrb w8, [sp, #160]
|
||||
stp x2, x3, [sp]
|
||||
stp x4, x6, [sp, #16]
|
||||
stp x7, x8, [sp, #32]
|
||||
|
||||
movi v31.4s, #6
|
||||
scvtf v31.4s, v31.4s
|
||||
dup v30.4s, wzr
|
||||
|
||||
mov x3, x5
|
||||
cmp x3, #0
|
||||
beq End
|
||||
|
||||
LoopPixel:
|
||||
ldp x5, x4, [sp] // weight, bias
|
||||
ld1 {v29.4s}, [x4], #16
|
||||
ldr x2, [sp, #16] // channel
|
||||
|
||||
ldp x6, x7, [x1]
|
||||
ldp x8, x9, [x1, #16]
|
||||
ldp x10, x11, [x1, #32]
|
||||
ldp x12, x13, [x1, #48]
|
||||
ldp x14, x15, [x1, #64]
|
||||
ldp x16, x17, [x1, #80]
|
||||
ldp x18, x19, [x1, #96]
|
||||
ldp x20, x21, [x1, #112]
|
||||
ldp x22, x23, [x1, #128]
|
||||
ldp x24, x25, [x1, #144]
|
||||
ldp x26, x27, [x1, #160]
|
||||
ldp x28, x29, [x1, #176]
|
||||
ldr x30, [x1, #192]
|
||||
|
||||
ld1 {v0.4s}, [x6], #16
|
||||
ld1 {v1.4s}, [x7], #16
|
||||
ld1 {v2.4s}, [x8], #16
|
||||
ld1 {v3.4s}, [x9], #16
|
||||
ld1 {v4.4s}, [x10], #16
|
||||
|
||||
ld1 {v18.4s}, [x5], #16
|
||||
ld1 {v19.4s}, [x5], #16
|
||||
ld1 {v20.4s}, [x5], #16
|
||||
ld1 {v21.4s}, [x5], #16
|
||||
ld1 {v22.4s}, [x5], #16
|
||||
stp x5, x4, [sp, #48]
|
||||
|
||||
cmp x2, #4
|
||||
ble LeftLoop
|
||||
LoopC4:
|
||||
ldr x5, [sp, #48]
|
||||
// column 0
|
||||
fmla v29.4s, v0.4s, v18.4s
|
||||
ld1 {v5.4s}, [x11], #16
|
||||
ld1 {v23.4s}, [x5], #16
|
||||
fmla v29.4s, v1.4s, v19.4s
|
||||
ld1 {v6.4s}, [x12], #16
|
||||
ld1 {v24.4s}, [x5], #16
|
||||
fmla v29.4s, v2.4s, v20.4s
|
||||
ld1 {v7.4s}, [x13], #16
|
||||
ld1 {v25.4s}, [x5], #16
|
||||
fmla v29.4s, v3.4s, v21.4s
|
||||
ld1 {v16.4s}, [x14], #16
|
||||
ld1 {v26.4s}, [x5], #16
|
||||
fmla v29.4s, v4.4s, v22.4s
|
||||
ld1 {v17.4s}, [x15], #16
|
||||
ld1 {v27.4s}, [x5], #16
|
||||
// column 1
|
||||
fmla v29.4s, v5.4s, v23.4s
|
||||
ld1 {v0.4s}, [x16], #16
|
||||
ld1 {v18.4s}, [x5], #16
|
||||
fmla v29.4s, v6.4s, v24.4s
|
||||
ld1 {v1.4s}, [x17], #16
|
||||
ld1 {v19.4s}, [x5], #16
|
||||
fmla v29.4s, v7.4s, v25.4s
|
||||
ld1 {v2.4s}, [x18], #16
|
||||
ld1 {v20.4s}, [x5], #16
|
||||
fmla v29.4s, v16.4s, v26.4s
|
||||
ld1 {v3.4s}, [x19], #16
|
||||
ld1 {v21.4s}, [x5], #16
|
||||
fmla v29.4s, v17.4s, v27.4s
|
||||
ld1 {v4.4s}, [x20], #16
|
||||
ld1 {v22.4s}, [x5], #16
|
||||
// column 2
|
||||
fmla v29.4s, v0.4s, v18.4s
|
||||
ld1 {v5.4s}, [x21], #16
|
||||
ld1 {v23.4s}, [x5], #16
|
||||
fmla v29.4s, v1.4s, v19.4s
|
||||
ld1 {v6.4s}, [x22], #16
|
||||
ld1 {v24.4s}, [x5], #16
|
||||
fmla v29.4s, v2.4s, v20.4s
|
||||
ld1 {v7.4s}, [x23], #16
|
||||
ld1 {v25.4s}, [x5], #16
|
||||
fmla v29.4s, v3.4s, v21.4s
|
||||
ld1 {v16.4s}, [x24], #16
|
||||
ld1 {v26.4s}, [x5], #16
|
||||
fmla v29.4s, v4.4s, v22.4s
|
||||
ld1 {v17.4s}, [x25], #16
|
||||
ld1 {v27.4s}, [x5], #16
|
||||
// column 3
|
||||
fmla v29.4s, v5.4s, v23.4s
|
||||
ld1 {v0.4s}, [x26], #16
|
||||
ld1 {v18.4s}, [x5], #16
|
||||
fmla v29.4s, v6.4s, v24.4s
|
||||
ld1 {v1.4s}, [x27], #16
|
||||
ld1 {v19.4s}, [x5], #16
|
||||
fmla v29.4s, v7.4s, v25.4s
|
||||
ld1 {v2.4s}, [x28], #16
|
||||
ld1 {v20.4s}, [x5], #16
|
||||
fmla v29.4s, v16.4s, v26.4s
|
||||
ld1 {v3.4s}, [x29], #16
|
||||
ld1 {v21.4s}, [x5], #16
|
||||
fmla v29.4s, v17.4s, v27.4s
|
||||
ld1 {v4.4s}, [x30], #16
|
||||
ld1 {v22.4s}, [x5], #16
|
||||
// column 4
|
||||
fmla v29.4s, v0.4s, v18.4s
|
||||
fmla v29.4s, v1.4s, v19.4s
|
||||
ld1 {v0.4s}, [x6], #16
|
||||
ld1 {v18.4s}, [x5], #16
|
||||
fmla v29.4s, v2.4s, v20.4s
|
||||
ld1 {v1.4s}, [x7], #16
|
||||
ld1 {v19.4s}, [x5], #16
|
||||
fmla v29.4s, v3.4s, v21.4s
|
||||
ld1 {v2.4s}, [x8], #16
|
||||
ld1 {v20.4s}, [x5], #16
|
||||
fmla v29.4s, v4.4s, v22.4s
|
||||
ld1 {v3.4s}, [x9], #16
|
||||
ld1 {v21.4s}, [x5], #16
|
||||
ld1 {v4.4s}, [x10], #16
|
||||
ld1 {v22.4s}, [x5], #16
|
||||
str x5, [sp, #48]
|
||||
|
||||
ldp x4, x5, [sp, #32]
|
||||
cbnz x5, RELU6
|
||||
cbnz x4, RELU
|
||||
b WRITE
|
||||
RELU6:
|
||||
fmin v29.4s, v29.4s, v31.4s
|
||||
RELU:
|
||||
fmax v29.4s, v29.4s, v30.4s
|
||||
WRITE:
|
||||
st1 {v29.4s}, [x0], #16
|
||||
|
||||
ldr x4, [sp, #56]
|
||||
ld1 {v29.4s}, [x4], #16
|
||||
str x4, [sp, #56]
|
||||
sub x2, x2, #4
|
||||
cmp x2, #4
|
||||
bgt LoopC4
|
||||
|
||||
LeftLoop:
|
||||
// column 0
|
||||
ldr x5, [sp, #48]
|
||||
fmla v29.4s, v0.4s, v18.4s
|
||||
ld1 {v5.4s}, [x11], #16
|
||||
ld1 {v23.4s}, [x5], #16
|
||||
fmla v29.4s, v1.4s, v19.4s
|
||||
ld1 {v6.4s}, [x12], #16
|
||||
ld1 {v24.4s}, [x5], #16
|
||||
fmla v29.4s, v2.4s, v20.4s
|
||||
ld1 {v7.4s}, [x13], #16
|
||||
ld1 {v25.4s}, [x5], #16
|
||||
fmla v29.4s, v3.4s, v21.4s
|
||||
ld1 {v16.4s}, [x14], #16
|
||||
ld1 {v26.4s}, [x5], #16
|
||||
fmla v29.4s, v4.4s, v22.4s
|
||||
ld1 {v17.4s}, [x15], #16
|
||||
ld1 {v27.4s}, [x5], #16
|
||||
// column 1
|
||||
fmla v29.4s, v5.4s, v23.4s
|
||||
ld1 {v0.4s}, [x16], #16
|
||||
ld1 {v18.4s}, [x5], #16
|
||||
fmla v29.4s, v6.4s, v24.4s
|
||||
ld1 {v1.4s}, [x17], #16
|
||||
ld1 {v19.4s}, [x5], #16
|
||||
fmla v29.4s, v7.4s, v25.4s
|
||||
ld1 {v2.4s}, [x18], #16
|
||||
ld1 {v20.4s}, [x5], #16
|
||||
fmla v29.4s, v16.4s, v26.4s
|
||||
ld1 {v3.4s}, [x19], #16
|
||||
ld1 {v21.4s}, [x5], #16
|
||||
fmla v29.4s, v17.4s, v27.4s
|
||||
ld1 {v4.4s}, [x20], #16
|
||||
ld1 {v22.4s}, [x5], #16
|
||||
// column 2
|
||||
fmla v29.4s, v0.4s, v18.4s
|
||||
ld1 {v5.4s}, [x21], #16
|
||||
ld1 {v23.4s}, [x5], #16
|
||||
fmla v29.4s, v1.4s, v19.4s
|
||||
ld1 {v6.4s}, [x22], #16
|
||||
ld1 {v24.4s}, [x5], #16
|
||||
fmla v29.4s, v2.4s, v20.4s
|
||||
ld1 {v7.4s}, [x23], #16
|
||||
ld1 {v25.4s}, [x5], #16
|
||||
fmla v29.4s, v3.4s, v21.4s
|
||||
ld1 {v16.4s}, [x24], #16
|
||||
ld1 {v26.4s}, [x5], #16
|
||||
fmla v29.4s, v4.4s, v22.4s
|
||||
ld1 {v17.4s}, [x25], #16
|
||||
ld1 {v27.4s}, [x5], #16
|
||||
// column 3
|
||||
fmla v29.4s, v5.4s, v23.4s
|
||||
ld1 {v0.4s}, [x26], #16
|
||||
ld1 {v18.4s}, [x5], #16
|
||||
fmla v29.4s, v6.4s, v24.4s
|
||||
ld1 {v1.4s}, [x27], #16
|
||||
ld1 {v19.4s}, [x5], #16
|
||||
fmla v29.4s, v7.4s, v25.4s
|
||||
ld1 {v2.4s}, [x28], #16
|
||||
ld1 {v20.4s}, [x5], #16
|
||||
fmla v29.4s, v16.4s, v26.4s
|
||||
ld1 {v3.4s}, [x29], #16
|
||||
ld1 {v21.4s}, [x5], #16
|
||||
fmla v29.4s, v17.4s, v27.4s
|
||||
ld1 {v4.4s}, [x30], #16
|
||||
ld1 {v22.4s}, [x5], #16
|
||||
// column 4
|
||||
fmla v29.4s, v0.4s, v18.4s
|
||||
fmla v29.4s, v1.4s, v19.4s
|
||||
fmla v29.4s, v2.4s, v20.4s
|
||||
fmla v29.4s, v3.4s, v21.4s
|
||||
fmla v29.4s, v4.4s, v22.4s
|
||||
|
||||
ldp x4, x5, [sp, #32]
|
||||
cbnz x5, LeftRelu6
|
||||
cbnz x4, LeftRelu
|
||||
b LeftWrite
|
||||
LeftRelu6:
|
||||
fmin v29.4s, v29.4s, v31.4s
|
||||
LeftRelu:
|
||||
fmax v29.4s, v29.4s, v30.4s
|
||||
LeftWrite:
|
||||
cmp x2, #4
|
||||
bne Write3
|
||||
st1 {v29.4s}, [x0], #16
|
||||
b NextPixel
|
||||
Write3:
|
||||
sxtw x2, w2
|
||||
tbnz w2, #1, Write2
|
||||
tbnz w2, #0, Write1
|
||||
Write2:
|
||||
str d29, [x0], #8
|
||||
ext v29.16b, v29.16b, v29.16b, #8
|
||||
tbz w2, #0, NextPixel
|
||||
Write1:
|
||||
str s29, [x0], #4
|
||||
|
||||
NextPixel:
|
||||
ldr x2, [sp, #24]
|
||||
add x1, x1, x2
|
||||
sub x3, x3, #1
|
||||
cmp x3, #0
|
||||
bgt LoopPixel
|
||||
End:
|
||||
ldp x19, x20, [sp, #64]
|
||||
ldp x21, x22, [sp, #80]
|
||||
ldp x23, x24, [sp, #96]
|
||||
ldp x25, x26, [sp, #112]
|
||||
ldp x27, x28, [sp, #128]
|
||||
ldp x29, x30, [sp, #144]
|
||||
add sp, sp, #160
|
||||
ret
|
||||
#endif
|
|
@ -577,7 +577,134 @@ void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const
|
|||
}
|
||||
}
|
||||
}
|
||||
/*conv depthwise 3x3 fp32 end*/
|
||||
|
||||
/*conv depthwise indirect buffer fp32 begin*/
|
||||
bool CheckConvDwUseIndirectBuffer(const ConvParameter *conv_param) {
|
||||
bool use_indirect = (conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3) ||
|
||||
(conv_param->kernel_h_ == 5 && conv_param->kernel_w_ == 5);
|
||||
return use_indirect;
|
||||
}
|
||||
|
||||
void ConvDwInitIndirection(float **indirect_buffer, float *src, float *zero_ptr, const ConvParameter *conv_param,
|
||||
int step_h, int step_w) {
|
||||
int ic_4 = UP_DIV(conv_param->input_channel_, C4NUM) * C4NUM;
|
||||
for (int b = 0; b < conv_param->output_batch_; b++) {
|
||||
float **indirect = indirect_buffer + b * conv_param->output_h_ * step_h;
|
||||
float *input = src + b * conv_param->input_h_ * conv_param->input_w_ * ic_4;
|
||||
for (int oh = 0; oh < conv_param->output_h_; oh++) {
|
||||
for (int kh = 0; kh < conv_param->kernel_h_; kh++) {
|
||||
int ih = oh * conv_param->stride_h_ + kh * conv_param->dilation_h_ - conv_param->pad_u_;
|
||||
if (ih < conv_param->input_h_ && ih >= 0) {
|
||||
for (int ow = 0; ow < conv_param->output_w_; ow++) {
|
||||
for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
|
||||
int iw = ow * conv_param->stride_w_ + kw * conv_param->dilation_w_ - conv_param->pad_l_;
|
||||
int index = oh * step_h + ow * step_w * conv_param->kernel_h_ + kw * conv_param->kernel_h_ + kh;
|
||||
if (iw < conv_param->input_w_ && iw >= 0) {
|
||||
indirect[index] = input + (ih * conv_param->input_w_ + iw) * ic_4;
|
||||
} else {
|
||||
indirect[index] = zero_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int ow = 0; ow < conv_param->output_w_; ow++) {
|
||||
for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
|
||||
int index = oh * step_h + ow * step_w * conv_param->kernel_h_ + kw * conv_param->kernel_h_ + kh;
|
||||
indirect[index] = zero_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ARM64
|
||||
void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
|
||||
int output_width, int input_stride, bool relu, bool relu6, int kernel) {
|
||||
do {
|
||||
float *in[kernel];
|
||||
for (int k = 0; k < kernel; k++) {
|
||||
in[k] = input[k];
|
||||
}
|
||||
input = input + input_stride;
|
||||
|
||||
size_t c = channels;
|
||||
const float *w = weights;
|
||||
float *out = output;
|
||||
memcpy(out, bias, channels * sizeof(float));
|
||||
for (; c >= C4NUM; c -= C4NUM) {
|
||||
for (int i = 0; i < C4NUM; i++) {
|
||||
for (int k = 0; k < kernel; k++) {
|
||||
out[i] += in[k][i] * w[i + k * C4NUM];
|
||||
}
|
||||
}
|
||||
w += kernel * C4NUM;
|
||||
out += C4NUM;
|
||||
for (int k = 0; k < kernel; k++) {
|
||||
in[k] += C4NUM;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < c; i++) {
|
||||
for (int k = 0; k < kernel; k++) {
|
||||
out[i] += in[k][i] * w[i + k * C4NUM];
|
||||
}
|
||||
}
|
||||
if (relu) {
|
||||
ReluFp32(output, output, channels);
|
||||
}
|
||||
if (relu6) {
|
||||
Relu6Fp32(output, output, channels);
|
||||
}
|
||||
output += channels;
|
||||
} while (--output_width != 0);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
|
||||
int output_width, int input_stride, bool relu, bool relu6, int kernel) {
|
||||
if (kernel == 9) {
|
||||
ConvDwFp32Indirect3x3(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu,
|
||||
relu6);
|
||||
} else if (kernel == 25) {
|
||||
ConvDwFp32Indirect5x5(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu,
|
||||
relu6);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void ConvDwIndirection(float *output_data, float **indirect_buffer, const float *weight_data, const float *bias_data,
|
||||
float *zero_ptr, const ConvParameter *conv_param, int task_id) {
|
||||
int step_w = conv_param->dilation_w_ == 1 ? conv_param->stride_w_ : conv_param->kernel_w_;
|
||||
int step_h =
|
||||
(conv_param->kernel_h_ * conv_param->kernel_w_) + (conv_param->output_w_ - 1) * step_w * conv_param->kernel_h_;
|
||||
int input_stride = conv_param->kernel_h_ * step_w;
|
||||
|
||||
bool relu = conv_param->act_type_ == ActType_Relu;
|
||||
bool relu6 = conv_param->act_type_ == ActType_Relu6;
|
||||
|
||||
int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
|
||||
int h_start = h_step * task_id;
|
||||
int h_end = MSMIN(h_start + h_step, conv_param->output_h_);
|
||||
|
||||
for (int b = 0; b < conv_param->output_batch_; b++) {
|
||||
float **indirect_b = indirect_buffer + b * conv_param->output_h_ * step_h;
|
||||
float *outout_b = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
|
||||
for (int oh = h_start; oh < h_end; oh++) {
|
||||
float **indirect = indirect_b + oh * step_h;
|
||||
float *output_h = outout_b + oh * conv_param->output_w_ * conv_param->output_channel_;
|
||||
if (conv_param->kernel_w_ == 3) {
|
||||
ConvDwFp32IndirectRow(output_h, indirect, weight_data, bias_data, conv_param->output_channel_,
|
||||
conv_param->output_w_, input_stride, relu, relu6, 9);
|
||||
} else if (conv_param->kernel_w_ == 5) {
|
||||
ConvDwFp32IndirectRow(output_h, indirect, weight_data, bias_data, conv_param->output_channel_,
|
||||
conv_param->output_w_, input_stride, relu, relu6, 25);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
/*conv depthwise indirect buffer fp32 end*/
|
||||
|
||||
/*deconv depthwise fp32 begin*/
|
||||
void DeconvDwBorderPixel(float *dst, const float *src, const float *weight, int height, int width, int in_kh_step,
|
||||
|
|
|
@ -53,6 +53,25 @@ void ConvDw3x3Pad(float *output_data, const float *input_data, const float *weig
|
|||
void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data,
|
||||
const float *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id);
|
||||
|
||||
bool CheckConvDwUseIndirectBuffer(const ConvParameter *conv_param);
|
||||
|
||||
void ConvDwInitIndirection(float **indirect_buffer, float *src, float *zero_ptr, const ConvParameter *conv_param,
|
||||
int step_h, int step_w);
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void ConvDwFp32Indirect3x3(float *output, float **input, const float *weights, const float *bias, int channels,
|
||||
int output_width, size_t input_stride, size_t relu, size_t relu6);
|
||||
|
||||
void ConvDwFp32Indirect5x5(float *output, float **input, const float *weights, const float *bias, int channels,
|
||||
int output_width, size_t input_stride, size_t relu, size_t relu6);
|
||||
#endif
|
||||
|
||||
void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
|
||||
int output_width, int input_stride, bool relu, bool relu6, int kernel);
|
||||
|
||||
void ConvDwIndirection(float *output_data, float **indirect_buffer, const float *weight_data, const float *bias_data,
|
||||
float *zero_ptr, const ConvParameter *conv_param, int task_id);
|
||||
|
||||
void DeconvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
|
||||
const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id);
|
||||
|
||||
|
|
|
@ -583,6 +583,23 @@ void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int
|
|||
return;
|
||||
}
|
||||
|
||||
void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel) {
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
for (int c = 0; c < c4; c++) {
|
||||
int dst_off_c = c * C4NUM * height * width;
|
||||
for (int i = 0; i < C4NUM; i++) {
|
||||
int src_off_c = (c * C4NUM + i) * height * width;
|
||||
for (int kh = 0; kh < height; kh++) {
|
||||
int src_off_kh = src_off_c + kh * width;
|
||||
for (int kw = 0; kw < width; kw++) {
|
||||
int dst_off = dst_off_c + kw * height * C4NUM + kh * C4NUM + i;
|
||||
((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
int c4_channel = c4 * C4NUM;
|
||||
|
|
|
@ -78,6 +78,8 @@ void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int
|
|||
|
||||
void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel);
|
||||
|
||||
void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "include/context.h"
|
||||
#include "src/ops/pooling.h"
|
||||
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
|
@ -95,6 +96,11 @@ int PoolingBaseCPUKernel::ReSize() {
|
|||
auto out_tensor = this->out_tensors_.front();
|
||||
MS_ASSERT(in_tensor != nullptr);
|
||||
MS_ASSERT(out_tensor != nullptr);
|
||||
auto pooling_lite_primitive = (lite::Pooling *)primitive_;
|
||||
pooling_param_->pad_u_ = pooling_lite_primitive->PadUp();
|
||||
pooling_param_->pad_d_ = pooling_lite_primitive->PadDown();
|
||||
pooling_param_->pad_l_ = pooling_lite_primitive->PadLeft();
|
||||
pooling_param_->pad_r_ = pooling_lite_primitive->PadRight();
|
||||
pooling_param_->input_batch_ = in_tensor->Batch();
|
||||
pooling_param_->input_channel_ = in_tensor->Channel();
|
||||
pooling_param_->input_h_ = in_tensor->Height();
|
||||
|
|
|
@ -74,6 +74,9 @@ int ArithmeticCPUKernel::PreProcess() {
|
|||
}
|
||||
|
||||
int ArithmeticCPUKernel::ReSize() {
|
||||
auto arithmetic_lite_primitive = (lite::Arithmetic *)primitive_;
|
||||
arithmeticParameter_->broadcasting_ = arithmetic_lite_primitive->Broadcasting();
|
||||
arithmeticParameter_->ndim_ = arithmetic_lite_primitive->NDims();
|
||||
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat16) {
|
||||
data_type_ = kDataTypeFloat;
|
||||
} else {
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.h"
|
||||
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.h"
|
||||
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_fp32.h"
|
||||
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_indirect_fp32.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
|
@ -147,12 +147,12 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
|
|||
conv_param->input_channel_ = inputs[kInputIndex]->Channel();
|
||||
conv_param->output_h_ = outputs[kOutputIndex]->Height();
|
||||
conv_param->output_w_ = outputs[kOutputIndex]->Width();
|
||||
if (CheckConvDwUse3X3(conv_param) && conv_param->input_channel_ % C4NUM == 0) {
|
||||
#ifdef ENABLE_ARM64
|
||||
if (CheckConvDwUseIndirectBuffer(conv_param)) {
|
||||
kernel =
|
||||
new (std::nothrow) kernel::ConvolutionDepthwise3x3CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
#endif
|
||||
new (std::nothrow) kernel::ConvolutionDepthwiseIndirectCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
}
|
||||
#endif
|
||||
if (kernel == nullptr && conv_param->input_channel_ < 32) {
|
||||
kernel = new (std::nothrow) kernel::ConvolutionDepthwiseSWCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,182 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_indirect_fp32.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_INFER_INVALID;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_DepthwiseConv2D;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
ConvolutionDepthwiseIndirectCPUKernel::~ConvolutionDepthwiseIndirectCPUKernel() {
|
||||
if (packed_weight_ != nullptr) {
|
||||
free(packed_weight_);
|
||||
packed_weight_ = nullptr;
|
||||
}
|
||||
if (zero_ptr_ != nullptr) {
|
||||
free(zero_ptr_);
|
||||
zero_ptr_ = nullptr;
|
||||
}
|
||||
if (indirect_buffer_ != nullptr) {
|
||||
free(indirect_buffer_);
|
||||
indirect_buffer_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseIndirectCPUKernel::InitWeightBias() {
|
||||
// init weight: o, h, w, i; o == group, i == 1
|
||||
auto weight_tensor = in_tensors_[kWeightIndex];
|
||||
auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData());
|
||||
int C4 = UP_DIV(weight_tensor->Batch(), C4NUM);
|
||||
int pack_weight_size = C4NUM * C4 * weight_tensor->Height() * weight_tensor->Width();
|
||||
|
||||
packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float)));
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
PackDepthwiseIndirectWeightC4Fp32(origin_weight, packed_weight_, weight_tensor->Height(), weight_tensor->Width(),
|
||||
weight_tensor->Batch());
|
||||
|
||||
auto bias_tensor = in_tensors_[kBiasIndex];
|
||||
bias_data_ = reinterpret_cast<float *>(malloc(C4NUM * C4 * sizeof(float)));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
memset(bias_data_, 0, C4NUM * C4 * sizeof(float));
|
||||
if (in_tensors_.size() == kInputSize2) {
|
||||
auto ori_bias = reinterpret_cast<float *>(bias_tensor->MutableData());
|
||||
memcpy(bias_data_, ori_bias, bias_tensor->ElementsNum() * sizeof(float));
|
||||
}
|
||||
|
||||
// malloc zero ptr
|
||||
zero_ptr_ = reinterpret_cast<float *>(malloc(C4NUM * C4 * sizeof(float)));
|
||||
if (zero_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(zero_ptr_, 0, C4NUM * C4 * sizeof(float));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseIndirectCPUKernel::Init() {
|
||||
auto ret = InitWeightBias();
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "Convolution depthwise Indirect fp32 InitWeightBias failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseIndirectCPUKernel::MallocIndirectBuffer() {
|
||||
// malloc indirect buffer
|
||||
step_w = conv_param_->dilation_w_ == 1 ? conv_param_->stride_w_ : conv_param_->kernel_w_;
|
||||
step_h =
|
||||
(conv_param_->kernel_h_ * conv_param_->kernel_w_) + (conv_param_->output_w_ - 1) * step_w * conv_param_->kernel_h_;
|
||||
int buffer_size = conv_param_->output_batch_ * conv_param_->output_h_ * step_h;
|
||||
indirect_buffer_ = reinterpret_cast<float **>(malloc(buffer_size * sizeof(float *)));
|
||||
if (indirect_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseIndirectCPUKernel::ReSize() {
|
||||
if (indirect_buffer_ != nullptr) {
|
||||
free(indirect_buffer_);
|
||||
indirect_buffer_ = nullptr;
|
||||
}
|
||||
ConvolutionBaseCPUKernel::Init();
|
||||
auto ret = MallocIndirectBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvolutionDepthwiseIndirect MallocIndirectBuffer failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
conv_param_->thread_num_ = MSMIN(thread_count_, conv_param_->output_h_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseIndirectCPUKernel::Execute(int task_id) {
|
||||
ConvDwIndirection(output_ptr_, indirect_buffer_, packed_weight_, reinterpret_cast<float *>(bias_data_), zero_ptr_,
|
||||
conv_param_, task_id);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvDwIndirectRun(void *cdata, int task_id) {
|
||||
auto conv_dw = reinterpret_cast<ConvolutionDepthwiseIndirectCPUKernel *>(cdata);
|
||||
auto ret = conv_dw->Execute(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvolutionDepthwiseIndirectRun error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseIndirectCPUKernel::MallocPackedInput() {
|
||||
int IC4 = UP_DIV(conv_param_->input_channel_, C4NUM);
|
||||
int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C4NUM * IC4;
|
||||
packed_input_ = reinterpret_cast<float *>(context_->allocator->Malloc(pack_input_size * sizeof(float)));
|
||||
if (packed_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseIndirectCPUKernel::Run() {
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto input_ptr = reinterpret_cast<float *>(input_tensor->data_c());
|
||||
if (conv_param_->input_channel_ % C4NUM != 0) {
|
||||
auto ret = MallocPackedInput();
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "Convolution depthwise fp32 indirect buffer MallocPackedInput failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
PackNHWCToNHWC4Fp32(input_ptr, packed_input_, conv_param_->input_batch_,
|
||||
conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_);
|
||||
} else {
|
||||
packed_input_ = input_ptr;
|
||||
}
|
||||
|
||||
auto output_tensor = out_tensors_.at(kOutputIndex);
|
||||
output_ptr_ = reinterpret_cast<float *>(output_tensor->data_c());
|
||||
|
||||
ConvDwInitIndirection(indirect_buffer_, packed_input_, zero_ptr_, conv_param_, step_h, step_w);
|
||||
|
||||
auto ret = ParallelLaunch(this->context_->thread_pool_, ConvDwIndirectRun, this, conv_param_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvDwIndirectRun error: error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (conv_param_->input_channel_ % C4NUM != 0) {
|
||||
context_->allocator->Free(packed_input_);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_INDIRECT_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_INDIRECT_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/runtime/kernel/arm/base/convolution_base.h"
|
||||
#include "nnacl/fp32/conv_depthwise_fp32.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class ConvolutionDepthwiseIndirectCPUKernel : public ConvolutionBaseCPUKernel {
|
||||
public:
|
||||
ConvolutionDepthwiseIndirectCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
~ConvolutionDepthwiseIndirectCPUKernel() override;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
int InitWeightBias();
|
||||
int Execute(int task_id);
|
||||
|
||||
private:
|
||||
int MallocIndirectBuffer();
|
||||
int MallocPackedInput();
|
||||
int step_w = 0;
|
||||
int step_h = 0;
|
||||
float **indirect_buffer_ = nullptr;
|
||||
float *zero_ptr_ = nullptr;
|
||||
float *packed_weight_ = nullptr;
|
||||
float *output_ptr_ = nullptr;
|
||||
float *packed_input_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_INDIRECT_H_
|
|
@ -111,8 +111,15 @@ int Scheduler::InferShape(const lite::Model *model, std::vector<Tensor *> *tenso
|
|||
MS_LOG(ERROR) << "Op " << node->name_ << " should exist in model!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
primitive->set_infer_flag(!infer_shape_interrupt);
|
||||
auto ret = primitive->InferShape(inputs, outputs);
|
||||
STATUS ret = RET_INFER_INVALID;
|
||||
bool infer_valid = std::all_of(inputs.begin(), inputs.end(), [](Tensor *tensor) {
|
||||
auto shape = tensor->shape();
|
||||
return std::all_of(shape.begin(), shape.end(), [](int dim) { return dim != -1; });
|
||||
});
|
||||
if (infer_valid) {
|
||||
primitive->set_infer_flag(!infer_shape_interrupt);
|
||||
ret = primitive->InferShape(inputs, outputs);
|
||||
}
|
||||
if (ret == RET_INFER_INVALID) {
|
||||
MS_LOG(INFO) << "InferShape shouldn't be done before runtime, name: " << node->name_
|
||||
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type()))
|
||||
|
|
Loading…
Reference in New Issue