forked from mindspore-Ecosystem/mindspore
!19001 support fp16-fp32 cast on non-fp16-feature-device
Merge pull request !19001 from hangq/wood
This commit is contained in:
commit
4f0a30a3cf
|
@ -34,6 +34,7 @@
|
|||
#include "src/runtime/infer_manager.h"
|
||||
#include "src/sub_graph_split.h"
|
||||
#include "src/weight_decoder.h"
|
||||
#include "nnacl/nnacl_common.h"
|
||||
#if GPU_OPENCL
|
||||
#include "src/runtime/kernel/opencl/opencl_subgraph.h"
|
||||
#include "src/runtime/gpu/opencl/opencl_runtime.h"
|
||||
|
@ -271,7 +272,6 @@ int Scheduler::InferSubGraphShape(size_t subgraph_index) {
|
|||
|
||||
namespace {
|
||||
int CastConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_origin_tensors, TypeId dst_data_type) {
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
MS_ASSERT(tensor != nullptr);
|
||||
MS_ASSERT(tensor->IsConst());
|
||||
MS_ASSERT(tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeFloat16);
|
||||
|
@ -294,9 +294,25 @@ int CastConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_o
|
|||
auto new_tensor_data = tensor->data_c();
|
||||
MS_ASSERT(new_tensor_data != nullptr);
|
||||
if (dst_data_type == kNumberTypeFloat32) {
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
Float16ToFloat32_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum());
|
||||
#else
|
||||
auto src_data = reinterpret_cast<uint16_t *>(origin_data);
|
||||
auto dst_data = reinterpret_cast<float *>(new_tensor_data);
|
||||
for (int i = 0; i < tensor->ElementsNum(); i++) {
|
||||
dst_data[i] = ShortToFloat32(src_data[i]);
|
||||
}
|
||||
#endif
|
||||
} else { // dst_data_type == kNumberTypeFloat16
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
Float32ToFloat16_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum());
|
||||
#else
|
||||
auto src_data = reinterpret_cast<float *>(origin_data);
|
||||
auto dst_data = reinterpret_cast<uint16_t *>(new_tensor_data);
|
||||
for (int i = 0; i < tensor->ElementsNum(); i++) {
|
||||
dst_data[i] = Float32ToShort(src_data[i]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if (restored_origin_tensors->find(tensor) != restored_origin_tensors->end()) {
|
||||
MS_LOG(ERROR) << "Tensor " << tensor->tensor_name() << " is already be stored";
|
||||
|
@ -304,9 +320,6 @@ int CastConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_o
|
|||
}
|
||||
(*restored_origin_tensors)[tensor] = restore_tensor;
|
||||
return RET_OK;
|
||||
#else
|
||||
return RET_NOT_SUPPORT;
|
||||
#endif
|
||||
}
|
||||
|
||||
int CastConstTensorsData(const std::vector<Tensor *> &tensors, std::map<Tensor *, Tensor *> *restored_origin_tensors,
|
||||
|
|
Loading…
Reference in New Issue