From f3a9fbdd788b527da46c2fb49aab8870e1569cf4 Mon Sep 17 00:00:00 2001 From: BowenK Date: Mon, 17 Aug 2020 20:09:55 +0800 Subject: [PATCH] Eliminate AllReduce when the input is a constant --- mindspore/ccsrc/frontend/optimizer/irpass.cc | 2 + mindspore/ccsrc/frontend/optimizer/irpass.h | 1 + .../optimizer/irpass/special_op_eliminate.h | 53 +++++++++++++++++++ mindspore/ccsrc/pipeline/jit/pass.cc | 1 + 4 files changed, 57 insertions(+) diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index 92917483117..00817abff25 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -83,6 +83,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { reset_defer_inline_ = MakeSubstitution(std::make_shared(), "reset_defer_inline", IsValueNode); depend_value_elim_ = MakeSubstitution(std::make_shared(), "depend_value_elim", prim::kPrimDepend); + all_reduce_const_elim_ = + MakeSubstitution(std::make_shared(), "reduce_all_const_elim", prim::kPrimAllReduce); // Env Item Eliminate env_get_item_eliminate_ = diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index 835c4e5cf51..07fd1ce4f85 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -50,6 +50,7 @@ class OptimizeIRPassLib { SubstitutionPtr check_bprop_eliminate_; SubstitutionPtr reset_defer_inline_; SubstitutionPtr depend_value_elim_; + SubstitutionPtr all_reduce_const_elim_; // Env Item Eliminate SubstitutionPtr env_get_item_eliminate_; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h index 9ad50b8b33f..a25e6a104bd 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h @@ -29,6 +29,8 @@ #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/irpass/prim_eliminate.h" #include "frontend/optimizer/optimizer.h" +#include "utils/comm_manager.h" +#include "frontend/parallel/context.h" namespace mindspore { namespace opt { @@ -203,6 +205,57 @@ class DependValueElim : public OptimizerCaller { return nullptr; } }; + +class AllReduceConstElim : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + PatternNode x; + auto pattern = PPrimitive(prim::kPrimAllReduce, x); + // If AllReduce takes contant value as input and values across devices are all the same(ensured by parallel mode) + if (pattern.TryCapture(node) && IsVNode(x.GetNode(node)) && + (pattern.GetFuncGraph()->has_flag(parallel::AUTO_PARALLEL) || + pattern.GetFuncGraph()->has_flag(parallel::SEMI_AUTO_PARALLEL))) { + auto cur_func_graph = pattern.GetFuncGraph(); + // If reduce operation is sum, then multiply constant by number of devices, otherwise just return the contant + auto prim_cnode = pattern.GetOriginalNode(); + MS_EXCEPTION_IF_NULL(prim_cnode); + auto primitive = GetCNodePrimitive(prim_cnode); + auto reduce_op = primitive->GetAttr("op"); + auto group = primitive->GetAttr("group")->ToString(); + // For sum operation, multiply constant tensor by number of devices + if (reduce_op->ToString() == "sum") { + unsigned int num_of_devices; + // Get number of devices + if (!CommManager::GetInstance().GetRankSize(group, &num_of_devices)) { + MS_LOG(EXCEPTION) << "Failed to get num of devices for group [" + group + "]"; + } + // Multiply constant by number of devices then return + std::vector mul_inputs; + auto constant_node = x.GetNode(node); + MS_EXCEPTION_IF_NULL(constant_node); + auto constant_value_node = constant_node->cast(); + MS_EXCEPTION_IF_NULL(constant_value_node); + if (!constant_value_node->value()->isa()) { + MS_LOG(EXCEPTION) << "Expect the constant input for AllReduce to be a Tensor. Got " + + constant_value_node->value()->ToString(); + } + auto constant_tensor = constant_value_node->value()->cast(); + auto tensor_dtype = constant_tensor->Dtype(); + auto num_of_device_node = NewValueNode(std::make_shared((int64_t)num_of_devices, tensor_dtype)); + // Multiply nodes + auto mul_prim = prim::GetPythonOps("tensor_mul", "mindspore.ops.functional"); + MS_EXCEPTION_IF_NULL(mul_prim); + mul_inputs.push_back(NewValueNode(mul_prim)); + mul_inputs.push_back(constant_node); + mul_inputs.push_back(num_of_device_node); + return cur_func_graph->NewCNode(mul_inputs); + } else { + return x.GetNode(node); + } + } + return nullptr; + } +}; } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 3a5fa3128b4..f862938a419 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -132,6 +132,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.incorporate_env_getitem_switch_, irpass.new_env_get_item_, irpass.depend_value_elim_, + irpass.all_reduce_const_elim_, }); opt::OptPassConfig a_3 = opt::OptPassConfig({ irpass.arithmetic_simplify2_,