forked from mindspore-Ecosystem/mindspore
[MS][LITE] optimize arm cpu fp16 op: conv depthwise fp16
This commit is contained in:
parent
124b35dba1
commit
b7fa2cd6a9
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
|
@ -177,10 +178,22 @@ int ConvolutionDepthwiseFp16CPUKernel::Run() {
|
|||
}
|
||||
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto input_addr = reinterpret_cast<float *>(input_tensor->Data());
|
||||
float16_t *input_addr;
|
||||
if (input_tensor->data_type() == kNumberTypeFloat32) {
|
||||
input_addr =
|
||||
reinterpret_cast<float16_t *>(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t)));
|
||||
if (input_addr == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
Float32ToFloat16(reinterpret_cast<float *>(input_tensor->Data()), input_addr, input_tensor->ElementsNum());
|
||||
} else {
|
||||
input_addr = reinterpret_cast<float16_t *>(input_tensor->Data());
|
||||
}
|
||||
|
||||
// pack input: to nhwc8
|
||||
PackNHWCFp32ToNHWC8Fp16(input_addr, packed_input_, conv_param_->input_batch_,
|
||||
conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_);
|
||||
PackNHWCToNHWC8Fp16(input_addr, packed_input_, conv_param_->input_batch_,
|
||||
conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_);
|
||||
|
||||
ret = LiteBackendParallelLaunch(ConvDwFp16Run, this, conv_param_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
|
@ -188,10 +201,13 @@ int ConvolutionDepthwiseFp16CPUKernel::Run() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->Data());
|
||||
PackNHWC8Fp16ToNHWCFp32(packed_output_, output_addr, conv_param_->output_batch_,
|
||||
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
|
||||
auto output_addr = reinterpret_cast<float16_t *>(out_tensors_.at(kOutputIndex)->Data());
|
||||
PackNHWC8ToNHWCFp16(packed_output_, output_addr, conv_param_->output_batch_,
|
||||
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
|
||||
|
||||
if (input_tensor->data_type() == kNumberTypeFloat32) {
|
||||
context_->allocator->Free(input_addr);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -334,31 +334,57 @@ void PackNCHWFp32ToNC8HW8Fp16(float *src, float16_t *dst, int batch, int plane,
|
|||
}
|
||||
|
||||
void PackNHWCFp32ToNHWC8Fp16(float *src, float16_t *dst, int batch, int plane, int channel) {
|
||||
int c8 = UP_DIV(channel, C8NUM);
|
||||
int nhwc8_batch_unit_offset = c8 * C8NUM * plane;
|
||||
int nhwc8_batch_offset = 0;
|
||||
int c8_channel = UP_DIV(channel, C8NUM) * C8NUM;
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int batch_offset = b * channel * plane;
|
||||
float16_t *dst_batch = dst + b * plane * c8_channel;
|
||||
float *src_batch = src + b * plane * channel;
|
||||
for (int i = 0; i < plane; i++) {
|
||||
float16_t *dst_plane = dst_batch + i * c8_channel;
|
||||
float *src_plane = src_batch + i * channel;
|
||||
for (int c = 0; c < channel; c++) {
|
||||
(dst + nhwc8_batch_offset + i * c8 * C8NUM)[c] = (float16_t)(src + batch_offset + i * channel)[c];
|
||||
dst_plane[c] = (float16_t)(src_plane[c]);
|
||||
}
|
||||
}
|
||||
nhwc8_batch_offset += nhwc8_batch_unit_offset;
|
||||
}
|
||||
}
|
||||
|
||||
void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel) {
|
||||
int c8 = UP_DIV(channel, C8NUM);
|
||||
int nhwc_batch_unit_offset = channel * plane;
|
||||
int nhwc_batch_offset = 0;
|
||||
int c8_channel = UP_DIV(channel, C8NUM) * C8NUM;
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int batch_offset = b * c8 * C8NUM * plane;
|
||||
float16_t *src_batch = src + b * plane * c8_channel;
|
||||
float *dst_batch = dst + b * plane * channel;
|
||||
for (int i = 0; i < plane; i++) {
|
||||
float16_t *src_plane = src_batch + i * c8_channel;
|
||||
float *dst_plane = dst_batch + i * channel;
|
||||
for (int c = 0; c < channel; c++) {
|
||||
(dst + nhwc_batch_offset + i * channel)[c] = (float)(src + batch_offset + i * c8 * C8NUM)[c];
|
||||
dst_plane[c] = (float16_t)(src_plane[c]);
|
||||
}
|
||||
}
|
||||
nhwc_batch_offset += nhwc_batch_unit_offset;
|
||||
}
|
||||
}
|
||||
|
||||
void PackNHWCToNHWC8Fp16(float16_t *src, float16_t *dst, int batch, int plane, int channel) {
|
||||
int c8_channel = UP_DIV(channel, C8NUM) * C8NUM;
|
||||
for (int b = 0; b < batch; b++) {
|
||||
float16_t *dst_batch = dst + b * plane * c8_channel;
|
||||
float16_t *src_batch = src + b * plane * channel;
|
||||
for (int i = 0; i < plane; i++) {
|
||||
float16_t *dst_plane = dst_batch + i * c8_channel;
|
||||
float16_t *src_plane = src_batch + i * channel;
|
||||
memcpy(dst_plane, src_batch, channel * sizeof(float16_t));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PackNHWC8ToNHWCFp16(float16_t *src, float16_t *dst, int batch, int plane, int channel) {
|
||||
int c8_channel = UP_DIV(channel, C8NUM) * C8NUM;
|
||||
for (int b = 0; b < batch; b++) {
|
||||
float16_t *src_batch = src + b * plane * c8_channel;
|
||||
float16_t *dst_batch = dst + b * plane * channel;
|
||||
for (int i = 0; i < plane; i++) {
|
||||
float16_t *src_plane = src_batch + i * c8_channel;
|
||||
float16_t *dst_plane = dst_batch + i * channel;
|
||||
memcpy(dst_plane, src_batch, channel * sizeof(float16_t));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -58,6 +58,10 @@ void PackNCHWFp32ToNC8HW8Fp16(float *src, float16_t *dst, int batch, int plane,
|
|||
void PackNHWCFp32ToNHWC8Fp16(float *src, float16_t *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWCToNHWC8Fp16(float16_t *src, float16_t *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWC8ToNHWCFp16(float16_t *src, float16_t *dst, int batch, int plane, int channel);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
Loading…
Reference in New Issue