support dynamic rnn and dynamic rnn grad op
This commit is contained in:
parent
31d1a7051b
commit
a24c5b3231
|
@ -140,7 +140,8 @@ static std::map<string, string> tbe_func_adapter_map = {
|
|||
{"inplace_update", "inplace_update_d"},
|
||||
{"matrix_diag", "matrix_diag_d"},
|
||||
{"matrix_diag_part", "matrix_diag_part_d"},
|
||||
{"matrix_set_diag", "matrix_set_diag_d"}};
|
||||
{"matrix_set_diag", "matrix_set_diag_d"},
|
||||
{"l_stm_input_grad", "lstm_input_grad"}};
|
||||
|
||||
void TbeAdapter::NormalizeFuncName(std::string *func_name) {
|
||||
if (func_name == nullptr) {
|
||||
|
|
|
@ -150,7 +150,13 @@ bool TbeKernelJsonCreator::GenInputDescJson(const std::shared_ptr<AnfNode> &anf_
|
|||
MS_EXCEPTION_IF_NULL(input_ptr);
|
||||
MS_EXCEPTION_IF_NULL(input_list);
|
||||
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
|
||||
if (input_ptr->name() == "input_indices" && op_name == kTopKOpName) {
|
||||
if (op_name == kDynamicRNNOpName && input_ptr->name() == "seq_length") {
|
||||
nlohmann::json input_desc_json;
|
||||
auto in_name = input_ptr->name();
|
||||
input_desc_json[kJName] = in_name + std::to_string(input_i);
|
||||
input_desc_json[kJValid] = false;
|
||||
input_list->emplace_back(input_desc_json);
|
||||
} else if (input_ptr->name() == "input_indices" && op_name == kTopKOpName) {
|
||||
TbeAdapter::GenTopKV2IndicesTensorInfo(anf_node, real_input_index, input_list, creater_type_);
|
||||
} else {
|
||||
auto dtype = GetDeviceInputType(anf_node, real_input_index);
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/bn_split.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/bn_grad_split.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h"
|
||||
|
@ -107,6 +108,7 @@
|
|||
#include "backend/optimizer/ascend/ir_fission/concat_fission.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/pack_fission.h"
|
||||
#include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h"
|
||||
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
|
||||
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
|
||||
|
@ -278,6 +280,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
|
|||
}
|
||||
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicRNN>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DynamicRNNGradFission>());
|
||||
AddAscendIRFusionRulesPass(ir_fusion_pm.get());
|
||||
AddAscendIRFusionPass(ir_fusion_pm.get());
|
||||
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "utils/utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "base/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef InsertPlaceholderForDynamicRNN::DefinePattern() const {
|
||||
std::shared_ptr<Var> V = std::make_shared<CondVar>(UnVisited);
|
||||
std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({V, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr InsertPlaceholderForDynamicRNN::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto op_name = AnfAlgo::GetCNodeName(cnode);
|
||||
if (op_name != kDynamicRNNOpName) {
|
||||
return nullptr;
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
||||
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(node);
|
||||
if (input_num == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
|
||||
for (size_t in_idx = 0; in_idx < input_num; in_idx++) {
|
||||
auto input_node = AnfAlgo::GetInputNode(cnode, in_idx);
|
||||
if (in_idx == 3) {
|
||||
auto value = std::make_shared<None>();
|
||||
auto value_node = NewValueNode(value);
|
||||
value_node->set_abstract(std::make_shared<abstract::AbstractNone>());
|
||||
auto new_node = kernel_graph->NewValueNode(value_node);
|
||||
kernel_graph->AddValueNodeToGraph(new_node);
|
||||
new_inputs.push_back(new_node);
|
||||
}
|
||||
new_inputs.push_back(input_node);
|
||||
}
|
||||
|
||||
CNodePtr new_node = nullptr;
|
||||
if (kernel_graph == nullptr) {
|
||||
new_node = std::make_shared<CNode>(*cnode);
|
||||
} else {
|
||||
new_node = kernel_graph->NewCNode(cnode);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
new_node->set_inputs(new_inputs);
|
||||
return new_node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_PLACEHOLDER_FOR_DYNAMIC_RNN_H
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_PLACEHOLDER_FOR_DYNAMIC_RNN_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/ascend/ascend_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class InsertPlaceholderForDynamicRNN : public PatternProcessPass {
|
||||
public:
|
||||
explicit InsertPlaceholderForDynamicRNN(bool multigraph = true)
|
||||
: PatternProcessPass("add_placeholder_for_dynamic_rnn", multigraph) {}
|
||||
~InsertPlaceholderForDynamicRNN() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_PLACEHOLDER_FOR_DYNAMIC_RNN_H
|
|
@ -0,0 +1,250 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
constexpr size_t kDynamicRNNGradInputNum = 16;
|
||||
constexpr size_t kLSTMInputGradOutputNum = 4;
|
||||
const BaseRef DynamicRNNGradFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({prim::kPrimDynamicRNNGrad, Xs});
|
||||
}
|
||||
|
||||
AnfNodePtr CreateSplitVD(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// SplitV
|
||||
std::vector<AnfNodePtr> splitvd_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())), node};
|
||||
auto split_vd = graph->NewCNode(splitvd_input);
|
||||
MS_EXCEPTION_IF_NULL(split_vd);
|
||||
auto dtypes = {AnfAlgo::GetOutputInferDataType(node, 0), AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node, 0)[0] - 1, AnfAlgo::GetOutputInferShape(node, 0)[1],
|
||||
AnfAlgo::GetOutputInferShape(node, 0)[2]};
|
||||
auto shape2 = {IntToSize(1), AnfAlgo::GetOutputInferShape(node, 0)[1], AnfAlgo::GetOutputInferShape(node, 0)[2]};
|
||||
std::vector<std::vector<size_t>> shapes = {shape, shape2};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_vd.get());
|
||||
AnfAlgo::SetNodeAttr("split_dim", MakeValue(0), split_vd);
|
||||
AnfAlgo::SetNodeAttr("num_split", MakeValue(2), split_vd);
|
||||
int tmp = SizeToInt(AnfAlgo::GetOutputInferShape(node, 0)[0]) - 1;
|
||||
AnfAlgo::SetNodeAttr("size_splits", MakeValue(std::vector<int>{tmp, 1}), split_vd);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_vd);
|
||||
return split_vd;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateLSTMInputGrad(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const auto &dynamic_rnn_grad_inputs = cnode->inputs();
|
||||
std::vector<AnfNodePtr> lstm_input_grad_inputs = {NewValueNode(std::make_shared<Primitive>(kLSTMInputGradOpName)),
|
||||
dynamic_rnn_grad_inputs[2],
|
||||
dynamic_rnn_grad_inputs[6],
|
||||
dynamic_rnn_grad_inputs[8],
|
||||
dynamic_rnn_grad_inputs[9],
|
||||
dynamic_rnn_grad_inputs[10],
|
||||
dynamic_rnn_grad_inputs[11],
|
||||
dynamic_rnn_grad_inputs[12],
|
||||
dynamic_rnn_grad_inputs[13],
|
||||
dynamic_rnn_grad_inputs[14],
|
||||
dynamic_rnn_grad_inputs[15],
|
||||
dynamic_rnn_grad_inputs[16]};
|
||||
std::vector<AnfNodePtr> ori_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, node, 5, &ori_outputs);
|
||||
auto lstm_op = graph->NewCNode(lstm_input_grad_inputs);
|
||||
MS_EXCEPTION_IF_NULL(lstm_op);
|
||||
auto ori_type = AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_inputs[8], 0);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(ori_outputs[2], 0), AnfAlgo::GetOutputInferDataType(ori_outputs[3], 0),
|
||||
AnfAlgo::GetOutputInferDataType(ori_outputs[4], 0), ori_type};
|
||||
std::vector<size_t> ori_shape = {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_inputs[8], 0)[0],
|
||||
AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_inputs[8], 0)[1],
|
||||
4 * AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_inputs[8], 0)[2]};
|
||||
auto shapes = {AnfAlgo::GetOutputInferShape(ori_outputs[2], 0), AnfAlgo::GetOutputInferShape(ori_outputs[3], 0),
|
||||
AnfAlgo::GetOutputInferShape(ori_outputs[4], 0), ori_shape};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, lstm_op.get());
|
||||
return lstm_op;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node1);
|
||||
MS_EXCEPTION_IF_NULL(node2);
|
||||
// BatchMatMul
|
||||
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name())),
|
||||
node2, node1};
|
||||
auto batch_matmul = graph->NewCNode(matmul_inputs);
|
||||
MS_EXCEPTION_IF_NULL(batch_matmul);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node1, 0)};
|
||||
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node2, 0)[0], AnfAlgo::GetOutputInferShape(node2, 0)[2],
|
||||
AnfAlgo::GetOutputInferShape(node1, 0)[2]};
|
||||
auto shapes = {shape};
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul);
|
||||
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(true), batch_matmul);
|
||||
AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(false), batch_matmul);
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, batch_matmul.get());
|
||||
return batch_matmul;
|
||||
}
|
||||
|
||||
AnfNodePtr AddHConcatD(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node1);
|
||||
MS_EXCEPTION_IF_NULL(node2);
|
||||
std::vector<AnfNodePtr> ori_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, node2, 2, &ori_outputs);
|
||||
auto ori_shape = AnfAlgo::GetOutputInferShape(node1, 0);
|
||||
std::vector<std::vector<size_t>> shape_tmp;
|
||||
if (ori_shape.size() == 3) {
|
||||
shape_tmp = {ori_shape};
|
||||
} else {
|
||||
shape_tmp = {{IntToSize(1), ori_shape[0], ori_shape[1]}};
|
||||
}
|
||||
auto ori_dtype = {AnfAlgo::GetOutputInferDataType(node1, 0)};
|
||||
// reshape
|
||||
std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
|
||||
node1};
|
||||
auto reshape = graph->NewCNode(reshape_input);
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape);
|
||||
AnfAlgo::SetOutputInferTypeAndShape(ori_dtype, shape_tmp, reshape.get());
|
||||
|
||||
// concatd --> concat
|
||||
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
|
||||
reshape, ori_outputs[0]};
|
||||
auto concat_op = graph->NewCNode(concat_inputs);
|
||||
MS_EXCEPTION_IF_NULL(concat_op);
|
||||
std::vector<size_t> input = {AnfAlgo::GetOutputInferShape(node2, 0)[0] + 1, AnfAlgo::GetOutputInferShape(node2, 0)[1],
|
||||
AnfAlgo::GetOutputInferShape(node2, 0)[2]};
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node1, 0)};
|
||||
auto shapes = {input};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, concat_op.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat_op);
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat_op);
|
||||
AnfAlgo::SetNodeAttr("axis", MakeValue(0), concat_op);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat_op);
|
||||
return concat_op;
|
||||
}
|
||||
|
||||
AnfNodePtr AddConcatD(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node1);
|
||||
MS_EXCEPTION_IF_NULL(node2);
|
||||
// concatd --> concat
|
||||
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())), node1,
|
||||
node2};
|
||||
auto concat_op = graph->NewCNode(concat_inputs);
|
||||
MS_EXCEPTION_IF_NULL(concat_op);
|
||||
std::vector<size_t> input = {AnfAlgo::GetOutputInferShape(node1, 0)[0], AnfAlgo::GetOutputInferShape(node1, 0)[1],
|
||||
AnfAlgo::GetOutputInferShape(node1, 0)[2] + AnfAlgo::GetOutputInferShape(node2, 0)[2]};
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node1, 0)};
|
||||
auto shapes = {input};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, concat_op.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat_op);
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat_op);
|
||||
AnfAlgo::SetNodeAttr("axis", MakeValue(2), concat_op);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat_op);
|
||||
return concat_op;
|
||||
}
|
||||
|
||||
AnfNodePtr AddDwReduceSum(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||
// node1 : dynamic output
|
||||
// node2 : matmul
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node1);
|
||||
MS_EXCEPTION_IF_NULL(node2);
|
||||
std::vector<AnfNodePtr> ori_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, node1, 5, &ori_outputs);
|
||||
// ReduceSumd
|
||||
std::vector<AnfNodePtr> reducesum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
|
||||
node2};
|
||||
auto reduce_sumd = graph->NewCNode(reducesum_inputs);
|
||||
MS_EXCEPTION_IF_NULL(reduce_sumd);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(ori_outputs[0], 0)};
|
||||
auto shapes = {AnfAlgo::GetOutputInferShape(ori_outputs[0], 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, reduce_sumd.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int>{0}), reduce_sumd);
|
||||
AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_sumd);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sumd);
|
||||
return reduce_sumd;
|
||||
}
|
||||
|
||||
AnfNodePtr AddDbReduceSum(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||
// node1 lstm output
|
||||
// node2 // dynamic output
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node1);
|
||||
MS_EXCEPTION_IF_NULL(node2);
|
||||
std::vector<AnfNodePtr> ori_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, node2, 5, &ori_outputs);
|
||||
// ReduceSumd --> ReduceSum
|
||||
std::vector<AnfNodePtr> reducerum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
|
||||
node1};
|
||||
auto reduce_sumd = graph->NewCNode(reducerum_inputs);
|
||||
MS_EXCEPTION_IF_NULL(reduce_sumd);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(ori_outputs[1], 0)};
|
||||
auto shapes = {AnfAlgo::GetOutputInferShape(ori_outputs[1], 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, reduce_sumd.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int>{0, 1}), reduce_sumd);
|
||||
AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_sumd);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sumd);
|
||||
return reduce_sumd;
|
||||
}
|
||||
|
||||
const AnfNodePtr DynamicRNNGradFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->size() < kDynamicRNNGradInputNum + 1) {
|
||||
MS_LOG(INFO) << "The input num of DynamicRNNGrad less than" << kDynamicRNNGradInputNum
|
||||
<< ". The node should not be changed";
|
||||
return nullptr;
|
||||
}
|
||||
// input_list of dynamic_rnn_grad
|
||||
const auto &ori_inputs = cnode->inputs();
|
||||
// create split_vd
|
||||
auto split_vd = CreateSplitVD(func_graph, ori_inputs[7]);
|
||||
// create concat_1
|
||||
auto h_concat = AddHConcatD(func_graph, ori_inputs[5], split_vd);
|
||||
// create concat_2
|
||||
auto concat = AddConcatD(func_graph, ori_inputs[1], h_concat);
|
||||
// create lsym_input_grad
|
||||
auto lstm_input_grad = CreateLSTMInputGrad(func_graph, cnode);
|
||||
std::vector<AnfNodePtr> lstm_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_input_grad, kLSTMInputGradOutputNum, &lstm_outputs);
|
||||
// create matmul
|
||||
auto batch_matmul = CreateBatchMatMul(func_graph, lstm_outputs[3], concat);
|
||||
// create reduce_sum_1
|
||||
auto dw_reduce_sum = AddDwReduceSum(func_graph, node, batch_matmul);
|
||||
// create reduce_sum_2
|
||||
auto db_reduce_sum = AddDbReduceSum(func_graph, lstm_outputs[3], node);
|
||||
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple),
|
||||
dw_reduce_sum,
|
||||
db_reduce_sum,
|
||||
lstm_outputs[0],
|
||||
lstm_outputs[1],
|
||||
lstm_outputs[2]};
|
||||
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
return make_tuple;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_H_
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class DynamicRNNGradFission : public PatternProcessPass {
|
||||
public:
|
||||
explicit DynamicRNNGradFission(bool multigraph = true) : PatternProcessPass("dynamic_rnn_grad_fission", multigraph) {}
|
||||
~DynamicRNNGradFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_H_
|
|
@ -21,11 +21,10 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const std::set<std::pair<string, string>> invalid_formats_pair = {{kOpFormat_C1HWNCoC0, kOpFormat_NCHW},
|
||||
{kOpFormat_NCHW, kOpFormat_C1HWNCoC0},
|
||||
{kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT},
|
||||
{kOpFormat_DEFAULT, kOpFormat_FRACTAL_ZN_LSTM},
|
||||
{kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}};
|
||||
const std::set<std::pair<string, string>> invalid_formats_pair = {
|
||||
{kOpFormat_C1HWNCoC0, kOpFormat_NCHW}, {kOpFormat_NCHW, kOpFormat_C1HWNCoC0},
|
||||
{kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT}, {kOpFormat_DEFAULT, kOpFormat_FRACTAL_ZN_LSTM},
|
||||
{kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_DEFAULT}, {kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}};
|
||||
|
||||
bool TransDataSplit::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -83,6 +82,9 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n
|
|||
new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_,
|
||||
false, prim::kPrimTranspose->name());
|
||||
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{2, 3, 1, 0}), new_transpose_node);
|
||||
if (output_format == kOpFormat_FRACTAL_ZN_LSTM) {
|
||||
AnfAlgo::SetNodeAttr("nop_op", MakeValue(true), new_transpose_node);
|
||||
}
|
||||
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transpose_node);
|
||||
|
||||
// trans hwcn to output_format
|
||||
|
|
|
@ -404,7 +404,11 @@ bool IsNopNode(const AnfNodePtr &node) {
|
|||
}
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (nop_nodes.find(AnfAlgo::GetCNodeName(cnode)) == nop_nodes.end()) {
|
||||
bool is_nop_node = false;
|
||||
if (AnfAlgo::HasNodeAttr("nop_op", cnode)) {
|
||||
is_nop_node = AnfAlgo::GetNodeAttr<bool>(cnode, "nop_op");
|
||||
}
|
||||
if (nop_nodes.find(AnfAlgo::GetCNodeName(cnode)) == nop_nodes.end() && !is_nop_node) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -52,8 +52,12 @@ const int kUnSupportMixedDataTypeIndex = -1;
|
|||
bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// Check input data type
|
||||
auto name = AnfAlgo::GetCNodeName(cnode);
|
||||
for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) {
|
||||
TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
|
||||
if (name == kDynamicRNNOpName && input_origin_type == kMetaTypeNone) {
|
||||
continue;
|
||||
}
|
||||
if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) {
|
||||
return false;
|
||||
}
|
||||
|
@ -478,6 +482,9 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
|
|||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get());
|
||||
continue;
|
||||
}
|
||||
if (selected_kernel_info.GetInputFormat(input_index) == kOpFormat_FRACTAL_ZN_LSTM) {
|
||||
continue;
|
||||
}
|
||||
// we set special device info of a input tensor.
|
||||
bool is_ref = false;
|
||||
auto op_info = kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE);
|
||||
|
|
|
@ -127,8 +127,12 @@ bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_i
|
|||
auto kernel_mod = AnfAlgo::GetKernelMod(anf_node_ptr);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
kernel_mod->set_kernel_name(anf_node_ptr->fullname_with_scope());
|
||||
auto op_name = AnfAlgo::GetCNodeName(anf_node_ptr);
|
||||
if (AnfAlgo::GetCNodeName(anf_node_ptr) != kAtomicAddrCleanOpName) {
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node_ptr); ++i) {
|
||||
if (op_name == kDynamicRNNOpName && i == 3) {
|
||||
continue;
|
||||
}
|
||||
auto real_input_index = AnfAlgo::GetRealInputIndex(anf_node_ptr, i);
|
||||
auto device_address = AnfAlgo::GetPrevNodeOutputAddr(anf_node_ptr, real_input_index);
|
||||
AddressPtr input = std::make_shared<Address>();
|
||||
|
|
|
@ -219,6 +219,8 @@ constexpr auto kFusedScaleApplyMomentum = "FusedScaleApplyMomentum";
|
|||
constexpr auto kBasicLSTMCellWeightGradOpName = "BasicLSTMCellWeightGrad";
|
||||
constexpr auto kBasicLSTMCellInputGradOpName = "BasicLSTMCellInputGrad";
|
||||
constexpr auto kBasicLSTMCellOpName = "BasicLSTMCell";
|
||||
constexpr auto kDynamicRNNOpName = "DynamicRNN";
|
||||
constexpr auto kLSTMInputGradOpName = "LSTMInputGrad";
|
||||
|
||||
// attr key name
|
||||
constexpr auto kAttrInputNames = "input_names";
|
||||
|
|
|
@ -105,6 +105,8 @@ inline const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("Ar
|
|||
inline const PrimitivePtr kPrimUnique = std::make_shared<Primitive>("Unique");
|
||||
inline const PrimitivePtr kPrimUniqueGrad = std::make_shared<Primitive>("UniqueGrad");
|
||||
inline const PrimitivePtr kPrimExtractImagePatches = std::make_shared<Primitive>("ExtractImagePatches");
|
||||
inline const PrimitivePtr kPrimDynamicRNN = std::make_shared<Primitive>("DynamicRNN");
|
||||
inline const PrimitivePtr kPrimDynamicRNNGrad = std::make_shared<Primitive>("DynamicRNNGrad");
|
||||
|
||||
// NN
|
||||
inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
|
||||
|
@ -214,6 +216,7 @@ inline const PrimitivePtr kPrimRound = std::make_shared<Primitive>("Round");
|
|||
inline const PrimitivePtr kPrimExp = std::make_shared<Primitive>("Exp");
|
||||
inline const PrimitivePtr kPrimLog = std::make_shared<Primitive>("Log");
|
||||
inline const PrimitivePtr kPrimRsqrt = std::make_shared<Primitive>("Rsqrt");
|
||||
inline const PrimitivePtr kPrimSplitV = std::make_shared<Primitive>("SplitV");
|
||||
|
||||
// Statements
|
||||
inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return");
|
||||
|
|
Loading…
Reference in New Issue