!19001 support fp16-fp32 cast on non-fp16-feature-device

Merge pull request !19001 from hangq/wood
This commit is contained in:
i-robot 2021-06-28 09:01:43 +00:00 committed by Gitee
commit 4f0a30a3cf
1 changed files with 17 additions and 4 deletions

View File

@ -34,6 +34,7 @@
#include "src/runtime/infer_manager.h"
#include "src/sub_graph_split.h"
#include "src/weight_decoder.h"
#include "nnacl/nnacl_common.h"
#if GPU_OPENCL
#include "src/runtime/kernel/opencl/opencl_subgraph.h"
#include "src/runtime/gpu/opencl/opencl_runtime.h"
@ -271,7 +272,6 @@ int Scheduler::InferSubGraphShape(size_t subgraph_index) {
namespace {
int CastConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_origin_tensors, TypeId dst_data_type) {
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
MS_ASSERT(tensor != nullptr);
MS_ASSERT(tensor->IsConst());
MS_ASSERT(tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeFloat16);
@ -294,9 +294,25 @@ int CastConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_o
auto new_tensor_data = tensor->data_c();
MS_ASSERT(new_tensor_data != nullptr);
if (dst_data_type == kNumberTypeFloat32) {
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
Float16ToFloat32_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum());
#else
auto src_data = reinterpret_cast<uint16_t *>(origin_data);
auto dst_data = reinterpret_cast<float *>(new_tensor_data);
for (int i = 0; i < tensor->ElementsNum(); i++) {
dst_data[i] = ShortToFloat32(src_data[i]);
}
#endif
} else { // dst_data_type == kNumberTypeFloat16
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
Float32ToFloat16_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum());
#else
auto src_data = reinterpret_cast<float *>(origin_data);
auto dst_data = reinterpret_cast<uint16_t *>(new_tensor_data);
for (int i = 0; i < tensor->ElementsNum(); i++) {
dst_data[i] = Float32ToShort(src_data[i]);
}
#endif
}
if (restored_origin_tensors->find(tensor) != restored_origin_tensors->end()) {
MS_LOG(ERROR) << "Tensor " << tensor->tensor_name() << " is already be stored";
@ -304,9 +320,6 @@ int CastConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_o
}
(*restored_origin_tensors)[tensor] = restore_tensor;
return RET_OK;
#else
return RET_NOT_SUPPORT;
#endif
}
int CastConstTensorsData(const std::vector<Tensor *> &tensors, std::map<Tensor *, Tensor *> *restored_origin_tensors,