!29757 dynamic matmul opitmize with sdot

Merge pull request !29757 from yeyunpeng2020/dynamic_quant_success
This commit is contained in:
i-robot 2022-02-08 11:25:39 +00:00 committed by Gitee
commit 11372f2a2a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
18 changed files with 1474 additions and 334 deletions

View File

@ -0,0 +1,487 @@
#ifdef ENABLE_ARM64
#include "nnacl/assembly_global.h"
.text
.align 5
// void DynamicMatmulSdot4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep, float *multi_scales,
// float *bias, size_t row, size_t col, size_t stride);
// x0: a(left matrix ptr)
// x1: b(right matrix ptr)
// x2: out ptr
// x3: deep
// x4: multi_scales
// x5: bias
// x6: row
// x7: col
// x8: stride
asm_function DynamicMatmulSdot4x4x16AIWI
ldr x8, [sp]
dup v16.4s, wzr // dup:Duplicate general-purpose register to vector.
dup v17.4s, wzr
dup v18.4s, wzr
dup v19.4s, wzr
dup v20.4s, wzr
dup v21.4s, wzr
dup v22.4s, wzr
dup v23.4s, wzr
dup v24.4s, wzr
dup v25.4s, wzr
dup v26.4s, wzr
dup v27.4s, wzr
dup v28.4s, wzr
dup v29.4s, wzr
dup v30.4s, wzr
dup v31.4s, wzr
mov x18, x1 // reload rhs ptr
mov x17, x0 // reload lhs ptr
mov x16, x3 // reload depth
cmp x7, #4
ble LoopDepthQuarter
cmp x7, #8
ble LoopDepthHalf
LoopDepth:
ld1 {v0.16b}, [x17], #16
ld1 {v1.16b, v2.16b, v3.16b, v4.16b}, [x18], #64
sdot v16.4s, v1.16b, v0.4b[0]
sdot v17.4s, v2.16b, v0.4b[0]
sdot v18.4s, v3.16b, v0.4b[0]
sdot v19.4s, v4.16b, v0.4b[0]
sdot v20.4s, v1.16b, v0.4b[1]
sdot v21.4s, v2.16b, v0.4b[1]
sdot v22.4s, v3.16b, v0.4b[1]
sdot v23.4s, v4.16b, v0.4b[1]
sdot v24.4s, v1.16b, v0.4b[2]
sdot v25.4s, v2.16b, v0.4b[2]
sdot v26.4s, v3.16b, v0.4b[2]
sdot v27.4s, v4.16b, v0.4b[2]
sdot v28.4s, v1.16b, v0.4b[3]
sdot v29.4s, v2.16b, v0.4b[3]
sdot v30.4s, v3.16b, v0.4b[3]
sdot v31.4s, v4.16b, v0.4b[3]
subs x16, x16, #4
bgt LoopDepth
b Convert2Float
LoopDepthHalf:
ld1 {v0.16b}, [x17], #16
ld1 {v1.16b, v2.16b}, [x18]
add x18, x18, #64
sdot v16.4s, v1.16b, v0.4b[0]
sdot v17.4s, v2.16b, v0.4b[0]
sdot v20.4s, v1.16b, v0.4b[1]
sdot v21.4s, v2.16b, v0.4b[1]
sdot v24.4s, v1.16b, v0.4b[2]
sdot v25.4s, v2.16b, v0.4b[2]
sdot v28.4s, v1.16b, v0.4b[3]
sdot v29.4s, v2.16b, v0.4b[3]
subs x16, x16, #4
bgt LoopDepthHalf
b Convert2Float
LoopDepthQuarter:
ld1 {v0.16b}, [x17], #16
ld1 {v1.16b}, [x18]
add x18, x18, #64
sdot v16.4s, v1.16b, v0.4b[0]
sdot v20.4s, v1.16b, v0.4b[1]
sdot v24.4s, v1.16b, v0.4b[2]
sdot v28.4s, v1.16b, v0.4b[3]
subs x16, x16, #4
bgt LoopDepthQuarter
b Convert2Float
Convert2Float:
scvtf v16.4s, v16.4s
scvtf v17.4s, v17.4s
scvtf v18.4s, v18.4s
scvtf v19.4s, v19.4s
scvtf v20.4s, v20.4s
scvtf v21.4s, v21.4s
scvtf v22.4s, v22.4s
scvtf v23.4s, v23.4s
scvtf v24.4s, v24.4s
scvtf v25.4s, v25.4s
scvtf v26.4s, v26.4s
scvtf v27.4s, v27.4s
scvtf v28.4s, v28.4s
scvtf v29.4s, v29.4s
scvtf v30.4s, v30.4s
scvtf v31.4s, v31.4s
MultiplyScale:
// multi_scale * input_matrix
ld1 {v1.4s, v2.4s, v3.4s, v4.4s}, [x4]
fmul v16.4s,v16.4s,v1.4s
fmul v17.4s,v17.4s,v2.4s
fmul v18.4s,v18.4s,v3.4s
fmul v19.4s,v19.4s,v4.4s
fmul v20.4s,v20.4s,v1.4s
fmul v21.4s,v21.4s,v2.4s
fmul v22.4s,v22.4s,v3.4s
fmul v23.4s,v23.4s,v4.4s
fmul v24.4s,v24.4s,v1.4s
fmul v25.4s,v25.4s,v2.4s
fmul v26.4s,v26.4s,v3.4s
fmul v27.4s,v27.4s,v4.4s
fmul v28.4s,v28.4s,v1.4s
fmul v29.4s,v29.4s,v2.4s
fmul v30.4s,v30.4s,v3.4s
fmul v31.4s,v31.4s,v4.4s
AddBias:
// +bias
cbz x5, StoreData
ld1 {v1.4s, v2.4s, v3.4s, v4.4s}, [x5]
fadd v16.4s,v16.4s,v1.4s
fadd v17.4s,v17.4s,v2.4s
fadd v18.4s,v18.4s,v3.4s
fadd v19.4s,v19.4s,v4.4s
fadd v20.4s,v20.4s,v1.4s
fadd v21.4s,v21.4s,v2.4s
fadd v22.4s,v22.4s,v3.4s
fadd v23.4s,v23.4s,v4.4s
fadd v24.4s,v24.4s,v1.4s
fadd v25.4s,v25.4s,v2.4s
fadd v26.4s,v26.4s,v3.4s
fadd v27.4s,v27.4s,v4.4s
fadd v28.4s,v28.4s,v1.4s
fadd v29.4s,v29.4s,v2.4s
fadd v30.4s,v30.4s,v3.4s
fadd v31.4s,v31.4s,v4.4s
StoreData:
cmp x7, #16
beq Write16
mov x15, x2 // reload out ptr
add x14, x15, x8
add x13, x14, x8
add x12, x13, x8
cmp x7, #15
beq Write15
cmp x7, #14
beq Write14
cmp x7, #13
beq Write13
cmp x7, #12
beq Write12
cmp x7, #11
beq Write11
cmp x7, #10
beq Write10
cmp x7, #9
beq Write9
cmp x7, #8
beq Write8
cmp x7, #7
beq Write7
cmp x7, #6
beq Write6
cmp x7, #5
beq Write5
cmp x7, #4
beq Write4
cmp x7, #3
beq Write3
cmp x7, #2
beq Write2
cmp x7, #1
beq Write1
b StoreDataEnd
Write16:
cmp x6, #4
beq Write16Row4
cmp x6, #3
beq Write16Row3
cmp x6, #2
beq Write16Row2
cmp x6, #1
beq Write16Row1
Write16Row4:
st1 {v16.4s,v17.4s,v18.4s,v19.4s}, [x2], x8
st1 {v20.4s,v21.4s,v22.4s,v23.4s}, [x2], x8
st1 {v24.4s,v25.4s,v26.4s,v27.4s}, [x2], x8
st1 {v28.4s,v29.4s,v30.4s,v31.4s}, [x2]
b StoreDataEnd
Write16Row3:
st1 {v16.4s,v17.4s,v18.4s,v19.4s}, [x2], x8
st1 {v20.4s,v21.4s,v22.4s,v23.4s}, [x2], x8
st1 {v24.4s,v25.4s,v26.4s,v27.4s}, [x2]
b StoreDataEnd
Write16Row2:
st1 {v16.4s,v17.4s,v18.4s,v19.4s}, [x2], x8
st1 {v20.4s,v21.4s,v22.4s,v23.4s}, [x2]
b StoreDataEnd
Write16Row1:
st1 {v16.4s,v17.4s,v18.4s,v19.4s}, [x2]
b StoreDataEnd
Write15:
st1 {v16.4s,v17.4s,v18.4s}, [x15], #48
st1 {v19.1d}, [x15], #8
st1 {v19.s}[2], [x15]
cmp x6, #1
beq StoreDataEnd
st1 {v20.4s,v21.4s,v22.4s}, [x14], #48
st1 {v23.1d}, [x14], #8
st1 {v23.s}[2], [x14]
cmp x6, #2
beq StoreDataEnd
st1 {v24.4s,v25.4s,v26.4s}, [x13], #48
st1 {v27.1d}, [x13], #8
st1 {v27.s}[2], [x13]
cmp x6, #3
beq StoreDataEnd
st1 {v28.4s,v29.4s,v30.4s}, [x12], #48
st1 {v31.1d}, [x12], #8
st1 {v31.s}[2], [x12]
b StoreDataEnd
Write14:
st1 {v16.4s,v17.4s,v18.4s}, [x15], #48
st1 {v19.1d}, [x15]
cmp x6, #1
beq StoreDataEnd
st1 {v20.4s,v21.4s,v22.4s}, [x14], #48
st1 {v23.1d}, [x14]
cmp x6, #2
beq StoreDataEnd
st1 {v24.4s,v25.4s,v26.4s}, [x13], #48
st1 {v27.1d}, [x13]
cmp x6, #3
beq StoreDataEnd
st1 {v28.4s,v29.4s,v30.4s}, [x12], #48
st1 {v31.1d}, [x12]
b StoreDataEnd
Write13:
st1 {v16.4s,v17.4s,v18.4s}, [x15], #48
st1 {v19.s}[0], [x15]
cmp x6, #1
beq StoreDataEnd
st1 {v20.4s,v21.4s,v22.4s}, [x14], #48
st1 {v23.s}[0], [x14]
cmp x6, #2
beq StoreDataEnd
st1 {v24.4s,v25.4s,v26.4s}, [x13], #48
st1 {v27.s}[0], [x13]
cmp x6, #3
beq StoreDataEnd
st1 {v28.4s,v29.4s,v30.4s}, [x12], #48
st1 {v31.s}[0], [x12]
b StoreDataEnd
Write12:
st1 {v16.4s,v17.4s,v18.4s}, [x15], #48
cmp x6, #1
beq StoreDataEnd
st1 {v20.4s,v21.4s,v22.4s}, [x14], #48
cmp x6, #2
beq StoreDataEnd
st1 {v24.4s,v25.4s,v26.4s}, [x13], #48
cmp x6, #3
beq StoreDataEnd
st1 {v28.4s,v29.4s,v30.4s}, [x12], #48
b StoreDataEnd
Write11:
st1 {v16.4s,v17.4s}, [x15], #32
st1 {v18.1d}, [x15], #8
st1 {v18.s}[2], [x15]
cmp x6, #1
beq StoreDataEnd
st1 {v20.4s,v21.4s}, [x14], #32
st1 {v22.1d}, [x14], #8
st1 {v22.s}[2], [x14]
cmp x6, #2
beq StoreDataEnd
st1 {v24.4s,v25.4s}, [x13], #32
st1 {v26.1d}, [x13], #8
st1 {v26.s}[2], [x13]
cmp x6, #3
beq StoreDataEnd
st1 {v28.4s,v29.4s}, [x12], #32
st1 {v30.1d}, [x12], #8
st1 {v30.s}[2], [x12]
b StoreDataEnd
Write10:
st1 {v16.4s,v17.4s}, [x15], #32
st1 {v18.1d}, [x15]
cmp x6, #1
beq StoreDataEnd
st1 {v20.4s,v21.4s}, [x14], #32
st1 {v22.1d}, [x14]
cmp x6, #2
beq StoreDataEnd
st1 {v24.4s,v25.4s}, [x13], #32
st1 {v26.1d}, [x13]
cmp x6, #3
beq StoreDataEnd
st1 {v28.4s,v29.4s}, [x12], #32
st1 {v30.1d}, [x12]
b StoreDataEnd
Write9:
st1 {v16.4s,v17.4s}, [x15], #32
st1 {v18.s}[0], [x15]
cmp x6, #1
beq StoreDataEnd
st1 {v20.4s,v21.4s}, [x14], #32
st1 {v22.s}[0], [x14]
cmp x6, #2
beq StoreDataEnd
st1 {v24.4s,v25.4s}, [x13], #32
st1 {v26.s}[0], [x13]
cmp x6, #3
beq StoreDataEnd
st1 {v28.4s,v29.4s}, [x12], #32
st1 {v30.s}[0], [x12]
b StoreDataEnd
Write8:
st1 {v16.4s,v17.4s}, [x15], #32
cmp x6, #1
beq StoreDataEnd
st1 {v20.4s,v21.4s}, [x14], #32
cmp x6, #2
beq StoreDataEnd
st1 {v24.4s,v25.4s}, [x13], #32
cmp x6, #3
beq StoreDataEnd
st1 {v28.4s,v29.4s}, [x12], #32
b StoreDataEnd
Write7:
st1 {v16.4s}, [x15], #16
st1 {v17.1d}, [x15], #8
st1 {v17.s}[2], [x15]
cmp x6, #1
beq StoreDataEnd
st1 {v20.4s}, [x14], #16
st1 {v21.1d}, [x14], #8
st1 {v21.s}[2], [x14]
cmp x6, #2
beq StoreDataEnd
st1 {v24.4s}, [x13], #16
st1 {v25.1d}, [x13], #8
st1 {v25.s}[2], [x13]
cmp x6, #3
beq StoreDataEnd
st1 {v28.4s}, [x12], #16
st1 {v29.1d}, [x12], #8
st1 {v29.s}[2], [x12]
b StoreDataEnd
Write6:
st1 {v16.4s}, [x15], #16
st1 {v17.1d}, [x15]
cmp x6, #1
beq StoreDataEnd
st1 {v20.4s}, [x14], #16
st1 {v21.1d}, [x14]
cmp x6, #2
beq StoreDataEnd
st1 {v24.4s}, [x13], #16
st1 {v25.1d}, [x13]
cmp x6, #3
beq StoreDataEnd
st1 {v28.4s}, [x12], #16
st1 {v29.1d}, [x12]
b StoreDataEnd
Write5:
st1 {v16.4s}, [x15], #16
st1 {v17.s}[0], [x15]
cmp x6, #1
beq StoreDataEnd
st1 {v20.4s}, [x14], #16
st1 {v21.s}[0], [x14]
cmp x6, #2
beq StoreDataEnd
st1 {v24.4s}, [x13], #16
st1 {v25.s}[0], [x13]
cmp x6, #3
beq StoreDataEnd
st1 {v28.4s}, [x12], #16
st1 {v29.s}[0], [x12]
b StoreDataEnd
Write4:
st1 {v16.4s}, [x15]
cmp x6, #1
beq StoreDataEnd
st1 {v20.4s}, [x14]
cmp x6, #2
beq StoreDataEnd
st1 {v24.4s}, [x13]
cmp x6, #3
beq StoreDataEnd
st1 {v28.4s}, [x12]
b StoreDataEnd
Write3:
st1 {v16.1d}, [x15]
st1 {v16.s}[2], [x15]
cmp x6, #1
beq StoreDataEnd
st1 {v20.1d}, [x14]
st1 {v20.s}[2], [x14]
cmp x6, #2
beq StoreDataEnd
st1 {v24.1d}, [x13]
st1 {v24.s}[2], [x13]
cmp x6, #3
beq StoreDataEnd
st1 {v28.1d}, [x12]
st1 {v28.s}[2], [x12]
b StoreDataEnd
Write2:
st1 {v16.1d}, [x15]
cmp x6, #1
beq StoreDataEnd
st1 {v20.1d}, [x14]
cmp x6, #2
beq StoreDataEnd
st1 {v24.1d}, [x13]
cmp x6, #3
beq StoreDataEnd
st1 {v28.1d}, [x12]
b StoreDataEnd
Write1:
st1 {v16.s}[0], [x15]
cmp x6, #1
beq StoreDataEnd
st1 {v20.s}[0], [x14]
cmp x6, #2
beq StoreDataEnd
st1 {v24.s}[0], [x13]
cmp x6, #3
beq StoreDataEnd
st1 {v28.s}[0], [x12]
b StoreDataEnd
StoreDataEnd:
ret
#endif

