From f978beb1f2c4e2d8078e7dcd20f5f944ba3008c9 Mon Sep 17 00:00:00 2001 From: sunsuodong Date: Thu, 3 Sep 2020 16:42:42 +0800 Subject: [PATCH] fix ScheduleNode and fill parser --- mindspore/lite/src/scheduler.cc | 12 +++++++++++- mindspore/lite/src/scheduler.h | 1 + 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index 250ab9b610d..925dd52c723 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -231,7 +231,7 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector const std::vector &out_tensors, const mindspore::lite::PrimitiveC *primitive, const schema::CNode *cnode) { MS_ASSERT(nullptr != primitive); - auto data_type = in_tensors.front()->data_type(); + TypeId data_type = GetFirstFp32Fp16OrInt8Type(in_tensors); kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, static_cast(primitive->Type())}; if (context_->device_ctx_.type == DT_GPU) { desc.arch = kernel::KERNEL_ARCH::kGPU; @@ -271,6 +271,16 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector return nullptr; } +TypeId Scheduler::GetFirstFp32Fp16OrInt8Type(const std::vector &in_tensors) { + for (const auto &tensor : in_tensors) { + auto dtype = tensor->data_type(); + if (dtype == kNumberTypeFloat32 || dtype == kNumberTypeFloat16 || dtype == kNumberTypeInt8) { + return dtype; + } + } + return kNumberTypeFloat32; +} + void Scheduler::SetKernelTensorDataType(kernel::LiteKernel *kernel) { if (kernel->desc().arch != kernel::KERNEL_ARCH::kCPU) { return; diff --git a/mindspore/lite/src/scheduler.h b/mindspore/lite/src/scheduler.h index aa39383c07b..bd560c52172 100644 --- a/mindspore/lite/src/scheduler.h +++ b/mindspore/lite/src/scheduler.h @@ -47,6 +47,7 @@ class Scheduler { void ConstructSubgraphs(std::vector *kernels); kernel::LiteKernel *CreateSubKernel(const std::vector &kernels, kernel::KERNEL_ARCH arch); + TypeId GetFirstFp32Fp16OrInt8Type(const std::vector &in_tensors); void SetKernelTensorDataType(kernel::LiteKernel *kernel); protected: