full connection vec matmul opt

This commit is contained in:
zhaozhenlong 2021-06-29 16:56:19 +08:00 committed by zhaozhenlong
parent 03ba1d08b3
commit 23f16b8778
11 changed files with 289 additions and 39 deletions

View File

@ -393,6 +393,91 @@ void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const floa
return;
}
#ifdef ENABLE_ARM64
// 8 X 16
void VecMatmulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, int depth,
int col) {
int align_col = UP_ROUND(col, C16NUM);
int ci = 0;
for (; ci < align_col - C16NUM + 1; ci += C16NUM) {
float16x8_t acc_0 = vdupq_n_f16((float16_t)0.0);
float16x8_t acc_1 = vdupq_n_f16((float16_t)0.0);
if (bias != NULL) {
acc_0 = vld1q_f16(bias + ci);
acc_1 = vld1q_f16(bias + ci + C8NUM);
}
const float16_t *bv_base = b + ci * depth;
int di = 0;
for (; di < depth - C8NUM + 1; di += C8NUM) {
float16x8_t av = vld1q_f16(a + di);
float16x8_t bv_0[C8NUM];
float16x8_t bv_1[C8NUM];
for (int i = 0; i < C8NUM; ++i) {
bv_0[i] = vld1q_f16(bv_base);
bv_1[i] = vld1q_f16(bv_base + C8NUM);
bv_base += C16NUM;
}
for (int i = 0; i < C8NUM; ++i) {
acc_0 = vfmaq_n_f16(acc_0, bv_0[i], av[i]);
acc_1 = vfmaq_n_f16(acc_1, bv_1[i], av[i]);
}
}
if (di < depth) {
for (; di < depth; ++di) {
float16_t ai = a[di];
float16x8_t bv0 = vld1q_f16(bv_base);
float16x8_t bv1 = vld1q_f16(bv_base + C8NUM);
acc_0 = vfmaq_n_f16(acc_0, bv0, ai);
acc_1 = vfmaq_n_f16(acc_1, bv1, ai);
bv_base += C16NUM;
}
} // only save actual col num data
if (ci + C8NUM > col) {
int c_remain = col - ci;
for (int i = 0; i < c_remain; ++i) {
if (act_type == ActType_Relu) {
c[i] = MSMAX(acc_0[i], (float16_t)0.0);
} else if (act_type == ActType_Relu6) {
c[i] = MSMIN(MSMAX(acc_0[i], (float16_t)0.0), (float16_t)6.0);
} else {
c[i] = acc_0[i];
}
}
return;
}
if (act_type == ActType_Relu) {
acc_0 = vmaxq_f16(acc_0, vdupq_n_f16((float16_t)0.0));
}
if (act_type == ActType_Relu6) {
acc_0 = vminq_f16(vmaxq_f16(acc_0, vdupq_n_f16((float16_t)0.0)), vdupq_n_f16((float16_t)6.0));
}
vst1q_f16(c, acc_0);
if (ci + C16NUM > col) {
int c_remain = col - ci;
for (int i = 0; i < c_remain; ++i) {
if (act_type == ActType_Relu) {
c[C8NUM + i] = MSMAX(acc_1[i], (float16_t)0.0);
} else if (act_type == ActType_Relu6) {
c[C8NUM + i] = MSMIN(MSMAX(acc_1[i], (float16_t)0.0), (float16_t)6.0);
} else {
c[C8NUM + i] = acc_1[i];
}
}
return;
}
if (act_type == ActType_Relu) {
acc_1 = vmaxq_f16(acc_1, vdupq_n_f16((float16_t)0.0));
}
if (act_type == ActType_Relu6) {
acc_1 = vminq_f16(vmaxq_f16(acc_1, vdupq_n_f16((float16_t)0.0)), vdupq_n_f16((float16_t)6.0));
}
vst1q_f16(c + C8NUM, acc_1);
c += C16NUM;
}
}
#endif
#ifdef ENABLE_ARM82_A32
void MatVecMulA32Fp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
int depth, int col) {
@ -675,6 +760,23 @@ void RowMajor2Row16MajorFp16(const void *src, float16_t *dst, int row, int col,
}
}
void RowMajor2Row16MajorFp16Opt(const float16_t *src, float16_t *dst, int row, int col) {
int col_align = UP_ROUND(col, C16NUM);
for (int r = 0; r < row; r++) {
int c = 0;
for (; c < col; c++) {
int c_div16 = c / C16NUM;
int c_mod16 = c % C16NUM;
dst[c_div16 * C16NUM * row + r * C16NUM + c_mod16] = src[r * col + c];
}
for (; c < col_align; c++) {
int c_div16 = c / C16NUM;
int c_mod16 = c % C16NUM;
dst[c_div16 * C16NUM * row + r * C16NUM + c_mod16] = (float16_t)0.0;
}
}
}
void RowMajor2Row12MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) {
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {

View File

@ -56,6 +56,8 @@ void MatmulFp16Neon64Opt(const float16_t *a, const float16_t *b, float16_t *c, c
void MatVecMulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
int depth, int col);
void VecMatmulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, int depth,
int col);
#elif ENABLE_ARM82_A32
void MatMul12x8A32Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
int deep, int row, int col, int stride, int write_mode);
@ -86,6 +88,8 @@ void RowMajor2Col16MajorFp16(const void *src, float16_t *dst, int row, int col,
void RowMajor2Col12MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src);
void RowMajor2Row16MajorFp16Opt(const float16_t *src, float16_t *dst, int row, int col);
void RowMajor2Row16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src);
void RowMajor2Row12MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src);

View File

@ -18,6 +18,9 @@
#ifdef ENABLE_SSE
#include <x86intrin.h>
#endif
#ifdef ENABLE_ARM64
#include <arm_neon.h>
#endif
void RowMajor2ColMajor(const float *src_ptr, float *dst_ptr, int row, int col) {
for (int r = 0; r < row; ++r) {
for (int c = 0; c < col; ++c) {
@ -881,6 +884,100 @@ void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias,
}
}
#endif
#ifdef ENABLE_ARM64
// 4x8
void MatVecMulFp32Neon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col,
int align_col) {
int ci = 0;
for (; ci < align_col - C8NUM + 1; ci += C8NUM) {
float32x4_t acc_0;
float32x4_t acc_1;
if (bias != NULL) {
acc_0 = vld1q_f32(bias + ci);
acc_1 = vld1q_f32(bias + ci + C4NUM);
} else {
acc_0 = vdupq_n_f32(0.0f);
acc_1 = vdupq_n_f32(0.0f);
}
const float *bv_base = b + ci * depth;
int di = 0;
for (; di < depth - C4NUM + 1; di += C4NUM) {
float32x4_t av = vld1q_f32(a + di);
float32x4_t bv_00 = vld1q_f32(bv_base);
float32x4_t bv_10 = vld1q_f32(bv_base + C4NUM);
bv_base += C8NUM;
float32x4_t bv_01 = vld1q_f32(bv_base);
float32x4_t bv_11 = vld1q_f32(bv_base + C4NUM);
bv_base += C8NUM;
float32x4_t bv_02 = vld1q_f32(bv_base);
float32x4_t bv_12 = vld1q_f32(bv_base + C4NUM);
bv_base += C8NUM;
float32x4_t bv_03 = vld1q_f32(bv_base);
float32x4_t bv_13 = vld1q_f32(bv_base + C4NUM);
bv_base += C8NUM;
acc_0 = vmlaq_n_f32(acc_0, bv_00, av[0]);
acc_1 = vmlaq_n_f32(acc_1, bv_10, av[0]);
acc_0 = vmlaq_n_f32(acc_0, bv_01, av[1]);
acc_1 = vmlaq_n_f32(acc_1, bv_11, av[1]);
acc_0 = vmlaq_n_f32(acc_0, bv_02, av[2]);
acc_1 = vmlaq_n_f32(acc_1, bv_12, av[2]);
acc_0 = vmlaq_n_f32(acc_0, bv_03, av[3]);
acc_1 = vmlaq_n_f32(acc_1, bv_13, av[3]);
}
if (di < depth) {
for (; di < depth; ++di) {
float ai = a[di];
float32x4_t bv0 = vld1q_f32(bv_base);
float32x4_t bv1 = vld1q_f32(bv_base + C4NUM);
acc_0 = vmlaq_n_f32(acc_0, bv0, ai);
acc_1 = vmlaq_n_f32(acc_1, bv1, ai);
bv_base += C8NUM;
}
} // only save actual col num data
if (ci + C4NUM - 1 >= col) {
int c_remain = col - ci;
for (int i = 0; i < c_remain; ++i) {
if (act_type == ActType_Relu) {
c[i] = MSMAX(acc_0[i], 0.0f);
} else if (act_type == ActType_Relu6) {
c[i] = MSMIN(MSMAX(acc_0[i], 0.0f), 6.0f);
} else {
c[i] = acc_0[i];
}
}
return;
}
if (act_type == ActType_Relu) {
acc_0 = vmaxq_f32(acc_0, vdupq_n_f32(0.0f));
} else if (act_type == ActType_Relu6) {
acc_0 = vminq_f32(vmaxq_f32(acc_0, vdupq_n_f32(0.0f)), vdupq_n_f32(6.0f));
}
vst1q_f32(c, acc_0);
if (ci + C8NUM - 1 >= col) {
int c_remain = col - ci;
for (int i = 0; i < c_remain; ++i) {
if (act_type == ActType_Relu) {
c[C4NUM + i] = MSMAX(acc_1[i], 0.0f);
} else if (act_type == ActType_Relu6) {
c[C4NUM + i] = MSMIN(MSMAX(acc_1[i], 0.0f), 6.0f);
} else {
c[C4NUM + i] = acc_1[i];
}
}
return;
}
if (act_type == ActType_Relu) {
acc_1 = vmaxq_f32(acc_1, vdupq_n_f32(0.0f));
} else if (act_type == ActType_Relu6) {
acc_1 = vminq_f32(vmaxq_f32(acc_1, vdupq_n_f32(0.0f)), vdupq_n_f32(6.0f));
}
vst1q_f32(c + C4NUM, acc_1);
c += C8NUM;
}
}
#endif
void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
int col, int stride, int out_type) {
if (out_type == OutType_Nhwc) {

View File

@ -65,6 +65,8 @@ void MatmulFloatNeon64OptRow4(const float *a, const float *b, float *c, const fl
int row, int col, size_t stride, size_t write_mode);
void MatmulFloatNeon64OptRow12(const float *a, const float *b, float *c, const float *bias, int act_type, int depth,
int row, int col, size_t stride, size_t write_mode);
void MatVecMulFp32Neon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col,
int align_col);
#elif ENABLE_ARM32
void MatmulFloatNeon32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, int stride, size_t writeNhwc, size_t WriteWino);

View File

@ -50,12 +50,11 @@ int FullconnectionFP16CPUKernel::Init() {
params_->a_transpose_ = false;
params_->b_transpose_ = true;
MatmulBaseFP16CPUKernel::InitParameter();
params_->a_const_ = (in_tensors_[0]->data_c() != nullptr);
params_->b_const_ = (in_tensors_[1]->data_c() != nullptr);
if (params_->a_const_ == true) {
InitAShape();
}
if (params_->b_const_ == true) {
InitBShape();
}

View File

@ -100,9 +100,18 @@ int MatmulBaseFP16CPUKernel::ReSize() {
free(src_b_);
src_b_ = nullptr;
}
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(params_->col_, C8NUM));
thread_stride_ = UP_DIV(UP_DIV(params_->col_, C8NUM), thread_count_) * C8NUM;
if (vec_matmul_) {
#ifdef ENABLE_ARM64
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(params_->col_, C16NUM));
thread_stride_ = UP_DIV(UP_DIV(params_->col_, C16NUM), thread_count_) * C16NUM;
#else
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(params_->col_, C8NUM));
thread_stride_ = UP_DIV(UP_DIV(params_->col_, C8NUM), thread_count_) * C8NUM;
#endif
} else {
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(params_->col_, C8NUM));
thread_stride_ = UP_DIV(UP_DIV(params_->col_, C8NUM), thread_count_) * C8NUM;
}
return RET_OK;
}
@ -113,7 +122,11 @@ void MatmulBaseFP16CPUKernel::ResizeParameter() {
if (vec_matmul_) {
params_->row_align_ = 1;
#ifdef ENABLE_ARM64
params_->col_align_ = UP_ROUND(params_->col_, C16NUM);
#else
params_->col_align_ = params_->col_;
#endif
} else {
params_->row_align_ = UP_ROUND(params_->row_, row_tile_);
params_->col_align_ = UP_ROUND(params_->col_, C8NUM);
@ -188,13 +201,27 @@ void MatmulBaseFP16CPUKernel::InitMatrixB(void *src_ptr, TypeId src_data_type) {
Float32ToFloat16(reinterpret_cast<float *>(src_ptr), b_pack_ptr_,
params_->batch * params_->col_ * params_->deep_);
} else {
#ifdef ENABLE_ARM64
for (auto i = 0; i < params_->batch; ++i) {
const auto *b_src = reinterpret_cast<float16_t *>(src_ptr) + i * params_->col_align_ * params_->deep_;
auto *dst = b_pack_ptr_ + i * params_->col_align_ * params_->deep_;
RowMajor2Col16MajorFp16Opt(b_src, dst, params_->col_align_, params_->deep_);
}
#else
memcpy(b_pack_ptr_, src_ptr, params_->batch * params_->col_ * params_->deep_ * sizeof(float16_t));
#endif
}
} else {
for (int i = 0; i < params_->batch; i++) {
#ifdef ENABLE_ARM64
const auto *b_src = reinterpret_cast<float16_t *>(src_ptr) + i * params_->col_align_ * params_->deep_;
auto *dst = b_pack_ptr_ + i * params_->col_align_ * params_->deep_;
RowMajor2Row16MajorFp16Opt(b_src, dst, params_->deep_, params_->col_);
#else
const int8_t *batch_src = int8_src + i * params_->deep_ * params_->col_ * lite::DataTypeSize(src_data_type);
float16_t *dst = b_pack_ptr_ + i * params_->deep_ * params_->col_;
RowMajor2ColMajorFp16(batch_src, dst, params_->deep_, params_->col_, src_data_type == kNumberTypeFloat32);
#endif
}
}
return;
@ -210,7 +237,7 @@ void MatmulBaseFP16CPUKernel::InitMatrixB(void *src_ptr, TypeId src_data_type) {
}
}
return;
}
} // namespace mindspore::kernel
int MatmulBaseFP16CPUKernel::Init() {
ResizeParameter();
@ -259,7 +286,11 @@ int MatmulBaseFP16CPUKernel::RunImpl(int task_id) {
auto c = batch_c_ptr_ + task_id * thread_stride_;
if (vec_matmul_) {
#ifdef ENABLE_ARM64
VecMatmulFp16(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc);
#else
MatVecMulFp16(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc);
#endif
} else {
MatMulFp16(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, params_->row_, cur_oc, params_->col_,
OutType_Nhwc);
@ -288,7 +319,11 @@ int MatmulBaseFP16CPUKernel::Run() {
for (int i = 0; i < params_->batch; ++i) {
if (vec_matmul_) {
batch_a_ptr_ = a_pack_ptr_ + i * params_->deep_;
#ifdef ENABLE_ARM64
batch_b_ptr_ = b_pack_ptr_ + i * params_->deep_ * params_->col_align_;
#else
batch_b_ptr_ = b_pack_ptr_ + i * params_->deep_ * params_->col_;
#endif
batch_c_ptr_ = c_ptr + i * params_->row_ * params_->col_;
} else {
batch_a_ptr_ = a_pack_ptr_ + i * params_->row_align_ * params_->deep_;

View File

@ -53,6 +53,8 @@ void MatmulFp32BaseCPUKernel::ResizeParameter() {
#ifdef ENABLE_AVX
// vector matmul col is aligned to C8NUM in avx
col_tile_ = C8NUM;
#elif defined(ENABLE_ARM64)
col_tile_ = C8NUM;
#endif
row_tile_ = 1;
}
@ -60,6 +62,9 @@ void MatmulFp32BaseCPUKernel::ResizeParameter() {
#ifdef ENABLE_AVX
// avx is aligned to col_tile_
params_->col_align_ = UP_ROUND(params_->col_, col_tile_);
#elif defined(ENABLE_ARM64)
// no matter vec_matmul_ or not, use col_tile_ to get col_align_
params_->col_align_ = UP_ROUND(params_->col_, col_tile_);
#else
params_->col_align_ = vec_matmul_ ? params_->col_ : UP_ROUND(params_->col_, col_tile_);
#endif
@ -170,12 +175,16 @@ int MatmulFp32BaseCPUKernel::InitMatrixB(const float *src_ptr) {
if (params_->b_transpose_) {
#ifdef ENABLE_AVX
RowMajor2Col32Major(src_data, dst, params_->deep_, params_->col_);
#elif defined(ENABLE_ARM64)
RowMajor2Col8Major(src_data, dst, params_->col_, params_->deep_);
#else
memcpy(dst, src_data, params_->col_ * params_->deep_ * sizeof(float));
#endif
} else {
#ifdef ENABLE_AVX
RowMajor2Row32Major(src_data, dst, params_->col_, params_->deep_);
#elif defined(ENABLE_ARM64)
RowMajor2Row8Major(src_data, dst, params_->deep_, params_->col_);
#else
RowMajor2ColMajor(src_data, dst, params_->deep_, params_->col_);
#endif
@ -248,6 +257,8 @@ int MatmulFp32BaseCPUKernel::FloatRun(int task_id) {
if (vec_matmul_) {
#ifdef ENABLE_AVX
MatVecMulAvxFp32(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc, params_->col_align_);
#elif defined(ENABLE_ARM64)
MatVecMulFp32Neon64(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc, params_->col_align_);
#else
MatVecMulFp32(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc);
#endif

View File

@ -6,7 +6,7 @@ beard 2
emotion 60
gender_res_large_deploy 0.1
glasses 4
hat 1
hat 2.5
isface 1
ml_bank_detect_0312_tmp 20
ml_face_div_parsing 8
@ -24,8 +24,8 @@ mtk_detect-deeper-halfdeeper-mbv1-lastearlySSD-shortcut-400-400_nopostprocess_si
# mtk_detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified: precision is 5%
detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified 5.5
hiai_face_detect_rfb 4
hiai_face_isface 0.1
hiai_face_landmark 0.2
hiai_face_isface 0.2
hiai_face_landmark 0.3
hiai_face_pose_tuku 1.3
ml_hand_detection 8
ml_ocr_cn 6
@ -45,17 +45,17 @@ model_hebing_3branch 40
hiai_cv_focusShootOCRModel_07 3
hiai_cv_focusShootOCRModel_03 60
hiai_cv_focusShootOCRModel_01 14
hiai_face_hat1 1
hiai_face_hat1 1.7
hiai_cv_focusShootOCRModel_04 8
hiai_cv_focusShootOCRModel_06 13
hiai_cpu_face_hat 0.3
hiai_cpu_face_hat 1.7
hiai_video_seg 1
hiai_semantic_seg 3
hiai_human_seg 28
hiai_face_recognition_1 10
hiai_cpu_face_detect 4.5
hiai_cpu_face_attr 12
hiai_face_attr1 12
hiai_cpu_face_attr 82.3 # divded by small number causes big bias
hiai_face_attr1 82.3 # divded by small number causes big bias
# mtk_detect-mbv1-shortcut-400-400_nopostprocess_simplified: precision is 5%
mtk_detect-mbv1-shortcut-400-400_nopostprocess_simplified 5.5
mtk_detect_mbv1_640_480_nopostprocess_simplified 5
@ -79,7 +79,7 @@ hdc_contour_pose_128 0.5
hdc_emotion 0.5
hdc_fivembnet 1
hdc_isface 0.5
hdc_mobilenetface 8.5
hdc_mobilenetface 11.5 # small output causes big bias
hdc_retinaface 14
hdc_resnet 7
ml_video_edit_detect 2.5
@ -94,13 +94,13 @@ ml_video_edit_video_segment_gauss_adaptis_part1 5
# When the input range is [-1,1], the precision is poor, and the output value is very small (10e-5). If the input range is adjusted to [0,255], the precision will decrease to 15.5415%, and the rest is cumulative error.
ml_handpose 175
hdc_Face_Aesthetic_MTI_Aesthetic 0.5
ml_face_compare 5.5
ml_face_compare 8.7
ml_face_tracking 2.5
ml_face_beard 0.5
ml_face_age 3.5
ml_face_pose 1
ml_face_isface 0.5
ml_face_glasses 2.5
ml_face_glasses 3.4
# ml_segmentation_matting 26 # output value unstable
ml_segmentation_atlanta_10 5
# ml_bodymask: The difference of output node divided by a very small value leads to a large error
@ -108,13 +108,13 @@ ml_bodymask 16
ml_Hand_deploy 4
# ml_hand_3d_detection: The difference of output node divided by a very small value leads to a large error
ml_hand_3d_detection 12
ml_hand_3d_regression 3
ml_hand_3d_regression 5.4
# ml_ARengine23_bodypose: The difference of output node divided by a very small value leads to a large error
ml_ARengine23_bodypose 56
ml_ocr_bank_card_detection_inception_tmp 20
ml_ocr_bank_card_recognition_fcny 0.5
hiai_cv_aestheticsEngineModel_osp 1.6
ml_face_hat 0.5
ml_face_hat 2.2
bank_card_recognition_fcny 17
bank_card_detection_inception_tmp 12
ml_ocr_identify_card_fcny 0.5
@ -122,8 +122,8 @@ ml_ocr_identify_card_detect_tmp 2
identify_card_detect_tmp 0.5
ml_2012_ocr_detection_caffe_tmp 1
ml_2012_ocr_rec_caffe 0.5
ml_lable_model_hebing_device 2
ml_face_sex 0.5
ml_lable_model_hebing_device 3
ml_face_sex 0.6
# ml_face_mnet: The precision problem caused by cumulative error.
ml_face_mnet 12
ml_segmentation_atlanta_1 0.5

View File

@ -61,8 +61,8 @@ ml_ei_facedetection.onnx 2
#ml_video_edit_art_generate.onnx #mul operator overflows, not suitable for fp16
#ml_voice_detect.onnx #conv operator overflows, not suitable for fp16
#ml_location_lane_counter.onnx has very small values during op computation (<1e-6), which causes the precision variation
ml_location_lane_counter.onnx 7
ml_location_lane_counter0.onnx 0.5
ml_location_lane_counter.onnx 7.5
ml_location_lane_counter0.onnx 1.0
#The encoder an decoder model are used in ml_asr scene, both have value overflow. Not suitable for fp16.
#But added for guarding process.
encoder.onnx;1;1,32,83 1262
@ -75,19 +75,19 @@ mtk_detect-mbv1-shortcut-400-400_nopostprocess_simplified_onnx.onnx 6.5
mtk_detect-deeper-halfdeeper-mbv1-lastearlySSD-shortcut-400-400_nopostprocess_simplified_onnx.onnx 2.5
mtk_detect_mbv1_640_480_nopostprocess_simplified_onnx.onnx;1;1,480,640,3 2
mtk_face_features_v2.onnx;1;1,256,192,3 0.5
mtk_face_recognition_v3.onnx 0.5
mtk_face_recognition_v3.onnx 1.1
mtk_face_recognition_v2.onnx 2.5
ml_2012_ocr_detection_tmp.onnx 0.5
Harmony_Voiceprint_resnet18.onnx;1;1,150,40,1 4.5
Harmony_Voiceprint_resnet18.onnx;1;1,150,40,1 5.5
bloom_hongmo_detection_tmp.onnx 0.5
Q_face_recognition.onnx 3
Q_face_recognition.onnx 3.2
ml_video_edit_enhance_update_tmp.onnx 0.5
Q888_face_recognition.onnx 3.5
Q888_iris_detect.onnx 0.5
ssd_mobilenet_v1_10.onnx;1;1,383,640,3 0.5
# The output from a conv in the later part contains many minus values, the following leakyRelu makes them become very
# close to 0 (-e^-4). The fp16 precision lost a lot in this case and it affects the following computation.
Harmony_Voiceprint.onnx;1;1,200,40,1 5.5
Harmony_Voiceprint.onnx;1;1,200,40,1 21.5 # small output causes big bias
# A matmul op in the later part produces overflowed output values (>65504).
#ml_video_edit_art_generate_20210513.onnx nan
ml_asr_encoder_int8_202103.onnx;;;4 2.1

View File

@ -33,7 +33,7 @@ mtk_face_features_v1.pb 26
model_normalize_object_scene_ps_20200519.pb;1;1,224,224,3 10
hiai_AADB_HADB_MBV2_model.pb;1;1,224,224,3 6
hiai_frozen_inference_graph.pb 8
hiai_lm_inference_graph.pb 0.6
hiai_lm_inference_graph.pb 1.2
hiai_ghostnet.pb 0.9
hiai_face_model_npu.pb 0.5
hiai_cv_focusShootOCRModel_02.pb 10.5
@ -60,9 +60,9 @@ bolt_segment.pb 2
siteAI_wireless_depress_w.pb;1;1,36 0.5
siteAI_wireless_restore_w.pb;1;1,36 0.5
siteAI_trans_nonlinear.pb;1;1,137 0.5
siteAI_trans_nonlinear40g.pb;1;1,271 0.5
siteAI_trans_nonlinear40g.pb;1;1,271 0.6
siteAI_trans_nonlinear134g.pb;1;1,137 0.5
siteAI_trans_nonlinear134g_nrz.pb;1;1,182 0.5
siteAI_trans_nonlinear134g_nrz.pb;1;1,182 0.6
ml_vision_guide_detection2.pb;1;1,320,320,1 1
# ml_tts_encoder.pb has a round op, which will cause round-off error when the decimal of input value is near 0.5
ml_tts_encoder.pb;4;1:1,44:1:1 9
@ -85,4 +85,4 @@ ml_tts_decoder_control_flow.pb;5 1
ml_tts_decoder.pb;5 2.5
ml_tts_vocoder.pb;66 53
hiai_transformer_encoder.pb;15 4
decoder_step_nocumsum_v5.pb;13;1:1,512:1,1429,2:1,127:1,127:1,127:1,127,320:1,80:1,512:1,512:1,512:1,512:1,512 0.5
decoder_step_nocumsum_v5.pb;13;1:1,512:1,1429,2:1,127:1,127:1,127:1,127,320:1,80:1,512:1,512:1,512:1,512:1,512 1.2

View File

@ -75,8 +75,8 @@ mtk_model_emotions_0725_fp16.tflite 3
mtk_face_features_v1_fp16.tflite 20
siteAI_digcom_AI_ECN.tflite 0.1
siteAI_digcom_g2v_keras.tflite 5
siteAI_trans_nonlinear.tflite 0.1
siteAI_trans_tcpclassify.tflite 5
siteAI_trans_nonlinear.tflite 0.2
siteAI_trans_tcpclassify.tflite 5.3
siteAI_wireless_depress_w.tflite 8
siteAI_wireless_restore_w.tflite 0.1
magenta_arbitrary-image-stylization-v1-256_fp16_prediction_1.tflite 5
@ -123,7 +123,7 @@ lite-model_cartoongan_fp16_1.tflite 3
lite-model_arbitrary-image-stylization-inceptionv3_fp16_predict_1.tflite 6
gts_detect_5k_tf115.tflite 9.5
mtk_isface.tflite 0.2
mtk_landmark.tflite 0.1
mtk_landmark.tflite 0.3
mtk_new_detect.tflite 3
mtk_pose.tflite 2
mtk_model_emotions_0727_nosoftmax.tflite 2
@ -132,7 +132,7 @@ mtk_276landmark_0913.tflite 16
mtk_face_recognition.tflite 8
mtk_convert_model.tflite 5
smartreply.tflite 0.1
mindspore_text_classification_tflite.tflite 4
mindspore_text_classification_tflite.tflite 9.2 # small output causes big bias
#ml_location.tflite 0.1
ml_text_correction.tflite 1
# ml_pic_shopping.tflite involves subtract two close numbers.
@ -147,7 +147,7 @@ ml_ocr_jk_pb2tflite.tflite 0.5
ml_ocr_latin_pb2tflite.tflite 11.5
scan_hms_angle_pb2tflite.tflite 2.5
scan_hms_detect_pb2tflite.tflite 1.5
ml_location.tflite 0.5
ml_location.tflite 0.6
ml_face_openclose_tflite.tflite 0.5
ml_object_detect_pb2tflite.tflite 1.5
# lite-model_on_device_vision_classifier_landmarks_classifier* models' bias are caused by error accumulation and small
@ -189,7 +189,7 @@ Q_landmark.tflite 0.5
Q_new_detect.tflite 3.5
# the input of Q_object_scene model is between 0-255
Q_object_scene.tflite 3
Q_pose.tflite 1.5
Q_pose.tflite 4.1
Q_detect_fpn_add_inception-1448650.tflite 1
bloom_landmark.tflite 0.5
# input data: 0~255
@ -197,7 +197,7 @@ Q888_age_gender_orderd.tflite 1.5
Q888_face_dress_mv3y.tflite 0.5
Q888_HADB_AADB_MBV2_model_fp32.tflite 2.5
Q888_landmark.tflite 0.5
Q888_pose.tflite 5
Q888_pose.tflite 6.1
# the output contains value less than e-7
Q888_lapa158_unet_0924.tflite 19
Q888_isface.tflite 1.0
@ -219,4 +219,4 @@ hdc_tb_cn_neg.tflite;3 295
# The input of hiai_cv_labelDetectorModel_v3.tflite is between 0-255.
hiai_cv_labelDetectorModel_v3.tflite;2 2
ml_headpose_pb2tflite.tflite;3;16:1,64,64,3:16 1
ml_ei_headpose_pb2tflite.tflite;3;16:1,64,64,3:16 0.5
ml_ei_headpose_pb2tflite.tflite;3;16:1,64,64,3:16 0.6