!15503 [MS][LITE][CPU] arm32 fp16 算子优化

From: @lzkcode
Reviewed-by: @zhang_xue_tong,@zhanghaibo5
Signed-off-by: @zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2021-04-22 17:31:48 +08:00 committed by Gitee
commit 190ba8fd85
24 changed files with 324 additions and 192 deletions

View File

@ -322,7 +322,7 @@ void DeconvDepthwiseCenterFp16(float16_t *dst, const float16_t *src, const float
float16_t *dst_kw = dst_kh;
const float16_t *weight_kw = weight_kh;
for (int kw = 0; kw < kernel_w; kw++) {
#ifdef ENABLE_ARM64
#ifdef ENABLE_NEON
float16x8_t src_8 = vld1q_f16(src_w);
float16x8_t weight_8 = vld1q_f16(weight_kw);
float16x8_t dst_8 = vld1q_f16(dst_kw);

View File

@ -19,108 +19,6 @@
#include "nnacl/fp16/winograd_transform_fp16.h"
#include "nnacl/fp16/matmul_fp16.h"
#ifdef __cplusplus
extern "C" {
#endif
#ifdef ENABLE_ARM64
void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu,
size_t relu6);
#endif
#ifdef __cplusplus
}
#endif
#ifndef ENABLE_ARM64
void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
size_t ic4, size_t out_channel, size_t offset, size_t mode, size_t writeC8, size_t relu,
size_t relu6) {
if (!(mode && writeC8)) {
IndirectGemmFp16_16x8_common(output, input, weight, bias, step, ic4, out_channel, offset, relu, relu6);
} else {
IndirectGemmFp16_16x8_c8(output, input, weight, bias, step, ic4, out_channel, offset, mode, writeC8, relu, relu6);
}
}
void IndirectGemmFp16_16x8_common(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
size_t ic4, size_t out_channel, size_t offset, size_t relu, size_t relu6) {
const int tile_n = 16;
for (int i = 0; i < out_channel; i++) {
int oc8_block = i / C8NUM;
int oc8_res = i % C8NUM;
int weight_oc_offset = oc8_block * step * ic4 * C4NUM * C8NUM + oc8_res;
for (int k = 0; k < tile_n; k++) {
int input_tile_offset = k * C4NUM;
int out_tile_offset = i + k * out_channel;
float16_t tmp_out = 0;
for (int n = 0; n < step; n++) {
int input_kw_offset = input_tile_offset + n * tile_n * ic4 * C4NUM;
int weight_kw_offset = weight_oc_offset + n * ic4 * C4NUM * C8NUM;
for (int j = 0; j < ic4; j++) {
int input_ic4_offset = input_kw_offset + j * tile_n * C4NUM;
int weight_ic4_offset = weight_kw_offset + j * C4NUM * C8NUM;
for (int m = 0; m < C4NUM; m++) {
int input_c4_offset = input_ic4_offset + m;
int weight_c4_offset = weight_ic4_offset + m * C8NUM;
tmp_out += (input + input_c4_offset)[0] * (weight + weight_c4_offset)[0];
}
}
}
float16_t *tmp = output + out_tile_offset;
if (bias != NULL) {
tmp[0] = tmp_out + bias[i];
}
if (relu) {
tmp[0] = tmp[0] < 0 ? 0 : tmp[0];
} else if (relu6) {
tmp[0] = tmp[0] < 0 ? 0 : tmp[0];
tmp[0] = tmp[0] > 6 ? 6 : tmp[0];
}
}
}
}
void IndirectGemmFp16_16x8_c8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC8,
size_t relu, size_t relu6) {
const int tile_num = 16;
if (mode && writeC8) {
for (int i = 0; i < tile_num; i++) {
int input_tile_offset = i * C4NUM;
int output_tile_offset = i * output_channel * step;
for (int j = 0; j < output_channel; j++) {
int oc8_block = j / C8NUM;
int oc8_res = j % C8NUM;
int weight_oc_offset = oc8_block * step * ic4 * C4NUM * C8NUM + oc8_res;
int out_oc_offset = output_tile_offset + oc8_block * step * C8NUM + oc8_res;
for (int n = 0; n < step; n++) {
int input_kw_offset = input_tile_offset + n * ic4 * C4NUM * tile_num;
int weight_kw_offset = weight_oc_offset + n * ic4 * C4NUM * C8NUM;
int output_kw_offset = out_oc_offset + n * C8NUM;
float16_t acc = 0;
for (int k = 0; k < ic4; k++) {
int input_ic4_offset = input_kw_offset + k * tile_num * C4NUM;
int weight_ic4_offset = weight_kw_offset + k * C4NUM * C8NUM;
for (int m = 0; m < C4NUM; m++) {
int input_ic_offset = input_ic4_offset + m;
int weight_ic_offset = weight_ic4_offset + m * C8NUM;
acc += (weight + weight_ic_offset)[0] * (input + input_ic_offset)[0];
}
}
(output + output_kw_offset)[0] = acc;
}
}
}
} else {
}
}
#endif
// fp16 convolution common (im2col+gemm)
void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data,
float16_t *col_major_input, float16_t *output_data, int task_id, ConvParameter *conv_param) {

View File

@ -24,19 +24,6 @@
typedef float16_t *TmpBufferAddressFp16;
typedef float16_t *MatricesFp16;
#ifndef ENABLE_ARM64
void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC8, size_t relu,
size_t relu6);
void IndirectGemmFp16_16x8_common(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
size_t ic4, size_t oc8, size_t offset, size_t relu, size_t relu6);
void IndirectGemmFp16_16x8_c8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC8, size_t relu,
size_t relu6);
#endif
#ifdef __cplusplus
extern "C" {
#endif

View File

@ -21,7 +21,7 @@
void ExpFp16(const float16_t *src, float16_t *dst, int num) {
int i = 0;
#ifdef ENABLE_ARM64
#ifdef ENABLE_NEON
int count = (num / C8NUM) * C8NUM;
for (; i < count; i += C8NUM) {
simd_exp_fp16(vld1q_f16(src + i), dst + i);

View File

@ -25,7 +25,7 @@ extern "C" {
#endif
void ExpFp16(const float16_t *src, float16_t *dst, int num);
#if defined(ENABLE_ARM64)
#if defined(ENABLE_NEON)
static inline float32x4_t exp_fp32(float32x4_t input) {
static float32x4_t param[] = {{0.693147f, 0.693147f, 0.693147f, 0.693147f},
{1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120},
@ -49,7 +49,7 @@ static inline void simd_exp_fp16(float16x8_t input, float16_t *dst) {
input = vmaxq_f16(minv, vminq_f16(input, maxv));
float32x4_t input_low = vcvt_f32_f16(vget_low_f16(input));
float32x4_t input_high = vcvt_high_f32_f16(input);
float32x4_t input_high = vcvt_f32_f16(vget_high_f16(input));
vst1q_f16(dst, vcombine_f16(vcvt_f16_f32(exp_fp32(input_low)), vcvt_f16_f32(exp_fp32(input_high))));
}
#endif

View File

@ -43,11 +43,11 @@ int InstanceNormFp16(const float16_t *src_data, float16_t *dst_data, const float
float16x4_t sum2 = vadd_f16(vget_low_f16(srcv), vget_high_f16(srcv));
float32x4_t sum_f32 = vcvt_f32_f16(sum2);
mean += vaddvq_f32(sum_f32);
mean += MS_ADDVQ_F32(sum_f32);
float16x4_t square2 = vadd_f16(vget_low_f16(squarev), vget_high_f16(squarev));
float32x4_t square_f32 = vcvt_f32_f16(square2);
square_mean += vaddvq_f32(square_f32);
square_mean += MS_ADDVQ_F32(square_f32);
}
for (; index < param->inner_size_; index++) {
mean += src[index];

View File

@ -290,7 +290,7 @@ void MatMul16x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, cons
void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
int deep, int row, int col, int stride, int write_mode) {
if (write_mode == OutType_Nhwc) {
if (write_mode == OutType_Nhwc) { // common conv and matmul
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r12div = r / 12, r12mod = r % 12;
@ -308,7 +308,7 @@ void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, cons
dst[ci] = value;
}
}
} else if (write_mode == OutType_C8) {
} else if (write_mode == OutType_C8) { // common deconv
int col_8 = UP_ROUND(col, C8NUM);
int row_12 = UP_ROUND(row, C12NUM);
for (int r = 0; r < row_12; r++) {
@ -328,7 +328,7 @@ void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, cons
dst[ci] = value;
}
}
} else {
} else { // winograd conv
for (int i = 0; i < row; ++i) {
int src_r_offset = i;
int dst_r_offset = i * col * stride;
@ -353,12 +353,14 @@ void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, cons
void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type,
int depth, int row, int col, int stride, int out_type) {
if (out_type == OutType_C8) {
// common deconv
#ifdef ENABLE_ARM64
MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, false);
#else
MatMul12x8A32Fp16(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type);
MatMul12x8Fp16(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type);
#endif
} else {
// winograd conv(OntType_TileC8) and common conv(OutType_Nhwc) and matmul(OutType_Nhwc)
#ifdef ENABLE_ARM64
MatmulFp16Neon64Opt(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type);
#else

View File

@ -17,7 +17,7 @@
#include "nnacl/fp16/power_fp16.h"
#include "nnacl/errorcode.h"
#if defined(ENABLE_ARM64)
#if defined(ENABLE_NEON)
float16x8_t OptimizedPowerSimdFp16(float16x8_t x, const void *exponent) {
int tmp = (int)(*(float16_t *)exponent);
int exp = abs(tmp);
@ -53,23 +53,23 @@ float16_t OptimizedPowerScalarFp16(float16_t x, const void *exponent) {
void PowerBroadCastFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale,
float shift) {
PowerScalarFunFp16 PowerScalarFunFp16_ = NULL;
#if defined(ENABLE_ARM64)
#if defined(ENABLE_NEON)
PowerSimdFunFp16 PowerSimdFunFp16_ = NULL;
#endif
if (CheckInteger(*exponent)) {
#if defined(ENABLE_ARM64)
#if defined(ENABLE_NEON)
PowerSimdFunFp16_ = OptimizedPowerSimdFp16;
#endif
PowerScalarFunFp16_ = OptimizedPowerScalarFp16;
} else {
#if defined(ENABLE_ARM64)
#if defined(ENABLE_NEON)
PowerSimdFunFp16_ = StdPowerSimdFp16;
#endif
PowerScalarFunFp16_ = StdPowerScalarFp16;
}
int i = 0;
#ifdef ENABLE_ARM64
#ifdef ENABLE_NEON
int len_c8 = UP_ROUND(len, C8NUM);
float16x8_t scale_8 = vmovq_n_f16(scale);
float16x8_t shift_8 = vmovq_n_f16(shift);
@ -87,7 +87,7 @@ void PowerSingleFp16(const float16_t *input, const float16_t *exponent, float16_
float shift) {
int i = 0;
PowerScalarFunFp16 PowerScalarFunFp16_ = NULL;
#ifdef ENABLE_ARM64
#ifdef ENABLE_NEON
int len_c8 = UP_ROUND(len, C8NUM);
float16x8_t scale_8 = vmovq_n_f16(scale);
float16x8_t shift_8 = vmovq_n_f16(shift);

View File

@ -22,7 +22,7 @@
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
#include "nnacl/power_parameter.h"
#if defined(ENABLE_ARM64)
#if defined(ENABLE_NEON)
typedef float16x8_t (*PowerSimdFunFp16)(float16x8_t x, const void *exponent);
#endif
typedef float16_t (*PowerScalarFunFp16)(float16_t x, const void *exponent);
@ -37,7 +37,7 @@ static inline float16_t StdPowerScalarFp16(float16_t x, const void *exponent) {
return powf(x, *(float16_t *)exponent);
}
#if defined(ENABLE_ARM64)
#if defined(ENABLE_NEON)
static inline float16x8_t StdPowerSimdFp16(float16x8_t x, const void *exponent) {
float16x8_t result;
result[0] = powf(x[0], *(float16_t *)exponent);

View File

@ -23,7 +23,7 @@ void Fp16ScaleInner(float16_t *in_data, float16_t *out_data, float16_t *scale, f
for (int i = 0; i < axis_size; i++) {
int axis_offset = out_offset + i * inner_size;
int in_index = 0;
#ifdef ENABLE_ARM64
#ifdef ENABLE_NEON
for (; in_index < inner_size - 8; in_index += 8) {
int in_offset = axis_offset + in_index;
float16x8_t data = vld1q_f16(in_data + in_offset);
@ -47,7 +47,7 @@ void Fp16ScaleAxis(float16_t *in_data, float16_t *out_data, float16_t *scale, fl
for (int out = outer_start; out < outer_end; out++) {
int out_offset = out * axis_size;
int index = 0;
#ifdef ENABLE_ARM64
#ifdef ENABLE_NEON
for (; index < axis_size - 8; index += 8) {
int in_offset = out_offset + index;
float16x8_t data = vld1q_f16(in_data + in_offset);
@ -80,7 +80,7 @@ void DoScaleFp16(float16_t *in_data, float16_t *out_data, float16_t *scale, floa
void Fp16ScaleInnerRelu(float16_t *in_data, float16_t *out_data, float16_t *scale, float16_t *offset, int outer_start,
int outer_end, int axis_size, int inner_size) {
#ifdef ENABLE_ARM64
#ifdef ENABLE_NEON
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
#endif
for (int out = outer_start; out < outer_end; out++) {
@ -88,7 +88,7 @@ void Fp16ScaleInnerRelu(float16_t *in_data, float16_t *out_data, float16_t *scal
for (int i = 0; i < axis_size; i++) {
int axis_offset = out_offset + i * inner_size;
int in_index = 0;
#ifdef ENABLE_ARM64
#ifdef ENABLE_NEON
for (; in_index < inner_size - 8; in_index += 8) {
int in_offset = axis_offset + in_index;
float16x8_t data = vld1q_f16(in_data + in_offset);
@ -110,13 +110,13 @@ void Fp16ScaleInnerRelu(float16_t *in_data, float16_t *out_data, float16_t *scal
void Fp16ScaleAxisRelu(float16_t *in_data, float16_t *out_data, float16_t *scale, float16_t *offset, int outer_start,
int outer_end, int axis_size) {
#ifdef ENABLE_ARM64
#ifdef ENABLE_NEON
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
#endif
for (int out = outer_start; out < outer_end; out++) {
int out_offset = out * axis_size;
int index = 0;
#ifdef ENABLE_ARM64
#ifdef ENABLE_NEON
for (; index < axis_size - 8; index += 8) {
int in_offset = out_offset + index;
float16x8_t data = vld1q_f16(in_data + in_offset);
@ -151,7 +151,7 @@ void Fp16DoScaleRelu(float16_t *in_data, float16_t *out_data, float16_t *scale,
void Fp16ScaleInnerRelu6(float16_t *in_data, float16_t *out_data, float16_t *scale, float16_t *offset, int outer_start,
int outer_end, int axis_size, int inner_size) {
#ifdef ENABLE_ARM64
#ifdef ENABLE_NEON
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6};
#endif
@ -160,7 +160,7 @@ void Fp16ScaleInnerRelu6(float16_t *in_data, float16_t *out_data, float16_t *sca
for (int i = 0; i < axis_size; i++) {
int axis_offset = out_offset + i * inner_size;
int in_index = 0;
#ifdef ENABLE_ARM64
#ifdef ENABLE_NEON
for (; in_index < inner_size - 8; in_index += 8) {
int in_offset = axis_offset + in_index;
float16x8_t data = vld1q_f16(in_data + in_offset);
@ -182,14 +182,14 @@ void Fp16ScaleInnerRelu6(float16_t *in_data, float16_t *out_data, float16_t *sca
void Fp16ScaleAxisRelu6(float16_t *in_data, float16_t *out_data, float16_t *scale, float16_t *offset, int outer_start,
int outer_end, int axis_size) {
#ifdef ENABLE_ARM64
#ifdef ENABLE_NEON
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6};
#endif
for (int out = outer_start; out < outer_end; out++) {
int out_offset = out * axis_size;
int index = 0;
#ifdef ENABLE_ARM64
#ifdef ENABLE_NEON
for (; index < axis_size - 8; index += 8) {
int in_offset = out_offset + index;
float16x8_t data = vld1q_f16(in_data + in_offset);

View File

@ -22,14 +22,14 @@ void SoftmaxNormFp16(const float16_t *src, float16_t *dst, int batch, int channe
int cur_batch_offset = 0;
for (int i = 0; i < batch; i++, cur_batch_offset += channel) {
int j = 0;
#ifdef ENABLE_ARM64
#ifdef ENABLE_NEON
float16x8_t max_8 = vdupq_n_f16(-FLT16_MAX);
int count = (channel / C8NUM) * C8NUM;
for (; j < count; j += C8NUM) {
float16x8_t input_8 = vld1q_f16(src + cur_batch_offset + j);
max_8 = vmaxq_f16(max_8, input_8);
}
float16_t max = vmaxvq_f16(max_8);
float16_t max = MS_MAXVQ_F16(max_8);
#else
float16_t max = -FLT_MAX;
#endif
@ -40,7 +40,7 @@ void SoftmaxNormFp16(const float16_t *src, float16_t *dst, int batch, int channe
}
}
int k = 0;
#ifdef ENABLE_ARM64
#ifdef ENABLE_NEON
int count2 = (channel / C8NUM) * C8NUM;
for (; k < count2; k += C8NUM) {
float16x8_t input_8 = vld1q_f16(src + cur_batch_offset + k);
@ -60,7 +60,7 @@ void SumAndDivFp16(const float16_t *src, float16_t *dst, int batch, int channel)
for (int i = 0; i < batch; i++, cur_batch_offset += channel) {
float16_t sum = 0.0f;
int j = 0;
#ifdef ENABLE_ARM64
#ifdef ENABLE_NEON
float16x8_t sum8 = vdupq_n_f16(0);
int count = (channel / C8NUM) * C8NUM;
for (; j < count; j += C8NUM) {
@ -72,7 +72,7 @@ void SumAndDivFp16(const float16_t *src, float16_t *dst, int batch, int channel)
sum += src[cur_batch_offset + j];
}
int k = 0;
#ifdef ENABLE_ARM64
#ifdef ENABLE_NEON
const float16_t div = 1.0f / sum;
for (; k < count; k += C8NUM) {
vst1q_f16(dst + cur_batch_offset + k, vmulq_n_f16(vld1q_f16(src + cur_batch_offset + k), div));

View File

@ -113,7 +113,7 @@ void PRelu(const float *input, float *output, float *slope, int start, int end,
const float *cur_in = input + i * channel;
float *cur_out = output + i * channel;
int j = 0;
#if defined(ENABLE_ARM)
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
for (; j < channel - 3; j += 4) {
MS_FLOAT32X4 in = MS_LDQ_F32(cur_in + j);
MS_FLOAT32X4 s = MS_LDQ_F32(slope + j);

View File

@ -19,7 +19,7 @@
#include "nnacl/intrinsics/ms_simd_instructions.h"
#if defined(ENABLE_ARM82_A32)
static inline float16x8_t divq_f16(float16x8_t in1, float16x8_t in2) {
static inline float16x8_t ms_vdivq_f16(float16x8_t in1, float16x8_t in2) {
float16x8_t dst;
asm volatile(
"vrecpe.f16 q14, %3\n"
@ -34,7 +34,7 @@ static inline float16x8_t divq_f16(float16x8_t in1, float16x8_t in2) {
return dst;
}
static inline float16x4_t div_f16(float16x4_t in1, float16x4_t in2) {
static inline float16x4_t ms_vdiv_f16(float16x4_t in1, float16x4_t in2) {
float16x4_t dst;
asm volatile(
"vrecpe.f16 d14, %3\n"
@ -49,33 +49,47 @@ static inline float16x4_t div_f16(float16x4_t in1, float16x4_t in2) {
return dst;
}
static inline float vaddvq_f32(float32x4_t in) { // is not support in arm82 aarch32
static inline float ms_vaddvq_f32(float32x4_t in) {
// is not support in arm82 aarch32 and there is no assembly instruction to process all the data
return in[0] + in[1] + in[2] + in[3];
}
static inline float32x4_t cvt_f32_f16(float16x4_t in) {
static inline float16_t ms_vmaxvq_f16(float16x8_t in) {
// is not support in arm82 aarch32 and there is no assembly instruction to process all the data
float16_t dst = in[0];
for (int i = 1; i < 8; ++i) {
dst = dst > in[i] ? dst : in[i];
}
return dst;
}
static inline float32x4_t ms_vcvt_f32_f16(float16x4_t in) {
float32x4_t dst;
asm volatile("vcvt.f32.f16 %0, %2\n" : "=w"(dst) : "0"(dst), "w"(in) :);
return dst;
}
static inline float16x4_t cvt_f16_f32(float32x4_t in) {
static inline float16x4_t ms_vcvt_f16_f32(float32x4_t in) {
float16x4_t dst;
asm volatile("vcvt.f16.f32 %0, %2\n" : "=w"(dst) : "0"(dst), "w"(in) :);
return dst;
}
#define MS_CVT_F32_F16(src) cvt_f32_f16(src)
#define MS_CVT_F16_F32(src) cvt_f16_f32(src)
#define MS_DIV_F16(src1, src2) div_f16(src1, src2)
#define MS_DIVQ_F16(src1, src2) divq_f16(src1, src2)
#define MS_CVT_F32_F16(src) ms_vcvt_f32_f16(src)
#define MS_CVT_F16_F32(src) ms_vcvt_f16_f32(src)
#define MS_DIV_F16(src1, src2) ms_vdiv_f16(src1, src2)
#define MS_DIVQ_F16(src1, src2) ms_vdivq_f16(src1, src2)
#define MS_FMAQ_N_F16(src1, src2, src3) vfmaq_f16(src1, src2, vdupq_n_f16(src3))
#define MS_MAXVQ_F16(src) ms_vmaxvq_f16(src)
#define MS_ADDVQ_F32(src) ms_vaddvq_f32(src)
#else
#define MS_CVT_F32_F16(src) vcvt_f32_f16(src)
#define MS_CVT_F16_F32(src) vcvt_f16_f32(src)
#define MS_DIV_F16(src1, src2) vdiv_f16(src1, src2)
#define MS_DIVQ_F16(src1, src2) vdivq_f16(src1, src2)
#define MS_FMAQ_N_F16(src1, src2, src3) vfmaq_n_f16(src1, src2, src3)
#define MS_MAXVQ_F16(src) vmaxvq_f16(src)
#define MS_ADDVQ_F32(src) vaddvq_f32(src)
#endif
static inline float16x8_t MS_TANHX8_F16(float16x8_t src) {

View File

@ -7,7 +7,7 @@ endif()
if(PLATFORM_ARM32 AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
set(ENABLE_FP16 "off")
message(WARNING "If you want to build fp16 in arm82_a32, \
message(STATUS "If you want to build fp16 in arm82_a32, \
your Clang version:[${CMAKE_CXX_COMPILER_VERSION}] must not be less than 9.0 and please use android nkd r21e!")
endif()

View File

@ -42,6 +42,11 @@ int FullconnectionFP16CPUKernel::ReSize() {
}
int FullconnectionFP16CPUKernel::Init() {
#ifdef ENABLE_ARM64
row_tile_ = C16NUM;
#else
row_tile_ = C12NUM;
#endif
params_->batch = 1;
params_->a_transpose_ = false;
params_->b_transpose_ = true;

View File

@ -114,12 +114,7 @@ void MatmulBaseFP16CPUKernel::ResizeParameter() {
params_->row_align_ = 1;
params_->col_align_ = params_->col_;
} else {
#ifdef ENABLE_ARM64
int row_tile = C16NUM;
#else
int row_tile = C12NUM;
#endif
params_->row_align_ = UP_ROUND(params_->row_, row_tile);
params_->row_align_ = UP_ROUND(params_->row_, row_tile_);
params_->col_align_ = UP_ROUND(params_->col_, C8NUM);
}
return;

View File

@ -55,6 +55,7 @@ class MatmulBaseFP16CPUKernel : public LiteKernel {
protected:
MatMulParameter *params_ = nullptr;
int row_tile_ = 0;
private:
int thread_stride_ = 0;

View File

@ -36,7 +36,7 @@ void MatmulFP16CPUKernel::InitAShape() {
params_->batch = batch;
params_->row_ = params_->a_transpose_ ? a_shape[a_shape.size() - 1] : a_shape[a_shape.size() - 2];
params_->deep_ = params_->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1];
params_->row_16_ = UP_ROUND(params_->row_, C16NUM);
params_->row_16_ = UP_ROUND(params_->row_, row_tile_);
}
void MatmulFP16CPUKernel::InitBShape() {
@ -55,6 +55,11 @@ void MatmulFP16CPUKernel::InitBShape() {
}
int MatmulFP16CPUKernel::Init() {
#ifdef ENABLE_ARM64
row_tile_ = C16NUM;
#else
row_tile_ = C12NUM;
#endif
MatmulBaseFP16CPUKernel::InitParameter();
if (params_->a_const_) {

View File

@ -1,3 +1,8 @@
# [first column]:model_name, If you need input shape, please connect it through ';' after the model name.
# [second column]:accuracy limit in arm64
# [third column]:accuracy limit in armv82_a32
# Each column is separated by a space and comment on a single line!
# The missing third column indicates that armv82_a32 does not need to maintain this model.
age_medium 6
beard 2
emotion 60
@ -68,7 +73,8 @@ PoseNet_dla_17_x512_tmp 5
ml_location_scene_division 8
ml_tabel_recog 0.1
ml_text_division 12
ml_video_edit_Mnet 11 # Further analysis in the future
# Further analysis in the future to model ml_video_edit_Mnet
ml_video_edit_Mnet 11
ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145 0.5
hdc_age_medium 6
hdc_contour_pose_128 0.5
@ -100,13 +106,13 @@ ml_face_glasses 2.5
# ml_segmentation_matting 26 # output value unstable
ml_segmentation_atlanta_10 5
# ml_bodymask: The difference of output node divided by a very small value leads to a large error
ml_bodymask 14
ml_Hand_deploy 4
ml_bodymask 14 13
ml_Hand_deploy 4 4
# ml_hand_3d_detection: The difference of output node divided by a very small value leads to a large error
ml_hand_3d_detection 12
ml_hand_3d_regression 3
ml_hand_3d_detection 12 10
ml_hand_3d_regression 3 4
# ml_ARengine23_bodypose: The difference of output node divided by a very small value leads to a large error
ml_ARengine23_bodypose 56
ml_ARengine23_bodypose 56 58
ml_ocr_bank_card_detection_inception_tmp 20
ml_ocr_bank_card_recognition_fcny 0.5
hiai_cv_aestheticsEngineModel_osp 1.5

View File

@ -1,3 +1,8 @@
# [first column]:model_name, If you need input shape, please connect it through ';' after the model name.
# [second column]:accuracy limit in arm64
# [third column]:accuracy limit in armv82_a32
# Each column is separated by a space and comment on a single line!
# The missing third column indicates that armv82_a32 does not need to maintain this model.
mtk_detect-mbv2-shortcut-400-400-simplified.onnx 4
mtk_face_features_v3.onnx 20
emotion-ferplus-8.onnx 1
@ -37,7 +42,8 @@ residual_distill_res34_cifar10_bs_1_update.onnx 2
residual_distill_res50_cifar10_bs_1_update.onnx 2
#ml_voice_detect.onnx #out of float16 range because power op
hdc_ocr_attention.onnx 1.6
hdc_ocr_detect_tmp.onnx 30 #one of the output has small values
#one of the output has small values in model hdc_ocr_detect_tmp.onnx
hdc_ocr_detect_tmp.onnx 30
ml_edu_kit_hand_detection.onnx 2
ml_edu_kit_hand_key_position.onnx 2
ml_video_edit_judge.onnx 12

View File

@ -1,3 +1,8 @@
# [first column]:model_name, If you need input shape, please connect it through ';' after the model name.
# [second column]:accuracy limit in arm64
# [third column]:accuracy limit in armv82_a32
# Each column is separated by a space and comment on a single line!
# The missing third column indicates that armv82_a32 does not need to maintain this model.
ml_vision_guide_detection1.pb 0.5
ml_vision_guide_detection3.pb 0.5
ml_video_edit_generate_filter.pb 2

View File

@ -1,3 +1,8 @@
# [first column]:model_name, If you need input shape, please connect it through ';' after the model name.
# [second column]:accuracy limit in arm64
# [third column]:accuracy limit in armv82_a32
# Each column is separated by a space and comment on a single line!
# The missing third column indicates that armv82_a32 does not need to maintain this model.
hiai_model_0909_kd_rot_ps_softmax.tflite 10
hiai_chinese_english_recognize_model_float32.tflite 13
hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite.tflite 10

View File

@ -1,3 +1,8 @@
# [first column]:model_name;input_bin_number;input_shape (input_bin_number and input_shape maybe do not need.)
# [second column]:accuracy limit in arm64
# [third column]:accuracy limit in armv82_a32
# Each column is separated by a space.
# The missing third column indicates that armv82_a32 does not need to maintain this model.
ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite.tflite;2 11
ml_video_edit_video_segment_gauss_adaptis_part2.pb;2 11
ml_video_edit_img_segment_adaptise.pb;2 40
@ -19,5 +24,5 @@ ml_tts_decoder.pb;5 2.5
hiai_cv_labelDetectorModel_v3.tflite;2 2
ml_tts_vocoder.pb;66 53
# The outputs of two Heatmap_depth models have small value
ml_Heatmap_depth_240180;2 102
ml_Heatmap_depth_180240;2 101
ml_Heatmap_depth_240180;2 10 16
ml_Heatmap_depth_180240;2 7 7

232
mindspore/lite/test/run_benchmark_nets.sh Executable file → Normal file
View File

@ -1900,6 +1900,181 @@ function Run_arm64_fp16() {
fi
done < ${models_multiple_inputs_fp16_config}
}
# Run on armv8.2-a32-fp16 platform:
function Run_armv82_a32_fp16() {
cd ${armv82_path} || exit 1
tar -zxf mindspore-lite-${version}-inference-android-aarch32.tar.gz || exit 1
# If build with minddata, copy the minddata related libs
cd ${benchmark_test_path} || exit 1
if [ -f ${armv82_path}/mindspore-lite-${version}-inference-android-aarch32/inference/minddata/lib/libminddata-lite.so ]; then
cp -a ${armv82_path}/mindspore-lite-${version}-inference-android-aarch32/inference/minddata/lib/libminddata-lite.so ${benchmark_test_path}/libminddata-lite.so || exit 1
fi
cp -a ${armv82_path}/mindspore-lite-${version}-inference-android-aarch32/inference/lib/libmindspore-lite.so ${benchmark_test_path}/libmindspore-lite.so || exit 1
cp -a ${armv82_path}/mindspore-lite-${version}-inference-android-aarch32/tools/benchmark/benchmark ${benchmark_test_path}/benchmark || exit 1
# adb push all needed files to the phone
adb -s ${device_id} push ${benchmark_test_path} /data/local/tmp/ > adb_push_log.txt
# run adb ,run session ,check the result:
echo 'cd /data/local/tmp/benchmark_test' > adb_cmd.txt
echo 'cp /data/local/tmp/arm32/libc++_shared.so ./' >> adb_cmd.txt
echo 'chmod 777 benchmark' >> adb_cmd.txt
adb -s ${device_id} shell < adb_cmd.txt
# Run fp16 converted models:
while read line; do
fp16_line_info=${line}
column_num=`echo ${fp16_line_info} | awk -F ' ' '{print NF}'`
if [[ ${fp16_line_info} == \#* || ${column_num} -lt 3 ]]; then
continue
fi
model_info=`echo ${fp16_line_info}|awk -F ' ' '{print $1}'`
accuracy_limit=`echo ${fp16_line_info}|awk -F ' ' '{print $3}'`
model_name=${model_info%%;*}
length=${#model_name}
input_shapes=${model_info:length+1}
echo "---------------------------------------------------------" >> "${run_armv82_a32_fp16_log_file}"
echo "fp16 run: ${model_name}, accuracy limit:${accuracy_limit}" >> "${run_armv82_a32_fp16_log_file}"
echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test' >> adb_run_cmd.txt
if [[ $accuracy_limit == "-1" ]]; then
echo './benchmark --modelFile='${model_name}'.fp16.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --enableFp16=true --inputShapes='${input_shapes} >> adb_run_cmd.txt
else
echo './benchmark --modelFile='${model_name}'.fp16.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --enableFp16=true --accuracyThreshold='${accuracy_limit} ' --inputShapes='${input_shapes} >> adb_run_cmd.txt
fi
cat adb_run_cmd.txt >> "${run_armv82_a32_fp16_log_file}"
adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_armv82_a32_fp16_log_file}"
if [ $? = 0 ]; then
run_result='armv82_a32_fp16: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file}
else
run_result='armv82_a32_fp16: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1
fi
done < ${models_onnx_fp16_config}
while read line; do
fp16_line_info=${line}
column_num=`echo ${fp16_line_info} | awk -F ' ' '{print NF}'`
if [[ ${fp16_line_info} == \#* || ${column_num} -lt 3 ]]; then
continue
fi
model_name=`echo ${fp16_line_info}|awk -F ' ' '{print $1}'`
accuracy_limit=`echo ${fp16_line_info}|awk -F ' ' '{print $3}'`
echo "---------------------------------------------------------" >> "${run_armv82_a32_fp16_log_file}"
echo "fp16 run: ${model_name}, accuracy limit:${accuracy_limit}" >> "${run_armv82_a32_fp16_log_file}"
echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test' >> adb_run_cmd.txt
echo './benchmark --modelFile='${model_name}'.fp16.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --enableFp16=true --accuracyThreshold='${accuracy_limit} >> adb_run_cmd.txt
cat adb_run_cmd.txt >> "${run_armv82_a32_fp16_log_file}"
adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_armv82_a32_fp16_log_file}"
if [ $? = 0 ]; then
run_result='armv82_a32_fp16: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file}
else
run_result='armv82_a32_fp16: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1
fi
done < ${models_caffe_fp16_config}
while read line; do
fp16_line_info=${line}
column_num=`echo ${fp16_line_info} | awk -F ' ' '{print NF}'`
if [[ ${fp16_line_info} == \#* || ${column_num} -lt 3 ]]; then
continue
fi
model_name=`echo ${fp16_line_info}|awk -F ' ' '{print $1}'`
accuracy_limit=`echo ${fp16_line_info}|awk -F ' ' '{print $3}'`
echo "---------------------------------------------------------" >> "${run_armv82_a32_fp16_log_file}"
echo "fp16 run: ${model_name}, accuracy limit:${accuracy_limit}" >> "${run_armv82_a32_fp16_log_file}"
echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test' >> adb_run_cmd.txt
echo './benchmark --modelFile='${model_name}'.fp16.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --enableFp16=true --accuracyThreshold='${accuracy_limit} >> adb_run_cmd.txt
cat adb_run_cmd.txt >> "${run_armv82_a32_fp16_log_file}"
adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_armv82_a32_fp16_log_file}"
if [ $? = 0 ]; then
run_result='armv82_a32_fp16: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file}
else
run_result='armv82_a32_fp16: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1
fi
done < ${models_tflite_fp16_config}
# Run fp16 converted models:
while read line; do
fp16_line_info=${line}
column_num=`echo ${fp16_line_info} | awk -F ' ' '{print NF}'`
if [[ ${fp16_line_info} == \#* || ${column_num} -lt 3 ]]; then
continue
fi
model_info=`echo ${fp16_line_info}|awk -F ' ' '{print $1}'`
accuracy_limit=`echo ${fp16_line_info}|awk -F ' ' '{print $3}'`
model_name=${model_info%%;*}
length=${#model_name}
input_shapes=${model_info:length+1}
echo "---------------------------------------------------------" >> "${run_armv82_a32_fp16_log_file}"
echo "fp16 run: ${model_name}, accuracy limit:${accuracy_limit}" >> "${run_armv82_a32_fp16_log_file}"
echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test' >> adb_run_cmd.txt
if [[ $accuracy_limit == "-1" ]]; then
echo './benchmark --modelFile='${model_name}'.fp16.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --enableFp16=true --inputShapes='${input_shapes} >> adb_run_cmd.txt
else
echo './benchmark --modelFile='${model_name}'.fp16.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --enableFp16=true --accuracyThreshold='${accuracy_limit} ' --inputShapes='${input_shapes} >> adb_run_cmd.txt
fi
cat adb_run_cmd.txt >> "${run_armv82_a32_fp16_log_file}"
adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_armv82_a32_fp16_log_file}"
if [ $? = 0 ]; then
run_result='armv82_a32_fp16: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file}
else
run_result='armv82_a32_fp16: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1
fi
done < ${models_tf_fp16_config}
# Run converted models which has multiple inputs in fp16 mode:
while read line; do
fp16_line_info=${line}
column_num=`echo ${fp16_line_info} | awk -F ' ' '{print NF}'`
if [[ ${fp16_line_info} == \#* || ${column_num} -lt 3 ]]; then
continue
fi
model_info=`echo ${fp16_line_info}|awk -F ' ' '{print $1}'`
accuracy_limit=`echo ${fp16_line_info}|awk -F ' ' '{print $3}'`
model_name=`echo ${model_info}|awk -F ';' '{print $1}'`
input_num=`echo ${model_info} | awk -F ';' '{print $2}'`
input_shapes=`echo ${model_info} | awk -F ';' '{print $3}'`
input_files=''
output_file=''
data_path="/data/local/tmp/input_output/"
for i in $(seq 1 $input_num)
do
input_files=$input_files${data_path}'input/'$model_name'.ms.bin_'$i','
done
output_file=${data_path}'output/'${model_name}'.ms.out'
if [[ ${model_name##*.} == "caffemodel" ]]; then
model_name=${model_name%.*}
fi
echo "---------------------------------------------------------" >> "${run_armv82_a32_fp16_log_file}"
echo "fp16 run: ${model_name}, accuracy limit:${accuracy_limit}" >> "${run_armv82_a32_fp16_log_file}"
echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test' >> adb_run_cmd.txt
echo './benchmark --modelFile='${model_name}'.fp16.ms --inDataFile='${input_files}' --inputShapes='${input_shapes}' --benchmarkDataFile='${output_file} '--enableFp16=true --accuracyThreshold='${accuracy_limit} >> adb_run_cmd.txt
cat adb_run_cmd.txt >> "${run_armv82_a32_fp16_log_file}"
adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_armv82_a32_fp16_log_file}"
if [ $? = 0 ]; then
run_result='armv82_a32_fp16: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file}
else
run_result='armv82_a32_fp16: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1
fi
done < ${models_multiple_inputs_fp16_config}
}
# Run on gpu platform:
function Run_gpu() {
cd ${arm64_path} || exit 1
@ -2249,7 +2424,7 @@ fi
# Write benchmark result to temp file
run_benchmark_result_file=${basepath}/run_benchmark_result.txt
echo ' ' > ${run_benchmark_result_file}
run_x86_log_file
run_x86_log_file=${basepath}/run_x86_log.txt
echo 'run x86 logs: ' > ${run_x86_log_file}
@ -2271,6 +2446,9 @@ echo 'run arm64_fp32 logs: ' > ${run_arm64_fp32_log_file}
run_arm64_fp16_log_file=${basepath}/run_arm64_fp16_log.txt
echo 'run arm64_fp16 logs: ' > ${run_arm64_fp16_log_file}
run_armv82_a32_fp16_log_file=${basepath}/run_armv82_a32_fp16_log.txt
echo 'run arm82_a32_fp16 logs: ' > ${run_armv82_a32_fp16_log_file}
run_arm32_log_file=${basepath}/run_arm32_log.txt
echo 'run arm32 logs: ' > ${run_arm32_log_file}
@ -2331,6 +2509,33 @@ if [[ $backend == "all" || $backend == "x86-all" || $backend == "x86-codegen" ]]
sleep 1
fi
if [[ $backend == "all" || $backend == "arm_cpu" || $backend == "arm32_fp16" ]]; then
# Run on armv82-a32-fp16
armv82_path=${release_path}/android_aarch32
file_name=$(ls ${armv82_path}/*inference-android-aarch32.tar.gz)
IFS="-" read -r -a file_name_array <<< "$file_name"
version=${file_name_array[2]}
echo "start Run armv82-a32-fp16 ..."
Run_armv82_a32_fp16
Run_armv82_a32_fp16_status=$?
sleep 1
fi
if [[ $backend == "all" || $backend == "arm_cpu" || $backend == "arm32_fp32" ]]; then
# Run on arm32
arm32_path=${release_path}/android_aarch32
# mv ${arm32_path}/*train-android-aarch32* ./train
file_name=$(ls ${arm32_path}/*inference-android-aarch32.tar.gz)
IFS="-" read -r -a file_name_array <<< "$file_name"
version=${file_name_array[2]}
echo "start Run arm32 ..."
Run_arm32
Run_arm32_status=$?
sleep 1
fi
if [[ $backend == "all" || $backend == "arm_cpu" || $backend == "arm64_fp32" ]]; then
# Run on arm64
arm64_path=${release_path}/android_aarch64
@ -2359,20 +2564,6 @@ if [[ $backend == "all" || $backend == "arm_cpu" || $backend == "arm64_fp16" ]];
sleep 1
fi
if [[ $backend == "all" || $backend == "arm_cpu" || $backend == "arm32" ]]; then
# Run on arm32
arm32_path=${release_path}/android_aarch32
# mv ${arm32_path}/*train-android-aarch32* ./train
file_name=$(ls ${arm32_path}/*inference-android-aarch32.tar.gz)
IFS="-" read -r -a file_name_array <<< "$file_name"
version=${file_name_array[2]}
echo "start Run arm32 ..."
Run_arm32
Run_arm32_status=$?
sleep 1
fi
if [[ $backend == "all" || $backend == "gpu_npu" || $backend == "gpu" ]]; then
# Run on gpu
arm64_path=${release_path}/android_aarch64
@ -2468,7 +2659,14 @@ if [[ $backend == "all" || $backend == "arm_cpu" || $backend == "arm64_fp16" ]];
isFailed=1
fi
fi
if [[ $backend == "all" || $backend == "arm_cpu" || $backend == "arm32" ]]; then
if [[ $backend == "all" || $backend == "arm_cpu" || $backend == "arm32_fp16" ]]; then
if [[ ${Run_armv82_a32_fp16_status} != 0 ]];then
echo "Run_armv82_a32_fp16 failed"
cat ${run_armv82_a32_fp16_log_file}
isFailed=1
fi
fi
if [[ $backend == "all" || $backend == "arm_cpu" || $backend == "arm32_fp32" ]]; then
if [[ ${Run_arm32_status} != 0 ]];then
echo "Run_arm32 failed"
cat ${run_arm32_log_file}
@ -2490,7 +2688,7 @@ if [[ $backend == "all" || $backend == "gpu_npu" || $backend == "npu" ]]; then
fi
fi
echo "Run_x86 and Run_x86_sse and Run_arm64_fp32 and Run_arm64_fp16 and Run_arm32 and Run_gpu and Run_npu is ended"
echo "Run_x86 and Run_x86_sse and Run_x86_avx and Run_arm64_fp32 and Run_arm64_fp16 and Run_arm32_fp32 and Run_armv82_a32_fp16 and Run_gpu and Run_npu and is ended"
Print_Benchmark_Result
if [[ $isFailed == 1 ]]; then
exit 1