!22993 [MSLITE] mix data type bug

Merge pull request !22993 from ling/sr
This commit is contained in:
i-robot 2021-09-08 03:42:43 +00:00 committed by Gitee
commit 214d5fb518
2 changed files with 18 additions and 15 deletions

View File

@ -174,19 +174,19 @@ Status Model::LoadConfig(const std::string &config_path) {
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
if (impl_ != nullptr) {
MS_LOG(ERROR) << "impl_ illegal in LoadConfig.";
return kLiteFileError;
return Status(kLiteFileError, "Illegal operation.");
}
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null.";
return kLiteFileError;
return Status(kLiteFileError, "Fail to load config file.");
}
auto ret = impl_->LoadConfig(config_path);
if (ret != kSuccess) {
MS_LOG(ERROR) << "impl_ LoadConfig failed,";
return kLiteFileError;
return Status(kLiteFileError, "Invalid config file.");
}
return kSuccess;
}

View File

@ -544,6 +544,7 @@ int Scheduler::InferNodeShape(const lite::Model::Node *node) {
FreeOpParameters();
return RET_ERROR;
}
parameter->quant_type_ = node->quant_type_;
parameter->thread_num_ = context_->thread_num_;
@ -1060,19 +1061,17 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op.";
desc.data_type = kNumberTypeFloat32;
}
if (prefer_data_type == kNumberTypeFloat32 || prefer_data_type == kTypeUnknown) {
status = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat32, &kernel);
if (status == RET_OK) {
return kernel;
} else if (status == RET_ERROR) {
op_parameters_.erase(node->output_indices_.at(0));
auto ret = InferNodeShape(node);
if (!(ret == RET_INFER_INVALID || ret == RET_OK)) {
MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
}
} else if (status == RET_NOT_SUPPORT) {
free(op_parameter);
status = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat32, &kernel);
if (status == RET_OK) {
return kernel;
} else if (status == RET_ERROR) {
op_parameters_.erase(node->output_indices_.at(0));
auto ret = InferNodeShape(node);
if (!(ret == RET_INFER_INVALID || ret == RET_OK)) {
MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
}
} else if (status == RET_NOT_SUPPORT) {
free(op_parameter);
}
return nullptr;
}
@ -1325,6 +1324,10 @@ kernel::LiteKernel *Scheduler::ScheduleNodeToKernel(const lite::Model::Node *src
<< ", type: " << GetPrimitiveTypeName(src_node->primitive_, schema_version_);
return nullptr;
}
MS_LOG(DEBUG) << "kernel: [" << src_node->name_ << "] get TypeId(" << kernel->desc().data_type
<< ") op success. op_type: " << PrimitiveCurVersionTypeName(kernel->desc().type)
<< ", arch: " << kernel->desc().arch;
SetKernelTensorDataType(kernel);
kernel->set_name(src_node->name_);
return kernel;