!12453 fix npu fp16 tensor type wrong
From: @zhaozhenlong Reviewed-by: @zhanghaibo5,@HilbertDavid Signed-off-by: @zhanghaibo5
This commit is contained in:
commit
9c5089f1f0
|
@ -204,6 +204,11 @@ int NPUFusionPass::FormatFusion(kernel::LiteKernel *kernel) {
|
|||
if (trans_kernel->out_kernels().empty()) {
|
||||
// kernel is a trans kernel, it's input kernel num and input tensor num must be 1
|
||||
kernel->in_kernels()[0]->set_out_tensors({trans_kernel->out_tensors()[0]});
|
||||
// in fp16 mode, tensor data type fp16 need to be changed back.
|
||||
auto tensor = kernel->in_kernels()[0]->out_tensors()[0];
|
||||
if (tensor->data_type() == kNumberTypeFloat16) {
|
||||
tensor->set_data_type(kNumberTypeFloat32);
|
||||
}
|
||||
}
|
||||
for (const auto &post_kernel : trans_kernel->out_kernels()) {
|
||||
// update tensor
|
||||
|
|
|
@ -212,6 +212,11 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
|
|||
if (desc.data_type == kNumberTypeFloat16) {
|
||||
desc.data_type = kNumberTypeFloat32;
|
||||
}
|
||||
for (auto tensor : in_tensors) {
|
||||
if (tensor->data_type() == kNumberTypeFloat16) {
|
||||
tensor->set_data_type(kNumberTypeFloat32);
|
||||
}
|
||||
}
|
||||
kernel::KernelKey npu_desc{kNPU, desc.data_type, desc.type};
|
||||
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, npu_desc);
|
||||
if (kernel != nullptr) {
|
||||
|
|
Loading…
Reference in New Issue