forked from mindspore-Ecosystem/mindspore
!2296 fix export lite model for gpu
Merge pull request !2296 from yangjie159/fix-export-lite-model-for-gpu
This commit is contained in:
commit
8b4fc72cbd
|
@ -402,8 +402,7 @@ bool Kernel2Ms::SetGraphOpTensors(const KernelGraphPtr &kernel_graph_ptr, const
|
||||||
auto c_name = AnfAlgo::GetCNodeName(kernel);
|
auto c_name = AnfAlgo::GetCNodeName(kernel);
|
||||||
auto fun = predict::convert::OpAttrFactory::GetInstance()->GetPackFun(c_name);
|
auto fun = predict::convert::OpAttrFactory::GetInstance()->GetPackFun(c_name);
|
||||||
if (fun == nullptr) {
|
if (fun == nullptr) {
|
||||||
MS_LOG(ERROR) << "get node [" << kernel->fullname_with_scope() << "] attr failed.";
|
MS_LOG(WARNING) << "get node [" << kernel->fullname_with_scope() << "] attr failed.";
|
||||||
return false;
|
|
||||||
} else if (!fun(kernel, ms_node.get())) {
|
} else if (!fun(kernel, ms_node.get())) {
|
||||||
MS_LOG(ERROR) << "set node [" << kernel->fullname_with_scope() << "] attr failed.";
|
MS_LOG(ERROR) << "set node [" << kernel->fullname_with_scope() << "] attr failed.";
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -72,8 +72,8 @@ OpAttrFactory::OpAttrFactory() {
|
||||||
{"AddFold", AddFoldPacker},
|
{"AddFold", AddFoldPacker},
|
||||||
{"ArgMax", ArgMaxPacker},
|
{"ArgMax", ArgMaxPacker},
|
||||||
{"BatchNorm", BatchNormFoldPacker},
|
{"BatchNorm", BatchNormFoldPacker},
|
||||||
{"FakeQuantWithMinMax", FakeQuantWithMinMaxPacker},
|
{"FakeQuantPerLayer", FakeQuantWithMinMaxPacker},
|
||||||
{"FakeQuantWithMinMaxPerChannel", FakeQuantWithMinMaxPerChannelPacker},
|
{"FakeQuantPerChannel", FakeQuantWithMinMaxPerChannelPacker},
|
||||||
{"Mul", MulPacker},
|
{"Mul", MulPacker},
|
||||||
{"MulFold", MulFoldPacker},
|
{"MulFold", MulFoldPacker},
|
||||||
{"Squeeze", SqueezePacker}};
|
{"Squeeze", SqueezePacker}};
|
||||||
|
|
|
@ -28,6 +28,9 @@ void StepConvertGraph(const KernelGraphPtr &kernel_graph_ptr) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
|
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
|
||||||
bool save_ms_model = MsContext::GetInstance()->save_ms_model_flag();
|
bool save_ms_model = MsContext::GetInstance()->save_ms_model_flag();
|
||||||
if (save_ms_model) {
|
if (save_ms_model) {
|
||||||
|
if (kernel_graph_ptr->inputs().empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
// set convert_mode: convert cpu info or convert Davnici
|
// set convert_mode: convert cpu info or convert Davnici
|
||||||
executor::Kernel2Ms::GetInstance().set_convert_mode(executor::kConvertCpuMode);
|
executor::Kernel2Ms::GetInstance().set_convert_mode(executor::kConvertCpuMode);
|
||||||
// convert kernel_graph to sub_ms_graph
|
// convert kernel_graph to sub_ms_graph
|
||||||
|
@ -46,6 +49,9 @@ void StepConvertWeight(const std::vector<tensor::TensorPtr> &inputs) {
|
||||||
bool save_ms_model = MsContext::GetInstance()->save_ms_model_flag();
|
bool save_ms_model = MsContext::GetInstance()->save_ms_model_flag();
|
||||||
std::string save_path = MsContext::GetInstance()->save_ms_model_path();
|
std::string save_path = MsContext::GetInstance()->save_ms_model_path();
|
||||||
if (save_ms_model) {
|
if (save_ms_model) {
|
||||||
|
if (inputs.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
MS_LOG(INFO) << "save ms model is true to path " << save_path;
|
MS_LOG(INFO) << "save ms model is true to path " << save_path;
|
||||||
if (!executor::Kernel2Ms::GetInstance().KernelInput2MS(inputs)) {
|
if (!executor::Kernel2Ms::GetInstance().KernelInput2MS(inputs)) {
|
||||||
MS_LOG(WARNING) << "convert mindspore kernel input failed";
|
MS_LOG(WARNING) << "convert mindspore kernel input failed";
|
||||||
|
|
Loading…
Reference in New Issue