support dynamic rnn and dynamic rnn grad op

This commit is contained in:
liubuyu 2020-09-28 15:25:31 +08:00
parent 31d1a7051b
commit a24c5b3231
13 changed files with 438 additions and 8 deletions

View File

@ -140,7 +140,8 @@ static std::map<string, string> tbe_func_adapter_map = {
{"inplace_update", "inplace_update_d"},
{"matrix_diag", "matrix_diag_d"},
{"matrix_diag_part", "matrix_diag_part_d"},
{"matrix_set_diag", "matrix_set_diag_d"}};
{"matrix_set_diag", "matrix_set_diag_d"},
{"l_stm_input_grad", "lstm_input_grad"}};
void TbeAdapter::NormalizeFuncName(std::string *func_name) {
if (func_name == nullptr) {

View File

@ -150,7 +150,13 @@ bool TbeKernelJsonCreator::GenInputDescJson(const std::shared_ptr<AnfNode> &anf_
MS_EXCEPTION_IF_NULL(input_ptr);
MS_EXCEPTION_IF_NULL(input_list);
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
if (input_ptr->name() == "input_indices" && op_name == kTopKOpName) {
if (op_name == kDynamicRNNOpName && input_ptr->name() == "seq_length") {
nlohmann::json input_desc_json;
auto in_name = input_ptr->name();
input_desc_json[kJName] = in_name + std::to_string(input_i);
input_desc_json[kJValid] = false;
input_list->emplace_back(input_desc_json);
} else if (input_ptr->name() == "input_indices" && op_name == kTopKOpName) {
TbeAdapter::GenTopKV2IndicesTensorInfo(anf_node, real_input_index, input_list, creater_type_);
} else {
auto dtype = GetDeviceInputType(anf_node, real_input_index);

View File

@ -19,6 +19,7 @@
#include <memory>
#include <string>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission.h"
#include "backend/optimizer/ascend/ir_fission/bn_split.h"
#include "backend/optimizer/ascend/ir_fission/bn_grad_split.h"
#include "backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h"
@ -107,6 +108,7 @@
#include "backend/optimizer/ascend/ir_fission/concat_fission.h"
#include "backend/optimizer/ascend/ir_fission/pack_fission.h"
#include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h"
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h"
#include "utils/ms_context.h"
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
@ -278,6 +280,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
}
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicRNN>());
ir_fusion_pm->AddPass(std::make_shared<DynamicRNNGradFission>());
AddAscendIRFusionRulesPass(ir_fusion_pm.get());
AddAscendIRFusionPass(ir_fusion_pm.get());

View File

@ -0,0 +1,77 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h"
#include <vector>
#include <memory>
#include "backend/optimizer/common/helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "utils/utils.h"
#include "abstract/abstract_value.h"
#include "base/core_ops.h"
namespace mindspore {
namespace opt {
const BaseRef InsertPlaceholderForDynamicRNN::DefinePattern() const {
std::shared_ptr<Var> V = std::make_shared<CondVar>(UnVisited);
std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
return VectorRef({V, Xs});
}
const AnfNodePtr InsertPlaceholderForDynamicRNN::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto op_name = AnfAlgo::GetCNodeName(cnode);
if (op_name != kDynamicRNNOpName) {
return nullptr;
}
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
MS_EXCEPTION_IF_NULL(kernel_graph);
size_t input_num = AnfAlgo::GetInputTensorNum(node);
if (input_num == 0) {
return nullptr;
}
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
for (size_t in_idx = 0; in_idx < input_num; in_idx++) {
auto input_node = AnfAlgo::GetInputNode(cnode, in_idx);
if (in_idx == 3) {
auto value = std::make_shared<None>();
auto value_node = NewValueNode(value);
value_node->set_abstract(std::make_shared<abstract::AbstractNone>());
auto new_node = kernel_graph->NewValueNode(value_node);
kernel_graph->AddValueNodeToGraph(new_node);
new_inputs.push_back(new_node);
}
new_inputs.push_back(input_node);
}
CNodePtr new_node = nullptr;
if (kernel_graph == nullptr) {
new_node = std::make_shared<CNode>(*cnode);
} else {
new_node = kernel_graph->NewCNode(cnode);
}
MS_EXCEPTION_IF_NULL(new_node);
new_node->set_inputs(new_inputs);
return new_node;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_PLACEHOLDER_FOR_DYNAMIC_RNN_H
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_PLACEHOLDER_FOR_DYNAMIC_RNN_H
#include <memory>
#include <vector>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ascend_helper.h"
namespace mindspore {
namespace opt {
class InsertPlaceholderForDynamicRNN : public PatternProcessPass {
public:
explicit InsertPlaceholderForDynamicRNN(bool multigraph = true)
: PatternProcessPass("add_placeholder_for_dynamic_rnn", multigraph) {}
~InsertPlaceholderForDynamicRNN() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_PLACEHOLDER_FOR_DYNAMIC_RNN_H

View File

@ -0,0 +1,250 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission.h"
#include <vector>
#include <memory>
#include <algorithm>
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
constexpr size_t kDynamicRNNGradInputNum = 16;
constexpr size_t kLSTMInputGradOutputNum = 4;
const BaseRef DynamicRNNGradFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({prim::kPrimDynamicRNNGrad, Xs});
}
AnfNodePtr CreateSplitVD(const FuncGraphPtr &graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
// SplitV
std::vector<AnfNodePtr> splitvd_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())), node};
auto split_vd = graph->NewCNode(splitvd_input);
MS_EXCEPTION_IF_NULL(split_vd);
auto dtypes = {AnfAlgo::GetOutputInferDataType(node, 0), AnfAlgo::GetOutputInferDataType(node, 0)};
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node, 0)[0] - 1, AnfAlgo::GetOutputInferShape(node, 0)[1],
AnfAlgo::GetOutputInferShape(node, 0)[2]};
auto shape2 = {IntToSize(1), AnfAlgo::GetOutputInferShape(node, 0)[1], AnfAlgo::GetOutputInferShape(node, 0)[2]};
std::vector<std::vector<size_t>> shapes = {shape, shape2};
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_vd.get());
AnfAlgo::SetNodeAttr("split_dim", MakeValue(0), split_vd);
AnfAlgo::SetNodeAttr("num_split", MakeValue(2), split_vd);
int tmp = SizeToInt(AnfAlgo::GetOutputInferShape(node, 0)[0]) - 1;
AnfAlgo::SetNodeAttr("size_splits", MakeValue(std::vector<int>{tmp, 1}), split_vd);
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_vd);
return split_vd;
}
AnfNodePtr CreateLSTMInputGrad(const FuncGraphPtr &graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &dynamic_rnn_grad_inputs = cnode->inputs();
std::vector<AnfNodePtr> lstm_input_grad_inputs = {NewValueNode(std::make_shared<Primitive>(kLSTMInputGradOpName)),
dynamic_rnn_grad_inputs[2],
dynamic_rnn_grad_inputs[6],
dynamic_rnn_grad_inputs[8],
dynamic_rnn_grad_inputs[9],
dynamic_rnn_grad_inputs[10],
dynamic_rnn_grad_inputs[11],
dynamic_rnn_grad_inputs[12],
dynamic_rnn_grad_inputs[13],
dynamic_rnn_grad_inputs[14],
dynamic_rnn_grad_inputs[15],
dynamic_rnn_grad_inputs[16]};
std::vector<AnfNodePtr> ori_outputs;
CreateMultipleOutputsOfAnfNode(graph, node, 5, &ori_outputs);
auto lstm_op = graph->NewCNode(lstm_input_grad_inputs);
MS_EXCEPTION_IF_NULL(lstm_op);
auto ori_type = AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_inputs[8], 0);
auto types = {AnfAlgo::GetOutputInferDataType(ori_outputs[2], 0), AnfAlgo::GetOutputInferDataType(ori_outputs[3], 0),
AnfAlgo::GetOutputInferDataType(ori_outputs[4], 0), ori_type};
std::vector<size_t> ori_shape = {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_inputs[8], 0)[0],
AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_inputs[8], 0)[1],
4 * AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_inputs[8], 0)[2]};
auto shapes = {AnfAlgo::GetOutputInferShape(ori_outputs[2], 0), AnfAlgo::GetOutputInferShape(ori_outputs[3], 0),
AnfAlgo::GetOutputInferShape(ori_outputs[4], 0), ori_shape};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, lstm_op.get());
return lstm_op;
}
AnfNodePtr CreateBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node1);
MS_EXCEPTION_IF_NULL(node2);
// BatchMatMul
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name())),
node2, node1};
auto batch_matmul = graph->NewCNode(matmul_inputs);
MS_EXCEPTION_IF_NULL(batch_matmul);
auto types = {AnfAlgo::GetOutputInferDataType(node1, 0)};
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node2, 0)[0], AnfAlgo::GetOutputInferShape(node2, 0)[2],
AnfAlgo::GetOutputInferShape(node1, 0)[2]};
auto shapes = {shape};
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul);
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(true), batch_matmul);
AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(false), batch_matmul);
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, batch_matmul.get());
return batch_matmul;
}
AnfNodePtr AddHConcatD(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node1);
MS_EXCEPTION_IF_NULL(node2);
std::vector<AnfNodePtr> ori_outputs;
CreateMultipleOutputsOfAnfNode(graph, node2, 2, &ori_outputs);
auto ori_shape = AnfAlgo::GetOutputInferShape(node1, 0);
std::vector<std::vector<size_t>> shape_tmp;
if (ori_shape.size() == 3) {
shape_tmp = {ori_shape};
} else {
shape_tmp = {{IntToSize(1), ori_shape[0], ori_shape[1]}};
}
auto ori_dtype = {AnfAlgo::GetOutputInferDataType(node1, 0)};
// reshape
std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
node1};
auto reshape = graph->NewCNode(reshape_input);
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape);
AnfAlgo::SetOutputInferTypeAndShape(ori_dtype, shape_tmp, reshape.get());
// concatd --> concat
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
reshape, ori_outputs[0]};
auto concat_op = graph->NewCNode(concat_inputs);
MS_EXCEPTION_IF_NULL(concat_op);
std::vector<size_t> input = {AnfAlgo::GetOutputInferShape(node2, 0)[0] + 1, AnfAlgo::GetOutputInferShape(node2, 0)[1],
AnfAlgo::GetOutputInferShape(node2, 0)[2]};
auto types = {AnfAlgo::GetOutputInferDataType(node1, 0)};
auto shapes = {input};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, concat_op.get());
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat_op);
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat_op);
AnfAlgo::SetNodeAttr("axis", MakeValue(0), concat_op);
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat_op);
return concat_op;
}
AnfNodePtr AddConcatD(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node1);
MS_EXCEPTION_IF_NULL(node2);
// concatd --> concat
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())), node1,
node2};
auto concat_op = graph->NewCNode(concat_inputs);
MS_EXCEPTION_IF_NULL(concat_op);
std::vector<size_t> input = {AnfAlgo::GetOutputInferShape(node1, 0)[0], AnfAlgo::GetOutputInferShape(node1, 0)[1],
AnfAlgo::GetOutputInferShape(node1, 0)[2] + AnfAlgo::GetOutputInferShape(node2, 0)[2]};
auto types = {AnfAlgo::GetOutputInferDataType(node1, 0)};
auto shapes = {input};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, concat_op.get());
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat_op);
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat_op);
AnfAlgo::SetNodeAttr("axis", MakeValue(2), concat_op);
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat_op);
return concat_op;
}
AnfNodePtr AddDwReduceSum(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) {
// node1 : dynamic output
// node2 : matmul
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node1);
MS_EXCEPTION_IF_NULL(node2);
std::vector<AnfNodePtr> ori_outputs;
CreateMultipleOutputsOfAnfNode(graph, node1, 5, &ori_outputs);
// ReduceSumd
std::vector<AnfNodePtr> reducesum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
node2};
auto reduce_sumd = graph->NewCNode(reducesum_inputs);
MS_EXCEPTION_IF_NULL(reduce_sumd);
auto types = {AnfAlgo::GetOutputInferDataType(ori_outputs[0], 0)};
auto shapes = {AnfAlgo::GetOutputInferShape(ori_outputs[0], 0)};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, reduce_sumd.get());
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int>{0}), reduce_sumd);
AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_sumd);
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sumd);
return reduce_sumd;
}
AnfNodePtr AddDbReduceSum(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) {
// node1 lstm output
// node2 // dynamic output
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node1);
MS_EXCEPTION_IF_NULL(node2);
std::vector<AnfNodePtr> ori_outputs;
CreateMultipleOutputsOfAnfNode(graph, node2, 5, &ori_outputs);
// ReduceSumd --> ReduceSum
std::vector<AnfNodePtr> reducerum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
node1};
auto reduce_sumd = graph->NewCNode(reducerum_inputs);
MS_EXCEPTION_IF_NULL(reduce_sumd);
auto types = {AnfAlgo::GetOutputInferDataType(ori_outputs[1], 0)};
auto shapes = {AnfAlgo::GetOutputInferShape(ori_outputs[1], 0)};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, reduce_sumd.get());
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int>{0, 1}), reduce_sumd);
AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_sumd);
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sumd);
return reduce_sumd;
}
const AnfNodePtr DynamicRNNGradFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() < kDynamicRNNGradInputNum + 1) {
MS_LOG(INFO) << "The input num of DynamicRNNGrad less than" << kDynamicRNNGradInputNum
<< ". The node should not be changed";
return nullptr;
}
// input_list of dynamic_rnn_grad
const auto &ori_inputs = cnode->inputs();
// create split_vd
auto split_vd = CreateSplitVD(func_graph, ori_inputs[7]);
// create concat_1
auto h_concat = AddHConcatD(func_graph, ori_inputs[5], split_vd);
// create concat_2
auto concat = AddConcatD(func_graph, ori_inputs[1], h_concat);
// create lsym_input_grad
auto lstm_input_grad = CreateLSTMInputGrad(func_graph, cnode);
std::vector<AnfNodePtr> lstm_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, lstm_input_grad, kLSTMInputGradOutputNum, &lstm_outputs);
// create matmul
auto batch_matmul = CreateBatchMatMul(func_graph, lstm_outputs[3], concat);
// create reduce_sum_1
auto dw_reduce_sum = AddDwReduceSum(func_graph, node, batch_matmul);
// create reduce_sum_2
auto db_reduce_sum = AddDbReduceSum(func_graph, lstm_outputs[3], node);
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple),
dw_reduce_sum,
db_reduce_sum,
lstm_outputs[0],
lstm_outputs[1],
lstm_outputs[2]};
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
MS_EXCEPTION_IF_NULL(make_tuple);
return make_tuple;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,33 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_H_
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class DynamicRNNGradFission : public PatternProcessPass {
public:
explicit DynamicRNNGradFission(bool multigraph = true) : PatternProcessPass("dynamic_rnn_grad_fission", multigraph) {}
~DynamicRNNGradFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_H_

View File

@ -21,11 +21,10 @@
namespace mindspore {
namespace opt {
const std::set<std::pair<string, string>> invalid_formats_pair = {{kOpFormat_C1HWNCoC0, kOpFormat_NCHW},
{kOpFormat_NCHW, kOpFormat_C1HWNCoC0},
{kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT},
{kOpFormat_DEFAULT, kOpFormat_FRACTAL_ZN_LSTM},
{kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}};
const std::set<std::pair<string, string>> invalid_formats_pair = {
{kOpFormat_C1HWNCoC0, kOpFormat_NCHW}, {kOpFormat_NCHW, kOpFormat_C1HWNCoC0},
{kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT}, {kOpFormat_DEFAULT, kOpFormat_FRACTAL_ZN_LSTM},
{kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_DEFAULT}, {kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}};
bool TransDataSplit::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
@ -83,6 +82,9 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n
new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_,
false, prim::kPrimTranspose->name());
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{2, 3, 1, 0}), new_transpose_node);
if (output_format == kOpFormat_FRACTAL_ZN_LSTM) {
AnfAlgo::SetNodeAttr("nop_op", MakeValue(true), new_transpose_node);
}
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transpose_node);
// trans hwcn to output_format

View File

@ -404,7 +404,11 @@ bool IsNopNode(const AnfNodePtr &node) {
}
CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (nop_nodes.find(AnfAlgo::GetCNodeName(cnode)) == nop_nodes.end()) {
bool is_nop_node = false;
if (AnfAlgo::HasNodeAttr("nop_op", cnode)) {
is_nop_node = AnfAlgo::GetNodeAttr<bool>(cnode, "nop_op");
}
if (nop_nodes.find(AnfAlgo::GetCNodeName(cnode)) == nop_nodes.end() && !is_nop_node) {
return false;
}
return true;

View File

@ -52,8 +52,12 @@ const int kUnSupportMixedDataTypeIndex = -1;
bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) {
MS_EXCEPTION_IF_NULL(cnode);
// Check input data type
auto name = AnfAlgo::GetCNodeName(cnode);
for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) {
TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
if (name == kDynamicRNNOpName && input_origin_type == kMetaTypeNone) {
continue;
}
if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) {
return false;
}
@ -478,6 +482,9 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get());
continue;
}
if (selected_kernel_info.GetInputFormat(input_index) == kOpFormat_FRACTAL_ZN_LSTM) {
continue;
}
// we set special device info of a input tensor.
bool is_ref = false;
auto op_info = kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE);

View File

