tod new ops, performance improvment and bug fix

This commit is contained in:
yoni 2021-02-02 17:39:09 +02:00
parent 3bfcf8947c
commit ae389ae1f9
40 changed files with 1713 additions and 430 deletions

View File

@ -50,3 +50,35 @@ void backwardAll(const float *restrict in, const float *restrict yt, const float
} }
} }
} }
void backwardP1(const float *restrict in, const float *restrict yt, const float *restrict mean,
const float *restrict invar, const float *restrict scale, int size, int ch, float *restrict dxhat_sum,
float *restrict dxhathat_sum, float *restrict dbias, float *restrict dscale) {
for (int i = 0; i < size; i++) {
for (int c = 0; c < ch; c++) {
int ix = i * ch + c;
dbias[c] += yt[ix];
// dscale
float x_hat = (in[ix] - mean[c]) * invar[c];
dscale[c] += (yt[ix] * x_hat);
// dx_1
float dx_hat = yt[ix] * scale[c];
dxhat_sum[c] += dx_hat;
dxhathat_sum[c] += dx_hat * x_hat;
}
}
}
void backwardP2(const float *restrict in, const float *restrict yt, const float *restrict mean,
const float *restrict invar, const float *restrict scale, int size, int total_size, int ch,
const float *dxhat_sum, const float *dxhathat_sum, float *restrict dx) {
float N = (float)total_size;
for (int i = 0; i < size; i++) {
for (int c = 0; c < ch; c++) {
// dx_2
int ix = i * ch + c;
float x_hat = (in[ix] - mean[c]) * invar[c];
float dx_hat = yt[ix] * scale[c];
dx[ix] = 1.0f / N * (invar[c]) * (N * dx_hat - dxhat_sum[c] - x_hat * dxhathat_sum[c]);
}
}
}

View File

@ -32,6 +32,10 @@ extern "C" {
void var2Invar(float *save_var, int size, float eps); void var2Invar(float *save_var, int size, float eps);
void backwardAll(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size, void backwardAll(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size,
int ch, float *dxhat_sum, float *dxhathat_sum, float *dbias, float *dscale, float *dx); int ch, float *dxhat_sum, float *dxhathat_sum, float *dbias, float *dscale, float *dx);
void backwardP1(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size,
int ch, float *dxhat_sum, float *dxhathat_sum, float *dbias, float *dscale);
void backwardP2(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size,
int total_size, int ch, const float *dxhat_sum, const float *dxhathat_sum, float *dx);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -0,0 +1,379 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp32_grad/convolution_grad_filter.h"
#ifdef ENABLE_ARM
#include <arm_neon.h>
#endif
#ifdef ENABLE_ARM
static int FilterGrad16Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw,
const ConvParameter *conv_param) {
int in_h = conv_param->input_h_;
int in_w = conv_param->input_w_;
int k_h = conv_param->kernel_h_;
int k_w = conv_param->kernel_w_;
int batch = conv_param->output_batch_;
int out_ch = conv_param->output_channel_;
int in_ch = conv_param->input_channel_;
int out_h = conv_param->output_h_;
int out_w = conv_param->output_w_;
int m = out_h * out_w;
int x_size = in_h * in_w * in_ch;
int y_size = out_ch * out_h * out_w;
int k_spatial = k_w * k_h;
int i_kh = k_idx / k_w;
int i_kw = k_idx % k_w;
for (; i_c < (out_ch & ~15); i_c += 16) {
float32x4_t sum_03_4 = vdupq_n_f32(0.0f);
float32x4_t sum_47_4 = vdupq_n_f32(0.0f);
float32x4_t sum_9x_4 = vdupq_n_f32(0.0f);
float32x4_t sum_12x_4 = vdupq_n_f32(0.0f);
for (int b = 0; b < batch; ++b) {
const float *x_addr = &x[b * x_size];
const float *dy_addr = &dy[b * y_size];
for (int i = 0; i < m; i++) {
int idx = i;
int input_h = idx / out_w * conv_param->stride_h_;
int input_w = idx % out_w * conv_param->stride_w_;
int input_row = -conv_param->pad_u_ + i_kh + input_h;
int input_col = -conv_param->pad_l_ + i_kw + input_w;
if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) {
int offset_x = (input_row * in_w + input_col) * out_ch + i_c;
int offset_dy = idx * out_ch + i_c;
float32x4_t x_03_4 = vld1q_f32(x_addr + offset_x);
float32x4_t dy_03_4 = vld1q_f32(dy_addr + offset_dy);
sum_03_4 = vmlaq_f32(sum_03_4, x_03_4, dy_03_4);
float32x4_t x_47_4 = vld1q_f32(x_addr + offset_x + 4);
float32x4_t dy_47_4 = vld1q_f32(dy_addr + offset_dy + 4);
sum_47_4 = vmlaq_f32(sum_47_4, x_47_4, dy_47_4);
float32x4_t x_9x_4 = vld1q_f32(x_addr + offset_x + 8);
float32x4_t dy_9x_4 = vld1q_f32(dy_addr + offset_dy + 8);
sum_9x_4 = vmlaq_f32(sum_9x_4, x_9x_4, dy_9x_4);
float32x4_t x_12x_4 = vld1q_f32(x_addr + offset_x + 12);
float32x4_t dy_12x_4 = vld1q_f32(dy_addr + offset_dy + 12);
sum_12x_4 = vmlaq_f32(sum_12x_4, x_12x_4, dy_12x_4);
}
}
}
dw[(i_c + 0) * k_spatial + k_idx] = sum_03_4[0];
dw[(i_c + 1) * k_spatial + k_idx] = sum_03_4[1];
dw[(i_c + 2) * k_spatial + k_idx] = sum_03_4[2];
dw[(i_c + 3) * k_spatial + k_idx] = sum_03_4[3];
dw[(i_c + 4) * k_spatial + k_idx] = sum_47_4[0];
dw[(i_c + 5) * k_spatial + k_idx] = sum_47_4[1];
dw[(i_c + 6) * k_spatial + k_idx] = sum_47_4[2];
dw[(i_c + 7) * k_spatial + k_idx] = sum_47_4[3];
dw[(i_c + 8) * k_spatial + k_idx] = sum_9x_4[0];
dw[(i_c + 9) * k_spatial + k_idx] = sum_9x_4[1];
dw[(i_c + 10) * k_spatial + k_idx] = sum_9x_4[2];
dw[(i_c + 11) * k_spatial + k_idx] = sum_9x_4[3];
dw[(i_c + 12) * k_spatial + k_idx] = sum_12x_4[0];
dw[(i_c + 13) * k_spatial + k_idx] = sum_12x_4[1];
dw[(i_c + 14) * k_spatial + k_idx] = sum_12x_4[2];
dw[(i_c + 15) * k_spatial + k_idx] = sum_12x_4[3];
}
return i_c;
}
static int FilterGrad12Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw,
const ConvParameter *conv_param) {
int in_h = conv_param->input_h_;
int in_w = conv_param->input_w_;
int k_h = conv_param->kernel_h_;
int k_w = conv_param->kernel_w_;
int batch = conv_param->output_batch_;
int out_ch = conv_param->output_channel_;
int in_ch = conv_param->input_channel_;
int out_h = conv_param->output_h_;
int out_w = conv_param->output_w_;
int m = out_h * out_w;
int x_size = in_h * in_w * in_ch;
int y_size = out_ch * out_h * out_w;
int k_spatial = k_w * k_h;
int i_kh = k_idx / k_w;
int i_kw = k_idx % k_w;
if ((out_ch - i_c) >= 12) {
float32x4_t sum_03_4 = vdupq_n_f32(0.0f);
float32x4_t sum_47_4 = vdupq_n_f32(0.0f);
float32x4_t sum_9x_4 = vdupq_n_f32(0.0f);
for (int b = 0; b < batch; ++b) {
const float *x_addr = &x[b * x_size];
const float *dy_addr = &dy[b * y_size];
for (int i = 0; i < m; i++) {
int idx = i;
int input_h = idx / out_w * conv_param->stride_h_;
int input_w = idx % out_w * conv_param->stride_w_;
int input_row = -conv_param->pad_u_ + i_kh + input_h;
int input_col = -conv_param->pad_l_ + i_kw + input_w;
if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) {
int offset_x = (input_row * in_w + input_col) * out_ch + i_c;
int offset_dy = idx * out_ch + i_c;
float32x4_t x_03_4 = vld1q_f32(x_addr + offset_x);
float32x4_t dy_03_4 = vld1q_f32(dy_addr + offset_dy);
sum_03_4 = vmlaq_f32(sum_03_4, x_03_4, dy_03_4);
float32x4_t x_47_4 = vld1q_f32(x_addr + offset_x + 4);
float32x4_t dy_47_4 = vld1q_f32(dy_addr + offset_dy + 4);
sum_47_4 = vmlaq_f32(sum_47_4, x_47_4, dy_47_4);
float32x4_t x_9x_4 = vld1q_f32(x_addr + offset_x + 8);
float32x4_t dy_9x_4 = vld1q_f32(dy_addr + offset_dy + 8);
sum_9x_4 = vmlaq_f32(sum_9x_4, x_9x_4, dy_9x_4);
}
}
}
dw[(i_c + 0) * k_spatial + k_idx] = sum_03_4[0];
dw[(i_c + 1) * k_spatial + k_idx] = sum_03_4[1];
dw[(i_c + 2) * k_spatial + k_idx] = sum_03_4[2];
dw[(i_c + 3) * k_spatial + k_idx] = sum_03_4[3];
dw[(i_c + 4) * k_spatial + k_idx] = sum_47_4[0];
dw[(i_c + 5) * k_spatial + k_idx] = sum_47_4[1];
dw[(i_c + 6) * k_spatial + k_idx] = sum_47_4[2];
dw[(i_c + 7) * k_spatial + k_idx] = sum_47_4[3];
dw[(i_c + 8) * k_spatial + k_idx] = sum_9x_4[0];
dw[(i_c + 9) * k_spatial + k_idx] = sum_9x_4[1];
dw[(i_c + 10) * k_spatial + k_idx] = sum_9x_4[2];
dw[(i_c + 11) * k_spatial + k_idx] = sum_9x_4[3];
i_c += 12;
}
return i_c;
}
static int FilterGrad8Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw,
const ConvParameter *conv_param) {
int in_h = conv_param->input_h_;
int in_w = conv_param->input_w_;
int k_h = conv_param->kernel_h_;
int k_w = conv_param->kernel_w_;
int batch = conv_param->output_batch_;
int out_ch = conv_param->output_channel_;
int in_ch = conv_param->input_channel_;
int out_h = conv_param->output_h_;
int out_w = conv_param->output_w_;
int m = out_h * out_w;
int x_size = in_h * in_w * in_ch;
int y_size = out_ch * out_h * out_w;
int k_spatial = k_w * k_h;
int i_kh = k_idx / k_w;
int i_kw = k_idx % k_w;
if ((out_ch - i_c) >= 8) {
float32x4_t sum_03_4 = vdupq_n_f32(0.0f);
float32x4_t sum_47_4 = vdupq_n_f32(0.0f);
for (int b = 0; b < batch; ++b) {
const float *x_addr = &x[b * x_size];
const float *dy_addr = &dy[b * y_size];
for (int i = 0; i < m; i++) {
int idx = i;
int input_h = idx / out_w * conv_param->stride_h_;
int input_w = idx % out_w * conv_param->stride_w_;
int input_row = -conv_param->pad_u_ + i_kh + input_h;
int input_col = -conv_param->pad_l_ + i_kw + input_w;
if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) {
int offset_x = (input_row * in_w + input_col) * out_ch + i_c;
int offset_dy = idx * out_ch + i_c;
float32x4_t x_03_4 = vld1q_f32(x_addr + offset_x);
float32x4_t dy_03_4 = vld1q_f32(dy_addr + offset_dy);
sum_03_4 = vmlaq_f32(sum_03_4, x_03_4, dy_03_4);
float32x4_t x_47_4 = vld1q_f32(x_addr + offset_x + 4);
float32x4_t dy_47_4 = vld1q_f32(dy_addr + offset_dy + 4);
sum_47_4 = vmlaq_f32(sum_47_4, x_47_4, dy_47_4);
}
}
}
dw[(i_c + 0) * k_spatial + k_idx] = sum_03_4[0];
dw[(i_c + 1) * k_spatial + k_idx] = sum_03_4[1];
dw[(i_c + 2) * k_spatial + k_idx] = sum_03_4[2];
dw[(i_c + 3) * k_spatial + k_idx] = sum_03_4[3];
dw[(i_c + 4) * k_spatial + k_idx] = sum_47_4[0];
dw[(i_c + 5) * k_spatial + k_idx] = sum_47_4[1];
dw[(i_c + 6) * k_spatial + k_idx] = sum_47_4[2];
dw[(i_c + 7) * k_spatial + k_idx] = sum_47_4[3];
i_c += 8;
}
return i_c;
}
static int FilterGrad4Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw,
const ConvParameter *conv_param) {
int in_h = conv_param->input_h_;
int in_w = conv_param->input_w_;
int k_h = conv_param->kernel_h_;
int k_w = conv_param->kernel_w_;
int batch = conv_param->output_batch_;
int out_ch = conv_param->output_channel_;
int in_ch = conv_param->input_channel_;
int out_h = conv_param->output_h_;
int out_w = conv_param->output_w_;
int m = out_h * out_w;
int x_size = in_h * in_w * in_ch;
int y_size = out_ch * out_h * out_w;
int k_spatial = k_w * k_h;
int i_kh = k_idx / k_w;
int i_kw = k_idx % k_w;
if ((out_ch - i_c) >= 4) {
float32x4_t sum_4 = vdupq_n_f32(0.0f);
for (int b = 0; b < batch; ++b) {
const float *x_addr = &x[b * x_size];
const float *dy_addr = &dy[b * y_size];
for (int i = 0; i < m; i++) {
int idx = i;
int input_h = idx / out_w * conv_param->stride_h_;
int input_w = idx % out_w * conv_param->stride_w_;
int input_row = -conv_param->pad_u_ + i_kh + input_h;
int input_col = -conv_param->pad_l_ + i_kw + input_w;
if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) {
int offset_x = (input_row * in_w + input_col) * out_ch + i_c;
int offset_dy = idx * out_ch + i_c;
float32x4_t x_4 = vld1q_f32(x_addr + offset_x);
float32x4_t dy_4 = vld1q_f32(dy_addr + offset_dy);
sum_4 = vmlaq_f32(sum_4, x_4, dy_4);
}
}
}
dw[(i_c + 0) * k_spatial + k_idx] = sum_4[0];
dw[(i_c + 1) * k_spatial + k_idx] = sum_4[1];
dw[(i_c + 2) * k_spatial + k_idx] = sum_4[2];
dw[(i_c + 3) * k_spatial + k_idx] = sum_4[3];
i_c += 4;
}
return i_c;
}
static int Filtergrad2Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw,
const ConvParameter *conv_param) {
int in_h = conv_param->input_h_;
int in_w = conv_param->input_w_;
int k_h = conv_param->kernel_h_;
int k_w = conv_param->kernel_w_;
int batch = conv_param->output_batch_;
int out_ch = conv_param->output_channel_;
int in_ch = conv_param->input_channel_;
int out_h = conv_param->output_h_;
int out_w = conv_param->output_w_;
int m = out_h * out_w;
int x_size = in_h * in_w * in_ch;
int y_size = out_ch * out_h * out_w;
int k_spatial = k_w * k_h;
int i_kh = k_idx / k_w;
int i_kw = k_idx % k_w;
if ((out_ch - i_c) >= 2) {
float32x2_t sum_2 = vdup_n_f32(0.0f);
for (int b = 0; b < batch; ++b) {
const float *x_addr = &x[b * x_size];
const float *dy_addr = &dy[b * y_size];
for (int i = 0; i < m; i++) {
int idx = i;
int input_h = idx / out_w * conv_param->stride_h_;
int input_w = idx % out_w * conv_param->stride_w_;
int input_row = -conv_param->pad_u_ + i_kh + input_h;
int input_col = -conv_param->pad_l_ + i_kw + input_w;
if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) {
int offset_x = (input_row * in_w + input_col) * out_ch + i_c;
int offset_dy = idx * out_ch + i_c;
float32x2_t x_4 = vld1_f32(x_addr + offset_x);
float32x2_t dy_4 = vld1_f32(dy_addr + offset_dy);
sum_2 = vmla_f32(sum_2, x_4, dy_4);
}
}
}
dw[(i_c + 0) * k_spatial + k_idx] = sum_2[0];
dw[(i_c + 1) * k_spatial + k_idx] = sum_2[1];
i_c += 2;
}
return i_c += 2;
}
#endif
int ConvDwFilterGrad(const float *x, const float *dy, float *dw, int start, int count,
const ConvParameter *conv_param) {
int in_h = conv_param->input_h_;
int in_w = conv_param->input_w_;
int k_h = conv_param->kernel_h_;
int k_w = conv_param->kernel_w_;
int batch = conv_param->output_batch_;
int out_ch = conv_param->output_channel_;
int in_ch = conv_param->input_channel_;
int out_h = conv_param->output_h_;
int out_w = conv_param->output_w_;
int m = out_h * out_w;
int x_size = in_h * in_w * in_ch;
int y_size = out_ch * out_h * out_w;
int k_spatial = k_w * k_h;
for (int i_k = 0; i_k < count; i_k++) {
int k_idx = start + i_k;
int i_kh = k_idx / k_w;
int i_kw = k_idx % k_w;
int i_c = 0;
#ifdef ENABLE_ARM
i_c = FilterGrad16Arm(x, dy, i_c, k_idx, dw, conv_param);
i_c = FilterGrad12Arm(x, dy, i_c, k_idx, dw, conv_param);
i_c = FilterGrad8Arm(x, dy, i_c, k_idx, dw, conv_param);
i_c = FilterGrad4Arm(x, dy, i_c, k_idx, dw, conv_param);
i_c = Filtergrad2Arm(x, dy, i_c, k_idx, dw, conv_param);
#endif
for (; i_c < out_ch; i_c++) {
float sum = 0;
for (int b = 0; b < batch; ++b) {
const float *x_addr = &x[b * x_size];
const float *dy_addr = &dy[b * y_size];
for (int i = 0; i < m; i++) {
int idx = i;
int input_h = idx / out_w * conv_param->stride_h_;
int input_w = idx % out_w * conv_param->stride_w_;
int input_row = -conv_param->pad_u_ + i_kh + input_h;
int input_col = -conv_param->pad_l_ + i_kw + input_w;
if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) {
int offset_x = (input_row * in_w + input_col) * out_ch + i_c;
int offset_dy = idx * out_ch + i_c;
sum += x_addr[offset_x] * dy_addr[offset_dy];
}
}
}
dw[i_c * k_spatial + k_idx] = sum;
}
}
return 0;
}

View File

@ -0,0 +1,32 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_
#define MINDSPORE_LITE_NNACL_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_
#include <stddef.h>
#include "nnacl/conv_parameter.h"
#ifdef __cplusplus
extern "C" {
#endif
int ConvDwFilterGrad(const float *x, const float *dy, float *dw, int start, int count, const ConvParameter *conv_param);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_

View File

@ -18,6 +18,56 @@
#include "nnacl/fp32_grad/pack_ext.h" #include "nnacl/fp32_grad/pack_ext.h"
#include "nnacl/pack.h" #include "nnacl/pack.h"
void RollingIm2ColPackDwUnitFp32(const float *in_data, const ConvParameter *conv_param, float *data_col_orig,
int real_cal_num, int start) {
const int pad_left = conv_param->pad_l_;
const int pad_up = conv_param->pad_u_;
const int stride_h = conv_param->stride_h_;
const int stride_w = conv_param->stride_w_;
const int dilation_h = conv_param->dilation_h_;
const int dilation_w = conv_param->dilation_w_;
const int kernel_h = conv_param->kernel_h_;
const int kernel_w = conv_param->kernel_w_;
const int in_height = conv_param->input_h_;
const int in_width = conv_param->input_w_;
const int output_w = conv_param->output_w_;
const int channels = conv_param->input_channel_;
const int stride = kernel_h * kernel_w;
int kernel_row, kernel_col;
for (int i = 0; i < real_cal_num; i++) {
int block_start = start + i;
int input_h = block_start / output_w * stride_h;
int input_w = block_start % output_w * stride_w;
float *data_col = data_col_orig + i * channels * stride;
for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
int input_row = -pad_up + kernel_row * dilation_h + input_h;
for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
int input_col = -pad_left + kernel_col * dilation_w + input_w;
if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) {
const int offset = (input_row * in_width + input_col) * channels;
for (int c = 0; c < channels; c++) {
data_col[c * stride] = in_data[offset + c];
}
data_col++;
} else {
for (int c = 0; c < channels; c++) {
data_col[c * stride] = 0;
}
data_col++;
}
}
}
}
}
void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int real_cal_num, void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int real_cal_num,
int start) { int start) {
const int pad_left = conv_param->pad_l_; const int pad_left = conv_param->pad_l_;
@ -90,85 +140,6 @@ void RollingIm2ColPackUnitFp32(const float *input_data, const ConvParameter *con
rolling_im2col_hwc(input_data, packed_input, conv_param, real_cal_num, block_index); rolling_im2col_hwc(input_data, packed_input, conv_param, real_cal_num, block_index);
} }
void im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, bool transpose) {
const int pad_left = conv_param->pad_l_;
const int pad_up = conv_param->pad_u_;
const int stride_h = conv_param->stride_h_;
const int stride_w = conv_param->stride_w_;
const int dilation_h = conv_param->dilation_h_;
const int dilation_w = conv_param->dilation_w_;
const int kernel_h = conv_param->kernel_h_;
const int kernel_w = conv_param->kernel_w_;
const int in_height = (transpose) ? conv_param->output_h_ : conv_param->input_h_;
const int in_width = (transpose) ? conv_param->output_w_ : conv_param->input_w_;
const int output_h = (transpose) ? conv_param->input_h_ : conv_param->output_h_;
const int output_w = (transpose) ? conv_param->input_w_ : conv_param->output_w_;
const int tot_channels = (transpose) ? conv_param->output_channel_ : conv_param->input_channel_;
const int channels = tot_channels / conv_param->group_;
int channel, kernel_row, kernel_col, output_rows, output_col;
if (transpose) {
for (channel = 0; channel < channels; channel++) {
for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
int input_row = -pad_up + kernel_row * dilation_h;
for (output_rows = output_h; output_rows; output_rows--) {
if (!((unsigned)(input_row) < (unsigned)(in_height))) {
for (output_col = output_w; output_col; output_col--) {
*(data_row++) = 0;
}
} else {
int input_col = -pad_left + kernel_col * dilation_w;
for (output_col = output_w; output_col; output_col--) {
if (((unsigned)(input_col) < (unsigned)(in_width))) {
const int offset = (input_row * in_width + input_col) * tot_channels + channel;
*(data_row++) = in_data[offset];
} else {
*(data_row++) = 0;
}
input_col += stride_w;
}
}
input_row += stride_h;
}
}
}
}
} else {
for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
for (channel = 0; channel < channels; channel++) {
int input_row = -pad_up + kernel_row * dilation_h;
for (output_rows = output_h; output_rows; output_rows--) {
if (!((unsigned)(input_row) < (unsigned)(in_height))) {
for (output_col = output_w; output_col; output_col--) {
*(data_row++) = 0;
}
} else {
int input_col = -pad_left + kernel_col * dilation_w;
for (output_col = output_w; output_col; output_col--) {
if (((unsigned)(input_col) < (unsigned)(in_width))) {
const int offset = (input_row * in_width + input_col) * tot_channels + channel;
*(data_row++) = in_data[offset];
} else {
*(data_row++) = 0;
}
input_col += stride_w;
}
}
input_row += stride_h;
}
}
}
}
}
}
void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start) { void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start) {
const int pad_left = conv_param->pad_l_; const int pad_left = conv_param->pad_l_;
const int pad_up = conv_param->pad_u_; const int pad_up = conv_param->pad_u_;

View File

@ -26,6 +26,9 @@ extern "C" {
void RollingIm2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, void RollingIm2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input,
int real_cal_num, int block_index); int real_cal_num, int block_index);
void RollingIm2ColPackDwUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input,
int real_cal_num, int block_index);
void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int rows, int start); void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int rows, int start);
void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start); void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start);
void rolling_col2im_hwc(const float *data_col, float *data_im, const ConvParameter *conv_param, int rows, int start); void rolling_col2im_hwc(const float *data_col, float *data_im, const ConvParameter *conv_param, int rows, int start);

View File

