fix dynamic_rnn_grad_fission_v2 output shape
This commit is contained in:
parent
bfb901e5cc
commit
a9e8058e9d
|
@ -456,6 +456,20 @@ AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dyn
|
|||
return reduce_sum;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateDwReshape(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
const AnfNodePtr &batch_matmul) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
// Create node
|
||||
std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
|
||||
batch_matmul};
|
||||
auto reshape = func_graph->NewCNode(reshape_inputs);
|
||||
// Set infer data type and shape
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)},
|
||||
{AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 0)}, reshape.get());
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reshape);
|
||||
return reshape;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateValueNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) {
|
||||
auto origin_input7 = dynamic_rnn_grad_cnode->input(kIndex8);
|
||||
auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0);
|
||||
|
@ -528,7 +542,8 @@ const AnfNodePtr DynamicRnnGradFissionV2::Process(const FuncGraphPtr &func_graph
|
|||
auto dw_reduce_sum = CreateDwReduceSum(func_graph, dynamic_rnn_grad_cnode, batch_matmul);
|
||||
make_tuple_inputs.emplace_back(dw_reduce_sum);
|
||||
} else {
|
||||
make_tuple_inputs.emplace_back(batch_matmul);
|
||||
auto dw_reshape = CreateDwReshape(func_graph, dynamic_rnn_grad_cnode, batch_matmul);
|
||||
make_tuple_inputs.emplace_back(dw_reshape);
|
||||
}
|
||||
|
||||
auto value_node = CreateValueNode(func_graph, dynamic_rnn_grad_cnode);
|
||||
|
|
Loading…
Reference in New Issue