fix split output

This commit is contained in:
sunsuodong 2020-09-24 14:48:12 +08:00
parent 2e071243b3
commit 343a489eaa
2 changed files with 10 additions and 1 deletions

View File

@ -17,6 +17,7 @@
#include "src/runtime/kernel/arm/fp16/common_fp16.h"
#include "src/runtime/kernel/arm/base/split_base.h"
#include "nnacl/fp16/split_fp16.h"
#include "nnacl/fp16/cast_fp16.h"
#include "nnacl/split.h"
#include "nnacl/split_parameter.h"
#include "src/kernel_registry.h"
@ -94,6 +95,13 @@ int SplitFp16CPUKernel::Run() {
}
}
ret = ParallelLaunch(this->context_->thread_pool_, SplitRun, this, thread_n_num_);
for (int i = 0; i < param->num_split_; i++) {
if (out_tensors_.at(i)->data_type() == kNumberTypeFloat32) {
Float16ToFloat32(output_ptr_[i], reinterpret_cast<float *>(out_tensors_.at(i)->MutableData()),
out_tensors_.at(i)->ElementsNum());
}
}
FreeInputAndOutput();
if (ret != RET_OK) {
MS_LOG(ERROR) << "split error error_code[" << ret << "]";
}

View File

@ -267,7 +267,8 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<Tensor *> &in_tens
kernel::KernelKey key{desc.arch, kNumberTypeFloat16, desc.type};
kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, key);
if (kernel != nullptr) {
MS_LOG(DEBUG) << "Get fp16 op success.";
MS_LOG(INFO) << "Get fp16 op success. type:"
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type()));
desc.data_type = kNumberTypeFloat16;
kernel->set_desc(desc);
return kernel;