@ -18,7 +18,7 @@
#include <float.h> #include <float.h>
#include "nnacl/fp32_grad/pooling_grad.h" #include "nnacl/fp32_grad/pooling_grad.h"
void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id) { void AvgPoolingGrad(const float *input_ptr, float *output_ptr, int count, PoolingParameter *pooling_param) {
int stride_w = pooling_param->stride_w_; int stride_w = pooling_param->stride_w_;
int stride_h = pooling_param->stride_h_; int stride_h = pooling_param->stride_h_;
int pad_w = pooling_param->pad_l_; int pad_w = pooling_param->pad_l_;
@ -30,29 +30,58 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter
int in_h = pooling_param->input_h_; int in_h = pooling_param->input_h_;
int output_w = pooling_param->output_w_; int output_w = pooling_param->output_w_;
int output_h = pooling_param->output_h_; int output_h = pooling_param->output_h_;
int output_batch = pooling_param->output_batch_;
memset(output_ptr, 0, in_h * in_w * channel * output_batch * sizeof(float)); const float kk = 1.0f / (float)(win_h * win_w);
float kk = (float)(win_h * win_w); #if ENABLE_ARM
for (int ib = 0; ib < output_batch; ib++) { const float32x4_t factor = vdupq_n_f32(kk);
#endif
for (int ib = 0; ib < count; ib++) {
float *out = &output_ptr[(ib * in_h * in_w * channel)]; float *out = &output_ptr[(ib * in_h * in_w * channel)];
const float *inPtr = &input_ptr[(ib * output_h * output_w * channel)]; const float *inPtr = &input_ptr[(ib * output_h * output_w * channel)];
// iterate over yt // iterate over yt
for (int yh = 0; yh < output_h; yh++) { for (int yh = 0; yh < output_h; yh++) {
int over_h = pad_h - yh * stride_h;
int kh_s = MSMAX(0, over_h);
int kh_e = MSMIN(win_h, in_h + over_h);
for (int yw = 0; yw < output_w; yw++) { for (int yw = 0; yw < output_w; yw++) {
for (int ic = 0; ic < channel; ic++) { int over_w = pad_w - yw * stride_w;
int kw_s = MSMAX(0, over_w);
int kw_e = MSMIN(win_w, in_w + over_w);
int ic = 0;
for (; ic < channel - 4; ic += 4) {
int idx = (yw + yh * output_w) * channel + ic; int idx = (yw + yh * output_w) * channel + ic;
float delta = inPtr[idx] / kk; #ifdef ENABLE_ARM
for (int kh = 0; kh < win_h; kh++) { float32x4_t in = vld1q_f32(inPtr + idx);
float32x4_t delta = vmulq_f32(in, factor);
#else
float delta[4] = {inPtr[idx], inPtr[idx + 1], inPtr[idx + 2], inPtr[idx + 3]};
for (int i = 0; i < 4; i++) delta[i] *= kk;
#endif
for (int kh = kh_s; kh < kh_e; kh++) {
int xh = yh * stride_h + kh - pad_h; int xh = yh * stride_h + kh - pad_h;
if ((xh < 0) || (xh >= in_h)) { for (int kw = kw_s; kw < kw_e; kw++) {
continue;
}
for (int kw = 0; kw < win_w; kw++) {
int xw = yw * stride_w + kw - pad_w; int xw = yw * stride_w + kw - pad_w;
if ((xw < 0) || (xw >= in_w)) { #ifdef ENABLE_ARM
continue; float *out_vec = out + (xw + in_w * xh) * channel + ic;
float32x4_t outr = vld1q_f32(out + (xw + in_w * xh) * channel + ic);
float32x4_t outs = vaddq_s32(outr, delta);
vst1q_f32(out_vec, outs);
#else
for (int i = 0; i < 4; i++) {
out[(xw + in_w * xh) * channel + ic + i] += ((float *)&delta)[i];
} }
#endif
}
}
}
for (; ic < channel; ic++) {
int idx = (yw + yh * output_w) * channel + ic;
float delta = inPtr[idx] * kk;
for (int kh = kh_s; kh < kh_e; kh++) {
int xh = yh * stride_h + kh - pad_h;
for (int kw = kw_s; kw < kw_e; kw++) {
int xw = yw * stride_w + kw - pad_w;
out[(xw + in_w * xh) * channel + ic] += delta; out[(xw + in_w * xh) * channel + ic] += delta;
} }
} }
@ -62,8 +91,17 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter
} }
} }
void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy_ptr, float *output_ptr, #ifdef ENABLE_ARM
PoolingParameter *pooling_param, int task_id) { static int32x4_t MaxIndex(float32x4_t in, float32x4_t *max, int32x4_t index, int32x4_t prev_index) {
uint32x4_t res = vcgtq_f32(in, *max);
uint32x4_t m_index = vbslq_f32(res, index, prev_index);
*max = vbslq_f32(res, in, *max);
return m_index;
}
#endif
void MaxPoolingGrad(const float *input_ptr, const float *dy_ptr, float *output_ptr, int output_batch,
PoolingParameter *pooling_param) {
int stride_w = pooling_param->stride_w_; int stride_w = pooling_param->stride_w_;
int stride_h = pooling_param->stride_h_; int stride_h = pooling_param->stride_h_;
int pad_w = pooling_param->pad_l_; int pad_w = pooling_param->pad_l_;
@ -75,36 +113,71 @@ void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy
int in_h = pooling_param->input_h_; int in_h = pooling_param->input_h_;
int output_w = pooling_param->output_w_; int output_w = pooling_param->output_w_;
int output_h = pooling_param->output_h_; int output_h = pooling_param->output_h_;
int output_batch = pooling_param->output_batch_;
memset(output_ptr, 0, in_h * in_w * channel * output_batch * sizeof(float));
for (int ib = 0; ib < output_batch; ib++) { for (int ib = 0; ib < output_batch; ib++) {
float *out = &output_ptr[(ib * in_h * in_w * channel)]; float *out = &output_ptr[(ib * in_h * in_w * channel)];
const float *inPtr = (const float *)(&input_ptr[(ib * in_h * in_w * channel)]); const float *inPtr = &input_ptr[(ib * in_h * in_w * channel)];
const float *dyPtr = (const float *)(&dy_ptr[(ib * output_h * output_w * channel)]); const float *dyPtr = &dy_ptr[(ib * output_h * output_w * channel)];
for (int yh = 0; yh < output_h; yh++) { for (int yh = 0; yh < output_h; yh++) {
int over_h = pad_h - yh * stride_h;
int kh_s = MSMAX(0, over_h);
int kh_e = MSMIN(win_h, in_h + over_h);
for (int yw = 0; yw < output_w; yw++) { for (int yw = 0; yw < output_w; yw++) {
for (int ic = 0; ic < channel; ic++) { int over_w = pad_w - yw * stride_w;
int kw_s = MSMAX(0, over_w);
int kw_e = MSMIN(win_w, in_w + over_w);
int ic = 0;
for (; ic < channel - 4; ic += 4) {
int idx = (yw + yh * output_w) * channel + ic; int idx = (yw + yh * output_w) * channel + ic;
#ifdef ENABLE_ARM
float delta = dyPtr[idx]; uint32x4_t max_idx = vdupq_n_u32(0);
float32x4_t max_val = vdupq_n_f32(-FLT_MAX);
float32x4_t delta = vld1q_f32(dyPtr + idx);
#else
float delta[4] = {dyPtr[idx], dyPtr[idx + 1], dyPtr[idx + 2], dyPtr[idx + 3]};
float max_val[4] = {-FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX};
int max_idx[4] = {0};
#endif
for (int kh = kh_s; kh < kh_e; kh++) {
int xh = yh * stride_h + kh - pad_h;
for (int kw = kw_s; kw < kw_e; kw++) {
int xw = yw * stride_w + kw - pad_w;
int val_idx = (xw + in_w * xh) * channel + ic;
#ifdef ENABLE_ARM
unsigned int val_idx_vec[] = {val_idx, val_idx + 1, val_idx + 2, val_idx + 3};
uint32x4_t index = vld1q_u32(val_idx_vec);
float32x4_t in = vld1q_f32(inPtr + val_idx);
max_idx = MaxIndex(in, &max_val, index, max_idx);
#else
float val[4] = {inPtr[val_idx], inPtr[val_idx + 1], inPtr[val_idx + 2], inPtr[val_idx + 3]};
for (int i = 0; i < 4; i++) {
if (val[i] > max_val[i]) {
max_val[i] = val[i];
max_idx[i] = val_idx + i;
}
}
#endif
}
}
for (int i = 0; i < 4; i++) {
out[((int *)&max_idx)[i]] += ((float *)&delta)[i];
}
}
for (; ic < channel; ic++) {
float max_val = -FLT_MAX; float max_val = -FLT_MAX;
int max_idx = 0; int max_idx = 0;
for (int kh = 0; kh < win_h; kh++) { int idx = (yw + yh * output_w) * channel + ic;
float delta = dyPtr[idx];
for (int kh = kh_s; kh < kh_e; kh++) {
int xh = yh * stride_h + kh - pad_h; int xh = yh * stride_h + kh - pad_h;
if ((xh < 0) || (xh >= in_h)) { int loop = kw_e - kw_s;
continue; for (int kw = 0; kw < loop; kw++) {
} int xw = yw * stride_w + kw + kw_s - pad_w;
for (int kw = 0; kw < win_w; kw++) { int val_idx = (xw + in_w * xh) * channel + ic;
int xw = yw * stride_w + kw - pad_w; float val = inPtr[val_idx];
if ((xw < 0) || (xw >= in_w)) { if (val > max_val) {
continue; max_val = val;
} max_idx = val_idx;
if (inPtr[(xw + in_w * xh) * channel + ic] > max_val) {
max_val = inPtr[(xw + in_w * xh) * channel + ic];
max_idx = (xw + in_w * xh) * channel + ic;
} }
} }
} }

View File

@ -22,9 +22,9 @@
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id); void AvgPoolingGrad(const float *input_ptr, float *output_ptr, int count, PoolingParameter *pooling_param);
void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy_ptr, float *output_ptr, void MaxPoolingGrad(const float *input_ptr, const float *dy_ptr, float *output_ptr, int output_batch,
PoolingParameter *pooling_param, int task_id); PoolingParameter *pooling_param);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -0,0 +1,61 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp32_grad/strided_slice_grad.h"
#include "nnacl/errorcode.h"
static size_t CalcIndex(const int *shape, size_t size, int i, size_t pos) {
size_t res = 1;
for (size_t j = 0; j < size; j++) {
res *= shape[(i + 1) + j];
}
return (pos / res % shape[i]);
}
int DoStridedSliceGrad(const float *inputs, float *output, const int *dx_shape, StridedSliceParameter *param) {
if (inputs == NULL || output == NULL || param == NULL) {
return NNACL_NULL_PTR;
}
if (param->num_axes_ > DIMENSION_7D) {
return NNACL_PARAM_INVALID;
}
size_t size = 1;
int *s = param->strides_;
int *b = param->begins_;
for (int i = 0; i < DIMENSION_7D; i++) {
size *= param->in_shape_[i];
}
for (size_t pos = 0; pos < size; pos++) {
size_t i = CalcIndex(param->in_shape_, 6, 0, pos);
size_t j = CalcIndex(param->in_shape_, 5, 1, pos);
size_t k = CalcIndex(param->in_shape_, 4, 2, pos);
size_t l = CalcIndex(param->in_shape_, 3, 3, pos);
size_t m = CalcIndex(param->in_shape_, 2, 4, pos);
size_t n = CalcIndex(param->in_shape_, 1, 5, pos);
size_t o = CalcIndex(param->in_shape_, 0, 6, pos);
size_t input_idx =
(i * s[0] + b[0]) * dx_shape[1] * dx_shape[2] * dx_shape[3] * dx_shape[4] * dx_shape[5] * dx_shape[6] +
(j * s[1] + b[1]) * dx_shape[2] * dx_shape[3] * dx_shape[4] * dx_shape[5] * dx_shape[6] +
(k * s[2] + b[2]) * dx_shape[3] * dx_shape[4] * dx_shape[5] * dx_shape[6] +
(l * s[3] + b[3]) * dx_shape[4] * dx_shape[5] * dx_shape[6] + (m * s[4] + b[4]) * dx_shape[5] * dx_shape[6] +
(n * s[5] + b[5]) * dx_shape[6] + (o * s[6] + b[6]);
output[input_idx] = inputs[pos];
}
return NNACL_OK;
}

View File

@ -0,0 +1,30 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_STRIDED_SLICE_GRAD_H_
#define MINDSPORE_LITE_NNACL_FP32_GRAD_STRIDED_SLICE_GRAD_H_
#include "nnacl/op_base.h"
#include "nnacl/strided_slice_parameter.h"
#ifdef __cplusplus
extern "C" {
#endif
int DoStridedSliceGrad(const float *inputs, float *output, const int *dx_shape, StridedSliceParameter *param);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP32_GRAD_STRIDED_SLICE_GRAD_H_

View File

@ -53,6 +53,7 @@
#define DIMENSION_4D 4 #define DIMENSION_4D 4
#define DIMENSION_6D 6 #define DIMENSION_6D 6
#define DIMENSION_7D 7
#define kInputIndex 0 #define kInputIndex 0
#define kWeightIndex 1 #define kWeightIndex 1
#define kBiasIndex 2 #define kBiasIndex 2

View File

@ -273,6 +273,7 @@ union PrimitiveType {
RandomStandardNormal, RandomStandardNormal,
CropAndResize, CropAndResize,
Erf, Erf,
StridedSliceGrad
} }
enum QuantType: int { enum QuantType: int {

View File

@ -1259,6 +1259,18 @@ table RandomStandardNormal {
table CropAndResize { table CropAndResize {
method : ResizeMethod; method : ResizeMethod;
extrapolation_value : float; extrapolation_value : float;
}
table StridedSliceGrad {
beginMask: int;
endMask: int;
ellipsisMask: int;
newAxisMask: int;
shrinkAxisMask: int;
begin: [int];
end: [int];
stride: [int];
isScale: [int];
} }
table Erf { table Erf {

View File

@ -31,7 +31,7 @@ int FlattenGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *>
MS_LOG(ERROR) << "FlattenGrad input or output is null!"; MS_LOG(ERROR) << "FlattenGrad input or output is null!";
return RET_ERROR; return RET_ERROR;
} }
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) { if (inputs_.size() != kDoubleNum || outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size(); MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size();
return RET_INPUT_TENSOR_ERROR; return RET_INPUT_TENSOR_ERROR;
} }
@ -42,16 +42,15 @@ int FlattenGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *>
return RET_INFER_INVALID; return RET_INFER_INVALID;
} }
auto input_shape = input->shape(); auto output_size = inputs_.at(1)->shape().at(0);
std::vector<int> output_shape(2); std::vector<int> output_shape(output_size);
output_shape.at(0) = input_shape.at(0); for (int i = 0; i < output_size; i++) {
output_shape.at(1) = 1; output_shape.at(i) = static_cast<int *>(inputs_.at(1)->data_c())[i];
for (size_t i = 1; i < input_shape.size(); i++) {
output_shape.at(1) *= input_shape.at(i);
} }
output->set_shape(output_shape); output->set_shape(output_shape);
return RET_OK; return RET_OK;
} }
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
int FlattenGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { int FlattenGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) { if (this->primitive_ == nullptr) {

View File

@ -91,6 +91,8 @@ int PoolingGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
attr->poolingMode = schema::PoolMode_MEAN_POOLING; attr->poolingMode = schema::PoolMode_MEAN_POOLING;
} else if (prim.instance_name() == "AvgPoolGradGpu") { } else if (prim.instance_name() == "AvgPoolGradGpu") {
attr->poolingMode = schema::PoolMode_MEAN_POOLING; attr->poolingMode = schema::PoolMode_MEAN_POOLING;
} else if (prim.instance_name() == "AvgPoolGradCpu") {
attr->poolingMode = schema::PoolMode_MEAN_POOLING;
} else { } else {
attr->poolingMode = schema::PoolMode_MAX_POOLING; attr->poolingMode = schema::PoolMode_MAX_POOLING;
} }

View File

@ -202,6 +202,7 @@
#include "src/ops/smooth_l1_loss_grad.h" #include "src/ops/smooth_l1_loss_grad.h"
#include "src/ops/sigmoid_cross_entropy_with_logits.h" #include "src/ops/sigmoid_cross_entropy_with_logits.h"
#include "src/ops/sigmoid_cross_entropy_with_logits_grad.h" #include "src/ops/sigmoid_cross_entropy_with_logits_grad.h"
#include "src/ops/strided_slice_grad.h"
#endif #endif
#endif #endif
namespace mindspore { namespace mindspore {
@ -724,6 +725,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<SigmoidCrossEntropyWithLogitsGrad>(prim, inputs, quantType); return NewPrimitiveC<SigmoidCrossEntropyWithLogitsGrad>(prim, inputs, quantType);
} else if (op_type == "Pad") { } else if (op_type == "Pad") {
return NewPrimitiveC<Pad>(prim, inputs, quantType); return NewPrimitiveC<Pad>(prim, inputs, quantType);
} else if (op_type == "StridedSliceGrad") {
return NewPrimitiveC<StridedSliceGrad>(prim, inputs, quantType);
#else #else
} else if (op_type == "Conv2DBackpropInput") { } else if (op_type == "Conv2DBackpropInput") {
return NewPrimitiveC<DeConv2D>(prim, inputs, quantType); return NewPrimitiveC<DeConv2D>(prim, inputs, quantType);
@ -1102,6 +1105,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new (std::nothrow) SigmoidCrossEntropyWithLogits(primitive); return new (std::nothrow) SigmoidCrossEntropyWithLogits(primitive);
case schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad: case schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad:
return new (std::nothrow) SigmoidCrossEntropyWithLogitsGrad(primitive); return new (std::nothrow) SigmoidCrossEntropyWithLogitsGrad(primitive);
case schema::PrimitiveType_StridedSliceGrad:
return new (std::nothrow) StridedSliceGrad(primitive);
#endif #endif
default: default:
MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type); MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type);

View File

@ -0,0 +1,266 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/strided_slice_grad.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int StridedSliceGrad::GetBeginMask() const { return this->primitive_->value.AsStridedSliceGrad()->beginMask; }
int StridedSliceGrad::GetEndMask() const { return this->primitive_->value.AsStridedSliceGrad()->endMask; }
int StridedSliceGrad::GetEllipsisMask() const { return this->primitive_->value.AsStridedSliceGrad()->ellipsisMask; }
int StridedSliceGrad::GetNewAxisMask() const { return this->primitive_->value.AsStridedSliceGrad()->newAxisMask; }
int StridedSliceGrad::GetShrinkAxisMask() const { return this->primitive_->value.AsStridedSliceGrad()->shrinkAxisMask; }
std::vector<int> StridedSliceGrad::GetBegin() const { return this->primitive_->value.AsStridedSliceGrad()->begin; }
std::vector<int> StridedSliceGrad::GetEnd() const { return this->primitive_->value.AsStridedSliceGrad()->end; }
std::vector<int> StridedSliceGrad::GetStride() const { return this->primitive_->value.AsStridedSliceGrad()->stride; }
std::vector<int> StridedSliceGrad::GetIsScale() const { return this->primitive_->value.AsStridedSliceGrad()->isScale; }
void StridedSliceGrad::SetBeginMask(int begin_mask) {
this->primitive_->value.AsStridedSliceGrad()->beginMask = begin_mask;
}
void StridedSliceGrad::SetEndMask(int end_mask) { this->primitive_->value.AsStridedSliceGrad()->endMask = end_mask; }
void StridedSliceGrad::SetEllipsisMask(int ellipsis_mask) {
this->primitive_->value.AsStridedSliceGrad()->ellipsisMask = ellipsis_mask;
}
void StridedSliceGrad::SetNewAxisMask(int new_axis_mask) {
this->primitive_->value.AsStridedSliceGrad()->newAxisMask = new_axis_mask;
}
void StridedSliceGrad::SetShrinkAxisMask(int shrink_axis_mask) {
this->primitive_->value.AsStridedSliceGrad()->shrinkAxisMask = shrink_axis_mask;
}
void StridedSliceGrad::SetBegin(const std::vector<int> &begin) {
this->primitive_->value.AsStridedSliceGrad()->begin = begin;
}
void StridedSliceGrad::SetEnd(const std::vector<int> &end) { this->primitive_->value.AsStridedSliceGrad()->end = end; }
void StridedSliceGrad::SetStride(const std::vector<int> &stride) {
this->primitive_->value.AsStridedSliceGrad()->stride = stride;
}
void StridedSliceGrad::SetIsScale(const std::vector<int> &is_scale) {
this->primitive_->value.AsStridedSliceGrad()->isScale = is_scale;
}
int StridedSliceGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_StridedSliceGrad;
}
if (this->primitive_->value.type != schema::PrimitiveType_StridedSliceGrad) {
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::StridedSliceGradT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new StridedSliceGrad failed";
return RET_ERROR;
}
attr->beginMask = CastToInt(prim.GetAttr("begin_mask")).front();
attr->endMask = CastToInt(prim.GetAttr("end_mask")).front();
attr->ellipsisMask = CastToInt(prim.GetAttr("ellipsis_mask")).front();
attr->newAxisMask = CastToInt(prim.GetAttr("new_axis_mask")).front();
attr->shrinkAxisMask = CastToInt(prim.GetAttr("shrink_axis_mask")).front();
auto inputNodeFirst = inputs[kAnfPopulaterInputNumOne];
std::vector<int> beginVec;
GetAttrDataFromInput(inputNodeFirst, &beginVec);
attr->begin = beginVec;
auto inputNodeSecond = inputs[kAnfPopulaterInputNumTwo];
std::vector<int> endVec;
GetAttrDataFromInput(inputNodeSecond, &endVec);
attr->end = endVec;
auto inputNodeThird = inputs[kAnfPopulaterInputNumThree];
std::vector<int> strideVec;
GetAttrDataFromInput(inputNodeThird, &strideVec);
attr->stride = strideVec;
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int StridedSliceGrad::GetBeginMask() const { return this->primitive_->value_as_StridedSliceGrad()->beginMask(); }
int StridedSliceGrad::GetEndMask() const { return this->primitive_->value_as_StridedSliceGrad()->endMask(); }
int StridedSliceGrad::GetEllipsisMask() const { return this->primitive_->value_as_StridedSliceGrad()->ellipsisMask(); }
int StridedSliceGrad::GetNewAxisMask() const { return this->primitive_->value_as_StridedSliceGrad()->newAxisMask(); }
int StridedSliceGrad::GetShrinkAxisMask() const {
return this->primitive_->value_as_StridedSliceGrad()->shrinkAxisMask();
}
std::vector<int> StridedSliceGrad::GetBegin() const {
auto fb_vector = this->primitive_->value_as_StridedSliceGrad()->begin();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
std::vector<int> StridedSliceGrad::GetEnd() const {
auto fb_vector = this->primitive_->value_as_StridedSliceGrad()->end();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
std::vector<int> StridedSliceGrad::GetStride() const {
auto fb_vector = this->primitive_->value_as_StridedSliceGrad()->stride();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
std::vector<int> StridedSliceGrad::GetIsScale() const {
auto fb_vector = this->primitive_->value_as_StridedSliceGrad()->isScale();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
int StridedSliceGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_StridedSliceGrad();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_StridedSliceGrad return nullptr";
return RET_ERROR;
}
std::vector<int32_t> begin;
if (attr->begin() != nullptr) {
for (int i = 0; i < static_cast<int>(attr->begin()->size()); i++) {
begin.push_back(attr->begin()->data()[i]);
}
}
std::vector<int32_t> end;
if (attr->end() != nullptr) {
for (int i = 0; i < static_cast<int>(attr->end()->size()); i++) {
end.push_back(attr->end()->data()[i]);
}
}
std::vector<int32_t> stride;
if (attr->stride() != nullptr) {
for (int i = 0; i < static_cast<int>(attr->stride()->size()); i++) {
stride.push_back(attr->stride()->data()[i]);
}
}
std::vector<int32_t> isScale;
if (attr->isScale() != nullptr) {
for (int i = 0; i < static_cast<int>(attr->isScale()->size()); i++) {
isScale.push_back(attr->isScale()->data()[i]);
}
}
auto val_offset =
schema::CreateStridedSliceGradDirect(*fbb, attr->beginMask(), attr->endMask(), attr->ellipsisMask(),
attr->newAxisMask(), attr->shrinkAxisMask(), &begin, &end, &stride, &isScale);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_StridedSliceGrad, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *StridedSliceGradCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<StridedSliceGrad>(primitive);
}
Registry StridedSliceGradRegistry(schema::PrimitiveType_StridedSliceGrad, StridedSliceGradCreator);
#endif
namespace {
constexpr size_t kStridedSliceGradOutputNum = 1;
constexpr size_t kStridedSliceGradMultiInputNumMax = 5;
} // namespace
int StridedSliceGrad::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
MS_ASSERT(this->primitive_ != nullptr);
if (outputs.size() != kStridedSliceGradOutputNum) {
MS_LOG(ERROR) << "Invalid output size:" << outputs.size();
return RET_PARAM_INVALID;
}
if (inputs.size() != kStridedSliceGradMultiInputNumMax) {
MS_LOG(ERROR) << "Invalid input size " << inputs.size();
return RET_PARAM_INVALID;
}
auto input = inputs.at(0);
outputs.front()->set_data_type(input->data_type());
outputs.at(0)->set_format(input->format());
MS_ASSERT(input != nullptr);
auto input_shape = input->shape();
auto inferflag = infer_flag();
in_shape_.clear();
if (inferflag) {
in_shape_.assign(input_shape.begin(), input_shape.end());
}
begins_.clear();
ends_.clear();
strides_.clear();
if (!CheckInputs(inputs)) {
MS_LOG(DEBUG) << "Do infer shape in runtime.";
return RET_INFER_INVALID;
}
// input order: dy, shapex, begins, ends, strides.
auto begin_tensor = inputs.at(2);
int *begin_data = reinterpret_cast<int *>(begin_tensor->MutableData());
auto end_tensor = inputs.at(3);
int *end_data = reinterpret_cast<int *>(end_tensor->MutableData());
auto stride_tensor = inputs.at(4);
int *stride_data = reinterpret_cast<int *>(stride_tensor->MutableData());
if (begin_data == nullptr || end_data == nullptr || stride_data == nullptr) {
return RET_INFER_ERR;
}
ndim_ = begin_tensor->ElementsNum();
for (size_t i = 0; i < ndim_; ++i) {
begins_.emplace_back(begin_data[i]);
ends_.emplace_back(end_data[i]);
strides_.emplace_back(stride_data[i]);
}
// set all mask to original input shape
begins_mask_.resize(ndim_);
ends_mask_.resize(ndim_);
ellipsis_mask_.resize(ndim_);
new_axis_mask_.resize(ndim_);
shrink_axis_mask_.resize(ndim_);
for (size_t i = 0; i < ndim_; i++) {
begins_mask_.at(i) = static_cast<uint32_t>(GetBeginMask()) & (1 << i);
ends_mask_.at(i) = static_cast<uint32_t>(GetEndMask()) & (1 << i);
ellipsis_mask_.at(i) = static_cast<uint32_t>(GetEllipsisMask()) & (1 << i);
new_axis_mask_.at(i) = static_cast<uint32_t>(GetNewAxisMask()) & (1 << i);
shrink_axis_mask_.at(i) = static_cast<uint32_t>(GetShrinkAxisMask()) & (1 << i);
}
ApplyNewAxisMask();
ApplyBeginMask();
ApplyEndMask();
ApplyEllipsisMask();
if (!inferflag) {
return RET_OK;
}
auto output_size = inputs.at(1)->shape().at(0);
std::vector<int> output_shape;
MS_ASSERT(inputs.at(1)->MutableData() != nullptr);
for (int i = 0; i < output_size; i++) {
output_shape.push_back(static_cast<int *>(inputs.at(1)->MutableData())[i]);
}
outputs.front()->set_shape(output_shape);
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,64 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_OPS_STRIDED_SLICE_GRAD_H_
#define MINDSPORE_LITE_SRC_OPS_STRIDED_SLICE_GRAD_H_
#include <vector>
#include <set>
#include <cmath>
#include <memory>
#include "src/ops/strided_slice.h"
namespace mindspore {
namespace lite {
class StridedSliceGrad : public StridedSlice {
public:
StridedSliceGrad() = default;
~StridedSliceGrad() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(StridedSliceGrad, StridedSlice);
explicit StridedSliceGrad(schema::PrimitiveT *primitive) : StridedSlice(primitive) {}
void SetBeginMask(int begin_mask);
void SetEndMask(int end_mask);
void SetEllipsisMask(int ellipsis_mask);
void SetNewAxisMask(int new_axis_mask);
void SetShrinkAxisMask(int shrink_axis_mask);
void SetBegin(const std::vector<int> &begin);
void SetEnd(const std::vector<int> &end);
void SetStride(const std::vector<int> &stride);
void SetIsScale(const std::vector<int> &is_scale);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
// bool CheckInputs(std::vector<lite::Tensor *> inputs_);
int GetBeginMask() const;
int GetEndMask() const;
int GetEllipsisMask() const;
int GetNewAxisMask() const;
int GetShrinkAxisMask() const;
std::vector<int> GetBegin() const;
std::vector<int> GetEnd() const;
std::vector<int> GetStride() const;
std::vector<int> GetIsScale() const;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_OPS_STRIDED_SLICE_GRAD_H_

View File

@ -91,10 +91,12 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
// init bias // init bias
size_t new_bias_size = oc4 * C4NUM * sizeof(float); size_t new_bias_size = oc4 * C4NUM * sizeof(float);
bias_data_ = reinterpret_cast<float *>(malloc(new_bias_size));
if (bias_data_ == nullptr) { if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "malloc bias_data_ failed."; bias_data_ = reinterpret_cast<float *>(malloc(new_bias_size));
return RET_MEMORY_FAILED; if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "malloc bias_data_ failed.";
return RET_MEMORY_FAILED;
}
} }
memset(bias_data_, 0, new_bias_size); memset(bias_data_, 0, new_bias_size);
if (in_tensors_.size() == kInputSize2) { if (in_tensors_.size() == kInputSize2) {

View File

@ -91,10 +91,6 @@ int FusedBatchnormCPUKernel::Run() {
memcpy(scale_, scale, in_tensors_[1]->Size()); memcpy(scale_, scale, in_tensors_[1]->Size());
memcpy(offset_, offset, in_tensors_[2]->Size()); memcpy(offset_, offset, in_tensors_[2]->Size());
// save for next iteration
memcpy(in_tensors_[3]->MutableData(), save_mean, in_tensors_[3]->Size());
memcpy(in_tensors_[4]->MutableData(), save_variance, in_tensors_[4]->Size());
trained_ = true; // trained at least once trained_ = true; // trained at least once
} }
auto ret = ParallelLaunch(this->context_->thread_pool_, BatchNormRun, this, op_parameter_->thread_num_); auto ret = ParallelLaunch(this->context_->thread_pool_, BatchNormRun, this, op_parameter_->thread_num_);

View File

@ -40,17 +40,16 @@ int ApplyMomentumCPUKernel::Execute(int task_id) {
size_t stride = UP_DIV(length, thread_count_); size_t stride = UP_DIV(length, thread_count_);
size_t count = MSMIN(stride, length - stride * task_id); size_t count = MSMIN(stride, length - stride * task_id);
size_t start = stride * task_id; size_t start = stride * task_id;
size_t end = start + count; size_t end = start + count;
if (apply_momentum_param_->use_nesterov_) { if (apply_momentum_param_->use_nesterov_) {
for (size_t i = start; i < end; ++i) { for (size_t i = start; i < end; i++) {
accumulate[i] = accumulate[i] * moment + gradient[i]; accumulate[i] = accumulate[i] * moment + gradient[i];
weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate; weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate;
} }
} else { } else {
for (size_t i = start; i < end; ++i) { for (size_t i = start; i < end; i++) {
accumulate[i] = accumulate[i] * moment + gradient[i]; accumulate[i] = accumulate[i] * moment + gradient[i];
weight[i] -= accumulate[i] * learning_rate; weight[i] -= accumulate[i] * learning_rate;
} }

View File

@ -18,6 +18,10 @@
#include <math.h> #include <math.h>
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include <thread>
#include <fstream>
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "nnacl/fp32_grad/batch_norm.h" #include "nnacl/fp32_grad/batch_norm.h"
@ -34,7 +38,8 @@ namespace mindspore::kernel {
int BNGradCPUKernel::ReSize() { int BNGradCPUKernel::ReSize() {
auto *input_x = in_tensors_.at(1); auto *input_x = in_tensors_.at(1);
int channels = input_x->shape().at(kNHWC_C); int channels = input_x->shape().at(kNHWC_C);
set_workspace_size(2 * channels * sizeof(float)); ws_size_ = 2 * channels;
set_workspace_size(ws_size_ * sizeof(float));
return RET_OK; return RET_OK;
} }
@ -46,7 +51,9 @@ int BNGradCPUKernel::Execute(int task_id) {
auto *input_scale = in_tensors_.at(2); auto *input_scale = in_tensors_.at(2);
auto *input_mean = in_tensors_.at(3); auto *input_mean = in_tensors_.at(3);
auto *input_var = in_tensors_.at(4); auto *input_var = in_tensors_.at(4);
auto bn_param = reinterpret_cast<BNGradParameter *>(op_parameter_);
int stage = stage_;
int thread_num = thread_num_;
float *save_mean = reinterpret_cast<float *>(input_mean->MutableData()); float *save_mean = reinterpret_cast<float *>(input_mean->MutableData());
float *save_var = reinterpret_cast<float *>(input_var->MutableData()); float *save_var = reinterpret_cast<float *>(input_var->MutableData());
@ -58,26 +65,57 @@ int BNGradCPUKernel::Execute(int task_id) {
int32_t spatial = input_x->Height() * input_x->Width(); int32_t spatial = input_x->Height() * input_x->Width();
float *workspace_temp = static_cast<float *>(workspace()); float *workspace_temp = static_cast<float *>(workspace());
std::fill(workspace_temp, workspace_temp + workspace_size() / sizeof(*workspace_temp), 0.f);
float *dxhat_sum = workspace_temp; float *dxhat_sum = workspace_temp;
float *dxhathat_sum = dxhat_sum + channels; float *dxhathat_sum = dxhat_sum + channels;
float *x = reinterpret_cast<float *>(input_x->MutableData()); float *x = reinterpret_cast<float *>(input_x->MutableData());
float *yt = reinterpret_cast<float *>(input_yt->MutableData()); float *yt = reinterpret_cast<float *>(input_yt->MutableData());
float *scale = reinterpret_cast<float *>(input_scale->MutableData()); float *scale = reinterpret_cast<float *>(input_scale->MutableData());
float *dx = reinterpret_cast<float *>(output_dx->MutableData()); float *dx = reinterpret_cast<float *>(output_dx->MutableData());
float *dbias = reinterpret_cast<float *>(output_bias->MutableData()); float *dbias = reinterpret_cast<float *>(output_bias->MutableData());
float *dscale = reinterpret_cast<float *>(output_scale->MutableData()); float *dscale = reinterpret_cast<float *>(output_scale->MutableData());
std::fill(dbias, dbias + channels, 0.f); int total = spatial * batch;
std::fill(dscale, dscale + channels, 0.f); int stride = UP_DIV(total, thread_num);
backwardAll(x, yt, save_mean, save_var, scale, batch * spatial, channels, dxhat_sum, dxhathat_sum, dbias, dscale, dx); int count = MSMIN(stride, total - stride * task_id);
switch (stage) {
case 0: {
for (int job = task_id; job < 4; job += thread_num) {
switch (job) {
case 0:
var2Invar(save_var, input_var->ElementsNum(), bn_param->epsilon_);
break;
case 1:
std::fill(workspace_temp, workspace_temp + ws_size_, 0.f);
break;
case 2:
std::fill(dbias, dbias + channels, 0.f);
break;
case 3:
std::fill(dscale, dscale + channels, 0.f);
break;
}
}
if (thread_num == 1) {
backwardAll(x, yt, save_mean, save_var, scale, total, channels, dxhat_sum, dxhathat_sum, dbias, dscale, dx);
}
break;
}
case 1: {
backwardP1(x, yt, save_mean, save_var, scale, total, channels, dxhat_sum, dxhathat_sum, dbias, dscale);
break;
}
case 2: {
backwardP2(x + task_id * stride * channels, yt + task_id * stride * channels, save_mean, save_var, scale, count,
total, channels, dxhat_sum, dxhathat_sum, dx + task_id * stride * channels);
break;
}
}
return RET_OK; return RET_OK;
} }
int BNGradRun(void *cdata, int task_id) { int BNGradRun(void *cdata, int task_id) {
MS_ASSERT(cdata != nullptr); MS_ASSERT(cdata != nullptr);
auto bn_kernel = reinterpret_cast<BNGradCPUKernel *>(cdata); auto bn_kernel = reinterpret_cast<BNGradCPUKernel *>(cdata);
auto error_code = bn_kernel->Execute(task_id); auto error_code = bn_kernel->Execute(task_id);
if (error_code != RET_OK) { if (error_code != RET_OK) {
MS_LOG(ERROR) << "BNGradRun error task_id[" << task_id << "] error_code[" << error_code << "]"; MS_LOG(ERROR) << "BNGradRun error task_id[" << task_id << "] error_code[" << error_code << "]";
@ -87,15 +125,24 @@ int BNGradRun(void *cdata, int task_id) {
} }
int BNGradCPUKernel::Run() { int BNGradCPUKernel::Run() {
auto *input_var = in_tensors_.at(4); stage_ = 0;
float *save_var = reinterpret_cast<float *>(input_var->MutableData()); thread_num_ = context_->thread_num_;
auto bn_param = reinterpret_cast<BNGradParameter *>(op_parameter_); if (thread_num_ == 1) {
float eps = bn_param->epsilon_; int error_code = ParallelLaunch(this->context_->thread_pool_, BNGradRun, this, thread_num_);
var2Invar(save_var, input_var->ElementsNum(), eps); if (error_code != RET_OK) {
int error_code = ParallelLaunch(this->context_->thread_pool_, BNGradRun, this, 1); MS_LOG(ERROR) << "BN function error error_code[" << error_code << "]";
if (error_code != RET_OK) { return RET_ERROR;
MS_LOG(ERROR) << "BN function error error_code[" << error_code << "]"; }
return RET_ERROR; } else {
const std::vector<int> threads = {thread_num_, 1, thread_num_};
for (size_t stage = 0; stage < threads.size(); stage++) {
stage_ = static_cast<int>(stage);
int error_code = ParallelLaunch(this->context_->thread_pool_, BNGradRun, this, threads.at(stage));
if (error_code != RET_OK) {
MS_LOG(ERROR) << "BN function error error_code[" << error_code << "]";
return RET_ERROR;
}
}
} }
return RET_OK; return RET_OK;
} }

View File

@ -33,6 +33,11 @@ class BNGradCPUKernel : public LiteKernel {
int ReSize() override; int ReSize() override;
int Run() override; int Run() override;
int Execute(int task_id); int Execute(int task_id);
private:
int thread_num_ = 1;
int stage_ = 0;
size_t ws_size_ = 0;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BN_GRAD_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BN_GRAD_H_

View File

@ -54,9 +54,6 @@ int ConvolutionTrainCPUKernel::ReSize() {
conv_param_->group_ = (conv_param_->group_ == 0) ? conv_param_->input_channel_ : conv_param_->group_; conv_param_->group_ = (conv_param_->group_ == 0) ? conv_param_->input_channel_ : conv_param_->group_;
const int n = conv_param_->output_channel_ * conv_param_->group_; const int n = conv_param_->output_channel_ * conv_param_->group_;
const int k = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ / conv_param_->group_; const int k = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ / conv_param_->group_;
ws_size_ = chunk_ * k;
int mat_alloc = MatSizeTotal(chunk_, n, k, 0);
set_workspace_size((ws_size_ + mat_alloc) * sizeof(float));
do_img2col_ = (conv_param_->kernel_h_ == 1) && (conv_param_->kernel_w_ == 1) && (conv_param_->pad_d_ == 0) && do_img2col_ = (conv_param_->kernel_h_ == 1) && (conv_param_->kernel_w_ == 1) && (conv_param_->pad_d_ == 0) &&
(conv_param_->pad_u_ == 0) && (conv_param_->pad_l_ == 0) && (conv_param_->pad_r_ == 0) && (conv_param_->pad_u_ == 0) && (conv_param_->pad_l_ == 0) && (conv_param_->pad_r_ == 0) &&
@ -64,6 +61,16 @@ int ConvolutionTrainCPUKernel::ReSize() {
(conv_param_->stride_h_ == 1) && (conv_param_->stride_w_ == 1) && (conv_param_->group_ == 1) (conv_param_->stride_h_ == 1) && (conv_param_->stride_w_ == 1) && (conv_param_->group_ == 1)
? false ? false
: true; : true;
do_dw_ = (conv_param_->output_channel_ == conv_param_->group_) &&
(conv_param_->input_channel_ == conv_param_->output_channel_) && (conv_param_->dilation_h_ == 1) &&
(conv_param_->dilation_w_ == 1)
? true
: false;
ws_size_ = chunk_ * conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_;
ws_size_ = do_dw_ ? ws_size_ : ws_size_ / conv_param_->group_;
int mat_alloc = MatSizeTotal(chunk_, n, k, 0);
set_workspace_size((ws_size_ + mat_alloc) * sizeof(float));
return RET_OK; return RET_OK;
} }
@ -97,7 +104,25 @@ int ConvolutionTrainCPUKernel::Execute(int task_id) {
float *workspace_temp = static_cast<float *>(workspace()); float *workspace_temp = static_cast<float *>(workspace());
float *mat_workspace = workspace_temp + ws_size_; float *mat_workspace = workspace_temp + ws_size_;
if (do_img2col_) { if (do_dw_) {
const int kernel_spatial = k_h * k_w;
for (int i = 0; i < batch; ++i) {
for (int ci = 0; ci < m; ci += chunk_) {
int real_chunk = MSMIN(m - ci, chunk_);
float *mat_a = workspace_temp;
float *im = x_addr + (i * in_ch * in_h * in_w);
RollingIm2ColPackDwUnitFp32(im, conv_param_, mat_a, real_chunk, ci);
for (int j = 0; j < groups; ++j) {
const float *mat_b = w_addr + j * nweights / groups;
float *mat_c = y_addr + (i * groups) * n * m + j * (out_ch / groups) + ci * out_ch;
// float *im = x_addr + i * in_ch * in_h * in_w + j * (in_ch / groups);
// RollingIm2ColPackUnitFp32(im, conv_param_, mat_a, real_chunk, ci);
GemmMatmul(0, 1, real_chunk, n, k, 1, mat_a + (j * kernel_spatial), k * groups, mat_b, k, 0, mat_c, out_ch,
mat_workspace);
}
}
}
} else if (do_img2col_) {
for (int i = 0; i < batch; ++i) { for (int i = 0; i < batch; ++i) {
for (int j = 0; j < groups; ++j) { for (int j = 0; j < groups; ++j) {
for (int ci = 0; ci < m; ci += chunk_) { for (int ci = 0; ci < m; ci += chunk_) {

View File

@ -37,6 +37,7 @@ class ConvolutionTrainCPUKernel : public LiteKernel {
private: private:
int ws_size_ = 0; int ws_size_ = 0;
bool do_img2col_ = true; bool do_img2col_ = true;
bool do_dw_ = false;
#ifdef ENABLE_ARM32 #ifdef ENABLE_ARM32
const int chunk_ = C4NUM * 2; const int chunk_ = C4NUM * 2;
#else #else

View File

@ -17,6 +17,7 @@
#include "src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h" #include "src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "nnacl/pack.h" #include "nnacl/pack.h"
#include "nnacl/fp32_grad/convolution_grad_filter.h"
#include "nnacl/fp32_grad/pack_ext.h" #include "nnacl/fp32_grad/pack_ext.h"
#include "nnacl/fp32_grad/gemm.h" #include "nnacl/fp32_grad/gemm.h"
#include "include/errorcode.h" #include "include/errorcode.h"
@ -51,20 +52,25 @@ int ConvolutionGradFilterCPUKernel::ReSize() {
conv_param->output_h_ = dy_tensor->shape()[kNHWC_H]; conv_param->output_h_ = dy_tensor->shape()[kNHWC_H];
conv_param->output_w_ = dy_tensor->shape()[kNHWC_W]; conv_param->output_w_ = dy_tensor->shape()[kNHWC_W];
ws_size_ = chunk_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_;
int n = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_;
int k = conv_param->output_channel_ / conv_param->group_;
int thread_num = context_->thread_num_;
mat_alloc_ = MatSizeTotal(k, n, chunk_, 0);
set_workspace_size((ws_size_ + mat_alloc_ + (k * n)) * thread_num * sizeof(float));
do_img2col_ = (conv_param->kernel_h_ == 1) && (conv_param->kernel_w_ == 1) && (conv_param->pad_d_ == 0) && do_img2col_ = (conv_param->kernel_h_ == 1) && (conv_param->kernel_w_ == 1) && (conv_param->pad_d_ == 0) &&
(conv_param->pad_u_ == 0) && (conv_param->pad_l_ == 0) && (conv_param->pad_r_ == 0) && (conv_param->pad_u_ == 0) && (conv_param->pad_l_ == 0) && (conv_param->pad_r_ == 0) &&
(conv_param->dilation_h_ == 1) && (conv_param->dilation_w_ == 1) && (conv_param->stride_h_ == 1) && (conv_param->dilation_h_ == 1) && (conv_param->dilation_w_ == 1) && (conv_param->stride_h_ == 1) &&
(conv_param->stride_w_ == 1) && (conv_param->group_ == 1) (conv_param->stride_w_ == 1) && (conv_param->group_ == 1)
? false ? false
: true; : true;
do_dw_ = (conv_param->output_channel_ == conv_param->group_) &&
(conv_param->input_channel_ == conv_param->output_channel_) && (conv_param->dilation_h_ == 1) &&
(conv_param->dilation_w_ == 1)
? true
: false;
ws_size_ = chunk_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
ws_size_ = do_dw_ ? ws_size_ : ws_size_ / conv_param->group_;
int n = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_;
int k = conv_param->output_channel_ / conv_param->group_;
int thread_num = context_->thread_num_;
mat_alloc_ = MatSizeTotal(k, n, chunk_, 0);
set_workspace_size((ws_size_ + mat_alloc_ + (k * n)) * thread_num * sizeof(float));
return RET_OK; return RET_OK;
} }
@ -105,10 +111,38 @@ int ConvolutionGradFilterCPUKernel::Execute(int task_id) {
int start = stride * task_id; int start = stride * task_id;
int end = start + count; int end = start + count;
if (do_img2col_) { if (do_dw_) {
#ifdef ENABLE_ARM
stride = UP_DIV(k_h * k_w, thread_num);
count = MSMIN(stride, k_h * k_w - stride * task_id);
start = stride * task_id;
ConvDwFilterGrad(x_addr, dy_addr, dw_addr, start, count, conv_param);
#else
stride = UP_DIV(groups, thread_num);
count = MSMIN(stride, groups - stride * task_id);
start = stride * task_id;
end = start + count;
const int kernel_spatial = k_h * k_w;
for (int i = 0; i < batch; ++i) {
for (int ci = 0; ci < m; ci += chunk_) {
int real_chunk = MSMIN(m - ci, chunk_);
float *mat_b = workspace_temp + task_id * ws_size_;
float *im = x_addr + (i * in_ch * in_h * in_w);
RollingIm2ColPackDwUnitFp32(im, conv_param, mat_b, real_chunk, ci);
for (int j = start; j < end; ++j) {
float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch;
float *mat_c = dw_addr + j * nweights / groups;
GemmMatmul(1, 0, k, n, real_chunk, 1, mat_a, out_ch, mat_b + (j * kernel_spatial), n * groups, 1, mat_c, n,
mat_workspace);
}
}
}
#endif
} else if (do_img2col_) {
for (int i = start; i < end; ++i) { for (int i = start; i < end; ++i) {
for (int j = 0; j < groups; ++j) { for (int ci = 0; ci < m; ci += chunk_) {
for (int ci = 0; ci < m; ci += chunk_) { for (int j = 0; j < groups; ++j) {
int real_chunk = MSMIN(m - ci, chunk_); int real_chunk = MSMIN(m - ci, chunk_);
float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch; float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch;
float *mat_b = workspace_temp + task_id * ws_size_; float *mat_b = workspace_temp + task_id * ws_size_;

View File

@ -38,6 +38,7 @@ class ConvolutionGradFilterCPUKernel : public LiteKernel {
private: private:
size_t ws_size_ = 0; size_t ws_size_ = 0;
bool do_img2col_ = true; bool do_img2col_ = true;
bool do_dw_ = false;
std::mutex lock_; std::mutex lock_;
size_t mat_alloc_ = 0; size_t mat_alloc_ = 0;
#ifdef ENABLE_ARM32 #ifdef ENABLE_ARM32

View File

@ -66,13 +66,20 @@ int PoolingGradCPUKernel::Execute(int task_id) {
auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
int stride = UP_DIV(pool_param->output_batch_, thread_num_);
int count = MSMIN(stride, pool_param->output_batch_ - stride * task_id);
int in_batch_size = pool_param->input_h_ * pool_param->input_w_ * pool_param->input_channel_;
int out_batch_size = pool_param->output_h_ * pool_param->output_w_ * pool_param->input_channel_;
std::fill(output_ptr + task_id * stride * in_batch_size, output_ptr + ((task_id * stride) + count) * in_batch_size,
0.f);
if (pool_param->pool_mode_ == PoolMode_MaxPool) { if (pool_param->pool_mode_ == PoolMode_MaxPool) {
auto dx_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
auto dy_ptr = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData()); auto dy_ptr = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData());
MaxPoolingGrad(input_ptr, dx_ptr, dy_ptr, output_ptr, pool_param, task_id); MaxPoolingGrad(input_ptr + task_id * stride * in_batch_size, dy_ptr + task_id * stride * out_batch_size,
output_ptr + task_id * stride * in_batch_size, count, pool_param);
} else { } else {
input_ptr = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData()); input_ptr = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData());
AvgPoolingGrad(input_ptr, output_ptr, pool_param, task_id); AvgPoolingGrad(input_ptr + task_id * stride * out_batch_size, output_ptr + task_id * stride * in_batch_size, count,
pool_param);
} }
return RET_OK; return RET_OK;
} }
@ -89,7 +96,8 @@ int PoolingGradImpl(void *cdata, int task_id) {
} }
int PoolingGradCPUKernel::Run() { int PoolingGradCPUKernel::Run() {
int error_code = ParallelLaunch(this->context_->thread_pool_, PoolingGradImpl, this, 1); thread_num_ = context_->thread_num_;
int error_code = ParallelLaunch(this->context_->thread_pool_, PoolingGradImpl, this, thread_num_);
if (error_code != RET_OK) { if (error_code != RET_OK) {
MS_LOG(ERROR) << "pooling error error_code[" << error_code << "]"; MS_LOG(ERROR) << "pooling error error_code[" << error_code << "]";
return RET_ERROR; return RET_ERROR;

View File

@ -40,6 +40,7 @@ class PoolingGradCPUKernel : public LiteKernel {
int Execute(int task_id); int Execute(int task_id);
private: private:
int thread_num_ = 1;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel

View File

@ -0,0 +1,150 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32_grad/strided_slice_grad.h"
#include <vector>
#include <algorithm>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "nnacl/fp32_grad/strided_slice_grad.h"
#include "src/ops/populate/strided_slice_populate.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_StridedSliceGrad;
namespace mindspore::kernel {
int StridedSliceGradCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
param_ = reinterpret_cast<StridedSliceParameter *>(op_parameter_);
auto input = in_tensors_.at(0);
MS_ASSERT(input);
switch (input->data_type()) {
case kNumberTypeFloat32:
param_->data_type = kDataTypeFloat;
break;
default:
MS_LOG(ERROR) << "Not supported data type: " << input->data_type();
return RET_ERROR;
}
FillEmptyDims();
FillOutputDim();
return ReSize();
}
void StridedSliceGradCPUKernel::FillEmptyDims() {
int32_t begins[DIMENSION_7D];
int32_t ends[DIMENSION_7D];
int32_t strides[DIMENSION_7D];
int32_t input_shape[DIMENSION_7D];
int32_t i;
for (i = 0; i < param_->num_axes_; ++i) {
begins[i] = param_->begins_[i];
ends[i] = MSMIN(param_->ends_[i], param_->in_shape_[i]);
strides[i] = param_->strides_[i];
input_shape[i] = param_->in_shape_[i];
}
for (i = param_->num_axes_; i < param_->in_shape_length_; ++i) {
input_shape[i] = param_->in_shape_[i];
begins[i] = 0;
ends[i] = param_->in_shape_[i];
strides[i] = 1;
}
int32_t real_index = param_->in_shape_length_ - 1;
for (i = DIMENSION_7D - 1; i >= 0; --i) {
if (real_index >= 0) {
param_->begins_[i] = begins[real_index];
param_->ends_[i] = ends[real_index];
param_->strides_[i] = strides[real_index];
param_->in_shape_[i] = input_shape[real_index--];
} else {
param_->begins_[i] = 0;
param_->ends_[i] = 1;
param_->strides_[i] = 1;
param_->in_shape_[i] = 1;
}
}
param_->num_axes_ = DIMENSION_7D;
param_->in_shape_length_ = DIMENSION_7D;
for (i = 0; i < DIMENSION_7D; ++i) {
if (param_->begins_[i] < 0) {
param_->begins_[i] += param_->in_shape_[i];
}
if (param_->ends_[i] < 0) {
param_->ends_[i] += param_->in_shape_[i];
}
}
}
void StridedSliceGradCPUKernel::FillOutputDim() {
auto output = out_tensors_.at(0);
size_t out_size = output->shape().size();
for (size_t i = 0; i < DIMENSION_7D; i++) {
if (i < out_size) {
output_shape_.push_back(output->shape()[i]);
} else {
output_shape_.insert(output_shape_.begin(), 1);
}
}
}
int StridedSliceGradCPUKernel::ReSize() { return RET_OK; }
int StridedSliceGradImpl(void *cdata, int task_id) {
MS_ASSERT(cdata != nullptr);
auto slice = reinterpret_cast<StridedSliceGradCPUKernel *>(cdata);
auto error_code = slice->Execute(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "StridedSliceGrad Run error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
int StridedSliceGradCPUKernel::Run() {
int error_code = ParallelLaunch(this->context_->thread_pool_, StridedSliceGradImpl, this, 1);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Strided slice error error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
int StridedSliceGradCPUKernel::Execute(int task_id) {
auto input = in_tensors_.at(0);
auto output = out_tensors_.at(0);
MS_ASSERT(output);
int *po = output_shape_.data();
auto ret = DoStridedSliceGrad(reinterpret_cast<float *>(input->MutableData()),
reinterpret_cast<float *>(output->MutableData()), po, param_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "StridedSliceGrad error error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_StridedSliceGrad, LiteKernelCreator<StridedSliceGradCPUKernel>)
} // namespace mindspore::kernel

View File

@ -0,0 +1,50 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_STRIDED_SLICE_GRAD_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_STRIDED_SLICE_GRAD_H_
#include <vector>
#include "nnacl/fp32_grad/strided_slice_grad.h"
#include "src/lite_kernel.h"
namespace mindspore::kernel {
class StridedSliceGradCPUKernel : public LiteKernel {
public:
StridedSliceGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
param_ = reinterpret_cast<StridedSliceParameter *>(parameter);
}
~StridedSliceGradCPUKernel() override = default;
int Init() override;
int ReSize() override;
int Run() override;
int Execute(int task_id);
private:
void FillEmptyDims();
void FillOutputDim();
void ParseMasks();
StridedSliceParameter *param_;
std::vector<int> output_shape_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_STRIDED_SLICE_GRAD_H_

View File

@ -49,6 +49,7 @@
#include "src/ops/smooth_l1_loss_grad.h" #include "src/ops/smooth_l1_loss_grad.h"
#include "nnacl/fp32_grad/smooth_l1_loss.h" #include "nnacl/fp32_grad/smooth_l1_loss.h"
#include "src/ops/arithmetic_grad.h" #include "src/ops/arithmetic_grad.h"
#include "src/ops/populate/strided_slice_populate.h"
namespace mindspore::kernel { namespace mindspore::kernel {
OpParameter *DefaultPopulateParameter(const mindspore::lite::PrimitiveC *primitive) { OpParameter *DefaultPopulateParameter(const mindspore::lite::PrimitiveC *primitive) {
@ -569,6 +570,9 @@ void PopulateTrainParameters() {
DefaultPopulateParameter); DefaultPopulateParameter);
lite::Registry SigmoidCrossEntropyWithLogitsGradRegistry(schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad, lite::Registry SigmoidCrossEntropyWithLogitsGradRegistry(schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad,
DefaultPopulateParameter); DefaultPopulateParameter);
lite::Registry FlattenGradParameterRegistry(schema::PrimitiveType_FlattenGrad, DefaultPopulateParameter);
lite::Registry StridedSliceGradParameterRegistry(schema::PrimitiveType_StridedSliceGrad,
mindspore::lite::PopulateStridedSliceParameter);
} }
} // namespace mindspore::kernel } // namespace mindspore::kernel

View File

@ -0,0 +1,72 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_TRAIN_TRANSFER_SESSION_H_
#define MINDSPORE_LITE_SRC_TRAIN_TRANSFER_SESSION_H_
#include <vector>
#include <string>
#include <tuple>
#include <unordered_map>
#include "src/ops/primitive_c.h"
#include "include/train_session.h"
#include "src/train/train_model.h"
#include "src/lite_session.h"
#include "src/train/train_session.h"
/*
Inheritance Diagram
+-------------------------------+
| session::LiteSession |
+--------+------------+---------+
/ \
+-----------------+-----+ +-------+------------+
| session::TrainSession | | lite::LiteSession |
+-----------------+-----+ +-------+------------+
\ /
+--------+------------+---------+
| lite::TrainSession |
+-------------------------------+
|
+--------+------------+---------+
| lite::TrasferSession |
+-------------------------------+
*/
namespace mindspore {
namespace lite {
class TransferSession : public lite::TrainSession {
public:
TransferSession();
explicit TransferSession(lite::LiteSession *backend_session);
~TransferSession();
int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override;
void BindThread(bool if_bind) override;
std::vector<tensor::MSTensor *> GetInputs() const override { return lite::LiteSession::GetInputs(); }
mindspore::tensor::MSTensor *GetInputsByTensorName(const std::string &tensor_name) const override {
return lite::LiteSession::GetInputsByTensorName(tensor_name);
}
protected:
lite::LiteSession *backend_session_;
private:
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_TRAIN_TRANSFER_SESSION_H_

View File

@ -1,13 +1,15 @@
mini_alexnet mini_alexnet
# mobilenetv1 mobilenetv1
mobilenetv2 mobilenetv2
mobilenetv3 mobilenetv3
lenet lenet
effnet effnet
# effnet_tune effnet_tune
# lenetv1 resnet
# resnet googlenet
# googlenet
# densenet # densenet
# shufflenetv2
# nin
# one_net # one_net
# lenetv1
#LAST #LAST

View File

@ -0,0 +1,82 @@
#!/bin/bash
# Print start msg after run testcase
function MS_PRINT_TESTCASE_END_MSG() {
echo -e "-----------------------------------------------------------------------------------------------------------------------------------"
}
function Print_Result() {
MS_PRINT_TESTCASE_END_MSG
while read line; do
arr=("${line}")
printf "%-15s %-20s %-90s %-7s\n" ${arr[0]} ${arr[1]} ${arr[2]} ${arr[3]}
done < $1
MS_PRINT_TESTCASE_END_MSG
}
basepath=$(pwd)
echo ${basepath}
# Example:run_net_export.sh -m /home/emir/Work/TestingEnv/train_models
epoch_num=1
while getopts "m:t:" opt; do
case ${opt} in
m)
models_path=${OPTARG}"/models_train"
echo "models_path is ${OPTARG}"
;;
t)
epoch_num=${OPTARG}
echo "train epoch num is ${OPTARG}"
;;
?)
echo "unknown para"
exit 1;;
esac
done
# Set models config filepath
models_mindspore_train_config=${basepath}/models_ms_train.cfg
logs_path=${basepath}/logs_train
rm -rf ${logs_path}
mkdir -p ${logs_path}
docker_image=mindspore/mindspore-gpu:1.1.0
# Export models
echo "Start Exporting models ..."
# Set log files
export_log_file=${logs_path}/export_log.txt
echo ' ' > ${export_log_file}
export_result_file=${logs_path}/export_result.txt
echo ' ' > ${export_result_file}
# Run export according to config file
cd $models_path || exit 1
if [[ -z "${CLOUD_MODEL_ZOO}" ]]; then
echo "CLOUD_MODEL_ZOO is not defined - exiting export models"
exit 1
fi
# Export mindspore train models:
while read line; do
model_name=${line}
if [[ $model_name == \#* ]]; then
continue
fi
echo ${model_name}'_train_export.py' >> "${export_log_file}"
echo 'exporting' ${model_name}
echo 'docker run --user '"$(id -u):$(id -g)"' --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true '${docker_image}' python '${models_path}'/'${model_name}'_train_export.py' >> "${export_log_file}"
docker run --user "$(id -u):$(id -g)" --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true "${docker_image}" python ${models_path}'/'${model_name}_train_export.py "${epoch_num}"
if [ $? = 0 ]; then
export_result='export mindspore '${model_name}'_train_export pass';echo ${export_result} >> ${export_result_file}
else
export_result='export mindspore '${model_name}'_train_export failed';echo ${export_result} >> ${export_result_file}
fi
done < ${models_mindspore_train_config}
Print_Result ${export_result_file}

View File

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
# Run Export on x86 platform and create output test files: # Run Export on x86 platform and create output test files:
docker_image=mindspore_dev:8 docker_image=
function Run_Export(){ function Run_Export(){
cd $models_path || exit 1 cd $models_path || exit 1
if [[ -z "${CLOUD_MODEL_ZOO}" ]]; then if [[ -z "${CLOUD_MODEL_ZOO}" ]]; then
@ -16,8 +16,13 @@ function Run_Export(){
fi fi
echo ${model_name}'_train_export.py' >> "${export_log_file}" echo ${model_name}'_train_export.py' >> "${export_log_file}"
echo 'exporting' ${model_name} echo 'exporting' ${model_name}
echo 'docker run --user '"$(id -u):$(id -g)"' --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true '${docker_image}' python '${models_path}'/'${model_name}'_train_export.py' >> "${export_log_file}" if [ -n "$docker_image" ]; then
docker run --user "$(id -u):$(id -g)" --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true "${docker_image}" python ${models_path}'/'${model_name}_train_export.py "${epoch_num}" echo 'docker run --user '"$(id -u):$(id -g)"' --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true '${docker_image}' python '${models_path}'/'${model_name}'_train_export.py' >> "${export_log_file}"
docker run --user "$(id -u):$(id -g)" --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true "${docker_image}" python ${models_path}'/'${model_name}_train_export.py "${epoch_num}"
else
echo 'CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} python '${models_path}'/'${model_name}'_train_export.py' >> "${export_log_file}"
CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} python ${models_path}'/'${model_name}_train_export.py "${epoch_num}"
fi
if [ $? = 0 ]; then if [ $? = 0 ]; then
export_result='export mindspore '${model_name}'_train_export pass';echo ${export_result} >> ${export_result_file} export_result='export mindspore '${model_name}'_train_export pass';echo ${export_result} >> ${export_result_file}
else else
@ -28,7 +33,7 @@ function Run_Export(){
# Run converter on x86 platform: # Run converter on x86 platform:
function Run_Converter() { function Run_Converter() {
# Unzip x86 runtime and convertor # Unzip x86 runtime and converter
cd ${x86_path} || exit 1 cd ${x86_path} || exit 1
tar -zxf mindspore-lite-${version}-train-linux-x64.tar.gz || exit 1 tar -zxf mindspore-lite-${version}-train-linux-x64.tar.gz || exit 1
@ -189,7 +194,7 @@ ENDM
if [ $? = 0 ]; then if [ $? = 0 ]; then
run_result=$1': '${model_name}'_train pass'; echo ${run_result} >> ${run_benchmark_train_result_file} run_result=$1': '${model_name}'_train pass'; echo ${run_result} >> ${run_benchmark_train_result_file}
else else
run_result=$1': '${model_name}'_train failed'; echo ${run_result} >> ${run_benchmark_train_result_file}; return 1 run_result=$1': '${model_name}'_train failed'; echo ${run_result} >> ${run_benchmark_train_result_file};
fi fi
done < ${models_mindspore_train_config} done < ${models_mindspore_train_config}
} }
@ -222,16 +227,15 @@ echo ${basepath}
# Example:run_benchmark_train.sh -r /home/emir/Work/TestingEnv/release -m /home/emir/Work/TestingEnv/train_models -i /home/emir/Work/TestingEnv/train_io -d "8KE5T19620002408" # Example:run_benchmark_train.sh -r /home/emir/Work/TestingEnv/release -m /home/emir/Work/TestingEnv/train_models -i /home/emir/Work/TestingEnv/train_io -d "8KE5T19620002408"
# For running on arm64, use -t to set platform tools path (for using adb commands) # For running on arm64, use -t to set platform tools path (for using adb commands)
epoch_num=1 epoch_num=1
threads=1 threads=2
train_io_path="" train_io_path=""
while getopts "r:m:d:i:e:vt:q:" opt; do while getopts "r:m:d:i:e:vt:q:D" opt; do
case ${opt} in case ${opt} in
r) r)
release_path=${OPTARG} release_path=${OPTARG}
echo "release_path is ${OPTARG}" echo "release_path is ${OPTARG}"
;; ;;
m) m)
models_path=${OPTARG}"/models_train" models_path=${OPTARG}"/models_train"
echo "models_path is ${OPTARG}" echo "models_path is ${OPTARG}"
;; ;;
@ -244,8 +248,9 @@ while getopts "r:m:d:i:e:vt:q:" opt; do
echo "device_id is ${OPTARG}" echo "device_id is ${OPTARG}"
;; ;;
e) e)
enable_export=${OPTARG} enable_export=1
echo "enable_export = ${OPTARG}" docker_image=${OPTARG}
echo "enable_export = 1, docker_image = ${OPTARG}"
;; ;;
v) v)
run_valgrind="valgrind --log-file=valgrind.log " run_valgrind="valgrind --log-file=valgrind.log "
@ -404,27 +409,27 @@ function Print_Benchmark_Result() {
done < ${run_benchmark_train_result_file} done < ${run_benchmark_train_result_file}
MS_PRINT_TESTCASE_END_MSG MS_PRINT_TESTCASE_END_MSG
} }
result=0
# Check benchmark_train result and return value # Check benchmark_train result and return value
if [[ ${Run_x86_status} != 0 ]];then if [[ ${Run_x86_status} != 0 ]];then
echo "Run_x86 failed" echo "Run_x86 failed"
cat ${run_x86_log_file} cat ${run_x86_log_file}
exit 1 result=1
fi fi
if [[ ${Run_arm64_status} != 0 ]];then if [[ ${Run_arm64_status} != 0 ]];then
echo "Run_arm64 failed" echo "Run_arm64 failed"
cat ${run_arm64_log_file} cat ${run_arm64_log_file}
exit 1 result=1
fi fi
if [[ ${Run_arm32_status} != 0 ]];then if [[ ${Run_arm32_status} != 0 ]];then
echo "Run_arm32 failed" echo "Run_arm32 failed"
cat ${run_arm32_log_file} cat ${run_arm32_log_file}
exit 1 result=1
fi fi
echo "Test ended - Results:" echo "Test ended - Results:"
Print_Benchmark_Result Print_Benchmark_Result
echo "Test run Time:" $DIFF echo "Test run Time:" $DIFF
exit 0 exit ${result}

View File

@ -79,15 +79,18 @@ TEST_F(TestPoolingGradFp32, AvgPoolingGradFp32) {
auto output_data = new float[output_data_size]; auto output_data = new float[output_data_size];
ASSERT_NE(output_data, nullptr); ASSERT_NE(output_data, nullptr);
// warm up loop // warm up loop
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
AvgPoolingGrad(input_data, output_data, pooling_param, 1); std::fill(output_data, output_data + output_data_size, 0.f);
AvgPoolingGrad(input_data, output_data, pooling_param->output_batch_, pooling_param);
} }
int loop_count = 100; int loop_count = 100;
auto time_start = mindspore::lite::GetTimeUs(); auto time_start = mindspore::lite::GetTimeUs();
for (int i = 0; i < loop_count; i++) { for (int i = 0; i < loop_count; i++) {
AvgPoolingGrad(input_data, output_data, pooling_param, 1); std::fill(output_data, output_data + output_data_size, 0.f);
AvgPoolingGrad(input_data, output_data, pooling_param->output_batch_, pooling_param);
} }
auto time_end = mindspore::lite::GetTimeUs(); auto time_end = mindspore::lite::GetTimeUs();
auto cost = time_end - time_start; auto cost = time_end - time_start;
@ -407,18 +410,21 @@ TEST_F(TestPoolingGradFp32, MaxPoolingGradFp32) {
std::string dx_path = "./test_data/pooling/maxpoolgradfp32_1_dx_1_28_28_3.bin"; std::string dx_path = "./test_data/pooling/maxpoolgradfp32_1_dx_1_28_28_3.bin";
auto dx_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(dx_path.c_str(), &input_size)); auto dx_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(dx_path.c_str(), &input_size));
ASSERT_NE(dx_data, nullptr); ASSERT_NE(dx_data, nullptr);
int in_batch_size =
pooling_param->input_h_ * pooling_param->input_w_ * pooling_param->input_channel_ * pooling_param->input_batch_;
auto output_data = new float[output_data_size]; auto output_data = new float[output_data_size];
ASSERT_NE(output_data, nullptr); ASSERT_NE(output_data, nullptr);
// warm up loop // warm up loop
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
MaxPoolingGrad(in_data, dx_data, dy_data, output_data, pooling_param, 1); std::fill(output_data, output_data + in_batch_size, 0.f);
MaxPoolingGrad(in_data, dy_data, output_data, pooling_param->output_batch_, pooling_param);
} }
int loop_count = 100; int loop_count = 100;
auto time_start = mindspore::lite::GetTimeUs(); auto time_start = mindspore::lite::GetTimeUs();
for (int i = 0; i < loop_count; i++) { for (int i = 0; i < loop_count; i++) {
MaxPoolingGrad(in_data, dx_data, dy_data, output_data, pooling_param, 1); std::fill(output_data, output_data + in_batch_size, 0.f);
MaxPoolingGrad(in_data, dy_data, output_data, pooling_param->output_batch_, pooling_param);
} }
auto time_end = mindspore::lite::GetTimeUs(); auto time_end = mindspore::lite::GetTimeUs();
auto cost = time_end - time_start; auto cost = time_end - time_start;

View File

@ -135,7 +135,6 @@ int NetTrain::ReadCalibData() {
MS_LOG(INFO) << "Start reading calibData file"; MS_LOG(INFO) << "Start reading calibData file";
std::string tensor_name; std::string tensor_name;
while (!in_file.eof()) { while (!in_file.eof()) {
getline(in_file, line); getline(in_file, line);
std::stringstream string_line1(line); std::stringstream string_line1(line);

View File

@ -79,7 +79,7 @@ class MS_API NetTrainFlags : public virtual FlagParser {
std::vector<std::string> input_data_list_; std::vector<std::string> input_data_list_;
DataType in_data_type_; DataType in_data_type_;
std::string in_data_type_in_ = "bin"; std::string in_data_type_in_ = "bin";
int cpu_bind_mode_ = 0; int cpu_bind_mode_ = 1;
// MarkPerformance // MarkPerformance
int num_threads_ = 1; int num_threads_ = 1;
int warm_up_loop_count_ = 0; int warm_up_loop_count_ = 0;

View File

@ -32,7 +32,6 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = {
schema::PrimitiveType_PoolingGrad, schema::PrimitiveType_PoolingGrad,
schema::PrimitiveType_BiasGrad, schema::PrimitiveType_BiasGrad,
schema::PrimitiveType_BNGrad, schema::PrimitiveType_BNGrad,
schema::PrimitiveType_ActivationGrad,
schema::PrimitiveType_ApplyMomentum, schema::PrimitiveType_ApplyMomentum,
schema::PrimitiveType_Sgd, schema::PrimitiveType_Sgd,
schema::PrimitiveType_Adam, schema::PrimitiveType_Adam,
@ -219,6 +218,26 @@ STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::v
return RET_OK; return RET_OK;
} }
static bool IsKCHWSource(kTransFilterType type) {
return (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW);
}
static bool IsCKHWSource(kTransFilterType type) {
return (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC);
}
static bool IsHWCKSource(kTransFilterType type) { return (type == kHWCK2KCHW || type == kHWCK2CKHW); }
static bool IsHWKCSource(kTransFilterType type) { return (type == kHWKC2KCHW || type == kHWKC2CKHW); }
static bool IsNHWCSource(kTransFilterType type) {
return (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW);
}
static bool IsCHWKSource(kTransFilterType type) { return (type == kCHWK2HWCK || type == kCHWK2KHWC); }
static bool IsKHWCSource(kTransFilterType type) { return (type == kKHWC2HWCK || type == kKHWC2CHWK); }
STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC, STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC,
int32_t *filterH, int32_t *filterW) { int32_t *filterH, int32_t *filterW) {
if (filterK == nullptr || filterC == nullptr || filterH == nullptr || filterW == nullptr) { if (filterK == nullptr || filterC == nullptr || filterH == nullptr || filterW == nullptr) {
@ -226,37 +245,37 @@ STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type,
return RET_NULL_PTR; return RET_NULL_PTR;
} }
MS_ASSERT(oriDims.size() == 4); MS_ASSERT(oriDims.size() == 4);
if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW) { if (IsKCHWSource(type)) {
*filterK = oriDims.at(KCHW_K); *filterK = oriDims.at(KCHW_K);
*filterC = oriDims.at(KCHW_C); *filterC = oriDims.at(KCHW_C);
*filterH = oriDims.at(KCHW_H); *filterH = oriDims.at(KCHW_H);
*filterW = oriDims.at(KCHW_W); *filterW = oriDims.at(KCHW_W);
} else if (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC) { } else if (IsCKHWSource(type)) {
*filterC = oriDims.at(CKHW_C); *filterC = oriDims.at(CKHW_C);
*filterK = oriDims.at(CKHW_K); *filterK = oriDims.at(CKHW_K);
*filterH = oriDims.at(CKHW_H); *filterH = oriDims.at(CKHW_H);
*filterW = oriDims.at(CKHW_W); *filterW = oriDims.at(CKHW_W);
} else if (type == kHWCK2KCHW || type == kHWCK2CKHW) { } else if (IsHWCKSource(type)) {
*filterH = oriDims.at(HWCK_H); *filterH = oriDims.at(HWCK_H);
*filterW = oriDims.at(HWCK_W); *filterW = oriDims.at(HWCK_W);
*filterC = oriDims.at(HWCK_C); *filterC = oriDims.at(HWCK_C);
*filterK = oriDims.at(HWCK_K); *filterK = oriDims.at(HWCK_K);
} else if (type == kHWKC2KCHW || type == kHWKC2CKHW) { } else if (IsHWKCSource(type)) {
*filterH = oriDims.at(HWKC_H); *filterH = oriDims.at(HWKC_H);
*filterW = oriDims.at(HWKC_W); *filterW = oriDims.at(HWKC_W);
*filterK = oriDims.at(HWKC_K); *filterK = oriDims.at(HWKC_K);
*filterC = oriDims.at(HWKC_C); *filterC = oriDims.at(HWKC_C);
} else if (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW) { } else if (IsNHWCSource(type)) {
*filterK = oriDims.at(NHWC_N); *filterK = oriDims.at(NHWC_N);
*filterH = oriDims.at(NHWC_H); *filterH = oriDims.at(NHWC_H);
*filterW = oriDims.at(NHWC_W); *filterW = oriDims.at(NHWC_W);
*filterC = oriDims.at(NHWC_C); *filterC = oriDims.at(NHWC_C);
} else if (type == kCHWK2HWCK || type == kCHWK2KHWC) { } else if (IsCHWKSource(type)) {
*filterC = oriDims.at(CHWK_C); *filterC = oriDims.at(CHWK_C);
*filterH = oriDims.at(CHWK_H); *filterH = oriDims.at(CHWK_H);
*filterW = oriDims.at(CHWK_W); *filterW = oriDims.at(CHWK_W);
*filterK = oriDims.at(CHWK_K); *filterK = oriDims.at(CHWK_K);
} else if (type == kKHWC2HWCK || type == kKHWC2CHWK) { } else if (IsKHWCSource(type)) {
*filterK = oriDims.at(KHWC_K); *filterK = oriDims.at(KHWC_K);
*filterH = oriDims.at(KHWC_H); *filterH = oriDims.at(KHWC_H);
*filterW = oriDims.at(KHWC_W); *filterW = oriDims.at(KHWC_W);
@ -290,6 +309,37 @@ STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filt
return RET_OK; return RET_OK;
} }
static int Convert2KHWC(int srcFormat) {
if (srcFormat == schema::Format::Format_KCHW) return kKCHW2KHWC;
if (srcFormat == schema::Format::Format_CKHW) return kCKHW2KHWC;
if (srcFormat == schema::Format::Format_CHWK) return kCHWK2KHWC;
return -1;
}
static int Convert2HWCK(int srcFormat) {
if (srcFormat == schema::Format::Format_KCHW) return kKCHW2HWCK;
if (srcFormat == schema::Format::Format_KHWC) return kKHWC2HWCK;
if (srcFormat == schema::Format::Format_CKHW) return kCKHW2HWCK;
if (srcFormat == schema::Format::Format_CHWK) return kCHWK2HWCK;
return -1;
}
static int Convert2KCHW(int srcFormat) {
if (srcFormat == schema::Format::Format_HWCK) return kHWCK2KCHW;
if (srcFormat == schema::Format::Format_HWKC) return kHWKC2KCHW;
if (srcFormat == schema::Format::Format_KHWC) return kKHWC2KCHW;
if (srcFormat == schema::Format::Format_CKHW) return kCKHW2KCHW;
if (srcFormat == schema::Format::Format_CHWK) return kCHWK2KCHW;
return -1;
}
static int Convert2CKHW(int srcFormat) {
if (srcFormat == schema::Format::Format_HWCK) return kHWCK2CKHW;
if (srcFormat == schema::Format::Format_HWKC) return kHWKC2CKHW;
if (srcFormat == schema::Format::Format_KCHW) return kKCHW2CKHW;
return -1;
}
STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) { STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) {
if (tensor == nullptr) { if (tensor == nullptr) {
MS_LOG(ERROR) << "tensor is null"; MS_LOG(ERROR) << "tensor is null";
@ -303,231 +353,40 @@ STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) {
auto srcFormat = tensor->format; auto srcFormat = tensor->format;
auto dataType = tensor->dataType; auto dataType = tensor->dataType;
STATUS status; STATUS status;
int convert = -1;
if (dstFormat == srcFormat) return RET_OK;
switch (dstFormat) { switch (dstFormat) {
case schema::Format::Format_KHWC: { case schema::Format::Format_KHWC:
switch (srcFormat) { convert = Convert2KHWC(srcFormat);
case schema::Format::Format_KCHW: break;
if (dataType == kNumberTypeFloat32) { case schema::Format::Format_HWCK:
status = TransFilterFormat<float>(tensor, kKCHW2KHWC); convert = Convert2HWCK(srcFormat);
} else if (dataType == kNumberTypeUInt8) { break;
status = TransFilterFormat<uint8_t>(tensor, kKCHW2KHWC); case schema::Format::Format_KCHW:
} else if (dataType == kNumberTypeInt8) { convert = Convert2KCHW(srcFormat);
status = TransFilterFormat<int8_t>(tensor, kKCHW2KHWC); break;
} else { case schema::Format::Format_CKHW:
MS_LOG(ERROR) << "Unsupported dataType: " << dataType; convert = Convert2CKHW(srcFormat);
return RET_ERROR; break;
}
break;
case schema::Format::Format_CKHW:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kCKHW2KHWC);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kCKHW2KHWC);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kCKHW2KHWC);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format::Format_CHWK:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kCHWK2KHWC);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kCHWK2KHWC);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kCHWK2KHWC);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format::Format_KHWC:
return RET_OK;
default:
MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to "
<< EnumNameFormat(dstFormat);
return RET_ERROR;
}
} break;
case schema::Format::Format_HWCK: {
switch (srcFormat) {
case schema::Format::Format_KCHW:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kKCHW2HWCK);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kKCHW2HWCK);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kKCHW2HWCK);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format::Format_KHWC:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kKHWC2HWCK);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kKHWC2HWCK);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kKHWC2HWCK);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format::Format_CKHW:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kCKHW2HWCK);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kCKHW2HWCK);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kCKHW2HWCK);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format::Format_CHWK:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kCHWK2HWCK);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kCHWK2HWCK);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kCHWK2HWCK);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format::Format_HWCK:
return RET_OK;
default:
MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to "
<< EnumNameFormat(dstFormat);
return RET_ERROR;
}
} break;
case schema::Format::Format_KCHW: {
switch (srcFormat) {
case schema::Format::Format_KCHW:
return RET_OK;
case schema::Format::Format_HWCK:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kHWCK2KCHW);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kHWCK2KCHW);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kHWCK2KCHW);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format::Format_HWKC:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kHWKC2KCHW);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kHWKC2KCHW);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kHWKC2KCHW);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format::Format_KHWC:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kKHWC2KCHW);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kKHWC2KCHW);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kKHWC2KCHW);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format::Format_CKHW:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kCKHW2KCHW);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kCKHW2KCHW);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kCKHW2KCHW);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format::Format_CHWK:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kCHWK2KCHW);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kCHWK2KCHW);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kCHWK2KCHW);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
default:
MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to "
<< EnumNameFormat(dstFormat);
return RET_ERROR;
}
} break;
case schema::Format::Format_CKHW: {
switch (srcFormat) {
case schema::Format::Format_HWCK:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kHWCK2CKHW);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kHWCK2CKHW);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kHWCK2CKHW);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format::Format_HWKC:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kHWKC2CKHW);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kHWKC2CKHW);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kHWKC2CKHW);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format::Format_KCHW:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kKCHW2CKHW);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kKCHW2CKHW);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kKCHW2CKHW);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format::Format_CKHW:
return RET_OK;
default:
MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to "
<< EnumNameFormat(dstFormat);
return RET_ERROR;
}
} break;
default: default:
MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " convert = -1;
<< EnumNameFormat(dstFormat); }
return RET_ERROR; if (convert == -1) {
MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " << EnumNameFormat(dstFormat);
return RET_ERROR;
}
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, static_cast<kTransFilterType>(convert));
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, static_cast<kTransFilterType>(convert));
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, static_cast<kTransFilterType>(convert));
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
} }
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "TransFilterData failed: " << status; MS_LOG(ERROR) << "TransFilterData failed: " << status;