forked from mindspore-Ecosystem/mindspore
!5707 [MS][LITE][Develop]fix ScheduleNode and fill parser
Merge pull request !5707 from sunsuodong/fix_fill_parser
This commit is contained in:
commit
c2ff5e3fba
|
@ -231,7 +231,7 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<tensor::Tensor *>
|
||||||
const std::vector<tensor::Tensor *> &out_tensors,
|
const std::vector<tensor::Tensor *> &out_tensors,
|
||||||
const mindspore::lite::PrimitiveC *primitive, const schema::CNode *cnode) {
|
const mindspore::lite::PrimitiveC *primitive, const schema::CNode *cnode) {
|
||||||
MS_ASSERT(nullptr != primitive);
|
MS_ASSERT(nullptr != primitive);
|
||||||
auto data_type = in_tensors.front()->data_type();
|
TypeId data_type = GetFirstFp32Fp16OrInt8Type(in_tensors);
|
||||||
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, static_cast<schema::PrimitiveType>(primitive->Type())};
|
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, static_cast<schema::PrimitiveType>(primitive->Type())};
|
||||||
if (context_->device_ctx_.type == DT_GPU) {
|
if (context_->device_ctx_.type == DT_GPU) {
|
||||||
desc.arch = kernel::KERNEL_ARCH::kGPU;
|
desc.arch = kernel::KERNEL_ARCH::kGPU;
|
||||||
|
@ -271,6 +271,16 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<tensor::Tensor *>
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TypeId Scheduler::GetFirstFp32Fp16OrInt8Type(const std::vector<tensor::Tensor *> &in_tensors) {
|
||||||
|
for (const auto &tensor : in_tensors) {
|
||||||
|
auto dtype = tensor->data_type();
|
||||||
|
if (dtype == kNumberTypeFloat32 || dtype == kNumberTypeFloat16 || dtype == kNumberTypeInt8) {
|
||||||
|
return dtype;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return kNumberTypeFloat32;
|
||||||
|
}
|
||||||
|
|
||||||
void Scheduler::SetKernelTensorDataType(kernel::LiteKernel *kernel) {
|
void Scheduler::SetKernelTensorDataType(kernel::LiteKernel *kernel) {
|
||||||
if (kernel->desc().arch != kernel::KERNEL_ARCH::kCPU) {
|
if (kernel->desc().arch != kernel::KERNEL_ARCH::kCPU) {
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -47,6 +47,7 @@ class Scheduler {
|
||||||
void ConstructSubgraphs(std::vector<kernel::LiteKernel *> *kernels);
|
void ConstructSubgraphs(std::vector<kernel::LiteKernel *> *kernels);
|
||||||
|
|
||||||
kernel::LiteKernel *CreateSubKernel(const std::vector<kernel::LiteKernel *> &kernels, kernel::KERNEL_ARCH arch);
|
kernel::LiteKernel *CreateSubKernel(const std::vector<kernel::LiteKernel *> &kernels, kernel::KERNEL_ARCH arch);
|
||||||
|
TypeId GetFirstFp32Fp16OrInt8Type(const std::vector<tensor::Tensor *> &in_tensors);
|
||||||
void SetKernelTensorDataType(kernel::LiteKernel *kernel);
|
void SetKernelTensorDataType(kernel::LiteKernel *kernel);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
Loading…
Reference in New Issue