forked from mindspore-Ecosystem/mindspore
!20399 [MS][LITE][develop] lite support vs build
Merge pull request !20399 from sunsuodong/vs_build_b
This commit is contained in:
commit
2a6caf97cc
|
@ -13,10 +13,10 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/fp32/conv_depthwise_fp32.h"
|
||||
#include "nnacl/common_func.h"
|
||||
#include "nnacl/fp32/common_func_fp32.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions.h"
|
||||
|
||||
#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE)
|
||||
void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, int num_pixels,
|
||||
|
@ -622,8 +622,8 @@ void ConvDw3x3Line(float *dst, float **lines, const float *weight, const float *
|
|||
MS_STQ_F32(cur_dst + ori_channel, res1);
|
||||
} else {
|
||||
for (int i = 0; i < channel; i++) {
|
||||
cur_dst[i] = res0[i];
|
||||
cur_dst[ori_channel + i] = res1[i];
|
||||
cur_dst[i] = MS_F32X4_GETI(res0, i);
|
||||
cur_dst[ori_channel + i] = MS_F32X4_GETI(res1, i);
|
||||
}
|
||||
}
|
||||
cur_dst += 2 * ori_channel;
|
||||
|
@ -653,7 +653,7 @@ void ConvDw3x3Line(float *dst, float **lines, const float *weight, const float *
|
|||
MS_STQ_F32(cur_dst, res0);
|
||||
} else {
|
||||
for (int i = 0; i < channel; i++) {
|
||||
cur_dst[i] = res0[i];
|
||||
cur_dst[i] = MS_F32X4_GETI(res0, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -91,7 +91,7 @@ void PowerSingle(const float *input, const float *exponent, float *output, int l
|
|||
MS_FLOAT32X4 tmp_4 = MS_ADDQ_F32(MS_MULQ_F32(scale_4, MS_LDQ_F32(input + i)), shift_4);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
PowerScalarFun_ = CheckInteger(exponent[i + j]) ? OptimizedPowerScalar : StdPowerScalar;
|
||||
output[i + j] = PowerScalarFun_(tmp_4[j], exponent + i + j);
|
||||
output[i + j] = PowerScalarFun_(MS_F32X4_GETI(tmp_4, j), exponent + i + j);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <math.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/power_parameter.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions.h"
|
||||
|
||||
#if defined(ENABLE_ARM) || defined(ENABLE_AVX) || defined(ENABLE_SSE)
|
||||
typedef MS_FLOAT32X4 (*PowerSimdFun)(MS_FLOAT32X4 x, const void *exponent);
|
||||
|
@ -38,7 +39,7 @@ static inline float StdPowerScalar(float x, const void *exponent) { return powf(
|
|||
static inline MS_FLOAT32X4 StdPowerSimd(MS_FLOAT32X4 x, const void *exponent) {
|
||||
MS_FLOAT32X4 result;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
result[i] = powf(x[i], *(float *)exponent);
|
||||
MS_F32X4_GETI(result, i) = powf(MS_F32X4_GETI(x, i), *(float *)exponent);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -14,8 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#ifdef ENABLE_SSE
|
||||
#ifdef SUPPORT_MSVC
|
||||
#include <immintrin.h>
|
||||
#else
|
||||
#include <x86intrin.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
#include <immintrin.h>
|
||||
|
|
|
@ -13,8 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/fp32/winograd_utils.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions.h"
|
||||
#include "nnacl/base/minimal_filtering_generator.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
|
@ -486,7 +486,7 @@ void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 2;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -551,7 +551,7 @@ void OutputTransform4x2ReluUnit(const float *src_data, float *dst_data, const fl
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 2;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -621,7 +621,7 @@ void OutputTransform4x2Relu6Unit(const float *src_data, float *dst_data, const f
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 2;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -690,7 +690,7 @@ void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 3;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -762,7 +762,7 @@ void OutputTransform4x3ReluUnit(const float *src_data, float *dst_data, const fl
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 3;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -840,7 +840,7 @@ void OutputTransform4x3Relu6Unit(const float *src_data, float *dst_data, const f
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 3;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -917,7 +917,7 @@ void OutputTransform6x2Unit(const float *src_data, float *dst_data, const float
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 2;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -992,7 +992,7 @@ void OutputTransform6x2ReluUnit(const float *src_data, float *dst_data, const fl
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 2;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1072,7 +1072,7 @@ void OutputTransform6x2Relu6Unit(const float *src_data, float *dst_data, const f
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 2;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1146,7 +1146,7 @@ void OutputTransform6x3Unit(const float *src_data, float *dst_data, const float
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 3;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1223,7 +1223,7 @@ void OutputTransform6x3ReluUnit(const float *src_data, float *dst_data, const fl
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 3;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1306,7 +1306,7 @@ void OutputTransform6x3Relu6Unit(const float *src_data, float *dst_data, const f
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 3;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1385,7 +1385,7 @@ void OutputTransform6x4Unit(const float *src_data, float *dst_data, const float
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 4;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1468,7 +1468,7 @@ void OutputTransform6x4ReluUnit(const float *src_data, float *dst_data, const fl
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 4;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1558,7 +1558,7 @@ void OutputTransform6x4Relu6Unit(const float *src_data, float *dst_data, const f
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 4;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1641,7 +1641,7 @@ void OutputTransform6x5Unit(const float *src_data, float *dst_data, const float
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 5;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1729,7 +1729,7 @@ void OutputTransform6x5ReluUnit(const float *src_data, float *dst_data, const fl
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 5;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1825,7 +1825,7 @@ void OutputTransform6x5Relu6Unit(const float *src_data, float *dst_data, const f
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 5;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1911,7 +1911,7 @@ void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 2;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1994,7 +1994,7 @@ void OutputTransform8x2ReluUnit(const float *src_data, float *dst_data, const fl
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 2;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2082,7 +2082,7 @@ void OutputTransform8x2Relu6Unit(const float *src_data, float *dst_data, const f
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 2;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2168,7 +2168,7 @@ void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 3;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2259,7 +2259,7 @@ void OutputTransform8x3ReluUnit(const float *src_data, float *dst_data, const fl
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 3;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2356,7 +2356,7 @@ void OutputTransform8x3Relu6Unit(const float *src_data, float *dst_data, const f
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 3;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2450,7 +2450,7 @@ void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 4;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2550,7 +2550,7 @@ void OutputTransform8x4ReluUnit(const float *src_data, float *dst_data, const fl
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 4;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2657,7 +2657,7 @@ void OutputTransform8x4Relu6Unit(const float *src_data, float *dst_data, const f
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 4;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2761,7 +2761,7 @@ void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 5;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2869,7 +2869,7 @@ void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const fl
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 5;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2988,7 +2988,7 @@ void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const f
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 5;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3108,7 +3108,7 @@ void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 6;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3236,7 +3236,7 @@ void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const fl
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 6;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3373,7 +3373,7 @@ void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const f
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 6;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3501,7 +3501,7 @@ void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 7;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3638,7 +3638,7 @@ void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const fl
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 7;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3785,7 +3785,7 @@ void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const f
|
|||
int dst_k_offset = j * dst_step * out_c;
|
||||
int m_k_offset = j * 7;
|
||||
for (int k = 0; k < r_w; k++) {
|
||||
dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i];
|
||||
dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,10 +22,14 @@ void var2Invar(float *save_var, int size, float eps) {
|
|||
save_var[i] = 1.0f / sqrt(save_var[i] + eps);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef SUPPORT_MSVC
|
||||
void backwardAll(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size,
|
||||
int ch, float *dxhat_sum, float *dxhathat_sum, float *dbias, float *dscale, float *dx) {
|
||||
#else
|
||||
void backwardAll(const float *restrict in, const float *restrict yt, const float *restrict mean,
|
||||
const float *restrict invar, const float *restrict scale, int size, int ch, float *restrict dxhat_sum,
|
||||
float *restrict dxhathat_sum, float *restrict dbias, float *restrict dscale, float *restrict dx) {
|
||||
#endif
|
||||
float N = (float)size;
|
||||
for (int i = 0; i < size; i++) {
|
||||
for (int c = 0; c < ch; c++) {
|
||||
|
@ -50,9 +54,15 @@ void backwardAll(const float *restrict in, const float *restrict yt, const float
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef SUPPORT_MSVC
|
||||
void backwardP1(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size,
|
||||
int ch, float *dxhat_sum, float *dxhathat_sum, float *dbias, float *dscale) {
|
||||
#else
|
||||
void backwardP1(const float *restrict in, const float *restrict yt, const float *restrict mean,
|
||||
const float *restrict invar, const float *restrict scale, int size, int ch, float *restrict dxhat_sum,
|
||||
float *restrict dxhathat_sum, float *restrict dbias, float *restrict dscale) {
|
||||
#endif
|
||||
for (int i = 0; i < size; i++) {
|
||||
for (int c = 0; c < ch; c++) {
|
||||
int ix = i * ch + c;
|
||||
|
@ -68,9 +78,14 @@ void backwardP1(const float *restrict in, const float *restrict yt, const float
|
|||
}
|
||||
}
|
||||
|
||||
#ifdef SUPPORT_MSVC
|
||||
void backwardP2(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size,
|
||||
int total_size, int ch, const float *dxhat_sum, const float *dxhathat_sum, float *dx) {
|
||||
#else
|
||||
void backwardP2(const float *restrict in, const float *restrict yt, const float *restrict mean,
|
||||
const float *restrict invar, const float *restrict scale, int size, int total_size, int ch,
|
||||
const float *dxhat_sum, const float *dxhathat_sum, float *restrict dx) {
|
||||
#endif
|
||||
const float N = (float)total_size;
|
||||
for (int i = 0; i < size; i++) {
|
||||
for (int c = 0; c < ch; c++) {
|
||||
|
|
|
@ -21,7 +21,11 @@
|
|||
#endif
|
||||
#include "nnacl/fp32/matmul_fp32.h"
|
||||
|
||||
#ifdef SUPPORT_MSVC
|
||||
void AddMatrix(const float *__restrict v1, float *__restrict v2, float beta, int row, int col, int stride) {
|
||||
#else
|
||||
void AddMatrix(const float *restrict v1, float *restrict v2, float beta, int row, int col, int stride) {
|
||||
#endif
|
||||
const float *src_ptr = v1;
|
||||
float *dst_ptr = v2;
|
||||
for (int r = 0; r < row; r++) {
|
||||
|
@ -86,8 +90,7 @@ static void RowMajor2Row12MajorStride(const float *src_ptr, float *dst_ptr, int
|
|||
return;
|
||||
}
|
||||
|
||||
static void RowMajor2Col12MajorStride(const float *restrict src_ptr, float *restrict dst_ptr, size_t row, size_t col,
|
||||
int lead) {
|
||||
static void RowMajor2Col12MajorStride(const float *src_ptr, float *dst_ptr, size_t row, size_t col, int lead) {
|
||||
size_t row_up_12 = UP_ROUND(row, C12NUM);
|
||||
size_t row12 = row / C12NUM * C12NUM;
|
||||
size_t col4 = col / C4NUM * C4NUM;
|
||||
|
|
|
@ -15,7 +15,333 @@
|
|||
*/
|
||||
#include "nnacl/infer/infer_register.h"
|
||||
|
||||
__attribute__((init_priority(101))) InferShape g_infer_func[PrimType_MAX * sizeof(InferShape)];
|
||||
#ifdef SUPPORT_MSVC
|
||||
#include "nnacl/infer/adam_infer.h"
|
||||
#include "nnacl/infer/add_sub_grad_infer.h"
|
||||
#include "nnacl/infer/addn_infer.h"
|
||||
#include "nnacl/infer/apply_momentum_infer.h"
|
||||
#include "nnacl/infer/argmin_max_infer.h"
|
||||
#include "nnacl/infer/arithmetic_compare_infer.h"
|
||||
#include "nnacl/infer/arithmetic_grad_infer.h"
|
||||
#include "nnacl/infer/arithmetic_infer.h"
|
||||
#include "nnacl/infer/assert_op_infer.h"
|
||||
#include "nnacl/infer/assign_add_infer.h"
|
||||
#include "nnacl/infer/assign_infer.h"
|
||||
#include "nnacl/infer/audio_spectrogram_infer.h"
|
||||
#include "nnacl/infer/batch_to_space_infer.h"
|
||||
#include "nnacl/infer/bias_grad_infer.h"
|
||||
#include "nnacl/infer/binary_cross_entropy_infer.h"
|
||||
#include "nnacl/infer/bn_grad_infer.h"
|
||||
#include "nnacl/infer/broadcast_to_infer.h"
|
||||
#include "nnacl/infer/cast_infer.h"
|
||||
#include "nnacl/infer/common_infer.h"
|
||||
#include "nnacl/infer/concat_infer.h"
|
||||
#include "nnacl/infer/constant_of_shape_infer.h"
|
||||
#include "nnacl/infer/conv2d_grad_filter_infer.h"
|
||||
#include "nnacl/infer/conv2d_grad_input_infer.h"
|
||||
#include "nnacl/infer/conv2d_infer.h"
|
||||
#include "nnacl/infer/crop_and_resize_infer.h"
|
||||
#include "nnacl/infer/crop_infer.h"
|
||||
#include "nnacl/infer/cumsum_infer.h"
|
||||
#include "nnacl/infer/custom_extract_features_infer.h"
|
||||
#include "nnacl/infer/custom_normalize_infer.h"
|
||||
#include "nnacl/infer/custom_predict_infer.h"
|
||||
#include "nnacl/infer/deconv2d_infer.h"
|
||||
#include "nnacl/infer/dedepthwise_conv2d_infer.h"
|
||||
#include "nnacl/infer/depth_to_space_infer.h"
|
||||
#include "nnacl/infer/depthwise_conv2d_infer.h"
|
||||
#include "nnacl/infer/detection_post_process_infer.h"
|
||||
#include "nnacl/infer/dropout_grad_infer.h"
|
||||
#include "nnacl/infer/dropout_infer.h"
|
||||
#include "nnacl/infer/embedding_lookup_infer.h"
|
||||
#include "nnacl/infer/expand_dims_infer.h"
|
||||
#include "nnacl/infer/fft_imag_infer.h"
|
||||
#include "nnacl/infer/fft_real_infer.h"
|
||||
#include "nnacl/infer/fill_infer.h"
|
||||
#include "nnacl/infer/flatten_grad_infer.h"
|
||||
#include "nnacl/infer/flatten_infer.h"
|
||||
#include "nnacl/infer/full_connection_infer.h"
|
||||
#include "nnacl/infer/fused_batchnorm_infer.h"
|
||||
#include "nnacl/infer/gather_infer.h"
|
||||
#include "nnacl/infer/gather_nd_infer.h"
|
||||
#include "nnacl/infer/group_conv2d_grad_input_infer.h"
|
||||
#include "nnacl/infer/gru_infer.h"
|
||||
#include "nnacl/infer/hashtable_lookup_infer.h"
|
||||
#include "nnacl/infer/invert_permutation_infer.h"
|
||||
#include "nnacl/infer/layer_norm_grad_infer.h"
|
||||
#include "nnacl/infer/layer_norm_infer.h"
|
||||
#include "nnacl/infer/lin_space_infer.h"
|
||||
#include "nnacl/infer/log_softmax_infer.h"
|
||||
#include "nnacl/infer/lsh_projection_infer.h"
|
||||
#include "nnacl/infer/lstm_infer.h"
|
||||
#include "nnacl/infer/matmul_infer.h"
|
||||
#include "nnacl/infer/max_min_grad_infer.h"
|
||||
#include "nnacl/infer/mean_infer.h"
|
||||
#include "nnacl/infer/merge_infer.h"
|
||||
#include "nnacl/infer/mfcc_infer.h"
|
||||
#include "nnacl/infer/non_max_suppression_infer.h"
|
||||
#include "nnacl/infer/one_hot_infer.h"
|
||||
#include "nnacl/infer/pad_infer.h"
|
||||
#include "nnacl/infer/partial_infer.h"
|
||||
#include "nnacl/infer/pooling_grad_infer.h"
|
||||
#include "nnacl/infer/pooling_infer.h"
|
||||
#include "nnacl/infer/power_infer.h"
|
||||
#include "nnacl/infer/prior_box_infer.h"
|
||||
#include "nnacl/infer/quant_dtype_cast_infer.h"
|
||||
#include "nnacl/infer/random_standard_normal_infer.h"
|
||||
#include "nnacl/infer/range_infer.h"
|
||||
#include "nnacl/infer/rank_infer.h"
|
||||
#include "nnacl/infer/reduce_infer.h"
|
||||
#include "nnacl/infer/reshape_infer.h"
|
||||
#include "nnacl/infer/resize_grad_infer.h"
|
||||
#include "nnacl/infer/resize_infer.h"
|
||||
#include "nnacl/infer/rfft_infer.h"
|
||||
#include "nnacl/infer/roi_pooling_infer.h"
|
||||
#include "nnacl/infer/scatter_nd_infer.h"
|
||||
#include "nnacl/infer/select_infer.h"
|
||||
#include "nnacl/infer/sgd_infer.h"
|
||||
#include "nnacl/infer/shape_infer.h"
|
||||
#include "nnacl/infer/size_infer.h"
|
||||
#include "nnacl/infer/skip_gram_infer.h"
|
||||
#include "nnacl/infer/slice_infer.h"
|
||||
#include "nnacl/infer/softmax_cross_entropy_infer.h"
|
||||
#include "nnacl/infer/softmax_infer.h"
|
||||
#include "nnacl/infer/space_to_batch_infer.h"
|
||||
#include "nnacl/infer/space_to_batch_nd_infer.h"
|
||||
#include "nnacl/infer/space_to_depth_infer.h"
|
||||
#include "nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.h"
|
||||
#include "nnacl/infer/sparse_to_dense_infer.h"
|
||||
#include "nnacl/infer/splice_infer.h"
|
||||
#include "nnacl/infer/split_infer.h"
|
||||
#include "nnacl/infer/squeeze_infer.h"
|
||||
#include "nnacl/infer/stack_infer.h"
|
||||
#include "nnacl/infer/strided_slice_grad_infer.h"
|
||||
#include "nnacl/infer/strided_slice_infer.h"
|
||||
#include "nnacl/infer/switch_infer.h"
|
||||
#include "nnacl/infer/tensorlist_fromtensor_infer.h"
|
||||
#include "nnacl/infer/tensorlist_getitem_infer.h"
|
||||
#include "nnacl/infer/tensorlist_reserve_infer.h"
|
||||
#include "nnacl/infer/tensorlist_setitem_infer.h"
|
||||
#include "nnacl/infer/tensorlist_stack_infer.h"
|
||||
#include "nnacl/infer/tile_infer.h"
|
||||
#include "nnacl/infer/topk_infer.h"
|
||||
#include "nnacl/infer/transpose_infer.h"
|
||||
#include "nnacl/infer/uniform_real_infer.h"
|
||||
#include "nnacl/infer/unique_infer.h"
|
||||
#include "nnacl/infer/unsorted_segment_sum_infer.h"
|
||||
#include "nnacl/infer/unsqueeze_infer.h"
|
||||
#include "nnacl/infer/unstack_infer.h"
|
||||
#include "nnacl/infer/where_infer.h"
|
||||
#include "nnacl/infer/while_infer.h"
|
||||
|
||||
InferShape g_infer_func[PrimType_MAX * sizeof(InferShape)] = {0};
|
||||
void RegAllInferFunc1() {
|
||||
g_infer_func[PrimType_NONE] = NULL;
|
||||
g_infer_func[PrimType_Abs] = CommonInferShape;
|
||||
g_infer_func[PrimType_Activation] = CommonInferShape;
|
||||
g_infer_func[PrimType_ActivationGrad] = CommonInferShape;
|
||||
g_infer_func[PrimType_Adam] = AdamInferShape;
|
||||
g_infer_func[PrimType_AddFusion] = ArithmeticInferShape;
|
||||
g_infer_func[PrimType_AdderFusion] = Conv2dInferShape;
|
||||
g_infer_func[PrimType_AddGrad] = AddSubGradInferShape;
|
||||
g_infer_func[PrimType_AddN] = AddnInferShape;
|
||||
g_infer_func[PrimType_All] = NULL;
|
||||
g_infer_func[PrimType_ApplyMomentum] = ApplyMomentumInferShape;
|
||||
g_infer_func[PrimType_ArgMaxFusion] = ArgMinMaxInferShape;
|
||||
g_infer_func[PrimType_ArgMinFusion] = ArgMinMaxInferShape;
|
||||
g_infer_func[PrimType_Assert] = AssertOpInferShape;
|
||||
g_infer_func[PrimType_Assign] = AssignInferShape;
|
||||
g_infer_func[PrimType_AssignAdd] = AssignAddInferShape;
|
||||
g_infer_func[PrimType_AudioSpectrogram] = AudioSpectrogramInferShape;
|
||||
g_infer_func[PrimType_AvgPoolFusion] = PoolingInferShape;
|
||||
g_infer_func[PrimType_AvgPoolGrad] = PoolingGradInferShape;
|
||||
g_infer_func[PrimType_BatchNorm] = CommonInferShape;
|
||||
g_infer_func[PrimType_BatchNormGrad] = BnGradInferShape;
|
||||
g_infer_func[PrimType_BatchToSpace] = BatchToSpaceInferShape;
|
||||
g_infer_func[PrimType_BatchToSpaceND] = NULL;
|
||||
g_infer_func[PrimType_BiasAdd] = CommonInferShape;
|
||||
g_infer_func[PrimType_BinaryCrossEntropy] = BinaryCrossEntropyInferShape;
|
||||
g_infer_func[PrimType_BinaryCrossEntropyGrad] = CommonInferShape;
|
||||
g_infer_func[PrimType_BiasAddGrad] = BiasGradInferShape;
|
||||
g_infer_func[PrimType_BroadcastTo] = BroadcastToInferShape;
|
||||
g_infer_func[PrimType_Cast] = CastInferShape;
|
||||
g_infer_func[PrimType_Ceil] = CommonInferShape;
|
||||
g_infer_func[PrimType_Clip] = CommonInferShape;
|
||||
g_infer_func[PrimType_Concat] = ConcatInferShape;
|
||||
g_infer_func[PrimType_ControlDepend] = CommonInferShape;
|
||||
g_infer_func[PrimType_Conv2DBackpropFilterFusion] = Conv2dGradFilterInferShape;
|
||||
g_infer_func[PrimType_Conv2DBackpropInputFusion] = Conv2dGradInputInferShape;
|
||||
g_infer_func[PrimType_Conv2DFusion] = Conv2dInferShape;
|
||||
g_infer_func[PrimType_Conv2dTransposeFusion] = Deconv2dInferShape;
|
||||
g_infer_func[PrimType_Cos] = CommonInferShape;
|
||||
g_infer_func[PrimType_ConstantOfShape] = ConstantOfShapeInferShape;
|
||||
g_infer_func[PrimType_Crop] = CropInferShape;
|
||||
g_infer_func[PrimType_CustomExtractFeatures] = CustomExtractFeaturesInferShape;
|
||||
g_infer_func[PrimType_CustomNormalize] = CustomNormalizeInferShape;
|
||||
g_infer_func[PrimType_CustomPredict] = CustomPredictInferShape;
|
||||
g_infer_func[PrimType_DeConv2DGradFilter] = NULL;
|
||||
g_infer_func[PrimType_Depend] = CommonInferShape;
|
||||
g_infer_func[PrimType_DepthToSpace] = DepthToSpaceInferShape;
|
||||
g_infer_func[PrimType_DetectionPostProcess] = DetectionPostProcessInferShape;
|
||||
g_infer_func[PrimType_DivFusion] = ArithmeticInferShape;
|
||||
g_infer_func[PrimType_DivGrad] = ArithmeticGradInferShape;
|
||||
g_infer_func[PrimType_Dropout] = DropoutInferShape;
|
||||
g_infer_func[PrimType_DropoutGrad] = DropoutGradInferShape;
|
||||
g_infer_func[PrimType_Elu] = CommonInferShape;
|
||||
g_infer_func[PrimType_Eltwise] = ArithmeticInferShape;
|
||||
g_infer_func[PrimType_Equal] = ArithmeticCompareInferShape;
|
||||
g_infer_func[PrimType_EmbeddingLookupFusion] = EmbeddingLookupInferShape;
|
||||
g_infer_func[PrimType_ExpFusion] = CommonInferShape;
|
||||
g_infer_func[PrimType_ExpandDims] = ExpandDimsInferShape;
|
||||
g_infer_func[PrimType_FakeQuantWithMinMaxVars] = CommonInferShape;
|
||||
g_infer_func[PrimType_FakeQuantWithMinMaxVarsPerChannel] = NULL;
|
||||
g_infer_func[PrimType_FftReal] = FftRealInferShape;
|
||||
g_infer_func[PrimType_FftImag] = FftImagInferShape;
|
||||
g_infer_func[PrimType_Flatten] = FlattenInferShape;
|
||||
g_infer_func[PrimType_FlattenGrad] = FlattenGradInferShape;
|
||||
g_infer_func[PrimType_Floor] = CommonInferShape;
|
||||
g_infer_func[PrimType_FloorDiv] = ArithmeticInferShape;
|
||||
g_infer_func[PrimType_FloorMod] = ArithmeticInferShape;
|
||||
g_infer_func[PrimType_Fill] = FillInferShape;
|
||||
g_infer_func[PrimType_FullConnection] = FullConnectionInferShape;
|
||||
g_infer_func[PrimType_FusedBatchNorm] = FusedBatchNormInferShape;
|
||||
g_infer_func[PrimType_Gather] = GatherInferShape;
|
||||
g_infer_func[PrimType_GatherNd] = GatherNdInferShape;
|
||||
g_infer_func[PrimType_Greater] = ArithmeticCompareInferShape;
|
||||
g_infer_func[PrimType_GreaterEqual] = ArithmeticCompareInferShape;
|
||||
g_infer_func[PrimType_HashtableLookup] = HashtableLoopupInferShape;
|
||||
g_infer_func[PrimType_InstanceNorm] = CommonInferShape;
|
||||
g_infer_func[PrimType_LayerNormFusion] = LayerNormGradInferShape;
|
||||
g_infer_func[PrimType_LeakyRelu] = CommonInferShape;
|
||||
g_infer_func[PrimType_Less] = ArithmeticCompareInferShape;
|
||||
g_infer_func[PrimType_LessEqual] = ArithmeticCompareInferShape;
|
||||
g_infer_func[PrimType_Log] = CommonInferShape;
|
||||
g_infer_func[PrimType_LogGrad] = CommonInferShape;
|
||||
g_infer_func[PrimType_LogicalAnd] = ArithmeticInferShape;
|
||||
g_infer_func[PrimType_LogicalNot] = CommonInferShape;
|
||||
g_infer_func[PrimType_LogicalOr] = ArithmeticInferShape;
|
||||
g_infer_func[PrimType_LpNormalization] = NULL;
|
||||
g_infer_func[PrimType_LRN] = CommonInferShape;
|
||||
g_infer_func[PrimType_LshProjection] = LshProjectionInferShape;
|
||||
g_infer_func[PrimType_LSTM] = LstmInferShape;
|
||||
g_infer_func[PrimType_L2NormalizeFusion] = CommonInferShape;
|
||||
g_infer_func[PrimType_MatMul] = MatmulInferShape;
|
||||
g_infer_func[PrimType_Maximum] = ArithmeticInferShape;
|
||||
g_infer_func[PrimType_MaximumGrad] = MaxMinGradInferShape;
|
||||
g_infer_func[PrimType_MaxPoolFusion] = PoolingInferShape;
|
||||
g_infer_func[PrimType_MaxPoolGrad] = PoolingGradInferShape;
|
||||
g_infer_func[PrimType_Merge] = MergeInferShape;
|
||||
g_infer_func[PrimType_Mfcc] = MfccInferShape;
|
||||
g_infer_func[PrimType_Minimum] = ArithmeticInferShape;
|
||||
g_infer_func[PrimType_MinimumGrad] = MaxMinGradInferShape;
|
||||
}
|
||||
|
||||
void RegAllInferFunc2() {
|
||||
g_infer_func[PrimType_Mod] = ArithmeticInferShape;
|
||||
g_infer_func[PrimType_MulFusion] = ArithmeticInferShape;
|
||||
g_infer_func[PrimType_MulGrad] = ArithmeticGradInferShape;
|
||||
g_infer_func[PrimType_Neg] = CommonInferShape;
|
||||
g_infer_func[PrimType_NegGrad] = CommonInferShape;
|
||||
g_infer_func[PrimType_NotEqual] = ArithmeticCompareInferShape;
|
||||
g_infer_func[PrimType_NonMaxSuppression] = NonMaxSuppressionInferShape;
|
||||
g_infer_func[PrimType_OneHot] = OneHotInferShape;
|
||||
g_infer_func[PrimType_OnesLike] = NULL;
|
||||
g_infer_func[PrimType_PadFusion] = PadInferShape;
|
||||
g_infer_func[PrimType_PartialFusion] = PartialInferShape;
|
||||
g_infer_func[PrimType_PowerGrad] = CommonInferShape;
|
||||
g_infer_func[PrimType_PowFusion] = PowerInferShape;
|
||||
g_infer_func[PrimType_PriorBox] = PriorBoxInferShape;
|
||||
g_infer_func[PrimType_PReLUFusion] = CommonInferShape;
|
||||
g_infer_func[PrimType_QuantDTypeCast] = QuantDtypeCastInferShape;
|
||||
g_infer_func[PrimType_Rank] = RankInferShape;
|
||||
g_infer_func[PrimType_Range] = RangeInferShape;
|
||||
g_infer_func[PrimType_Reciprocal] = CommonInferShape;
|
||||
g_infer_func[PrimType_RealDiv] = ArithmeticInferShape;
|
||||
g_infer_func[PrimType_ReduceFusion] = ReduceInferShape;
|
||||
g_infer_func[PrimType_Reshape] = ReshapeInferShape;
|
||||
g_infer_func[PrimType_Resize] = ResizeInferShape;
|
||||
g_infer_func[PrimType_ReverseSequence] = CommonInferShape;
|
||||
g_infer_func[PrimType_ReverseV2] = CommonInferShape;
|
||||
g_infer_func[PrimType_Rfft] = RfftInferShape;
|
||||
g_infer_func[PrimType_ROIPooling] = ROIPoolingInferShape;
|
||||
g_infer_func[PrimType_Round] = CommonInferShape;
|
||||
g_infer_func[PrimType_Rsqrt] = CommonInferShape;
|
||||
g_infer_func[PrimType_ScaleFusion] = CommonInferShape;
|
||||
g_infer_func[PrimType_ScatterNd] = ScatterNdInferShape;
|
||||
g_infer_func[PrimType_SGD] = SgdInferShape;
|
||||
g_infer_func[PrimType_Shape] = ShapeInferShape;
|
||||
g_infer_func[PrimType_SigmoidCrossEntropyWithLogits] = CommonInferShape;
|
||||
g_infer_func[PrimType_SigmoidCrossEntropyWithLogitsGrad] = CommonInferShape;
|
||||
g_infer_func[PrimType_Sin] = CommonInferShape;
|
||||
g_infer_func[PrimType_SkipGram] = SkipGramInferShape;
|
||||
g_infer_func[PrimType_SliceFusion] = SliceInferShape;
|
||||
g_infer_func[PrimType_SmoothL1Loss] = CommonInferShape;
|
||||
g_infer_func[PrimType_SmoothL1LossGrad] = CommonInferShape;
|
||||
g_infer_func[PrimType_Softmax] = SoftMaxInferShape;
|
||||
g_infer_func[PrimType_SoftmaxCrossEntropyWithLogits] = SoftmaxCrossEntropyInferShape;
|
||||
g_infer_func[PrimType_SpaceToBatch] = SpaceToBatchInferShape;
|
||||
g_infer_func[PrimType_SpaceToBatchND] = SpaceToBatchNdInferShape;
|
||||
g_infer_func[PrimType_SpaceToDepth] = SpaceToDepthInferShape;
|
||||
g_infer_func[PrimType_SparseSoftmaxCrossEntropyWithLogits] = SparseSoftmaxCrossEntropyWithLogitsInferShape;
|
||||
g_infer_func[PrimType_SparseToDense] = SparseToDenseInferShape;
|
||||
g_infer_func[PrimType_Split] = SplitInferShape;
|
||||
g_infer_func[PrimType_Sqrt] = CommonInferShape;
|
||||
g_infer_func[PrimType_Squeeze] = SqueezeInferShape;
|
||||
g_infer_func[PrimType_Square] = CommonInferShape;
|
||||
g_infer_func[PrimType_SquaredDifference] = ArithmeticInferShape;
|
||||
g_infer_func[PrimType_Stack] = StackInferShape;
|
||||
g_infer_func[PrimType_StridedSlice] = StridedSliceInferShape;
|
||||
g_infer_func[PrimType_SubFusion] = ArithmeticInferShape;
|
||||
g_infer_func[PrimType_SubGrad] = AddSubGradInferShape;
|
||||
g_infer_func[PrimType_Switch] = SwitchInferShape;
|
||||
g_infer_func[PrimType_TensorListFromTensor] = TensorListFromTensorInferShape;
|
||||
g_infer_func[PrimType_TensorListGetItem] = TensorListGetItemInferShape;
|
||||
g_infer_func[PrimType_TensorListReserve] = TensorListReserveInferShape;
|
||||
g_infer_func[PrimType_TensorListSetItem] = TensorListSetItemInferShape;
|
||||
g_infer_func[PrimType_TensorListStack] = TensorListStackInferShape;
|
||||
g_infer_func[PrimType_TileFusion] = TileInferShape;
|
||||
g_infer_func[PrimType_TopKFusion] = TopKInferShape;
|
||||
g_infer_func[PrimType_Transpose] = TransposeInferShape;
|
||||
g_infer_func[PrimType_Unique] = UniqueInferShape;
|
||||
g_infer_func[PrimType_UnsortedSegmentSum] = UnsortedSegmentSumInferShape;
|
||||
g_infer_func[PrimType_Unsqueeze] = UnsqueezeInferShape;
|
||||
g_infer_func[PrimType_Unstack] = UnstackInferShape;
|
||||
g_infer_func[PrimType_While] = WhileInferShape;
|
||||
g_infer_func[PrimType_Where] = WhereInferShape;
|
||||
g_infer_func[PrimType_ZerosLike] = CommonInferShape;
|
||||
g_infer_func[PrimType_Select] = SelectInferShape;
|
||||
g_infer_func[PrimType_If] = CommonInferShape;
|
||||
g_infer_func[PrimType_GRU] = GruInferShape;
|
||||
g_infer_func[PrimType_NonZero] = NULL;
|
||||
g_infer_func[PrimType_InvertPermutation] = InvertPermutationInferShape;
|
||||
g_infer_func[PrimType_Size] = SizeInferShape;
|
||||
g_infer_func[PrimType_RandomStandardNormal] = RandomStandardNormalInferShape;
|
||||
g_infer_func[PrimType_CropAndResize] = CropAndResizeInferShape;
|
||||
g_infer_func[PrimType_Erf] = CommonInferShape;
|
||||
g_infer_func[PrimType_StridedSliceGrad] = StridedSliceGradInferShape;
|
||||
g_infer_func[PrimType_IsFinite] = CommonInferShape;
|
||||
g_infer_func[PrimType_LinSpace] = LinSpaceInferShape;
|
||||
g_infer_func[PrimType_UniformReal] = UniformRealInferShape;
|
||||
g_infer_func[PrimType_AbsGrad] = CommonInferShape;
|
||||
g_infer_func[PrimType_RsqrtGrad] = NULL;
|
||||
g_infer_func[PrimType_SqrtGrad] = NULL;
|
||||
g_infer_func[PrimType_LayerNormGrad] = LayerNormGradInferShape;
|
||||
g_infer_func[PrimType_ResizeGrad] = ResizeGradInferShape;
|
||||
g_infer_func[PrimType_Splice] = SpliceInferShape;
|
||||
g_infer_func[PrimType_LogSoftmax] = LogSoftmaxInferShape;
|
||||
g_infer_func[PrimType_Call] = NULL;
|
||||
g_infer_func[PrimType_Custom] = NULL;
|
||||
g_infer_func[PrimType_CumSum] = CumsumInferShape;
|
||||
}
|
||||
|
||||
typedef void RegFunc();
|
||||
#pragma data_seg(".CRT$XIU")
|
||||
static RegFunc *before[] = {RegAllInferFunc1, RegAllInferFunc2};
|
||||
#pragma data_seg()
|
||||
|
||||
#else
|
||||
__attribute__((init_priority(101))) InferShape g_infer_func[PrimType_MAX * sizeof(InferShape)] = {0};
|
||||
#endif // SUPPORT_MSVC
|
||||
|
||||
InferShape GetInferFunc(int prim_type) {
|
||||
if (prim_type < PrimType_MAX) {
|
||||
|
|
|
@ -232,8 +232,13 @@ enum PrimType {
|
|||
|
||||
void RegInfer(int prim_type, InferShape func);
|
||||
|
||||
#ifdef SUPPORT_MSVC
|
||||
#define REG_INFER(op, type, func)
|
||||
#else
|
||||
#define REG_INFER(op, type, func) \
|
||||
__attribute__((constructor(102))) void Reg##op##Infer() { RegInfer(type, func); }
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -16,7 +16,11 @@
|
|||
#ifndef MINDSPORE_NNACL_X86_64_AVX_COMMON_UTILS_H_
|
||||
#define MINDSPORE_NNACL_X86_64_AVX_COMMON_UTILS_H_
|
||||
|
||||
#ifdef SUPPORT_MSVC
|
||||
#include <immintrin.h>
|
||||
#else
|
||||
#include <x86intrin.h>
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -20,10 +20,17 @@
|
|||
|
||||
#ifdef ENABLE_ARM
|
||||
#include <arm_neon.h>
|
||||
#define MS_F32X4_GETI(src, i) src[i]
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_SSE) || defined(ENABLE_AVX)
|
||||
#ifdef SUPPORT_MSVC
|
||||
#include <immintrin.h>
|
||||
#define MS_F32X4_GETI(src, i) src.m128_f32[i]
|
||||
#else
|
||||
#include <x86intrin.h>
|
||||
#define MS_F32X4_GETI(src, i) src[i]
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_ARM
|
||||
|
@ -137,26 +144,7 @@ static inline float32x4_t vrecp(float32x4_t v) {
|
|||
#define MS_CAST_F32_S32(src) _mm_castsi128_ps(src)
|
||||
#endif
|
||||
|
||||
#define LOAD256X8_F32(src, input_ptr, num) \
|
||||
MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \
|
||||
MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \
|
||||
MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \
|
||||
MS_FLOAT32X8 src##4 = MS_LD256_F32(input_ptr + 3 * num); \
|
||||
MS_FLOAT32X8 src##5 = MS_LD256_F32(input_ptr + 4 * num); \
|
||||
MS_FLOAT32X8 src##6 = MS_LD256_F32(input_ptr + 5 * num); \
|
||||
MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num); \
|
||||
MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num);
|
||||
|
||||
#define STORE256X8_F32(output_ptr, num, dst) \
|
||||
MS_ST256_F32(output_ptr + 0 * num, dst##1); \
|
||||
MS_ST256_F32(output_ptr + 1 * num, dst##2); \
|
||||
MS_ST256_F32(output_ptr + 2 * num, dst##3); \
|
||||
MS_ST256_F32(output_ptr + 3 * num, dst##4); \
|
||||
MS_ST256_F32(output_ptr + 4 * num, dst##5); \
|
||||
MS_ST256_F32(output_ptr + 5 * num, dst##6); \
|
||||
MS_ST256_F32(output_ptr + 6 * num, dst##7); \
|
||||
MS_ST256_F32(output_ptr + 7 * num, dst##8);
|
||||
|
||||
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
|
||||
#define LOAD128X8_F32(src, input_ptr, num) \
|
||||
MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \
|
||||
MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \
|
||||
|
@ -195,7 +183,37 @@ static inline MS_FLOAT32X4 MS_TANHX4_F32(MS_FLOAT32X4 src) {
|
|||
return MS_MINQ_F32(MS_MAXQ_F32(MS_DIVQ_F32(a, b), neg), pos);
|
||||
}
|
||||
|
||||
static inline MS_FLOAT32X4 MS_ERFX4_F32(MS_FLOAT32X4 src) {
|
||||
MS_FLOAT32X4 dst;
|
||||
MS_F32X4_GETI(dst, 0) = erff(MS_F32X4_GETI(src, 0));
|
||||
MS_F32X4_GETI(dst, 1) = erff(MS_F32X4_GETI(src, 1));
|
||||
MS_F32X4_GETI(dst, 2) = erff(MS_F32X4_GETI(src, 2));
|
||||
MS_F32X4_GETI(dst, 3) = erff(MS_F32X4_GETI(src, 3));
|
||||
return dst;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
#define LOAD256X8_F32(src, input_ptr, num) \
|
||||
MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \
|
||||
MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \
|
||||
MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \
|
||||
MS_FLOAT32X8 src##4 = MS_LD256_F32(input_ptr + 3 * num); \
|
||||
MS_FLOAT32X8 src##5 = MS_LD256_F32(input_ptr + 4 * num); \
|
||||
MS_FLOAT32X8 src##6 = MS_LD256_F32(input_ptr + 5 * num); \
|
||||
MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num); \
|
||||
MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num);
|
||||
|
||||
#define STORE256X8_F32(output_ptr, num, dst) \
|
||||
MS_ST256_F32(output_ptr + 0 * num, dst##1); \
|
||||
MS_ST256_F32(output_ptr + 1 * num, dst##2); \
|
||||
MS_ST256_F32(output_ptr + 2 * num, dst##3); \
|
||||
MS_ST256_F32(output_ptr + 3 * num, dst##4); \
|
||||
MS_ST256_F32(output_ptr + 4 * num, dst##5); \
|
||||
MS_ST256_F32(output_ptr + 5 * num, dst##6); \
|
||||
MS_ST256_F32(output_ptr + 6 * num, dst##7); \
|
||||
MS_ST256_F32(output_ptr + 7 * num, dst##8);
|
||||
|
||||
static inline MS_FLOAT32X8 MS_TANHX8_F32(MS_FLOAT32X8 src) {
|
||||
static const float data[] = {378.0f, 17325.0f, 135135.0f, 28.0f, 3150.0f, 62370.0f};
|
||||
static const MS_FLOAT32X8 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f};
|
||||
|
@ -207,13 +225,4 @@ static inline MS_FLOAT32X8 MS_TANHX8_F32(MS_FLOAT32X8 src) {
|
|||
}
|
||||
#endif
|
||||
|
||||
static inline MS_FLOAT32X4 MS_ERFX4_F32(MS_FLOAT32X4 src) {
|
||||
MS_FLOAT32X4 dst;
|
||||
dst[0] = erff(src[0]);
|
||||
dst[1] = erff(src[1]);
|
||||
dst[2] = erff(src[2]);
|
||||
dst[3] = erff(src[3]);
|
||||
return dst;
|
||||
}
|
||||
|
||||
#endif // MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
|
||||
#if defined(ENABLE_SSE) && !defined(ENABLE_AVX)
|
||||
#include <x86intrin.h>
|
||||
#include "nnacl/intrinsics/ms_simd_instructions.h"
|
||||
#include "nnacl/fp32/common_func_fp32.h"
|
||||
|
||||
void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, size_t num_pixels,
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
|
||||
#ifdef ENABLE_SSE
|
||||
#include <x86intrin.h>
|
||||
#include "nnacl/intrinsics/ms_simd_instructions.h"
|
||||
#include "nnacl/fp32/conv_depthwise_fp32.h"
|
||||
#include "nnacl/intrinsics/sse/sse_common.h"
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
|
||||
#ifdef ENABLE_SSE
|
||||
#include <x86intrin.h>
|
||||
#include "nnacl/intrinsics/ms_simd_instructions.h"
|
||||
#include "nnacl/fp32/matmul_fp32.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/matmul_parameter.h"
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
|
||||
#ifdef ENABLE_SSE
|
||||
#include <x86intrin.h>
|
||||
#include "nnacl/intrinsics/ms_simd_instructions.h"
|
||||
#include "nnacl/fp32/common_func_fp32.h"
|
||||
#include "nnacl/intrinsics/sse/sse_common.h"
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
|
||||
#ifdef ENABLE_SSE
|
||||
#include <x86intrin.h>
|
||||
#include "nnacl/intrinsics/ms_simd_instructions.h"
|
||||
#include "nnacl/fp32/common_func_fp32.h"
|
||||
#include "nnacl/intrinsics/sse/sse_common.h"
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#if defined(ENABLE_SSE) && !defined(ENABLE_AVX)
|
||||
#include <x86intrin.h>
|
||||
#include "nnacl/intrinsics/ms_simd_instructions.h"
|
||||
#include "nnacl/fp32/common_func_fp32.h"
|
||||
|
||||
static inline void TiledC4MatmulFp32_Transfer(__m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4,
|
||||
|
@ -51,21 +51,23 @@ void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t
|
|||
weight_data[2] = _mm_loadu_ps(weight + 8);
|
||||
weight_data[3] = _mm_loadu_ps(weight + 12);
|
||||
weight += 16;
|
||||
__m128 dst1 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0]));
|
||||
__m128 dst2 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0]));
|
||||
__m128 dst3 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0]));
|
||||
__m128 dst4 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0]));
|
||||
__m128 dst1 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src1, 0)));
|
||||
__m128 dst2 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src2, 0)));
|
||||
__m128 dst3 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src3, 0)));
|
||||
__m128 dst4 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src4, 0)));
|
||||
for (int j = 1; j < 4; ++j) {
|
||||
TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[j], src1[j], src2[j], src3[j], src4[j]);
|
||||
TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[j], MS_F32X4_GETI(src1, j),
|
||||
MS_F32X4_GETI(src2, j), MS_F32X4_GETI(src3, j), MS_F32X4_GETI(src4, j));
|
||||
}
|
||||
TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src);
|
||||
src += 16;
|
||||
__m128 dst5 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0]));
|
||||
__m128 dst6 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0]));
|
||||
__m128 dst7 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0]));
|
||||
__m128 dst8 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0]));
|
||||
__m128 dst5 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src1, 0)));
|
||||
__m128 dst6 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src2, 0)));
|
||||
__m128 dst7 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src3, 0)));
|
||||
__m128 dst8 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src4, 0)));
|
||||
for (int j = 1; j < 4; ++j) {
|
||||
TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[j], src1[j], src2[j], src3[j], src4[j]);
|
||||
TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[j], MS_F32X4_GETI(src1, j),
|
||||
MS_F32X4_GETI(src2, j), MS_F32X4_GETI(src3, j), MS_F32X4_GETI(src4, j));
|
||||
}
|
||||
if (ic4_tmp != 0) {
|
||||
ic4_tmp -= 1;
|
||||
|
@ -75,64 +77,74 @@ void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t
|
|||
weight_data[1] = _mm_loadu_ps(weight + 4);
|
||||
weight += 8;
|
||||
|
||||
dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0])));
|
||||
dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0])));
|
||||
dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src1, 0))));
|
||||
dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src2, 0))));
|
||||
for (; ic4_tmp != 0; ic4_tmp -= 1) {
|
||||
dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0])));
|
||||
dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0])));
|
||||
dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src3, 0))));
|
||||
dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src4, 0))));
|
||||
|
||||
TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[1], src1[1], src2[1], src3[1], src4[1]);
|
||||
TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[1], MS_F32X4_GETI(src1, 1),
|
||||
MS_F32X4_GETI(src2, 1), MS_F32X4_GETI(src3, 1), MS_F32X4_GETI(src4, 1));
|
||||
|
||||
weight_data[2] = _mm_loadu_ps(weight);
|
||||
weight_data[3] = _mm_loadu_ps(weight + 4);
|
||||
weight += 8;
|
||||
|
||||
TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[2], src1[2], src2[2], src3[2], src4[2]);
|
||||
TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[2], MS_F32X4_GETI(src1, 2),
|
||||
MS_F32X4_GETI(src2, 2), MS_F32X4_GETI(src3, 2), MS_F32X4_GETI(src4, 2));
|
||||
|
||||
dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[3], _mm_set_ps1(src1[3])));
|
||||
dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[3], _mm_set_ps1(src2[3])));
|
||||
dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[3], _mm_set_ps1(MS_F32X4_GETI(src1, 3))));
|
||||
dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[3], _mm_set_ps1(MS_F32X4_GETI(src2, 3))));
|
||||
src1 = _mm_loadu_ps(src);
|
||||
src2 = _mm_loadu_ps(src + 4);
|
||||
dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[3], _mm_set_ps1(src3[3])));
|
||||
dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[3], _mm_set_ps1(src4[3])));
|
||||
dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[3], _mm_set_ps1(MS_F32X4_GETI(src3, 3))));
|
||||
dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[3], _mm_set_ps1(MS_F32X4_GETI(src4, 3))));
|
||||
src3 = _mm_loadu_ps(src + 8);
|
||||
src4 = _mm_loadu_ps(src + 12);
|
||||
src += 16;
|
||||
|
||||
TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[0], src1[0], src2[0], src3[0], src4[0]);
|
||||
TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[0], MS_F32X4_GETI(src1, 0),
|
||||
MS_F32X4_GETI(src2, 0), MS_F32X4_GETI(src3, 0), MS_F32X4_GETI(src4, 0));
|
||||
|
||||
TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[1], src1[1], src2[1], src3[1], src4[1]);
|
||||
TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[1], MS_F32X4_GETI(src1, 1),
|
||||
MS_F32X4_GETI(src2, 1), MS_F32X4_GETI(src3, 1), MS_F32X4_GETI(src4, 1));
|
||||
|
||||
TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[2], src1[2], src2[2], src3[2], src4[2]);
|
||||
TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[2], MS_F32X4_GETI(src1, 2),
|
||||
MS_F32X4_GETI(src2, 2), MS_F32X4_GETI(src3, 2), MS_F32X4_GETI(src4, 2));
|
||||
|
||||
weight_data[0] = _mm_loadu_ps(weight);
|
||||
weight_data[1] = _mm_loadu_ps(weight + 4);
|
||||
weight += 8;
|
||||
|
||||
TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[3], src1[3], src2[3], src3[3], src4[3]);
|
||||
TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[3], MS_F32X4_GETI(src1, 3),
|
||||
MS_F32X4_GETI(src2, 3), MS_F32X4_GETI(src3, 3), MS_F32X4_GETI(src4, 3));
|
||||
TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src);
|
||||
src += 16;
|
||||
|
||||
dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0])));
|
||||
dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0])));
|
||||
dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src1, 0))));
|
||||
dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src2, 0))));
|
||||
}
|
||||
dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0])));
|
||||
dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0])));
|
||||
dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src3, 0))));
|
||||
dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src4, 0))));
|
||||
|
||||
TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[1], src1[1], src2[1], src3[1], src4[1]);
|
||||
TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[1], MS_F32X4_GETI(src1, 1),
|
||||
MS_F32X4_GETI(src2, 1), MS_F32X4_GETI(src3, 1), MS_F32X4_GETI(src4, 1));
|
||||
|
||||
weight_data[2] = _mm_loadu_ps(weight);
|
||||
weight_data[3] = _mm_loadu_ps(weight + 4);
|
||||
weight += 8;
|
||||
|
||||
TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[2], src1[2], src2[2], src3[2], src4[2]);
|
||||
TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[2], MS_F32X4_GETI(src1, 2),
|
||||
MS_F32X4_GETI(src2, 2), MS_F32X4_GETI(src3, 2), MS_F32X4_GETI(src4, 2));
|
||||
|
||||
TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[3], src1[3], src2[3], src3[3], src4[3]);
|
||||
TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[3], MS_F32X4_GETI(src1, 3),
|
||||
MS_F32X4_GETI(src2, 3), MS_F32X4_GETI(src3, 3), MS_F32X4_GETI(src4, 3));
|
||||
|
||||
TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src);
|
||||
src += 16;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[j], src1[j], src2[j], src3[j], src4[j]);
|
||||
TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[j], MS_F32X4_GETI(src1, j),
|
||||
MS_F32X4_GETI(src2, j), MS_F32X4_GETI(src3, j), MS_F32X4_GETI(src4, j));
|
||||
}
|
||||
}
|
||||
_mm_storeu_ps(dst, dst1);
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#ifdef ENABLE_SSE
|
||||
#include <x86intrin.h>
|
||||
#include "nnacl/intrinsics/ms_simd_instructions.h"
|
||||
#include "nnacl/fp32/common_func_fp32.h"
|
||||
|
||||
void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) {
|
||||
|
|
|
@ -477,7 +477,11 @@ int Model::Export(Model *model, const char *filename) {
|
|||
ofs.seekp(0, std::ios::beg);
|
||||
ofs.write(liteModel->buf, liteModel->buf_size_);
|
||||
ofs.close();
|
||||
#ifdef SUPPORT_MSVC
|
||||
return RET_OK;
|
||||
#else
|
||||
return chmod(filename, S_IRUSR);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace mindspore::lite
|
||||
|
|
Loading…
Reference in New Issue