!32491 fix dynamic_gru pass

Merge pull request !32491 from jiangzhenguang/fix_dynamicgru_pass
This commit is contained in:
i-robot 2022-04-12 06:45:54 +00:00 committed by Gitee
commit 4183a03188
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 6 additions and 5 deletions

View File

@ -37,7 +37,6 @@ size_t t_size = 0;
size_t batch_size = 0;
size_t hidden_size = 0;
size_t input_size = 0;
TypeId dh_dtype = kNumberTypeFloat32;
std::map<std::string, size_t> input_index = {
{"x", kIndex1}, {"weight_input", kIndex2}, {"weight_hidden", kIndex3},
@ -88,7 +87,7 @@ AnfNodePtr DynamicGRUV2GradFission::CreateGRUV2HiddenGradCellNode(const FuncGrap
}
(void)gru_v2_hidden_grad_cell_inputs.emplace_back(dynamic_gru_v2_grad_inputs[input_index["dy"]]);
auto input_dh = dynamic_gru_v2_grad_inputs[input_index["dh"]];
dh_dtype = common::AnfAlgo::GetOutputInferDataType(input_dh, 0);
auto dh_dtype = common::AnfAlgo::GetOutputInferDataType(input_dh, 0);
if (cur_t == 0) {
(void)gru_v2_hidden_grad_cell_inputs.emplace_back(input_dh);
} else {
@ -142,14 +141,15 @@ void DynamicGRUV2GradFission::AddTLoopNode(const FuncGraphPtr &func_graph, const
std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
weight_hidden};
auto reshape = NewCNode(reshape_inputs, func_graph);
auto weight_hidden_dtype = common::AnfAlgo::GetOutputInferDataType(weight_hidden, input_index["weight_hidden"]);
auto reshape_out_shape = {IntToSize(1), common::AnfAlgo::GetOutputInferShape(weight_hidden, 0)[0],
common::AnfAlgo::GetOutputInferShape(weight_hidden, 0)[1]};
common::AnfAlgo::SetOutputInferTypeAndShape({dh_dtype}, {reshape_out_shape}, reshape.get());
common::AnfAlgo::SetOutputInferTypeAndShape({weight_hidden_dtype}, {reshape_out_shape}, reshape.get());
(void)matmul_inputs.emplace_back(reshape);
auto matmul_node = NewCNode(matmul_inputs, func_graph);
MS_EXCEPTION_IF_NULL(matmul_node);
std::vector<size_t> out_shape = {1, batch_size, hidden_size};
common::AnfAlgo::SetOutputInferTypeAndShape({dh_dtype}, {out_shape}, matmul_node.get());
common::AnfAlgo::SetOutputInferTypeAndShape({weight_hidden_dtype}, {out_shape}, matmul_node.get());
common::AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), matmul_node);
common::AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(true), matmul_node);
@ -380,7 +380,8 @@ AnfNodePtr DynamicGRUV2GradFission::CreateDwxBatchMatMul(const FuncGraphPtr &gra
auto batch_matmul = NewCNode(matmul_inputs, graph);
MS_EXCEPTION_IF_NULL(batch_matmul);
std::vector<size_t> shape = {t_size, input_size, kGateNum * hidden_size};
common::AnfAlgo::SetOutputInferTypeAndShape({dh_dtype}, {shape}, batch_matmul.get());
auto x_dtype = common::AnfAlgo::GetOutputInferDataType(node1, input_index["x"]);
common::AnfAlgo::SetOutputInferTypeAndShape({x_dtype}, {shape}, batch_matmul.get());
common::AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(true), batch_matmul);
common::AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(false), batch_matmul);
common::AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul);