diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index b0855feb490..ad6c58bc939 100755 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -92,6 +92,51 @@ GraphId GetDistinctionLabel(const KernelGraphPtr &graph) { // else use first node of execution order as label return AnfAlgo::GetStreamDistinctionLabel(graph->execution_order()[0].get()); } + +std::vector GetRealArgs(const KernelGraphPtr graph, const VectorRef &args) { + MS_EXCEPTION_IF_NULL(graph); + std::vector graph_inputs = graph->inputs(); + auto valid_inputs = graph->ValidInputs(); + size_t real_args_size = 0; + std::vector real_args = {}; + for (size_t i = 0; i < args.size(); i++) { + if (utils::isa(args[i])) { + auto tmp_args = AnfAlgo::GetAllOutput(utils::cast(args[i]), {prim::kPrimTupleGetItem}); + for (auto &real_arg : tmp_args) { + auto anf_node = utils::cast(real_arg); + MS_EXCEPTION_IF_NULL(anf_node); + auto abstract = anf_node->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + // create multiple parameters if is a tuple output real kernel + if (abstract->isa() && + !AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) { + auto tuple_abstract = abstract->cast(); + real_args_size += tuple_abstract->size(); + continue; + } + real_args_size += 1; + real_args.push_back(real_arg); + } + } else { + real_args_size += 1; + real_args.push_back(args[i]); + } + } + if (graph_inputs.size() != valid_inputs.size()) { + MS_LOG(EXCEPTION) << "graph_inputs.size(): " << graph_inputs.size() + << ", valid_inputs.size(): " << valid_inputs.size() << " not equal"; + } + if (real_args_size != graph_inputs.size()) { + for (size_t j = 0; j < valid_inputs.size(); j++) { + if (valid_inputs[j]) { + MS_LOG(INFO) << "index: " << j << ", nodes: " << graph_inputs[j]->DebugString(); + } + } + MS_LOG(WARNING) << "real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size() + << " not equal"; + } + return real_args; +} } // namespace GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { @@ -763,38 +808,26 @@ void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) { UpdateGraphOrder(g); std::vector graph_inputs = to_graph->inputs(); auto valid_inputs = to_graph->ValidInputs(); - size_t real_args_size = 0; - for (size_t i = 0; i < args.size(); i++) { - real_args_size += AnfAlgo::GetAllOutput(utils::cast(args[i]), {prim::kPrimTupleGetItem}).size(); - } - if (real_args_size != graph_inputs.size()) { - for (size_t j = 0; j < valid_inputs.size(); j++) { - if (valid_inputs[j]) { - MS_LOG(INFO) << "index: " << j << ", nodes: " << graph_inputs[j]->DebugString(); - } - } - MS_LOG(WARNING) << "real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size() - << " not equal"; - } + auto real_args = GetRealArgs(to_graph, args); size_t input_index = 0; - if (graph_inputs.size() != valid_inputs.size()) { - MS_LOG(EXCEPTION) << "graph_inputs.size(): " << graph_inputs.size() - << ", valid_inputs.size(): " << valid_inputs.size() << " not equal"; - } - for (size_t i = 0; i < args.size(); i++) { + for (size_t i = 0; i < real_args.size(); i++) { if (input_index >= graph_inputs.size()) { MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size(); } - if (utils::isa(args[i])) { + if (utils::isa(real_args[i])) { // arg is a anf node - for (const auto &real_arg : AnfAlgo::GetAllOutput(utils::cast(args[i]), {prim::kPrimTupleGetItem})) { - if (!valid_inputs[input_index]) { - MS_LOG(DEBUG) << "Invalid input arg" << real_arg->DebugString(); - continue; - } - SetChildGraphParameter(real_arg, graph_inputs[input_index]); - input_index++; + auto real_arg = utils::cast(real_args[i]); + auto real_arg_output_num = AnfAlgo::GetOutputTensorNum(real_arg); + if (!AnfAlgo::CheckPrimitiveType(real_arg, prim::kPrimTupleGetItem) && real_arg_output_num > 1) { + input_index += real_arg_output_num; + continue; } + if (valid_inputs[input_index]) { + SetChildGraphParameter(real_arg, graph_inputs[input_index]); + } else { + MS_LOG(DEBUG) << "Invalid input arg" << real_arg->DebugString(); + } + input_index++; } else if (utils::isa(args[i])) { auto value = utils::cast(args[i]); MS_EXCEPTION_IF_NULL(value);