From a9e8058e9d2b36533e9ecdce3f568a87e1d5c969 Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Mon, 11 Oct 2021 12:05:56 +0800 Subject: [PATCH] fix dynamic_rnn_grad_fission_v2 output shape --- .../ir_fission/dynamic_rnn_grad_fission_v2.cc | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc index 4d76e40d037..0e6909d43a3 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc @@ -456,6 +456,20 @@ AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dyn return reduce_sum; } +AnfNodePtr CreateDwReshape(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode, + const AnfNodePtr &batch_matmul) { + MS_EXCEPTION_IF_NULL(func_graph); + // Create node + std::vector reshape_inputs = {NewValueNode(std::make_shared(prim::kPrimReshape->name())), + batch_matmul}; + auto reshape = func_graph->NewCNode(reshape_inputs); + // Set infer data type and shape + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)}, + {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 0)}, reshape.get()); + AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reshape); + return reshape; +} + AnfNodePtr CreateValueNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) { auto origin_input7 = dynamic_rnn_grad_cnode->input(kIndex8); auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0); @@ -528,7 +542,8 @@ const AnfNodePtr DynamicRnnGradFissionV2::Process(const FuncGraphPtr &func_graph auto dw_reduce_sum = CreateDwReduceSum(func_graph, dynamic_rnn_grad_cnode, batch_matmul); make_tuple_inputs.emplace_back(dw_reduce_sum); } else { - make_tuple_inputs.emplace_back(batch_matmul); + auto dw_reshape = CreateDwReshape(func_graph, dynamic_rnn_grad_cnode, batch_matmul); + make_tuple_inputs.emplace_back(dw_reshape); } auto value_node = CreateValueNode(func_graph, dynamic_rnn_grad_cnode);