!8079 support GNMT net fix dynamic rnn grad fission pass
Merge pull request !8079 from liubuyu/op_support
This commit is contained in:
commit
a5b0d13141
|
@ -19,7 +19,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/bn_split.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/bn_grad_split.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h"
|
||||
|
@ -61,6 +61,7 @@
|
|||
#include "backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h"
|
||||
#include "backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h"
|
||||
#include "backend/optimizer/ascend/format_type/insert_trans_op.h"
|
||||
#include "backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.h"
|
||||
#include "backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.h"
|
||||
#include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h"
|
||||
#include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h"
|
||||
|
@ -215,6 +216,7 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph)
|
|||
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||
auto data_layout_pm = std::make_shared<PassManager>("transop_pm");
|
||||
data_layout_pm->AddPass(std::make_shared<RectifyDoMaskKernelInfo>());
|
||||
data_layout_pm->AddPass(std::make_shared<DynamicRNNGradReformat>());
|
||||
data_layout_pm->AddPass(std::make_shared<InsertTransOp>());
|
||||
data_layout_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||
data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
|
||||
|
@ -276,7 +278,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
|
|||
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicRNN>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DynamicRNNGradFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DynamicRnnGradFissionV2>());
|
||||
AddAscendIRFusionRulesPass(ir_fusion_pm.get());
|
||||
AddAscendIRFusionPass(ir_fusion_pm.get());
|
||||
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.h"
|
||||
#include <memory>
|
||||
#include "backend/optimizer/ascend/ascend_helper.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "utils/utils.h"
|
||||
#include "base/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef DynamicRNNGradReformat::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<Var>();
|
||||
VarPtr Xs2 = std::make_shared<Var>();
|
||||
MS_EXCEPTION_IF_NULL(Xs);
|
||||
MS_EXCEPTION_IF_NULL(Xs2);
|
||||
const auto split = std::make_shared<Primitive>(prim::kPrimSplitV->name());
|
||||
return VectorRef({split, VectorRef({std::make_shared<Primitive>(prim::kPrimMatMul->name()), Xs, Xs2})});
|
||||
}
|
||||
|
||||
const AnfNodePtr DynamicRNNGradReformat::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto split_v = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(split_v);
|
||||
auto matmul = CheckAnfNodeIfCNodeAndInputSize(split_v->input(1), 3);
|
||||
MS_EXCEPTION_IF_NULL(matmul);
|
||||
auto input_node_with_idx = AnfAlgo::GetPrevNodeOutput(matmul, 0);
|
||||
auto input_node = input_node_with_idx.first;
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (!(input_node->isa<CNode>() &&
|
||||
AnfAlgo::GetCNodeName(input_node->cast<CNodePtr>()) == kBasicLSTMCellCStateGradV2OpName)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// reformat matmul
|
||||
auto matmul_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(matmul);
|
||||
MS_EXCEPTION_IF_NULL(matmul_kernel_build_info);
|
||||
auto matmul_new_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
matmul_new_builder->SetInputsFormat({kOpFormat_FRAC_NZ, kOpFormat_FRAC_NZ});
|
||||
matmul_new_builder->SetOutputsFormat({kOpFormat_FRAC_NZ});
|
||||
matmul_new_builder->SetInputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16});
|
||||
matmul_new_builder->SetOutputsDeviceType({kNumberTypeFloat});
|
||||
matmul_new_builder->SetKernelType(matmul_kernel_build_info->kernel_type());
|
||||
matmul_new_builder->SetFusionType(matmul_kernel_build_info->fusion_type());
|
||||
matmul_new_builder->SetProcessor(matmul_kernel_build_info->processor());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(matmul_new_builder->Build(), matmul.get());
|
||||
AnfAlgo::SetNodeAttr("insert_backend", MakeValue(true), matmul);
|
||||
|
||||
// reformat split_v
|
||||
auto split_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(split_v);
|
||||
MS_EXCEPTION_IF_NULL(split_kernel_build_info);
|
||||
auto split_new_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
split_new_builder->SetInputsFormat({kOpFormat_FRAC_NZ});
|
||||
split_new_builder->SetOutputsFormat({kOpFormat_FRAC_NZ, kOpFormat_FRAC_NZ});
|
||||
split_new_builder->SetInputsDeviceType(split_kernel_build_info->GetAllInputDeviceTypes());
|
||||
split_new_builder->SetOutputsDeviceType(split_kernel_build_info->GetAllOutputDeviceTypes());
|
||||
split_new_builder->SetKernelType(split_kernel_build_info->kernel_type());
|
||||
split_new_builder->SetFusionType(split_kernel_build_info->fusion_type());
|
||||
split_new_builder->SetProcessor(split_kernel_build_info->processor());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(split_new_builder->Build(), split_v.get());
|
||||
AnfAlgo::SetNodeAttr("insert_backend", MakeValue(true), split_v);
|
||||
return split_v;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DYNAMIC_RNN_GRAD_REFORMAT_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DYNAMIC_RNN_GRAD_REFORMAT_H_
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include "ir/anf.h"
|
||||
#include "backend/optimizer/common/pattern_engine.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class DynamicRNNGradReformat : public PatternProcessPass {
|
||||
public:
|
||||
explicit DynamicRNNGradReformat(bool multigraph = true)
|
||||
: PatternProcessPass("dynamic_rnn_grad_reformat", multigraph) {}
|
||||
~DynamicRNNGradReformat() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DYNAMIC_RNN_GRAD_REFORMAT_H_
|
|
@ -1,250 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
constexpr size_t kDynamicRNNGradInputNum = 16;
|
||||
constexpr size_t kLSTMInputGradOutputNum = 4;
|
||||
const BaseRef DynamicRNNGradFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({prim::kPrimDynamicRNNGrad, Xs});
|
||||
}
|
||||
|
||||
AnfNodePtr CreateSplitVD(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// SplitV
|
||||
std::vector<AnfNodePtr> splitvd_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())), node};
|
||||
auto split_vd = graph->NewCNode(splitvd_input);
|
||||
MS_EXCEPTION_IF_NULL(split_vd);
|
||||
auto dtypes = {AnfAlgo::GetOutputInferDataType(node, 0), AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node, 0)[0] - 1, AnfAlgo::GetOutputInferShape(node, 0)[1],
|
||||
AnfAlgo::GetOutputInferShape(node, 0)[2]};
|
||||
auto shape2 = {IntToSize(1), AnfAlgo::GetOutputInferShape(node, 0)[1], AnfAlgo::GetOutputInferShape(node, 0)[2]};
|
||||
std::vector<std::vector<size_t>> shapes = {shape, shape2};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_vd.get());
|
||||
AnfAlgo::SetNodeAttr("split_dim", MakeValue(0), split_vd);
|
||||
AnfAlgo::SetNodeAttr("num_split", MakeValue(2), split_vd);
|
||||
int tmp = SizeToInt(AnfAlgo::GetOutputInferShape(node, 0)[0]) - 1;
|
||||
AnfAlgo::SetNodeAttr("size_splits", MakeValue(std::vector<int>{tmp, 1}), split_vd);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_vd);
|
||||
return split_vd;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateLSTMInputGrad(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const auto &dynamic_rnn_grad_inputs = cnode->inputs();
|
||||
std::vector<AnfNodePtr> lstm_input_grad_inputs = {NewValueNode(std::make_shared<Primitive>(kLSTMInputGradOpName)),
|
||||
dynamic_rnn_grad_inputs[2],
|
||||
dynamic_rnn_grad_inputs[6],
|
||||
dynamic_rnn_grad_inputs[8],
|
||||
dynamic_rnn_grad_inputs[9],
|
||||
dynamic_rnn_grad_inputs[10],
|
||||
dynamic_rnn_grad_inputs[11],
|
||||
dynamic_rnn_grad_inputs[12],
|
||||
dynamic_rnn_grad_inputs[13],
|
||||
dynamic_rnn_grad_inputs[14],
|
||||
dynamic_rnn_grad_inputs[15],
|
||||
dynamic_rnn_grad_inputs[16]};
|
||||
std::vector<AnfNodePtr> ori_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, node, 5, &ori_outputs);
|
||||
auto lstm_op = graph->NewCNode(lstm_input_grad_inputs);
|
||||
MS_EXCEPTION_IF_NULL(lstm_op);
|
||||
auto ori_type = AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_inputs[8], 0);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(ori_outputs[2], 0), AnfAlgo::GetOutputInferDataType(ori_outputs[3], 0),
|
||||
AnfAlgo::GetOutputInferDataType(ori_outputs[4], 0), ori_type};
|
||||
std::vector<size_t> ori_shape = {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_inputs[8], 0)[0],
|
||||
AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_inputs[8], 0)[1],
|
||||
4 * AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_inputs[8], 0)[2]};
|
||||
auto shapes = {AnfAlgo::GetOutputInferShape(ori_outputs[2], 0), AnfAlgo::GetOutputInferShape(ori_outputs[3], 0),
|
||||
AnfAlgo::GetOutputInferShape(ori_outputs[4], 0), ori_shape};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, lstm_op.get());
|
||||
return lstm_op;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node1);
|
||||
MS_EXCEPTION_IF_NULL(node2);
|
||||
// BatchMatMul
|
||||
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name())),
|
||||
node2, node1};
|
||||
auto batch_matmul = graph->NewCNode(matmul_inputs);
|
||||
MS_EXCEPTION_IF_NULL(batch_matmul);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node1, 0)};
|
||||
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node2, 0)[0], AnfAlgo::GetOutputInferShape(node2, 0)[2],
|
||||
AnfAlgo::GetOutputInferShape(node1, 0)[2]};
|
||||
auto shapes = {shape};
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul);
|
||||
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(true), batch_matmul);
|
||||
AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(false), batch_matmul);
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, batch_matmul.get());
|
||||
return batch_matmul;
|
||||
}
|
||||
|
||||
AnfNodePtr AddHConcatD(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node1);
|
||||
MS_EXCEPTION_IF_NULL(node2);
|
||||
std::vector<AnfNodePtr> ori_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, node2, 2, &ori_outputs);
|
||||
auto ori_shape = AnfAlgo::GetOutputInferShape(node1, 0);
|
||||
std::vector<std::vector<size_t>> shape_tmp;
|
||||
if (ori_shape.size() == 3) {
|
||||
shape_tmp = {ori_shape};
|
||||
} else {
|
||||
shape_tmp = {{IntToSize(1), ori_shape[0], ori_shape[1]}};
|
||||
}
|
||||
auto ori_dtype = {AnfAlgo::GetOutputInferDataType(node1, 0)};
|
||||
// reshape
|
||||
std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
|
||||
node1};
|
||||
auto reshape = graph->NewCNode(reshape_input);
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape);
|
||||
AnfAlgo::SetOutputInferTypeAndShape(ori_dtype, shape_tmp, reshape.get());
|
||||
|
||||
// concatd --> concat
|
||||
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
|
||||
reshape, ori_outputs[0]};
|
||||
auto concat_op = graph->NewCNode(concat_inputs);
|
||||
MS_EXCEPTION_IF_NULL(concat_op);
|
||||
std::vector<size_t> input = {AnfAlgo::GetOutputInferShape(node2, 0)[0] + 1, AnfAlgo::GetOutputInferShape(node2, 0)[1],
|
||||
AnfAlgo::GetOutputInferShape(node2, 0)[2]};
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node1, 0)};
|
||||
auto shapes = {input};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, concat_op.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat_op);
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat_op);
|
||||
AnfAlgo::SetNodeAttr("axis", MakeValue(0), concat_op);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat_op);
|
||||
return concat_op;
|
||||
}
|
||||
|
||||
AnfNodePtr AddConcatD(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node1);
|
||||
MS_EXCEPTION_IF_NULL(node2);
|
||||
// concatd --> concat
|
||||
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())), node1,
|
||||
node2};
|
||||
auto concat_op = graph->NewCNode(concat_inputs);
|
||||
MS_EXCEPTION_IF_NULL(concat_op);
|
||||
std::vector<size_t> input = {AnfAlgo::GetOutputInferShape(node1, 0)[0], AnfAlgo::GetOutputInferShape(node1, 0)[1],
|
||||
AnfAlgo::GetOutputInferShape(node1, 0)[2] + AnfAlgo::GetOutputInferShape(node2, 0)[2]};
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node1, 0)};
|
||||
auto shapes = {input};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, concat_op.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat_op);
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat_op);
|
||||
AnfAlgo::SetNodeAttr("axis", MakeValue(2), concat_op);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat_op);
|
||||
return concat_op;
|
||||
}
|
||||
|
||||
AnfNodePtr AddDwReduceSum(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||
// node1 : dynamic output
|
||||
// node2 : matmul
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node1);
|
||||
MS_EXCEPTION_IF_NULL(node2);
|
||||
std::vector<AnfNodePtr> ori_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, node1, 5, &ori_outputs);
|
||||
// ReduceSumd
|
||||
std::vector<AnfNodePtr> reducesum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
|
||||
node2};
|
||||
auto reduce_sumd = graph->NewCNode(reducesum_inputs);
|
||||
MS_EXCEPTION_IF_NULL(reduce_sumd);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(ori_outputs[0], 0)};
|
||||
auto shapes = {AnfAlgo::GetOutputInferShape(ori_outputs[0], 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, reduce_sumd.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int>{0}), reduce_sumd);
|
||||
AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_sumd);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sumd);
|
||||
return reduce_sumd;
|
||||
}
|
||||
|
||||
AnfNodePtr AddDbReduceSum(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||
// node1 lstm output
|
||||
// node2 // dynamic output
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node1);
|
||||
MS_EXCEPTION_IF_NULL(node2);
|
||||
std::vector<AnfNodePtr> ori_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, node2, 5, &ori_outputs);
|
||||
// ReduceSumd --> ReduceSum
|
||||
std::vector<AnfNodePtr> reducerum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
|
||||
node1};
|
||||
auto reduce_sumd = graph->NewCNode(reducerum_inputs);
|
||||
MS_EXCEPTION_IF_NULL(reduce_sumd);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(ori_outputs[1], 0)};
|
||||
auto shapes = {AnfAlgo::GetOutputInferShape(ori_outputs[1], 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, reduce_sumd.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int>{0, 1}), reduce_sumd);
|
||||
AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_sumd);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sumd);
|
||||
return reduce_sumd;
|
||||
}
|
||||
|
||||
const AnfNodePtr DynamicRNNGradFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->size() < kDynamicRNNGradInputNum + 1) {
|
||||
MS_LOG(INFO) << "The input num of DynamicRNNGrad less than" << kDynamicRNNGradInputNum
|
||||
<< ". The node should not be changed";
|
||||
return nullptr;
|
||||
}
|
||||
// input_list of dynamic_rnn_grad
|
||||
const auto &ori_inputs = cnode->inputs();
|
||||
// create split_vd
|
||||
auto split_vd = CreateSplitVD(func_graph, ori_inputs[7]);
|
||||
// create concat_1
|
||||
auto h_concat = AddHConcatD(func_graph, ori_inputs[5], split_vd);
|
||||
// create concat_2
|
||||
auto concat = AddConcatD(func_graph, ori_inputs[1], h_concat);
|
||||
// create lsym_input_grad
|
||||
auto lstm_input_grad = CreateLSTMInputGrad(func_graph, cnode);
|
||||
std::vector<AnfNodePtr> lstm_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_input_grad, kLSTMInputGradOutputNum, &lstm_outputs);
|
||||
// create matmul
|
||||
auto batch_matmul = CreateBatchMatMul(func_graph, lstm_outputs[3], concat);
|
||||
// create reduce_sum_1
|
||||
auto dw_reduce_sum = AddDwReduceSum(func_graph, node, batch_matmul);
|
||||
// create reduce_sum_2
|
||||
auto db_reduce_sum = AddDbReduceSum(func_graph, lstm_outputs[3], node);
|
||||
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple),
|
||||
dw_reduce_sum,
|
||||
db_reduce_sum,
|
||||
lstm_outputs[0],
|
||||
lstm_outputs[1],
|
||||
lstm_outputs[2]};
|
||||
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
return make_tuple;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,483 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kDynamicRNNGradInputNum = 16;
|
||||
constexpr size_t kSplitVOutputNum = 2;
|
||||
constexpr size_t kLSTMInputGradOutputNum = 4;
|
||||
|
||||
void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
std::vector<std::vector<AnfNodePtr>> *result_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
|
||||
MS_EXCEPTION_IF_NULL(result_nodes);
|
||||
std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_nodes;
|
||||
std::vector<AnfNodePtr> matmul_nodes;
|
||||
std::vector<AnfNodePtr> split_nodes;
|
||||
// Get the size of t
|
||||
auto origin_input9_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(10), 0);
|
||||
size_t t_size = origin_input9_shape[0];
|
||||
auto input_i_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(12), 0);
|
||||
|
||||
for (size_t i = 0; i < t_size; ++i) {
|
||||
// Create basic_lstm_cell_c_state_grad
|
||||
std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kBasicLSTMCellCStateGradV2OpName))};
|
||||
auto basic_lstm_cell_c_state_grad = func_graph->NewCNode(basic_lstm_cell_c_state_grad_inputs);
|
||||
|
||||
std::vector<size_t> output0_dims{origin_input9_shape[0], 4 * (((origin_input9_shape[1] + 15) / 16) * 16)};
|
||||
std::vector<size_t> output1_dims{input_i_shape[1], input_i_shape[2]};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32}, {output0_dims, output1_dims},
|
||||
basic_lstm_cell_c_state_grad.get());
|
||||
AnfAlgo::SetNodeAttr("forget_bias", MakeValue(1.0f), basic_lstm_cell_c_state_grad);
|
||||
AnfAlgo::SetNodeAttr("activation", MakeValue("Tanh"), basic_lstm_cell_c_state_grad);
|
||||
|
||||
// Create matmul
|
||||
auto origin_input1_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(2), 0);
|
||||
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))};
|
||||
auto matmul = func_graph->NewCNode(matmul_inputs);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {{output0_dims[0], origin_input1_shape[0]}},
|
||||
matmul.get());
|
||||
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), matmul);
|
||||
AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(true), matmul);
|
||||
|
||||
// Create split
|
||||
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
|
||||
auto split_v = func_graph->NewCNode(splitv_input);
|
||||
auto origin_output2_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 2);
|
||||
auto origin_output3_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 3);
|
||||
std::vector<size_t> split_v_output0_shape{origin_output2_shape[1], origin_output2_shape[2]};
|
||||
std::vector<size_t> split_v_output1_shape{origin_output3_shape[0], origin_output3_shape[1]};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32},
|
||||
{split_v_output0_shape, split_v_output1_shape}, split_v.get());
|
||||
|
||||
AnfAlgo::SetNodeAttr(kAttrSizeSplits,
|
||||
MakeValue(std::vector<int>{SizeToInt((origin_output2_shape[2] + 15) / 16),
|
||||
SizeToInt((origin_output3_shape[1] + 15) / 16)}),
|
||||
split_v);
|
||||
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(0), split_v);
|
||||
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(2), split_v);
|
||||
|
||||
basic_lstm_cell_c_state_grad_nodes.emplace_back(basic_lstm_cell_c_state_grad);
|
||||
matmul_nodes.emplace_back(matmul);
|
||||
split_nodes.emplace_back(split_v);
|
||||
}
|
||||
result_nodes->emplace_back(basic_lstm_cell_c_state_grad_nodes);
|
||||
result_nodes->emplace_back(matmul_nodes);
|
||||
result_nodes->emplace_back(split_nodes);
|
||||
}
|
||||
|
||||
AnfNodePtr CreateLSTMSPlitV(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
|
||||
const std::vector<std::vector<size_t>> &split_shapes,
|
||||
const std::vector<TypeId> &split_types, const std::vector<int> &size_split,
|
||||
size_t num_split_x) {
|
||||
std::vector<AnfNodePtr> lstm_split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
|
||||
input};
|
||||
auto lstm_split = func_graph->NewCNode(lstm_split_input);
|
||||
AnfAlgo::SetOutputInferTypeAndShape(split_types, split_shapes, lstm_split.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(size_split), lstm_split);
|
||||
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(0), lstm_split);
|
||||
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(SizeToInt(num_split_x)), lstm_split);
|
||||
return lstm_split;
|
||||
}
|
||||
|
||||
AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
std::vector<AnfNodePtr> *outputs) {
|
||||
std::vector<std::vector<AnfNodePtr>> result_nodes;
|
||||
CreateTLoopNode(func_graph, dynamic_rnn_grad_cnode, &result_nodes);
|
||||
|
||||
auto origin_input5_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(6), 0);
|
||||
std::vector<size_t> split_c_dims{1, origin_input5_shape[0], origin_input5_shape[1]};
|
||||
|
||||
auto origin_input7 = dynamic_rnn_grad_cnode->input(8);
|
||||
size_t num_split_x = AnfAlgo::GetOutputInferShape(origin_input7, 0)[0];
|
||||
std::vector<std::vector<size_t>> split_shapes;
|
||||
std::vector<TypeId> split_types;
|
||||
std::vector<int> size_split;
|
||||
for (size_t i = 0; i < num_split_x; ++i) {
|
||||
split_shapes.emplace_back(split_c_dims);
|
||||
split_types.emplace_back(kNumberTypeFloat32);
|
||||
size_split.emplace_back(1);
|
||||
}
|
||||
// Create lstm_split_c
|
||||
auto lstm_split_c = CreateLSTMSPlitV(func_graph, origin_input7, split_shapes, split_types, size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_c_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_c, num_split_x, &lstm_split_c_outputs);
|
||||
|
||||
// Create lstm_split_dy
|
||||
auto lstm_split_dy =
|
||||
CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(9), split_shapes, split_types, size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_dy_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_dy, num_split_x, &lstm_split_dy_outputs);
|
||||
|
||||
// Create lstm_split_i
|
||||
auto lstm_split_i =
|
||||
CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(12), split_shapes, split_types, size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_i_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_i, num_split_x, &lstm_split_i_outputs);
|
||||
|
||||
// Create lstm_split_j
|
||||
auto lstm_split_j =
|
||||
CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(13), split_shapes, split_types, size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_j_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_j, num_split_x, &lstm_split_j_outputs);
|
||||
|
||||
// Create lstm_split_f
|
||||
auto lstm_split_f =
|
||||
CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(14), split_shapes, split_types, size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_f_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_f, num_split_x, &lstm_split_f_outputs);
|
||||
|
||||
// Create lstm_split_o
|
||||
auto lstm_split_o =
|
||||
CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(15), split_shapes, split_types, size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_o_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_o, num_split_x, &lstm_split_o_outputs);
|
||||
|
||||
// Create lstm_split_tanh
|
||||
auto lstm_split_tanh =
|
||||
CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(16), split_shapes, split_types, size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_tanh_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_tanh, num_split_x, &lstm_split_tanh_outputs);
|
||||
|
||||
// Add edges
|
||||
std::vector<AnfNodePtr> pre_basic_lstm_cell_c_state_grad_outputs;
|
||||
std::vector<AnfNodePtr> pre_split_outputs;
|
||||
auto basic_lstm_cell_c_state_grad_nodes = result_nodes[0];
|
||||
auto matmul_nodes = result_nodes[1];
|
||||
auto split_nodes = result_nodes[2];
|
||||
std::vector<AnfNodePtr> lstm_x_concat_input(num_split_x + 1);
|
||||
lstm_x_concat_input[0] = NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()));
|
||||
std::vector<AnfNodePtr> lstm_gage_concat_input(num_split_x + 1);
|
||||
lstm_gage_concat_input[0] = NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()));
|
||||
|
||||
for (size_t i = 0; i < num_split_x; ++i) {
|
||||
size_t idx = num_split_x - i - 1;
|
||||
// Create basic_lstm_cell_c_state_grad
|
||||
std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kBasicLSTMCellCStateGradV2OpName))};
|
||||
if (i == num_split_x - 1) {
|
||||
std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
|
||||
dynamic_rnn_grad_cnode->input(6)};
|
||||
auto reshape = func_graph->NewCNode(reshape_inputs);
|
||||
auto reshape_out_shape = {IntToSize(1), AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(6), 0)[0],
|
||||
AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(6), 0)[1]};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {reshape_out_shape}, reshape.get());
|
||||
basic_lstm_cell_c_state_grad_inputs.emplace_back(reshape);
|
||||
} else {
|
||||
basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_c_outputs[idx - 1]);
|
||||
}
|
||||
basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_dy_outputs[idx]);
|
||||
if (i == 0) {
|
||||
basic_lstm_cell_c_state_grad_inputs.emplace_back(dynamic_rnn_grad_cnode->input(10));
|
||||
basic_lstm_cell_c_state_grad_inputs.emplace_back(dynamic_rnn_grad_cnode->input(11));
|
||||
} else {
|
||||
basic_lstm_cell_c_state_grad_inputs.emplace_back(pre_split_outputs[1]);
|
||||
basic_lstm_cell_c_state_grad_inputs.emplace_back(pre_basic_lstm_cell_c_state_grad_outputs[1]);
|
||||
}
|
||||
basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_i_outputs[idx]);
|
||||
basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_j_outputs[idx]);
|
||||
basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_f_outputs[idx]);
|
||||
basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_o_outputs[idx]);
|
||||
basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_tanh_outputs[idx]);
|
||||
auto basic_lstm_cell_c_state_grad = func_graph->NewCNode(basic_lstm_cell_c_state_grad_inputs);
|
||||
MS_EXCEPTION_IF_NULL(basic_lstm_cell_c_state_grad);
|
||||
basic_lstm_cell_c_state_grad->set_abstract(basic_lstm_cell_c_state_grad_nodes[i]->abstract());
|
||||
AnfAlgo::CopyNodeAttrs(basic_lstm_cell_c_state_grad_nodes[i], basic_lstm_cell_c_state_grad);
|
||||
// Create outputs for current basic_lstm_cell_c_state_grad node
|
||||
std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, basic_lstm_cell_c_state_grad, 2, &basic_lstm_cell_c_state_grad_outputs);
|
||||
pre_basic_lstm_cell_c_state_grad_outputs = basic_lstm_cell_c_state_grad_outputs;
|
||||
|
||||
// Create MatMul
|
||||
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))};
|
||||
matmul_inputs.emplace_back(basic_lstm_cell_c_state_grad_outputs[0]);
|
||||
matmul_inputs.emplace_back(dynamic_rnn_grad_cnode->input(2));
|
||||
auto matmul = func_graph->NewCNode(matmul_inputs);
|
||||
MS_EXCEPTION_IF_NULL(matmul);
|
||||
matmul->set_abstract(matmul_nodes[i]->abstract());
|
||||
AnfAlgo::CopyNodeAttrs(matmul_nodes[i], matmul);
|
||||
|
||||
// Create splitv
|
||||
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
|
||||
matmul};
|
||||
auto split_v = func_graph->NewCNode(splitv_input);
|
||||
MS_EXCEPTION_IF_NULL(split_v);
|
||||
split_v->set_abstract(split_nodes[i]->abstract());
|
||||
AnfAlgo::CopyNodeAttrs(split_nodes[i], split_v);
|
||||
|
||||
// Create outputs for current split node
|
||||
std::vector<AnfNodePtr> split_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, split_v, 2, &split_outputs);
|
||||
pre_split_outputs = split_outputs;
|
||||
|
||||
lstm_x_concat_input[idx + 1] = split_outputs[0];
|
||||
lstm_gage_concat_input[idx + 1] = basic_lstm_cell_c_state_grad_outputs[0];
|
||||
}
|
||||
|
||||
// Create lstm_x_concat
|
||||
auto lstm_x_concat = func_graph->NewCNode(lstm_x_concat_input);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 2)},
|
||||
lstm_x_concat.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(num_split_x)), lstm_x_concat);
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{SizeToInt(num_split_x)}), lstm_x_concat);
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), lstm_x_concat);
|
||||
|
||||
// Create lstm_gage_concat
|
||||
auto lstm_gage_concat = func_graph->NewCNode(lstm_gage_concat_input);
|
||||
auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32},
|
||||
{{origin_input7_shape[0], origin_input7_shape[1], 4 * origin_input7_shape[2]}},
|
||||
lstm_gage_concat.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(num_split_x)), lstm_gage_concat);
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{SizeToInt(num_split_x)}), lstm_gage_concat);
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), lstm_gage_concat);
|
||||
|
||||
outputs->emplace_back(lstm_x_concat);
|
||||
outputs->emplace_back(pre_split_outputs[1]);
|
||||
outputs->emplace_back(pre_basic_lstm_cell_c_state_grad_outputs[1]);
|
||||
return lstm_gage_concat;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateSplitV(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
|
||||
// Create node
|
||||
auto origin_input6 = dynamic_rnn_grad_cnode->input(7);
|
||||
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
|
||||
origin_input6};
|
||||
auto split_v = func_graph->NewCNode(splitv_input);
|
||||
// Set infer data type and shape
|
||||
auto dtypes = {AnfAlgo::GetOutputInferDataType(origin_input6, 0), AnfAlgo::GetOutputInferDataType(origin_input6, 0)};
|
||||
auto origin_input6_shape = AnfAlgo::GetOutputInferShape(origin_input6, 0);
|
||||
std::vector<size_t> shape1 = {origin_input6_shape[0] - 1, origin_input6_shape[1], origin_input6_shape[2]};
|
||||
std::vector<size_t> shape2 = {1, origin_input6_shape[1], origin_input6_shape[2]};
|
||||
std::vector<std::vector<size_t>> shapes = {shape1, shape2};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_v.get());
|
||||
// Set attr
|
||||
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(0), split_v);
|
||||
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(2), split_v);
|
||||
AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(std::vector<int>{SizeToInt(origin_input6_shape[0] - 1), 1}), split_v);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_v);
|
||||
return split_v;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
const AnfNodePtr &splitv) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
|
||||
MS_EXCEPTION_IF_NULL(splitv);
|
||||
// Create node
|
||||
std::vector<AnfNodePtr> splitv_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, splitv, kSplitVOutputNum, &splitv_outputs);
|
||||
if (splitv_outputs.size() != kSplitVOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "Create outputs of node " << splitv->DebugString() << " failed";
|
||||
}
|
||||
auto origin_input4 = dynamic_rnn_grad_cnode->input(5);
|
||||
auto origin_input4_shape = AnfAlgo::GetOutputInferShape(origin_input4, 0);
|
||||
// Create reshape to change shape
|
||||
std::vector<size_t> shape_tmp;
|
||||
if (origin_input4_shape.size() == 3) {
|
||||
shape_tmp = origin_input4_shape;
|
||||
} else {
|
||||
shape_tmp = {1, origin_input4_shape[0], origin_input4_shape[1]};
|
||||
}
|
||||
std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
|
||||
origin_input4};
|
||||
auto reshape = func_graph->NewCNode(reshape_input);
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape_tmp}, reshape.get());
|
||||
|
||||
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
|
||||
reshape, splitv_outputs[0]};
|
||||
auto concat = func_graph->NewCNode(concat_inputs);
|
||||
// Set infer data type and shape
|
||||
auto splitv_output0_shape = AnfAlgo::GetOutputInferShape(splitv, 0);
|
||||
std::vector<size_t> shape = {splitv_output0_shape[0] + 1, origin_input4_shape[0], origin_input4_shape[1]};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape}, concat.get());
|
||||
// Set attr
|
||||
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat);
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat);
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), concat);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat);
|
||||
return concat;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
const AnfNodePtr &h_concat) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
|
||||
// Create node
|
||||
auto origin_input0 = dynamic_rnn_grad_cnode->input(1);
|
||||
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
|
||||
origin_input0, h_concat};
|
||||
auto concat = func_graph->NewCNode(concat_inputs);
|
||||
// Set infer data type and shape
|
||||
auto origin_output0_shape = AnfAlgo::GetOutputInferShape(origin_input0, 0);
|
||||
auto h_concat_output_shape = AnfAlgo::GetOutputInferShape(h_concat, 0);
|
||||
std::vector<size_t> shape = {origin_output0_shape[0], origin_output0_shape[1],
|
||||
origin_output0_shape[2] + h_concat_output_shape[2]};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input0, 0)}, {shape}, concat.get());
|
||||
// Set attr
|
||||
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat);
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat);
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(2), concat);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat);
|
||||
return concat;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
|
||||
// Create node
|
||||
auto origin_input0 = dynamic_rnn_grad_cnode->input(1);
|
||||
auto origin_input4 = dynamic_rnn_grad_cnode->input(5);
|
||||
auto origin_input4_shape = AnfAlgo::GetOutputInferShape(origin_input4, 0);
|
||||
// Create reshape to change shape
|
||||
std::vector<size_t> shape_tmp;
|
||||
if (origin_input4_shape.size() == 3) {
|
||||
shape_tmp = origin_input4_shape;
|
||||
} else {
|
||||
shape_tmp = {1, origin_input4_shape[0], origin_input4_shape[1]};
|
||||
}
|
||||
std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
|
||||
origin_input4};
|
||||
auto reshape = func_graph->NewCNode(reshape_input);
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape_tmp}, reshape.get());
|
||||
|
||||
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
|
||||
origin_input0, reshape};
|
||||
auto concat = func_graph->NewCNode(concat_inputs);
|
||||
// Set infer data type and shape
|
||||
auto origin_input0_shape = AnfAlgo::GetOutputInferShape(origin_input0, 0);
|
||||
std::vector<size_t> shape = {origin_input0_shape[0], origin_input0_shape[1], origin_input0_shape[2] + shape_tmp[2]};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input0, 0)}, {shape}, concat.get());
|
||||
// Set attr
|
||||
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat);
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat);
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(2), concat);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat);
|
||||
return concat;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr &lstm_input_grad,
|
||||
const AnfNodePtr &concat) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
// Create node
|
||||
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name())),
|
||||
concat, lstm_input_grad};
|
||||
auto batch_matmul = func_graph->NewCNode(matmul_inputs);
|
||||
// Set infer data type and shape
|
||||
auto concat_shape = AnfAlgo::GetOutputInferShape(concat, 0);
|
||||
auto lstm_input_grad_shape = AnfAlgo::GetOutputInferShape(lstm_input_grad, 0);
|
||||
std::vector<size_t> shape = {concat_shape[0], concat_shape[2], lstm_input_grad_shape[2]};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {shape}, batch_matmul.get());
|
||||
// Set attr
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul);
|
||||
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(true), batch_matmul);
|
||||
AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(false), batch_matmul);
|
||||
return batch_matmul;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
const AnfNodePtr &batch_matmul) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
// Create node
|
||||
std::vector<AnfNodePtr> reduce_sum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
|
||||
batch_matmul};
|
||||
auto reduce_sum = func_graph->NewCNode(reduce_sum_inputs);
|
||||
// Set infer data type and shape
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)},
|
||||
{AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 0)}, reduce_sum.get());
|
||||
// Set attr
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int>{0}), reduce_sum);
|
||||
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum);
|
||||
return reduce_sum;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
const AnfNodePtr &lstm_input_grad) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
// Create node
|
||||
std::vector<AnfNodePtr> reduce_sum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
|
||||
lstm_input_grad};
|
||||
auto reduce_sum = func_graph->NewCNode(reduce_sum_inputs);
|
||||
// Set infer data type and shape
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 1)},
|
||||
{AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 1)}, reduce_sum.get());
|
||||
// Set attr
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int>{0, 1}), reduce_sum);
|
||||
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum);
|
||||
return reduce_sum;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef DynamicRnnGradFissionV2::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({prim::kPrimDynamicRNNGrad, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr DynamicRnnGradFissionV2::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto dynamic_rnn_grad_cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
|
||||
if (dynamic_rnn_grad_cnode->inputs().size() < kDynamicRNNGradInputNum + 1) {
|
||||
MS_LOG(INFO) << "The node " << dynamic_rnn_grad_cnode->DebugString() << " has less than "
|
||||
<< kDynamicRNNGradInputNum + 1 << " inputs";
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<AnfNodePtr> new_outputs;
|
||||
auto lstm_input_grad = AddLSTMInputGradNode(func_graph, dynamic_rnn_grad_cnode, &new_outputs);
|
||||
|
||||
size_t t_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(7), 0)[0];
|
||||
AnfNodePtr concat = nullptr;
|
||||
if (t_size != 1) {
|
||||
auto splitv = CreateSplitV(func_graph, dynamic_rnn_grad_cnode);
|
||||
auto h_concat = CreateHConcat(func_graph, dynamic_rnn_grad_cnode, splitv);
|
||||
concat = CreateConcat(func_graph, dynamic_rnn_grad_cnode, h_concat);
|
||||
} else {
|
||||
concat = CreateConcatNodeT1(func_graph, dynamic_rnn_grad_cnode);
|
||||
}
|
||||
|
||||
auto batch_matmul = CreateBatchMatMul(func_graph, lstm_input_grad, concat);
|
||||
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
if (t_size != 1) {
|
||||
auto dw_reduce_sum = CreateDwReduceSum(func_graph, dynamic_rnn_grad_cnode, batch_matmul);
|
||||
make_tuple_inputs.emplace_back(dw_reduce_sum);
|
||||
} else {
|
||||
make_tuple_inputs.emplace_back(batch_matmul);
|
||||
}
|
||||
|
||||
// create reduce_sum_2
|
||||
auto db_reduce_sum = CreateDbReduceSum(func_graph, dynamic_rnn_grad_cnode, lstm_input_grad);
|
||||
make_tuple_inputs.emplace_back(db_reduce_sum);
|
||||
make_tuple_inputs.insert(make_tuple_inputs.end(), new_outputs.begin(), new_outputs.end());
|
||||
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
||||
return make_tuple;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -13,21 +13,22 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_H_
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_V2_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_V2_H_
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class DynamicRNNGradFission : public PatternProcessPass {
|
||||
class DynamicRnnGradFissionV2 : public PatternProcessPass {
|
||||
public:
|
||||
explicit DynamicRNNGradFission(bool multigraph = true) : PatternProcessPass("dynamic_rnn_grad_fission", multigraph) {}
|
||||
~DynamicRNNGradFission() override = default;
|
||||
explicit DynamicRnnGradFissionV2(bool multigraph = true)
|
||||
: PatternProcessPass("dynamic_rnn_grad_fission_v2", multigraph) {}
|
||||
~DynamicRnnGradFissionV2() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_H_
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_V2_H_
|
|
@ -234,6 +234,9 @@ constexpr auto kSparseApplyFtrlName = "SparseApplyFtrl";
|
|||
constexpr auto kSparseApplyFtrlV2Name = "SparseApplyFtrlV2";
|
||||
constexpr auto kSGDName = "SGD";
|
||||
constexpr auto kLARSUpdateName = "LARSUpdate";
|
||||
constexpr auto kBasicLSTMCellCStateGradOpName = "BasicLSTMCellCStateGrad";
|
||||
constexpr auto kBasicLSTMCellCStateGradV2OpName = "BasicLSTMCellCStateGradV2";
|
||||
constexpr auto kMatMulV2OpName = "MatMulV2";
|
||||
|
||||
// Hcom Op Type
|
||||
constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce";
|
||||
|
|
|
@ -282,6 +282,7 @@ from .inv import _inv_tbe
|
|||
from .inv_grad import _inv_grad_tbe
|
||||
from .invert import _invert_tbe
|
||||
from .basic_lstm_cell import _basic_lstm_cell_tbe
|
||||
from .basic_lstm_cell_c_state_grad_v2 import _basic_lstm_cell_c_state_grad_tbe_v2
|
||||
from .basic_lstm_cell_c_state_grad import _basic_lstm_cell_c_state_grad_tbe
|
||||
from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe
|
||||
from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""BasicLSTMCellCStateGradV2 op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
basic_lstm_cell_c_state_grad_op_info_v2 = TBERegOp("BasicLSTMCellCStateGradV2") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("basic_lstm_cell_c_state_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("basic_lstm_cell_c_state_grad_v2") \
|
||||
.attr("forget_bias", "optional", "float", "all") \
|
||||
.attr("activation", "optional", "str", "all") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "c", False, "required", "all") \
|
||||
.input(1, "dy", False, "required", "all") \
|
||||
.input(2, "dht", False, "required", "all") \
|
||||
.input(3, "dct", False, "required", "all") \
|
||||
.input(4, "it", False, "required", "all") \
|
||||
.input(5, "jt", False, "required", "all") \
|
||||
.input(6, "ft", False, "required", "all") \
|
||||
.input(7, "ot", False, "required", "all") \
|
||||
.input(8, "tanhct", False, "required", "all") \
|
||||
.output(0, "dgate", False, "required", "all") \
|
||||
.output(1, "dct_1", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
|
||||
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
|
||||
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(basic_lstm_cell_c_state_grad_op_info_v2)
|
||||
def _basic_lstm_cell_c_state_grad_tbe_v2():
|
||||
"""BasicLSTMCellCStateGradV2 TBE register"""
|
||||
return
|
Loading…
Reference in New Issue