From ae389ae1f94c232a840631c8fc7dbbd3a08e6f3f Mon Sep 17 00:00:00 2001 From: yoni Date: Tue, 2 Feb 2021 17:39:09 +0200 Subject: [PATCH] tod new ops, performance improvment and bug fix --- mindspore/lite/nnacl/fp32_grad/batch_norm.c | 32 ++ mindspore/lite/nnacl/fp32_grad/batch_norm.h | 4 + .../nnacl/fp32_grad/convolution_grad_filter.c | 379 ++++++++++++++++++ .../nnacl/fp32_grad/convolution_grad_filter.h | 32 ++ mindspore/lite/nnacl/fp32_grad/pack_ext.c | 129 +++--- mindspore/lite/nnacl/fp32_grad/pack_ext.h | 3 + mindspore/lite/nnacl/fp32_grad/pooling_grad.c | 147 +++++-- mindspore/lite/nnacl/fp32_grad/pooling_grad.h | 6 +- .../lite/nnacl/fp32_grad/strided_slice_grad.c | 61 +++ .../lite/nnacl/fp32_grad/strided_slice_grad.h | 30 ++ mindspore/lite/nnacl/op_base.h | 1 + mindspore/lite/schema/model.fbs | 1 + mindspore/lite/schema/ops.fbs | 12 + mindspore/lite/src/ops/flatten_grad.cc | 13 +- mindspore/lite/src/ops/pooling_grad.cc | 2 + mindspore/lite/src/ops/primitive_c.cc | 5 + mindspore/lite/src/ops/strided_slice_grad.cc | 266 ++++++++++++ mindspore/lite/src/ops/strided_slice_grad.h | 64 +++ .../arm/fp32/convolution_winograd_fp32.cc | 8 +- .../kernel/arm/fp32/fused_batchnorm_fp32.cc | 4 - .../kernel/arm/fp32_grad/apply_momentum.cc | 5 +- .../runtime/kernel/arm/fp32_grad/bn_grad.cc | 81 +++- .../runtime/kernel/arm/fp32_grad/bn_grad.h | 5 + .../kernel/arm/fp32_grad/convolution.cc | 33 +- .../kernel/arm/fp32_grad/convolution.h | 1 + .../arm/fp32_grad/convolution_grad_filter.cc | 56 ++- .../arm/fp32_grad/convolution_grad_filter.h | 1 + .../kernel/arm/fp32_grad/pooling_grad.cc | 16 +- .../kernel/arm/fp32_grad/pooling_grad.h | 1 + .../arm/fp32_grad/strided_slice_grad.cc | 150 +++++++ .../kernel/arm/fp32_grad/strided_slice_grad.h | 50 +++ .../src/train/train_populate_parameter.cc | 4 + mindspore/lite/src/train/transfer_session.h | 72 ++++ mindspore/lite/test/models_ms_train.cfg | 12 +- mindspore/lite/test/run_net_export.sh | 82 ++++ mindspore/lite/test/run_net_train.sh | 35 +- .../arm/fp32_grad/pooling_grad_fp32_tests.cc | 16 +- .../lite/tools/benchmark_train/net_train.cc | 1 - .../lite/tools/benchmark_train/net_train.h | 2 +- mindspore/lite/tools/common/node_util.cc | 321 +++++---------- 40 files changed, 1713 insertions(+), 430 deletions(-) create mode 100644 mindspore/lite/nnacl/fp32_grad/convolution_grad_filter.c create mode 100644 mindspore/lite/nnacl/fp32_grad/convolution_grad_filter.h create mode 100644 mindspore/lite/nnacl/fp32_grad/strided_slice_grad.c create mode 100644 mindspore/lite/nnacl/fp32_grad/strided_slice_grad.h create mode 100644 mindspore/lite/src/ops/strided_slice_grad.cc create mode 100644 mindspore/lite/src/ops/strided_slice_grad.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32_grad/strided_slice_grad.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32_grad/strided_slice_grad.h create mode 100644 mindspore/lite/src/train/transfer_session.h create mode 100755 mindspore/lite/test/run_net_export.sh diff --git a/mindspore/lite/nnacl/fp32_grad/batch_norm.c b/mindspore/lite/nnacl/fp32_grad/batch_norm.c index add506cfb7e..69ff1b03235 100644 --- a/mindspore/lite/nnacl/fp32_grad/batch_norm.c +++ b/mindspore/lite/nnacl/fp32_grad/batch_norm.c @@ -50,3 +50,35 @@ void backwardAll(const float *restrict in, const float *restrict yt, const float } } } +void backwardP1(const float *restrict in, const float *restrict yt, const float *restrict mean, + const float *restrict invar, const float *restrict scale, int size, int ch, float *restrict dxhat_sum, + float *restrict dxhathat_sum, float *restrict dbias, float *restrict dscale) { + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + int ix = i * ch + c; + dbias[c] += yt[ix]; + // dscale + float x_hat = (in[ix] - mean[c]) * invar[c]; + dscale[c] += (yt[ix] * x_hat); + // dx_1 + float dx_hat = yt[ix] * scale[c]; + dxhat_sum[c] += dx_hat; + dxhathat_sum[c] += dx_hat * x_hat; + } + } +} + +void backwardP2(const float *restrict in, const float *restrict yt, const float *restrict mean, + const float *restrict invar, const float *restrict scale, int size, int total_size, int ch, + const float *dxhat_sum, const float *dxhathat_sum, float *restrict dx) { + float N = (float)total_size; + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + // dx_2 + int ix = i * ch + c; + float x_hat = (in[ix] - mean[c]) * invar[c]; + float dx_hat = yt[ix] * scale[c]; + dx[ix] = 1.0f / N * (invar[c]) * (N * dx_hat - dxhat_sum[c] - x_hat * dxhathat_sum[c]); + } + } +} diff --git a/mindspore/lite/nnacl/fp32_grad/batch_norm.h b/mindspore/lite/nnacl/fp32_grad/batch_norm.h index b3728d6d755..ca64c6a5e63 100644 --- a/mindspore/lite/nnacl/fp32_grad/batch_norm.h +++ b/mindspore/lite/nnacl/fp32_grad/batch_norm.h @@ -32,6 +32,10 @@ extern "C" { void var2Invar(float *save_var, int size, float eps); void backwardAll(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size, int ch, float *dxhat_sum, float *dxhathat_sum, float *dbias, float *dscale, float *dx); +void backwardP1(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size, + int ch, float *dxhat_sum, float *dxhathat_sum, float *dbias, float *dscale); +void backwardP2(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size, + int total_size, int ch, const float *dxhat_sum, const float *dxhathat_sum, float *dx); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/fp32_grad/convolution_grad_filter.c b/mindspore/lite/nnacl/fp32_grad/convolution_grad_filter.c new file mode 100644 index 00000000000..3fdf8570a4b --- /dev/null +++ b/mindspore/lite/nnacl/fp32_grad/convolution_grad_filter.c @@ -0,0 +1,379 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32_grad/convolution_grad_filter.h" +#ifdef ENABLE_ARM +#include +#endif + +#ifdef ENABLE_ARM +static int FilterGrad16Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + for (; i_c < (out_ch & ~15); i_c += 16) { + float32x4_t sum_03_4 = vdupq_n_f32(0.0f); + float32x4_t sum_47_4 = vdupq_n_f32(0.0f); + float32x4_t sum_9x_4 = vdupq_n_f32(0.0f); + float32x4_t sum_12x_4 = vdupq_n_f32(0.0f); + for (int b = 0; b < batch; ++b) { + const float *x_addr = &x[b * x_size]; + const float *dy_addr = &dy[b * y_size]; + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float32x4_t x_03_4 = vld1q_f32(x_addr + offset_x); + float32x4_t dy_03_4 = vld1q_f32(dy_addr + offset_dy); + sum_03_4 = vmlaq_f32(sum_03_4, x_03_4, dy_03_4); + + float32x4_t x_47_4 = vld1q_f32(x_addr + offset_x + 4); + float32x4_t dy_47_4 = vld1q_f32(dy_addr + offset_dy + 4); + sum_47_4 = vmlaq_f32(sum_47_4, x_47_4, dy_47_4); + + float32x4_t x_9x_4 = vld1q_f32(x_addr + offset_x + 8); + float32x4_t dy_9x_4 = vld1q_f32(dy_addr + offset_dy + 8); + sum_9x_4 = vmlaq_f32(sum_9x_4, x_9x_4, dy_9x_4); + + float32x4_t x_12x_4 = vld1q_f32(x_addr + offset_x + 12); + float32x4_t dy_12x_4 = vld1q_f32(dy_addr + offset_dy + 12); + sum_12x_4 = vmlaq_f32(sum_12x_4, x_12x_4, dy_12x_4); + } + } + } + dw[(i_c + 0) * k_spatial + k_idx] = sum_03_4[0]; + dw[(i_c + 1) * k_spatial + k_idx] = sum_03_4[1]; + dw[(i_c + 2) * k_spatial + k_idx] = sum_03_4[2]; + dw[(i_c + 3) * k_spatial + k_idx] = sum_03_4[3]; + + dw[(i_c + 4) * k_spatial + k_idx] = sum_47_4[0]; + dw[(i_c + 5) * k_spatial + k_idx] = sum_47_4[1]; + dw[(i_c + 6) * k_spatial + k_idx] = sum_47_4[2]; + dw[(i_c + 7) * k_spatial + k_idx] = sum_47_4[3]; + + dw[(i_c + 8) * k_spatial + k_idx] = sum_9x_4[0]; + dw[(i_c + 9) * k_spatial + k_idx] = sum_9x_4[1]; + dw[(i_c + 10) * k_spatial + k_idx] = sum_9x_4[2]; + dw[(i_c + 11) * k_spatial + k_idx] = sum_9x_4[3]; + + dw[(i_c + 12) * k_spatial + k_idx] = sum_12x_4[0]; + dw[(i_c + 13) * k_spatial + k_idx] = sum_12x_4[1]; + dw[(i_c + 14) * k_spatial + k_idx] = sum_12x_4[2]; + dw[(i_c + 15) * k_spatial + k_idx] = sum_12x_4[3]; + } + return i_c; +} + +static int FilterGrad12Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + if ((out_ch - i_c) >= 12) { + float32x4_t sum_03_4 = vdupq_n_f32(0.0f); + float32x4_t sum_47_4 = vdupq_n_f32(0.0f); + float32x4_t sum_9x_4 = vdupq_n_f32(0.0f); + for (int b = 0; b < batch; ++b) { + const float *x_addr = &x[b * x_size]; + const float *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float32x4_t x_03_4 = vld1q_f32(x_addr + offset_x); + float32x4_t dy_03_4 = vld1q_f32(dy_addr + offset_dy); + sum_03_4 = vmlaq_f32(sum_03_4, x_03_4, dy_03_4); + + float32x4_t x_47_4 = vld1q_f32(x_addr + offset_x + 4); + float32x4_t dy_47_4 = vld1q_f32(dy_addr + offset_dy + 4); + sum_47_4 = vmlaq_f32(sum_47_4, x_47_4, dy_47_4); + + float32x4_t x_9x_4 = vld1q_f32(x_addr + offset_x + 8); + float32x4_t dy_9x_4 = vld1q_f32(dy_addr + offset_dy + 8); + sum_9x_4 = vmlaq_f32(sum_9x_4, x_9x_4, dy_9x_4); + } + } + } + dw[(i_c + 0) * k_spatial + k_idx] = sum_03_4[0]; + dw[(i_c + 1) * k_spatial + k_idx] = sum_03_4[1]; + dw[(i_c + 2) * k_spatial + k_idx] = sum_03_4[2]; + dw[(i_c + 3) * k_spatial + k_idx] = sum_03_4[3]; + + dw[(i_c + 4) * k_spatial + k_idx] = sum_47_4[0]; + dw[(i_c + 5) * k_spatial + k_idx] = sum_47_4[1]; + dw[(i_c + 6) * k_spatial + k_idx] = sum_47_4[2]; + dw[(i_c + 7) * k_spatial + k_idx] = sum_47_4[3]; + + dw[(i_c + 8) * k_spatial + k_idx] = sum_9x_4[0]; + dw[(i_c + 9) * k_spatial + k_idx] = sum_9x_4[1]; + dw[(i_c + 10) * k_spatial + k_idx] = sum_9x_4[2]; + dw[(i_c + 11) * k_spatial + k_idx] = sum_9x_4[3]; + + i_c += 12; + } + return i_c; +} + +static int FilterGrad8Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + + if ((out_ch - i_c) >= 8) { + float32x4_t sum_03_4 = vdupq_n_f32(0.0f); + float32x4_t sum_47_4 = vdupq_n_f32(0.0f); + for (int b = 0; b < batch; ++b) { + const float *x_addr = &x[b * x_size]; + const float *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float32x4_t x_03_4 = vld1q_f32(x_addr + offset_x); + float32x4_t dy_03_4 = vld1q_f32(dy_addr + offset_dy); + sum_03_4 = vmlaq_f32(sum_03_4, x_03_4, dy_03_4); + + float32x4_t x_47_4 = vld1q_f32(x_addr + offset_x + 4); + float32x4_t dy_47_4 = vld1q_f32(dy_addr + offset_dy + 4); + sum_47_4 = vmlaq_f32(sum_47_4, x_47_4, dy_47_4); + } + } + } + dw[(i_c + 0) * k_spatial + k_idx] = sum_03_4[0]; + dw[(i_c + 1) * k_spatial + k_idx] = sum_03_4[1]; + dw[(i_c + 2) * k_spatial + k_idx] = sum_03_4[2]; + dw[(i_c + 3) * k_spatial + k_idx] = sum_03_4[3]; + + dw[(i_c + 4) * k_spatial + k_idx] = sum_47_4[0]; + dw[(i_c + 5) * k_spatial + k_idx] = sum_47_4[1]; + dw[(i_c + 6) * k_spatial + k_idx] = sum_47_4[2]; + dw[(i_c + 7) * k_spatial + k_idx] = sum_47_4[3]; + i_c += 8; + } + return i_c; +} +static int FilterGrad4Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + if ((out_ch - i_c) >= 4) { + float32x4_t sum_4 = vdupq_n_f32(0.0f); + + for (int b = 0; b < batch; ++b) { + const float *x_addr = &x[b * x_size]; + const float *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float32x4_t x_4 = vld1q_f32(x_addr + offset_x); + float32x4_t dy_4 = vld1q_f32(dy_addr + offset_dy); + sum_4 = vmlaq_f32(sum_4, x_4, dy_4); + } + } + } + dw[(i_c + 0) * k_spatial + k_idx] = sum_4[0]; + dw[(i_c + 1) * k_spatial + k_idx] = sum_4[1]; + dw[(i_c + 2) * k_spatial + k_idx] = sum_4[2]; + dw[(i_c + 3) * k_spatial + k_idx] = sum_4[3]; + i_c += 4; + } + return i_c; +} + +static int Filtergrad2Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + + if ((out_ch - i_c) >= 2) { + float32x2_t sum_2 = vdup_n_f32(0.0f); + for (int b = 0; b < batch; ++b) { + const float *x_addr = &x[b * x_size]; + const float *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float32x2_t x_4 = vld1_f32(x_addr + offset_x); + float32x2_t dy_4 = vld1_f32(dy_addr + offset_dy); + sum_2 = vmla_f32(sum_2, x_4, dy_4); + } + } + } + dw[(i_c + 0) * k_spatial + k_idx] = sum_2[0]; + dw[(i_c + 1) * k_spatial + k_idx] = sum_2[1]; + i_c += 2; + } + return i_c += 2; +} +#endif +int ConvDwFilterGrad(const float *x, const float *dy, float *dw, int start, int count, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + + for (int i_k = 0; i_k < count; i_k++) { + int k_idx = start + i_k; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + int i_c = 0; +#ifdef ENABLE_ARM + i_c = FilterGrad16Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = FilterGrad12Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = FilterGrad8Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = FilterGrad4Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = Filtergrad2Arm(x, dy, i_c, k_idx, dw, conv_param); +#endif + for (; i_c < out_ch; i_c++) { + float sum = 0; + for (int b = 0; b < batch; ++b) { + const float *x_addr = &x[b * x_size]; + const float *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + sum += x_addr[offset_x] * dy_addr[offset_dy]; + } + } + } + dw[i_c * k_spatial + k_idx] = sum; + } + } + return 0; +} diff --git a/mindspore/lite/nnacl/fp32_grad/convolution_grad_filter.h b/mindspore/lite/nnacl/fp32_grad/convolution_grad_filter.h new file mode 100644 index 00000000000..0209f942855 --- /dev/null +++ b/mindspore/lite/nnacl/fp32_grad/convolution_grad_filter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_ +#define MINDSPORE_LITE_NNACL_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_ + +#include +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ConvDwFilterGrad(const float *x, const float *dy, float *dw, int start, int count, const ConvParameter *conv_param); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_NNACL_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_ diff --git a/mindspore/lite/nnacl/fp32_grad/pack_ext.c b/mindspore/lite/nnacl/fp32_grad/pack_ext.c index 660484aecc5..0ac265340ab 100644 --- a/mindspore/lite/nnacl/fp32_grad/pack_ext.c +++ b/mindspore/lite/nnacl/fp32_grad/pack_ext.c @@ -18,6 +18,56 @@ #include "nnacl/fp32_grad/pack_ext.h" #include "nnacl/pack.h" +void RollingIm2ColPackDwUnitFp32(const float *in_data, const ConvParameter *conv_param, float *data_col_orig, + int real_cal_num, int start) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_w = conv_param->output_w_; + + const int channels = conv_param->input_channel_; + const int stride = kernel_h * kernel_w; + + int kernel_row, kernel_col; + + for (int i = 0; i < real_cal_num; i++) { + int block_start = start + i; + int input_h = block_start / output_w * stride_h; + int input_w = block_start % output_w * stride_w; + float *data_col = data_col_orig + i * channels * stride; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + input_h; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + input_w; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + const int offset = (input_row * in_width + input_col) * channels; + for (int c = 0; c < channels; c++) { + data_col[c * stride] = in_data[offset + c]; + } + data_col++; + } else { + for (int c = 0; c < channels; c++) { + data_col[c * stride] = 0; + } + data_col++; + } + } + } + } +} + void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int real_cal_num, int start) { const int pad_left = conv_param->pad_l_; @@ -90,85 +140,6 @@ void RollingIm2ColPackUnitFp32(const float *input_data, const ConvParameter *con rolling_im2col_hwc(input_data, packed_input, conv_param, real_cal_num, block_index); } -void im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, bool transpose) { - const int pad_left = conv_param->pad_l_; - const int pad_up = conv_param->pad_u_; - - const int stride_h = conv_param->stride_h_; - const int stride_w = conv_param->stride_w_; - - const int dilation_h = conv_param->dilation_h_; - const int dilation_w = conv_param->dilation_w_; - - const int kernel_h = conv_param->kernel_h_; - const int kernel_w = conv_param->kernel_w_; - - const int in_height = (transpose) ? conv_param->output_h_ : conv_param->input_h_; - const int in_width = (transpose) ? conv_param->output_w_ : conv_param->input_w_; - - const int output_h = (transpose) ? conv_param->input_h_ : conv_param->output_h_; - const int output_w = (transpose) ? conv_param->input_w_ : conv_param->output_w_; - - const int tot_channels = (transpose) ? conv_param->output_channel_ : conv_param->input_channel_; - const int channels = tot_channels / conv_param->group_; - int channel, kernel_row, kernel_col, output_rows, output_col; - - if (transpose) { - for (channel = 0; channel < channels; channel++) { - for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { - for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { - int input_row = -pad_up + kernel_row * dilation_h; - for (output_rows = output_h; output_rows; output_rows--) { - if (!((unsigned)(input_row) < (unsigned)(in_height))) { - for (output_col = output_w; output_col; output_col--) { - *(data_row++) = 0; - } - } else { - int input_col = -pad_left + kernel_col * dilation_w; - for (output_col = output_w; output_col; output_col--) { - if (((unsigned)(input_col) < (unsigned)(in_width))) { - const int offset = (input_row * in_width + input_col) * tot_channels + channel; - *(data_row++) = in_data[offset]; - } else { - *(data_row++) = 0; - } - input_col += stride_w; - } - } - input_row += stride_h; - } - } - } - } - } else { - for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { - for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { - for (channel = 0; channel < channels; channel++) { - int input_row = -pad_up + kernel_row * dilation_h; - for (output_rows = output_h; output_rows; output_rows--) { - if (!((unsigned)(input_row) < (unsigned)(in_height))) { - for (output_col = output_w; output_col; output_col--) { - *(data_row++) = 0; - } - } else { - int input_col = -pad_left + kernel_col * dilation_w; - for (output_col = output_w; output_col; output_col--) { - if (((unsigned)(input_col) < (unsigned)(in_width))) { - const int offset = (input_row * in_width + input_col) * tot_channels + channel; - *(data_row++) = in_data[offset]; - } else { - *(data_row++) = 0; - } - input_col += stride_w; - } - } - input_row += stride_h; - } - } - } - } - } -} void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start) { const int pad_left = conv_param->pad_l_; const int pad_up = conv_param->pad_u_; diff --git a/mindspore/lite/nnacl/fp32_grad/pack_ext.h b/mindspore/lite/nnacl/fp32_grad/pack_ext.h index c2095a75035..ca8b67336d4 100644 --- a/mindspore/lite/nnacl/fp32_grad/pack_ext.h +++ b/mindspore/lite/nnacl/fp32_grad/pack_ext.h @@ -26,6 +26,9 @@ extern "C" { void RollingIm2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num, int block_index); +void RollingIm2ColPackDwUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, + int real_cal_num, int block_index); + void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int rows, int start); void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start); void rolling_col2im_hwc(const float *data_col, float *data_im, const ConvParameter *conv_param, int rows, int start); diff --git a/mindspore/lite/nnacl/fp32_grad/pooling_grad.c b/mindspore/lite/nnacl/fp32_grad/pooling_grad.c index 9ecd19d7a99..78ad6b5eda0 100644 --- a/mindspore/lite/nnacl/fp32_grad/pooling_grad.c +++ b/mindspore/lite/nnacl/fp32_grad/pooling_grad.c @@ -18,7 +18,7 @@ #include #include "nnacl/fp32_grad/pooling_grad.h" -void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id) { +void AvgPoolingGrad(const float *input_ptr, float *output_ptr, int count, PoolingParameter *pooling_param) { int stride_w = pooling_param->stride_w_; int stride_h = pooling_param->stride_h_; int pad_w = pooling_param->pad_l_; @@ -30,29 +30,58 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter int in_h = pooling_param->input_h_; int output_w = pooling_param->output_w_; int output_h = pooling_param->output_h_; - int output_batch = pooling_param->output_batch_; - memset(output_ptr, 0, in_h * in_w * channel * output_batch * sizeof(float)); - float kk = (float)(win_h * win_w); - for (int ib = 0; ib < output_batch; ib++) { + const float kk = 1.0f / (float)(win_h * win_w); +#if ENABLE_ARM + const float32x4_t factor = vdupq_n_f32(kk); +#endif + for (int ib = 0; ib < count; ib++) { float *out = &output_ptr[(ib * in_h * in_w * channel)]; const float *inPtr = &input_ptr[(ib * output_h * output_w * channel)]; // iterate over yt for (int yh = 0; yh < output_h; yh++) { + int over_h = pad_h - yh * stride_h; + int kh_s = MSMAX(0, over_h); + int kh_e = MSMIN(win_h, in_h + over_h); for (int yw = 0; yw < output_w; yw++) { - for (int ic = 0; ic < channel; ic++) { + int over_w = pad_w - yw * stride_w; + int kw_s = MSMAX(0, over_w); + int kw_e = MSMIN(win_w, in_w + over_w); + int ic = 0; + for (; ic < channel - 4; ic += 4) { int idx = (yw + yh * output_w) * channel + ic; - float delta = inPtr[idx] / kk; - for (int kh = 0; kh < win_h; kh++) { +#ifdef ENABLE_ARM + float32x4_t in = vld1q_f32(inPtr + idx); + float32x4_t delta = vmulq_f32(in, factor); +#else + float delta[4] = {inPtr[idx], inPtr[idx + 1], inPtr[idx + 2], inPtr[idx + 3]}; + for (int i = 0; i < 4; i++) delta[i] *= kk; +#endif + for (int kh = kh_s; kh < kh_e; kh++) { int xh = yh * stride_h + kh - pad_h; - if ((xh < 0) || (xh >= in_h)) { - continue; - } - for (int kw = 0; kw < win_w; kw++) { + for (int kw = kw_s; kw < kw_e; kw++) { int xw = yw * stride_w + kw - pad_w; - if ((xw < 0) || (xw >= in_w)) { - continue; +#ifdef ENABLE_ARM + float *out_vec = out + (xw + in_w * xh) * channel + ic; + float32x4_t outr = vld1q_f32(out + (xw + in_w * xh) * channel + ic); + float32x4_t outs = vaddq_s32(outr, delta); + vst1q_f32(out_vec, outs); +#else + + for (int i = 0; i < 4; i++) { + out[(xw + in_w * xh) * channel + ic + i] += ((float *)&delta)[i]; } +#endif + } + } + } + for (; ic < channel; ic++) { + int idx = (yw + yh * output_w) * channel + ic; + float delta = inPtr[idx] * kk; + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_s; kw < kw_e; kw++) { + int xw = yw * stride_w + kw - pad_w; out[(xw + in_w * xh) * channel + ic] += delta; } } @@ -62,8 +91,17 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter } } -void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy_ptr, float *output_ptr, - PoolingParameter *pooling_param, int task_id) { +#ifdef ENABLE_ARM +static int32x4_t MaxIndex(float32x4_t in, float32x4_t *max, int32x4_t index, int32x4_t prev_index) { + uint32x4_t res = vcgtq_f32(in, *max); + uint32x4_t m_index = vbslq_f32(res, index, prev_index); + *max = vbslq_f32(res, in, *max); + return m_index; +} +#endif + +void MaxPoolingGrad(const float *input_ptr, const float *dy_ptr, float *output_ptr, int output_batch, + PoolingParameter *pooling_param) { int stride_w = pooling_param->stride_w_; int stride_h = pooling_param->stride_h_; int pad_w = pooling_param->pad_l_; @@ -75,36 +113,71 @@ void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy int in_h = pooling_param->input_h_; int output_w = pooling_param->output_w_; int output_h = pooling_param->output_h_; - int output_batch = pooling_param->output_batch_; - memset(output_ptr, 0, in_h * in_w * channel * output_batch * sizeof(float)); for (int ib = 0; ib < output_batch; ib++) { float *out = &output_ptr[(ib * in_h * in_w * channel)]; - const float *inPtr = (const float *)(&input_ptr[(ib * in_h * in_w * channel)]); - const float *dyPtr = (const float *)(&dy_ptr[(ib * output_h * output_w * channel)]); - + const float *inPtr = &input_ptr[(ib * in_h * in_w * channel)]; + const float *dyPtr = &dy_ptr[(ib * output_h * output_w * channel)]; for (int yh = 0; yh < output_h; yh++) { + int over_h = pad_h - yh * stride_h; + int kh_s = MSMAX(0, over_h); + int kh_e = MSMIN(win_h, in_h + over_h); for (int yw = 0; yw < output_w; yw++) { - for (int ic = 0; ic < channel; ic++) { + int over_w = pad_w - yw * stride_w; + int kw_s = MSMAX(0, over_w); + int kw_e = MSMIN(win_w, in_w + over_w); + int ic = 0; + for (; ic < channel - 4; ic += 4) { int idx = (yw + yh * output_w) * channel + ic; - - float delta = dyPtr[idx]; +#ifdef ENABLE_ARM + uint32x4_t max_idx = vdupq_n_u32(0); + float32x4_t max_val = vdupq_n_f32(-FLT_MAX); + float32x4_t delta = vld1q_f32(dyPtr + idx); +#else + float delta[4] = {dyPtr[idx], dyPtr[idx + 1], dyPtr[idx + 2], dyPtr[idx + 3]}; + float max_val[4] = {-FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX}; + int max_idx[4] = {0}; +#endif + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_s; kw < kw_e; kw++) { + int xw = yw * stride_w + kw - pad_w; + int val_idx = (xw + in_w * xh) * channel + ic; +#ifdef ENABLE_ARM + unsigned int val_idx_vec[] = {val_idx, val_idx + 1, val_idx + 2, val_idx + 3}; + uint32x4_t index = vld1q_u32(val_idx_vec); + float32x4_t in = vld1q_f32(inPtr + val_idx); + max_idx = MaxIndex(in, &max_val, index, max_idx); +#else + float val[4] = {inPtr[val_idx], inPtr[val_idx + 1], inPtr[val_idx + 2], inPtr[val_idx + 3]}; + for (int i = 0; i < 4; i++) { + if (val[i] > max_val[i]) { + max_val[i] = val[i]; + max_idx[i] = val_idx + i; + } + } +#endif + } + } + for (int i = 0; i < 4; i++) { + out[((int *)&max_idx)[i]] += ((float *)&delta)[i]; + } + } + for (; ic < channel; ic++) { float max_val = -FLT_MAX; int max_idx = 0; - for (int kh = 0; kh < win_h; kh++) { + int idx = (yw + yh * output_w) * channel + ic; + float delta = dyPtr[idx]; + for (int kh = kh_s; kh < kh_e; kh++) { int xh = yh * stride_h + kh - pad_h; - if ((xh < 0) || (xh >= in_h)) { - continue; - } - for (int kw = 0; kw < win_w; kw++) { - int xw = yw * stride_w + kw - pad_w; - if ((xw < 0) || (xw >= in_w)) { - continue; - } - - if (inPtr[(xw + in_w * xh) * channel + ic] > max_val) { - max_val = inPtr[(xw + in_w * xh) * channel + ic]; - max_idx = (xw + in_w * xh) * channel + ic; + int loop = kw_e - kw_s; + for (int kw = 0; kw < loop; kw++) { + int xw = yw * stride_w + kw + kw_s - pad_w; + int val_idx = (xw + in_w * xh) * channel + ic; + float val = inPtr[val_idx]; + if (val > max_val) { + max_val = val; + max_idx = val_idx; } } } diff --git a/mindspore/lite/nnacl/fp32_grad/pooling_grad.h b/mindspore/lite/nnacl/fp32_grad/pooling_grad.h index 0af884b4492..1bfd684ad40 100644 --- a/mindspore/lite/nnacl/fp32_grad/pooling_grad.h +++ b/mindspore/lite/nnacl/fp32_grad/pooling_grad.h @@ -22,9 +22,9 @@ #ifdef __cplusplus extern "C" { #endif -void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id); -void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy_ptr, float *output_ptr, - PoolingParameter *pooling_param, int task_id); +void AvgPoolingGrad(const float *input_ptr, float *output_ptr, int count, PoolingParameter *pooling_param); +void MaxPoolingGrad(const float *input_ptr, const float *dy_ptr, float *output_ptr, int output_batch, + PoolingParameter *pooling_param); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/fp32_grad/strided_slice_grad.c b/mindspore/lite/nnacl/fp32_grad/strided_slice_grad.c new file mode 100644 index 00000000000..02d7ea8c8a5 --- /dev/null +++ b/mindspore/lite/nnacl/fp32_grad/strided_slice_grad.c @@ -0,0 +1,61 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32_grad/strided_slice_grad.h" +#include "nnacl/errorcode.h" + +static size_t CalcIndex(const int *shape, size_t size, int i, size_t pos) { + size_t res = 1; + for (size_t j = 0; j < size; j++) { + res *= shape[(i + 1) + j]; + } + return (pos / res % shape[i]); +} + +int DoStridedSliceGrad(const float *inputs, float *output, const int *dx_shape, StridedSliceParameter *param) { + if (inputs == NULL || output == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->num_axes_ > DIMENSION_7D) { + return NNACL_PARAM_INVALID; + } + + size_t size = 1; + int *s = param->strides_; + int *b = param->begins_; + for (int i = 0; i < DIMENSION_7D; i++) { + size *= param->in_shape_[i]; + } + + for (size_t pos = 0; pos < size; pos++) { + size_t i = CalcIndex(param->in_shape_, 6, 0, pos); + size_t j = CalcIndex(param->in_shape_, 5, 1, pos); + size_t k = CalcIndex(param->in_shape_, 4, 2, pos); + size_t l = CalcIndex(param->in_shape_, 3, 3, pos); + size_t m = CalcIndex(param->in_shape_, 2, 4, pos); + size_t n = CalcIndex(param->in_shape_, 1, 5, pos); + size_t o = CalcIndex(param->in_shape_, 0, 6, pos); + + size_t input_idx = + (i * s[0] + b[0]) * dx_shape[1] * dx_shape[2] * dx_shape[3] * dx_shape[4] * dx_shape[5] * dx_shape[6] + + (j * s[1] + b[1]) * dx_shape[2] * dx_shape[3] * dx_shape[4] * dx_shape[5] * dx_shape[6] + + (k * s[2] + b[2]) * dx_shape[3] * dx_shape[4] * dx_shape[5] * dx_shape[6] + + (l * s[3] + b[3]) * dx_shape[4] * dx_shape[5] * dx_shape[6] + (m * s[4] + b[4]) * dx_shape[5] * dx_shape[6] + + (n * s[5] + b[5]) * dx_shape[6] + (o * s[6] + b[6]); + output[input_idx] = inputs[pos]; + } + return NNACL_OK; +} diff --git a/mindspore/lite/nnacl/fp32_grad/strided_slice_grad.h b/mindspore/lite/nnacl/fp32_grad/strided_slice_grad.h new file mode 100644 index 00000000000..5ed2a68d76c --- /dev/null +++ b/mindspore/lite/nnacl/fp32_grad/strided_slice_grad.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_STRIDED_SLICE_GRAD_H_ +#define MINDSPORE_LITE_NNACL_FP32_GRAD_STRIDED_SLICE_GRAD_H_ + +#include "nnacl/op_base.h" +#include "nnacl/strided_slice_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DoStridedSliceGrad(const float *inputs, float *output, const int *dx_shape, StridedSliceParameter *param); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_NNACL_FP32_GRAD_STRIDED_SLICE_GRAD_H_ diff --git a/mindspore/lite/nnacl/op_base.h b/mindspore/lite/nnacl/op_base.h index d83743778e2..1ea2d711b06 100644 --- a/mindspore/lite/nnacl/op_base.h +++ b/mindspore/lite/nnacl/op_base.h @@ -53,6 +53,7 @@ #define DIMENSION_4D 4 #define DIMENSION_6D 6 +#define DIMENSION_7D 7 #define kInputIndex 0 #define kWeightIndex 1 #define kBiasIndex 2 diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 3b5e082d463..61552ea9dcc 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -273,6 +273,7 @@ union PrimitiveType { RandomStandardNormal, CropAndResize, Erf, + StridedSliceGrad } enum QuantType: int { diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 60dd7d66039..837f49d18b5 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -1259,6 +1259,18 @@ table RandomStandardNormal { table CropAndResize { method : ResizeMethod; extrapolation_value : float; +} + +table StridedSliceGrad { + beginMask: int; + endMask: int; + ellipsisMask: int; + newAxisMask: int; + shrinkAxisMask: int; + begin: [int]; + end: [int]; + stride: [int]; + isScale: [int]; } table Erf { diff --git a/mindspore/lite/src/ops/flatten_grad.cc b/mindspore/lite/src/ops/flatten_grad.cc index bb768b05c0b..f0e52562f37 100644 --- a/mindspore/lite/src/ops/flatten_grad.cc +++ b/mindspore/lite/src/ops/flatten_grad.cc @@ -31,7 +31,7 @@ int FlattenGrad::InferShape(std::vector inputs_, std::vector MS_LOG(ERROR) << "FlattenGrad input or output is null!"; return RET_ERROR; } - if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) { + if (inputs_.size() != kDoubleNum || outputs_.size() != kSingleNum) { MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size(); return RET_INPUT_TENSOR_ERROR; } @@ -42,16 +42,15 @@ int FlattenGrad::InferShape(std::vector inputs_, std::vector return RET_INFER_INVALID; } - auto input_shape = input->shape(); - std::vector output_shape(2); - output_shape.at(0) = input_shape.at(0); - output_shape.at(1) = 1; - for (size_t i = 1; i < input_shape.size(); i++) { - output_shape.at(1) *= input_shape.at(i); + auto output_size = inputs_.at(1)->shape().at(0); + std::vector output_shape(output_size); + for (int i = 0; i < output_size; i++) { + output_shape.at(i) = static_cast(inputs_.at(1)->data_c())[i]; } output->set_shape(output_shape); return RET_OK; } + #ifdef PRIMITIVE_WRITEABLE int FlattenGrad::UnPackAttr(const Primitive &prim, const std::vector &inputs) { if (this->primitive_ == nullptr) { diff --git a/mindspore/lite/src/ops/pooling_grad.cc b/mindspore/lite/src/ops/pooling_grad.cc index da24f23cfce..47825e97090 100644 --- a/mindspore/lite/src/ops/pooling_grad.cc +++ b/mindspore/lite/src/ops/pooling_grad.cc @@ -91,6 +91,8 @@ int PoolingGrad::UnPackAttr(const Primitive &prim, const std::vector attr->poolingMode = schema::PoolMode_MEAN_POOLING; } else if (prim.instance_name() == "AvgPoolGradGpu") { attr->poolingMode = schema::PoolMode_MEAN_POOLING; + } else if (prim.instance_name() == "AvgPoolGradCpu") { + attr->poolingMode = schema::PoolMode_MEAN_POOLING; } else { attr->poolingMode = schema::PoolMode_MAX_POOLING; } diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index bc3c491f7de..8fa9fde0399 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -202,6 +202,7 @@ #include "src/ops/smooth_l1_loss_grad.h" #include "src/ops/sigmoid_cross_entropy_with_logits.h" #include "src/ops/sigmoid_cross_entropy_with_logits_grad.h" +#include "src/ops/strided_slice_grad.h" #endif #endif namespace mindspore { @@ -724,6 +725,8 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Pad") { return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "StridedSliceGrad") { + return NewPrimitiveC(prim, inputs, quantType); #else } else if (op_type == "Conv2DBackpropInput") { return NewPrimitiveC(prim, inputs, quantType); @@ -1102,6 +1105,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new (std::nothrow) SigmoidCrossEntropyWithLogits(primitive); case schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad: return new (std::nothrow) SigmoidCrossEntropyWithLogitsGrad(primitive); + case schema::PrimitiveType_StridedSliceGrad: + return new (std::nothrow) StridedSliceGrad(primitive); #endif default: MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type); diff --git a/mindspore/lite/src/ops/strided_slice_grad.cc b/mindspore/lite/src/ops/strided_slice_grad.cc new file mode 100644 index 00000000000..aaeda6c46ee --- /dev/null +++ b/mindspore/lite/src/ops/strided_slice_grad.cc @@ -0,0 +1,266 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/strided_slice_grad.h" + +#ifndef PRIMITIVE_WRITEABLE +#include "src/ops/ops_register.h" +#endif + +namespace mindspore { +namespace lite { + +#ifdef PRIMITIVE_WRITEABLE +int StridedSliceGrad::GetBeginMask() const { return this->primitive_->value.AsStridedSliceGrad()->beginMask; } +int StridedSliceGrad::GetEndMask() const { return this->primitive_->value.AsStridedSliceGrad()->endMask; } +int StridedSliceGrad::GetEllipsisMask() const { return this->primitive_->value.AsStridedSliceGrad()->ellipsisMask; } +int StridedSliceGrad::GetNewAxisMask() const { return this->primitive_->value.AsStridedSliceGrad()->newAxisMask; } +int StridedSliceGrad::GetShrinkAxisMask() const { return this->primitive_->value.AsStridedSliceGrad()->shrinkAxisMask; } +std::vector StridedSliceGrad::GetBegin() const { return this->primitive_->value.AsStridedSliceGrad()->begin; } +std::vector StridedSliceGrad::GetEnd() const { return this->primitive_->value.AsStridedSliceGrad()->end; } +std::vector StridedSliceGrad::GetStride() const { return this->primitive_->value.AsStridedSliceGrad()->stride; } +std::vector StridedSliceGrad::GetIsScale() const { return this->primitive_->value.AsStridedSliceGrad()->isScale; } + +void StridedSliceGrad::SetBeginMask(int begin_mask) { + this->primitive_->value.AsStridedSliceGrad()->beginMask = begin_mask; +} +void StridedSliceGrad::SetEndMask(int end_mask) { this->primitive_->value.AsStridedSliceGrad()->endMask = end_mask; } +void StridedSliceGrad::SetEllipsisMask(int ellipsis_mask) { + this->primitive_->value.AsStridedSliceGrad()->ellipsisMask = ellipsis_mask; +} +void StridedSliceGrad::SetNewAxisMask(int new_axis_mask) { + this->primitive_->value.AsStridedSliceGrad()->newAxisMask = new_axis_mask; +} +void StridedSliceGrad::SetShrinkAxisMask(int shrink_axis_mask) { + this->primitive_->value.AsStridedSliceGrad()->shrinkAxisMask = shrink_axis_mask; +} +void StridedSliceGrad::SetBegin(const std::vector &begin) { + this->primitive_->value.AsStridedSliceGrad()->begin = begin; +} +void StridedSliceGrad::SetEnd(const std::vector &end) { this->primitive_->value.AsStridedSliceGrad()->end = end; } +void StridedSliceGrad::SetStride(const std::vector &stride) { + this->primitive_->value.AsStridedSliceGrad()->stride = stride; +} +void StridedSliceGrad::SetIsScale(const std::vector &is_scale) { + this->primitive_->value.AsStridedSliceGrad()->isScale = is_scale; +} + +int StridedSliceGrad::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_StridedSliceGrad; + } + if (this->primitive_->value.type != schema::PrimitiveType_StridedSliceGrad) { + MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::StridedSliceGradT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new StridedSliceGrad failed"; + return RET_ERROR; + } + attr->beginMask = CastToInt(prim.GetAttr("begin_mask")).front(); + attr->endMask = CastToInt(prim.GetAttr("end_mask")).front(); + attr->ellipsisMask = CastToInt(prim.GetAttr("ellipsis_mask")).front(); + attr->newAxisMask = CastToInt(prim.GetAttr("new_axis_mask")).front(); + attr->shrinkAxisMask = CastToInt(prim.GetAttr("shrink_axis_mask")).front(); + auto inputNodeFirst = inputs[kAnfPopulaterInputNumOne]; + std::vector beginVec; + GetAttrDataFromInput(inputNodeFirst, &beginVec); + attr->begin = beginVec; + + auto inputNodeSecond = inputs[kAnfPopulaterInputNumTwo]; + std::vector endVec; + GetAttrDataFromInput(inputNodeSecond, &endVec); + attr->end = endVec; + + auto inputNodeThird = inputs[kAnfPopulaterInputNumThree]; + std::vector strideVec; + GetAttrDataFromInput(inputNodeThird, &strideVec); + attr->stride = strideVec; + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + } + return RET_OK; +} + +#else + +int StridedSliceGrad::GetBeginMask() const { return this->primitive_->value_as_StridedSliceGrad()->beginMask(); } +int StridedSliceGrad::GetEndMask() const { return this->primitive_->value_as_StridedSliceGrad()->endMask(); } +int StridedSliceGrad::GetEllipsisMask() const { return this->primitive_->value_as_StridedSliceGrad()->ellipsisMask(); } +int StridedSliceGrad::GetNewAxisMask() const { return this->primitive_->value_as_StridedSliceGrad()->newAxisMask(); } +int StridedSliceGrad::GetShrinkAxisMask() const { + return this->primitive_->value_as_StridedSliceGrad()->shrinkAxisMask(); +} +std::vector StridedSliceGrad::GetBegin() const { + auto fb_vector = this->primitive_->value_as_StridedSliceGrad()->begin(); + return std::vector(fb_vector->begin(), fb_vector->end()); +} +std::vector StridedSliceGrad::GetEnd() const { + auto fb_vector = this->primitive_->value_as_StridedSliceGrad()->end(); + return std::vector(fb_vector->begin(), fb_vector->end()); +} +std::vector StridedSliceGrad::GetStride() const { + auto fb_vector = this->primitive_->value_as_StridedSliceGrad()->stride(); + return std::vector(fb_vector->begin(), fb_vector->end()); +} +std::vector StridedSliceGrad::GetIsScale() const { + auto fb_vector = this->primitive_->value_as_StridedSliceGrad()->isScale(); + return std::vector(fb_vector->begin(), fb_vector->end()); +} +int StridedSliceGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_StridedSliceGrad(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_StridedSliceGrad return nullptr"; + return RET_ERROR; + } + std::vector begin; + if (attr->begin() != nullptr) { + for (int i = 0; i < static_cast(attr->begin()->size()); i++) { + begin.push_back(attr->begin()->data()[i]); + } + } + std::vector end; + if (attr->end() != nullptr) { + for (int i = 0; i < static_cast(attr->end()->size()); i++) { + end.push_back(attr->end()->data()[i]); + } + } + std::vector stride; + if (attr->stride() != nullptr) { + for (int i = 0; i < static_cast(attr->stride()->size()); i++) { + stride.push_back(attr->stride()->data()[i]); + } + } + std::vector isScale; + if (attr->isScale() != nullptr) { + for (int i = 0; i < static_cast(attr->isScale()->size()); i++) { + isScale.push_back(attr->isScale()->data()[i]); + } + } + auto val_offset = + schema::CreateStridedSliceGradDirect(*fbb, attr->beginMask(), attr->endMask(), attr->ellipsisMask(), + attr->newAxisMask(), attr->shrinkAxisMask(), &begin, &end, &stride, &isScale); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_StridedSliceGrad, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + +PrimitiveC *StridedSliceGradCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry StridedSliceGradRegistry(schema::PrimitiveType_StridedSliceGrad, StridedSliceGradCreator); +#endif + +namespace { +constexpr size_t kStridedSliceGradOutputNum = 1; +constexpr size_t kStridedSliceGradMultiInputNumMax = 5; +} // namespace + +int StridedSliceGrad::InferShape(std::vector inputs, std::vector outputs) { + MS_ASSERT(this->primitive_ != nullptr); + if (outputs.size() != kStridedSliceGradOutputNum) { + MS_LOG(ERROR) << "Invalid output size:" << outputs.size(); + return RET_PARAM_INVALID; + } + if (inputs.size() != kStridedSliceGradMultiInputNumMax) { + MS_LOG(ERROR) << "Invalid input size " << inputs.size(); + return RET_PARAM_INVALID; + } + auto input = inputs.at(0); + outputs.front()->set_data_type(input->data_type()); + outputs.at(0)->set_format(input->format()); + MS_ASSERT(input != nullptr); + auto input_shape = input->shape(); + auto inferflag = infer_flag(); + + in_shape_.clear(); + if (inferflag) { + in_shape_.assign(input_shape.begin(), input_shape.end()); + } + begins_.clear(); + ends_.clear(); + strides_.clear(); + + if (!CheckInputs(inputs)) { + MS_LOG(DEBUG) << "Do infer shape in runtime."; + return RET_INFER_INVALID; + } + + // input order: dy, shapex, begins, ends, strides. + auto begin_tensor = inputs.at(2); + int *begin_data = reinterpret_cast(begin_tensor->MutableData()); + auto end_tensor = inputs.at(3); + int *end_data = reinterpret_cast(end_tensor->MutableData()); + auto stride_tensor = inputs.at(4); + int *stride_data = reinterpret_cast(stride_tensor->MutableData()); + if (begin_data == nullptr || end_data == nullptr || stride_data == nullptr) { + return RET_INFER_ERR; + } + ndim_ = begin_tensor->ElementsNum(); + for (size_t i = 0; i < ndim_; ++i) { + begins_.emplace_back(begin_data[i]); + ends_.emplace_back(end_data[i]); + strides_.emplace_back(stride_data[i]); + } + + // set all mask to original input shape + begins_mask_.resize(ndim_); + ends_mask_.resize(ndim_); + ellipsis_mask_.resize(ndim_); + new_axis_mask_.resize(ndim_); + shrink_axis_mask_.resize(ndim_); + + for (size_t i = 0; i < ndim_; i++) { + begins_mask_.at(i) = static_cast(GetBeginMask()) & (1 << i); + ends_mask_.at(i) = static_cast(GetEndMask()) & (1 << i); + ellipsis_mask_.at(i) = static_cast(GetEllipsisMask()) & (1 << i); + new_axis_mask_.at(i) = static_cast(GetNewAxisMask()) & (1 << i); + shrink_axis_mask_.at(i) = static_cast(GetShrinkAxisMask()) & (1 << i); + } + + ApplyNewAxisMask(); + ApplyBeginMask(); + ApplyEndMask(); + ApplyEllipsisMask(); + + if (!inferflag) { + return RET_OK; + } + + auto output_size = inputs.at(1)->shape().at(0); + std::vector output_shape; + MS_ASSERT(inputs.at(1)->MutableData() != nullptr); + for (int i = 0; i < output_size; i++) { + output_shape.push_back(static_cast(inputs.at(1)->MutableData())[i]); + } + outputs.front()->set_shape(output_shape); + + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/strided_slice_grad.h b/mindspore/lite/src/ops/strided_slice_grad.h new file mode 100644 index 00000000000..f0951ccf205 --- /dev/null +++ b/mindspore/lite/src/ops/strided_slice_grad.h @@ -0,0 +1,64 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_OPS_STRIDED_SLICE_GRAD_H_ +#define MINDSPORE_LITE_SRC_OPS_STRIDED_SLICE_GRAD_H_ + +#include +#include +#include +#include + +#include "src/ops/strided_slice.h" + +namespace mindspore { +namespace lite { +class StridedSliceGrad : public StridedSlice { + public: + StridedSliceGrad() = default; + ~StridedSliceGrad() = default; +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(StridedSliceGrad, StridedSlice); + explicit StridedSliceGrad(schema::PrimitiveT *primitive) : StridedSlice(primitive) {} + void SetBeginMask(int begin_mask); + void SetEndMask(int end_mask); + void SetEllipsisMask(int ellipsis_mask); + void SetNewAxisMask(int new_axis_mask); + void SetShrinkAxisMask(int shrink_axis_mask); + void SetBegin(const std::vector &begin); + void SetEnd(const std::vector &end); + void SetStride(const std::vector &stride); + void SetIsScale(const std::vector &is_scale); + int UnPackAttr(const Primitive &prim, const std::vector &inputs); +#else + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; + // bool CheckInputs(std::vector inputs_); + int GetBeginMask() const; + int GetEndMask() const; + int GetEllipsisMask() const; + int GetNewAxisMask() const; + int GetShrinkAxisMask() const; + std::vector GetBegin() const; + std::vector GetEnd() const; + std::vector GetStride() const; + std::vector GetIsScale() const; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_OPS_STRIDED_SLICE_GRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc index ed491047b24..7d1f37ad83b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc @@ -91,10 +91,12 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() { // init bias size_t new_bias_size = oc4 * C4NUM * sizeof(float); - bias_data_ = reinterpret_cast(malloc(new_bias_size)); if (bias_data_ == nullptr) { - MS_LOG(ERROR) << "malloc bias_data_ failed."; - return RET_MEMORY_FAILED; + bias_data_ = reinterpret_cast(malloc(new_bias_size)); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "malloc bias_data_ failed."; + return RET_MEMORY_FAILED; + } } memset(bias_data_, 0, new_bias_size); if (in_tensors_.size() == kInputSize2) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm_fp32.cc index 6f5c505af3d..1df0ee7af1b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm_fp32.cc @@ -91,10 +91,6 @@ int FusedBatchnormCPUKernel::Run() { memcpy(scale_, scale, in_tensors_[1]->Size()); memcpy(offset_, offset, in_tensors_[2]->Size()); - // save for next iteration - memcpy(in_tensors_[3]->MutableData(), save_mean, in_tensors_[3]->Size()); - memcpy(in_tensors_[4]->MutableData(), save_variance, in_tensors_[4]->Size()); - trained_ = true; // trained at least once } auto ret = ParallelLaunch(this->context_->thread_pool_, BatchNormRun, this, op_parameter_->thread_num_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc index 8c6ac3046b0..7ee2f2d3dc6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc @@ -40,17 +40,16 @@ int ApplyMomentumCPUKernel::Execute(int task_id) { size_t stride = UP_DIV(length, thread_count_); size_t count = MSMIN(stride, length - stride * task_id); - size_t start = stride * task_id; size_t end = start + count; if (apply_momentum_param_->use_nesterov_) { - for (size_t i = start; i < end; ++i) { + for (size_t i = start; i < end; i++) { accumulate[i] = accumulate[i] * moment + gradient[i]; weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate; } } else { - for (size_t i = start; i < end; ++i) { + for (size_t i = start; i < end; i++) { accumulate[i] = accumulate[i] * moment + gradient[i]; weight[i] -= accumulate[i] * learning_rate; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc index 90a3c8d72e3..6385dcccbbb 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc @@ -18,6 +18,10 @@ #include #include #include + +#include +#include + #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "nnacl/fp32_grad/batch_norm.h" @@ -34,7 +38,8 @@ namespace mindspore::kernel { int BNGradCPUKernel::ReSize() { auto *input_x = in_tensors_.at(1); int channels = input_x->shape().at(kNHWC_C); - set_workspace_size(2 * channels * sizeof(float)); + ws_size_ = 2 * channels; + set_workspace_size(ws_size_ * sizeof(float)); return RET_OK; } @@ -46,7 +51,9 @@ int BNGradCPUKernel::Execute(int task_id) { auto *input_scale = in_tensors_.at(2); auto *input_mean = in_tensors_.at(3); auto *input_var = in_tensors_.at(4); - + auto bn_param = reinterpret_cast(op_parameter_); + int stage = stage_; + int thread_num = thread_num_; float *save_mean = reinterpret_cast(input_mean->MutableData()); float *save_var = reinterpret_cast(input_var->MutableData()); @@ -58,26 +65,57 @@ int BNGradCPUKernel::Execute(int task_id) { int32_t spatial = input_x->Height() * input_x->Width(); float *workspace_temp = static_cast(workspace()); - std::fill(workspace_temp, workspace_temp + workspace_size() / sizeof(*workspace_temp), 0.f); float *dxhat_sum = workspace_temp; float *dxhathat_sum = dxhat_sum + channels; - float *x = reinterpret_cast(input_x->MutableData()); float *yt = reinterpret_cast(input_yt->MutableData()); float *scale = reinterpret_cast(input_scale->MutableData()); float *dx = reinterpret_cast(output_dx->MutableData()); float *dbias = reinterpret_cast(output_bias->MutableData()); float *dscale = reinterpret_cast(output_scale->MutableData()); - std::fill(dbias, dbias + channels, 0.f); - std::fill(dscale, dscale + channels, 0.f); - backwardAll(x, yt, save_mean, save_var, scale, batch * spatial, channels, dxhat_sum, dxhathat_sum, dbias, dscale, dx); + int total = spatial * batch; + int stride = UP_DIV(total, thread_num); + int count = MSMIN(stride, total - stride * task_id); + switch (stage) { + case 0: { + for (int job = task_id; job < 4; job += thread_num) { + switch (job) { + case 0: + var2Invar(save_var, input_var->ElementsNum(), bn_param->epsilon_); + break; + case 1: + std::fill(workspace_temp, workspace_temp + ws_size_, 0.f); + break; + case 2: + std::fill(dbias, dbias + channels, 0.f); + break; + case 3: + std::fill(dscale, dscale + channels, 0.f); + break; + } + } + if (thread_num == 1) { + backwardAll(x, yt, save_mean, save_var, scale, total, channels, dxhat_sum, dxhathat_sum, dbias, dscale, dx); + } + break; + } + case 1: { + backwardP1(x, yt, save_mean, save_var, scale, total, channels, dxhat_sum, dxhathat_sum, dbias, dscale); + break; + } + case 2: { + backwardP2(x + task_id * stride * channels, yt + task_id * stride * channels, save_mean, save_var, scale, count, + total, channels, dxhat_sum, dxhathat_sum, dx + task_id * stride * channels); + break; + } + } + return RET_OK; } int BNGradRun(void *cdata, int task_id) { MS_ASSERT(cdata != nullptr); auto bn_kernel = reinterpret_cast(cdata); - auto error_code = bn_kernel->Execute(task_id); if (error_code != RET_OK) { MS_LOG(ERROR) << "BNGradRun error task_id[" << task_id << "] error_code[" << error_code << "]"; @@ -87,15 +125,24 @@ int BNGradRun(void *cdata, int task_id) { } int BNGradCPUKernel::Run() { - auto *input_var = in_tensors_.at(4); - float *save_var = reinterpret_cast(input_var->MutableData()); - auto bn_param = reinterpret_cast(op_parameter_); - float eps = bn_param->epsilon_; - var2Invar(save_var, input_var->ElementsNum(), eps); - int error_code = ParallelLaunch(this->context_->thread_pool_, BNGradRun, this, 1); - if (error_code != RET_OK) { - MS_LOG(ERROR) << "BN function error error_code[" << error_code << "]"; - return RET_ERROR; + stage_ = 0; + thread_num_ = context_->thread_num_; + if (thread_num_ == 1) { + int error_code = ParallelLaunch(this->context_->thread_pool_, BNGradRun, this, thread_num_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "BN function error error_code[" << error_code << "]"; + return RET_ERROR; + } + } else { + const std::vector threads = {thread_num_, 1, thread_num_}; + for (size_t stage = 0; stage < threads.size(); stage++) { + stage_ = static_cast(stage); + int error_code = ParallelLaunch(this->context_->thread_pool_, BNGradRun, this, threads.at(stage)); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "BN function error error_code[" << error_code << "]"; + return RET_ERROR; + } + } } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.h index cc2b57b8cc9..d0e85384a83 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.h @@ -33,6 +33,11 @@ class BNGradCPUKernel : public LiteKernel { int ReSize() override; int Run() override; int Execute(int task_id); + + private: + int thread_num_ = 1; + int stage_ = 0; + size_t ws_size_ = 0; }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BN_GRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.cc index 24beda510fd..24082f283ae 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.cc @@ -54,9 +54,6 @@ int ConvolutionTrainCPUKernel::ReSize() { conv_param_->group_ = (conv_param_->group_ == 0) ? conv_param_->input_channel_ : conv_param_->group_; const int n = conv_param_->output_channel_ * conv_param_->group_; const int k = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ / conv_param_->group_; - ws_size_ = chunk_ * k; - int mat_alloc = MatSizeTotal(chunk_, n, k, 0); - set_workspace_size((ws_size_ + mat_alloc) * sizeof(float)); do_img2col_ = (conv_param_->kernel_h_ == 1) && (conv_param_->kernel_w_ == 1) && (conv_param_->pad_d_ == 0) && (conv_param_->pad_u_ == 0) && (conv_param_->pad_l_ == 0) && (conv_param_->pad_r_ == 0) && @@ -64,6 +61,16 @@ int ConvolutionTrainCPUKernel::ReSize() { (conv_param_->stride_h_ == 1) && (conv_param_->stride_w_ == 1) && (conv_param_->group_ == 1) ? false : true; + do_dw_ = (conv_param_->output_channel_ == conv_param_->group_) && + (conv_param_->input_channel_ == conv_param_->output_channel_) && (conv_param_->dilation_h_ == 1) && + (conv_param_->dilation_w_ == 1) + ? true + : false; + + ws_size_ = chunk_ * conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_; + ws_size_ = do_dw_ ? ws_size_ : ws_size_ / conv_param_->group_; + int mat_alloc = MatSizeTotal(chunk_, n, k, 0); + set_workspace_size((ws_size_ + mat_alloc) * sizeof(float)); return RET_OK; } @@ -97,7 +104,25 @@ int ConvolutionTrainCPUKernel::Execute(int task_id) { float *workspace_temp = static_cast(workspace()); float *mat_workspace = workspace_temp + ws_size_; - if (do_img2col_) { + if (do_dw_) { + const int kernel_spatial = k_h * k_w; + for (int i = 0; i < batch; ++i) { + for (int ci = 0; ci < m; ci += chunk_) { + int real_chunk = MSMIN(m - ci, chunk_); + float *mat_a = workspace_temp; + float *im = x_addr + (i * in_ch * in_h * in_w); + RollingIm2ColPackDwUnitFp32(im, conv_param_, mat_a, real_chunk, ci); + for (int j = 0; j < groups; ++j) { + const float *mat_b = w_addr + j * nweights / groups; + float *mat_c = y_addr + (i * groups) * n * m + j * (out_ch / groups) + ci * out_ch; + // float *im = x_addr + i * in_ch * in_h * in_w + j * (in_ch / groups); + // RollingIm2ColPackUnitFp32(im, conv_param_, mat_a, real_chunk, ci); + GemmMatmul(0, 1, real_chunk, n, k, 1, mat_a + (j * kernel_spatial), k * groups, mat_b, k, 0, mat_c, out_ch, + mat_workspace); + } + } + } + } else if (do_img2col_) { for (int i = 0; i < batch; ++i) { for (int j = 0; j < groups; ++j) { for (int ci = 0; ci < m; ci += chunk_) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.h index bfaf9d25ea0..ccb634ea1a2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.h @@ -37,6 +37,7 @@ class ConvolutionTrainCPUKernel : public LiteKernel { private: int ws_size_ = 0; bool do_img2col_ = true; + bool do_dw_ = false; #ifdef ENABLE_ARM32 const int chunk_ = C4NUM * 2; #else diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.cc index d872cceb9fe..8defa03cb93 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.cc @@ -17,6 +17,7 @@ #include "src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h" #include "src/kernel_registry.h" #include "nnacl/pack.h" +#include "nnacl/fp32_grad/convolution_grad_filter.h" #include "nnacl/fp32_grad/pack_ext.h" #include "nnacl/fp32_grad/gemm.h" #include "include/errorcode.h" @@ -51,20 +52,25 @@ int ConvolutionGradFilterCPUKernel::ReSize() { conv_param->output_h_ = dy_tensor->shape()[kNHWC_H]; conv_param->output_w_ = dy_tensor->shape()[kNHWC_W]; - ws_size_ = chunk_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_; - - int n = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_; - int k = conv_param->output_channel_ / conv_param->group_; - int thread_num = context_->thread_num_; - mat_alloc_ = MatSizeTotal(k, n, chunk_, 0); - set_workspace_size((ws_size_ + mat_alloc_ + (k * n)) * thread_num * sizeof(float)); - do_img2col_ = (conv_param->kernel_h_ == 1) && (conv_param->kernel_w_ == 1) && (conv_param->pad_d_ == 0) && (conv_param->pad_u_ == 0) && (conv_param->pad_l_ == 0) && (conv_param->pad_r_ == 0) && (conv_param->dilation_h_ == 1) && (conv_param->dilation_w_ == 1) && (conv_param->stride_h_ == 1) && (conv_param->stride_w_ == 1) && (conv_param->group_ == 1) ? false : true; + do_dw_ = (conv_param->output_channel_ == conv_param->group_) && + (conv_param->input_channel_ == conv_param->output_channel_) && (conv_param->dilation_h_ == 1) && + (conv_param->dilation_w_ == 1) + ? true + : false; + + ws_size_ = chunk_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + ws_size_ = do_dw_ ? ws_size_ : ws_size_ / conv_param->group_; + int n = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_; + int k = conv_param->output_channel_ / conv_param->group_; + int thread_num = context_->thread_num_; + mat_alloc_ = MatSizeTotal(k, n, chunk_, 0); + set_workspace_size((ws_size_ + mat_alloc_ + (k * n)) * thread_num * sizeof(float)); return RET_OK; } @@ -105,10 +111,38 @@ int ConvolutionGradFilterCPUKernel::Execute(int task_id) { int start = stride * task_id; int end = start + count; - if (do_img2col_) { + if (do_dw_) { +#ifdef ENABLE_ARM + stride = UP_DIV(k_h * k_w, thread_num); + count = MSMIN(stride, k_h * k_w - stride * task_id); + start = stride * task_id; + ConvDwFilterGrad(x_addr, dy_addr, dw_addr, start, count, conv_param); +#else + stride = UP_DIV(groups, thread_num); + count = MSMIN(stride, groups - stride * task_id); + start = stride * task_id; + end = start + count; + + const int kernel_spatial = k_h * k_w; + for (int i = 0; i < batch; ++i) { + for (int ci = 0; ci < m; ci += chunk_) { + int real_chunk = MSMIN(m - ci, chunk_); + float *mat_b = workspace_temp + task_id * ws_size_; + float *im = x_addr + (i * in_ch * in_h * in_w); + RollingIm2ColPackDwUnitFp32(im, conv_param, mat_b, real_chunk, ci); + for (int j = start; j < end; ++j) { + float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch; + float *mat_c = dw_addr + j * nweights / groups; + GemmMatmul(1, 0, k, n, real_chunk, 1, mat_a, out_ch, mat_b + (j * kernel_spatial), n * groups, 1, mat_c, n, + mat_workspace); + } + } + } +#endif + } else if (do_img2col_) { for (int i = start; i < end; ++i) { - for (int j = 0; j < groups; ++j) { - for (int ci = 0; ci < m; ci += chunk_) { + for (int ci = 0; ci < m; ci += chunk_) { + for (int j = 0; j < groups; ++j) { int real_chunk = MSMIN(m - ci, chunk_); float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch; float *mat_b = workspace_temp + task_id * ws_size_; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h index bf1696ab655..b7cd2f90949 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h @@ -38,6 +38,7 @@ class ConvolutionGradFilterCPUKernel : public LiteKernel { private: size_t ws_size_ = 0; bool do_img2col_ = true; + bool do_dw_ = false; std::mutex lock_; size_t mat_alloc_ = 0; #ifdef ENABLE_ARM32 diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc index 13cfbfd6d59..8011384e554 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc @@ -66,13 +66,20 @@ int PoolingGradCPUKernel::Execute(int task_id) { auto input_ptr = reinterpret_cast(in_tensors_.at(0)->MutableData()); auto output_ptr = reinterpret_cast(out_tensors_.at(0)->MutableData()); + int stride = UP_DIV(pool_param->output_batch_, thread_num_); + int count = MSMIN(stride, pool_param->output_batch_ - stride * task_id); + int in_batch_size = pool_param->input_h_ * pool_param->input_w_ * pool_param->input_channel_; + int out_batch_size = pool_param->output_h_ * pool_param->output_w_ * pool_param->input_channel_; + std::fill(output_ptr + task_id * stride * in_batch_size, output_ptr + ((task_id * stride) + count) * in_batch_size, + 0.f); if (pool_param->pool_mode_ == PoolMode_MaxPool) { - auto dx_ptr = reinterpret_cast(in_tensors_.at(1)->MutableData()); auto dy_ptr = reinterpret_cast(in_tensors_.at(2)->MutableData()); - MaxPoolingGrad(input_ptr, dx_ptr, dy_ptr, output_ptr, pool_param, task_id); + MaxPoolingGrad(input_ptr + task_id * stride * in_batch_size, dy_ptr + task_id * stride * out_batch_size, + output_ptr + task_id * stride * in_batch_size, count, pool_param); } else { input_ptr = reinterpret_cast(in_tensors_.at(2)->MutableData()); - AvgPoolingGrad(input_ptr, output_ptr, pool_param, task_id); + AvgPoolingGrad(input_ptr + task_id * stride * out_batch_size, output_ptr + task_id * stride * in_batch_size, count, + pool_param); } return RET_OK; } @@ -89,7 +96,8 @@ int PoolingGradImpl(void *cdata, int task_id) { } int PoolingGradCPUKernel::Run() { - int error_code = ParallelLaunch(this->context_->thread_pool_, PoolingGradImpl, this, 1); + thread_num_ = context_->thread_num_; + int error_code = ParallelLaunch(this->context_->thread_pool_, PoolingGradImpl, this, thread_num_); if (error_code != RET_OK) { MS_LOG(ERROR) << "pooling error error_code[" << error_code << "]"; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.h index 43f6ad79ec4..43edc88c399 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.h @@ -40,6 +40,7 @@ class PoolingGradCPUKernel : public LiteKernel { int Execute(int task_id); private: + int thread_num_ = 1; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/strided_slice_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/strided_slice_grad.cc new file mode 100644 index 00000000000..6da2c485e40 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/strided_slice_grad.cc @@ -0,0 +1,150 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32_grad/strided_slice_grad.h" +#include +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "nnacl/fp32_grad/strided_slice_grad.h" +#include "src/ops/populate/strided_slice_populate.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_StridedSliceGrad; + +namespace mindspore::kernel { + +int StridedSliceGradCPUKernel::Init() { + if (!InferShapeDone()) { + return RET_OK; + } + param_ = reinterpret_cast(op_parameter_); + auto input = in_tensors_.at(0); + MS_ASSERT(input); + switch (input->data_type()) { + case kNumberTypeFloat32: + param_->data_type = kDataTypeFloat; + break; + default: + MS_LOG(ERROR) << "Not supported data type: " << input->data_type(); + return RET_ERROR; + } + FillEmptyDims(); + FillOutputDim(); + return ReSize(); +} + +void StridedSliceGradCPUKernel::FillEmptyDims() { + int32_t begins[DIMENSION_7D]; + int32_t ends[DIMENSION_7D]; + int32_t strides[DIMENSION_7D]; + int32_t input_shape[DIMENSION_7D]; + int32_t i; + for (i = 0; i < param_->num_axes_; ++i) { + begins[i] = param_->begins_[i]; + ends[i] = MSMIN(param_->ends_[i], param_->in_shape_[i]); + strides[i] = param_->strides_[i]; + input_shape[i] = param_->in_shape_[i]; + } + for (i = param_->num_axes_; i < param_->in_shape_length_; ++i) { + input_shape[i] = param_->in_shape_[i]; + begins[i] = 0; + ends[i] = param_->in_shape_[i]; + strides[i] = 1; + } + + int32_t real_index = param_->in_shape_length_ - 1; + for (i = DIMENSION_7D - 1; i >= 0; --i) { + if (real_index >= 0) { + param_->begins_[i] = begins[real_index]; + param_->ends_[i] = ends[real_index]; + param_->strides_[i] = strides[real_index]; + param_->in_shape_[i] = input_shape[real_index--]; + } else { + param_->begins_[i] = 0; + param_->ends_[i] = 1; + param_->strides_[i] = 1; + param_->in_shape_[i] = 1; + } + } + param_->num_axes_ = DIMENSION_7D; + param_->in_shape_length_ = DIMENSION_7D; + + for (i = 0; i < DIMENSION_7D; ++i) { + if (param_->begins_[i] < 0) { + param_->begins_[i] += param_->in_shape_[i]; + } + if (param_->ends_[i] < 0) { + param_->ends_[i] += param_->in_shape_[i]; + } + } +} + +void StridedSliceGradCPUKernel::FillOutputDim() { + auto output = out_tensors_.at(0); + size_t out_size = output->shape().size(); + for (size_t i = 0; i < DIMENSION_7D; i++) { + if (i < out_size) { + output_shape_.push_back(output->shape()[i]); + } else { + output_shape_.insert(output_shape_.begin(), 1); + } + } +} + +int StridedSliceGradCPUKernel::ReSize() { return RET_OK; } + +int StridedSliceGradImpl(void *cdata, int task_id) { + MS_ASSERT(cdata != nullptr); + auto slice = reinterpret_cast(cdata); + auto error_code = slice->Execute(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "StridedSliceGrad Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int StridedSliceGradCPUKernel::Run() { + int error_code = ParallelLaunch(this->context_->thread_pool_, StridedSliceGradImpl, this, 1); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Strided slice error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int StridedSliceGradCPUKernel::Execute(int task_id) { + auto input = in_tensors_.at(0); + auto output = out_tensors_.at(0); + MS_ASSERT(output); + int *po = output_shape_.data(); + auto ret = DoStridedSliceGrad(reinterpret_cast(input->MutableData()), + reinterpret_cast(output->MutableData()), po, param_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "StridedSliceGrad error error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_StridedSliceGrad, LiteKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/strided_slice_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/strided_slice_grad.h new file mode 100644 index 00000000000..b6641695c1c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/strided_slice_grad.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_STRIDED_SLICE_GRAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_STRIDED_SLICE_GRAD_H_ + +#include +#include "nnacl/fp32_grad/strided_slice_grad.h" +#include "src/lite_kernel.h" + +namespace mindspore::kernel { +class StridedSliceGradCPUKernel : public LiteKernel { + public: + StridedSliceGradCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + param_ = reinterpret_cast(parameter); + } + ~StridedSliceGradCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int Execute(int task_id); + + private: + void FillEmptyDims(); + void FillOutputDim(); + void ParseMasks(); + + StridedSliceParameter *param_; + std::vector output_shape_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_STRIDED_SLICE_GRAD_H_ diff --git a/mindspore/lite/src/train/train_populate_parameter.cc b/mindspore/lite/src/train/train_populate_parameter.cc index 333e76a8dd2..36216fd6905 100644 --- a/mindspore/lite/src/train/train_populate_parameter.cc +++ b/mindspore/lite/src/train/train_populate_parameter.cc @@ -49,6 +49,7 @@ #include "src/ops/smooth_l1_loss_grad.h" #include "nnacl/fp32_grad/smooth_l1_loss.h" #include "src/ops/arithmetic_grad.h" +#include "src/ops/populate/strided_slice_populate.h" namespace mindspore::kernel { OpParameter *DefaultPopulateParameter(const mindspore::lite::PrimitiveC *primitive) { @@ -569,6 +570,9 @@ void PopulateTrainParameters() { DefaultPopulateParameter); lite::Registry SigmoidCrossEntropyWithLogitsGradRegistry(schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad, DefaultPopulateParameter); + lite::Registry FlattenGradParameterRegistry(schema::PrimitiveType_FlattenGrad, DefaultPopulateParameter); + lite::Registry StridedSliceGradParameterRegistry(schema::PrimitiveType_StridedSliceGrad, + mindspore::lite::PopulateStridedSliceParameter); } } // namespace mindspore::kernel diff --git a/mindspore/lite/src/train/transfer_session.h b/mindspore/lite/src/train/transfer_session.h new file mode 100644 index 00000000000..004eba421e2 --- /dev/null +++ b/mindspore/lite/src/train/transfer_session.h @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_TRAIN_TRANSFER_SESSION_H_ +#define MINDSPORE_LITE_SRC_TRAIN_TRANSFER_SESSION_H_ +#include +#include +#include +#include +#include "src/ops/primitive_c.h" +#include "include/train_session.h" +#include "src/train/train_model.h" +#include "src/lite_session.h" +#include "src/train/train_session.h" + +/* + Inheritance Diagram + + +-------------------------------+ + | session::LiteSession | + +--------+------------+---------+ + / \ + +-----------------+-----+ +-------+------------+ + | session::TrainSession | | lite::LiteSession | + +-----------------+-----+ +-------+------------+ + \ / + +--------+------------+---------+ + | lite::TrainSession | + +-------------------------------+ + | + +--------+------------+---------+ + | lite::TrasferSession | + +-------------------------------+ +*/ + +namespace mindspore { +namespace lite { + +class TransferSession : public lite::TrainSession { + public: + TransferSession(); + explicit TransferSession(lite::LiteSession *backend_session); + ~TransferSession(); + + int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override; + + void BindThread(bool if_bind) override; + std::vector GetInputs() const override { return lite::LiteSession::GetInputs(); } + mindspore::tensor::MSTensor *GetInputsByTensorName(const std::string &tensor_name) const override { + return lite::LiteSession::GetInputsByTensorName(tensor_name); + } + + protected: + lite::LiteSession *backend_session_; + + private: +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_TRAIN_TRANSFER_SESSION_H_ diff --git a/mindspore/lite/test/models_ms_train.cfg b/mindspore/lite/test/models_ms_train.cfg index b17284a7d75..329bc23d922 100644 --- a/mindspore/lite/test/models_ms_train.cfg +++ b/mindspore/lite/test/models_ms_train.cfg @@ -1,13 +1,15 @@ mini_alexnet -# mobilenetv1 +mobilenetv1 mobilenetv2 mobilenetv3 lenet effnet -# effnet_tune -# lenetv1 -# resnet -# googlenet +effnet_tune +resnet +googlenet # densenet +# shufflenetv2 +# nin # one_net +# lenetv1 #LAST \ No newline at end of file diff --git a/mindspore/lite/test/run_net_export.sh b/mindspore/lite/test/run_net_export.sh new file mode 100755 index 00000000000..1ff7997cee7 --- /dev/null +++ b/mindspore/lite/test/run_net_export.sh @@ -0,0 +1,82 @@ +#!/bin/bash + +# Print start msg after run testcase +function MS_PRINT_TESTCASE_END_MSG() { + echo -e "-----------------------------------------------------------------------------------------------------------------------------------" +} + +function Print_Result() { + MS_PRINT_TESTCASE_END_MSG + while read line; do + arr=("${line}") + printf "%-15s %-20s %-90s %-7s\n" ${arr[0]} ${arr[1]} ${arr[2]} ${arr[3]} + done < $1 + MS_PRINT_TESTCASE_END_MSG +} + +basepath=$(pwd) +echo ${basepath} + +# Example:run_net_export.sh -m /home/emir/Work/TestingEnv/train_models +epoch_num=1 +while getopts "m:t:" opt; do + case ${opt} in + m) + + models_path=${OPTARG}"/models_train" + echo "models_path is ${OPTARG}" + ;; + t) + epoch_num=${OPTARG} + echo "train epoch num is ${OPTARG}" + ;; + ?) + echo "unknown para" + exit 1;; + esac +done + + +# Set models config filepath +models_mindspore_train_config=${basepath}/models_ms_train.cfg + +logs_path=${basepath}/logs_train +rm -rf ${logs_path} +mkdir -p ${logs_path} + +docker_image=mindspore/mindspore-gpu:1.1.0 +# Export models +echo "Start Exporting models ..." +# Set log files +export_log_file=${logs_path}/export_log.txt +echo ' ' > ${export_log_file} + +export_result_file=${logs_path}/export_result.txt +echo ' ' > ${export_result_file} + +# Run export according to config file +cd $models_path || exit 1 +if [[ -z "${CLOUD_MODEL_ZOO}" ]]; then + echo "CLOUD_MODEL_ZOO is not defined - exiting export models" + exit 1 +fi + +# Export mindspore train models: +while read line; do + model_name=${line} + if [[ $model_name == \#* ]]; then + continue + fi + echo ${model_name}'_train_export.py' >> "${export_log_file}" + echo 'exporting' ${model_name} + echo 'docker run --user '"$(id -u):$(id -g)"' --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true '${docker_image}' python '${models_path}'/'${model_name}'_train_export.py' >> "${export_log_file}" + docker run --user "$(id -u):$(id -g)" --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true "${docker_image}" python ${models_path}'/'${model_name}_train_export.py "${epoch_num}" + if [ $? = 0 ]; then + export_result='export mindspore '${model_name}'_train_export pass';echo ${export_result} >> ${export_result_file} + else + export_result='export mindspore '${model_name}'_train_export failed';echo ${export_result} >> ${export_result_file} + fi +done < ${models_mindspore_train_config} + +Print_Result ${export_result_file} + diff --git a/mindspore/lite/test/run_net_train.sh b/mindspore/lite/test/run_net_train.sh index c161801a5e1..5d443bf365d 100755 --- a/mindspore/lite/test/run_net_train.sh +++ b/mindspore/lite/test/run_net_train.sh @@ -1,7 +1,7 @@ #!/bin/bash # Run Export on x86 platform and create output test files: -docker_image=mindspore_dev:8 +docker_image= function Run_Export(){ cd $models_path || exit 1 if [[ -z "${CLOUD_MODEL_ZOO}" ]]; then @@ -16,8 +16,13 @@ function Run_Export(){ fi echo ${model_name}'_train_export.py' >> "${export_log_file}" echo 'exporting' ${model_name} - echo 'docker run --user '"$(id -u):$(id -g)"' --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true '${docker_image}' python '${models_path}'/'${model_name}'_train_export.py' >> "${export_log_file}" - docker run --user "$(id -u):$(id -g)" --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true "${docker_image}" python ${models_path}'/'${model_name}_train_export.py "${epoch_num}" + if [ -n "$docker_image" ]; then + echo 'docker run --user '"$(id -u):$(id -g)"' --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true '${docker_image}' python '${models_path}'/'${model_name}'_train_export.py' >> "${export_log_file}" + docker run --user "$(id -u):$(id -g)" --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true "${docker_image}" python ${models_path}'/'${model_name}_train_export.py "${epoch_num}" + else + echo 'CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} python '${models_path}'/'${model_name}'_train_export.py' >> "${export_log_file}" + CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} python ${models_path}'/'${model_name}_train_export.py "${epoch_num}" + fi if [ $? = 0 ]; then export_result='export mindspore '${model_name}'_train_export pass';echo ${export_result} >> ${export_result_file} else @@ -28,7 +33,7 @@ function Run_Export(){ # Run converter on x86 platform: function Run_Converter() { - # Unzip x86 runtime and convertor + # Unzip x86 runtime and converter cd ${x86_path} || exit 1 tar -zxf mindspore-lite-${version}-train-linux-x64.tar.gz || exit 1 @@ -189,7 +194,7 @@ ENDM if [ $? = 0 ]; then run_result=$1': '${model_name}'_train pass'; echo ${run_result} >> ${run_benchmark_train_result_file} else - run_result=$1': '${model_name}'_train failed'; echo ${run_result} >> ${run_benchmark_train_result_file}; return 1 + run_result=$1': '${model_name}'_train failed'; echo ${run_result} >> ${run_benchmark_train_result_file}; fi done < ${models_mindspore_train_config} } @@ -222,16 +227,15 @@ echo ${basepath} # Example:run_benchmark_train.sh -r /home/emir/Work/TestingEnv/release -m /home/emir/Work/TestingEnv/train_models -i /home/emir/Work/TestingEnv/train_io -d "8KE5T19620002408" # For running on arm64, use -t to set platform tools path (for using adb commands) epoch_num=1 -threads=1 +threads=2 train_io_path="" -while getopts "r:m:d:i:e:vt:q:" opt; do +while getopts "r:m:d:i:e:vt:q:D" opt; do case ${opt} in r) release_path=${OPTARG} echo "release_path is ${OPTARG}" ;; m) - models_path=${OPTARG}"/models_train" echo "models_path is ${OPTARG}" ;; @@ -244,8 +248,9 @@ while getopts "r:m:d:i:e:vt:q:" opt; do echo "device_id is ${OPTARG}" ;; e) - enable_export=${OPTARG} - echo "enable_export = ${OPTARG}" + enable_export=1 + docker_image=${OPTARG} + echo "enable_export = 1, docker_image = ${OPTARG}" ;; v) run_valgrind="valgrind --log-file=valgrind.log " @@ -404,27 +409,27 @@ function Print_Benchmark_Result() { done < ${run_benchmark_train_result_file} MS_PRINT_TESTCASE_END_MSG } - +result=0 # Check benchmark_train result and return value if [[ ${Run_x86_status} != 0 ]];then echo "Run_x86 failed" cat ${run_x86_log_file} - exit 1 + result=1 fi if [[ ${Run_arm64_status} != 0 ]];then echo "Run_arm64 failed" cat ${run_arm64_log_file} - exit 1 + result=1 fi if [[ ${Run_arm32_status} != 0 ]];then echo "Run_arm32 failed" cat ${run_arm32_log_file} - exit 1 + result=1 fi echo "Test ended - Results:" Print_Benchmark_Result echo "Test run Time:" $DIFF -exit 0 +exit ${result} diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc index 9d939c71b6a..0af6e5a66c9 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc @@ -79,15 +79,18 @@ TEST_F(TestPoolingGradFp32, AvgPoolingGradFp32) { auto output_data = new float[output_data_size]; ASSERT_NE(output_data, nullptr); + // warm up loop for (int i = 0; i < 3; i++) { - AvgPoolingGrad(input_data, output_data, pooling_param, 1); + std::fill(output_data, output_data + output_data_size, 0.f); + AvgPoolingGrad(input_data, output_data, pooling_param->output_batch_, pooling_param); } int loop_count = 100; auto time_start = mindspore::lite::GetTimeUs(); for (int i = 0; i < loop_count; i++) { - AvgPoolingGrad(input_data, output_data, pooling_param, 1); + std::fill(output_data, output_data + output_data_size, 0.f); + AvgPoolingGrad(input_data, output_data, pooling_param->output_batch_, pooling_param); } auto time_end = mindspore::lite::GetTimeUs(); auto cost = time_end - time_start; @@ -407,18 +410,21 @@ TEST_F(TestPoolingGradFp32, MaxPoolingGradFp32) { std::string dx_path = "./test_data/pooling/maxpoolgradfp32_1_dx_1_28_28_3.bin"; auto dx_data = reinterpret_cast(mindspore::lite::ReadFile(dx_path.c_str(), &input_size)); ASSERT_NE(dx_data, nullptr); - + int in_batch_size = + pooling_param->input_h_ * pooling_param->input_w_ * pooling_param->input_channel_ * pooling_param->input_batch_; auto output_data = new float[output_data_size]; ASSERT_NE(output_data, nullptr); // warm up loop for (int i = 0; i < 3; i++) { - MaxPoolingGrad(in_data, dx_data, dy_data, output_data, pooling_param, 1); + std::fill(output_data, output_data + in_batch_size, 0.f); + MaxPoolingGrad(in_data, dy_data, output_data, pooling_param->output_batch_, pooling_param); } int loop_count = 100; auto time_start = mindspore::lite::GetTimeUs(); for (int i = 0; i < loop_count; i++) { - MaxPoolingGrad(in_data, dx_data, dy_data, output_data, pooling_param, 1); + std::fill(output_data, output_data + in_batch_size, 0.f); + MaxPoolingGrad(in_data, dy_data, output_data, pooling_param->output_batch_, pooling_param); } auto time_end = mindspore::lite::GetTimeUs(); auto cost = time_end - time_start; diff --git a/mindspore/lite/tools/benchmark_train/net_train.cc b/mindspore/lite/tools/benchmark_train/net_train.cc index 2a70bb37186..b7e037e5942 100644 --- a/mindspore/lite/tools/benchmark_train/net_train.cc +++ b/mindspore/lite/tools/benchmark_train/net_train.cc @@ -135,7 +135,6 @@ int NetTrain::ReadCalibData() { MS_LOG(INFO) << "Start reading calibData file"; std::string tensor_name; - while (!in_file.eof()) { getline(in_file, line); std::stringstream string_line1(line); diff --git a/mindspore/lite/tools/benchmark_train/net_train.h b/mindspore/lite/tools/benchmark_train/net_train.h index f724fb87dbe..49c0f4daf3e 100644 --- a/mindspore/lite/tools/benchmark_train/net_train.h +++ b/mindspore/lite/tools/benchmark_train/net_train.h @@ -79,7 +79,7 @@ class MS_API NetTrainFlags : public virtual FlagParser { std::vector input_data_list_; DataType in_data_type_; std::string in_data_type_in_ = "bin"; - int cpu_bind_mode_ = 0; + int cpu_bind_mode_ = 1; // MarkPerformance int num_threads_ = 1; int warm_up_loop_count_ = 0; diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index 26432bc8c2a..4739a821734 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -32,7 +32,6 @@ static const std::vector nhwcOpList = { schema::PrimitiveType_PoolingGrad, schema::PrimitiveType_BiasGrad, schema::PrimitiveType_BNGrad, - schema::PrimitiveType_ActivationGrad, schema::PrimitiveType_ApplyMomentum, schema::PrimitiveType_Sgd, schema::PrimitiveType_Adam, @@ -219,6 +218,26 @@ STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::v return RET_OK; } +static bool IsKCHWSource(kTransFilterType type) { + return (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW); +} + +static bool IsCKHWSource(kTransFilterType type) { + return (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC); +} + +static bool IsHWCKSource(kTransFilterType type) { return (type == kHWCK2KCHW || type == kHWCK2CKHW); } + +static bool IsHWKCSource(kTransFilterType type) { return (type == kHWKC2KCHW || type == kHWKC2CKHW); } + +static bool IsNHWCSource(kTransFilterType type) { + return (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW); +} + +static bool IsCHWKSource(kTransFilterType type) { return (type == kCHWK2HWCK || type == kCHWK2KHWC); } + +static bool IsKHWCSource(kTransFilterType type) { return (type == kKHWC2HWCK || type == kKHWC2CHWK); } + STATUS GetFilterDim(const std::vector &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC, int32_t *filterH, int32_t *filterW) { if (filterK == nullptr || filterC == nullptr || filterH == nullptr || filterW == nullptr) { @@ -226,37 +245,37 @@ STATUS GetFilterDim(const std::vector &oriDims, kTransFilterType type, return RET_NULL_PTR; } MS_ASSERT(oriDims.size() == 4); - if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW) { + if (IsKCHWSource(type)) { *filterK = oriDims.at(KCHW_K); *filterC = oriDims.at(KCHW_C); *filterH = oriDims.at(KCHW_H); *filterW = oriDims.at(KCHW_W); - } else if (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC) { + } else if (IsCKHWSource(type)) { *filterC = oriDims.at(CKHW_C); *filterK = oriDims.at(CKHW_K); *filterH = oriDims.at(CKHW_H); *filterW = oriDims.at(CKHW_W); - } else if (type == kHWCK2KCHW || type == kHWCK2CKHW) { + } else if (IsHWCKSource(type)) { *filterH = oriDims.at(HWCK_H); *filterW = oriDims.at(HWCK_W); *filterC = oriDims.at(HWCK_C); *filterK = oriDims.at(HWCK_K); - } else if (type == kHWKC2KCHW || type == kHWKC2CKHW) { + } else if (IsHWKCSource(type)) { *filterH = oriDims.at(HWKC_H); *filterW = oriDims.at(HWKC_W); *filterK = oriDims.at(HWKC_K); *filterC = oriDims.at(HWKC_C); - } else if (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW) { + } else if (IsNHWCSource(type)) { *filterK = oriDims.at(NHWC_N); *filterH = oriDims.at(NHWC_H); *filterW = oriDims.at(NHWC_W); *filterC = oriDims.at(NHWC_C); - } else if (type == kCHWK2HWCK || type == kCHWK2KHWC) { + } else if (IsCHWKSource(type)) { *filterC = oriDims.at(CHWK_C); *filterH = oriDims.at(CHWK_H); *filterW = oriDims.at(CHWK_W); *filterK = oriDims.at(CHWK_K); - } else if (type == kKHWC2HWCK || type == kKHWC2CHWK) { + } else if (IsKHWCSource(type)) { *filterK = oriDims.at(KHWC_K); *filterH = oriDims.at(KHWC_H); *filterW = oriDims.at(KHWC_W); @@ -290,6 +309,37 @@ STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filt return RET_OK; } +static int Convert2KHWC(int srcFormat) { + if (srcFormat == schema::Format::Format_KCHW) return kKCHW2KHWC; + if (srcFormat == schema::Format::Format_CKHW) return kCKHW2KHWC; + if (srcFormat == schema::Format::Format_CHWK) return kCHWK2KHWC; + return -1; +} + +static int Convert2HWCK(int srcFormat) { + if (srcFormat == schema::Format::Format_KCHW) return kKCHW2HWCK; + if (srcFormat == schema::Format::Format_KHWC) return kKHWC2HWCK; + if (srcFormat == schema::Format::Format_CKHW) return kCKHW2HWCK; + if (srcFormat == schema::Format::Format_CHWK) return kCHWK2HWCK; + return -1; +} + +static int Convert2KCHW(int srcFormat) { + if (srcFormat == schema::Format::Format_HWCK) return kHWCK2KCHW; + if (srcFormat == schema::Format::Format_HWKC) return kHWKC2KCHW; + if (srcFormat == schema::Format::Format_KHWC) return kKHWC2KCHW; + if (srcFormat == schema::Format::Format_CKHW) return kCKHW2KCHW; + if (srcFormat == schema::Format::Format_CHWK) return kCHWK2KCHW; + return -1; +} + +static int Convert2CKHW(int srcFormat) { + if (srcFormat == schema::Format::Format_HWCK) return kHWCK2CKHW; + if (srcFormat == schema::Format::Format_HWKC) return kHWKC2CKHW; + if (srcFormat == schema::Format::Format_KCHW) return kKCHW2CKHW; + return -1; +} + STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) { if (tensor == nullptr) { MS_LOG(ERROR) << "tensor is null"; @@ -303,231 +353,40 @@ STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) { auto srcFormat = tensor->format; auto dataType = tensor->dataType; STATUS status; + int convert = -1; + + if (dstFormat == srcFormat) return RET_OK; + switch (dstFormat) { - case schema::Format::Format_KHWC: { - switch (srcFormat) { - case schema::Format::Format_KCHW: - if (dataType == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kKCHW2KHWC); - } else if (dataType == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kKCHW2KHWC); - } else if (dataType == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kKCHW2KHWC); - } else { - MS_LOG(ERROR) << "Unsupported dataType: " << dataType; - return RET_ERROR; - } - break; - case schema::Format::Format_CKHW: - if (dataType == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kCKHW2KHWC); - } else if (dataType == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kCKHW2KHWC); - } else if (dataType == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kCKHW2KHWC); - } else { - MS_LOG(ERROR) << "Unsupported dataType: " << dataType; - return RET_ERROR; - } - break; - case schema::Format::Format_CHWK: - if (dataType == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kCHWK2KHWC); - } else if (dataType == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kCHWK2KHWC); - } else if (dataType == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kCHWK2KHWC); - } else { - MS_LOG(ERROR) << "Unsupported dataType: " << dataType; - return RET_ERROR; - } - break; - case schema::Format::Format_KHWC: - return RET_OK; - default: - MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " - << EnumNameFormat(dstFormat); - return RET_ERROR; - } - } break; - case schema::Format::Format_HWCK: { - switch (srcFormat) { - case schema::Format::Format_KCHW: - if (dataType == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kKCHW2HWCK); - } else if (dataType == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kKCHW2HWCK); - } else if (dataType == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kKCHW2HWCK); - } else { - MS_LOG(ERROR) << "Unsupported dataType: " << dataType; - return RET_ERROR; - } - break; - case schema::Format::Format_KHWC: - if (dataType == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kKHWC2HWCK); - } else if (dataType == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kKHWC2HWCK); - } else if (dataType == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kKHWC2HWCK); - } else { - MS_LOG(ERROR) << "Unsupported dataType: " << dataType; - return RET_ERROR; - } - break; - case schema::Format::Format_CKHW: - if (dataType == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kCKHW2HWCK); - } else if (dataType == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kCKHW2HWCK); - } else if (dataType == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kCKHW2HWCK); - } else { - MS_LOG(ERROR) << "Unsupported dataType: " << dataType; - return RET_ERROR; - } - break; - case schema::Format::Format_CHWK: - if (dataType == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kCHWK2HWCK); - } else if (dataType == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kCHWK2HWCK); - } else if (dataType == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kCHWK2HWCK); - } else { - MS_LOG(ERROR) << "Unsupported dataType: " << dataType; - return RET_ERROR; - } - break; - case schema::Format::Format_HWCK: - return RET_OK; - default: - MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " - << EnumNameFormat(dstFormat); - return RET_ERROR; - } - } break; - case schema::Format::Format_KCHW: { - switch (srcFormat) { - case schema::Format::Format_KCHW: - return RET_OK; - case schema::Format::Format_HWCK: - if (dataType == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kHWCK2KCHW); - } else if (dataType == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kHWCK2KCHW); - } else if (dataType == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kHWCK2KCHW); - } else { - MS_LOG(ERROR) << "Unsupported dataType: " << dataType; - return RET_ERROR; - } - break; - case schema::Format::Format_HWKC: - if (dataType == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kHWKC2KCHW); - } else if (dataType == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kHWKC2KCHW); - } else if (dataType == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kHWKC2KCHW); - } else { - MS_LOG(ERROR) << "Unsupported dataType: " << dataType; - return RET_ERROR; - } - break; - case schema::Format::Format_KHWC: - if (dataType == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kKHWC2KCHW); - } else if (dataType == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kKHWC2KCHW); - } else if (dataType == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kKHWC2KCHW); - } else { - MS_LOG(ERROR) << "Unsupported dataType: " << dataType; - return RET_ERROR; - } - break; - case schema::Format::Format_CKHW: - if (dataType == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kCKHW2KCHW); - } else if (dataType == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kCKHW2KCHW); - } else if (dataType == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kCKHW2KCHW); - } else { - MS_LOG(ERROR) << "Unsupported dataType: " << dataType; - return RET_ERROR; - } - break; - case schema::Format::Format_CHWK: - if (dataType == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kCHWK2KCHW); - } else if (dataType == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kCHWK2KCHW); - } else if (dataType == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kCHWK2KCHW); - } else { - MS_LOG(ERROR) << "Unsupported dataType: " << dataType; - return RET_ERROR; - } - break; - default: - MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " - << EnumNameFormat(dstFormat); - return RET_ERROR; - } - } break; - case schema::Format::Format_CKHW: { - switch (srcFormat) { - case schema::Format::Format_HWCK: - if (dataType == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kHWCK2CKHW); - } else if (dataType == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kHWCK2CKHW); - } else if (dataType == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kHWCK2CKHW); - } else { - MS_LOG(ERROR) << "Unsupported dataType: " << dataType; - return RET_ERROR; - } - break; - case schema::Format::Format_HWKC: - if (dataType == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kHWKC2CKHW); - } else if (dataType == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kHWKC2CKHW); - } else if (dataType == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kHWKC2CKHW); - } else { - MS_LOG(ERROR) << "Unsupported dataType: " << dataType; - return RET_ERROR; - } - break; - case schema::Format::Format_KCHW: - if (dataType == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kKCHW2CKHW); - } else if (dataType == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kKCHW2CKHW); - } else if (dataType == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kKCHW2CKHW); - } else { - MS_LOG(ERROR) << "Unsupported dataType: " << dataType; - return RET_ERROR; - } - break; - case schema::Format::Format_CKHW: - return RET_OK; - default: - MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " - << EnumNameFormat(dstFormat); - return RET_ERROR; - } - } break; + case schema::Format::Format_KHWC: + convert = Convert2KHWC(srcFormat); + break; + case schema::Format::Format_HWCK: + convert = Convert2HWCK(srcFormat); + break; + case schema::Format::Format_KCHW: + convert = Convert2KCHW(srcFormat); + break; + case schema::Format::Format_CKHW: + convert = Convert2CKHW(srcFormat); + break; default: - MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " - << EnumNameFormat(dstFormat); - return RET_ERROR; + convert = -1; + } + if (convert == -1) { + MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " << EnumNameFormat(dstFormat); + return RET_ERROR; + } + + if (dataType == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, static_cast(convert)); + } else if (dataType == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, static_cast(convert)); + } else if (dataType == kNumberTypeInt8) { + status = TransFilterFormat(tensor, static_cast(convert)); + } else { + MS_LOG(ERROR) << "Unsupported dataType: " << dataType; + return RET_ERROR; } if (status != RET_OK) { MS_LOG(ERROR) << "TransFilterData failed: " << status;