forked from mindspore-Ecosystem/mindspore
tod new ops, performance improvment and bug fix
This commit is contained in:
parent
3bfcf8947c
commit
ae389ae1f9
|
@ -50,3 +50,35 @@ void backwardAll(const float *restrict in, const float *restrict yt, const float
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
void backwardP1(const float *restrict in, const float *restrict yt, const float *restrict mean,
|
||||||
|
const float *restrict invar, const float *restrict scale, int size, int ch, float *restrict dxhat_sum,
|
||||||
|
float *restrict dxhathat_sum, float *restrict dbias, float *restrict dscale) {
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
for (int c = 0; c < ch; c++) {
|
||||||
|
int ix = i * ch + c;
|
||||||
|
dbias[c] += yt[ix];
|
||||||
|
// dscale
|
||||||
|
float x_hat = (in[ix] - mean[c]) * invar[c];
|
||||||
|
dscale[c] += (yt[ix] * x_hat);
|
||||||
|
// dx_1
|
||||||
|
float dx_hat = yt[ix] * scale[c];
|
||||||
|
dxhat_sum[c] += dx_hat;
|
||||||
|
dxhathat_sum[c] += dx_hat * x_hat;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void backwardP2(const float *restrict in, const float *restrict yt, const float *restrict mean,
|
||||||
|
const float *restrict invar, const float *restrict scale, int size, int total_size, int ch,
|
||||||
|
const float *dxhat_sum, const float *dxhathat_sum, float *restrict dx) {
|
||||||
|
float N = (float)total_size;
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
for (int c = 0; c < ch; c++) {
|
||||||
|
// dx_2
|
||||||
|
int ix = i * ch + c;
|
||||||
|
float x_hat = (in[ix] - mean[c]) * invar[c];
|
||||||
|
float dx_hat = yt[ix] * scale[c];
|
||||||
|
dx[ix] = 1.0f / N * (invar[c]) * (N * dx_hat - dxhat_sum[c] - x_hat * dxhathat_sum[c]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -32,6 +32,10 @@ extern "C" {
|
||||||
void var2Invar(float *save_var, int size, float eps);
|
void var2Invar(float *save_var, int size, float eps);
|
||||||
void backwardAll(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size,
|
void backwardAll(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size,
|
||||||
int ch, float *dxhat_sum, float *dxhathat_sum, float *dbias, float *dscale, float *dx);
|
int ch, float *dxhat_sum, float *dxhathat_sum, float *dbias, float *dscale, float *dx);
|
||||||
|
void backwardP1(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size,
|
||||||
|
int ch, float *dxhat_sum, float *dxhathat_sum, float *dbias, float *dscale);
|
||||||
|
void backwardP2(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size,
|
||||||
|
int total_size, int ch, const float *dxhat_sum, const float *dxhathat_sum, float *dx);
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -0,0 +1,379 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "nnacl/fp32_grad/convolution_grad_filter.h"
|
||||||
|
#ifdef ENABLE_ARM
|
||||||
|
#include <arm_neon.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef ENABLE_ARM
|
||||||
|
static int FilterGrad16Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw,
|
||||||
|
const ConvParameter *conv_param) {
|
||||||
|
int in_h = conv_param->input_h_;
|
||||||
|
int in_w = conv_param->input_w_;
|
||||||
|
int k_h = conv_param->kernel_h_;
|
||||||
|
int k_w = conv_param->kernel_w_;
|
||||||
|
int batch = conv_param->output_batch_;
|
||||||
|
int out_ch = conv_param->output_channel_;
|
||||||
|
int in_ch = conv_param->input_channel_;
|
||||||
|
int out_h = conv_param->output_h_;
|
||||||
|
int out_w = conv_param->output_w_;
|
||||||
|
|
||||||
|
int m = out_h * out_w;
|
||||||
|
int x_size = in_h * in_w * in_ch;
|
||||||
|
int y_size = out_ch * out_h * out_w;
|
||||||
|
int k_spatial = k_w * k_h;
|
||||||
|
int i_kh = k_idx / k_w;
|
||||||
|
int i_kw = k_idx % k_w;
|
||||||
|
for (; i_c < (out_ch & ~15); i_c += 16) {
|
||||||
|
float32x4_t sum_03_4 = vdupq_n_f32(0.0f);
|
||||||
|
float32x4_t sum_47_4 = vdupq_n_f32(0.0f);
|
||||||
|
float32x4_t sum_9x_4 = vdupq_n_f32(0.0f);
|
||||||
|
float32x4_t sum_12x_4 = vdupq_n_f32(0.0f);
|
||||||
|
for (int b = 0; b < batch; ++b) {
|
||||||
|
const float *x_addr = &x[b * x_size];
|
||||||
|
const float *dy_addr = &dy[b * y_size];
|
||||||
|
for (int i = 0; i < m; i++) {
|
||||||
|
int idx = i;
|
||||||
|
int input_h = idx / out_w * conv_param->stride_h_;
|
||||||
|
int input_w = idx % out_w * conv_param->stride_w_;
|
||||||
|
int input_row = -conv_param->pad_u_ + i_kh + input_h;
|
||||||
|
int input_col = -conv_param->pad_l_ + i_kw + input_w;
|
||||||
|
if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) {
|
||||||
|
int offset_x = (input_row * in_w + input_col) * out_ch + i_c;
|
||||||
|
int offset_dy = idx * out_ch + i_c;
|
||||||
|
|
||||||
|
float32x4_t x_03_4 = vld1q_f32(x_addr + offset_x);
|
||||||
|
float32x4_t dy_03_4 = vld1q_f32(dy_addr + offset_dy);
|
||||||
|
sum_03_4 = vmlaq_f32(sum_03_4, x_03_4, dy_03_4);
|
||||||
|
|
||||||
|
float32x4_t x_47_4 = vld1q_f32(x_addr + offset_x + 4);
|
||||||
|
float32x4_t dy_47_4 = vld1q_f32(dy_addr + offset_dy + 4);
|
||||||
|
sum_47_4 = vmlaq_f32(sum_47_4, x_47_4, dy_47_4);
|
||||||
|
|
||||||
|
float32x4_t x_9x_4 = vld1q_f32(x_addr + offset_x + 8);
|
||||||
|
float32x4_t dy_9x_4 = vld1q_f32(dy_addr + offset_dy + 8);
|
||||||
|
sum_9x_4 = vmlaq_f32(sum_9x_4, x_9x_4, dy_9x_4);
|
||||||
|
|
||||||
|
float32x4_t x_12x_4 = vld1q_f32(x_addr + offset_x + 12);
|
||||||
|
float32x4_t dy_12x_4 = vld1q_f32(dy_addr + offset_dy + 12);
|
||||||
|
sum_12x_4 = vmlaq_f32(sum_12x_4, x_12x_4, dy_12x_4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dw[(i_c + 0) * k_spatial + k_idx] = sum_03_4[0];
|
||||||
|
dw[(i_c + 1) * k_spatial + k_idx] = sum_03_4[1];
|
||||||
|
dw[(i_c + 2) * k_spatial + k_idx] = sum_03_4[2];
|
||||||
|
dw[(i_c + 3) * k_spatial + k_idx] = sum_03_4[3];
|
||||||
|
|
||||||
|
dw[(i_c + 4) * k_spatial + k_idx] = sum_47_4[0];
|
||||||
|
dw[(i_c + 5) * k_spatial + k_idx] = sum_47_4[1];
|
||||||
|
dw[(i_c + 6) * k_spatial + k_idx] = sum_47_4[2];
|
||||||
|
dw[(i_c + 7) * k_spatial + k_idx] = sum_47_4[3];
|
||||||
|
|
||||||
|
dw[(i_c + 8) * k_spatial + k_idx] = sum_9x_4[0];
|
||||||
|
dw[(i_c + 9) * k_spatial + k_idx] = sum_9x_4[1];
|
||||||
|
dw[(i_c + 10) * k_spatial + k_idx] = sum_9x_4[2];
|
||||||
|
dw[(i_c + 11) * k_spatial + k_idx] = sum_9x_4[3];
|
||||||
|
|
||||||
|
dw[(i_c + 12) * k_spatial + k_idx] = sum_12x_4[0];
|
||||||
|
dw[(i_c + 13) * k_spatial + k_idx] = sum_12x_4[1];
|
||||||
|
dw[(i_c + 14) * k_spatial + k_idx] = sum_12x_4[2];
|
||||||
|
dw[(i_c + 15) * k_spatial + k_idx] = sum_12x_4[3];
|
||||||
|
}
|
||||||
|
return i_c;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int FilterGrad12Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw,
|
||||||
|
const ConvParameter *conv_param) {
|
||||||
|
int in_h = conv_param->input_h_;
|
||||||
|
int in_w = conv_param->input_w_;
|
||||||
|
int k_h = conv_param->kernel_h_;
|
||||||
|
int k_w = conv_param->kernel_w_;
|
||||||
|
int batch = conv_param->output_batch_;
|
||||||
|
int out_ch = conv_param->output_channel_;
|
||||||
|
int in_ch = conv_param->input_channel_;
|
||||||
|
int out_h = conv_param->output_h_;
|
||||||
|
int out_w = conv_param->output_w_;
|
||||||
|
|
||||||
|
int m = out_h * out_w;
|
||||||
|
int x_size = in_h * in_w * in_ch;
|
||||||
|
int y_size = out_ch * out_h * out_w;
|
||||||
|
int k_spatial = k_w * k_h;
|
||||||
|
int i_kh = k_idx / k_w;
|
||||||
|
int i_kw = k_idx % k_w;
|
||||||
|
if ((out_ch - i_c) >= 12) {
|
||||||
|
float32x4_t sum_03_4 = vdupq_n_f32(0.0f);
|
||||||
|
float32x4_t sum_47_4 = vdupq_n_f32(0.0f);
|
||||||
|
float32x4_t sum_9x_4 = vdupq_n_f32(0.0f);
|
||||||
|
for (int b = 0; b < batch; ++b) {
|
||||||
|
const float *x_addr = &x[b * x_size];
|
||||||
|
const float *dy_addr = &dy[b * y_size];
|
||||||
|
|
||||||
|
for (int i = 0; i < m; i++) {
|
||||||
|
int idx = i;
|
||||||
|
int input_h = idx / out_w * conv_param->stride_h_;
|
||||||
|
int input_w = idx % out_w * conv_param->stride_w_;
|
||||||
|
int input_row = -conv_param->pad_u_ + i_kh + input_h;
|
||||||
|
int input_col = -conv_param->pad_l_ + i_kw + input_w;
|
||||||
|
if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) {
|
||||||
|
int offset_x = (input_row * in_w + input_col) * out_ch + i_c;
|
||||||
|
int offset_dy = idx * out_ch + i_c;
|
||||||
|
|
||||||
|
float32x4_t x_03_4 = vld1q_f32(x_addr + offset_x);
|
||||||
|
float32x4_t dy_03_4 = vld1q_f32(dy_addr + offset_dy);
|
||||||
|
sum_03_4 = vmlaq_f32(sum_03_4, x_03_4, dy_03_4);
|
||||||
|
|
||||||
|
float32x4_t x_47_4 = vld1q_f32(x_addr + offset_x + 4);
|
||||||
|
float32x4_t dy_47_4 = vld1q_f32(dy_addr + offset_dy + 4);
|
||||||
|
sum_47_4 = vmlaq_f32(sum_47_4, x_47_4, dy_47_4);
|
||||||
|
|
||||||
|
float32x4_t x_9x_4 = vld1q_f32(x_addr + offset_x + 8);
|
||||||
|
float32x4_t dy_9x_4 = vld1q_f32(dy_addr + offset_dy + 8);
|
||||||
|
sum_9x_4 = vmlaq_f32(sum_9x_4, x_9x_4, dy_9x_4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dw[(i_c + 0) * k_spatial + k_idx] = sum_03_4[0];
|
||||||
|
dw[(i_c + 1) * k_spatial + k_idx] = sum_03_4[1];
|
||||||
|
dw[(i_c + 2) * k_spatial + k_idx] = sum_03_4[2];
|
||||||
|
dw[(i_c + 3) * k_spatial + k_idx] = sum_03_4[3];
|
||||||
|
|
||||||
|
dw[(i_c + 4) * k_spatial + k_idx] = sum_47_4[0];
|
||||||
|
dw[(i_c + 5) * k_spatial + k_idx] = sum_47_4[1];
|
||||||
|
dw[(i_c + 6) * k_spatial + k_idx] = sum_47_4[2];
|
||||||
|
dw[(i_c + 7) * k_spatial + k_idx] = sum_47_4[3];
|
||||||
|
|
||||||
|
dw[(i_c + 8) * k_spatial + k_idx] = sum_9x_4[0];
|
||||||
|
dw[(i_c + 9) * k_spatial + k_idx] = sum_9x_4[1];
|
||||||
|
dw[(i_c + 10) * k_spatial + k_idx] = sum_9x_4[2];
|
||||||
|
dw[(i_c + 11) * k_spatial + k_idx] = sum_9x_4[3];
|
||||||
|
|
||||||
|
i_c += 12;
|
||||||
|
}
|
||||||
|
return i_c;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int FilterGrad8Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw,
|
||||||
|
const ConvParameter *conv_param) {
|
||||||
|
int in_h = conv_param->input_h_;
|
||||||
|
int in_w = conv_param->input_w_;
|
||||||
|
int k_h = conv_param->kernel_h_;
|
||||||
|
int k_w = conv_param->kernel_w_;
|
||||||
|
int batch = conv_param->output_batch_;
|
||||||
|
int out_ch = conv_param->output_channel_;
|
||||||
|
int in_ch = conv_param->input_channel_;
|
||||||
|
int out_h = conv_param->output_h_;
|
||||||
|
int out_w = conv_param->output_w_;
|
||||||
|
|
||||||
|
int m = out_h * out_w;
|
||||||
|
int x_size = in_h * in_w * in_ch;
|
||||||
|
int y_size = out_ch * out_h * out_w;
|
||||||
|
int k_spatial = k_w * k_h;
|
||||||
|
int i_kh = k_idx / k_w;
|
||||||
|
int i_kw = k_idx % k_w;
|
||||||
|
|
||||||
|
if ((out_ch - i_c) >= 8) {
|
||||||
|
float32x4_t sum_03_4 = vdupq_n_f32(0.0f);
|
||||||
|
float32x4_t sum_47_4 = vdupq_n_f32(0.0f);
|
||||||
|
for (int b = 0; b < batch; ++b) {
|
||||||
|
const float *x_addr = &x[b * x_size];
|
||||||
|
const float *dy_addr = &dy[b * y_size];
|
||||||
|
|
||||||
|
for (int i = 0; i < m; i++) {
|
||||||
|
int idx = i;
|
||||||
|
int input_h = idx / out_w * conv_param->stride_h_;
|
||||||
|
int input_w = idx % out_w * conv_param->stride_w_;
|
||||||
|
int input_row = -conv_param->pad_u_ + i_kh + input_h;
|
||||||
|
int input_col = -conv_param->pad_l_ + i_kw + input_w;
|
||||||
|
if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) {
|
||||||
|
int offset_x = (input_row * in_w + input_col) * out_ch + i_c;
|
||||||
|
int offset_dy = idx * out_ch + i_c;
|
||||||
|
|
||||||
|
float32x4_t x_03_4 = vld1q_f32(x_addr + offset_x);
|
||||||
|
float32x4_t dy_03_4 = vld1q_f32(dy_addr + offset_dy);
|
||||||
|
sum_03_4 = vmlaq_f32(sum_03_4, x_03_4, dy_03_4);
|
||||||
|
|
||||||
|
float32x4_t x_47_4 = vld1q_f32(x_addr + offset_x + 4);
|
||||||
|
float32x4_t dy_47_4 = vld1q_f32(dy_addr + offset_dy + 4);
|
||||||
|
sum_47_4 = vmlaq_f32(sum_47_4, x_47_4, dy_47_4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dw[(i_c + 0) * k_spatial + k_idx] = sum_03_4[0];
|
||||||
|
dw[(i_c + 1) * k_spatial + k_idx] = sum_03_4[1];
|
||||||
|
dw[(i_c + 2) * k_spatial + k_idx] = sum_03_4[2];
|
||||||
|
dw[(i_c + 3) * k_spatial + k_idx] = sum_03_4[3];
|
||||||
|
|
||||||
|
dw[(i_c + 4) * k_spatial + k_idx] = sum_47_4[0];
|
||||||
|
dw[(i_c + 5) * k_spatial + k_idx] = sum_47_4[1];
|
||||||
|
dw[(i_c + 6) * k_spatial + k_idx] = sum_47_4[2];
|
||||||
|
dw[(i_c + 7) * k_spatial + k_idx] = sum_47_4[3];
|
||||||
|
i_c += 8;
|
||||||
|
}
|
||||||
|
return i_c;
|
||||||
|
}
|
||||||
|
static int FilterGrad4Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw,
|
||||||
|
const ConvParameter *conv_param) {
|
||||||
|
int in_h = conv_param->input_h_;
|
||||||
|
int in_w = conv_param->input_w_;
|
||||||
|
int k_h = conv_param->kernel_h_;
|
||||||
|
int k_w = conv_param->kernel_w_;
|
||||||
|
int batch = conv_param->output_batch_;
|
||||||
|
int out_ch = conv_param->output_channel_;
|
||||||
|
int in_ch = conv_param->input_channel_;
|
||||||
|
int out_h = conv_param->output_h_;
|
||||||
|
int out_w = conv_param->output_w_;
|
||||||
|
|
||||||
|
int m = out_h * out_w;
|
||||||
|
int x_size = in_h * in_w * in_ch;
|
||||||
|
int y_size = out_ch * out_h * out_w;
|
||||||
|
int k_spatial = k_w * k_h;
|
||||||
|
int i_kh = k_idx / k_w;
|
||||||
|
int i_kw = k_idx % k_w;
|
||||||
|
if ((out_ch - i_c) >= 4) {
|
||||||
|
float32x4_t sum_4 = vdupq_n_f32(0.0f);
|
||||||
|
|
||||||
|
for (int b = 0; b < batch; ++b) {
|
||||||
|
const float *x_addr = &x[b * x_size];
|
||||||
|
const float *dy_addr = &dy[b * y_size];
|
||||||
|
|
||||||
|
for (int i = 0; i < m; i++) {
|
||||||
|
int idx = i;
|
||||||
|
int input_h = idx / out_w * conv_param->stride_h_;
|
||||||
|
int input_w = idx % out_w * conv_param->stride_w_;
|
||||||
|
int input_row = -conv_param->pad_u_ + i_kh + input_h;
|
||||||
|
int input_col = -conv_param->pad_l_ + i_kw + input_w;
|
||||||
|
if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) {
|
||||||
|
int offset_x = (input_row * in_w + input_col) * out_ch + i_c;
|
||||||
|
int offset_dy = idx * out_ch + i_c;
|
||||||
|
|
||||||
|
float32x4_t x_4 = vld1q_f32(x_addr + offset_x);
|
||||||
|
float32x4_t dy_4 = vld1q_f32(dy_addr + offset_dy);
|
||||||
|
sum_4 = vmlaq_f32(sum_4, x_4, dy_4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dw[(i_c + 0) * k_spatial + k_idx] = sum_4[0];
|
||||||
|
dw[(i_c + 1) * k_spatial + k_idx] = sum_4[1];
|
||||||
|
dw[(i_c + 2) * k_spatial + k_idx] = sum_4[2];
|
||||||
|
dw[(i_c + 3) * k_spatial + k_idx] = sum_4[3];
|
||||||
|
i_c += 4;
|
||||||
|
}
|
||||||
|
return i_c;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int Filtergrad2Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw,
|
||||||
|
const ConvParameter *conv_param) {
|
||||||
|
int in_h = conv_param->input_h_;
|
||||||
|
int in_w = conv_param->input_w_;
|
||||||
|
int k_h = conv_param->kernel_h_;
|
||||||
|
int k_w = conv_param->kernel_w_;
|
||||||
|
int batch = conv_param->output_batch_;
|
||||||
|
int out_ch = conv_param->output_channel_;
|
||||||
|
int in_ch = conv_param->input_channel_;
|
||||||
|
int out_h = conv_param->output_h_;
|
||||||
|
int out_w = conv_param->output_w_;
|
||||||
|
|
||||||
|
int m = out_h * out_w;
|
||||||
|
int x_size = in_h * in_w * in_ch;
|
||||||
|
int y_size = out_ch * out_h * out_w;
|
||||||
|
int k_spatial = k_w * k_h;
|
||||||
|
int i_kh = k_idx / k_w;
|
||||||
|
int i_kw = k_idx % k_w;
|
||||||
|
|
||||||
|
if ((out_ch - i_c) >= 2) {
|
||||||
|
float32x2_t sum_2 = vdup_n_f32(0.0f);
|
||||||
|
for (int b = 0; b < batch; ++b) {
|
||||||
|
const float *x_addr = &x[b * x_size];
|
||||||
|
const float *dy_addr = &dy[b * y_size];
|
||||||
|
|
||||||
|
for (int i = 0; i < m; i++) {
|
||||||
|
int idx = i;
|
||||||
|
int input_h = idx / out_w * conv_param->stride_h_;
|
||||||
|
int input_w = idx % out_w * conv_param->stride_w_;
|
||||||
|
int input_row = -conv_param->pad_u_ + i_kh + input_h;
|
||||||
|
int input_col = -conv_param->pad_l_ + i_kw + input_w;
|
||||||
|
if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) {
|
||||||
|
int offset_x = (input_row * in_w + input_col) * out_ch + i_c;
|
||||||
|
int offset_dy = idx * out_ch + i_c;
|
||||||
|
|
||||||
|
float32x2_t x_4 = vld1_f32(x_addr + offset_x);
|
||||||
|
float32x2_t dy_4 = vld1_f32(dy_addr + offset_dy);
|
||||||
|
sum_2 = vmla_f32(sum_2, x_4, dy_4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dw[(i_c + 0) * k_spatial + k_idx] = sum_2[0];
|
||||||
|
dw[(i_c + 1) * k_spatial + k_idx] = sum_2[1];
|
||||||
|
i_c += 2;
|
||||||
|
}
|
||||||
|
return i_c += 2;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
int ConvDwFilterGrad(const float *x, const float *dy, float *dw, int start, int count,
|
||||||
|
const ConvParameter *conv_param) {
|
||||||
|
int in_h = conv_param->input_h_;
|
||||||
|
int in_w = conv_param->input_w_;
|
||||||
|
int k_h = conv_param->kernel_h_;
|
||||||
|
int k_w = conv_param->kernel_w_;
|
||||||
|
int batch = conv_param->output_batch_;
|
||||||
|
int out_ch = conv_param->output_channel_;
|
||||||
|
int in_ch = conv_param->input_channel_;
|
||||||
|
int out_h = conv_param->output_h_;
|
||||||
|
int out_w = conv_param->output_w_;
|
||||||
|
|
||||||
|
int m = out_h * out_w;
|
||||||
|
int x_size = in_h * in_w * in_ch;
|
||||||
|
int y_size = out_ch * out_h * out_w;
|
||||||
|
int k_spatial = k_w * k_h;
|
||||||
|
|
||||||
|
for (int i_k = 0; i_k < count; i_k++) {
|
||||||
|
int k_idx = start + i_k;
|
||||||
|
int i_kh = k_idx / k_w;
|
||||||
|
int i_kw = k_idx % k_w;
|
||||||
|
int i_c = 0;
|
||||||
|
#ifdef ENABLE_ARM
|
||||||
|
i_c = FilterGrad16Arm(x, dy, i_c, k_idx, dw, conv_param);
|
||||||
|
i_c = FilterGrad12Arm(x, dy, i_c, k_idx, dw, conv_param);
|
||||||
|
i_c = FilterGrad8Arm(x, dy, i_c, k_idx, dw, conv_param);
|
||||||
|
i_c = FilterGrad4Arm(x, dy, i_c, k_idx, dw, conv_param);
|
||||||
|
i_c = Filtergrad2Arm(x, dy, i_c, k_idx, dw, conv_param);
|
||||||
|
#endif
|
||||||
|
for (; i_c < out_ch; i_c++) {
|
||||||
|
float sum = 0;
|
||||||
|
for (int b = 0; b < batch; ++b) {
|
||||||
|
const float *x_addr = &x[b * x_size];
|
||||||
|
const float *dy_addr = &dy[b * y_size];
|
||||||
|
|
||||||
|
for (int i = 0; i < m; i++) {
|
||||||
|
int idx = i;
|
||||||
|
int input_h = idx / out_w * conv_param->stride_h_;
|
||||||
|
int input_w = idx % out_w * conv_param->stride_w_;
|
||||||
|
int input_row = -conv_param->pad_u_ + i_kh + input_h;
|
||||||
|
int input_col = -conv_param->pad_l_ + i_kw + input_w;
|
||||||
|
if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) {
|
||||||
|
int offset_x = (input_row * in_w + input_col) * out_ch + i_c;
|
||||||
|
int offset_dy = idx * out_ch + i_c;
|
||||||
|
sum += x_addr[offset_x] * dy_addr[offset_dy];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dw[i_c * k_spatial + k_idx] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
|
@ -0,0 +1,32 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_
|
||||||
|
#define MINDSPORE_LITE_NNACL_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include "nnacl/conv_parameter.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
int ConvDwFilterGrad(const float *x, const float *dy, float *dw, int start, int count, const ConvParameter *conv_param);
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // MINDSPORE_LITE_NNACL_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_
|
|
@ -18,6 +18,56 @@
|
||||||
#include "nnacl/fp32_grad/pack_ext.h"
|
#include "nnacl/fp32_grad/pack_ext.h"
|
||||||
#include "nnacl/pack.h"
|
#include "nnacl/pack.h"
|
||||||
|
|
||||||
|
void RollingIm2ColPackDwUnitFp32(const float *in_data, const ConvParameter *conv_param, float *data_col_orig,
|
||||||
|
int real_cal_num, int start) {
|
||||||
|
const int pad_left = conv_param->pad_l_;
|
||||||
|
const int pad_up = conv_param->pad_u_;
|
||||||
|
|
||||||
|
const int stride_h = conv_param->stride_h_;
|
||||||
|
const int stride_w = conv_param->stride_w_;
|
||||||
|
|
||||||
|
const int dilation_h = conv_param->dilation_h_;
|
||||||
|
const int dilation_w = conv_param->dilation_w_;
|
||||||
|
|
||||||
|
const int kernel_h = conv_param->kernel_h_;
|
||||||
|
const int kernel_w = conv_param->kernel_w_;
|
||||||
|
|
||||||
|
const int in_height = conv_param->input_h_;
|
||||||
|
const int in_width = conv_param->input_w_;
|
||||||
|
|
||||||
|
const int output_w = conv_param->output_w_;
|
||||||
|
|
||||||
|
const int channels = conv_param->input_channel_;
|
||||||
|
const int stride = kernel_h * kernel_w;
|
||||||
|
|
||||||
|
int kernel_row, kernel_col;
|
||||||
|
|
||||||
|
for (int i = 0; i < real_cal_num; i++) {
|
||||||
|
int block_start = start + i;
|
||||||
|
int input_h = block_start / output_w * stride_h;
|
||||||
|
int input_w = block_start % output_w * stride_w;
|
||||||
|
float *data_col = data_col_orig + i * channels * stride;
|
||||||
|
for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
|
||||||
|
int input_row = -pad_up + kernel_row * dilation_h + input_h;
|
||||||
|
for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
|
||||||
|
int input_col = -pad_left + kernel_col * dilation_w + input_w;
|
||||||
|
if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) {
|
||||||
|
const int offset = (input_row * in_width + input_col) * channels;
|
||||||
|
for (int c = 0; c < channels; c++) {
|
||||||
|
data_col[c * stride] = in_data[offset + c];
|
||||||
|
}
|
||||||
|
data_col++;
|
||||||
|
} else {
|
||||||
|
for (int c = 0; c < channels; c++) {
|
||||||
|
data_col[c * stride] = 0;
|
||||||
|
}
|
||||||
|
data_col++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int real_cal_num,
|
void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int real_cal_num,
|
||||||
int start) {
|
int start) {
|
||||||
const int pad_left = conv_param->pad_l_;
|
const int pad_left = conv_param->pad_l_;
|
||||||
|
@ -90,85 +140,6 @@ void RollingIm2ColPackUnitFp32(const float *input_data, const ConvParameter *con
|
||||||
rolling_im2col_hwc(input_data, packed_input, conv_param, real_cal_num, block_index);
|
rolling_im2col_hwc(input_data, packed_input, conv_param, real_cal_num, block_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
void im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, bool transpose) {
|
|
||||||
const int pad_left = conv_param->pad_l_;
|
|
||||||
const int pad_up = conv_param->pad_u_;
|
|
||||||
|
|
||||||
const int stride_h = conv_param->stride_h_;
|
|
||||||
const int stride_w = conv_param->stride_w_;
|
|
||||||
|
|
||||||
const int dilation_h = conv_param->dilation_h_;
|
|
||||||
const int dilation_w = conv_param->dilation_w_;
|
|
||||||
|
|
||||||
const int kernel_h = conv_param->kernel_h_;
|
|
||||||
const int kernel_w = conv_param->kernel_w_;
|
|
||||||
|
|
||||||
const int in_height = (transpose) ? conv_param->output_h_ : conv_param->input_h_;
|
|
||||||
const int in_width = (transpose) ? conv_param->output_w_ : conv_param->input_w_;
|
|
||||||
|
|
||||||
const int output_h = (transpose) ? conv_param->input_h_ : conv_param->output_h_;
|
|
||||||
const int output_w = (transpose) ? conv_param->input_w_ : conv_param->output_w_;
|
|
||||||
|
|
||||||
const int tot_channels = (transpose) ? conv_param->output_channel_ : conv_param->input_channel_;
|
|
||||||
const int channels = tot_channels / conv_param->group_;
|
|
||||||
int channel, kernel_row, kernel_col, output_rows, output_col;
|
|
||||||
|
|
||||||
if (transpose) {
|
|
||||||
for (channel = 0; channel < channels; channel++) {
|
|
||||||
for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
|
|
||||||
for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
|
|
||||||
int input_row = -pad_up + kernel_row * dilation_h;
|
|
||||||
for (output_rows = output_h; output_rows; output_rows--) {
|
|
||||||
if (!((unsigned)(input_row) < (unsigned)(in_height))) {
|
|
||||||
for (output_col = output_w; output_col; output_col--) {
|
|
||||||
*(data_row++) = 0;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
int input_col = -pad_left + kernel_col * dilation_w;
|
|
||||||
for (output_col = output_w; output_col; output_col--) {
|
|
||||||
if (((unsigned)(input_col) < (unsigned)(in_width))) {
|
|
||||||
const int offset = (input_row * in_width + input_col) * tot_channels + channel;
|
|
||||||
*(data_row++) = in_data[offset];
|
|
||||||
} else {
|
|
||||||
*(data_row++) = 0;
|
|
||||||
}
|
|
||||||
input_col += stride_w;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
input_row += stride_h;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
|
|
||||||
for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
|
|
||||||
for (channel = 0; channel < channels; channel++) {
|
|
||||||
int input_row = -pad_up + kernel_row * dilation_h;
|
|
||||||
for (output_rows = output_h; output_rows; output_rows--) {
|
|
||||||
if (!((unsigned)(input_row) < (unsigned)(in_height))) {
|
|
||||||
for (output_col = output_w; output_col; output_col--) {
|
|
||||||
*(data_row++) = 0;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
int input_col = -pad_left + kernel_col * dilation_w;
|
|
||||||
for (output_col = output_w; output_col; output_col--) {
|
|
||||||
if (((unsigned)(input_col) < (unsigned)(in_width))) {
|
|
||||||
const int offset = (input_row * in_width + input_col) * tot_channels + channel;
|
|
||||||
*(data_row++) = in_data[offset];
|
|
||||||
} else {
|
|
||||||
*(data_row++) = 0;
|
|
||||||
}
|
|
||||||
input_col += stride_w;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
input_row += stride_h;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start) {
|
void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start) {
|
||||||
const int pad_left = conv_param->pad_l_;
|
const int pad_left = conv_param->pad_l_;
|
||||||
const int pad_up = conv_param->pad_u_;
|
const int pad_up = conv_param->pad_u_;
|
||||||
|
|
|
@ -26,6 +26,9 @@ extern "C" {
|
||||||
|
|
||||||
void RollingIm2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input,
|
void RollingIm2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input,
|
||||||
int real_cal_num, int block_index);
|
int real_cal_num, int block_index);
|
||||||
|
void RollingIm2ColPackDwUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input,
|
||||||
|
int real_cal_num, int block_index);
|
||||||
|
|
||||||
void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int rows, int start);
|
void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int rows, int start);
|
||||||
void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start);
|
void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start);
|
||||||
void rolling_col2im_hwc(const float *data_col, float *data_im, const ConvParameter *conv_param, int rows, int start);
|
void rolling_col2im_hwc(const float *data_col, float *data_im, const ConvParameter *conv_param, int rows, int start);
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
#include <float.h>
|
#include <float.h>
|
||||||
#include "nnacl/fp32_grad/pooling_grad.h"
|
#include "nnacl/fp32_grad/pooling_grad.h"
|
||||||
|
|
||||||
void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id) {
|
void AvgPoolingGrad(const float *input_ptr, float *output_ptr, int count, PoolingParameter *pooling_param) {
|
||||||
int stride_w = pooling_param->stride_w_;
|
int stride_w = pooling_param->stride_w_;
|
||||||
int stride_h = pooling_param->stride_h_;
|
int stride_h = pooling_param->stride_h_;
|
||||||
int pad_w = pooling_param->pad_l_;
|
int pad_w = pooling_param->pad_l_;
|
||||||
|
@ -30,29 +30,58 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter
|
||||||
int in_h = pooling_param->input_h_;
|
int in_h = pooling_param->input_h_;
|
||||||
int output_w = pooling_param->output_w_;
|
int output_w = pooling_param->output_w_;
|
||||||
int output_h = pooling_param->output_h_;
|
int output_h = pooling_param->output_h_;
|
||||||
int output_batch = pooling_param->output_batch_;
|
|
||||||
|
|
||||||
memset(output_ptr, 0, in_h * in_w * channel * output_batch * sizeof(float));
|
const float kk = 1.0f / (float)(win_h * win_w);
|
||||||
float kk = (float)(win_h * win_w);
|
#if ENABLE_ARM
|
||||||
for (int ib = 0; ib < output_batch; ib++) {
|
const float32x4_t factor = vdupq_n_f32(kk);
|
||||||
|
#endif
|
||||||
|
for (int ib = 0; ib < count; ib++) {
|
||||||
float *out = &output_ptr[(ib * in_h * in_w * channel)];
|
float *out = &output_ptr[(ib * in_h * in_w * channel)];
|
||||||
const float *inPtr = &input_ptr[(ib * output_h * output_w * channel)];
|
const float *inPtr = &input_ptr[(ib * output_h * output_w * channel)];
|
||||||
// iterate over yt
|
// iterate over yt
|
||||||
for (int yh = 0; yh < output_h; yh++) {
|
for (int yh = 0; yh < output_h; yh++) {
|
||||||
|
int over_h = pad_h - yh * stride_h;
|
||||||
|
int kh_s = MSMAX(0, over_h);
|
||||||
|
int kh_e = MSMIN(win_h, in_h + over_h);
|
||||||
for (int yw = 0; yw < output_w; yw++) {
|
for (int yw = 0; yw < output_w; yw++) {
|
||||||
for (int ic = 0; ic < channel; ic++) {
|
int over_w = pad_w - yw * stride_w;
|
||||||
|
int kw_s = MSMAX(0, over_w);
|
||||||
|
int kw_e = MSMIN(win_w, in_w + over_w);
|
||||||
|
int ic = 0;
|
||||||
|
for (; ic < channel - 4; ic += 4) {
|
||||||
int idx = (yw + yh * output_w) * channel + ic;
|
int idx = (yw + yh * output_w) * channel + ic;
|
||||||
float delta = inPtr[idx] / kk;
|
#ifdef ENABLE_ARM
|
||||||
for (int kh = 0; kh < win_h; kh++) {
|
float32x4_t in = vld1q_f32(inPtr + idx);
|
||||||
|
float32x4_t delta = vmulq_f32(in, factor);
|
||||||
|
#else
|
||||||
|
float delta[4] = {inPtr[idx], inPtr[idx + 1], inPtr[idx + 2], inPtr[idx + 3]};
|
||||||
|
for (int i = 0; i < 4; i++) delta[i] *= kk;
|
||||||
|
#endif
|
||||||
|
for (int kh = kh_s; kh < kh_e; kh++) {
|
||||||
int xh = yh * stride_h + kh - pad_h;
|
int xh = yh * stride_h + kh - pad_h;
|
||||||
if ((xh < 0) || (xh >= in_h)) {
|
for (int kw = kw_s; kw < kw_e; kw++) {
|
||||||
continue;
|
|
||||||
}
|
|
||||||
for (int kw = 0; kw < win_w; kw++) {
|
|
||||||
int xw = yw * stride_w + kw - pad_w;
|
int xw = yw * stride_w + kw - pad_w;
|
||||||
if ((xw < 0) || (xw >= in_w)) {
|
#ifdef ENABLE_ARM
|
||||||
continue;
|
float *out_vec = out + (xw + in_w * xh) * channel + ic;
|
||||||
|
float32x4_t outr = vld1q_f32(out + (xw + in_w * xh) * channel + ic);
|
||||||
|
float32x4_t outs = vaddq_s32(outr, delta);
|
||||||
|
vst1q_f32(out_vec, outs);
|
||||||
|
#else
|
||||||
|
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
out[(xw + in_w * xh) * channel + ic + i] += ((float *)&delta)[i];
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (; ic < channel; ic++) {
|
||||||
|
int idx = (yw + yh * output_w) * channel + ic;
|
||||||
|
float delta = inPtr[idx] * kk;
|
||||||
|
for (int kh = kh_s; kh < kh_e; kh++) {
|
||||||
|
int xh = yh * stride_h + kh - pad_h;
|
||||||
|
for (int kw = kw_s; kw < kw_e; kw++) {
|
||||||
|
int xw = yw * stride_w + kw - pad_w;
|
||||||
out[(xw + in_w * xh) * channel + ic] += delta;
|
out[(xw + in_w * xh) * channel + ic] += delta;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -62,8 +91,17 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy_ptr, float *output_ptr,
|
#ifdef ENABLE_ARM
|
||||||
PoolingParameter *pooling_param, int task_id) {
|
static int32x4_t MaxIndex(float32x4_t in, float32x4_t *max, int32x4_t index, int32x4_t prev_index) {
|
||||||
|
uint32x4_t res = vcgtq_f32(in, *max);
|
||||||
|
uint32x4_t m_index = vbslq_f32(res, index, prev_index);
|
||||||
|
*max = vbslq_f32(res, in, *max);
|
||||||
|
return m_index;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void MaxPoolingGrad(const float *input_ptr, const float *dy_ptr, float *output_ptr, int output_batch,
|
||||||
|
PoolingParameter *pooling_param) {
|
||||||
int stride_w = pooling_param->stride_w_;
|
int stride_w = pooling_param->stride_w_;
|
||||||
int stride_h = pooling_param->stride_h_;
|
int stride_h = pooling_param->stride_h_;
|
||||||
int pad_w = pooling_param->pad_l_;
|
int pad_w = pooling_param->pad_l_;
|
||||||
|
@ -75,36 +113,71 @@ void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy
|
||||||
int in_h = pooling_param->input_h_;
|
int in_h = pooling_param->input_h_;
|
||||||
int output_w = pooling_param->output_w_;
|
int output_w = pooling_param->output_w_;
|
||||||
int output_h = pooling_param->output_h_;
|
int output_h = pooling_param->output_h_;
|
||||||
int output_batch = pooling_param->output_batch_;
|
|
||||||
|
|
||||||
memset(output_ptr, 0, in_h * in_w * channel * output_batch * sizeof(float));
|
|
||||||
for (int ib = 0; ib < output_batch; ib++) {
|
for (int ib = 0; ib < output_batch; ib++) {
|
||||||
float *out = &output_ptr[(ib * in_h * in_w * channel)];
|
float *out = &output_ptr[(ib * in_h * in_w * channel)];
|
||||||
const float *inPtr = (const float *)(&input_ptr[(ib * in_h * in_w * channel)]);
|
const float *inPtr = &input_ptr[(ib * in_h * in_w * channel)];
|
||||||
const float *dyPtr = (const float *)(&dy_ptr[(ib * output_h * output_w * channel)]);
|
const float *dyPtr = &dy_ptr[(ib * output_h * output_w * channel)];
|
||||||
|
|
||||||
for (int yh = 0; yh < output_h; yh++) {
|
for (int yh = 0; yh < output_h; yh++) {
|
||||||
|
int over_h = pad_h - yh * stride_h;
|
||||||
|
int kh_s = MSMAX(0, over_h);
|
||||||
|
int kh_e = MSMIN(win_h, in_h + over_h);
|
||||||
for (int yw = 0; yw < output_w; yw++) {
|
for (int yw = 0; yw < output_w; yw++) {
|
||||||
for (int ic = 0; ic < channel; ic++) {
|
int over_w = pad_w - yw * stride_w;
|
||||||
|
int kw_s = MSMAX(0, over_w);
|
||||||
|
int kw_e = MSMIN(win_w, in_w + over_w);
|
||||||
|
int ic = 0;
|
||||||
|
for (; ic < channel - 4; ic += 4) {
|
||||||
int idx = (yw + yh * output_w) * channel + ic;
|
int idx = (yw + yh * output_w) * channel + ic;
|
||||||
|
#ifdef ENABLE_ARM
|
||||||
float delta = dyPtr[idx];
|
uint32x4_t max_idx = vdupq_n_u32(0);
|
||||||
|
float32x4_t max_val = vdupq_n_f32(-FLT_MAX);
|
||||||
|
float32x4_t delta = vld1q_f32(dyPtr + idx);
|
||||||
|
#else
|
||||||
|
float delta[4] = {dyPtr[idx], dyPtr[idx + 1], dyPtr[idx + 2], dyPtr[idx + 3]};
|
||||||
|
float max_val[4] = {-FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX};
|
||||||
|
int max_idx[4] = {0};
|
||||||
|
#endif
|
||||||
|
for (int kh = kh_s; kh < kh_e; kh++) {
|
||||||
|
int xh = yh * stride_h + kh - pad_h;
|
||||||
|
for (int kw = kw_s; kw < kw_e; kw++) {
|
||||||
|
int xw = yw * stride_w + kw - pad_w;
|
||||||
|
int val_idx = (xw + in_w * xh) * channel + ic;
|
||||||
|
#ifdef ENABLE_ARM
|
||||||
|
unsigned int val_idx_vec[] = {val_idx, val_idx + 1, val_idx + 2, val_idx + 3};
|
||||||
|
uint32x4_t index = vld1q_u32(val_idx_vec);
|
||||||
|
float32x4_t in = vld1q_f32(inPtr + val_idx);
|
||||||
|
max_idx = MaxIndex(in, &max_val, index, max_idx);
|
||||||
|
#else
|
||||||
|
float val[4] = {inPtr[val_idx], inPtr[val_idx + 1], inPtr[val_idx + 2], inPtr[val_idx + 3]};
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
if (val[i] > max_val[i]) {
|
||||||
|
max_val[i] = val[i];
|
||||||
|
max_idx[i] = val_idx + i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
out[((int *)&max_idx)[i]] += ((float *)&delta)[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (; ic < channel; ic++) {
|
||||||
float max_val = -FLT_MAX;
|
float max_val = -FLT_MAX;
|
||||||
int max_idx = 0;
|
int max_idx = 0;
|
||||||
for (int kh = 0; kh < win_h; kh++) {
|
int idx = (yw + yh * output_w) * channel + ic;
|
||||||
|
float delta = dyPtr[idx];
|
||||||
|
for (int kh = kh_s; kh < kh_e; kh++) {
|
||||||
int xh = yh * stride_h + kh - pad_h;
|
int xh = yh * stride_h + kh - pad_h;
|
||||||
if ((xh < 0) || (xh >= in_h)) {
|
int loop = kw_e - kw_s;
|
||||||
continue;
|
for (int kw = 0; kw < loop; kw++) {
|
||||||
}
|
int xw = yw * stride_w + kw + kw_s - pad_w;
|
||||||
for (int kw = 0; kw < win_w; kw++) {
|
int val_idx = (xw + in_w * xh) * channel + ic;
|
||||||
int xw = yw * stride_w + kw - pad_w;
|
float val = inPtr[val_idx];
|
||||||
if ((xw < 0) || (xw >= in_w)) {
|
if (val > max_val) {
|
||||||
continue;
|
max_val = val;
|
||||||
}
|
max_idx = val_idx;
|
||||||
|
|
||||||
if (inPtr[(xw + in_w * xh) * channel + ic] > max_val) {
|
|
||||||
max_val = inPtr[(xw + in_w * xh) * channel + ic];
|
|
||||||
max_idx = (xw + in_w * xh) * channel + ic;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,9 +22,9 @@
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id);
|
void AvgPoolingGrad(const float *input_ptr, float *output_ptr, int count, PoolingParameter *pooling_param);
|
||||||
void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy_ptr, float *output_ptr,
|
void MaxPoolingGrad(const float *input_ptr, const float *dy_ptr, float *output_ptr, int output_batch,
|
||||||
PoolingParameter *pooling_param, int task_id);
|
PoolingParameter *pooling_param);
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "nnacl/fp32_grad/strided_slice_grad.h"
|
||||||
|
#include "nnacl/errorcode.h"
|
||||||
|
|
||||||
|
static size_t CalcIndex(const int *shape, size_t size, int i, size_t pos) {
|
||||||
|
size_t res = 1;
|
||||||
|
for (size_t j = 0; j < size; j++) {
|
||||||
|
res *= shape[(i + 1) + j];
|
||||||
|
}
|
||||||
|
return (pos / res % shape[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
int DoStridedSliceGrad(const float *inputs, float *output, const int *dx_shape, StridedSliceParameter *param) {
|
||||||
|
if (inputs == NULL || output == NULL || param == NULL) {
|
||||||
|
return NNACL_NULL_PTR;
|
||||||
|
}
|
||||||
|
if (param->num_axes_ > DIMENSION_7D) {
|
||||||
|
return NNACL_PARAM_INVALID;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t size = 1;
|
||||||
|
int *s = param->strides_;
|
||||||
|
int *b = param->begins_;
|
||||||
|
for (int i = 0; i < DIMENSION_7D; i++) {
|
||||||
|
size *= param->in_shape_[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t pos = 0; pos < size; pos++) {
|
||||||
|
size_t i = CalcIndex(param->in_shape_, 6, 0, pos);
|
||||||
|
size_t j = CalcIndex(param->in_shape_, 5, 1, pos);
|
||||||
|
size_t k = CalcIndex(param->in_shape_, 4, 2, pos);
|
||||||
|
size_t l = CalcIndex(param->in_shape_, 3, 3, pos);
|
||||||
|
size_t m = CalcIndex(param->in_shape_, 2, 4, pos);
|
||||||
|
size_t n = CalcIndex(param->in_shape_, 1, 5, pos);
|
||||||
|
size_t o = CalcIndex(param->in_shape_, 0, 6, pos);
|
||||||
|
|
||||||
|
size_t input_idx =
|
||||||
|
(i * s[0] + b[0]) * dx_shape[1] * dx_shape[2] * dx_shape[3] * dx_shape[4] * dx_shape[5] * dx_shape[6] +
|
||||||
|
(j * s[1] + b[1]) * dx_shape[2] * dx_shape[3] * dx_shape[4] * dx_shape[5] * dx_shape[6] +
|
||||||
|
(k * s[2] + b[2]) * dx_shape[3] * dx_shape[4] * dx_shape[5] * dx_shape[6] +
|
||||||
|
(l * s[3] + b[3]) * dx_shape[4] * dx_shape[5] * dx_shape[6] + (m * s[4] + b[4]) * dx_shape[5] * dx_shape[6] +
|
||||||
|
(n * s[5] + b[5]) * dx_shape[6] + (o * s[6] + b[6]);
|
||||||
|
output[input_idx] = inputs[pos];
|
||||||
|
}
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
|
@ -0,0 +1,30 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_STRIDED_SLICE_GRAD_H_
|
||||||
|
#define MINDSPORE_LITE_NNACL_FP32_GRAD_STRIDED_SLICE_GRAD_H_
|
||||||
|
|
||||||
|
#include "nnacl/op_base.h"
|
||||||
|
#include "nnacl/strided_slice_parameter.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
int DoStridedSliceGrad(const float *inputs, float *output, const int *dx_shape, StridedSliceParameter *param);
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // MINDSPORE_LITE_NNACL_FP32_GRAD_STRIDED_SLICE_GRAD_H_
|
|
@ -53,6 +53,7 @@
|
||||||
|
|
||||||
#define DIMENSION_4D 4
|
#define DIMENSION_4D 4
|
||||||
#define DIMENSION_6D 6
|
#define DIMENSION_6D 6
|
||||||
|
#define DIMENSION_7D 7
|
||||||
#define kInputIndex 0
|
#define kInputIndex 0
|
||||||
#define kWeightIndex 1
|
#define kWeightIndex 1
|
||||||
#define kBiasIndex 2
|
#define kBiasIndex 2
|
||||||
|
|
|
@ -273,6 +273,7 @@ union PrimitiveType {
|
||||||
RandomStandardNormal,
|
RandomStandardNormal,
|
||||||
CropAndResize,
|
CropAndResize,
|
||||||
Erf,
|
Erf,
|
||||||
|
StridedSliceGrad
|
||||||
}
|
}
|
||||||
|
|
||||||
enum QuantType: int {
|
enum QuantType: int {
|
||||||
|
|
|
@ -1259,6 +1259,18 @@ table RandomStandardNormal {
|
||||||
table CropAndResize {
|
table CropAndResize {
|
||||||
method : ResizeMethod;
|
method : ResizeMethod;
|
||||||
extrapolation_value : float;
|
extrapolation_value : float;
|
||||||
|
}
|
||||||
|
|
||||||
|
table StridedSliceGrad {
|
||||||
|
beginMask: int;
|
||||||
|
endMask: int;
|
||||||
|
ellipsisMask: int;
|
||||||
|
newAxisMask: int;
|
||||||
|
shrinkAxisMask: int;
|
||||||
|
begin: [int];
|
||||||
|
end: [int];
|
||||||
|
stride: [int];
|
||||||
|
isScale: [int];
|
||||||
}
|
}
|
||||||
|
|
||||||
table Erf {
|
table Erf {
|
||||||
|
|
|
@ -31,7 +31,7 @@ int FlattenGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *>
|
||||||
MS_LOG(ERROR) << "FlattenGrad input or output is null!";
|
MS_LOG(ERROR) << "FlattenGrad input or output is null!";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
|
if (inputs_.size() != kDoubleNum || outputs_.size() != kSingleNum) {
|
||||||
MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size();
|
MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size();
|
||||||
return RET_INPUT_TENSOR_ERROR;
|
return RET_INPUT_TENSOR_ERROR;
|
||||||
}
|
}
|
||||||
|
@ -42,16 +42,15 @@ int FlattenGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *>
|
||||||
return RET_INFER_INVALID;
|
return RET_INFER_INVALID;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto input_shape = input->shape();
|
auto output_size = inputs_.at(1)->shape().at(0);
|
||||||
std::vector<int> output_shape(2);
|
std::vector<int> output_shape(output_size);
|
||||||
output_shape.at(0) = input_shape.at(0);
|
for (int i = 0; i < output_size; i++) {
|
||||||
output_shape.at(1) = 1;
|
output_shape.at(i) = static_cast<int *>(inputs_.at(1)->data_c())[i];
|
||||||
for (size_t i = 1; i < input_shape.size(); i++) {
|
|
||||||
output_shape.at(1) *= input_shape.at(i);
|
|
||||||
}
|
}
|
||||||
output->set_shape(output_shape);
|
output->set_shape(output_shape);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef PRIMITIVE_WRITEABLE
|
#ifdef PRIMITIVE_WRITEABLE
|
||||||
int FlattenGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
int FlattenGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||||
if (this->primitive_ == nullptr) {
|
if (this->primitive_ == nullptr) {
|
||||||
|
|
|
@ -91,6 +91,8 @@ int PoolingGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
|
||||||
attr->poolingMode = schema::PoolMode_MEAN_POOLING;
|
attr->poolingMode = schema::PoolMode_MEAN_POOLING;
|
||||||
} else if (prim.instance_name() == "AvgPoolGradGpu") {
|
} else if (prim.instance_name() == "AvgPoolGradGpu") {
|
||||||
attr->poolingMode = schema::PoolMode_MEAN_POOLING;
|
attr->poolingMode = schema::PoolMode_MEAN_POOLING;
|
||||||
|
} else if (prim.instance_name() == "AvgPoolGradCpu") {
|
||||||
|
attr->poolingMode = schema::PoolMode_MEAN_POOLING;
|
||||||
} else {
|
} else {
|
||||||
attr->poolingMode = schema::PoolMode_MAX_POOLING;
|
attr->poolingMode = schema::PoolMode_MAX_POOLING;
|
||||||
}
|
}
|
||||||
|
|
|
@ -202,6 +202,7 @@
|
||||||
#include "src/ops/smooth_l1_loss_grad.h"
|
#include "src/ops/smooth_l1_loss_grad.h"
|
||||||
#include "src/ops/sigmoid_cross_entropy_with_logits.h"
|
#include "src/ops/sigmoid_cross_entropy_with_logits.h"
|
||||||
#include "src/ops/sigmoid_cross_entropy_with_logits_grad.h"
|
#include "src/ops/sigmoid_cross_entropy_with_logits_grad.h"
|
||||||
|
#include "src/ops/strided_slice_grad.h"
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -724,6 +725,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
|
||||||
return NewPrimitiveC<SigmoidCrossEntropyWithLogitsGrad>(prim, inputs, quantType);
|
return NewPrimitiveC<SigmoidCrossEntropyWithLogitsGrad>(prim, inputs, quantType);
|
||||||
} else if (op_type == "Pad") {
|
} else if (op_type == "Pad") {
|
||||||
return NewPrimitiveC<Pad>(prim, inputs, quantType);
|
return NewPrimitiveC<Pad>(prim, inputs, quantType);
|
||||||
|
} else if (op_type == "StridedSliceGrad") {
|
||||||
|
return NewPrimitiveC<StridedSliceGrad>(prim, inputs, quantType);
|
||||||
#else
|
#else
|
||||||
} else if (op_type == "Conv2DBackpropInput") {
|
} else if (op_type == "Conv2DBackpropInput") {
|
||||||
return NewPrimitiveC<DeConv2D>(prim, inputs, quantType);
|
return NewPrimitiveC<DeConv2D>(prim, inputs, quantType);
|
||||||
|
@ -1102,6 +1105,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
|
||||||
return new (std::nothrow) SigmoidCrossEntropyWithLogits(primitive);
|
return new (std::nothrow) SigmoidCrossEntropyWithLogits(primitive);
|
||||||
case schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad:
|
case schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad:
|
||||||
return new (std::nothrow) SigmoidCrossEntropyWithLogitsGrad(primitive);
|
return new (std::nothrow) SigmoidCrossEntropyWithLogitsGrad(primitive);
|
||||||
|
case schema::PrimitiveType_StridedSliceGrad:
|
||||||
|
return new (std::nothrow) StridedSliceGrad(primitive);
|
||||||
#endif
|
#endif
|
||||||
default:
|
default:
|
||||||
MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type);
|
MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type);
|
||||||
|
|
|
@ -0,0 +1,266 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "src/ops/strided_slice_grad.h"
|
||||||
|
|
||||||
|
#ifndef PRIMITIVE_WRITEABLE
|
||||||
|
#include "src/ops/ops_register.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
|
||||||
|
#ifdef PRIMITIVE_WRITEABLE
|
||||||
|
int StridedSliceGrad::GetBeginMask() const { return this->primitive_->value.AsStridedSliceGrad()->beginMask; }
|
||||||
|
int StridedSliceGrad::GetEndMask() const { return this->primitive_->value.AsStridedSliceGrad()->endMask; }
|
||||||
|
int StridedSliceGrad::GetEllipsisMask() const { return this->primitive_->value.AsStridedSliceGrad()->ellipsisMask; }
|
||||||
|
int StridedSliceGrad::GetNewAxisMask() const { return this->primitive_->value.AsStridedSliceGrad()->newAxisMask; }
|
||||||
|
int StridedSliceGrad::GetShrinkAxisMask() const { return this->primitive_->value.AsStridedSliceGrad()->shrinkAxisMask; }
|
||||||
|
std::vector<int> StridedSliceGrad::GetBegin() const { return this->primitive_->value.AsStridedSliceGrad()->begin; }
|
||||||
|
std::vector<int> StridedSliceGrad::GetEnd() const { return this->primitive_->value.AsStridedSliceGrad()->end; }
|
||||||
|
std::vector<int> StridedSliceGrad::GetStride() const { return this->primitive_->value.AsStridedSliceGrad()->stride; }
|
||||||
|
std::vector<int> StridedSliceGrad::GetIsScale() const { return this->primitive_->value.AsStridedSliceGrad()->isScale; }
|
||||||
|
|
||||||
|
void StridedSliceGrad::SetBeginMask(int begin_mask) {
|
||||||
|
this->primitive_->value.AsStridedSliceGrad()->beginMask = begin_mask;
|
||||||
|
}
|
||||||
|
void StridedSliceGrad::SetEndMask(int end_mask) { this->primitive_->value.AsStridedSliceGrad()->endMask = end_mask; }
|
||||||
|
void StridedSliceGrad::SetEllipsisMask(int ellipsis_mask) {
|
||||||
|
this->primitive_->value.AsStridedSliceGrad()->ellipsisMask = ellipsis_mask;
|
||||||
|
}
|
||||||
|
void StridedSliceGrad::SetNewAxisMask(int new_axis_mask) {
|
||||||
|
this->primitive_->value.AsStridedSliceGrad()->newAxisMask = new_axis_mask;
|
||||||
|
}
|
||||||
|
void StridedSliceGrad::SetShrinkAxisMask(int shrink_axis_mask) {
|
||||||
|
this->primitive_->value.AsStridedSliceGrad()->shrinkAxisMask = shrink_axis_mask;
|
||||||
|
}
|
||||||
|
void StridedSliceGrad::SetBegin(const std::vector<int> &begin) {
|
||||||
|
this->primitive_->value.AsStridedSliceGrad()->begin = begin;
|
||||||
|
}
|
||||||
|
void StridedSliceGrad::SetEnd(const std::vector<int> &end) { this->primitive_->value.AsStridedSliceGrad()->end = end; }
|
||||||
|
void StridedSliceGrad::SetStride(const std::vector<int> &stride) {
|
||||||
|
this->primitive_->value.AsStridedSliceGrad()->stride = stride;
|
||||||
|
}
|
||||||
|
void StridedSliceGrad::SetIsScale(const std::vector<int> &is_scale) {
|
||||||
|
this->primitive_->value.AsStridedSliceGrad()->isScale = is_scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
int StridedSliceGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||||
|
if (this->primitive_ == nullptr) {
|
||||||
|
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||||
|
if (this->primitive_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new primitiveT failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
this->primitive_->value.type = schema::PrimitiveType_StridedSliceGrad;
|
||||||
|
}
|
||||||
|
if (this->primitive_->value.type != schema::PrimitiveType_StridedSliceGrad) {
|
||||||
|
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
if (this->primitive_->value.value == nullptr) {
|
||||||
|
auto attr = new (std::nothrow) schema::StridedSliceGradT();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new StridedSliceGrad failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
attr->beginMask = CastToInt(prim.GetAttr("begin_mask")).front();
|
||||||
|
attr->endMask = CastToInt(prim.GetAttr("end_mask")).front();
|
||||||
|
attr->ellipsisMask = CastToInt(prim.GetAttr("ellipsis_mask")).front();
|
||||||
|
attr->newAxisMask = CastToInt(prim.GetAttr("new_axis_mask")).front();
|
||||||
|
attr->shrinkAxisMask = CastToInt(prim.GetAttr("shrink_axis_mask")).front();
|
||||||
|
auto inputNodeFirst = inputs[kAnfPopulaterInputNumOne];
|
||||||
|
std::vector<int> beginVec;
|
||||||
|
GetAttrDataFromInput(inputNodeFirst, &beginVec);
|
||||||
|
attr->begin = beginVec;
|
||||||
|
|
||||||
|
auto inputNodeSecond = inputs[kAnfPopulaterInputNumTwo];
|
||||||
|
std::vector<int> endVec;
|
||||||
|
GetAttrDataFromInput(inputNodeSecond, &endVec);
|
||||||
|
attr->end = endVec;
|
||||||
|
|
||||||
|
auto inputNodeThird = inputs[kAnfPopulaterInputNumThree];
|
||||||
|
std::vector<int> strideVec;
|
||||||
|
GetAttrDataFromInput(inputNodeThird, &strideVec);
|
||||||
|
attr->stride = strideVec;
|
||||||
|
this->primitive_->value.value = attr;
|
||||||
|
if (this->primitive_->value.value == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
int StridedSliceGrad::GetBeginMask() const { return this->primitive_->value_as_StridedSliceGrad()->beginMask(); }
|
||||||
|
int StridedSliceGrad::GetEndMask() const { return this->primitive_->value_as_StridedSliceGrad()->endMask(); }
|
||||||
|
int StridedSliceGrad::GetEllipsisMask() const { return this->primitive_->value_as_StridedSliceGrad()->ellipsisMask(); }
|
||||||
|
int StridedSliceGrad::GetNewAxisMask() const { return this->primitive_->value_as_StridedSliceGrad()->newAxisMask(); }
|
||||||
|
int StridedSliceGrad::GetShrinkAxisMask() const {
|
||||||
|
return this->primitive_->value_as_StridedSliceGrad()->shrinkAxisMask();
|
||||||
|
}
|
||||||
|
std::vector<int> StridedSliceGrad::GetBegin() const {
|
||||||
|
auto fb_vector = this->primitive_->value_as_StridedSliceGrad()->begin();
|
||||||
|
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||||
|
}
|
||||||
|
std::vector<int> StridedSliceGrad::GetEnd() const {
|
||||||
|
auto fb_vector = this->primitive_->value_as_StridedSliceGrad()->end();
|
||||||
|
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||||
|
}
|
||||||
|
std::vector<int> StridedSliceGrad::GetStride() const {
|
||||||
|
auto fb_vector = this->primitive_->value_as_StridedSliceGrad()->stride();
|
||||||
|
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||||
|
}
|
||||||
|
std::vector<int> StridedSliceGrad::GetIsScale() const {
|
||||||
|
auto fb_vector = this->primitive_->value_as_StridedSliceGrad()->isScale();
|
||||||
|
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||||
|
}
|
||||||
|
int StridedSliceGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||||
|
MS_ASSERT(nullptr != primitive);
|
||||||
|
MS_ASSERT(nullptr != fbb);
|
||||||
|
auto attr = primitive->value_as_StridedSliceGrad();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "value_as_StridedSliceGrad return nullptr";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
std::vector<int32_t> begin;
|
||||||
|
if (attr->begin() != nullptr) {
|
||||||
|
for (int i = 0; i < static_cast<int>(attr->begin()->size()); i++) {
|
||||||
|
begin.push_back(attr->begin()->data()[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::vector<int32_t> end;
|
||||||
|
if (attr->end() != nullptr) {
|
||||||
|
for (int i = 0; i < static_cast<int>(attr->end()->size()); i++) {
|
||||||
|
end.push_back(attr->end()->data()[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::vector<int32_t> stride;
|
||||||
|
if (attr->stride() != nullptr) {
|
||||||
|
for (int i = 0; i < static_cast<int>(attr->stride()->size()); i++) {
|
||||||
|
stride.push_back(attr->stride()->data()[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::vector<int32_t> isScale;
|
||||||
|
if (attr->isScale() != nullptr) {
|
||||||
|
for (int i = 0; i < static_cast<int>(attr->isScale()->size()); i++) {
|
||||||
|
isScale.push_back(attr->isScale()->data()[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto val_offset =
|
||||||
|
schema::CreateStridedSliceGradDirect(*fbb, attr->beginMask(), attr->endMask(), attr->ellipsisMask(),
|
||||||
|
attr->newAxisMask(), attr->shrinkAxisMask(), &begin, &end, &stride, &isScale);
|
||||||
|
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_StridedSliceGrad, val_offset.o);
|
||||||
|
fbb->Finish(prim_offset);
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
PrimitiveC *StridedSliceGradCreator(const schema::Primitive *primitive) {
|
||||||
|
return PrimitiveC::NewPrimitiveC<StridedSliceGrad>(primitive);
|
||||||
|
}
|
||||||
|
Registry StridedSliceGradRegistry(schema::PrimitiveType_StridedSliceGrad, StridedSliceGradCreator);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
constexpr size_t kStridedSliceGradOutputNum = 1;
|
||||||
|
constexpr size_t kStridedSliceGradMultiInputNumMax = 5;
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
int StridedSliceGrad::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
|
||||||
|
MS_ASSERT(this->primitive_ != nullptr);
|
||||||
|
if (outputs.size() != kStridedSliceGradOutputNum) {
|
||||||
|
MS_LOG(ERROR) << "Invalid output size:" << outputs.size();
|
||||||
|
return RET_PARAM_INVALID;
|
||||||
|
}
|
||||||
|
if (inputs.size() != kStridedSliceGradMultiInputNumMax) {
|
||||||
|
MS_LOG(ERROR) << "Invalid input size " << inputs.size();
|
||||||
|
return RET_PARAM_INVALID;
|
||||||
|
}
|
||||||
|
auto input = inputs.at(0);
|
||||||
|
outputs.front()->set_data_type(input->data_type());
|
||||||
|
outputs.at(0)->set_format(input->format());
|
||||||
|
MS_ASSERT(input != nullptr);
|
||||||
|
auto input_shape = input->shape();
|
||||||
|
auto inferflag = infer_flag();
|
||||||
|
|
||||||
|
in_shape_.clear();
|
||||||
|
if (inferflag) {
|
||||||
|
in_shape_.assign(input_shape.begin(), input_shape.end());
|
||||||
|
}
|
||||||
|
begins_.clear();
|
||||||
|
ends_.clear();
|
||||||
|
strides_.clear();
|
||||||
|
|
||||||
|
if (!CheckInputs(inputs)) {
|
||||||
|
MS_LOG(DEBUG) << "Do infer shape in runtime.";
|
||||||
|
return RET_INFER_INVALID;
|
||||||
|
}
|
||||||
|
|
||||||
|
// input order: dy, shapex, begins, ends, strides.
|
||||||
|
auto begin_tensor = inputs.at(2);
|
||||||
|
int *begin_data = reinterpret_cast<int *>(begin_tensor->MutableData());
|
||||||
|
auto end_tensor = inputs.at(3);
|
||||||
|
int *end_data = reinterpret_cast<int *>(end_tensor->MutableData());
|
||||||
|
auto stride_tensor = inputs.at(4);
|
||||||
|
int *stride_data = reinterpret_cast<int *>(stride_tensor->MutableData());
|
||||||
|
if (begin_data == nullptr || end_data == nullptr || stride_data == nullptr) {
|
||||||
|
return RET_INFER_ERR;
|
||||||
|
}
|
||||||
|
ndim_ = begin_tensor->ElementsNum();
|
||||||
|
for (size_t i = 0; i < ndim_; ++i) {
|
||||||
|
begins_.emplace_back(begin_data[i]);
|
||||||
|
ends_.emplace_back(end_data[i]);
|
||||||
|
strides_.emplace_back(stride_data[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// set all mask to original input shape
|
||||||
|
begins_mask_.resize(ndim_);
|
||||||
|
ends_mask_.resize(ndim_);
|
||||||
|
ellipsis_mask_.resize(ndim_);
|
||||||
|
new_axis_mask_.resize(ndim_);
|
||||||
|
shrink_axis_mask_.resize(ndim_);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < ndim_; i++) {
|
||||||
|
begins_mask_.at(i) = static_cast<uint32_t>(GetBeginMask()) & (1 << i);
|
||||||
|
ends_mask_.at(i) = static_cast<uint32_t>(GetEndMask()) & (1 << i);
|
||||||
|
ellipsis_mask_.at(i) = static_cast<uint32_t>(GetEllipsisMask()) & (1 << i);
|
||||||
|
new_axis_mask_.at(i) = static_cast<uint32_t>(GetNewAxisMask()) & (1 << i);
|
||||||
|
shrink_axis_mask_.at(i) = static_cast<uint32_t>(GetShrinkAxisMask()) & (1 << i);
|
||||||
|
}
|
||||||
|
|
||||||
|
ApplyNewAxisMask();
|
||||||
|
ApplyBeginMask();
|
||||||
|
ApplyEndMask();
|
||||||
|
ApplyEllipsisMask();
|
||||||
|
|
||||||
|
if (!inferflag) {
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto output_size = inputs.at(1)->shape().at(0);
|
||||||
|
std::vector<int> output_shape;
|
||||||
|
MS_ASSERT(inputs.at(1)->MutableData() != nullptr);
|
||||||
|
for (int i = 0; i < output_size; i++) {
|
||||||
|
output_shape.push_back(static_cast<int *>(inputs.at(1)->MutableData())[i]);
|
||||||
|
}
|
||||||
|
outputs.front()->set_shape(output_shape);
|
||||||
|
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,64 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_SRC_OPS_STRIDED_SLICE_GRAD_H_
|
||||||
|
#define MINDSPORE_LITE_SRC_OPS_STRIDED_SLICE_GRAD_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <set>
|
||||||
|
#include <cmath>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "src/ops/strided_slice.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
class StridedSliceGrad : public StridedSlice {
|
||||||
|
public:
|
||||||
|
StridedSliceGrad() = default;
|
||||||
|
~StridedSliceGrad() = default;
|
||||||
|
#ifdef PRIMITIVE_WRITEABLE
|
||||||
|
MS_DECLARE_PARENT(StridedSliceGrad, StridedSlice);
|
||||||
|
explicit StridedSliceGrad(schema::PrimitiveT *primitive) : StridedSlice(primitive) {}
|
||||||
|
void SetBeginMask(int begin_mask);
|
||||||
|
void SetEndMask(int end_mask);
|
||||||
|
void SetEllipsisMask(int ellipsis_mask);
|
||||||
|
void SetNewAxisMask(int new_axis_mask);
|
||||||
|
void SetShrinkAxisMask(int shrink_axis_mask);
|
||||||
|
void SetBegin(const std::vector<int> &begin);
|
||||||
|
void SetEnd(const std::vector<int> &end);
|
||||||
|
void SetStride(const std::vector<int> &stride);
|
||||||
|
void SetIsScale(const std::vector<int> &is_scale);
|
||||||
|
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||||
|
#else
|
||||||
|
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||||
|
#endif
|
||||||
|
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||||
|
// bool CheckInputs(std::vector<lite::Tensor *> inputs_);
|
||||||
|
int GetBeginMask() const;
|
||||||
|
int GetEndMask() const;
|
||||||
|
int GetEllipsisMask() const;
|
||||||
|
int GetNewAxisMask() const;
|
||||||
|
int GetShrinkAxisMask() const;
|
||||||
|
std::vector<int> GetBegin() const;
|
||||||
|
std::vector<int> GetEnd() const;
|
||||||
|
std::vector<int> GetStride() const;
|
||||||
|
std::vector<int> GetIsScale() const;
|
||||||
|
};
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_LITE_SRC_OPS_STRIDED_SLICE_GRAD_H_
|
|
@ -91,10 +91,12 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
|
||||||
|
|
||||||
// init bias
|
// init bias
|
||||||
size_t new_bias_size = oc4 * C4NUM * sizeof(float);
|
size_t new_bias_size = oc4 * C4NUM * sizeof(float);
|
||||||
bias_data_ = reinterpret_cast<float *>(malloc(new_bias_size));
|
|
||||||
if (bias_data_ == nullptr) {
|
if (bias_data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "malloc bias_data_ failed.";
|
bias_data_ = reinterpret_cast<float *>(malloc(new_bias_size));
|
||||||
return RET_MEMORY_FAILED;
|
if (bias_data_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "malloc bias_data_ failed.";
|
||||||
|
return RET_MEMORY_FAILED;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
memset(bias_data_, 0, new_bias_size);
|
memset(bias_data_, 0, new_bias_size);
|
||||||
if (in_tensors_.size() == kInputSize2) {
|
if (in_tensors_.size() == kInputSize2) {
|
||||||
|
|
|
@ -91,10 +91,6 @@ int FusedBatchnormCPUKernel::Run() {
|
||||||
memcpy(scale_, scale, in_tensors_[1]->Size());
|
memcpy(scale_, scale, in_tensors_[1]->Size());
|
||||||
memcpy(offset_, offset, in_tensors_[2]->Size());
|
memcpy(offset_, offset, in_tensors_[2]->Size());
|
||||||
|
|
||||||
// save for next iteration
|
|
||||||
memcpy(in_tensors_[3]->MutableData(), save_mean, in_tensors_[3]->Size());
|
|
||||||
memcpy(in_tensors_[4]->MutableData(), save_variance, in_tensors_[4]->Size());
|
|
||||||
|
|
||||||
trained_ = true; // trained at least once
|
trained_ = true; // trained at least once
|
||||||
}
|
}
|
||||||
auto ret = ParallelLaunch(this->context_->thread_pool_, BatchNormRun, this, op_parameter_->thread_num_);
|
auto ret = ParallelLaunch(this->context_->thread_pool_, BatchNormRun, this, op_parameter_->thread_num_);
|
||||||
|
|
|
@ -40,17 +40,16 @@ int ApplyMomentumCPUKernel::Execute(int task_id) {
|
||||||
|
|
||||||
size_t stride = UP_DIV(length, thread_count_);
|
size_t stride = UP_DIV(length, thread_count_);
|
||||||
size_t count = MSMIN(stride, length - stride * task_id);
|
size_t count = MSMIN(stride, length - stride * task_id);
|
||||||
|
|
||||||
size_t start = stride * task_id;
|
size_t start = stride * task_id;
|
||||||
size_t end = start + count;
|
size_t end = start + count;
|
||||||
|
|
||||||
if (apply_momentum_param_->use_nesterov_) {
|
if (apply_momentum_param_->use_nesterov_) {
|
||||||
for (size_t i = start; i < end; ++i) {
|
for (size_t i = start; i < end; i++) {
|
||||||
accumulate[i] = accumulate[i] * moment + gradient[i];
|
accumulate[i] = accumulate[i] * moment + gradient[i];
|
||||||
weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate;
|
weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t i = start; i < end; ++i) {
|
for (size_t i = start; i < end; i++) {
|
||||||
accumulate[i] = accumulate[i] * moment + gradient[i];
|
accumulate[i] = accumulate[i] * moment + gradient[i];
|
||||||
weight[i] -= accumulate[i] * learning_rate;
|
weight[i] -= accumulate[i] * learning_rate;
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,10 @@
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include <thread>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
#include "schema/model_generated.h"
|
#include "schema/model_generated.h"
|
||||||
#include "src/kernel_registry.h"
|
#include "src/kernel_registry.h"
|
||||||
#include "nnacl/fp32_grad/batch_norm.h"
|
#include "nnacl/fp32_grad/batch_norm.h"
|
||||||
|
@ -34,7 +38,8 @@ namespace mindspore::kernel {
|
||||||
int BNGradCPUKernel::ReSize() {
|
int BNGradCPUKernel::ReSize() {
|
||||||
auto *input_x = in_tensors_.at(1);
|
auto *input_x = in_tensors_.at(1);
|
||||||
int channels = input_x->shape().at(kNHWC_C);
|
int channels = input_x->shape().at(kNHWC_C);
|
||||||
set_workspace_size(2 * channels * sizeof(float));
|
ws_size_ = 2 * channels;
|
||||||
|
set_workspace_size(ws_size_ * sizeof(float));
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -46,7 +51,9 @@ int BNGradCPUKernel::Execute(int task_id) {
|
||||||
auto *input_scale = in_tensors_.at(2);
|
auto *input_scale = in_tensors_.at(2);
|
||||||
auto *input_mean = in_tensors_.at(3);
|
auto *input_mean = in_tensors_.at(3);
|
||||||
auto *input_var = in_tensors_.at(4);
|
auto *input_var = in_tensors_.at(4);
|
||||||
|
auto bn_param = reinterpret_cast<BNGradParameter *>(op_parameter_);
|
||||||
|
int stage = stage_;
|
||||||
|
int thread_num = thread_num_;
|
||||||
float *save_mean = reinterpret_cast<float *>(input_mean->MutableData());
|
float *save_mean = reinterpret_cast<float *>(input_mean->MutableData());
|
||||||
float *save_var = reinterpret_cast<float *>(input_var->MutableData());
|
float *save_var = reinterpret_cast<float *>(input_var->MutableData());
|
||||||
|
|
||||||
|
@ -58,26 +65,57 @@ int BNGradCPUKernel::Execute(int task_id) {
|
||||||
int32_t spatial = input_x->Height() * input_x->Width();
|
int32_t spatial = input_x->Height() * input_x->Width();
|
||||||
|
|
||||||
float *workspace_temp = static_cast<float *>(workspace());
|
float *workspace_temp = static_cast<float *>(workspace());
|
||||||
std::fill(workspace_temp, workspace_temp + workspace_size() / sizeof(*workspace_temp), 0.f);
|
|
||||||
float *dxhat_sum = workspace_temp;
|
float *dxhat_sum = workspace_temp;
|
||||||
float *dxhathat_sum = dxhat_sum + channels;
|
float *dxhathat_sum = dxhat_sum + channels;
|
||||||
|
|
||||||
float *x = reinterpret_cast<float *>(input_x->MutableData());
|
float *x = reinterpret_cast<float *>(input_x->MutableData());
|
||||||
float *yt = reinterpret_cast<float *>(input_yt->MutableData());
|
float *yt = reinterpret_cast<float *>(input_yt->MutableData());
|
||||||
float *scale = reinterpret_cast<float *>(input_scale->MutableData());
|
float *scale = reinterpret_cast<float *>(input_scale->MutableData());
|
||||||
float *dx = reinterpret_cast<float *>(output_dx->MutableData());
|
float *dx = reinterpret_cast<float *>(output_dx->MutableData());
|
||||||
float *dbias = reinterpret_cast<float *>(output_bias->MutableData());
|
float *dbias = reinterpret_cast<float *>(output_bias->MutableData());
|
||||||
float *dscale = reinterpret_cast<float *>(output_scale->MutableData());
|
float *dscale = reinterpret_cast<float *>(output_scale->MutableData());
|
||||||
std::fill(dbias, dbias + channels, 0.f);
|
int total = spatial * batch;
|
||||||
std::fill(dscale, dscale + channels, 0.f);
|
int stride = UP_DIV(total, thread_num);
|
||||||
backwardAll(x, yt, save_mean, save_var, scale, batch * spatial, channels, dxhat_sum, dxhathat_sum, dbias, dscale, dx);
|
int count = MSMIN(stride, total - stride * task_id);
|
||||||
|
switch (stage) {
|
||||||
|
case 0: {
|
||||||
|
for (int job = task_id; job < 4; job += thread_num) {
|
||||||
|
switch (job) {
|
||||||
|
case 0:
|
||||||
|
var2Invar(save_var, input_var->ElementsNum(), bn_param->epsilon_);
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
std::fill(workspace_temp, workspace_temp + ws_size_, 0.f);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
std::fill(dbias, dbias + channels, 0.f);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
std::fill(dscale, dscale + channels, 0.f);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (thread_num == 1) {
|
||||||
|
backwardAll(x, yt, save_mean, save_var, scale, total, channels, dxhat_sum, dxhathat_sum, dbias, dscale, dx);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 1: {
|
||||||
|
backwardP1(x, yt, save_mean, save_var, scale, total, channels, dxhat_sum, dxhathat_sum, dbias, dscale);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2: {
|
||||||
|
backwardP2(x + task_id * stride * channels, yt + task_id * stride * channels, save_mean, save_var, scale, count,
|
||||||
|
total, channels, dxhat_sum, dxhathat_sum, dx + task_id * stride * channels);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int BNGradRun(void *cdata, int task_id) {
|
int BNGradRun(void *cdata, int task_id) {
|
||||||
MS_ASSERT(cdata != nullptr);
|
MS_ASSERT(cdata != nullptr);
|
||||||
auto bn_kernel = reinterpret_cast<BNGradCPUKernel *>(cdata);
|
auto bn_kernel = reinterpret_cast<BNGradCPUKernel *>(cdata);
|
||||||
|
|
||||||
auto error_code = bn_kernel->Execute(task_id);
|
auto error_code = bn_kernel->Execute(task_id);
|
||||||
if (error_code != RET_OK) {
|
if (error_code != RET_OK) {
|
||||||
MS_LOG(ERROR) << "BNGradRun error task_id[" << task_id << "] error_code[" << error_code << "]";
|
MS_LOG(ERROR) << "BNGradRun error task_id[" << task_id << "] error_code[" << error_code << "]";
|
||||||
|
@ -87,15 +125,24 @@ int BNGradRun(void *cdata, int task_id) {
|
||||||
}
|
}
|
||||||
|
|
||||||
int BNGradCPUKernel::Run() {
|
int BNGradCPUKernel::Run() {
|
||||||
auto *input_var = in_tensors_.at(4);
|
stage_ = 0;
|
||||||
float *save_var = reinterpret_cast<float *>(input_var->MutableData());
|
thread_num_ = context_->thread_num_;
|
||||||
auto bn_param = reinterpret_cast<BNGradParameter *>(op_parameter_);
|
if (thread_num_ == 1) {
|
||||||
float eps = bn_param->epsilon_;
|
int error_code = ParallelLaunch(this->context_->thread_pool_, BNGradRun, this, thread_num_);
|
||||||
var2Invar(save_var, input_var->ElementsNum(), eps);
|
if (error_code != RET_OK) {
|
||||||
int error_code = ParallelLaunch(this->context_->thread_pool_, BNGradRun, this, 1);
|
MS_LOG(ERROR) << "BN function error error_code[" << error_code << "]";
|
||||||
if (error_code != RET_OK) {
|
return RET_ERROR;
|
||||||
MS_LOG(ERROR) << "BN function error error_code[" << error_code << "]";
|
}
|
||||||
return RET_ERROR;
|
} else {
|
||||||
|
const std::vector<int> threads = {thread_num_, 1, thread_num_};
|
||||||
|
for (size_t stage = 0; stage < threads.size(); stage++) {
|
||||||
|
stage_ = static_cast<int>(stage);
|
||||||
|
int error_code = ParallelLaunch(this->context_->thread_pool_, BNGradRun, this, threads.at(stage));
|
||||||
|
if (error_code != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "BN function error error_code[" << error_code << "]";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,6 +33,11 @@ class BNGradCPUKernel : public LiteKernel {
|
||||||
int ReSize() override;
|
int ReSize() override;
|
||||||
int Run() override;
|
int Run() override;
|
||||||
int Execute(int task_id);
|
int Execute(int task_id);
|
||||||
|
|
||||||
|
private:
|
||||||
|
int thread_num_ = 1;
|
||||||
|
int stage_ = 0;
|
||||||
|
size_t ws_size_ = 0;
|
||||||
};
|
};
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BN_GRAD_H_
|
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BN_GRAD_H_
|
||||||
|
|
|
@ -54,9 +54,6 @@ int ConvolutionTrainCPUKernel::ReSize() {
|
||||||
conv_param_->group_ = (conv_param_->group_ == 0) ? conv_param_->input_channel_ : conv_param_->group_;
|
conv_param_->group_ = (conv_param_->group_ == 0) ? conv_param_->input_channel_ : conv_param_->group_;
|
||||||
const int n = conv_param_->output_channel_ * conv_param_->group_;
|
const int n = conv_param_->output_channel_ * conv_param_->group_;
|
||||||
const int k = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ / conv_param_->group_;
|
const int k = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ / conv_param_->group_;
|
||||||
ws_size_ = chunk_ * k;
|
|
||||||
int mat_alloc = MatSizeTotal(chunk_, n, k, 0);
|
|
||||||
set_workspace_size((ws_size_ + mat_alloc) * sizeof(float));
|
|
||||||
|
|
||||||
do_img2col_ = (conv_param_->kernel_h_ == 1) && (conv_param_->kernel_w_ == 1) && (conv_param_->pad_d_ == 0) &&
|
do_img2col_ = (conv_param_->kernel_h_ == 1) && (conv_param_->kernel_w_ == 1) && (conv_param_->pad_d_ == 0) &&
|
||||||
(conv_param_->pad_u_ == 0) && (conv_param_->pad_l_ == 0) && (conv_param_->pad_r_ == 0) &&
|
(conv_param_->pad_u_ == 0) && (conv_param_->pad_l_ == 0) && (conv_param_->pad_r_ == 0) &&
|
||||||
|
@ -64,6 +61,16 @@ int ConvolutionTrainCPUKernel::ReSize() {
|
||||||
(conv_param_->stride_h_ == 1) && (conv_param_->stride_w_ == 1) && (conv_param_->group_ == 1)
|
(conv_param_->stride_h_ == 1) && (conv_param_->stride_w_ == 1) && (conv_param_->group_ == 1)
|
||||||
? false
|
? false
|
||||||
: true;
|
: true;
|
||||||
|
do_dw_ = (conv_param_->output_channel_ == conv_param_->group_) &&
|
||||||
|
(conv_param_->input_channel_ == conv_param_->output_channel_) && (conv_param_->dilation_h_ == 1) &&
|
||||||
|
(conv_param_->dilation_w_ == 1)
|
||||||
|
? true
|
||||||
|
: false;
|
||||||
|
|
||||||
|
ws_size_ = chunk_ * conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_;
|
||||||
|
ws_size_ = do_dw_ ? ws_size_ : ws_size_ / conv_param_->group_;
|
||||||
|
int mat_alloc = MatSizeTotal(chunk_, n, k, 0);
|
||||||
|
set_workspace_size((ws_size_ + mat_alloc) * sizeof(float));
|
||||||
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
@ -97,7 +104,25 @@ int ConvolutionTrainCPUKernel::Execute(int task_id) {
|
||||||
float *workspace_temp = static_cast<float *>(workspace());
|
float *workspace_temp = static_cast<float *>(workspace());
|
||||||
float *mat_workspace = workspace_temp + ws_size_;
|
float *mat_workspace = workspace_temp + ws_size_;
|
||||||
|
|
||||||
if (do_img2col_) {
|
if (do_dw_) {
|
||||||
|
const int kernel_spatial = k_h * k_w;
|
||||||
|
for (int i = 0; i < batch; ++i) {
|
||||||
|
for (int ci = 0; ci < m; ci += chunk_) {
|
||||||
|
int real_chunk = MSMIN(m - ci, chunk_);
|
||||||
|
float *mat_a = workspace_temp;
|
||||||
|
float *im = x_addr + (i * in_ch * in_h * in_w);
|
||||||
|
RollingIm2ColPackDwUnitFp32(im, conv_param_, mat_a, real_chunk, ci);
|
||||||
|
for (int j = 0; j < groups; ++j) {
|
||||||
|
const float *mat_b = w_addr + j * nweights / groups;
|
||||||
|
float *mat_c = y_addr + (i * groups) * n * m + j * (out_ch / groups) + ci * out_ch;
|
||||||
|
// float *im = x_addr + i * in_ch * in_h * in_w + j * (in_ch / groups);
|
||||||
|
// RollingIm2ColPackUnitFp32(im, conv_param_, mat_a, real_chunk, ci);
|
||||||
|
GemmMatmul(0, 1, real_chunk, n, k, 1, mat_a + (j * kernel_spatial), k * groups, mat_b, k, 0, mat_c, out_ch,
|
||||||
|
mat_workspace);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (do_img2col_) {
|
||||||
for (int i = 0; i < batch; ++i) {
|
for (int i = 0; i < batch; ++i) {
|
||||||
for (int j = 0; j < groups; ++j) {
|
for (int j = 0; j < groups; ++j) {
|
||||||
for (int ci = 0; ci < m; ci += chunk_) {
|
for (int ci = 0; ci < m; ci += chunk_) {
|
||||||
|
|
|
@ -37,6 +37,7 @@ class ConvolutionTrainCPUKernel : public LiteKernel {
|
||||||
private:
|
private:
|
||||||
int ws_size_ = 0;
|
int ws_size_ = 0;
|
||||||
bool do_img2col_ = true;
|
bool do_img2col_ = true;
|
||||||
|
bool do_dw_ = false;
|
||||||
#ifdef ENABLE_ARM32
|
#ifdef ENABLE_ARM32
|
||||||
const int chunk_ = C4NUM * 2;
|
const int chunk_ = C4NUM * 2;
|
||||||
#else
|
#else
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#include "src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h"
|
#include "src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h"
|
||||||
#include "src/kernel_registry.h"
|
#include "src/kernel_registry.h"
|
||||||
#include "nnacl/pack.h"
|
#include "nnacl/pack.h"
|
||||||
|
#include "nnacl/fp32_grad/convolution_grad_filter.h"
|
||||||
#include "nnacl/fp32_grad/pack_ext.h"
|
#include "nnacl/fp32_grad/pack_ext.h"
|
||||||
#include "nnacl/fp32_grad/gemm.h"
|
#include "nnacl/fp32_grad/gemm.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
|
@ -51,20 +52,25 @@ int ConvolutionGradFilterCPUKernel::ReSize() {
|
||||||
conv_param->output_h_ = dy_tensor->shape()[kNHWC_H];
|
conv_param->output_h_ = dy_tensor->shape()[kNHWC_H];
|
||||||
conv_param->output_w_ = dy_tensor->shape()[kNHWC_W];
|
conv_param->output_w_ = dy_tensor->shape()[kNHWC_W];
|
||||||
|
|
||||||
ws_size_ = chunk_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_;
|
|
||||||
|
|
||||||
int n = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_;
|
|
||||||
int k = conv_param->output_channel_ / conv_param->group_;
|
|
||||||
int thread_num = context_->thread_num_;
|
|
||||||
mat_alloc_ = MatSizeTotal(k, n, chunk_, 0);
|
|
||||||
set_workspace_size((ws_size_ + mat_alloc_ + (k * n)) * thread_num * sizeof(float));
|
|
||||||
|
|
||||||
do_img2col_ = (conv_param->kernel_h_ == 1) && (conv_param->kernel_w_ == 1) && (conv_param->pad_d_ == 0) &&
|
do_img2col_ = (conv_param->kernel_h_ == 1) && (conv_param->kernel_w_ == 1) && (conv_param->pad_d_ == 0) &&
|
||||||
(conv_param->pad_u_ == 0) && (conv_param->pad_l_ == 0) && (conv_param->pad_r_ == 0) &&
|
(conv_param->pad_u_ == 0) && (conv_param->pad_l_ == 0) && (conv_param->pad_r_ == 0) &&
|
||||||
(conv_param->dilation_h_ == 1) && (conv_param->dilation_w_ == 1) && (conv_param->stride_h_ == 1) &&
|
(conv_param->dilation_h_ == 1) && (conv_param->dilation_w_ == 1) && (conv_param->stride_h_ == 1) &&
|
||||||
(conv_param->stride_w_ == 1) && (conv_param->group_ == 1)
|
(conv_param->stride_w_ == 1) && (conv_param->group_ == 1)
|
||||||
? false
|
? false
|
||||||
: true;
|
: true;
|
||||||
|
do_dw_ = (conv_param->output_channel_ == conv_param->group_) &&
|
||||||
|
(conv_param->input_channel_ == conv_param->output_channel_) && (conv_param->dilation_h_ == 1) &&
|
||||||
|
(conv_param->dilation_w_ == 1)
|
||||||
|
? true
|
||||||
|
: false;
|
||||||
|
|
||||||
|
ws_size_ = chunk_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
|
||||||
|
ws_size_ = do_dw_ ? ws_size_ : ws_size_ / conv_param->group_;
|
||||||
|
int n = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_;
|
||||||
|
int k = conv_param->output_channel_ / conv_param->group_;
|
||||||
|
int thread_num = context_->thread_num_;
|
||||||
|
mat_alloc_ = MatSizeTotal(k, n, chunk_, 0);
|
||||||
|
set_workspace_size((ws_size_ + mat_alloc_ + (k * n)) * thread_num * sizeof(float));
|
||||||
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
@ -105,10 +111,38 @@ int ConvolutionGradFilterCPUKernel::Execute(int task_id) {
|
||||||
int start = stride * task_id;
|
int start = stride * task_id;
|
||||||
int end = start + count;
|
int end = start + count;
|
||||||
|
|
||||||
if (do_img2col_) {
|
if (do_dw_) {
|
||||||
|
#ifdef ENABLE_ARM
|
||||||
|
stride = UP_DIV(k_h * k_w, thread_num);
|
||||||
|
count = MSMIN(stride, k_h * k_w - stride * task_id);
|
||||||
|
start = stride * task_id;
|
||||||
|
ConvDwFilterGrad(x_addr, dy_addr, dw_addr, start, count, conv_param);
|
||||||
|
#else
|
||||||
|
stride = UP_DIV(groups, thread_num);
|
||||||
|
count = MSMIN(stride, groups - stride * task_id);
|
||||||
|
start = stride * task_id;
|
||||||
|
end = start + count;
|
||||||
|
|
||||||
|
const int kernel_spatial = k_h * k_w;
|
||||||
|
for (int i = 0; i < batch; ++i) {
|
||||||
|
for (int ci = 0; ci < m; ci += chunk_) {
|
||||||
|
int real_chunk = MSMIN(m - ci, chunk_);
|
||||||
|
float *mat_b = workspace_temp + task_id * ws_size_;
|
||||||
|
float *im = x_addr + (i * in_ch * in_h * in_w);
|
||||||
|
RollingIm2ColPackDwUnitFp32(im, conv_param, mat_b, real_chunk, ci);
|
||||||
|
for (int j = start; j < end; ++j) {
|
||||||
|
float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch;
|
||||||
|
float *mat_c = dw_addr + j * nweights / groups;
|
||||||
|
GemmMatmul(1, 0, k, n, real_chunk, 1, mat_a, out_ch, mat_b + (j * kernel_spatial), n * groups, 1, mat_c, n,
|
||||||
|
mat_workspace);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
} else if (do_img2col_) {
|
||||||
for (int i = start; i < end; ++i) {
|
for (int i = start; i < end; ++i) {
|
||||||
for (int j = 0; j < groups; ++j) {
|
for (int ci = 0; ci < m; ci += chunk_) {
|
||||||
for (int ci = 0; ci < m; ci += chunk_) {
|
for (int j = 0; j < groups; ++j) {
|
||||||
int real_chunk = MSMIN(m - ci, chunk_);
|
int real_chunk = MSMIN(m - ci, chunk_);
|
||||||
float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch;
|
float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch;
|
||||||
float *mat_b = workspace_temp + task_id * ws_size_;
|
float *mat_b = workspace_temp + task_id * ws_size_;
|
||||||
|
|
|
@ -38,6 +38,7 @@ class ConvolutionGradFilterCPUKernel : public LiteKernel {
|
||||||
private:
|
private:
|
||||||
size_t ws_size_ = 0;
|
size_t ws_size_ = 0;
|
||||||
bool do_img2col_ = true;
|
bool do_img2col_ = true;
|
||||||
|
bool do_dw_ = false;
|
||||||
std::mutex lock_;
|
std::mutex lock_;
|
||||||
size_t mat_alloc_ = 0;
|
size_t mat_alloc_ = 0;
|
||||||
#ifdef ENABLE_ARM32
|
#ifdef ENABLE_ARM32
|
||||||
|
|
|
@ -66,13 +66,20 @@ int PoolingGradCPUKernel::Execute(int task_id) {
|
||||||
auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
|
auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
|
||||||
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
|
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
|
||||||
|
|
||||||
|
int stride = UP_DIV(pool_param->output_batch_, thread_num_);
|
||||||
|
int count = MSMIN(stride, pool_param->output_batch_ - stride * task_id);
|
||||||
|
int in_batch_size = pool_param->input_h_ * pool_param->input_w_ * pool_param->input_channel_;
|
||||||
|
int out_batch_size = pool_param->output_h_ * pool_param->output_w_ * pool_param->input_channel_;
|
||||||
|
std::fill(output_ptr + task_id * stride * in_batch_size, output_ptr + ((task_id * stride) + count) * in_batch_size,
|
||||||
|
0.f);
|
||||||
if (pool_param->pool_mode_ == PoolMode_MaxPool) {
|
if (pool_param->pool_mode_ == PoolMode_MaxPool) {
|
||||||
auto dx_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
|
|
||||||
auto dy_ptr = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData());
|
auto dy_ptr = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData());
|
||||||
MaxPoolingGrad(input_ptr, dx_ptr, dy_ptr, output_ptr, pool_param, task_id);
|
MaxPoolingGrad(input_ptr + task_id * stride * in_batch_size, dy_ptr + task_id * stride * out_batch_size,
|
||||||
|
output_ptr + task_id * stride * in_batch_size, count, pool_param);
|
||||||
} else {
|
} else {
|
||||||
input_ptr = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData());
|
input_ptr = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData());
|
||||||
AvgPoolingGrad(input_ptr, output_ptr, pool_param, task_id);
|
AvgPoolingGrad(input_ptr + task_id * stride * out_batch_size, output_ptr + task_id * stride * in_batch_size, count,
|
||||||
|
pool_param);
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
@ -89,7 +96,8 @@ int PoolingGradImpl(void *cdata, int task_id) {
|
||||||
}
|
}
|
||||||
|
|
||||||
int PoolingGradCPUKernel::Run() {
|
int PoolingGradCPUKernel::Run() {
|
||||||
int error_code = ParallelLaunch(this->context_->thread_pool_, PoolingGradImpl, this, 1);
|
thread_num_ = context_->thread_num_;
|
||||||
|
int error_code = ParallelLaunch(this->context_->thread_pool_, PoolingGradImpl, this, thread_num_);
|
||||||
if (error_code != RET_OK) {
|
if (error_code != RET_OK) {
|
||||||
MS_LOG(ERROR) << "pooling error error_code[" << error_code << "]";
|
MS_LOG(ERROR) << "pooling error error_code[" << error_code << "]";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
|
|
@ -40,6 +40,7 @@ class PoolingGradCPUKernel : public LiteKernel {
|
||||||
int Execute(int task_id);
|
int Execute(int task_id);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
int thread_num_ = 1;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
|
@ -0,0 +1,150 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "src/runtime/kernel/arm/fp32_grad/strided_slice_grad.h"
|
||||||
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "schema/model_generated.h"
|
||||||
|
#include "src/kernel_registry.h"
|
||||||
|
#include "nnacl/fp32_grad/strided_slice_grad.h"
|
||||||
|
#include "src/ops/populate/strided_slice_populate.h"
|
||||||
|
#include "include/errorcode.h"
|
||||||
|
#include "src/runtime/runtime_api.h"
|
||||||
|
|
||||||
|
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||||
|
using mindspore::lite::KernelRegistrar;
|
||||||
|
using mindspore::lite::RET_ERROR;
|
||||||
|
using mindspore::lite::RET_OK;
|
||||||
|
using mindspore::schema::PrimitiveType_StridedSliceGrad;
|
||||||
|
|
||||||
|
namespace mindspore::kernel {
|
||||||
|
|
||||||
|
int StridedSliceGradCPUKernel::Init() {
|
||||||
|
if (!InferShapeDone()) {
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
param_ = reinterpret_cast<StridedSliceParameter *>(op_parameter_);
|
||||||
|
auto input = in_tensors_.at(0);
|
||||||
|
MS_ASSERT(input);
|
||||||
|
switch (input->data_type()) {
|
||||||
|
case kNumberTypeFloat32:
|
||||||
|
param_->data_type = kDataTypeFloat;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
MS_LOG(ERROR) << "Not supported data type: " << input->data_type();
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
FillEmptyDims();
|
||||||
|
FillOutputDim();
|
||||||
|
return ReSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
void StridedSliceGradCPUKernel::FillEmptyDims() {
|
||||||
|
int32_t begins[DIMENSION_7D];
|
||||||
|
int32_t ends[DIMENSION_7D];
|
||||||
|
int32_t strides[DIMENSION_7D];
|
||||||
|
int32_t input_shape[DIMENSION_7D];
|
||||||
|
int32_t i;
|
||||||
|
for (i = 0; i < param_->num_axes_; ++i) {
|
||||||
|
begins[i] = param_->begins_[i];
|
||||||
|
ends[i] = MSMIN(param_->ends_[i], param_->in_shape_[i]);
|
||||||
|
strides[i] = param_->strides_[i];
|
||||||
|
input_shape[i] = param_->in_shape_[i];
|
||||||
|
}
|
||||||
|
for (i = param_->num_axes_; i < param_->in_shape_length_; ++i) {
|
||||||
|
input_shape[i] = param_->in_shape_[i];
|
||||||
|
begins[i] = 0;
|
||||||
|
ends[i] = param_->in_shape_[i];
|
||||||
|
strides[i] = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t real_index = param_->in_shape_length_ - 1;
|
||||||
|
for (i = DIMENSION_7D - 1; i >= 0; --i) {
|
||||||
|
if (real_index >= 0) {
|
||||||
|
param_->begins_[i] = begins[real_index];
|
||||||
|
param_->ends_[i] = ends[real_index];
|
||||||
|
param_->strides_[i] = strides[real_index];
|
||||||
|
param_->in_shape_[i] = input_shape[real_index--];
|
||||||
|
} else {
|
||||||
|
param_->begins_[i] = 0;
|
||||||
|
param_->ends_[i] = 1;
|
||||||
|
param_->strides_[i] = 1;
|
||||||
|
param_->in_shape_[i] = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
param_->num_axes_ = DIMENSION_7D;
|
||||||
|
param_->in_shape_length_ = DIMENSION_7D;
|
||||||
|
|
||||||
|
for (i = 0; i < DIMENSION_7D; ++i) {
|
||||||
|
if (param_->begins_[i] < 0) {
|
||||||
|
param_->begins_[i] += param_->in_shape_[i];
|
||||||
|
}
|
||||||
|
if (param_->ends_[i] < 0) {
|
||||||
|
param_->ends_[i] += param_->in_shape_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void StridedSliceGradCPUKernel::FillOutputDim() {
|
||||||
|
auto output = out_tensors_.at(0);
|
||||||
|
size_t out_size = output->shape().size();
|
||||||
|
for (size_t i = 0; i < DIMENSION_7D; i++) {
|
||||||
|
if (i < out_size) {
|
||||||
|
output_shape_.push_back(output->shape()[i]);
|
||||||
|
} else {
|
||||||
|
output_shape_.insert(output_shape_.begin(), 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int StridedSliceGradCPUKernel::ReSize() { return RET_OK; }
|
||||||
|
|
||||||
|
int StridedSliceGradImpl(void *cdata, int task_id) {
|
||||||
|
MS_ASSERT(cdata != nullptr);
|
||||||
|
auto slice = reinterpret_cast<StridedSliceGradCPUKernel *>(cdata);
|
||||||
|
auto error_code = slice->Execute(task_id);
|
||||||
|
if (error_code != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "StridedSliceGrad Run error task_id[" << task_id << "] error_code[" << error_code << "]";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int StridedSliceGradCPUKernel::Run() {
|
||||||
|
int error_code = ParallelLaunch(this->context_->thread_pool_, StridedSliceGradImpl, this, 1);
|
||||||
|
if (error_code != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Strided slice error error_code[" << error_code << "]";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int StridedSliceGradCPUKernel::Execute(int task_id) {
|
||||||
|
auto input = in_tensors_.at(0);
|
||||||
|
auto output = out_tensors_.at(0);
|
||||||
|
MS_ASSERT(output);
|
||||||
|
int *po = output_shape_.data();
|
||||||
|
auto ret = DoStridedSliceGrad(reinterpret_cast<float *>(input->MutableData()),
|
||||||
|
reinterpret_cast<float *>(output->MutableData()), po, param_);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "StridedSliceGrad error error_code[" << ret << "]";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_StridedSliceGrad, LiteKernelCreator<StridedSliceGradCPUKernel>)
|
||||||
|
} // namespace mindspore::kernel
|
|
@ -0,0 +1,50 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_STRIDED_SLICE_GRAD_H_
|
||||||
|
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_STRIDED_SLICE_GRAD_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include "nnacl/fp32_grad/strided_slice_grad.h"
|
||||||
|
#include "src/lite_kernel.h"
|
||||||
|
|
||||||
|
namespace mindspore::kernel {
|
||||||
|
class StridedSliceGradCPUKernel : public LiteKernel {
|
||||||
|
public:
|
||||||
|
StridedSliceGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||||
|
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||||
|
const mindspore::lite::PrimitiveC *primitive)
|
||||||
|
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
|
||||||
|
param_ = reinterpret_cast<StridedSliceParameter *>(parameter);
|
||||||
|
}
|
||||||
|
~StridedSliceGradCPUKernel() override = default;
|
||||||
|
|
||||||
|
int Init() override;
|
||||||
|
int ReSize() override;
|
||||||
|
int Run() override;
|
||||||
|
int Execute(int task_id);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void FillEmptyDims();
|
||||||
|
void FillOutputDim();
|
||||||
|
void ParseMasks();
|
||||||
|
|
||||||
|
StridedSliceParameter *param_;
|
||||||
|
std::vector<int> output_shape_;
|
||||||
|
};
|
||||||
|
} // namespace mindspore::kernel
|
||||||
|
|
||||||
|
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_STRIDED_SLICE_GRAD_H_
|
|
@ -49,6 +49,7 @@
|
||||||
#include "src/ops/smooth_l1_loss_grad.h"
|
#include "src/ops/smooth_l1_loss_grad.h"
|
||||||
#include "nnacl/fp32_grad/smooth_l1_loss.h"
|
#include "nnacl/fp32_grad/smooth_l1_loss.h"
|
||||||
#include "src/ops/arithmetic_grad.h"
|
#include "src/ops/arithmetic_grad.h"
|
||||||
|
#include "src/ops/populate/strided_slice_populate.h"
|
||||||
namespace mindspore::kernel {
|
namespace mindspore::kernel {
|
||||||
|
|
||||||
OpParameter *DefaultPopulateParameter(const mindspore::lite::PrimitiveC *primitive) {
|
OpParameter *DefaultPopulateParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||||
|
@ -569,6 +570,9 @@ void PopulateTrainParameters() {
|
||||||
DefaultPopulateParameter);
|
DefaultPopulateParameter);
|
||||||
lite::Registry SigmoidCrossEntropyWithLogitsGradRegistry(schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad,
|
lite::Registry SigmoidCrossEntropyWithLogitsGradRegistry(schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad,
|
||||||
DefaultPopulateParameter);
|
DefaultPopulateParameter);
|
||||||
|
lite::Registry FlattenGradParameterRegistry(schema::PrimitiveType_FlattenGrad, DefaultPopulateParameter);
|
||||||
|
lite::Registry StridedSliceGradParameterRegistry(schema::PrimitiveType_StridedSliceGrad,
|
||||||
|
mindspore::lite::PopulateStridedSliceParameter);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_LITE_SRC_TRAIN_TRANSFER_SESSION_H_
|
||||||
|
#define MINDSPORE_LITE_SRC_TRAIN_TRANSFER_SESSION_H_
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include "src/ops/primitive_c.h"
|
||||||
|
#include "include/train_session.h"
|
||||||
|
#include "src/train/train_model.h"
|
||||||
|
#include "src/lite_session.h"
|
||||||
|
#include "src/train/train_session.h"
|
||||||
|
|
||||||
|
/*
|
||||||
|
Inheritance Diagram
|
||||||
|
|
||||||
|
+-------------------------------+
|
||||||
|
| session::LiteSession |
|
||||||
|
+--------+------------+---------+
|
||||||
|
/ \
|
||||||
|
+-----------------+-----+ +-------+------------+
|
||||||
|
| session::TrainSession | | lite::LiteSession |
|
||||||
|
+-----------------+-----+ +-------+------------+
|
||||||
|
\ /
|
||||||
|
+--------+------------+---------+
|
||||||
|
| lite::TrainSession |
|
||||||
|
+-------------------------------+
|
||||||
|
|
|
||||||
|
+--------+------------+---------+
|
||||||
|
| lite::TrasferSession |
|
||||||
|
+-------------------------------+
|
||||||
|
*/
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
|
||||||
|
class TransferSession : public lite::TrainSession {
|
||||||
|
public:
|
||||||
|
TransferSession();
|
||||||
|
explicit TransferSession(lite::LiteSession *backend_session);
|
||||||
|
~TransferSession();
|
||||||
|
|
||||||
|
int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override;
|
||||||
|
|
||||||
|
void BindThread(bool if_bind) override;
|
||||||
|
std::vector<tensor::MSTensor *> GetInputs() const override { return lite::LiteSession::GetInputs(); }
|
||||||
|
mindspore::tensor::MSTensor *GetInputsByTensorName(const std::string &tensor_name) const override {
|
||||||
|
return lite::LiteSession::GetInputsByTensorName(tensor_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
lite::LiteSession *backend_session_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
};
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_LITE_SRC_TRAIN_TRANSFER_SESSION_H_
|
|
@ -1,13 +1,15 @@
|
||||||
mini_alexnet
|
mini_alexnet
|
||||||
# mobilenetv1
|
mobilenetv1
|
||||||
mobilenetv2
|
mobilenetv2
|
||||||
mobilenetv3
|
mobilenetv3
|
||||||
lenet
|
lenet
|
||||||
effnet
|
effnet
|
||||||
# effnet_tune
|
effnet_tune
|
||||||
# lenetv1
|
resnet
|
||||||
# resnet
|
googlenet
|
||||||
# googlenet
|
|
||||||
# densenet
|
# densenet
|
||||||
|
# shufflenetv2
|
||||||
|
# nin
|
||||||
# one_net
|
# one_net
|
||||||
|
# lenetv1
|
||||||
#LAST
|
#LAST
|
|
@ -0,0 +1,82 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Print start msg after run testcase
|
||||||
|
function MS_PRINT_TESTCASE_END_MSG() {
|
||||||
|
echo -e "-----------------------------------------------------------------------------------------------------------------------------------"
|
||||||
|
}
|
||||||
|
|
||||||
|
function Print_Result() {
|
||||||
|
MS_PRINT_TESTCASE_END_MSG
|
||||||
|
while read line; do
|
||||||
|
arr=("${line}")
|
||||||
|
printf "%-15s %-20s %-90s %-7s\n" ${arr[0]} ${arr[1]} ${arr[2]} ${arr[3]}
|
||||||
|
done < $1
|
||||||
|
MS_PRINT_TESTCASE_END_MSG
|
||||||
|
}
|
||||||
|
|
||||||
|
basepath=$(pwd)
|
||||||
|
echo ${basepath}
|
||||||
|
|
||||||
|
# Example:run_net_export.sh -m /home/emir/Work/TestingEnv/train_models
|
||||||
|
epoch_num=1
|
||||||
|
while getopts "m:t:" opt; do
|
||||||
|
case ${opt} in
|
||||||
|
m)
|
||||||
|
|
||||||
|
models_path=${OPTARG}"/models_train"
|
||||||
|
echo "models_path is ${OPTARG}"
|
||||||
|
;;
|
||||||
|
t)
|
||||||
|
epoch_num=${OPTARG}
|
||||||
|
echo "train epoch num is ${OPTARG}"
|
||||||
|
;;
|
||||||
|
?)
|
||||||
|
echo "unknown para"
|
||||||
|
exit 1;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
|
||||||
|
# Set models config filepath
|
||||||
|
models_mindspore_train_config=${basepath}/models_ms_train.cfg
|
||||||
|
|
||||||
|
logs_path=${basepath}/logs_train
|
||||||
|
rm -rf ${logs_path}
|
||||||
|
mkdir -p ${logs_path}
|
||||||
|
|
||||||
|
docker_image=mindspore/mindspore-gpu:1.1.0
|
||||||
|
# Export models
|
||||||
|
echo "Start Exporting models ..."
|
||||||
|
# Set log files
|
||||||
|
export_log_file=${logs_path}/export_log.txt
|
||||||
|
echo ' ' > ${export_log_file}
|
||||||
|
|
||||||
|
export_result_file=${logs_path}/export_result.txt
|
||||||
|
echo ' ' > ${export_result_file}
|
||||||
|
|
||||||
|
# Run export according to config file
|
||||||
|
cd $models_path || exit 1
|
||||||
|
if [[ -z "${CLOUD_MODEL_ZOO}" ]]; then
|
||||||
|
echo "CLOUD_MODEL_ZOO is not defined - exiting export models"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Export mindspore train models:
|
||||||
|
while read line; do
|
||||||
|
model_name=${line}
|
||||||
|
if [[ $model_name == \#* ]]; then
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
echo ${model_name}'_train_export.py' >> "${export_log_file}"
|
||||||
|
echo 'exporting' ${model_name}
|
||||||
|
echo 'docker run --user '"$(id -u):$(id -g)"' --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true '${docker_image}' python '${models_path}'/'${model_name}'_train_export.py' >> "${export_log_file}"
|
||||||
|
docker run --user "$(id -u):$(id -g)" --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true "${docker_image}" python ${models_path}'/'${model_name}_train_export.py "${epoch_num}"
|
||||||
|
if [ $? = 0 ]; then
|
||||||
|
export_result='export mindspore '${model_name}'_train_export pass';echo ${export_result} >> ${export_result_file}
|
||||||
|
else
|
||||||
|
export_result='export mindspore '${model_name}'_train_export failed';echo ${export_result} >> ${export_result_file}
|
||||||
|
fi
|
||||||
|
done < ${models_mindspore_train_config}
|
||||||
|
|
||||||
|
Print_Result ${export_result_file}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
# Run Export on x86 platform and create output test files:
|
# Run Export on x86 platform and create output test files:
|
||||||
docker_image=mindspore_dev:8
|
docker_image=
|
||||||
function Run_Export(){
|
function Run_Export(){
|
||||||
cd $models_path || exit 1
|
cd $models_path || exit 1
|
||||||
if [[ -z "${CLOUD_MODEL_ZOO}" ]]; then
|
if [[ -z "${CLOUD_MODEL_ZOO}" ]]; then
|
||||||
|
@ -16,8 +16,13 @@ function Run_Export(){
|
||||||
fi
|
fi
|
||||||
echo ${model_name}'_train_export.py' >> "${export_log_file}"
|
echo ${model_name}'_train_export.py' >> "${export_log_file}"
|
||||||
echo 'exporting' ${model_name}
|
echo 'exporting' ${model_name}
|
||||||
echo 'docker run --user '"$(id -u):$(id -g)"' --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true '${docker_image}' python '${models_path}'/'${model_name}'_train_export.py' >> "${export_log_file}"
|
if [ -n "$docker_image" ]; then
|
||||||
docker run --user "$(id -u):$(id -g)" --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true "${docker_image}" python ${models_path}'/'${model_name}_train_export.py "${epoch_num}"
|
echo 'docker run --user '"$(id -u):$(id -g)"' --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true '${docker_image}' python '${models_path}'/'${model_name}'_train_export.py' >> "${export_log_file}"
|
||||||
|
docker run --user "$(id -u):$(id -g)" --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true "${docker_image}" python ${models_path}'/'${model_name}_train_export.py "${epoch_num}"
|
||||||
|
else
|
||||||
|
echo 'CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} python '${models_path}'/'${model_name}'_train_export.py' >> "${export_log_file}"
|
||||||
|
CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} python ${models_path}'/'${model_name}_train_export.py "${epoch_num}"
|
||||||
|
fi
|
||||||
if [ $? = 0 ]; then
|
if [ $? = 0 ]; then
|
||||||
export_result='export mindspore '${model_name}'_train_export pass';echo ${export_result} >> ${export_result_file}
|
export_result='export mindspore '${model_name}'_train_export pass';echo ${export_result} >> ${export_result_file}
|
||||||
else
|
else
|
||||||
|
@ -28,7 +33,7 @@ function Run_Export(){
|
||||||
|
|
||||||
# Run converter on x86 platform:
|
# Run converter on x86 platform:
|
||||||
function Run_Converter() {
|
function Run_Converter() {
|
||||||
# Unzip x86 runtime and convertor
|
# Unzip x86 runtime and converter
|
||||||
cd ${x86_path} || exit 1
|
cd ${x86_path} || exit 1
|
||||||
tar -zxf mindspore-lite-${version}-train-linux-x64.tar.gz || exit 1
|
tar -zxf mindspore-lite-${version}-train-linux-x64.tar.gz || exit 1
|
||||||
|
|
||||||
|
@ -189,7 +194,7 @@ ENDM
|
||||||
if [ $? = 0 ]; then
|
if [ $? = 0 ]; then
|
||||||
run_result=$1': '${model_name}'_train pass'; echo ${run_result} >> ${run_benchmark_train_result_file}
|
run_result=$1': '${model_name}'_train pass'; echo ${run_result} >> ${run_benchmark_train_result_file}
|
||||||
else
|
else
|
||||||
run_result=$1': '${model_name}'_train failed'; echo ${run_result} >> ${run_benchmark_train_result_file}; return 1
|
run_result=$1': '${model_name}'_train failed'; echo ${run_result} >> ${run_benchmark_train_result_file};
|
||||||
fi
|
fi
|
||||||
done < ${models_mindspore_train_config}
|
done < ${models_mindspore_train_config}
|
||||||
}
|
}
|
||||||
|
@ -222,16 +227,15 @@ echo ${basepath}
|
||||||
# Example:run_benchmark_train.sh -r /home/emir/Work/TestingEnv/release -m /home/emir/Work/TestingEnv/train_models -i /home/emir/Work/TestingEnv/train_io -d "8KE5T19620002408"
|
# Example:run_benchmark_train.sh -r /home/emir/Work/TestingEnv/release -m /home/emir/Work/TestingEnv/train_models -i /home/emir/Work/TestingEnv/train_io -d "8KE5T19620002408"
|
||||||
# For running on arm64, use -t to set platform tools path (for using adb commands)
|
# For running on arm64, use -t to set platform tools path (for using adb commands)
|
||||||
epoch_num=1
|
epoch_num=1
|
||||||
threads=1
|
threads=2
|
||||||
train_io_path=""
|
train_io_path=""
|
||||||
while getopts "r:m:d:i:e:vt:q:" opt; do
|
while getopts "r:m:d:i:e:vt:q:D" opt; do
|
||||||
case ${opt} in
|
case ${opt} in
|
||||||
r)
|
r)
|
||||||
release_path=${OPTARG}
|
release_path=${OPTARG}
|
||||||
echo "release_path is ${OPTARG}"
|
echo "release_path is ${OPTARG}"
|
||||||
;;
|
;;
|
||||||
m)
|
m)
|
||||||
|
|
||||||
models_path=${OPTARG}"/models_train"
|
models_path=${OPTARG}"/models_train"
|
||||||
echo "models_path is ${OPTARG}"
|
echo "models_path is ${OPTARG}"
|
||||||
;;
|
;;
|
||||||
|
@ -244,8 +248,9 @@ while getopts "r:m:d:i:e:vt:q:" opt; do
|
||||||
echo "device_id is ${OPTARG}"
|
echo "device_id is ${OPTARG}"
|
||||||
;;
|
;;
|
||||||
e)
|
e)
|
||||||
enable_export=${OPTARG}
|
enable_export=1
|
||||||
echo "enable_export = ${OPTARG}"
|
docker_image=${OPTARG}
|
||||||
|
echo "enable_export = 1, docker_image = ${OPTARG}"
|
||||||
;;
|
;;
|
||||||
v)
|
v)
|
||||||
run_valgrind="valgrind --log-file=valgrind.log "
|
run_valgrind="valgrind --log-file=valgrind.log "
|
||||||
|
@ -404,27 +409,27 @@ function Print_Benchmark_Result() {
|
||||||
done < ${run_benchmark_train_result_file}
|
done < ${run_benchmark_train_result_file}
|
||||||
MS_PRINT_TESTCASE_END_MSG
|
MS_PRINT_TESTCASE_END_MSG
|
||||||
}
|
}
|
||||||
|
result=0
|
||||||
# Check benchmark_train result and return value
|
# Check benchmark_train result and return value
|
||||||
if [[ ${Run_x86_status} != 0 ]];then
|
if [[ ${Run_x86_status} != 0 ]];then
|
||||||
echo "Run_x86 failed"
|
echo "Run_x86 failed"
|
||||||
cat ${run_x86_log_file}
|
cat ${run_x86_log_file}
|
||||||
exit 1
|
result=1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [[ ${Run_arm64_status} != 0 ]];then
|
if [[ ${Run_arm64_status} != 0 ]];then
|
||||||
echo "Run_arm64 failed"
|
echo "Run_arm64 failed"
|
||||||
cat ${run_arm64_log_file}
|
cat ${run_arm64_log_file}
|
||||||
exit 1
|
result=1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [[ ${Run_arm32_status} != 0 ]];then
|
if [[ ${Run_arm32_status} != 0 ]];then
|
||||||
echo "Run_arm32 failed"
|
echo "Run_arm32 failed"
|
||||||
cat ${run_arm32_log_file}
|
cat ${run_arm32_log_file}
|
||||||
exit 1
|
result=1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo "Test ended - Results:"
|
echo "Test ended - Results:"
|
||||||
Print_Benchmark_Result
|
Print_Benchmark_Result
|
||||||
echo "Test run Time:" $DIFF
|
echo "Test run Time:" $DIFF
|
||||||
exit 0
|
exit ${result}
|
||||||
|
|
|
@ -79,15 +79,18 @@ TEST_F(TestPoolingGradFp32, AvgPoolingGradFp32) {
|
||||||
|
|
||||||
auto output_data = new float[output_data_size];
|
auto output_data = new float[output_data_size];
|
||||||
ASSERT_NE(output_data, nullptr);
|
ASSERT_NE(output_data, nullptr);
|
||||||
|
|
||||||
// warm up loop
|
// warm up loop
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
AvgPoolingGrad(input_data, output_data, pooling_param, 1);
|
std::fill(output_data, output_data + output_data_size, 0.f);
|
||||||
|
AvgPoolingGrad(input_data, output_data, pooling_param->output_batch_, pooling_param);
|
||||||
}
|
}
|
||||||
|
|
||||||
int loop_count = 100;
|
int loop_count = 100;
|
||||||
auto time_start = mindspore::lite::GetTimeUs();
|
auto time_start = mindspore::lite::GetTimeUs();
|
||||||
for (int i = 0; i < loop_count; i++) {
|
for (int i = 0; i < loop_count; i++) {
|
||||||
AvgPoolingGrad(input_data, output_data, pooling_param, 1);
|
std::fill(output_data, output_data + output_data_size, 0.f);
|
||||||
|
AvgPoolingGrad(input_data, output_data, pooling_param->output_batch_, pooling_param);
|
||||||
}
|
}
|
||||||
auto time_end = mindspore::lite::GetTimeUs();
|
auto time_end = mindspore::lite::GetTimeUs();
|
||||||
auto cost = time_end - time_start;
|
auto cost = time_end - time_start;
|
||||||
|
@ -407,18 +410,21 @@ TEST_F(TestPoolingGradFp32, MaxPoolingGradFp32) {
|
||||||
std::string dx_path = "./test_data/pooling/maxpoolgradfp32_1_dx_1_28_28_3.bin";
|
std::string dx_path = "./test_data/pooling/maxpoolgradfp32_1_dx_1_28_28_3.bin";
|
||||||
auto dx_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(dx_path.c_str(), &input_size));
|
auto dx_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(dx_path.c_str(), &input_size));
|
||||||
ASSERT_NE(dx_data, nullptr);
|
ASSERT_NE(dx_data, nullptr);
|
||||||
|
int in_batch_size =
|
||||||
|
pooling_param->input_h_ * pooling_param->input_w_ * pooling_param->input_channel_ * pooling_param->input_batch_;
|
||||||
auto output_data = new float[output_data_size];
|
auto output_data = new float[output_data_size];
|
||||||
ASSERT_NE(output_data, nullptr);
|
ASSERT_NE(output_data, nullptr);
|
||||||
// warm up loop
|
// warm up loop
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
MaxPoolingGrad(in_data, dx_data, dy_data, output_data, pooling_param, 1);
|
std::fill(output_data, output_data + in_batch_size, 0.f);
|
||||||
|
MaxPoolingGrad(in_data, dy_data, output_data, pooling_param->output_batch_, pooling_param);
|
||||||
}
|
}
|
||||||
|
|
||||||
int loop_count = 100;
|
int loop_count = 100;
|
||||||
auto time_start = mindspore::lite::GetTimeUs();
|
auto time_start = mindspore::lite::GetTimeUs();
|
||||||
for (int i = 0; i < loop_count; i++) {
|
for (int i = 0; i < loop_count; i++) {
|
||||||
MaxPoolingGrad(in_data, dx_data, dy_data, output_data, pooling_param, 1);
|
std::fill(output_data, output_data + in_batch_size, 0.f);
|
||||||
|
MaxPoolingGrad(in_data, dy_data, output_data, pooling_param->output_batch_, pooling_param);
|
||||||
}
|
}
|
||||||
auto time_end = mindspore::lite::GetTimeUs();
|
auto time_end = mindspore::lite::GetTimeUs();
|
||||||
auto cost = time_end - time_start;
|
auto cost = time_end - time_start;
|
||||||
|
|
|
@ -135,7 +135,6 @@ int NetTrain::ReadCalibData() {
|
||||||
|
|
||||||
MS_LOG(INFO) << "Start reading calibData file";
|
MS_LOG(INFO) << "Start reading calibData file";
|
||||||
std::string tensor_name;
|
std::string tensor_name;
|
||||||
|
|
||||||
while (!in_file.eof()) {
|
while (!in_file.eof()) {
|
||||||
getline(in_file, line);
|
getline(in_file, line);
|
||||||
std::stringstream string_line1(line);
|
std::stringstream string_line1(line);
|
||||||
|
|
|
@ -79,7 +79,7 @@ class MS_API NetTrainFlags : public virtual FlagParser {
|
||||||
std::vector<std::string> input_data_list_;
|
std::vector<std::string> input_data_list_;
|
||||||
DataType in_data_type_;
|
DataType in_data_type_;
|
||||||
std::string in_data_type_in_ = "bin";
|
std::string in_data_type_in_ = "bin";
|
||||||
int cpu_bind_mode_ = 0;
|
int cpu_bind_mode_ = 1;
|
||||||
// MarkPerformance
|
// MarkPerformance
|
||||||
int num_threads_ = 1;
|
int num_threads_ = 1;
|
||||||
int warm_up_loop_count_ = 0;
|
int warm_up_loop_count_ = 0;
|
||||||
|
|
|
@ -32,7 +32,6 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = {
|
||||||
schema::PrimitiveType_PoolingGrad,
|
schema::PrimitiveType_PoolingGrad,
|
||||||
schema::PrimitiveType_BiasGrad,
|
schema::PrimitiveType_BiasGrad,
|
||||||
schema::PrimitiveType_BNGrad,
|
schema::PrimitiveType_BNGrad,
|
||||||
schema::PrimitiveType_ActivationGrad,
|
|
||||||
schema::PrimitiveType_ApplyMomentum,
|
schema::PrimitiveType_ApplyMomentum,
|
||||||
schema::PrimitiveType_Sgd,
|
schema::PrimitiveType_Sgd,
|
||||||
schema::PrimitiveType_Adam,
|
schema::PrimitiveType_Adam,
|
||||||
|
@ -219,6 +218,26 @@ STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::v
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool IsKCHWSource(kTransFilterType type) {
|
||||||
|
return (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool IsCKHWSource(kTransFilterType type) {
|
||||||
|
return (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool IsHWCKSource(kTransFilterType type) { return (type == kHWCK2KCHW || type == kHWCK2CKHW); }
|
||||||
|
|
||||||
|
static bool IsHWKCSource(kTransFilterType type) { return (type == kHWKC2KCHW || type == kHWKC2CKHW); }
|
||||||
|
|
||||||
|
static bool IsNHWCSource(kTransFilterType type) {
|
||||||
|
return (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool IsCHWKSource(kTransFilterType type) { return (type == kCHWK2HWCK || type == kCHWK2KHWC); }
|
||||||
|
|
||||||
|
static bool IsKHWCSource(kTransFilterType type) { return (type == kKHWC2HWCK || type == kKHWC2CHWK); }
|
||||||
|
|
||||||
STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC,
|
STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC,
|
||||||
int32_t *filterH, int32_t *filterW) {
|
int32_t *filterH, int32_t *filterW) {
|
||||||
if (filterK == nullptr || filterC == nullptr || filterH == nullptr || filterW == nullptr) {
|
if (filterK == nullptr || filterC == nullptr || filterH == nullptr || filterW == nullptr) {
|
||||||
|
@ -226,37 +245,37 @@ STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type,
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
MS_ASSERT(oriDims.size() == 4);
|
MS_ASSERT(oriDims.size() == 4);
|
||||||
if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW) {
|
if (IsKCHWSource(type)) {
|
||||||
*filterK = oriDims.at(KCHW_K);
|
*filterK = oriDims.at(KCHW_K);
|
||||||
*filterC = oriDims.at(KCHW_C);
|
*filterC = oriDims.at(KCHW_C);
|
||||||
*filterH = oriDims.at(KCHW_H);
|
*filterH = oriDims.at(KCHW_H);
|
||||||
*filterW = oriDims.at(KCHW_W);
|
*filterW = oriDims.at(KCHW_W);
|
||||||
} else if (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC) {
|
} else if (IsCKHWSource(type)) {
|
||||||
*filterC = oriDims.at(CKHW_C);
|
*filterC = oriDims.at(CKHW_C);
|
||||||
*filterK = oriDims.at(CKHW_K);
|
*filterK = oriDims.at(CKHW_K);
|
||||||
*filterH = oriDims.at(CKHW_H);
|
*filterH = oriDims.at(CKHW_H);
|
||||||
*filterW = oriDims.at(CKHW_W);
|
*filterW = oriDims.at(CKHW_W);
|
||||||
} else if (type == kHWCK2KCHW || type == kHWCK2CKHW) {
|
} else if (IsHWCKSource(type)) {
|
||||||
*filterH = oriDims.at(HWCK_H);
|
*filterH = oriDims.at(HWCK_H);
|
||||||
*filterW = oriDims.at(HWCK_W);
|
*filterW = oriDims.at(HWCK_W);
|
||||||
*filterC = oriDims.at(HWCK_C);
|
*filterC = oriDims.at(HWCK_C);
|
||||||
*filterK = oriDims.at(HWCK_K);
|
*filterK = oriDims.at(HWCK_K);
|
||||||
} else if (type == kHWKC2KCHW || type == kHWKC2CKHW) {
|
} else if (IsHWKCSource(type)) {
|
||||||
*filterH = oriDims.at(HWKC_H);
|
*filterH = oriDims.at(HWKC_H);
|
||||||
*filterW = oriDims.at(HWKC_W);
|
*filterW = oriDims.at(HWKC_W);
|
||||||
*filterK = oriDims.at(HWKC_K);
|
*filterK = oriDims.at(HWKC_K);
|
||||||
*filterC = oriDims.at(HWKC_C);
|
*filterC = oriDims.at(HWKC_C);
|
||||||
} else if (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW) {
|
} else if (IsNHWCSource(type)) {
|
||||||
*filterK = oriDims.at(NHWC_N);
|
*filterK = oriDims.at(NHWC_N);
|
||||||
*filterH = oriDims.at(NHWC_H);
|
*filterH = oriDims.at(NHWC_H);
|
||||||
*filterW = oriDims.at(NHWC_W);
|
*filterW = oriDims.at(NHWC_W);
|
||||||
*filterC = oriDims.at(NHWC_C);
|
*filterC = oriDims.at(NHWC_C);
|
||||||
} else if (type == kCHWK2HWCK || type == kCHWK2KHWC) {
|
} else if (IsCHWKSource(type)) {
|
||||||
*filterC = oriDims.at(CHWK_C);
|
*filterC = oriDims.at(CHWK_C);
|
||||||
*filterH = oriDims.at(CHWK_H);
|
*filterH = oriDims.at(CHWK_H);
|
||||||
*filterW = oriDims.at(CHWK_W);
|
*filterW = oriDims.at(CHWK_W);
|
||||||
*filterK = oriDims.at(CHWK_K);
|
*filterK = oriDims.at(CHWK_K);
|
||||||
} else if (type == kKHWC2HWCK || type == kKHWC2CHWK) {
|
} else if (IsKHWCSource(type)) {
|
||||||
*filterK = oriDims.at(KHWC_K);
|
*filterK = oriDims.at(KHWC_K);
|
||||||
*filterH = oriDims.at(KHWC_H);
|
*filterH = oriDims.at(KHWC_H);
|
||||||
*filterW = oriDims.at(KHWC_W);
|
*filterW = oriDims.at(KHWC_W);
|
||||||
|
@ -290,6 +309,37 @@ STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filt
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int Convert2KHWC(int srcFormat) {
|
||||||
|
if (srcFormat == schema::Format::Format_KCHW) return kKCHW2KHWC;
|
||||||
|
if (srcFormat == schema::Format::Format_CKHW) return kCKHW2KHWC;
|
||||||
|
if (srcFormat == schema::Format::Format_CHWK) return kCHWK2KHWC;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int Convert2HWCK(int srcFormat) {
|
||||||
|
if (srcFormat == schema::Format::Format_KCHW) return kKCHW2HWCK;
|
||||||
|
if (srcFormat == schema::Format::Format_KHWC) return kKHWC2HWCK;
|
||||||
|
if (srcFormat == schema::Format::Format_CKHW) return kCKHW2HWCK;
|
||||||
|
if (srcFormat == schema::Format::Format_CHWK) return kCHWK2HWCK;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int Convert2KCHW(int srcFormat) {
|
||||||
|
if (srcFormat == schema::Format::Format_HWCK) return kHWCK2KCHW;
|
||||||
|
if (srcFormat == schema::Format::Format_HWKC) return kHWKC2KCHW;
|
||||||
|
if (srcFormat == schema::Format::Format_KHWC) return kKHWC2KCHW;
|
||||||
|
if (srcFormat == schema::Format::Format_CKHW) return kCKHW2KCHW;
|
||||||
|
if (srcFormat == schema::Format::Format_CHWK) return kCHWK2KCHW;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int Convert2CKHW(int srcFormat) {
|
||||||
|
if (srcFormat == schema::Format::Format_HWCK) return kHWCK2CKHW;
|
||||||
|
if (srcFormat == schema::Format::Format_HWKC) return kHWKC2CKHW;
|
||||||
|
if (srcFormat == schema::Format::Format_KCHW) return kKCHW2CKHW;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) {
|
STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) {
|
||||||
if (tensor == nullptr) {
|
if (tensor == nullptr) {
|
||||||
MS_LOG(ERROR) << "tensor is null";
|
MS_LOG(ERROR) << "tensor is null";
|
||||||
|
@ -303,231 +353,40 @@ STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) {
|
||||||
auto srcFormat = tensor->format;
|
auto srcFormat = tensor->format;
|
||||||
auto dataType = tensor->dataType;
|
auto dataType = tensor->dataType;
|
||||||
STATUS status;
|
STATUS status;
|
||||||
|
int convert = -1;
|
||||||
|
|
||||||
|
if (dstFormat == srcFormat) return RET_OK;
|
||||||
|
|
||||||
switch (dstFormat) {
|
switch (dstFormat) {
|
||||||
case schema::Format::Format_KHWC: {
|
case schema::Format::Format_KHWC:
|
||||||
switch (srcFormat) {
|
convert = Convert2KHWC(srcFormat);
|
||||||
case schema::Format::Format_KCHW:
|
break;
|
||||||
if (dataType == kNumberTypeFloat32) {
|
case schema::Format::Format_HWCK:
|
||||||
status = TransFilterFormat<float>(tensor, kKCHW2KHWC);
|
convert = Convert2HWCK(srcFormat);
|
||||||
} else if (dataType == kNumberTypeUInt8) {
|
break;
|
||||||
status = TransFilterFormat<uint8_t>(tensor, kKCHW2KHWC);
|
case schema::Format::Format_KCHW:
|
||||||
} else if (dataType == kNumberTypeInt8) {
|
convert = Convert2KCHW(srcFormat);
|
||||||
status = TransFilterFormat<int8_t>(tensor, kKCHW2KHWC);
|
break;
|
||||||
} else {
|
case schema::Format::Format_CKHW:
|
||||||
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
|
convert = Convert2CKHW(srcFormat);
|
||||||
return RET_ERROR;
|
break;
|
||||||
}
|
|
||||||
break;
|
|
||||||
case schema::Format::Format_CKHW:
|
|
||||||
if (dataType == kNumberTypeFloat32) {
|
|
||||||
status = TransFilterFormat<float>(tensor, kCKHW2KHWC);
|
|
||||||
} else if (dataType == kNumberTypeUInt8) {
|
|
||||||
status = TransFilterFormat<uint8_t>(tensor, kCKHW2KHWC);
|
|
||||||
} else if (dataType == kNumberTypeInt8) {
|
|
||||||
status = TransFilterFormat<int8_t>(tensor, kCKHW2KHWC);
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case schema::Format::Format_CHWK:
|
|
||||||
if (dataType == kNumberTypeFloat32) {
|
|
||||||
status = TransFilterFormat<float>(tensor, kCHWK2KHWC);
|
|
||||||
} else if (dataType == kNumberTypeUInt8) {
|
|
||||||
status = TransFilterFormat<uint8_t>(tensor, kCHWK2KHWC);
|
|
||||||
} else if (dataType == kNumberTypeInt8) {
|
|
||||||
status = TransFilterFormat<int8_t>(tensor, kCHWK2KHWC);
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case schema::Format::Format_KHWC:
|
|
||||||
return RET_OK;
|
|
||||||
default:
|
|
||||||
MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to "
|
|
||||||
<< EnumNameFormat(dstFormat);
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
} break;
|
|
||||||
case schema::Format::Format_HWCK: {
|
|
||||||
switch (srcFormat) {
|
|
||||||
case schema::Format::Format_KCHW:
|
|
||||||
if (dataType == kNumberTypeFloat32) {
|
|
||||||
status = TransFilterFormat<float>(tensor, kKCHW2HWCK);
|
|
||||||
} else if (dataType == kNumberTypeUInt8) {
|
|
||||||
status = TransFilterFormat<uint8_t>(tensor, kKCHW2HWCK);
|
|
||||||
} else if (dataType == kNumberTypeInt8) {
|
|
||||||
status = TransFilterFormat<int8_t>(tensor, kKCHW2HWCK);
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case schema::Format::Format_KHWC:
|
|
||||||
if (dataType == kNumberTypeFloat32) {
|
|
||||||
status = TransFilterFormat<float>(tensor, kKHWC2HWCK);
|
|
||||||
} else if (dataType == kNumberTypeUInt8) {
|
|
||||||
status = TransFilterFormat<uint8_t>(tensor, kKHWC2HWCK);
|
|
||||||
} else if (dataType == kNumberTypeInt8) {
|
|
||||||
status = TransFilterFormat<int8_t>(tensor, kKHWC2HWCK);
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case schema::Format::Format_CKHW:
|
|
||||||
if (dataType == kNumberTypeFloat32) {
|
|
||||||
status = TransFilterFormat<float>(tensor, kCKHW2HWCK);
|
|
||||||
} else if (dataType == kNumberTypeUInt8) {
|
|
||||||
status = TransFilterFormat<uint8_t>(tensor, kCKHW2HWCK);
|
|
||||||
} else if (dataType == kNumberTypeInt8) {
|
|
||||||
status = TransFilterFormat<int8_t>(tensor, kCKHW2HWCK);
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case schema::Format::Format_CHWK:
|
|
||||||
if (dataType == kNumberTypeFloat32) {
|
|
||||||
status = TransFilterFormat<float>(tensor, kCHWK2HWCK);
|
|
||||||
} else if (dataType == kNumberTypeUInt8) {
|
|
||||||
status = TransFilterFormat<uint8_t>(tensor, kCHWK2HWCK);
|
|
||||||
} else if (dataType == kNumberTypeInt8) {
|
|
||||||
status = TransFilterFormat<int8_t>(tensor, kCHWK2HWCK);
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case schema::Format::Format_HWCK:
|
|
||||||
return RET_OK;
|
|
||||||
default:
|
|
||||||
MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to "
|
|
||||||
<< EnumNameFormat(dstFormat);
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
} break;
|
|
||||||
case schema::Format::Format_KCHW: {
|
|
||||||
switch (srcFormat) {
|
|
||||||
case schema::Format::Format_KCHW:
|
|
||||||
return RET_OK;
|
|
||||||
case schema::Format::Format_HWCK:
|
|
||||||
if (dataType == kNumberTypeFloat32) {
|
|
||||||
status = TransFilterFormat<float>(tensor, kHWCK2KCHW);
|
|
||||||
} else if (dataType == kNumberTypeUInt8) {
|
|
||||||
status = TransFilterFormat<uint8_t>(tensor, kHWCK2KCHW);
|
|
||||||
} else if (dataType == kNumberTypeInt8) {
|
|
||||||
status = TransFilterFormat<int8_t>(tensor, kHWCK2KCHW);
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case schema::Format::Format_HWKC:
|
|
||||||
if (dataType == kNumberTypeFloat32) {
|
|
||||||
status = TransFilterFormat<float>(tensor, kHWKC2KCHW);
|
|
||||||
} else if (dataType == kNumberTypeUInt8) {
|
|
||||||
status = TransFilterFormat<uint8_t>(tensor, kHWKC2KCHW);
|
|
||||||
} else if (dataType == kNumberTypeInt8) {
|
|
||||||
status = TransFilterFormat<int8_t>(tensor, kHWKC2KCHW);
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case schema::Format::Format_KHWC:
|
|
||||||
if (dataType == kNumberTypeFloat32) {
|
|
||||||
status = TransFilterFormat<float>(tensor, kKHWC2KCHW);
|
|
||||||
} else if (dataType == kNumberTypeUInt8) {
|
|
||||||
status = TransFilterFormat<uint8_t>(tensor, kKHWC2KCHW);
|
|
||||||
} else if (dataType == kNumberTypeInt8) {
|
|
||||||
status = TransFilterFormat<int8_t>(tensor, kKHWC2KCHW);
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case schema::Format::Format_CKHW:
|
|
||||||
if (dataType == kNumberTypeFloat32) {
|
|
||||||
status = TransFilterFormat<float>(tensor, kCKHW2KCHW);
|
|
||||||
} else if (dataType == kNumberTypeUInt8) {
|
|
||||||
status = TransFilterFormat<uint8_t>(tensor, kCKHW2KCHW);
|
|
||||||
} else if (dataType == kNumberTypeInt8) {
|
|
||||||
status = TransFilterFormat<int8_t>(tensor, kCKHW2KCHW);
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case schema::Format::Format_CHWK:
|
|
||||||
if (dataType == kNumberTypeFloat32) {
|
|
||||||
status = TransFilterFormat<float>(tensor, kCHWK2KCHW);
|
|
||||||
} else if (dataType == kNumberTypeUInt8) {
|
|
||||||
status = TransFilterFormat<uint8_t>(tensor, kCHWK2KCHW);
|
|
||||||
} else if (dataType == kNumberTypeInt8) {
|
|
||||||
status = TransFilterFormat<int8_t>(tensor, kCHWK2KCHW);
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to "
|
|
||||||
<< EnumNameFormat(dstFormat);
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
} break;
|
|
||||||
case schema::Format::Format_CKHW: {
|
|
||||||
switch (srcFormat) {
|
|
||||||
case schema::Format::Format_HWCK:
|
|
||||||
if (dataType == kNumberTypeFloat32) {
|
|
||||||
status = TransFilterFormat<float>(tensor, kHWCK2CKHW);
|
|
||||||
} else if (dataType == kNumberTypeUInt8) {
|
|
||||||
status = TransFilterFormat<uint8_t>(tensor, kHWCK2CKHW);
|
|
||||||
} else if (dataType == kNumberTypeInt8) {
|
|
||||||
status = TransFilterFormat<int8_t>(tensor, kHWCK2CKHW);
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case schema::Format::Format_HWKC:
|
|
||||||
if (dataType == kNumberTypeFloat32) {
|
|
||||||
status = TransFilterFormat<float>(tensor, kHWKC2CKHW);
|
|
||||||
} else if (dataType == kNumberTypeUInt8) {
|
|
||||||
status = TransFilterFormat<uint8_t>(tensor, kHWKC2CKHW);
|
|
||||||
} else if (dataType == kNumberTypeInt8) {
|
|
||||||
status = TransFilterFormat<int8_t>(tensor, kHWKC2CKHW);
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case schema::Format::Format_KCHW:
|
|
||||||
if (dataType == kNumberTypeFloat32) {
|
|
||||||
status = TransFilterFormat<float>(tensor, kKCHW2CKHW);
|
|
||||||
} else if (dataType == kNumberTypeUInt8) {
|
|
||||||
status = TransFilterFormat<uint8_t>(tensor, kKCHW2CKHW);
|
|
||||||
} else if (dataType == kNumberTypeInt8) {
|
|
||||||
status = TransFilterFormat<int8_t>(tensor, kKCHW2CKHW);
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case schema::Format::Format_CKHW:
|
|
||||||
return RET_OK;
|
|
||||||
default:
|
|
||||||
MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to "
|
|
||||||
<< EnumNameFormat(dstFormat);
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
} break;
|
|
||||||
default:
|
default:
|
||||||
MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to "
|
convert = -1;
|
||||||
<< EnumNameFormat(dstFormat);
|
}
|
||||||
return RET_ERROR;
|
if (convert == -1) {
|
||||||
|
MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " << EnumNameFormat(dstFormat);
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dataType == kNumberTypeFloat32) {
|
||||||
|
status = TransFilterFormat<float>(tensor, static_cast<kTransFilterType>(convert));
|
||||||
|
} else if (dataType == kNumberTypeUInt8) {
|
||||||
|
status = TransFilterFormat<uint8_t>(tensor, static_cast<kTransFilterType>(convert));
|
||||||
|
} else if (dataType == kNumberTypeInt8) {
|
||||||
|
status = TransFilterFormat<int8_t>(tensor, static_cast<kTransFilterType>(convert));
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
|
||||||
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "TransFilterData failed: " << status;
|
MS_LOG(ERROR) << "TransFilterData failed: " << status;
|
||||||
|
|
Loading…
Reference in New Issue