diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index b77878dec0a..ce70fbc697c 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -196,6 +196,81 @@ FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph, return ret; } +const FuncGraphPtr GetLoadedGraph(const ResourcePtr &res) { + MS_EXCEPTION_IF_NULL(res); + auto manager = res->manager(); + MS_EXCEPTION_IF_NULL(manager); + FuncGraphPtr loaded_graph = nullptr; + size_t loaded_graph_num = 0; + auto all_graphs = manager->func_graphs(); + for (auto &graph : all_graphs) { + MS_EXCEPTION_IF_NULL(graph); + if (graph->has_attr("is_load")) { + loaded_graph = graph; + loaded_graph_num += 1; + res->set_is_load(true); + } + } + if (loaded_graph_num == 0) { + return nullptr; + } + if (loaded_graph_num == 1) { + return loaded_graph; + } + MS_LOG(EXCEPTION) << "The loaded sub graph currently should less than 2, but got " << loaded_graph_num; +} + +void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &loaded_graph) { + MS_EXCEPTION_IF_NULL(res); + auto manager = res->manager(); + MS_EXCEPTION_IF_NULL(manager); + FuncGraphPtr root_graph = *(manager->roots().begin()); + auto root_inputs = root_graph->get_inputs(); + auto loaded_inputs = loaded_graph->get_inputs(); + MS_LOG(DEBUG) << "root_graph: " << root_graph->ToString(); + MS_LOG(DEBUG) << "loaded_graph: " << loaded_graph->ToString(); + size_t root_inputs_num = root_inputs.size(); + size_t loaded_inputs_num = loaded_inputs.size(); + if (root_inputs_num != loaded_inputs_num) { + MS_LOG(EXCEPTION) << "The inputs number " << root_inputs_num << " not equal to the inputs number of loaded graph " + << loaded_inputs_num; + } + for (size_t index = 0; index < root_inputs_num; index++) { + auto root_input = root_inputs[index]; + auto loaded_input = loaded_inputs[index]; + + MS_LOG(DEBUG) << "root_input[" << index << "]: " << root_input->DebugString(1); + MS_LOG(DEBUG) << "loaded_input[" << index << "]: " << loaded_input->DebugString(1); + MS_LOG(DEBUG) << "root_input abstract[" << index + << "]: " << (root_input->abstract() ? root_input->abstract()->ToString() : "NULL"); + MS_LOG(DEBUG) << "loaded_input abstract [" << index + << "]: " << (loaded_input->abstract() ? loaded_input->abstract()->ToString() : "NULL"); + + auto root_shape = root_input->Shape() == nullptr ? nullptr : dyn_cast(root_input->Shape()); + auto loaded_shape = loaded_input->Shape() == nullptr ? nullptr : dyn_cast(loaded_input->Shape()); + auto root_type = root_input->Type() == nullptr ? nullptr : dyn_cast(root_input->Type()); + auto loaded_type = loaded_input->Type() == nullptr ? nullptr : dyn_cast(loaded_input->Type()); + + MS_EXCEPTION_IF_NULL(root_shape); + MS_EXCEPTION_IF_NULL(loaded_shape); + MS_EXCEPTION_IF_NULL(root_type); + MS_EXCEPTION_IF_NULL(loaded_type); + + auto shapeEqu = (root_shape->shape() == loaded_shape->shape()) || + (root_shape->shape().size() <= 1 && loaded_shape->shape().size() <= 1); + if (!shapeEqu) { + MS_EXCEPTION(ValueError) << "The " << index + << " th input shape differ from loaded graph. Input shape: " << root_shape->ToString() + << ", input shape of loaded graph: " << loaded_shape->ToString(); + } + if (root_type->type_id() != loaded_type->type_id()) { + MS_EXCEPTION(TypeError) << "The " << std::to_string(index) + << " th input type differ from loaded graph. Input type: " << root_type->ToString() + << ", input type of loaded graph: " << loaded_type->ToString(); + } + } +} + bool ParseAction(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(res); if (!res->input()) { @@ -378,6 +453,8 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance()); context->ParallelParameterContextInitShape(func_graph); + // Get original loaded graph to check inputs later + auto loaded_graph_ptr = GetLoadedGraph(res); // suppose that there is not KeywordArgument for the top graph // get the hyper parameter for (const auto ¶m : func_graph->parameters()) { @@ -414,6 +491,10 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { } } } + // Check input after abstract when there is a loaded graph + if (loaded_graph_ptr != nullptr) { + CheckRootInputShapeAndType(res, loaded_graph_ptr); + } MS_LOG(DEBUG) << "End graph: " << new_fg->ToString() << ", return: " << new_fg->get_return()->DebugString(true); return true; }