fix export lite model for gpu

This commit is contained in:
yangjie159 2020-06-18 16:36:11 +08:00 committed by Gitee
parent f9120e6886
commit 8dde5b9ad9
3 changed files with 9 additions and 4 deletions

View File

@ -402,8 +402,7 @@ bool Kernel2Ms::SetGraphOpTensors(const KernelGraphPtr &kernel_graph_ptr, const
auto c_name = AnfAlgo::GetCNodeName(kernel);
auto fun = predict::convert::OpAttrFactory::GetInstance()->GetPackFun(c_name);
if (fun == nullptr) {
MS_LOG(ERROR) << "get node [" << kernel->fullname_with_scope() << "] attr failed.";
return false;
MS_LOG(WARNING) << "get node [" << kernel->fullname_with_scope() << "] attr failed.";
} else if (!fun(kernel, ms_node.get())) {
MS_LOG(ERROR) << "set node [" << kernel->fullname_with_scope() << "] attr failed.";
return false;

View File

@ -72,8 +72,8 @@ OpAttrFactory::OpAttrFactory() {
{"AddFold", AddFoldPacker},
{"ArgMax", ArgMaxPacker},
{"BatchNorm", BatchNormFoldPacker},
{"FakeQuantWithMinMax", FakeQuantWithMinMaxPacker},
{"FakeQuantWithMinMaxPerChannel", FakeQuantWithMinMaxPerChannelPacker},
{"FakeQuantPerLayer", FakeQuantWithMinMaxPacker},
{"FakeQuantPerChannel", FakeQuantWithMinMaxPerChannelPacker},
{"Mul", MulPacker},
{"MulFold", MulFoldPacker},
{"Squeeze", SqueezePacker}};

View File

@ -28,6 +28,9 @@ void StepConvertGraph(const KernelGraphPtr &kernel_graph_ptr) {
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
bool save_ms_model = MsContext::GetInstance()->save_ms_model_flag();
if (save_ms_model) {
if (kernel_graph_ptr->inputs().empty()) {
return;
}
// set convert_mode: convert cpu info or convert Davnici
executor::Kernel2Ms::GetInstance().set_convert_mode(executor::kConvertCpuMode);
// convert kernel_graph to sub_ms_graph
@ -46,6 +49,9 @@ void StepConvertWeight(const std::vector<tensor::TensorPtr> &inputs) {
bool save_ms_model = MsContext::GetInstance()->save_ms_model_flag();
std::string save_path = MsContext::GetInstance()->save_ms_model_path();
if (save_ms_model) {
if (inputs.empty()) {
return;
}
MS_LOG(INFO) << "save ms model is true to path " << save_path;
if (!executor::Kernel2Ms::GetInstance().KernelInput2MS(inputs)) {
MS_LOG(WARNING) << "convert mindspore kernel input failed";