From 542bfaa49fc0583af809ad2c5078d183a119c2bc Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Mon, 28 Jun 2021 15:43:59 +0800 Subject: [PATCH] support fp16 cast in non-fp16-feature-device --- mindspore/lite/src/scheduler.cc | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index 7e7a1cbb759..8efabf691b1 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -34,6 +34,7 @@ #include "src/runtime/infer_manager.h" #include "src/sub_graph_split.h" #include "src/weight_decoder.h" +#include "nnacl/nnacl_common.h" #if GPU_OPENCL #include "src/runtime/kernel/opencl/opencl_subgraph.h" #include "src/runtime/gpu/opencl/opencl_runtime.h" @@ -271,7 +272,6 @@ int Scheduler::InferSubGraphShape(size_t subgraph_index) { namespace { int CastConstTensorData(Tensor *tensor, std::map *restored_origin_tensors, TypeId dst_data_type) { -#if defined(ENABLE_ARM) && defined(ENABLE_FP16) MS_ASSERT(tensor != nullptr); MS_ASSERT(tensor->IsConst()); MS_ASSERT(tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeFloat16); @@ -294,9 +294,25 @@ int CastConstTensorData(Tensor *tensor, std::map *restored_o auto new_tensor_data = tensor->data_c(); MS_ASSERT(new_tensor_data != nullptr); if (dst_data_type == kNumberTypeFloat32) { +#if defined(ENABLE_ARM) && defined(ENABLE_FP16) Float16ToFloat32_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum()); +#else + auto src_data = reinterpret_cast(origin_data); + auto dst_data = reinterpret_cast(new_tensor_data); + for (int i = 0; i < tensor->ElementsNum(); i++) { + dst_data[i] = ShortToFloat32(src_data[i]); + } +#endif } else { // dst_data_type == kNumberTypeFloat16 +#if defined(ENABLE_ARM) && defined(ENABLE_FP16) Float32ToFloat16_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum()); +#else + auto src_data = reinterpret_cast(origin_data); + auto dst_data = reinterpret_cast(new_tensor_data); + for (int i = 0; i < tensor->ElementsNum(); i++) { + dst_data[i] = Float32ToShort(src_data[i]); + } +#endif } if (restored_origin_tensors->find(tensor) != restored_origin_tensors->end()) { MS_LOG(ERROR) << "Tensor " << tensor->tensor_name() << " is already be stored"; @@ -304,9 +320,6 @@ int CastConstTensorData(Tensor *tensor, std::map *restored_o } (*restored_origin_tensors)[tensor] = restore_tensor; return RET_OK; -#else - return RET_NOT_SUPPORT; -#endif } int CastConstTensorsData(const std::vector &tensors, std::map *restored_origin_tensors,