!19039 fix bug while checking device and package support fp16
Merge pull request !19039 from hangq/wood
This commit is contained in:
commit
e14c86cdb5
|
@ -17,6 +17,8 @@
|
|||
#define MINDSPORE_NNACL_CAST_FP16_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
#include <arm_neon.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
@ -88,4 +90,5 @@ void Float16ToFloat32(const float16_t *input, float *output, int number);
|
|||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
#endif // MINDSPORE_NNACL_CAST_FP16_H_
|
||||
|
|
|
@ -103,6 +103,9 @@ void CpuInfo::GetArmProcCpuInfo(AndroidCpuInfo *android_cpu_info) {
|
|||
}
|
||||
|
||||
bool CpuInfo::ArmIsSupportFp16() {
|
||||
#ifdef MS_COMPILE_IOS
|
||||
return false;
|
||||
#else
|
||||
#ifdef ENABLE_ARM32
|
||||
GetArmProcCpuInfo(&android_cpu_info_);
|
||||
midr_ = MidrSetPart(android_cpu_info_.cpu_part);
|
||||
|
@ -142,6 +145,7 @@ bool CpuInfo::ArmIsSupportFp16() {
|
|||
}
|
||||
#endif
|
||||
return fp16_flag_;
|
||||
#endif
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
#endif
|
||||
|
|
|
@ -31,11 +31,11 @@ InnerContext::InnerContext(const Context *context) {
|
|||
this->thread_num_ = context->thread_num_;
|
||||
this->enable_parallel_ = context->enable_parallel_;
|
||||
SetContextDevice(context);
|
||||
#ifdef ENABLE_ARM
|
||||
#ifndef MS_COMPILE_IOS
|
||||
cpu_info_ = new CpuInfo;
|
||||
fp16_flag_ = cpu_info_->ArmIsSupportFp16();
|
||||
#endif
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
CpuInfo cpu_info;
|
||||
device_and_pkg_support_fp16_ = cpu_info.ArmIsSupportFp16();
|
||||
#else
|
||||
device_and_pkg_support_fp16_ = false;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -114,11 +114,6 @@ InnerContext::~InnerContext() {
|
|||
delete thread_pool_;
|
||||
this->thread_pool_ = nullptr;
|
||||
}
|
||||
#ifdef ENABLE_ARM
|
||||
#ifndef MS_COMPILE_IOS
|
||||
delete cpu_info_;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
int InnerContext::IsValid() const {
|
||||
|
@ -168,7 +163,7 @@ bool InnerContext::IsCpuFloat16Enabled() const {
|
|||
if (!IsCpuEnabled()) {
|
||||
return false;
|
||||
}
|
||||
if (!IsSupportFloat16()) {
|
||||
if (!device_and_pkg_support_fp16_) {
|
||||
return false;
|
||||
}
|
||||
return GetCpuInfo().enable_float16_;
|
||||
|
@ -286,11 +281,10 @@ NpuDeviceInfo InnerContext::GetNpuInfo() const {
|
|||
}
|
||||
}
|
||||
|
||||
// Support CPU backend to judge whether it supports Float16.
|
||||
bool InnerContext::IsSupportFloat16() const { return fp16_flag_; }
|
||||
|
||||
ActorThreadPool *InnerContext::thread_pool() const { return thread_pool_; }
|
||||
|
||||
bool InnerContext::device_and_pkg_support_fp16() const { return this->device_and_pkg_support_fp16_; }
|
||||
|
||||
int ParallelLaunch(const Context *context, const Func &func, Content content, int task_num) {
|
||||
ActorThreadPool *pool = static_cast<const lite::InnerContext *>(context)->thread_pool();
|
||||
if (pool == nullptr) {
|
||||
|
|
|
@ -60,6 +60,8 @@ struct InnerContext : public Context {
|
|||
|
||||
virtual ~InnerContext();
|
||||
|
||||
bool device_and_pkg_support_fp16() const;
|
||||
|
||||
private:
|
||||
bool IsAllDeviceTypeValid() const;
|
||||
|
||||
|
@ -71,19 +73,11 @@ struct InnerContext : public Context {
|
|||
|
||||
bool IsUserSetNpu() const;
|
||||
|
||||
bool IsSupportFloat16() const;
|
||||
|
||||
void SetContextDevice(const Context *context);
|
||||
|
||||
bool fp16_flag_ = false;
|
||||
bool device_and_pkg_support_fp16_ = false;
|
||||
|
||||
ActorThreadPool *thread_pool_{nullptr};
|
||||
|
||||
#ifdef ENABLE_ARM
|
||||
#ifndef MS_COMPILE_IOS
|
||||
CpuInfo *cpu_info_ = nullptr;
|
||||
#endif
|
||||
#endif
|
||||
};
|
||||
|
||||
int ParallelLaunch(const Context *context, const Func &func, Content content, int task_num);
|
||||
|
|
|
@ -248,7 +248,7 @@ void LiteOpActor::CopyInputData(Tensor *dst_tensor, Tensor *src_tensor) {
|
|||
}
|
||||
|
||||
int LiteOpActor::CastTensorData(Tensor *dst, Tensor *src) {
|
||||
#ifdef ENABLE_FP16
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
if (dst->shape() != src->shape()) {
|
||||
MS_LOG(ERROR) << "dst tensor: " << dst->tensor_name() << " shape: " << dst->shape() << " vs "
|
||||
<< "src tensor: " << src->tensor_name() << " shape: " << src->shape();
|
||||
|
@ -261,9 +261,9 @@ int LiteOpActor::CastTensorData(Tensor *dst, Tensor *src) {
|
|||
auto src_data_type = static_cast<int>(src->data_type());
|
||||
|
||||
if (dst_data_type == kNumberTypeFloat32 && src_data_type == kNumberTypeFloat16) {
|
||||
Float16ToFloat32_fp16_handler(src_data, dst_data, src_nums_size);
|
||||
Float16ToFloat32_fp16_handler(src_data, dst_data, src_nums_size, support_fp16_);
|
||||
} else if (dst_data_type == kNumberTypeFloat16 && src_data_type == kNumberTypeFloat32) {
|
||||
Float32ToFloat16_fp16_handler(src_data, dst_data, src_nums_size);
|
||||
Float32ToFloat16_fp16_handler(src_data, dst_data, src_nums_size, support_fp16_);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "not support dst_data_type: " << dst_data_type << " src_data_type: " << src_data_type;
|
||||
return RET_NOT_SUPPORT;
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "async/uuid_base.h"
|
||||
#include "async/future.h"
|
||||
#include "src/sub_graph_kernel.h"
|
||||
#include "src/cpu_info.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
|
||||
|
@ -39,6 +40,10 @@ class LiteOpActor : public OpActor<lite::Tensor> {
|
|||
public:
|
||||
explicit LiteOpActor(kernel::LiteKernel *kernel) : OpActor<lite::Tensor>(kernel->name()), kernel_(kernel) {
|
||||
inputs_data_.resize(kernel_->in_tensors().size());
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
CpuInfo cpu_info;
|
||||
support_fp16_ = cpu_info.ArmIsSupportFp16();
|
||||
#endif
|
||||
}
|
||||
~LiteOpActor() override {
|
||||
for (auto map : isolate_input_map_) {
|
||||
|
@ -89,6 +94,7 @@ class LiteOpActor : public OpActor<lite::Tensor> {
|
|||
kernel::LiteKernel *partial_node_ = nullptr;
|
||||
kernel::LiteKernel *call_node_ = nullptr;
|
||||
std::unordered_map<Tensor *, Tensor *> isolate_input_map_; /* <calculate-tensor, src-input-tensor> */
|
||||
bool support_fp16_ = false;
|
||||
};
|
||||
|
||||
class LiteSwitchOpActor : public LiteOpActor {
|
||||
|
|
|
@ -600,7 +600,7 @@ int LiteSession::RunGraph(const KernelCallBack &before, const KernelCallBack &af
|
|||
MS_LOG(ERROR) << "CheckInputs failed.";
|
||||
return ret;
|
||||
}
|
||||
MS_ASSERT(this->context_);
|
||||
MS_ASSERT(this->context_ != nullptr);
|
||||
if (before == nullptr && after == nullptr) {
|
||||
ret = executor_->Run(this->inputs_, this->outputs_, this->kernels_, this->context_->allocator.get());
|
||||
} else {
|
||||
|
|
|
@ -17,15 +17,41 @@
|
|||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/fp16/cast_fp16.h"
|
||||
#include "nnacl/nnacl_common.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
static inline void Float32ToFloat16_fp16_handler(const void *input, void *output, int number) {
|
||||
Float32ToFloat16(reinterpret_cast<const float *>(input), reinterpret_cast<float16_t *>(output), number);
|
||||
static inline void Float32ToFloat16_fp16_handler(const void *input, void *output, int number, bool support_fp16) {
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
if (support_fp16) {
|
||||
Float32ToFloat16(reinterpret_cast<const float *>(input), reinterpret_cast<float16_t *>(output), number);
|
||||
} else {
|
||||
#endif
|
||||
auto src_data = reinterpret_cast<const float *>(input);
|
||||
auto dst_data = reinterpret_cast<uint16_t *>(output);
|
||||
for (int i = 0; i < number; i++) {
|
||||
dst_data[i] = Float32ToShort(src_data[i]);
|
||||
}
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
}
|
||||
#endif
|
||||
}
|
||||
static inline void Float16ToFloat32_fp16_handler(const void *input, void *output, int number) {
|
||||
Float16ToFloat32(reinterpret_cast<const float16_t *>(input), reinterpret_cast<float *>(output), number);
|
||||
|
||||
static inline void Float16ToFloat32_fp16_handler(const void *input, void *output, int number, bool support_fp16) {
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
if (support_fp16) {
|
||||
Float16ToFloat32(reinterpret_cast<const float16_t *>(input), reinterpret_cast<float *>(output), number);
|
||||
} else {
|
||||
#endif
|
||||
auto src_data = reinterpret_cast<const uint16_t *>(input);
|
||||
auto dst_data = reinterpret_cast<float *>(output);
|
||||
for (int i = 0; i < number; i++) {
|
||||
dst_data[i] = ShortToFloat32(src_data[i]);
|
||||
}
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -34,14 +34,12 @@
|
|||
#include "src/runtime/infer_manager.h"
|
||||
#include "src/sub_graph_split.h"
|
||||
#include "src/weight_decoder.h"
|
||||
#include "src/runtime/kernel/arm/fp16/fp16_op_handler.h"
|
||||
#include "nnacl/nnacl_common.h"
|
||||
#if GPU_OPENCL
|
||||
#include "src/runtime/kernel/opencl/opencl_subgraph.h"
|
||||
#include "src/runtime/gpu/opencl/opencl_runtime.h"
|
||||
#endif
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
#include "src/runtime/kernel/arm/fp16/fp16_op_handler.h"
|
||||
#endif
|
||||
#include "include/registry/kernel_interface.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
|
@ -271,7 +269,9 @@ int Scheduler::InferSubGraphShape(size_t subgraph_index) {
|
|||
}
|
||||
|
||||
namespace {
|
||||
int CastConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_origin_tensors, TypeId dst_data_type) {
|
||||
// support_fp16: current device and package support float16
|
||||
int CastConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_origin_tensors, TypeId dst_data_type,
|
||||
bool support_fp16) {
|
||||
MS_ASSERT(tensor != nullptr);
|
||||
MS_ASSERT(tensor->IsConst());
|
||||
MS_ASSERT(tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeFloat16);
|
||||
|
@ -294,25 +294,9 @@ int CastConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_o
|
|||
auto new_tensor_data = tensor->data_c();
|
||||
MS_ASSERT(new_tensor_data != nullptr);
|
||||
if (dst_data_type == kNumberTypeFloat32) {
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
Float16ToFloat32_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum());
|
||||
#else
|
||||
auto src_data = reinterpret_cast<uint16_t *>(origin_data);
|
||||
auto dst_data = reinterpret_cast<float *>(new_tensor_data);
|
||||
for (int i = 0; i < tensor->ElementsNum(); i++) {
|
||||
dst_data[i] = ShortToFloat32(src_data[i]);
|
||||
}
|
||||
#endif
|
||||
Float16ToFloat32_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum(), support_fp16);
|
||||
} else { // dst_data_type == kNumberTypeFloat16
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
Float32ToFloat16_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum());
|
||||
#else
|
||||
auto src_data = reinterpret_cast<float *>(origin_data);
|
||||
auto dst_data = reinterpret_cast<uint16_t *>(new_tensor_data);
|
||||
for (int i = 0; i < tensor->ElementsNum(); i++) {
|
||||
dst_data[i] = Float32ToShort(src_data[i]);
|
||||
}
|
||||
#endif
|
||||
Float32ToFloat16_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum(), support_fp16);
|
||||
}
|
||||
if (restored_origin_tensors->find(tensor) != restored_origin_tensors->end()) {
|
||||
MS_LOG(ERROR) << "Tensor " << tensor->tensor_name() << " is already be stored";
|
||||
|
@ -322,8 +306,9 @@ int CastConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_o
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
// support_fp16: current device and package support float16
|
||||
int CastConstTensorsData(const std::vector<Tensor *> &tensors, std::map<Tensor *, Tensor *> *restored_origin_tensors,
|
||||
TypeId dst_data_type) {
|
||||
TypeId dst_data_type, bool support_fp16) {
|
||||
MS_ASSERT(restored_origin_tensors != nullptr);
|
||||
if (dst_data_type != kNumberTypeFloat32 && dst_data_type != kNumberTypeFloat16) {
|
||||
MS_LOG(ERROR) << "Only support fp32 or fp16 as dst_data_type.";
|
||||
|
@ -341,13 +326,13 @@ int CastConstTensorsData(const std::vector<Tensor *> &tensors, std::map<Tensor *
|
|||
continue;
|
||||
}
|
||||
if (tensor->data_type() == kNumberTypeFloat32 && dst_data_type == kNumberTypeFloat16) {
|
||||
auto ret = CastConstTensorData(tensor, restored_origin_tensors, kNumberTypeFloat16);
|
||||
auto ret = CastConstTensorData(tensor, restored_origin_tensors, kNumberTypeFloat16, support_fp16);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(DEBUG) << "Cast const tensor from fp32 to fp16 failed, tensor name : " << tensor->tensor_name();
|
||||
return ret;
|
||||
}
|
||||
} else if (tensor->data_type() == kNumberTypeFloat16 && dst_data_type == kNumberTypeFloat32) {
|
||||
auto ret = CastConstTensorData(tensor, restored_origin_tensors, kNumberTypeFloat32);
|
||||
auto ret = CastConstTensorData(tensor, restored_origin_tensors, kNumberTypeFloat32, support_fp16);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(DEBUG) << "Cast const tensor from fp16 to fp32 failed, tensor name : " << tensor->tensor_name();
|
||||
return ret;
|
||||
|
@ -437,7 +422,8 @@ int Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std:
|
|||
}
|
||||
std::map<Tensor *, Tensor *> restored_origin_tensors;
|
||||
|
||||
ret = CastConstTensorsData(in_tensors, &restored_origin_tensors, kernel_data_type);
|
||||
ret = CastConstTensorsData(in_tensors, &restored_origin_tensors, kernel_data_type,
|
||||
context_->device_and_pkg_support_fp16());
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(DEBUG) << "CastConstTensorsData failed: " << ret;
|
||||
return RET_NOT_SUPPORT;
|
||||
|
|
|
@ -226,7 +226,7 @@ int CpuSubGraph::Execute(const KernelCallBack &before, const KernelCallBack &aft
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
#ifdef ENABLE_FP16
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
void CpuFp16SubGraph::FreeOriginInputData() {
|
||||
for (auto &iter : this->origin_input_data_) {
|
||||
auto *data_store = iter.second;
|
||||
|
@ -269,7 +269,7 @@ int CpuFp16SubGraph::Float32TensorToFloat16Tensor(lite::Tensor *tensor) {
|
|||
return RET_ERROR;
|
||||
}
|
||||
MS_ASSERT(tensor->data_c() != nullptr);
|
||||
Float32ToFloat16_fp16_handler(float32_data, tensor->data_c(), tensor->ElementsNum());
|
||||
Float32ToFloat16_fp16_handler(float32_data, tensor->data_c(), tensor->ElementsNum(), support_fp16_);
|
||||
if (tensor->allocator() != nullptr) {
|
||||
tensor->allocator()->SetRefCount(tensor->data_c(), tensor->allocator()->RefCount(float32_data));
|
||||
}
|
||||
|
@ -302,7 +302,7 @@ int CpuFp16SubGraph::Float16TensorToFloat32Tensor(lite::Tensor *tensor) {
|
|||
return RET_ERROR;
|
||||
}
|
||||
MS_ASSERT(tensor->data_c() != nullptr);
|
||||
Float16ToFloat32_fp16_handler(float16_data, tensor->data_c(), tensor->ElementsNum());
|
||||
Float16ToFloat32_fp16_handler(float16_data, tensor->data_c(), tensor->ElementsNum(), support_fp16_);
|
||||
if (tensor->allocator() != nullptr) {
|
||||
tensor->allocator()->SetRefCount(tensor->data_c(), tensor->allocator()->RefCount(float16_data));
|
||||
tensor->allocator()->Free(float16_data);
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "src/lite_kernel.h"
|
||||
#include "src/executor.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/cpu_info.h"
|
||||
#ifdef ENABLE_ARM64
|
||||
#include "src/common/utils.h"
|
||||
#endif
|
||||
|
@ -152,7 +153,7 @@ class CpuFp32SubGraph : public CpuSubGraph {
|
|||
~CpuFp32SubGraph() override = default;
|
||||
};
|
||||
|
||||
#ifdef ENABLE_FP16
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
class CpuFp16SubGraph : public CpuSubGraph {
|
||||
public:
|
||||
CpuFp16SubGraph(std::vector<LiteKernel *> in_kernels, std::vector<LiteKernel *> out_kernels,
|
||||
|
@ -162,6 +163,9 @@ class CpuFp16SubGraph : public CpuSubGraph {
|
|||
static std::atomic_int index = 0;
|
||||
this->set_name("CpuFP16SubGraph" + std::to_string(index++));
|
||||
desc_.data_type = kNumberTypeFloat16;
|
||||
const auto *context = this->Context();
|
||||
MS_ASSERT(context != nullptr);
|
||||
support_fp16_ = context->device_and_pkg_support_fp16();
|
||||
}
|
||||
|
||||
~CpuFp16SubGraph() override = default;
|
||||
|
@ -227,6 +231,7 @@ class CpuFp16SubGraph : public CpuSubGraph {
|
|||
|
||||
private:
|
||||
std::map<lite::Tensor *, DataStore *> origin_input_data_;
|
||||
bool support_fp16_ = false;
|
||||
};
|
||||
#endif
|
||||
|
||||
|
|
|
@ -276,7 +276,7 @@ int TrainSession::MixPrecisionPreProcess(kernel::LiteKernel *kernel, float scale
|
|||
}
|
||||
// adjust tensor data type
|
||||
if (tensor->data_type() != kernel_type) {
|
||||
auto restore_tensor = CastTensor(tensor, kernel_type);
|
||||
auto restore_tensor = CastTensor(tensor, kernel_type, this->context_->device_and_pkg_support_fp16());
|
||||
if (restore_tensor != nullptr) {
|
||||
restored_origin_tensors_[tensor] = restore_tensor;
|
||||
}
|
||||
|
@ -345,7 +345,7 @@ int TrainSession::MixPrecisionExecKernels(const KernelCallBack &before, const Ke
|
|||
if (train_mode_ == false) {
|
||||
for (auto t : this->outputs_) {
|
||||
if (t->data_type() == kNumberTypeFloat16) {
|
||||
auto restore = CastTensor(t, kNumberTypeFloat32);
|
||||
auto restore = CastTensor(t, kNumberTypeFloat32, this->context_->device_and_pkg_support_fp16());
|
||||
delete restore;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -133,7 +133,7 @@ float CalculateOneHotClassification(tensor::MSTensor *input, tensor::MSTensor *o
|
|||
return acc;
|
||||
}
|
||||
|
||||
Tensor *CastTensor(Tensor *tensor, TypeId dst_data_type) {
|
||||
Tensor *CastTensor(Tensor *tensor, TypeId dst_data_type, bool support_fp16) {
|
||||
#ifdef ENABLE_FP16
|
||||
MS_ASSERT(tensor != nullptr);
|
||||
std::vector<TypeId> valid_type = {kNumberTypeFloat32, kNumberTypeFloat16, kNumberTypeFloat};
|
||||
|
@ -167,7 +167,7 @@ Tensor *CastTensor(Tensor *tensor, TypeId dst_data_type) {
|
|||
return nullptr;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Convert tensor to fp16 " << tensor->tensor_name();
|
||||
Float32ToFloat16_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum());
|
||||
Float32ToFloat16_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum(), support_fp16);
|
||||
} else {
|
||||
tensor->set_data(nullptr);
|
||||
tensor->set_data_type(kNumberTypeFloat32);
|
||||
|
@ -180,7 +180,7 @@ Tensor *CastTensor(Tensor *tensor, TypeId dst_data_type) {
|
|||
auto new_tensor_data = tensor->data_c();
|
||||
MS_ASSERT(new_tensor_data != nullptr);
|
||||
MS_LOG(DEBUG) << "Convert tensor to fp32 " << tensor->tensor_name();
|
||||
Float16ToFloat32_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum());
|
||||
Float16ToFloat32_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum(), support_fp16);
|
||||
}
|
||||
return restore_tensor;
|
||||
#else
|
||||
|
|
|
@ -35,7 +35,7 @@ kernel::LiteKernel *TSFindKernel(const std::vector<kernel::LiteKernel *> &where,
|
|||
size_t TSFindTensor(const std::vector<lite::Tensor *> &where, const lite::Tensor *searchParameter);
|
||||
float CalculateSparseClassification(tensor::MSTensor *input, tensor::MSTensor *output);
|
||||
float CalculateOneHotClassification(tensor::MSTensor *input, tensor::MSTensor *output);
|
||||
Tensor *CastTensor(Tensor *tensor, TypeId dst_data_type);
|
||||
Tensor *CastTensor(Tensor *tensor, TypeId dst_data_type, bool support_fp16);
|
||||
int ScaleTensor(Tensor *tensor, float scale);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue