!10795 Fix-bug-of-multiple-tuple-output-by-connect-cnode-to-return

From: @joylvliang
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-30 17:10:25 +08:00 committed by Gitee
commit 439ef332b8
3 changed files with 9 additions and 4 deletions

View File

@ -817,7 +817,7 @@ void AscendSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &g
}
// construct graph include one op
auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask, true);
MS_EXCEPTION_IF_NULL(graph);
opt::RunOpAscendBackendIRFusionOptimization(graph);
// kernel select

View File

@ -1569,7 +1569,8 @@ void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr
std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) {
const std::vector<int64_t> &tensors_mask,
bool is_ascend) {
auto graph = std::make_shared<KernelGraph>();
graph->set_graph_id(graph_sum_);
graph_sum_++;
@ -1612,7 +1613,11 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf
graph->set_execution_order(exe_order);
graph->UpdateGraphDynamicAttr();
// set output
CreateOutputNode(cnode, graph);
if (is_ascend) {
graph->set_output(cnode);
} else {
CreateOutputNode(cnode, graph);
}
graph->SetInputNodes();
auto manager = MakeManager({graph});
if (manager != nullptr) {

View File

@ -180,7 +180,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
// create a single run op graph
std::shared_ptr<KernelGraph> ConstructSingleOpGraph(const OpRunInfo &op_run_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask);
const std::vector<int64_t> &tensors_mask, bool is_ascend = false);
// create a new kernel graph and update the graph sum
KernelGraphPtr NewKernelGraph();
std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph);