forked from mindspore-Ecosystem/mindspore
!1226 Add reduce_scatter communication op fusion pass
Merge pull request !1226 from YuJianfeng/reduce_scatter
This commit is contained in:
commit
c77e82b811
|
@ -287,6 +287,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
|
|||
auto other_pm = std::make_shared<PassManager>("other_pm");
|
||||
other_pm->AddPass(std::make_shared<AllReduceFusion>());
|
||||
other_pm->AddPass(std::make_shared<AllGatherFusion>());
|
||||
other_pm->AddPass(std::make_shared<ReduceScatterFusion>());
|
||||
other_pm->AddPass(std::make_shared<BroadcastFusion>());
|
||||
other_pm->AddPass(std::make_shared<ParameterTransOpFusion>());
|
||||
other_pm->AddPass(std::make_shared<RefreshParameterFormat>());
|
||||
|
|
|
@ -68,6 +68,13 @@ class BroadcastFusion : public CommunicationOpFusion {
|
|||
explicit BroadcastFusion(size_t groups = 1) : CommunicationOpFusion("broadcast_fusion", kBroadcastOpName, groups) {}
|
||||
~BroadcastFusion() override = default;
|
||||
};
|
||||
|
||||
class ReduceScatterFusion : public CommunicationOpFusion {
|
||||
public:
|
||||
explicit ReduceScatterFusion(size_t groups = 1)
|
||||
: CommunicationOpFusion("reduce_scatter_fusion", kReduceScatterOpName, groups) {}
|
||||
~ReduceScatterFusion() override = default;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_
|
||||
|
|
Loading…
Reference in New Issue