!5794 rewrite fp16 to fp32

Merge pull request !5794 from cjh9368/static_check_r0.7
This commit is contained in:
mindspore-ci-bot 2020-09-05 13:56:50 +08:00 committed by Gitee
commit 5de9578abb
2 changed files with 77 additions and 50 deletions

View File

@ -129,68 +129,95 @@ union float32_bits {
}; };
typedef union float32_bits float32_bits; typedef union float32_bits float32_bits;
float ShortToFloat32(uint16_t srcValue) { float ShortToFloat32(uint16_t src_value) {
const float32_bits magic = {113 << 23}; const float32_bits magic = {113 << 23};
const unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift const unsigned int shifted_exp = 0x7c00 << 13;
float32_bits o; float32_bits o;
o.u = (srcValue & 0x7fff) << 13; // exponent/mantissa bits o.u = (src_value & 0x7fff) << 13;
unsigned int exp = shifted_exp & o.u; // just the exponent unsigned int exp = shifted_exp & o.u;
o.u += (127 - 15) << 23; // exponent adjust o.u += (127 - 15) << 23;
// handle exponent special cases if (exp == shifted_exp) {
if (exp == shifted_exp) { // Inf/NaN? o.u += (128 - 16) << 23;
o.u += (128 - 16) << 23; // extra exp adjust } else if (exp == 0) {
} else if (exp == 0) { // Zero/Denormal? o.u += 1 << 23;
o.u += 1 << 23; // extra exp adjust o.f -= magic.f;
o.f -= magic.f; // renormalize
} }
o.u |= (srcValue & 0x8000) << 16; // sign bit o.u |= (src_value & 0x8000) << 16;
return o.f; return o.f;
} }
uint16_t Float32ToShort(float srcValue) { static const unsigned int FP32_BIT_SIZE = 32;
float32_bits f; static const unsigned int FP32_EXPONENT_BIAS = 127;
f.f = srcValue; static const unsigned int FP32_SIGNIFICAND = 23;
const float32_bits f32infty = {255 << 23}; static const unsigned int FP32_EXPONENT_MAX = 255;
const float32_bits f16max = {(127 + 16) << 23};
const float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
unsigned int sign_mask = 0x80000000u;
uint16_t o;
unsigned int sign = f.u & sign_mask; static const unsigned int FP16_BIT_SIZE = 16;
f.u ^= sign; static const unsigned int FP16_EXPONENT_BIAS = 15;
static const unsigned int FP16_SIGNIFICAND = 10;
// NOTE all the integer compares in this function can be safely static const int FP16_EXPONENT_MAX = 30;
// compiled into signed compares since all operands are below static const int FP16_EXPONENT_MIN = -10;
// 0x80000000. Important if you want fast straight SSE2 code
// (since there's no unsigned PCMPGTD).
if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set) uint16_t Float32ToShort(float src_value) {
o = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf float *psrcValue = NULL;
} else { // (De)normalized number or zero psrcValue = &src_value;
if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero unsigned int srcValueBit = (unsigned int)(*psrcValue);
// use a magic value to align our 10 mantissa bits at the bottom of unsigned int sign = srcValueBit >> (FP32_BIT_SIZE - 1);
// the float. as long as FP addition is round-to-nearest-even this unsigned int mantissa = srcValueBit & 0x007FFFFF;
// just works. // exponent
f.f += denorm_magic.f; int exp = ((srcValueBit & 0x7F800000) >> FP32_SIGNIFICAND) + FP16_EXPONENT_BIAS - FP32_EXPONENT_BIAS;
uint16_t res;
// and one integer subtract of the bias later, we have our final float! if (exp > 0 && exp < FP16_EXPONENT_MAX) {
o = (uint16_t)(f.u - denorm_magic.u); // use rte rounding mode, round the significand, combine sign, exponent and significand into a short.
res = (sign << (FP16_BIT_SIZE - 1)) | (exp << FP16_SIGNIFICAND) |
((mantissa + 0x00001000) >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND));
} else if (srcValueBit == 0) {
res = 0;
} else { } else {
unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd if (exp <= 0) {
if (exp < FP16_EXPONENT_MIN) {
// update exponent, rounding bias part 1 // value is less than min half float point
f.u += ((unsigned int)(15 - 127) << 23) + 0xfff; res = 0;
// rounding bias part 2 } else {
f.u += mant_odd; // normalized single, magnitude is less than min normal half float point.
// take the bits! mantissa = (mantissa | 0x00800000) >> (1 - exp);
o = (uint16_t)(f.u >> 13); // round to nearest
if ((mantissa & 0x00001000) > 0) {
mantissa = mantissa + 0x00002000;
}
// combine sign & mantissa (exp is zero to get denormalized number)
res = (sign << FP16_EXPONENT_BIAS) | (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND));
}
} else if (exp == (FP32_EXPONENT_MAX - FP32_EXPONENT_BIAS + FP16_EXPONENT_BIAS)) {
if (mantissa == 0) {
// input float is infinity, return infinity half
res = (sign << FP16_EXPONENT_BIAS) | 0x7C00;
} else {
// input float is NaN, return half NaN
res = (sign << FP16_EXPONENT_BIAS) | 0x7C00 | (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND));
}
} else {
// exp > 0, normalized single, round to nearest
if ((mantissa & 0x00001000) > 0) {
mantissa = mantissa + 0x00002000;
if ((mantissa & 0x00800000) > 0) {
mantissa = 0;
exp = exp + 1;
} }
} }
if (exp > FP16_EXPONENT_MAX) {
o |= (uint16_t)(sign >> 16); // exponent overflow - return infinity half
return o; res = (sign << FP16_EXPONENT_BIAS) | 0x7C00;
} else {
// combine sign, exp and mantissa into normalized half
res = (sign << FP16_EXPONENT_BIAS) | (exp << FP16_SIGNIFICAND) |
(mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND));
}
}
}
return res;
} }

View File

@ -37,9 +37,9 @@ void MatrixSub(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stri
size_t row, size_t col); size_t row, size_t col);
void MatrixMultiAdd(float *c11, float *c12, float *c21, float *c22, float *x_ptr, size_t row, size_t col, void MatrixMultiAdd(float *c11, float *c12, float *c21, float *c22, float *x_ptr, size_t row, size_t col,
size_t c_stride, size_t x_stride); size_t c_stride, size_t x_stride);
float ShortToFloat32(uint16_t srcValue); float ShortToFloat32(uint16_t src_value);
uint16_t Float32ToShort(float srcValue); uint16_t Float32ToShort(float src_value);
#ifdef ENABLE_ARM #ifdef ENABLE_ARM
void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,