forked from mindspore-Ecosystem/mindspore
fix export lite model for gpu
This commit is contained in:
parent
f9120e6886
commit
8dde5b9ad9
|
@ -402,8 +402,7 @@ bool Kernel2Ms::SetGraphOpTensors(const KernelGraphPtr &kernel_graph_ptr, const
|
|||
auto c_name = AnfAlgo::GetCNodeName(kernel);
|
||||
auto fun = predict::convert::OpAttrFactory::GetInstance()->GetPackFun(c_name);
|
||||
if (fun == nullptr) {
|
||||
MS_LOG(ERROR) << "get node [" << kernel->fullname_with_scope() << "] attr failed.";
|
||||
return false;
|
||||
MS_LOG(WARNING) << "get node [" << kernel->fullname_with_scope() << "] attr failed.";
|
||||
} else if (!fun(kernel, ms_node.get())) {
|
||||
MS_LOG(ERROR) << "set node [" << kernel->fullname_with_scope() << "] attr failed.";
|
||||
return false;
|
||||
|
|
|
@ -72,8 +72,8 @@ OpAttrFactory::OpAttrFactory() {
|
|||
{"AddFold", AddFoldPacker},
|
||||
{"ArgMax", ArgMaxPacker},
|
||||
{"BatchNorm", BatchNormFoldPacker},
|
||||
{"FakeQuantWithMinMax", FakeQuantWithMinMaxPacker},
|
||||
{"FakeQuantWithMinMaxPerChannel", FakeQuantWithMinMaxPerChannelPacker},
|
||||
{"FakeQuantPerLayer", FakeQuantWithMinMaxPacker},
|
||||
{"FakeQuantPerChannel", FakeQuantWithMinMaxPerChannelPacker},
|
||||
{"Mul", MulPacker},
|
||||
{"MulFold", MulFoldPacker},
|
||||
{"Squeeze", SqueezePacker}};
|
||||
|
|
|
@ -28,6 +28,9 @@ void StepConvertGraph(const KernelGraphPtr &kernel_graph_ptr) {
|
|||
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
|
||||
bool save_ms_model = MsContext::GetInstance()->save_ms_model_flag();
|
||||
if (save_ms_model) {
|
||||
if (kernel_graph_ptr->inputs().empty()) {
|
||||
return;
|
||||
}
|
||||
// set convert_mode: convert cpu info or convert Davnici
|
||||
executor::Kernel2Ms::GetInstance().set_convert_mode(executor::kConvertCpuMode);
|
||||
// convert kernel_graph to sub_ms_graph
|
||||
|
@ -46,6 +49,9 @@ void StepConvertWeight(const std::vector<tensor::TensorPtr> &inputs) {
|
|||
bool save_ms_model = MsContext::GetInstance()->save_ms_model_flag();
|
||||
std::string save_path = MsContext::GetInstance()->save_ms_model_path();
|
||||
if (save_ms_model) {
|
||||
if (inputs.empty()) {
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "save ms model is true to path " << save_path;
|
||||
if (!executor::Kernel2Ms::GetInstance().KernelInput2MS(inputs)) {
|
||||
MS_LOG(WARNING) << "convert mindspore kernel input failed";
|
||||
|
|
Loading…
Reference in New Issue