@ -127,8 +127,12 @@ bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_i
auto kernel_mod = AnfAlgo::GetKernelMod(anf_node_ptr);
MS_EXCEPTION_IF_NULL(kernel_mod);
kernel_mod->set_kernel_name(anf_node_ptr->fullname_with_scope());
auto op_name = AnfAlgo::GetCNodeName(anf_node_ptr);
if (AnfAlgo::GetCNodeName(anf_node_ptr) != kAtomicAddrCleanOpName) {
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node_ptr); ++i) {
if (op_name == kDynamicRNNOpName && i == 3) {
continue;
}
auto real_input_index = AnfAlgo::GetRealInputIndex(anf_node_ptr, i);
auto device_address = AnfAlgo::GetPrevNodeOutputAddr(anf_node_ptr, real_input_index);
AddressPtr input = std::make_shared<Address>();

View File

@ -219,6 +219,8 @@ constexpr auto kFusedScaleApplyMomentum = "FusedScaleApplyMomentum";
constexpr auto kBasicLSTMCellWeightGradOpName = "BasicLSTMCellWeightGrad";
constexpr auto kBasicLSTMCellInputGradOpName = "BasicLSTMCellInputGrad";
constexpr auto kBasicLSTMCellOpName = "BasicLSTMCell";
constexpr auto kDynamicRNNOpName = "DynamicRNN";
constexpr auto kLSTMInputGradOpName = "LSTMInputGrad";
// attr key name
constexpr auto kAttrInputNames = "input_names";

View File

@ -105,6 +105,8 @@ inline const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("Ar
inline const PrimitivePtr kPrimUnique = std::make_shared<Primitive>("Unique");
inline const PrimitivePtr kPrimUniqueGrad = std::make_shared<Primitive>("UniqueGrad");
inline const PrimitivePtr kPrimExtractImagePatches = std::make_shared<Primitive>("ExtractImagePatches");
inline const PrimitivePtr kPrimDynamicRNN = std::make_shared<Primitive>("DynamicRNN");
inline const PrimitivePtr kPrimDynamicRNNGrad = std::make_shared<Primitive>("DynamicRNNGrad");
// NN
inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
@ -214,6 +216,7 @@ inline const PrimitivePtr kPrimRound = std::make_shared<Primitive>("Round");
inline const PrimitivePtr kPrimExp = std::make_shared<Primitive>("Exp");
inline const PrimitivePtr kPrimLog = std::make_shared<Primitive>("Log");
inline const PrimitivePtr kPrimRsqrt = std::make_shared<Primitive>("Rsqrt");
inline const PrimitivePtr kPrimSplitV = std::make_shared<Primitive>("SplitV");
// Statements
inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return");