!25744 IR fusion adapts dump flag

Merge pull request !25744 from yuchaojie/ir_fusion
This commit is contained in:
i-robot 2021-11-12 01:14:20 +00:00 committed by Gitee
commit 10b63dffc0
93 changed files with 1209 additions and 900 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,16 +20,15 @@
namespace mindspore {
namespace opt {
namespace {
AnfNodePtr CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_addn_cnode, size_t begin_index,
size_t offset) {
AnfNodePtr AddnFission::CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_addn_cnode,
size_t begin_index, size_t offset) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(origin_addn_cnode);
std::vector<AnfNodePtr> new_addn_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimAddN->name()))};
for (size_t i = begin_index; i < begin_index + offset; ++i) {
new_addn_inputs.emplace_back(origin_addn_cnode->input(i));
}
CNodePtr new_addn = func_graph->NewCNode(new_addn_inputs);
CNodePtr new_addn = NewCNode(new_addn_inputs, func_graph);
MS_EXCEPTION_IF_NULL(new_addn);
new_addn->set_scope(origin_addn_cnode->scope());
new_addn->set_abstract(origin_addn_cnode->abstract());
@ -38,7 +37,6 @@ AnfNodePtr CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_addn);
return new_addn;
}
} // namespace
const BaseRef AddnFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
@ -68,7 +66,7 @@ const AnfNodePtr AddnFission::Process(const FuncGraphPtr &func_graph, const AnfN
for (size_t i = cur_input_index; i <= origin_input_size; i++) {
base_addn_inputs.emplace_back(new_cnode->input(i));
}
CNodePtr base_addn = func_graph->NewCNode(base_addn_inputs);
CNodePtr base_addn = NewCNode(base_addn_inputs, func_graph);
MS_EXCEPTION_IF_NULL(base_addn);
base_addn->set_scope(new_cnode->scope());
base_addn->set_abstract(new_cnode->abstract());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -30,6 +30,8 @@ class AddnFission : public PatternProcessPass {
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
AnfNodePtr CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_addn_cnode, size_t begin_index,
size_t offset) const;
size_t inputs_divisor_;
};
} // namespace opt

View File

@ -58,8 +58,9 @@ bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, s
}
return output_num == kBatchNormRealOutputNum;
}
} // namespace
AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) {
AnfNodePtr BatchNormBertFission::CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(bn);
auto bn_cnode = bn->cast<CNodePtr>();
@ -70,7 +71,7 @@ AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodeP
}
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)), bn_cnode->input(kIndex1)};
auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs);
auto bn_training_reduce = NewCNode(bn_training_reduce_inputs, func_graph);
MS_EXCEPTION_IF_NULL(bn_training_reduce);
auto bn_input1 = bn_cnode->input(kIndex2);
MS_EXCEPTION_IF_NULL(bn_input1);
@ -84,8 +85,9 @@ AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodeP
return bn_training_reduce;
}
AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNodePtr &bn,
const std::vector<AnfNodePtr> &bn_training_reduce_outputs) {
AnfNodePtr BatchNormBertFission::CreateBNTrainingUpdateV2(
const FuncGraphPtr &func_graph, const AnfNodePtr &bn,
const std::vector<AnfNodePtr> &bn_training_reduce_outputs) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(bn);
auto bn_cnode = bn->cast<CNodePtr>();
@ -106,7 +108,7 @@ AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNod
bn_training_reduce_outputs[kIndex1],
bn_cnode->input(kIndex2),
bn_cnode->input(kIndex3)};
auto bn_training_update_v2 = func_graph->NewCNode(bn_training_update_v2_inputs);
auto bn_training_update_v2 = NewCNode(bn_training_update_v2_inputs, func_graph);
MS_EXCEPTION_IF_NULL(bn_training_update_v2);
auto bn_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn->abstract());
@ -124,7 +126,6 @@ AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNod
AnfAlgo::CopyNodeAttrs(bn, bn_training_update_v2);
return bn_training_update_v2;
}
} // namespace
const BaseRef BatchNormBertFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,6 +16,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BATCH_NORM_BERT_FISSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BATCH_NORM_BERT_FISSION_H_
#include <vector>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
@ -26,6 +27,11 @@ class BatchNormBertFission : public PatternProcessPass {
~BatchNormBertFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) const;
AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNodePtr &bn,
const std::vector<AnfNodePtr> &bn_training_reduce_outputs) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -79,7 +79,7 @@ AnfNodePtr BatchNormGradInferFission::CreateBNInferGrad(const FuncGraphPtr &func
std::vector<AnfNodePtr> bn_infer_grad_inputs = {
NewValueNode(std::make_shared<Primitive>(kBNInferGradOpName)), utils::cast<AnfNodePtr>(iter_input0->second),
utils::cast<AnfNodePtr>(iter_input2->second), utils::cast<AnfNodePtr>(iter_input4->second)};
auto bn_infer_grad = func_graph->NewCNode(bn_infer_grad_inputs);
auto bn_infer_grad = NewCNode(bn_infer_grad_inputs, func_graph);
MS_EXCEPTION_IF_NULL(bn_infer_grad);
// Set abstract, the output of new node is taking the place of the 0th output of bn_grad.
auto bn_grad_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn_grad->abstract());
@ -125,7 +125,7 @@ AnfNodePtr BatchNormGradInferFission::CreateBNTrainingUpdateGrad(const FuncGraph
NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateGradOpName)),
utils::cast<AnfNodePtr>(iter_input0->second), utils::cast<AnfNodePtr>(iter_input1->second),
utils::cast<AnfNodePtr>(iter_input3->second), utils::cast<AnfNodePtr>(iter_input4->second)};
auto bn_training_update_grad = func_graph->NewCNode(bn_training_update_grad_inputs);
auto bn_training_update_grad = NewCNode(bn_training_update_grad_inputs, func_graph);
MS_EXCEPTION_IF_NULL(bn_training_update_grad);
// Set abstract, the outputs of new node are taking the place of the 1st and 2nd outputs of bn_grad.
auto bn_grad_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn_grad->abstract());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,9 +27,8 @@
namespace mindspore {
namespace opt {
namespace {
void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
std::vector<AnfNodePtr> *bn_update_grad_outputs) {
void BatchNormGradSplit::CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
std::vector<AnfNodePtr> *bn_update_grad_outputs) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(bn_grad_node);
const auto &bn_grad_inputs = bn_grad_node->inputs();
@ -37,7 +36,7 @@ void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
std::vector<AnfNodePtr> bn_update_grad_inputs = {
NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateGradOpName)), bn_grad_inputs[kIndex1],
bn_grad_inputs[kIndex2], bn_grad_inputs[kIndex4], bn_grad_inputs[kIndex5]};
auto bn_update_grad = graph->NewCNode(bn_update_grad_inputs);
auto bn_update_grad = NewCNode(bn_update_grad_inputs, graph);
MS_EXCEPTION_IF_NULL(bn_update_grad);
bn_update_grad->set_kernel_info(std::make_shared<device::KernelInfo>());
bn_update_grad->set_scope(bn_grad_node->scope());
@ -50,9 +49,9 @@ void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
CreateMultipleOutputsOfAnfNode(graph, bn_update_grad, kBNTrainingUpdateGradOutputNum, bn_update_grad_outputs);
}
void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
const std::vector<AnfNodePtr> &bn_update_grad_outputs,
std::vector<AnfNodePtr> *bn_reduce_grad_outputs) {
void BatchNormGradSplit::CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
const std::vector<AnfNodePtr> &bn_update_grad_outputs,
std::vector<AnfNodePtr> *bn_reduce_grad_outputs) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(bn_grad_node);
MS_EXCEPTION_IF_NULL(bn_reduce_grad_outputs);
@ -71,7 +70,7 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
bn_grad_inputs[kIndex3],
bn_grad_inputs[kIndex4],
bn_grad_inputs[kIndex5]};
auto bn_reduce_grad = graph->NewCNode(bn_reduce_grad_inputs);
auto bn_reduce_grad = NewCNode(bn_reduce_grad_inputs, graph);
MS_EXCEPTION_IF_NULL(bn_reduce_grad);
bn_reduce_grad->set_kernel_info(std::make_shared<device::KernelInfo>());
bn_reduce_grad->set_scope(bn_grad_node->scope());
@ -83,7 +82,7 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_reduce_grad);
(*bn_reduce_grad_outputs).push_back(bn_reduce_grad);
}
} // namespace
const BaseRef BatchNormGradSplit::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
auto prim = std::make_shared<Primitive>(kBatchNormGradOpName);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,6 +16,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BATCH_NORM_GRAD_SPLIT_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BATCH_NORM_GRAD_SPLIT_H_
#include <vector>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/common/helper.h"
@ -27,6 +28,13 @@ class BatchNormGradSplit : public PatternProcessPass {
~BatchNormGradSplit() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
std::vector<AnfNodePtr> *bn_update_grad_outputs) const;
void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
const std::vector<AnfNodePtr> &bn_update_grad_outputs,
std::vector<AnfNodePtr> *bn_reduce_grad_outputs) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -26,8 +26,7 @@
namespace mindspore {
namespace opt {
namespace {
AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
AnfNodePtr BCEWithLogitsLossFission::AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
@ -36,7 +35,7 @@ AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node)
std::vector<AnfNodePtr> new_simoid_inputs = {
NewValueNode(std::make_shared<Primitive>(prim::kPrimBCEWithLogitsLoss->name()))};
new_simoid_inputs.insert(new_simoid_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
CNodePtr new_cnode = func_graph->NewCNode(new_simoid_inputs);
CNodePtr new_cnode = NewCNode(new_simoid_inputs, func_graph);
MS_EXCEPTION_IF_NULL(new_cnode);
auto predict_input = cnode->inputs()[kIndex1];
auto new_node_dtype = {AnfAlgo::GetOutputInferDataType(predict_input, 0)};
@ -55,7 +54,7 @@ AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node)
MS_LOG(INFO) << "Reduction attr is not mean or sum, can not do fission.";
return nullptr;
}
auto reduce_node = func_graph->NewCNode(reduce_inputs);
auto reduce_node = NewCNode(reduce_inputs, func_graph);
MS_EXCEPTION_IF_NULL(reduce_node);
auto type = AnfAlgo::GetOutputInferDataType(node, 0);
if (type == kNumberTypeFloat16) {
@ -69,7 +68,6 @@ AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node)
reduce_node->set_scope(cnode->scope());
return reduce_node;
}
} // namespace
const BaseRef BCEWithLogitsLossFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();

View File

@ -28,6 +28,9 @@ class BCEWithLogitsLossFission : public PatternProcessPass {
~BCEWithLogitsLossFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -28,9 +28,8 @@
namespace mindspore {
namespace opt {
namespace {
void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
std::vector<AnfNodePtr> *bn_update_grad_outputs) {
void BnGradSplit::CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
std::vector<AnfNodePtr> *bn_update_grad_outputs) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(bn_grad_node);
auto bn_grad_inputs = bn_grad_node->inputs();
@ -38,7 +37,7 @@ void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
std::vector<AnfNodePtr> bn_update_grad_inputs = {
NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateGradOpName)), bn_grad_inputs[kIndex1],
bn_grad_inputs[kIndex2], bn_grad_inputs[kIndex4], bn_grad_inputs[kIndex5]};
auto bn_update_grad = graph->NewCNode(bn_update_grad_inputs);
auto bn_update_grad = NewCNode(bn_update_grad_inputs, graph);
MS_EXCEPTION_IF_NULL(bn_update_grad);
bn_update_grad->set_kernel_info(std::make_shared<device::KernelInfo>());
bn_update_grad->set_scope(bn_grad_node->scope());
@ -51,9 +50,9 @@ void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
CreateMultipleOutputsOfAnfNode(graph, bn_update_grad, kBNTrainingUpdateGradOutputNum, bn_update_grad_outputs);
}
void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
const std::vector<AnfNodePtr> &bn_update_grad_outputs,
std::vector<AnfNodePtr> *bn_reduce_grad_outputs) {
void BnGradSplit::CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
const std::vector<AnfNodePtr> &bn_update_grad_outputs,
std::vector<AnfNodePtr> *bn_reduce_grad_outputs) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(bn_grad_node);
auto bn_grad_inputs = bn_grad_node->inputs();
@ -70,7 +69,7 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
bn_grad_inputs[kIndex3],
bn_grad_inputs[kIndex4],
bn_grad_inputs[kIndex5]};
auto bn_reduce_grad = graph->NewCNode(bn_reduce_grad_inputs);
auto bn_reduce_grad = NewCNode(bn_reduce_grad_inputs, graph);
MS_EXCEPTION_IF_NULL(bn_reduce_grad);
bn_reduce_grad->set_kernel_info(std::make_shared<device::KernelInfo>());
bn_reduce_grad->set_scope(bn_grad_node->scope());
@ -83,7 +82,7 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
(*bn_reduce_grad_outputs).push_back(bn_reduce_grad);
}
CNodePtr BNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
CNodePtr BnGradSplit::BNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> bn_update_grad_outputs;
CreateOutputsOfUpdateGrad(func_graph, cnode, &bn_update_grad_outputs);
@ -106,7 +105,7 @@ CNodePtr BNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode
return make_tuple;
}
CNodePtr SyncBNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
CNodePtr SyncBnGradSplit::SyncBNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode);
std::vector<AnfNodePtr> bn_update_grad_outputs;
@ -119,7 +118,7 @@ CNodePtr SyncBNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &c
std::vector<AnfNodePtr> allreduce_mul_outputs;
for (size_t i = 0; i < bn_update_grad_outputs.size(); ++i) {
auto allreduce_mul_output = CreateAllReduceAndMul(func_graph, bn_update_grad_outputs[i], cnode);
auto allreduce_mul_output = CreateAllReduceAndMul(func_graph, bn_update_grad_outputs[i], cnode, *this);
allreduce_mul_outputs.emplace_back(allreduce_mul_output);
}
@ -136,7 +135,6 @@ CNodePtr SyncBNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &c
MS_EXCEPTION_IF_NULL(make_tuple);
return make_tuple;
}
} // namespace
const BaseRef BnGradSplit::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,6 +16,8 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_
#include <string>
#include <vector>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/common/helper.h"
@ -23,18 +25,31 @@ namespace mindspore {
namespace opt {
class BnGradSplit : public PatternProcessPass {
public:
explicit BnGradSplit(bool multigraph = true) : PatternProcessPass("bn_grad_split", multigraph) {}
explicit BnGradSplit(string name = "bn_grad_split", bool multigraph = true) : PatternProcessPass(name, multigraph) {}
~BnGradSplit() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
protected:
void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
std::vector<AnfNodePtr> *bn_update_grad_outputs) const;
void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
const std::vector<AnfNodePtr> &bn_update_grad_outputs,
std::vector<AnfNodePtr> *bn_reduce_grad_outputs) const;
private:
CNodePtr BNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const;
};
class SyncBnGradSplit : public PatternProcessPass {
class SyncBnGradSplit : public BnGradSplit {
public:
explicit SyncBnGradSplit(bool multigraph = true) : PatternProcessPass("sync_bn_grad_split", multigraph) {}
explicit SyncBnGradSplit(bool multigraph = true) : BnGradSplit("sync_bn_grad_split", multigraph) {}
~SyncBnGradSplit() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
CNodePtr SyncBNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -34,9 +34,10 @@ constexpr auto kReduceOpSum = "sum";
constexpr auto kDeviceNum = "device_num";
constexpr size_t kPositionOffset = 3;
constexpr int64_t kFusionNumThreshold = 2;
} // namespace
bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &bn_cnode,
std::vector<AnfNodePtr> *bn_training_reduce_outputs) {
bool BnSplit::CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &bn_cnode,
std::vector<AnfNodePtr> *bn_training_reduce_outputs) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(bn_cnode);
if (AnfAlgo::GetInputTensorNum(bn_cnode) != kBnInputTensorNum) {
@ -46,7 +47,7 @@ bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName))};
bn_training_reduce_inputs.push_back(bn_cnode->input(kIndex1));
auto bn_training_reduce = graph->NewCNode(bn_training_reduce_inputs);
auto bn_training_reduce = NewCNode(bn_training_reduce_inputs, graph);
MS_EXCEPTION_IF_NULL(bn_training_reduce);
auto kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(kernel_info);
@ -67,8 +68,8 @@ bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &
return true;
}
AnfNodePtr CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNodePtr &bn_cnode,
const std::vector<AnfNodePtr> &bn_training_reduce_outputs) {
AnfNodePtr BnSplit::CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNodePtr &bn_cnode,
const std::vector<AnfNodePtr> &bn_training_reduce_outputs) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(bn_cnode);
CheckCNodeInputSize(bn_cnode, kBnInputTensorNum);
@ -86,7 +87,7 @@ AnfNodePtr CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNod
bn_training_update_inputs.push_back(bn_cnode->input(kIndex3));
bn_training_update_inputs.push_back(bn_cnode->input(kIndex4));
bn_training_update_inputs.push_back(bn_cnode->input(kIndex5));
auto bn_training_update = graph->NewCNode(bn_training_update_inputs);
auto bn_training_update = NewCNode(bn_training_update_inputs, graph);
MS_EXCEPTION_IF_NULL(bn_training_update);
auto kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(kernel_info);
@ -100,7 +101,7 @@ AnfNodePtr CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNod
return bn_training_update;
}
AnfNodePtr SplitBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
AnfNodePtr BnSplit::SplitBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
@ -125,7 +126,7 @@ AnfNodePtr SplitBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr
return CreateOutputsOfBNTrainingUpdate(func_graph, cnode, bn_training_reduce_outputs);
}
AnfNodePtr SyncBNSplitForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
AnfNodePtr SyncBnSplit::SyncBNSplitForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
@ -148,14 +149,13 @@ AnfNodePtr SyncBNSplitForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &n
std::vector<AnfNodePtr> allreduce_mul_outputs;
for (size_t i = 0; i < bn_training_reduce_outputs.size(); ++i) {
auto allreduce_mul_output = CreateAllReduceAndMul(func_graph, bn_training_reduce_outputs[i], cnode);
auto allreduce_mul_output = CreateAllReduceAndMul(func_graph, bn_training_reduce_outputs[i], cnode, *this);
allreduce_mul_outputs.emplace_back(allreduce_mul_output);
}
// Create BNTrainingUpdate node
return CreateOutputsOfBNTrainingUpdate(func_graph, cnode, allreduce_mul_outputs);
}
} // namespace
AnfNodePtr CreateValueNodeOfDeviceNumReciprocal(const FuncGraphPtr &graph, const CNodePtr &sync_bn_cnode) {
MS_EXCEPTION_IF_NULL(graph);
@ -201,7 +201,7 @@ AnfNodePtr InsertCast(const FuncGraphPtr &graph, const AnfNodePtr &input, const
}
AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &allreduce_input,
const CNodePtr &sync_bn_cnode) {
const CNodePtr &sync_bn_cnode, const PatternProcessPass &pass) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(allreduce_input);
MS_EXCEPTION_IF_NULL(sync_bn_cnode);
@ -214,7 +214,7 @@ AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &al
// create AllReduce
std::vector<AnfNodePtr> allreduce_inputs = {NewValueNode(std::make_shared<Primitive>(kAllReduceOpName)), input_node};
auto allreduce = graph->NewCNode(allreduce_inputs);
auto allreduce = pass.NewCNode(allreduce_inputs, graph);
MS_EXCEPTION_IF_NULL(allreduce);
allreduce->set_abstract(input_node->abstract());
allreduce->set_scope(allreduce_input->scope());
@ -238,7 +238,7 @@ AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &al
auto device_num_reciprocal_vnode = CreateValueNodeOfDeviceNumReciprocal(graph, sync_bn_cnode);
std::vector<AnfNodePtr> mul_inputs = {NewValueNode(std::make_shared<Primitive>(kMulOpName)), allreduce,
device_num_reciprocal_vnode};
auto mul = graph->NewCNode(mul_inputs);
auto mul = pass.NewCNode(mul_inputs, graph);
MS_EXCEPTION_IF_NULL(mul);
mul->set_abstract(input_node->abstract());
mul->set_scope(allreduce_input->scope());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,6 +16,8 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BN_SPLIT_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BN_SPLIT_H_
#include <string>
#include <vector>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/common/helper.h"
@ -23,24 +25,36 @@ namespace mindspore {
namespace opt {
class BnSplit : public PatternProcessPass {
public:
explicit BnSplit(bool multigraph = true) : PatternProcessPass("bn_split", multigraph) {}
explicit BnSplit(string name = "bn_split", bool multigraph = true) : PatternProcessPass(name, multigraph) {}
~BnSplit() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
protected:
bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &bn_cnode,
std::vector<AnfNodePtr> *bn_training_reduce_outputs) const;
AnfNodePtr CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNodePtr &bn_cnode,
const std::vector<AnfNodePtr> &bn_training_reduce_outputs) const;
private:
AnfNodePtr SplitBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const;
};
class SyncBnSplit : public PatternProcessPass {
class SyncBnSplit : public BnSplit {
public:
explicit SyncBnSplit(bool multigraph = true) : PatternProcessPass("sync_bn_split", multigraph) {}
explicit SyncBnSplit(bool multigraph = true) : BnSplit("sync_bn_split", multigraph) {}
~SyncBnSplit() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
AnfNodePtr SyncBNSplitForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const;
};
AnfNodePtr CreateValueNodeOfDeviceNumReciprocal(const FuncGraphPtr &graph, const CNodePtr &sync_bn_cnode);
AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &allreduce_input,
const CNodePtr &sync_bn_cnode);
const CNodePtr &sync_bn_cnode, const PatternProcessPass &pass);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BN_SPLIT_H_

View File

