|
|
|
@ -27,6 +27,13 @@ namespace {
|
|
|
|
|
constexpr size_t kDynamicRNNGradInputNum = 16;
|
|
|
|
|
constexpr size_t kSplitVOutputNum = 2;
|
|
|
|
|
constexpr size_t kBasicCellOutputNum = 2;
|
|
|
|
|
constexpr size_t kBasicLstmCStateGradOutput0DimNum = 3;
|
|
|
|
|
constexpr int64_t kAttrNValue = 2;
|
|
|
|
|
constexpr int64_t kAttrDynInputSizesValue = 2;
|
|
|
|
|
constexpr int64_t kAttrAxis2Value = 2;
|
|
|
|
|
constexpr int64_t kAttrNumSplitValue = 2;
|
|
|
|
|
constexpr int64_t kAttrSplitDimValue = 2;
|
|
|
|
|
constexpr size_t kDimMultiNum = 4;
|
|
|
|
|
|
|
|
|
|
void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
|
|
|
|
std::vector<std::vector<AnfNodePtr>> *result_nodes) {
|
|
|
|
@ -47,8 +54,9 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
|
|
|
|
|
NewValueNode(std::make_shared<Primitive>(kBasicLSTMCellCStateGradV2OpName))};
|
|
|
|
|
auto basic_lstm_cell_c_state_grad = func_graph->NewCNode(basic_lstm_cell_c_state_grad_inputs);
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> output0_dims{origin_input9_shape[kDim0],
|
|
|
|
|
4 * (((origin_input9_shape[kDim1] + kCubeSize - 1) / kCubeSize) * kCubeSize)};
|
|
|
|
|
std::vector<size_t> output0_dims{
|
|
|
|
|
origin_input9_shape[kDim0],
|
|
|
|
|
kDimMultiNum * (((origin_input9_shape[kDim1] + kCubeSize - 1) / kCubeSize) * kCubeSize)};
|
|
|
|
|
std::vector<size_t> output1_dims{input_i_shape[kDim1], input_i_shape[kDim2]};
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16, kNumberTypeFloat32}, {output0_dims, output1_dims},
|
|
|
|
|
basic_lstm_cell_c_state_grad.get());
|
|
|
|
@ -79,8 +87,8 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
|
|
|
|
|
SizeToLong((origin_output2_shape[kDim2] + kCubeSize - 1) / kCubeSize * kCubeSize),
|
|
|
|
|
SizeToLong((origin_output3_shape[kDim1] + kCubeSize - 1) / kCubeSize * kCubeSize)}),
|
|
|
|
|
split_v);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(2)), split_v);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(static_cast<int64_t>(2)), split_v);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(kAttrSplitDimValue)), split_v);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(static_cast<int64_t>(kAttrNumSplitValue)), split_v);
|
|
|
|
|
|
|
|
|
|
basic_lstm_cell_c_state_grad_nodes.emplace_back(basic_lstm_cell_c_state_grad);
|
|
|
|
|
matmul_nodes.emplace_back(matmul);
|
|
|
|
@ -242,7 +250,6 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
|
|
|
|
auto basic_lstm_cell_c_state_grad_outputs_0_shape =
|
|
|
|
|
AnfAlgo::GetOutputInferShape(basic_lstm_cell_c_state_grad_outputs[0], 0);
|
|
|
|
|
std::vector<size_t> temp_shape;
|
|
|
|
|
constexpr size_t kBasicLstmCStateGradOutput0DimNum = 3;
|
|
|
|
|
if (basic_lstm_cell_c_state_grad_outputs_0_shape.size() == kBasicLstmCStateGradOutput0DimNum) {
|
|
|
|
|
temp_shape = basic_lstm_cell_c_state_grad_outputs_0_shape;
|
|
|
|
|
} else {
|
|
|
|
@ -269,7 +276,8 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
|
|
|
|
auto lstm_gage_concat = func_graph->NewCNode(lstm_gage_concat_input);
|
|
|
|
|
auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0);
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(
|
|
|
|
|
{kNumberTypeFloat16}, {{origin_input7_shape[kDim0], origin_input7_shape[kDim1], 4 * origin_input7_shape[kDim2]}},
|
|
|
|
|
{kNumberTypeFloat16},
|
|
|
|
|
{{origin_input7_shape[kDim0], origin_input7_shape[kDim1], kDimMultiNum * origin_input7_shape[kDim2]}},
|
|
|
|
|
lstm_gage_concat.get());
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_gage_concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(num_split_x)}), lstm_gage_concat);
|
|
|
|
@ -298,7 +306,7 @@ AnfNodePtr CreateSplitV(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_v.get());
|
|
|
|
|
// Set attr
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(SizeToLong(0)), split_v);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(SizeToLong(2)), split_v);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(SizeToLong(kAttrNumSplitValue)), split_v);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(std::vector<int64_t>{SizeToLong(origin_input6_shape[0] - 1), 1}),
|
|
|
|
|
split_v);
|
|
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_v);
|
|
|
|
@ -321,8 +329,7 @@ AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic
|
|
|
|
|
auto origin_input4_shape = AnfAlgo::GetOutputInferShape(origin_input4, 0);
|
|
|
|
|
// Create reshape to change shape
|
|
|
|
|
std::vector<size_t> shape_tmp;
|
|
|
|
|
constexpr size_t kInput4DimNum = 3;
|
|
|
|
|
if (origin_input4_shape.size() == kInput4DimNum) {
|
|
|
|
|
if (origin_input4_shape.size() == kShape4dDims) {
|
|
|
|
|
shape_tmp = origin_input4_shape;
|
|
|
|
|
} else {
|
|
|
|
|
shape_tmp = {1, origin_input4_shape[0], origin_input4_shape[1]};
|
|
|
|
@ -339,8 +346,8 @@ AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic
|
|
|
|
|
std::vector<size_t> shape = {splitv_output0_shape[0] + 1, origin_input4_shape[0], origin_input4_shape[1]};
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape}, concat.get());
|
|
|
|
|
// Set attr
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(2)), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{2}), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(kAttrNValue)), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{kAttrDynInputSizesValue}), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(0)), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat);
|
|
|
|
|
return concat;
|
|
|
|
@ -362,9 +369,9 @@ AnfNodePtr CreateConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_
|
|
|
|
|
origin_output0_shape[kDim2] + h_concat_output_shape[kDim2]};
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input0, 0)}, {shape}, concat.get());
|
|
|
|
|
// Set attr
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(2)), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{2}), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(2)), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(kAttrNValue)), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{kAttrDynInputSizesValue}), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(kAttrAxis2Value)), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat);
|
|
|
|
|
return concat;
|
|
|
|
|
}
|
|
|
|
@ -378,8 +385,7 @@ AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dy
|
|
|
|
|
auto origin_input4_shape = AnfAlgo::GetOutputInferShape(origin_input4, 0);
|
|
|
|
|
// Create reshape to change shape
|
|
|
|
|
std::vector<size_t> shape_tmp;
|
|
|
|
|
constexpr size_t kShapeDimNum = 3;
|
|
|
|
|
if (origin_input4_shape.size() == kShapeDimNum) {
|
|
|
|
|
if (origin_input4_shape.size() == kShape3dDims) {
|
|
|
|
|
shape_tmp = origin_input4_shape;
|
|
|
|
|
} else {
|
|
|
|
|
shape_tmp = {1, origin_input4_shape[0], origin_input4_shape[1]};
|
|
|
|
@ -398,9 +404,9 @@ AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dy
|
|
|
|
|
origin_input0_shape[kDim2] + shape_tmp[kDim2]};
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input0, 0)}, {shape}, concat.get());
|
|
|
|
|
// Set attr
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(2)), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{2}), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(2)), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(kAttrNValue)), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{kAttrDynInputSizesValue}), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(kAttrAxis2Value)), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat);
|
|
|
|
|
return concat;
|
|
|
|
|
}
|
|
|
|
@ -432,8 +438,8 @@ AnfNodePtr CreateBatchMatMul2(const FuncGraphPtr &func_graph, const AnfNodePtr &
|
|
|
|
|
node, lstm_input_grad};
|
|
|
|
|
auto batch_matmul = func_graph->NewCNode(matmul_inputs);
|
|
|
|
|
// Set infer data type and shape
|
|
|
|
|
auto out_shape = {AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[0], IntToSize(1),
|
|
|
|
|
AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[2]};
|
|
|
|
|
auto out_shape = {AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[kIndex0], IntToSize(1),
|
|
|
|
|
AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[kIndex2]};
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {out_shape}, batch_matmul.get());
|
|
|
|
|
// Set attr
|
|
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul);
|
|
|
|
|