!32491 fix dynamic_gru pass
Merge pull request !32491 from jiangzhenguang/fix_dynamicgru_pass
This commit is contained in:
commit
4183a03188
|
@ -37,7 +37,6 @@ size_t t_size = 0;
|
|||
size_t batch_size = 0;
|
||||
size_t hidden_size = 0;
|
||||
size_t input_size = 0;
|
||||
TypeId dh_dtype = kNumberTypeFloat32;
|
||||
|
||||
std::map<std::string, size_t> input_index = {
|
||||
{"x", kIndex1}, {"weight_input", kIndex2}, {"weight_hidden", kIndex3},
|
||||
|
@ -88,7 +87,7 @@ AnfNodePtr DynamicGRUV2GradFission::CreateGRUV2HiddenGradCellNode(const FuncGrap
|
|||
}
|
||||
(void)gru_v2_hidden_grad_cell_inputs.emplace_back(dynamic_gru_v2_grad_inputs[input_index["dy"]]);
|
||||
auto input_dh = dynamic_gru_v2_grad_inputs[input_index["dh"]];
|
||||
dh_dtype = common::AnfAlgo::GetOutputInferDataType(input_dh, 0);
|
||||
auto dh_dtype = common::AnfAlgo::GetOutputInferDataType(input_dh, 0);
|
||||
if (cur_t == 0) {
|
||||
(void)gru_v2_hidden_grad_cell_inputs.emplace_back(input_dh);
|
||||
} else {
|
||||
|
@ -142,14 +141,15 @@ void DynamicGRUV2GradFission::AddTLoopNode(const FuncGraphPtr &func_graph, const
|
|||
std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
|
||||
weight_hidden};
|
||||
auto reshape = NewCNode(reshape_inputs, func_graph);
|
||||
auto weight_hidden_dtype = common::AnfAlgo::GetOutputInferDataType(weight_hidden, input_index["weight_hidden"]);
|
||||
auto reshape_out_shape = {IntToSize(1), common::AnfAlgo::GetOutputInferShape(weight_hidden, 0)[0],
|
||||
common::AnfAlgo::GetOutputInferShape(weight_hidden, 0)[1]};
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape({dh_dtype}, {reshape_out_shape}, reshape.get());
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape({weight_hidden_dtype}, {reshape_out_shape}, reshape.get());
|
||||
(void)matmul_inputs.emplace_back(reshape);
|
||||
auto matmul_node = NewCNode(matmul_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(matmul_node);
|
||||
std::vector<size_t> out_shape = {1, batch_size, hidden_size};
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape({dh_dtype}, {out_shape}, matmul_node.get());
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape({weight_hidden_dtype}, {out_shape}, matmul_node.get());
|
||||
common::AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), matmul_node);
|
||||
common::AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(true), matmul_node);
|
||||
|
||||
|
@ -380,7 +380,8 @@ AnfNodePtr DynamicGRUV2GradFission::CreateDwxBatchMatMul(const FuncGraphPtr &gra
|
|||
auto batch_matmul = NewCNode(matmul_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(batch_matmul);
|
||||
std::vector<size_t> shape = {t_size, input_size, kGateNum * hidden_size};
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape({dh_dtype}, {shape}, batch_matmul.get());
|
||||
auto x_dtype = common::AnfAlgo::GetOutputInferDataType(node1, input_index["x"]);
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape({x_dtype}, {shape}, batch_matmul.get());
|
||||
common::AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(true), batch_matmul);
|
||||
common::AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(false), batch_matmul);
|
||||
common::AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul);
|
||||
|
|
Loading…
Reference in New Issue