@ -55,13 +55,13 @@ std::vector<size_t> CalCdistBroadCastShape(std::vector<size_t> x_shape, std::vec
}
AnfNodePtr AddBroadCastToNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, int64_t dim,
const std::vector<size_t> &need_shape) {
const std::vector<size_t> &need_shape, const PatternProcessPass &pass) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(input_node);
// Add ExpandDims Node
std::vector<AnfNodePtr> expand_dims_inputs = {
NewValueNode(std::make_shared<Primitive>(prim::kPrimExpandDims->name())), input_node};
auto expand_dims = func_graph->NewCNode(expand_dims_inputs);
auto expand_dims = pass.NewCNode(expand_dims_inputs, func_graph);
auto dtype = AnfAlgo::GetOutputInferDataType(input_node, 0);
auto expand_shape = AnfAlgo::GetOutputInferShape(input_node, 0);
(void)expand_shape.insert(expand_shape.end() + dim, 1);
@ -71,7 +71,7 @@ AnfNodePtr AddBroadCastToNode(const FuncGraphPtr &func_graph, const AnfNodePtr &
// Add BroadCastTo Node
std::vector<AnfNodePtr> broadcast_to_inputs = {
NewValueNode(std::make_shared<Primitive>(prim::kPrimBroadcastTo->name())), expand_dims};
auto broadcast_to = func_graph->NewCNode(broadcast_to_inputs);
auto broadcast_to = pass.NewCNode(broadcast_to_inputs, func_graph);
AnfAlgo::SetOutputInferTypeAndShape({dtype}, {need_shape}, broadcast_to.get());
std::vector<int64_t> shape;
(void)std::transform(need_shape.begin(), need_shape.end(), std::back_inserter(shape), LongToSize);
@ -110,11 +110,11 @@ const AnfNodePtr CdistFission::Process(const FuncGraphPtr &graph, const AnfNodeP
auto x_shape = AnfAlgo::GetOutputInferShape(cdist_inputs[kDim1], 0);
auto y_shape = AnfAlgo::GetOutputInferShape(cdist_inputs[kDim2], 0);
auto broadcast_to_shape = CalCdistBroadCastShape(x_shape, y_shape);
auto broadcast_input_x = AddBroadCastToNode(graph, cdist_inputs[kDim1], kInputXDimP, broadcast_to_shape);
auto broadcast_input_y = AddBroadCastToNode(graph, cdist_inputs[kDim2], kInputYDimR, broadcast_to_shape);
auto broadcast_input_x = AddBroadCastToNode(graph, cdist_inputs[kDim1], kInputXDimP, broadcast_to_shape, *this);
auto broadcast_input_y = AddBroadCastToNode(graph, cdist_inputs[kDim2], kInputYDimR, broadcast_to_shape, *this);
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimCdist->name())),
broadcast_input_x, broadcast_input_y};
CNodePtr new_cnode = graph->NewCNode(new_inputs);
CNodePtr new_cnode = NewCNode(new_inputs, graph);
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(cdist_cnode->abstract());
new_cnode->set_scope(cdist_cnode->scope());
@ -140,13 +140,13 @@ const AnfNodePtr CdistGradFission::Process(const FuncGraphPtr &graph, const AnfN
auto x_shape = AnfAlgo::GetOutputInferShape(cdist_grad_inputs[kDim2], 0);
auto y_shape = AnfAlgo::GetOutputInferShape(cdist_grad_inputs[kDim3], 0);
auto broadcast_to_shape = CalCdistBroadCastShape(x_shape, y_shape);
auto broadcast_grad = AddBroadCastToNode(graph, cdist_grad_inputs[kDim1], 0, broadcast_to_shape);
auto broadcast_input_x = AddBroadCastToNode(graph, cdist_grad_inputs[kDim2], kInputXDimP, broadcast_to_shape);
auto broadcast_input_y = AddBroadCastToNode(graph, cdist_grad_inputs[kDim3], kInputYDimR, broadcast_to_shape);
auto broadcast_out = AddBroadCastToNode(graph, cdist_grad_inputs[kDim4], 0, broadcast_to_shape);
auto broadcast_grad = AddBroadCastToNode(graph, cdist_grad_inputs[kDim1], 0, broadcast_to_shape, *this);
auto broadcast_input_x = AddBroadCastToNode(graph, cdist_grad_inputs[kDim2], kInputXDimP, broadcast_to_shape, *this);
auto broadcast_input_y = AddBroadCastToNode(graph, cdist_grad_inputs[kDim3], kInputYDimR, broadcast_to_shape, *this);
auto broadcast_out = AddBroadCastToNode(graph, cdist_grad_inputs[kDim4], 0, broadcast_to_shape, *this);
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimCdistGrad->name())),
broadcast_grad, broadcast_input_x, broadcast_input_y, broadcast_out};
CNodePtr new_cnode = graph->NewCNode(new_inputs);
CNodePtr new_cnode = NewCNode(new_inputs, graph);
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(cdist_grad_cnode->abstract());
new_cnode->set_scope(cdist_grad_cnode->scope());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,16 +21,15 @@
namespace mindspore {
namespace opt {
namespace {
AnfNodePtr CreateNewConcat(const FuncGraphPtr &func_graph, const CNodePtr &origin_concat_cnode, size_t begin_index,
size_t offset) {
AnfNodePtr ConcatFission::CreateNewConcat(const FuncGraphPtr &func_graph, const CNodePtr &origin_concat_cnode,
size_t begin_index, size_t offset) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(origin_concat_cnode);
std::vector<AnfNodePtr> new_concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
for (size_t i = begin_index; i < begin_index + offset; ++i) {
new_concat_inputs.emplace_back(origin_concat_cnode->input(i));
}
CNodePtr new_concat = func_graph->NewCNode(new_concat_inputs);
CNodePtr new_concat = NewCNode(new_concat_inputs, func_graph);
MS_EXCEPTION_IF_NULL(new_concat);
new_concat->set_scope(origin_concat_cnode->scope());
// Set attrs
@ -66,7 +65,6 @@ AnfNodePtr CreateNewConcat(const FuncGraphPtr &func_graph, const CNodePtr &origi
new_concat.get());
return new_concat;
}
} // namespace
const BaseRef ConcatFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
@ -100,7 +98,7 @@ const AnfNodePtr ConcatFission::Process(const FuncGraphPtr &func_graph, const An
for (size_t i = cur_input_index; i <= origin_input_size; i++) {
base_concat_inputs.emplace_back(new_cnode->input(i));
}
CNodePtr base_concat = func_graph->NewCNode(base_concat_inputs);
CNodePtr base_concat = NewCNode(base_concat_inputs, func_graph);
MS_EXCEPTION_IF_NULL(base_concat);
base_concat->set_scope(new_cnode->scope());
base_concat->set_abstract(new_cnode->abstract());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -30,6 +30,8 @@ class ConcatFission : public PatternProcessPass {
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
AnfNodePtr CreateNewConcat(const FuncGraphPtr &func_graph, const CNodePtr &origin_concat_cnode, size_t begin_index,
size_t offset) const;
size_t inputs_divisor_;
};
} // namespace opt

View File

@ -96,7 +96,7 @@ const AnfNodePtr DiagFission::Process(const FuncGraphPtr &graph, const AnfNodePt
auto assist_const = CreateAssistNode(graph, diag_cnode, input_shape);
(void)new_inputs.insert(new_inputs.end(), diag_cnode->inputs().begin() + 1, diag_cnode->inputs().end());
new_inputs.push_back(assist_const);
CNodePtr new_cnode = graph->NewCNode(new_inputs);
CNodePtr new_cnode = NewCNode(new_inputs, graph);
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(diag_cnode->abstract());
new_cnode->set_scope(diag_cnode->scope());

View File

@ -48,7 +48,7 @@ const AnfNodePtr DiagPartFission::Process(const FuncGraphPtr &func_graph, const
(void)new_node_inputs.insert(new_node_inputs.end(), diag_part_cnode->inputs().begin() + 1,
diag_part_cnode->inputs().end());
new_node_inputs.push_back(assist_node);
CNodePtr new_cnode = func_graph->NewCNode(new_node_inputs);
CNodePtr new_cnode = NewCNode(new_node_inputs, func_graph);
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(diag_part_cnode->abstract());
new_cnode->set_scope(diag_part_cnode->scope());

View File

@ -54,11 +54,14 @@ std::map<std::string, size_t> hidden_grad_input_index = {
std::map<std::string, size_t> hidden_grad_output_index = {
{"dh_prev", kIndex0}, {"dgate_h", kIndex1}, {"dnt_x", kIndex2}};
} // namespace
AnfNodePtr CreateGRUV2HiddenGradCellNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2_grad_cnode,
const AnfNodePtr &last_gru_hidden_grad_node,
const AnfNodePtr &last_matmul_node, const std::string &gate_order,
const size_t cur_t) {
AnfNodePtr DynamicGRUV2GradFission::CreateGRUV2HiddenGradCellNode(const FuncGraphPtr &func_graph,
const CNodePtr &dynamic_gru_v2_grad_cnode,
const AnfNodePtr &last_gru_hidden_grad_node,
const AnfNodePtr &last_matmul_node,
const std::string &gate_order,
const size_t cur_t) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(dynamic_gru_v2_grad_cnode);
const auto &dynamic_gru_v2_grad_inputs = dynamic_gru_v2_grad_cnode->inputs();
@ -95,7 +98,7 @@ AnfNodePtr CreateGRUV2HiddenGradCellNode(const FuncGraphPtr &func_graph, const C
(void)gru_v2_hidden_grad_cell_inputs.emplace_back(dynamic_gru_v2_grad_inputs[input_index["reset"]]);
(void)gru_v2_hidden_grad_cell_inputs.emplace_back(dynamic_gru_v2_grad_inputs[input_index["new"]]);
(void)gru_v2_hidden_grad_cell_inputs.emplace_back(dynamic_gru_v2_grad_inputs[input_index["hidden_new"]]);
auto gru_v2_hidden_grad_cell_op = func_graph->NewCNode(gru_v2_hidden_grad_cell_inputs);
auto gru_v2_hidden_grad_cell_op = NewCNode(gru_v2_hidden_grad_cell_inputs, func_graph);
std::vector<size_t> dh_prev_shape =
AnfAlgo::GetOutputInferShape(dynamic_gru_grad_outputs[output_index["dh_prev"]], 0);
@ -108,8 +111,8 @@ AnfNodePtr CreateGRUV2HiddenGradCellNode(const FuncGraphPtr &func_graph, const C
return gru_v2_hidden_grad_cell_op;
}
void AddTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2_grad_cnode,
std::vector<std::vector<AnfNodePtr>> *result_nodes) {
void DynamicGRUV2GradFission::AddTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2_grad_cnode,
std::vector<std::vector<AnfNodePtr>> *result_nodes) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(dynamic_gru_v2_grad_cnode);
MS_EXCEPTION_IF_NULL(result_nodes);
@ -137,12 +140,12 @@ void AddTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2
auto weight_hidden = dynamic_gru_v2_grad_inputs[input_index["weight_hidden"]];
std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
weight_hidden};
auto reshape = func_graph->NewCNode(reshape_inputs);
auto reshape = NewCNode(reshape_inputs, func_graph);
auto reshape_out_shape = {IntToSize(1), AnfAlgo::GetOutputInferShape(weight_hidden, 0)[0],
AnfAlgo::GetOutputInferShape(weight_hidden, 0)[1]};
AnfAlgo::SetOutputInferTypeAndShape({dh_dtype}, {reshape_out_shape}, reshape.get());
(void)matmul_inputs.emplace_back(reshape);
auto matmul_node = func_graph->NewCNode(matmul_inputs);
auto matmul_node = NewCNode(matmul_inputs, func_graph);
MS_EXCEPTION_IF_NULL(matmul_node);
std::vector<size_t> out_shape = {1, batch_size, hidden_size};
AnfAlgo::SetOutputInferTypeAndShape({dh_dtype}, {out_shape}, matmul_node.get());
@ -162,8 +165,9 @@ void AddTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2
(void)result_nodes->emplace_back(matmul_nodes);
}
AnfNodePtr AddTConcatNode(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &gru_hidden_grad_nodes,
size_t concat_output_index) {
AnfNodePtr DynamicGRUV2GradFission::AddTConcatNode(const FuncGraphPtr &func_graph,
const std::vector<AnfNodePtr> &gru_hidden_grad_nodes,
size_t concat_output_index) const {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
for (size_t i = 0; i < t_size; i++) {
@ -174,7 +178,7 @@ AnfNodePtr AddTConcatNode(const FuncGraphPtr &func_graph, const std::vector<AnfN
&gru_hidden_grad_node_outputs);
(void)concat_inputs.emplace_back(gru_hidden_grad_node_outputs[concat_output_index]);
}
auto concat_t_node = func_graph->NewCNode(concat_inputs);
auto concat_t_node = NewCNode(concat_inputs, func_graph);
auto out_dims = AnfAlgo::GetOutputInferShape(gru_hidden_grad_nodes[kIndex0], concat_output_index);
std::vector<size_t> concat_output_shape = {t_size, out_dims[kDim1], out_dims[kDim2]};
auto out_type = AnfAlgo::GetOutputInferDataType(gru_hidden_grad_nodes[kIndex0], concat_output_index);
@ -185,8 +189,8 @@ AnfNodePtr AddTConcatNode(const FuncGraphPtr &func_graph, const std::vector<AnfN
return concat_t_node;
}
std::vector<AnfNodePtr> AddGRUHiddenGradNode(const FuncGraphPtr &func_graph,
const CNodePtr &dynamic_gru_v2_grad_cnode) {
std::vector<AnfNodePtr> DynamicGRUV2GradFission::AddGRUHiddenGradNode(const FuncGraphPtr &func_graph,
const CNodePtr &dynamic_gru_v2_grad_cnode) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(dynamic_gru_v2_grad_cnode);
std::vector<AnfNodePtr> result;
@ -213,13 +217,14 @@ std::vector<AnfNodePtr> AddGRUHiddenGradNode(const FuncGraphPtr &func_graph,
return result;
}
AnfNodePtr AddHSplitNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2_grad_cnode) {
AnfNodePtr DynamicGRUV2GradFission::AddHSplitNode(const FuncGraphPtr &func_graph,
const CNodePtr &dynamic_gru_v2_grad_cnode) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(dynamic_gru_v2_grad_cnode);
auto input_h = dynamic_gru_v2_grad_cnode->input(input_index["h"]);
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
input_h};
auto split_v = func_graph->NewCNode(splitv_input);
auto split_v = NewCNode(splitv_input, func_graph);
// Set infer data type and shape
auto dtypes = {AnfAlgo::GetOutputInferDataType(input_h, 0), AnfAlgo::GetOutputInferDataType(input_h, 0)};
std::vector<size_t> output1_shape = {t_size - 1, batch_size, hidden_size};
@ -235,7 +240,7 @@ AnfNodePtr AddHSplitNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic
return split_v;
}
AnfNodePtr CreateHReshape(const FuncGraphPtr &graph, const AnfNodePtr &node) {
AnfNodePtr DynamicGRUV2GradFission::CreateHReshape(const FuncGraphPtr &graph, const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto ori_shape = AnfAlgo::GetOutputInferShape(node, 0);
@ -248,14 +253,15 @@ AnfNodePtr CreateHReshape(const FuncGraphPtr &graph, const AnfNodePtr &node) {
auto ori_dtype = {AnfAlgo::GetOutputInferDataType(node, 0)};
// reshape
std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())), node};
auto reshape = graph->NewCNode(reshape_input);
auto reshape = NewCNode(reshape_input, graph);
AnfAlgo::SetOutputInferTypeAndShape(ori_dtype, shape_tmp, reshape.get());
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reshape);
return reshape;
}
AnfNodePtr AddHConcatNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2_grad_cnode,
const AnfNodePtr &splitv) {
AnfNodePtr DynamicGRUV2GradFission::AddHConcatNode(const FuncGraphPtr &func_graph,
const CNodePtr &dynamic_gru_v2_grad_cnode,
const AnfNodePtr &splitv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(dynamic_gru_v2_grad_cnode);
MS_EXCEPTION_IF_NULL(splitv);
@ -270,7 +276,7 @@ AnfNodePtr AddHConcatNode(const FuncGraphPtr &func_graph, const CNodePtr &dynami
auto init_h_reshape = CreateHReshape(func_graph, dynamic_gru_v2_grad_cnode->input(input_index["init_h"]));
(void)concat_inputs.emplace_back(init_h_reshape);
(void)concat_inputs.emplace_back(splitv_outputs[kIndex0]);
auto concat = func_graph->NewCNode(concat_inputs);
auto concat = NewCNode(concat_inputs, func_graph);
// Set infer data type and shape
std::vector<size_t> output_shape = {t_size, batch_size, hidden_size};
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(init_h_reshape, 0)}, {output_shape},
@ -283,7 +289,8 @@ AnfNodePtr AddHConcatNode(const FuncGraphPtr &func_graph, const CNodePtr &dynami
return concat;
}
AnfNodePtr AddDwhMatmulNode(const FuncGraphPtr &func_graph, const AnfNodePtr &dgate_h, const AnfNodePtr &node) {
AnfNodePtr DynamicGRUV2GradFission::AddDwhMatmulNode(const FuncGraphPtr &func_graph, const AnfNodePtr &dgate_h,
const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(dgate_h);
MS_EXCEPTION_IF_NULL(node);
@ -297,7 +304,7 @@ AnfNodePtr AddDwhMatmulNode(const FuncGraphPtr &func_graph, const AnfNodePtr &dg
} else {
(void)matmul_inputs.emplace_back(dgate_h);
}
auto batch_matmul = func_graph->NewCNode(matmul_inputs);
auto batch_matmul = NewCNode(matmul_inputs, func_graph);
std::vector<size_t> shape = {t_size, hidden_size, kGateNum * hidden_size};
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {shape}, batch_matmul.get());
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(true), batch_matmul);
@ -306,7 +313,8 @@ AnfNodePtr AddDwhMatmulNode(const FuncGraphPtr &func_graph, const AnfNodePtr &dg
return batch_matmul;
}
AnfNodePtr CreateDgateHSplitVDNode(const FuncGraphPtr &func_graph, const AnfNodePtr &dgate_h) {
AnfNodePtr DynamicGRUV2GradFission::CreateDgateHSplitVDNode(const FuncGraphPtr &func_graph,
const AnfNodePtr &dgate_h) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(dgate_h);
std::vector<AnfNodePtr> splitvd_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
@ -317,7 +325,7 @@ AnfNodePtr CreateDgateHSplitVDNode(const FuncGraphPtr &func_graph, const AnfNode
} else {
(void)splitvd_input.emplace_back(dgate_h);
}
auto split_vd = func_graph->NewCNode(splitvd_input);
auto split_vd = NewCNode(splitvd_input, func_graph);
auto dtypes = {AnfAlgo::GetOutputInferDataType(dgate_h, 0), AnfAlgo::GetOutputInferDataType(dgate_h, 0)};
std::vector<size_t> shape = {t_size, batch_size, hidden_size << 1};
std::vector<size_t> shape2 = {t_size, batch_size, hidden_size};
@ -331,7 +339,8 @@ AnfNodePtr CreateDgateHSplitVDNode(const FuncGraphPtr &func_graph, const AnfNode
return split_vd;
}
AnfNodePtr CreateDgateXConcatDNode(const FuncGraphPtr &func_graph, const AnfNodePtr &split, const AnfNodePtr &dnt_x) {
AnfNodePtr DynamicGRUV2GradFission::CreateDgateXConcatDNode(const FuncGraphPtr &func_graph, const AnfNodePtr &split,
const AnfNodePtr &dnt_x) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(split);
MS_EXCEPTION_IF_NULL(dnt_x);
@ -346,7 +355,7 @@ AnfNodePtr CreateDgateXConcatDNode(const FuncGraphPtr &func_graph, const AnfNode
} else {
(void)concat_inputs.emplace_back(dnt_x);
}
auto concat_op = func_graph->NewCNode(concat_inputs);
auto concat_op = NewCNode(concat_inputs, func_graph);
std::vector<size_t> shape = {t_size, batch_size, kGateNum * hidden_size};
auto types = {AnfAlgo::GetOutputInferDataType(dnt_x, 0)};
AnfAlgo::SetOutputInferTypeAndShape(types, {shape}, concat_op.get());
@ -357,14 +366,15 @@ AnfNodePtr CreateDgateXConcatDNode(const FuncGraphPtr &func_graph, const AnfNode
return concat_op;
}
AnfNodePtr CreateDwxBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) {
AnfNodePtr DynamicGRUV2GradFission::CreateDwxBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &node1,
const AnfNodePtr &node2) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node1);
MS_EXCEPTION_IF_NULL(node2);
// BatchMatMul
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name())),
node1, node2};
auto batch_matmul = graph->NewCNode(matmul_inputs);
auto batch_matmul = NewCNode(matmul_inputs, graph);
MS_EXCEPTION_IF_NULL(batch_matmul);
std::vector<size_t> shape = {t_size, input_size, kGateNum * hidden_size};
AnfAlgo::SetOutputInferTypeAndShape({dh_dtype}, {shape}, batch_matmul.get());
@ -374,15 +384,15 @@ AnfNodePtr CreateDwxBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &nod
return batch_matmul;
}
AnfNodePtr CreateDxtBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr &dgate_concat,
const AnfNodePtr &weight_input, const AnfNodePtr &dx) {
AnfNodePtr DynamicGRUV2GradFission::CreateDxtBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr &dgate_concat,
const AnfNodePtr &weight_input, const AnfNodePtr &dx) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(dgate_concat);
MS_EXCEPTION_IF_NULL(weight_input);
MS_EXCEPTION_IF_NULL(dx);
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name())),
dgate_concat, weight_input};
auto batch_matmul = func_graph->NewCNode(matmul_inputs);
auto batch_matmul = NewCNode(matmul_inputs, func_graph);
MS_EXCEPTION_IF_NULL(batch_matmul);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dx, 0)}, {AnfAlgo::GetOutputInferShape(dx, 0)},
batch_matmul.get());
@ -392,12 +402,12 @@ AnfNodePtr CreateDxtBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr
return batch_matmul;
}
AnfNodePtr CreateWBroadcastToDNode(const FuncGraphPtr &graph, const AnfNodePtr &node) {
AnfNodePtr DynamicGRUV2GradFission::CreateWBroadcastToDNode(const FuncGraphPtr &graph, const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
// BroadcastTo
std::vector<AnfNodePtr> braodcast_to_input = {NewValueNode(std::make_shared<Primitive>(kBroadcastToOpName)), node};
auto broadcast_to_d = graph->NewCNode(braodcast_to_input);
auto broadcast_to_d = NewCNode(braodcast_to_input, graph);
std::vector<size_t> shape = {t_size, input_size, kGateNum * hidden_size};
auto type = {AnfAlgo::GetOutputInferDataType(node, 0)};
AnfAlgo::SetOutputInferTypeAndShape(type, {shape}, broadcast_to_d.get());
@ -407,14 +417,15 @@ AnfNodePtr CreateWBroadcastToDNode(const FuncGraphPtr &graph, const AnfNodePtr &
return broadcast_to_d;
}
AnfNodePtr CreateDwReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &matmul, const AnfNodePtr &gru_grad) {
AnfNodePtr DynamicGRUV2GradFission::CreateDwReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &matmul,
const AnfNodePtr &gru_grad) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(matmul);
MS_EXCEPTION_IF_NULL(gru_grad);
// ReduceSumD for dw_x and dw_h
std::vector<AnfNodePtr> reducesum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
matmul};
auto reduce_sumd = graph->NewCNode(reducesum_inputs);
auto reduce_sumd = NewCNode(reducesum_inputs, graph);
auto types = {AnfAlgo::GetOutputInferDataType(gru_grad, 0)};
auto shapes = {AnfAlgo::GetOutputInferShape(gru_grad, 0)};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, reduce_sumd.get());
@ -424,14 +435,15 @@ AnfNodePtr CreateDwReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &m
return reduce_sumd;
}
AnfNodePtr CreateDbReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &node2) {
AnfNodePtr DynamicGRUV2GradFission::CreateDbReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &node,
const AnfNodePtr &node2) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node2);
// ReduceSumD for db_x and db_h
std::vector<AnfNodePtr> reducesum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
node};
auto reduce_sumd = graph->NewCNode(reducesum_inputs);
auto reduce_sumd = NewCNode(reducesum_inputs, graph);
MS_EXCEPTION_IF_NULL(reduce_sumd);
std::vector<size_t> shape = {kGateNum * hidden_size};
auto types = {AnfAlgo::GetOutputInferDataType(node2, 0)};
@ -441,7 +453,6 @@ AnfNodePtr CreateDbReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &n
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sumd);
return reduce_sumd;
}
} // namespace
const BaseRef DynamicGRUV2GradFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();

View File

