forked from mindspore-Ecosystem/mindspore
!26545 cast between fp16 int8 optimize
Merge pull request !26545 from zhaozhenlong/lite/issue/cast-fp16-int8
This commit is contained in:
commit
7fb6df8384
|
@ -93,6 +93,7 @@ mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/avx/Winog
|
|||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/deconv_winograd_fp32.c:PackDeConvWgDataFp32
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/deconv_winograd_fp32.c:DeConvWgMerge
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/avx/TiledC8MatMulFp32.c:TiledC8MatmulFp32
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/quant_dtype_cast_fp16.c:Fp16ToInt8_arm64
|
||||
mindspore/mindspore/ccsrc/debug/dump_proto.cc:mindspore::ProtoExporter::SetNodeOutputType
|
||||
mindspore/mindspore/ccsrc/debug/dump_proto.cc:mindspore::ProtoExporter::SetValueToProto
|
||||
mindspore/mindspore/ccsrc/debug/dump_proto.cc:mindspore::ProtoExporter::SetScalarToProto
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -18,27 +18,227 @@
|
|||
#include "nnacl/fp16/quant_dtype_cast_fp16.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void Int8ToFp16_arm64(const int8_t *quant_values, float16_t *dst, float scale, int32_t zp, int size) {
|
||||
asm volatile(
|
||||
"mov w8, %w[size]\n"
|
||||
"cmp w8, #0\n"
|
||||
"beq 2f\n"
|
||||
|
||||
"dup v20.4s, %w[zp32]\n"
|
||||
"dup v21.4s, %w[scale]\n"
|
||||
|
||||
"cmp w8, #16\n"
|
||||
"blt 1f\n"
|
||||
|
||||
"0:\n"
|
||||
"subs w8, w8, #16\n"
|
||||
"ld1 {v7.16b}, [%[quant_values]], #16\n"
|
||||
|
||||
"sxtl v8.8h, v7.8b\n"
|
||||
"sxtl2 v9.8h, v7.16b\n"
|
||||
|
||||
"sxtl v0.4s, v8.4h\n"
|
||||
"sxtl2 v1.4s, v8.8h\n"
|
||||
"sxtl v2.4s, v9.4h\n"
|
||||
"sxtl2 v3.4s, v9.8h\n"
|
||||
"sub v0.4s, v0.4s, v20.4s\n"
|
||||
"sub v1.4s, v1.4s, v20.4s\n"
|
||||
"sub v2.4s, v2.4s, v20.4s\n"
|
||||
"sub v3.4s, v3.4s, v20.4s\n"
|
||||
"scvtf v4.4s, v0.4s\n"
|
||||
"scvtf v5.4s, v1.4s\n"
|
||||
"scvtf v6.4s, v2.4s\n"
|
||||
"scvtf v7.4s, v3.4s\n"
|
||||
|
||||
"fmul v0.4s, v4.4s, v21.4s\n"
|
||||
"fmul v1.4s, v5.4s, v21.4s\n"
|
||||
"fmul v2.4s, v6.4s, v21.4s\n"
|
||||
"fmul v3.4s, v7.4s, v21.4s\n"
|
||||
|
||||
"fcvtn v4.4h, v0.4s\n"
|
||||
"fcvtn2 v4.8h, v1.4s\n"
|
||||
"fcvtn v5.4h, v2.4s\n"
|
||||
"fcvtn2 v5.8h, v3.4s\n"
|
||||
|
||||
"st1 {v4.8h, v5.8h}, [%[dst]], #32\n"
|
||||
"beq 2f\n"
|
||||
"cmp w8, #16\n"
|
||||
"bge 0b\n"
|
||||
|
||||
"1:\n"
|
||||
"ldrsb w9, [%[quant_values]], #1\n"
|
||||
|
||||
"subs w8, w8, #1\n"
|
||||
"sub w9, w9, %w[zp32]\n"
|
||||
"scvtf s9, w9\n"
|
||||
|
||||
"fmul s9, s9, s21\n"
|
||||
"fcvtn v4.4h, v9.4s\n"
|
||||
"str h4, [%[dst]], #2\n"
|
||||
"bne 1b\n"
|
||||
|
||||
"2:\n"
|
||||
|
||||
:
|
||||
: [ quant_values ] "r"(quant_values), [ dst ] "r"(dst), [ scale ] "r"(scale), [ zp32 ] "r"(zp), [ size ] "r"(size)
|
||||
: "w8", "w9", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v20", "v21");
|
||||
}
|
||||
#endif
|
||||
|
||||
int DoDequantizeInt8ToFp16(const int8_t *quant_values, float16_t *real_values, float scale, int32_t zp, int size) {
|
||||
if (quant_values == NULL || real_values == NULL) {
|
||||
return NNACL_PARAM_INVALID;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
Int8ToFp16_arm64(quant_values, real_values, scale, zp, size);
|
||||
#else
|
||||
for (int i = 0; i < size; ++i) {
|
||||
real_values[i] = (quant_values[i] - zp) * scale;
|
||||
}
|
||||
#endif
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void Fp16ToInt8_arm64(const float16_t *real_values, int8_t *quant_values, float scale, int32_t zp, int size) {
|
||||
float ivs = 1.0f / scale;
|
||||
const int32_t min_value = -128;
|
||||
const int32_t max_value = 127;
|
||||
asm volatile(
|
||||
"mov w8, %w[size]\n"
|
||||
"cmp w8, wzr\n"
|
||||
"beq 3f\n"
|
||||
|
||||
"dup v28.4s, %w[ivs]\n"
|
||||
"dup v29.4s, %w[min_value]\n"
|
||||
"dup v30.4s, %w[max_value]\n"
|
||||
|
||||
"cmp w8, #32\n"
|
||||
"blt 2f\n"
|
||||
"1:\n" // loop 32
|
||||
"subs w8, w8, #32\n"
|
||||
"ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%[real_values]], #64\n"
|
||||
"fcvtl v8.4s, v0.4h\n"
|
||||
"fcvtl2 v9.4s, v0.8h\n"
|
||||
"fcvtl v10.4s, v1.4h\n"
|
||||
"fcvtl2 v11.4s, v1.8h\n"
|
||||
"fcvtl v12.4s, v2.4h\n"
|
||||
"fcvtl2 v13.4s, v2.8h\n"
|
||||
"fcvtl v14.4s, v3.4h\n"
|
||||
"fcvtl2 v15.4s, v3.8h\n"
|
||||
|
||||
"dup v16.4s, %w[zp]\n"
|
||||
"dup v17.4s, %w[zp]\n"
|
||||
"dup v18.4s, %w[zp]\n"
|
||||
"dup v19.4s, %w[zp]\n"
|
||||
"dup v20.4s, %w[zp]\n"
|
||||
"dup v21.4s, %w[zp]\n"
|
||||
"dup v22.4s, %w[zp]\n"
|
||||
"dup v23.4s, %w[zp]\n"
|
||||
"scvtf v16.4s, v16.4s\n"
|
||||
"scvtf v17.4s, v17.4s\n"
|
||||
"scvtf v18.4s, v18.4s\n"
|
||||
"scvtf v19.4s, v19.4s\n"
|
||||
"scvtf v20.4s, v20.4s\n"
|
||||
"scvtf v21.4s, v21.4s\n"
|
||||
"scvtf v22.4s, v22.4s\n"
|
||||
"scvtf v23.4s, v23.4s\n"
|
||||
|
||||
"fmla v16.4s, v8.4s, v28.4s\n"
|
||||
"fmla v17.4s, v9.4s, v28.4s\n"
|
||||
"fmla v18.4s, v10.4s, v28.4s\n"
|
||||
"fmla v19.4s, v11.4s, v28.4s\n"
|
||||
"fmla v20.4s, v12.4s, v28.4s\n"
|
||||
"fmla v21.4s, v13.4s, v28.4s\n"
|
||||
"fmla v22.4s, v14.4s, v28.4s\n"
|
||||
"fmla v23.4s, v15.4s, v28.4s\n"
|
||||
|
||||
"fcvtas v8.4s, v16.4s\n"
|
||||
"fcvtas v9.4s, v17.4s\n"
|
||||
"fcvtas v10.4s, v18.4s\n"
|
||||
"fcvtas v11.4s, v19.4s\n"
|
||||
"fcvtas v12.4s, v20.4s\n"
|
||||
"fcvtas v13.4s, v21.4s\n"
|
||||
"fcvtas v14.4s, v22.4s\n"
|
||||
"fcvtas v15.4s, v23.4s\n"
|
||||
|
||||
"smax v8.4s, v8.4s, v29.4s\n"
|
||||
"smax v9.4s, v9.4s, v29.4s\n"
|
||||
"smax v10.4s, v10.4s, v29.4s\n"
|
||||
"smax v11.4s, v11.4s, v29.4s\n"
|
||||
"smax v12.4s, v12.4s, v29.4s\n"
|
||||
"smax v13.4s, v13.4s, v29.4s\n"
|
||||
"smax v14.4s, v14.4s, v29.4s\n"
|
||||
"smax v15.4s, v15.4s, v29.4s\n"
|
||||
"smin v8.4s, v8.4s, v30.4s\n"
|
||||
"smin v9.4s, v9.4s, v30.4s\n"
|
||||
"smin v10.4s, v10.4s, v30.4s\n"
|
||||
"smin v11.4s, v11.4s, v30.4s\n"
|
||||
"smin v12.4s, v12.4s, v30.4s\n"
|
||||
"smin v13.4s, v13.4s, v30.4s\n"
|
||||
"smin v14.4s, v14.4s, v30.4s\n"
|
||||
"smin v15.4s, v15.4s, v30.4s\n"
|
||||
|
||||
"sqxtn v16.4h, v8.4s\n"
|
||||
"sqxtn2 v16.8h, v9.4s\n"
|
||||
"sqxtn v17.4h, v10.4s\n"
|
||||
"sqxtn2 v17.8h, v11.4s\n"
|
||||
"sqxtn v18.4h, v12.4s\n"
|
||||
"sqxtn2 v18.8h, v13.4s\n"
|
||||
"sqxtn v19.4h, v14.4s\n"
|
||||
"sqxtn2 v19.8h, v15.4s\n"
|
||||
"sqxtn v20.8b, v16.8h\n"
|
||||
"sqxtn2 v20.16b, v17.8h\n"
|
||||
"sqxtn v21.8b, v18.8h\n"
|
||||
"sqxtn2 v21.16b, v19.8h\n"
|
||||
|
||||
"st1 {v20.16b, v21.16b}, [%[quant_values]], #32\n"
|
||||
|
||||
"beq 3f\n"
|
||||
"cmp w8, #32\n"
|
||||
"bge 1b\n"
|
||||
|
||||
"2:\n" // 1 by 1
|
||||
"scvtf s10, %w[zp]\n"
|
||||
"subs w8, w8, #1\n"
|
||||
"ldr h0, [%[real_values]], #2\n"
|
||||
"fcvt s0, h0\n"
|
||||
"fmul s0, s0, s28\n"
|
||||
"fadd s0, s0, s10\n"
|
||||
"fcvtas s0, s0\n"
|
||||
"smax v0.4s, v0.4s, v29.4s\n"
|
||||
"smin v0.4s, v0.4s, v30.4s\n"
|
||||
"sqxtn v0.4h, v0.4s\n"
|
||||
"sqxtn v0.8b, v0.8h\n"
|
||||
"st1 {v0.b}[0], [%[quant_values]], #1\n"
|
||||
"bne 2b\n"
|
||||
|
||||
"3:\n"
|
||||
:
|
||||
: [ size ] "r"(size), [ ivs ] "r"(ivs), [ real_values ] "r"(real_values), [ quant_values ] "r"(quant_values),
|
||||
[ zp ] "r"(zp), [ min_value ] "r"(min_value), [ max_value ] "r"(max_value)
|
||||
: "w8", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
|
||||
"v17", "v18", "v19", "v20", "v21", "v22", "v23", "v28", "v29", "v30");
|
||||
}
|
||||
#endif
|
||||
|
||||
int DoQuantizeFp16ToInt8(const float16_t *real_values, int8_t *quant_values, float scale, int32_t zp, int size) {
|
||||
if (quant_values == NULL || real_values == NULL) {
|
||||
return NNACL_PARAM_INVALID;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
Fp16ToInt8_arm64(real_values, quant_values, scale, zp, size);
|
||||
#else
|
||||
for (int i = 0; i < size; ++i) {
|
||||
if (isinf(real_values[i])) {
|
||||
if (real_values[i] == INFINITY) {
|
||||
quant_values[i] = 127;
|
||||
continue;
|
||||
}
|
||||
if (real_values[i] == -INFINITY) {
|
||||
quant_values[i] = -128;
|
||||
continue;
|
||||
}
|
||||
float temp = round((float)real_values[i] / scale + zp);
|
||||
if (temp > 127) {
|
||||
quant_values[i] = 127;
|
||||
|
@ -48,6 +248,7 @@ int DoQuantizeFp16ToInt8(const float16_t *real_values, int8_t *quant_values, flo
|
|||
quant_values[i] = (int8_t)temp;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -17,6 +17,9 @@
|
|||
#include <cmath>
|
||||
#include "common/common_test.h"
|
||||
#include "nnacl/int8/quant_dtype_cast_int8.h"
|
||||
#ifdef ENABLE_ARM64
|
||||
#include "nnacl/fp16/quant_dtype_cast_fp16.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
|
@ -57,6 +60,155 @@ void Fp32ToInt8Util(const float *real_values, int8_t *quant_values, float scale,
|
|||
return;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void Fp16ToInt8Util(const float16_t *real_values, int8_t *quant_values, float scale, int32_t zp, int size) {
|
||||
if (quant_values == NULL || real_values == NULL) {
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < size; ++i) {
|
||||
if (real_values[i] == INFINITY) {
|
||||
quant_values[i] = INT8_MAX;
|
||||
continue;
|
||||
}
|
||||
if (real_values[i] == -INFINITY) {
|
||||
quant_values[i] = INT8_MIN;
|
||||
continue;
|
||||
}
|
||||
float temp = round(static_cast<float>(real_values[i]) / scale + zp);
|
||||
if (temp > INT8_MAX) {
|
||||
quant_values[i] = INT8_MAX;
|
||||
} else if (temp < INT8_MIN) {
|
||||
quant_values[i] = INT8_MIN;
|
||||
} else {
|
||||
quant_values[i] = (int8_t)temp;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void ConstructFp16Int8Data(float16_t *real_values, int8_t *benchmark_data, int kSize, int32_t zp, float scale) {
|
||||
constexpr int kDiv = 2;
|
||||
for (int i = 0; i < kSize; ++i) {
|
||||
real_values[i] = static_cast<float16_t>(i - kSize / kDiv);
|
||||
}
|
||||
Fp16ToInt8Util(real_values, benchmark_data, scale, zp, kSize);
|
||||
}
|
||||
|
||||
TEST_F(QuantCastInt8Test, Fp16Int8Size3) {
|
||||
constexpr int kSize = 8;
|
||||
float16_t real_values[kSize] = {-INFINITY, INFINITY, -0.5, 0.0, 0.5, 1.0, 2.0, 3.5};
|
||||
int32_t zp = -1;
|
||||
float scale = 0.3f;
|
||||
|
||||
int8_t benchmark_data[kSize];
|
||||
Fp16ToInt8Util(real_values, benchmark_data, scale, zp, kSize);
|
||||
int8_t quant_values[kSize];
|
||||
(void)DoQuantizeFp16ToInt8(real_values, quant_values, scale, zp, kSize);
|
||||
|
||||
for (int i = 0; i < kSize; ++i) {
|
||||
ASSERT_EQ(quant_values[i], benchmark_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(QuantCastInt8Test, Fp16Int8Size32) {
|
||||
constexpr int kSize = 32;
|
||||
int32_t zp = 0;
|
||||
float scale = 0.3f;
|
||||
float16_t real_values[kSize];
|
||||
int8_t benchmark_data[kSize];
|
||||
ConstructFp16Int8Data(real_values, benchmark_data, kSize, zp, scale);
|
||||
|
||||
int8_t quant_values[kSize];
|
||||
(void)DoQuantizeFp16ToInt8(real_values, quant_values, scale, zp, kSize);
|
||||
|
||||
for (int i = 0; i < kSize; ++i) {
|
||||
ASSERT_EQ(quant_values[i], benchmark_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(QuantCastInt8Test, Fp16Int8Size33) {
|
||||
constexpr int kSize = 33;
|
||||
float16_t real_values[kSize];
|
||||
int32_t zp = 2;
|
||||
float scale = 0.1f;
|
||||
int8_t benchmark_data[kSize];
|
||||
ConstructFp16Int8Data(real_values, benchmark_data, kSize, zp, scale);
|
||||
|
||||
int8_t quant_values[kSize];
|
||||
(void)DoQuantizeFp16ToInt8(real_values, quant_values, scale, zp, kSize);
|
||||
|
||||
for (int i = 0; i < kSize; ++i) {
|
||||
ASSERT_EQ(quant_values[i], benchmark_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void Int8ToFp16Util(const int8_t *quant_values, float16_t *real_values, float scale, int32_t zp, int size) {
|
||||
if (quant_values == NULL || real_values == NULL) {
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < size; ++i) {
|
||||
real_values[i] = (quant_values[i] - zp) * scale;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void ConstructInt8ToFp16Data(int8_t *quant_values, float16_t *benchmark_data, int kSize, int32_t zp, float scale) {
|
||||
constexpr int kDiv = 2;
|
||||
for (int i = 0; i < kSize; ++i) {
|
||||
quant_values[i] = (i - kSize / kDiv);
|
||||
}
|
||||
Int8ToFp16Util(quant_values, benchmark_data, scale, zp, kSize);
|
||||
}
|
||||
|
||||
TEST_F(QuantCastInt8Test, Int8Fp16Size6) {
|
||||
constexpr int kSize = 6;
|
||||
int8_t quant_values[kSize];
|
||||
int32_t zp = 2;
|
||||
float scale = 0.3f;
|
||||
float16_t benchmark_data[kSize];
|
||||
ConstructInt8ToFp16Data(quant_values, benchmark_data, kSize, zp, scale);
|
||||
|
||||
float16_t real_values[kSize];
|
||||
DoDequantizeInt8ToFp16(quant_values, real_values, scale, zp, kSize);
|
||||
|
||||
for (int i = 0; i < kSize; ++i) {
|
||||
ASSERT_LE(real_values[i] - benchmark_data[i], float16_t(1e-5));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(QuantCastInt8Test, Int8Fp16Size16) {
|
||||
constexpr int kSize = 16;
|
||||
int8_t quant_values[kSize];
|
||||
int32_t zp = 2;
|
||||
float scale = 0.3f;
|
||||
float16_t benchmark_data[kSize];
|
||||
ConstructInt8ToFp16Data(quant_values, benchmark_data, kSize, zp, scale);
|
||||
|
||||
float16_t real_values[kSize];
|
||||
DoDequantizeInt8ToFp16(quant_values, real_values, scale, zp, kSize);
|
||||
|
||||
for (int i = 0; i < kSize; ++i) {
|
||||
ASSERT_LE(real_values[i] - benchmark_data[i], float16_t(1e-5));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(QuantCastInt8Test, Int8Fp16Size18) {
|
||||
constexpr int kSize = 18;
|
||||
int8_t quant_values[kSize];
|
||||
int32_t zp = 2;
|
||||
float scale = 0.3f;
|
||||
float16_t benchmark_data[kSize];
|
||||
ConstructInt8ToFp16Data(quant_values, benchmark_data, kSize, zp, scale);
|
||||
|
||||
float16_t real_values[kSize];
|
||||
DoDequantizeInt8ToFp16(quant_values, real_values, scale, zp, kSize);
|
||||
|
||||
for (int i = 0; i < kSize; ++i) {
|
||||
ASSERT_LE(real_values[i] - benchmark_data[i], float16_t(1e-5));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST_F(QuantCastInt8Test, Int8Fp32Size8) {
|
||||
constexpr int kSize = 8;
|
||||
int8_t quant_values[kSize] = {-128, 1, 1, 2, 3, 5, 8, 13};
|
||||
|
|
Loading…
Reference in New Issue