forked from mindspore-Ecosystem/mindspore
move in matmul, transpose and bias_add's opt
This commit is contained in:
parent
980f3769c3
commit
399276790e
|
@ -95,6 +95,7 @@ mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/deconv_winograd
|
|||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/deconv_winograd_fp32.c:DeConvWgMerge
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/avx/TiledC8MatMulFp32.c:TiledC8MatmulFp32
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/quant_dtype_cast_fp16.c:Fp16ToInt8_arm64
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.c:MatMul4x1Kernel
|
||||
mindspore/mindspore/ccsrc/backend/session/gpu_session.cc:mindspore::session::gpu::GPUSession::LoadInputData
|
||||
mindspore/mindspore/ccsrc/debug/dump_proto.cc:mindspore::ProtoExporter::SetNodeOutputType
|
||||
mindspore/mindspore/ccsrc/debug/dump_proto.cc:mindspore::ProtoExporter::SetValueToProto
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/fp32/bias_add.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
void BiasAddByInnerCore(const float *input, const float *bias, float *output, int64_t num) {
|
||||
int64_t index = 0;
|
||||
#if defined(ENABLE_SSE) || defined(ENABLE_ARM)
|
||||
for (; index <= num - C4NUM; index += C4NUM) {
|
||||
MS_FLOAT32X4 input_data = MS_LDQ_F32(input + index);
|
||||
MS_FLOAT32X4 bias_data = MS_LDQ_F32(bias + index);
|
||||
MS_STQ_F32(output + index, MS_ADD128_F32(input_data, bias_data));
|
||||
}
|
||||
#endif
|
||||
|
||||
for (; index < num; ++index) {
|
||||
output[index] = input[index] + bias[index];
|
||||
}
|
||||
}
|
||||
|
||||
void BiasAddByBatchCore(const float *input, const float *bias, float *output, int64_t num) {
|
||||
float *output1 = output;
|
||||
float *output2 = output + num;
|
||||
float *output3 = output + num * 2;
|
||||
float *output4 = output + num * 3;
|
||||
int64_t index = 0;
|
||||
#if defined(ENABLE_SSE) || defined(ENABLE_ARM)
|
||||
for (; index <= num - C4NUM; index += C4NUM) {
|
||||
MS_LOAD128X4_F32(input_data, input + index, num);
|
||||
MS_FLOAT32X4 bias_data = MS_LDQ_F32(bias + index);
|
||||
MS_STQ_F32(output1 + index, MS_ADD128_F32(input_data1, bias_data));
|
||||
MS_STQ_F32(output2 + index, MS_ADD128_F32(input_data2, bias_data));
|
||||
MS_STQ_F32(output3 + index, MS_ADD128_F32(input_data3, bias_data));
|
||||
MS_STQ_F32(output4 + index, MS_ADD128_F32(input_data4, bias_data));
|
||||
}
|
||||
#endif
|
||||
const float *input_data1 = input;
|
||||
const float *input_data2 = input + num;
|
||||
const float *input_data3 = input + num * 2;
|
||||
const float *input_data4 = input + num * 3;
|
||||
for (; index < num; ++index) {
|
||||
output1[index] = input_data1[index] + bias[index];
|
||||
output2[index] = input_data2[index] + bias[index];
|
||||
output3[index] = input_data3[index] + bias[index];
|
||||
output4[index] = input_data4[index] + bias[index];
|
||||
}
|
||||
}
|
||||
|
||||
void DoBiasAddByBatch(const float *input, const float *bias, float *output, int64_t start, int64_t end,
|
||||
int64_t inner_num) {
|
||||
if (inner_num == 0) {
|
||||
return;
|
||||
}
|
||||
int64_t start_outer = start / inner_num;
|
||||
int64_t start_inner = start % inner_num;
|
||||
int64_t end_outer = end / inner_num;
|
||||
int64_t end_inner = end % inner_num;
|
||||
const float *cur_input = input + start;
|
||||
const float *cur_bias = bias + start_inner;
|
||||
float *cur_output = output + start;
|
||||
if (start_outer == end_outer) {
|
||||
BiasAddByInnerCore(cur_input, cur_bias, cur_output, end_inner - start_inner);
|
||||
return;
|
||||
}
|
||||
if (start_inner != 0) {
|
||||
BiasAddByInnerCore(cur_input, cur_bias, cur_output, inner_num - start_inner);
|
||||
start_outer += 1;
|
||||
cur_input += inner_num - start_inner;
|
||||
cur_bias = bias;
|
||||
cur_output += inner_num - start_inner;
|
||||
}
|
||||
int64_t step = C4NUM * inner_num;
|
||||
for (; start_outer <= end_outer - C4NUM; start_outer += C4NUM) {
|
||||
BiasAddByBatchCore(cur_input, cur_bias, cur_output, inner_num);
|
||||
cur_input += step;
|
||||
cur_output += step;
|
||||
}
|
||||
for (; start_outer < end_outer; ++start_outer) {
|
||||
BiasAddByInnerCore(cur_input, cur_bias, cur_output, inner_num);
|
||||
cur_input += inner_num;
|
||||
cur_output += inner_num;
|
||||
}
|
||||
BiasAddByInnerCore(cur_input, cur_bias, cur_output, end_inner);
|
||||
}
|
||||
|
||||
void DoBiasAddByInner(const float *input, const float *bias, float *output, int64_t start, int64_t end,
|
||||
int64_t inner_num) {
|
||||
if (inner_num == 0) {
|
||||
return;
|
||||
}
|
||||
int64_t start_outer = start / inner_num;
|
||||
int64_t start_inner = start % inner_num;
|
||||
int64_t end_outer = end / inner_num;
|
||||
int64_t end_inner = end % inner_num;
|
||||
const float *cur_input = input + start;
|
||||
const float *cur_bias = bias + start_inner;
|
||||
float *cur_output = output + start;
|
||||
if (start_outer == end_outer) {
|
||||
BiasAddByInnerCore(cur_input, cur_bias, cur_output, end_inner - start_inner);
|
||||
return;
|
||||
} else {
|
||||
BiasAddByInnerCore(cur_input, cur_bias, cur_output, inner_num - start_inner);
|
||||
start_outer += 1;
|
||||
cur_input += inner_num - start_inner;
|
||||
cur_bias = bias;
|
||||
cur_output += inner_num - start_inner;
|
||||
}
|
||||
if (start_outer == end_outer) {
|
||||
BiasAddByInnerCore(cur_input, cur_bias, cur_output, end_inner);
|
||||
return;
|
||||
} else {
|
||||
for (; start_outer < end_outer; ++start_outer) {
|
||||
BiasAddByInnerCore(cur_input, cur_bias, cur_output, inner_num);
|
||||
cur_input += inner_num;
|
||||
cur_output += inner_num;
|
||||
}
|
||||
}
|
||||
BiasAddByInnerCore(cur_input, cur_bias, cur_output, end_inner);
|
||||
}
|
||||
|
||||
void BiasAddOpt(const float *input, const float *bias, float *output, int64_t start, int64_t end, int64_t inner_num,
|
||||
bool batch_priority) {
|
||||
if (batch_priority) {
|
||||
DoBiasAddByBatch(input, bias, output, start, end, inner_num);
|
||||
} else {
|
||||
DoBiasAddByInner(input, bias, output, start, end, inner_num);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_NNACL_FP32_BIAS_ADD_H_
|
||||
#define MINDSPORE_NNACL_FP32_BIAS_ADD_H_
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
void BiasAddOpt(const float *input, const float *bias, float *output, int64_t start, int64_t end, int64_t inner_num,
|
||||
bool batch_priority);
|
||||
|
||||
#ifdef __cplusplus
|
||||
};
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_NNACL_FP32_BIAS_ADD_H_
|
|
@ -726,7 +726,7 @@ void RowMajor2Col64Major(const float *src_ptr, float *dst_ptr, int row, int col)
|
|||
for (int i = 0; i < all_block_num; i += cur_block) {
|
||||
cur_block = MSMIN(C4NUM, all_block_num - i); // max_tile = 4
|
||||
int dst_stride = cur_block * C16NUM;
|
||||
int row_num = MSMIN(dst_stride, row - i * C8NUM);
|
||||
int row_num = MSMIN(dst_stride, row - i * C16NUM);
|
||||
const float *src = src_ptr + i * C16NUM * col;
|
||||
float *dst = dst_ptr + i * C16NUM * col;
|
||||
int r = 0;
|
||||
|
@ -2268,3 +2268,331 @@ void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void MatMul4x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep) {
|
||||
// 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute
|
||||
// 9: WriteBack
|
||||
asm volatile(
|
||||
"mov x8, %[input]\n"
|
||||
"mov x9, %[weight]\n"
|
||||
"mov x10, %[deep]\n"
|
||||
"add x5, %[input], %[deep], LSL #2\n"
|
||||
"add x6, %[input], %[deep], LSL #3\n"
|
||||
"add x7, x5, %[deep], LSL #3\n"
|
||||
"dup v0.2d, xzr\n"
|
||||
"dup v1.2d, xzr\n"
|
||||
"dup v2.2d, xzr\n"
|
||||
"dup v3.2d, xzr\n"
|
||||
"subs x10, x10, #16\n"
|
||||
"blt 2f\n"
|
||||
"1:\n" // LoopD16
|
||||
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n"
|
||||
"ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64\n"
|
||||
"ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x6], #64\n"
|
||||
"ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x7], #64\n"
|
||||
"ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v1.4s, v16.4s, v28.4s\n"
|
||||
"fmla v2.4s, v20.4s, v28.4s\n"
|
||||
"fmla v3.4s, v24.4s, v28.4s\n"
|
||||
"fmla v0.4s, v5.4s, v29.4s\n"
|
||||
"fmla v1.4s, v17.4s, v29.4s\n"
|
||||
"fmla v2.4s, v21.4s, v29.4s\n"
|
||||
"fmla v3.4s, v25.4s, v29.4s\n"
|
||||
"fmla v0.4s, v6.4s, v30.4s\n"
|
||||
"fmla v1.4s, v18.4s, v30.4s\n"
|
||||
"fmla v2.4s, v22.4s, v30.4s\n"
|
||||
"fmla v3.4s, v26.4s, v30.4s\n"
|
||||
"fmla v0.4s, v7.4s, v31.4s\n"
|
||||
"fmla v1.4s, v19.4s, v31.4s\n"
|
||||
"fmla v2.4s, v23.4s, v31.4s\n"
|
||||
"fmla v3.4s, v27.4s, v31.4s\n"
|
||||
"subs x10, x10, #16\n"
|
||||
"bge 1b\n"
|
||||
"2:\n" // LoopD12
|
||||
"adds x10, x10, #16\n"
|
||||
"cbz x10, 6f\n"
|
||||
"cmp x10, #12\n"
|
||||
"blt 3f\n"
|
||||
"ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n"
|
||||
"ld1 {v16.4s, v17.4s, v18.4s}, [x5], #48\n"
|
||||
"ld1 {v20.4s, v21.4s, v22.4s}, [x6], #48\n"
|
||||
"ld1 {v24.4s, v25.4s, v26.4s}, [x7], #48\n"
|
||||
"ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v1.4s, v16.4s, v28.4s\n"
|
||||
"fmla v2.4s, v20.4s, v28.4s\n"
|
||||
"fmla v3.4s, v24.4s, v28.4s\n"
|
||||
"fmla v0.4s, v5.4s, v29.4s\n"
|
||||
"fmla v1.4s, v17.4s, v29.4s\n"
|
||||
"fmla v2.4s, v21.4s, v29.4s\n"
|
||||
"fmla v3.4s, v25.4s, v29.4s\n"
|
||||
"fmla v0.4s, v6.4s, v30.4s\n"
|
||||
"fmla v1.4s, v18.4s, v30.4s\n"
|
||||
"fmla v2.4s, v22.4s, v30.4s\n"
|
||||
"fmla v3.4s, v26.4s, v30.4s\n"
|
||||
"sub x10, x10, #12\n"
|
||||
"b 7f\n"
|
||||
"3:\n" // LoopD8
|
||||
"cmp x10, #8\n"
|
||||
"blt 4f\n"
|
||||
"ld1 {v4.4s, v5.4s}, [x8], #32\n"
|
||||
"ld1 {v16.4s, v17.4s}, [x5], #32\n"
|
||||
"ld1 {v20.4s, v21.4s}, [x6], #32\n"
|
||||
"ld1 {v24.4s, v25.4s}, [x7], #32\n"
|
||||
"ld1 {v28.4s, v29.4s}, [x9], #32\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v1.4s, v16.4s, v28.4s\n"
|
||||
"fmla v2.4s, v20.4s, v28.4s\n"
|
||||
"fmla v3.4s, v24.4s, v28.4s\n"
|
||||
"fmla v0.4s, v5.4s, v29.4s\n"
|
||||
"fmla v1.4s, v17.4s, v29.4s\n"
|
||||
"fmla v2.4s, v21.4s, v29.4s\n"
|
||||
"fmla v3.4s, v25.4s, v29.4s\n"
|
||||
"sub x10, x10, #8\n"
|
||||
"b 7f\n"
|
||||
"4:\n" // LoopD4
|
||||
"cmp x10, #4\n"
|
||||
"blt 7f\n"
|
||||
"ld1 {v4.4s}, [x8], #16\n"
|
||||
"ld1 {v16.4s}, [x5], #16\n"
|
||||
"ld1 {v20.4s}, [x6], #16\n"
|
||||
"ld1 {v24.4s}, [x7], #16\n"
|
||||
"ld1 {v28.4s}, [x9], #16\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v1.4s, v16.4s, v28.4s\n"
|
||||
"fmla v2.4s, v20.4s, v28.4s\n"
|
||||
"fmla v3.4s, v24.4s, v28.4s\n"
|
||||
"sub x10, x10, #4\n"
|
||||
"7:\n"
|
||||
"cbz x10, 6f\n"
|
||||
"dup v4.2d, xzr\n"
|
||||
"dup v16.2d, xzr\n"
|
||||
"dup v20.2d, xzr\n"
|
||||
"dup v24.2d, xzr\n"
|
||||
"dup v28.2d, xzr\n"
|
||||
"subs x10, x10, #2\n"
|
||||
"blt 5f\n"
|
||||
"ld1 {v4.d}[0], [x8], #8\n" // LoopD2
|
||||
"ld1 {v16.d}[0], [x5], #8\n"
|
||||
"ld1 {v20.d}[0], [x6], #8\n"
|
||||
"ld1 {v24.d}[0], [x7], #8\n"
|
||||
"ld1 {v28.d}[0], [x9], #8\n"
|
||||
"cbz x10, 8f\n"
|
||||
"5:\n" // LoopD1
|
||||
"ld1 {v4.s}[2], [x8]\n"
|
||||
"ld1 {v16.s}[2], [x5]\n"
|
||||
"ld1 {v20.s}[2], [x6]\n"
|
||||
"ld1 {v24.s}[2], [x7]\n"
|
||||
"ld1 {v28.s}[2], [x9]\n"
|
||||
"8:\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v1.4s, v16.4s, v28.4s\n"
|
||||
"fmla v2.4s, v20.4s, v28.4s\n"
|
||||
"fmla v3.4s, v24.4s, v28.4s\n"
|
||||
"6:\n"
|
||||
"faddp v4.4s, v0.4s, v1.4s\n"
|
||||
"faddp v5.4s, v2.4s, v3.4s\n"
|
||||
"faddp v0.4s, v4.4s, v5.4s\n"
|
||||
"cbz %[bias], 9f\n"
|
||||
"ld1r {v1.4s}, [%[bias]]\n"
|
||||
"fadd v0.4s, v0.4s, v1.4s\n"
|
||||
"9:\n"
|
||||
"st1 {v0.4s}, [%[output]]\n"
|
||||
|
||||
:
|
||||
: [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep)
|
||||
: "cc", "x5", "x6", "x7", "x8", "x9", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18",
|
||||
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
|
||||
}
|
||||
|
||||
void MatMul2x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep) {
|
||||
// 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute
|
||||
// 9: WriteBack
|
||||
asm volatile(
|
||||
"mov x8, %[input]\n"
|
||||
"mov x9, %[weight]\n"
|
||||
"mov x10, %[deep]\n"
|
||||
"add x5, %[input], %[deep], LSL #2\n"
|
||||
"dup v0.2d, xzr\n"
|
||||
"dup v1.2d, xzr\n"
|
||||
"subs x10, x10, #16\n"
|
||||
"blt 2f\n"
|
||||
"1:\n" // LoopD16
|
||||
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n"
|
||||
"ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64\n"
|
||||
"ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v1.4s, v16.4s, v28.4s\n"
|
||||
"fmla v0.4s, v5.4s, v29.4s\n"
|
||||
"fmla v1.4s, v17.4s, v29.4s\n"
|
||||
"fmla v0.4s, v6.4s, v30.4s\n"
|
||||
"fmla v1.4s, v18.4s, v30.4s\n"
|
||||
"fmla v0.4s, v7.4s, v31.4s\n"
|
||||
"fmla v1.4s, v19.4s, v31.4s\n"
|
||||
"subs x10, x10, #16\n"
|
||||
"bge 1b\n"
|
||||
"2:\n" // LoopD12
|
||||
"adds x10, x10, #16\n"
|
||||
"cbz x10, 6f\n"
|
||||
"cmp x10, #12\n"
|
||||
"blt 3f\n"
|
||||
"ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n"
|
||||
"ld1 {v16.4s, v17.4s, v18.4s}, [x5], #48\n"
|
||||
"ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v1.4s, v16.4s, v28.4s\n"
|
||||
"fmla v0.4s, v5.4s, v29.4s\n"
|
||||
"fmla v1.4s, v17.4s, v29.4s\n"
|
||||
"fmla v0.4s, v6.4s, v30.4s\n"
|
||||
"fmla v1.4s, v18.4s, v30.4s\n"
|
||||
"sub x10, x10, #12\n"
|
||||
"b 7f\n"
|
||||
"3:\n" // LoopD8
|
||||
"cmp x10, #8\n"
|
||||
"blt 4f\n"
|
||||
"ld1 {v4.4s, v5.4s}, [x8], #32\n"
|
||||
"ld1 {v16.4s, v17.4s}, [x5], #32\n"
|
||||
"ld1 {v28.4s, v29.4s}, [x9], #32\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v1.4s, v16.4s, v28.4s\n"
|
||||
"fmla v0.4s, v5.4s, v29.4s\n"
|
||||
"fmla v1.4s, v17.4s, v29.4s\n"
|
||||
"sub x10, x10, #8\n"
|
||||
"b 7f\n"
|
||||
"4:\n" // LoopD4
|
||||
"cmp x10, #4\n"
|
||||
"blt 7f\n"
|
||||
"ld1 {v4.4s}, [x8], #16\n"
|
||||
"ld1 {v16.4s}, [x5], #16\n"
|
||||
"ld1 {v28.4s}, [x9], #16\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v1.4s, v16.4s, v28.4s\n"
|
||||
"sub x10, x10, #4\n"
|
||||
"7:\n"
|
||||
"cbz x10, 6f\n"
|
||||
"dup v4.2d, xzr\n"
|
||||
"dup v16.2d, xzr\n"
|
||||
"subs x10, x10, #2\n"
|
||||
"blt 5f\n"
|
||||
"ld1 {v4.d}[0], [x8], #8\n" // LoopD2
|
||||
"ld1 {v16.d}[0], [x5], #8\n"
|
||||
"ld1 {v28.d}[0], [x9], #8\n"
|
||||
"cbz x10, 8f\n"
|
||||
"5:\n" // LoopD1
|
||||
"ld1 {v4.s}[2], [x8]\n"
|
||||
"ld1 {v16.s}[2], [x5]\n"
|
||||
"ld1 {v28.s}[2], [x9]\n"
|
||||
"8:\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v1.4s, v16.4s, v28.4s\n"
|
||||
"6:\n"
|
||||
"faddp v4.4s, v0.4s, v1.4s\n"
|
||||
"faddp v0.4s, v4.4s, v4.4s\n"
|
||||
"cbz %[bias], 9f\n"
|
||||
"ld1r {v1.4s}, [%[bias]]\n"
|
||||
"fadd v0.2s, v0.2s, v1.2s\n"
|
||||
"9:\n"
|
||||
"st1 {v0.2s}, [%[output]]\n"
|
||||
|
||||
:
|
||||
: [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep)
|
||||
: "cc", "x5", "x8", "x9", "x10", "v0", "v1", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29",
|
||||
"v30", "v31", "memory");
|
||||
}
|
||||
|
||||
void MatMul1x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep) {
|
||||
// 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute
|
||||
// 9: WriteBack
|
||||
asm volatile(
|
||||
"mov x8, %[input]\n"
|
||||
"mov x9, %[weight]\n"
|
||||
"mov x10, %[deep]\n"
|
||||
"dup v0.2d, xzr\n"
|
||||
"subs x10, x10, #16\n"
|
||||
"blt 2f\n"
|
||||
"1:\n" // LoopD16
|
||||
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n"
|
||||
"ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v0.4s, v5.4s, v29.4s\n"
|
||||
"fmla v0.4s, v6.4s, v30.4s\n"
|
||||
"fmla v0.4s, v7.4s, v31.4s\n"
|
||||
"subs x10, x10, #16\n"
|
||||
"bge 1b\n"
|
||||
"2:\n" // LoopD12
|
||||
"adds x10, x10, #16\n"
|
||||
"cbz x10, 6f\n"
|
||||
"cmp x10, #12\n"
|
||||
"blt 3f\n"
|
||||
"ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n"
|
||||
"ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v0.4s, v5.4s, v29.4s\n"
|
||||
"fmla v0.4s, v6.4s, v30.4s\n"
|
||||
"sub x10, x10, #12\n"
|
||||
"b 7f\n"
|
||||
"3:\n" // LoopD8
|
||||
"cmp x10, #8\n"
|
||||
"blt 4f\n"
|
||||
"ld1 {v4.4s, v5.4s}, [x8], #32\n"
|
||||
"ld1 {v28.4s, v29.4s}, [x9], #32\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v0.4s, v5.4s, v29.4s\n"
|
||||
"sub x10, x10, #8\n"
|
||||
"b 7f\n"
|
||||
"4:\n" // LoopD4
|
||||
"cmp x10, #4\n"
|
||||
"blt 7f\n"
|
||||
"ld1 {v4.4s}, [x8], #16\n"
|
||||
"ld1 {v28.4s}, [x9], #16\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"sub x10, x10, #4\n"
|
||||
"7:\n"
|
||||
"cbz x10, 6f\n"
|
||||
"dup v4.2d, xzr\n"
|
||||
"subs x10, x10, #2\n"
|
||||
"blt 5f\n"
|
||||
"ld1 {v4.d}[0], [x8], #8\n" // LoopD2
|
||||
"ld1 {v28.d}[0], [x9], #8\n"
|
||||
"cbz x10, 8f\n"
|
||||
"5:\n" // LoopD1
|
||||
"ld1 {v4.s}[3], [x8]\n"
|
||||
"ld1 {v28.s}[3], [x9]\n"
|
||||
"8:\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"6:\n"
|
||||
"faddp v4.4s, v0.4s, v0.4s\n"
|
||||
"faddp v0.4s, v4.4s, v4.4s\n"
|
||||
"cbz %[bias], 9f\n"
|
||||
"ld1 {v1.s}[0], [%[bias]]\n"
|
||||
"fadd s0, s0, s1\n"
|
||||
"9:\n"
|
||||
"st1 {v0.s}[0], [%[output]]\n"
|
||||
|
||||
:
|
||||
: [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep)
|
||||
: "cc", "x8", "x9", "x10", "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29", "v30", "v31");
|
||||
}
|
||||
|
||||
void GemmIsNotPackByRow(const float *a, const float *b, float *c, const float *bias, int start_row, int end_row,
|
||||
int deep) {
|
||||
const float *input = a + start_row * deep;
|
||||
float *output = c + start_row;
|
||||
const int step = C4NUM * deep;
|
||||
for (; start_row <= end_row - C4NUM; start_row += C4NUM) {
|
||||
MatMul4x1Kernel(input, b, output, bias, deep);
|
||||
input += step;
|
||||
output += C4NUM;
|
||||
}
|
||||
for (; start_row <= end_row - C2NUM; start_row += C2NUM) {
|
||||
MatMul2x1Kernel(input, b, output, bias, deep);
|
||||
input += C2NUM * deep;
|
||||
output += C2NUM;
|
||||
}
|
||||
if (start_row == end_row - 1) {
|
||||
MatMul1x1Kernel(input, b, output, bias, deep);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -129,6 +129,10 @@ void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias,
|
|||
|
||||
void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k);
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void GemmIsNotPackByRow(const float *a, const float *b, float *c, const float *bias, int start_row, int end_row,
|
||||
int deep);
|
||||
#endif
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -29,9 +29,7 @@ using mindspore::lite::RET_OP_EXECUTE_FAILURE;
|
|||
using mindspore::schema::PrimitiveType_Transpose;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
void TransposeFp16CPUKernel::GetNchwToNhwcFunc(TypeId dtype) { NHNCTransposeFunc_ = PackNCHWToNHWCFp16; }
|
||||
|
||||
void TransposeFp16CPUKernel::GetNhwcToNchwFunc(TypeId dtype) { NHNCTransposeFunc_ = PackNHWCToNCHWFp16; }
|
||||
void TransposeFp16CPUKernel::SetOptTransposeFunc() { optTransposeFunc_ = PackNHWCToNCHWFp16; }
|
||||
|
||||
int TransposeFp16CPUKernel::TransposeDim2to6() {
|
||||
return DoTransposeFp16(static_cast<const float16_t *>(in_data_), static_cast<float16_t *>(out_data_), out_shape_,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -32,8 +32,7 @@ class TransposeFp16CPUKernel : public TransposeCPUKernel {
|
|||
~TransposeFp16CPUKernel() = default;
|
||||
|
||||
private:
|
||||
void GetNchwToNhwcFunc(TypeId dtype) override;
|
||||
void GetNhwcToNchwFunc(TypeId dtype) override;
|
||||
void SetOptTransposeFunc() override;
|
||||
int TransposeDim2to6() override;
|
||||
int TransposeDimGreaterThan6(int task_id) override;
|
||||
};
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "src/runtime/kernel/arm/fp32/bias_fp32.h"
|
||||
#include <vector>
|
||||
#include "nnacl/fp32/bias_add.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
|
@ -27,39 +28,13 @@ using mindspore::lite::RET_OK;
|
|||
using mindspore::schema::PrimitiveType_BiasAdd;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int BiasCPUKernel::ReSize() {
|
||||
auto dims = in_tensors_.at(0)->shape();
|
||||
bias_param_->ndim_ = dims.size();
|
||||
if (bias_param_->ndim_ < 1 || bias_param_->ndim_ > 5) {
|
||||
MS_LOG(ERROR) << "input shape is invalid";
|
||||
return RET_ERROR;
|
||||
int BiasAddRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
CHECK_NULL_RETURN(cdata);
|
||||
auto kernel = reinterpret_cast<BiasCPUKernel *>(cdata);
|
||||
auto ret = kernel->DoExecute(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "BatchnormRun error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
}
|
||||
for (size_t i = 0; i < bias_param_->ndim_; i++) {
|
||||
bias_param_->in_shape0_[i] = dims[i];
|
||||
bias_param_->in_shape1_[i] = 1;
|
||||
bias_param_->out_shape_[i] = dims[i];
|
||||
}
|
||||
bias_param_->in_shape1_[bias_param_->ndim_ - 1] = dims[bias_param_->ndim_ - 1];
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int BiasCPUKernel::Run() {
|
||||
auto in = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
|
||||
auto bias = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
|
||||
auto out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
|
||||
size_t data_size = static_cast<size_t>(in_tensors_.at(0)->ElementsNum());
|
||||
CHECK_NULL_RETURN(ms_context_->allocator);
|
||||
float *tile_in = reinterpret_cast<float *>(ms_context_->allocator->Malloc(data_size * sizeof(float)));
|
||||
float *tile_bias = reinterpret_cast<float *>(ms_context_->allocator->Malloc(data_size * sizeof(float)));
|
||||
if (tile_in == nullptr || tile_bias == nullptr) {
|
||||
MS_LOG(ERROR) << "Memory allocation failed";
|
||||
ms_context_->allocator->Free(tile_in);
|
||||
ms_context_->allocator->Free(tile_bias);
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ret = BroadcastAdd(in, bias, tile_in, tile_bias, out, static_cast<int>(data_size), bias_param_);
|
||||
ms_context_->allocator->Free(tile_in);
|
||||
ms_context_->allocator->Free(tile_bias);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -73,5 +48,79 @@ int BiasCPUKernel::Prepare() {
|
|||
return ReSize();
|
||||
}
|
||||
|
||||
int BiasCPUKernel::ReSize() {
|
||||
auto in_dims = in_tensors_.at(0)->shape();
|
||||
auto bias_dims = in_tensors_.at(1)->shape();
|
||||
if (bias_dims.empty() || in_dims.empty() || in_dims.size() < bias_dims.size()) {
|
||||
MS_LOG(ERROR) << "inTensors' shape are invalid.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t dim_offset = in_dims.size() - bias_dims.size();
|
||||
inner_num_ = 1;
|
||||
for (size_t i = 0; i < bias_dims.size(); ++i) {
|
||||
if (in_dims[i + dim_offset] != bias_dims[i]) {
|
||||
MS_LOG(ERROR) << "inTensors' shape cannot match.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW(bias_dims[i], inner_num_), RET_ERROR, "mul overflow.");
|
||||
inner_num_ *= bias_dims[i];
|
||||
}
|
||||
outer_num_ = 1;
|
||||
for (size_t i = 0; i < dim_offset; ++i) {
|
||||
MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW(in_dims[i], outer_num_), RET_ERROR, "mul overflow.");
|
||||
outer_num_ *= in_dims[i];
|
||||
}
|
||||
MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW(inner_num_, outer_num_), RET_ERROR, "mul overflow.");
|
||||
total_num_ = inner_num_ * outer_num_;
|
||||
GetThreadSegmentInfos();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void BiasCPUKernel::GetThreadSegmentInfos() {
|
||||
split_start_points_ = std::vector<int64_t>(op_parameter_->thread_num_, 0);
|
||||
split_end_points_ = std::vector<int64_t>(op_parameter_->thread_num_, 0);
|
||||
int64_t step = MSMAX(total_num_ / op_parameter_->thread_num_, C128NUM);
|
||||
int64_t remain_data = MSMAX(total_num_ - step * op_parameter_->thread_num_, 0);
|
||||
for (int i = 0; i < op_parameter_->thread_num_; ++i) {
|
||||
if (i == 0) {
|
||||
split_end_points_[i] = MSMIN(step, total_num_) + (i < remain_data ? 1 : 0);
|
||||
continue;
|
||||
}
|
||||
split_start_points_[i] = split_end_points_[i - 1];
|
||||
if (split_start_points_[i] >= total_num_) {
|
||||
split_start_points_[i] = 0;
|
||||
break;
|
||||
}
|
||||
split_end_points_[i] =
|
||||
split_start_points_[i] + MSMIN(step, total_num_ - split_start_points_[i]) + (i < remain_data ? 1 : 0);
|
||||
}
|
||||
MS_ASSERT(inner_num_ != 0);
|
||||
if (inner_num_ >= C64NUM && step / inner_num_ >= C6NUM) {
|
||||
batch_priority_ = true;
|
||||
} else {
|
||||
batch_priority_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
int BiasCPUKernel::Run() {
|
||||
auto ret = ParallelLaunch(this->ms_context_, BiasAddRun, this, op_parameter_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "BiasAddRun error error_code[" << ret << "]";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int BiasCPUKernel::DoExecute(int task_id) {
|
||||
auto input = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
|
||||
auto bias = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
|
||||
auto output = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
|
||||
if (split_start_points_[task_id] == split_end_points_[task_id]) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
BiasAddOpt(input, bias, output, split_start_points_[task_id], split_end_points_[task_id], inner_num_,
|
||||
batch_priority_);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BiasAdd, LiteKernelCreator<BiasCPUKernel>)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -33,9 +33,17 @@ class BiasCPUKernel : public InnerKernel {
|
|||
int Prepare() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int DoExecute(int task_id);
|
||||
|
||||
private:
|
||||
void GetThreadSegmentInfos();
|
||||
ArithmeticParameter *bias_param_;
|
||||
bool batch_priority_{false};
|
||||
int64_t inner_num_{0};
|
||||
int64_t outer_num_{0};
|
||||
int64_t total_num_{0};
|
||||
std::vector<int64_t> split_start_points_;
|
||||
std::vector<int64_t> split_end_points_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ int MatmulRun(const void *cdata, int task_id, float, float) {
|
|||
MatmulFp32BaseCPUKernel::~MatmulFp32BaseCPUKernel() {
|
||||
FreeResizeBufA();
|
||||
FreeResizeBufB();
|
||||
if (is_pack_ && out_need_aligned_ && oc_res_ != 0 && output_data_ != nullptr) {
|
||||
if (out_need_aligned_ && output_data_ != nullptr) {
|
||||
free(output_data_);
|
||||
output_data_ = nullptr;
|
||||
}
|
||||
|
@ -287,6 +287,34 @@ int MatmulFp32BaseCPUKernel::ParallelRunByBatch(int task_id) const {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
#if defined(ENABLE_AVX) || defined(ENABLE_AVX512) || defined(ENABLE_ARM64)
|
||||
int MatmulFp32BaseCPUKernel::ParallelRunByRow(int task_id) const {
|
||||
int start_row = row_split_points_[task_id];
|
||||
int end_row = row_num_;
|
||||
if (task_id < (thread_count_ - 1)) {
|
||||
end_row = row_split_points_[task_id + 1];
|
||||
}
|
||||
int row_num = end_row - start_row;
|
||||
if (row_num <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
#if defined(ENABLE_AVX512)
|
||||
const float *input = a_pack_ptr_ + start_row * params_->deep_;
|
||||
float *output = output_data_ + start_row * params_->col_align_;
|
||||
MatMulAvx512Fp32(input, b_pack_ptr_, output, bias_ptr_, params_->act_type_, params_->deep_, params_->col_align_,
|
||||
params_->col_align_, row_num);
|
||||
#elif defined(ENABLE_AVX)
|
||||
const float *input = a_pack_ptr_ + start_row * params_->deep_;
|
||||
float *output = output_data_ + start_row * params_->col_align_;
|
||||
MatMulAvxFp32(input, b_pack_ptr_, output, bias_ptr_, params_->act_type_, params_->deep_, params_->col_align_,
|
||||
params_->col_align_, row_num);
|
||||
#elif defined(ENABLE_ARM64)
|
||||
GemmIsNotPackByRow(a_pack_ptr_, b_pack_ptr_, output_data_, bias_ptr_, start_row, end_row, params_->deep_);
|
||||
#endif
|
||||
return RET_OK;
|
||||
}
|
||||
#endif
|
||||
|
||||
int MatmulFp32BaseCPUKernel::ParallelRunIsNotPackByBatch(int task_id) const {
|
||||
int start_batch = task_id * batch_stride_;
|
||||
int end_batch = MSMIN(params_->batch, start_batch + batch_stride_);
|
||||
|
@ -295,8 +323,8 @@ int MatmulFp32BaseCPUKernel::ParallelRunIsNotPackByBatch(int task_id) const {
|
|||
bias = bias_ptr_[0];
|
||||
}
|
||||
for (int index = start_batch; index < end_batch; ++index) {
|
||||
const float *a = a_pack_ptr_ + index * params_->row_ * params_->deep_;
|
||||
const float *b = b_pack_ptr_ + index * params_->deep_ * params_->col_;
|
||||
const float *a = a_pack_ptr_ + a_offset_[index] * params_->row_ * params_->deep_;
|
||||
const float *b = b_pack_ptr_ + b_offset_[index] * params_->deep_ * params_->col_;
|
||||
float *c = output_data_ + index * params_->row_ * params_->col_;
|
||||
gemmIsNotPackFun(a, b, c, &bias, params_->row_, params_->deep_);
|
||||
}
|
||||
|
@ -375,28 +403,28 @@ int MatmulFp32BaseCPUKernel::init_global_variable() {
|
|||
row_tile_ = C12NUM;
|
||||
col_tile_ = C8NUM;
|
||||
#endif
|
||||
if (params_->col_ == 1 && !params_->a_const_) {
|
||||
is_pack_ = false;
|
||||
out_need_aligned_ = false;
|
||||
row_tile_ = 1;
|
||||
col_tile_ = 1;
|
||||
matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor;
|
||||
matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor;
|
||||
}
|
||||
params_->row_align_ = UP_ROUND(params_->row_, row_tile_);
|
||||
params_->col_align_ = UP_ROUND(params_->col_, col_tile_);
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_, params_->row_align_, RET_ERROR);
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->row_align_, params_->deep_, RET_ERROR);
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_, params_->col_align_, RET_ERROR);
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->col_align_, params_->deep_, RET_ERROR);
|
||||
matrix_a_pack_size_ = a_batch_ * params_->row_align_ * params_->deep_;
|
||||
matrix_b_pack_size_ = b_batch_ * params_->col_align_ * params_->deep_;
|
||||
#if defined(ENABLE_AVX) || defined(ENABLE_AVX512)
|
||||
col_step_ = params_->col_align_;
|
||||
#else
|
||||
// need not aligned
|
||||
col_step_ = params_->col_;
|
||||
#endif
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_, params_->row_align_, RET_ERROR);
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->row_align_, params_->deep_, RET_ERROR);
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_, params_->col_align_, RET_ERROR);
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->col_align_, params_->deep_, RET_ERROR);
|
||||
if (params_->col_ == 1 && params_->b_const_) {
|
||||
is_pack_ = false;
|
||||
matrix_a_pack_size_ = a_batch_ * params_->row_ * params_->deep_;
|
||||
matrix_b_pack_size_ = b_batch_ * params_->col_ * params_->deep_;
|
||||
matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor;
|
||||
matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor;
|
||||
} else {
|
||||
matrix_a_pack_size_ = a_batch_ * params_->row_align_ * params_->deep_;
|
||||
matrix_b_pack_size_ = b_batch_ * params_->col_align_ * params_->deep_;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -455,6 +483,8 @@ int MatmulFp32BaseCPUKernel::Prepare() {
|
|||
|
||||
int MatmulFp32BaseCPUKernel::ReSize() {
|
||||
ResizeParameter();
|
||||
MS_CHECK_FALSE(INT_MUL_OVERFLOW(a_batch_, params_->row_), RET_ERROR);
|
||||
row_num_ = a_batch_ * params_->row_;
|
||||
matrix_a_pack_size_ = a_batch_ * params_->row_align_ * params_->deep_;
|
||||
matrix_b_pack_size_ = b_batch_ * params_->col_align_ * params_->deep_;
|
||||
if (matrix_a_pack_size_ < 0 || matrix_b_pack_size_ < 0) {
|
||||
|
@ -465,8 +495,12 @@ int MatmulFp32BaseCPUKernel::ReSize() {
|
|||
if (op_parameter_->is_train_session_) {
|
||||
set_workspace_size((matrix_a_pack_size_ + matrix_b_pack_size_) * static_cast<int>(sizeof(float)));
|
||||
}
|
||||
GetThreadCuttingPolicy();
|
||||
auto ret = InitTmpOutBuffer();
|
||||
auto ret = GetThreadCuttingPolicy();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ThreadCuttingPolicy error!";
|
||||
return ret;
|
||||
}
|
||||
ret = InitTmpOutBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InitTmpOutBuffer error!";
|
||||
return ret;
|
||||
|
@ -483,11 +517,11 @@ void MatmulFp32BaseCPUKernel::ResizeParameter() {
|
|||
vec_matmul_ = false;
|
||||
}
|
||||
params_->row_align_ = UP_ROUND(params_->row_, row_tile_);
|
||||
oc_res_ = params_->col_ % col_tile_;
|
||||
out_need_aligned_ = (out_need_aligned_ && ((params_->col_ % col_tile_) != 0));
|
||||
}
|
||||
|
||||
int MatmulFp32BaseCPUKernel::InitTmpOutBuffer() {
|
||||
if (is_pack_ && out_need_aligned_ && oc_res_ != 0) {
|
||||
if (out_need_aligned_) {
|
||||
if (output_data_ != nullptr) {
|
||||
free(output_data_);
|
||||
}
|
||||
|
@ -505,12 +539,22 @@ int MatmulFp32BaseCPUKernel::InitTmpOutBuffer() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
void MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() {
|
||||
if (params_->batch >= op_parameter_->thread_num_ || (params_->col_ == 1 && params_->b_const_)) {
|
||||
int MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() {
|
||||
if (params_->batch >= op_parameter_->thread_num_ || (params_->col_ == 1 && !params_->a_const_)) {
|
||||
thread_count_ = op_parameter_->thread_num_;
|
||||
batch_stride_ = UP_DIV(params_->batch, thread_count_);
|
||||
batch_split_ = true;
|
||||
parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunByBatch;
|
||||
} else if (CheckThreadCuttingByRow()) {
|
||||
#if defined(ENABLE_AVX) || defined(ENABLE_AVX512)
|
||||
is_pack_ = !params_->b_const_;
|
||||
batch_split_ = true;
|
||||
parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunByRow;
|
||||
GetThreadCuttingInfoByRow();
|
||||
#else
|
||||
MS_LOG(ERROR) << "current branch only support avx.";
|
||||
return RET_ERROR;
|
||||
#endif
|
||||
} else {
|
||||
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(params_->col_align_, col_tile_));
|
||||
#if defined(ENABLE_AVX) || defined(ENABLE_AVX512) // thread tile by col_tile * C4NUM
|
||||
|
@ -521,21 +565,57 @@ void MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() {
|
|||
batch_split_ = false;
|
||||
parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunByOC;
|
||||
}
|
||||
if (params_->col_ == 1 && params_->b_const_) {
|
||||
if (params_->col_ == 1 && !params_->a_const_) {
|
||||
is_pack_ = false;
|
||||
batch_split_ = true;
|
||||
parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunIsNotPackByBatch;
|
||||
if (params_->deep_ == 1) {
|
||||
gemmIsNotPackFun = GemmIsNotPack;
|
||||
} else {
|
||||
gemmIsNotPackFun = GemmIsNotPackOptimize;
|
||||
#ifdef ENABLE_ARM64
|
||||
if (b_batch_ == 1) {
|
||||
parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunByRow;
|
||||
GetThreadCuttingInfoByRow();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
bool MatmulFp32BaseCPUKernel::CheckThreadCuttingByRow() {
|
||||
if (b_batch_ != C1NUM) {
|
||||
return false;
|
||||
}
|
||||
#if defined(ENABLE_AVX) || defined(ENABLE_AVX512)
|
||||
if (row_num_ >= op_parameter_->thread_num_) {
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
void MatmulFp32BaseCPUKernel::GetThreadCuttingInfoByRow() {
|
||||
int row_step = MSMAX(row_num_ / op_parameter_->thread_num_, C64NUM);
|
||||
int row_remaining = MSMAX(row_num_ - row_step * op_parameter_->thread_num_, 0);
|
||||
row_split_points_.resize(op_parameter_->thread_num_);
|
||||
for (size_t i = 0; i < row_split_points_.size(); ++i) {
|
||||
if (i == 0) {
|
||||
row_split_points_[i] = 0;
|
||||
continue;
|
||||
}
|
||||
row_split_points_[i] =
|
||||
MSMIN(row_split_points_[i - 1] + row_step + (static_cast<int>(i) < row_remaining ? 1 : 0), row_num_);
|
||||
}
|
||||
int unused_thread_num = std::count(row_split_points_.begin(), row_split_points_.end(), row_num_);
|
||||
thread_count_ = op_parameter_->thread_num_ - unused_thread_num;
|
||||
}
|
||||
|
||||
int MatmulFp32BaseCPUKernel::Run() {
|
||||
auto out_data = reinterpret_cast<float *>(out_tensors_.front()->data());
|
||||
CHECK_NULL_RETURN(out_data);
|
||||
if (!is_pack_ || !out_need_aligned_ || oc_res_ == 0) {
|
||||
if (!out_need_aligned_) {
|
||||
output_data_ = out_data;
|
||||
}
|
||||
if (!params_->b_const_) {
|
||||
|
@ -557,9 +637,7 @@ int MatmulFp32BaseCPUKernel::Run() {
|
|||
if (!params_->a_const_) {
|
||||
auto a_ptr = reinterpret_cast<float *>(in_tensors_[0]->data());
|
||||
CHECK_NULL_RETURN(a_ptr);
|
||||
if (!is_pack_) {
|
||||
a_pack_ptr_ = a_ptr;
|
||||
} else {
|
||||
if (is_pack_ || (params_->a_transpose_ && params_->deep_ != 1)) {
|
||||
if (InitBufferA() != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -568,10 +646,12 @@ int MatmulFp32BaseCPUKernel::Run() {
|
|||
MS_LOG(ERROR) << "InitMatrixA failed!";
|
||||
return ret;
|
||||
}
|
||||
} else {
|
||||
a_pack_ptr_ = a_ptr;
|
||||
}
|
||||
}
|
||||
|
||||
if (batch_split_ || !is_pack_) {
|
||||
if (batch_split_) {
|
||||
auto ret = ParallelLaunch(this->ms_context_, MatmulRun, this, thread_count_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "MatmulRun failed in split by batch";
|
||||
|
@ -590,11 +670,13 @@ int MatmulFp32BaseCPUKernel::Run() {
|
|||
}
|
||||
}
|
||||
|
||||
if (oc_res_ != 0 && out_need_aligned_ && is_pack_) {
|
||||
if (out_need_aligned_) {
|
||||
PackNHWCXToNHWCFp32(output_data_, out_data, params_->batch, params_->row_, params_->col_, col_tile_);
|
||||
} else {
|
||||
output_data_ = nullptr;
|
||||
}
|
||||
if (!params_->a_const_) {
|
||||
if (is_pack_) {
|
||||
if (is_pack_ || (params_->a_transpose_ && params_->deep_ != 1)) {
|
||||
FreeResizeBufA();
|
||||
} else {
|
||||
a_pack_ptr_ = nullptr;
|
||||
|
|
|
@ -51,12 +51,14 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
|
|||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
#if defined(ENABLE_AVX) || defined(ENABLE_AVX512) || defined(ENABLE_ARM64)
|
||||
int ParallelRunByRow(int task_id) const;
|
||||
#endif
|
||||
int ParallelRunByOC(int task_id) const;
|
||||
int ParallelRunByBatch(int task_id) const;
|
||||
int ParallelRunIsNotPackByBatch(int task_id) const;
|
||||
using ParallelRun = int (MatmulFp32BaseCPUKernel::*)(int task_id) const;
|
||||
ParallelRun parallel_fun_ = nullptr;
|
||||
bool is_pack_ = true;
|
||||
|
||||
protected:
|
||||
int InitBufferA();
|
||||
|
@ -74,7 +76,9 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
|
|||
void FreeResizeBufB();
|
||||
int CalBroadCastBiasDataElements();
|
||||
int InitTmpOutBuffer();
|
||||
void GetThreadCuttingPolicy();
|
||||
int GetThreadCuttingPolicy();
|
||||
bool CheckThreadCuttingByRow();
|
||||
void GetThreadCuttingInfoByRow();
|
||||
|
||||
protected:
|
||||
MatMulParameter *params_ = nullptr;
|
||||
|
@ -92,7 +96,6 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
|
|||
private:
|
||||
int col_tile_ = 0;
|
||||
int row_tile_ = 0;
|
||||
int oc_res_ = 0;
|
||||
int batch_stride_ = 0;
|
||||
int oc_stride_ = 0;
|
||||
int thread_count_ = 0;
|
||||
|
@ -107,6 +110,7 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
|
|||
MatrixPackFun matrix_a_pack_fun_ = nullptr;
|
||||
MatrixPackFun matrix_b_pack_fun_ = nullptr;
|
||||
bool batch_split_ = false;
|
||||
bool is_pack_ = true;
|
||||
bool out_need_aligned_ = false;
|
||||
int col_step_ = 0;
|
||||
#if defined(ENABLE_AVX) || defined(ENABLE_AVX512)
|
||||
|
@ -114,6 +118,8 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
|
|||
GemvFun gemvCalFun = nullptr;
|
||||
#endif
|
||||
GemmIsNotPackFun gemmIsNotPackFun = nullptr;
|
||||
int row_num_;
|
||||
std::vector<int> row_split_points_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_MATMUL_FP32_BASE_H_
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -36,21 +36,31 @@ int TransposeCPUKernel::Prepare() {
|
|||
}
|
||||
|
||||
int TransposeCPUKernel::ReSize() {
|
||||
auto &inTensor = in_tensors_.front();
|
||||
auto in_shape = inTensor->shape();
|
||||
if (in_tensors_.size() == 2) {
|
||||
param_->num_axes_ = in_tensors_.at(1)->ElementsNum();
|
||||
}
|
||||
int trans3d[3] = {0, 2, 1};
|
||||
if (in_shape.size() > MAX_TRANSPOSE_DIM_SIZE) {
|
||||
MS_LOG(ERROR) << "input shape out of range.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int transNd[MAX_TRANSPOSE_DIM_SIZE] = {0, 2, 1};
|
||||
int *perm_data = nullptr;
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
if (input_tensor->shape().size() != static_cast<size_t>(param_->num_axes_)) {
|
||||
if (input_tensor->shape().size() == 3 && param_->num_axes_ == 4) {
|
||||
param_->num_axes_ = 3;
|
||||
perm_data = trans3d;
|
||||
} else {
|
||||
return RET_OK;
|
||||
perm_data = transNd;
|
||||
if (input_tensor->shape().size() == C3NUM && param_->num_axes_ == C4NUM) {
|
||||
param_->num_axes_ = C3NUM;
|
||||
}
|
||||
if (param_->num_axes_ == 0) {
|
||||
for (int i = 0; i < static_cast<int>(in_shape.size()); ++i) {
|
||||
transNd[i] = static_cast<int>(in_shape.size()) - 1 - i;
|
||||
}
|
||||
param_->num_axes_ = static_cast<int>(in_shape.size());
|
||||
}
|
||||
} else {
|
||||
MS_ASSERT(in_tensors_.size() == 2);
|
||||
MS_ASSERT(in_tensors_.size() == C2NUM);
|
||||
auto perm_tensor = in_tensors_.at(1);
|
||||
if (perm_tensor->data_type() != kNumberTypeInt32) {
|
||||
MS_LOG(ERROR) << "Unsupported type id: " << perm_tensor->data_type() << " of perm tensor.";
|
||||
|
@ -59,30 +69,31 @@ int TransposeCPUKernel::ReSize() {
|
|||
perm_data = reinterpret_cast<int *>(perm_tensor->data());
|
||||
MSLITE_CHECK_PTR(perm_data);
|
||||
}
|
||||
if (param_->num_axes_ > MAX_TRANSPOSE_DIM_SIZE || param_->num_axes_ < 0) {
|
||||
MS_LOG(ERROR) << "num_axes_ " << param_->num_axes_ << "is invalid.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(param_->num_axes_ <= MAX_TRANSPOSE_DIM_SIZE, RET_ERROR, "transpose's perm is invalid.");
|
||||
for (int i = 0; i < param_->num_axes_; ++i) {
|
||||
param_->perm_[i] = perm_data[i];
|
||||
}
|
||||
|
||||
for (int i = 0; i < param_->num_axes_; i++) {
|
||||
if (param_->perm_[i] < 0 || param_->perm_[i] >= param_->num_axes_) {
|
||||
MS_LOG(ERROR) << "Check perm failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (GetOptParameters() != RET_OK) {
|
||||
MS_LOG(ERROR) << "cannot compute optimizer parameters.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
DecideIfOnlyCopy();
|
||||
if (only_copy_) {
|
||||
return RET_OK;
|
||||
}
|
||||
GetOptTransposeFunc();
|
||||
if (optTransposeFunc_ != nullptr) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
auto &inTensor = in_tensors_.front();
|
||||
auto &outTensor = out_tensors_.front();
|
||||
auto in_shape = inTensor->shape();
|
||||
auto out_shape = outTensor->shape();
|
||||
param_->strides_[param_->num_axes_ - 1] = 1;
|
||||
param_->out_strides_[param_->num_axes_ - 1] = 1;
|
||||
param_->data_num_ = inTensor->ElementsNum();
|
||||
MS_CHECK_LE(static_cast<size_t>(param_->num_axes_), in_shape.size(), RET_ERROR);
|
||||
MS_CHECK_LE(static_cast<size_t>(param_->num_axes_), out_shape.size(), RET_ERROR);
|
||||
MS_CHECK_TRUE_RET(static_cast<size_t>(param_->num_axes_) == in_shape.size(), RET_ERROR);
|
||||
MS_CHECK_TRUE_RET(static_cast<size_t>(param_->num_axes_) == out_shape.size(), RET_ERROR);
|
||||
for (int i = param_->num_axes_ - 2; i >= 0; i--) {
|
||||
param_->strides_[i] = in_shape.at(i + 1) * param_->strides_[i + 1];
|
||||
param_->out_strides_[i] = out_shape.at(i + 1) * param_->out_strides_[i + 1];
|
||||
|
@ -102,24 +113,104 @@ int TransposeCPUKernel::ReSize() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int TransposeCPUKernel::GetOptParameters() {
|
||||
auto in_shape = in_tensors_[0]->shape();
|
||||
if (in_shape.size() != static_cast<size_t>(param_->num_axes_)) {
|
||||
return RET_OK;
|
||||
}
|
||||
for (int i = 0; i < param_->num_axes_; i++) {
|
||||
if (param_->perm_[i] < 0 || param_->perm_[i] >= param_->num_axes_) {
|
||||
MS_LOG(ERROR) << "Check perm failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
std::vector<std::vector<int>> segments;
|
||||
for (int i = 0; i < param_->num_axes_;) {
|
||||
std::vector<int> segment{param_->perm_[i]};
|
||||
++i;
|
||||
for (; i < param_->num_axes_; ++i) {
|
||||
if (param_->perm_[i] - 1 != param_->perm_[i - 1]) {
|
||||
break;
|
||||
}
|
||||
segment.push_back(param_->perm_[i]);
|
||||
}
|
||||
segments.push_back(segment);
|
||||
}
|
||||
in_shape_opt_ = std::vector<int>(segments.size(), 1);
|
||||
perm_opt_ = std::vector<int>(segments.size(), 0);
|
||||
for (size_t i = 0; i < segments.size(); ++i) {
|
||||
for (size_t j = 0; j < segments.size(); ++j) {
|
||||
perm_opt_[i] += (segments[j].front() < segments[i].front() ? 1 : 0);
|
||||
}
|
||||
for (auto index : segments[i]) {
|
||||
MS_CHECK_FALSE(INT_MUL_OVERFLOW(in_shape_opt_[perm_opt_[i]], in_shape[index]), RET_ERROR);
|
||||
in_shape_opt_[perm_opt_[i]] *= in_shape[index];
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void TransposeCPUKernel::DecideIfOnlyCopy() {
|
||||
auto in_shape = in_tensors_[0]->shape();
|
||||
int dim = 0;
|
||||
if (in_shape.size() != static_cast<size_t>(param_->num_axes_) || perm_opt_.size() == 1) {
|
||||
only_copy_ = true;
|
||||
return;
|
||||
}
|
||||
dim = 0;
|
||||
std::vector<int> need_trans_dims;
|
||||
std::for_each(perm_opt_.begin(), perm_opt_.end(), [&dim, &need_trans_dims](int val) {
|
||||
if (val != dim) {
|
||||
need_trans_dims.push_back(dim);
|
||||
}
|
||||
++dim;
|
||||
});
|
||||
if (need_trans_dims.size() == C2NUM && need_trans_dims.back() - need_trans_dims.front() == C1NUM) {
|
||||
if (in_shape_opt_[need_trans_dims.front()] == 1 || in_shape_opt_[need_trans_dims.back()] == 1) {
|
||||
only_copy_ = true;
|
||||
return;
|
||||
}
|
||||
}
|
||||
only_copy_ = false;
|
||||
}
|
||||
|
||||
void TransposeCPUKernel::SetOptTransposeFunc() { optTransposeFunc_ = PackNHWCToNCHWFp32; }
|
||||
|
||||
int TransposeCPUKernel::GetOptTransposeFunc() {
|
||||
if (in_tensors_[0]->data_type() != kNumberTypeFloat32 || perm_opt_.size() > C3NUM || perm_opt_.size() < C2NUM) {
|
||||
optTransposeFunc_ = nullptr;
|
||||
return RET_OK;
|
||||
}
|
||||
bool trans_last_two_dim{true};
|
||||
for (size_t i = 0; i < perm_opt_.size() - C2NUM; ++i) {
|
||||
if (perm_opt_[i] != static_cast<int>(i)) {
|
||||
trans_last_two_dim = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!trans_last_two_dim) {
|
||||
optTransposeFunc_ = nullptr;
|
||||
return RET_OK;
|
||||
}
|
||||
SetOptTransposeFunc();
|
||||
if (perm_opt_.size() == C2NUM) {
|
||||
nhnc_param_[FIRST_INPUT] = 1;
|
||||
nhnc_param_[SECOND_INPUT] = in_shape_opt_.front();
|
||||
nhnc_param_[THIRD_INPUT] = in_shape_opt_.back();
|
||||
} else {
|
||||
nhnc_param_[FIRST_INPUT] = in_shape_opt_.front();
|
||||
nhnc_param_[SECOND_INPUT] = in_shape_opt_[SECOND_INPUT];
|
||||
nhnc_param_[THIRD_INPUT] = in_shape_opt_.back();
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
TransposeCPUKernel::~TransposeCPUKernel() {
|
||||
if (this->out_shape_ != nullptr) {
|
||||
free(this->out_shape_);
|
||||
}
|
||||
}
|
||||
|
||||
void TransposeCPUKernel::GetNchwToNhwcFunc(TypeId dtype) {
|
||||
if (dtype == kNumberTypeFloat32) {
|
||||
NHNCTransposeFunc_ = PackNCHWToNHWCFp32;
|
||||
}
|
||||
}
|
||||
|
||||
void TransposeCPUKernel::GetNhwcToNchwFunc(TypeId dtype) {
|
||||
if (dtype == kNumberTypeFloat32) {
|
||||
NHNCTransposeFunc_ = PackNHWCToNCHWFp32;
|
||||
}
|
||||
}
|
||||
|
||||
int TransposeCPUKernel::TransposeDim2to6() {
|
||||
return DoTransposeFp32(static_cast<const float *>(in_data_), static_cast<float *>(out_data_), out_shape_, param_);
|
||||
}
|
||||
|
@ -130,34 +221,35 @@ int TransposeCPUKernel::TransposeDimGreaterThan6(int task_id) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int TransposeCPUKernel::GetNHNCTransposeFunc(const lite::Tensor *in_tensor, const lite::Tensor *out_tensor) {
|
||||
if (in_tensor->shape().size() != 4) {
|
||||
int TransposeCPUKernel::CopyInputToOutput() {
|
||||
auto in_tensor = in_tensors().front();
|
||||
CHECK_NULL_RETURN(in_tensor);
|
||||
auto out_tensor = out_tensors().front();
|
||||
CHECK_NULL_RETURN(out_tensor);
|
||||
if (in_tensor->allocator() == nullptr || in_tensor->allocator() != out_tensor->allocator() ||
|
||||
in_tensor->allocator() != ms_context_->allocator || op_parameter_->is_train_session_ ||
|
||||
((in_tensor->IsGraphInput() || in_tensor->IsGraphOutput()) && out_tensor->IsGraphOutput())) {
|
||||
CHECK_NULL_RETURN(out_tensor->data());
|
||||
CHECK_NULL_RETURN(in_tensor->data());
|
||||
MS_CHECK_FALSE(in_tensor->Size() == 0, RET_ERROR);
|
||||
if (in_tensor->data() != out_tensor->data()) {
|
||||
memcpy(out_tensor->data(), in_tensor->data(), in_tensor->Size());
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
auto out_shape = out_tensor->shape();
|
||||
if (param_->perm_[FIRST_INPUT] == FIRST_INPUT && param_->perm_[SECOND_INPUT] == THIRD_INPUT &&
|
||||
param_->perm_[THIRD_INPUT] == FOURTH_INPUT && param_->perm_[FOURTH_INPUT] == SECOND_INPUT) {
|
||||
nhnc_param_[FIRST_INPUT] = out_shape[FIRST_INPUT];
|
||||
MS_CHECK_FALSE(INT_MUL_OVERFLOW(out_shape[SECOND_INPUT], out_shape[THIRD_INPUT]), RET_ERROR);
|
||||
nhnc_param_[SECOND_INPUT] = out_shape[SECOND_INPUT] * out_shape[THIRD_INPUT];
|
||||
nhnc_param_[THIRD_INPUT] = out_shape[FOURTH_INPUT];
|
||||
GetNchwToNhwcFunc(in_tensor->data_type());
|
||||
}
|
||||
if (param_->perm_[FIRST_INPUT] == FIRST_INPUT && param_->perm_[SECOND_INPUT] == FOURTH_INPUT &&
|
||||
param_->perm_[THIRD_INPUT] == SECOND_INPUT && param_->perm_[FOURTH_INPUT] == THIRD_INPUT) {
|
||||
nhnc_param_[FIRST_INPUT] = out_shape[FIRST_INPUT];
|
||||
MS_CHECK_FALSE(INT_MUL_OVERFLOW(out_shape[THIRD_INPUT], out_shape[FOURTH_INPUT]), RET_ERROR);
|
||||
nhnc_param_[SECOND_INPUT] = out_shape[THIRD_INPUT] * out_shape[FOURTH_INPUT];
|
||||
nhnc_param_[THIRD_INPUT] = out_shape[SECOND_INPUT];
|
||||
GetNhwcToNchwFunc(in_tensor->data_type());
|
||||
}
|
||||
|
||||
out_tensor->FreeData();
|
||||
out_tensor->ResetRefCount();
|
||||
in_tensor->allocator()->IncRefCount(in_tensor->data(), out_tensor->ref_count());
|
||||
out_tensor->set_data(in_tensor->data());
|
||||
out_tensor->set_own_data(in_tensor->own_data());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int TransposeCPUKernel::RunImpl(int task_id) {
|
||||
if (NHNCTransposeFunc_ != nullptr) {
|
||||
NHNCTransposeFunc_(in_data_, out_data_, nhnc_param_[FIRST_INPUT], nhnc_param_[SECOND_INPUT],
|
||||
nhnc_param_[THIRD_INPUT], task_id, op_parameter_->thread_num_);
|
||||
if (optTransposeFunc_ != nullptr) {
|
||||
optTransposeFunc_(in_data_, out_data_, nhnc_param_[FIRST_INPUT], nhnc_param_[SECOND_INPUT],
|
||||
nhnc_param_[THIRD_INPUT], task_id, op_parameter_->thread_num_);
|
||||
} else {
|
||||
return TransposeDimGreaterThan6(task_id);
|
||||
}
|
||||
|
@ -176,6 +268,9 @@ int TransposeImpl(void *kernel, int task_id, float lhs_scale, float rhs_scale) {
|
|||
int TransposeCPUKernel::Run() {
|
||||
MS_ASSERT(in_tensors_.size() == 1 || in_tensors_.size() == 2);
|
||||
MS_ASSERT(out_tensors_.size() == 1);
|
||||
if (only_copy_) {
|
||||
return CopyInputToOutput();
|
||||
}
|
||||
auto &in_tensor = in_tensors_.front();
|
||||
auto &out_tensor = out_tensors_.front();
|
||||
if (in_tensor == nullptr || out_tensor == nullptr) {
|
||||
|
@ -186,16 +281,7 @@ int TransposeCPUKernel::Run() {
|
|||
out_data_ = out_tensor->data();
|
||||
CHECK_NULL_RETURN(in_data_);
|
||||
CHECK_NULL_RETURN(out_data_);
|
||||
|
||||
if (in_tensor->shape().size() != static_cast<size_t>(param_->num_axes_)) {
|
||||
memcpy(out_data_, in_data_, in_tensor->Size());
|
||||
return RET_OK;
|
||||
}
|
||||
if (GetNHNCTransposeFunc(in_tensor, out_tensor) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Get NHWC tranpose func fail!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (NHNCTransposeFunc_ != nullptr) {
|
||||
if (optTransposeFunc_ != nullptr) {
|
||||
return ParallelLaunch(this->ms_context_, TransposeImpl, this, op_parameter_->thread_num_);
|
||||
}
|
||||
if (out_tensor->shape().size() <= DIMENSION_6D) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -43,18 +43,28 @@ class TransposeCPUKernel : public InnerKernel {
|
|||
int RunImpl(int task_id);
|
||||
|
||||
protected:
|
||||
virtual void GetNchwToNhwcFunc(TypeId dtype);
|
||||
virtual void GetNhwcToNchwFunc(TypeId dtype);
|
||||
virtual void SetOptTransposeFunc();
|
||||
virtual int TransposeDim2to6();
|
||||
virtual int TransposeDimGreaterThan6(int task_id);
|
||||
|
||||
int GetNHNCTransposeFunc(const lite::Tensor *in_tensor, const lite::Tensor *out_tensor);
|
||||
private:
|
||||
int GetOptParameters();
|
||||
void DecideIfOnlyCopy();
|
||||
int GetOptTransposeFunc();
|
||||
int CopyInputToOutput();
|
||||
|
||||
protected:
|
||||
void *in_data_ = nullptr;
|
||||
void *out_data_ = nullptr;
|
||||
int *out_shape_ = nullptr;
|
||||
TransposeParameter *param_ = nullptr;
|
||||
TransposeFunc NHNCTransposeFunc_ = nullptr;
|
||||
TransposeFunc optTransposeFunc_ = nullptr;
|
||||
|
||||
private:
|
||||
int nhnc_param_[3] = {0};
|
||||
bool only_copy_{false};
|
||||
std::vector<int> in_shape_opt_;
|
||||
std::vector<int> perm_opt_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
Loading…
Reference in New Issue