@ -17,6 +17,7 @@
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_GRU_V2_GRAD_FISSION_H_
#include <vector>
#include <string>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
@ -28,6 +29,33 @@ class DynamicGRUV2GradFission : public PatternProcessPass {
~DynamicGRUV2GradFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
AnfNodePtr CreateGRUV2HiddenGradCellNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2_grad_cnode,
const AnfNodePtr &last_gru_hidden_grad_node,
const AnfNodePtr &last_matmul_node, const std::string &gate_order,
const size_t cur_t) const;
void AddTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2_grad_cnode,
std::vector<std::vector<AnfNodePtr>> *result_nodes) const;
AnfNodePtr AddTConcatNode(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &gru_hidden_grad_nodes,
size_t concat_output_index) const;
std::vector<AnfNodePtr> AddGRUHiddenGradNode(const FuncGraphPtr &func_graph,
const CNodePtr &dynamic_gru_v2_grad_cnode) const;
AnfNodePtr AddHSplitNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2_grad_cnode) const;
AnfNodePtr CreateHReshape(const FuncGraphPtr &graph, const AnfNodePtr &node) const;
AnfNodePtr AddHConcatNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_gru_v2_grad_cnode,
const AnfNodePtr &splitv) const;
AnfNodePtr AddDwhMatmulNode(const FuncGraphPtr &func_graph, const AnfNodePtr &dgate_h, const AnfNodePtr &node) const;
AnfNodePtr CreateDgateHSplitVDNode(const FuncGraphPtr &func_graph, const AnfNodePtr &dgate_h) const;
AnfNodePtr CreateDgateXConcatDNode(const FuncGraphPtr &func_graph, const AnfNodePtr &split,
const AnfNodePtr &dnt_x) const;
AnfNodePtr CreateDwxBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) const;
AnfNodePtr CreateDxtBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr &dgate_concat,
const AnfNodePtr &weight_input, const AnfNodePtr &dx) const;
AnfNodePtr CreateWBroadcastToDNode(const FuncGraphPtr &graph, const AnfNodePtr &node) const;
AnfNodePtr CreateDwReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &matmul,
const AnfNodePtr &gru_grad) const;
AnfNodePtr CreateDbReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &node2) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -34,9 +34,10 @@ constexpr int64_t kAttrAxis2Value = 2;
constexpr int64_t kAttrNumSplitValue = 2;
constexpr int64_t kAttrSplitDimValue = 2;
constexpr size_t kDimMultiNum = 4;
} // namespace
void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
std::vector<std::vector<AnfNodePtr>> *result_nodes) {
void DynamicRnnGradFissionV2::CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
std::vector<std::vector<AnfNodePtr>> *result_nodes) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
MS_EXCEPTION_IF_NULL(result_nodes);
@ -52,7 +53,7 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
// Create basic_lstm_cell_c_state_grad
std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_inputs = {
NewValueNode(std::make_shared<Primitive>(kBasicLSTMCellCStateGradV2OpName))};
auto basic_lstm_cell_c_state_grad = func_graph->NewCNode(basic_lstm_cell_c_state_grad_inputs);
auto basic_lstm_cell_c_state_grad = NewCNode(basic_lstm_cell_c_state_grad_inputs, func_graph);
std::vector<size_t> output0_dims{
origin_input9_shape[kDim0],
@ -66,7 +67,7 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
// Create matmul
auto origin_input1_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex2), 0);
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))};
auto matmul = func_graph->NewCNode(matmul_inputs);
auto matmul = NewCNode(matmul_inputs, func_graph);
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {{IntToSize(1), output0_dims[0], origin_input1_shape[0]}},
matmul.get());
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), matmul);
@ -74,7 +75,7 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
// Create split
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
auto split_v = func_graph->NewCNode(splitv_input);
auto split_v = NewCNode(splitv_input, func_graph);
auto origin_output2_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, kIndex2);
auto origin_output3_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, kIndex3);
std::vector<size_t> split_v_output0_shape{IntToSize(1), origin_output2_shape[kDim1], origin_output2_shape[kDim2]};
@ -99,13 +100,13 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
result_nodes->emplace_back(split_nodes);
}
AnfNodePtr CreateLSTMSPlitV(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
const std::vector<std::vector<size_t>> &split_shapes,
const std::vector<TypeId> &split_types, const std::vector<int64_t> &size_split,
size_t num_split_x) {
AnfNodePtr DynamicRnnGradFissionV2::CreateLSTMSPlitV(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
const std::vector<std::vector<size_t>> &split_shapes,
const std::vector<TypeId> &split_types,
const std::vector<int64_t> &size_split, size_t num_split_x) const {
std::vector<AnfNodePtr> lstm_split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
input};
auto lstm_split = func_graph->NewCNode(lstm_split_input);
auto lstm_split = NewCNode(lstm_split_input, func_graph);
AnfAlgo::SetOutputInferTypeAndShape(split_types, split_shapes, lstm_split.get());
AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(size_split), lstm_split);
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(0)), lstm_split);
@ -113,76 +114,27 @@ AnfNodePtr CreateLSTMSPlitV(const FuncGraphPtr &func_graph, const AnfNodePtr &in
return lstm_split;
}
AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
std::vector<AnfNodePtr> *outputs) {
std::vector<std::vector<AnfNodePtr>> result_nodes;
CreateTLoopNode(func_graph, dynamic_rnn_grad_cnode, &result_nodes);
auto origin_input5_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex6), 0);
std::vector<size_t> split_c_dims{IntToSize(1), origin_input5_shape[0], origin_input5_shape[1]};
auto origin_input7 = dynamic_rnn_grad_cnode->input(kIndex8);
size_t num_split_x = AnfAlgo::GetOutputInferShape(origin_input7, 0)[0];
std::vector<std::vector<size_t>> split_shapes;
std::vector<TypeId> split_types;
std::vector<int64_t> size_split;
for (size_t i = 0; i < num_split_x; ++i) {
split_shapes.emplace_back(split_c_dims);
split_types.emplace_back(kNumberTypeFloat32);
size_split.emplace_back(1);
}
// Create lstm_split_c
auto lstm_split_c = CreateLSTMSPlitV(func_graph, origin_input7, split_shapes, split_types, size_split, num_split_x);
std::vector<AnfNodePtr> lstm_split_c_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_c, num_split_x, &lstm_split_c_outputs);
// Create lstm_split_dy
auto lstm_split_dy = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex9), split_shapes, split_types,
size_split, num_split_x);
std::vector<AnfNodePtr> lstm_split_dy_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_dy, num_split_x, &lstm_split_dy_outputs);
// Create lstm_split_i
auto lstm_split_i = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex12), split_shapes, split_types,
size_split, num_split_x);
std::vector<AnfNodePtr> lstm_split_i_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_i, num_split_x, &lstm_split_i_outputs);
// Create lstm_split_j
auto lstm_split_j = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex13), split_shapes, split_types,
size_split, num_split_x);
std::vector<AnfNodePtr> lstm_split_j_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_j, num_split_x, &lstm_split_j_outputs);
// Create lstm_split_f
auto lstm_split_f = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex14), split_shapes, split_types,
size_split, num_split_x);
std::vector<AnfNodePtr> lstm_split_f_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_f, num_split_x, &lstm_split_f_outputs);
// Create lstm_split_o
auto lstm_split_o = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex15), split_shapes, split_types,
size_split, num_split_x);
std::vector<AnfNodePtr> lstm_split_o_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_o, num_split_x, &lstm_split_o_outputs);
// Create lstm_split_tanh
auto lstm_split_tanh = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex16), split_shapes,
split_types, size_split, num_split_x);
std::vector<AnfNodePtr> lstm_split_tanh_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_tanh, num_split_x, &lstm_split_tanh_outputs);
// Add edges
void DynamicRnnGradFissionV2::CreateTLoopNodeWithEdge(const FuncGraphPtr &func_graph,
const CNodePtr &dynamic_rnn_grad_cnode,
const std::vector<std::vector<AnfNodePtr>> &result_nodes,
size_t num_split_x,
std::vector<std::vector<AnfNodePtr>> *loop_node_outputs) const {
auto &basic_lstm_cell_c_state_grad_nodes = result_nodes[kIndex0];
auto &matmul_nodes = result_nodes[kIndex1];
auto &split_nodes = result_nodes[kIndex2];
auto &lstm_split_c_outputs = result_nodes[kIndex3];
auto &lstm_split_dy_outputs = result_nodes[kIndex4];
auto &lstm_split_i_outputs = result_nodes[kIndex5];
auto &lstm_split_j_outputs = result_nodes[kIndex6];
auto &lstm_split_f_outputs = result_nodes[kIndex7];
auto &lstm_split_o_outputs = result_nodes[kIndex8];
auto &lstm_split_tanh_outputs = result_nodes[kIndex9];
std::vector<AnfNodePtr> pre_basic_lstm_cell_c_state_grad_outputs;
std::vector<AnfNodePtr> pre_split_outputs;
auto basic_lstm_cell_c_state_grad_nodes = result_nodes[kIndex0];
auto matmul_nodes = result_nodes[kIndex1];
auto split_nodes = result_nodes[kIndex2];
std::vector<AnfNodePtr> lstm_x_concat_input(num_split_x + 1);
lstm_x_concat_input[0] = NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()));
std::vector<AnfNodePtr> lstm_gage_concat_input(num_split_x + 1);
lstm_gage_concat_input[0] = NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()));
for (size_t i = 0; i < num_split_x; ++i) {
size_t idx = num_split_x - i - 1;
// Create basic_lstm_cell_c_state_grad
@ -191,7 +143,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
if (i == num_split_x - 1) {
std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
dynamic_rnn_grad_cnode->input(6)};
auto reshape = func_graph->NewCNode(reshape_inputs);
auto reshape = NewCNode(reshape_inputs, func_graph);
auto reshape_out_shape = {IntToSize(1),
AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex6), 0)[0],
AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex6), 0)[1]};
@ -213,7 +165,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
(void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_f_outputs[idx]);
(void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_o_outputs[idx]);
(void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_tanh_outputs[idx]);
auto basic_lstm_cell_c_state_grad = func_graph->NewCNode(basic_lstm_cell_c_state_grad_inputs);
auto basic_lstm_cell_c_state_grad = NewCNode(basic_lstm_cell_c_state_grad_inputs, func_graph);
MS_EXCEPTION_IF_NULL(basic_lstm_cell_c_state_grad);
basic_lstm_cell_c_state_grad->set_abstract(basic_lstm_cell_c_state_grad_nodes[i]->abstract());
AnfAlgo::CopyNodeAttrs(basic_lstm_cell_c_state_grad_nodes[i], basic_lstm_cell_c_state_grad);
@ -227,7 +179,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))};
(void)matmul_inputs.emplace_back(basic_lstm_cell_c_state_grad_outputs[0]);
(void)matmul_inputs.emplace_back(dynamic_rnn_grad_cnode->input(kIndex2));
auto matmul = func_graph->NewCNode(matmul_inputs);
auto matmul = NewCNode(matmul_inputs, func_graph);
MS_EXCEPTION_IF_NULL(matmul);
matmul->set_abstract(matmul_nodes[i]->abstract());
AnfAlgo::CopyNodeAttrs(matmul_nodes[i], matmul);
@ -235,7 +187,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
// Create splitv
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
matmul};
auto split_v = func_graph->NewCNode(splitv_input);
auto split_v = NewCNode(splitv_input, func_graph);
MS_EXCEPTION_IF_NULL(split_v);
split_v->set_abstract(split_nodes[i]->abstract());
AnfAlgo::CopyNodeAttrs(split_nodes[i], split_v);
@ -258,14 +210,94 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
}
std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
basic_lstm_cell_c_state_grad_outputs[0]};
auto reshape = func_graph->NewCNode(reshape_input);
auto reshape = NewCNode(reshape_input, func_graph);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(basic_lstm_cell_c_state_grad_outputs[0], 0)},
{temp_shape}, reshape.get());
lstm_gage_concat_input[idx + 1] = reshape;
}
loop_node_outputs->push_back(pre_basic_lstm_cell_c_state_grad_outputs);
loop_node_outputs->push_back(pre_split_outputs);
loop_node_outputs->push_back(lstm_x_concat_input);
loop_node_outputs->push_back(lstm_gage_concat_input);
}
AnfNodePtr DynamicRnnGradFissionV2::AddLSTMInputGradNode(const FuncGraphPtr &func_graph,
const CNodePtr &dynamic_rnn_grad_cnode,
std::vector<AnfNodePtr> *outputs) const {
std::vector<std::vector<AnfNodePtr>> result_nodes;
CreateTLoopNode(func_graph, dynamic_rnn_grad_cnode, &result_nodes);
auto origin_input5_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex6), 0);
std::vector<size_t> split_c_dims{IntToSize(1), origin_input5_shape[0], origin_input5_shape[1]};
auto origin_input7 = dynamic_rnn_grad_cnode->input(kIndex8);
size_t num_split_x = AnfAlgo::GetOutputInferShape(origin_input7, 0)[0];
std::vector<std::vector<size_t>> split_shapes;
std::vector<TypeId> split_types;
std::vector<int64_t> size_split;
for (size_t i = 0; i < num_split_x; ++i) {
split_shapes.emplace_back(split_c_dims);
split_types.emplace_back(kNumberTypeFloat32);
size_split.emplace_back(1);
}
// Create lstm_split_c
auto lstm_split_c = CreateLSTMSPlitV(func_graph, origin_input7, split_shapes, split_types, size_split, num_split_x);
std::vector<AnfNodePtr> lstm_split_c_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_c, num_split_x, &lstm_split_c_outputs);
result_nodes.push_back(lstm_split_c_outputs);
// Create lstm_split_dy
auto lstm_split_dy = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex9), split_shapes, split_types,
size_split, num_split_x);
std::vector<AnfNodePtr> lstm_split_dy_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_dy, num_split_x, &lstm_split_dy_outputs);
result_nodes.push_back(lstm_split_dy_outputs);
// Create lstm_split_i
auto lstm_split_i = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex12), split_shapes, split_types,
size_split, num_split_x);
std::vector<AnfNodePtr> lstm_split_i_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_i, num_split_x, &lstm_split_i_outputs);
result_nodes.push_back(lstm_split_i_outputs);
// Create lstm_split_j
auto lstm_split_j = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex13), split_shapes, split_types,
size_split, num_split_x);
std::vector<AnfNodePtr> lstm_split_j_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_j, num_split_x, &lstm_split_j_outputs);
result_nodes.push_back(lstm_split_j_outputs);
// Create lstm_split_f
auto lstm_split_f = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex14), split_shapes, split_types,
size_split, num_split_x);
std::vector<AnfNodePtr> lstm_split_f_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_f, num_split_x, &lstm_split_f_outputs);
result_nodes.push_back(lstm_split_f_outputs);
// Create lstm_split_o
auto lstm_split_o = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex15), split_shapes, split_types,
size_split, num_split_x);
std::vector<AnfNodePtr> lstm_split_o_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_o, num_split_x, &lstm_split_o_outputs);
result_nodes.push_back(lstm_split_o_outputs);
// Create lstm_split_tanh
auto lstm_split_tanh = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex16), split_shapes,
split_types, size_split, num_split_x);
std::vector<AnfNodePtr> lstm_split_tanh_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_tanh, num_split_x, &lstm_split_tanh_outputs);
result_nodes.push_back(lstm_split_tanh_outputs);
// Add edges
std::vector<std::vector<AnfNodePtr>> loop_node_outputs;
CreateTLoopNodeWithEdge(func_graph, dynamic_rnn_grad_cnode, result_nodes, num_split_x, &loop_node_outputs);
auto &pre_basic_lstm_cell_c_state_grad_outputs = loop_node_outputs[kIndex0];
auto &pre_split_outputs = loop_node_outputs[kIndex1];
auto &lstm_x_concat_input = loop_node_outputs[kIndex2];
auto &lstm_gage_concat_input = loop_node_outputs[kIndex3];
// Create lstm_x_concat
auto lstm_x_concat = func_graph->NewCNode(lstm_x_concat_input);
auto lstm_x_concat = NewCNode(lstm_x_concat_input, func_graph);
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 2)},
lstm_x_concat.get());
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_x_concat);
@ -273,7 +305,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast<int64_t>(0)), lstm_x_concat);
// Create lstm_gage_concat
auto lstm_gage_concat = func_graph->NewCNode(lstm_gage_concat_input);
auto lstm_gage_concat = NewCNode(lstm_gage_concat_input, func_graph);
auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0);
AnfAlgo::SetOutputInferTypeAndShape(
{kNumberTypeFloat16},
@ -289,14 +321,15 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
return lstm_gage_concat;
}
AnfNodePtr CreateSplitV(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) {
AnfNodePtr DynamicRnnGradFissionV2::CreateSplitV(const FuncGraphPtr &func_graph,
const CNodePtr &dynamic_rnn_grad_cnode) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
// Create node
auto origin_input6 = dynamic_rnn_grad_cnode->input(kIndex7);
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
origin_input6};
auto split_v = func_graph->NewCNode(splitv_input);
auto split_v = NewCNode(splitv_input, func_graph);
// Set infer data type and shape
auto dtypes = {AnfAlgo::GetOutputInferDataType(origin_input6, 0), AnfAlgo::GetOutputInferDataType(origin_input6, 0)};
auto origin_input6_shape = AnfAlgo::GetOutputInferShape(origin_input6, 0);
@ -313,8 +346,9 @@ AnfNodePtr CreateSplitV(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_
return split_v;
}
AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
const AnfNodePtr &splitv) {
AnfNodePtr DynamicRnnGradFissionV2::CreateHConcat(const FuncGraphPtr &func_graph,
const CNodePtr &dynamic_rnn_grad_cnode,
const AnfNodePtr &splitv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
MS_EXCEPTION_IF_NULL(splitv);
@ -336,11 +370,11 @@ AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic
}
std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
origin_input4};
auto reshape = func_graph->NewCNode(reshape_input);
auto reshape = NewCNode(reshape_input, func_graph);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape_tmp}, reshape.get());
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
reshape, splitv_outputs[0]};
auto concat = func_graph->NewCNode(concat_inputs);
auto concat = NewCNode(concat_inputs, func_graph);
// Set infer data type and shape
auto splitv_output0_shape = AnfAlgo::GetOutputInferShape(splitv, 0);
std::vector<size_t> shape = {splitv_output0_shape[0] + 1, origin_input4_shape[0], origin_input4_shape[1]};
@ -353,15 +387,15 @@ AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic
return concat;
}
AnfNodePtr CreateConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
const AnfNodePtr &h_concat) {
AnfNodePtr DynamicRnnGradFissionV2::CreateConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
const AnfNodePtr &h_concat) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
// Create node
auto origin_input0 = dynamic_rnn_grad_cnode->input(1);
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
origin_input0, h_concat};
auto concat = func_graph->NewCNode(concat_inputs);
auto concat = NewCNode(concat_inputs, func_graph);
// Set infer data type and shape
auto origin_output0_shape = AnfAlgo::GetOutputInferShape(origin_input0, 0);
auto h_concat_output_shape = AnfAlgo::GetOutputInferShape(h_concat, 0);
@ -376,7 +410,8 @@ AnfNodePtr CreateConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_
return concat;
}
AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) {
AnfNodePtr DynamicRnnGradFissionV2::CreateConcatNodeT1(const FuncGraphPtr &func_graph,
const CNodePtr &dynamic_rnn_grad_cnode) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
// Create node
@ -392,12 +427,12 @@ AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dy
}
std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
origin_input4};
auto reshape = func_graph->NewCNode(reshape_input);
auto reshape = NewCNode(reshape_input, func_graph);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape_tmp}, reshape.get());
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
origin_input0, reshape};
auto concat = func_graph->NewCNode(concat_inputs);
auto concat = NewCNode(concat_inputs, func_graph);
// Set infer data type and shape
auto origin_input0_shape = AnfAlgo::GetOutputInferShape(origin_input0, 0);
std::vector<size_t> shape = {origin_input0_shape[kDim0], origin_input0_shape[kDim1],
@ -411,13 +446,13 @@ AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dy
return concat;
}
AnfNodePtr CreateBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr &lstm_input_grad,
const AnfNodePtr &concat) {
AnfNodePtr DynamicRnnGradFissionV2::CreateBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr &lstm_input_grad,
const AnfNodePtr &concat) const {
MS_EXCEPTION_IF_NULL(func_graph);
// Create node
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name())),
concat, lstm_input_grad};
auto batch_matmul = func_graph->NewCNode(matmul_inputs);
auto batch_matmul = NewCNode(matmul_inputs, func_graph);
// Set infer data type and shape
auto concat_shape = AnfAlgo::GetOutputInferShape(concat, 0);
auto lstm_input_grad_shape = AnfAlgo::GetOutputInferShape(lstm_input_grad, 0);
@ -430,13 +465,14 @@ AnfNodePtr CreateBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr &l
return batch_matmul;
}
AnfNodePtr CreateBatchMatMul2(const FuncGraphPtr &func_graph, const AnfNodePtr &lstm_input_grad,
const AnfNodePtr &node) {
AnfNodePtr DynamicRnnGradFissionV2::CreateBatchMatMul2(const FuncGraphPtr &func_graph,
const AnfNodePtr &lstm_input_grad,
const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(func_graph);
// Create node
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name())),
node, lstm_input_grad};
auto batch_matmul = func_graph->NewCNode(matmul_inputs);
auto batch_matmul = NewCNode(matmul_inputs, func_graph);
// Set infer data type and shape
auto out_shape = {AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[kIndex0], IntToSize(1),
AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[kIndex2]};
@ -448,13 +484,14 @@ AnfNodePtr CreateBatchMatMul2(const FuncGraphPtr &func_graph, const AnfNodePtr &
return batch_matmul;
}
AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
const AnfNodePtr &batch_matmul) {
AnfNodePtr DynamicRnnGradFissionV2::CreateDwReduceSum(const FuncGraphPtr &func_graph,
const CNodePtr &dynamic_rnn_grad_cnode,
const AnfNodePtr &batch_matmul) const {
MS_EXCEPTION_IF_NULL(func_graph);
// Create node
std::vector<AnfNodePtr> reduce_sum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
batch_matmul};
auto reduce_sum = func_graph->NewCNode(reduce_sum_inputs);
auto reduce_sum = NewCNode(reduce_sum_inputs, func_graph);
// Set infer data type and shape
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)},
{AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 0)}, reduce_sum.get());
@ -465,13 +502,14 @@ AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dyn
return reduce_sum;
}
AnfNodePtr CreateDwReshape(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
const AnfNodePtr &batch_matmul) {
AnfNodePtr DynamicRnnGradFissionV2::CreateDwReshape(const FuncGraphPtr &func_graph,
const CNodePtr &dynamic_rnn_grad_cnode,
const AnfNodePtr &batch_matmul) const {
MS_EXCEPTION_IF_NULL(func_graph);
// Create node
std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
batch_matmul};
auto reshape = func_graph->NewCNode(reshape_inputs);
auto reshape = NewCNode(reshape_inputs, func_graph);
// Set infer data type and shape
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)},
{AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 0)}, reshape.get());
@ -479,7 +517,8 @@ AnfNodePtr CreateDwReshape(const FuncGraphPtr &func_graph, const CNodePtr &dynam
return reshape;
}
AnfNodePtr CreateValueNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) {
AnfNodePtr DynamicRnnGradFissionV2::CreateValueNode(const FuncGraphPtr &func_graph,
const CNodePtr &dynamic_rnn_grad_cnode) const {
auto origin_input7 = dynamic_rnn_grad_cnode->input(kIndex8);
auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0);
auto t_size = origin_input7_shape[0];
@ -497,14 +536,15 @@ AnfNodePtr CreateValueNode(const FuncGraphPtr &func_graph, const CNodePtr &dynam
return value_node;
}
AnfNodePtr CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &, const AnfNodePtr &lstm_input_grad,
const AnfNodePtr &value_node) {
AnfNodePtr DynamicRnnGradFissionV2::CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &,
const AnfNodePtr &lstm_input_grad,
const AnfNodePtr &value_node) const {
MS_EXCEPTION_IF_NULL(func_graph);
// Create node
auto batch_matmul = CreateBatchMatMul2(func_graph, lstm_input_grad, value_node);
std::vector<AnfNodePtr> reduce_sum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
batch_matmul};
auto reduce_sum = func_graph->NewCNode(reduce_sum_inputs);
auto reduce_sum = NewCNode(reduce_sum_inputs, func_graph);
// Set infer data type and shape
auto out_shape = {AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[kDim2]};
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {out_shape}, reduce_sum.get());
@ -514,7 +554,6 @@ AnfNodePtr CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &, c
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum);
return reduce_sum;
}
} // namespace
const BaseRef DynamicRnnGradFissionV2::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_V2_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_RNN_GRAD_FISSION_V2_H_
#include <vector>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
@ -28,6 +29,36 @@ class DynamicRnnGradFissionV2 : public PatternProcessPass {
~DynamicRnnGradFissionV2() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
std::vector<std::vector<AnfNodePtr>> *result_nodes) const;
AnfNodePtr CreateLSTMSPlitV(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
const std::vector<std::vector<size_t>> &split_shapes,
const std::vector<TypeId> &split_types, const std::vector<int64_t> &size_split,
size_t num_split_x) const;
void CreateTLoopNodeWithEdge(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
const std::vector<std::vector<AnfNodePtr>> &result_nodes, size_t num_split_x,
std::vector<std::vector<AnfNodePtr>> *loop_node_outputs) const;
AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
std::vector<AnfNodePtr> *outputs) const;
AnfNodePtr CreateSplitV(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) const;
AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
const AnfNodePtr &splitv) const;
AnfNodePtr CreateConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
const AnfNodePtr &h_concat) const;
AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) const;
AnfNodePtr CreateBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr &lstm_input_grad,
const AnfNodePtr &concat) const;
AnfNodePtr CreateBatchMatMul2(const FuncGraphPtr &func_graph, const AnfNodePtr &lstm_input_grad,
const AnfNodePtr &node) const;
AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
const AnfNodePtr &batch_matmul) const;
AnfNodePtr CreateDwReshape(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
const AnfNodePtr &batch_matmul) const;
AnfNodePtr CreateValueNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) const;
AnfNodePtr CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &, const AnfNodePtr &lstm_input_grad,
const AnfNodePtr &value_node) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -28,12 +28,36 @@ constexpr size_t kOriginPaddingSize = 2;
constexpr size_t kGatherInputNum = 4;
constexpr size_t kGatherInputIndicesIndex = 2;
constexpr size_t kGatherInputAxisIndex = 3;
bool CheckInputs(const CNodePtr &origin_node) {
MS_EXCEPTION_IF_NULL(origin_node);
if (AnfAlgo::GetInputTensorNum(origin_node) != kGatherV2DynInputTensorNum) {
MS_LOG(DEBUG) << "GatherV2 in dynamic shape has wrong inputs num, not equal " << kGatherV2DynInputTensorNum
<< ". CNode= " << origin_node->DebugString();
return false;
}
auto param_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
auto indice_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 1);
// this optimizer only support embedding_table has dynamic shape
if (param_shape.empty() || indice_shape.empty() || AnfAlgo::IsDynamicShape(origin_node->input(kDim2))) {
return false;
}
if (param_shape[param_shape.size() - 1] != 1) {
MS_LOG(DEBUG) << "GatherV2 in dynamic shape is not need fission. The last value of input0's shape is "
<< param_shape[param_shape.size() - 1];
return false;
}
return true;
}
} // namespace
// only pad operator can run in dynamic shape.
CNodePtr CreatePad(const FuncGraphPtr &graph, const CNodePtr &origin_node, const size_t &pad_dim_size) {
CNodePtr GatherV2DsFission::CreatePad(const FuncGraphPtr &graph, const CNodePtr &origin_node,
const size_t &pad_dim_size) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(origin_node);
std::vector<AnfNodePtr> pad_inputs = {NewValueNode(std::make_shared<Primitive>(kPadOpName)), origin_node->input(1)};
auto pad = graph->NewCNode(pad_inputs);
auto pad = NewCNode(pad_inputs, graph);
MS_EXCEPTION_IF_NULL(pad);
pad->set_scope(origin_node->scope());
@ -83,8 +107,8 @@ CNodePtr CreatePad(const FuncGraphPtr &graph, const CNodePtr &origin_node, const
return pad;
}
CNodePtr CreateGatherV2Ds(const FuncGraphPtr &graph, const CNodePtr &origin_node, const CNodePtr &pad,
const size_t &pad_dim_size) {
CNodePtr GatherV2DsFission::CreateGatherV2Ds(const FuncGraphPtr &graph, const CNodePtr &origin_node,
const CNodePtr &pad, const size_t &pad_dim_size) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(origin_node);
MS_EXCEPTION_IF_NULL(pad);
@ -94,7 +118,7 @@ CNodePtr CreateGatherV2Ds(const FuncGraphPtr &graph, const CNodePtr &origin_node
std::vector<AnfNodePtr> gatherv2_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimGather->name())), pad,
origin_node->input(kGatherInputIndicesIndex),
origin_node->input(kGatherInputAxisIndex)};
auto gather_v2 = graph->NewCNode(gatherv2_inputs);
auto gather_v2 = NewCNode(gatherv2_inputs, graph);
MS_EXCEPTION_IF_NULL(gather_v2);
gather_v2->set_scope(origin_node->scope());
@ -110,12 +134,13 @@ CNodePtr CreateGatherV2Ds(const FuncGraphPtr &graph, const CNodePtr &origin_node
return gather_v2;
}
CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &gather_v2, const CNodePtr &gather_v2_padding_8) {
CNodePtr GatherV2DsFission::CreateSlice(const FuncGraphPtr &graph, const CNodePtr &gather_v2,
const CNodePtr &gather_v2_padding_8) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(gather_v2);
MS_EXCEPTION_IF_NULL(gather_v2_padding_8);
std::vector<AnfNodePtr> slice_inputs = {NewValueNode(std::make_shared<Primitive>(kSliceOpName)), gather_v2_padding_8};
auto slice = graph->NewCNode(slice_inputs);
auto slice = NewCNode(slice_inputs, graph);
MS_EXCEPTION_IF_NULL(slice);
slice->set_scope(gather_v2->scope());
slice->set_abstract(gather_v2->abstract());
@ -126,28 +151,6 @@ CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &gather_v2, const
return slice;
}
bool CheckInputs(const CNodePtr &origin_node) {
MS_EXCEPTION_IF_NULL(origin_node);
if (AnfAlgo::GetInputTensorNum(origin_node) != kGatherV2DynInputTensorNum) {
MS_LOG(DEBUG) << "GatherV2 in dynamic shape has wrong inputs num, not equal " << kGatherV2DynInputTensorNum
<< ". CNode= " << origin_node->DebugString();
return false;
}
auto param_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
auto indice_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 1);
// this optimizer only support embedding_table has dynamic shape
if (param_shape.empty() || indice_shape.empty() || AnfAlgo::IsDynamicShape(origin_node->input(kDim2))) {
return false;
}
if (param_shape[param_shape.size() - 1] != 1) {
MS_LOG(DEBUG) << "GatherV2 in dynamic shape is not need fission. The last value of input0's shape is "
<< param_shape[param_shape.size() - 1];
return false;
}
return true;
}
} // namespace
const BaseRef GatherV2DsFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
VectorRef pattern({prim::kPrimGather, Xs});

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -30,6 +30,12 @@ class GatherV2DsFission : public PatternProcessPass {
~GatherV2DsFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
CNodePtr CreatePad(const FuncGraphPtr &graph, const CNodePtr &origin_node, const size_t &pad_dim_size) const;
CNodePtr CreateGatherV2Ds(const FuncGraphPtr &graph, const CNodePtr &origin_node, const CNodePtr &pad,
const size_t &pad_dim_size) const;
CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &gather_v2, const CNodePtr &gather_v2_padding_8) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,7 +15,6 @@
*/
#include "backend/optimizer/ascend/ir_fission/lars_v2_fission.h"
#include <memory>
#include <vector>
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/common/helper.h"
#include "utils/utils.h"
@ -29,14 +28,16 @@ constexpr size_t kLarsV2WIndex = 1;
constexpr size_t kLarsV2GIndex = 2;
constexpr size_t kLarsV2WeightDecayIndex = 3;
constexpr size_t kLarsV2LearningRatIndex = 4;
void CreateOutputsOfSquareSumAll(const FuncGraphPtr &graph, const CNodePtr &lars_v2,
std::vector<AnfNodePtr> *square_sum_all_outputs) {
} // namespace
void LarsV2Fission::CreateOutputsOfSquareSumAll(const FuncGraphPtr &graph, const CNodePtr &lars_v2,
std::vector<AnfNodePtr> *square_sum_all_outputs) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(lars_v2);
CheckCNodeInputSize(lars_v2, kLarsV2InputTensorNum);
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kSquareSumAllOpName)), lars_v2->input(1),
lars_v2->input(2)};
auto square_sum_all = graph->NewCNode(inputs);
auto square_sum_all = NewCNode(inputs, graph);
MS_EXCEPTION_IF_NULL(square_sum_all);
square_sum_all->set_scope(lars_v2->scope());
@ -47,8 +48,8 @@ void CreateOutputsOfSquareSumAll(const FuncGraphPtr &graph, const CNodePtr &lars
CreateMultipleOutputsOfAnfNode(graph, square_sum_all, kSquareSumOutputNum, square_sum_all_outputs);
}
CNodePtr CreateLarsV2Update(const FuncGraphPtr &graph, const CNodePtr &lars_v2,
const std::vector<AnfNodePtr> &square_sum_all_outputs) {
CNodePtr LarsV2Fission::CreateLarsV2Update(const FuncGraphPtr &graph, const CNodePtr &lars_v2,
const std::vector<AnfNodePtr> &square_sum_all_outputs) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(lars_v2);
if (square_sum_all_outputs.size() != kSquareSumOutputNum) {
@ -63,13 +64,12 @@ CNodePtr CreateLarsV2Update(const FuncGraphPtr &graph, const CNodePtr &lars_v2,
square_sum_all_outputs[1],
lars_v2->input(kLarsV2WeightDecayIndex),
lars_v2->input(kLarsV2LearningRatIndex)};
auto lars_v2_update = graph->NewCNode(inputs);
auto lars_v2_update = NewCNode(inputs, graph);
MS_EXCEPTION_IF_NULL(lars_v2_update);
lars_v2_update->set_scope(lars_v2->scope());
lars_v2_update->set_abstract(lars_v2->abstract());
return lars_v2_update;
}
} // namespace
const BaseRef LarsV2Fission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,6 +16,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_LARS_V2_FISSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_LARS_V2_FISSION_H_
#include <vector>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
@ -26,6 +27,12 @@ class LarsV2Fission : public PatternProcessPass {
~LarsV2Fission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
void CreateOutputsOfSquareSumAll(const FuncGraphPtr &graph, const CNodePtr &lars_v2,
std::vector<AnfNodePtr> *square_sum_all_outputs) const;
CNodePtr CreateLarsV2Update(const FuncGraphPtr &graph, const CNodePtr &lars_v2,
const std::vector<AnfNodePtr> &square_sum_all_outputs) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -43,7 +43,7 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormXBackpropV2(const FuncGraphPtr
for (size_t i = 1; i < layer_norm_grad->inputs().size(); ++i) {
layer_norm_x_backprop_inputs.push_back(layer_norm_grad->input(i));
}
auto layer_norm_x_backprop = graph->NewCNode(layer_norm_x_backprop_inputs);
auto layer_norm_x_backprop = NewCNode(layer_norm_x_backprop_inputs, graph);
MS_EXCEPTION_IF_NULL(layer_norm_x_backprop);
layer_norm_x_backprop->set_scope(layer_norm_grad->scope());
auto types = {AnfAlgo::GetOutputInferDataType(layer_norm_grad, 0), kNumberTypeFloat32};
@ -68,7 +68,7 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormBetaGammaBackpropV2(
auto prim = std::make_shared<Primitive>(kLayerNormBetaGammaBackpropV2OpName);
std::vector<AnfNodePtr> layer_norm_beta_gamma_backprop_inputs = {NewValueNode(prim), layer_norm_grad->input(kIndex2),
res_for_gamma};
auto layer_norm_beta_gamma_backprop = graph->NewCNode(layer_norm_beta_gamma_backprop_inputs);
auto layer_norm_beta_gamma_backprop = NewCNode(layer_norm_beta_gamma_backprop_inputs, graph);
MS_EXCEPTION_IF_NULL(layer_norm_beta_gamma_backprop);
auto kernel_info = std::make_shared<device::KernelInfo>();
layer_norm_beta_gamma_backprop->set_kernel_info(kernel_info);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -109,7 +109,7 @@ const AnfNodePtr LinSpaceFission::Process(const FuncGraphPtr &graph, const AnfNo
auto assist_const = CreateValueNode(cnode);
new_inputs.push_back(assist_const);
new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
CNodePtr new_cnode = graph->NewCNode(new_inputs);
CNodePtr new_cnode = NewCNode(new_inputs, graph);
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(cnode->abstract());
new_cnode->set_scope(cnode->scope());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except i n compliance with the License.
@ -112,7 +112,7 @@ const AnfNodePtr MaxPool3DGradGradFission::Process(const FuncGraphPtr &graph, co
auto assist_const = CreateValueNode(cnode);
(void)new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
(void)new_inputs.emplace_back(assist_const);
CNodePtr new_cnode = graph->NewCNode(new_inputs);
CNodePtr new_cnode = NewCNode(new_inputs, graph);
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(cnode->abstract());
new_cnode->set_scope(cnode->scope());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,16 +21,15 @@
namespace mindspore {
namespace opt {
namespace {
AnfNodePtr CreateNewPack(const FuncGraphPtr &func_graph, const CNodePtr &origin_pack_cnode, size_t begin_index,
size_t offset) {
AnfNodePtr PackFission::CreateNewPack(const FuncGraphPtr &func_graph, const CNodePtr &origin_pack_cnode,
size_t begin_index, size_t offset) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(origin_pack_cnode);
std::vector<AnfNodePtr> new_pack_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimStack->name()))};
for (size_t i = begin_index; i < begin_index + offset; ++i) {
new_pack_inputs.push_back(origin_pack_cnode->input(i));
}
CNodePtr new_pack = func_graph->NewCNode(new_pack_inputs);
CNodePtr new_pack = NewCNode(new_pack_inputs, func_graph);
MS_EXCEPTION_IF_NULL(new_pack);
new_pack->set_scope(origin_pack_cnode->scope());
new_pack->set_abstract(origin_pack_cnode->abstract());
@ -58,7 +57,6 @@ AnfNodePtr CreateNewPack(const FuncGraphPtr &func_graph, const CNodePtr &origin_
new_pack.get());
return new_pack;
}
} // namespace
const BaseRef PackFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
@ -90,7 +88,7 @@ const AnfNodePtr PackFission::Process(const FuncGraphPtr &func_graph, const AnfN
CreateNewPack(func_graph, cnode, cur_input_index, origin_input_size - cur_input_index + 1));
}
CNodePtr base_concat = func_graph->NewCNode(base_concat_inputs);
CNodePtr base_concat = NewCNode(base_concat_inputs, func_graph);
MS_EXCEPTION_IF_NULL(base_concat);
base_concat->set_scope(cnode->scope());
base_concat->set_abstract(cnode->abstract());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -30,6 +30,8 @@ class PackFission : public PatternProcessPass {
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
AnfNodePtr CreateNewPack(const FuncGraphPtr &func_graph, const CNodePtr &origin_pack_cnode, size_t begin_index,
size_t offset) const;
size_t inputs_divisor_;
};
} // namespace opt

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,18 +21,6 @@
namespace mindspore {
namespace opt {
namespace {
CNodePtr CreateReduceMin(const FuncGraphPtr &graph, const AnfNodePtr &input, const CNodePtr &old_node) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(input);
MS_EXCEPTION_IF_NULL(old_node);
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceMin->name())), input};
CNodePtr reduce_min = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(reduce_min);
reduce_min->set_scope(old_node->scope());
AnfAlgo::CopyNodeAttr(kAttrKeepDims, old_node, reduce_min);
return reduce_min;
}
bool NeedOptimize(const TypeId &dtype, const std::vector<size_t> &shape, const std::vector<int64_t> &axis) {
if (dtype != kNumberTypeFloat32) {
MS_LOG(INFO) << "ReduceMin's input Dtype is not float32, no need to optimize!";
@ -97,6 +85,19 @@ std::vector<size_t> GetInferShape(const std::vector<size_t> &shape, const std::v
}
} // namespace
CNodePtr ReduceMinFission::CreateReduceMin(const FuncGraphPtr &graph, const AnfNodePtr &input,
const CNodePtr &old_node) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(input);
MS_EXCEPTION_IF_NULL(old_node);
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceMin->name())), input};
CNodePtr reduce_min = NewCNode(inputs, graph);
MS_EXCEPTION_IF_NULL(reduce_min);
reduce_min->set_scope(old_node->scope());
AnfAlgo::CopyNodeAttr(kAttrKeepDims, old_node, reduce_min);
return reduce_min;
}
const BaseRef ReduceMinFission::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
return VectorRef({prim::kPrimReduceMin, X});

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,6 +27,9 @@ class ReduceMinFission : public PatternProcessPass {
~ReduceMinFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
CNodePtr CreateReduceMin(const FuncGraphPtr &graph, const AnfNodePtr &input, const CNodePtr &old_node) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -24,8 +24,9 @@ namespace mindspore {
namespace opt {
namespace {
constexpr size_t kBatchNormRealInputNum = 3;
} // namespace
AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) {
AnfNodePtr SingleBatchNormFission::CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(bn);
auto bn_cnode = bn->cast<CNodePtr>();
@ -36,7 +37,7 @@ AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodeP
}
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)), bn_cnode->input(1)};
auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs);
auto bn_training_reduce = NewCNode(bn_training_reduce_inputs, func_graph);
MS_EXCEPTION_IF_NULL(bn_training_reduce);
// set abstract
@ -49,8 +50,9 @@ AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodeP
return bn_training_reduce;
}
AnfNodePtr CreateBNTrainingUpdateV3(const FuncGraphPtr &func_graph, const AnfNodePtr &bn,
const std::vector<AnfNodePtr> &bn_training_reduce_outputs) {
AnfNodePtr SingleBatchNormFission::CreateBNTrainingUpdateV3(
const FuncGraphPtr &func_graph, const AnfNodePtr &bn,
const std::vector<AnfNodePtr> &bn_training_reduce_outputs) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(bn);
auto bn_cnode = bn->cast<CNodePtr>();
@ -71,7 +73,7 @@ AnfNodePtr CreateBNTrainingUpdateV3(const FuncGraphPtr &func_graph, const AnfNod
bn_training_reduce_outputs[kIndex1],
bn_cnode->input(kIndex2),
bn_cnode->input(kIndex3)};
auto bn_training_update_v3 = func_graph->NewCNode(bn_training_update_v3_inputs);
auto bn_training_update_v3 = NewCNode(bn_training_update_v3_inputs, func_graph);
MS_EXCEPTION_IF_NULL(bn_training_update_v3);
auto bn_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn->abstract());
@ -85,7 +87,6 @@ AnfNodePtr CreateBNTrainingUpdateV3(const FuncGraphPtr &func_graph, const AnfNod
AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_cnode, bn_training_update_v3);
return bn_training_update_v3;
}
} // namespace
const BaseRef SingleBatchNormFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,6 +16,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_
#include <vector>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
@ -27,6 +28,11 @@ class SingleBatchNormFission : public PatternProcessPass {
~SingleBatchNormFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) const;
AnfNodePtr CreateBNTrainingUpdateV3(const FuncGraphPtr &func_graph, const AnfNodePtr &bn,
const std::vector<AnfNodePtr> &bn_training_reduce_outputs) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -116,7 +116,7 @@ const AnfNodePtr SpaceToDepthSplit::Process(const FuncGraphPtr &graph, const Anf
auto last_input_value = CreateValueNode(cnode);
(void)new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
(void)new_inputs.emplace_back(last_input_value);
CNodePtr new_cnode = graph->NewCNode(new_inputs);
CNodePtr new_cnode = NewCNode(new_inputs, graph);
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(cnode->abstract());
new_cnode->set_scope(cnode->scope());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -22,22 +22,6 @@
namespace mindspore {
namespace opt {
namespace {
CNodePtr CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(input_node);
std::vector<AnfNodePtr> splitv_inputs{NewValueNode(std::make_shared<Primitive>(kSplitVOpName)), input_node};
CNodePtr splitv = func_graph->NewCNode(splitv_inputs);
MS_EXCEPTION_IF_NULL(splitv);
splitv->set_scope(input_node->scope());
return splitv;
}
CNodePtr CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) {
MS_EXCEPTION_IF_NULL(origin_cnode);
CheckCNodeInputSize(origin_cnode, kSplitInputTensorNum);
return CreateSplitVNode(func_graph, origin_cnode->input(1));
}
void SetAttrForSplitVNode(const AnfNodePtr &splitv, const std::vector<int64_t> &size_splits, int64_t split_dim,
int64_t num_split) {
AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(size_splits), splitv);
@ -121,6 +105,22 @@ void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePt
}
} // namespace
CNodePtr SplitFission::CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(input_node);
std::vector<AnfNodePtr> splitv_inputs{NewValueNode(std::make_shared<Primitive>(kSplitVOpName)), input_node};
CNodePtr splitv = NewCNode(splitv_inputs, func_graph);
MS_EXCEPTION_IF_NULL(splitv);
splitv->set_scope(input_node->scope());
return splitv;
}
CNodePtr SplitFission::CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) const {
MS_EXCEPTION_IF_NULL(origin_cnode);
CheckCNodeInputSize(origin_cnode, kSplitInputTensorNum);
return CreateSplitVNode(func_graph, origin_cnode->input(1));
}
AnfNodePtr SplitFission::DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int64_t num_split,
int64_t divisor, int64_t split_dim) const {
MS_EXCEPTION_IF_NULL(func_graph);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -34,6 +34,8 @@ class SplitFission : public PatternProcessPass {
protected:
AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int64_t num_split, int64_t divisor,
int64_t split_dim) const;
CNodePtr CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node) const;
CNodePtr CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) const;
int64_t outputs_divisor_;
};
} // namespace opt

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,13 +21,13 @@
namespace mindspore {
namespace opt {
namespace {
CNodePtr CreateTensorMove(const FuncGraphPtr &graph, const CNodePtr &tensor_scatter_update) {
CNodePtr TensorScatterUpdateFission::CreateTensorMove(const FuncGraphPtr &graph,
const CNodePtr &tensor_scatter_update) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(tensor_scatter_update);
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kTensorMoveOpName)),
tensor_scatter_update->input(1)};
auto tensor_move = graph->NewCNode(inputs);
auto tensor_move = NewCNode(inputs, graph);
MS_EXCEPTION_IF_NULL(tensor_move);
tensor_move->set_scope(tensor_scatter_update->scope());
tensor_move->set_abstract(tensor_scatter_update->abstract());
@ -35,20 +35,20 @@ CNodePtr CreateTensorMove(const FuncGraphPtr &graph, const CNodePtr &tensor_scat
return tensor_move;
}
CNodePtr CreateScatterNdUpdate(const FuncGraphPtr &graph, const CNodePtr &tensor_scatter_update,
const CNodePtr &tensor_move) {
CNodePtr TensorScatterUpdateFission::CreateScatterNdUpdate(const FuncGraphPtr &graph,
const CNodePtr &tensor_scatter_update,
const CNodePtr &tensor_move) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(tensor_scatter_update);
MS_EXCEPTION_IF_NULL(tensor_move);
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kScatterNdUpdateOpName)), tensor_move,
tensor_scatter_update->input(2), tensor_scatter_update->input(3)};
auto scatter_nd_update = graph->NewCNode(inputs);
auto scatter_nd_update = NewCNode(inputs, graph);
MS_EXCEPTION_IF_NULL(scatter_nd_update);
scatter_nd_update->set_scope(tensor_scatter_update->scope());
scatter_nd_update->set_abstract(tensor_scatter_update->abstract());
return scatter_nd_update;
}
} // namespace
const BaseRef TensorScatterUpdateFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,6 +27,11 @@ class TensorScatterUpdateFission : public PatternProcessPass {
~TensorScatterUpdateFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
CNodePtr CreateTensorMove(const FuncGraphPtr &graph, const CNodePtr &tensor_scatter_update) const;
CNodePtr CreateScatterNdUpdate(const FuncGraphPtr &graph, const CNodePtr &tensor_scatter_update,
const CNodePtr &tensor_move) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,9 +27,10 @@
#include "utils/ms_context.h"
namespace mindspore::opt {
namespace {
constexpr size_t kFloat16Len = 2; // size of float16;
constexpr size_t kTopkIndexK = 1;
namespace {
tensor::TensorPtr CreateTensor() {
// 1 create tensor
const size_t last_dim = 4096;
@ -142,7 +143,7 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
// Copy a new node to check supported.
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kTopKOpName))};
new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
CNodePtr new_cnode = func_graph->NewCNode(new_inputs);
CNodePtr new_cnode = NewCNode(new_inputs, func_graph);
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(cnode->abstract());
new_cnode->set_scope(cnode->scope());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -23,59 +23,6 @@
namespace mindspore {
namespace opt {
namespace {
CNodePtr CreatePadding(const FuncGraphPtr &graph, const CNodePtr &origin_node, const size_t &pad_dim_size) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(origin_node);
std::vector<AnfNodePtr> padding_inputs = {NewValueNode(std::make_shared<Primitive>(kPaddingOpName)),
origin_node->input(kIndex1)};
auto padding = graph->NewCNode(padding_inputs);
MS_EXCEPTION_IF_NULL(padding);
padding->set_scope(origin_node->scope());
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
shape[shape.size() - 1] = pad_dim_size;
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(origin_node, 0)}, {shape},
padding.get());
AnfAlgo::SetNodeAttr(kAttrPadDimSize, MakeValue(SizeToLong(pad_dim_size)), padding);
return padding;
}
CNodePtr CreateUnsortedSegmentSum(const FuncGraphPtr &graph, const CNodePtr &origin_node, const CNodePtr &padding,
const size_t &pad_dim_size) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(origin_node);
MS_EXCEPTION_IF_NULL(padding);
std::vector<AnfNodePtr> unsorted_segment_sum8_inputs = {
NewValueNode(std::make_shared<Primitive>(prim::kPrimUnsortedSegmentSum->name())), padding,
origin_node->input(kIndex2)};
auto unsorted_segment_sum = graph->NewCNode(unsorted_segment_sum8_inputs);
MS_EXCEPTION_IF_NULL(unsorted_segment_sum);
unsorted_segment_sum->set_scope(origin_node->scope());
auto shape = AnfAlgo::GetOutputInferShape(origin_node, 0);
shape[shape.size() - 1] = pad_dim_size;
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_node, 0)}, {shape},
unsorted_segment_sum.get());
AnfAlgo::SetNodeAttr(kAttrNumSegments, MakeValue(SizeToLong(shape[0])), unsorted_segment_sum);
return unsorted_segment_sum;
}
CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_sum,
const CNodePtr &unsorted_segment_sum8) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(unsort_segment_sum);
MS_EXCEPTION_IF_NULL(unsorted_segment_sum8);
std::vector<AnfNodePtr> slice_inputs = {NewValueNode(std::make_shared<Primitive>(kSliceOpName)),
unsorted_segment_sum8};
auto slice = graph->NewCNode(slice_inputs);
MS_EXCEPTION_IF_NULL(slice);
slice->set_scope(unsort_segment_sum->scope());
slice->set_abstract(unsort_segment_sum->abstract());
auto unsort_segment_sum_shape = AnfAlgo::GetOutputInferShape(unsort_segment_sum, 0);
std::vector<size_t> offsets(unsort_segment_sum_shape.size(), 0);
AnfAlgo::SetNodeAttr(kAttrBegin, MakeValue(Convert2Long(offsets)), slice);
AnfAlgo::SetNodeAttr(kAttrSize, MakeValue(Convert2Long(unsort_segment_sum_shape)), slice);
return slice;
}
bool CheckInputs(const CNodePtr &origin_node) {
MS_EXCEPTION_IF_NULL(origin_node);
if (AnfAlgo::GetInputTensorNum(origin_node) != kUnsortedSegmentSumInputTensorNum) {
@ -97,6 +44,60 @@ bool CheckInputs(const CNodePtr &origin_node) {
}
} // namespace
CNodePtr UnsortSegmentSumFission::CreatePadding(const FuncGraphPtr &graph, const CNodePtr &origin_node,
const size_t &pad_dim_size) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(origin_node);
std::vector<AnfNodePtr> padding_inputs = {NewValueNode(std::make_shared<Primitive>(kPaddingOpName)),
origin_node->input(kIndex1)};
auto padding = NewCNode(padding_inputs, graph);
MS_EXCEPTION_IF_NULL(padding);
padding->set_scope(origin_node->scope());
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
shape[shape.size() - 1] = pad_dim_size;
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(origin_node, 0)}, {shape},
padding.get());
AnfAlgo::SetNodeAttr(kAttrPadDimSize, MakeValue(SizeToLong(pad_dim_size)), padding);
return padding;
}
CNodePtr UnsortSegmentSumFission::CreateUnsortedSegmentSum(const FuncGraphPtr &graph, const CNodePtr &origin_node,
const CNodePtr &padding, const size_t &pad_dim_size) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(origin_node);
MS_EXCEPTION_IF_NULL(padding);
std::vector<AnfNodePtr> unsorted_segment_sum8_inputs = {
NewValueNode(std::make_shared<Primitive>(prim::kPrimUnsortedSegmentSum->name())), padding,
origin_node->input(kIndex2)};
auto unsorted_segment_sum = NewCNode(unsorted_segment_sum8_inputs, graph);
MS_EXCEPTION_IF_NULL(unsorted_segment_sum);
unsorted_segment_sum->set_scope(origin_node->scope());
auto shape = AnfAlgo::GetOutputInferShape(origin_node, 0);
shape[shape.size() - 1] = pad_dim_size;
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_node, 0)}, {shape},
unsorted_segment_sum.get());
AnfAlgo::SetNodeAttr(kAttrNumSegments, MakeValue(SizeToLong(shape[0])), unsorted_segment_sum);
return unsorted_segment_sum;
}
CNodePtr UnsortSegmentSumFission::CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_sum,
const CNodePtr &unsorted_segment_sum8) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(unsort_segment_sum);
MS_EXCEPTION_IF_NULL(unsorted_segment_sum8);
std::vector<AnfNodePtr> slice_inputs = {NewValueNode(std::make_shared<Primitive>(kSliceOpName)),
unsorted_segment_sum8};
auto slice = NewCNode(slice_inputs, graph);
MS_EXCEPTION_IF_NULL(slice);
slice->set_scope(unsort_segment_sum->scope());
slice->set_abstract(unsort_segment_sum->abstract());
auto unsort_segment_sum_shape = AnfAlgo::GetOutputInferShape(unsort_segment_sum, 0);
std::vector<size_t> offsets(unsort_segment_sum_shape.size(), 0);
AnfAlgo::SetNodeAttr(kAttrBegin, MakeValue(Convert2Long(offsets)), slice);
AnfAlgo::SetNodeAttr(kAttrSize, MakeValue(Convert2Long(unsort_segment_sum_shape)), slice);
return slice;
}
const BaseRef UnsortSegmentSumFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
VectorRef pattern({prim::kPrimUnsortedSegmentSum, Xs});

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -31,6 +31,13 @@ class UnsortSegmentSumFission : public PatternProcessPass {
~UnsortSegmentSumFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
CNodePtr CreatePadding(const FuncGraphPtr &graph, const CNodePtr &origin_node, const size_t &pad_dim_size) const;
CNodePtr CreateUnsortedSegmentSum(const FuncGraphPtr &graph, const CNodePtr &origin_node, const CNodePtr &padding,
const size_t &pad_dim_size) const;
CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_sum,
const CNodePtr &unsorted_segment_sum8) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -228,7 +228,7 @@ AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_g
auto add2_y_node = utils::cast<AnfNodePtr>((*equiv)[add2_y_]);
MS_EXCEPTION_IF_NULL(add2_y_node);
new_node_inputs.push_back(add2_y_node);
auto new_node = func_graph->NewCNode(new_node_inputs);
auto new_node = NewCNode(new_node_inputs, func_graph);
return new_node;
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -288,7 +288,7 @@ const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, c
return nullptr;
}
std::vector<AnfNodePtr> inputs = GetFusionNodeInputs(equiv, node);
auto fusion_node = graph->NewCNode(inputs);
auto fusion_node = NewCNode(inputs, graph);
MS_EXCEPTION_IF_NULL(fusion_node);
fusion_node->set_scope(sub0->scope());

View File

@ -304,7 +304,7 @@ const AnfNodePtr AvgPool3DFusion::Process(const FuncGraphPtr &func_graph, const
pad_list, count_include_pad);
new_inputs.push_back(multiplier);
}
auto new_3d = func_graph->NewCNode(new_inputs);
auto new_3d = NewCNode(new_inputs, func_graph);
MS_EXCEPTION_IF_NULL(new_3d);
new_3d->set_scope(avg_pool_3d_node->scope());
new_3d->set_abstract(avg_pool_3d_node->abstract());

View File

@ -235,7 +235,7 @@ const AnfNodePtr AvgPool3DGradFusion::Process(const FuncGraphPtr &func_graph, co
ConstructMultiplier(func_graph, dims_in, origin_input_shape, kernel_size, strides, pad_list, count_include_pad);
new_inputs.push_back(multiplier);
}
auto new_3d_grad = func_graph->NewCNode(new_inputs);
auto new_3d_grad = NewCNode(new_inputs, func_graph);
MS_EXCEPTION_IF_NULL(new_3d_grad);
new_3d_grad->set_scope(avg_pool_3d_grad_node->scope());
new_3d_grad->set_abstract(avg_pool_3d_grad_node->abstract());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -26,24 +26,6 @@
namespace mindspore {
namespace opt {
namespace {
CNodePtr CreateBNInfer(const FuncGraphPtr &graph, const CNodePtr &batchnorm, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(batchnorm);
MS_EXCEPTION_IF_NULL(node);
auto prim = std::make_shared<Primitive>(kBNInferOpName);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
for (size_t i = 1; i < batchnorm->size(); ++i) {
inputs.push_back(batchnorm->input(i));
}
auto new_node = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(new_node);
new_node->set_scope(batchnorm->scope());
new_node->set_abstract(node->abstract());
AnfAlgo::CopyNodeAttr(kAttrIsTraining, batchnorm, new_node);
AnfAlgo::CopyNodeAttr(kAttrEpsilon, batchnorm, new_node);
return new_node;
}
bool CheckIndex(const AnfNodePtr &index_node) {
MS_EXCEPTION_IF_NULL(index_node);
if (!IsValueNode<Int64Imm>(index_node)) {
@ -103,6 +85,25 @@ bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *bat
}
} // namespace
CNodePtr BatchNorm2BNInfer::CreateBNInfer(const FuncGraphPtr &graph, const CNodePtr &batchnorm,
const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(batchnorm);
MS_EXCEPTION_IF_NULL(node);
auto prim = std::make_shared<Primitive>(kBNInferOpName);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
for (size_t i = 1; i < batchnorm->size(); ++i) {
inputs.push_back(batchnorm->input(i));
}
auto new_node = NewCNode(inputs, graph);
MS_EXCEPTION_IF_NULL(new_node);
new_node->set_scope(batchnorm->scope());
new_node->set_abstract(node->abstract());
AnfAlgo::CopyNodeAttr(kAttrIsTraining, batchnorm, new_node);
AnfAlgo::CopyNodeAttr(kAttrEpsilon, batchnorm, new_node);
return new_node;
}
const BaseRef BatchNorm2BNInfer::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
VarPtr Y = std::make_shared<Var>();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,6 +27,9 @@ class BatchNorm2BNInfer : public PatternProcessPass {
~BatchNorm2BNInfer() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
CNodePtr CreateBNInfer(const FuncGraphPtr &graph, const CNodePtr &batchnorm, const AnfNodePtr &node) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -85,7 +85,7 @@ const AnfNodePtr BNReduceGradConv2dBackpropFilterFusion::Process(const FuncGraph
for (size_t i = 1; i <= kBNTrainingReduceGradInputNum; ++i) {
fused_dbn_dw_inputs.push_back(bnreduce_grad->input(i));
}
auto fused_dbn_dw = graph->NewCNode(fused_dbn_dw_inputs);
auto fused_dbn_dw = NewCNode(fused_dbn_dw_inputs, graph);
MS_EXCEPTION_IF_NULL(fused_dbn_dw);
auto types = {AnfAlgo::GetOutputInferDataType(bnreduce_grad, 0),
AnfAlgo::GetOutputInferDataType(conv_back_filter, 0)};

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -61,7 +61,7 @@ const AnfNodePtr ClipByNormNoDivSquareSumFusion::Process(const FuncGraphPtr &gra
auto prim = std::make_shared<Primitive>(kClipByNormNoDivSumOpName);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), input, constant_greater, constant_select, constant_maximum};
auto fusion_node = graph->NewCNode(inputs);
auto fusion_node = NewCNode(inputs, graph);
MS_EXCEPTION_IF_NULL(fusion_node);
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -87,7 +87,7 @@ const AnfNodePtr ClipByValueFusion::Process(const FuncGraphPtr &graph, const Anf
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), minimum->input(kIndex1),
is_first_input ? maximum_input1 : maximum_input0, minimum->input(kIndex2)};
auto clip_by_value = graph->NewCNode(inputs);
auto clip_by_value = NewCNode(inputs, graph);
MS_EXCEPTION_IF_NULL(clip_by_value);
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -30,28 +30,6 @@ namespace opt {
namespace {
const size_t kConfusionMulGradOutputNum = 2;
CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &reduce_sum, const AnfNodePtr &mul0_anf,
const AnfNodePtr &input3) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(reduce_sum);
MS_EXCEPTION_IF_NULL(mul0_anf);
MS_EXCEPTION_IF_NULL(input3);
auto mul0 = mul0_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(mul0);
auto prim = std::make_shared<Primitive>(kConfusionMulGradOpName);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), mul0->input(kIndex1), mul0->input(kIndex2), input3};
auto fusion_node = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(fusion_node);
fusion_node->set_scope(reduce_sum->scope());
AnfAlgo::CopyNodeAttr(kAttrAxis, reduce_sum, fusion_node);
AnfAlgo::CopyNodeAttr(kAttrKeepDims, reduce_sum, fusion_node);
auto types = {AnfAlgo::GetOutputInferDataType(mul0, 0), AnfAlgo::GetOutputInferDataType(reduce_sum, 0)};
auto shapes = {AnfAlgo::GetOutputInferShape(mul0, 0), AnfAlgo::GetOutputInferShape(reduce_sum, 0)};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get());
return fusion_node;
}
AnfNodePtr GetMul0(const FuncGraphPtr &graph, const AnfNodePtr &input2, const AnfNodePtr &mul1) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(input2);
@ -117,6 +95,28 @@ bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const Anf
}
} // namespace
CNodePtr ConfusionMulGradFusion::CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &reduce_sum,
const AnfNodePtr &mul0_anf, const AnfNodePtr &input3) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(reduce_sum);
MS_EXCEPTION_IF_NULL(mul0_anf);
MS_EXCEPTION_IF_NULL(input3);
auto mul0 = mul0_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(mul0);
auto prim = std::make_shared<Primitive>(kConfusionMulGradOpName);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), mul0->input(kIndex1), mul0->input(kIndex2), input3};
auto fusion_node = NewCNode(inputs, graph);
MS_EXCEPTION_IF_NULL(fusion_node);
fusion_node->set_scope(reduce_sum->scope());
AnfAlgo::CopyNodeAttr(kAttrAxis, reduce_sum, fusion_node);
AnfAlgo::CopyNodeAttr(kAttrKeepDims, reduce_sum, fusion_node);
auto types = {AnfAlgo::GetOutputInferDataType(mul0, 0), AnfAlgo::GetOutputInferDataType(reduce_sum, 0)};
auto shapes = {AnfAlgo::GetOutputInferShape(mul0, 0), AnfAlgo::GetOutputInferShape(reduce_sum, 0)};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get());
return fusion_node;
}
const BaseRef ConfusionMulGradFusion::DefinePattern() const {
VectorRef mul1({prim::kPrimMul, input3_, input2_});
VectorRef reduce_sum({prim::kPrimReduceSum, mul1});

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -33,6 +33,8 @@ class ConfusionMulGradFusion : public PatternProcessPass {
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &reduce_sum, const AnfNodePtr &mul0_anf,
const AnfNodePtr &input3) const;
VarPtr input2_;
VarPtr input3_;
};

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -35,15 +35,16 @@ CNodePtr GetRelu(const CNodePtr &relu_grad) {
MS_EXCEPTION_IF_NULL(relu_anf);
return relu_anf->cast<CNodePtr>();
}
} // namespace
CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) {
CNodePtr DereluFusion::CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(relu);
CheckCNodeInputSize(relu, kReluInputTensorNum);
constexpr auto kMaskShapeSize = 4;
auto prim = std::make_shared<Primitive>(kReluV2OpName);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), relu->input(kIndex1)};
auto new_node = graph->NewCNode(inputs);
auto new_node = NewCNode(inputs, graph);
MS_EXCEPTION_IF_NULL(new_node);
new_node->set_scope(relu->scope());
@ -79,20 +80,20 @@ CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) {
return new_node;
}
CNodePtr CreateReluGradV2(const FuncGraphPtr &graph, const CNodePtr &relu_grad, const AnfNodePtr &second_input) {
CNodePtr DereluFusion::CreateReluGradV2(const FuncGraphPtr &graph, const CNodePtr &relu_grad,
const AnfNodePtr &second_input) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(relu_grad);
MS_EXCEPTION_IF_NULL(second_input);
auto prim = std::make_shared<Primitive>(kReluGradV2OpName);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), relu_grad->input(1), second_input};
auto new_node = graph->NewCNode(inputs);
auto new_node = NewCNode(inputs, graph);
MS_EXCEPTION_IF_NULL(new_node);
new_node->set_scope(relu_grad->scope());
new_node->set_abstract(relu_grad->abstract());
return new_node;
}
} // namespace
const BaseRef DereluFusion::DefinePattern() const {
VarPtr i0 = std::make_shared<Var>();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,6 +27,10 @@ class DereluFusion : public PatternProcessPass {
~DereluFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) const;
CNodePtr CreateReluGradV2(const FuncGraphPtr &graph, const CNodePtr &relu_grad, const AnfNodePtr &second_input) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -92,7 +92,7 @@ AnfNodePtr FusedBatchNormFusion::CreateBNTrainingReduce(const FuncGraphPtr &func
// Set input to create node
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)), GetAnfNodeByVar(equiv, data_input0_var_)};
auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs);
auto bn_training_reduce = NewCNode(bn_training_reduce_inputs, func_graph);
MS_EXCEPTION_IF_NULL(bn_training_reduce);
bn_training_reduce->set_scope(node->scope());
// Set abstract
@ -151,7 +151,7 @@ AnfNodePtr FusedBatchNormFusion::CreateBNTrainingUpdate(
// Set input
std::vector<AnfNodePtr> bn_training_update_inputs;
GetBNTrainingUpdateInputs(equiv, bn_training_reduce_outputs, &bn_training_update_inputs);
auto bn_training_update = func_graph->NewCNode(bn_training_update_inputs);
auto bn_training_update = NewCNode(bn_training_update_inputs, func_graph);
MS_EXCEPTION_IF_NULL(bn_training_update);
// Set abstract
AnfNodePtr bn = GetAnfNodeByVar(equiv, batch_norm_var_);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -74,7 +74,7 @@ AnfNodePtr LambNextMVRule::CreateLambNextMVNode(const FuncGraphPtr &func_graph,
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul3_sub1_]));
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul4_x_]));
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[add2_y_]));
auto lamb_next_mv_rule = func_graph->NewCNode(lamb_next_mv_rule_inputs);
auto lamb_next_mv_rule = NewCNode(lamb_next_mv_rule_inputs, func_graph);
MS_EXCEPTION_IF_NULL(lamb_next_mv_rule);
// Set abstract of new node

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -74,7 +74,7 @@ AnfNodePtr LambNextMVWithDecayRule::CreateLambNextMVWithDecayNode(const FuncGrap
auto constant_add2_y_node = utils::cast<AnfNodePtr>((*equiv)[constant_add2_y_]);
MS_EXCEPTION_IF_NULL(constant_add2_y_node);
new_node_inputs.push_back(constant_add2_y_node);
auto new_node = func_graph->NewCNode(new_node_inputs);
auto new_node = NewCNode(new_node_inputs, func_graph);
return GetLambNextMVWithDecayOutput(func_graph, new_node, add3, add5, equiv);
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -44,7 +44,7 @@ AnfNodePtr LambNextRightRule::CreateLambNextRightNode(const FuncGraphPtr &func_g
auto add2_y = utils::cast<AnfNodePtr>((*equiv)[add2_y_]);
MS_EXCEPTION_IF_NULL(add2_y);
new_node_inputs.push_back(add2_y);
auto new_node = func_graph->NewCNode(new_node_inputs);
auto new_node = NewCNode(new_node_inputs, func_graph);
return new_node;
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -65,7 +65,7 @@ const AnfNodePtr LambUpdateWithLRRuleFusion::Process(const FuncGraphPtr &graph,
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {
NewValueNode(prim), input0, input1, input2, input3, input4, input5, input6, input7, input8};
auto lamb_update_with_lr = graph->NewCNode(inputs);
auto lamb_update_with_lr = NewCNode(inputs, graph);
MS_EXCEPTION_IF_NULL(lamb_update_with_lr);
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -47,7 +47,7 @@ const AnfNodePtr LambUpdateWithLrV2::Process(const FuncGraphPtr &func_graph, con
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
(void)std::transform(input_varptr_.begin(), input_varptr_.end(), std::back_inserter(inputs),
[&equiv](const VarPtr &in) { return utils::cast<AnfNodePtr>((*equiv)[in]); });
auto lamb_update_with_lr_v2 = func_graph->NewCNode(inputs);
auto lamb_update_with_lr_v2 = NewCNode(inputs, func_graph);
MS_EXCEPTION_IF_NULL(lamb_update_with_lr_v2);
lamb_update_with_lr_v2->set_abstract(node->abstract());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -48,7 +48,7 @@ const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &graph, const A
inputs.emplace_back(GetAnfNodeByVar(equiv, x0_));
inputs.emplace_back(GetAnfNodeByVar(equiv, x1_));
inputs.emplace_back(GetAnfNodeByVar(equiv, x2_));
auto new_node = graph->NewCNode(inputs);
auto new_node = NewCNode(inputs, graph);
MS_EXCEPTION_IF_NULL(new_node);
new_node->set_scope(node->scope());
new_node->set_abstract(node->abstract());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -85,7 +85,7 @@ const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph
mul_cnode->input(kMulInputTensorNum + 1 - value_node_index),
depend,
mul_cnode->input(value_node_index)};
auto new_node = func_graph->NewCNode(new_node_inputs);
auto new_node = NewCNode(new_node_inputs, func_graph);
MS_EXCEPTION_IF_NULL(new_node);
AnfAlgo::CopyNodeAttrs(node, new_node);
auto input_names_value = AnfAlgo::GetNodeAttr<std::vector<std::string>>(new_node, kAttrInputNames);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -90,7 +90,7 @@ const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodeP
return nullptr;
}
inputs.push_back(another_input_node);
auto fusion_node = graph->NewCNode(inputs);
auto fusion_node = NewCNode(inputs, graph);
fusion_node->set_scope(add->scope());
fusion_node->set_abstract(add->abstract());
return fusion_node;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -22,7 +22,6 @@
namespace mindspore {
namespace opt {
namespace {
CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &mul, const CNodePtr &addn,
const size_t &lossscale_input_index) {
MS_EXCEPTION_IF_NULL(graph);
@ -34,13 +33,12 @@ CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &mul, const
inputs.push_back(addn->input(kIndex2));
// scalar input should be 3rd input
inputs.push_back(mul->input(lossscale_input_index));
auto fusion_node = graph->NewCNode(inputs);
auto fusion_node = NewCNode(inputs, graph);
MS_EXCEPTION_IF_NULL(fusion_node);
fusion_node->set_scope(addn->scope());
fusion_node->set_abstract(addn->abstract());
return fusion_node;
}
} // namespace
const BaseRef MulAddNFusion::DefinePattern() const {
VarPtr X = std::make_shared<Var>();

View File

@ -49,7 +49,7 @@ const AnfNodePtr PReluFusion::Process(const FuncGraphPtr &graph, const AnfNodePt
auto prim = std::make_shared<Primitive>(kPReluOpName);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, weight};
auto fusion_node = graph->NewCNode(inputs);
auto fusion_node = NewCNode(inputs, graph);
MS_EXCEPTION_IF_NULL(fusion_node);
fusion_node->set_abstract(node->abstract());
fusion_node->set_scope(node->scope());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -69,7 +69,7 @@ const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph,
auto prim = std::make_shared<Primitive>(kConfusionTransposeDOpName);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), utils::cast<AnfNodePtr>((*equiv)[input_varptr_])};
auto new_node = func_graph->NewCNode(inputs);
auto new_node = NewCNode(inputs, func_graph);
MS_EXCEPTION_IF_NULL(new_node);
new_node->set_abstract(node->abstract());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -65,7 +65,7 @@ const AnfNodePtr SoftmaxGradExtFusion::Process(const FuncGraphPtr &graph, const
}
auto prim = std::make_shared<Primitive>(kSoftmaxGradExtOpName);
auto fusion_node = graph->NewCNode({NewValueNode(prim), input0, input1, input2});
auto fusion_node = NewCNode({NewValueNode(prim), input0, input1, input2}, graph);
MS_EXCEPTION_IF_NULL(fusion_node);
fusion_node->set_scope(node->scope());
fusion_node->set_abstract(node->abstract());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -30,7 +30,22 @@
namespace mindspore {
namespace opt {
namespace {
CNodePtr GenerateSquareSumV1(const FuncGraphPtr &graph, const CNodePtr &square, const CNodePtr &sum) {
std::tuple<CNodePtr, AnfNodePtr, CNodePtr> GetPrevNodes(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto sum = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(sum);
CheckCNodeInputSize(sum, kSumNodeInputTensorNum);
auto square_anf = sum->input(1);
MS_EXCEPTION_IF_NULL(square_anf);
auto square = square_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(square);
return std::make_tuple(sum, square_anf, square);
}
} // namespace
CNodePtr SquareSumFusion::GenerateSquareSumV1(const FuncGraphPtr &graph, const CNodePtr &square,
const CNodePtr &sum) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(square);
MS_EXCEPTION_IF_NULL(sum);
@ -38,7 +53,7 @@ CNodePtr GenerateSquareSumV1(const FuncGraphPtr &graph, const CNodePtr &square,
auto prim = std::make_shared<Primitive>(kSquareSumV1OpName);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> square_sumv1_inputs = {NewValueNode(prim), square->input(1)};
auto square_sumv1 = graph->NewCNode(square_sumv1_inputs);
auto square_sumv1 = NewCNode(square_sumv1_inputs, graph);
MS_EXCEPTION_IF_NULL(square_sumv1);
auto kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(kernel_info);
@ -54,7 +69,8 @@ CNodePtr GenerateSquareSumV1(const FuncGraphPtr &graph, const CNodePtr &square,
return square_sumv1;
}
CNodePtr GenerateSquareSumV2(const FuncGraphPtr &graph, const CNodePtr &square, const CNodePtr &sum) {
CNodePtr SquareSumFusion::GenerateSquareSumV2(const FuncGraphPtr &graph, const CNodePtr &square,
const CNodePtr &sum) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(square);
MS_EXCEPTION_IF_NULL(sum);
@ -62,7 +78,7 @@ CNodePtr GenerateSquareSumV2(const FuncGraphPtr &graph, const CNodePtr &square,
auto prim = std::make_shared<Primitive>(kSquareSumV2OpName);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> square_sumv2_inputs = {NewValueNode(prim), square->input(1)};
auto square_sumv2 = graph->NewCNode(square_sumv2_inputs);
auto square_sumv2 = NewCNode(square_sumv2_inputs, graph);
MS_EXCEPTION_IF_NULL(square_sumv2);
auto types = {AnfAlgo::GetOutputInferDataType(sum, 0), AnfAlgo::GetOutputInferDataType(square, 0)};
auto shapes = {AnfAlgo::GetOutputInferShape(sum, 0), AnfAlgo::GetOutputInferShape(square, 0)};
@ -75,20 +91,6 @@ CNodePtr GenerateSquareSumV2(const FuncGraphPtr &graph, const CNodePtr &square,
return square_sumv2;
}
std::tuple<CNodePtr, AnfNodePtr, CNodePtr> GetPrevNodes(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto sum = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(sum);
CheckCNodeInputSize(sum, kSumNodeInputTensorNum);
auto square_anf = sum->input(1);
MS_EXCEPTION_IF_NULL(square_anf);
auto square = square_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(square);
return std::make_tuple(sum, square_anf, square);
}
} // namespace
const BaseRef SquareSumFusion::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
MS_EXCEPTION_IF_NULL(X);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -28,6 +28,10 @@ class SquareSumFusion : public PatternProcessPassWithSwitch {
~SquareSumFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
CNodePtr GenerateSquareSumV1(const FuncGraphPtr &graph, const CNodePtr &square, const CNodePtr &sum) const;
CNodePtr GenerateSquareSumV2(const FuncGraphPtr &graph, const CNodePtr &square, const CNodePtr &sum) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -68,7 +68,7 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph,
}
auto prim = std::make_shared<Primitive>(kConfusionTransposeDOpName);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), utils::cast<AnfNodePtr>((*equiv)[input_varptr_])};
auto new_node = func_graph->NewCNode(inputs);
auto new_node = NewCNode(inputs, func_graph);
MS_EXCEPTION_IF_NULL(new_node);
new_node->set_abstract(node->abstract());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -57,7 +57,7 @@ const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_grap
if (supported_checker_->CheckAICoreSupported(transdata_cnode, new_transdata_builder->Build())) {
std::vector<AnfNodePtr> inputs = {NewValueNode(new_fusion_transdata),
utils::cast<AnfNodePtr>((*equiv)[input_varptr_])};
auto new_node = func_graph->NewCNode(inputs);
auto new_node = NewCNode(inputs, func_graph);
MS_EXCEPTION_IF_NULL(new_node);
new_node->set_abstract(node->abstract());
AnfAlgo::CopyNodeAttrs(transdata_cnode, new_node);

View File

@ -79,7 +79,7 @@ const AnfNodePtr TransposedUpdateFusion::Process(const FuncGraphPtr &func_graph,
auto perm_vnode = CreatePermValueNode(transposed);
std::vector<AnfNodePtr> transpose_inputs = {NewValueNode(std::make_shared<Primitive>(kTransposeNODOpName)),
transposed->input(1), perm_vnode};
auto transpose = kernel_graph->NewCNode(transpose_inputs);
auto transpose = NewCNode(transpose_inputs, kernel_graph);
transpose->set_scope(transposed->scope());
transpose->set_abstract(transposed->abstract());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -53,8 +53,9 @@ uint32_t GetRankSize(const std::string &group) {
}
return rank_size;
}
} // namespace
CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all) {
CNodePtr AllToAllUnifyMindIR::CreateSplitNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(all_to_all);
int64_t split_count = AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrSplitCount);
@ -66,7 +67,7 @@ CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all)
auto all_to_all_input = all_to_all->input(kAllToAllInputIdx);
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
all_to_all_input};
auto split_v = graph->NewCNode(split_input);
auto split_v = NewCNode(split_input, graph);
MS_EXCEPTION_IF_NULL(split_v);
auto dtype = AnfAlgo::GetOutputInferDataType(all_to_all_input, 0);
auto shape = AnfAlgo::GetOutputInferShape(all_to_all_input, 0);
@ -90,7 +91,8 @@ CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all)
return split_v;
}
CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all, const CNodePtr &split) {
CNodePtr AllToAllUnifyMindIR::CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all,
const CNodePtr &split) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(all_to_all);
MS_EXCEPTION_IF_NULL(split);
@ -103,7 +105,7 @@ CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &all_to_a
}
std::vector<AnfNodePtr> all_to_all_v_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllVOpName))};
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs.begin(), split_outputs.end());
auto all_to_all_v = graph->NewCNode(all_to_all_v_input);
auto all_to_all_v = NewCNode(all_to_all_v_input, graph);
MS_EXCEPTION_IF_NULL(all_to_all_v);
auto single_shape = AnfAlgo::GetOutputInferShape(split_outputs[0], 0);
auto single_type = AnfAlgo::GetOutputInferDataType(split_outputs[0], 0);
@ -123,7 +125,8 @@ CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &all_to_a
return all_to_all_v;
}
CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all, const CNodePtr &all_to_all_v) {
CNodePtr AllToAllUnifyMindIR::CreateConcatNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all,
const CNodePtr &all_to_all_v) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(all_to_all);
MS_EXCEPTION_IF_NULL(all_to_all_v);
@ -136,7 +139,7 @@ CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all,
}
std::vector<AnfNodePtr> concat_input = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))};
(void)concat_input.insert(concat_input.end(), all_to_all_v_outputs.begin(), all_to_all_v_outputs.end());
auto concat = graph->NewCNode(concat_input);
auto concat = NewCNode(concat_input, graph);
MS_EXCEPTION_IF_NULL(concat);
auto single_shape = AnfAlgo::GetOutputInferShape(all_to_all_v_outputs[0], 0);
concat_dim = NormalizeDim(single_shape, concat_dim);
@ -152,7 +155,6 @@ CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all,
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat);
return concat;
}
} // namespace
const BaseRef NeighborExchangeUnifyMindIR::DefinePattern() const {
return VectorRef({prim::kPrimNeighborExchange, std::make_shared<SeqVar>()});

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -36,6 +36,11 @@ class AllToAllUnifyMindIR : public PatternProcessPass {
~AllToAllUnifyMindIR() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all) const;
CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all, const CNodePtr &split) const;
CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all, const CNodePtr &all_to_all_v) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -203,7 +203,7 @@ const AnfNodePtr AvgPoolGradUnifyMindIR::Process(const FuncGraphPtr &graph, cons
std::vector<AnfNodePtr> avgpool_grad_vm_inputs = {NewValueNode(std::make_shared<Primitive>(kAvgPoolGradVmOpName)),
x_shape_vnode, avgpool_grad->input(3), mean_matrix_vnode,
kernel_matrix_vnode};
auto avgpool_grad_vm = graph->NewCNode(avgpool_grad_vm_inputs);
auto avgpool_grad_vm = NewCNode(avgpool_grad_vm_inputs, graph);
MS_EXCEPTION_IF_NULL(avgpool_grad_vm);
avgpool_grad_vm->set_scope(avgpool_grad->scope());
avgpool_grad_vm->set_abstract(avgpool_grad->abstract());

View File

@ -28,8 +28,10 @@ namespace mindspore {
namespace opt {
namespace {
constexpr auto kAttrUnifyIRPassed = "unifyir_passed";
} // namespace
AnfNodePtr CreateNewBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node) {
AnfNodePtr BatchNormGradUnifyMindIR::CreateNewBatchNormGrad(const FuncGraphPtr &graph,
const CNodePtr &bn_grad_node) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(bn_grad_node);
size_t kBNGradInputNum = 6;
@ -41,7 +43,7 @@ AnfNodePtr CreateNewBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &bn_
bn_grad_node_inputs[kDim3],
bn_grad_node_inputs[kDim4],
bn_grad_node_inputs[kDim5]};
auto new_bn_grad = graph->NewCNode(bn_grad_inputs);
auto new_bn_grad = NewCNode(bn_grad_inputs, graph);
MS_EXCEPTION_IF_NULL(new_bn_grad);
new_bn_grad->set_scope(bn_grad_node->scope());
auto types = {AnfAlgo::GetOutputInferDataType(bn_grad_node, 0), AnfAlgo::GetOutputInferDataType(bn_grad_node, 1),
@ -56,7 +58,6 @@ AnfNodePtr CreateNewBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &bn_
AnfAlgo::SetNodeAttr(kAttrUnifyIRPassed, MakeValue(true), new_bn_grad);
return new_bn_grad;
}
} // namespace
const BaseRef BatchNormGradUnifyMindIR::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();

View File

@ -27,6 +27,9 @@ class BatchNormGradUnifyMindIR : public PatternProcessPass {
~BatchNormGradUnifyMindIR() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
AnfNodePtr CreateNewBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -91,7 +91,7 @@ ValueNodePtr CreatePermValueNode(const FuncGraphPtr &func_graph, const std::vect
}
CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, const AnfNodePtr &input_node,
bool need_trans_output) {
bool need_trans_output, const PatternProcessPass &pass) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(conv2d);
MS_EXCEPTION_IF_NULL(input_node);
@ -105,7 +105,7 @@ CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, cons
transpose_inputs = {NewValueNode(std::make_shared<Primitive>(kTransposeOpName)), input_node,
CreatePermValueNode(graph, perm)};
}
auto transpose = graph->NewCNode(transpose_inputs);
auto transpose = pass.NewCNode(transpose_inputs, graph);
MS_EXCEPTION_IF_NULL(transpose);
transpose->set_scope(conv2d->scope());
@ -133,82 +133,6 @@ CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, cons
return transpose;
}
CNodePtr CreateDepthwiseConv2D(const FuncGraphPtr &graph, const CNodePtr &conv2d, const CNodePtr &transpose) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(conv2d);
CheckCNodeInputSize(conv2d, kConvInputTensorNum);
std::vector<AnfNodePtr> depth_conv_inputs = {NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeOpName)),
conv2d->input(kIndex1), transpose};
auto depth_conv = graph->NewCNode(depth_conv_inputs);
MS_EXCEPTION_IF_NULL(depth_conv);
depth_conv->set_abstract(conv2d->abstract());
depth_conv->set_scope(conv2d->scope());
return depth_conv;
}
CNodePtr CreateDepthwiseConv2DBackpropInput(const FuncGraphPtr &graph, const CNodePtr &conv2d_backin,
const CNodePtr &transpose) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(conv2d_backin);
CNodePtr depth_conv_backin = nullptr;
if (conv2d_backin->inputs().size() == kConv2DBackpropInputNum) {
std::vector<AnfNodePtr> depth_conv_backin_inputs = {
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)),
conv2d_backin->input(kIndex3), transpose, conv2d_backin->input(kIndex1)};
depth_conv_backin = graph->NewCNode(depth_conv_backin_inputs);
} else {
// In nn.Conv2DTranspose, Conv2DBackpropInput is a forward op and the input_sizes input will be convert to attr
// in pynative mode.
std::vector<AnfNodePtr> depth_conv_backin_inputs = {
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)), transpose,
conv2d_backin->input(kIndex1)};
depth_conv_backin = graph->NewCNode(depth_conv_backin_inputs);
AnfAlgo::CopyNodeAttr(kAttrInputSizes, kAttrInputSize, conv2d_backin, depth_conv_backin);
}
MS_EXCEPTION_IF_NULL(depth_conv_backin);
depth_conv_backin->set_abstract(conv2d_backin->abstract());
depth_conv_backin->set_scope(conv2d_backin->scope());
return depth_conv_backin;
}
CNodePtr CreateDepthwiseConv2DBackpropFilter(const FuncGraphPtr &graph, const CNodePtr &conv2d_backfil) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(conv2d_backfil);
if (conv2d_backfil->inputs().size() != kConv2DBackpropInputNum) {
MS_LOG(EXCEPTION) << "Conv2DBackpropFilter's input number should be " << (kConv2DBackpropInputNum - 1)
<< ", but got " << (conv2d_backfil->inputs().size() - 1);
}
auto filter_size_node = conv2d_backfil->input(kIndex3);
MS_EXCEPTION_IF_NULL(filter_size_node);
auto filter_size_vnode = filter_size_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(filter_size_vnode);
auto filter_size = GetValue<std::vector<int64_t>>(filter_size_vnode->value());
// swap axis 0 and 1 of filter shape, but don't swap twice since some node share same filter_size valuenode
// when the filter_size value is same.
if (filter_size[0] != 1) {
std::swap(filter_size[0], filter_size[1]);
conv2d_backfil->input(kIndex3)->cast<ValueNodePtr>()->set_value(MakeValue(filter_size));
}
std::vector<AnfNodePtr> depth_conv_backfil_inputs = {
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropFilterOpName)),
conv2d_backfil->input(kIndex2), conv2d_backfil->input(kIndex3), conv2d_backfil->input(kIndex1)};
auto depth_conv_backfil = graph->NewCNode(depth_conv_backfil_inputs);
MS_EXCEPTION_IF_NULL(depth_conv_backfil);
depth_conv_backfil->set_scope(conv2d_backfil->scope());
auto types = {AnfAlgo::GetOutputInferDataType(conv2d_backfil, 0)};
std::vector<size_t> out_shape = AnfAlgo::GetOutputInferShape(conv2d_backfil, 0);
if (out_shape.size() != kConv2DAxisNum) {
MS_LOG(EXCEPTION) << "Conv2DBackpropFilter's output axis number should be " << kConv2DAxisNum << ", but got "
<< out_shape.size();
}
std::swap(out_shape[0], out_shape[1]);
auto shapes = {out_shape};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, depth_conv_backfil.get());
return depth_conv_backfil;
}
void SetCommonAttrs(const CNodePtr &conv2d, const CNodePtr &depth_conv) {
AnfAlgo::CopyNodeAttr(kAttrKernelSize, conv2d, depth_conv);
AnfAlgo::CopyNodeAttr(kAttrDilation, conv2d, depth_conv);
@ -256,6 +180,20 @@ void SetConv2DBackpropFilterAttrs(const CNodePtr &conv2d_backfil, const CNodePtr
}
} // namespace
CNodePtr Conv2DUnifyMindIR::CreateDepthwiseConv2D(const FuncGraphPtr &graph, const CNodePtr &conv2d,
const CNodePtr &transpose) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(conv2d);
CheckCNodeInputSize(conv2d, kConvInputTensorNum);
std::vector<AnfNodePtr> depth_conv_inputs = {NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeOpName)),
conv2d->input(kIndex1), transpose};
auto depth_conv = NewCNode(depth_conv_inputs, graph);
MS_EXCEPTION_IF_NULL(depth_conv);
depth_conv->set_abstract(conv2d->abstract());
depth_conv->set_scope(conv2d->scope());
return depth_conv;
}
const BaseRef Conv2DUnifyMindIR::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
VarPtr W = std::make_shared<Var>();
@ -275,12 +213,39 @@ const AnfNodePtr Conv2DUnifyMindIR::Process(const FuncGraphPtr &graph, const Anf
return nullptr;
}
CheckCNodeInputSize(conv2d, kConvInputTensorNum);
auto transpose = CreateTranspose(graph, conv2d, conv2d->input(kIndex2), true);
auto transpose = CreateTranspose(graph, conv2d, conv2d->input(kIndex2), true, *this);
auto depth_conv = CreateDepthwiseConv2D(graph, conv2d, transpose);
SetConv2DAttrs(conv2d, depth_conv);
return depth_conv;
}
CNodePtr Conv2DBackpropInputUnifyMindIR::CreateDepthwiseConv2DBackpropInput(const FuncGraphPtr &graph,
const CNodePtr &conv2d_backin,
const CNodePtr &transpose) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(conv2d_backin);
CNodePtr depth_conv_backin = nullptr;
if (conv2d_backin->inputs().size() == kConv2DBackpropInputNum) {
std::vector<AnfNodePtr> depth_conv_backin_inputs = {
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)),
conv2d_backin->input(kIndex3), transpose, conv2d_backin->input(kIndex1)};
depth_conv_backin = NewCNode(depth_conv_backin_inputs, graph);
} else {
// In nn.Conv2DTranspose, Conv2DBackpropInput is a forward op and the input_sizes input will be convert to attr
// in pynative mode.
std::vector<AnfNodePtr> depth_conv_backin_inputs = {
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)), transpose,
conv2d_backin->input(kIndex1)};
depth_conv_backin = NewCNode(depth_conv_backin_inputs, graph);
AnfAlgo::CopyNodeAttr(kAttrInputSizes, kAttrInputSize, conv2d_backin, depth_conv_backin);
}
MS_EXCEPTION_IF_NULL(depth_conv_backin);
depth_conv_backin->set_abstract(conv2d_backin->abstract());
depth_conv_backin->set_scope(conv2d_backin->scope());
return depth_conv_backin;
}
const BaseRef Conv2DBackpropInputUnifyMindIR::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
VectorRef pattern({prim::kPrimConv2DBackpropInput, Xs});
@ -306,12 +271,50 @@ const AnfNodePtr Conv2DBackpropInputUnifyMindIR::Process(const FuncGraphPtr &gra
MS_LOG(EXCEPTION) << "Conv2DBackpropInput's input number should be " << (kConv2DBackpropInputNum - 1) << " or "
<< (kConv2DBackpropInputNum - 2) << ", but got " << (input_size - 1);
}
auto transpose = CreateTranspose(graph, conv2d_backin, conv2d_backin->input(kIndex2), true);
auto transpose = CreateTranspose(graph, conv2d_backin, conv2d_backin->input(kIndex2), true, *this);
auto depth_conv_backin = CreateDepthwiseConv2DBackpropInput(graph, conv2d_backin, transpose);
SetConv2DBackpropInputAttrs(conv2d_backin, depth_conv_backin);
return depth_conv_backin;
}
CNodePtr Conv2DBackpropFilterUnifyMindIR::CreateDepthwiseConv2DBackpropFilter(const FuncGraphPtr &graph,
const CNodePtr &conv2d_backfil) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(conv2d_backfil);
if (conv2d_backfil->inputs().size() != kConv2DBackpropInputNum) {
MS_LOG(EXCEPTION) << "Conv2DBackpropFilter's input number should be " << (kConv2DBackpropInputNum - 1)
<< ", but got " << (conv2d_backfil->inputs().size() - 1);
}
auto filter_size_node = conv2d_backfil->input(kIndex3);
MS_EXCEPTION_IF_NULL(filter_size_node);
auto filter_size_vnode = filter_size_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(filter_size_vnode);
auto filter_size = GetValue<std::vector<int64_t>>(filter_size_vnode->value());
// swap axis 0 and 1 of filter shape, but don't swap twice since some node share same filter_size valuenode
// when the filter_size value is same.
if (filter_size[0] != 1) {
std::swap(filter_size[0], filter_size[1]);
conv2d_backfil->input(kIndex3)->cast<ValueNodePtr>()->set_value(MakeValue(filter_size));
}
std::vector<AnfNodePtr> depth_conv_backfil_inputs = {
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropFilterOpName)),
conv2d_backfil->input(kIndex2), conv2d_backfil->input(kIndex3), conv2d_backfil->input(kIndex1)};
auto depth_conv_backfil = NewCNode(depth_conv_backfil_inputs, graph);
MS_EXCEPTION_IF_NULL(depth_conv_backfil);
depth_conv_backfil->set_scope(conv2d_backfil->scope());
auto types = {AnfAlgo::GetOutputInferDataType(conv2d_backfil, 0)};
std::vector<size_t> out_shape = AnfAlgo::GetOutputInferShape(conv2d_backfil, 0);
if (out_shape.size() != kConv2DAxisNum) {
MS_LOG(EXCEPTION) << "Conv2DBackpropFilter's output axis number should be " << kConv2DAxisNum << ", but got "
<< out_shape.size();
}
std::swap(out_shape[0], out_shape[1]);
auto shapes = {out_shape};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, depth_conv_backfil.get());
return depth_conv_backfil;
}
const BaseRef Conv2DBackpropFilterUnifyMindIR::DefinePattern() const {
VarPtr dout = std::make_shared<Var>();
VarPtr input = std::make_shared<Var>();
@ -335,7 +338,7 @@ const AnfNodePtr Conv2DBackpropFilterUnifyMindIR::Process(const FuncGraphPtr &gr
auto depth_conv_backfil = CreateDepthwiseConv2DBackpropFilter(graph, conv2d_backfil);
SetConv2DBackpropFilterAttrs(conv2d_backfil, depth_conv_backfil);
auto transpose = CreateTranspose(graph, conv2d_backfil, depth_conv_backfil, false);
auto transpose = CreateTranspose(graph, conv2d_backfil, depth_conv_backfil, false, *this);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,6 +27,9 @@ class Conv2DUnifyMindIR : public PatternProcessPass {
~Conv2DUnifyMindIR() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
CNodePtr CreateDepthwiseConv2D(const FuncGraphPtr &graph, const CNodePtr &conv2d, const CNodePtr &transpose) const;
};
class Conv2DBackpropInputUnifyMindIR : public PatternProcessPass {
@ -36,6 +39,10 @@ class Conv2DBackpropInputUnifyMindIR : public PatternProcessPass {
~Conv2DBackpropInputUnifyMindIR() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
CNodePtr CreateDepthwiseConv2DBackpropInput(const FuncGraphPtr &graph, const CNodePtr &conv2d_backin,
const CNodePtr &transpose) const;
};
class Conv2DBackpropFilterUnifyMindIR : public PatternProcessPass {
@ -45,6 +52,9 @@ class Conv2DBackpropFilterUnifyMindIR : public PatternProcessPass {
~Conv2DBackpropFilterUnifyMindIR() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
CNodePtr CreateDepthwiseConv2DBackpropFilter(const FuncGraphPtr &graph, const CNodePtr &conv2d_backfil) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -20,6 +20,7 @@
#include <memory>
#include <numeric>
#include <functional>
#include <algorithm>
#include "backend/session/anf_runtime_algorithm.h"
#include "utils/log_adapter.h"
@ -140,15 +141,15 @@ CNodePtr CreateDynamicShapeCNode(const FuncGraphPtr &func_graph, const AnfNodePt
return dynamic_shape;
}
CNodePtr CreateDropoutGenMaskCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &dropout,
const ValueNodePtr &keep_prob_value, const AnfNodePtr &dropout_input,
const abstract::ShapePtr &input_shape) {
CNodePtr CreateDropoutGenMaskCNode(const FuncGraphPtr &func_graph, const CNodePtr &dropout,
const ValueNodePtr &keep_prob_value, const abstract::ShapePtr &input_shape,
const PatternProcessPass &pass) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(dropout);
MS_EXCEPTION_IF_NULL(input_shape);
std::vector<AnfNodePtr> dropout_gen_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName))};
if (input_shape->IsDynamic()) {
CNodePtr dynamic_shape = CreateDynamicShapeCNode(func_graph, dropout_input, input_shape);
CNodePtr dynamic_shape = CreateDynamicShapeCNode(func_graph, dropout->input(kIndex1), input_shape);
dynamic_shape->set_scope(dropout->scope());
dropout_gen_mask_inputs.push_back(dynamic_shape);
dropout_gen_mask_inputs.push_back(keep_prob_value);
@ -157,7 +158,7 @@ CNodePtr CreateDropoutGenMaskCNode(const FuncGraphPtr &func_graph, const AnfNode
dropout_gen_mask_inputs.push_back(shape_value);
dropout_gen_mask_inputs.push_back(keep_prob_value);
}
CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs);
CNodePtr dropout_gen_mask = NewCNode(dropout_gen_mask_inputs, func_graph);
MS_EXCEPTION_IF_NULL(dropout_gen_mask);
std::shared_ptr<abstract::AbstractTensor> gen_mask_abstract;
@ -223,8 +224,7 @@ const AnfNodePtr DropoutAndDropoutGradUnifyMindIR::Process(const FuncGraphPtr &f
auto dropout_input = dropout_cnode->input(kIndex1);
auto input_shape = GetDropoutInputShape(dropout_input);
// CreateDropoutGenMask
auto dropout_gen_mask =
CreateDropoutGenMaskCNode(func_graph, dropout_node, keep_prob_value, dropout_input, input_shape);
auto dropout_gen_mask = CreateDropoutGenMaskCNode(func_graph, dropout_cnode, keep_prob_value, input_shape, *this);
// CreateDropoutDoMask-forward
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
@ -242,7 +242,7 @@ const AnfNodePtr DropoutAndDropoutGradUnifyMindIR::Process(const FuncGraphPtr &f
std::vector<AnfNodePtr> dropout_do_mask1_inputs{
NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)), dropout_input, dropout_gen_mask,
keep_prob_value};
dropout_do_mask1 = func_graph->NewCNode(dropout_do_mask1_inputs);
dropout_do_mask1 = NewCNode(dropout_do_mask1_inputs, func_graph);
MS_EXCEPTION_IF_NULL(dropout_do_mask1);
auto do_mask_abstract1 =
std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), input_shape);
@ -262,7 +262,7 @@ const AnfNodePtr DropoutAndDropoutGradUnifyMindIR::Process(const FuncGraphPtr &f
auto dropout_grad_input = utils::cast<AnfNodePtr>((*equiv)[grad_input_]);
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
dropout_grad_input, dropout_gen_mask, keep_prob_value};
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);
auto dropout_do_mask = NewCNode(dropout_do_mask_inputs, func_graph);
MS_EXCEPTION_IF_NULL(dropout_do_mask);
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), input_shape);
dropout_do_mask->set_abstract(do_mask_abstract);
@ -300,12 +300,11 @@ const AnfNodePtr DropoutUnifyMindIR0::Process(const FuncGraphPtr &func_graph, co
auto input_shape = GetDropoutInputShape(dropout_input);
// CreateDropoutGenMask
auto dropout_gen_mask =
CreateDropoutGenMaskCNode(func_graph, dropout_node, keep_prob_value, dropout_input, input_shape);
auto dropout_gen_mask = CreateDropoutGenMaskCNode(func_graph, dropout_cnode, keep_prob_value, input_shape, *this);
// CreateDropoutDoMask
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
dropout_input, dropout_gen_mask, keep_prob_value};
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);
auto dropout_do_mask = NewCNode(dropout_do_mask_inputs, func_graph);
MS_EXCEPTION_IF_NULL(dropout_do_mask);
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), input_shape);
dropout_do_mask->set_abstract(do_mask_abstract);
@ -341,12 +340,11 @@ const AnfNodePtr DropoutUnifyMindIR1::Process(const FuncGraphPtr &func_graph, co
auto dropout_input = dropout_node->input(kIndex1);
auto input_shape = GetDropoutInputShape(dropout_input);
// CreateDropoutGenMask
auto dropout_gen_mask =
CreateDropoutGenMaskCNode(func_graph, dropout_node, keep_prob_value, dropout_input, input_shape);
auto dropout_gen_mask = CreateDropoutGenMaskCNode(func_graph, dropout_node, keep_prob_value, input_shape, *this);
// CreateDropoutDoMask
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
dropout_input, dropout_gen_mask, keep_prob_value};
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);
auto dropout_do_mask = NewCNode(dropout_do_mask_inputs, func_graph);
MS_EXCEPTION_IF_NULL(dropout_do_mask);
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), input_shape);
dropout_do_mask->set_abstract(do_mask_abstract);
@ -396,7 +394,7 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph,
auto grad_input = dropout_grad_cnode->input(kIndex1);
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
grad_input, mask_input, keep_prob_value};
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);
auto dropout_do_mask = NewCNode(dropout_do_mask_inputs, func_graph);
MS_EXCEPTION_IF_NULL(dropout_do_mask);
auto do_mask_abstract =
std::make_shared<abstract::AbstractTensor>(TypeIdToType(grad_input_type_id), grad_input_shape);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -27,9 +27,9 @@
namespace mindspore {
namespace opt {
namespace {
void CreateOutputsOfLSQPerLayerGradD(const FuncGraphPtr &graph, const CNodePtr &lsq_perlayer_grad_node,
std::vector<AnfNodePtr> *const lsq_perlayer_grad_d_outputs) {
void FakeLearnedScaleQuantPerLayerGradUnifyMindIR::CreateOutputsOfLSQPerLayerGradD(
const FuncGraphPtr &graph, const CNodePtr &lsq_perlayer_grad_node,
std::vector<AnfNodePtr> *const lsq_perlayer_grad_d_outputs) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(lsq_perlayer_grad_node);
const auto &lsq_perlayer_grad_inputs = lsq_perlayer_grad_node->inputs();
@ -41,7 +41,7 @@ void CreateOutputsOfLSQPerLayerGradD(const FuncGraphPtr &graph, const CNodePtr &
NewValueNode(std::make_shared<Primitive>(kFakeLearnedScaleQuantPerLayerGradDOpName)),
lsq_perlayer_grad_inputs[kIndex1], lsq_perlayer_grad_inputs[kIndex2], lsq_perlayer_grad_inputs[kIndex3],
lsq_perlayer_grad_inputs[kIndex4]};
auto lsq_perlayer_grad_d = graph->NewCNode(lsq_perlayer_grad_d_inputs);
auto lsq_perlayer_grad_d = NewCNode(lsq_perlayer_grad_d_inputs, graph);
MS_EXCEPTION_IF_NULL(lsq_perlayer_grad_d);
lsq_perlayer_grad_d->set_scope(lsq_perlayer_grad_node->scope());
@ -56,9 +56,10 @@ void CreateOutputsOfLSQPerLayerGradD(const FuncGraphPtr &graph, const CNodePtr &
lsq_perlayer_grad_d_outputs);
}
void CreateOutputsOfLSQPerLayerReduceGrad(const FuncGraphPtr &graph, const CNodePtr &lsq_perlayer_grad_node,
const std::vector<AnfNodePtr> &lsq_perlayer_grad_d_outputs,
std::vector<AnfNodePtr> *const lsq_perlayer_reduce_grad_outputs) {
void FakeLearnedScaleQuantPerLayerGradUnifyMindIR::CreateOutputsOfLSQPerLayerReduceGrad(
const FuncGraphPtr &graph, const CNodePtr &lsq_perlayer_grad_node,
const std::vector<AnfNodePtr> &lsq_perlayer_grad_d_outputs,
std::vector<AnfNodePtr> *const lsq_perlayer_reduce_grad_outputs) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(lsq_perlayer_grad_node);
MS_EXCEPTION_IF_NULL(lsq_perlayer_reduce_grad_outputs);
@ -74,7 +75,7 @@ void CreateOutputsOfLSQPerLayerReduceGrad(const FuncGraphPtr &graph, const CNode
std::vector<AnfNodePtr> lsq_perlayer_reduce_grad_inputs = {
NewValueNode(std::make_shared<Primitive>(kFakeLearnedScaleQuantPerLayerGradDReduceOpName)),
lsq_perlayer_grad_d_outputs[kIndex1]};
auto lsq_perlayer_reduce_grad = graph->NewCNode(lsq_perlayer_reduce_grad_inputs);
auto lsq_perlayer_reduce_grad = NewCNode(lsq_perlayer_reduce_grad_inputs, graph);
MS_EXCEPTION_IF_NULL(lsq_perlayer_reduce_grad);
lsq_perlayer_reduce_grad->set_scope(lsq_perlayer_grad_node->scope());
@ -85,8 +86,9 @@ void CreateOutputsOfLSQPerLayerReduceGrad(const FuncGraphPtr &graph, const CNode
(*lsq_perlayer_reduce_grad_outputs).push_back(lsq_perlayer_reduce_grad);
}
void CreateOutputsOfLSQPerChannelGradD(const FuncGraphPtr &graph, const CNodePtr &lsq_perchannel_grad_node,
std::vector<AnfNodePtr> *const lsq_perchannel_grad_d_outputs) {
void FakeLearnedScaleQuantPerChannelGradUnifyMindIR::CreateOutputsOfLSQPerChannelGradD(
const FuncGraphPtr &graph, const CNodePtr &lsq_perchannel_grad_node,
std::vector<AnfNodePtr> *const lsq_perchannel_grad_d_outputs) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(lsq_perchannel_grad_node);
const auto &lsq_perchannel_grad_inputs = lsq_perchannel_grad_node->inputs();
@ -98,7 +100,7 @@ void CreateOutputsOfLSQPerChannelGradD(const FuncGraphPtr &graph, const CNodePtr
NewValueNode(std::make_shared<Primitive>(kFakeLearnedScaleQuantPerChannelGradDOpName)),
lsq_perchannel_grad_inputs[1], lsq_perchannel_grad_inputs[2], lsq_perchannel_grad_inputs[3],
lsq_perchannel_grad_inputs[4]};
auto lsq_perchannel_grad_d = graph->NewCNode(lsq_perchannel_grad_d_inputs);
auto lsq_perchannel_grad_d = NewCNode(lsq_perchannel_grad_d_inputs, graph);
MS_EXCEPTION_IF_NULL(lsq_perchannel_grad_d);
lsq_perchannel_grad_d->set_scope(lsq_perchannel_grad_node->scope());
@ -114,9 +116,10 @@ void CreateOutputsOfLSQPerChannelGradD(const FuncGraphPtr &graph, const CNodePtr
lsq_perchannel_grad_d_outputs);
}
void CreateOutputsOfLSQPerChannelReduceGrad(const FuncGraphPtr &graph, const CNodePtr &lsq_perchannel_grad_node,
const std::vector<AnfNodePtr> &lsq_perchannel_grad_d_outputs,
std::vector<AnfNodePtr> *const lsq_perchannel_reduce_grad_outputs) {
void FakeLearnedScaleQuantPerChannelGradUnifyMindIR::CreateOutputsOfLSQPerChannelReduceGrad(
const FuncGraphPtr &graph, const CNodePtr &lsq_perchannel_grad_node,
const std::vector<AnfNodePtr> &lsq_perchannel_grad_d_outputs,
std::vector<AnfNodePtr> *const lsq_perchannel_reduce_grad_outputs) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(lsq_perchannel_grad_node);
MS_EXCEPTION_IF_NULL(lsq_perchannel_reduce_grad_outputs);
@ -132,7 +135,7 @@ void CreateOutputsOfLSQPerChannelReduceGrad(const FuncGraphPtr &graph, const CNo
std::vector<AnfNodePtr> lsq_perchannel_reduce_grad_inputs = {
NewValueNode(std::make_shared<Primitive>(kFakeLearnedScaleQuantPerChannelGradDReduceOpName)),
lsq_perchannel_grad_d_outputs[kIndex1]};
auto lsq_perchannel_reduce_grad = graph->NewCNode(lsq_perchannel_reduce_grad_inputs);
auto lsq_perchannel_reduce_grad = NewCNode(lsq_perchannel_reduce_grad_inputs, graph);
MS_EXCEPTION_IF_NULL(lsq_perchannel_reduce_grad);
lsq_perchannel_reduce_grad->set_scope(lsq_perchannel_grad_node->scope());
@ -142,7 +145,7 @@ void CreateOutputsOfLSQPerChannelReduceGrad(const FuncGraphPtr &graph, const CNo
AnfAlgo::CopyNodeAttr(kAttrChannelAxis, lsq_perchannel_grad_node, lsq_perchannel_reduce_grad);
(*lsq_perchannel_reduce_grad_outputs).push_back(lsq_perchannel_reduce_grad);
}
} // namespace
const BaseRef FakeLearnedScaleQuantPerLayerGradUnifyMindIR::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
auto prim = std::make_shared<Primitive>(kFakeLearnedScaleQuantPerLayerGradOpName);

View File

@ -16,6 +16,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_FAKE_LEARNED_SCALE_QUANT_GRAD_UNIFY_MINDIR_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_FAKE_LEARNED_SCALE_QUANT_GRAD_UNIFY_MINDIR_H_
#include <vector>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/common/helper.h"
@ -41,6 +42,13 @@ class FakeLearnedScaleQuantPerLayerGradUnifyMindIR : public PatternProcessPass {
~FakeLearnedScaleQuantPerLayerGradUnifyMindIR() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
void CreateOutputsOfLSQPerLayerGradD(const FuncGraphPtr &graph, const CNodePtr &lsq_perlayer_grad_node,
std::vector<AnfNodePtr> *const lsq_perlayer_grad_d_outputs) const;
void CreateOutputsOfLSQPerLayerReduceGrad(const FuncGraphPtr &graph, const CNodePtr &lsq_perlayer_grad_node,
const std::vector<AnfNodePtr> &lsq_perlayer_grad_d_outputs,
std::vector<AnfNodePtr> *const lsq_perlayer_reduce_grad_outputs) const;
};
class FakeLearnedScaleQuantPerChannelGradUnifyMindIR : public PatternProcessPass {
@ -50,6 +58,13 @@ class FakeLearnedScaleQuantPerChannelGradUnifyMindIR : public PatternProcessPass
~FakeLearnedScaleQuantPerChannelGradUnifyMindIR() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
void CreateOutputsOfLSQPerChannelGradD(const FuncGraphPtr &graph, const CNodePtr &lsq_perchannel_grad_node,
std::vector<AnfNodePtr> *const lsq_perchannel_grad_d_outputs) const;
void CreateOutputsOfLSQPerChannelReduceGrad(const FuncGraphPtr &graph, const CNodePtr &lsq_perchannel_grad_node,
const std::vector<AnfNodePtr> &lsq_perchannel_grad_d_outputs,
std::vector<AnfNodePtr> *const lsq_perchannel_reduce_grad_outputs) const;
};
} // namespace opt

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -43,8 +43,9 @@ CNodePtr GetMaxPool(const CNodePtr &maxpool_grad) {
MS_EXCEPTION_IF_NULL(maxpool_anf);
return maxpool_anf->cast<CNodePtr>();
}
} // namespace
CNodePtr CreateMaxPoolWithArgmax(const FuncGraphPtr &graph, const CNodePtr &maxpool) {
CNodePtr MaxPool2MaxPoolWithArgmax::CreateMaxPoolWithArgmax(const FuncGraphPtr &graph, const CNodePtr &maxpool) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(maxpool);
if (maxpool->inputs().size() != kMaxPoolInputNum) {
@ -53,7 +54,7 @@ CNodePtr CreateMaxPoolWithArgmax(const FuncGraphPtr &graph, const CNodePtr &maxp
}
std::vector<AnfNodePtr> maxpool_argmax_inputs = {NewValueNode(std::make_shared<Primitive>(kMaxPoolWithArgmaxOpName)),
maxpool->input(kIndex1)};
auto maxpool_argmax = graph->NewCNode(maxpool_argmax_inputs);
auto maxpool_argmax = NewCNode(maxpool_argmax_inputs, graph);
MS_EXCEPTION_IF_NULL(maxpool_argmax);
maxpool_argmax->set_scope(maxpool->scope());
@ -66,8 +67,9 @@ CNodePtr CreateMaxPoolWithArgmax(const FuncGraphPtr &graph, const CNodePtr &maxp
return maxpool_argmax;
}
CNodePtr CreateMaxPoolGradWithArgmax(const FuncGraphPtr &graph, const CNodePtr &maxpool_grad,
const std::vector<AnfNodePtr> &maxpool_argmax_outputs) {
CNodePtr MaxPool2MaxPoolWithArgmax::CreateMaxPoolGradWithArgmax(
const FuncGraphPtr &graph, const CNodePtr &maxpool_grad,
const std::vector<AnfNodePtr> &maxpool_argmax_outputs) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(maxpool_grad);
if (maxpool_grad->inputs().size() != kMaxPoolGradInputNum) {
@ -79,15 +81,16 @@ CNodePtr CreateMaxPoolGradWithArgmax(const FuncGraphPtr &graph, const CNodePtr &
std::vector<AnfNodePtr> maxpool_grad_argmax_inputs = {
NewValueNode(std::make_shared<Primitive>(kMaxPoolGradWithArgmaxOpName)), maxpool_grad->input(kIndex1),
maxpool_grad->input(kIndex3), maxpool_argmax_outputs[kIndex1]};
auto maxpool_grad_argmax = graph->NewCNode(maxpool_grad_argmax_inputs);
auto maxpool_grad_argmax = NewCNode(maxpool_grad_argmax_inputs, graph);
MS_EXCEPTION_IF_NULL(maxpool_grad_argmax);
maxpool_grad_argmax->set_scope(maxpool_grad->scope());
maxpool_grad_argmax->set_abstract(maxpool_grad->abstract());
return maxpool_grad_argmax;
}
void SetNodeAttrs(const CNodePtr &maxpool, const CNodePtr &maxpool_grad, const CNodePtr &maxpool_argmax,
const CNodePtr &maxpool_grad_argmax) {
void MaxPool2MaxPoolWithArgmax::SetNodeAttrs(const CNodePtr &maxpool, const CNodePtr &maxpool_grad,
const CNodePtr &maxpool_argmax,
const CNodePtr &maxpool_grad_argmax) const {
auto strides = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(maxpool, kAttrStrides);
auto ksize = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(maxpool, kAttrKernelSize);
if (strides.size() != kMaxPoolAttrAxisNum) {
@ -114,7 +117,6 @@ void SetNodeAttrs(const CNodePtr &maxpool, const CNodePtr &maxpool_grad, const C
AnfAlgo::SetNodeAttr(kAttrKernelSize, MakeValue(ksize), maxpool_argmax);
AnfAlgo::SetNodeAttr(kAttrKernelSize, MakeValue(ksize), maxpool_grad_argmax);
}
} // namespace
const BaseRef MaxPool2MaxPoolWithArgmax::DefinePattern() const {
VarPtr X = std::make_shared<Var>();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,6 +16,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_TO_MAXPOOL_WITH_ARGMAX_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_TO_MAXPOOL_WITH_ARGMAX_H_
#include <vector>
#include <memory>
#include "backend/optimizer/common/optimizer.h"
@ -28,6 +29,13 @@ class MaxPool2MaxPoolWithArgmax : public PatternProcessPass {
~MaxPool2MaxPoolWithArgmax() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
CNodePtr CreateMaxPoolWithArgmax(const FuncGraphPtr &graph, const CNodePtr &maxpool) const;
CNodePtr CreateMaxPoolGradWithArgmax(const FuncGraphPtr &graph, const CNodePtr &maxpool_grad,
const std::vector<AnfNodePtr> &maxpool_argmax_outputs) const;
void SetNodeAttrs(const CNodePtr &maxpool, const CNodePtr &maxpool_grad, const CNodePtr &maxpool_argmax,
const CNodePtr &maxpool_grad_argmax) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -45,6 +45,16 @@ constexpr int64_t kRankIdSeven = 7;
constexpr size_t kSizeFour = 4;
constexpr int64_t kInvalidId = -1;
bool IsTop(const std::vector<int64_t> &send_rank_ids) {
return send_rank_ids[kRankIdZero] != kInvalidId || send_rank_ids[kRankIdOne] != kInvalidId ||
send_rank_ids[kRankIdSeven] != kInvalidId;
}
bool IsBottom(const std::vector<int64_t> &send_rank_ids) {
return send_rank_ids[kRankIdThree] != kInvalidId || send_rank_ids[kRankIdFour] != kInvalidId ||
send_rank_ids[kRankIdFive] != kInvalidId;
}
// cal split attrs size_splits, shapes and num_split
int64_t CalSplitAttrs(const std::vector<size_t> &base_shape, const bool is_first, const bool is_last,
const int64_t split_dim, const std::vector<int64_t> &send_lens, std::vector<int64_t> *size_splits,
@ -61,8 +71,8 @@ int64_t CalSplitAttrs(const std::vector<size_t> &base_shape, const bool is_first
int64_t split_middle_size = base_shape[split_dim];
std::vector<size_t> shape_tmp(base_shape);
// [top, bottom, left, right]
int64_t first_size = split_dim == kWDim ? send_lens[2] : send_lens[0];
int64_t last_size = split_dim == kWDim ? send_lens[3] : send_lens[1];
int64_t first_size = split_dim == kWDim ? send_lens[kDim2] : send_lens[0];
int64_t last_size = split_dim == kWDim ? send_lens[kDim3] : send_lens[1];
if (is_first) {
// first
@ -95,14 +105,15 @@ int64_t CalSplitAttrs(const std::vector<size_t> &base_shape, const bool is_first
CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &split_input,
const std::vector<size_t> &base_shape, bool is_first, bool is_last, int64_t split_dim,
const std::vector<int64_t> &send_lens, TypeId input_dtype, int64_t *num_split) {
const std::vector<int64_t> &send_lens, TypeId input_dtype, int64_t *num_split,
const PatternProcessPass &pass) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(num_split);
if (split_input.empty()) {
MS_LOG(EXCEPTION) << "The input is empty, can not create splitv node.";
return nullptr;
}
auto split_v = graph->NewCNode(split_input);
auto split_v = pass.NewCNode(split_input, graph);
MS_EXCEPTION_IF_NULL(split_v);
std::vector<int64_t> size_splits = {};
std::vector<std::vector<size_t>> shapes = {};
@ -147,127 +158,6 @@ std::vector<std::vector<size_t>> CalAllToAllvOutputShape(const std::vector<size_
return shapes;
}
// returns {top_bottom, left_right, top_corner, bottom_corner}, if no split, set it nullptr
std::vector<CNodePtr> CreateSplitNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
std::vector<int64_t> *split_num) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2);
MS_EXCEPTION_IF_NULL(split_num);
std::vector<int64_t> send_rank_ids =
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2, kAttrSendRankIds);
std::vector<int64_t> send_lens = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2, kAttrSendLens);
if (neighbor_exchange_v2->size() <= kNeighborExchangeV2InputIdx) {
MS_LOG(EXCEPTION) << "Invalid cnode " << neighbor_exchange_v2->DebugString() << " input size "
<< neighbor_exchange_v2->size();
}
std::vector<CNodePtr> split_nodes = {};
auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx);
bool is_top = ((send_rank_ids[kRankIdZero] != kInvalidId) || (send_rank_ids[kRankIdOne] != kInvalidId) ||
(send_rank_ids[kRankIdSeven] != kInvalidId));
bool is_bottom = ((send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFour] != kInvalidId) ||
(send_rank_ids[kRankIdFive] != kInvalidId));
bool is_left = (send_rank_ids[kRankIdSix] != kInvalidId);
bool is_right = (send_rank_ids[kRankIdTwo] != kInvalidId);
auto dtype = AnfAlgo::GetOutputInferDataType(neighbor_exchange_v2_input, 0);
auto shape = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_input, 0);
if (SizeToLong(shape.size()) != kShapeSize) { // only support NCHW now
MS_LOG(EXCEPTION) << "Invalid shape size " << shape.size() << ", only support NCHW input now!";
}
// splitv for top & bottom
int64_t num_split_h = 0;
CNodePtr split_v_top_bottom = nullptr;
if (is_top || is_bottom) {
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
neighbor_exchange_v2_input};
split_v_top_bottom =
CreateSplitNode(graph, split_input, shape, is_top, is_bottom, kHDim, send_lens, dtype, &num_split_h);
}
split_nodes.emplace_back(split_v_top_bottom);
split_num->push_back(num_split_h);
// splitv for left & right
int64_t num_split_w = 0;
CNodePtr split_v_left_right = nullptr;
if (is_left || is_right) {
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
neighbor_exchange_v2_input};
split_v_left_right =
CreateSplitNode(graph, split_input, shape, is_left, is_right, kWDim, send_lens, dtype, &num_split_w);
}
split_nodes.emplace_back(split_v_left_right);
split_num->push_back(num_split_w);
// splitv for corner
if ((send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdSeven] != kInvalidId) ||
(send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFive] != kInvalidId)) {
// top_bottom_split outputs
std::vector<AnfNodePtr> split_outputs_top_bottom;
CreateMultipleOutputsOfAnfNode(graph, split_nodes[0], static_cast<size_t>((*split_num)[0]),
&split_outputs_top_bottom);
if (split_outputs_top_bottom.empty()) {
MS_LOG(EXCEPTION) << "The node " << split_nodes[0]->DebugString()
<< " should have at least one output, but got 0.";
}
// for top corner
if ((send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdSeven] != kInvalidId)) {
auto shape_tmp(shape);
shape_tmp[kHDim] = send_lens[0];
bool is_first = (send_rank_ids[kRankIdSeven] != kInvalidId);
bool is_last = (send_rank_ids[kRankIdOne] != kInvalidId);
std::vector<AnfNodePtr> split_v_corner_top_input = {
NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
split_v_corner_top_input.insert(split_v_corner_top_input.end(), split_outputs_top_bottom.begin(),
split_outputs_top_bottom.begin() + 1);
int64_t num_split_top_corner = 0;
CNodePtr split_v_corner_top = CreateSplitNode(graph, split_v_corner_top_input, shape_tmp, is_first, is_last,
kWDim, send_lens, dtype, &num_split_top_corner);
split_nodes.emplace_back(split_v_corner_top);
split_num->push_back(num_split_top_corner);
} else {
split_nodes.emplace_back(nullptr);
split_num->push_back(0);
}
// for bottom corner
if ((send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFive] != kInvalidId)) {
auto shape_tmp(shape);
shape_tmp[kHDim] = send_lens[1];
bool is_first = (send_rank_ids[kRankIdFive] != kInvalidId);
bool is_last = (send_rank_ids[kRankIdThree] != kInvalidId);
std::vector<AnfNodePtr> split_v_corner_bottom_input = {
NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
split_v_corner_bottom_input.insert(split_v_corner_bottom_input.end(), split_outputs_top_bottom.end() - 1,
split_outputs_top_bottom.end());
int64_t num_split_bottom_corner = 0;
CNodePtr split_v_corner_bottom = CreateSplitNode(graph, split_v_corner_bottom_input, shape_tmp, is_first, is_last,
kWDim, send_lens, dtype, &num_split_bottom_corner);
split_nodes.emplace_back(split_v_corner_bottom);
split_num->push_back(num_split_bottom_corner);
} else {
split_nodes.emplace_back(nullptr);
split_num->push_back(0);
}
} else {
split_nodes.emplace_back(nullptr);
split_num->push_back(0);
split_nodes.emplace_back(nullptr);
split_num->push_back(0);
}
return split_nodes;
}
std::vector<AnfNodePtr> CreateAllToAllvInput(const std::vector<std::vector<AnfNodePtr>> &split_outputs,
const std::vector<int64_t> &send_rank_ids) {
std::vector<AnfNodePtr> all_to_all_v_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllVOpName))};
@ -309,7 +199,7 @@ AnfNodePtr GetCenter(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchang
return neighbor_exchange_v2_grad->input(kNeighborExchangeV2InputIdx);
}
} else {
CreateMultipleOutputsOfAnfNode(graph, split_nodes[2], static_cast<size_t>(split_num[2]), &output);
CreateMultipleOutputsOfAnfNode(graph, split_nodes[kDim2], static_cast<size_t>(split_num[kDim2]), &output);
if (output.size() < 2) {
MS_LOG(EXCEPTION) << "Wrong split output size: " << output.size() << ", except size >= 2.";
}
@ -355,16 +245,17 @@ std::vector<AnfNodePtr> CreateAllToAllvInputForGrad(const std::vector<int64_t> &
}
}
// 2
if (split_nodes[2] != nullptr && send_rank_ids[kRankIdTwo] != kInvalidId) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[2].end() - 1, split_outputs[2].end());
if (split_nodes[kIndex2] != nullptr && send_rank_ids[kRankIdTwo] != kInvalidId) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[kIndex2].end() - 1, split_outputs[kIndex2].end());
}
// 3, 4, 5
if (split_nodes[3] != nullptr) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[3].rbegin(), split_outputs[3].rend());
if (split_nodes[kIndex3] != nullptr) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[kIndex3].rbegin(), split_outputs[kIndex3].rend());
}
// 6
if (split_nodes[2] != nullptr && send_rank_ids[kRankIdSix] != kInvalidId) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[2].begin(), split_outputs[2].begin() + 1);
if (split_nodes[kIndex2] != nullptr && send_rank_ids[kRankIdSix] != kInvalidId) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[kIndex2].begin(),
split_outputs[kIndex2].begin() + 1);
}
// 7
if (split_nodes[1] != nullptr && send_rank_ids[kRankIdSeven] != kInvalidId) {
@ -373,10 +264,11 @@ std::vector<AnfNodePtr> CreateAllToAllvInputForGrad(const std::vector<int64_t> &
return all_to_all_v_input;
}
// alltoallv for forward & grad
CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_or_grad,
const std::vector<CNodePtr> &split_nodes, const std::vector<int64_t> &split_num,
bool is_grad) {
bool is_grad, const PatternProcessPass &pass) {
MS_LOG(DEBUG) << "Start to create alltoallv node.";
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_or_grad);
@ -434,7 +326,7 @@ CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &neighbor
std::vector<TypeId> dtypes(real_recv_rank_ids.size(), base_dtype);
// create alltoallv node
auto all_to_all_v = graph->NewCNode(all_to_all_v_input);
auto all_to_all_v = pass.NewCNode(all_to_all_v_input, graph);
MS_EXCEPTION_IF_NULL(all_to_all_v);
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, all_to_all_v.get());
@ -456,12 +348,135 @@ int64_t AllToAllRealIds(int64_t ids, const std::vector<int64_t> &recv_rank_ids)
}
return real_ids;
}
} // namespace
CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &concat_input,
const std::vector<std::vector<size_t>> &output_shape, const std::vector<TypeId> &output_dtype,
int64_t axis, int64_t input_nums) {
// returns {top_bottom, left_right, top_corner, bottom_corner}, if no split, set it nullptr
std::vector<CNodePtr> NeighborExchangeV2UnifyMindIR::CreateSplitNodes(const FuncGraphPtr &graph,
const CNodePtr &neighbor_exchange_v2,
std::vector<int64_t> *split_num) const {
MS_EXCEPTION_IF_NULL(graph);
auto concat = graph->NewCNode(concat_input);
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2);
MS_EXCEPTION_IF_NULL(split_num);
std::vector<int64_t> send_rank_ids =
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2, kAttrSendRankIds);
std::vector<int64_t> send_lens = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2, kAttrSendLens);
if (neighbor_exchange_v2->size() <= kNeighborExchangeV2InputIdx) {
MS_LOG(EXCEPTION) << "Invalid cnode " << neighbor_exchange_v2->DebugString() << " input size "
<< neighbor_exchange_v2->size();
}
std::vector<CNodePtr> split_nodes = {};
auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx);
bool is_top = IsTop(send_rank_ids);
bool is_bottom = IsBottom(send_rank_ids);
bool is_left = (send_rank_ids[kRankIdSix] != kInvalidId);
bool is_right = (send_rank_ids[kRankIdTwo] != kInvalidId);
auto dtype = AnfAlgo::GetOutputInferDataType(neighbor_exchange_v2_input, 0);
auto shape = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_input, 0);
if (SizeToLong(shape.size()) != kShapeSize) { // only support NCHW now
MS_LOG(EXCEPTION) << "Invalid shape size " << shape.size() << ", only support NCHW input now!";
}
// splitv for top & bottom
int64_t num_split_h = 0;
CNodePtr split_v_top_bottom = nullptr;
if (is_top || is_bottom) {
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
neighbor_exchange_v2_input};
split_v_top_bottom =
CreateSplitNode(graph, split_input, shape, is_top, is_bottom, kHDim, send_lens, dtype, &num_split_h, *this);
}
split_nodes.emplace_back(split_v_top_bottom);
split_num->push_back(num_split_h);
// splitv for left & right
int64_t num_split_w = 0;
CNodePtr split_v_left_right = nullptr;
if (is_left || is_right) {
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
neighbor_exchange_v2_input};
split_v_left_right =
CreateSplitNode(graph, split_input, shape, is_left, is_right, kWDim, send_lens, dtype, &num_split_w, *this);
}
split_nodes.emplace_back(split_v_left_right);
split_num->push_back(num_split_w);
// splitv for corner
if ((send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdSeven] != kInvalidId) ||
(send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFive] != kInvalidId)) {
// top_bottom_split outputs
std::vector<AnfNodePtr> split_outputs_top_bottom;
CreateMultipleOutputsOfAnfNode(graph, split_nodes[0], static_cast<size_t>((*split_num)[0]),
&split_outputs_top_bottom);
if (split_outputs_top_bottom.empty()) {
MS_LOG(EXCEPTION) << "The node " << split_nodes[0]->DebugString()
<< " should have at least one output, but got 0.";
}
// for top corner
if ((send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdSeven] != kInvalidId)) {
auto shape_tmp(shape);
shape_tmp[kHDim] = send_lens[0];
bool is_first = (send_rank_ids[kRankIdSeven] != kInvalidId);
bool is_last = (send_rank_ids[kRankIdOne] != kInvalidId);
std::vector<AnfNodePtr> split_v_corner_top_input = {
NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
split_v_corner_top_input.insert(split_v_corner_top_input.end(), split_outputs_top_bottom.begin(),
split_outputs_top_bottom.begin() + 1);
int64_t num_split_top_corner = 0;
CNodePtr split_v_corner_top = CreateSplitNode(graph, split_v_corner_top_input, shape_tmp, is_first, is_last,
kWDim, send_lens, dtype, &num_split_top_corner, *this);
split_nodes.emplace_back(split_v_corner_top);
split_num->push_back(num_split_top_corner);
} else {
split_nodes.emplace_back(nullptr);
split_num->push_back(0);
}
// for bottom corner
if ((send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFive] != kInvalidId)) {
auto shape_tmp(shape);
shape_tmp[kHDim] = send_lens[1];
bool is_first = (send_rank_ids[kRankIdFive] != kInvalidId);
bool is_last = (send_rank_ids[kRankIdThree] != kInvalidId);
std::vector<AnfNodePtr> split_v_corner_bottom_input = {
NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
split_v_corner_bottom_input.insert(split_v_corner_bottom_input.end(), split_outputs_top_bottom.end() - 1,
split_outputs_top_bottom.end());
int64_t num_split_bottom_corner = 0;
CNodePtr split_v_corner_bottom = CreateSplitNode(graph, split_v_corner_bottom_input, shape_tmp, is_first, is_last,
kWDim, send_lens, dtype, &num_split_bottom_corner, *this);
split_nodes.emplace_back(split_v_corner_bottom);
split_num->push_back(num_split_bottom_corner);
} else {
split_nodes.emplace_back(nullptr);
split_num->push_back(0);
}
} else {
split_nodes.emplace_back(nullptr);
split_num->push_back(0);
split_nodes.emplace_back(nullptr);
split_num->push_back(0);
}
return split_nodes;
}
CNodePtr NeighborExchangeV2UnifyMindIR::CreateConcatNode(const FuncGraphPtr &graph,
const std::vector<AnfNodePtr> &concat_input,
const std::vector<std::vector<size_t>> &output_shape,
const std::vector<TypeId> &output_dtype, int64_t axis,
int64_t input_nums) const {
MS_EXCEPTION_IF_NULL(graph);
auto concat = NewCNode(concat_input, graph);
MS_EXCEPTION_IF_NULL(concat);
AnfAlgo::SetOutputInferTypeAndShape(output_dtype, output_shape, concat.get());
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue<int64_t>(axis), concat);
@ -471,9 +486,11 @@ CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const std::vector<AnfNodePt
return concat;
}
CNodePtr CreateLeftRightConcat(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &all_to_all_v_outputs,
const std::vector<int64_t> &recv_rank_ids, const std::vector<int64_t> &recv_lens,
bool is_left) {
CNodePtr NeighborExchangeV2UnifyMindIR::CreateLeftRightConcat(const FuncGraphPtr &graph,
const std::vector<AnfNodePtr> &all_to_all_v_outputs,
const std::vector<int64_t> &recv_rank_ids,
const std::vector<int64_t> &recv_lens,
bool is_left) const {
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> concat_input = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))};
@ -486,11 +503,11 @@ CNodePtr CreateLeftRightConcat(const FuncGraphPtr &graph, const std::vector<AnfN
if (recv_rank_ids[first_ids] != kInvalidId) {
++input_num;
single_shape[2] += static_cast<size_t>(recv_lens[0]); // H in NCHW
single_shape[kDim2] += static_cast<size_t>(recv_lens[0]); // H in NCHW
}
if (recv_rank_ids[last_ids] != kInvalidId) {
++input_num;
single_shape[2] += static_cast<size_t>(recv_lens[1]); // H in NCHW
single_shape[kDim2] += static_cast<size_t>(recv_lens[1]); // H in NCHW
}
if (is_left) {
concat_input.insert(concat_input.end(), all_to_all_v_outputs.rbegin(), all_to_all_v_outputs.rbegin() + input_num);
@ -506,18 +523,17 @@ CNodePtr CreateLeftRightConcat(const FuncGraphPtr &graph, const std::vector<AnfN
return concat;
}
CNodePtr CreateMiddleConcat(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
const std::vector<AnfNodePtr> &all_to_all_v_outputs,
const std::vector<int64_t> &recv_rank_ids, const std::vector<int64_t> &recv_lens,
int64_t concat_dim) {
CNodePtr NeighborExchangeV2UnifyMindIR::CreateMiddleConcat(
const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2, const std::vector<AnfNodePtr> &all_to_all_v_outputs,
const std::vector<int64_t> &recv_rank_ids, const std::vector<int64_t> &recv_lens, int64_t concat_dim) const {
std::vector<AnfNodePtr> concat_input_all = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))};
int64_t input_num_all = 0;
auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx);
auto single_shape = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_input, 0);
size_t first_idx = concat_dim == kWDim ? 6 : 0;
size_t last_idx = concat_dim == kWDim ? 2 : 4;
size_t first_len = concat_dim == kWDim ? static_cast<size_t>(recv_lens[2]) : static_cast<size_t>(recv_lens[0]);
size_t last_len = concat_dim == kWDim ? static_cast<size_t>(recv_lens[3]) : static_cast<size_t>(recv_lens[1]);
size_t first_len = concat_dim == kWDim ? static_cast<size_t>(recv_lens[kDim2]) : static_cast<size_t>(recv_lens[0]);
size_t last_len = concat_dim == kWDim ? static_cast<size_t>(recv_lens[kDim3]) : static_cast<size_t>(recv_lens[1]);
// left
if (recv_rank_ids[first_idx] != kInvalidId) {
@ -553,8 +569,9 @@ CNodePtr CreateMiddleConcat(const FuncGraphPtr &graph, const CNodePtr &neighbor_
return concat_all;
}
CNodePtr AllToAllvRecvEmpty(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
const CNodePtr &all_to_all_v) {
CNodePtr NeighborExchangeV2UnifyMindIR::AllToAllvRecvEmpty(const FuncGraphPtr &graph,
const CNodePtr &neighbor_exchange_v2,
const CNodePtr &all_to_all_v) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2);
MS_EXCEPTION_IF_NULL(all_to_all_v);
@ -568,8 +585,9 @@ CNodePtr AllToAllvRecvEmpty(const FuncGraphPtr &graph, const CNodePtr &neighbor_
return depend;
}
CNodePtr CreateConcatNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
const CNodePtr &all_to_all_v) {
CNodePtr NeighborExchangeV2UnifyMindIR::CreateConcatNodes(const FuncGraphPtr &graph,
const CNodePtr &neighbor_exchange_v2,
const CNodePtr &all_to_all_v) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2);
MS_EXCEPTION_IF_NULL(all_to_all_v);
@ -611,10 +629,10 @@ CNodePtr CreateConcatNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_e
std::vector<AnfNodePtr> concat_input_all = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))};
auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx);
std::vector<size_t> shape_all = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_input, 0);
shape_all[2] =
recv_rank_ids[kRankIdZero] != kInvalidId ? shape_all[2] + static_cast<size_t>(recv_lens[0]) : shape_all[2];
shape_all[2] =
recv_rank_ids[kRankIdFour] != kInvalidId ? shape_all[2] + static_cast<size_t>(recv_lens[1]) : shape_all[2];
shape_all[kDim2] =
recv_rank_ids[kRankIdZero] != kInvalidId ? shape_all[kDim2] + static_cast<size_t>(recv_lens[0]) : shape_all[kDim2];
shape_all[kDim2] =
recv_rank_ids[kRankIdFour] != kInvalidId ? shape_all[kDim2] + static_cast<size_t>(recv_lens[1]) : shape_all[kDim2];
int64_t input_nums_all = 0;
// left concat
if (is_left) {
@ -628,7 +646,7 @@ CNodePtr CreateConcatNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_e
}
concat_input_all.insert(concat_input_all.end(), concat_left_outputs.begin(), concat_left_outputs.end());
++input_nums_all;
shape_all[3] += recv_lens[2];
shape_all[kDim3] += recv_lens[kDim2];
}
// middle concat connect to concat_all
@ -651,7 +669,7 @@ CNodePtr CreateConcatNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_e
}
concat_input_all.insert(concat_input_all.end(), concat_right_outputs.begin(), concat_right_outputs.end());
++input_nums_all;
shape_all[3] += recv_lens[3];
shape_all[kDim3] += recv_lens[kDim3];
}
std::vector<TypeId> concat_right_output_dtype = {AnfAlgo::GetOutputInferDataType(concat_input_all[1], 0)};
@ -662,8 +680,8 @@ CNodePtr CreateConcatNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_e
// grad
// returns {top_bottom, left_right, top_corner, bottom_corner}, if no split, set it nullptr
std::vector<CNodePtr> CreateSplitNodesForGrad(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad,
std::vector<int64_t> *split_num) {
std::vector<CNodePtr> NeighborExchangeV2GradUnifyMindIR::CreateSplitNodesForGrad(
const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad, std::vector<int64_t> *split_num) const {
MS_LOG(DEBUG) << "Start create splitv nodes.";
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_grad);
@ -686,17 +704,15 @@ std::vector<CNodePtr> CreateSplitNodesForGrad(const FuncGraphPtr &graph, const C
std::vector<CNodePtr> split_nodes = {};
// splitv for top & bottom
bool is_top = ((send_rank_ids[kRankIdZero] != kInvalidId) || (send_rank_ids[kRankIdOne] != kInvalidId) ||
(send_rank_ids[kRankIdSeven] != kInvalidId));
bool is_bottom = ((send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFour] != kInvalidId) ||
(send_rank_ids[kRankIdFive] != kInvalidId));
bool is_top = IsTop(send_rank_ids);
bool is_bottom = IsBottom(send_rank_ids);
CNodePtr split_v_top_bottom = nullptr;
int64_t num_split_h = 0;
if (is_top || is_bottom) {
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
neighbor_exchange_v2_grad_input};
split_v_top_bottom =
CreateSplitNode(graph, split_input, shape, is_top, is_bottom, kHDim, send_lens, dtype, &num_split_h);
CreateSplitNode(graph, split_input, shape, is_top, is_bottom, kHDim, send_lens, dtype, &num_split_h, *this);
}
split_nodes.emplace_back(split_v_top_bottom);
split_num->push_back(num_split_h);
@ -735,8 +751,8 @@ std::vector<CNodePtr> CreateSplitNodesForGrad(const FuncGraphPtr &graph, const C
int64_t num_split_w = 0;
std::vector<size_t> base_shape(shape);
base_shape[kHDim] = size_split_h[i];
auto split_v_left_right =
CreateSplitNode(graph, split_input, base_shape, is_left, is_right, kWDim, send_lens, dtype, &num_split_w);
auto split_v_left_right = CreateSplitNode(graph, split_input, base_shape, is_left, is_right, kWDim, send_lens,
dtype, &num_split_w, *this);
split_nodes.emplace_back(split_v_left_right);
split_num->push_back(num_split_w);
}
@ -756,12 +772,14 @@ std::vector<CNodePtr> CreateSplitNodesForGrad(const FuncGraphPtr &graph, const C
return split_nodes;
}
CNodePtr CreatePadNode(const FuncGraphPtr &graph, const AnfNodePtr &input, const std::vector<int64_t> &begin,
const std::vector<int64_t> &size, const std::vector<size_t> &shape, TypeId dtype) {
CNodePtr NeighborExchangeV2GradUnifyMindIR::CreatePadNode(const FuncGraphPtr &graph, const AnfNodePtr &input,
const std::vector<int64_t> &begin,
const std::vector<int64_t> &size,
const std::vector<size_t> &shape, TypeId dtype) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(input);
std::vector<AnfNodePtr> pad_inputs = {NewValueNode(std::make_shared<Primitive>(kPadOpName)), input};
auto pad = graph->NewCNode(pad_inputs);
auto pad = NewCNode(pad_inputs, graph);
std::vector<std::vector<int64_t>> paddings;
for (size_t i = 0; i < shape.size(); ++i) {
paddings.emplace_back(std::vector<int64_t>{begin[i], static_cast<int64_t>(shape[i]) - begin[i] - size[i]});
@ -772,9 +790,11 @@ CNodePtr CreatePadNode(const FuncGraphPtr &graph, const AnfNodePtr &input, const
return pad;
}
CNodePtr CreateSplitGradNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad,
const CNodePtr &all_to_all_v, const std::vector<CNodePtr> &split_nodes,
const std::vector<int64_t> &split_num) {
CNodePtr NeighborExchangeV2GradUnifyMindIR::CreateSplitGradNodes(const FuncGraphPtr &graph,
const CNodePtr &neighbor_exchange_v2_grad,
const CNodePtr &all_to_all_v,
const std::vector<CNodePtr> &split_nodes,
const std::vector<int64_t> &split_num) const {
MS_LOG(DEBUG) << "Start create splitvs grad nodes.";
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_grad);
@ -810,27 +830,27 @@ CNodePtr CreateSplitGradNodes(const FuncGraphPtr &graph, const CNodePtr &neighbo
// create pad nodes
// slice begin & size
std::vector<std::vector<int64_t>> begins = {{0, 0, 0, 0},
{0, 0, 0, static_cast<int64_t>(centerx_shape[3]) - recv_lens[3]},
{0, 0, 0, static_cast<int64_t>(centerx_shape[3]) - recv_lens[3]},
{0, 0, static_cast<int64_t>(centerx_shape[2]) - recv_lens[1],
static_cast<int64_t>(centerx_shape[3]) - recv_lens[3]},
{0, 0, static_cast<int64_t>(centerx_shape[2]) - recv_lens[1], 0},
{0, 0, static_cast<int64_t>(centerx_shape[2]) - recv_lens[1], 0},
{0, 0, 0, static_cast<int64_t>(centerx_shape[kDim3]) - recv_lens[kDim3]},
{0, 0, 0, static_cast<int64_t>(centerx_shape[kDim3]) - recv_lens[kDim3]},
{0, 0, static_cast<int64_t>(centerx_shape[kDim2]) - recv_lens[kDim1],
static_cast<int64_t>(centerx_shape[kDim3]) - recv_lens[kDim3]},
{0, 0, static_cast<int64_t>(centerx_shape[kDim2]) - recv_lens[kDim1], 0},
{0, 0, static_cast<int64_t>(centerx_shape[kDim2]) - recv_lens[kDim1], 0},
{0, 0, 0, 0},
{0, 0, 0, 0}};
std::vector<std::vector<int64_t>> sizes = {
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[0],
static_cast<int64_t>(centerx_shape[3])},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[0], recv_lens[3]},
static_cast<int64_t>(centerx_shape[kDim3])},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[0], recv_lens[kDim3]},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]),
static_cast<int64_t>(centerx_shape[2]), recv_lens[3]},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[1], recv_lens[3]},
static_cast<int64_t>(centerx_shape[kDim2]), recv_lens[kDim3]},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[1], recv_lens[kDim3]},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[1],
static_cast<int64_t>(centerx_shape[3])},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[1], recv_lens[2]},
static_cast<int64_t>(centerx_shape[kDim3])},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[1], recv_lens[kDim2]},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]),
static_cast<int64_t>(centerx_shape[2]), recv_lens[2]},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[0], recv_lens[2]}};
static_cast<int64_t>(centerx_shape[kDim2]), recv_lens[kDim2]},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[0], recv_lens[kDim2]}};
std::vector<CNodePtr> pad_nodes;
size_t output_index = 0;
for (size_t i = 0; i < recv_rank_ids.size(); ++i) {
@ -854,7 +874,7 @@ CNodePtr CreateSplitGradNodes(const FuncGraphPtr &graph, const CNodePtr &neighbo
addn_inputs.insert(addn_inputs.end(), pad_outputs.begin(), pad_outputs.end());
++pad_num;
}
auto addn = graph->NewCNode(addn_inputs);
auto addn = NewCNode(addn_inputs, graph);
MS_EXCEPTION_IF_NULL(addn);
AnfAlgo::SetOutputInferTypeAndShape({centerx_dtype}, {centerx_shape}, addn.get());
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue<std::vector<int64_t>>({pad_num}), addn);
@ -862,7 +882,6 @@ CNodePtr CreateSplitGradNodes(const FuncGraphPtr &graph, const CNodePtr &neighbo
MS_LOG(DEBUG) << "Create splitvs grad nodes success.";
return addn;
}
} // namespace
const BaseRef NeighborExchangeV2UnifyMindIR::DefinePattern() const {
return VectorRef({prim::kPrimNeighborExchangeV2, std::make_shared<SeqVar>()});
@ -876,7 +895,7 @@ const AnfNodePtr NeighborExchangeV2UnifyMindIR::Process(const FuncGraphPtr &grap
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2);
std::vector<int64_t> split_num;
auto split_nodes = CreateSplitNodes(graph, neighbor_exchange_v2, &split_num);
auto all_to_all_v = CreateAllToAllvNode(graph, neighbor_exchange_v2, split_nodes, split_num, false);
auto all_to_all_v = CreateAllToAllvNode(graph, neighbor_exchange_v2, split_nodes, split_num, false, *this);
auto concat = CreateConcatNodes(graph, neighbor_exchange_v2, all_to_all_v);
return concat;
}
@ -892,7 +911,7 @@ const AnfNodePtr NeighborExchangeV2GradUnifyMindIR::Process(const FuncGraphPtr &
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_grad);
std::vector<int64_t> split_num;
auto split_nodes = CreateSplitNodesForGrad(graph, neighbor_exchange_v2_grad, &split_num);
auto all_to_all_v = CreateAllToAllvNode(graph, neighbor_exchange_v2_grad, split_nodes, split_num, true);
auto all_to_all_v = CreateAllToAllvNode(graph, neighbor_exchange_v2_grad, split_nodes, split_num, true, *this);
auto add = CreateSplitGradNodes(graph, neighbor_exchange_v2_grad, all_to_all_v, split_nodes, split_num);
return add;
}

