!22296 resolve issue:pynative use graph loaded from mindir

Merge pull request !22296 from lanzhineng/infer_optv3
This commit is contained in:
i-robot 2021-08-24 08:09:25 +00:00 committed by Gitee
commit af6d16ec14
1 changed files with 81 additions and 0 deletions

View File

@ -196,6 +196,81 @@ FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph,
return ret;
}
const FuncGraphPtr GetLoadedGraph(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
auto manager = res->manager();
MS_EXCEPTION_IF_NULL(manager);
FuncGraphPtr loaded_graph = nullptr;
size_t loaded_graph_num = 0;
auto all_graphs = manager->func_graphs();
for (auto &graph : all_graphs) {
MS_EXCEPTION_IF_NULL(graph);
if (graph->has_attr("is_load")) {
loaded_graph = graph;
loaded_graph_num += 1;
res->set_is_load(true);
}
}
if (loaded_graph_num == 0) {
return nullptr;
}
if (loaded_graph_num == 1) {
return loaded_graph;
}
MS_LOG(EXCEPTION) << "The loaded sub graph currently should less than 2, but got " << loaded_graph_num;
}
void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &loaded_graph) {
MS_EXCEPTION_IF_NULL(res);
auto manager = res->manager();
MS_EXCEPTION_IF_NULL(manager);
FuncGraphPtr root_graph = *(manager->roots().begin());
auto root_inputs = root_graph->get_inputs();
auto loaded_inputs = loaded_graph->get_inputs();
MS_LOG(DEBUG) << "root_graph: " << root_graph->ToString();
MS_LOG(DEBUG) << "loaded_graph: " << loaded_graph->ToString();
size_t root_inputs_num = root_inputs.size();
size_t loaded_inputs_num = loaded_inputs.size();
if (root_inputs_num != loaded_inputs_num) {
MS_LOG(EXCEPTION) << "The inputs number " << root_inputs_num << " not equal to the inputs number of loaded graph "
<< loaded_inputs_num;
}
for (size_t index = 0; index < root_inputs_num; index++) {
auto root_input = root_inputs[index];
auto loaded_input = loaded_inputs[index];
MS_LOG(DEBUG) << "root_input[" << index << "]: " << root_input->DebugString(1);
MS_LOG(DEBUG) << "loaded_input[" << index << "]: " << loaded_input->DebugString(1);
MS_LOG(DEBUG) << "root_input abstract[" << index
<< "]: " << (root_input->abstract() ? root_input->abstract()->ToString() : "NULL");
MS_LOG(DEBUG) << "loaded_input abstract [" << index
<< "]: " << (loaded_input->abstract() ? loaded_input->abstract()->ToString() : "NULL");
auto root_shape = root_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(root_input->Shape());
auto loaded_shape = loaded_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(loaded_input->Shape());
auto root_type = root_input->Type() == nullptr ? nullptr : dyn_cast<Type>(root_input->Type());
auto loaded_type = loaded_input->Type() == nullptr ? nullptr : dyn_cast<Type>(loaded_input->Type());
MS_EXCEPTION_IF_NULL(root_shape);
MS_EXCEPTION_IF_NULL(loaded_shape);
MS_EXCEPTION_IF_NULL(root_type);
MS_EXCEPTION_IF_NULL(loaded_type);
auto shapeEqu = (root_shape->shape() == loaded_shape->shape()) ||
(root_shape->shape().size() <= 1 && loaded_shape->shape().size() <= 1);
if (!shapeEqu) {
MS_EXCEPTION(ValueError) << "The " << index
<< " th input shape differ from loaded graph. Input shape: " << root_shape->ToString()
<< ", input shape of loaded graph: " << loaded_shape->ToString();
}
if (root_type->type_id() != loaded_type->type_id()) {
MS_EXCEPTION(TypeError) << "The " << std::to_string(index)
<< " th input type differ from loaded graph. Input type: " << root_type->ToString()
<< ", input type of loaded graph: " << loaded_type->ToString();
}
}
}
bool ParseAction(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
if (!res->input()) {
@ -378,6 +453,8 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
context->ParallelParameterContextInitShape(func_graph);
// Get original loaded graph to check inputs later
auto loaded_graph_ptr = GetLoadedGraph(res);
// suppose that there is not KeywordArgument for the top graph
// get the hyper parameter
for (const auto &param : func_graph->parameters()) {
@ -414,6 +491,10 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
}
}
}
// Check input after abstract when there is a loaded graph
if (loaded_graph_ptr != nullptr) {
CheckRootInputShapeAndType(res, loaded_graph_ptr);
}
MS_LOG(DEBUG) << "End graph: " << new_fg->ToString() << ", return: " << new_fg->get_return()->DebugString(true);
return true;
}