forked from OSSInnovation/mindspore
fix split output
This commit is contained in:
parent
2e071243b3
commit
343a489eaa
|
@ -17,6 +17,7 @@
|
|||
#include "src/runtime/kernel/arm/fp16/common_fp16.h"
|
||||
#include "src/runtime/kernel/arm/base/split_base.h"
|
||||
#include "nnacl/fp16/split_fp16.h"
|
||||
#include "nnacl/fp16/cast_fp16.h"
|
||||
#include "nnacl/split.h"
|
||||
#include "nnacl/split_parameter.h"
|
||||
#include "src/kernel_registry.h"
|
||||
|
@ -94,6 +95,13 @@ int SplitFp16CPUKernel::Run() {
|
|||
}
|
||||
}
|
||||
ret = ParallelLaunch(this->context_->thread_pool_, SplitRun, this, thread_n_num_);
|
||||
for (int i = 0; i < param->num_split_; i++) {
|
||||
if (out_tensors_.at(i)->data_type() == kNumberTypeFloat32) {
|
||||
Float16ToFloat32(output_ptr_[i], reinterpret_cast<float *>(out_tensors_.at(i)->MutableData()),
|
||||
out_tensors_.at(i)->ElementsNum());
|
||||
}
|
||||
}
|
||||
FreeInputAndOutput();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "split error error_code[" << ret << "]";
|
||||
}
|
||||
|
|
|
@ -267,7 +267,8 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<Tensor *> &in_tens
|
|||
kernel::KernelKey key{desc.arch, kNumberTypeFloat16, desc.type};
|
||||
kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, key);
|
||||
if (kernel != nullptr) {
|
||||
MS_LOG(DEBUG) << "Get fp16 op success.";
|
||||
MS_LOG(INFO) << "Get fp16 op success. type:"
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type()));
|
||||
desc.data_type = kNumberTypeFloat16;
|
||||
kernel->set_desc(desc);
|
||||
return kernel;
|
||||
|
|
Loading…
Reference in New Issue