View File

@ -30,6 +30,24 @@ class NeighborExchangeV2UnifyMindIR : public PatternProcessPass {
~NeighborExchangeV2UnifyMindIR() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
std::vector<CNodePtr> CreateSplitNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
std::vector<int64_t> *split_num) const;
CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &concat_input,
const std::vector<std::vector<size_t>> &output_shape,
const std::vector<TypeId> &output_dtype, int64_t axis, int64_t input_nums) const;
CNodePtr CreateLeftRightConcat(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &all_to_all_v_outputs,
const std::vector<int64_t> &recv_rank_ids, const std::vector<int64_t> &recv_lens,
bool is_left) const;
CNodePtr CreateMiddleConcat(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
const std::vector<AnfNodePtr> &all_to_all_v_outputs,
const std::vector<int64_t> &recv_rank_ids, const std::vector<int64_t> &recv_lens,
int64_t concat_dim) const;
CNodePtr AllToAllvRecvEmpty(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
const CNodePtr &all_to_all_v) const;
CNodePtr CreateConcatNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
const CNodePtr &all_to_all_v) const;
};
class NeighborExchangeV2GradUnifyMindIR : public PatternProcessPass {
@ -39,6 +57,15 @@ class NeighborExchangeV2GradUnifyMindIR : public PatternProcessPass {
~NeighborExchangeV2GradUnifyMindIR() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
std::vector<CNodePtr> CreateSplitNodesForGrad(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad,
std::vector<int64_t> *split_num) const;
CNodePtr CreatePadNode(const FuncGraphPtr &graph, const AnfNodePtr &input, const std::vector<int64_t> &begin,
const std::vector<int64_t> &size, const std::vector<size_t> &shape, TypeId dtype) const;
CNodePtr CreateSplitGradNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad,
const CNodePtr &all_to_all_v, const std::vector<CNodePtr> &split_nodes,
const std::vector<int64_t> &split_num) const;
};
} // namespace opt

