forked from mindspore-Ecosystem/mindspore
!25744 IR fusion adapts dump flag
Merge pull request !25744 from yuchaojie/ir_fusion
This commit is contained in:
commit
10b63dffc0
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -20,16 +20,15 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
AnfNodePtr CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_addn_cnode, size_t begin_index,
|
||||
size_t offset) {
|
||||
AnfNodePtr AddnFission::CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_addn_cnode,
|
||||
size_t begin_index, size_t offset) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(origin_addn_cnode);
|
||||
std::vector<AnfNodePtr> new_addn_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimAddN->name()))};
|
||||
for (size_t i = begin_index; i < begin_index + offset; ++i) {
|
||||
new_addn_inputs.emplace_back(origin_addn_cnode->input(i));
|
||||
}
|
||||
CNodePtr new_addn = func_graph->NewCNode(new_addn_inputs);
|
||||
CNodePtr new_addn = NewCNode(new_addn_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(new_addn);
|
||||
new_addn->set_scope(origin_addn_cnode->scope());
|
||||
new_addn->set_abstract(origin_addn_cnode->abstract());
|
||||
|
@ -38,7 +37,6 @@ AnfNodePtr CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_
|
|||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_addn);
|
||||
return new_addn;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef AddnFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
|
@ -68,7 +66,7 @@ const AnfNodePtr AddnFission::Process(const FuncGraphPtr &func_graph, const AnfN
|
|||
for (size_t i = cur_input_index; i <= origin_input_size; i++) {
|
||||
base_addn_inputs.emplace_back(new_cnode->input(i));
|
||||
}
|
||||
CNodePtr base_addn = func_graph->NewCNode(base_addn_inputs);
|
||||
CNodePtr base_addn = NewCNode(base_addn_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(base_addn);
|
||||
base_addn->set_scope(new_cnode->scope());
|
||||
base_addn->set_abstract(new_cnode->abstract());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -30,6 +30,8 @@ class AddnFission : public PatternProcessPass {
|
|||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
AnfNodePtr CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_addn_cnode, size_t begin_index,
|
||||
size_t offset) const;
|
||||
size_t inputs_divisor_;
|
||||
};
|
||||
} // namespace opt
|
||||
|
|
|
@ -58,8 +58,9 @@ bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, s
|
|||
}
|
||||
return output_num == kBatchNormRealOutputNum;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) {
|
||||
AnfNodePtr BatchNormBertFission::CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(bn);
|
||||
auto bn_cnode = bn->cast<CNodePtr>();
|
||||
|
@ -70,7 +71,7 @@ AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodeP
|
|||
}
|
||||
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)), bn_cnode->input(kIndex1)};
|
||||
auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs);
|
||||
auto bn_training_reduce = NewCNode(bn_training_reduce_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_training_reduce);
|
||||
auto bn_input1 = bn_cnode->input(kIndex2);
|
||||
MS_EXCEPTION_IF_NULL(bn_input1);
|
||||
|
@ -84,8 +85,9 @@ AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodeP
|
|||
return bn_training_reduce;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNodePtr &bn,
|
||||
const std::vector<AnfNodePtr> &bn_training_reduce_outputs) {
|
||||
AnfNodePtr BatchNormBertFission::CreateBNTrainingUpdateV2(
|
||||
const FuncGraphPtr &func_graph, const AnfNodePtr &bn,
|
||||
const std::vector<AnfNodePtr> &bn_training_reduce_outputs) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(bn);
|
||||
auto bn_cnode = bn->cast<CNodePtr>();
|
||||
|
@ -106,7 +108,7 @@ AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNod
|
|||
bn_training_reduce_outputs[kIndex1],
|
||||
bn_cnode->input(kIndex2),
|
||||
bn_cnode->input(kIndex3)};
|
||||
auto bn_training_update_v2 = func_graph->NewCNode(bn_training_update_v2_inputs);
|
||||
auto bn_training_update_v2 = NewCNode(bn_training_update_v2_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_training_update_v2);
|
||||
|
||||
auto bn_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn->abstract());
|
||||
|
@ -124,7 +126,6 @@ AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNod
|
|||
AnfAlgo::CopyNodeAttrs(bn, bn_training_update_v2);
|
||||
return bn_training_update_v2;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef BatchNormBertFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,6 +16,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BATCH_NORM_BERT_FISSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BATCH_NORM_BERT_FISSION_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -26,6 +27,11 @@ class BatchNormBertFission : public PatternProcessPass {
|
|||
~BatchNormBertFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) const;
|
||||
AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNodePtr &bn,
|
||||
const std::vector<AnfNodePtr> &bn_training_reduce_outputs) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -79,7 +79,7 @@ AnfNodePtr BatchNormGradInferFission::CreateBNInferGrad(const FuncGraphPtr &func
|
|||
std::vector<AnfNodePtr> bn_infer_grad_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kBNInferGradOpName)), utils::cast<AnfNodePtr>(iter_input0->second),
|
||||
utils::cast<AnfNodePtr>(iter_input2->second), utils::cast<AnfNodePtr>(iter_input4->second)};
|
||||
auto bn_infer_grad = func_graph->NewCNode(bn_infer_grad_inputs);
|
||||
auto bn_infer_grad = NewCNode(bn_infer_grad_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_infer_grad);
|
||||
// Set abstract, the output of new node is taking the place of the 0th output of bn_grad.
|
||||
auto bn_grad_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn_grad->abstract());
|
||||
|
@ -125,7 +125,7 @@ AnfNodePtr BatchNormGradInferFission::CreateBNTrainingUpdateGrad(const FuncGraph
|
|||
NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateGradOpName)),
|
||||
utils::cast<AnfNodePtr>(iter_input0->second), utils::cast<AnfNodePtr>(iter_input1->second),
|
||||
utils::cast<AnfNodePtr>(iter_input3->second), utils::cast<AnfNodePtr>(iter_input4->second)};
|
||||
auto bn_training_update_grad = func_graph->NewCNode(bn_training_update_grad_inputs);
|
||||
auto bn_training_update_grad = NewCNode(bn_training_update_grad_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_training_update_grad);
|
||||
// Set abstract, the outputs of new node are taking the place of the 1st and 2nd outputs of bn_grad.
|
||||
auto bn_grad_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn_grad->abstract());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -27,9 +27,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
|
||||
std::vector<AnfNodePtr> *bn_update_grad_outputs) {
|
||||
void BatchNormGradSplit::CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
|
||||
std::vector<AnfNodePtr> *bn_update_grad_outputs) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_grad_node);
|
||||
const auto &bn_grad_inputs = bn_grad_node->inputs();
|
||||
|
@ -37,7 +36,7 @@ void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
|
|||
std::vector<AnfNodePtr> bn_update_grad_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateGradOpName)), bn_grad_inputs[kIndex1],
|
||||
bn_grad_inputs[kIndex2], bn_grad_inputs[kIndex4], bn_grad_inputs[kIndex5]};
|
||||
auto bn_update_grad = graph->NewCNode(bn_update_grad_inputs);
|
||||
auto bn_update_grad = NewCNode(bn_update_grad_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_update_grad);
|
||||
bn_update_grad->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
bn_update_grad->set_scope(bn_grad_node->scope());
|
||||
|
@ -50,9 +49,9 @@ void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
|
|||
CreateMultipleOutputsOfAnfNode(graph, bn_update_grad, kBNTrainingUpdateGradOutputNum, bn_update_grad_outputs);
|
||||
}
|
||||
|
||||
void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
|
||||
const std::vector<AnfNodePtr> &bn_update_grad_outputs,
|
||||
std::vector<AnfNodePtr> *bn_reduce_grad_outputs) {
|
||||
void BatchNormGradSplit::CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
|
||||
const std::vector<AnfNodePtr> &bn_update_grad_outputs,
|
||||
std::vector<AnfNodePtr> *bn_reduce_grad_outputs) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_grad_node);
|
||||
MS_EXCEPTION_IF_NULL(bn_reduce_grad_outputs);
|
||||
|
@ -71,7 +70,7 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
|
|||
bn_grad_inputs[kIndex3],
|
||||
bn_grad_inputs[kIndex4],
|
||||
bn_grad_inputs[kIndex5]};
|
||||
auto bn_reduce_grad = graph->NewCNode(bn_reduce_grad_inputs);
|
||||
auto bn_reduce_grad = NewCNode(bn_reduce_grad_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_reduce_grad);
|
||||
bn_reduce_grad->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
bn_reduce_grad->set_scope(bn_grad_node->scope());
|
||||
|
@ -83,7 +82,7 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
|
|||
AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_reduce_grad);
|
||||
(*bn_reduce_grad_outputs).push_back(bn_reduce_grad);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef BatchNormGradSplit::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
auto prim = std::make_shared<Primitive>(kBatchNormGradOpName);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,6 +16,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BATCH_NORM_GRAD_SPLIT_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BATCH_NORM_GRAD_SPLIT_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
|
@ -27,6 +28,13 @@ class BatchNormGradSplit : public PatternProcessPass {
|
|||
~BatchNormGradSplit() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
|
||||
std::vector<AnfNodePtr> *bn_update_grad_outputs) const;
|
||||
void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
|
||||
const std::vector<AnfNodePtr> &bn_update_grad_outputs,
|
||||
std::vector<AnfNodePtr> *bn_reduce_grad_outputs) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,8 +26,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
AnfNodePtr BCEWithLogitsLossFission::AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
|
@ -36,7 +35,7 @@ AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node)
|
|||
std::vector<AnfNodePtr> new_simoid_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimBCEWithLogitsLoss->name()))};
|
||||
new_simoid_inputs.insert(new_simoid_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
CNodePtr new_cnode = func_graph->NewCNode(new_simoid_inputs);
|
||||
CNodePtr new_cnode = NewCNode(new_simoid_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
auto predict_input = cnode->inputs()[kIndex1];
|
||||
auto new_node_dtype = {AnfAlgo::GetOutputInferDataType(predict_input, 0)};
|
||||
|
@ -55,7 +54,7 @@ AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node)
|
|||
MS_LOG(INFO) << "Reduction attr is not mean or sum, can not do fission.";
|
||||
return nullptr;
|
||||
}
|
||||
auto reduce_node = func_graph->NewCNode(reduce_inputs);
|
||||
auto reduce_node = NewCNode(reduce_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(reduce_node);
|
||||
auto type = AnfAlgo::GetOutputInferDataType(node, 0);
|
||||
if (type == kNumberTypeFloat16) {
|
||||
|
@ -69,7 +68,6 @@ AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node)
|
|||
reduce_node->set_scope(cnode->scope());
|
||||
return reduce_node;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef BCEWithLogitsLossFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
|
|
|
@ -28,6 +28,9 @@ class BCEWithLogitsLossFission : public PatternProcessPass {
|
|||
~BCEWithLogitsLossFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -28,9 +28,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
|
||||
std::vector<AnfNodePtr> *bn_update_grad_outputs) {
|
||||
void BnGradSplit::CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
|
||||
std::vector<AnfNodePtr> *bn_update_grad_outputs) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_grad_node);
|
||||
auto bn_grad_inputs = bn_grad_node->inputs();
|
||||
|
@ -38,7 +37,7 @@ void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
|
|||
std::vector<AnfNodePtr> bn_update_grad_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateGradOpName)), bn_grad_inputs[kIndex1],
|
||||
bn_grad_inputs[kIndex2], bn_grad_inputs[kIndex4], bn_grad_inputs[kIndex5]};
|
||||
auto bn_update_grad = graph->NewCNode(bn_update_grad_inputs);
|
||||
auto bn_update_grad = NewCNode(bn_update_grad_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_update_grad);
|
||||
bn_update_grad->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
bn_update_grad->set_scope(bn_grad_node->scope());
|
||||
|
@ -51,9 +50,9 @@ void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
|
|||
CreateMultipleOutputsOfAnfNode(graph, bn_update_grad, kBNTrainingUpdateGradOutputNum, bn_update_grad_outputs);
|
||||
}
|
||||
|
||||
void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
|
||||
const std::vector<AnfNodePtr> &bn_update_grad_outputs,
|
||||
std::vector<AnfNodePtr> *bn_reduce_grad_outputs) {
|
||||
void BnGradSplit::CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
|
||||
const std::vector<AnfNodePtr> &bn_update_grad_outputs,
|
||||
std::vector<AnfNodePtr> *bn_reduce_grad_outputs) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_grad_node);
|
||||
auto bn_grad_inputs = bn_grad_node->inputs();
|
||||
|
@ -70,7 +69,7 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
|
|||
bn_grad_inputs[kIndex3],
|
||||
bn_grad_inputs[kIndex4],
|
||||
bn_grad_inputs[kIndex5]};
|
||||
auto bn_reduce_grad = graph->NewCNode(bn_reduce_grad_inputs);
|
||||
auto bn_reduce_grad = NewCNode(bn_reduce_grad_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_reduce_grad);
|
||||
bn_reduce_grad->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
bn_reduce_grad->set_scope(bn_grad_node->scope());
|
||||
|
@ -83,7 +82,7 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
|
|||
(*bn_reduce_grad_outputs).push_back(bn_reduce_grad);
|
||||
}
|
||||
|
||||
CNodePtr BNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
CNodePtr BnGradSplit::BNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<AnfNodePtr> bn_update_grad_outputs;
|
||||
CreateOutputsOfUpdateGrad(func_graph, cnode, &bn_update_grad_outputs);
|
||||
|
@ -106,7 +105,7 @@ CNodePtr BNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode
|
|||
return make_tuple;
|
||||
}
|
||||
|
||||
CNodePtr SyncBNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
CNodePtr SyncBnGradSplit::SyncBNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::vector<AnfNodePtr> bn_update_grad_outputs;
|
||||
|
@ -119,7 +118,7 @@ CNodePtr SyncBNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &c
|
|||
|
||||
std::vector<AnfNodePtr> allreduce_mul_outputs;
|
||||
for (size_t i = 0; i < bn_update_grad_outputs.size(); ++i) {
|
||||
auto allreduce_mul_output = CreateAllReduceAndMul(func_graph, bn_update_grad_outputs[i], cnode);
|
||||
auto allreduce_mul_output = CreateAllReduceAndMul(func_graph, bn_update_grad_outputs[i], cnode, *this);
|
||||
allreduce_mul_outputs.emplace_back(allreduce_mul_output);
|
||||
}
|
||||
|
||||
|
@ -136,7 +135,6 @@ CNodePtr SyncBNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &c
|
|||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
return make_tuple;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef BnGradSplit::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,6 +16,8 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
|
@ -23,18 +25,31 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class BnGradSplit : public PatternProcessPass {
|
||||
public:
|
||||
explicit BnGradSplit(bool multigraph = true) : PatternProcessPass("bn_grad_split", multigraph) {}
|
||||
explicit BnGradSplit(string name = "bn_grad_split", bool multigraph = true) : PatternProcessPass(name, multigraph) {}
|
||||
~BnGradSplit() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
protected:
|
||||
void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
|
||||
std::vector<AnfNodePtr> *bn_update_grad_outputs) const;
|
||||
void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
|
||||
const std::vector<AnfNodePtr> &bn_update_grad_outputs,
|
||||
std::vector<AnfNodePtr> *bn_reduce_grad_outputs) const;
|
||||
|
||||
private:
|
||||
CNodePtr BNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const;
|
||||
};
|
||||
|
||||
class SyncBnGradSplit : public PatternProcessPass {
|
||||
class SyncBnGradSplit : public BnGradSplit {
|
||||
public:
|
||||
explicit SyncBnGradSplit(bool multigraph = true) : PatternProcessPass("sync_bn_grad_split", multigraph) {}
|
||||
explicit SyncBnGradSplit(bool multigraph = true) : BnGradSplit("sync_bn_grad_split", multigraph) {}
|
||||
~SyncBnGradSplit() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
CNodePtr SyncBNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,9 +34,10 @@ constexpr auto kReduceOpSum = "sum";
|
|||
constexpr auto kDeviceNum = "device_num";
|
||||
constexpr size_t kPositionOffset = 3;
|
||||
constexpr int64_t kFusionNumThreshold = 2;
|
||||
} // namespace
|
||||
|
||||
bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &bn_cnode,
|
||||
std::vector<AnfNodePtr> *bn_training_reduce_outputs) {
|
||||
bool BnSplit::CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &bn_cnode,
|
||||
std::vector<AnfNodePtr> *bn_training_reduce_outputs) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_cnode);
|
||||
if (AnfAlgo::GetInputTensorNum(bn_cnode) != kBnInputTensorNum) {
|
||||
|
@ -46,7 +47,7 @@ bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &
|
|||
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName))};
|
||||
bn_training_reduce_inputs.push_back(bn_cnode->input(kIndex1));
|
||||
auto bn_training_reduce = graph->NewCNode(bn_training_reduce_inputs);
|
||||
auto bn_training_reduce = NewCNode(bn_training_reduce_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_training_reduce);
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
|
@ -67,8 +68,8 @@ bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &
|
|||
return true;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNodePtr &bn_cnode,
|
||||
const std::vector<AnfNodePtr> &bn_training_reduce_outputs) {
|
||||
AnfNodePtr BnSplit::CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNodePtr &bn_cnode,
|
||||
const std::vector<AnfNodePtr> &bn_training_reduce_outputs) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_cnode);
|
||||
CheckCNodeInputSize(bn_cnode, kBnInputTensorNum);
|
||||
|
@ -86,7 +87,7 @@ AnfNodePtr CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNod
|
|||
bn_training_update_inputs.push_back(bn_cnode->input(kIndex3));
|
||||
bn_training_update_inputs.push_back(bn_cnode->input(kIndex4));
|
||||
bn_training_update_inputs.push_back(bn_cnode->input(kIndex5));
|
||||
auto bn_training_update = graph->NewCNode(bn_training_update_inputs);
|
||||
auto bn_training_update = NewCNode(bn_training_update_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_training_update);
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
|
@ -100,7 +101,7 @@ AnfNodePtr CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNod
|
|||
return bn_training_update;
|
||||
}
|
||||
|
||||
AnfNodePtr SplitBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
AnfNodePtr BnSplit::SplitBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
|
@ -125,7 +126,7 @@ AnfNodePtr SplitBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr
|
|||
return CreateOutputsOfBNTrainingUpdate(func_graph, cnode, bn_training_reduce_outputs);
|
||||
}
|
||||
|
||||
AnfNodePtr SyncBNSplitForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
AnfNodePtr SyncBnSplit::SyncBNSplitForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
|
@ -148,14 +149,13 @@ AnfNodePtr SyncBNSplitForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &n
|
|||
|
||||
std::vector<AnfNodePtr> allreduce_mul_outputs;
|
||||
for (size_t i = 0; i < bn_training_reduce_outputs.size(); ++i) {
|
||||
auto allreduce_mul_output = CreateAllReduceAndMul(func_graph, bn_training_reduce_outputs[i], cnode);
|
||||
auto allreduce_mul_output = CreateAllReduceAndMul(func_graph, bn_training_reduce_outputs[i], cnode, *this);
|
||||
allreduce_mul_outputs.emplace_back(allreduce_mul_output);
|
||||
}
|
||||
|
||||
// Create BNTrainingUpdate node
|
||||
return CreateOutputsOfBNTrainingUpdate(func_graph, cnode, allreduce_mul_outputs);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr CreateValueNodeOfDeviceNumReciprocal(const FuncGraphPtr &graph, const CNodePtr &sync_bn_cnode) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -201,7 +201,7 @@ AnfNodePtr InsertCast(const FuncGraphPtr &graph, const AnfNodePtr &input, const
|
|||
}
|
||||
|
||||
AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &allreduce_input,
|
||||
const CNodePtr &sync_bn_cnode) {
|
||||
const CNodePtr &sync_bn_cnode, const PatternProcessPass &pass) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(allreduce_input);
|
||||
MS_EXCEPTION_IF_NULL(sync_bn_cnode);
|
||||
|
@ -214,7 +214,7 @@ AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &al
|
|||
|
||||
// create AllReduce
|
||||
std::vector<AnfNodePtr> allreduce_inputs = {NewValueNode(std::make_shared<Primitive>(kAllReduceOpName)), input_node};
|
||||
auto allreduce = graph->NewCNode(allreduce_inputs);
|
||||
auto allreduce = pass.NewCNode(allreduce_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(allreduce);
|
||||
allreduce->set_abstract(input_node->abstract());
|
||||
allreduce->set_scope(allreduce_input->scope());
|
||||
|
@ -238,7 +238,7 @@ AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &al
|
|||
auto device_num_reciprocal_vnode = CreateValueNodeOfDeviceNumReciprocal(graph, sync_bn_cnode);
|
||||
std::vector<AnfNodePtr> mul_inputs = {NewValueNode(std::make_shared<Primitive>(kMulOpName)), allreduce,
|
||||
device_num_reciprocal_vnode};
|
||||
auto mul = graph->NewCNode(mul_inputs);
|
||||
auto mul = pass.NewCNode(mul_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(mul);
|
||||
mul->set_abstract(input_node->abstract());
|
||||
mul->set_scope(allreduce_input->scope());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,6 +16,8 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BN_SPLIT_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BN_SPLIT_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
|
@ -23,24 +25,36 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class BnSplit : public PatternProcessPass {
|
||||
public:
|
||||
explicit BnSplit(bool multigraph = true) : PatternProcessPass("bn_split", multigraph) {}
|
||||
explicit BnSplit(string name = "bn_split", bool multigraph = true) : PatternProcessPass(name, multigraph) {}
|
||||
~BnSplit() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
protected:
|
||||
bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &bn_cnode,
|
||||
std::vector<AnfNodePtr> *bn_training_reduce_outputs) const;
|
||||
AnfNodePtr CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNodePtr &bn_cnode,
|
||||
const std::vector<AnfNodePtr> &bn_training_reduce_outputs) const;
|
||||
|
||||
private:
|
||||
AnfNodePtr SplitBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const;
|
||||
};
|
||||
|
||||
class SyncBnSplit : public PatternProcessPass {
|
||||
class SyncBnSplit : public BnSplit {
|
||||
public:
|
||||
explicit SyncBnSplit(bool multigraph = true) : PatternProcessPass("sync_bn_split", multigraph) {}
|
||||
explicit SyncBnSplit(bool multigraph = true) : BnSplit("sync_bn_split", multigraph) {}
|
||||
~SyncBnSplit() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
AnfNodePtr SyncBNSplitForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const;
|
||||
};
|
||||
|
||||
AnfNodePtr CreateValueNodeOfDeviceNumReciprocal(const FuncGraphPtr &graph, const CNodePtr &sync_bn_cnode);
|
||||
|
||||
AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &allreduce_input,
|
||||
const CNodePtr &sync_bn_cnode);
|
||||
const CNodePtr &sync_bn_cnode, const PatternProcessPass &pass);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BN_SPLIT_H_
|
||||
|
|
|
@ -55,13 +55,13 @@ std::vector<size_t> CalCdistBroadCastShape(std::vector<size_t> x_shape, std::vec
|
|||
}
|
||||
|
||||
AnfNodePtr AddBroadCastToNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, int64_t dim,
|
||||
const std::vector<size_t> &need_shape) {
|
||||
const std::vector<size_t> &need_shape, const PatternProcessPass &pass) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
// Add ExpandDims Node
|
||||
std::vector<AnfNodePtr> expand_dims_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimExpandDims->name())), input_node};
|
||||
auto expand_dims = func_graph->NewCNode(expand_dims_inputs);
|
||||
auto expand_dims = pass.NewCNode(expand_dims_inputs, func_graph);
|
||||
auto dtype = AnfAlgo::GetOutputInferDataType(input_node, 0);
|
||||
auto expand_shape = AnfAlgo::GetOutputInferShape(input_node, 0);
|
||||
(void)expand_shape.insert(expand_shape.end() + dim, 1);
|
||||
|
@ -71,7 +71,7 @@ AnfNodePtr AddBroadCastToNode(const FuncGraphPtr &func_graph, const AnfNodePtr &
|
|||
// Add BroadCastTo Node
|
||||
std::vector<AnfNodePtr> broadcast_to_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimBroadcastTo->name())), expand_dims};
|
||||
auto broadcast_to = func_graph->NewCNode(broadcast_to_inputs);
|
||||
auto broadcast_to = pass.NewCNode(broadcast_to_inputs, func_graph);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({dtype}, {need_shape}, broadcast_to.get());
|
||||
std::vector<int64_t> shape;
|
||||
(void)std::transform(need_shape.begin(), need_shape.end(), std::back_inserter(shape), LongToSize);
|
||||
|
@ -110,11 +110,11 @@ const AnfNodePtr CdistFission::Process(const FuncGraphPtr &graph, const AnfNodeP
|
|||
auto x_shape = AnfAlgo::GetOutputInferShape(cdist_inputs[kDim1], 0);
|
||||
auto y_shape = AnfAlgo::GetOutputInferShape(cdist_inputs[kDim2], 0);
|
||||
auto broadcast_to_shape = CalCdistBroadCastShape(x_shape, y_shape);
|
||||
auto broadcast_input_x = AddBroadCastToNode(graph, cdist_inputs[kDim1], kInputXDimP, broadcast_to_shape);
|
||||
auto broadcast_input_y = AddBroadCastToNode(graph, cdist_inputs[kDim2], kInputYDimR, broadcast_to_shape);
|
||||
auto broadcast_input_x = AddBroadCastToNode(graph, cdist_inputs[kDim1], kInputXDimP, broadcast_to_shape, *this);
|
||||
auto broadcast_input_y = AddBroadCastToNode(graph, cdist_inputs[kDim2], kInputYDimR, broadcast_to_shape, *this);
|
||||
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimCdist->name())),
|
||||
broadcast_input_x, broadcast_input_y};
|
||||
CNodePtr new_cnode = graph->NewCNode(new_inputs);
|
||||
CNodePtr new_cnode = NewCNode(new_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_abstract(cdist_cnode->abstract());
|
||||
new_cnode->set_scope(cdist_cnode->scope());
|
||||
|
@ -140,13 +140,13 @@ const AnfNodePtr CdistGradFission::Process(const FuncGraphPtr &graph, const AnfN
|
|||
auto x_shape = AnfAlgo::GetOutputInferShape(cdist_grad_inputs[kDim2], 0);
|
||||
auto y_shape = AnfAlgo::GetOutputInferShape(cdist_grad_inputs[kDim3], 0);
|
||||
auto broadcast_to_shape = CalCdistBroadCastShape(x_shape, y_shape);
|
||||
auto broadcast_grad = AddBroadCastToNode(graph, cdist_grad_inputs[kDim1], 0, broadcast_to_shape);
|
||||
auto broadcast_input_x = AddBroadCastToNode(graph, cdist_grad_inputs[kDim2], kInputXDimP, broadcast_to_shape);
|
||||
auto broadcast_input_y = AddBroadCastToNode(graph, cdist_grad_inputs[kDim3], kInputYDimR, broadcast_to_shape);
|
||||
auto broadcast_out = AddBroadCastToNode(graph, cdist_grad_inputs[kDim4], 0, broadcast_to_shape);
|
||||
auto broadcast_grad = AddBroadCastToNode(graph, cdist_grad_inputs[kDim1], 0, broadcast_to_shape, *this);
|
||||
auto broadcast_input_x = AddBroadCastToNode(graph, cdist_grad_inputs[kDim2], kInputXDimP, broadcast_to_shape, *this);
|
||||
auto broadcast_input_y = AddBroadCastToNode(graph, cdist_grad_inputs[kDim3], kInputYDimR, broadcast_to_shape, *this);
|
||||
auto broadcast_out = AddBroadCastToNode(graph, cdist_grad_inputs[kDim4], 0, broadcast_to_shape, *this);
|
||||
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimCdistGrad->name())),
|
||||
broadcast_grad, broadcast_input_x, broadcast_input_y, broadcast_out};
|
||||
CNodePtr new_cnode = graph->NewCNode(new_inputs);
|
||||
CNodePtr new_cnode = NewCNode(new_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_abstract(cdist_grad_cnode->abstract());
|
||||
new_cnode->set_scope(cdist_grad_cnode->scope());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -21,16 +21,15 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
AnfNodePtr CreateNewConcat(const FuncGraphPtr &func_graph, const CNodePtr &origin_concat_cnode, size_t begin_index,
|
||||
size_t offset) {
|
||||
AnfNodePtr ConcatFission::CreateNewConcat(const FuncGraphPtr &func_graph, const CNodePtr &origin_concat_cnode,
|
||||
size_t begin_index, size_t offset) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(origin_concat_cnode);
|
||||
std::vector<AnfNodePtr> new_concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
|
||||
for (size_t i = begin_index; i < begin_index + offset; ++i) {
|
||||
new_concat_inputs.emplace_back(origin_concat_cnode->input(i));
|
||||
}
|
||||
CNodePtr new_concat = func_graph->NewCNode(new_concat_inputs);
|
||||
CNodePtr new_concat = NewCNode(new_concat_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(new_concat);
|
||||
new_concat->set_scope(origin_concat_cnode->scope());
|
||||
// Set attrs
|
||||
|
@ -66,7 +65,6 @@ AnfNodePtr CreateNewConcat(const FuncGraphPtr &func_graph, const CNodePtr &origi
|
|||
new_concat.get());
|
||||
return new_concat;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef ConcatFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
|
@ -100,7 +98,7 @@ const AnfNodePtr ConcatFission::Process(const FuncGraphPtr &func_graph, const An
|
|||
for (size_t i = cur_input_index; i <= origin_input_size; i++) {
|
||||
base_concat_inputs.emplace_back(new_cnode->input(i));
|
||||
}
|
||||
CNodePtr base_concat = func_graph->NewCNode(base_concat_inputs);
|
||||
CNodePtr base_concat = NewCNode(base_concat_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(base_concat);
|
||||
base_concat->set_scope(new_cnode->scope());
|
||||
base_concat->set_abstract(new_cnode->abstract());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -30,6 +30,8 @@ class ConcatFission : public PatternProcessPass {
|
|||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
AnfNodePtr CreateNewConcat(const FuncGraphPtr &func_graph, const CNodePtr &origin_concat_cnode, size_t begin_index,
|
||||
size_t offset) const;
|
||||
size_t inputs_divisor_;
|
||||
};
|
||||
} // namespace opt
|
||||
|
|
|
@ -96,7 +96,7 @@ const AnfNodePtr DiagFission::Process(const FuncGraphPtr &graph, const AnfNodePt
|
|||
auto assist_const = CreateAssistNode(graph, diag_cnode, input_shape);
|
||||
(void)new_inputs.insert(new_inputs.end(), diag_cnode->inputs().begin() + 1, diag_cnode->inputs().end());
|
||||
new_inputs.push_back(assist_const);
|
||||
CNodePtr new_cnode = graph->NewCNode(new_inputs);
|
||||
CNodePtr new_cnode = NewCNode(new_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_abstract(diag_cnode->abstract());
|
||||
new_cnode->set_scope(diag_cnode->scope());
|
||||
|
|
|
@ -48,7 +48,7 @@ const AnfNodePtr DiagPartFission::Process(const FuncGraphPtr &func_graph, const
|
|||
(void)new_node_inputs.insert(new_node_inputs.end(), diag_part_cnode->inputs().begin() + 1,
|
||||
diag_part_cnode->inputs().end());
|
||||
new_node_inputs.push_back(assist_node);
|
||||
CNodePtr new_cnode = func_graph->NewCNode(new_node_inputs);
|
||||
CNodePtr new_cnode = NewCNode(new_node_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_abstract(diag_part_cnode->abstract());
|
||||
new_cnode->set_scope(diag_part_cnode->scope());
|
||||
|
|
|
@ -54,11 +54,14 @@ std::map<std::string, size_t> hidden_grad_input_index = {
|
|||
|
||||
std::map<std::string, size_t> hidden_grad_output_index = {
|
||||
{"dh_prev", kIndex0}, {"dgate_h", kIndex1}, {"dnt_x", kIndex2}};
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr CreateGRUV2HiddenGradCellNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2_grad_cnode,
|
||||
const AnfNodePtr &last_gru_hidden_grad_node,
|
||||
const AnfNodePtr &last_matmul_node, const std::string &gate_order,
|
||||
const size_t cur_t) {
|
||||
AnfNodePtr DynamicGRUV2GradFission::CreateGRUV2HiddenGradCellNode(const FuncGraphPtr &func_graph,
|
||||
const CNodePtr &dynamic_gru_v2_grad_cnode,
|
||||
const AnfNodePtr &last_gru_hidden_grad_node,
|
||||
const AnfNodePtr &last_matmul_node,
|
||||
const std::string &gate_order,
|
||||
const size_t cur_t) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dynamic_gru_v2_grad_cnode);
|
||||
const auto &dynamic_gru_v2_grad_inputs = dynamic_gru_v2_grad_cnode->inputs();
|
||||
|
@ -95,7 +98,7 @@ AnfNodePtr CreateGRUV2HiddenGradCellNode(const FuncGraphPtr &func_graph, const C
|
|||
(void)gru_v2_hidden_grad_cell_inputs.emplace_back(dynamic_gru_v2_grad_inputs[input_index["reset"]]);
|
||||
(void)gru_v2_hidden_grad_cell_inputs.emplace_back(dynamic_gru_v2_grad_inputs[input_index["new"]]);
|
||||
(void)gru_v2_hidden_grad_cell_inputs.emplace_back(dynamic_gru_v2_grad_inputs[input_index["hidden_new"]]);
|
||||
auto gru_v2_hidden_grad_cell_op = func_graph->NewCNode(gru_v2_hidden_grad_cell_inputs);
|
||||
auto gru_v2_hidden_grad_cell_op = NewCNode(gru_v2_hidden_grad_cell_inputs, func_graph);
|
||||
|
||||
std::vector<size_t> dh_prev_shape =
|
||||
AnfAlgo::GetOutputInferShape(dynamic_gru_grad_outputs[output_index["dh_prev"]], 0);
|
||||
|
@ -108,8 +111,8 @@ AnfNodePtr CreateGRUV2HiddenGradCellNode(const FuncGraphPtr &func_graph, const C
|
|||
return gru_v2_hidden_grad_cell_op;
|
||||
}
|
||||
|
||||
void AddTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2_grad_cnode,
|
||||
std::vector<std::vector<AnfNodePtr>> *result_nodes) {
|
||||
void DynamicGRUV2GradFission::AddTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2_grad_cnode,
|
||||
std::vector<std::vector<AnfNodePtr>> *result_nodes) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dynamic_gru_v2_grad_cnode);
|
||||
MS_EXCEPTION_IF_NULL(result_nodes);
|
||||
|
@ -137,12 +140,12 @@ void AddTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2
|
|||
auto weight_hidden = dynamic_gru_v2_grad_inputs[input_index["weight_hidden"]];
|
||||
std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
|
||||
weight_hidden};
|
||||
auto reshape = func_graph->NewCNode(reshape_inputs);
|
||||
auto reshape = NewCNode(reshape_inputs, func_graph);
|
||||
auto reshape_out_shape = {IntToSize(1), AnfAlgo::GetOutputInferShape(weight_hidden, 0)[0],
|
||||
AnfAlgo::GetOutputInferShape(weight_hidden, 0)[1]};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({dh_dtype}, {reshape_out_shape}, reshape.get());
|
||||
(void)matmul_inputs.emplace_back(reshape);
|
||||
auto matmul_node = func_graph->NewCNode(matmul_inputs);
|
||||
auto matmul_node = NewCNode(matmul_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(matmul_node);
|
||||
std::vector<size_t> out_shape = {1, batch_size, hidden_size};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({dh_dtype}, {out_shape}, matmul_node.get());
|
||||
|
@ -162,8 +165,9 @@ void AddTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2
|
|||
(void)result_nodes->emplace_back(matmul_nodes);
|
||||
}
|
||||
|
||||
AnfNodePtr AddTConcatNode(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &gru_hidden_grad_nodes,
|
||||
size_t concat_output_index) {
|
||||
AnfNodePtr DynamicGRUV2GradFission::AddTConcatNode(const FuncGraphPtr &func_graph,
|
||||
const std::vector<AnfNodePtr> &gru_hidden_grad_nodes,
|
||||
size_t concat_output_index) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
|
||||
for (size_t i = 0; i < t_size; i++) {
|
||||
|
@ -174,7 +178,7 @@ AnfNodePtr AddTConcatNode(const FuncGraphPtr &func_graph, const std::vector<AnfN
|
|||
&gru_hidden_grad_node_outputs);
|
||||
(void)concat_inputs.emplace_back(gru_hidden_grad_node_outputs[concat_output_index]);
|
||||
}
|
||||
auto concat_t_node = func_graph->NewCNode(concat_inputs);
|
||||
auto concat_t_node = NewCNode(concat_inputs, func_graph);
|
||||
auto out_dims = AnfAlgo::GetOutputInferShape(gru_hidden_grad_nodes[kIndex0], concat_output_index);
|
||||
std::vector<size_t> concat_output_shape = {t_size, out_dims[kDim1], out_dims[kDim2]};
|
||||
auto out_type = AnfAlgo::GetOutputInferDataType(gru_hidden_grad_nodes[kIndex0], concat_output_index);
|
||||
|
@ -185,8 +189,8 @@ AnfNodePtr AddTConcatNode(const FuncGraphPtr &func_graph, const std::vector<AnfN
|
|||
return concat_t_node;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> AddGRUHiddenGradNode(const FuncGraphPtr &func_graph,
|
||||
const CNodePtr &dynamic_gru_v2_grad_cnode) {
|
||||
std::vector<AnfNodePtr> DynamicGRUV2GradFission::AddGRUHiddenGradNode(const FuncGraphPtr &func_graph,
|
||||
const CNodePtr &dynamic_gru_v2_grad_cnode) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dynamic_gru_v2_grad_cnode);
|
||||
std::vector<AnfNodePtr> result;
|
||||
|
@ -213,13 +217,14 @@ std::vector<AnfNodePtr> AddGRUHiddenGradNode(const FuncGraphPtr &func_graph,
|
|||
return result;
|
||||
}
|
||||
|
||||
AnfNodePtr AddHSplitNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2_grad_cnode) {
|
||||
AnfNodePtr DynamicGRUV2GradFission::AddHSplitNode(const FuncGraphPtr &func_graph,
|
||||
const CNodePtr &dynamic_gru_v2_grad_cnode) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dynamic_gru_v2_grad_cnode);
|
||||
auto input_h = dynamic_gru_v2_grad_cnode->input(input_index["h"]);
|
||||
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
|
||||
input_h};
|
||||
auto split_v = func_graph->NewCNode(splitv_input);
|
||||
auto split_v = NewCNode(splitv_input, func_graph);
|
||||
// Set infer data type and shape
|
||||
auto dtypes = {AnfAlgo::GetOutputInferDataType(input_h, 0), AnfAlgo::GetOutputInferDataType(input_h, 0)};
|
||||
std::vector<size_t> output1_shape = {t_size - 1, batch_size, hidden_size};
|
||||
|
@ -235,7 +240,7 @@ AnfNodePtr AddHSplitNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic
|
|||
return split_v;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateHReshape(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
||||
AnfNodePtr DynamicGRUV2GradFission::CreateHReshape(const FuncGraphPtr &graph, const AnfNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto ori_shape = AnfAlgo::GetOutputInferShape(node, 0);
|
||||
|
@ -248,14 +253,15 @@ AnfNodePtr CreateHReshape(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
|||
auto ori_dtype = {AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
// reshape
|
||||
std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())), node};
|
||||
auto reshape = graph->NewCNode(reshape_input);
|
||||
auto reshape = NewCNode(reshape_input, graph);
|
||||
AnfAlgo::SetOutputInferTypeAndShape(ori_dtype, shape_tmp, reshape.get());
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reshape);
|
||||
return reshape;
|
||||
}
|
||||
|
||||
AnfNodePtr AddHConcatNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2_grad_cnode,
|
||||
const AnfNodePtr &splitv) {
|
||||
AnfNodePtr DynamicGRUV2GradFission::AddHConcatNode(const FuncGraphPtr &func_graph,
|
||||
const CNodePtr &dynamic_gru_v2_grad_cnode,
|
||||
const AnfNodePtr &splitv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dynamic_gru_v2_grad_cnode);
|
||||
MS_EXCEPTION_IF_NULL(splitv);
|
||||
|
@ -270,7 +276,7 @@ AnfNodePtr AddHConcatNode(const FuncGraphPtr &func_graph, const CNodePtr &dynami
|
|||
auto init_h_reshape = CreateHReshape(func_graph, dynamic_gru_v2_grad_cnode->input(input_index["init_h"]));
|
||||
(void)concat_inputs.emplace_back(init_h_reshape);
|
||||
(void)concat_inputs.emplace_back(splitv_outputs[kIndex0]);
|
||||
auto concat = func_graph->NewCNode(concat_inputs);
|
||||
auto concat = NewCNode(concat_inputs, func_graph);
|
||||
// Set infer data type and shape
|
||||
std::vector<size_t> output_shape = {t_size, batch_size, hidden_size};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(init_h_reshape, 0)}, {output_shape},
|
||||
|
@ -283,7 +289,8 @@ AnfNodePtr AddHConcatNode(const FuncGraphPtr &func_graph, const CNodePtr &dynami
|
|||
return concat;
|
||||
}
|
||||
|
||||
AnfNodePtr AddDwhMatmulNode(const FuncGraphPtr &func_graph, const AnfNodePtr &dgate_h, const AnfNodePtr &node) {
|
||||
AnfNodePtr DynamicGRUV2GradFission::AddDwhMatmulNode(const FuncGraphPtr &func_graph, const AnfNodePtr &dgate_h,
|
||||
const AnfNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dgate_h);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -297,7 +304,7 @@ AnfNodePtr AddDwhMatmulNode(const FuncGraphPtr &func_graph, const AnfNodePtr &dg
|
|||
} else {
|
||||
(void)matmul_inputs.emplace_back(dgate_h);
|
||||
}
|
||||
auto batch_matmul = func_graph->NewCNode(matmul_inputs);
|
||||
auto batch_matmul = NewCNode(matmul_inputs, func_graph);
|
||||
std::vector<size_t> shape = {t_size, hidden_size, kGateNum * hidden_size};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {shape}, batch_matmul.get());
|
||||
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(true), batch_matmul);
|
||||
|
@ -306,7 +313,8 @@ AnfNodePtr AddDwhMatmulNode(const FuncGraphPtr &func_graph, const AnfNodePtr &dg
|
|||
return batch_matmul;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateDgateHSplitVDNode(const FuncGraphPtr &func_graph, const AnfNodePtr &dgate_h) {
|
||||
AnfNodePtr DynamicGRUV2GradFission::CreateDgateHSplitVDNode(const FuncGraphPtr &func_graph,
|
||||
const AnfNodePtr &dgate_h) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dgate_h);
|
||||
std::vector<AnfNodePtr> splitvd_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
|
||||
|
@ -317,7 +325,7 @@ AnfNodePtr CreateDgateHSplitVDNode(const FuncGraphPtr &func_graph, const AnfNode
|
|||
} else {
|
||||
(void)splitvd_input.emplace_back(dgate_h);
|
||||
}
|
||||
auto split_vd = func_graph->NewCNode(splitvd_input);
|
||||
auto split_vd = NewCNode(splitvd_input, func_graph);
|
||||
auto dtypes = {AnfAlgo::GetOutputInferDataType(dgate_h, 0), AnfAlgo::GetOutputInferDataType(dgate_h, 0)};
|
||||
std::vector<size_t> shape = {t_size, batch_size, hidden_size << 1};
|
||||
std::vector<size_t> shape2 = {t_size, batch_size, hidden_size};
|
||||
|
@ -331,7 +339,8 @@ AnfNodePtr CreateDgateHSplitVDNode(const FuncGraphPtr &func_graph, const AnfNode
|
|||
return split_vd;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateDgateXConcatDNode(const FuncGraphPtr &func_graph, const AnfNodePtr &split, const AnfNodePtr &dnt_x) {
|
||||
AnfNodePtr DynamicGRUV2GradFission::CreateDgateXConcatDNode(const FuncGraphPtr &func_graph, const AnfNodePtr &split,
|
||||
const AnfNodePtr &dnt_x) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(split);
|
||||
MS_EXCEPTION_IF_NULL(dnt_x);
|
||||
|
@ -346,7 +355,7 @@ AnfNodePtr CreateDgateXConcatDNode(const FuncGraphPtr &func_graph, const AnfNode
|
|||
} else {
|
||||
(void)concat_inputs.emplace_back(dnt_x);
|
||||
}
|
||||
auto concat_op = func_graph->NewCNode(concat_inputs);
|
||||
auto concat_op = NewCNode(concat_inputs, func_graph);
|
||||
std::vector<size_t> shape = {t_size, batch_size, kGateNum * hidden_size};
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(dnt_x, 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, {shape}, concat_op.get());
|
||||
|
@ -357,14 +366,15 @@ AnfNodePtr CreateDgateXConcatDNode(const FuncGraphPtr &func_graph, const AnfNode
|
|||
return concat_op;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateDwxBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||
AnfNodePtr DynamicGRUV2GradFission::CreateDwxBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &node1,
|
||||
const AnfNodePtr &node2) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node1);
|
||||
MS_EXCEPTION_IF_NULL(node2);
|
||||
// BatchMatMul
|
||||
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name())),
|
||||
node1, node2};
|
||||
auto batch_matmul = graph->NewCNode(matmul_inputs);
|
||||
auto batch_matmul = NewCNode(matmul_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(batch_matmul);
|
||||
std::vector<size_t> shape = {t_size, input_size, kGateNum * hidden_size};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({dh_dtype}, {shape}, batch_matmul.get());
|
||||
|
@ -374,15 +384,15 @@ AnfNodePtr CreateDwxBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &nod
|
|||
return batch_matmul;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateDxtBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr &dgate_concat,
|
||||
const AnfNodePtr &weight_input, const AnfNodePtr &dx) {
|
||||
AnfNodePtr DynamicGRUV2GradFission::CreateDxtBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr &dgate_concat,
|
||||
const AnfNodePtr &weight_input, const AnfNodePtr &dx) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dgate_concat);
|
||||
MS_EXCEPTION_IF_NULL(weight_input);
|
||||
MS_EXCEPTION_IF_NULL(dx);
|
||||
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name())),
|
||||
dgate_concat, weight_input};
|
||||
auto batch_matmul = func_graph->NewCNode(matmul_inputs);
|
||||
auto batch_matmul = NewCNode(matmul_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(batch_matmul);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dx, 0)}, {AnfAlgo::GetOutputInferShape(dx, 0)},
|
||||
batch_matmul.get());
|
||||
|
@ -392,12 +402,12 @@ AnfNodePtr CreateDxtBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr
|
|||
return batch_matmul;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateWBroadcastToDNode(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
||||
AnfNodePtr DynamicGRUV2GradFission::CreateWBroadcastToDNode(const FuncGraphPtr &graph, const AnfNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// BroadcastTo
|
||||
std::vector<AnfNodePtr> braodcast_to_input = {NewValueNode(std::make_shared<Primitive>(kBroadcastToOpName)), node};
|
||||
auto broadcast_to_d = graph->NewCNode(braodcast_to_input);
|
||||
auto broadcast_to_d = NewCNode(braodcast_to_input, graph);
|
||||
std::vector<size_t> shape = {t_size, input_size, kGateNum * hidden_size};
|
||||
auto type = {AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(type, {shape}, broadcast_to_d.get());
|
||||
|
@ -407,14 +417,15 @@ AnfNodePtr CreateWBroadcastToDNode(const FuncGraphPtr &graph, const AnfNodePtr &
|
|||
return broadcast_to_d;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateDwReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &matmul, const AnfNodePtr &gru_grad) {
|
||||
AnfNodePtr DynamicGRUV2GradFission::CreateDwReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &matmul,
|
||||
const AnfNodePtr &gru_grad) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(matmul);
|
||||
MS_EXCEPTION_IF_NULL(gru_grad);
|
||||
// ReduceSumD for dw_x and dw_h
|
||||
std::vector<AnfNodePtr> reducesum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
|
||||
matmul};
|
||||
auto reduce_sumd = graph->NewCNode(reducesum_inputs);
|
||||
auto reduce_sumd = NewCNode(reducesum_inputs, graph);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(gru_grad, 0)};
|
||||
auto shapes = {AnfAlgo::GetOutputInferShape(gru_grad, 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, reduce_sumd.get());
|
||||
|
@ -424,14 +435,15 @@ AnfNodePtr CreateDwReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &m
|
|||
return reduce_sumd;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateDbReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &node2) {
|
||||
AnfNodePtr DynamicGRUV2GradFission::CreateDbReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const AnfNodePtr &node2) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(node2);
|
||||
// ReduceSumD for db_x and db_h
|
||||
std::vector<AnfNodePtr> reducesum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
|
||||
node};
|
||||
auto reduce_sumd = graph->NewCNode(reducesum_inputs);
|
||||
auto reduce_sumd = NewCNode(reducesum_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(reduce_sumd);
|
||||
std::vector<size_t> shape = {kGateNum * hidden_size};
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node2, 0)};
|
||||
|
@ -441,7 +453,6 @@ AnfNodePtr CreateDbReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &n
|
|||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sumd);
|
||||
return reduce_sumd;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef DynamicGRUV2GradFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_GRU_V2_GRAD_FISSION_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -28,6 +29,33 @@ class DynamicGRUV2GradFission : public PatternProcessPass {
|
|||
~DynamicGRUV2GradFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
AnfNodePtr CreateGRUV2HiddenGradCellNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2_grad_cnode,
|
||||
const AnfNodePtr &last_gru_hidden_grad_node,
|
||||
const AnfNodePtr &last_matmul_node, const std::string &gate_order,
|
||||
const size_t cur_t) const;
|
||||
void AddTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2_grad_cnode,
|
||||
std::vector<std::vector<AnfNodePtr>> *result_nodes) const;
|
||||
AnfNodePtr AddTConcatNode(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &gru_hidden_grad_nodes,
|
||||
size_t concat_output_index) const;
|
||||
std::vector<AnfNodePtr> AddGRUHiddenGradNode(const FuncGraphPtr &func_graph,
|
||||
const CNodePtr &dynamic_gru_v2_grad_cnode) const;
|
||||
AnfNodePtr AddHSplitNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2_grad_cnode) const;
|
||||
AnfNodePtr CreateHReshape(const FuncGraphPtr &graph, const AnfNodePtr &node) const;
|
||||
AnfNodePtr AddHConcatNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2_grad_cnode,
|
||||
const AnfNodePtr &splitv) const;
|
||||
AnfNodePtr AddDwhMatmulNode(const FuncGraphPtr &func_graph, const AnfNodePtr &dgate_h, const AnfNodePtr &node) const;
|
||||
AnfNodePtr CreateDgateHSplitVDNode(const FuncGraphPtr &func_graph, const AnfNodePtr &dgate_h) const;
|
||||
AnfNodePtr CreateDgateXConcatDNode(const FuncGraphPtr &func_graph, const AnfNodePtr &split,
|
||||
const AnfNodePtr &dnt_x) const;
|
||||
AnfNodePtr CreateDwxBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) const;
|
||||
AnfNodePtr CreateDxtBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr &dgate_concat,
|
||||
const AnfNodePtr &weight_input, const AnfNodePtr &dx) const;
|
||||
AnfNodePtr CreateWBroadcastToDNode(const FuncGraphPtr &graph, const AnfNodePtr &node) const;
|
||||
AnfNodePtr CreateDwReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &matmul,
|
||||
const AnfNodePtr &gru_grad) const;
|
||||
AnfNodePtr CreateDbReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &node2) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -34,9 +34,10 @@ constexpr int64_t kAttrAxis2Value = 2;
|
|||
constexpr int64_t kAttrNumSplitValue = 2;
|
||||
constexpr int64_t kAttrSplitDimValue = 2;
|
||||
constexpr size_t kDimMultiNum = 4;
|
||||
} // namespace
|
||||
|
||||
void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
std::vector<std::vector<AnfNodePtr>> *result_nodes) {
|
||||
void DynamicRnnGradFissionV2::CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
std::vector<std::vector<AnfNodePtr>> *result_nodes) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
|
||||
MS_EXCEPTION_IF_NULL(result_nodes);
|
||||
|
@ -52,7 +53,7 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
|
|||
// Create basic_lstm_cell_c_state_grad
|
||||
std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kBasicLSTMCellCStateGradV2OpName))};
|
||||
auto basic_lstm_cell_c_state_grad = func_graph->NewCNode(basic_lstm_cell_c_state_grad_inputs);
|
||||
auto basic_lstm_cell_c_state_grad = NewCNode(basic_lstm_cell_c_state_grad_inputs, func_graph);
|
||||
|
||||
std::vector<size_t> output0_dims{
|
||||
origin_input9_shape[kDim0],
|
||||
|
@ -66,7 +67,7 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
|
|||
// Create matmul
|
||||
auto origin_input1_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex2), 0);
|
||||
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))};
|
||||
auto matmul = func_graph->NewCNode(matmul_inputs);
|
||||
auto matmul = NewCNode(matmul_inputs, func_graph);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {{IntToSize(1), output0_dims[0], origin_input1_shape[0]}},
|
||||
matmul.get());
|
||||
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), matmul);
|
||||
|
@ -74,7 +75,7 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
|
|||
|
||||
// Create split
|
||||
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
|
||||
auto split_v = func_graph->NewCNode(splitv_input);
|
||||
auto split_v = NewCNode(splitv_input, func_graph);
|
||||
auto origin_output2_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, kIndex2);
|
||||
auto origin_output3_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, kIndex3);
|
||||
std::vector<size_t> split_v_output0_shape{IntToSize(1), origin_output2_shape[kDim1], origin_output2_shape[kDim2]};
|
||||
|
@ -99,13 +100,13 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
|
|||
result_nodes->emplace_back(split_nodes);
|
||||
}
|
||||
|
||||
AnfNodePtr CreateLSTMSPlitV(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
|
||||
const std::vector<std::vector<size_t>> &split_shapes,
|
||||
const std::vector<TypeId> &split_types, const std::vector<int64_t> &size_split,
|
||||
size_t num_split_x) {
|
||||
AnfNodePtr DynamicRnnGradFissionV2::CreateLSTMSPlitV(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
|
||||
const std::vector<std::vector<size_t>> &split_shapes,
|
||||
const std::vector<TypeId> &split_types,
|
||||
const std::vector<int64_t> &size_split, size_t num_split_x) const {
|
||||
std::vector<AnfNodePtr> lstm_split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
|
||||
input};
|
||||
auto lstm_split = func_graph->NewCNode(lstm_split_input);
|
||||
auto lstm_split = NewCNode(lstm_split_input, func_graph);
|
||||
AnfAlgo::SetOutputInferTypeAndShape(split_types, split_shapes, lstm_split.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(size_split), lstm_split);
|
||||
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(0)), lstm_split);
|
||||
|
@ -113,76 +114,27 @@ AnfNodePtr CreateLSTMSPlitV(const FuncGraphPtr &func_graph, const AnfNodePtr &in
|
|||
return lstm_split;
|
||||
}
|
||||
|
||||
AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
std::vector<AnfNodePtr> *outputs) {
|
||||
std::vector<std::vector<AnfNodePtr>> result_nodes;
|
||||
CreateTLoopNode(func_graph, dynamic_rnn_grad_cnode, &result_nodes);
|
||||
|
||||
auto origin_input5_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex6), 0);
|
||||
std::vector<size_t> split_c_dims{IntToSize(1), origin_input5_shape[0], origin_input5_shape[1]};
|
||||
|
||||
auto origin_input7 = dynamic_rnn_grad_cnode->input(kIndex8);
|
||||
size_t num_split_x = AnfAlgo::GetOutputInferShape(origin_input7, 0)[0];
|
||||
std::vector<std::vector<size_t>> split_shapes;
|
||||
std::vector<TypeId> split_types;
|
||||
std::vector<int64_t> size_split;
|
||||
for (size_t i = 0; i < num_split_x; ++i) {
|
||||
split_shapes.emplace_back(split_c_dims);
|
||||
split_types.emplace_back(kNumberTypeFloat32);
|
||||
size_split.emplace_back(1);
|
||||
}
|
||||
// Create lstm_split_c
|
||||
auto lstm_split_c = CreateLSTMSPlitV(func_graph, origin_input7, split_shapes, split_types, size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_c_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_c, num_split_x, &lstm_split_c_outputs);
|
||||
|
||||
// Create lstm_split_dy
|
||||
auto lstm_split_dy = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex9), split_shapes, split_types,
|
||||
size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_dy_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_dy, num_split_x, &lstm_split_dy_outputs);
|
||||
|
||||
// Create lstm_split_i
|
||||
auto lstm_split_i = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex12), split_shapes, split_types,
|
||||
size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_i_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_i, num_split_x, &lstm_split_i_outputs);
|
||||
|
||||
// Create lstm_split_j
|
||||
auto lstm_split_j = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex13), split_shapes, split_types,
|
||||
size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_j_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_j, num_split_x, &lstm_split_j_outputs);
|
||||
|
||||
// Create lstm_split_f
|
||||
auto lstm_split_f = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex14), split_shapes, split_types,
|
||||
size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_f_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_f, num_split_x, &lstm_split_f_outputs);
|
||||
|
||||
// Create lstm_split_o
|
||||
auto lstm_split_o = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex15), split_shapes, split_types,
|
||||
size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_o_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_o, num_split_x, &lstm_split_o_outputs);
|
||||
|
||||
// Create lstm_split_tanh
|
||||
auto lstm_split_tanh = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex16), split_shapes,
|
||||
split_types, size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_tanh_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_tanh, num_split_x, &lstm_split_tanh_outputs);
|
||||
|
||||
// Add edges
|
||||
void DynamicRnnGradFissionV2::CreateTLoopNodeWithEdge(const FuncGraphPtr &func_graph,
|
||||
const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
const std::vector<std::vector<AnfNodePtr>> &result_nodes,
|
||||
size_t num_split_x,
|
||||
std::vector<std::vector<AnfNodePtr>> *loop_node_outputs) const {
|
||||
auto &basic_lstm_cell_c_state_grad_nodes = result_nodes[kIndex0];
|
||||
auto &matmul_nodes = result_nodes[kIndex1];
|
||||
auto &split_nodes = result_nodes[kIndex2];
|
||||
auto &lstm_split_c_outputs = result_nodes[kIndex3];
|
||||
auto &lstm_split_dy_outputs = result_nodes[kIndex4];
|
||||
auto &lstm_split_i_outputs = result_nodes[kIndex5];
|
||||
auto &lstm_split_j_outputs = result_nodes[kIndex6];
|
||||
auto &lstm_split_f_outputs = result_nodes[kIndex7];
|
||||
auto &lstm_split_o_outputs = result_nodes[kIndex8];
|
||||
auto &lstm_split_tanh_outputs = result_nodes[kIndex9];
|
||||
std::vector<AnfNodePtr> pre_basic_lstm_cell_c_state_grad_outputs;
|
||||
std::vector<AnfNodePtr> pre_split_outputs;
|
||||
auto basic_lstm_cell_c_state_grad_nodes = result_nodes[kIndex0];
|
||||
auto matmul_nodes = result_nodes[kIndex1];
|
||||
auto split_nodes = result_nodes[kIndex2];
|
||||
std::vector<AnfNodePtr> lstm_x_concat_input(num_split_x + 1);
|
||||
lstm_x_concat_input[0] = NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()));
|
||||
std::vector<AnfNodePtr> lstm_gage_concat_input(num_split_x + 1);
|
||||
lstm_gage_concat_input[0] = NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()));
|
||||
|
||||
for (size_t i = 0; i < num_split_x; ++i) {
|
||||
size_t idx = num_split_x - i - 1;
|
||||
// Create basic_lstm_cell_c_state_grad
|
||||
|
@ -191,7 +143,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
|||
if (i == num_split_x - 1) {
|
||||
std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
|
||||
dynamic_rnn_grad_cnode->input(6)};
|
||||
auto reshape = func_graph->NewCNode(reshape_inputs);
|
||||
auto reshape = NewCNode(reshape_inputs, func_graph);
|
||||
auto reshape_out_shape = {IntToSize(1),
|
||||
AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex6), 0)[0],
|
||||
AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex6), 0)[1]};
|
||||
|
@ -213,7 +165,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
|||
(void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_f_outputs[idx]);
|
||||
(void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_o_outputs[idx]);
|
||||
(void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_tanh_outputs[idx]);
|
||||
auto basic_lstm_cell_c_state_grad = func_graph->NewCNode(basic_lstm_cell_c_state_grad_inputs);
|
||||
auto basic_lstm_cell_c_state_grad = NewCNode(basic_lstm_cell_c_state_grad_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(basic_lstm_cell_c_state_grad);
|
||||
basic_lstm_cell_c_state_grad->set_abstract(basic_lstm_cell_c_state_grad_nodes[i]->abstract());
|
||||
AnfAlgo::CopyNodeAttrs(basic_lstm_cell_c_state_grad_nodes[i], basic_lstm_cell_c_state_grad);
|
||||
|
@ -227,7 +179,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
|||
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))};
|
||||
(void)matmul_inputs.emplace_back(basic_lstm_cell_c_state_grad_outputs[0]);
|
||||
(void)matmul_inputs.emplace_back(dynamic_rnn_grad_cnode->input(kIndex2));
|
||||
auto matmul = func_graph->NewCNode(matmul_inputs);
|
||||
auto matmul = NewCNode(matmul_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(matmul);
|
||||
matmul->set_abstract(matmul_nodes[i]->abstract());
|
||||
AnfAlgo::CopyNodeAttrs(matmul_nodes[i], matmul);
|
||||
|
@ -235,7 +187,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
|||
// Create splitv
|
||||
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
|
||||
matmul};
|
||||
auto split_v = func_graph->NewCNode(splitv_input);
|
||||
auto split_v = NewCNode(splitv_input, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(split_v);
|
||||
split_v->set_abstract(split_nodes[i]->abstract());
|
||||
AnfAlgo::CopyNodeAttrs(split_nodes[i], split_v);
|
||||
|
@ -258,14 +210,94 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
|||
}
|
||||
std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
|
||||
basic_lstm_cell_c_state_grad_outputs[0]};
|
||||
auto reshape = func_graph->NewCNode(reshape_input);
|
||||
auto reshape = NewCNode(reshape_input, func_graph);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(basic_lstm_cell_c_state_grad_outputs[0], 0)},
|
||||
{temp_shape}, reshape.get());
|
||||
lstm_gage_concat_input[idx + 1] = reshape;
|
||||
}
|
||||
loop_node_outputs->push_back(pre_basic_lstm_cell_c_state_grad_outputs);
|
||||
loop_node_outputs->push_back(pre_split_outputs);
|
||||
loop_node_outputs->push_back(lstm_x_concat_input);
|
||||
loop_node_outputs->push_back(lstm_gage_concat_input);
|
||||
}
|
||||
|
||||
AnfNodePtr DynamicRnnGradFissionV2::AddLSTMInputGradNode(const FuncGraphPtr &func_graph,
|
||||
const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
std::vector<AnfNodePtr> *outputs) const {
|
||||
std::vector<std::vector<AnfNodePtr>> result_nodes;
|
||||
CreateTLoopNode(func_graph, dynamic_rnn_grad_cnode, &result_nodes);
|
||||
|
||||
auto origin_input5_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex6), 0);
|
||||
std::vector<size_t> split_c_dims{IntToSize(1), origin_input5_shape[0], origin_input5_shape[1]};
|
||||
|
||||
auto origin_input7 = dynamic_rnn_grad_cnode->input(kIndex8);
|
||||
size_t num_split_x = AnfAlgo::GetOutputInferShape(origin_input7, 0)[0];
|
||||
std::vector<std::vector<size_t>> split_shapes;
|
||||
std::vector<TypeId> split_types;
|
||||
std::vector<int64_t> size_split;
|
||||
for (size_t i = 0; i < num_split_x; ++i) {
|
||||
split_shapes.emplace_back(split_c_dims);
|
||||
split_types.emplace_back(kNumberTypeFloat32);
|
||||
size_split.emplace_back(1);
|
||||
}
|
||||
// Create lstm_split_c
|
||||
auto lstm_split_c = CreateLSTMSPlitV(func_graph, origin_input7, split_shapes, split_types, size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_c_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_c, num_split_x, &lstm_split_c_outputs);
|
||||
result_nodes.push_back(lstm_split_c_outputs);
|
||||
|
||||
// Create lstm_split_dy
|
||||
auto lstm_split_dy = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex9), split_shapes, split_types,
|
||||
size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_dy_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_dy, num_split_x, &lstm_split_dy_outputs);
|
||||
result_nodes.push_back(lstm_split_dy_outputs);
|
||||
|
||||
// Create lstm_split_i
|
||||
auto lstm_split_i = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex12), split_shapes, split_types,
|
||||
size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_i_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_i, num_split_x, &lstm_split_i_outputs);
|
||||
result_nodes.push_back(lstm_split_i_outputs);
|
||||
|
||||
// Create lstm_split_j
|
||||
auto lstm_split_j = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex13), split_shapes, split_types,
|
||||
size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_j_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_j, num_split_x, &lstm_split_j_outputs);
|
||||
result_nodes.push_back(lstm_split_j_outputs);
|
||||
|
||||
// Create lstm_split_f
|
||||
auto lstm_split_f = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex14), split_shapes, split_types,
|
||||
size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_f_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_f, num_split_x, &lstm_split_f_outputs);
|
||||
result_nodes.push_back(lstm_split_f_outputs);
|
||||
|
||||
// Create lstm_split_o
|
||||
auto lstm_split_o = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex15), split_shapes, split_types,
|
||||
size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_o_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_o, num_split_x, &lstm_split_o_outputs);
|
||||
result_nodes.push_back(lstm_split_o_outputs);
|
||||
|
||||
// Create lstm_split_tanh
|
||||
auto lstm_split_tanh = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex16), split_shapes,
|
||||
split_types, size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_tanh_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_tanh, num_split_x, &lstm_split_tanh_outputs);
|
||||
result_nodes.push_back(lstm_split_tanh_outputs);
|
||||
|
||||
// Add edges
|
||||
std::vector<std::vector<AnfNodePtr>> loop_node_outputs;
|
||||
CreateTLoopNodeWithEdge(func_graph, dynamic_rnn_grad_cnode, result_nodes, num_split_x, &loop_node_outputs);
|
||||
auto &pre_basic_lstm_cell_c_state_grad_outputs = loop_node_outputs[kIndex0];
|
||||
auto &pre_split_outputs = loop_node_outputs[kIndex1];
|
||||
auto &lstm_x_concat_input = loop_node_outputs[kIndex2];
|
||||
auto &lstm_gage_concat_input = loop_node_outputs[kIndex3];
|
||||
|
||||
// Create lstm_x_concat
|
||||
auto lstm_x_concat = func_graph->NewCNode(lstm_x_concat_input);
|
||||
auto lstm_x_concat = NewCNode(lstm_x_concat_input, func_graph);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 2)},
|
||||
lstm_x_concat.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_x_concat);
|
||||
|
@ -273,7 +305,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
|||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast<int64_t>(0)), lstm_x_concat);
|
||||
|
||||
// Create lstm_gage_concat
|
||||
auto lstm_gage_concat = func_graph->NewCNode(lstm_gage_concat_input);
|
||||
auto lstm_gage_concat = NewCNode(lstm_gage_concat_input, func_graph);
|
||||
auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0);
|
||||
AnfAlgo::SetOutputInferTypeAndShape(
|
||||
{kNumberTypeFloat16},
|
||||
|
@ -289,14 +321,15 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
|||
return lstm_gage_concat;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateSplitV(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) {
|
||||
AnfNodePtr DynamicRnnGradFissionV2::CreateSplitV(const FuncGraphPtr &func_graph,
|
||||
const CNodePtr &dynamic_rnn_grad_cnode) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
|
||||
// Create node
|
||||
auto origin_input6 = dynamic_rnn_grad_cnode->input(kIndex7);
|
||||
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
|
||||
origin_input6};
|
||||
auto split_v = func_graph->NewCNode(splitv_input);
|
||||
auto split_v = NewCNode(splitv_input, func_graph);
|
||||
// Set infer data type and shape
|
||||
auto dtypes = {AnfAlgo::GetOutputInferDataType(origin_input6, 0), AnfAlgo::GetOutputInferDataType(origin_input6, 0)};
|
||||
auto origin_input6_shape = AnfAlgo::GetOutputInferShape(origin_input6, 0);
|
||||
|
@ -313,8 +346,9 @@ AnfNodePtr CreateSplitV(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_
|
|||
return split_v;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
const AnfNodePtr &splitv) {
|
||||
AnfNodePtr DynamicRnnGradFissionV2::CreateHConcat(const FuncGraphPtr &func_graph,
|
||||
const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
const AnfNodePtr &splitv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
|
||||
MS_EXCEPTION_IF_NULL(splitv);
|
||||
|
@ -336,11 +370,11 @@ AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic
|
|||
}
|
||||
std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
|
||||
origin_input4};
|
||||
auto reshape = func_graph->NewCNode(reshape_input);
|
||||
auto reshape = NewCNode(reshape_input, func_graph);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape_tmp}, reshape.get());
|
||||
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
|
||||
reshape, splitv_outputs[0]};
|
||||
auto concat = func_graph->NewCNode(concat_inputs);
|
||||
auto concat = NewCNode(concat_inputs, func_graph);
|
||||
// Set infer data type and shape
|
||||
auto splitv_output0_shape = AnfAlgo::GetOutputInferShape(splitv, 0);
|
||||
std::vector<size_t> shape = {splitv_output0_shape[0] + 1, origin_input4_shape[0], origin_input4_shape[1]};
|
||||
|
@ -353,15 +387,15 @@ AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic
|
|||
return concat;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
const AnfNodePtr &h_concat) {
|
||||
AnfNodePtr DynamicRnnGradFissionV2::CreateConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
const AnfNodePtr &h_concat) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
|
||||
// Create node
|
||||
auto origin_input0 = dynamic_rnn_grad_cnode->input(1);
|
||||
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
|
||||
origin_input0, h_concat};
|
||||
auto concat = func_graph->NewCNode(concat_inputs);
|
||||
auto concat = NewCNode(concat_inputs, func_graph);
|
||||
// Set infer data type and shape
|
||||
auto origin_output0_shape = AnfAlgo::GetOutputInferShape(origin_input0, 0);
|
||||
auto h_concat_output_shape = AnfAlgo::GetOutputInferShape(h_concat, 0);
|
||||
|
@ -376,7 +410,8 @@ AnfNodePtr CreateConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_
|
|||
return concat;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) {
|
||||
AnfNodePtr DynamicRnnGradFissionV2::CreateConcatNodeT1(const FuncGraphPtr &func_graph,
|
||||
const CNodePtr &dynamic_rnn_grad_cnode) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
|
||||
// Create node
|
||||
|
@ -392,12 +427,12 @@ AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dy
|
|||
}
|
||||
std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
|
||||
origin_input4};
|
||||
auto reshape = func_graph->NewCNode(reshape_input);
|
||||
auto reshape = NewCNode(reshape_input, func_graph);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape_tmp}, reshape.get());
|
||||
|
||||
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
|
||||
origin_input0, reshape};
|
||||
auto concat = func_graph->NewCNode(concat_inputs);
|
||||
auto concat = NewCNode(concat_inputs, func_graph);
|
||||
// Set infer data type and shape
|
||||
auto origin_input0_shape = AnfAlgo::GetOutputInferShape(origin_input0, 0);
|
||||
std::vector<size_t> shape = {origin_input0_shape[kDim0], origin_input0_shape[kDim1],
|
||||
|
@ -411,13 +446,13 @@ AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dy
|
|||
return concat;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr &lstm_input_grad,
|
||||
const AnfNodePtr &concat) {
|
||||
AnfNodePtr DynamicRnnGradFissionV2::CreateBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr &lstm_input_grad,
|
||||
const AnfNodePtr &concat) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
// Create node
|
||||
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name())),
|
||||
concat, lstm_input_grad};
|
||||
auto batch_matmul = func_graph->NewCNode(matmul_inputs);
|
||||
auto batch_matmul = NewCNode(matmul_inputs, func_graph);
|
||||
// Set infer data type and shape
|
||||
auto concat_shape = AnfAlgo::GetOutputInferShape(concat, 0);
|
||||
auto lstm_input_grad_shape = AnfAlgo::GetOutputInferShape(lstm_input_grad, 0);
|
||||
|
@ -430,13 +465,14 @@ AnfNodePtr CreateBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr &l
|
|||
return batch_matmul;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateBatchMatMul2(const FuncGraphPtr &func_graph, const AnfNodePtr &lstm_input_grad,
|
||||
const AnfNodePtr &node) {
|
||||
AnfNodePtr DynamicRnnGradFissionV2::CreateBatchMatMul2(const FuncGraphPtr &func_graph,
|
||||
const AnfNodePtr &lstm_input_grad,
|
||||
const AnfNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
// Create node
|
||||
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name())),
|
||||
node, lstm_input_grad};
|
||||
auto batch_matmul = func_graph->NewCNode(matmul_inputs);
|
||||
auto batch_matmul = NewCNode(matmul_inputs, func_graph);
|
||||
// Set infer data type and shape
|
||||
auto out_shape = {AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[kIndex0], IntToSize(1),
|
||||
AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[kIndex2]};
|
||||
|
@ -448,13 +484,14 @@ AnfNodePtr CreateBatchMatMul2(const FuncGraphPtr &func_graph, const AnfNodePtr &
|
|||
return batch_matmul;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
const AnfNodePtr &batch_matmul) {
|
||||
AnfNodePtr DynamicRnnGradFissionV2::CreateDwReduceSum(const FuncGraphPtr &func_graph,
|
||||
const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
const AnfNodePtr &batch_matmul) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
// Create node
|
||||
std::vector<AnfNodePtr> reduce_sum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
|
||||
batch_matmul};
|
||||
auto reduce_sum = func_graph->NewCNode(reduce_sum_inputs);
|
||||
auto reduce_sum = NewCNode(reduce_sum_inputs, func_graph);
|
||||
// Set infer data type and shape
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)},
|
||||
{AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 0)}, reduce_sum.get());
|
||||
|
@ -465,13 +502,14 @@ AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dyn
|
|||
return reduce_sum;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateDwReshape(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
const AnfNodePtr &batch_matmul) {
|
||||
AnfNodePtr DynamicRnnGradFissionV2::CreateDwReshape(const FuncGraphPtr &func_graph,
|
||||
const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
const AnfNodePtr &batch_matmul) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
// Create node
|
||||
std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
|
||||
batch_matmul};
|
||||
auto reshape = func_graph->NewCNode(reshape_inputs);
|
||||
auto reshape = NewCNode(reshape_inputs, func_graph);
|
||||
// Set infer data type and shape
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)},
|
||||
{AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 0)}, reshape.get());
|
||||
|
@ -479,7 +517,8 @@ AnfNodePtr CreateDwReshape(const FuncGraphPtr &func_graph, const CNodePtr &dynam
|
|||
return reshape;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateValueNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) {
|
||||
AnfNodePtr DynamicRnnGradFissionV2::CreateValueNode(const FuncGraphPtr &func_graph,
|
||||
const CNodePtr &dynamic_rnn_grad_cnode) const {
|
||||
auto origin_input7 = dynamic_rnn_grad_cnode->input(kIndex8);
|
||||
auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0);
|
||||
auto t_size = origin_input7_shape[0];
|
||||
|
@ -497,14 +536,15 @@ AnfNodePtr CreateValueNode(const FuncGraphPtr &func_graph, const CNodePtr &dynam
|
|||
return value_node;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &, const AnfNodePtr &lstm_input_grad,
|
||||
const AnfNodePtr &value_node) {
|
||||
AnfNodePtr DynamicRnnGradFissionV2::CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &,
|
||||
const AnfNodePtr &lstm_input_grad,
|
||||
const AnfNodePtr &value_node) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
// Create node
|
||||
auto batch_matmul = CreateBatchMatMul2(func_graph, lstm_input_grad, value_node);
|
||||
std::vector<AnfNodePtr> reduce_sum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
|
||||
batch_matmul};
|
||||
auto reduce_sum = func_graph->NewCNode(reduce_sum_inputs);
|
||||
auto reduce_sum = NewCNode(reduce_sum_inputs, func_graph);
|
||||
// Set infer data type and shape
|
||||
auto out_shape = {AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[kDim2]};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {out_shape}, reduce_sum.get());
|
||||
|
@ -514,7 +554,6 @@ AnfNodePtr CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &, c
|
|||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum);
|
||||
return reduce_sum;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef DynamicRnnGradFissionV2::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_V2_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_V2_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -28,6 +29,36 @@ class DynamicRnnGradFissionV2 : public PatternProcessPass {
|
|||
~DynamicRnnGradFissionV2() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
std::vector<std::vector<AnfNodePtr>> *result_nodes) const;
|
||||
AnfNodePtr CreateLSTMSPlitV(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
|
||||
const std::vector<std::vector<size_t>> &split_shapes,
|
||||
const std::vector<TypeId> &split_types, const std::vector<int64_t> &size_split,
|
||||
size_t num_split_x) const;
|
||||
void CreateTLoopNodeWithEdge(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
const std::vector<std::vector<AnfNodePtr>> &result_nodes, size_t num_split_x,
|
||||
std::vector<std::vector<AnfNodePtr>> *loop_node_outputs) const;
|
||||
AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
std::vector<AnfNodePtr> *outputs) const;
|
||||
AnfNodePtr CreateSplitV(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) const;
|
||||
AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
const AnfNodePtr &splitv) const;
|
||||
AnfNodePtr CreateConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
const AnfNodePtr &h_concat) const;
|
||||
AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) const;
|
||||
AnfNodePtr CreateBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr &lstm_input_grad,
|
||||
const AnfNodePtr &concat) const;
|
||||
AnfNodePtr CreateBatchMatMul2(const FuncGraphPtr &func_graph, const AnfNodePtr &lstm_input_grad,
|
||||
const AnfNodePtr &node) const;
|
||||
AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
const AnfNodePtr &batch_matmul) const;
|
||||
AnfNodePtr CreateDwReshape(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
const AnfNodePtr &batch_matmul) const;
|
||||
AnfNodePtr CreateValueNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) const;
|
||||
AnfNodePtr CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &, const AnfNodePtr &lstm_input_grad,
|
||||
const AnfNodePtr &value_node) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -28,12 +28,36 @@ constexpr size_t kOriginPaddingSize = 2;
|
|||
constexpr size_t kGatherInputNum = 4;
|
||||
constexpr size_t kGatherInputIndicesIndex = 2;
|
||||
constexpr size_t kGatherInputAxisIndex = 3;
|
||||
|
||||
bool CheckInputs(const CNodePtr &origin_node) {
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
if (AnfAlgo::GetInputTensorNum(origin_node) != kGatherV2DynInputTensorNum) {
|
||||
MS_LOG(DEBUG) << "GatherV2 in dynamic shape has wrong inputs num, not equal " << kGatherV2DynInputTensorNum
|
||||
<< ". CNode= " << origin_node->DebugString();
|
||||
return false;
|
||||
}
|
||||
auto param_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
|
||||
auto indice_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 1);
|
||||
// this optimizer only support embedding_table has dynamic shape
|
||||
if (param_shape.empty() || indice_shape.empty() || AnfAlgo::IsDynamicShape(origin_node->input(kDim2))) {
|
||||
return false;
|
||||
}
|
||||
if (param_shape[param_shape.size() - 1] != 1) {
|
||||
MS_LOG(DEBUG) << "GatherV2 in dynamic shape is not need fission. The last value of input0's shape is "
|
||||
<< param_shape[param_shape.size() - 1];
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// only pad operator can run in dynamic shape.
|
||||
CNodePtr CreatePad(const FuncGraphPtr &graph, const CNodePtr &origin_node, const size_t &pad_dim_size) {
|
||||
CNodePtr GatherV2DsFission::CreatePad(const FuncGraphPtr &graph, const CNodePtr &origin_node,
|
||||
const size_t &pad_dim_size) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
std::vector<AnfNodePtr> pad_inputs = {NewValueNode(std::make_shared<Primitive>(kPadOpName)), origin_node->input(1)};
|
||||
auto pad = graph->NewCNode(pad_inputs);
|
||||
auto pad = NewCNode(pad_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(pad);
|
||||
pad->set_scope(origin_node->scope());
|
||||
|
||||
|
@ -83,8 +107,8 @@ CNodePtr CreatePad(const FuncGraphPtr &graph, const CNodePtr &origin_node, const
|
|||
return pad;
|
||||
}
|
||||
|
||||
CNodePtr CreateGatherV2Ds(const FuncGraphPtr &graph, const CNodePtr &origin_node, const CNodePtr &pad,
|
||||
const size_t &pad_dim_size) {
|
||||
CNodePtr GatherV2DsFission::CreateGatherV2Ds(const FuncGraphPtr &graph, const CNodePtr &origin_node,
|
||||
const CNodePtr &pad, const size_t &pad_dim_size) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
MS_EXCEPTION_IF_NULL(pad);
|
||||
|
@ -94,7 +118,7 @@ CNodePtr CreateGatherV2Ds(const FuncGraphPtr &graph, const CNodePtr &origin_node
|
|||
std::vector<AnfNodePtr> gatherv2_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimGather->name())), pad,
|
||||
origin_node->input(kGatherInputIndicesIndex),
|
||||
origin_node->input(kGatherInputAxisIndex)};
|
||||
auto gather_v2 = graph->NewCNode(gatherv2_inputs);
|
||||
auto gather_v2 = NewCNode(gatherv2_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(gather_v2);
|
||||
gather_v2->set_scope(origin_node->scope());
|
||||
|
||||
|
@ -110,12 +134,13 @@ CNodePtr CreateGatherV2Ds(const FuncGraphPtr &graph, const CNodePtr &origin_node
|
|||
return gather_v2;
|
||||
}
|
||||
|
||||
CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &gather_v2, const CNodePtr &gather_v2_padding_8) {
|
||||
CNodePtr GatherV2DsFission::CreateSlice(const FuncGraphPtr &graph, const CNodePtr &gather_v2,
|
||||
const CNodePtr &gather_v2_padding_8) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(gather_v2);
|
||||
MS_EXCEPTION_IF_NULL(gather_v2_padding_8);
|
||||
std::vector<AnfNodePtr> slice_inputs = {NewValueNode(std::make_shared<Primitive>(kSliceOpName)), gather_v2_padding_8};
|
||||
auto slice = graph->NewCNode(slice_inputs);
|
||||
auto slice = NewCNode(slice_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(slice);
|
||||
slice->set_scope(gather_v2->scope());
|
||||
slice->set_abstract(gather_v2->abstract());
|
||||
|
@ -126,28 +151,6 @@ CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &gather_v2, const
|
|||
return slice;
|
||||
}
|
||||
|
||||
bool CheckInputs(const CNodePtr &origin_node) {
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
if (AnfAlgo::GetInputTensorNum(origin_node) != kGatherV2DynInputTensorNum) {
|
||||
MS_LOG(DEBUG) << "GatherV2 in dynamic shape has wrong inputs num, not equal " << kGatherV2DynInputTensorNum
|
||||
<< ". CNode= " << origin_node->DebugString();
|
||||
return false;
|
||||
}
|
||||
auto param_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
|
||||
auto indice_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 1);
|
||||
// this optimizer only support embedding_table has dynamic shape
|
||||
if (param_shape.empty() || indice_shape.empty() || AnfAlgo::IsDynamicShape(origin_node->input(kDim2))) {
|
||||
return false;
|
||||
}
|
||||
if (param_shape[param_shape.size() - 1] != 1) {
|
||||
MS_LOG(DEBUG) << "GatherV2 in dynamic shape is not need fission. The last value of input0's shape is "
|
||||
<< param_shape[param_shape.size() - 1];
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef GatherV2DsFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
VectorRef pattern({prim::kPrimGather, Xs});
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -30,6 +30,12 @@ class GatherV2DsFission : public PatternProcessPass {
|
|||
~GatherV2DsFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
CNodePtr CreatePad(const FuncGraphPtr &graph, const CNodePtr &origin_node, const size_t &pad_dim_size) const;
|
||||
CNodePtr CreateGatherV2Ds(const FuncGraphPtr &graph, const CNodePtr &origin_node, const CNodePtr &pad,
|
||||
const size_t &pad_dim_size) const;
|
||||
CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &gather_v2, const CNodePtr &gather_v2_padding_8) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -15,7 +15,6 @@
|
|||
*/
|
||||
#include "backend/optimizer/ascend/ir_fission/lars_v2_fission.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "utils/utils.h"
|
||||
|
@ -29,14 +28,16 @@ constexpr size_t kLarsV2WIndex = 1;
|
|||
constexpr size_t kLarsV2GIndex = 2;
|
||||
constexpr size_t kLarsV2WeightDecayIndex = 3;
|
||||
constexpr size_t kLarsV2LearningRatIndex = 4;
|
||||
void CreateOutputsOfSquareSumAll(const FuncGraphPtr &graph, const CNodePtr &lars_v2,
|
||||
std::vector<AnfNodePtr> *square_sum_all_outputs) {
|
||||
} // namespace
|
||||
|
||||
void LarsV2Fission::CreateOutputsOfSquareSumAll(const FuncGraphPtr &graph, const CNodePtr &lars_v2,
|
||||
std::vector<AnfNodePtr> *square_sum_all_outputs) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(lars_v2);
|
||||
CheckCNodeInputSize(lars_v2, kLarsV2InputTensorNum);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kSquareSumAllOpName)), lars_v2->input(1),
|
||||
lars_v2->input(2)};
|
||||
auto square_sum_all = graph->NewCNode(inputs);
|
||||
auto square_sum_all = NewCNode(inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(square_sum_all);
|
||||
square_sum_all->set_scope(lars_v2->scope());
|
||||
|
||||
|
@ -47,8 +48,8 @@ void CreateOutputsOfSquareSumAll(const FuncGraphPtr &graph, const CNodePtr &lars
|
|||
CreateMultipleOutputsOfAnfNode(graph, square_sum_all, kSquareSumOutputNum, square_sum_all_outputs);
|
||||
}
|
||||
|
||||
CNodePtr CreateLarsV2Update(const FuncGraphPtr &graph, const CNodePtr &lars_v2,
|
||||
const std::vector<AnfNodePtr> &square_sum_all_outputs) {
|
||||
CNodePtr LarsV2Fission::CreateLarsV2Update(const FuncGraphPtr &graph, const CNodePtr &lars_v2,
|
||||
const std::vector<AnfNodePtr> &square_sum_all_outputs) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(lars_v2);
|
||||
if (square_sum_all_outputs.size() != kSquareSumOutputNum) {
|
||||
|
@ -63,13 +64,12 @@ CNodePtr CreateLarsV2Update(const FuncGraphPtr &graph, const CNodePtr &lars_v2,
|
|||
square_sum_all_outputs[1],
|
||||
lars_v2->input(kLarsV2WeightDecayIndex),
|
||||
lars_v2->input(kLarsV2LearningRatIndex)};
|
||||
auto lars_v2_update = graph->NewCNode(inputs);
|
||||
auto lars_v2_update = NewCNode(inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(lars_v2_update);
|
||||
lars_v2_update->set_scope(lars_v2->scope());
|
||||
lars_v2_update->set_abstract(lars_v2->abstract());
|
||||
return lars_v2_update;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef LarsV2Fission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,6 +16,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_LARS_V2_FISSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_LARS_V2_FISSION_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -26,6 +27,12 @@ class LarsV2Fission : public PatternProcessPass {
|
|||
~LarsV2Fission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
void CreateOutputsOfSquareSumAll(const FuncGraphPtr &graph, const CNodePtr &lars_v2,
|
||||
std::vector<AnfNodePtr> *square_sum_all_outputs) const;
|
||||
CNodePtr CreateLarsV2Update(const FuncGraphPtr &graph, const CNodePtr &lars_v2,
|
||||
const std::vector<AnfNodePtr> &square_sum_all_outputs) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -43,7 +43,7 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormXBackpropV2(const FuncGraphPtr
|
|||
for (size_t i = 1; i < layer_norm_grad->inputs().size(); ++i) {
|
||||
layer_norm_x_backprop_inputs.push_back(layer_norm_grad->input(i));
|
||||
}
|
||||
auto layer_norm_x_backprop = graph->NewCNode(layer_norm_x_backprop_inputs);
|
||||
auto layer_norm_x_backprop = NewCNode(layer_norm_x_backprop_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(layer_norm_x_backprop);
|
||||
layer_norm_x_backprop->set_scope(layer_norm_grad->scope());
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(layer_norm_grad, 0), kNumberTypeFloat32};
|
||||
|
@ -68,7 +68,7 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormBetaGammaBackpropV2(
|
|||
auto prim = std::make_shared<Primitive>(kLayerNormBetaGammaBackpropV2OpName);
|
||||
std::vector<AnfNodePtr> layer_norm_beta_gamma_backprop_inputs = {NewValueNode(prim), layer_norm_grad->input(kIndex2),
|
||||
res_for_gamma};
|
||||
auto layer_norm_beta_gamma_backprop = graph->NewCNode(layer_norm_beta_gamma_backprop_inputs);
|
||||
auto layer_norm_beta_gamma_backprop = NewCNode(layer_norm_beta_gamma_backprop_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(layer_norm_beta_gamma_backprop);
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
layer_norm_beta_gamma_backprop->set_kernel_info(kernel_info);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -109,7 +109,7 @@ const AnfNodePtr LinSpaceFission::Process(const FuncGraphPtr &graph, const AnfNo
|
|||
auto assist_const = CreateValueNode(cnode);
|
||||
new_inputs.push_back(assist_const);
|
||||
new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
CNodePtr new_cnode = graph->NewCNode(new_inputs);
|
||||
CNodePtr new_cnode = NewCNode(new_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_abstract(cnode->abstract());
|
||||
new_cnode->set_scope(cnode->scope());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except i n compliance with the License.
|
||||
|
@ -112,7 +112,7 @@ const AnfNodePtr MaxPool3DGradGradFission::Process(const FuncGraphPtr &graph, co
|
|||
auto assist_const = CreateValueNode(cnode);
|
||||
(void)new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
(void)new_inputs.emplace_back(assist_const);
|
||||
CNodePtr new_cnode = graph->NewCNode(new_inputs);
|
||||
CNodePtr new_cnode = NewCNode(new_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_abstract(cnode->abstract());
|
||||
new_cnode->set_scope(cnode->scope());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -21,16 +21,15 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
AnfNodePtr CreateNewPack(const FuncGraphPtr &func_graph, const CNodePtr &origin_pack_cnode, size_t begin_index,
|
||||
size_t offset) {
|
||||
AnfNodePtr PackFission::CreateNewPack(const FuncGraphPtr &func_graph, const CNodePtr &origin_pack_cnode,
|
||||
size_t begin_index, size_t offset) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(origin_pack_cnode);
|
||||
std::vector<AnfNodePtr> new_pack_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimStack->name()))};
|
||||
for (size_t i = begin_index; i < begin_index + offset; ++i) {
|
||||
new_pack_inputs.push_back(origin_pack_cnode->input(i));
|
||||
}
|
||||
CNodePtr new_pack = func_graph->NewCNode(new_pack_inputs);
|
||||
CNodePtr new_pack = NewCNode(new_pack_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(new_pack);
|
||||
new_pack->set_scope(origin_pack_cnode->scope());
|
||||
new_pack->set_abstract(origin_pack_cnode->abstract());
|
||||
|
@ -58,7 +57,6 @@ AnfNodePtr CreateNewPack(const FuncGraphPtr &func_graph, const CNodePtr &origin_
|
|||
new_pack.get());
|
||||
return new_pack;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef PackFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
|
@ -90,7 +88,7 @@ const AnfNodePtr PackFission::Process(const FuncGraphPtr &func_graph, const AnfN
|
|||
CreateNewPack(func_graph, cnode, cur_input_index, origin_input_size - cur_input_index + 1));
|
||||
}
|
||||
|
||||
CNodePtr base_concat = func_graph->NewCNode(base_concat_inputs);
|
||||
CNodePtr base_concat = NewCNode(base_concat_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(base_concat);
|
||||
base_concat->set_scope(cnode->scope());
|
||||
base_concat->set_abstract(cnode->abstract());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -30,6 +30,8 @@ class PackFission : public PatternProcessPass {
|
|||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
AnfNodePtr CreateNewPack(const FuncGraphPtr &func_graph, const CNodePtr &origin_pack_cnode, size_t begin_index,
|
||||
size_t offset) const;
|
||||
size_t inputs_divisor_;
|
||||
};
|
||||
} // namespace opt
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -21,18 +21,6 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
CNodePtr CreateReduceMin(const FuncGraphPtr &graph, const AnfNodePtr &input, const CNodePtr &old_node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
MS_EXCEPTION_IF_NULL(old_node);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceMin->name())), input};
|
||||
CNodePtr reduce_min = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(reduce_min);
|
||||
reduce_min->set_scope(old_node->scope());
|
||||
AnfAlgo::CopyNodeAttr(kAttrKeepDims, old_node, reduce_min);
|
||||
return reduce_min;
|
||||
}
|
||||
|
||||
bool NeedOptimize(const TypeId &dtype, const std::vector<size_t> &shape, const std::vector<int64_t> &axis) {
|
||||
if (dtype != kNumberTypeFloat32) {
|
||||
MS_LOG(INFO) << "ReduceMin's input Dtype is not float32, no need to optimize!";
|
||||
|
@ -97,6 +85,19 @@ std::vector<size_t> GetInferShape(const std::vector<size_t> &shape, const std::v
|
|||
}
|
||||
} // namespace
|
||||
|
||||
CNodePtr ReduceMinFission::CreateReduceMin(const FuncGraphPtr &graph, const AnfNodePtr &input,
|
||||
const CNodePtr &old_node) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
MS_EXCEPTION_IF_NULL(old_node);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceMin->name())), input};
|
||||
CNodePtr reduce_min = NewCNode(inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(reduce_min);
|
||||
reduce_min->set_scope(old_node->scope());
|
||||
AnfAlgo::CopyNodeAttr(kAttrKeepDims, old_node, reduce_min);
|
||||
return reduce_min;
|
||||
}
|
||||
|
||||
const BaseRef ReduceMinFission::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
return VectorRef({prim::kPrimReduceMin, X});
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -27,6 +27,9 @@ class ReduceMinFission : public PatternProcessPass {
|
|||
~ReduceMinFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
CNodePtr CreateReduceMin(const FuncGraphPtr &graph, const AnfNodePtr &input, const CNodePtr &old_node) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -24,8 +24,9 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kBatchNormRealInputNum = 3;
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) {
|
||||
AnfNodePtr SingleBatchNormFission::CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(bn);
|
||||
auto bn_cnode = bn->cast<CNodePtr>();
|
||||
|
@ -36,7 +37,7 @@ AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodeP
|
|||
}
|
||||
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)), bn_cnode->input(1)};
|
||||
auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs);
|
||||
auto bn_training_reduce = NewCNode(bn_training_reduce_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_training_reduce);
|
||||
|
||||
// set abstract
|
||||
|
@ -49,8 +50,9 @@ AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodeP
|
|||
return bn_training_reduce;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateBNTrainingUpdateV3(const FuncGraphPtr &func_graph, const AnfNodePtr &bn,
|
||||
const std::vector<AnfNodePtr> &bn_training_reduce_outputs) {
|
||||
AnfNodePtr SingleBatchNormFission::CreateBNTrainingUpdateV3(
|
||||
const FuncGraphPtr &func_graph, const AnfNodePtr &bn,
|
||||
const std::vector<AnfNodePtr> &bn_training_reduce_outputs) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(bn);
|
||||
auto bn_cnode = bn->cast<CNodePtr>();
|
||||
|
@ -71,7 +73,7 @@ AnfNodePtr CreateBNTrainingUpdateV3(const FuncGraphPtr &func_graph, const AnfNod
|
|||
bn_training_reduce_outputs[kIndex1],
|
||||
bn_cnode->input(kIndex2),
|
||||
bn_cnode->input(kIndex3)};
|
||||
auto bn_training_update_v3 = func_graph->NewCNode(bn_training_update_v3_inputs);
|
||||
auto bn_training_update_v3 = NewCNode(bn_training_update_v3_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_training_update_v3);
|
||||
|
||||
auto bn_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn->abstract());
|
||||
|
@ -85,7 +87,6 @@ AnfNodePtr CreateBNTrainingUpdateV3(const FuncGraphPtr &func_graph, const AnfNod
|
|||
AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_cnode, bn_training_update_v3);
|
||||
return bn_training_update_v3;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef SingleBatchNormFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,6 +16,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -27,6 +28,11 @@ class SingleBatchNormFission : public PatternProcessPass {
|
|||
~SingleBatchNormFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) const;
|
||||
AnfNodePtr CreateBNTrainingUpdateV3(const FuncGraphPtr &func_graph, const AnfNodePtr &bn,
|
||||
const std::vector<AnfNodePtr> &bn_training_reduce_outputs) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -116,7 +116,7 @@ const AnfNodePtr SpaceToDepthSplit::Process(const FuncGraphPtr &graph, const Anf
|
|||
auto last_input_value = CreateValueNode(cnode);
|
||||
(void)new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
(void)new_inputs.emplace_back(last_input_value);
|
||||
CNodePtr new_cnode = graph->NewCNode(new_inputs);
|
||||
CNodePtr new_cnode = NewCNode(new_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_abstract(cnode->abstract());
|
||||
new_cnode->set_scope(cnode->scope());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -22,22 +22,6 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
CNodePtr CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
std::vector<AnfNodePtr> splitv_inputs{NewValueNode(std::make_shared<Primitive>(kSplitVOpName)), input_node};
|
||||
CNodePtr splitv = func_graph->NewCNode(splitv_inputs);
|
||||
MS_EXCEPTION_IF_NULL(splitv);
|
||||
splitv->set_scope(input_node->scope());
|
||||
return splitv;
|
||||
}
|
||||
|
||||
CNodePtr CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) {
|
||||
MS_EXCEPTION_IF_NULL(origin_cnode);
|
||||
CheckCNodeInputSize(origin_cnode, kSplitInputTensorNum);
|
||||
return CreateSplitVNode(func_graph, origin_cnode->input(1));
|
||||
}
|
||||
|
||||
void SetAttrForSplitVNode(const AnfNodePtr &splitv, const std::vector<int64_t> &size_splits, int64_t split_dim,
|
||||
int64_t num_split) {
|
||||
AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(size_splits), splitv);
|
||||
|
@ -121,6 +105,22 @@ void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePt
|
|||
}
|
||||
} // namespace
|
||||
|
||||
CNodePtr SplitFission::CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
std::vector<AnfNodePtr> splitv_inputs{NewValueNode(std::make_shared<Primitive>(kSplitVOpName)), input_node};
|
||||
CNodePtr splitv = NewCNode(splitv_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(splitv);
|
||||
splitv->set_scope(input_node->scope());
|
||||
return splitv;
|
||||
}
|
||||
|
||||
CNodePtr SplitFission::CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) const {
|
||||
MS_EXCEPTION_IF_NULL(origin_cnode);
|
||||
CheckCNodeInputSize(origin_cnode, kSplitInputTensorNum);
|
||||
return CreateSplitVNode(func_graph, origin_cnode->input(1));
|
||||
}
|
||||
|
||||
AnfNodePtr SplitFission::DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int64_t num_split,
|
||||
int64_t divisor, int64_t split_dim) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -34,6 +34,8 @@ class SplitFission : public PatternProcessPass {
|
|||
protected:
|
||||
AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int64_t num_split, int64_t divisor,
|
||||
int64_t split_dim) const;
|
||||
CNodePtr CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node) const;
|
||||
CNodePtr CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) const;
|
||||
int64_t outputs_divisor_;
|
||||
};
|
||||
} // namespace opt
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -21,13 +21,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
CNodePtr CreateTensorMove(const FuncGraphPtr &graph, const CNodePtr &tensor_scatter_update) {
|
||||
CNodePtr TensorScatterUpdateFission::CreateTensorMove(const FuncGraphPtr &graph,
|
||||
const CNodePtr &tensor_scatter_update) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(tensor_scatter_update);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kTensorMoveOpName)),
|
||||
tensor_scatter_update->input(1)};
|
||||
auto tensor_move = graph->NewCNode(inputs);
|
||||
auto tensor_move = NewCNode(inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(tensor_move);
|
||||
tensor_move->set_scope(tensor_scatter_update->scope());
|
||||
tensor_move->set_abstract(tensor_scatter_update->abstract());
|
||||
|
@ -35,20 +35,20 @@ CNodePtr CreateTensorMove(const FuncGraphPtr &graph, const CNodePtr &tensor_scat
|
|||
return tensor_move;
|
||||
}
|
||||
|
||||
CNodePtr CreateScatterNdUpdate(const FuncGraphPtr &graph, const CNodePtr &tensor_scatter_update,
|
||||
const CNodePtr &tensor_move) {
|
||||
CNodePtr TensorScatterUpdateFission::CreateScatterNdUpdate(const FuncGraphPtr &graph,
|
||||
const CNodePtr &tensor_scatter_update,
|
||||
const CNodePtr &tensor_move) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(tensor_scatter_update);
|
||||
MS_EXCEPTION_IF_NULL(tensor_move);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kScatterNdUpdateOpName)), tensor_move,
|
||||
tensor_scatter_update->input(2), tensor_scatter_update->input(3)};
|
||||
auto scatter_nd_update = graph->NewCNode(inputs);
|
||||
auto scatter_nd_update = NewCNode(inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(scatter_nd_update);
|
||||
scatter_nd_update->set_scope(tensor_scatter_update->scope());
|
||||
scatter_nd_update->set_abstract(tensor_scatter_update->abstract());
|
||||
return scatter_nd_update;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef TensorScatterUpdateFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -27,6 +27,11 @@ class TensorScatterUpdateFission : public PatternProcessPass {
|
|||
~TensorScatterUpdateFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
CNodePtr CreateTensorMove(const FuncGraphPtr &graph, const CNodePtr &tensor_scatter_update) const;
|
||||
CNodePtr CreateScatterNdUpdate(const FuncGraphPtr &graph, const CNodePtr &tensor_scatter_update,
|
||||
const CNodePtr &tensor_move) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -27,9 +27,10 @@
|
|||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
constexpr size_t kFloat16Len = 2; // size of float16;
|
||||
constexpr size_t kTopkIndexK = 1;
|
||||
namespace {
|
||||
|
||||
tensor::TensorPtr CreateTensor() {
|
||||
// 1 create tensor
|
||||
const size_t last_dim = 4096;
|
||||
|
@ -142,7 +143,7 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
|
|||
// Copy a new node to check supported.
|
||||
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kTopKOpName))};
|
||||
new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
CNodePtr new_cnode = func_graph->NewCNode(new_inputs);
|
||||
CNodePtr new_cnode = NewCNode(new_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_abstract(cnode->abstract());
|
||||
new_cnode->set_scope(cnode->scope());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -23,59 +23,6 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
CNodePtr CreatePadding(const FuncGraphPtr &graph, const CNodePtr &origin_node, const size_t &pad_dim_size) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
std::vector<AnfNodePtr> padding_inputs = {NewValueNode(std::make_shared<Primitive>(kPaddingOpName)),
|
||||
origin_node->input(kIndex1)};
|
||||
auto padding = graph->NewCNode(padding_inputs);
|
||||
MS_EXCEPTION_IF_NULL(padding);
|
||||
padding->set_scope(origin_node->scope());
|
||||
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
|
||||
shape[shape.size() - 1] = pad_dim_size;
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(origin_node, 0)}, {shape},
|
||||
padding.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrPadDimSize, MakeValue(SizeToLong(pad_dim_size)), padding);
|
||||
return padding;
|
||||
}
|
||||
|
||||
CNodePtr CreateUnsortedSegmentSum(const FuncGraphPtr &graph, const CNodePtr &origin_node, const CNodePtr &padding,
|
||||
const size_t &pad_dim_size) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
MS_EXCEPTION_IF_NULL(padding);
|
||||
std::vector<AnfNodePtr> unsorted_segment_sum8_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimUnsortedSegmentSum->name())), padding,
|
||||
origin_node->input(kIndex2)};
|
||||
auto unsorted_segment_sum = graph->NewCNode(unsorted_segment_sum8_inputs);
|
||||
MS_EXCEPTION_IF_NULL(unsorted_segment_sum);
|
||||
unsorted_segment_sum->set_scope(origin_node->scope());
|
||||
auto shape = AnfAlgo::GetOutputInferShape(origin_node, 0);
|
||||
shape[shape.size() - 1] = pad_dim_size;
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_node, 0)}, {shape},
|
||||
unsorted_segment_sum.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrNumSegments, MakeValue(SizeToLong(shape[0])), unsorted_segment_sum);
|
||||
return unsorted_segment_sum;
|
||||
}
|
||||
|
||||
CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_sum,
|
||||
const CNodePtr &unsorted_segment_sum8) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(unsort_segment_sum);
|
||||
MS_EXCEPTION_IF_NULL(unsorted_segment_sum8);
|
||||
std::vector<AnfNodePtr> slice_inputs = {NewValueNode(std::make_shared<Primitive>(kSliceOpName)),
|
||||
unsorted_segment_sum8};
|
||||
auto slice = graph->NewCNode(slice_inputs);
|
||||
MS_EXCEPTION_IF_NULL(slice);
|
||||
slice->set_scope(unsort_segment_sum->scope());
|
||||
slice->set_abstract(unsort_segment_sum->abstract());
|
||||
auto unsort_segment_sum_shape = AnfAlgo::GetOutputInferShape(unsort_segment_sum, 0);
|
||||
std::vector<size_t> offsets(unsort_segment_sum_shape.size(), 0);
|
||||
AnfAlgo::SetNodeAttr(kAttrBegin, MakeValue(Convert2Long(offsets)), slice);
|
||||
AnfAlgo::SetNodeAttr(kAttrSize, MakeValue(Convert2Long(unsort_segment_sum_shape)), slice);
|
||||
return slice;
|
||||
}
|
||||
|
||||
bool CheckInputs(const CNodePtr &origin_node) {
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
if (AnfAlgo::GetInputTensorNum(origin_node) != kUnsortedSegmentSumInputTensorNum) {
|
||||
|
@ -97,6 +44,60 @@ bool CheckInputs(const CNodePtr &origin_node) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
CNodePtr UnsortSegmentSumFission::CreatePadding(const FuncGraphPtr &graph, const CNodePtr &origin_node,
|
||||
const size_t &pad_dim_size) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
std::vector<AnfNodePtr> padding_inputs = {NewValueNode(std::make_shared<Primitive>(kPaddingOpName)),
|
||||
origin_node->input(kIndex1)};
|
||||
auto padding = NewCNode(padding_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(padding);
|
||||
padding->set_scope(origin_node->scope());
|
||||
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
|
||||
shape[shape.size() - 1] = pad_dim_size;
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(origin_node, 0)}, {shape},
|
||||
padding.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrPadDimSize, MakeValue(SizeToLong(pad_dim_size)), padding);
|
||||
return padding;
|
||||
}
|
||||
|
||||
CNodePtr UnsortSegmentSumFission::CreateUnsortedSegmentSum(const FuncGraphPtr &graph, const CNodePtr &origin_node,
|
||||
const CNodePtr &padding, const size_t &pad_dim_size) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
MS_EXCEPTION_IF_NULL(padding);
|
||||
std::vector<AnfNodePtr> unsorted_segment_sum8_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimUnsortedSegmentSum->name())), padding,
|
||||
origin_node->input(kIndex2)};
|
||||
auto unsorted_segment_sum = NewCNode(unsorted_segment_sum8_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(unsorted_segment_sum);
|
||||
unsorted_segment_sum->set_scope(origin_node->scope());
|
||||
auto shape = AnfAlgo::GetOutputInferShape(origin_node, 0);
|
||||
shape[shape.size() - 1] = pad_dim_size;
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_node, 0)}, {shape},
|
||||
unsorted_segment_sum.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrNumSegments, MakeValue(SizeToLong(shape[0])), unsorted_segment_sum);
|
||||
return unsorted_segment_sum;
|
||||
}
|
||||
|
||||
CNodePtr UnsortSegmentSumFission::CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_sum,
|
||||
const CNodePtr &unsorted_segment_sum8) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(unsort_segment_sum);
|
||||
MS_EXCEPTION_IF_NULL(unsorted_segment_sum8);
|
||||
std::vector<AnfNodePtr> slice_inputs = {NewValueNode(std::make_shared<Primitive>(kSliceOpName)),
|
||||
unsorted_segment_sum8};
|
||||
auto slice = NewCNode(slice_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(slice);
|
||||
slice->set_scope(unsort_segment_sum->scope());
|
||||
slice->set_abstract(unsort_segment_sum->abstract());
|
||||
auto unsort_segment_sum_shape = AnfAlgo::GetOutputInferShape(unsort_segment_sum, 0);
|
||||
std::vector<size_t> offsets(unsort_segment_sum_shape.size(), 0);
|
||||
AnfAlgo::SetNodeAttr(kAttrBegin, MakeValue(Convert2Long(offsets)), slice);
|
||||
AnfAlgo::SetNodeAttr(kAttrSize, MakeValue(Convert2Long(unsort_segment_sum_shape)), slice);
|
||||
return slice;
|
||||
}
|
||||
|
||||
const BaseRef UnsortSegmentSumFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
VectorRef pattern({prim::kPrimUnsortedSegmentSum, Xs});
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -31,6 +31,13 @@ class UnsortSegmentSumFission : public PatternProcessPass {
|
|||
~UnsortSegmentSumFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
CNodePtr CreatePadding(const FuncGraphPtr &graph, const CNodePtr &origin_node, const size_t &pad_dim_size) const;
|
||||
CNodePtr CreateUnsortedSegmentSum(const FuncGraphPtr &graph, const CNodePtr &origin_node, const CNodePtr &padding,
|
||||
const size_t &pad_dim_size) const;
|
||||
CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_sum,
|
||||
const CNodePtr &unsorted_segment_sum8) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -228,7 +228,7 @@ AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_g
|
|||
auto add2_y_node = utils::cast<AnfNodePtr>((*equiv)[add2_y_]);
|
||||
MS_EXCEPTION_IF_NULL(add2_y_node);
|
||||
new_node_inputs.push_back(add2_y_node);
|
||||
auto new_node = func_graph->NewCNode(new_node_inputs);
|
||||
auto new_node = NewCNode(new_node_inputs, func_graph);
|
||||
return new_node;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -288,7 +288,7 @@ const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, c
|
|||
return nullptr;
|
||||
}
|
||||
std::vector<AnfNodePtr> inputs = GetFusionNodeInputs(equiv, node);
|
||||
auto fusion_node = graph->NewCNode(inputs);
|
||||
auto fusion_node = NewCNode(inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(fusion_node);
|
||||
fusion_node->set_scope(sub0->scope());
|
||||
|
||||
|
|
|
@ -304,7 +304,7 @@ const AnfNodePtr AvgPool3DFusion::Process(const FuncGraphPtr &func_graph, const
|
|||
pad_list, count_include_pad);
|
||||
new_inputs.push_back(multiplier);
|
||||
}
|
||||
auto new_3d = func_graph->NewCNode(new_inputs);
|
||||
auto new_3d = NewCNode(new_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(new_3d);
|
||||
new_3d->set_scope(avg_pool_3d_node->scope());
|
||||
new_3d->set_abstract(avg_pool_3d_node->abstract());
|
||||
|
|
|
@ -235,7 +235,7 @@ const AnfNodePtr AvgPool3DGradFusion::Process(const FuncGraphPtr &func_graph, co
|
|||
ConstructMultiplier(func_graph, dims_in, origin_input_shape, kernel_size, strides, pad_list, count_include_pad);
|
||||
new_inputs.push_back(multiplier);
|
||||
}
|
||||
auto new_3d_grad = func_graph->NewCNode(new_inputs);
|
||||
auto new_3d_grad = NewCNode(new_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(new_3d_grad);
|
||||
new_3d_grad->set_scope(avg_pool_3d_grad_node->scope());
|
||||
new_3d_grad->set_abstract(avg_pool_3d_grad_node->abstract());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -26,24 +26,6 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
CNodePtr CreateBNInfer(const FuncGraphPtr &graph, const CNodePtr &batchnorm, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(batchnorm);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto prim = std::make_shared<Primitive>(kBNInferOpName);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
|
||||
for (size_t i = 1; i < batchnorm->size(); ++i) {
|
||||
inputs.push_back(batchnorm->input(i));
|
||||
}
|
||||
auto new_node = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
new_node->set_scope(batchnorm->scope());
|
||||
new_node->set_abstract(node->abstract());
|
||||
AnfAlgo::CopyNodeAttr(kAttrIsTraining, batchnorm, new_node);
|
||||
AnfAlgo::CopyNodeAttr(kAttrEpsilon, batchnorm, new_node);
|
||||
return new_node;
|
||||
}
|
||||
|
||||
bool CheckIndex(const AnfNodePtr &index_node) {
|
||||
MS_EXCEPTION_IF_NULL(index_node);
|
||||
if (!IsValueNode<Int64Imm>(index_node)) {
|
||||
|
@ -103,6 +85,25 @@ bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *bat
|
|||
}
|
||||
} // namespace
|
||||
|
||||
CNodePtr BatchNorm2BNInfer::CreateBNInfer(const FuncGraphPtr &graph, const CNodePtr &batchnorm,
|
||||
const AnfNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(batchnorm);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto prim = std::make_shared<Primitive>(kBNInferOpName);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
|
||||
for (size_t i = 1; i < batchnorm->size(); ++i) {
|
||||
inputs.push_back(batchnorm->input(i));
|
||||
}
|
||||
auto new_node = NewCNode(inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
new_node->set_scope(batchnorm->scope());
|
||||
new_node->set_abstract(node->abstract());
|
||||
AnfAlgo::CopyNodeAttr(kAttrIsTraining, batchnorm, new_node);
|
||||
AnfAlgo::CopyNodeAttr(kAttrEpsilon, batchnorm, new_node);
|
||||
return new_node;
|
||||
}
|
||||
|
||||
const BaseRef BatchNorm2BNInfer::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
VarPtr Y = std::make_shared<Var>();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -27,6 +27,9 @@ class BatchNorm2BNInfer : public PatternProcessPass {
|
|||
~BatchNorm2BNInfer() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
CNodePtr CreateBNInfer(const FuncGraphPtr &graph, const CNodePtr &batchnorm, const AnfNodePtr &node) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -85,7 +85,7 @@ const AnfNodePtr BNReduceGradConv2dBackpropFilterFusion::Process(const FuncGraph
|
|||
for (size_t i = 1; i <= kBNTrainingReduceGradInputNum; ++i) {
|
||||
fused_dbn_dw_inputs.push_back(bnreduce_grad->input(i));
|
||||
}
|
||||
auto fused_dbn_dw = graph->NewCNode(fused_dbn_dw_inputs);
|
||||
auto fused_dbn_dw = NewCNode(fused_dbn_dw_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(fused_dbn_dw);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(bnreduce_grad, 0),
|
||||
AnfAlgo::GetOutputInferDataType(conv_back_filter, 0)};
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -61,7 +61,7 @@ const AnfNodePtr ClipByNormNoDivSquareSumFusion::Process(const FuncGraphPtr &gra
|
|||
auto prim = std::make_shared<Primitive>(kClipByNormNoDivSumOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), input, constant_greater, constant_select, constant_maximum};
|
||||
auto fusion_node = graph->NewCNode(inputs);
|
||||
auto fusion_node = NewCNode(inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(fusion_node);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -87,7 +87,7 @@ const AnfNodePtr ClipByValueFusion::Process(const FuncGraphPtr &graph, const Anf
|
|||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), minimum->input(kIndex1),
|
||||
is_first_input ? maximum_input1 : maximum_input0, minimum->input(kIndex2)};
|
||||
auto clip_by_value = graph->NewCNode(inputs);
|
||||
auto clip_by_value = NewCNode(inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(clip_by_value);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -30,28 +30,6 @@ namespace opt {
|
|||
namespace {
|
||||
const size_t kConfusionMulGradOutputNum = 2;
|
||||
|
||||
CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &reduce_sum, const AnfNodePtr &mul0_anf,
|
||||
const AnfNodePtr &input3) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(reduce_sum);
|
||||
MS_EXCEPTION_IF_NULL(mul0_anf);
|
||||
MS_EXCEPTION_IF_NULL(input3);
|
||||
auto mul0 = mul0_anf->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(mul0);
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kConfusionMulGradOpName);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), mul0->input(kIndex1), mul0->input(kIndex2), input3};
|
||||
auto fusion_node = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(fusion_node);
|
||||
fusion_node->set_scope(reduce_sum->scope());
|
||||
AnfAlgo::CopyNodeAttr(kAttrAxis, reduce_sum, fusion_node);
|
||||
AnfAlgo::CopyNodeAttr(kAttrKeepDims, reduce_sum, fusion_node);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(mul0, 0), AnfAlgo::GetOutputInferDataType(reduce_sum, 0)};
|
||||
auto shapes = {AnfAlgo::GetOutputInferShape(mul0, 0), AnfAlgo::GetOutputInferShape(reduce_sum, 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get());
|
||||
return fusion_node;
|
||||
}
|
||||
|
||||
AnfNodePtr GetMul0(const FuncGraphPtr &graph, const AnfNodePtr &input2, const AnfNodePtr &mul1) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(input2);
|
||||
|
@ -117,6 +95,28 @@ bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const Anf
|
|||
}
|
||||
} // namespace
|
||||
|
||||
CNodePtr ConfusionMulGradFusion::CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &reduce_sum,
|
||||
const AnfNodePtr &mul0_anf, const AnfNodePtr &input3) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(reduce_sum);
|
||||
MS_EXCEPTION_IF_NULL(mul0_anf);
|
||||
MS_EXCEPTION_IF_NULL(input3);
|
||||
auto mul0 = mul0_anf->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(mul0);
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kConfusionMulGradOpName);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), mul0->input(kIndex1), mul0->input(kIndex2), input3};
|
||||
auto fusion_node = NewCNode(inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(fusion_node);
|
||||
fusion_node->set_scope(reduce_sum->scope());
|
||||
AnfAlgo::CopyNodeAttr(kAttrAxis, reduce_sum, fusion_node);
|
||||
AnfAlgo::CopyNodeAttr(kAttrKeepDims, reduce_sum, fusion_node);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(mul0, 0), AnfAlgo::GetOutputInferDataType(reduce_sum, 0)};
|
||||
auto shapes = {AnfAlgo::GetOutputInferShape(mul0, 0), AnfAlgo::GetOutputInferShape(reduce_sum, 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get());
|
||||
return fusion_node;
|
||||
}
|
||||
|
||||
const BaseRef ConfusionMulGradFusion::DefinePattern() const {
|
||||
VectorRef mul1({prim::kPrimMul, input3_, input2_});
|
||||
VectorRef reduce_sum({prim::kPrimReduceSum, mul1});
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -33,6 +33,8 @@ class ConfusionMulGradFusion : public PatternProcessPass {
|
|||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &reduce_sum, const AnfNodePtr &mul0_anf,
|
||||
const AnfNodePtr &input3) const;
|
||||
VarPtr input2_;
|
||||
VarPtr input3_;
|
||||
};
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -35,15 +35,16 @@ CNodePtr GetRelu(const CNodePtr &relu_grad) {
|
|||
MS_EXCEPTION_IF_NULL(relu_anf);
|
||||
return relu_anf->cast<CNodePtr>();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) {
|
||||
CNodePtr DereluFusion::CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(relu);
|
||||
CheckCNodeInputSize(relu, kReluInputTensorNum);
|
||||
constexpr auto kMaskShapeSize = 4;
|
||||
auto prim = std::make_shared<Primitive>(kReluV2OpName);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), relu->input(kIndex1)};
|
||||
auto new_node = graph->NewCNode(inputs);
|
||||
auto new_node = NewCNode(inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
new_node->set_scope(relu->scope());
|
||||
|
||||
|
@ -79,20 +80,20 @@ CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) {
|
|||
return new_node;
|
||||
}
|
||||
|
||||
CNodePtr CreateReluGradV2(const FuncGraphPtr &graph, const CNodePtr &relu_grad, const AnfNodePtr &second_input) {
|
||||
CNodePtr DereluFusion::CreateReluGradV2(const FuncGraphPtr &graph, const CNodePtr &relu_grad,
|
||||
const AnfNodePtr &second_input) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(relu_grad);
|
||||
MS_EXCEPTION_IF_NULL(second_input);
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kReluGradV2OpName);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), relu_grad->input(1), second_input};
|
||||
auto new_node = graph->NewCNode(inputs);
|
||||
auto new_node = NewCNode(inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
new_node->set_scope(relu_grad->scope());
|
||||
new_node->set_abstract(relu_grad->abstract());
|
||||
return new_node;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef DereluFusion::DefinePattern() const {
|
||||
VarPtr i0 = std::make_shared<Var>();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -27,6 +27,10 @@ class DereluFusion : public PatternProcessPass {
|
|||
~DereluFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) const;
|
||||
CNodePtr CreateReluGradV2(const FuncGraphPtr &graph, const CNodePtr &relu_grad, const AnfNodePtr &second_input) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -92,7 +92,7 @@ AnfNodePtr FusedBatchNormFusion::CreateBNTrainingReduce(const FuncGraphPtr &func
|
|||
// Set input to create node
|
||||
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)), GetAnfNodeByVar(equiv, data_input0_var_)};
|
||||
auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs);
|
||||
auto bn_training_reduce = NewCNode(bn_training_reduce_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_training_reduce);
|
||||
bn_training_reduce->set_scope(node->scope());
|
||||
// Set abstract
|
||||
|
@ -151,7 +151,7 @@ AnfNodePtr FusedBatchNormFusion::CreateBNTrainingUpdate(
|
|||
// Set input
|
||||
std::vector<AnfNodePtr> bn_training_update_inputs;
|
||||
GetBNTrainingUpdateInputs(equiv, bn_training_reduce_outputs, &bn_training_update_inputs);
|
||||
auto bn_training_update = func_graph->NewCNode(bn_training_update_inputs);
|
||||
auto bn_training_update = NewCNode(bn_training_update_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_training_update);
|
||||
// Set abstract
|
||||
AnfNodePtr bn = GetAnfNodeByVar(equiv, batch_norm_var_);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -74,7 +74,7 @@ AnfNodePtr LambNextMVRule::CreateLambNextMVNode(const FuncGraphPtr &func_graph,
|
|||
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul3_sub1_]));
|
||||
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul4_x_]));
|
||||
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[add2_y_]));
|
||||
auto lamb_next_mv_rule = func_graph->NewCNode(lamb_next_mv_rule_inputs);
|
||||
auto lamb_next_mv_rule = NewCNode(lamb_next_mv_rule_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(lamb_next_mv_rule);
|
||||
|
||||
// Set abstract of new node
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -74,7 +74,7 @@ AnfNodePtr LambNextMVWithDecayRule::CreateLambNextMVWithDecayNode(const FuncGrap
|
|||
auto constant_add2_y_node = utils::cast<AnfNodePtr>((*equiv)[constant_add2_y_]);
|
||||
MS_EXCEPTION_IF_NULL(constant_add2_y_node);
|
||||
new_node_inputs.push_back(constant_add2_y_node);
|
||||
auto new_node = func_graph->NewCNode(new_node_inputs);
|
||||
auto new_node = NewCNode(new_node_inputs, func_graph);
|
||||
return GetLambNextMVWithDecayOutput(func_graph, new_node, add3, add5, equiv);
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -44,7 +44,7 @@ AnfNodePtr LambNextRightRule::CreateLambNextRightNode(const FuncGraphPtr &func_g
|
|||
auto add2_y = utils::cast<AnfNodePtr>((*equiv)[add2_y_]);
|
||||
MS_EXCEPTION_IF_NULL(add2_y);
|
||||
new_node_inputs.push_back(add2_y);
|
||||
auto new_node = func_graph->NewCNode(new_node_inputs);
|
||||
auto new_node = NewCNode(new_node_inputs, func_graph);
|
||||
return new_node;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -65,7 +65,7 @@ const AnfNodePtr LambUpdateWithLRRuleFusion::Process(const FuncGraphPtr &graph,
|
|||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {
|
||||
NewValueNode(prim), input0, input1, input2, input3, input4, input5, input6, input7, input8};
|
||||
auto lamb_update_with_lr = graph->NewCNode(inputs);
|
||||
auto lamb_update_with_lr = NewCNode(inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(lamb_update_with_lr);
|
||||
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -47,7 +47,7 @@ const AnfNodePtr LambUpdateWithLrV2::Process(const FuncGraphPtr &func_graph, con
|
|||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
|
||||
(void)std::transform(input_varptr_.begin(), input_varptr_.end(), std::back_inserter(inputs),
|
||||
[&equiv](const VarPtr &in) { return utils::cast<AnfNodePtr>((*equiv)[in]); });
|
||||
auto lamb_update_with_lr_v2 = func_graph->NewCNode(inputs);
|
||||
auto lamb_update_with_lr_v2 = NewCNode(inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(lamb_update_with_lr_v2);
|
||||
lamb_update_with_lr_v2->set_abstract(node->abstract());
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -48,7 +48,7 @@ const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &graph, const A
|
|||
inputs.emplace_back(GetAnfNodeByVar(equiv, x0_));
|
||||
inputs.emplace_back(GetAnfNodeByVar(equiv, x1_));
|
||||
inputs.emplace_back(GetAnfNodeByVar(equiv, x2_));
|
||||
auto new_node = graph->NewCNode(inputs);
|
||||
auto new_node = NewCNode(inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
new_node->set_scope(node->scope());
|
||||
new_node->set_abstract(node->abstract());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -85,7 +85,7 @@ const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph
|
|||
mul_cnode->input(kMulInputTensorNum + 1 - value_node_index),
|
||||
depend,
|
||||
mul_cnode->input(value_node_index)};
|
||||
auto new_node = func_graph->NewCNode(new_node_inputs);
|
||||
auto new_node = NewCNode(new_node_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
AnfAlgo::CopyNodeAttrs(node, new_node);
|
||||
auto input_names_value = AnfAlgo::GetNodeAttr<std::vector<std::string>>(new_node, kAttrInputNames);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -90,7 +90,7 @@ const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodeP
|
|||
return nullptr;
|
||||
}
|
||||
inputs.push_back(another_input_node);
|
||||
auto fusion_node = graph->NewCNode(inputs);
|
||||
auto fusion_node = NewCNode(inputs, graph);
|
||||
fusion_node->set_scope(add->scope());
|
||||
fusion_node->set_abstract(add->abstract());
|
||||
return fusion_node;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -22,7 +22,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &mul, const CNodePtr &addn,
|
||||
const size_t &lossscale_input_index) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -34,13 +33,12 @@ CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &mul, const
|
|||
inputs.push_back(addn->input(kIndex2));
|
||||
// scalar input should be 3rd input
|
||||
inputs.push_back(mul->input(lossscale_input_index));
|
||||
auto fusion_node = graph->NewCNode(inputs);
|
||||
auto fusion_node = NewCNode(inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(fusion_node);
|
||||
fusion_node->set_scope(addn->scope());
|
||||
fusion_node->set_abstract(addn->abstract());
|
||||
return fusion_node;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef MulAddNFusion::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
|
|
|
@ -49,7 +49,7 @@ const AnfNodePtr PReluFusion::Process(const FuncGraphPtr &graph, const AnfNodePt
|
|||
auto prim = std::make_shared<Primitive>(kPReluOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, weight};
|
||||
auto fusion_node = graph->NewCNode(inputs);
|
||||
auto fusion_node = NewCNode(inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(fusion_node);
|
||||
fusion_node->set_abstract(node->abstract());
|
||||
fusion_node->set_scope(node->scope());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -69,7 +69,7 @@ const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph,
|
|||
|
||||
auto prim = std::make_shared<Primitive>(kConfusionTransposeDOpName);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), utils::cast<AnfNodePtr>((*equiv)[input_varptr_])};
|
||||
auto new_node = func_graph->NewCNode(inputs);
|
||||
auto new_node = NewCNode(inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
new_node->set_abstract(node->abstract());
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -65,7 +65,7 @@ const AnfNodePtr SoftmaxGradExtFusion::Process(const FuncGraphPtr &graph, const
|
|||
}
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kSoftmaxGradExtOpName);
|
||||
auto fusion_node = graph->NewCNode({NewValueNode(prim), input0, input1, input2});
|
||||
auto fusion_node = NewCNode({NewValueNode(prim), input0, input1, input2}, graph);
|
||||
MS_EXCEPTION_IF_NULL(fusion_node);
|
||||
fusion_node->set_scope(node->scope());
|
||||
fusion_node->set_abstract(node->abstract());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -30,7 +30,22 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
CNodePtr GenerateSquareSumV1(const FuncGraphPtr &graph, const CNodePtr &square, const CNodePtr &sum) {
|
||||
std::tuple<CNodePtr, AnfNodePtr, CNodePtr> GetPrevNodes(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto sum = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(sum);
|
||||
CheckCNodeInputSize(sum, kSumNodeInputTensorNum);
|
||||
auto square_anf = sum->input(1);
|
||||
MS_EXCEPTION_IF_NULL(square_anf);
|
||||
auto square = square_anf->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(square);
|
||||
|
||||
return std::make_tuple(sum, square_anf, square);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
CNodePtr SquareSumFusion::GenerateSquareSumV1(const FuncGraphPtr &graph, const CNodePtr &square,
|
||||
const CNodePtr &sum) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(square);
|
||||
MS_EXCEPTION_IF_NULL(sum);
|
||||
|
@ -38,7 +53,7 @@ CNodePtr GenerateSquareSumV1(const FuncGraphPtr &graph, const CNodePtr &square,
|
|||
auto prim = std::make_shared<Primitive>(kSquareSumV1OpName);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> square_sumv1_inputs = {NewValueNode(prim), square->input(1)};
|
||||
auto square_sumv1 = graph->NewCNode(square_sumv1_inputs);
|
||||
auto square_sumv1 = NewCNode(square_sumv1_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(square_sumv1);
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
|
@ -54,7 +69,8 @@ CNodePtr GenerateSquareSumV1(const FuncGraphPtr &graph, const CNodePtr &square,
|
|||
return square_sumv1;
|
||||
}
|
||||
|
||||
CNodePtr GenerateSquareSumV2(const FuncGraphPtr &graph, const CNodePtr &square, const CNodePtr &sum) {
|
||||
CNodePtr SquareSumFusion::GenerateSquareSumV2(const FuncGraphPtr &graph, const CNodePtr &square,
|
||||
const CNodePtr &sum) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(square);
|
||||
MS_EXCEPTION_IF_NULL(sum);
|
||||
|
@ -62,7 +78,7 @@ CNodePtr GenerateSquareSumV2(const FuncGraphPtr &graph, const CNodePtr &square,
|
|||
auto prim = std::make_shared<Primitive>(kSquareSumV2OpName);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> square_sumv2_inputs = {NewValueNode(prim), square->input(1)};
|
||||
auto square_sumv2 = graph->NewCNode(square_sumv2_inputs);
|
||||
auto square_sumv2 = NewCNode(square_sumv2_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(square_sumv2);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(sum, 0), AnfAlgo::GetOutputInferDataType(square, 0)};
|
||||
auto shapes = {AnfAlgo::GetOutputInferShape(sum, 0), AnfAlgo::GetOutputInferShape(square, 0)};
|
||||
|
@ -75,20 +91,6 @@ CNodePtr GenerateSquareSumV2(const FuncGraphPtr &graph, const CNodePtr &square,
|
|||
return square_sumv2;
|
||||
}
|
||||
|
||||
std::tuple<CNodePtr, AnfNodePtr, CNodePtr> GetPrevNodes(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto sum = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(sum);
|
||||
CheckCNodeInputSize(sum, kSumNodeInputTensorNum);
|
||||
auto square_anf = sum->input(1);
|
||||
MS_EXCEPTION_IF_NULL(square_anf);
|
||||
auto square = square_anf->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(square);
|
||||
|
||||
return std::make_tuple(sum, square_anf, square);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef SquareSumFusion::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
MS_EXCEPTION_IF_NULL(X);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -28,6 +28,10 @@ class SquareSumFusion : public PatternProcessPassWithSwitch {
|
|||
~SquareSumFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
CNodePtr GenerateSquareSumV1(const FuncGraphPtr &graph, const CNodePtr &square, const CNodePtr &sum) const;
|
||||
CNodePtr GenerateSquareSumV2(const FuncGraphPtr &graph, const CNodePtr &square, const CNodePtr &sum) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -68,7 +68,7 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph,
|
|||
}
|
||||
auto prim = std::make_shared<Primitive>(kConfusionTransposeDOpName);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), utils::cast<AnfNodePtr>((*equiv)[input_varptr_])};
|
||||
auto new_node = func_graph->NewCNode(inputs);
|
||||
auto new_node = NewCNode(inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
|
||||
new_node->set_abstract(node->abstract());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -57,7 +57,7 @@ const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_grap
|
|||
if (supported_checker_->CheckAICoreSupported(transdata_cnode, new_transdata_builder->Build())) {
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(new_fusion_transdata),
|
||||
utils::cast<AnfNodePtr>((*equiv)[input_varptr_])};
|
||||
auto new_node = func_graph->NewCNode(inputs);
|
||||
auto new_node = NewCNode(inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
new_node->set_abstract(node->abstract());
|
||||
AnfAlgo::CopyNodeAttrs(transdata_cnode, new_node);
|
||||
|
|
|
@ -79,7 +79,7 @@ const AnfNodePtr TransposedUpdateFusion::Process(const FuncGraphPtr &func_graph,
|
|||
auto perm_vnode = CreatePermValueNode(transposed);
|
||||
std::vector<AnfNodePtr> transpose_inputs = {NewValueNode(std::make_shared<Primitive>(kTransposeNODOpName)),
|
||||
transposed->input(1), perm_vnode};
|
||||
auto transpose = kernel_graph->NewCNode(transpose_inputs);
|
||||
auto transpose = NewCNode(transpose_inputs, kernel_graph);
|
||||
transpose->set_scope(transposed->scope());
|
||||
transpose->set_abstract(transposed->abstract());
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -53,8 +53,9 @@ uint32_t GetRankSize(const std::string &group) {
|
|||
}
|
||||
return rank_size;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all) {
|
||||
CNodePtr AllToAllUnifyMindIR::CreateSplitNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(all_to_all);
|
||||
int64_t split_count = AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrSplitCount);
|
||||
|
@ -66,7 +67,7 @@ CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all)
|
|||
auto all_to_all_input = all_to_all->input(kAllToAllInputIdx);
|
||||
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
|
||||
all_to_all_input};
|
||||
auto split_v = graph->NewCNode(split_input);
|
||||
auto split_v = NewCNode(split_input, graph);
|
||||
MS_EXCEPTION_IF_NULL(split_v);
|
||||
auto dtype = AnfAlgo::GetOutputInferDataType(all_to_all_input, 0);
|
||||
auto shape = AnfAlgo::GetOutputInferShape(all_to_all_input, 0);
|
||||
|
@ -90,7 +91,8 @@ CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all)
|
|||
return split_v;
|
||||
}
|
||||
|
||||
CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all, const CNodePtr &split) {
|
||||
CNodePtr AllToAllUnifyMindIR::CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all,
|
||||
const CNodePtr &split) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(all_to_all);
|
||||
MS_EXCEPTION_IF_NULL(split);
|
||||
|
@ -103,7 +105,7 @@ CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &all_to_a
|
|||
}
|
||||
std::vector<AnfNodePtr> all_to_all_v_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllVOpName))};
|
||||
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs.begin(), split_outputs.end());
|
||||
auto all_to_all_v = graph->NewCNode(all_to_all_v_input);
|
||||
auto all_to_all_v = NewCNode(all_to_all_v_input, graph);
|
||||
MS_EXCEPTION_IF_NULL(all_to_all_v);
|
||||
auto single_shape = AnfAlgo::GetOutputInferShape(split_outputs[0], 0);
|
||||
auto single_type = AnfAlgo::GetOutputInferDataType(split_outputs[0], 0);
|
||||
|
@ -123,7 +125,8 @@ CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &all_to_a
|
|||
return all_to_all_v;
|
||||
}
|
||||
|
||||
CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all, const CNodePtr &all_to_all_v) {
|
||||
CNodePtr AllToAllUnifyMindIR::CreateConcatNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all,
|
||||
const CNodePtr &all_to_all_v) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(all_to_all);
|
||||
MS_EXCEPTION_IF_NULL(all_to_all_v);
|
||||
|
@ -136,7 +139,7 @@ CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all,
|
|||
}
|
||||
std::vector<AnfNodePtr> concat_input = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))};
|
||||
(void)concat_input.insert(concat_input.end(), all_to_all_v_outputs.begin(), all_to_all_v_outputs.end());
|
||||
auto concat = graph->NewCNode(concat_input);
|
||||
auto concat = NewCNode(concat_input, graph);
|
||||
MS_EXCEPTION_IF_NULL(concat);
|
||||
auto single_shape = AnfAlgo::GetOutputInferShape(all_to_all_v_outputs[0], 0);
|
||||
concat_dim = NormalizeDim(single_shape, concat_dim);
|
||||
|
@ -152,7 +155,6 @@ CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all,
|
|||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat);
|
||||
return concat;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef NeighborExchangeUnifyMindIR::DefinePattern() const {
|
||||
return VectorRef({prim::kPrimNeighborExchange, std::make_shared<SeqVar>()});
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -36,6 +36,11 @@ class AllToAllUnifyMindIR : public PatternProcessPass {
|
|||
~AllToAllUnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all) const;
|
||||
CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all, const CNodePtr &split) const;
|
||||
CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all, const CNodePtr &all_to_all_v) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -203,7 +203,7 @@ const AnfNodePtr AvgPoolGradUnifyMindIR::Process(const FuncGraphPtr &graph, cons
|
|||
std::vector<AnfNodePtr> avgpool_grad_vm_inputs = {NewValueNode(std::make_shared<Primitive>(kAvgPoolGradVmOpName)),
|
||||
x_shape_vnode, avgpool_grad->input(3), mean_matrix_vnode,
|
||||
kernel_matrix_vnode};
|
||||
auto avgpool_grad_vm = graph->NewCNode(avgpool_grad_vm_inputs);
|
||||
auto avgpool_grad_vm = NewCNode(avgpool_grad_vm_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(avgpool_grad_vm);
|
||||
avgpool_grad_vm->set_scope(avgpool_grad->scope());
|
||||
avgpool_grad_vm->set_abstract(avgpool_grad->abstract());
|
||||
|
|
|
@ -28,8 +28,10 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
namespace {
|
||||
constexpr auto kAttrUnifyIRPassed = "unifyir_passed";
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr CreateNewBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node) {
|
||||
AnfNodePtr BatchNormGradUnifyMindIR::CreateNewBatchNormGrad(const FuncGraphPtr &graph,
|
||||
const CNodePtr &bn_grad_node) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_grad_node);
|
||||
size_t kBNGradInputNum = 6;
|
||||
|
@ -41,7 +43,7 @@ AnfNodePtr CreateNewBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &bn_
|
|||
bn_grad_node_inputs[kDim3],
|
||||
bn_grad_node_inputs[kDim4],
|
||||
bn_grad_node_inputs[kDim5]};
|
||||
auto new_bn_grad = graph->NewCNode(bn_grad_inputs);
|
||||
auto new_bn_grad = NewCNode(bn_grad_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(new_bn_grad);
|
||||
new_bn_grad->set_scope(bn_grad_node->scope());
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(bn_grad_node, 0), AnfAlgo::GetOutputInferDataType(bn_grad_node, 1),
|
||||
|
@ -56,7 +58,6 @@ AnfNodePtr CreateNewBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &bn_
|
|||
AnfAlgo::SetNodeAttr(kAttrUnifyIRPassed, MakeValue(true), new_bn_grad);
|
||||
return new_bn_grad;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef BatchNormGradUnifyMindIR::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
|
|
|
@ -27,6 +27,9 @@ class BatchNormGradUnifyMindIR : public PatternProcessPass {
|
|||
~BatchNormGradUnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
AnfNodePtr CreateNewBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -91,7 +91,7 @@ ValueNodePtr CreatePermValueNode(const FuncGraphPtr &func_graph, const std::vect
|
|||
}
|
||||
|
||||
CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, const AnfNodePtr &input_node,
|
||||
bool need_trans_output) {
|
||||
bool need_trans_output, const PatternProcessPass &pass) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(conv2d);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
|
@ -105,7 +105,7 @@ CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, cons
|
|||
transpose_inputs = {NewValueNode(std::make_shared<Primitive>(kTransposeOpName)), input_node,
|
||||
CreatePermValueNode(graph, perm)};
|
||||
}
|
||||
auto transpose = graph->NewCNode(transpose_inputs);
|
||||
auto transpose = pass.NewCNode(transpose_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(transpose);
|
||||
transpose->set_scope(conv2d->scope());
|
||||
|
||||
|
@ -133,82 +133,6 @@ CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, cons
|
|||
return transpose;
|
||||
}
|
||||
|
||||
CNodePtr CreateDepthwiseConv2D(const FuncGraphPtr &graph, const CNodePtr &conv2d, const CNodePtr &transpose) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(conv2d);
|
||||
CheckCNodeInputSize(conv2d, kConvInputTensorNum);
|
||||
std::vector<AnfNodePtr> depth_conv_inputs = {NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeOpName)),
|
||||
conv2d->input(kIndex1), transpose};
|
||||
auto depth_conv = graph->NewCNode(depth_conv_inputs);
|
||||
MS_EXCEPTION_IF_NULL(depth_conv);
|
||||
depth_conv->set_abstract(conv2d->abstract());
|
||||
depth_conv->set_scope(conv2d->scope());
|
||||
return depth_conv;
|
||||
}
|
||||
|
||||
CNodePtr CreateDepthwiseConv2DBackpropInput(const FuncGraphPtr &graph, const CNodePtr &conv2d_backin,
|
||||
const CNodePtr &transpose) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(conv2d_backin);
|
||||
|
||||
CNodePtr depth_conv_backin = nullptr;
|
||||
if (conv2d_backin->inputs().size() == kConv2DBackpropInputNum) {
|
||||
std::vector<AnfNodePtr> depth_conv_backin_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)),
|
||||
conv2d_backin->input(kIndex3), transpose, conv2d_backin->input(kIndex1)};
|
||||
depth_conv_backin = graph->NewCNode(depth_conv_backin_inputs);
|
||||
} else {
|
||||
// In nn.Conv2DTranspose, Conv2DBackpropInput is a forward op and the input_sizes input will be convert to attr
|
||||
// in pynative mode.
|
||||
std::vector<AnfNodePtr> depth_conv_backin_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)), transpose,
|
||||
conv2d_backin->input(kIndex1)};
|
||||
depth_conv_backin = graph->NewCNode(depth_conv_backin_inputs);
|
||||
AnfAlgo::CopyNodeAttr(kAttrInputSizes, kAttrInputSize, conv2d_backin, depth_conv_backin);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(depth_conv_backin);
|
||||
depth_conv_backin->set_abstract(conv2d_backin->abstract());
|
||||
depth_conv_backin->set_scope(conv2d_backin->scope());
|
||||
return depth_conv_backin;
|
||||
}
|
||||
|
||||
CNodePtr CreateDepthwiseConv2DBackpropFilter(const FuncGraphPtr &graph, const CNodePtr &conv2d_backfil) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(conv2d_backfil);
|
||||
if (conv2d_backfil->inputs().size() != kConv2DBackpropInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Conv2DBackpropFilter's input number should be " << (kConv2DBackpropInputNum - 1)
|
||||
<< ", but got " << (conv2d_backfil->inputs().size() - 1);
|
||||
}
|
||||
auto filter_size_node = conv2d_backfil->input(kIndex3);
|
||||
MS_EXCEPTION_IF_NULL(filter_size_node);
|
||||
auto filter_size_vnode = filter_size_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(filter_size_vnode);
|
||||
auto filter_size = GetValue<std::vector<int64_t>>(filter_size_vnode->value());
|
||||
// swap axis 0 and 1 of filter shape, but don't swap twice since some node share same filter_size valuenode
|
||||
// when the filter_size value is same.
|
||||
if (filter_size[0] != 1) {
|
||||
std::swap(filter_size[0], filter_size[1]);
|
||||
conv2d_backfil->input(kIndex3)->cast<ValueNodePtr>()->set_value(MakeValue(filter_size));
|
||||
}
|
||||
std::vector<AnfNodePtr> depth_conv_backfil_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropFilterOpName)),
|
||||
conv2d_backfil->input(kIndex2), conv2d_backfil->input(kIndex3), conv2d_backfil->input(kIndex1)};
|
||||
auto depth_conv_backfil = graph->NewCNode(depth_conv_backfil_inputs);
|
||||
MS_EXCEPTION_IF_NULL(depth_conv_backfil);
|
||||
depth_conv_backfil->set_scope(conv2d_backfil->scope());
|
||||
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(conv2d_backfil, 0)};
|
||||
std::vector<size_t> out_shape = AnfAlgo::GetOutputInferShape(conv2d_backfil, 0);
|
||||
if (out_shape.size() != kConv2DAxisNum) {
|
||||
MS_LOG(EXCEPTION) << "Conv2DBackpropFilter's output axis number should be " << kConv2DAxisNum << ", but got "
|
||||
<< out_shape.size();
|
||||
}
|
||||
std::swap(out_shape[0], out_shape[1]);
|
||||
auto shapes = {out_shape};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, depth_conv_backfil.get());
|
||||
return depth_conv_backfil;
|
||||
}
|
||||
|
||||
void SetCommonAttrs(const CNodePtr &conv2d, const CNodePtr &depth_conv) {
|
||||
AnfAlgo::CopyNodeAttr(kAttrKernelSize, conv2d, depth_conv);
|
||||
AnfAlgo::CopyNodeAttr(kAttrDilation, conv2d, depth_conv);
|
||||
|
@ -256,6 +180,20 @@ void SetConv2DBackpropFilterAttrs(const CNodePtr &conv2d_backfil, const CNodePtr
|
|||
}
|
||||
} // namespace
|
||||
|
||||
CNodePtr Conv2DUnifyMindIR::CreateDepthwiseConv2D(const FuncGraphPtr &graph, const CNodePtr &conv2d,
|
||||
const CNodePtr &transpose) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(conv2d);
|
||||
CheckCNodeInputSize(conv2d, kConvInputTensorNum);
|
||||
std::vector<AnfNodePtr> depth_conv_inputs = {NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeOpName)),
|
||||
conv2d->input(kIndex1), transpose};
|
||||
auto depth_conv = NewCNode(depth_conv_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(depth_conv);
|
||||
depth_conv->set_abstract(conv2d->abstract());
|
||||
depth_conv->set_scope(conv2d->scope());
|
||||
return depth_conv;
|
||||
}
|
||||
|
||||
const BaseRef Conv2DUnifyMindIR::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
VarPtr W = std::make_shared<Var>();
|
||||
|
@ -275,12 +213,39 @@ const AnfNodePtr Conv2DUnifyMindIR::Process(const FuncGraphPtr &graph, const Anf
|
|||
return nullptr;
|
||||
}
|
||||
CheckCNodeInputSize(conv2d, kConvInputTensorNum);
|
||||
auto transpose = CreateTranspose(graph, conv2d, conv2d->input(kIndex2), true);
|
||||
auto transpose = CreateTranspose(graph, conv2d, conv2d->input(kIndex2), true, *this);
|
||||
auto depth_conv = CreateDepthwiseConv2D(graph, conv2d, transpose);
|
||||
SetConv2DAttrs(conv2d, depth_conv);
|
||||
return depth_conv;
|
||||
}
|
||||
|
||||
CNodePtr Conv2DBackpropInputUnifyMindIR::CreateDepthwiseConv2DBackpropInput(const FuncGraphPtr &graph,
|
||||
const CNodePtr &conv2d_backin,
|
||||
const CNodePtr &transpose) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(conv2d_backin);
|
||||
|
||||
CNodePtr depth_conv_backin = nullptr;
|
||||
if (conv2d_backin->inputs().size() == kConv2DBackpropInputNum) {
|
||||
std::vector<AnfNodePtr> depth_conv_backin_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)),
|
||||
conv2d_backin->input(kIndex3), transpose, conv2d_backin->input(kIndex1)};
|
||||
depth_conv_backin = NewCNode(depth_conv_backin_inputs, graph);
|
||||
} else {
|
||||
// In nn.Conv2DTranspose, Conv2DBackpropInput is a forward op and the input_sizes input will be convert to attr
|
||||
// in pynative mode.
|
||||
std::vector<AnfNodePtr> depth_conv_backin_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)), transpose,
|
||||
conv2d_backin->input(kIndex1)};
|
||||
depth_conv_backin = NewCNode(depth_conv_backin_inputs, graph);
|
||||
AnfAlgo::CopyNodeAttr(kAttrInputSizes, kAttrInputSize, conv2d_backin, depth_conv_backin);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(depth_conv_backin);
|
||||
depth_conv_backin->set_abstract(conv2d_backin->abstract());
|
||||
depth_conv_backin->set_scope(conv2d_backin->scope());
|
||||
return depth_conv_backin;
|
||||
}
|
||||
|
||||
const BaseRef Conv2DBackpropInputUnifyMindIR::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
VectorRef pattern({prim::kPrimConv2DBackpropInput, Xs});
|
||||
|
@ -306,12 +271,50 @@ const AnfNodePtr Conv2DBackpropInputUnifyMindIR::Process(const FuncGraphPtr &gra
|
|||
MS_LOG(EXCEPTION) << "Conv2DBackpropInput's input number should be " << (kConv2DBackpropInputNum - 1) << " or "
|
||||
<< (kConv2DBackpropInputNum - 2) << ", but got " << (input_size - 1);
|
||||
}
|
||||
auto transpose = CreateTranspose(graph, conv2d_backin, conv2d_backin->input(kIndex2), true);
|
||||
auto transpose = CreateTranspose(graph, conv2d_backin, conv2d_backin->input(kIndex2), true, *this);
|
||||
auto depth_conv_backin = CreateDepthwiseConv2DBackpropInput(graph, conv2d_backin, transpose);
|
||||
SetConv2DBackpropInputAttrs(conv2d_backin, depth_conv_backin);
|
||||
return depth_conv_backin;
|
||||
}
|
||||
|
||||
CNodePtr Conv2DBackpropFilterUnifyMindIR::CreateDepthwiseConv2DBackpropFilter(const FuncGraphPtr &graph,
|
||||
const CNodePtr &conv2d_backfil) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(conv2d_backfil);
|
||||
if (conv2d_backfil->inputs().size() != kConv2DBackpropInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Conv2DBackpropFilter's input number should be " << (kConv2DBackpropInputNum - 1)
|
||||
<< ", but got " << (conv2d_backfil->inputs().size() - 1);
|
||||
}
|
||||
auto filter_size_node = conv2d_backfil->input(kIndex3);
|
||||
MS_EXCEPTION_IF_NULL(filter_size_node);
|
||||
auto filter_size_vnode = filter_size_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(filter_size_vnode);
|
||||
auto filter_size = GetValue<std::vector<int64_t>>(filter_size_vnode->value());
|
||||
// swap axis 0 and 1 of filter shape, but don't swap twice since some node share same filter_size valuenode
|
||||
// when the filter_size value is same.
|
||||
if (filter_size[0] != 1) {
|
||||
std::swap(filter_size[0], filter_size[1]);
|
||||
conv2d_backfil->input(kIndex3)->cast<ValueNodePtr>()->set_value(MakeValue(filter_size));
|
||||
}
|
||||
std::vector<AnfNodePtr> depth_conv_backfil_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropFilterOpName)),
|
||||
conv2d_backfil->input(kIndex2), conv2d_backfil->input(kIndex3), conv2d_backfil->input(kIndex1)};
|
||||
auto depth_conv_backfil = NewCNode(depth_conv_backfil_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(depth_conv_backfil);
|
||||
depth_conv_backfil->set_scope(conv2d_backfil->scope());
|
||||
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(conv2d_backfil, 0)};
|
||||
std::vector<size_t> out_shape = AnfAlgo::GetOutputInferShape(conv2d_backfil, 0);
|
||||
if (out_shape.size() != kConv2DAxisNum) {
|
||||
MS_LOG(EXCEPTION) << "Conv2DBackpropFilter's output axis number should be " << kConv2DAxisNum << ", but got "
|
||||
<< out_shape.size();
|
||||
}
|
||||
std::swap(out_shape[0], out_shape[1]);
|
||||
auto shapes = {out_shape};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, depth_conv_backfil.get());
|
||||
return depth_conv_backfil;
|
||||
}
|
||||
|
||||
const BaseRef Conv2DBackpropFilterUnifyMindIR::DefinePattern() const {
|
||||
VarPtr dout = std::make_shared<Var>();
|
||||
VarPtr input = std::make_shared<Var>();
|
||||
|
@ -335,7 +338,7 @@ const AnfNodePtr Conv2DBackpropFilterUnifyMindIR::Process(const FuncGraphPtr &gr
|
|||
|
||||
auto depth_conv_backfil = CreateDepthwiseConv2DBackpropFilter(graph, conv2d_backfil);
|
||||
SetConv2DBackpropFilterAttrs(conv2d_backfil, depth_conv_backfil);
|
||||
auto transpose = CreateTranspose(graph, conv2d_backfil, depth_conv_backfil, false);
|
||||
auto transpose = CreateTranspose(graph, conv2d_backfil, depth_conv_backfil, false, *this);
|
||||
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -27,6 +27,9 @@ class Conv2DUnifyMindIR : public PatternProcessPass {
|
|||
~Conv2DUnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
CNodePtr CreateDepthwiseConv2D(const FuncGraphPtr &graph, const CNodePtr &conv2d, const CNodePtr &transpose) const;
|
||||
};
|
||||
|
||||
class Conv2DBackpropInputUnifyMindIR : public PatternProcessPass {
|
||||
|
@ -36,6 +39,10 @@ class Conv2DBackpropInputUnifyMindIR : public PatternProcessPass {
|
|||
~Conv2DBackpropInputUnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
CNodePtr CreateDepthwiseConv2DBackpropInput(const FuncGraphPtr &graph, const CNodePtr &conv2d_backin,
|
||||
const CNodePtr &transpose) const;
|
||||
};
|
||||
|
||||
class Conv2DBackpropFilterUnifyMindIR : public PatternProcessPass {
|
||||
|
@ -45,6 +52,9 @@ class Conv2DBackpropFilterUnifyMindIR : public PatternProcessPass {
|
|||
~Conv2DBackpropFilterUnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
CNodePtr CreateDepthwiseConv2DBackpropFilter(const FuncGraphPtr &graph, const CNodePtr &conv2d_backfil) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
|
@ -140,15 +141,15 @@ CNodePtr CreateDynamicShapeCNode(const FuncGraphPtr &func_graph, const AnfNodePt
|
|||
return dynamic_shape;
|
||||
}
|
||||
|
||||
CNodePtr CreateDropoutGenMaskCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &dropout,
|
||||
const ValueNodePtr &keep_prob_value, const AnfNodePtr &dropout_input,
|
||||
const abstract::ShapePtr &input_shape) {
|
||||
CNodePtr CreateDropoutGenMaskCNode(const FuncGraphPtr &func_graph, const CNodePtr &dropout,
|
||||
const ValueNodePtr &keep_prob_value, const abstract::ShapePtr &input_shape,
|
||||
const PatternProcessPass &pass) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dropout);
|
||||
MS_EXCEPTION_IF_NULL(input_shape);
|
||||
std::vector<AnfNodePtr> dropout_gen_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName))};
|
||||
if (input_shape->IsDynamic()) {
|
||||
CNodePtr dynamic_shape = CreateDynamicShapeCNode(func_graph, dropout_input, input_shape);
|
||||
CNodePtr dynamic_shape = CreateDynamicShapeCNode(func_graph, dropout->input(kIndex1), input_shape);
|
||||
dynamic_shape->set_scope(dropout->scope());
|
||||
dropout_gen_mask_inputs.push_back(dynamic_shape);
|
||||
dropout_gen_mask_inputs.push_back(keep_prob_value);
|
||||
|
@ -157,7 +158,7 @@ CNodePtr CreateDropoutGenMaskCNode(const FuncGraphPtr &func_graph, const AnfNode
|
|||
dropout_gen_mask_inputs.push_back(shape_value);
|
||||
dropout_gen_mask_inputs.push_back(keep_prob_value);
|
||||
}
|
||||
CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs);
|
||||
CNodePtr dropout_gen_mask = NewCNode(dropout_gen_mask_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dropout_gen_mask);
|
||||
|
||||
std::shared_ptr<abstract::AbstractTensor> gen_mask_abstract;
|
||||
|
@ -223,8 +224,7 @@ const AnfNodePtr DropoutAndDropoutGradUnifyMindIR::Process(const FuncGraphPtr &f
|
|||
auto dropout_input = dropout_cnode->input(kIndex1);
|
||||
auto input_shape = GetDropoutInputShape(dropout_input);
|
||||
// CreateDropoutGenMask
|
||||
auto dropout_gen_mask =
|
||||
CreateDropoutGenMaskCNode(func_graph, dropout_node, keep_prob_value, dropout_input, input_shape);
|
||||
auto dropout_gen_mask = CreateDropoutGenMaskCNode(func_graph, dropout_cnode, keep_prob_value, input_shape, *this);
|
||||
// CreateDropoutDoMask-forward
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
@ -242,7 +242,7 @@ const AnfNodePtr DropoutAndDropoutGradUnifyMindIR::Process(const FuncGraphPtr &f
|
|||
std::vector<AnfNodePtr> dropout_do_mask1_inputs{
|
||||
NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)), dropout_input, dropout_gen_mask,
|
||||
keep_prob_value};
|
||||
dropout_do_mask1 = func_graph->NewCNode(dropout_do_mask1_inputs);
|
||||
dropout_do_mask1 = NewCNode(dropout_do_mask1_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dropout_do_mask1);
|
||||
auto do_mask_abstract1 =
|
||||
std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), input_shape);
|
||||
|
@ -262,7 +262,7 @@ const AnfNodePtr DropoutAndDropoutGradUnifyMindIR::Process(const FuncGraphPtr &f
|
|||
auto dropout_grad_input = utils::cast<AnfNodePtr>((*equiv)[grad_input_]);
|
||||
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
|
||||
dropout_grad_input, dropout_gen_mask, keep_prob_value};
|
||||
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);
|
||||
auto dropout_do_mask = NewCNode(dropout_do_mask_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dropout_do_mask);
|
||||
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), input_shape);
|
||||
dropout_do_mask->set_abstract(do_mask_abstract);
|
||||
|
@ -300,12 +300,11 @@ const AnfNodePtr DropoutUnifyMindIR0::Process(const FuncGraphPtr &func_graph, co
|
|||
auto input_shape = GetDropoutInputShape(dropout_input);
|
||||
|
||||
// CreateDropoutGenMask
|
||||
auto dropout_gen_mask =
|
||||
CreateDropoutGenMaskCNode(func_graph, dropout_node, keep_prob_value, dropout_input, input_shape);
|
||||
auto dropout_gen_mask = CreateDropoutGenMaskCNode(func_graph, dropout_cnode, keep_prob_value, input_shape, *this);
|
||||
// CreateDropoutDoMask
|
||||
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
|
||||
dropout_input, dropout_gen_mask, keep_prob_value};
|
||||
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);
|
||||
auto dropout_do_mask = NewCNode(dropout_do_mask_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dropout_do_mask);
|
||||
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), input_shape);
|
||||
dropout_do_mask->set_abstract(do_mask_abstract);
|
||||
|
@ -341,12 +340,11 @@ const AnfNodePtr DropoutUnifyMindIR1::Process(const FuncGraphPtr &func_graph, co
|
|||
auto dropout_input = dropout_node->input(kIndex1);
|
||||
auto input_shape = GetDropoutInputShape(dropout_input);
|
||||
// CreateDropoutGenMask
|
||||
auto dropout_gen_mask =
|
||||
CreateDropoutGenMaskCNode(func_graph, dropout_node, keep_prob_value, dropout_input, input_shape);
|
||||
auto dropout_gen_mask = CreateDropoutGenMaskCNode(func_graph, dropout_node, keep_prob_value, input_shape, *this);
|
||||
// CreateDropoutDoMask
|
||||
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
|
||||
dropout_input, dropout_gen_mask, keep_prob_value};
|
||||
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);
|
||||
auto dropout_do_mask = NewCNode(dropout_do_mask_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dropout_do_mask);
|
||||
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), input_shape);
|
||||
dropout_do_mask->set_abstract(do_mask_abstract);
|
||||
|
@ -396,7 +394,7 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph,
|
|||
auto grad_input = dropout_grad_cnode->input(kIndex1);
|
||||
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
|
||||
grad_input, mask_input, keep_prob_value};
|
||||
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);
|
||||
auto dropout_do_mask = NewCNode(dropout_do_mask_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dropout_do_mask);
|
||||
auto do_mask_abstract =
|
||||
std::make_shared<abstract::AbstractTensor>(TypeIdToType(grad_input_type_id), grad_input_shape);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -27,9 +27,9 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
void CreateOutputsOfLSQPerLayerGradD(const FuncGraphPtr &graph, const CNodePtr &lsq_perlayer_grad_node,
|
||||
std::vector<AnfNodePtr> *const lsq_perlayer_grad_d_outputs) {
|
||||
void FakeLearnedScaleQuantPerLayerGradUnifyMindIR::CreateOutputsOfLSQPerLayerGradD(
|
||||
const FuncGraphPtr &graph, const CNodePtr &lsq_perlayer_grad_node,
|
||||
std::vector<AnfNodePtr> *const lsq_perlayer_grad_d_outputs) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(lsq_perlayer_grad_node);
|
||||
const auto &lsq_perlayer_grad_inputs = lsq_perlayer_grad_node->inputs();
|
||||
|
@ -41,7 +41,7 @@ void CreateOutputsOfLSQPerLayerGradD(const FuncGraphPtr &graph, const CNodePtr &
|
|||
NewValueNode(std::make_shared<Primitive>(kFakeLearnedScaleQuantPerLayerGradDOpName)),
|
||||
lsq_perlayer_grad_inputs[kIndex1], lsq_perlayer_grad_inputs[kIndex2], lsq_perlayer_grad_inputs[kIndex3],
|
||||
lsq_perlayer_grad_inputs[kIndex4]};
|
||||
auto lsq_perlayer_grad_d = graph->NewCNode(lsq_perlayer_grad_d_inputs);
|
||||
auto lsq_perlayer_grad_d = NewCNode(lsq_perlayer_grad_d_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(lsq_perlayer_grad_d);
|
||||
lsq_perlayer_grad_d->set_scope(lsq_perlayer_grad_node->scope());
|
||||
|
||||
|
@ -56,9 +56,10 @@ void CreateOutputsOfLSQPerLayerGradD(const FuncGraphPtr &graph, const CNodePtr &
|
|||
lsq_perlayer_grad_d_outputs);
|
||||
}
|
||||
|
||||
void CreateOutputsOfLSQPerLayerReduceGrad(const FuncGraphPtr &graph, const CNodePtr &lsq_perlayer_grad_node,
|
||||
const std::vector<AnfNodePtr> &lsq_perlayer_grad_d_outputs,
|
||||
std::vector<AnfNodePtr> *const lsq_perlayer_reduce_grad_outputs) {
|
||||
void FakeLearnedScaleQuantPerLayerGradUnifyMindIR::CreateOutputsOfLSQPerLayerReduceGrad(
|
||||
const FuncGraphPtr &graph, const CNodePtr &lsq_perlayer_grad_node,
|
||||
const std::vector<AnfNodePtr> &lsq_perlayer_grad_d_outputs,
|
||||
std::vector<AnfNodePtr> *const lsq_perlayer_reduce_grad_outputs) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(lsq_perlayer_grad_node);
|
||||
MS_EXCEPTION_IF_NULL(lsq_perlayer_reduce_grad_outputs);
|
||||
|
@ -74,7 +75,7 @@ void CreateOutputsOfLSQPerLayerReduceGrad(const FuncGraphPtr &graph, const CNode
|
|||
std::vector<AnfNodePtr> lsq_perlayer_reduce_grad_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kFakeLearnedScaleQuantPerLayerGradDReduceOpName)),
|
||||
lsq_perlayer_grad_d_outputs[kIndex1]};
|
||||
auto lsq_perlayer_reduce_grad = graph->NewCNode(lsq_perlayer_reduce_grad_inputs);
|
||||
auto lsq_perlayer_reduce_grad = NewCNode(lsq_perlayer_reduce_grad_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(lsq_perlayer_reduce_grad);
|
||||
lsq_perlayer_reduce_grad->set_scope(lsq_perlayer_grad_node->scope());
|
||||
|
||||
|
@ -85,8 +86,9 @@ void CreateOutputsOfLSQPerLayerReduceGrad(const FuncGraphPtr &graph, const CNode
|
|||
(*lsq_perlayer_reduce_grad_outputs).push_back(lsq_perlayer_reduce_grad);
|
||||
}
|
||||
|
||||
void CreateOutputsOfLSQPerChannelGradD(const FuncGraphPtr &graph, const CNodePtr &lsq_perchannel_grad_node,
|
||||
std::vector<AnfNodePtr> *const lsq_perchannel_grad_d_outputs) {
|
||||
void FakeLearnedScaleQuantPerChannelGradUnifyMindIR::CreateOutputsOfLSQPerChannelGradD(
|
||||
const FuncGraphPtr &graph, const CNodePtr &lsq_perchannel_grad_node,
|
||||
std::vector<AnfNodePtr> *const lsq_perchannel_grad_d_outputs) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(lsq_perchannel_grad_node);
|
||||
const auto &lsq_perchannel_grad_inputs = lsq_perchannel_grad_node->inputs();
|
||||
|
@ -98,7 +100,7 @@ void CreateOutputsOfLSQPerChannelGradD(const FuncGraphPtr &graph, const CNodePtr
|
|||
NewValueNode(std::make_shared<Primitive>(kFakeLearnedScaleQuantPerChannelGradDOpName)),
|
||||
lsq_perchannel_grad_inputs[1], lsq_perchannel_grad_inputs[2], lsq_perchannel_grad_inputs[3],
|
||||
lsq_perchannel_grad_inputs[4]};
|
||||
auto lsq_perchannel_grad_d = graph->NewCNode(lsq_perchannel_grad_d_inputs);
|
||||
auto lsq_perchannel_grad_d = NewCNode(lsq_perchannel_grad_d_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(lsq_perchannel_grad_d);
|
||||
lsq_perchannel_grad_d->set_scope(lsq_perchannel_grad_node->scope());
|
||||
|
||||
|
@ -114,9 +116,10 @@ void CreateOutputsOfLSQPerChannelGradD(const FuncGraphPtr &graph, const CNodePtr
|
|||
lsq_perchannel_grad_d_outputs);
|
||||
}
|
||||
|
||||
void CreateOutputsOfLSQPerChannelReduceGrad(const FuncGraphPtr &graph, const CNodePtr &lsq_perchannel_grad_node,
|
||||
const std::vector<AnfNodePtr> &lsq_perchannel_grad_d_outputs,
|
||||
std::vector<AnfNodePtr> *const lsq_perchannel_reduce_grad_outputs) {
|
||||
void FakeLearnedScaleQuantPerChannelGradUnifyMindIR::CreateOutputsOfLSQPerChannelReduceGrad(
|
||||
const FuncGraphPtr &graph, const CNodePtr &lsq_perchannel_grad_node,
|
||||
const std::vector<AnfNodePtr> &lsq_perchannel_grad_d_outputs,
|
||||
std::vector<AnfNodePtr> *const lsq_perchannel_reduce_grad_outputs) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(lsq_perchannel_grad_node);
|
||||
MS_EXCEPTION_IF_NULL(lsq_perchannel_reduce_grad_outputs);
|
||||
|
@ -132,7 +135,7 @@ void CreateOutputsOfLSQPerChannelReduceGrad(const FuncGraphPtr &graph, const CNo
|
|||
std::vector<AnfNodePtr> lsq_perchannel_reduce_grad_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kFakeLearnedScaleQuantPerChannelGradDReduceOpName)),
|
||||
lsq_perchannel_grad_d_outputs[kIndex1]};
|
||||
auto lsq_perchannel_reduce_grad = graph->NewCNode(lsq_perchannel_reduce_grad_inputs);
|
||||
auto lsq_perchannel_reduce_grad = NewCNode(lsq_perchannel_reduce_grad_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(lsq_perchannel_reduce_grad);
|
||||
lsq_perchannel_reduce_grad->set_scope(lsq_perchannel_grad_node->scope());
|
||||
|
||||
|
@ -142,7 +145,7 @@ void CreateOutputsOfLSQPerChannelReduceGrad(const FuncGraphPtr &graph, const CNo
|
|||
AnfAlgo::CopyNodeAttr(kAttrChannelAxis, lsq_perchannel_grad_node, lsq_perchannel_reduce_grad);
|
||||
(*lsq_perchannel_reduce_grad_outputs).push_back(lsq_perchannel_reduce_grad);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef FakeLearnedScaleQuantPerLayerGradUnifyMindIR::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
auto prim = std::make_shared<Primitive>(kFakeLearnedScaleQuantPerLayerGradOpName);
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_FAKE_LEARNED_SCALE_QUANT_GRAD_UNIFY_MINDIR_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_FAKE_LEARNED_SCALE_QUANT_GRAD_UNIFY_MINDIR_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
|
@ -41,6 +42,13 @@ class FakeLearnedScaleQuantPerLayerGradUnifyMindIR : public PatternProcessPass {
|
|||
~FakeLearnedScaleQuantPerLayerGradUnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
void CreateOutputsOfLSQPerLayerGradD(const FuncGraphPtr &graph, const CNodePtr &lsq_perlayer_grad_node,
|
||||
std::vector<AnfNodePtr> *const lsq_perlayer_grad_d_outputs) const;
|
||||
void CreateOutputsOfLSQPerLayerReduceGrad(const FuncGraphPtr &graph, const CNodePtr &lsq_perlayer_grad_node,
|
||||
const std::vector<AnfNodePtr> &lsq_perlayer_grad_d_outputs,
|
||||
std::vector<AnfNodePtr> *const lsq_perlayer_reduce_grad_outputs) const;
|
||||
};
|
||||
|
||||
class FakeLearnedScaleQuantPerChannelGradUnifyMindIR : public PatternProcessPass {
|
||||
|
@ -50,6 +58,13 @@ class FakeLearnedScaleQuantPerChannelGradUnifyMindIR : public PatternProcessPass
|
|||
~FakeLearnedScaleQuantPerChannelGradUnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
void CreateOutputsOfLSQPerChannelGradD(const FuncGraphPtr &graph, const CNodePtr &lsq_perchannel_grad_node,
|
||||
std::vector<AnfNodePtr> *const lsq_perchannel_grad_d_outputs) const;
|
||||
void CreateOutputsOfLSQPerChannelReduceGrad(const FuncGraphPtr &graph, const CNodePtr &lsq_perchannel_grad_node,
|
||||
const std::vector<AnfNodePtr> &lsq_perchannel_grad_d_outputs,
|
||||
std::vector<AnfNodePtr> *const lsq_perchannel_reduce_grad_outputs) const;
|
||||
};
|
||||
|
||||
} // namespace opt
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -43,8 +43,9 @@ CNodePtr GetMaxPool(const CNodePtr &maxpool_grad) {
|
|||
MS_EXCEPTION_IF_NULL(maxpool_anf);
|
||||
return maxpool_anf->cast<CNodePtr>();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
CNodePtr CreateMaxPoolWithArgmax(const FuncGraphPtr &graph, const CNodePtr &maxpool) {
|
||||
CNodePtr MaxPool2MaxPoolWithArgmax::CreateMaxPoolWithArgmax(const FuncGraphPtr &graph, const CNodePtr &maxpool) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(maxpool);
|
||||
if (maxpool->inputs().size() != kMaxPoolInputNum) {
|
||||
|
@ -53,7 +54,7 @@ CNodePtr CreateMaxPoolWithArgmax(const FuncGraphPtr &graph, const CNodePtr &maxp
|
|||
}
|
||||
std::vector<AnfNodePtr> maxpool_argmax_inputs = {NewValueNode(std::make_shared<Primitive>(kMaxPoolWithArgmaxOpName)),
|
||||
maxpool->input(kIndex1)};
|
||||
auto maxpool_argmax = graph->NewCNode(maxpool_argmax_inputs);
|
||||
auto maxpool_argmax = NewCNode(maxpool_argmax_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(maxpool_argmax);
|
||||
maxpool_argmax->set_scope(maxpool->scope());
|
||||
|
||||
|
@ -66,8 +67,9 @@ CNodePtr CreateMaxPoolWithArgmax(const FuncGraphPtr &graph, const CNodePtr &maxp
|
|||
return maxpool_argmax;
|
||||
}
|
||||
|
||||
CNodePtr CreateMaxPoolGradWithArgmax(const FuncGraphPtr &graph, const CNodePtr &maxpool_grad,
|
||||
const std::vector<AnfNodePtr> &maxpool_argmax_outputs) {
|
||||
CNodePtr MaxPool2MaxPoolWithArgmax::CreateMaxPoolGradWithArgmax(
|
||||
const FuncGraphPtr &graph, const CNodePtr &maxpool_grad,
|
||||
const std::vector<AnfNodePtr> &maxpool_argmax_outputs) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(maxpool_grad);
|
||||
if (maxpool_grad->inputs().size() != kMaxPoolGradInputNum) {
|
||||
|
@ -79,15 +81,16 @@ CNodePtr CreateMaxPoolGradWithArgmax(const FuncGraphPtr &graph, const CNodePtr &
|
|||
std::vector<AnfNodePtr> maxpool_grad_argmax_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kMaxPoolGradWithArgmaxOpName)), maxpool_grad->input(kIndex1),
|
||||
maxpool_grad->input(kIndex3), maxpool_argmax_outputs[kIndex1]};
|
||||
auto maxpool_grad_argmax = graph->NewCNode(maxpool_grad_argmax_inputs);
|
||||
auto maxpool_grad_argmax = NewCNode(maxpool_grad_argmax_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(maxpool_grad_argmax);
|
||||
maxpool_grad_argmax->set_scope(maxpool_grad->scope());
|
||||
maxpool_grad_argmax->set_abstract(maxpool_grad->abstract());
|
||||
return maxpool_grad_argmax;
|
||||
}
|
||||
|
||||
void SetNodeAttrs(const CNodePtr &maxpool, const CNodePtr &maxpool_grad, const CNodePtr &maxpool_argmax,
|
||||
const CNodePtr &maxpool_grad_argmax) {
|
||||
void MaxPool2MaxPoolWithArgmax::SetNodeAttrs(const CNodePtr &maxpool, const CNodePtr &maxpool_grad,
|
||||
const CNodePtr &maxpool_argmax,
|
||||
const CNodePtr &maxpool_grad_argmax) const {
|
||||
auto strides = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(maxpool, kAttrStrides);
|
||||
auto ksize = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(maxpool, kAttrKernelSize);
|
||||
if (strides.size() != kMaxPoolAttrAxisNum) {
|
||||
|
@ -114,7 +117,6 @@ void SetNodeAttrs(const CNodePtr &maxpool, const CNodePtr &maxpool_grad, const C
|
|||
AnfAlgo::SetNodeAttr(kAttrKernelSize, MakeValue(ksize), maxpool_argmax);
|
||||
AnfAlgo::SetNodeAttr(kAttrKernelSize, MakeValue(ksize), maxpool_grad_argmax);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef MaxPool2MaxPoolWithArgmax::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,6 +16,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_TO_MAXPOOL_WITH_ARGMAX_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_TO_MAXPOOL_WITH_ARGMAX_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
|
@ -28,6 +29,13 @@ class MaxPool2MaxPoolWithArgmax : public PatternProcessPass {
|
|||
~MaxPool2MaxPoolWithArgmax() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
CNodePtr CreateMaxPoolWithArgmax(const FuncGraphPtr &graph, const CNodePtr &maxpool) const;
|
||||
CNodePtr CreateMaxPoolGradWithArgmax(const FuncGraphPtr &graph, const CNodePtr &maxpool_grad,
|
||||
const std::vector<AnfNodePtr> &maxpool_argmax_outputs) const;
|
||||
void SetNodeAttrs(const CNodePtr &maxpool, const CNodePtr &maxpool_grad, const CNodePtr &maxpool_argmax,
|
||||
const CNodePtr &maxpool_grad_argmax) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -45,6 +45,16 @@ constexpr int64_t kRankIdSeven = 7;
|
|||
constexpr size_t kSizeFour = 4;
|
||||
constexpr int64_t kInvalidId = -1;
|
||||
|
||||
bool IsTop(const std::vector<int64_t> &send_rank_ids) {
|
||||
return send_rank_ids[kRankIdZero] != kInvalidId || send_rank_ids[kRankIdOne] != kInvalidId ||
|
||||
send_rank_ids[kRankIdSeven] != kInvalidId;
|
||||
}
|
||||
|
||||
bool IsBottom(const std::vector<int64_t> &send_rank_ids) {
|
||||
return send_rank_ids[kRankIdThree] != kInvalidId || send_rank_ids[kRankIdFour] != kInvalidId ||
|
||||
send_rank_ids[kRankIdFive] != kInvalidId;
|
||||
}
|
||||
|
||||
// cal split attrs size_splits, shapes and num_split
|
||||
int64_t CalSplitAttrs(const std::vector<size_t> &base_shape, const bool is_first, const bool is_last,
|
||||
const int64_t split_dim, const std::vector<int64_t> &send_lens, std::vector<int64_t> *size_splits,
|
||||
|
@ -61,8 +71,8 @@ int64_t CalSplitAttrs(const std::vector<size_t> &base_shape, const bool is_first
|
|||
int64_t split_middle_size = base_shape[split_dim];
|
||||
std::vector<size_t> shape_tmp(base_shape);
|
||||
// [top, bottom, left, right]
|
||||
int64_t first_size = split_dim == kWDim ? send_lens[2] : send_lens[0];
|
||||
int64_t last_size = split_dim == kWDim ? send_lens[3] : send_lens[1];
|
||||
int64_t first_size = split_dim == kWDim ? send_lens[kDim2] : send_lens[0];
|
||||
int64_t last_size = split_dim == kWDim ? send_lens[kDim3] : send_lens[1];
|
||||
|
||||
if (is_first) {
|
||||
// first
|
||||
|
@ -95,14 +105,15 @@ int64_t CalSplitAttrs(const std::vector<size_t> &base_shape, const bool is_first
|
|||
|
||||
CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &split_input,
|
||||
const std::vector<size_t> &base_shape, bool is_first, bool is_last, int64_t split_dim,
|
||||
const std::vector<int64_t> &send_lens, TypeId input_dtype, int64_t *num_split) {
|
||||
const std::vector<int64_t> &send_lens, TypeId input_dtype, int64_t *num_split,
|
||||
const PatternProcessPass &pass) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(num_split);
|
||||
if (split_input.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The input is empty, can not create splitv node.";
|
||||
return nullptr;
|
||||
}
|
||||
auto split_v = graph->NewCNode(split_input);
|
||||
auto split_v = pass.NewCNode(split_input, graph);
|
||||
MS_EXCEPTION_IF_NULL(split_v);
|
||||
std::vector<int64_t> size_splits = {};
|
||||
std::vector<std::vector<size_t>> shapes = {};
|
||||
|
@ -147,127 +158,6 @@ std::vector<std::vector<size_t>> CalAllToAllvOutputShape(const std::vector<size_
|
|||
return shapes;
|
||||
}
|
||||
|
||||
// returns {top_bottom, left_right, top_corner, bottom_corner}, if no split, set it nullptr
|
||||
std::vector<CNodePtr> CreateSplitNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
|
||||
std::vector<int64_t> *split_num) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2);
|
||||
MS_EXCEPTION_IF_NULL(split_num);
|
||||
std::vector<int64_t> send_rank_ids =
|
||||
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2, kAttrSendRankIds);
|
||||
std::vector<int64_t> send_lens = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2, kAttrSendLens);
|
||||
|
||||
if (neighbor_exchange_v2->size() <= kNeighborExchangeV2InputIdx) {
|
||||
MS_LOG(EXCEPTION) << "Invalid cnode " << neighbor_exchange_v2->DebugString() << " input size "
|
||||
<< neighbor_exchange_v2->size();
|
||||
}
|
||||
std::vector<CNodePtr> split_nodes = {};
|
||||
|
||||
auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx);
|
||||
|
||||
bool is_top = ((send_rank_ids[kRankIdZero] != kInvalidId) || (send_rank_ids[kRankIdOne] != kInvalidId) ||
|
||||
(send_rank_ids[kRankIdSeven] != kInvalidId));
|
||||
bool is_bottom = ((send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFour] != kInvalidId) ||
|
||||
(send_rank_ids[kRankIdFive] != kInvalidId));
|
||||
bool is_left = (send_rank_ids[kRankIdSix] != kInvalidId);
|
||||
bool is_right = (send_rank_ids[kRankIdTwo] != kInvalidId);
|
||||
|
||||
auto dtype = AnfAlgo::GetOutputInferDataType(neighbor_exchange_v2_input, 0);
|
||||
auto shape = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_input, 0);
|
||||
if (SizeToLong(shape.size()) != kShapeSize) { // only support NCHW now
|
||||
MS_LOG(EXCEPTION) << "Invalid shape size " << shape.size() << ", only support NCHW input now!";
|
||||
}
|
||||
|
||||
// splitv for top & bottom
|
||||
int64_t num_split_h = 0;
|
||||
CNodePtr split_v_top_bottom = nullptr;
|
||||
if (is_top || is_bottom) {
|
||||
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
|
||||
neighbor_exchange_v2_input};
|
||||
|
||||
split_v_top_bottom =
|
||||
CreateSplitNode(graph, split_input, shape, is_top, is_bottom, kHDim, send_lens, dtype, &num_split_h);
|
||||
}
|
||||
split_nodes.emplace_back(split_v_top_bottom);
|
||||
split_num->push_back(num_split_h);
|
||||
|
||||
// splitv for left & right
|
||||
int64_t num_split_w = 0;
|
||||
CNodePtr split_v_left_right = nullptr;
|
||||
if (is_left || is_right) {
|
||||
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
|
||||
neighbor_exchange_v2_input};
|
||||
split_v_left_right =
|
||||
CreateSplitNode(graph, split_input, shape, is_left, is_right, kWDim, send_lens, dtype, &num_split_w);
|
||||
}
|
||||
split_nodes.emplace_back(split_v_left_right);
|
||||
split_num->push_back(num_split_w);
|
||||
|
||||
// splitv for corner
|
||||
if ((send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdSeven] != kInvalidId) ||
|
||||
(send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFive] != kInvalidId)) {
|
||||
// top_bottom_split outputs
|
||||
std::vector<AnfNodePtr> split_outputs_top_bottom;
|
||||
CreateMultipleOutputsOfAnfNode(graph, split_nodes[0], static_cast<size_t>((*split_num)[0]),
|
||||
&split_outputs_top_bottom);
|
||||
if (split_outputs_top_bottom.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The node " << split_nodes[0]->DebugString()
|
||||
<< " should have at least one output, but got 0.";
|
||||
}
|
||||
|
||||
// for top corner
|
||||
if ((send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdSeven] != kInvalidId)) {
|
||||
auto shape_tmp(shape);
|
||||
shape_tmp[kHDim] = send_lens[0];
|
||||
bool is_first = (send_rank_ids[kRankIdSeven] != kInvalidId);
|
||||
bool is_last = (send_rank_ids[kRankIdOne] != kInvalidId);
|
||||
|
||||
std::vector<AnfNodePtr> split_v_corner_top_input = {
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
|
||||
split_v_corner_top_input.insert(split_v_corner_top_input.end(), split_outputs_top_bottom.begin(),
|
||||
split_outputs_top_bottom.begin() + 1);
|
||||
int64_t num_split_top_corner = 0;
|
||||
CNodePtr split_v_corner_top = CreateSplitNode(graph, split_v_corner_top_input, shape_tmp, is_first, is_last,
|
||||
kWDim, send_lens, dtype, &num_split_top_corner);
|
||||
|
||||
split_nodes.emplace_back(split_v_corner_top);
|
||||
split_num->push_back(num_split_top_corner);
|
||||
} else {
|
||||
split_nodes.emplace_back(nullptr);
|
||||
split_num->push_back(0);
|
||||
}
|
||||
|
||||
// for bottom corner
|
||||
if ((send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFive] != kInvalidId)) {
|
||||
auto shape_tmp(shape);
|
||||
shape_tmp[kHDim] = send_lens[1];
|
||||
bool is_first = (send_rank_ids[kRankIdFive] != kInvalidId);
|
||||
bool is_last = (send_rank_ids[kRankIdThree] != kInvalidId);
|
||||
|
||||
std::vector<AnfNodePtr> split_v_corner_bottom_input = {
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
|
||||
split_v_corner_bottom_input.insert(split_v_corner_bottom_input.end(), split_outputs_top_bottom.end() - 1,
|
||||
split_outputs_top_bottom.end());
|
||||
|
||||
int64_t num_split_bottom_corner = 0;
|
||||
CNodePtr split_v_corner_bottom = CreateSplitNode(graph, split_v_corner_bottom_input, shape_tmp, is_first, is_last,
|
||||
kWDim, send_lens, dtype, &num_split_bottom_corner);
|
||||
split_nodes.emplace_back(split_v_corner_bottom);
|
||||
split_num->push_back(num_split_bottom_corner);
|
||||
} else {
|
||||
split_nodes.emplace_back(nullptr);
|
||||
split_num->push_back(0);
|
||||
}
|
||||
} else {
|
||||
split_nodes.emplace_back(nullptr);
|
||||
split_num->push_back(0);
|
||||
split_nodes.emplace_back(nullptr);
|
||||
split_num->push_back(0);
|
||||
}
|
||||
|
||||
return split_nodes;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> CreateAllToAllvInput(const std::vector<std::vector<AnfNodePtr>> &split_outputs,
|
||||
const std::vector<int64_t> &send_rank_ids) {
|
||||
std::vector<AnfNodePtr> all_to_all_v_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllVOpName))};
|
||||
|
@ -309,7 +199,7 @@ AnfNodePtr GetCenter(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchang
|
|||
return neighbor_exchange_v2_grad->input(kNeighborExchangeV2InputIdx);
|
||||
}
|
||||
} else {
|
||||
CreateMultipleOutputsOfAnfNode(graph, split_nodes[2], static_cast<size_t>(split_num[2]), &output);
|
||||
CreateMultipleOutputsOfAnfNode(graph, split_nodes[kDim2], static_cast<size_t>(split_num[kDim2]), &output);
|
||||
if (output.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Wrong split output size: " << output.size() << ", except size >= 2.";
|
||||
}
|
||||
|
@ -355,16 +245,17 @@ std::vector<AnfNodePtr> CreateAllToAllvInputForGrad(const std::vector<int64_t> &
|
|||
}
|
||||
}
|
||||
// 2
|
||||
if (split_nodes[2] != nullptr && send_rank_ids[kRankIdTwo] != kInvalidId) {
|
||||
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[2].end() - 1, split_outputs[2].end());
|
||||
if (split_nodes[kIndex2] != nullptr && send_rank_ids[kRankIdTwo] != kInvalidId) {
|
||||
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[kIndex2].end() - 1, split_outputs[kIndex2].end());
|
||||
}
|
||||
// 3, 4, 5
|
||||
if (split_nodes[3] != nullptr) {
|
||||
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[3].rbegin(), split_outputs[3].rend());
|
||||
if (split_nodes[kIndex3] != nullptr) {
|
||||
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[kIndex3].rbegin(), split_outputs[kIndex3].rend());
|
||||
}
|
||||
// 6
|
||||
if (split_nodes[2] != nullptr && send_rank_ids[kRankIdSix] != kInvalidId) {
|
||||
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[2].begin(), split_outputs[2].begin() + 1);
|
||||
if (split_nodes[kIndex2] != nullptr && send_rank_ids[kRankIdSix] != kInvalidId) {
|
||||
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[kIndex2].begin(),
|
||||
split_outputs[kIndex2].begin() + 1);
|
||||
}
|
||||
// 7
|
||||
if (split_nodes[1] != nullptr && send_rank_ids[kRankIdSeven] != kInvalidId) {
|
||||
|
@ -373,10 +264,11 @@ std::vector<AnfNodePtr> CreateAllToAllvInputForGrad(const std::vector<int64_t> &
|
|||
|
||||
return all_to_all_v_input;
|
||||
}
|
||||
|
||||
// alltoallv for forward & grad
|
||||
CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_or_grad,
|
||||
const std::vector<CNodePtr> &split_nodes, const std::vector<int64_t> &split_num,
|
||||
bool is_grad) {
|
||||
bool is_grad, const PatternProcessPass &pass) {
|
||||
MS_LOG(DEBUG) << "Start to create alltoallv node.";
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_or_grad);
|
||||
|
@ -434,7 +326,7 @@ CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &neighbor
|
|||
std::vector<TypeId> dtypes(real_recv_rank_ids.size(), base_dtype);
|
||||
|
||||
// create alltoallv node
|
||||
auto all_to_all_v = graph->NewCNode(all_to_all_v_input);
|
||||
auto all_to_all_v = pass.NewCNode(all_to_all_v_input, graph);
|
||||
MS_EXCEPTION_IF_NULL(all_to_all_v);
|
||||
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, all_to_all_v.get());
|
||||
|
||||
|
@ -456,12 +348,135 @@ int64_t AllToAllRealIds(int64_t ids, const std::vector<int64_t> &recv_rank_ids)
|
|||
}
|
||||
return real_ids;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &concat_input,
|
||||
const std::vector<std::vector<size_t>> &output_shape, const std::vector<TypeId> &output_dtype,
|
||||
int64_t axis, int64_t input_nums) {
|
||||
// returns {top_bottom, left_right, top_corner, bottom_corner}, if no split, set it nullptr
|
||||
std::vector<CNodePtr> NeighborExchangeV2UnifyMindIR::CreateSplitNodes(const FuncGraphPtr &graph,
|
||||
const CNodePtr &neighbor_exchange_v2,
|
||||
std::vector<int64_t> *split_num) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto concat = graph->NewCNode(concat_input);
|
||||
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2);
|
||||
MS_EXCEPTION_IF_NULL(split_num);
|
||||
std::vector<int64_t> send_rank_ids =
|
||||
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2, kAttrSendRankIds);
|
||||
std::vector<int64_t> send_lens = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2, kAttrSendLens);
|
||||
|
||||
if (neighbor_exchange_v2->size() <= kNeighborExchangeV2InputIdx) {
|
||||
MS_LOG(EXCEPTION) << "Invalid cnode " << neighbor_exchange_v2->DebugString() << " input size "
|
||||
<< neighbor_exchange_v2->size();
|
||||
}
|
||||
std::vector<CNodePtr> split_nodes = {};
|
||||
|
||||
auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx);
|
||||
|
||||
bool is_top = IsTop(send_rank_ids);
|
||||
bool is_bottom = IsBottom(send_rank_ids);
|
||||
bool is_left = (send_rank_ids[kRankIdSix] != kInvalidId);
|
||||
bool is_right = (send_rank_ids[kRankIdTwo] != kInvalidId);
|
||||
|
||||
auto dtype = AnfAlgo::GetOutputInferDataType(neighbor_exchange_v2_input, 0);
|
||||
auto shape = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_input, 0);
|
||||
if (SizeToLong(shape.size()) != kShapeSize) { // only support NCHW now
|
||||
MS_LOG(EXCEPTION) << "Invalid shape size " << shape.size() << ", only support NCHW input now!";
|
||||
}
|
||||
|
||||
// splitv for top & bottom
|
||||
int64_t num_split_h = 0;
|
||||
CNodePtr split_v_top_bottom = nullptr;
|
||||
if (is_top || is_bottom) {
|
||||
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
|
||||
neighbor_exchange_v2_input};
|
||||
|
||||
split_v_top_bottom =
|
||||
CreateSplitNode(graph, split_input, shape, is_top, is_bottom, kHDim, send_lens, dtype, &num_split_h, *this);
|
||||
}
|
||||
split_nodes.emplace_back(split_v_top_bottom);
|
||||
split_num->push_back(num_split_h);
|
||||
|
||||
// splitv for left & right
|
||||
int64_t num_split_w = 0;
|
||||
CNodePtr split_v_left_right = nullptr;
|
||||
if (is_left || is_right) {
|
||||
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
|
||||
neighbor_exchange_v2_input};
|
||||
split_v_left_right =
|
||||
CreateSplitNode(graph, split_input, shape, is_left, is_right, kWDim, send_lens, dtype, &num_split_w, *this);
|
||||
}
|
||||
split_nodes.emplace_back(split_v_left_right);
|
||||
split_num->push_back(num_split_w);
|
||||
|
||||
// splitv for corner
|
||||
if ((send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdSeven] != kInvalidId) ||
|
||||
(send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFive] != kInvalidId)) {
|
||||
// top_bottom_split outputs
|
||||
std::vector<AnfNodePtr> split_outputs_top_bottom;
|
||||
CreateMultipleOutputsOfAnfNode(graph, split_nodes[0], static_cast<size_t>((*split_num)[0]),
|
||||
&split_outputs_top_bottom);
|
||||
if (split_outputs_top_bottom.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The node " << split_nodes[0]->DebugString()
|
||||
<< " should have at least one output, but got 0.";
|
||||
}
|
||||
|
||||
// for top corner
|
||||
if ((send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdSeven] != kInvalidId)) {
|
||||
auto shape_tmp(shape);
|
||||
shape_tmp[kHDim] = send_lens[0];
|
||||
bool is_first = (send_rank_ids[kRankIdSeven] != kInvalidId);
|
||||
bool is_last = (send_rank_ids[kRankIdOne] != kInvalidId);
|
||||
|
||||
std::vector<AnfNodePtr> split_v_corner_top_input = {
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
|
||||
split_v_corner_top_input.insert(split_v_corner_top_input.end(), split_outputs_top_bottom.begin(),
|
||||
split_outputs_top_bottom.begin() + 1);
|
||||
int64_t num_split_top_corner = 0;
|
||||
CNodePtr split_v_corner_top = CreateSplitNode(graph, split_v_corner_top_input, shape_tmp, is_first, is_last,
|
||||
kWDim, send_lens, dtype, &num_split_top_corner, *this);
|
||||
|
||||
split_nodes.emplace_back(split_v_corner_top);
|
||||
split_num->push_back(num_split_top_corner);
|
||||
} else {
|
||||
split_nodes.emplace_back(nullptr);
|
||||
split_num->push_back(0);
|
||||
}
|
||||
|
||||
// for bottom corner
|
||||
if ((send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFive] != kInvalidId)) {
|
||||
auto shape_tmp(shape);
|
||||
shape_tmp[kHDim] = send_lens[1];
|
||||
bool is_first = (send_rank_ids[kRankIdFive] != kInvalidId);
|
||||
bool is_last = (send_rank_ids[kRankIdThree] != kInvalidId);
|
||||
|
||||
std::vector<AnfNodePtr> split_v_corner_bottom_input = {
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
|
||||
split_v_corner_bottom_input.insert(split_v_corner_bottom_input.end(), split_outputs_top_bottom.end() - 1,
|
||||
split_outputs_top_bottom.end());
|
||||
|
||||
int64_t num_split_bottom_corner = 0;
|
||||
CNodePtr split_v_corner_bottom = CreateSplitNode(graph, split_v_corner_bottom_input, shape_tmp, is_first, is_last,
|
||||
kWDim, send_lens, dtype, &num_split_bottom_corner, *this);
|
||||
split_nodes.emplace_back(split_v_corner_bottom);
|
||||
split_num->push_back(num_split_bottom_corner);
|
||||
} else {
|
||||
split_nodes.emplace_back(nullptr);
|
||||
split_num->push_back(0);
|
||||
}
|
||||
} else {
|
||||
split_nodes.emplace_back(nullptr);
|
||||
split_num->push_back(0);
|
||||
split_nodes.emplace_back(nullptr);
|
||||
split_num->push_back(0);
|
||||
}
|
||||
|
||||
return split_nodes;
|
||||
}
|
||||
|
||||
CNodePtr NeighborExchangeV2UnifyMindIR::CreateConcatNode(const FuncGraphPtr &graph,
|
||||
const std::vector<AnfNodePtr> &concat_input,
|
||||
const std::vector<std::vector<size_t>> &output_shape,
|
||||
const std::vector<TypeId> &output_dtype, int64_t axis,
|
||||
int64_t input_nums) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto concat = NewCNode(concat_input, graph);
|
||||
MS_EXCEPTION_IF_NULL(concat);
|
||||
AnfAlgo::SetOutputInferTypeAndShape(output_dtype, output_shape, concat.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue<int64_t>(axis), concat);
|
||||
|
@ -471,9 +486,11 @@ CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const std::vector<AnfNodePt
|
|||
return concat;
|
||||
}
|
||||
|
||||
CNodePtr CreateLeftRightConcat(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &all_to_all_v_outputs,
|
||||
const std::vector<int64_t> &recv_rank_ids, const std::vector<int64_t> &recv_lens,
|
||||
bool is_left) {
|
||||
CNodePtr NeighborExchangeV2UnifyMindIR::CreateLeftRightConcat(const FuncGraphPtr &graph,
|
||||
const std::vector<AnfNodePtr> &all_to_all_v_outputs,
|
||||
const std::vector<int64_t> &recv_rank_ids,
|
||||
const std::vector<int64_t> &recv_lens,
|
||||
bool is_left) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
||||
std::vector<AnfNodePtr> concat_input = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))};
|
||||
|
@ -486,11 +503,11 @@ CNodePtr CreateLeftRightConcat(const FuncGraphPtr &graph, const std::vector<AnfN
|
|||
|
||||
if (recv_rank_ids[first_ids] != kInvalidId) {
|
||||
++input_num;
|
||||
single_shape[2] += static_cast<size_t>(recv_lens[0]); // H in NCHW
|
||||
single_shape[kDim2] += static_cast<size_t>(recv_lens[0]); // H in NCHW
|
||||
}
|
||||
if (recv_rank_ids[last_ids] != kInvalidId) {
|
||||
++input_num;
|
||||
single_shape[2] += static_cast<size_t>(recv_lens[1]); // H in NCHW
|
||||
single_shape[kDim2] += static_cast<size_t>(recv_lens[1]); // H in NCHW
|
||||
}
|
||||
if (is_left) {
|
||||
concat_input.insert(concat_input.end(), all_to_all_v_outputs.rbegin(), all_to_all_v_outputs.rbegin() + input_num);
|
||||
|
@ -506,18 +523,17 @@ CNodePtr CreateLeftRightConcat(const FuncGraphPtr &graph, const std::vector<AnfN
|
|||
return concat;
|
||||
}
|
||||
|
||||
CNodePtr CreateMiddleConcat(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
|
||||
const std::vector<AnfNodePtr> &all_to_all_v_outputs,
|
||||
const std::vector<int64_t> &recv_rank_ids, const std::vector<int64_t> &recv_lens,
|
||||
int64_t concat_dim) {
|
||||
CNodePtr NeighborExchangeV2UnifyMindIR::CreateMiddleConcat(
|
||||
const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2, const std::vector<AnfNodePtr> &all_to_all_v_outputs,
|
||||
const std::vector<int64_t> &recv_rank_ids, const std::vector<int64_t> &recv_lens, int64_t concat_dim) const {
|
||||
std::vector<AnfNodePtr> concat_input_all = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))};
|
||||
int64_t input_num_all = 0;
|
||||
auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx);
|
||||
auto single_shape = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_input, 0);
|
||||
size_t first_idx = concat_dim == kWDim ? 6 : 0;
|
||||
size_t last_idx = concat_dim == kWDim ? 2 : 4;
|
||||
size_t first_len = concat_dim == kWDim ? static_cast<size_t>(recv_lens[2]) : static_cast<size_t>(recv_lens[0]);
|
||||
size_t last_len = concat_dim == kWDim ? static_cast<size_t>(recv_lens[3]) : static_cast<size_t>(recv_lens[1]);
|
||||
size_t first_len = concat_dim == kWDim ? static_cast<size_t>(recv_lens[kDim2]) : static_cast<size_t>(recv_lens[0]);
|
||||
size_t last_len = concat_dim == kWDim ? static_cast<size_t>(recv_lens[kDim3]) : static_cast<size_t>(recv_lens[1]);
|
||||
|
||||
// left
|
||||
if (recv_rank_ids[first_idx] != kInvalidId) {
|
||||
|
@ -553,8 +569,9 @@ CNodePtr CreateMiddleConcat(const FuncGraphPtr &graph, const CNodePtr &neighbor_
|
|||
return concat_all;
|
||||
}
|
||||
|
||||
CNodePtr AllToAllvRecvEmpty(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
|
||||
const CNodePtr &all_to_all_v) {
|
||||
CNodePtr NeighborExchangeV2UnifyMindIR::AllToAllvRecvEmpty(const FuncGraphPtr &graph,
|
||||
const CNodePtr &neighbor_exchange_v2,
|
||||
const CNodePtr &all_to_all_v) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2);
|
||||
MS_EXCEPTION_IF_NULL(all_to_all_v);
|
||||
|
@ -568,8 +585,9 @@ CNodePtr AllToAllvRecvEmpty(const FuncGraphPtr &graph, const CNodePtr &neighbor_
|
|||
return depend;
|
||||
}
|
||||
|
||||
CNodePtr CreateConcatNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
|
||||
const CNodePtr &all_to_all_v) {
|
||||
CNodePtr NeighborExchangeV2UnifyMindIR::CreateConcatNodes(const FuncGraphPtr &graph,
|
||||
const CNodePtr &neighbor_exchange_v2,
|
||||
const CNodePtr &all_to_all_v) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2);
|
||||
MS_EXCEPTION_IF_NULL(all_to_all_v);
|
||||
|
@ -611,10 +629,10 @@ CNodePtr CreateConcatNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_e
|
|||
std::vector<AnfNodePtr> concat_input_all = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))};
|
||||
auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx);
|
||||
std::vector<size_t> shape_all = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_input, 0);
|
||||
shape_all[2] =
|
||||
recv_rank_ids[kRankIdZero] != kInvalidId ? shape_all[2] + static_cast<size_t>(recv_lens[0]) : shape_all[2];
|
||||
shape_all[2] =
|
||||
recv_rank_ids[kRankIdFour] != kInvalidId ? shape_all[2] + static_cast<size_t>(recv_lens[1]) : shape_all[2];
|
||||
shape_all[kDim2] =
|
||||
recv_rank_ids[kRankIdZero] != kInvalidId ? shape_all[kDim2] + static_cast<size_t>(recv_lens[0]) : shape_all[kDim2];
|
||||
shape_all[kDim2] =
|
||||
recv_rank_ids[kRankIdFour] != kInvalidId ? shape_all[kDim2] + static_cast<size_t>(recv_lens[1]) : shape_all[kDim2];
|
||||
int64_t input_nums_all = 0;
|
||||
// left concat
|
||||
if (is_left) {
|
||||
|
@ -628,7 +646,7 @@ CNodePtr CreateConcatNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_e
|
|||
}
|
||||
concat_input_all.insert(concat_input_all.end(), concat_left_outputs.begin(), concat_left_outputs.end());
|
||||
++input_nums_all;
|
||||
shape_all[3] += recv_lens[2];
|
||||
shape_all[kDim3] += recv_lens[kDim2];
|
||||
}
|
||||
|
||||
// middle concat connect to concat_all
|
||||
|
@ -651,7 +669,7 @@ CNodePtr CreateConcatNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_e
|
|||
}
|
||||
concat_input_all.insert(concat_input_all.end(), concat_right_outputs.begin(), concat_right_outputs.end());
|
||||
++input_nums_all;
|
||||
shape_all[3] += recv_lens[3];
|
||||
shape_all[kDim3] += recv_lens[kDim3];
|
||||
}
|
||||
|
||||
std::vector<TypeId> concat_right_output_dtype = {AnfAlgo::GetOutputInferDataType(concat_input_all[1], 0)};
|
||||
|
@ -662,8 +680,8 @@ CNodePtr CreateConcatNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_e
|
|||
|
||||
// grad
|
||||
// returns {top_bottom, left_right, top_corner, bottom_corner}, if no split, set it nullptr
|
||||
std::vector<CNodePtr> CreateSplitNodesForGrad(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad,
|
||||
std::vector<int64_t> *split_num) {
|
||||
std::vector<CNodePtr> NeighborExchangeV2GradUnifyMindIR::CreateSplitNodesForGrad(
|
||||
const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad, std::vector<int64_t> *split_num) const {
|
||||
MS_LOG(DEBUG) << "Start create splitv nodes.";
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_grad);
|
||||
|
@ -686,17 +704,15 @@ std::vector<CNodePtr> CreateSplitNodesForGrad(const FuncGraphPtr &graph, const C
|
|||
|
||||
std::vector<CNodePtr> split_nodes = {};
|
||||
// splitv for top & bottom
|
||||
bool is_top = ((send_rank_ids[kRankIdZero] != kInvalidId) || (send_rank_ids[kRankIdOne] != kInvalidId) ||
|
||||
(send_rank_ids[kRankIdSeven] != kInvalidId));
|
||||
bool is_bottom = ((send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFour] != kInvalidId) ||
|
||||
(send_rank_ids[kRankIdFive] != kInvalidId));
|
||||
bool is_top = IsTop(send_rank_ids);
|
||||
bool is_bottom = IsBottom(send_rank_ids);
|
||||
CNodePtr split_v_top_bottom = nullptr;
|
||||
int64_t num_split_h = 0;
|
||||
if (is_top || is_bottom) {
|
||||
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
|
||||
neighbor_exchange_v2_grad_input};
|
||||
split_v_top_bottom =
|
||||
CreateSplitNode(graph, split_input, shape, is_top, is_bottom, kHDim, send_lens, dtype, &num_split_h);
|
||||
CreateSplitNode(graph, split_input, shape, is_top, is_bottom, kHDim, send_lens, dtype, &num_split_h, *this);
|
||||
}
|
||||
split_nodes.emplace_back(split_v_top_bottom);
|
||||
split_num->push_back(num_split_h);
|
||||
|
@ -735,8 +751,8 @@ std::vector<CNodePtr> CreateSplitNodesForGrad(const FuncGraphPtr &graph, const C
|
|||
int64_t num_split_w = 0;
|
||||
std::vector<size_t> base_shape(shape);
|
||||
base_shape[kHDim] = size_split_h[i];
|
||||
auto split_v_left_right =
|
||||
CreateSplitNode(graph, split_input, base_shape, is_left, is_right, kWDim, send_lens, dtype, &num_split_w);
|
||||
auto split_v_left_right = CreateSplitNode(graph, split_input, base_shape, is_left, is_right, kWDim, send_lens,
|
||||
dtype, &num_split_w, *this);
|
||||
split_nodes.emplace_back(split_v_left_right);
|
||||
split_num->push_back(num_split_w);
|
||||
}
|
||||
|
@ -756,12 +772,14 @@ std::vector<CNodePtr> CreateSplitNodesForGrad(const FuncGraphPtr &graph, const C
|
|||
return split_nodes;
|
||||
}
|
||||
|
||||
CNodePtr CreatePadNode(const FuncGraphPtr &graph, const AnfNodePtr &input, const std::vector<int64_t> &begin,
|
||||
const std::vector<int64_t> &size, const std::vector<size_t> &shape, TypeId dtype) {
|
||||
CNodePtr NeighborExchangeV2GradUnifyMindIR::CreatePadNode(const FuncGraphPtr &graph, const AnfNodePtr &input,
|
||||
const std::vector<int64_t> &begin,
|
||||
const std::vector<int64_t> &size,
|
||||
const std::vector<size_t> &shape, TypeId dtype) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
std::vector<AnfNodePtr> pad_inputs = {NewValueNode(std::make_shared<Primitive>(kPadOpName)), input};
|
||||
auto pad = graph->NewCNode(pad_inputs);
|
||||
auto pad = NewCNode(pad_inputs, graph);
|
||||
std::vector<std::vector<int64_t>> paddings;
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
paddings.emplace_back(std::vector<int64_t>{begin[i], static_cast<int64_t>(shape[i]) - begin[i] - size[i]});
|
||||
|
@ -772,9 +790,11 @@ CNodePtr CreatePadNode(const FuncGraphPtr &graph, const AnfNodePtr &input, const
|
|||
return pad;
|
||||
}
|
||||
|
||||
CNodePtr CreateSplitGradNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad,
|
||||
const CNodePtr &all_to_all_v, const std::vector<CNodePtr> &split_nodes,
|
||||
const std::vector<int64_t> &split_num) {
|
||||
CNodePtr NeighborExchangeV2GradUnifyMindIR::CreateSplitGradNodes(const FuncGraphPtr &graph,
|
||||
const CNodePtr &neighbor_exchange_v2_grad,
|
||||
const CNodePtr &all_to_all_v,
|
||||
const std::vector<CNodePtr> &split_nodes,
|
||||
const std::vector<int64_t> &split_num) const {
|
||||
MS_LOG(DEBUG) << "Start create splitvs grad nodes.";
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_grad);
|
||||
|
@ -810,27 +830,27 @@ CNodePtr CreateSplitGradNodes(const FuncGraphPtr &graph, const CNodePtr &neighbo
|
|||
// create pad nodes
|
||||
// slice begin & size
|
||||
std::vector<std::vector<int64_t>> begins = {{0, 0, 0, 0},
|
||||
{0, 0, 0, static_cast<int64_t>(centerx_shape[3]) - recv_lens[3]},
|
||||
{0, 0, 0, static_cast<int64_t>(centerx_shape[3]) - recv_lens[3]},
|
||||
{0, 0, static_cast<int64_t>(centerx_shape[2]) - recv_lens[1],
|
||||
static_cast<int64_t>(centerx_shape[3]) - recv_lens[3]},
|
||||
{0, 0, static_cast<int64_t>(centerx_shape[2]) - recv_lens[1], 0},
|
||||
{0, 0, static_cast<int64_t>(centerx_shape[2]) - recv_lens[1], 0},
|
||||
{0, 0, 0, static_cast<int64_t>(centerx_shape[kDim3]) - recv_lens[kDim3]},
|
||||
{0, 0, 0, static_cast<int64_t>(centerx_shape[kDim3]) - recv_lens[kDim3]},
|
||||
{0, 0, static_cast<int64_t>(centerx_shape[kDim2]) - recv_lens[kDim1],
|
||||
static_cast<int64_t>(centerx_shape[kDim3]) - recv_lens[kDim3]},
|
||||
{0, 0, static_cast<int64_t>(centerx_shape[kDim2]) - recv_lens[kDim1], 0},
|
||||
{0, 0, static_cast<int64_t>(centerx_shape[kDim2]) - recv_lens[kDim1], 0},
|
||||
{0, 0, 0, 0},
|
||||
{0, 0, 0, 0}};
|
||||
std::vector<std::vector<int64_t>> sizes = {
|
||||
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[0],
|
||||
static_cast<int64_t>(centerx_shape[3])},
|
||||
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[0], recv_lens[3]},
|
||||
static_cast<int64_t>(centerx_shape[kDim3])},
|
||||
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[0], recv_lens[kDim3]},
|
||||
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]),
|
||||
static_cast<int64_t>(centerx_shape[2]), recv_lens[3]},
|
||||
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[1], recv_lens[3]},
|
||||
static_cast<int64_t>(centerx_shape[kDim2]), recv_lens[kDim3]},
|
||||
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[1], recv_lens[kDim3]},
|
||||
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[1],
|
||||
static_cast<int64_t>(centerx_shape[3])},
|
||||
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[1], recv_lens[2]},
|
||||
static_cast<int64_t>(centerx_shape[kDim3])},
|
||||
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[1], recv_lens[kDim2]},
|
||||
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]),
|
||||
static_cast<int64_t>(centerx_shape[2]), recv_lens[2]},
|
||||
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[0], recv_lens[2]}};
|
||||
static_cast<int64_t>(centerx_shape[kDim2]), recv_lens[kDim2]},
|
||||
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[0], recv_lens[kDim2]}};
|
||||
std::vector<CNodePtr> pad_nodes;
|
||||
size_t output_index = 0;
|
||||
for (size_t i = 0; i < recv_rank_ids.size(); ++i) {
|
||||
|
@ -854,7 +874,7 @@ CNodePtr CreateSplitGradNodes(const FuncGraphPtr &graph, const CNodePtr &neighbo
|
|||
addn_inputs.insert(addn_inputs.end(), pad_outputs.begin(), pad_outputs.end());
|
||||
++pad_num;
|
||||
}
|
||||
auto addn = graph->NewCNode(addn_inputs);
|
||||
auto addn = NewCNode(addn_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(addn);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({centerx_dtype}, {centerx_shape}, addn.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue<std::vector<int64_t>>({pad_num}), addn);
|
||||
|
@ -862,7 +882,6 @@ CNodePtr CreateSplitGradNodes(const FuncGraphPtr &graph, const CNodePtr &neighbo
|
|||
MS_LOG(DEBUG) << "Create splitvs grad nodes success.";
|
||||
return addn;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef NeighborExchangeV2UnifyMindIR::DefinePattern() const {
|
||||
return VectorRef({prim::kPrimNeighborExchangeV2, std::make_shared<SeqVar>()});
|
||||
|
@ -876,7 +895,7 @@ const AnfNodePtr NeighborExchangeV2UnifyMindIR::Process(const FuncGraphPtr &grap
|
|||
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2);
|
||||
std::vector<int64_t> split_num;
|
||||
auto split_nodes = CreateSplitNodes(graph, neighbor_exchange_v2, &split_num);
|
||||
auto all_to_all_v = CreateAllToAllvNode(graph, neighbor_exchange_v2, split_nodes, split_num, false);
|
||||
auto all_to_all_v = CreateAllToAllvNode(graph, neighbor_exchange_v2, split_nodes, split_num, false, *this);
|
||||
auto concat = CreateConcatNodes(graph, neighbor_exchange_v2, all_to_all_v);
|
||||
return concat;
|
||||
}
|
||||
|
@ -892,7 +911,7 @@ const AnfNodePtr NeighborExchangeV2GradUnifyMindIR::Process(const FuncGraphPtr &
|
|||
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_grad);
|
||||
std::vector<int64_t> split_num;
|
||||
auto split_nodes = CreateSplitNodesForGrad(graph, neighbor_exchange_v2_grad, &split_num);
|
||||
auto all_to_all_v = CreateAllToAllvNode(graph, neighbor_exchange_v2_grad, split_nodes, split_num, true);
|
||||
auto all_to_all_v = CreateAllToAllvNode(graph, neighbor_exchange_v2_grad, split_nodes, split_num, true, *this);
|
||||
auto add = CreateSplitGradNodes(graph, neighbor_exchange_v2_grad, all_to_all_v, split_nodes, split_num);
|
||||
return add;
|
||||
}
|
||||
|
|
|
@ -30,6 +30,24 @@ class NeighborExchangeV2UnifyMindIR : public PatternProcessPass {
|
|||
~NeighborExchangeV2UnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
std::vector<CNodePtr> CreateSplitNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
|
||||
std::vector<int64_t> *split_num) const;
|
||||
CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &concat_input,
|
||||
const std::vector<std::vector<size_t>> &output_shape,
|
||||
const std::vector<TypeId> &output_dtype, int64_t axis, int64_t input_nums) const;
|
||||
CNodePtr CreateLeftRightConcat(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &all_to_all_v_outputs,
|
||||
const std::vector<int64_t> &recv_rank_ids, const std::vector<int64_t> &recv_lens,
|
||||
bool is_left) const;
|
||||
CNodePtr CreateMiddleConcat(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
|
||||
const std::vector<AnfNodePtr> &all_to_all_v_outputs,
|
||||
const std::vector<int64_t> &recv_rank_ids, const std::vector<int64_t> &recv_lens,
|
||||
int64_t concat_dim) const;
|
||||
CNodePtr AllToAllvRecvEmpty(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
|
||||
const CNodePtr &all_to_all_v) const;
|
||||
CNodePtr CreateConcatNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
|
||||
const CNodePtr &all_to_all_v) const;
|
||||
};
|
||||
|
||||
class NeighborExchangeV2GradUnifyMindIR : public PatternProcessPass {
|
||||
|
@ -39,6 +57,15 @@ class NeighborExchangeV2GradUnifyMindIR : public PatternProcessPass {
|
|||
~NeighborExchangeV2GradUnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
std::vector<CNodePtr> CreateSplitNodesForGrad(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad,
|
||||
std::vector<int64_t> *split_num) const;
|
||||
CNodePtr CreatePadNode(const FuncGraphPtr &graph, const AnfNodePtr &input, const std::vector<int64_t> &begin,
|
||||
const std::vector<int64_t> &size, const std::vector<size_t> &shape, TypeId dtype) const;
|
||||
CNodePtr CreateSplitGradNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad,
|
||||
const CNodePtr &all_to_all_v, const std::vector<CNodePtr> &split_nodes,
|
||||
const std::vector<int64_t> &split_num) const;
|
||||
};
|
||||
|
||||
} // namespace opt
|
||||
|
|
|
@ -30,7 +30,8 @@ constexpr size_t kMomentumOutputNum = 2;
|
|||
constexpr size_t kRMSPropOutputNum = 3;
|
||||
constexpr size_t kCenteredRMSPropOutputNum = 4;
|
||||
|
||||
CNodePtr ProcessOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, const size_t output_size) {
|
||||
CNodePtr ProcessOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, const size_t output_size,
|
||||
const PatternProcessPass &pass) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
|
@ -53,7 +54,7 @@ CNodePtr ProcessOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, const
|
|||
cnode_ptr->set_abstract(abstract_tuple);
|
||||
|
||||
auto index = NewValueNode(static_cast<int64_t>(0));
|
||||
auto get_item = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode_ptr, index});
|
||||
auto get_item = pass.NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode_ptr, index}, graph);
|
||||
MS_EXCEPTION_IF_NULL(get_item);
|
||||
|
||||
get_item->set_abstract(abstract->Clone());
|
||||
|
@ -76,7 +77,7 @@ const BaseRef FtrlUnifyOutput::DefinePattern() const {
|
|||
}
|
||||
|
||||
const AnfNodePtr FtrlUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||
return ProcessOutput(graph, node, kFtrlOutputNum);
|
||||
return ProcessOutput(graph, node, kFtrlOutputNum, *this);
|
||||
}
|
||||
|
||||
const BaseRef MomentumUnifyOutput::DefinePattern() const {
|
||||
|
@ -92,7 +93,7 @@ const BaseRef MomentumUnifyOutput::DefinePattern() const {
|
|||
|
||||
const AnfNodePtr MomentumUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
return ProcessOutput(graph, node, kMomentumOutputNum);
|
||||
return ProcessOutput(graph, node, kMomentumOutputNum, *this);
|
||||
}
|
||||
|
||||
const BaseRef RMSPropUnifyOutput::DefinePattern() const {
|
||||
|
@ -103,7 +104,7 @@ const BaseRef RMSPropUnifyOutput::DefinePattern() const {
|
|||
|
||||
const AnfNodePtr RMSPropUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
return ProcessOutput(graph, node, kRMSPropOutputNum);
|
||||
return ProcessOutput(graph, node, kRMSPropOutputNum, *this);
|
||||
}
|
||||
|
||||
const BaseRef CenteredRMSPropUnifyOutput::DefinePattern() const {
|
||||
|
@ -123,7 +124,7 @@ const BaseRef CenteredRMSPropUnifyOutput::DefinePattern() const {
|
|||
|
||||
const AnfNodePtr CenteredRMSPropUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
return ProcessOutput(graph, node, kCenteredRMSPropOutputNum);
|
||||
return ProcessOutput(graph, node, kCenteredRMSPropOutputNum, *this);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -71,7 +71,7 @@ const AnfNodePtr SliceGradUnifyMindIR::Process(const FuncGraphPtr &graph, const
|
|||
}
|
||||
std::vector<AnfNodePtr> pad_inputs = {NewValueNode(std::make_shared<Primitive>(kPadOpName)),
|
||||
slice_grad->input(kIndex1)};
|
||||
auto pad = graph->NewCNode(pad_inputs);
|
||||
auto pad = NewCNode(pad_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(pad);
|
||||
pad->set_scope(slice_grad->scope());
|
||||
pad->set_abstract(slice_grad->abstract());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -51,7 +51,7 @@ ValueNodePtr CreateValueNode(const ValuePtr &value_ptr, TypeId output_type) {
|
|||
return new_node;
|
||||
}
|
||||
|
||||
CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node,
|
||||
CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const PatternProcessPass &pass,
|
||||
bool is_convert_const_to_attr = false) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
||||
|
@ -96,7 +96,7 @@ CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_
|
|||
one_hot_inputs = {NewValueNode(one_hot_primitive), sparse_softmax_node->input(kIndex2), depth_node, value_on_node,
|
||||
value_off_node};
|
||||
}
|
||||
auto one_hot_node = graph->NewCNode(one_hot_inputs);
|
||||
auto one_hot_node = pass.NewCNode(one_hot_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(one_hot_node);
|
||||
one_hot_node->set_scope(sparse_softmax_node->scope());
|
||||
std::vector<size_t> labels_shape = AnfAlgo ::GetPrevNodeOutputInferShape(sparse_softmax_node, 1);
|
||||
|
@ -109,14 +109,14 @@ CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_
|
|||
}
|
||||
|
||||
CNodePtr CreateSoftmaxCrossEntropyWithLogits(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node,
|
||||
const CNodePtr &one_hot_node) {
|
||||
const CNodePtr &one_hot_node, const PatternProcessPass &pass) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
||||
MS_EXCEPTION_IF_NULL(one_hot_node);
|
||||
CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kSoftmaxCrossEntropyWithLogitsOpName)),
|
||||
sparse_softmax_node->input(kIndex1), one_hot_node};
|
||||
auto softmax_node = graph->NewCNode(inputs);
|
||||
auto softmax_node = pass.NewCNode(inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(softmax_node);
|
||||
softmax_node->set_scope(sparse_softmax_node->scope());
|
||||
|
||||
|
@ -157,7 +157,8 @@ ValueNodePtr GetAxisNode(const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node,
|
||||
const AnfNodePtr &softmax_output_node, bool is_pynative = false) {
|
||||
const AnfNodePtr &softmax_output_node, const PatternProcessPass &pass,
|
||||
bool is_pynative = false) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
||||
MS_EXCEPTION_IF_NULL(softmax_output_node);
|
||||
|
@ -182,7 +183,7 @@ CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_soft
|
|||
kernel_graph->AddValueNodeToGraph(axis_node);
|
||||
inputs = {NewValueNode(reduce_primitive), softmax_output_node, axis_node};
|
||||
}
|
||||
auto reduce_node = graph->NewCNode(inputs);
|
||||
auto reduce_node = pass.NewCNode(inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(reduce_node);
|
||||
reduce_node->set_scope(sparse_softmax_node->scope());
|
||||
auto reduce_abstract = softmax_output_node->abstract();
|
||||
|
@ -194,7 +195,7 @@ CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_soft
|
|||
return reduce_node;
|
||||
}
|
||||
|
||||
CNodePtr CreateExpandDims(const FuncGraphPtr &graph, const CNodePtr &real_div_node) {
|
||||
CNodePtr CreateExpandDims(const FuncGraphPtr &graph, const CNodePtr &real_div_node, const PatternProcessPass &pass) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(real_div_node);
|
||||
CheckCNodeInputSize(real_div_node, kRealDivInputTensorNum);
|
||||
|
@ -213,7 +214,7 @@ CNodePtr CreateExpandDims(const FuncGraphPtr &graph, const CNodePtr &real_div_no
|
|||
expand_dims_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
|
||||
expand_dims_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
|
||||
std::vector<AnfNodePtr> expand_dims_inputs = {NewValueNode(expand_dims_primitive), real_div_node, axis_node};
|
||||
auto expand_dims_node = graph->NewCNode(expand_dims_inputs);
|
||||
auto expand_dims_node = pass.NewCNode(expand_dims_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(expand_dims_node);
|
||||
|
||||
expand_dims_node->set_scope(real_div_node->scope());
|
||||
|
@ -224,7 +225,8 @@ CNodePtr CreateExpandDims(const FuncGraphPtr &graph, const CNodePtr &real_div_no
|
|||
return expand_dims_node;
|
||||
}
|
||||
|
||||
CNodePtr CreateExpandDimsPynative(const FuncGraphPtr &graph, const CNodePtr &real_div_node) {
|
||||
CNodePtr CreateExpandDimsPynative(const FuncGraphPtr &graph, const CNodePtr &real_div_node,
|
||||
const PatternProcessPass &pass) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(real_div_node);
|
||||
CheckCNodeInputSize(real_div_node, kRealDivInputTensorNum);
|
||||
|
@ -236,7 +238,7 @@ CNodePtr CreateExpandDimsPynative(const FuncGraphPtr &graph, const CNodePtr &rea
|
|||
expand_dims_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
|
||||
expand_dims_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
|
||||
std::vector<AnfNodePtr> expand_dims_inputs = {NewValueNode(expand_dims_primitive), real_div_node};
|
||||
auto expand_dims_node = graph->NewCNode(expand_dims_inputs);
|
||||
auto expand_dims_node = pass.NewCNode(expand_dims_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(expand_dims_node);
|
||||
|
||||
expand_dims_node->set_scope(real_div_node->scope());
|
||||
|
@ -249,7 +251,7 @@ CNodePtr CreateExpandDimsPynative(const FuncGraphPtr &graph, const CNodePtr &rea
|
|||
}
|
||||
|
||||
CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const CNodePtr &mul_node,
|
||||
bool is_convert_const_to_attr = false) {
|
||||
const PatternProcessPass &pass, bool is_convert_const_to_attr = false) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
||||
MS_EXCEPTION_IF_NULL(mul_node);
|
||||
|
@ -282,7 +284,7 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no
|
|||
tile_inputs = {NewValueNode(tile_primitive), mul_node->input(2), multiples_node};
|
||||
}
|
||||
|
||||
auto tile_node = graph->NewCNode(tile_inputs);
|
||||
auto tile_node = pass.NewCNode(tile_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(tile_node);
|
||||
tile_node->set_scope(mul_node->scope());
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(mul_node, 1)}, {labels_shape},
|
||||
|
@ -297,7 +299,8 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no
|
|||
return tile_node;
|
||||
}
|
||||
|
||||
CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const AnfNodePtr &tile_node) {
|
||||
CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const AnfNodePtr &tile_node,
|
||||
const PatternProcessPass &pass) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
||||
MS_EXCEPTION_IF_NULL(tile_node);
|
||||
|
@ -320,7 +323,7 @@ CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax
|
|||
real_div_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
|
||||
real_div_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
|
||||
std::vector<AnfNodePtr> real_div_inputs = {NewValueNode(real_div_primitive), tile_node, y_node};
|
||||
auto real_div_node = graph->NewCNode(real_div_inputs);
|
||||
auto real_div_node = pass.NewCNode(real_div_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(real_div_node);
|
||||
|
||||
real_div_node->set_scope(sparse_softmax_node->scope());
|
||||
|
@ -345,7 +348,7 @@ CNodePtr GetDependNode(const CNodePtr &mul_node) {
|
|||
}
|
||||
|
||||
CNodePtr CreateMul(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node,
|
||||
const AnfNodePtr &softmax_output_node) {
|
||||
const AnfNodePtr &softmax_output_node, const PatternProcessPass &pass) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
||||
MS_EXCEPTION_IF_NULL(softmax_output_node);
|
||||
|
@ -377,7 +380,7 @@ CNodePtr CreateMul(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_nod
|
|||
mul_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
|
||||
|
||||
std::vector<AnfNodePtr> mul_input = {NewValueNode(mul_primitive), softmax_output_node, y_node};
|
||||
auto mul_node = graph->NewCNode(mul_input);
|
||||
auto mul_node = pass.NewCNode(mul_input, graph);
|
||||
MS_EXCEPTION_IF_NULL(mul_node);
|
||||
|
||||
mul_node->set_scope(sparse_softmax_node->scope());
|
||||
|
@ -407,12 +410,12 @@ const AnfNodePtr SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(const F
|
|||
}
|
||||
|
||||
CNodePtr softmax_node;
|
||||
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node);
|
||||
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node, one_hot_node);
|
||||
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node, *this);
|
||||
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node, one_hot_node, *this);
|
||||
|
||||
std::vector<AnfNodePtr> softmax_node_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
|
||||
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node, softmax_node_outputs[0]);
|
||||
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node, softmax_node_outputs[0], *this);
|
||||
return reduce_node;
|
||||
}
|
||||
|
||||
|
@ -444,23 +447,23 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con
|
|||
CheckCNodeInputSize(sparse_softmax_node_grad, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
|
||||
|
||||
CNodePtr softmax_node;
|
||||
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad);
|
||||
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node_grad, one_hot_node);
|
||||
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad, *this);
|
||||
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node_grad, one_hot_node, *this);
|
||||
|
||||
std::vector<AnfNodePtr> softmax_node_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
|
||||
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node_grad, softmax_node_outputs[0]);
|
||||
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node);
|
||||
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node_grad, softmax_node_outputs[0], *this);
|
||||
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node, *this);
|
||||
CNodePtr real_div_node;
|
||||
if (tile_node == nullptr) {
|
||||
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, mul_node->input(kIndex2));
|
||||
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, mul_node->input(kIndex2), *this);
|
||||
} else {
|
||||
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node);
|
||||
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node, *this);
|
||||
}
|
||||
auto expand_dims_node = CreateExpandDims(graph, real_div_node);
|
||||
auto expand_dims_node = CreateExpandDims(graph, real_div_node, *this);
|
||||
std::vector<AnfNodePtr> new_mul_inputs = {NewValueNode(std::make_shared<Primitive>(kMulOpName)),
|
||||
softmax_node_outputs[1], expand_dims_node};
|
||||
auto new_mul_node = graph->NewCNode(new_mul_inputs);
|
||||
auto new_mul_node = NewCNode(new_mul_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(new_mul_node);
|
||||
new_mul_node->set_scope(mul_node->scope());
|
||||
new_mul_node->set_abstract(mul_node->abstract());
|
||||
|
@ -496,13 +499,13 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::Process(c
|
|||
auto sparse_softmax_node = GetSparseNode(depend_node, kIndex2);
|
||||
|
||||
CNodePtr softmax_node;
|
||||
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad);
|
||||
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node_grad, one_hot_node);
|
||||
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad, *this);
|
||||
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node_grad, one_hot_node, *this);
|
||||
|
||||
std::vector<AnfNodePtr> softmax_node_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
|
||||
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node_grad, softmax_node_outputs[0]);
|
||||
auto mul_node = CreateMul(graph, sparse_softmax_node_grad, softmax_node_outputs[1]);
|
||||
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node_grad, softmax_node_outputs[0], *this);
|
||||
auto mul_node = CreateMul(graph, sparse_softmax_node_grad, softmax_node_outputs[1], *this);
|
||||
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
@ -521,8 +524,8 @@ const AnfNodePtr PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process
|
|||
CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
|
||||
|
||||
CNodePtr softmax_node;
|
||||
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node, true);
|
||||
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node, one_hot_node);
|
||||
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node, *this, true);
|
||||
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node, one_hot_node, *this);
|
||||
|
||||
std::vector<AnfNodePtr> softmax_node_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
|
||||
|
@ -532,7 +535,7 @@ const AnfNodePtr PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process
|
|||
AnfAlgo::GetNodeAttr<bool>(sparse_softmax_node, kAttrIsGrad)) {
|
||||
return softmax_node_outputs[1];
|
||||
} else {
|
||||
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node, softmax_node_outputs[0], true);
|
||||
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node, softmax_node_outputs[0], *this, true);
|
||||
return reduce_node;
|
||||
}
|
||||
}
|
||||
|
@ -561,22 +564,22 @@ const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Pro
|
|||
CheckCNodeInputSize(sparse_softmax_node_grad, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
|
||||
|
||||
CNodePtr softmax_node;
|
||||
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad);
|
||||
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node_grad, one_hot_node);
|
||||
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad, *this);
|
||||
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node_grad, one_hot_node, *this);
|
||||
|
||||
std::vector<AnfNodePtr> softmax_node_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
|
||||
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node);
|
||||
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node, *this);
|
||||
CNodePtr real_div_node;
|
||||
if (tile_node == nullptr) {
|
||||
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, mul_node->input(kIndex2));
|
||||
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, mul_node->input(kIndex2), *this);
|
||||
} else {
|
||||
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node);
|
||||
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node, *this);
|
||||
}
|
||||
auto expand_dims_node = CreateExpandDimsPynative(graph, real_div_node);
|
||||
auto expand_dims_node = CreateExpandDimsPynative(graph, real_div_node, *this);
|
||||
std::vector<AnfNodePtr> new_mul_inputs = {NewValueNode(std::make_shared<Primitive>(kMulOpName)),
|
||||
softmax_node_outputs[1], expand_dims_node};
|
||||
auto new_mul_node = graph->NewCNode(new_mul_inputs);
|
||||
auto new_mul_node = NewCNode(new_mul_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(new_mul_node);
|
||||
new_mul_node->set_scope(mul_node->scope());
|
||||
new_mul_node->set_abstract(mul_node->abstract());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
Loading…
Reference in New Issue