View File

@ -0,0 +1,313 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/int8/dynamic_matmul_int8.h"
#include "nnacl/int8/fixed_point.h"
void DynamicMatmul4x4x16AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col,
int deep4, size_t stride, float input_scale, const float *filter_scale,
bool filter_per_channel) {
/* *
* row4x4-major * row4x16-major => (int8)row-major
* support activation per-layer symmetric && weight per-layer/per-channel symmetric
* */
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r4div = r / C4NUM, r4mod = r % C4NUM;
int c16div = c / C16NUM, c16mod = c % C16NUM;
int32_t value = 0;
for (int d = 0; d < deep4; d++) {
int d4div = d / C4NUM, d4mod = d % C4NUM;
size_t ai = r4div * deep4 * C4NUM + d4div * C4NUM * C4NUM + r4mod * C4NUM + d4mod;
size_t bi = c16div * deep4 * C16NUM + d4div * C4NUM * C16NUM + c16mod * C4NUM + d4mod;
value += a[ai] * b[bi];
}
int filter_quant_index = filter_per_channel ? c : 0;
double multi_scale = input_scale * filter_scale[filter_quant_index];
size_t ci = r * stride + c;
dst[ci] = multi_scale * value;
if (bias != NULL) {
dst[ci] += bias[c];
}
}
}
return;
}
void DynamicMatmul4x16x4AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col,
int deep, int deep16, size_t stride, int input_zp, float input_scale,
const float *filter_scale, const int filter_zp, bool filter_per_channel) {
/* *
* row4x16-major * row16x4-major => (int8)row-major
* support activation per-layer symmetric && weight per-layer/per-channel symmetric
* */
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r4div = r / C4NUM, r4mod = r % C4NUM;
int c4div = c / C4NUM, c4mod = c % C4NUM;
int32_t value = 0;
for (int d = 0; d < deep; d++) {
int d16div = d / C16NUM, d16mod = d % C16NUM;
size_t ai = r4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod;
size_t bi = c4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod;
value += (a[ai] - input_zp) * (b[bi] - filter_zp);
}
int filter_quant_index = filter_per_channel ? c : 0;
double multi_scale = input_scale * filter_scale[filter_quant_index];
size_t ci = r * stride + c;
dst[ci] = multi_scale * value;
if (bias != NULL) {
dst[ci] += bias[c];
}
}
}
return;
}
#ifdef ENABLE_ARM64
void PackInput4x4Asm(const int8_t *src_ic, int8_t *pack_ic, size_t ic_4div, size_t input_channel) {
size_t src_stride = input_channel;
size_t ic_4res = input_channel - ic_4div;
asm volatile(
"dup v2.4s, wzr \n"
"mov x10, %[src_ic] \n"
"mov x11, %[pack_ic] \n"
"mov x15, #0 \n"
"1: \n"
"cmp x15, %[ic_4div] \n"
"add x15, x15, #4\n"
"mov x12, x10 \n"
"add x10, x10, #4\n"
"blt 2f \n"
"cmp %[ic_4res], #0\n"
"beq 6f \n"
"cmp %[ic_4res], #1\n"
"beq 3f \n"
"cmp %[ic_4res], #2\n"
"beq 4f \n"
"cmp %[ic_4res], #3\n"
"beq 5f \n"
"2: \n"
"ld1 {v0.s}[0], [x12], %[src_stride]\n"
"ld1 {v0.s}[1], [x12], %[src_stride]\n"
"ld1 {v0.s}[2], [x12], %[src_stride]\n"
"ld1 {v0.s}[3], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"b 1b \n"
"3: \n" /* ic res 1 */
"dup v0.4s, wzr \n"
"ld1 {v0.b}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[8], [x12], %[src_stride]\n"
"ld1 {v0.b}[12], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"b 6f \n"
"4: \n" /* ic res 2 */
"dup v0.4s, wzr \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"b 6f \n"
"5: \n" /* ic res 3 */
"dup v0.4s, wzr \n"
"add x13, x12, #2 \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[2], [x13], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.b}[6], [x13], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[10], [x13], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v0.b}[14], [x13], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"b 6f \n"
"6: \n"
:
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div),
[ ic_4res ] "r"(ic_4res)
: "x10", "x11", "x12", "x13", "x14", "x15", "v0", "v1", "v2", "v3");
}
#endif
void PackInput4x4(const int8_t *src_input, int8_t *packed_input, size_t input_channel, size_t plane_size) {
int ic4 = UP_ROUND(input_channel, C4NUM);
size_t hw_4div = plane_size / C4NUM * C4NUM;
size_t ic_4div = input_channel / C4NUM * C4NUM;
const int8_t *src_r = src_input;
int8_t *pack_r = packed_input;
/* per layer */
for (int hwi = 0; hwi < hw_4div; hwi += C4NUM) {
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
#ifdef ENABLE_ARM64
PackInput4x4Asm(src_ic, pack_ic, ic_4div, input_channel);
#else
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
for (size_t i = 0; i < C4NUM; i++) {
pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel];
pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel];
pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel];
pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel];
}
src_ic += C4NUM;
pack_ic += C4NUM * C4NUM;
}
for (int ici = ic_4div; ici < input_channel; ici += 1) {
for (int i = 0; i < C4NUM; i++) {
pack_ic[i * C4NUM] = src_ic[i * input_channel];
}
src_ic += 1;
pack_ic += 1;
}
for (int ici = input_channel; ici < ic4; ici += 1) {
for (int i = 0; i < C4NUM; i++) {
pack_ic[i * C4NUM] = 0;
}
pack_ic += 1;
}
#endif
src_r += input_channel * C4NUM;
pack_r += ic4 * C4NUM;
}
if (hw_4div != plane_size) {
memset(pack_r, 0, C4NUM * ic4);
for (int hwi = hw_4div; hwi < plane_size; hwi += 1) {
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
pack_ic[0] = src_ic[0];
pack_ic[1] = src_ic[1];
pack_ic[2] = src_ic[2];
pack_ic[3] = src_ic[3];
src_ic += C4NUM;
pack_ic += C4NUM * C4NUM;
}
src_r += input_channel;
pack_r += C4NUM;
}
}
return;
}
// For matmul input a transpose case
void PackInput2Col4x4(const int8_t *src_input, int8_t *packed_input, int row, int col, int row_stride) {
const int row_tile = C4NUM;
int row_align = UP_ROUND(row, row_tile);
int row_div = row / row_tile * row_tile;
const int row_res = row - row_div;
const int col_tile = C4NUM;
int col_div = col / col_tile * col_tile;
const int col_res = col - col_div;
const int8_t *src_ic = NULL;
int8_t *packed_ic = NULL;
for (int c = 0; c < col_div; c += C4NUM) {
int r = 0;
src_ic = src_input + c;
packed_ic = packed_input + c * row_align;
#ifdef ENABLE_ARM64
size_t row_stride_int64 = row_stride;
asm volatile(
"mov w10, %w[row]\n"
"mov x11, %[src_ic]\n"
"mov x12, %[packed_ic]\n"
"cmp w10, wzr\n"
"beq 1f\n"
"2:\n"
"subs w10, w10, #4\n"
"ld1 {v0.s}[0], [x11], %[row_stride]\n"
"ld1 {v1.s}[0], [x11], %[row_stride]\n"
"ld1 {v0.s}[1], [x11], %[row_stride]\n"
"ld1 {v1.s}[1], [x11], %[row_stride]\n"
"zip1 v2.8b, v0.8b, v1.8b\n"
"zip2 v3.8b, v0.8b, v1.8b\n"
"zip1 v4.4h, v2.4h, v3.4h\n"
"zip2 v5.4h, v2.4h, v3.4h\n"
"st1 {v4.4h, v5.4h}, [x12], #16\n"
"bgt 2b\n"
"1:\n"
:
: [ src_ic ] "r"(src_ic), [ packed_ic ] "r"(packed_ic), [ row ] "r"(row_div), [ row_stride ] "r"(row_stride_int64)
: "memory", "w10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12");
packed_ic += C4NUM * row_div;
src_ic += row_div * row_stride;
#else
for (; r < row_div; r += C4NUM) {
for (int i = 0; i < row_tile; i++) {
packed_ic[0 * row_tile + i] = src_ic[i * row_stride + 0];
packed_ic[1 * row_tile + i] = src_ic[i * row_stride + 1];
packed_ic[2 * row_tile + i] = src_ic[i * row_stride + 2];
packed_ic[3 * row_tile + i] = src_ic[i * row_stride + 3];
}
packed_ic += C16NUM;
src_ic += row_tile * row_stride;
}
#endif
for (r = 0; r < row_res; ++r) {
for (int i = 0; i < C4NUM; ++i) {
packed_ic[i * row_tile + r] = src_ic[r * row_stride + i];
}
}
}
if (col_res == 0) {
return;
}
src_ic = src_input + col_div;
packed_ic = packed_input + row_align * col_div;
for (int r = 0; r < row_div; r += row_tile) {
for (int i = 0; i < col_res; ++i) {
packed_ic[i * row_tile + 0] = src_ic[0 * row_stride + i];
packed_ic[i * row_tile + 1] = src_ic[1 * row_stride + i];
packed_ic[i * row_tile + 2] = src_ic[2 * row_stride + i];
packed_ic[i * row_tile + 3] = src_ic[3 * row_stride + i];
}
src_ic += row_tile * row_stride;
packed_ic += row_tile * col_tile;
}
for (int r = 0; r < row_res; ++r) {
for (int c = 0; c < col_res; ++c) {
packed_ic[c * row_tile + r] = src_ic[r * row_stride + c];
}
}
}