View File

@ -30,7 +30,8 @@ constexpr size_t kMomentumOutputNum = 2;
constexpr size_t kRMSPropOutputNum = 3;
constexpr size_t kCenteredRMSPropOutputNum = 4;
CNodePtr ProcessOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, const size_t output_size) {
CNodePtr ProcessOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, const size_t output_size,
const PatternProcessPass &pass) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
@ -53,7 +54,7 @@ CNodePtr ProcessOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, const
cnode_ptr->set_abstract(abstract_tuple);
auto index = NewValueNode(static_cast<int64_t>(0));
auto get_item = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode_ptr, index});
auto get_item = pass.NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode_ptr, index}, graph);
MS_EXCEPTION_IF_NULL(get_item);
get_item->set_abstract(abstract->Clone());
@ -76,7 +77,7 @@ const BaseRef FtrlUnifyOutput::DefinePattern() const {
}
const AnfNodePtr FtrlUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
return ProcessOutput(graph, node, kFtrlOutputNum);
return ProcessOutput(graph, node, kFtrlOutputNum, *this);
}
const BaseRef MomentumUnifyOutput::DefinePattern() const {
@ -92,7 +93,7 @@ const BaseRef MomentumUnifyOutput::DefinePattern() const {
const AnfNodePtr MomentumUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
return ProcessOutput(graph, node, kMomentumOutputNum);
return ProcessOutput(graph, node, kMomentumOutputNum, *this);
}
const BaseRef RMSPropUnifyOutput::DefinePattern() const {
@ -103,7 +104,7 @@ const BaseRef RMSPropUnifyOutput::DefinePattern() const {
const AnfNodePtr RMSPropUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
return ProcessOutput(graph, node, kRMSPropOutputNum);
return ProcessOutput(graph, node, kRMSPropOutputNum, *this);
}
const BaseRef CenteredRMSPropUnifyOutput::DefinePattern() const {
@ -123,7 +124,7 @@ const BaseRef CenteredRMSPropUnifyOutput::DefinePattern() const {
const AnfNodePtr CenteredRMSPropUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
return ProcessOutput(graph, node, kCenteredRMSPropOutputNum);
return ProcessOutput(graph, node, kCenteredRMSPropOutputNum, *this);
}
} // namespace opt
} // namespace mindspore

View File

@ -71,7 +71,7 @@ const AnfNodePtr SliceGradUnifyMindIR::Process(const FuncGraphPtr &graph, const
}
std::vector<AnfNodePtr> pad_inputs = {NewValueNode(std::make_shared<Primitive>(kPadOpName)),
slice_grad->input(kIndex1)};
auto pad = graph->NewCNode(pad_inputs);
auto pad = NewCNode(pad_inputs, graph);
MS_EXCEPTION_IF_NULL(pad);
pad->set_scope(slice_grad->scope());
pad->set_abstract(slice_grad->abstract());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -51,7 +51,7 @@ ValueNodePtr CreateValueNode(const ValuePtr &value_ptr, TypeId output_type) {
return new_node;
}
CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node,
CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const PatternProcessPass &pass,
bool is_convert_const_to_attr = false) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
@ -96,7 +96,7 @@ CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_
one_hot_inputs = {NewValueNode(one_hot_primitive), sparse_softmax_node->input(kIndex2), depth_node, value_on_node,
value_off_node};
}
auto one_hot_node = graph->NewCNode(one_hot_inputs);
auto one_hot_node = pass.NewCNode(one_hot_inputs, graph);
MS_EXCEPTION_IF_NULL(one_hot_node);
one_hot_node->set_scope(sparse_softmax_node->scope());
std::vector<size_t> labels_shape = AnfAlgo ::GetPrevNodeOutputInferShape(sparse_softmax_node, 1);
@ -109,14 +109,14 @@ CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_
}
CNodePtr CreateSoftmaxCrossEntropyWithLogits(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node,
const CNodePtr &one_hot_node) {
const CNodePtr &one_hot_node, const PatternProcessPass &pass) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
MS_EXCEPTION_IF_NULL(one_hot_node);
CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kSoftmaxCrossEntropyWithLogitsOpName)),
sparse_softmax_node->input(kIndex1), one_hot_node};
auto softmax_node = graph->NewCNode(inputs);
auto softmax_node = pass.NewCNode(inputs, graph);
MS_EXCEPTION_IF_NULL(softmax_node);
softmax_node->set_scope(sparse_softmax_node->scope());
@ -157,7 +157,8 @@ ValueNodePtr GetAxisNode(const AnfNodePtr &node) {
}
CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node,
const AnfNodePtr &softmax_output_node, bool is_pynative = false) {
const AnfNodePtr &softmax_output_node, const PatternProcessPass &pass,
bool is_pynative = false) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
MS_EXCEPTION_IF_NULL(softmax_output_node);
@ -182,7 +183,7 @@ CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_soft
kernel_graph->AddValueNodeToGraph(axis_node);
inputs = {NewValueNode(reduce_primitive), softmax_output_node, axis_node};
}
auto reduce_node = graph->NewCNode(inputs);
auto reduce_node = pass.NewCNode(inputs, graph);
MS_EXCEPTION_IF_NULL(reduce_node);
reduce_node->set_scope(sparse_softmax_node->scope());
auto reduce_abstract = softmax_output_node->abstract();
@ -194,7 +195,7 @@ CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_soft
return reduce_node;
}
CNodePtr CreateExpandDims(const FuncGraphPtr &graph, const CNodePtr &real_div_node) {
CNodePtr CreateExpandDims(const FuncGraphPtr &graph, const CNodePtr &real_div_node, const PatternProcessPass &pass) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(real_div_node);
CheckCNodeInputSize(real_div_node, kRealDivInputTensorNum);
@ -213,7 +214,7 @@ CNodePtr CreateExpandDims(const FuncGraphPtr &graph, const CNodePtr &real_div_no
expand_dims_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
expand_dims_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
std::vector<AnfNodePtr> expand_dims_inputs = {NewValueNode(expand_dims_primitive), real_div_node, axis_node};
auto expand_dims_node = graph->NewCNode(expand_dims_inputs);
auto expand_dims_node = pass.NewCNode(expand_dims_inputs, graph);
MS_EXCEPTION_IF_NULL(expand_dims_node);
expand_dims_node->set_scope(real_div_node->scope());
@ -224,7 +225,8 @@ CNodePtr CreateExpandDims(const FuncGraphPtr &graph, const CNodePtr &real_div_no
return expand_dims_node;
}
CNodePtr CreateExpandDimsPynative(const FuncGraphPtr &graph, const CNodePtr &real_div_node) {
CNodePtr CreateExpandDimsPynative(const FuncGraphPtr &graph, const CNodePtr &real_div_node,
const PatternProcessPass &pass) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(real_div_node);
CheckCNodeInputSize(real_div_node, kRealDivInputTensorNum);
@ -236,7 +238,7 @@ CNodePtr CreateExpandDimsPynative(const FuncGraphPtr &graph, const CNodePtr &rea
expand_dims_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
expand_dims_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
std::vector<AnfNodePtr> expand_dims_inputs = {NewValueNode(expand_dims_primitive), real_div_node};
auto expand_dims_node = graph->NewCNode(expand_dims_inputs);
auto expand_dims_node = pass.NewCNode(expand_dims_inputs, graph);
MS_EXCEPTION_IF_NULL(expand_dims_node);
expand_dims_node->set_scope(real_div_node->scope());
@ -249,7 +251,7 @@ CNodePtr CreateExpandDimsPynative(const FuncGraphPtr &graph, const CNodePtr &rea
}
CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const CNodePtr &mul_node,
bool is_convert_const_to_attr = false) {
const PatternProcessPass &pass, bool is_convert_const_to_attr = false) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
MS_EXCEPTION_IF_NULL(mul_node);
@ -282,7 +284,7 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no
tile_inputs = {NewValueNode(tile_primitive), mul_node->input(2), multiples_node};
}
auto tile_node = graph->NewCNode(tile_inputs);
auto tile_node = pass.NewCNode(tile_inputs, graph);
MS_EXCEPTION_IF_NULL(tile_node);
tile_node->set_scope(mul_node->scope());
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(mul_node, 1)}, {labels_shape},
@ -297,7 +299,8 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no
return tile_node;
}
CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const AnfNodePtr &tile_node) {
CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const AnfNodePtr &tile_node,
const PatternProcessPass &pass) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
MS_EXCEPTION_IF_NULL(tile_node);
@ -320,7 +323,7 @@ CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax
real_div_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
real_div_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
std::vector<AnfNodePtr> real_div_inputs = {NewValueNode(real_div_primitive), tile_node, y_node};
auto real_div_node = graph->NewCNode(real_div_inputs);
auto real_div_node = pass.NewCNode(real_div_inputs, graph);
MS_EXCEPTION_IF_NULL(real_div_node);
real_div_node->set_scope(sparse_softmax_node->scope());
@ -345,7 +348,7 @@ CNodePtr GetDependNode(const CNodePtr &mul_node) {
}
CNodePtr CreateMul(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node,
const AnfNodePtr &softmax_output_node) {
const AnfNodePtr &softmax_output_node, const PatternProcessPass &pass) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
MS_EXCEPTION_IF_NULL(softmax_output_node);
@ -377,7 +380,7 @@ CNodePtr CreateMul(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_nod
mul_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
std::vector<AnfNodePtr> mul_input = {NewValueNode(mul_primitive), softmax_output_node, y_node};
auto mul_node = graph->NewCNode(mul_input);
auto mul_node = pass.NewCNode(mul_input, graph);
MS_EXCEPTION_IF_NULL(mul_node);
mul_node->set_scope(sparse_softmax_node->scope());
@ -407,12 +410,12 @@ const AnfNodePtr SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(const F
}
CNodePtr softmax_node;
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node);
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node, one_hot_node);
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node, *this);
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node, one_hot_node, *this);
std::vector<AnfNodePtr> softmax_node_outputs;
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node, softmax_node_outputs[0]);
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node, softmax_node_outputs[0], *this);
return reduce_node;
}
@ -444,23 +447,23 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con
CheckCNodeInputSize(sparse_softmax_node_grad, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
CNodePtr softmax_node;
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad);
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node_grad, one_hot_node);
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad, *this);
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node_grad, one_hot_node, *this);
std::vector<AnfNodePtr> softmax_node_outputs;
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node_grad, softmax_node_outputs[0]);
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node);
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node_grad, softmax_node_outputs[0], *this);
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node, *this);
CNodePtr real_div_node;
if (tile_node == nullptr) {
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, mul_node->input(kIndex2));
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, mul_node->input(kIndex2), *this);
} else {
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node);
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node, *this);
}
auto expand_dims_node = CreateExpandDims(graph, real_div_node);
auto expand_dims_node = CreateExpandDims(graph, real_div_node, *this);
std::vector<AnfNodePtr> new_mul_inputs = {NewValueNode(std::make_shared<Primitive>(kMulOpName)),
softmax_node_outputs[1], expand_dims_node};
auto new_mul_node = graph->NewCNode(new_mul_inputs);
auto new_mul_node = NewCNode(new_mul_inputs, graph);
MS_EXCEPTION_IF_NULL(new_mul_node);
new_mul_node->set_scope(mul_node->scope());
new_mul_node->set_abstract(mul_node->abstract());
@ -496,13 +499,13 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::Process(c
auto sparse_softmax_node = GetSparseNode(depend_node, kIndex2);
CNodePtr softmax_node;
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad);
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node_grad, one_hot_node);
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad, *this);
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node_grad, one_hot_node, *this);
std::vector<AnfNodePtr> softmax_node_outputs;
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node_grad, softmax_node_outputs[0]);
auto mul_node = CreateMul(graph, sparse_softmax_node_grad, softmax_node_outputs[1]);
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node_grad, softmax_node_outputs[0], *this);
auto mul_node = CreateMul(graph, sparse_softmax_node_grad, softmax_node_outputs[1], *this);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
@ -521,8 +524,8 @@ const AnfNodePtr PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process
CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
CNodePtr softmax_node;
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node, true);
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node, one_hot_node);
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node, *this, true);
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node, one_hot_node, *this);
std::vector<AnfNodePtr> softmax_node_outputs;
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
@ -532,7 +535,7 @@ const AnfNodePtr PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process
AnfAlgo::GetNodeAttr<bool>(sparse_softmax_node, kAttrIsGrad)) {
return softmax_node_outputs[1];
} else {
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node, softmax_node_outputs[0], true);
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node, softmax_node_outputs[0], *this, true);
return reduce_node;
}
}
@ -561,22 +564,22 @@ const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Pro
CheckCNodeInputSize(sparse_softmax_node_grad, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
CNodePtr softmax_node;
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad);
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node_grad, one_hot_node);
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad, *this);
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node_grad, one_hot_node, *this);
std::vector<AnfNodePtr> softmax_node_outputs;
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node);
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node, *this);
CNodePtr real_div_node;
if (tile_node == nullptr) {
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, mul_node->input(kIndex2));
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, mul_node->input(kIndex2), *this);
} else {
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node);
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node, *this);
}
auto expand_dims_node = CreateExpandDimsPynative(graph, real_div_node);
auto expand_dims_node = CreateExpandDimsPynative(graph, real_div_node, *this);
std::vector<AnfNodePtr> new_mul_inputs = {NewValueNode(std::make_shared<Primitive>(kMulOpName)),
softmax_node_outputs[1], expand_dims_node};
auto new_mul_node = graph->NewCNode(new_mul_inputs);
auto new_mul_node = NewCNode(new_mul_inputs, graph);
MS_EXCEPTION_IF_NULL(new_mul_node);
new_mul_node->set_scope(mul_node->scope());
new_mul_node->set_abstract(mul_node->abstract());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.