!1226 Add reduce_scatter communication op fusion pass

Merge pull request !1226 from YuJianfeng/reduce_scatter
This commit is contained in:
mindspore-ci-bot 2020-05-18 21:12:59 +08:00 committed by Gitee
commit c77e82b811
2 changed files with 8 additions and 0 deletions

View File

@ -287,6 +287,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
auto other_pm = std::make_shared<PassManager>("other_pm");
other_pm->AddPass(std::make_shared<AllReduceFusion>());
other_pm->AddPass(std::make_shared<AllGatherFusion>());
other_pm->AddPass(std::make_shared<ReduceScatterFusion>());
other_pm->AddPass(std::make_shared<BroadcastFusion>());
other_pm->AddPass(std::make_shared<ParameterTransOpFusion>());
other_pm->AddPass(std::make_shared<RefreshParameterFormat>());

View File

@ -68,6 +68,13 @@ class BroadcastFusion : public CommunicationOpFusion {
explicit BroadcastFusion(size_t groups = 1) : CommunicationOpFusion("broadcast_fusion", kBroadcastOpName, groups) {}
~BroadcastFusion() override = default;
};
class ReduceScatterFusion : public CommunicationOpFusion {
public:
explicit ReduceScatterFusion(size_t groups = 1)
: CommunicationOpFusion("reduce_scatter_fusion", kReduceScatterOpName, groups) {}
~ReduceScatterFusion() override = default;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_