From 8de3427b84697fef3e58a1e2fe6759eb0eded28c Mon Sep 17 00:00:00 2001 From: jiangzhenguang Date: Sat, 2 Apr 2022 15:01:47 +0800 Subject: [PATCH] fix dynamic_gru --- .../ir_fission/dynamic_gru_v2_grad_fission.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/dynamic_gru_v2_grad_fission.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/dynamic_gru_v2_grad_fission.cc index 7fb041b4f91..fdaa67f1713 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/dynamic_gru_v2_grad_fission.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/dynamic_gru_v2_grad_fission.cc @@ -37,7 +37,6 @@ size_t t_size = 0; size_t batch_size = 0; size_t hidden_size = 0; size_t input_size = 0; -TypeId dh_dtype = kNumberTypeFloat32; std::map input_index = { {"x", kIndex1}, {"weight_input", kIndex2}, {"weight_hidden", kIndex3}, @@ -88,7 +87,7 @@ AnfNodePtr DynamicGRUV2GradFission::CreateGRUV2HiddenGradCellNode(const FuncGrap } (void)gru_v2_hidden_grad_cell_inputs.emplace_back(dynamic_gru_v2_grad_inputs[input_index["dy"]]); auto input_dh = dynamic_gru_v2_grad_inputs[input_index["dh"]]; - dh_dtype = common::AnfAlgo::GetOutputInferDataType(input_dh, 0); + auto dh_dtype = common::AnfAlgo::GetOutputInferDataType(input_dh, 0); if (cur_t == 0) { (void)gru_v2_hidden_grad_cell_inputs.emplace_back(input_dh); } else { @@ -142,14 +141,15 @@ void DynamicGRUV2GradFission::AddTLoopNode(const FuncGraphPtr &func_graph, const std::vector reshape_inputs = {NewValueNode(std::make_shared(prim::kPrimReshape->name())), weight_hidden}; auto reshape = NewCNode(reshape_inputs, func_graph); + auto weight_hidden_dtype = common::AnfAlgo::GetOutputInferDataType(weight_hidden, input_index["weight_hidden"]); auto reshape_out_shape = {IntToSize(1), common::AnfAlgo::GetOutputInferShape(weight_hidden, 0)[0], common::AnfAlgo::GetOutputInferShape(weight_hidden, 0)[1]}; - common::AnfAlgo::SetOutputInferTypeAndShape({dh_dtype}, {reshape_out_shape}, reshape.get()); + common::AnfAlgo::SetOutputInferTypeAndShape({weight_hidden_dtype}, {reshape_out_shape}, reshape.get()); (void)matmul_inputs.emplace_back(reshape); auto matmul_node = NewCNode(matmul_inputs, func_graph); MS_EXCEPTION_IF_NULL(matmul_node); std::vector out_shape = {1, batch_size, hidden_size}; - common::AnfAlgo::SetOutputInferTypeAndShape({dh_dtype}, {out_shape}, matmul_node.get()); + common::AnfAlgo::SetOutputInferTypeAndShape({weight_hidden_dtype}, {out_shape}, matmul_node.get()); common::AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), matmul_node); common::AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(true), matmul_node); @@ -380,7 +380,8 @@ AnfNodePtr DynamicGRUV2GradFission::CreateDwxBatchMatMul(const FuncGraphPtr &gra auto batch_matmul = NewCNode(matmul_inputs, graph); MS_EXCEPTION_IF_NULL(batch_matmul); std::vector shape = {t_size, input_size, kGateNum * hidden_size}; - common::AnfAlgo::SetOutputInferTypeAndShape({dh_dtype}, {shape}, batch_matmul.get()); + auto x_dtype = common::AnfAlgo::GetOutputInferDataType(node1, input_index["x"]); + common::AnfAlgo::SetOutputInferTypeAndShape({x_dtype}, {shape}, batch_matmul.get()); common::AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(true), batch_matmul); common::AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(false), batch_matmul); common::AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul);