!429 Fix bug of parameter nums don't match with input args when set child graph input
Merge pull request !429 from chenfei/expand-tuple-output-of-node-when-set-child-graph-input
This commit is contained in:
commit
329ddbeb32
|
@ -92,6 +92,51 @@ GraphId GetDistinctionLabel(const KernelGraphPtr &graph) {
|
|||
// else use first node of execution order as label
|
||||
return AnfAlgo::GetStreamDistinctionLabel(graph->execution_order()[0].get());
|
||||
}
|
||||
|
||||
std::vector<BaseRef> GetRealArgs(const KernelGraphPtr graph, const VectorRef &args) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> graph_inputs = graph->inputs();
|
||||
auto valid_inputs = graph->ValidInputs();
|
||||
size_t real_args_size = 0;
|
||||
std::vector<BaseRef> real_args = {};
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
if (utils::isa<AnfNodePtr>(args[i])) {
|
||||
auto tmp_args = AnfAlgo::GetAllOutput(utils::cast<AnfNodePtr>(args[i]), {prim::kPrimTupleGetItem});
|
||||
for (auto &real_arg : tmp_args) {
|
||||
auto anf_node = utils::cast<AnfNodePtr>(real_arg);
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
auto abstract = anf_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
// create multiple parameters if is a tuple output real kernel
|
||||
if (abstract->isa<abstract::AbstractTuple>() &&
|
||||
!AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
|
||||
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
||||
real_args_size += tuple_abstract->size();
|
||||
continue;
|
||||
}
|
||||
real_args_size += 1;
|
||||
real_args.push_back(real_arg);
|
||||
}
|
||||
} else {
|
||||
real_args_size += 1;
|
||||
real_args.push_back(args[i]);
|
||||
}
|
||||
}
|
||||
if (graph_inputs.size() != valid_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "graph_inputs.size(): " << graph_inputs.size()
|
||||
<< ", valid_inputs.size(): " << valid_inputs.size() << " not equal";
|
||||
}
|
||||
if (real_args_size != graph_inputs.size()) {
|
||||
for (size_t j = 0; j < valid_inputs.size(); j++) {
|
||||
if (valid_inputs[j]) {
|
||||
MS_LOG(INFO) << "index: " << j << ", nodes: " << graph_inputs[j]->DebugString();
|
||||
}
|
||||
}
|
||||
MS_LOG(WARNING) << "real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size()
|
||||
<< " not equal";
|
||||
}
|
||||
return real_args;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
||||
|
@ -763,38 +808,26 @@ void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) {
|
|||
UpdateGraphOrder(g);
|
||||
std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
|
||||
auto valid_inputs = to_graph->ValidInputs();
|
||||
size_t real_args_size = 0;
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
real_args_size += AnfAlgo::GetAllOutput(utils::cast<AnfNodePtr>(args[i]), {prim::kPrimTupleGetItem}).size();
|
||||
}
|
||||
if (real_args_size != graph_inputs.size()) {
|
||||
for (size_t j = 0; j < valid_inputs.size(); j++) {
|
||||
if (valid_inputs[j]) {
|
||||
MS_LOG(INFO) << "index: " << j << ", nodes: " << graph_inputs[j]->DebugString();
|
||||
}
|
||||
}
|
||||
MS_LOG(WARNING) << "real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size()
|
||||
<< " not equal";
|
||||
}
|
||||
auto real_args = GetRealArgs(to_graph, args);
|
||||
size_t input_index = 0;
|
||||
if (graph_inputs.size() != valid_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "graph_inputs.size(): " << graph_inputs.size()
|
||||
<< ", valid_inputs.size(): " << valid_inputs.size() << " not equal";
|
||||
}
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
for (size_t i = 0; i < real_args.size(); i++) {
|
||||
if (input_index >= graph_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size();
|
||||
}
|
||||
if (utils::isa<AnfNodePtr>(args[i])) {
|
||||
if (utils::isa<AnfNodePtr>(real_args[i])) {
|
||||
// arg is a anf node
|
||||
for (const auto &real_arg : AnfAlgo::GetAllOutput(utils::cast<AnfNodePtr>(args[i]), {prim::kPrimTupleGetItem})) {
|
||||
if (!valid_inputs[input_index]) {
|
||||
MS_LOG(DEBUG) << "Invalid input arg" << real_arg->DebugString();
|
||||
continue;
|
||||
}
|
||||
SetChildGraphParameter(real_arg, graph_inputs[input_index]);
|
||||
input_index++;
|
||||
auto real_arg = utils::cast<AnfNodePtr>(real_args[i]);
|
||||
auto real_arg_output_num = AnfAlgo::GetOutputTensorNum(real_arg);
|
||||
if (!AnfAlgo::CheckPrimitiveType(real_arg, prim::kPrimTupleGetItem) && real_arg_output_num > 1) {
|
||||
input_index += real_arg_output_num;
|
||||
continue;
|
||||
}
|
||||
if (valid_inputs[input_index]) {
|
||||
SetChildGraphParameter(real_arg, graph_inputs[input_index]);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Invalid input arg" << real_arg->DebugString();
|
||||
}
|
||||
input_index++;
|
||||
} else if (utils::isa<ValuePtr>(args[i])) {
|
||||
auto value = utils::cast<ValuePtr>(args[i]);
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
|
|
Loading…
Reference in New Issue