forked from mindspore-Ecosystem/mindspore
!22296 resolve issue:pynative use graph loaded from mindir
Merge pull request !22296 from lanzhineng/infer_optv3
This commit is contained in:
commit
af6d16ec14
|
@ -196,6 +196,81 @@ FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph,
|
|||
return ret;
|
||||
}
|
||||
|
||||
const FuncGraphPtr GetLoadedGraph(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
auto manager = res->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
FuncGraphPtr loaded_graph = nullptr;
|
||||
size_t loaded_graph_num = 0;
|
||||
auto all_graphs = manager->func_graphs();
|
||||
for (auto &graph : all_graphs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (graph->has_attr("is_load")) {
|
||||
loaded_graph = graph;
|
||||
loaded_graph_num += 1;
|
||||
res->set_is_load(true);
|
||||
}
|
||||
}
|
||||
if (loaded_graph_num == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
if (loaded_graph_num == 1) {
|
||||
return loaded_graph;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "The loaded sub graph currently should less than 2, but got " << loaded_graph_num;
|
||||
}
|
||||
|
||||
void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &loaded_graph) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
auto manager = res->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
FuncGraphPtr root_graph = *(manager->roots().begin());
|
||||
auto root_inputs = root_graph->get_inputs();
|
||||
auto loaded_inputs = loaded_graph->get_inputs();
|
||||
MS_LOG(DEBUG) << "root_graph: " << root_graph->ToString();
|
||||
MS_LOG(DEBUG) << "loaded_graph: " << loaded_graph->ToString();
|
||||
size_t root_inputs_num = root_inputs.size();
|
||||
size_t loaded_inputs_num = loaded_inputs.size();
|
||||
if (root_inputs_num != loaded_inputs_num) {
|
||||
MS_LOG(EXCEPTION) << "The inputs number " << root_inputs_num << " not equal to the inputs number of loaded graph "
|
||||
<< loaded_inputs_num;
|
||||
}
|
||||
for (size_t index = 0; index < root_inputs_num; index++) {
|
||||
auto root_input = root_inputs[index];
|
||||
auto loaded_input = loaded_inputs[index];
|
||||
|
||||
MS_LOG(DEBUG) << "root_input[" << index << "]: " << root_input->DebugString(1);
|
||||
MS_LOG(DEBUG) << "loaded_input[" << index << "]: " << loaded_input->DebugString(1);
|
||||
MS_LOG(DEBUG) << "root_input abstract[" << index
|
||||
<< "]: " << (root_input->abstract() ? root_input->abstract()->ToString() : "NULL");
|
||||
MS_LOG(DEBUG) << "loaded_input abstract [" << index
|
||||
<< "]: " << (loaded_input->abstract() ? loaded_input->abstract()->ToString() : "NULL");
|
||||
|
||||
auto root_shape = root_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(root_input->Shape());
|
||||
auto loaded_shape = loaded_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(loaded_input->Shape());
|
||||
auto root_type = root_input->Type() == nullptr ? nullptr : dyn_cast<Type>(root_input->Type());
|
||||
auto loaded_type = loaded_input->Type() == nullptr ? nullptr : dyn_cast<Type>(loaded_input->Type());
|
||||
|
||||
MS_EXCEPTION_IF_NULL(root_shape);
|
||||
MS_EXCEPTION_IF_NULL(loaded_shape);
|
||||
MS_EXCEPTION_IF_NULL(root_type);
|
||||
MS_EXCEPTION_IF_NULL(loaded_type);
|
||||
|
||||
auto shapeEqu = (root_shape->shape() == loaded_shape->shape()) ||
|
||||
(root_shape->shape().size() <= 1 && loaded_shape->shape().size() <= 1);
|
||||
if (!shapeEqu) {
|
||||
MS_EXCEPTION(ValueError) << "The " << index
|
||||
<< " th input shape differ from loaded graph. Input shape: " << root_shape->ToString()
|
||||
<< ", input shape of loaded graph: " << loaded_shape->ToString();
|
||||
}
|
||||
if (root_type->type_id() != loaded_type->type_id()) {
|
||||
MS_EXCEPTION(TypeError) << "The " << std::to_string(index)
|
||||
<< " th input type differ from loaded graph. Input type: " << root_type->ToString()
|
||||
<< ", input type of loaded graph: " << loaded_type->ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool ParseAction(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
if (!res->input()) {
|
||||
|
@ -378,6 +453,8 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
|
|||
MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
|
||||
context->ParallelParameterContextInitShape(func_graph);
|
||||
|
||||
// Get original loaded graph to check inputs later
|
||||
auto loaded_graph_ptr = GetLoadedGraph(res);
|
||||
// suppose that there is not KeywordArgument for the top graph
|
||||
// get the hyper parameter
|
||||
for (const auto ¶m : func_graph->parameters()) {
|
||||
|
@ -414,6 +491,10 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
|
|||
}
|
||||
}
|
||||
}
|
||||
// Check input after abstract when there is a loaded graph
|
||||
if (loaded_graph_ptr != nullptr) {
|
||||
CheckRootInputShapeAndType(res, loaded_graph_ptr);
|
||||
}
|
||||
MS_LOG(DEBUG) << "End graph: " << new_fg->ToString() << ", return: " << new_fg->get_return()->DebugString(true);
|
||||
return true;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue