fix dynamic_rnn_grad_fission_v2 output shape

This commit is contained in:
yuchaojie 2021-10-11 12:05:56 +08:00
parent bfb901e5cc
commit a9e8058e9d
1 changed files with 16 additions and 1 deletions

View File

@ -456,6 +456,20 @@ AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dyn
return reduce_sum;
}
AnfNodePtr CreateDwReshape(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
const AnfNodePtr &batch_matmul) {
MS_EXCEPTION_IF_NULL(func_graph);
// Create node
std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
batch_matmul};
auto reshape = func_graph->NewCNode(reshape_inputs);
// Set infer data type and shape
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)},
{AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 0)}, reshape.get());
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reshape);
return reshape;
}
AnfNodePtr CreateValueNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) {
auto origin_input7 = dynamic_rnn_grad_cnode->input(kIndex8);
auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0);
@ -528,7 +542,8 @@ const AnfNodePtr DynamicRnnGradFissionV2::Process(const FuncGraphPtr &func_graph
auto dw_reduce_sum = CreateDwReduceSum(func_graph, dynamic_rnn_grad_cnode, batch_matmul);
make_tuple_inputs.emplace_back(dw_reduce_sum);
} else {
make_tuple_inputs.emplace_back(batch_matmul);
auto dw_reshape = CreateDwReshape(func_graph, dynamic_rnn_grad_cnode, batch_matmul);
make_tuple_inputs.emplace_back(dw_reshape);
}
auto value_node = CreateValueNode(func_graph, dynamic_rnn_grad_cnode);