forked from mindspore-Ecosystem/mindspore
!29757 dynamic matmul opitmize with sdot
Merge pull request !29757 from yeyunpeng2020/dynamic_quant_success
This commit is contained in:
commit
11372f2a2a
|
@ -0,0 +1,487 @@
|
|||
#ifdef ENABLE_ARM64
|
||||
#include "nnacl/assembly_global.h"
|
||||
.text
|
||||
.align 5
|
||||
|
||||
// void DynamicMatmulSdot4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep, float *multi_scales,
|
||||
// float *bias, size_t row, size_t col, size_t stride);
|
||||
// x0: a(left matrix ptr)
|
||||
// x1: b(right matrix ptr)
|
||||
// x2: out ptr
|
||||
// x3: deep
|
||||
// x4: multi_scales
|
||||
// x5: bias
|
||||
// x6: row
|
||||
// x7: col
|
||||
// x8: stride
|
||||
|
||||
asm_function DynamicMatmulSdot4x4x16AIWI
|
||||
ldr x8, [sp]
|
||||
|
||||
dup v16.4s, wzr // dup:Duplicate general-purpose register to vector.
|
||||
dup v17.4s, wzr
|
||||
dup v18.4s, wzr
|
||||
dup v19.4s, wzr
|
||||
dup v20.4s, wzr
|
||||
dup v21.4s, wzr
|
||||
dup v22.4s, wzr
|
||||
dup v23.4s, wzr
|
||||
dup v24.4s, wzr
|
||||
dup v25.4s, wzr
|
||||
dup v26.4s, wzr
|
||||
dup v27.4s, wzr
|
||||
dup v28.4s, wzr
|
||||
dup v29.4s, wzr
|
||||
dup v30.4s, wzr
|
||||
dup v31.4s, wzr
|
||||
|
||||
mov x18, x1 // reload rhs ptr
|
||||
mov x17, x0 // reload lhs ptr
|
||||
mov x16, x3 // reload depth
|
||||
|
||||
cmp x7, #4
|
||||
ble LoopDepthQuarter
|
||||
cmp x7, #8
|
||||
ble LoopDepthHalf
|
||||
|
||||
LoopDepth:
|
||||
ld1 {v0.16b}, [x17], #16
|
||||
ld1 {v1.16b, v2.16b, v3.16b, v4.16b}, [x18], #64
|
||||
|
||||
sdot v16.4s, v1.16b, v0.4b[0]
|
||||
sdot v17.4s, v2.16b, v0.4b[0]
|
||||
sdot v18.4s, v3.16b, v0.4b[0]
|
||||
sdot v19.4s, v4.16b, v0.4b[0]
|
||||
sdot v20.4s, v1.16b, v0.4b[1]
|
||||
sdot v21.4s, v2.16b, v0.4b[1]
|
||||
sdot v22.4s, v3.16b, v0.4b[1]
|
||||
sdot v23.4s, v4.16b, v0.4b[1]
|
||||
sdot v24.4s, v1.16b, v0.4b[2]
|
||||
sdot v25.4s, v2.16b, v0.4b[2]
|
||||
sdot v26.4s, v3.16b, v0.4b[2]
|
||||
sdot v27.4s, v4.16b, v0.4b[2]
|
||||
sdot v28.4s, v1.16b, v0.4b[3]
|
||||
sdot v29.4s, v2.16b, v0.4b[3]
|
||||
sdot v30.4s, v3.16b, v0.4b[3]
|
||||
sdot v31.4s, v4.16b, v0.4b[3]
|
||||
|
||||
subs x16, x16, #4
|
||||
bgt LoopDepth
|
||||
b Convert2Float
|
||||
|
||||
LoopDepthHalf:
|
||||
ld1 {v0.16b}, [x17], #16
|
||||
ld1 {v1.16b, v2.16b}, [x18]
|
||||
add x18, x18, #64
|
||||
sdot v16.4s, v1.16b, v0.4b[0]
|
||||
sdot v17.4s, v2.16b, v0.4b[0]
|
||||
sdot v20.4s, v1.16b, v0.4b[1]
|
||||
sdot v21.4s, v2.16b, v0.4b[1]
|
||||
sdot v24.4s, v1.16b, v0.4b[2]
|
||||
sdot v25.4s, v2.16b, v0.4b[2]
|
||||
sdot v28.4s, v1.16b, v0.4b[3]
|
||||
sdot v29.4s, v2.16b, v0.4b[3]
|
||||
|
||||
subs x16, x16, #4
|
||||
bgt LoopDepthHalf
|
||||
b Convert2Float
|
||||
|
||||
LoopDepthQuarter:
|
||||
ld1 {v0.16b}, [x17], #16
|
||||
ld1 {v1.16b}, [x18]
|
||||
add x18, x18, #64
|
||||
sdot v16.4s, v1.16b, v0.4b[0]
|
||||
sdot v20.4s, v1.16b, v0.4b[1]
|
||||
sdot v24.4s, v1.16b, v0.4b[2]
|
||||
sdot v28.4s, v1.16b, v0.4b[3]
|
||||
|
||||
subs x16, x16, #4
|
||||
bgt LoopDepthQuarter
|
||||
b Convert2Float
|
||||
|
||||
Convert2Float:
|
||||
scvtf v16.4s, v16.4s
|
||||
scvtf v17.4s, v17.4s
|
||||
scvtf v18.4s, v18.4s
|
||||
scvtf v19.4s, v19.4s
|
||||
scvtf v20.4s, v20.4s
|
||||
scvtf v21.4s, v21.4s
|
||||
scvtf v22.4s, v22.4s
|
||||
scvtf v23.4s, v23.4s
|
||||
scvtf v24.4s, v24.4s
|
||||
scvtf v25.4s, v25.4s
|
||||
scvtf v26.4s, v26.4s
|
||||
scvtf v27.4s, v27.4s
|
||||
scvtf v28.4s, v28.4s
|
||||
scvtf v29.4s, v29.4s
|
||||
scvtf v30.4s, v30.4s
|
||||
scvtf v31.4s, v31.4s
|
||||
|
||||
MultiplyScale:
|
||||
// multi_scale * input_matrix
|
||||
ld1 {v1.4s, v2.4s, v3.4s, v4.4s}, [x4]
|
||||
|
||||
fmul v16.4s,v16.4s,v1.4s
|
||||
fmul v17.4s,v17.4s,v2.4s
|
||||
fmul v18.4s,v18.4s,v3.4s
|
||||
fmul v19.4s,v19.4s,v4.4s
|
||||
|
||||
fmul v20.4s,v20.4s,v1.4s
|
||||
fmul v21.4s,v21.4s,v2.4s
|
||||
fmul v22.4s,v22.4s,v3.4s
|
||||
fmul v23.4s,v23.4s,v4.4s
|
||||
|
||||
fmul v24.4s,v24.4s,v1.4s
|
||||
fmul v25.4s,v25.4s,v2.4s
|
||||
fmul v26.4s,v26.4s,v3.4s
|
||||
fmul v27.4s,v27.4s,v4.4s
|
||||
|
||||
fmul v28.4s,v28.4s,v1.4s
|
||||
fmul v29.4s,v29.4s,v2.4s
|
||||
fmul v30.4s,v30.4s,v3.4s
|
||||
fmul v31.4s,v31.4s,v4.4s
|
||||
|
||||
AddBias:
|
||||
// +bias
|
||||
cbz x5, StoreData
|
||||
ld1 {v1.4s, v2.4s, v3.4s, v4.4s}, [x5]
|
||||
|
||||
fadd v16.4s,v16.4s,v1.4s
|
||||
fadd v17.4s,v17.4s,v2.4s
|
||||
fadd v18.4s,v18.4s,v3.4s
|
||||
fadd v19.4s,v19.4s,v4.4s
|
||||
|
||||
fadd v20.4s,v20.4s,v1.4s
|
||||
fadd v21.4s,v21.4s,v2.4s
|
||||
fadd v22.4s,v22.4s,v3.4s
|
||||
fadd v23.4s,v23.4s,v4.4s
|
||||
|
||||
fadd v24.4s,v24.4s,v1.4s
|
||||
fadd v25.4s,v25.4s,v2.4s
|
||||
fadd v26.4s,v26.4s,v3.4s
|
||||
fadd v27.4s,v27.4s,v4.4s
|
||||
|
||||
fadd v28.4s,v28.4s,v1.4s
|
||||
fadd v29.4s,v29.4s,v2.4s
|
||||
fadd v30.4s,v30.4s,v3.4s
|
||||
fadd v31.4s,v31.4s,v4.4s
|
||||
|
||||
StoreData:
|
||||
cmp x7, #16
|
||||
beq Write16
|
||||
|
||||
mov x15, x2 // reload out ptr
|
||||
add x14, x15, x8
|
||||
add x13, x14, x8
|
||||
add x12, x13, x8
|
||||
|
||||
cmp x7, #15
|
||||
beq Write15
|
||||
cmp x7, #14
|
||||
beq Write14
|
||||
cmp x7, #13
|
||||
beq Write13
|
||||
cmp x7, #12
|
||||
beq Write12
|
||||
cmp x7, #11
|
||||
beq Write11
|
||||
cmp x7, #10
|
||||
beq Write10
|
||||
cmp x7, #9
|
||||
beq Write9
|
||||
cmp x7, #8
|
||||
beq Write8
|
||||
cmp x7, #7
|
||||
beq Write7
|
||||
cmp x7, #6
|
||||
beq Write6
|
||||
cmp x7, #5
|
||||
beq Write5
|
||||
cmp x7, #4
|
||||
beq Write4
|
||||
cmp x7, #3
|
||||
beq Write3
|
||||
cmp x7, #2
|
||||
beq Write2
|
||||
cmp x7, #1
|
||||
beq Write1
|
||||
b StoreDataEnd
|
||||
|
||||
Write16:
|
||||
cmp x6, #4
|
||||
beq Write16Row4
|
||||
cmp x6, #3
|
||||
beq Write16Row3
|
||||
cmp x6, #2
|
||||
beq Write16Row2
|
||||
cmp x6, #1
|
||||
beq Write16Row1
|
||||
|
||||
Write16Row4:
|
||||
st1 {v16.4s,v17.4s,v18.4s,v19.4s}, [x2], x8
|
||||
st1 {v20.4s,v21.4s,v22.4s,v23.4s}, [x2], x8
|
||||
st1 {v24.4s,v25.4s,v26.4s,v27.4s}, [x2], x8
|
||||
st1 {v28.4s,v29.4s,v30.4s,v31.4s}, [x2]
|
||||
b StoreDataEnd
|
||||
Write16Row3:
|
||||
st1 {v16.4s,v17.4s,v18.4s,v19.4s}, [x2], x8
|
||||
st1 {v20.4s,v21.4s,v22.4s,v23.4s}, [x2], x8
|
||||
st1 {v24.4s,v25.4s,v26.4s,v27.4s}, [x2]
|
||||
b StoreDataEnd
|
||||
Write16Row2:
|
||||
st1 {v16.4s,v17.4s,v18.4s,v19.4s}, [x2], x8
|
||||
st1 {v20.4s,v21.4s,v22.4s,v23.4s}, [x2]
|
||||
b StoreDataEnd
|
||||
Write16Row1:
|
||||
st1 {v16.4s,v17.4s,v18.4s,v19.4s}, [x2]
|
||||
b StoreDataEnd
|
||||
|
||||
Write15:
|
||||
st1 {v16.4s,v17.4s,v18.4s}, [x15], #48
|
||||
st1 {v19.1d}, [x15], #8
|
||||
st1 {v19.s}[2], [x15]
|
||||
cmp x6, #1
|
||||
beq StoreDataEnd
|
||||
st1 {v20.4s,v21.4s,v22.4s}, [x14], #48
|
||||
st1 {v23.1d}, [x14], #8
|
||||
st1 {v23.s}[2], [x14]
|
||||
cmp x6, #2
|
||||
beq StoreDataEnd
|
||||
st1 {v24.4s,v25.4s,v26.4s}, [x13], #48
|
||||
st1 {v27.1d}, [x13], #8
|
||||
st1 {v27.s}[2], [x13]
|
||||
cmp x6, #3
|
||||
beq StoreDataEnd
|
||||
st1 {v28.4s,v29.4s,v30.4s}, [x12], #48
|
||||
st1 {v31.1d}, [x12], #8
|
||||
st1 {v31.s}[2], [x12]
|
||||
b StoreDataEnd
|
||||
|
||||
Write14:
|
||||
st1 {v16.4s,v17.4s,v18.4s}, [x15], #48
|
||||
st1 {v19.1d}, [x15]
|
||||
cmp x6, #1
|
||||
beq StoreDataEnd
|
||||
st1 {v20.4s,v21.4s,v22.4s}, [x14], #48
|
||||
st1 {v23.1d}, [x14]
|
||||
cmp x6, #2
|
||||
beq StoreDataEnd
|
||||
st1 {v24.4s,v25.4s,v26.4s}, [x13], #48
|
||||
st1 {v27.1d}, [x13]
|
||||
cmp x6, #3
|
||||
beq StoreDataEnd
|
||||
st1 {v28.4s,v29.4s,v30.4s}, [x12], #48
|
||||
st1 {v31.1d}, [x12]
|
||||
b StoreDataEnd
|
||||
|
||||
Write13:
|
||||
st1 {v16.4s,v17.4s,v18.4s}, [x15], #48
|
||||
st1 {v19.s}[0], [x15]
|
||||
cmp x6, #1
|
||||
beq StoreDataEnd
|
||||
st1 {v20.4s,v21.4s,v22.4s}, [x14], #48
|
||||
st1 {v23.s}[0], [x14]
|
||||
cmp x6, #2
|
||||
beq StoreDataEnd
|
||||
st1 {v24.4s,v25.4s,v26.4s}, [x13], #48
|
||||
st1 {v27.s}[0], [x13]
|
||||
cmp x6, #3
|
||||
beq StoreDataEnd
|
||||
st1 {v28.4s,v29.4s,v30.4s}, [x12], #48
|
||||
st1 {v31.s}[0], [x12]
|
||||
b StoreDataEnd
|
||||
|
||||
Write12:
|
||||
st1 {v16.4s,v17.4s,v18.4s}, [x15], #48
|
||||
cmp x6, #1
|
||||
beq StoreDataEnd
|
||||
st1 {v20.4s,v21.4s,v22.4s}, [x14], #48
|
||||
cmp x6, #2
|
||||
beq StoreDataEnd
|
||||
st1 {v24.4s,v25.4s,v26.4s}, [x13], #48
|
||||
cmp x6, #3
|
||||
beq StoreDataEnd
|
||||
st1 {v28.4s,v29.4s,v30.4s}, [x12], #48
|
||||
b StoreDataEnd
|
||||
|
||||
Write11:
|
||||
st1 {v16.4s,v17.4s}, [x15], #32
|
||||
st1 {v18.1d}, [x15], #8
|
||||
st1 {v18.s}[2], [x15]
|
||||
cmp x6, #1
|
||||
beq StoreDataEnd
|
||||
st1 {v20.4s,v21.4s}, [x14], #32
|
||||
st1 {v22.1d}, [x14], #8
|
||||
st1 {v22.s}[2], [x14]
|
||||
cmp x6, #2
|
||||
beq StoreDataEnd
|
||||
st1 {v24.4s,v25.4s}, [x13], #32
|
||||
st1 {v26.1d}, [x13], #8
|
||||
st1 {v26.s}[2], [x13]
|
||||
cmp x6, #3
|
||||
beq StoreDataEnd
|
||||
st1 {v28.4s,v29.4s}, [x12], #32
|
||||
st1 {v30.1d}, [x12], #8
|
||||
st1 {v30.s}[2], [x12]
|
||||
b StoreDataEnd
|
||||
|
||||
Write10:
|
||||
st1 {v16.4s,v17.4s}, [x15], #32
|
||||
st1 {v18.1d}, [x15]
|
||||
cmp x6, #1
|
||||
beq StoreDataEnd
|
||||
st1 {v20.4s,v21.4s}, [x14], #32
|
||||
st1 {v22.1d}, [x14]
|
||||
cmp x6, #2
|
||||
beq StoreDataEnd
|
||||
st1 {v24.4s,v25.4s}, [x13], #32
|
||||
st1 {v26.1d}, [x13]
|
||||
cmp x6, #3
|
||||
beq StoreDataEnd
|
||||
st1 {v28.4s,v29.4s}, [x12], #32
|
||||
st1 {v30.1d}, [x12]
|
||||
b StoreDataEnd
|
||||
|
||||
Write9:
|
||||
st1 {v16.4s,v17.4s}, [x15], #32
|
||||
st1 {v18.s}[0], [x15]
|
||||
cmp x6, #1
|
||||
beq StoreDataEnd
|
||||
st1 {v20.4s,v21.4s}, [x14], #32
|
||||
st1 {v22.s}[0], [x14]
|
||||
cmp x6, #2
|
||||
beq StoreDataEnd
|
||||
st1 {v24.4s,v25.4s}, [x13], #32
|
||||
st1 {v26.s}[0], [x13]
|
||||
cmp x6, #3
|
||||
beq StoreDataEnd
|
||||
st1 {v28.4s,v29.4s}, [x12], #32
|
||||
st1 {v30.s}[0], [x12]
|
||||
b StoreDataEnd
|
||||
|
||||
Write8:
|
||||
st1 {v16.4s,v17.4s}, [x15], #32
|
||||
cmp x6, #1
|
||||
beq StoreDataEnd
|
||||
st1 {v20.4s,v21.4s}, [x14], #32
|
||||
cmp x6, #2
|
||||
beq StoreDataEnd
|
||||
st1 {v24.4s,v25.4s}, [x13], #32
|
||||
cmp x6, #3
|
||||
beq StoreDataEnd
|
||||
st1 {v28.4s,v29.4s}, [x12], #32
|
||||
b StoreDataEnd
|
||||
|
||||
Write7:
|
||||
st1 {v16.4s}, [x15], #16
|
||||
st1 {v17.1d}, [x15], #8
|
||||
st1 {v17.s}[2], [x15]
|
||||
cmp x6, #1
|
||||
beq StoreDataEnd
|
||||
st1 {v20.4s}, [x14], #16
|
||||
st1 {v21.1d}, [x14], #8
|
||||
st1 {v21.s}[2], [x14]
|
||||
cmp x6, #2
|
||||
beq StoreDataEnd
|
||||
st1 {v24.4s}, [x13], #16
|
||||
st1 {v25.1d}, [x13], #8
|
||||
st1 {v25.s}[2], [x13]
|
||||
cmp x6, #3
|
||||
beq StoreDataEnd
|
||||
st1 {v28.4s}, [x12], #16
|
||||
st1 {v29.1d}, [x12], #8
|
||||
st1 {v29.s}[2], [x12]
|
||||
b StoreDataEnd
|
||||
|
||||
Write6:
|
||||
st1 {v16.4s}, [x15], #16
|
||||
st1 {v17.1d}, [x15]
|
||||
cmp x6, #1
|
||||
beq StoreDataEnd
|
||||
st1 {v20.4s}, [x14], #16
|
||||
st1 {v21.1d}, [x14]
|
||||
cmp x6, #2
|
||||
beq StoreDataEnd
|
||||
st1 {v24.4s}, [x13], #16
|
||||
st1 {v25.1d}, [x13]
|
||||
cmp x6, #3
|
||||
beq StoreDataEnd
|
||||
st1 {v28.4s}, [x12], #16
|
||||
st1 {v29.1d}, [x12]
|
||||
b StoreDataEnd
|
||||
|
||||
Write5:
|
||||
st1 {v16.4s}, [x15], #16
|
||||
st1 {v17.s}[0], [x15]
|
||||
cmp x6, #1
|
||||
beq StoreDataEnd
|
||||
st1 {v20.4s}, [x14], #16
|
||||
st1 {v21.s}[0], [x14]
|
||||
cmp x6, #2
|
||||
beq StoreDataEnd
|
||||
st1 {v24.4s}, [x13], #16
|
||||
st1 {v25.s}[0], [x13]
|
||||
cmp x6, #3
|
||||
beq StoreDataEnd
|
||||
st1 {v28.4s}, [x12], #16
|
||||
st1 {v29.s}[0], [x12]
|
||||
b StoreDataEnd
|
||||
|
||||
Write4:
|
||||
st1 {v16.4s}, [x15]
|
||||
cmp x6, #1
|
||||
beq StoreDataEnd
|
||||
st1 {v20.4s}, [x14]
|
||||
cmp x6, #2
|
||||
beq StoreDataEnd
|
||||
st1 {v24.4s}, [x13]
|
||||
cmp x6, #3
|
||||
beq StoreDataEnd
|
||||
st1 {v28.4s}, [x12]
|
||||
b StoreDataEnd
|
||||
|
||||
Write3:
|
||||
st1 {v16.1d}, [x15]
|
||||
st1 {v16.s}[2], [x15]
|
||||
cmp x6, #1
|
||||
beq StoreDataEnd
|
||||
st1 {v20.1d}, [x14]
|
||||
st1 {v20.s}[2], [x14]
|
||||
cmp x6, #2
|
||||
beq StoreDataEnd
|
||||
st1 {v24.1d}, [x13]
|
||||
st1 {v24.s}[2], [x13]
|
||||
cmp x6, #3
|
||||
beq StoreDataEnd
|
||||
st1 {v28.1d}, [x12]
|
||||
st1 {v28.s}[2], [x12]
|
||||
b StoreDataEnd
|
||||
|
||||
Write2:
|
||||
st1 {v16.1d}, [x15]
|
||||
cmp x6, #1
|
||||
beq StoreDataEnd
|
||||
st1 {v20.1d}, [x14]
|
||||
cmp x6, #2
|
||||
beq StoreDataEnd
|
||||
st1 {v24.1d}, [x13]
|
||||
cmp x6, #3
|
||||
beq StoreDataEnd
|
||||
st1 {v28.1d}, [x12]
|
||||
b StoreDataEnd
|
||||
|
||||
Write1:
|
||||
st1 {v16.s}[0], [x15]
|
||||
cmp x6, #1
|
||||
beq StoreDataEnd
|
||||
st1 {v20.s}[0], [x14]
|
||||
cmp x6, #2
|
||||
beq StoreDataEnd
|
||||
st1 {v24.s}[0], [x13]
|
||||
cmp x6, #3
|
||||
beq StoreDataEnd
|
||||
st1 {v28.s}[0], [x12]
|
||||
b StoreDataEnd
|
||||
StoreDataEnd:
|
||||
ret
|
||||
#endif
|
|
@ -0,0 +1,313 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/int8/dynamic_matmul_int8.h"
|
||||
#include "nnacl/int8/fixed_point.h"
|
||||
|
||||
void DynamicMatmul4x4x16AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col,
|
||||
int deep4, size_t stride, float input_scale, const float *filter_scale,
|
||||
bool filter_per_channel) {
|
||||
/* *
|
||||
* row4x4-major * row4x16-major => (int8)row-major
|
||||
* support activation per-layer symmetric && weight per-layer/per-channel symmetric
|
||||
* */
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int c = 0; c < col; c++) {
|
||||
int r4div = r / C4NUM, r4mod = r % C4NUM;
|
||||
int c16div = c / C16NUM, c16mod = c % C16NUM;
|
||||
int32_t value = 0;
|
||||
for (int d = 0; d < deep4; d++) {
|
||||
int d4div = d / C4NUM, d4mod = d % C4NUM;
|
||||
size_t ai = r4div * deep4 * C4NUM + d4div * C4NUM * C4NUM + r4mod * C4NUM + d4mod;
|
||||
size_t bi = c16div * deep4 * C16NUM + d4div * C4NUM * C16NUM + c16mod * C4NUM + d4mod;
|
||||
value += a[ai] * b[bi];
|
||||
}
|
||||
int filter_quant_index = filter_per_channel ? c : 0;
|
||||
double multi_scale = input_scale * filter_scale[filter_quant_index];
|
||||
size_t ci = r * stride + c;
|
||||
dst[ci] = multi_scale * value;
|
||||
if (bias != NULL) {
|
||||
dst[ci] += bias[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void DynamicMatmul4x16x4AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col,
|
||||
int deep, int deep16, size_t stride, int input_zp, float input_scale,
|
||||
const float *filter_scale, const int filter_zp, bool filter_per_channel) {
|
||||
/* *
|
||||
* row4x16-major * row16x4-major => (int8)row-major
|
||||
* support activation per-layer symmetric && weight per-layer/per-channel symmetric
|
||||
* */
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int c = 0; c < col; c++) {
|
||||
int r4div = r / C4NUM, r4mod = r % C4NUM;
|
||||
int c4div = c / C4NUM, c4mod = c % C4NUM;
|
||||
int32_t value = 0;
|
||||
for (int d = 0; d < deep; d++) {
|
||||
int d16div = d / C16NUM, d16mod = d % C16NUM;
|
||||
size_t ai = r4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod;
|
||||
size_t bi = c4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod;
|
||||
value += (a[ai] - input_zp) * (b[bi] - filter_zp);
|
||||
}
|
||||
int filter_quant_index = filter_per_channel ? c : 0;
|
||||
double multi_scale = input_scale * filter_scale[filter_quant_index];
|
||||
size_t ci = r * stride + c;
|
||||
dst[ci] = multi_scale * value;
|
||||
if (bias != NULL) {
|
||||
dst[ci] += bias[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void PackInput4x4Asm(const int8_t *src_ic, int8_t *pack_ic, size_t ic_4div, size_t input_channel) {
|
||||
size_t src_stride = input_channel;
|
||||
size_t ic_4res = input_channel - ic_4div;
|
||||
asm volatile(
|
||||
"dup v2.4s, wzr \n"
|
||||
|
||||
"mov x10, %[src_ic] \n"
|
||||
"mov x11, %[pack_ic] \n"
|
||||
|
||||
"mov x15, #0 \n"
|
||||
"1: \n"
|
||||
"cmp x15, %[ic_4div] \n"
|
||||
"add x15, x15, #4\n"
|
||||
"mov x12, x10 \n"
|
||||
"add x10, x10, #4\n"
|
||||
"blt 2f \n"
|
||||
"cmp %[ic_4res], #0\n"
|
||||
"beq 6f \n"
|
||||
"cmp %[ic_4res], #1\n"
|
||||
"beq 3f \n"
|
||||
"cmp %[ic_4res], #2\n"
|
||||
"beq 4f \n"
|
||||
"cmp %[ic_4res], #3\n"
|
||||
"beq 5f \n"
|
||||
|
||||
"2: \n"
|
||||
"ld1 {v0.s}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[1], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[3], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
|
||||
"b 1b \n"
|
||||
|
||||
"3: \n" /* ic res 1 */
|
||||
"dup v0.4s, wzr \n"
|
||||
|
||||
"ld1 {v0.b}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[8], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[12], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
|
||||
"b 6f \n"
|
||||
|
||||
"4: \n" /* ic res 2 */
|
||||
"dup v0.4s, wzr \n"
|
||||
|
||||
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
|
||||
"b 6f \n"
|
||||
|
||||
"5: \n" /* ic res 3 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"add x13, x12, #2 \n"
|
||||
|
||||
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[2], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[6], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[10], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[14], [x13], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
|
||||
"b 6f \n"
|
||||
|
||||
"6: \n"
|
||||
|
||||
:
|
||||
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div),
|
||||
[ ic_4res ] "r"(ic_4res)
|
||||
: "x10", "x11", "x12", "x13", "x14", "x15", "v0", "v1", "v2", "v3");
|
||||
}
|
||||
#endif
|
||||
|
||||
void PackInput4x4(const int8_t *src_input, int8_t *packed_input, size_t input_channel, size_t plane_size) {
|
||||
int ic4 = UP_ROUND(input_channel, C4NUM);
|
||||
size_t hw_4div = plane_size / C4NUM * C4NUM;
|
||||
size_t ic_4div = input_channel / C4NUM * C4NUM;
|
||||
|
||||
const int8_t *src_r = src_input;
|
||||
int8_t *pack_r = packed_input;
|
||||
/* per layer */
|
||||
for (int hwi = 0; hwi < hw_4div; hwi += C4NUM) {
|
||||
const int8_t *src_ic = src_r;
|
||||
int8_t *pack_ic = pack_r;
|
||||
#ifdef ENABLE_ARM64
|
||||
PackInput4x4Asm(src_ic, pack_ic, ic_4div, input_channel);
|
||||
#else
|
||||
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
|
||||
for (size_t i = 0; i < C4NUM; i++) {
|
||||
pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel];
|
||||
pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel];
|
||||
pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel];
|
||||
pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel];
|
||||
}
|
||||
src_ic += C4NUM;
|
||||
pack_ic += C4NUM * C4NUM;
|
||||
}
|
||||
for (int ici = ic_4div; ici < input_channel; ici += 1) {
|
||||
for (int i = 0; i < C4NUM; i++) {
|
||||
pack_ic[i * C4NUM] = src_ic[i * input_channel];
|
||||
}
|
||||
src_ic += 1;
|
||||
pack_ic += 1;
|
||||
}
|
||||
|
||||
for (int ici = input_channel; ici < ic4; ici += 1) {
|
||||
for (int i = 0; i < C4NUM; i++) {
|
||||
pack_ic[i * C4NUM] = 0;
|
||||
}
|
||||
pack_ic += 1;
|
||||
}
|
||||
#endif
|
||||
src_r += input_channel * C4NUM;
|
||||
pack_r += ic4 * C4NUM;
|
||||
}
|
||||
|
||||
if (hw_4div != plane_size) {
|
||||
memset(pack_r, 0, C4NUM * ic4);
|
||||
for (int hwi = hw_4div; hwi < plane_size; hwi += 1) {
|
||||
const int8_t *src_ic = src_r;
|
||||
int8_t *pack_ic = pack_r;
|
||||
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
|
||||
pack_ic[0] = src_ic[0];
|
||||
pack_ic[1] = src_ic[1];
|
||||
pack_ic[2] = src_ic[2];
|
||||
pack_ic[3] = src_ic[3];
|
||||
src_ic += C4NUM;
|
||||
pack_ic += C4NUM * C4NUM;
|
||||
}
|
||||
src_r += input_channel;
|
||||
pack_r += C4NUM;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// For matmul input a transpose case
|
||||
void PackInput2Col4x4(const int8_t *src_input, int8_t *packed_input, int row, int col, int row_stride) {
|
||||
const int row_tile = C4NUM;
|
||||
int row_align = UP_ROUND(row, row_tile);
|
||||
int row_div = row / row_tile * row_tile;
|
||||
const int row_res = row - row_div;
|
||||
|
||||
const int col_tile = C4NUM;
|
||||
int col_div = col / col_tile * col_tile;
|
||||
const int col_res = col - col_div;
|
||||
|
||||
const int8_t *src_ic = NULL;
|
||||
int8_t *packed_ic = NULL;
|
||||
for (int c = 0; c < col_div; c += C4NUM) {
|
||||
int r = 0;
|
||||
src_ic = src_input + c;
|
||||
packed_ic = packed_input + c * row_align;
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t row_stride_int64 = row_stride;
|
||||
asm volatile(
|
||||
"mov w10, %w[row]\n"
|
||||
"mov x11, %[src_ic]\n"
|
||||
"mov x12, %[packed_ic]\n"
|
||||
"cmp w10, wzr\n"
|
||||
"beq 1f\n"
|
||||
"2:\n"
|
||||
"subs w10, w10, #4\n"
|
||||
"ld1 {v0.s}[0], [x11], %[row_stride]\n"
|
||||
"ld1 {v1.s}[0], [x11], %[row_stride]\n"
|
||||
"ld1 {v0.s}[1], [x11], %[row_stride]\n"
|
||||
"ld1 {v1.s}[1], [x11], %[row_stride]\n"
|
||||
"zip1 v2.8b, v0.8b, v1.8b\n"
|
||||
"zip2 v3.8b, v0.8b, v1.8b\n"
|
||||
"zip1 v4.4h, v2.4h, v3.4h\n"
|
||||
"zip2 v5.4h, v2.4h, v3.4h\n"
|
||||
"st1 {v4.4h, v5.4h}, [x12], #16\n"
|
||||
|
||||
"bgt 2b\n"
|
||||
"1:\n"
|
||||
|
||||
:
|
||||
: [ src_ic ] "r"(src_ic), [ packed_ic ] "r"(packed_ic), [ row ] "r"(row_div), [ row_stride ] "r"(row_stride_int64)
|
||||
: "memory", "w10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12");
|
||||
packed_ic += C4NUM * row_div;
|
||||
src_ic += row_div * row_stride;
|
||||
#else
|
||||
for (; r < row_div; r += C4NUM) {
|
||||
for (int i = 0; i < row_tile; i++) {
|
||||
packed_ic[0 * row_tile + i] = src_ic[i * row_stride + 0];
|
||||
packed_ic[1 * row_tile + i] = src_ic[i * row_stride + 1];
|
||||
packed_ic[2 * row_tile + i] = src_ic[i * row_stride + 2];
|
||||
packed_ic[3 * row_tile + i] = src_ic[i * row_stride + 3];
|
||||
}
|
||||
packed_ic += C16NUM;
|
||||
src_ic += row_tile * row_stride;
|
||||
}
|
||||
#endif
|
||||
for (r = 0; r < row_res; ++r) {
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
packed_ic[i * row_tile + r] = src_ic[r * row_stride + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
if (col_res == 0) {
|
||||
return;
|
||||
}
|
||||
src_ic = src_input + col_div;
|
||||
packed_ic = packed_input + row_align * col_div;
|
||||
for (int r = 0; r < row_div; r += row_tile) {
|
||||
for (int i = 0; i < col_res; ++i) {
|
||||
packed_ic[i * row_tile + 0] = src_ic[0 * row_stride + i];
|
||||
packed_ic[i * row_tile + 1] = src_ic[1 * row_stride + i];
|
||||
packed_ic[i * row_tile + 2] = src_ic[2 * row_stride + i];
|
||||
packed_ic[i * row_tile + 3] = src_ic[3 * row_stride + i];
|
||||
}
|
||||
src_ic += row_tile * row_stride;
|
||||
packed_ic += row_tile * col_tile;
|
||||
}
|
||||
|
||||
for (int r = 0; r < row_res; ++r) {
|
||||
for (int c = 0; c < col_res; ++c) {
|
||||
packed_ic[c * row_tile + r] = src_ic[r * row_stride + c];
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_NNACL_INT8_DYNAMIC_MATMUL_H_
|
||||
#define MINDSPORE_NNACL_INT8_DYNAMIC_MATMUL_H_
|
||||
|
||||
#include <string.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/matmul_parameter.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void PackInput2Col4x4(const int8_t *src_input, int8_t *packed_input, int row, int col, int row_stride);
|
||||
void PackInput4x4(const int8_t *src_input, int8_t *packed_input, size_t input_channel, size_t plane_size);
|
||||
void DynamicMatmul4x16x4AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col,
|
||||
int deep, int deep16, size_t stride, int input_zp, float input_scale,
|
||||
const float *filter_scale, const int filter_zp, bool filter_per_channel);
|
||||
#ifdef ENABLE_ARM64
|
||||
void DynamicMatmulSdot4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, float *multi_scales,
|
||||
float *bias, size_t row, size_t col, size_t stride);
|
||||
#else
|
||||
void DynamicMatmul4x4x16AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col,
|
||||
int deep4, size_t stride, float input_scale, const float *filter_scale,
|
||||
bool filter_per_channel);
|
||||
#endif
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_NNACL_INT8_DYNAMIC_MATMUL_H_
|
|
@ -15,8 +15,8 @@
|
|||
*/
|
||||
|
||||
#include "nnacl/int8/dynamic_quant_int8.h"
|
||||
|
||||
void CalculateMinMaxFp32(const float *data, int count, float *real_min, float *real_max) {
|
||||
#ifndef PLATFORM_ARM64
|
||||
for (int i = 0; i < count; ++i) {
|
||||
if (data[i] < *real_min) {
|
||||
*real_min = data[i];
|
||||
|
@ -25,4 +25,40 @@ void CalculateMinMaxFp32(const float *data, int count, float *real_min, float *r
|
|||
*real_max = data[i];
|
||||
}
|
||||
}
|
||||
#else
|
||||
int count_4 = DOWN_ROUND(count, C4NUM);
|
||||
asm volatile(
|
||||
"mov x4, %[data]\n" // reload data
|
||||
"mov w5, %w[count_4]\n" // reload count
|
||||
"ld1 {v31.4s}, [x4]\n" // min
|
||||
"ld1 {v30.4s}, [x4], #16\n" // max
|
||||
"subs w5, w5, #4\n"
|
||||
"ble MinMax\n"
|
||||
|
||||
"LoopCount:\n"
|
||||
"ld1 {v0.4s}, [x4], #16\n"
|
||||
"fmin v31.4s, v31.4s, v0.4s\n"
|
||||
"fmax v30.4s, v30.4s, v0.4s\\nn"
|
||||
"subs w5, w5, #4\n"
|
||||
"bgt LoopCount\n"
|
||||
|
||||
"MinMax:\n"
|
||||
"fminv s6, v31.4s\n"
|
||||
"fmaxv s7, v30.4s\n"
|
||||
|
||||
"str s6, %[real_min]\n"
|
||||
"str s7, %[real_max]\n"
|
||||
|
||||
:
|
||||
: [ data ] "r"(data), [ count_4 ] "r"(count_4), [ real_min ] "r"(real_min), [ real_max ] "r"(real_max)
|
||||
: "x4", "w5", "s6", "s7", "v0", "v30", "v31");
|
||||
for (int i = count_4; i < count; ++i) {
|
||||
if (data[i] < *real_min) {
|
||||
*real_min = data[i];
|
||||
}
|
||||
if (data[i] > *real_max) {
|
||||
*real_max = data[i];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -330,38 +330,6 @@ void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int c
|
|||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
void DynamicMatmulInt8AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col,
|
||||
int deep16, float input_scale, const float *filter_scale, size_t stride,
|
||||
bool filter_per_channel) {
|
||||
/* *
|
||||
* row4x16-major * row16x4-major => (int8)row-major
|
||||
* support activation per-layer symmetric && weight per-layer/per-channel symmetric
|
||||
* */
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int c = 0; c < col; c++) {
|
||||
int r4div = r / C4NUM, r4mod = r % C4NUM;
|
||||
int c4div = c / C4NUM, c4mod = c % C4NUM;
|
||||
int filter_quant_index = filter_per_channel ? c : 0;
|
||||
double multi_scale = input_scale * filter_scale[filter_quant_index];
|
||||
double value = 0;
|
||||
for (int d = 0; d < deep16; d++) {
|
||||
int d16div = d / C16NUM, d16mod = d % C16NUM;
|
||||
size_t ai = r4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod;
|
||||
size_t bi = c4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod;
|
||||
int32_t value_1 = a[ai] * b[bi];
|
||||
value += multi_scale * value_1;
|
||||
}
|
||||
if (bias != NULL) {
|
||||
value += bias[c];
|
||||
}
|
||||
size_t ci = r * stride + c;
|
||||
dst[ci] = value;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
|
||||
size_t stride, const int32_t *input_sum, const int32_t *bias, const int32_t *left_shift,
|
||||
const int32_t *right_shift, const int32_t *multiplier, int32_t output_zp, int32_t mini,
|
||||
|
|
|
@ -45,11 +45,6 @@ void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int c
|
|||
const int *bias, int act_min, int act_max, int out_zp, const int32_t *multiplier,
|
||||
const int32_t *left_shift, const int32_t *right_shift, size_t stride, size_t filter_peroc,
|
||||
const int32_t *filter_zp);
|
||||
|
||||
void DynamicMatmulInt8AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col,
|
||||
int deep16, float input_scale, const float *filter_scale, size_t stride,
|
||||
bool filter_per_channel);
|
||||
|
||||
/* 8x4 4x8 -> 8x8 */
|
||||
/* optimize conv */
|
||||
void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
|
||||
|
|
|
@ -87,7 +87,7 @@ int DynamicQuantCPUKernel::CalculateMinMax(int task_id) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int CalculateMinMaxRun(void *cdata, int task_id, float, float) {
|
||||
int CalculateMinMaxRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
CHECK_NULL_RETURN(cdata);
|
||||
auto g_kernel = reinterpret_cast<DynamicQuantCPUKernel *>(cdata);
|
||||
auto ret = g_kernel->CalculateMinMax(task_id);
|
||||
|
@ -112,10 +112,16 @@ void DynamicQuantCPUKernel::ReduceMinMaxFp32() {
|
|||
|
||||
void DynamicQuantCPUKernel::CalculateScaleZp() {
|
||||
lite::LiteQuantParam quant_parm;
|
||||
double scale = (real_max_ - real_min_) / (INT8_MAX - INT8_MIN);
|
||||
double scale;
|
||||
int zp = 0;
|
||||
constexpr int kQSymmetricRange = 255;
|
||||
constexpr int kQAsymmetricRange = 254;
|
||||
if (!symmetric_) {
|
||||
scale = (real_max_ - real_min_) / kQSymmetricRange; // -128 ~ 127
|
||||
zp = static_cast<int>(std::round(INT8_MIN - real_min_ / scale));
|
||||
} else {
|
||||
auto max = std::max(abs(real_max_), abs(real_min_));
|
||||
scale = 2 * max / kQAsymmetricRange; // -127 ~ 127
|
||||
}
|
||||
quant_parm.scale = scale;
|
||||
quant_parm.zeroPoint = zp;
|
||||
|
|
|
@ -0,0 +1,240 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/int8/matmul_dynamic_base_int8.h"
|
||||
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_MEMORY_FAILED;
|
||||
using mindspore::lite::RET_OK;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
namespace {
|
||||
constexpr int kHasBiasSize = 3;
|
||||
constexpr int kMinInputSize = 2;
|
||||
constexpr int kOutputSize = 1;
|
||||
constexpr int kSize1 = 1;
|
||||
constexpr int kSize2 = 2;
|
||||
} // namespace
|
||||
|
||||
MatmulDynamicBaseInt8CPUKernel::~MatmulDynamicBaseInt8CPUKernel() {
|
||||
FreeQuantParam();
|
||||
FreeTmpBuffer();
|
||||
}
|
||||
|
||||
void MatmulDynamicBaseInt8CPUKernel::FreeQuantParam() {
|
||||
if (quant_param_ != nullptr) {
|
||||
if (quant_param_->filter_scale_ != nullptr) {
|
||||
free(quant_param_->filter_scale_);
|
||||
quant_param_->filter_scale_ = nullptr;
|
||||
}
|
||||
if (quant_param_->filter_zp_ != nullptr) {
|
||||
free(quant_param_->filter_zp_);
|
||||
quant_param_->filter_zp_ = nullptr;
|
||||
}
|
||||
free(quant_param_);
|
||||
quant_param_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
int MatmulDynamicBaseInt8CPUKernel::MallocQuantParam() {
|
||||
quant_param_ = reinterpret_cast<MatmulDynamicQuantParameter *>(malloc(sizeof(MatmulQuantParameter)));
|
||||
if (quant_param_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc MatmulDynamicQuantParameter for Matmul int8 op failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(quant_param_, 0, sizeof(MatmulQuantParameter));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulDynamicBaseInt8CPUKernel::InitFilterQuantParam() {
|
||||
if (quant_param_->filter_scale_ != nullptr) {
|
||||
free(quant_param_->filter_scale_);
|
||||
quant_param_->filter_scale_ = nullptr;
|
||||
}
|
||||
if (quant_param_->filter_zp_ != nullptr) {
|
||||
free(quant_param_->filter_zp_);
|
||||
quant_param_->filter_zp_ = nullptr;
|
||||
}
|
||||
|
||||
auto weight_tensor = in_tensors_.at(kWeightIndex);
|
||||
auto weight_quant_params = weight_tensor->quant_params();
|
||||
auto w_shape = weight_tensor->shape();
|
||||
if (w_shape.size() < DIMENSION_2D) {
|
||||
MS_LOG(ERROR) << weight_tensor->tensor_name() << " dims < 2.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int col = param_->b_transpose_ ? w_shape[w_shape.size() - kSize2] : w_shape[w_shape.size() - kSize1];
|
||||
filter_per_channel_ = (weight_quant_params.size() > 1);
|
||||
channel_num_ = filter_per_channel_ ? col : 1;
|
||||
if (static_cast<int>(weight_quant_params.size()) != channel_num_) {
|
||||
MS_LOG(ERROR) << weight_tensor->tensor_name() << " quant params size:" << weight_quant_params.size()
|
||||
<< " != channel_num_:" << channel_num_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
quant_param_->filter_scale_ = reinterpret_cast<float *>(malloc(channel_num_ * sizeof(float)));
|
||||
CHECK_NULL_RETURN(quant_param_->filter_scale_);
|
||||
memset(quant_param_->filter_scale_, 0, sizeof(channel_num_));
|
||||
quant_param_->filter_zp_ = reinterpret_cast<int32_t *>(malloc(channel_num_ * sizeof(int32_t)));
|
||||
CHECK_NULL_RETURN(quant_param_->filter_zp_);
|
||||
memset(quant_param_->filter_zp_, 0, sizeof(channel_num_));
|
||||
|
||||
for (int i = 0; i < channel_num_; i++) {
|
||||
quant_param_->filter_scale_[i] = static_cast<float>(weight_quant_params[i].scale);
|
||||
quant_param_->filter_zp_[i] = weight_quant_params[i].zeroPoint;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void MatmulDynamicBaseInt8CPUKernel::ResizeParameter() {
|
||||
param_->row_align_ = UP_ROUND(param_->row_, row_tile_);
|
||||
param_->col_align_ = UP_ROUND(param_->col_, col_tile_);
|
||||
param_->deep_align_ = UP_ROUND(param_->deep_, deep_tile_);
|
||||
|
||||
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(param_->col_align_, col_tile_));
|
||||
thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, col_tile_), thread_count_);
|
||||
return;
|
||||
}
|
||||
|
||||
void MatmulDynamicBaseInt8CPUKernel::FreeTmpBuffer() {
|
||||
if (pack_a_ptr_ != nullptr) {
|
||||
free(pack_a_ptr_);
|
||||
pack_a_ptr_ = nullptr;
|
||||
}
|
||||
if (pack_b_ptr_ != nullptr) {
|
||||
free(pack_b_ptr_);
|
||||
pack_b_ptr_ = nullptr;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int MatmulDynamicBaseInt8CPUKernel::InitInputQuantParam() {
|
||||
auto in_quant_params = in_tensors_.at(kInputIndex)->quant_params();
|
||||
if (in_quant_params.empty()) {
|
||||
MS_LOG(ERROR) << "invalid in quant param";
|
||||
return RET_ERROR;
|
||||
}
|
||||
quant_param_->input_zp_ = in_quant_params.front().zeroPoint;
|
||||
quant_param_->input_scale_ = static_cast<float>(in_quant_params.front().scale);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulDynamicBaseInt8CPUKernel::TransferB() {
|
||||
auto weight_data = reinterpret_cast<int8_t *>(in_tensors_.at(kWeightIndex)->data());
|
||||
CHECK_NULL_RETURN(weight_data);
|
||||
memset(pack_b_ptr_, quant_param_->filter_zp_[0],
|
||||
param_->batch * param_->col_align_ * param_->deep_align_ * sizeof(int8_t));
|
||||
for (int i = 0; i < param_->batch; i++) {
|
||||
auto current_weight = weight_data + i * param_->deep_ * param_->col_;
|
||||
auto current_b_pack = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_;
|
||||
CHECK_NULL_RETURN(b_pack_func_);
|
||||
if (param_->b_transpose_) {
|
||||
b_pack_func_(current_weight, current_b_pack, param_->col_, param_->deep_);
|
||||
} else {
|
||||
b_pack_func_(current_weight, current_b_pack, param_->deep_, param_->col_);
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulDynamicBaseInt8CPUKernel::InitTmpBuffer() {
|
||||
pack_a_ptr_ = reinterpret_cast<int8_t *>(malloc(param_->row_align_ * param_->deep_align_ * sizeof(int8_t)));
|
||||
if (pack_a_ptr_ == nullptr) {
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
pack_b_ptr_ =
|
||||
reinterpret_cast<int8_t *>(malloc(param_->batch * param_->col_align_ * param_->deep_align_ * sizeof(int8_t)));
|
||||
if (pack_b_ptr_ == nullptr) {
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(pack_a_ptr_, 0, param_->row_align_ * param_->deep_align_ * sizeof(int8_t));
|
||||
memset(pack_b_ptr_, 0, param_->batch * param_->col_align_ * param_->deep_align_ * sizeof(int8_t));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulDynamicBaseInt8CPUKernel::CopyBias() {
|
||||
if (in_tensors_.size() == kHasBiasSize) {
|
||||
auto bias_tensor = in_tensors_[kBiasIndex];
|
||||
fp32_bias_ptr_ = reinterpret_cast<float *>(bias_tensor->data());
|
||||
if (fp32_bias_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Memory allocation failed";
|
||||
FreeTmpBuffer();
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
memcpy(fp32_bias_ptr_, bias_tensor->data(), bias_tensor->ElementsNum() * sizeof(float));
|
||||
} else {
|
||||
fp32_bias_ptr_ = nullptr;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulDynamicBaseInt8CPUKernel::Prepare() {
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), kMinInputSize);
|
||||
CHECK_LESS_RETURN(out_tensors_.size(), kOutputSize);
|
||||
InitParameter();
|
||||
auto ret = MallocQuantParam();
|
||||
if (ret != RET_OK) {
|
||||
FreeQuantParam();
|
||||
return ret;
|
||||
}
|
||||
if (param_->b_const_) {
|
||||
ret = InitFilterQuantParam();
|
||||
if (ret != RET_OK) {
|
||||
FreeQuantParam();
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
ret = CopyBias();
|
||||
if (ret != RET_OK) {
|
||||
FreeQuantParam();
|
||||
return ret;
|
||||
}
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int MatmulDynamicBaseInt8CPUKernel::ReSize() {
|
||||
int batch = 1;
|
||||
auto x_shape = in_tensors_.at(0)->shape();
|
||||
auto o_shape = out_tensors_.at(0)->shape();
|
||||
MS_ASSERT(x_shape.size() >= kSize2);
|
||||
for (size_t i = 0; i < x_shape.size() - kSize2; ++i) {
|
||||
batch *= x_shape[i];
|
||||
}
|
||||
param_->batch = batch;
|
||||
MS_ASSERT(o_shape.size() >= kSize2);
|
||||
param_->row_ = o_shape[o_shape.size() - kSize2];
|
||||
param_->col_ = o_shape[o_shape.size() - kSize1];
|
||||
param_->deep_ = param_->a_transpose_ ? x_shape[x_shape.size() - kSize2] : x_shape[x_shape.size() - kSize1];
|
||||
|
||||
FreeTmpBuffer();
|
||||
|
||||
ResizeParameter();
|
||||
|
||||
auto ret = InitTmpBuffer();
|
||||
if (ret != RET_OK) {
|
||||
FreeQuantParam();
|
||||
return ret;
|
||||
}
|
||||
if (param_->b_const_ == true) {
|
||||
TransferB();
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,79 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_DYNAMIC_BASE_INT8_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_DYNAMIC_BASE_INT8_H_
|
||||
|
||||
#include <vector>
|
||||
#include "include/errorcode.h"
|
||||
#include "include/context.h"
|
||||
#include "src/inner_kernel.h"
|
||||
#include "nnacl/matmul_parameter.h"
|
||||
#include "nnacl/common_func.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
#include "nnacl/int8/common_func_int8.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class MatmulDynamicBaseInt8CPUKernel : public InnerKernel {
|
||||
public:
|
||||
MatmulDynamicBaseInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: InnerKernel(parameter, inputs, outputs, ctx) {
|
||||
param_ = reinterpret_cast<MatMulParameter *>(op_parameter_);
|
||||
}
|
||||
~MatmulDynamicBaseInt8CPUKernel() override;
|
||||
int Prepare() override;
|
||||
int ReSize() override;
|
||||
|
||||
private:
|
||||
void ResizeParameter();
|
||||
int CopyBias();
|
||||
int InitTmpBuffer();
|
||||
|
||||
int MallocQuantParam();
|
||||
|
||||
protected:
|
||||
typedef void (*PackFunc)(const int8_t *src, int8_t *dst, int row, int col);
|
||||
virtual void InitParameter() = 0;
|
||||
int TransferA();
|
||||
int InitInputQuantParam();
|
||||
int InitFilterQuantParam();
|
||||
int TransferB();
|
||||
void FreeTmpBuffer();
|
||||
void FreeQuantParam();
|
||||
|
||||
protected:
|
||||
MatMulParameter *param_ = nullptr;
|
||||
MatmulDynamicQuantParameter *quant_param_ = nullptr;
|
||||
int thread_count_ = 1;
|
||||
int thread_stride_ = 0;
|
||||
int8_t *pack_a_ptr_ = nullptr;
|
||||
int8_t *pack_b_ptr_ = nullptr;
|
||||
float *fp32_bias_ptr_ = nullptr;
|
||||
bool filter_per_channel_ = true;
|
||||
int8_t *batch_input_ptr_ = nullptr;
|
||||
int8_t *batch_weight_ptr_ = nullptr;
|
||||
int8_t *batch_b_ptr_ = nullptr;
|
||||
float *batch_c_ptr_ = nullptr;
|
||||
int row_tile_ = C4NUM;
|
||||
int col_tile_ = C4NUM;
|
||||
int deep_tile_ = C16NUM;
|
||||
int channel_num_ = 0;
|
||||
PackFunc b_pack_func_{nullptr};
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_DYNAMIC_BASE_INT8_H_
|
|
@ -16,7 +16,9 @@
|
|||
|
||||
#include "src/runtime/kernel/arm/int8/matmul_dynamic_int8.h"
|
||||
#include "src/runtime/kernel/arm/int8/opt_op_handler.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#include "nnacl/int8/matmul_int8.h"
|
||||
#include "nnacl/int8/dynamic_matmul_int8.h"
|
||||
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_MEMORY_FAILED;
|
||||
|
@ -24,13 +26,6 @@ using mindspore::lite::RET_OK;
|
|||
|
||||
namespace mindspore::kernel {
|
||||
namespace {
|
||||
constexpr int kHasBiasSize = 3;
|
||||
constexpr int kMinInputSize = 2;
|
||||
constexpr int kOutputSize = 1;
|
||||
constexpr int kSize1 = 1;
|
||||
constexpr int kSize2 = 2;
|
||||
} // namespace
|
||||
|
||||
int MatmulDynamicInt8Run(void *cdata, int task_id, float, float) {
|
||||
CHECK_NULL_RETURN(cdata);
|
||||
auto op = reinterpret_cast<MatmulDynamicInt8CPUKernel *>(cdata);
|
||||
|
@ -41,6 +36,7 @@ int MatmulDynamicInt8Run(void *cdata, int task_id, float, float) {
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
int MatmulDynamicInt8CPUKernel::RunImpl(int task_id) {
|
||||
int stride = thread_stride_ * col_tile_;
|
||||
|
@ -55,92 +51,14 @@ int MatmulDynamicInt8CPUKernel::RunImpl(int task_id) {
|
|||
bias_ptr += cur_stride;
|
||||
}
|
||||
float *filter_scale = quant_param_->filter_scale_;
|
||||
int32_t filter_zp = quant_param_->filter_zp_[0];
|
||||
if (filter_per_channel_) {
|
||||
filter_scale += cur_stride;
|
||||
}
|
||||
DynamicMatmulInt8AIWI(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, bias_ptr,
|
||||
batch_c_ptr_ + cur_stride, param_->row_, cur_oc, param_->deep_align_,
|
||||
quant_param_->input_scale_, filter_scale, param_->col_, filter_per_channel_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
MatmulDynamicInt8CPUKernel::~MatmulDynamicInt8CPUKernel() {
|
||||
FreeQuantParam();
|
||||
FreeTmpBuffer();
|
||||
}
|
||||
|
||||
void MatmulDynamicInt8CPUKernel::FreeQuantParam() {
|
||||
if (quant_param_ != nullptr) {
|
||||
if (quant_param_->filter_scale_ != nullptr) {
|
||||
free(quant_param_->filter_scale_);
|
||||
quant_param_->filter_scale_ = nullptr;
|
||||
}
|
||||
if (quant_param_->filter_zp_ != nullptr) {
|
||||
free(quant_param_->filter_zp_);
|
||||
quant_param_->filter_zp_ = nullptr;
|
||||
}
|
||||
free(quant_param_);
|
||||
quant_param_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
int MatmulDynamicInt8CPUKernel::MallocQuantParam() {
|
||||
quant_param_ = reinterpret_cast<MatmulDynamicQuantParameter *>(malloc(sizeof(MatmulQuantParameter)));
|
||||
if (quant_param_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc MatmulDynamicQuantParameter for Matmul int8 op failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(quant_param_, 0, sizeof(MatmulQuantParameter));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulDynamicInt8CPUKernel::InitFilterQuantParam() {
|
||||
if (quant_param_->filter_scale_ != nullptr) {
|
||||
free(quant_param_->filter_scale_);
|
||||
quant_param_->filter_scale_ = nullptr;
|
||||
}
|
||||
if (quant_param_->filter_zp_ != nullptr) {
|
||||
free(quant_param_->filter_zp_);
|
||||
quant_param_->filter_zp_ = nullptr;
|
||||
}
|
||||
|
||||
auto weight_tensor = in_tensors_.at(kWeightIndex);
|
||||
auto weight_quant_params = weight_tensor->quant_params();
|
||||
auto w_shape = weight_tensor->shape();
|
||||
if (w_shape.size() < DIMENSION_2D) {
|
||||
MS_LOG(ERROR) << weight_tensor->tensor_name() << " dims < 2.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int col = param_->b_transpose_ ? w_shape[w_shape.size() - kSize2] : w_shape[w_shape.size() - kSize1];
|
||||
filter_per_channel_ = (weight_quant_params.size() > 1);
|
||||
channel_num_ = filter_per_channel_ ? col : 1;
|
||||
if (static_cast<int>(weight_quant_params.size()) != channel_num_) {
|
||||
MS_LOG(ERROR) << weight_tensor->tensor_name() << " quant params size:" << weight_quant_params.size()
|
||||
<< " != channel_num_:" << channel_num_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
quant_param_->filter_scale_ = reinterpret_cast<float *>(malloc(channel_num_ * sizeof(float)));
|
||||
CHECK_NULL_RETURN(quant_param_->filter_scale_);
|
||||
memset(quant_param_->filter_scale_, 0, sizeof(channel_num_));
|
||||
quant_param_->filter_zp_ = reinterpret_cast<int32_t *>(malloc(channel_num_ * sizeof(int32_t)));
|
||||
CHECK_NULL_RETURN(quant_param_->filter_zp_);
|
||||
memset(quant_param_->filter_zp_, 0, sizeof(channel_num_));
|
||||
|
||||
for (int i = 0; i < channel_num_; i++) {
|
||||
quant_param_->filter_scale_[i] = static_cast<float>(weight_quant_params[i].scale);
|
||||
quant_param_->filter_zp_[i] = weight_quant_params[i].zeroPoint;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulDynamicInt8CPUKernel::InitInputQuantParam() {
|
||||
auto in_quant_params = in_tensors_.at(kInputIndex)->quant_params();
|
||||
if (in_quant_params.empty()) {
|
||||
MS_LOG(ERROR) << "invalid in quant param";
|
||||
return RET_ERROR;
|
||||
}
|
||||
quant_param_->input_zp_ = in_quant_params.front().zeroPoint;
|
||||
quant_param_->input_scale_ = static_cast<float>(in_quant_params.front().scale);
|
||||
DynamicMatmul4x16x4AIWI(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, bias_ptr,
|
||||
batch_c_ptr_ + cur_stride, param_->row_, cur_oc, param_->deep_, param_->deep_align_,
|
||||
param_->col_, quant_param_->input_zp_, quant_param_->input_scale_, filter_scale, filter_zp,
|
||||
filter_per_channel_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -163,133 +81,6 @@ void MatmulDynamicInt8CPUKernel::InitParameter() {
|
|||
return;
|
||||
}
|
||||
|
||||
void MatmulDynamicInt8CPUKernel::ResizeParameter() {
|
||||
param_->row_align_ = UP_ROUND(param_->row_, row_tile_);
|
||||
param_->col_align_ = UP_ROUND(param_->col_, col_tile_);
|
||||
param_->deep_align_ = UP_ROUND(param_->deep_, deep_tile_);
|
||||
|
||||
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(param_->col_align_, col_tile_));
|
||||
thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, col_tile_), thread_count_);
|
||||
return;
|
||||
}
|
||||
|
||||
void MatmulDynamicInt8CPUKernel::FreeTmpBuffer() {
|
||||
if (pack_a_ptr_ != nullptr) {
|
||||
free(pack_a_ptr_);
|
||||
pack_a_ptr_ = nullptr;
|
||||
}
|
||||
if (pack_b_ptr_ != nullptr) {
|
||||
free(pack_b_ptr_);
|
||||
pack_b_ptr_ = nullptr;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int MatmulDynamicInt8CPUKernel::TransferB() {
|
||||
auto weight_data = reinterpret_cast<int8_t *>(in_tensors_.at(kWeightIndex)->data());
|
||||
CHECK_NULL_RETURN(weight_data);
|
||||
for (int i = 0; i < param_->batch; i++) {
|
||||
auto current_weight = weight_data + i * param_->deep_ * param_->col_;
|
||||
auto current_b_pack = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_;
|
||||
CHECK_NULL_RETURN(b_pack_func_);
|
||||
if (param_->b_transpose_) {
|
||||
b_pack_func_(current_weight, current_b_pack, param_->col_, param_->deep_);
|
||||
} else {
|
||||
b_pack_func_(current_weight, current_b_pack, param_->deep_, param_->col_);
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulDynamicInt8CPUKernel::InitTmpBuffer() {
|
||||
pack_a_ptr_ = reinterpret_cast<int8_t *>(malloc(param_->row_align_ * param_->deep_align_ * sizeof(int8_t)));
|
||||
if (pack_a_ptr_ == nullptr) {
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
pack_b_ptr_ =
|
||||
reinterpret_cast<int8_t *>(malloc(param_->batch * param_->col_align_ * param_->deep_align_ * sizeof(int8_t)));
|
||||
if (pack_b_ptr_ == nullptr) {
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(pack_a_ptr_, 0, param_->row_align_ * param_->deep_align_ * sizeof(int8_t));
|
||||
memset(pack_b_ptr_, 0, param_->batch * param_->col_align_ * param_->deep_align_ * sizeof(int8_t));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulDynamicInt8CPUKernel::CopyBias() {
|
||||
if (in_tensors_.size() == kHasBiasSize) {
|
||||
auto bias_tensor = in_tensors_[kBiasIndex];
|
||||
fp32_bias_ptr_ = reinterpret_cast<float *>(bias_tensor->data());
|
||||
if (fp32_bias_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Memory allocation failed";
|
||||
FreeTmpBuffer();
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
memcpy(fp32_bias_ptr_, bias_tensor->data(), bias_tensor->ElementsNum() * sizeof(float));
|
||||
} else {
|
||||
fp32_bias_ptr_ = nullptr;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulDynamicInt8CPUKernel::Prepare() {
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), kMinInputSize);
|
||||
CHECK_LESS_RETURN(out_tensors_.size(), kOutputSize);
|
||||
InitParameter();
|
||||
auto ret = MallocQuantParam();
|
||||
if (ret != RET_OK) {
|
||||
FreeQuantParam();
|
||||
return ret;
|
||||
}
|
||||
if (param_->b_const_) {
|
||||
ret = InitFilterQuantParam();
|
||||
if (ret != RET_OK) {
|
||||
FreeQuantParam();
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
ret = CopyBias();
|
||||
if (ret != RET_OK) {
|
||||
FreeQuantParam();
|
||||
return ret;
|
||||
}
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int MatmulDynamicInt8CPUKernel::ReSize() {
|
||||
int batch = 1;
|
||||
auto x_shape = in_tensors_.at(0)->shape();
|
||||
auto o_shape = out_tensors_.at(0)->shape();
|
||||
MS_ASSERT(x_shape.size() >= kSize2);
|
||||
for (size_t i = 0; i < x_shape.size() - kSize2; ++i) {
|
||||
batch *= x_shape[i];
|
||||
}
|
||||
param_->batch = batch;
|
||||
MS_ASSERT(o_shape.size() >= kSize2);
|
||||
param_->row_ = o_shape[o_shape.size() - kSize2];
|
||||
param_->col_ = o_shape[o_shape.size() - kSize1];
|
||||
param_->deep_ = param_->a_transpose_ ? x_shape[x_shape.size() - kSize2] : x_shape[x_shape.size() - kSize1];
|
||||
|
||||
FreeTmpBuffer();
|
||||
|
||||
ResizeParameter();
|
||||
|
||||
auto ret = InitTmpBuffer();
|
||||
if (ret != RET_OK) {
|
||||
FreeQuantParam();
|
||||
return ret;
|
||||
}
|
||||
if (param_->b_const_ == true) {
|
||||
TransferB();
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulDynamicInt8CPUKernel::Run() {
|
||||
auto ret = InitInputQuantParam();
|
||||
if (ret != RET_OK) {
|
||||
|
@ -314,6 +105,7 @@ int MatmulDynamicInt8CPUKernel::Run() {
|
|||
CHECK_NULL_RETURN(a_ptr);
|
||||
CHECK_NULL_RETURN(c_ptr);
|
||||
for (int i = 0; i < param_->batch; i++) {
|
||||
memset(pack_a_ptr_, quant_param_->input_zp_, param_->row_align_ * param_->deep_align_ * sizeof(int8_t));
|
||||
auto current_src_a = a_ptr + i * param_->row_ * param_->deep_;
|
||||
if (param_->a_transpose_) {
|
||||
MS_CHECK_TRUE_RET(a_pack_func_ != nullptr, RET_ERROR);
|
||||
|
|
|
@ -18,72 +18,25 @@
|
|||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_DYNAMIC_INT8_H_
|
||||
|
||||
#include <vector>
|
||||
#include "include/errorcode.h"
|
||||
#include "include/context.h"
|
||||
#include "src/inner_kernel.h"
|
||||
#include "nnacl/matmul_parameter.h"
|
||||
#include "nnacl/common_func.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
#include "nnacl/int8/common_func_int8.h"
|
||||
#include "nnacl/int8/matmul_int8.h"
|
||||
#include "src/runtime/kernel/arm/int8/matmul_dynamic_base_int8.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class MatmulDynamicInt8CPUKernel : public InnerKernel {
|
||||
typedef void (*PackFunc)(const int8_t *src, int8_t *dst, int row, int col);
|
||||
|
||||
class MatmulDynamicInt8CPUKernel : public MatmulDynamicBaseInt8CPUKernel {
|
||||
public:
|
||||
MatmulDynamicInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: InnerKernel(parameter, inputs, outputs, ctx) {
|
||||
param_ = reinterpret_cast<MatMulParameter *>(op_parameter_);
|
||||
}
|
||||
~MatmulDynamicInt8CPUKernel() override;
|
||||
int Prepare() override;
|
||||
int ReSize() override;
|
||||
: MatmulDynamicBaseInt8CPUKernel(parameter, inputs, outputs, ctx) {}
|
||||
~MatmulDynamicInt8CPUKernel() override = default;
|
||||
int Run() override;
|
||||
|
||||
public:
|
||||
int RunImpl(int task_id);
|
||||
#if defined(ENABLE_ARM64) && !defined(SUPPORT_NNIE) && (!defined(MACHINE_LINUX_ARM64))
|
||||
int RunArm64Sdot();
|
||||
int Arm64SdotImpl(int task_id);
|
||||
int Arm64SdotPre(int task_id);
|
||||
#endif
|
||||
|
||||
private:
|
||||
void InitParameter();
|
||||
void ResizeParameter();
|
||||
int CopyBias();
|
||||
int InitTmpBuffer();
|
||||
void FreeTmpBuffer();
|
||||
int TransferA();
|
||||
int TransferB();
|
||||
|
||||
int MallocQuantParam();
|
||||
int InitInputQuantParam();
|
||||
int InitFilterQuantParam();
|
||||
void FreeQuantParam();
|
||||
void InitParameter() override;
|
||||
|
||||
private:
|
||||
MatMulParameter *param_ = nullptr;
|
||||
MatmulDynamicQuantParameter *quant_param_ = nullptr;
|
||||
int thread_count_ = 1;
|
||||
int thread_stride_ = 0;
|
||||
int8_t *pack_a_ptr_ = nullptr;
|
||||
int8_t *pack_b_ptr_ = nullptr;
|
||||
float *fp32_bias_ptr_ = nullptr;
|
||||
bool filter_per_channel_ = true;
|
||||
int8_t *batch_input_ptr_ = nullptr;
|
||||
int8_t *batch_weight_ptr_ = nullptr;
|
||||
int8_t *batch_b_ptr_ = nullptr;
|
||||
float *batch_c_ptr_ = nullptr;
|
||||
int row_tile_ = C4NUM;
|
||||
int col_tile_ = C4NUM;
|
||||
int deep_tile_ = C16NUM;
|
||||
int channel_num_ = 0;
|
||||
bool support_sdot_ = false;
|
||||
PackFunc a_pack_func_{nullptr};
|
||||
PackFunc b_pack_func_{nullptr};
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -0,0 +1,184 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/int8/matmul_dynamic_sdot_int8.h"
|
||||
#include <vector>
|
||||
#include "src/common/file_utils.h"
|
||||
#include "nnacl/int8/dynamic_matmul_int8.h"
|
||||
#include "nnacl/int8/matmul_int8.h"
|
||||
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_MEMORY_FAILED;
|
||||
using mindspore::lite::RET_OK;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
namespace {
|
||||
int Arm64SdotPreRun(void *cdata, int task_id, float, float) {
|
||||
CHECK_NULL_RETURN(cdata);
|
||||
auto op = reinterpret_cast<MatMulDynamicSdotInt8Kernel *>(cdata);
|
||||
auto ret = op->MatMulDynamicArm64SdotPre(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "MatmulInt8Run error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
return ret;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Arm64SdotRun(void *cdata, int task_id, float, float) {
|
||||
CHECK_NULL_RETURN(cdata);
|
||||
auto op = reinterpret_cast<MatMulDynamicSdotInt8Kernel *>(cdata);
|
||||
auto ret = op->MatMulDynamicArm64SdotImpl(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "MatmulInt8Run error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
return ret;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
int MatMulDynamicSdotInt8Kernel::MatMulDynamicArm64SdotPre(int task_id) {
|
||||
int row_thread_count = MSMIN(op_parameter_->thread_num_, UP_DIV(param_->row_align_, row_tile_));
|
||||
int row_stride = UP_DIV(UP_DIV(param_->row_align_, row_tile_), row_thread_count) * row_tile_;
|
||||
|
||||
int row_current_stride = task_id * row_stride;
|
||||
int row_res_stride = param_->row_ - row_current_stride;
|
||||
int cur_r = MSMIN(row_res_stride, row_stride);
|
||||
if (cur_r <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
auto current_a_pack = pack_a_ptr_ + row_current_stride * param_->deep_align_;
|
||||
if (param_->a_transpose_) {
|
||||
auto current_src_a = batch_input_ptr_ + row_current_stride;
|
||||
PackInput2Col4x4(current_src_a, current_a_pack, param_->deep_, cur_r, param_->row_);
|
||||
} else {
|
||||
auto current_src_a = batch_input_ptr_ + row_current_stride * param_->deep_;
|
||||
PackInput4x4(current_src_a, current_a_pack, param_->deep_, cur_r);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatMulDynamicSdotInt8Kernel::MatMulDynamicArm64SdotImpl(int task_id) {
|
||||
#if defined(ENABLE_ARM64) && !defined(SUPPORT_NNIE) && (!defined(MACHINE_LINUX_ARM64))
|
||||
// Multi-thread split by col.
|
||||
int stride = thread_stride_ * col_tile_;
|
||||
int cur_stride = task_id * stride;
|
||||
int res_stride = param_->col_ - cur_stride;
|
||||
int cur_oc = MSMIN(stride, res_stride);
|
||||
if (cur_oc <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
if (!param_->b_const_) {
|
||||
auto current_b_pack = batch_b_ptr_ + cur_stride * param_->deep_align_;
|
||||
if (param_->b_transpose_) {
|
||||
auto current_weight = batch_weight_ptr_ + cur_stride * param_->deep_;
|
||||
RowMajor2Row4x16MajorInt8(current_weight, current_b_pack, cur_oc, param_->deep_);
|
||||
} else {
|
||||
auto current_weight = batch_weight_ptr_ + cur_stride;
|
||||
RowMajor2Col4x16MajorPartInt8(current_weight, current_b_pack, param_->deep_, param_->col_, cur_oc);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<float> multi_scale(cur_oc);
|
||||
for (int i = 0; i < cur_oc; ++i) {
|
||||
if (!param_->b_const_) {
|
||||
multi_scale[i] = quant_param_->input_scale_ * quant_param_->filter_scale_[0];
|
||||
} else {
|
||||
multi_scale[i] = quant_param_->input_scale_ * quant_param_->filter_scale_[cur_stride + i];
|
||||
}
|
||||
}
|
||||
for (int r = 0; r < param_->row_; r += C4NUM) {
|
||||
size_t row = MSMIN(C4NUM, param_->row_ - r);
|
||||
auto a_ptr = pack_a_ptr_ + r * param_->deep_align_;
|
||||
for (int c = 0; c < cur_oc; c += C16NUM) {
|
||||
size_t col = MSMIN(C16NUM, cur_oc - c);
|
||||
auto col_offset = cur_stride + c;
|
||||
auto b_ptr = batch_b_ptr_ + col_offset * param_->deep_align_;
|
||||
auto out_ptr = batch_c_ptr_ + r * param_->col_ + col_offset;
|
||||
auto bias = fp32_bias_ptr_;
|
||||
if (bias != nullptr) {
|
||||
bias += col_offset;
|
||||
}
|
||||
DynamicMatmulSdot4x4x16AIWI(a_ptr, b_ptr, out_ptr, param_->deep_align_, multi_scale.data() + c, bias, row, col,
|
||||
param_->col_ * sizeof(float));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void MatMulDynamicSdotInt8Kernel::InitParameter() {
|
||||
param_->a_const_ = (in_tensors_[0]->data() != nullptr);
|
||||
param_->b_const_ = (in_tensors_[1]->data() != nullptr);
|
||||
|
||||
row_tile_ = C4NUM;
|
||||
col_tile_ = C16NUM;
|
||||
deep_tile_ = C4NUM;
|
||||
|
||||
if (param_->b_transpose_) {
|
||||
b_pack_func_ = RowMajor2Row4x16MajorInt8;
|
||||
} else {
|
||||
b_pack_func_ = RowMajor2Col4x16MajorInt8;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int MatMulDynamicSdotInt8Kernel::MatMulDynamicRunArm64Sdot() {
|
||||
int8_t *a_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(0)->data());
|
||||
int8_t *b_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data());
|
||||
float *c_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->data());
|
||||
CHECK_NULL_RETURN(a_ptr);
|
||||
CHECK_NULL_RETURN(b_ptr);
|
||||
CHECK_NULL_RETURN(c_ptr);
|
||||
|
||||
for (int i = 0; i < param_->batch; i++) {
|
||||
batch_input_ptr_ = a_ptr + i * param_->row_ * param_->deep_;
|
||||
auto ret = ParallelLaunch(this->ms_context_, Arm64SdotPreRun, this, op_parameter_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Arm64SdotPreRun error: [" << ret << "]";
|
||||
return ret;
|
||||
}
|
||||
|
||||
batch_weight_ptr_ = b_ptr + i * param_->col_ * param_->deep_;
|
||||
batch_b_ptr_ = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_;
|
||||
batch_c_ptr_ = c_ptr + i * param_->row_ * param_->col_;
|
||||
|
||||
ret = ParallelLaunch(this->ms_context_, Arm64SdotRun, this, thread_count_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Arm64SdotRun error: [" << ret << "]";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatMulDynamicSdotInt8Kernel::Run() {
|
||||
auto ret = InitInputQuantParam();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init input quant param failed.";
|
||||
return ret;
|
||||
}
|
||||
if (!param_->b_const_) {
|
||||
ret = InitFilterQuantParam();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init filter quant param failed.";
|
||||
FreeQuantParam();
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
return MatMulDynamicRunArm64Sdot();
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_DYNAMIC_SDOT_INT8_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_DYNAMIC_SDOT_INT8_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/runtime/kernel/arm/int8/matmul_dynamic_base_int8.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class MatMulDynamicSdotInt8Kernel : public MatmulDynamicBaseInt8CPUKernel {
|
||||
public:
|
||||
MatMulDynamicSdotInt8Kernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: MatmulDynamicBaseInt8CPUKernel(parameter, inputs, outputs, ctx) {}
|
||||
~MatMulDynamicSdotInt8Kernel() override = default;
|
||||
int Run() override;
|
||||
|
||||
public:
|
||||
int MatMulDynamicRunArm64Sdot();
|
||||
int MatMulDynamicArm64SdotPre(int task_id);
|
||||
int MatMulDynamicArm64SdotImpl(int task_id);
|
||||
|
||||
private:
|
||||
void InitParameter() override;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_DYNAMIC_SDOT_INT8_H_
|
|
@ -86,17 +86,19 @@ int PreprocessParser::ParsePreprocess(const DataPreProcessString &data_pre_proce
|
|||
preprocess::ConvertColorConversionCodes(data_pre_process->image_pre_process.image_to_format);
|
||||
}
|
||||
}
|
||||
ret = ParseImagePreProcess(data_pre_process_str, &data_pre_process->image_pre_process);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "image preprocess parse failed.";
|
||||
return ret;
|
||||
}
|
||||
if (!data_pre_process_str.calibrate_path.empty() && !data_pre_process_str.calibrate_size.empty()) {
|
||||
ret = ParseImagePreProcess(data_pre_process_str, &data_pre_process->image_pre_process);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "image preprocess parse failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = CollectCalibInputs(data_pre_process->calibrate_path, data_pre_process->calibrate_size,
|
||||
&data_pre_process->calibrate_path_vector);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "collect calibrate inputs failed.";
|
||||
return ret;
|
||||
ret = CollectCalibInputs(data_pre_process->calibrate_path, data_pre_process->calibrate_size,
|
||||
&data_pre_process->calibrate_path_vector);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "collect calibrate inputs failed.";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -202,8 +204,13 @@ int PreprocessParser::CollectCalibInputs(const std::map<std::string, std::string
|
|||
MS_LOG(ERROR) << " close dir failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto &cur_inputs = inputs->at(image_path.first);
|
||||
std::sort(cur_inputs.begin(), cur_inputs.end());
|
||||
if (inputs->find(image_path.first) != inputs->end()) {
|
||||
auto &cur_inputs = inputs->at(image_path.first);
|
||||
std::sort(cur_inputs.begin(), cur_inputs.end());
|
||||
} else {
|
||||
MS_LOG(ERROR) << "cant find " << image_path.first << " at input maps.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (count != limited_count) {
|
||||
MS_LOG(ERROR) << " data path: " << image_path.second << " data count:" << count
|
||||
<< " < limited_count:" << limited_count;
|
||||
|
|
|
@ -177,9 +177,9 @@ int DebugInfoManager::SetOriginStaticInfo(QuantDebugInfo *quant_debug_info, cons
|
|||
}
|
||||
quant_debug_info->clip = 0;
|
||||
|
||||
CHECK_NULL_RETURN(tensor.data());
|
||||
quant_debug_info->tensor_data.data = malloc(tensor.Size());
|
||||
CHECK_MALLOC_RES(quant_debug_info->tensor_data.data, RET_NULL_PTR);
|
||||
CHECK_NULL_RETURN(tensor.data());
|
||||
auto ret = memcpy_s(quant_debug_info->tensor_data.data, tensor.Size(), tensor.data(), tensor.Size());
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy memory failed.";
|
||||
|
|
|
@ -29,7 +29,6 @@
|
|||
#include "schema/inner/model_generated.h"
|
||||
#include "src/lite_session.h"
|
||||
#include "tools/converter/quantizer/quantizer.h"
|
||||
#include "tools/converter/converter.h"
|
||||
#include "include/ms_tensor.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "tools/converter/quantizer/quant_params.h"
|
||||
|
|
|
@ -82,26 +82,22 @@ int DoWeightQuant(const FuncGraphPtr &old_graph, const converter::Flags *config)
|
|||
auto quantizer = std::make_unique<WeightQuantizer>(*config);
|
||||
if (quantizer == nullptr) {
|
||||
MS_LOG(ERROR) << "New WeightQuantizer failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
|
||||
return RET_ERROR;
|
||||
}
|
||||
status = static_cast<WeightQuantizer *>(quantizer.get())->DoQuantize(old_graph, init_scale);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoQuantization failed " << status;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
auto quantizer = std::make_unique<WeightQuantizer>(*config);
|
||||
if (quantizer == nullptr) {
|
||||
MS_LOG(ERROR) << "New WeightQuantizer failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto status = quantizer->DoQuantize(old_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoQuantization failed " << status;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
@ -112,13 +108,11 @@ int DoDynamicQuant(const FuncGraphPtr &old_graph, const converter::Flags *config
|
|||
auto quantizer = std::make_unique<DynamicQuantizer>(*config);
|
||||
if (quantizer == nullptr) {
|
||||
MS_LOG(ERROR) << "New DynamicQuantizer failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto status = quantizer->DoQuantize(old_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoQuantization failed " << status;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
|
|
|
@ -119,14 +119,15 @@ int WeightQuantizer::DoCNodeWeightQuant(const FuncGraphPtr &func_graph, const CN
|
|||
}
|
||||
auto status = RET_ERROR;
|
||||
if (is_mixed_bit_) {
|
||||
status = MixedBitQuantFilter(parameter, tensor_info, primitive, QuantType_QUANT_WEIGHT,
|
||||
status = MixedBitQuantFilter(parameter, tensor_info, primitive, flags_.commonQuantParam.quant_type,
|
||||
WeightQuantType::MIXED_BIT_PER_LAYER, type_id_, mixed_bit_init_scale_, idx - 1);
|
||||
} else if (type_id_ == kNumberTypeInt8) {
|
||||
status = FixedBitQuantFilter<int8_t>(parameter, tensor_info, primitive, QuantType_QUANT_WEIGHT, q_max, q_min,
|
||||
bit_num_, tmp_weight_quant_type, type_id_, idx - 1, symmetric);
|
||||
status = FixedBitQuantFilter<int8_t>(parameter, tensor_info, primitive, flags_.commonQuantParam.quant_type, q_max,
|
||||
q_min, bit_num_, tmp_weight_quant_type, type_id_, idx - 1, symmetric);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status = FixedBitQuantFilter<int16_t>(parameter, tensor_info, primitive, QuantType_QUANT_WEIGHT, q_max, q_min,
|
||||
bit_num_, tmp_weight_quant_type, type_id_, idx - 1, symmetric);
|
||||
status =
|
||||
FixedBitQuantFilter<int16_t>(parameter, tensor_info, primitive, flags_.commonQuantParam.quant_type, q_max,
|
||||
q_min, bit_num_, tmp_weight_quant_type, type_id_, idx - 1, symmetric);
|
||||
}
|
||||
if (status == RET_NO_CHANGE) {
|
||||
continue;
|
||||
|
|
Loading…
Reference in New Issue