forked from mindspore-Ecosystem/mindspore
!22993 [MSLITE] mix data type bug
Merge pull request !22993 from ling/sr
This commit is contained in:
commit
214d5fb518
|
@ -174,19 +174,19 @@ Status Model::LoadConfig(const std::string &config_path) {
|
|||
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
||||
if (impl_ != nullptr) {
|
||||
MS_LOG(ERROR) << "impl_ illegal in LoadConfig.";
|
||||
return kLiteFileError;
|
||||
return Status(kLiteFileError, "Illegal operation.");
|
||||
}
|
||||
|
||||
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Model implement is null.";
|
||||
return kLiteFileError;
|
||||
return Status(kLiteFileError, "Fail to load config file.");
|
||||
}
|
||||
|
||||
auto ret = impl_->LoadConfig(config_path);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "impl_ LoadConfig failed,";
|
||||
return kLiteFileError;
|
||||
return Status(kLiteFileError, "Invalid config file.");
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
|
|
@ -544,6 +544,7 @@ int Scheduler::InferNodeShape(const lite::Model::Node *node) {
|
|||
FreeOpParameters();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
parameter->quant_type_ = node->quant_type_;
|
||||
parameter->thread_num_ = context_->thread_num_;
|
||||
|
||||
|
@ -1060,19 +1061,17 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
|
|||
MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op.";
|
||||
desc.data_type = kNumberTypeFloat32;
|
||||
}
|
||||
if (prefer_data_type == kNumberTypeFloat32 || prefer_data_type == kTypeUnknown) {
|
||||
status = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat32, &kernel);
|
||||
if (status == RET_OK) {
|
||||
return kernel;
|
||||
} else if (status == RET_ERROR) {
|
||||
op_parameters_.erase(node->output_indices_.at(0));
|
||||
auto ret = InferNodeShape(node);
|
||||
if (!(ret == RET_INFER_INVALID || ret == RET_OK)) {
|
||||
MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
|
||||
}
|
||||
} else if (status == RET_NOT_SUPPORT) {
|
||||
free(op_parameter);
|
||||
status = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat32, &kernel);
|
||||
if (status == RET_OK) {
|
||||
return kernel;
|
||||
} else if (status == RET_ERROR) {
|
||||
op_parameters_.erase(node->output_indices_.at(0));
|
||||
auto ret = InferNodeShape(node);
|
||||
if (!(ret == RET_INFER_INVALID || ret == RET_OK)) {
|
||||
MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
|
||||
}
|
||||
} else if (status == RET_NOT_SUPPORT) {
|
||||
free(op_parameter);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -1325,6 +1324,10 @@ kernel::LiteKernel *Scheduler::ScheduleNodeToKernel(const lite::Model::Node *src
|
|||
<< ", type: " << GetPrimitiveTypeName(src_node->primitive_, schema_version_);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "kernel: [" << src_node->name_ << "] get TypeId(" << kernel->desc().data_type
|
||||
<< ") op success. op_type: " << PrimitiveCurVersionTypeName(kernel->desc().type)
|
||||
<< ", arch: " << kernel->desc().arch;
|
||||
SetKernelTensorDataType(kernel);
|
||||
kernel->set_name(src_node->name_);
|
||||
return kernel;
|
||||
|
|
Loading…
Reference in New Issue