View File

@ -0,0 +1,44 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_INT8_DYNAMIC_MATMUL_H_
#define MINDSPORE_NNACL_INT8_DYNAMIC_MATMUL_H_
#include <string.h>
#include "nnacl/op_base.h"
#include "nnacl/matmul_parameter.h"
#ifdef __cplusplus
extern "C" {
#endif
void PackInput2Col4x4(const int8_t *src_input, int8_t *packed_input, int row, int col, int row_stride);
void PackInput4x4(const int8_t *src_input, int8_t *packed_input, size_t input_channel, size_t plane_size);
void DynamicMatmul4x16x4AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col,
int deep, int deep16, size_t stride, int input_zp, float input_scale,
const float *filter_scale, const int filter_zp, bool filter_per_channel);
#ifdef ENABLE_ARM64
void DynamicMatmulSdot4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, float *multi_scales,
float *bias, size_t row, size_t col, size_t stride);
#else
void DynamicMatmul4x4x16AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col,
int deep4, size_t stride, float input_scale, const float *filter_scale,
bool filter_per_channel);
#endif
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_NNACL_INT8_DYNAMIC_MATMUL_H_

View File

@ -15,8 +15,8 @@
*/
#include "nnacl/int8/dynamic_quant_int8.h"
void CalculateMinMaxFp32(const float *data, int count, float *real_min, float *real_max) {
#ifndef PLATFORM_ARM64
for (int i = 0; i < count; ++i) {
if (data[i] < *real_min) {
*real_min = data[i];
@ -25,4 +25,40 @@ void CalculateMinMaxFp32(const float *data, int count, float *real_min, float *r
*real_max = data[i];
}
}
#else
int count_4 = DOWN_ROUND(count, C4NUM);
asm volatile(
"mov x4, %[data]\n" // reload data
"mov w5, %w[count_4]\n" // reload count
"ld1 {v31.4s}, [x4]\n" // min
"ld1 {v30.4s}, [x4], #16\n" // max
"subs w5, w5, #4\n"
"ble MinMax\n"
"LoopCount:\n"
"ld1 {v0.4s}, [x4], #16\n"
"fmin v31.4s, v31.4s, v0.4s\n"
"fmax v30.4s, v30.4s, v0.4s\\nn"
"subs w5, w5, #4\n"
"bgt LoopCount\n"
"MinMax:\n"
"fminv s6, v31.4s\n"
"fmaxv s7, v30.4s\n"
"str s6, %[real_min]\n"
"str s7, %[real_max]\n"
:
: [ data ] "r"(data), [ count_4 ] "r"(count_4), [ real_min ] "r"(real_min), [ real_max ] "r"(real_max)
: "x4", "w5", "s6", "s7", "v0", "v30", "v31");
for (int i = count_4; i < count; ++i) {
if (data[i] < *real_min) {
*real_min = data[i];
}
if (data[i] > *real_max) {
*real_max = data[i];
}
}
#endif
}

View File

@ -330,38 +330,6 @@ void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int c
return;
}
#endif
void DynamicMatmulInt8AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col,
int deep16, float input_scale, const float *filter_scale, size_t stride,
bool filter_per_channel) {
/* *
* row4x16-major * row16x4-major => (int8)row-major
* support activation per-layer symmetric && weight per-layer/per-channel symmetric
* */
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r4div = r / C4NUM, r4mod = r % C4NUM;
int c4div = c / C4NUM, c4mod = c % C4NUM;
int filter_quant_index = filter_per_channel ? c : 0;
double multi_scale = input_scale * filter_scale[filter_quant_index];
double value = 0;
for (int d = 0; d < deep16; d++) {
int d16div = d / C16NUM, d16mod = d % C16NUM;
size_t ai = r4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod;
size_t bi = c4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod;
int32_t value_1 = a[ai] * b[bi];
value += multi_scale * value_1;
}
if (bias != NULL) {
value += bias[c];
}
size_t ci = r * stride + c;
dst[ci] = value;
}
}
return;
}
void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
size_t stride, const int32_t *input_sum, const int32_t *bias, const int32_t *left_shift,
const int32_t *right_shift, const int32_t *multiplier, int32_t output_zp, int32_t mini,

View File

@ -45,11 +45,6 @@ void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int c
const int *bias, int act_min, int act_max, int out_zp, const int32_t *multiplier,
const int32_t *left_shift, const int32_t *right_shift, size_t stride, size_t filter_peroc,
const int32_t *filter_zp);
void DynamicMatmulInt8AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col,
int deep16, float input_scale, const float *filter_scale, size_t stride,
bool filter_per_channel);
/* 8x4 4x8 -> 8x8 */
/* optimize conv */
void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col);

View File

@ -87,7 +87,7 @@ int DynamicQuantCPUKernel::CalculateMinMax(int task_id) {
return RET_OK;
}
int CalculateMinMaxRun(void *cdata, int task_id, float, float) {
int CalculateMinMaxRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
CHECK_NULL_RETURN(cdata);
auto g_kernel = reinterpret_cast<DynamicQuantCPUKernel *>(cdata);
auto ret = g_kernel->CalculateMinMax(task_id);
@ -112,10 +112,16 @@ void DynamicQuantCPUKernel::ReduceMinMaxFp32() {
void DynamicQuantCPUKernel::CalculateScaleZp() {
lite::LiteQuantParam quant_parm;
double scale = (real_max_ - real_min_) / (INT8_MAX - INT8_MIN);
double scale;
int zp = 0;
constexpr int kQSymmetricRange = 255;
constexpr int kQAsymmetricRange = 254;
if (!symmetric_) {
scale = (real_max_ - real_min_) / kQSymmetricRange; // -128 ~ 127
zp = static_cast<int>(std::round(INT8_MIN - real_min_ / scale));
} else {
auto max = std::max(abs(real_max_), abs(real_min_));
scale = 2 * max / kQAsymmetricRange; // -127 ~ 127
}
quant_parm.scale = scale;
quant_parm.zeroPoint = zp;

View File

@ -0,0 +1,240 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/int8/matmul_dynamic_base_int8.h"
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_MEMORY_FAILED;
using mindspore::lite::RET_OK;
namespace mindspore::kernel {
namespace {
constexpr int kHasBiasSize = 3;
constexpr int kMinInputSize = 2;
constexpr int kOutputSize = 1;
constexpr int kSize1 = 1;
constexpr int kSize2 = 2;
} // namespace
MatmulDynamicBaseInt8CPUKernel::~MatmulDynamicBaseInt8CPUKernel() {
FreeQuantParam();
FreeTmpBuffer();
}
void MatmulDynamicBaseInt8CPUKernel::FreeQuantParam() {
if (quant_param_ != nullptr) {
if (quant_param_->filter_scale_ != nullptr) {
free(quant_param_->filter_scale_);
quant_param_->filter_scale_ = nullptr;
}
if (quant_param_->filter_zp_ != nullptr) {
free(quant_param_->filter_zp_);
quant_param_->filter_zp_ = nullptr;
}
free(quant_param_);
quant_param_ = nullptr;
}
}
int MatmulDynamicBaseInt8CPUKernel::MallocQuantParam() {
quant_param_ = reinterpret_cast<MatmulDynamicQuantParameter *>(malloc(sizeof(MatmulQuantParameter)));
if (quant_param_ == nullptr) {
MS_LOG(ERROR) << "Malloc MatmulDynamicQuantParameter for Matmul int8 op failed!";
return RET_ERROR;
}
memset(quant_param_, 0, sizeof(MatmulQuantParameter));
return RET_OK;
}
int MatmulDynamicBaseInt8CPUKernel::InitFilterQuantParam() {
if (quant_param_->filter_scale_ != nullptr) {
free(quant_param_->filter_scale_);
quant_param_->filter_scale_ = nullptr;
}
if (quant_param_->filter_zp_ != nullptr) {
free(quant_param_->filter_zp_);
quant_param_->filter_zp_ = nullptr;
}
auto weight_tensor = in_tensors_.at(kWeightIndex);
auto weight_quant_params = weight_tensor->quant_params();
auto w_shape = weight_tensor->shape();
if (w_shape.size() < DIMENSION_2D) {
MS_LOG(ERROR) << weight_tensor->tensor_name() << " dims < 2.";
return RET_ERROR;
}
int col = param_->b_transpose_ ? w_shape[w_shape.size() - kSize2] : w_shape[w_shape.size() - kSize1];
filter_per_channel_ = (weight_quant_params.size() > 1);
channel_num_ = filter_per_channel_ ? col : 1;
if (static_cast<int>(weight_quant_params.size()) != channel_num_) {
MS_LOG(ERROR) << weight_tensor->tensor_name() << " quant params size:" << weight_quant_params.size()
<< " != channel_num_:" << channel_num_;
return RET_ERROR;
}
quant_param_->filter_scale_ = reinterpret_cast<float *>(malloc(channel_num_ * sizeof(float)));
CHECK_NULL_RETURN(quant_param_->filter_scale_);
memset(quant_param_->filter_scale_, 0, sizeof(channel_num_));
quant_param_->filter_zp_ = reinterpret_cast<int32_t *>(malloc(channel_num_ * sizeof(int32_t)));
CHECK_NULL_RETURN(quant_param_->filter_zp_);
memset(quant_param_->filter_zp_, 0, sizeof(channel_num_));
for (int i = 0; i < channel_num_; i++) {
quant_param_->filter_scale_[i] = static_cast<float>(weight_quant_params[i].scale);
quant_param_->filter_zp_[i] = weight_quant_params[i].zeroPoint;
}
return RET_OK;
}
void MatmulDynamicBaseInt8CPUKernel::ResizeParameter() {
param_->row_align_ = UP_ROUND(param_->row_, row_tile_);
param_->col_align_ = UP_ROUND(param_->col_, col_tile_);
param_->deep_align_ = UP_ROUND(param_->deep_, deep_tile_);
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(param_->col_align_, col_tile_));
thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, col_tile_), thread_count_);
return;
}
void MatmulDynamicBaseInt8CPUKernel::FreeTmpBuffer() {
if (pack_a_ptr_ != nullptr) {
free(pack_a_ptr_);
pack_a_ptr_ = nullptr;
}
if (pack_b_ptr_ != nullptr) {
free(pack_b_ptr_);
pack_b_ptr_ = nullptr;
}
return;
}
int MatmulDynamicBaseInt8CPUKernel::InitInputQuantParam() {
auto in_quant_params = in_tensors_.at(kInputIndex)->quant_params();
if (in_quant_params.empty()) {
MS_LOG(ERROR) << "invalid in quant param";
return RET_ERROR;
}
quant_param_->input_zp_ = in_quant_params.front().zeroPoint;
quant_param_->input_scale_ = static_cast<float>(in_quant_params.front().scale);
return RET_OK;
}
int MatmulDynamicBaseInt8CPUKernel::TransferB() {
auto weight_data = reinterpret_cast<int8_t *>(in_tensors_.at(kWeightIndex)->data());
CHECK_NULL_RETURN(weight_data);
memset(pack_b_ptr_, quant_param_->filter_zp_[0],
param_->batch * param_->col_align_ * param_->deep_align_ * sizeof(int8_t));
for (int i = 0; i < param_->batch; i++) {
auto current_weight = weight_data + i * param_->deep_ * param_->col_;
auto current_b_pack = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_;
CHECK_NULL_RETURN(b_pack_func_);
if (param_->b_transpose_) {
b_pack_func_(current_weight, current_b_pack, param_->col_, param_->deep_);
} else {
b_pack_func_(current_weight, current_b_pack, param_->deep_, param_->col_);
}
}
return RET_OK;
}
int MatmulDynamicBaseInt8CPUKernel::InitTmpBuffer() {
pack_a_ptr_ = reinterpret_cast<int8_t *>(malloc(param_->row_align_ * param_->deep_align_ * sizeof(int8_t)));
if (pack_a_ptr_ == nullptr) {
FreeTmpBuffer();
return RET_ERROR;
}
pack_b_ptr_ =
reinterpret_cast<int8_t *>(malloc(param_->batch * param_->col_align_ * param_->deep_align_ * sizeof(int8_t)));
if (pack_b_ptr_ == nullptr) {
FreeTmpBuffer();
return RET_ERROR;
}
memset(pack_a_ptr_, 0, param_->row_align_ * param_->deep_align_ * sizeof(int8_t));
memset(pack_b_ptr_, 0, param_->batch * param_->col_align_ * param_->deep_align_ * sizeof(int8_t));
return RET_OK;
}
int MatmulDynamicBaseInt8CPUKernel::CopyBias() {
if (in_tensors_.size() == kHasBiasSize) {
auto bias_tensor = in_tensors_[kBiasIndex];
fp32_bias_ptr_ = reinterpret_cast<float *>(bias_tensor->data());
if (fp32_bias_ptr_ == nullptr) {
MS_LOG(ERROR) << "Memory allocation failed";
FreeTmpBuffer();
return RET_MEMORY_FAILED;
}
memcpy(fp32_bias_ptr_, bias_tensor->data(), bias_tensor->ElementsNum() * sizeof(float));
} else {
fp32_bias_ptr_ = nullptr;
}
return RET_OK;
}
int MatmulDynamicBaseInt8CPUKernel::Prepare() {
CHECK_LESS_RETURN(in_tensors_.size(), kMinInputSize);
CHECK_LESS_RETURN(out_tensors_.size(), kOutputSize);
InitParameter();
auto ret = MallocQuantParam();
if (ret != RET_OK) {
FreeQuantParam();
return ret;
}
if (param_->b_const_) {
ret = InitFilterQuantParam();
if (ret != RET_OK) {
FreeQuantParam();
return ret;
}
}
ret = CopyBias();
if (ret != RET_OK) {
FreeQuantParam();
return ret;
}
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int MatmulDynamicBaseInt8CPUKernel::ReSize() {
int batch = 1;
auto x_shape = in_tensors_.at(0)->shape();
auto o_shape = out_tensors_.at(0)->shape();
MS_ASSERT(x_shape.size() >= kSize2);
for (size_t i = 0; i < x_shape.size() - kSize2; ++i) {
batch *= x_shape[i];
}
param_->batch = batch;
MS_ASSERT(o_shape.size() >= kSize2);
param_->row_ = o_shape[o_shape.size() - kSize2];
param_->col_ = o_shape[o_shape.size() - kSize1];
param_->deep_ = param_->a_transpose_ ? x_shape[x_shape.size() - kSize2] : x_shape[x_shape.size() - kSize1];
FreeTmpBuffer();
ResizeParameter();
auto ret = InitTmpBuffer();
if (ret != RET_OK) {
FreeQuantParam();
return ret;
}
if (param_->b_const_ == true) {
TransferB();
}
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,79 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_DYNAMIC_BASE_INT8_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_DYNAMIC_BASE_INT8_H_
#include <vector>
#include "include/errorcode.h"
#include "include/context.h"
#include "src/inner_kernel.h"
#include "nnacl/matmul_parameter.h"
#include "nnacl/common_func.h"
#include "nnacl/int8/quantize.h"
#include "nnacl/int8/common_func_int8.h"
namespace mindspore::kernel {
class MatmulDynamicBaseInt8CPUKernel : public InnerKernel {
public:
MatmulDynamicBaseInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: InnerKernel(parameter, inputs, outputs, ctx) {
param_ = reinterpret_cast<MatMulParameter *>(op_parameter_);
}
~MatmulDynamicBaseInt8CPUKernel() override;
int Prepare() override;
int ReSize() override;
private:
void ResizeParameter();
int CopyBias();
int InitTmpBuffer();
int MallocQuantParam();
protected:
typedef void (*PackFunc)(const int8_t *src, int8_t *dst, int row, int col);
virtual void InitParameter() = 0;
int TransferA();
int InitInputQuantParam();
int InitFilterQuantParam();
int TransferB();
void FreeTmpBuffer();
void FreeQuantParam();
protected:
MatMulParameter *param_ = nullptr;
MatmulDynamicQuantParameter *quant_param_ = nullptr;
int thread_count_ = 1;
int thread_stride_ = 0;
int8_t *pack_a_ptr_ = nullptr;
int8_t *pack_b_ptr_ = nullptr;
float *fp32_bias_ptr_ = nullptr;
bool filter_per_channel_ = true;
int8_t *batch_input_ptr_ = nullptr;
int8_t *batch_weight_ptr_ = nullptr;
int8_t *batch_b_ptr_ = nullptr;
float *batch_c_ptr_ = nullptr;
int row_tile_ = C4NUM;
int col_tile_ = C4NUM;
int deep_tile_ = C16NUM;
int channel_num_ = 0;
PackFunc b_pack_func_{nullptr};
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_DYNAMIC_BASE_INT8_H_

View File

@ -16,7 +16,9 @@
#include "src/runtime/kernel/arm/int8/matmul_dynamic_int8.h"
#include "src/runtime/kernel/arm/int8/opt_op_handler.h"
#include "src/common/file_utils.h"
#include "nnacl/int8/matmul_int8.h"
#include "nnacl/int8/dynamic_matmul_int8.h"
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_MEMORY_FAILED;
@ -24,13 +26,6 @@ using mindspore::lite::RET_OK;
namespace mindspore::kernel {
namespace {
constexpr int kHasBiasSize = 3;
constexpr int kMinInputSize = 2;
constexpr int kOutputSize = 1;
constexpr int kSize1 = 1;
constexpr int kSize2 = 2;
} // namespace
int MatmulDynamicInt8Run(void *cdata, int task_id, float, float) {
CHECK_NULL_RETURN(cdata);
auto op = reinterpret_cast<MatmulDynamicInt8CPUKernel *>(cdata);
@ -41,6 +36,7 @@ int MatmulDynamicInt8Run(void *cdata, int task_id, float, float) {
}
return RET_OK;
}
} // namespace
int MatmulDynamicInt8CPUKernel::RunImpl(int task_id) {
int stride = thread_stride_ * col_tile_;
@ -55,92 +51,14 @@ int MatmulDynamicInt8CPUKernel::RunImpl(int task_id) {
bias_ptr += cur_stride;
}
float *filter_scale = quant_param_->filter_scale_;
int32_t filter_zp = quant_param_->filter_zp_[0];
if (filter_per_channel_) {
filter_scale += cur_stride;
}
DynamicMatmulInt8AIWI(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, bias_ptr,
batch_c_ptr_ + cur_stride, param_->row_, cur_oc, param_->deep_align_,
quant_param_->input_scale_, filter_scale, param_->col_, filter_per_channel_);
return RET_OK;
}
MatmulDynamicInt8CPUKernel::~MatmulDynamicInt8CPUKernel() {
FreeQuantParam();
FreeTmpBuffer();
}
void MatmulDynamicInt8CPUKernel::FreeQuantParam() {
if (quant_param_ != nullptr) {
if (quant_param_->filter_scale_ != nullptr) {
free(quant_param_->filter_scale_);
quant_param_->filter_scale_ = nullptr;
}
if (quant_param_->filter_zp_ != nullptr) {
free(quant_param_->filter_zp_);
quant_param_->filter_zp_ = nullptr;
}
free(quant_param_);
quant_param_ = nullptr;
}
}
int MatmulDynamicInt8CPUKernel::MallocQuantParam() {
quant_param_ = reinterpret_cast<MatmulDynamicQuantParameter *>(malloc(sizeof(MatmulQuantParameter)));
if (quant_param_ == nullptr) {
MS_LOG(ERROR) << "Malloc MatmulDynamicQuantParameter for Matmul int8 op failed!";
return RET_ERROR;
}
memset(quant_param_, 0, sizeof(MatmulQuantParameter));
return RET_OK;
}
int MatmulDynamicInt8CPUKernel::InitFilterQuantParam() {
if (quant_param_->filter_scale_ != nullptr) {
free(quant_param_->filter_scale_);
quant_param_->filter_scale_ = nullptr;
}
if (quant_param_->filter_zp_ != nullptr) {
free(quant_param_->filter_zp_);
quant_param_->filter_zp_ = nullptr;
}
auto weight_tensor = in_tensors_.at(kWeightIndex);
auto weight_quant_params = weight_tensor->quant_params();
auto w_shape = weight_tensor->shape();
if (w_shape.size() < DIMENSION_2D) {
MS_LOG(ERROR) << weight_tensor->tensor_name() << " dims < 2.";
return RET_ERROR;
}
int col = param_->b_transpose_ ? w_shape[w_shape.size() - kSize2] : w_shape[w_shape.size() - kSize1];
filter_per_channel_ = (weight_quant_params.size() > 1);
channel_num_ = filter_per_channel_ ? col : 1;
if (static_cast<int>(weight_quant_params.size()) != channel_num_) {
MS_LOG(ERROR) << weight_tensor->tensor_name() << " quant params size:" << weight_quant_params.size()
<< " != channel_num_:" << channel_num_;
return RET_ERROR;
}
quant_param_->filter_scale_ = reinterpret_cast<float *>(malloc(channel_num_ * sizeof(float)));
CHECK_NULL_RETURN(quant_param_->filter_scale_);
memset(quant_param_->filter_scale_, 0, sizeof(channel_num_));
quant_param_->filter_zp_ = reinterpret_cast<int32_t *>(malloc(channel_num_ * sizeof(int32_t)));
CHECK_NULL_RETURN(quant_param_->filter_zp_);
memset(quant_param_->filter_zp_, 0, sizeof(channel_num_));
for (int i = 0; i < channel_num_; i++) {
quant_param_->filter_scale_[i] = static_cast<float>(weight_quant_params[i].scale);
quant_param_->filter_zp_[i] = weight_quant_params[i].zeroPoint;
}
return RET_OK;
}
int MatmulDynamicInt8CPUKernel::InitInputQuantParam() {
auto in_quant_params = in_tensors_.at(kInputIndex)->quant_params();
if (in_quant_params.empty()) {
MS_LOG(ERROR) << "invalid in quant param";
return RET_ERROR;
}
quant_param_->input_zp_ = in_quant_params.front().zeroPoint;
quant_param_->input_scale_ = static_cast<float>(in_quant_params.front().scale);
DynamicMatmul4x16x4AIWI(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, bias_ptr,
batch_c_ptr_ + cur_stride, param_->row_, cur_oc, param_->deep_, param_->deep_align_,
param_->col_, quant_param_->input_zp_, quant_param_->input_scale_, filter_scale, filter_zp,
filter_per_channel_);
return RET_OK;
}
@ -163,133 +81,6 @@ void MatmulDynamicInt8CPUKernel::InitParameter() {
return;
}
void MatmulDynamicInt8CPUKernel::ResizeParameter() {
param_->row_align_ = UP_ROUND(param_->row_, row_tile_);
param_->col_align_ = UP_ROUND(param_->col_, col_tile_);
param_->deep_align_ = UP_ROUND(param_->deep_, deep_tile_);
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(param_->col_align_, col_tile_));
thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, col_tile_), thread_count_);
return;
}
void MatmulDynamicInt8CPUKernel::FreeTmpBuffer() {
if (pack_a_ptr_ != nullptr) {
free(pack_a_ptr_);
pack_a_ptr_ = nullptr;
}
if (pack_b_ptr_ != nullptr) {
free(pack_b_ptr_);
pack_b_ptr_ = nullptr;
}
return;
}
int MatmulDynamicInt8CPUKernel::TransferB() {
auto weight_data = reinterpret_cast<int8_t *>(in_tensors_.at(kWeightIndex)->data());
CHECK_NULL_RETURN(weight_data);
for (int i = 0; i < param_->batch; i++) {
auto current_weight = weight_data + i * param_->deep_ * param_->col_;
auto current_b_pack = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_;
CHECK_NULL_RETURN(b_pack_func_);
if (param_->b_transpose_) {
b_pack_func_(current_weight, current_b_pack, param_->col_, param_->deep_);
} else {
b_pack_func_(current_weight, current_b_pack, param_->deep_, param_->col_);
}
}
return RET_OK;
}
int MatmulDynamicInt8CPUKernel::InitTmpBuffer() {
pack_a_ptr_ = reinterpret_cast<int8_t *>(malloc(param_->row_align_ * param_->deep_align_ * sizeof(int8_t)));
if (pack_a_ptr_ == nullptr) {
FreeTmpBuffer();
return RET_ERROR;
}
pack_b_ptr_ =
reinterpret_cast<int8_t *>(malloc(param_->batch * param_->col_align_ * param_->deep_align_ * sizeof(int8_t)));
if (pack_b_ptr_ == nullptr) {
FreeTmpBuffer();
return RET_ERROR;
}
memset(pack_a_ptr_, 0, param_->row_align_ * param_->deep_align_ * sizeof(int8_t));
memset(pack_b_ptr_, 0, param_->batch * param_->col_align_ * param_->deep_align_ * sizeof(int8_t));
return RET_OK;
}
int MatmulDynamicInt8CPUKernel::CopyBias() {
if (in_tensors_.size() == kHasBiasSize) {
auto bias_tensor = in_tensors_[kBiasIndex];
fp32_bias_ptr_ = reinterpret_cast<float *>(bias_tensor->data());
if (fp32_bias_ptr_ == nullptr) {
MS_LOG(ERROR) << "Memory allocation failed";
FreeTmpBuffer();
return RET_MEMORY_FAILED;
}
memcpy(fp32_bias_ptr_, bias_tensor->data(), bias_tensor->ElementsNum() * sizeof(float));
} else {
fp32_bias_ptr_ = nullptr;
}
return RET_OK;
}
int MatmulDynamicInt8CPUKernel::Prepare() {
CHECK_LESS_RETURN(in_tensors_.size(), kMinInputSize);
CHECK_LESS_RETURN(out_tensors_.size(), kOutputSize);
InitParameter();
auto ret = MallocQuantParam();
if (ret != RET_OK) {
FreeQuantParam();
return ret;
}
if (param_->b_const_) {
ret = InitFilterQuantParam();
if (ret != RET_OK) {
FreeQuantParam();
return ret;
}
}
ret = CopyBias();
if (ret != RET_OK) {
FreeQuantParam();
return ret;
}
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int MatmulDynamicInt8CPUKernel::ReSize() {
int batch = 1;
auto x_shape = in_tensors_.at(0)->shape();
auto o_shape = out_tensors_.at(0)->shape();
MS_ASSERT(x_shape.size() >= kSize2);
for (size_t i = 0; i < x_shape.size() - kSize2; ++i) {
batch *= x_shape[i];
}
param_->batch = batch;
MS_ASSERT(o_shape.size() >= kSize2);
param_->row_ = o_shape[o_shape.size() - kSize2];
param_->col_ = o_shape[o_shape.size() - kSize1];
param_->deep_ = param_->a_transpose_ ? x_shape[x_shape.size() - kSize2] : x_shape[x_shape.size() - kSize1];
FreeTmpBuffer();
ResizeParameter();
auto ret = InitTmpBuffer();
if (ret != RET_OK) {
FreeQuantParam();
return ret;
}
if (param_->b_const_ == true) {
TransferB();
}
return RET_OK;
}
int MatmulDynamicInt8CPUKernel::Run() {
auto ret = InitInputQuantParam();
if (ret != RET_OK) {
@ -314,6 +105,7 @@ int MatmulDynamicInt8CPUKernel::Run() {
CHECK_NULL_RETURN(a_ptr);
CHECK_NULL_RETURN(c_ptr);
for (int i = 0; i < param_->batch; i++) {
memset(pack_a_ptr_, quant_param_->input_zp_, param_->row_align_ * param_->deep_align_ * sizeof(int8_t));
auto current_src_a = a_ptr + i * param_->row_ * param_->deep_;
if (param_->a_transpose_) {
MS_CHECK_TRUE_RET(a_pack_func_ != nullptr, RET_ERROR);

View File

@ -18,72 +18,25 @@
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_DYNAMIC_INT8_H_
#include <vector>
#include "include/errorcode.h"
#include "include/context.h"
#include "src/inner_kernel.h"
#include "nnacl/matmul_parameter.h"
#include "nnacl/common_func.h"
#include "nnacl/int8/quantize.h"
#include "nnacl/int8/common_func_int8.h"
#include "nnacl/int8/matmul_int8.h"
#include "src/runtime/kernel/arm/int8/matmul_dynamic_base_int8.h"
namespace mindspore::kernel {
class MatmulDynamicInt8CPUKernel : public InnerKernel {
typedef void (*PackFunc)(const int8_t *src, int8_t *dst, int row, int col);
class MatmulDynamicInt8CPUKernel : public MatmulDynamicBaseInt8CPUKernel {
public:
MatmulDynamicInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: InnerKernel(parameter, inputs, outputs, ctx) {
param_ = reinterpret_cast<MatMulParameter *>(op_parameter_);
}
~MatmulDynamicInt8CPUKernel() override;
int Prepare() override;
int ReSize() override;
: MatmulDynamicBaseInt8CPUKernel(parameter, inputs, outputs, ctx) {}
~MatmulDynamicInt8CPUKernel() override = default;
int Run() override;
public:
int RunImpl(int task_id);
#if defined(ENABLE_ARM64) && !defined(SUPPORT_NNIE) && (!defined(MACHINE_LINUX_ARM64))
int RunArm64Sdot();
int Arm64SdotImpl(int task_id);
int Arm64SdotPre(int task_id);
#endif
private:
void InitParameter();
void ResizeParameter();
int CopyBias();
int InitTmpBuffer();
void FreeTmpBuffer();
int TransferA();
int TransferB();
int MallocQuantParam();
int InitInputQuantParam();
int InitFilterQuantParam();
void FreeQuantParam();
void InitParameter() override;
private:
MatMulParameter *param_ = nullptr;
MatmulDynamicQuantParameter *quant_param_ = nullptr;
int thread_count_ = 1;
int thread_stride_ = 0;
int8_t *pack_a_ptr_ = nullptr;
int8_t *pack_b_ptr_ = nullptr;
float *fp32_bias_ptr_ = nullptr;
bool filter_per_channel_ = true;
int8_t *batch_input_ptr_ = nullptr;
int8_t *batch_weight_ptr_ = nullptr;
int8_t *batch_b_ptr_ = nullptr;
float *batch_c_ptr_ = nullptr;
int row_tile_ = C4NUM;
int col_tile_ = C4NUM;
int deep_tile_ = C16NUM;
int channel_num_ = 0;
bool support_sdot_ = false;
PackFunc a_pack_func_{nullptr};
PackFunc b_pack_func_{nullptr};
};
} // namespace mindspore::kernel

View File

@ -0,0 +1,184 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/int8/matmul_dynamic_sdot_int8.h"
#include <vector>
#include "src/common/file_utils.h"
#include "nnacl/int8/dynamic_matmul_int8.h"
#include "nnacl/int8/matmul_int8.h"
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_MEMORY_FAILED;
using mindspore::lite::RET_OK;
namespace mindspore::kernel {
namespace {
int Arm64SdotPreRun(void *cdata, int task_id, float, float) {
CHECK_NULL_RETURN(cdata);
auto op = reinterpret_cast<MatMulDynamicSdotInt8Kernel *>(cdata);
auto ret = op->MatMulDynamicArm64SdotPre(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "MatmulInt8Run error task_id[" << task_id << "] error_code[" << ret << "]";
return ret;
}
return RET_OK;
}
int Arm64SdotRun(void *cdata, int task_id, float, float) {
CHECK_NULL_RETURN(cdata);
auto op = reinterpret_cast<MatMulDynamicSdotInt8Kernel *>(cdata);
auto ret = op->MatMulDynamicArm64SdotImpl(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "MatmulInt8Run error task_id[" << task_id << "] error_code[" << ret << "]";
return ret;
}
return RET_OK;
}
} // namespace
int MatMulDynamicSdotInt8Kernel::MatMulDynamicArm64SdotPre(int task_id) {
int row_thread_count = MSMIN(op_parameter_->thread_num_, UP_DIV(param_->row_align_, row_tile_));
int row_stride = UP_DIV(UP_DIV(param_->row_align_, row_tile_), row_thread_count) * row_tile_;
int row_current_stride = task_id * row_stride;
int row_res_stride = param_->row_ - row_current_stride;
int cur_r = MSMIN(row_res_stride, row_stride);
if (cur_r <= 0) {
return RET_OK;
}
auto current_a_pack = pack_a_ptr_ + row_current_stride * param_->deep_align_;
if (param_->a_transpose_) {
auto current_src_a = batch_input_ptr_ + row_current_stride;
PackInput2Col4x4(current_src_a, current_a_pack, param_->deep_, cur_r, param_->row_);
} else {
auto current_src_a = batch_input_ptr_ + row_current_stride * param_->deep_;
PackInput4x4(current_src_a, current_a_pack, param_->deep_, cur_r);
}
return RET_OK;
}
int MatMulDynamicSdotInt8Kernel::MatMulDynamicArm64SdotImpl(int task_id) {
#if defined(ENABLE_ARM64) && !defined(SUPPORT_NNIE) && (!defined(MACHINE_LINUX_ARM64))
// Multi-thread split by col.
int stride = thread_stride_ * col_tile_;
int cur_stride = task_id * stride;
int res_stride = param_->col_ - cur_stride;
int cur_oc = MSMIN(stride, res_stride);
if (cur_oc <= 0) {
return RET_OK;
}
if (!param_->b_const_) {
auto current_b_pack = batch_b_ptr_ + cur_stride * param_->deep_align_;
if (param_->b_transpose_) {
auto current_weight = batch_weight_ptr_ + cur_stride * param_->deep_;
RowMajor2Row4x16MajorInt8(current_weight, current_b_pack, cur_oc, param_->deep_);
} else {
auto current_weight = batch_weight_ptr_ + cur_stride;
RowMajor2Col4x16MajorPartInt8(current_weight, current_b_pack, param_->deep_, param_->col_, cur_oc);
}
}
std::vector<float> multi_scale(cur_oc);
for (int i = 0; i < cur_oc; ++i) {
if (!param_->b_const_) {
multi_scale[i] = quant_param_->input_scale_ * quant_param_->filter_scale_[0];
} else {
multi_scale[i] = quant_param_->input_scale_ * quant_param_->filter_scale_[cur_stride + i];
}
}
for (int r = 0; r < param_->row_; r += C4NUM) {
size_t row = MSMIN(C4NUM, param_->row_ - r);
auto a_ptr = pack_a_ptr_ + r * param_->deep_align_;
for (int c = 0; c < cur_oc; c += C16NUM) {
size_t col = MSMIN(C16NUM, cur_oc - c);
auto col_offset = cur_stride + c;
auto b_ptr = batch_b_ptr_ + col_offset * param_->deep_align_;
auto out_ptr = batch_c_ptr_ + r * param_->col_ + col_offset;
auto bias = fp32_bias_ptr_;
if (bias != nullptr) {
bias += col_offset;
}
DynamicMatmulSdot4x4x16AIWI(a_ptr, b_ptr, out_ptr, param_->deep_align_, multi_scale.data() + c, bias, row, col,
param_->col_ * sizeof(float));
}
}
#endif
return RET_OK;
}
void MatMulDynamicSdotInt8Kernel::InitParameter() {
param_->a_const_ = (in_tensors_[0]->data() != nullptr);
param_->b_const_ = (in_tensors_[1]->data() != nullptr);
row_tile_ = C4NUM;
col_tile_ = C16NUM;
deep_tile_ = C4NUM;
if (param_->b_transpose_) {
b_pack_func_ = RowMajor2Row4x16MajorInt8;
} else {
b_pack_func_ = RowMajor2Col4x16MajorInt8;
}
return;
}
int MatMulDynamicSdotInt8Kernel::MatMulDynamicRunArm64Sdot() {
int8_t *a_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(0)->data());
int8_t *b_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data());
float *c_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->data());
CHECK_NULL_RETURN(a_ptr);
CHECK_NULL_RETURN(b_ptr);
CHECK_NULL_RETURN(c_ptr);
for (int i = 0; i < param_->batch; i++) {
batch_input_ptr_ = a_ptr + i * param_->row_ * param_->deep_;
auto ret = ParallelLaunch(this->ms_context_, Arm64SdotPreRun, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Arm64SdotPreRun error: [" << ret << "]";
return ret;
}
batch_weight_ptr_ = b_ptr + i * param_->col_ * param_->deep_;
batch_b_ptr_ = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_;
batch_c_ptr_ = c_ptr + i * param_->row_ * param_->col_;
ret = ParallelLaunch(this->ms_context_, Arm64SdotRun, this, thread_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Arm64SdotRun error: [" << ret << "]";
return ret;
}
}
return RET_OK;
}
int MatMulDynamicSdotInt8Kernel::Run() {
auto ret = InitInputQuantParam();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init input quant param failed.";
return ret;
}
if (!param_->b_const_) {
ret = InitFilterQuantParam();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init filter quant param failed.";
FreeQuantParam();
return ret;
}
}
return MatMulDynamicRunArm64Sdot();
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_DYNAMIC_SDOT_INT8_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_DYNAMIC_SDOT_INT8_H_
#include <vector>
#include "src/runtime/kernel/arm/int8/matmul_dynamic_base_int8.h"
namespace mindspore::kernel {
class MatMulDynamicSdotInt8Kernel : public MatmulDynamicBaseInt8CPUKernel {
public:
MatMulDynamicSdotInt8Kernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: MatmulDynamicBaseInt8CPUKernel(parameter, inputs, outputs, ctx) {}
~MatMulDynamicSdotInt8Kernel() override = default;
int Run() override;
public:
int MatMulDynamicRunArm64Sdot();
int MatMulDynamicArm64SdotPre(int task_id);
int MatMulDynamicArm64SdotImpl(int task_id);
private:
void InitParameter() override;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_DYNAMIC_SDOT_INT8_H_

View File

@ -86,17 +86,19 @@ int PreprocessParser::ParsePreprocess(const DataPreProcessString &data_pre_proce
preprocess::ConvertColorConversionCodes(data_pre_process->image_pre_process.image_to_format);
}
}
ret = ParseImagePreProcess(data_pre_process_str, &data_pre_process->image_pre_process);
if (ret != RET_OK) {
MS_LOG(ERROR) << "image preprocess parse failed.";
return ret;
}
if (!data_pre_process_str.calibrate_path.empty() && !data_pre_process_str.calibrate_size.empty()) {
ret = ParseImagePreProcess(data_pre_process_str, &data_pre_process->image_pre_process);
if (ret != RET_OK) {
MS_LOG(ERROR) << "image preprocess parse failed.";
return ret;
}
ret = CollectCalibInputs(data_pre_process->calibrate_path, data_pre_process->calibrate_size,
&data_pre_process->calibrate_path_vector);
if (ret != RET_OK) {
MS_LOG(ERROR) << "collect calibrate inputs failed.";
return ret;
ret = CollectCalibInputs(data_pre_process->calibrate_path, data_pre_process->calibrate_size,
&data_pre_process->calibrate_path_vector);
if (ret != RET_OK) {
MS_LOG(ERROR) << "collect calibrate inputs failed.";
return ret;
}
}
return RET_OK;
}
@ -202,8 +204,13 @@ int PreprocessParser::CollectCalibInputs(const std::map<std::string, std::string
MS_LOG(ERROR) << " close dir failed.";
return RET_ERROR;
}
auto &cur_inputs = inputs->at(image_path.first);
std::sort(cur_inputs.begin(), cur_inputs.end());
if (inputs->find(image_path.first) != inputs->end()) {
auto &cur_inputs = inputs->at(image_path.first);
std::sort(cur_inputs.begin(), cur_inputs.end());
} else {
MS_LOG(ERROR) << "cant find " << image_path.first << " at input maps.";
return RET_ERROR;
}
if (count != limited_count) {
MS_LOG(ERROR) << " data path: " << image_path.second << " data count:" << count
<< " < limited_count:" << limited_count;

View File

@ -177,9 +177,9 @@ int DebugInfoManager::SetOriginStaticInfo(QuantDebugInfo *quant_debug_info, cons
}
quant_debug_info->clip = 0;
CHECK_NULL_RETURN(tensor.data());
quant_debug_info->tensor_data.data = malloc(tensor.Size());
CHECK_MALLOC_RES(quant_debug_info->tensor_data.data, RET_NULL_PTR);
CHECK_NULL_RETURN(tensor.data());
auto ret = memcpy_s(quant_debug_info->tensor_data.data, tensor.Size(), tensor.data(), tensor.Size());
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy memory failed.";

View File

@ -29,7 +29,6 @@
#include "schema/inner/model_generated.h"
#include "src/lite_session.h"
#include "tools/converter/quantizer/quantizer.h"
#include "tools/converter/converter.h"
#include "include/ms_tensor.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "tools/converter/quantizer/quant_params.h"

View File

@ -82,26 +82,22 @@ int DoWeightQuant(const FuncGraphPtr &old_graph, const converter::Flags *config)
auto quantizer = std::make_unique<WeightQuantizer>(*config);
if (quantizer == nullptr) {
MS_LOG(ERROR) << "New WeightQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return RET_ERROR;
}
status = static_cast<WeightQuantizer *>(quantizer.get())->DoQuantize(old_graph, init_scale);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantization failed " << status;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return RET_ERROR;
}
} else {
auto quantizer = std::make_unique<WeightQuantizer>(*config);
if (quantizer == nullptr) {
MS_LOG(ERROR) << "New WeightQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return RET_ERROR;
}
auto status = quantizer->DoQuantize(old_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantization failed " << status;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return RET_ERROR;
}
}
@ -112,13 +108,11 @@ int DoDynamicQuant(const FuncGraphPtr &old_graph, const converter::Flags *config
auto quantizer = std::make_unique<DynamicQuantizer>(*config);
if (quantizer == nullptr) {
MS_LOG(ERROR) << "New DynamicQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return RET_ERROR;
}
auto status = quantizer->DoQuantize(old_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantization failed " << status;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return RET_ERROR;
}
return RET_OK;

View File

@ -119,14 +119,15 @@ int WeightQuantizer::DoCNodeWeightQuant(const FuncGraphPtr &func_graph, const CN
}
auto status = RET_ERROR;
if (is_mixed_bit_) {
status = MixedBitQuantFilter(parameter, tensor_info, primitive, QuantType_QUANT_WEIGHT,
status = MixedBitQuantFilter(parameter, tensor_info, primitive, flags_.commonQuantParam.quant_type,
WeightQuantType::MIXED_BIT_PER_LAYER, type_id_, mixed_bit_init_scale_, idx - 1);
} else if (type_id_ == kNumberTypeInt8) {
status = FixedBitQuantFilter<int8_t>(parameter, tensor_info, primitive, QuantType_QUANT_WEIGHT, q_max, q_min,
bit_num_, tmp_weight_quant_type, type_id_, idx - 1, symmetric);
status = FixedBitQuantFilter<int8_t>(parameter, tensor_info, primitive, flags_.commonQuantParam.quant_type, q_max,
q_min, bit_num_, tmp_weight_quant_type, type_id_, idx - 1, symmetric);
} else if (type_id_ == kNumberTypeInt16) {
status = FixedBitQuantFilter<int16_t>(parameter, tensor_info, primitive, QuantType_QUANT_WEIGHT, q_max, q_min,
bit_num_, tmp_weight_quant_type, type_id_, idx - 1, symmetric);
status =
FixedBitQuantFilter<int16_t>(parameter, tensor_info, primitive, flags_.commonQuantParam.quant_type, q_max,
q_min, bit_num_, tmp_weight_quant_type, type_id_, idx - 1, symmetric);
}
if (status == RET_NO_CHANGE) {
continue;