From 058519c9f9e57defc7dc1d4b2c85812d1d15cb04 Mon Sep 17 00:00:00 2001 From: sunsuodong Date: Thu, 24 Sep 2020 11:40:31 +0800 Subject: [PATCH] fix_split_fp16 --- .../src/runtime/kernel/arm/fp16/split_fp16.cc | 50 ++++++++----------- .../src/runtime/kernel/arm/fp16/split_fp16.h | 3 +- 2 files changed, 22 insertions(+), 31 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/split_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/split_fp16.cc index a2d8dc34d82..ccbd265ad01 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/split_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/split_fp16.cc @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "nnacl/fp16/cast_fp16.h" -#include "nnacl/fp16/split_fp16.h" #include "src/runtime/kernel/arm/fp16/split_fp16.h" +#include "src/runtime/kernel/arm/fp16/common_fp16.h" #include "src/runtime/kernel/arm/base/split_base.h" +#include "nnacl/fp16/split_fp16.h" #include "nnacl/split.h" #include "nnacl/split_parameter.h" #include "src/kernel_registry.h" @@ -36,9 +36,10 @@ int SplitFp16CPUKernel::Init() { if (ret != RET_OK) { return ret; } - output_ptr_.resize(param->num_split_); - + for (size_t i = 0; i < output_ptr_.size(); i++) { + output_ptr_[i] = nullptr; + } if (!InferShapeDone()) { return RET_OK; } @@ -79,48 +80,37 @@ int SplitFp16CPUKernel::Run() { MS_LOG(ERROR) << "Prepare failed."; return RET_ERROR; } - auto in_tensor = in_tensors_.front(); - if (in_tensor->data_type() == kNumberTypeFloat32) { - input_ptr_ = - reinterpret_cast(context_->allocator->Malloc(in_tensor->ElementsNum() * sizeof(float16_t))); - if (input_ptr_ == nullptr) { - MS_LOG(ERROR) << "malloc input_ptr_ failed."; - return RET_ERROR; - } - Float32ToFloat16(reinterpret_cast(in_tensor->MutableData()), input_ptr_, in_tensor->ElementsNum()); - } else { - input_ptr_ = reinterpret_cast(in_tensor->MutableData()); + input_ptr_ = ConvertInputFp32toFp16(in_tensors_.at(0), context_); + if (input_ptr_ == nullptr) { + MS_LOG(ERROR) << "input or output is nullptr"; + return RET_ERROR; } for (int i = 0; i < param->num_split_; i++) { - if (in_tensor->data_type() == kNumberTypeFloat32) { - output_ptr_[i] = reinterpret_cast( - context_->allocator->Malloc(out_tensors_.at(i)->ElementsNum() * sizeof(float16_t))); - if (output_ptr_[i] == nullptr) { - MS_LOG(ERROR) << "malloc output_ptr_[" << i << "]" << " failed."; - return RET_ERROR; - } - Float32ToFloat16(reinterpret_cast(out_tensors_.at(i)->MutableData()), output_ptr_[i], - out_tensors_.at(i)->ElementsNum()); - } else { - output_ptr_[i] = reinterpret_cast(out_tensors_.at(i)->MutableData()); + output_ptr_[i] = MallocOutputFp16(out_tensors_.at(i), context_); + if (output_ptr_[i] == nullptr) { + FreeInputAndOutput(); + MS_LOG(ERROR) << "input or output is nullptr"; + return RET_ERROR; } } ret = ParallelLaunch(this->context_->thread_pool_, SplitRun, this, thread_n_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "split error error_code[" << ret << "]"; - return RET_ERROR; } - if (in_tensor->data_type() == kNumberTypeFloat32) { + return ret; +} + +void SplitFp16CPUKernel::FreeInputAndOutput() { + if (in_tensors_.at(0)->data_type() == kNumberTypeFloat32) { context_->allocator->Free(input_ptr_); input_ptr_ = nullptr; } for (int i = 0; i < param->num_split_; i++) { - if (in_tensor->data_type() == kNumberTypeFloat32) { + if (out_tensors_.at(i)->data_type() == kNumberTypeFloat32) { context_->allocator->Free(output_ptr_[i]); output_ptr_[i] = nullptr; } } - return RET_OK; } kernel::LiteKernel *CpuSplitFp16KernelCreator(const std::vector &inputs, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/split_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/split_fp16.h index 1bf3f4a57cb..e10bbcea604 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/split_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/split_fp16.h @@ -37,8 +37,9 @@ class SplitFp16CPUKernel : public SplitBaseCPUKernel { int Split(int task_id); private: - float16_t *input_ptr_; + float16_t *input_ptr_ = nullptr; std::vector output_ptr_; + void FreeInputAndOutput(); }; } // namespace mindspore::kernel