forked from mindspore-Ecosystem/mindspore
performance improvments
This commit is contained in:
parent
1baebef648
commit
758e5e7ca8
|
@ -20,18 +20,156 @@
|
|||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
int ConvDwInputGradFp16(const float16_t *dy, const float16_t *w, float16_t *dx, int start, int count,
|
||||
const ConvParameter *conv_param) {
|
||||
static int ConvDwInputGrad16(const float16_t *dy, const float16_t *w, float16_t *dx, int start, int end,
|
||||
const ConvParameter *conv_param) {
|
||||
int in_h = conv_param->input_h_;
|
||||
int in_w = conv_param->input_w_;
|
||||
int out_w = conv_param->output_w_;
|
||||
int out_h = conv_param->output_h_;
|
||||
int out_ch = conv_param->output_channel_;
|
||||
int in_ch = conv_param->input_channel_;
|
||||
int out_spatial = conv_param->output_h_ * conv_param->output_w_;
|
||||
int k_h = conv_param->kernel_h_;
|
||||
int k_w = conv_param->kernel_w_;
|
||||
int k_spatial = k_h * k_w;
|
||||
int end = start + count;
|
||||
int batch = conv_param->input_batch_;
|
||||
int in_size = in_h * in_w * in_ch;
|
||||
int out_size = out_h * out_w * out_ch;
|
||||
|
||||
int j = start;
|
||||
for (; j <= (end - C16NUM); j += C16NUM) {
|
||||
float16_t *c = dx + j;
|
||||
const float16_t *mat_b[C16NUM];
|
||||
for (int j_i = 0; j_i < C16NUM; j_i++) {
|
||||
mat_b[j_i] = w + (j + j_i) * k_spatial;
|
||||
}
|
||||
for (int si = 0; si < out_spatial; si++) {
|
||||
const float16_t *a = dy + j + si * out_ch;
|
||||
int output_row = (si) / out_w;
|
||||
int output_col = (si) % out_w;
|
||||
int row_stride_offset = -conv_param->pad_u_ + output_row * conv_param->stride_h_;
|
||||
int col_stride_offset = -conv_param->pad_l_ + output_col * conv_param->stride_w_;
|
||||
for (int k = 0; k < k_spatial; k++) {
|
||||
int kernel_row = k / k_w;
|
||||
int kernel_col = k % k_w;
|
||||
int input_row = kernel_row * conv_param->dilation_h_ + row_stride_offset;
|
||||
int input_col = kernel_col * conv_param->dilation_w_ + col_stride_offset;
|
||||
if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) {
|
||||
int offset = (input_row * in_w + input_col) * in_ch;
|
||||
#ifdef ENABLE_ARM
|
||||
float16x8_t mat_b0 = {mat_b[0][k], mat_b[1][k], mat_b[2][k], mat_b[3][k],
|
||||
mat_b[4][k], mat_b[5][k], mat_b[6][k], mat_b[7][k]};
|
||||
float16x8_t mat_b1 = {mat_b[8][k], mat_b[9][k], mat_b[10][k], mat_b[11][k],
|
||||
mat_b[12][k], mat_b[13][k], mat_b[14][k], mat_b[15][k]};
|
||||
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int dx_offset = b * in_size + offset;
|
||||
int dy_offset = b * out_size;
|
||||
float16x8_t mat_c0 = vld1q_f16(c + dx_offset);
|
||||
float16x8_t mat_a0 = vld1q_f16(a + dy_offset);
|
||||
mat_c0 = vfmaq_f16(mat_c0, mat_b0, mat_a0);
|
||||
vst1q_f16(c + dx_offset, mat_c0);
|
||||
|
||||
float16x8_t mat_c1 = vld1q_f16(c + dx_offset + 8);
|
||||
float16x8_t mat_a1 = vld1q_f16(a + dy_offset + 8);
|
||||
mat_c1 = vfmaq_f16(mat_c1, mat_b1, mat_a1);
|
||||
vst1q_f16(c + dx_offset + 8, mat_c1);
|
||||
}
|
||||
#else
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int dx_offset = b * in_size + offset;
|
||||
int dy_offset = b * out_size;
|
||||
for (int j_i = 0; j_i < C16NUM; j_i++) {
|
||||
c[dx_offset + j_i] += a[dy_offset + j_i] * mat_b[j_i][k];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return j;
|
||||
}
|
||||
|
||||
static int ConvDwInputGrad8(const float16_t *dy, const float16_t *w, float16_t *dx, int start, int end,
|
||||
const ConvParameter *conv_param) {
|
||||
int in_h = conv_param->input_h_;
|
||||
int in_w = conv_param->input_w_;
|
||||
int out_w = conv_param->output_w_;
|
||||
int out_h = conv_param->output_h_;
|
||||
int out_ch = conv_param->output_channel_;
|
||||
int in_ch = conv_param->input_channel_;
|
||||
int out_spatial = conv_param->output_h_ * conv_param->output_w_;
|
||||
int k_h = conv_param->kernel_h_;
|
||||
int k_w = conv_param->kernel_w_;
|
||||
int k_spatial = k_h * k_w;
|
||||
int batch = conv_param->input_batch_;
|
||||
int in_size = in_h * in_w * in_ch;
|
||||
int out_size = out_h * out_w * out_ch;
|
||||
|
||||
int j = start;
|
||||
for (; j <= (end - C8NUM); j += C8NUM) {
|
||||
float16_t *c = dx + j;
|
||||
const float16_t *mat_b[C8NUM];
|
||||
for (int j_i = 0; j_i < C8NUM; j_i++) {
|
||||
mat_b[j_i] = w + (j + j_i) * k_spatial;
|
||||
}
|
||||
|
||||
for (int si = 0; si < out_spatial; si++) {
|
||||
const float16_t *a = dy + j + si * out_ch;
|
||||
int output_row = (si) / out_w;
|
||||
int output_col = (si) % out_w;
|
||||
int row_stride_offset = -conv_param->pad_u_ + output_row * conv_param->stride_h_;
|
||||
int col_stride_offset = -conv_param->pad_l_ + output_col * conv_param->stride_w_;
|
||||
for (int k = 0; k < k_spatial; k++) {
|
||||
int kernel_row = k / k_w;
|
||||
int kernel_col = k % k_w;
|
||||
int input_row = kernel_row * conv_param->dilation_h_ + row_stride_offset;
|
||||
int input_col = kernel_col * conv_param->dilation_w_ + col_stride_offset;
|
||||
if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) {
|
||||
int offset = (input_row * in_w + input_col) * in_ch;
|
||||
#ifdef ENABLE_ARM
|
||||
float16x8_t mat_b0 = {mat_b[0][k], mat_b[1][k], mat_b[2][k], mat_b[3][k],
|
||||
mat_b[4][k], mat_b[5][k], mat_b[6][k], mat_b[7][k]};
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int dx_offset = b * in_size + offset;
|
||||
int dy_offset = b * out_size;
|
||||
float16x8_t mat_c0 = vld1q_f16(c + dx_offset);
|
||||
float16x8_t mat_a0 = vld1q_f16(a + dy_offset);
|
||||
mat_c0 = vfmaq_f16(mat_c0, mat_b0, mat_a0);
|
||||
vst1q_f16(c + dx_offset, mat_c0);
|
||||
}
|
||||
#else
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int dx_offset = b * in_size + offset;
|
||||
int dy_offset = b * out_size;
|
||||
for (int j_i = 0; j_i < C8NUM; j_i++) {
|
||||
c[dx_offset + j_i] += a[dy_offset + j_i] * mat_b[j_i][k];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return j;
|
||||
}
|
||||
|
||||
static int ConvDwInputGrad4(const float16_t *dy, const float16_t *w, float16_t *dx, int start, int end,
|
||||
const ConvParameter *conv_param) {
|
||||
int in_h = conv_param->input_h_;
|
||||
int in_w = conv_param->input_w_;
|
||||
int out_w = conv_param->output_w_;
|
||||
int out_h = conv_param->output_h_;
|
||||
int out_ch = conv_param->output_channel_;
|
||||
int in_ch = conv_param->input_channel_;
|
||||
int out_spatial = conv_param->output_h_ * conv_param->output_w_;
|
||||
int k_h = conv_param->kernel_h_;
|
||||
int k_w = conv_param->kernel_w_;
|
||||
int k_spatial = k_h * k_w;
|
||||
int batch = conv_param->input_batch_;
|
||||
int in_size = in_h * in_w * in_ch;
|
||||
int out_size = out_h * out_w * out_ch;
|
||||
|
||||
int j = start;
|
||||
for (; j <= (end - C4NUM); j += C4NUM) {
|
||||
|
@ -43,38 +181,65 @@ int ConvDwInputGradFp16(const float16_t *dy, const float16_t *w, float16_t *dx,
|
|||
|
||||
for (int si = 0; si < out_spatial; si++) {
|
||||
const float16_t *a = dy + j + si * out_ch;
|
||||
#ifdef ENABLE_ARM
|
||||
float16x4_t mat_a = vld1_f16(a);
|
||||
#else
|
||||
float16_t mat_a[4] = {a[0], a[1], a[2], a[3]};
|
||||
#endif
|
||||
int output_row = (si) / out_w;
|
||||
int output_col = (si) % out_w;
|
||||
int row_stride_offset = -conv_param->pad_u_ + output_row * conv_param->stride_h_;
|
||||
int col_stride_offset = -conv_param->pad_l_ + output_col * conv_param->stride_w_;
|
||||
for (int k = 0; k < k_spatial; k++) {
|
||||
int row_stride_offset = output_row * conv_param->stride_h_;
|
||||
int col_stride_offset = output_col * conv_param->stride_w_;
|
||||
int kernel_row = k / k_w;
|
||||
int kernel_col = k % k_w;
|
||||
int input_row = -conv_param->pad_u_ + kernel_row * conv_param->dilation_h_ + row_stride_offset;
|
||||
int input_col = -conv_param->pad_l_ + kernel_col * conv_param->dilation_w_ + col_stride_offset;
|
||||
int input_row = kernel_row * conv_param->dilation_h_ + row_stride_offset;
|
||||
int input_col = kernel_col * conv_param->dilation_w_ + col_stride_offset;
|
||||
if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) {
|
||||
int offset = (input_row * in_w + input_col) * in_ch;
|
||||
#ifdef ENABLE_ARM
|
||||
float16x4_t mat_b = {mat_b_0[k], mat_b_1[k], mat_b_2[k], mat_b_3[k]};
|
||||
float16x4_t mat_c = vld1_f16(c + offset);
|
||||
mat_c = vfma_f16(mat_c, mat_b, mat_a);
|
||||
vst1_f16(c + offset, mat_c);
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int dx_offset = b * in_size + offset;
|
||||
int dy_offset = b * out_size;
|
||||
float16x4_t mat_c = vld1_f16(c + dx_offset);
|
||||
float16x4_t mat_a = vld1_f16(a + dy_offset);
|
||||
mat_c = vfma_f16(mat_c, mat_b, mat_a);
|
||||
vst1_f16(c + dx_offset, mat_c);
|
||||
}
|
||||
#else
|
||||
c[offset + 0] += mat_a[0] * mat_b_0[k];
|
||||
c[offset + 1] += mat_a[1] * mat_b_1[k];
|
||||
c[offset + 2] += mat_a[2] * mat_b_2[k];
|
||||
c[offset + 3] += mat_a[3] * mat_b_3[k];
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int dx_offset = b * in_size + offset;
|
||||
int dy_offset = b * out_size;
|
||||
c[dx_offset + 0] += a[dy_offset + 0] * mat_b_0[k];
|
||||
c[dx_offset + 1] += a[dy_offset + 1] * mat_b_1[k];
|
||||
c[dx_offset + 2] += a[dy_offset + 2] * mat_b_2[k];
|
||||
c[dx_offset + 3] += a[dy_offset + 3] * mat_b_3[k];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return j;
|
||||
}
|
||||
|
||||
int ConvDwInputGradFp16(const float16_t *dy, const float16_t *w, float16_t *dx, int start, int count,
|
||||
const ConvParameter *conv_param) {
|
||||
int in_h = conv_param->input_h_;
|
||||
int in_w = conv_param->input_w_;
|
||||
int out_w = conv_param->output_w_;
|
||||
int out_h = conv_param->output_h_;
|
||||
int out_ch = conv_param->output_channel_;
|
||||
int in_ch = conv_param->input_channel_;
|
||||
int out_spatial = conv_param->output_h_ * conv_param->output_w_;
|
||||
int k_h = conv_param->kernel_h_;
|
||||
int k_w = conv_param->kernel_w_;
|
||||
int k_spatial = k_h * k_w;
|
||||
int end = start + count;
|
||||
int batch = conv_param->input_batch_;
|
||||
int in_size = in_h * in_w * in_ch;
|
||||
int out_size = out_h * out_w * out_ch;
|
||||
|
||||
int j = start;
|
||||
j = ConvDwInputGrad16(dy, w, dx, j, end, conv_param);
|
||||
j = ConvDwInputGrad8(dy, w, dx, j, end, conv_param);
|
||||
j = ConvDwInputGrad4(dy, w, dx, j, end, conv_param);
|
||||
for (; j < end; j++) {
|
||||
float16_t *c = dx + j;
|
||||
const float16_t *b = w + j * k_spatial;
|
||||
|
@ -82,16 +247,18 @@ int ConvDwInputGradFp16(const float16_t *dy, const float16_t *w, float16_t *dx,
|
|||
const float16_t *a = dy + j + si * out_ch;
|
||||
int output_row = si / out_w;
|
||||
int output_col = si % out_w;
|
||||
int row_stride_offset = output_row * conv_param->stride_h_;
|
||||
int col_stride_offset = output_col * conv_param->stride_w_;
|
||||
int row_stride_offset = -conv_param->pad_u_ + output_row * conv_param->stride_h_;
|
||||
int col_stride_offset = -conv_param->pad_l_ + output_col * conv_param->stride_w_;
|
||||
for (int k = 0; k < k_spatial; k++) {
|
||||
int kernel_row = k / k_w;
|
||||
int kernel_col = k % k_w;
|
||||
int input_row = -conv_param->pad_u_ + kernel_row * conv_param->dilation_h_ + row_stride_offset;
|
||||
int input_col = -conv_param->pad_l_ + kernel_col * conv_param->dilation_w_ + col_stride_offset;
|
||||
int input_row = kernel_row * conv_param->dilation_h_ + row_stride_offset;
|
||||
int input_col = kernel_col * conv_param->dilation_w_ + col_stride_offset;
|
||||
if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) {
|
||||
int offset = (input_row * in_w + input_col) * in_ch;
|
||||
c[offset] += a[0] * b[k];
|
||||
for (int bi = 0; bi < batch; bi++) {
|
||||
c[bi * in_size + offset + 0] += a[0 + bi * out_size] * b[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -135,8 +135,20 @@ static void Row2Col16Block16(const float16_t *src_ptr, float16_t *dst_ptr, size_
|
|||
void AddMatrixFp16(const float16_t *restrict v1, float16_t *restrict v2, float16_t beta, int row, int col, int stride) {
|
||||
const float16_t *src_ptr = v1;
|
||||
float16_t *dst_ptr = v2;
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t beta_0 = vdupq_n_f16(beta);
|
||||
#endif
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int c = 0; c < col; c++) {
|
||||
int c = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
for (; c <= (col - C8NUM); c += C8NUM) {
|
||||
float16x8_t dst_0 = vld1q_f16(dst_ptr + c);
|
||||
float16x8_t src_0 = vld1q_f16(src_ptr + c);
|
||||
float16x8_t sum_0 = vfmaq_f16(dst_0, beta_0, src_0);
|
||||
vst1q_f16(dst_ptr + c, sum_0);
|
||||
}
|
||||
#endif
|
||||
for (; c < col; c++) {
|
||||
dst_ptr[c] += beta * src_ptr[c];
|
||||
}
|
||||
src_ptr += stride;
|
||||
|
|
|
@ -252,7 +252,7 @@ int CheckTensorsInvalid(const std::vector<Tensor *> &tensors) {
|
|||
MS_LOG(ERROR) << "Graph input tensor data is nullptr " << tensor->tensor_name();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto shape = tensor->shape();
|
||||
const auto &shape = tensor->shape();
|
||||
bool valid = all_of(shape.begin(), shape.end(), [](int i) { return i >= 0; });
|
||||
if (!valid) {
|
||||
MS_LOG(ERROR) << "The shape of tensor contains negative dimension,"
|
||||
|
@ -260,7 +260,7 @@ int CheckTensorsInvalid(const std::vector<Tensor *> &tensors) {
|
|||
return RET_ERROR;
|
||||
}
|
||||
if (tensor->format() != mindspore::NHWC) {
|
||||
MS_LOG(ERROR) << "model input's format mey be changed, which should keep default value NHWC";
|
||||
MS_LOG(ERROR) << "model input's format may be changed, which should keep default value NHWC";
|
||||
return RET_FORMAT_ERR;
|
||||
}
|
||||
if (tensor->data_c() == nullptr) {
|
||||
|
|
|
@ -43,7 +43,7 @@ class ConvolutionGradFilterCPUKernelFp16 : public InnerKernel {
|
|||
#ifdef ENABLE_ARM32
|
||||
const int chunk_ = C4NUM * 2;
|
||||
#else
|
||||
const int chunk_ = C12NUM * 2;
|
||||
const int chunk_ = C32NUM;
|
||||
#endif
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -114,10 +114,7 @@ int ConvolutionGradInputCPUKernelFp16::Execute(int task_id) {
|
|||
count = MSMIN(stride, groups - stride * task_id);
|
||||
count = (count < 0) ? 0 : count;
|
||||
start = stride * task_id;
|
||||
for (i = 0; i < batch; ++i) {
|
||||
ConvDwInputGradFp16(dy_addr + (i * groups) * m * k, w_addr, dx_addr + (i * groups) * in_h * in_w, start, count,
|
||||
conv_param);
|
||||
}
|
||||
ConvDwInputGradFp16(dy_addr, w_addr, dx_addr, start, count, conv_param);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ class ConvolutionGradInputCPUKernelFp16 : public InnerKernel {
|
|||
size_t mat_alloc_ = 0;
|
||||
bool do_img2col_ = true;
|
||||
bool do_dw_ = false;
|
||||
const int chunk_ = C12NUM;
|
||||
const int chunk_ = C16NUM;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -126,22 +126,16 @@ int TrainSession::InitCallBack() {
|
|||
if (node_type == schema::PrimitiveType_Cast) {
|
||||
return false;
|
||||
}
|
||||
TensorPtrVector inputs;
|
||||
auto in_size = node->input_indices_.size();
|
||||
inputs.reserve(in_size);
|
||||
for (size_t k = 0; k < in_size; ++k) {
|
||||
inputs.emplace_back(model_->all_tensors_.at(node->input_indices_[k]));
|
||||
bool force_fp16 = false;
|
||||
for (std::size_t k = 0; k < in_size; k++) {
|
||||
schema::Tensor *tensor = model_->all_tensors_.at(node->input_indices_[k]);
|
||||
if ((tensor->dataType() == kNumberTypeFloat16) && (tensor->nodeType() == NodeType_ValueNode)) {
|
||||
force_fp16 = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
bool force_fp16 = (inputs.size() > 0 && std::any_of(inputs.begin(), inputs.end(),
|
||||
[&](schema::Tensor *tensor) {
|
||||
return ((tensor->dataType() == kNumberTypeFloat16) &&
|
||||
(tensor->nodeType() == NodeType_ValueNode));
|
||||
}))
|
||||
? true
|
||||
: false;
|
||||
inputs.clear();
|
||||
auto node_name = node->name_;
|
||||
|
||||
const auto &node_name = node->name_;
|
||||
bool is_fp16 = true;
|
||||
if (!force_fp16) {
|
||||
// optimizer runs in fp32
|
||||
|
@ -256,18 +250,16 @@ bool TrainSession::IsLossTensor(Tensor *tensor) {
|
|||
|
||||
bool TrainSession::AllInputsNeedScale(kernel::LiteKernel *kernel) {
|
||||
auto type = kernel->type();
|
||||
auto is_scale = false;
|
||||
|
||||
for (auto &tensor : kernel->in_tensors()) {
|
||||
is_scale |= tensor->IsScale();
|
||||
}
|
||||
|
||||
bool is_scale = false;
|
||||
switch (type) {
|
||||
case schema::PrimitiveType_AbsGrad:
|
||||
case schema::PrimitiveType_AddFusion:
|
||||
case schema::PrimitiveType_SubFusion:
|
||||
case schema::PrimitiveType_AddN:
|
||||
return (true && is_scale);
|
||||
for (auto &tensor : kernel->in_tensors()) {
|
||||
is_scale = is_scale || tensor->IsScale();
|
||||
}
|
||||
return (is_scale);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
@ -791,7 +783,6 @@ session::LiteSession *session::LiteSession::CreateTrainSession(const std::string
|
|||
MS_LOG(ERROR) << "Could not switch to Train Modei " << train_mode;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return session.release();
|
||||
}
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ if [[ $backend == "all" || $backend == "arm64_cpu" || $backend == "arm64_fp32" |
|
|||
exit 1
|
||||
fi
|
||||
# run train
|
||||
sh $cur_path/scripts/run_net_train.sh -r $release_path -m $models_path -d $device_id -e $backend
|
||||
sh $cur_path/scripts/run_net_train.sh -r $release_path -m ${models_path}/../../models_train -d $device_id -e $backend
|
||||
arm64_status=$?
|
||||
if [[ $arm64_status -ne 0 ]]; then
|
||||
echo "Run arm64 train failed"
|
||||
|
@ -54,7 +54,7 @@ if [[ $backend == "all" || $backend == "arm32_cpu" || $backend == "arm32_fp32" |
|
|||
exit 1
|
||||
fi
|
||||
# run train
|
||||
sh $cur_path/scripts/run_net_train.sh -r $release_path -m $models_path -d $device_id -e $backend
|
||||
sh $cur_path/scripts/run_net_train.sh -r $release_path -m ${models_path}/../../models_train -d $device_id -e $backend
|
||||
arm32_status=$?
|
||||
if [[ $arm32_status -ne 0 ]]; then
|
||||
echo "Run arm32 train failed"
|
||||
|
@ -89,7 +89,7 @@ if [[ $backend == "all" || $backend == "x86-all" || $backend == "x86" || $backen
|
|||
exit 1
|
||||
fi
|
||||
# run train
|
||||
sh $cur_path/scripts/run_net_train.sh -r $release_path -m $models_path -e $backend
|
||||
sh $cur_path/scripts/run_net_train.sh -r $release_path -m ${models_path}/../../models_train -e $backend
|
||||
x86_status=$?
|
||||
if [[ $x86_status -ne 0 ]]; then
|
||||
echo "Run x86 train failed"
|
||||
|
|
|
@ -289,7 +289,7 @@ while getopts "r:c:m:d:i:e:v:t:q:D:" opt; do
|
|||
echo "release_path is ${OPTARG}"
|
||||
;;
|
||||
m)
|
||||
models_path=${OPTARG}"/../../models_train"
|
||||
models_path=${OPTARG}
|
||||
echo "models_path is ${OPTARG}"
|
||||
;;
|
||||
c)
|
||||
|
|
Loading…
Reference in New Issue