forked from mindspore-Ecosystem/mindspore
Revert "Eliminate AllReduce when the input is a constant"
This reverts commit f3a9fbdd78
.
This commit is contained in:
parent
b7d9d5e0cf
commit
7a7e499475
|
@ -81,8 +81,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
||||||
reset_defer_inline_ =
|
reset_defer_inline_ =
|
||||||
MakeSubstitution(std::make_shared<ResetDeferInline>(), "reset_defer_inline", IsValueNode<FuncGraph>);
|
MakeSubstitution(std::make_shared<ResetDeferInline>(), "reset_defer_inline", IsValueNode<FuncGraph>);
|
||||||
depend_value_elim_ = MakeSubstitution(std::make_shared<DependValueElim>(), "depend_value_elim", prim::kPrimDepend);
|
depend_value_elim_ = MakeSubstitution(std::make_shared<DependValueElim>(), "depend_value_elim", prim::kPrimDepend);
|
||||||
all_reduce_const_elim_ =
|
|
||||||
MakeSubstitution(std::make_shared<AllReduceConstElim>(), "reduce_all_const_elim", prim::kPrimAllReduce);
|
|
||||||
|
|
||||||
// Env Item Eliminate
|
// Env Item Eliminate
|
||||||
env_get_item_eliminate_ =
|
env_get_item_eliminate_ =
|
||||||
|
|
|
@ -50,7 +50,6 @@ class OptimizeIRPassLib {
|
||||||
SubstitutionPtr check_bprop_eliminate_;
|
SubstitutionPtr check_bprop_eliminate_;
|
||||||
SubstitutionPtr reset_defer_inline_;
|
SubstitutionPtr reset_defer_inline_;
|
||||||
SubstitutionPtr depend_value_elim_;
|
SubstitutionPtr depend_value_elim_;
|
||||||
SubstitutionPtr all_reduce_const_elim_;
|
|
||||||
|
|
||||||
// Env Item Eliminate
|
// Env Item Eliminate
|
||||||
SubstitutionPtr env_get_item_eliminate_;
|
SubstitutionPtr env_get_item_eliminate_;
|
||||||
|
|
|
@ -29,8 +29,6 @@
|
||||||
#include "frontend/optimizer/irpass.h"
|
#include "frontend/optimizer/irpass.h"
|
||||||
#include "frontend/optimizer/irpass/prim_eliminate.h"
|
#include "frontend/optimizer/irpass/prim_eliminate.h"
|
||||||
#include "frontend/optimizer/optimizer.h"
|
#include "frontend/optimizer/optimizer.h"
|
||||||
#include "utils/comm_manager.h"
|
|
||||||
#include "frontend/parallel/context.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
@ -205,57 +203,6 @@ class DependValueElim : public OptimizerCaller {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class AllReduceConstElim : public OptimizerCaller {
|
|
||||||
public:
|
|
||||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
||||||
PatternNode<AnfNodePtr> x;
|
|
||||||
auto pattern = PPrimitive(prim::kPrimAllReduce, x);
|
|
||||||
// If AllReduce takes contant value as input and values across devices are all the same(ensured by parallel mode)
|
|
||||||
if (pattern.TryCapture(node) && IsVNode(x.GetNode(node)) &&
|
|
||||||
(pattern.GetFuncGraph()->has_flag(parallel::AUTO_PARALLEL) ||
|
|
||||||
pattern.GetFuncGraph()->has_flag(parallel::SEMI_AUTO_PARALLEL))) {
|
|
||||||
auto cur_func_graph = pattern.GetFuncGraph();
|
|
||||||
// If reduce operation is sum, then multiply constant by number of devices, otherwise just return the contant
|
|
||||||
auto prim_cnode = pattern.GetOriginalNode();
|
|
||||||
MS_EXCEPTION_IF_NULL(prim_cnode);
|
|
||||||
auto primitive = GetCNodePrimitive(prim_cnode);
|
|
||||||
auto reduce_op = primitive->GetAttr("op");
|
|
||||||
auto group = primitive->GetAttr("group")->ToString();
|
|
||||||
// For sum operation, multiply constant tensor by number of devices
|
|
||||||
if (reduce_op->ToString() == "sum") {
|
|
||||||
unsigned int num_of_devices;
|
|
||||||
// Get number of devices
|
|
||||||
if (!CommManager::GetInstance().GetRankSize(group, &num_of_devices)) {
|
|
||||||
MS_LOG(EXCEPTION) << "Failed to get num of devices for group [" + group + "]";
|
|
||||||
}
|
|
||||||
// Multiply constant by number of devices then return
|
|
||||||
std::vector<AnfNodePtr> mul_inputs;
|
|
||||||
auto constant_node = x.GetNode(node);
|
|
||||||
MS_EXCEPTION_IF_NULL(constant_node);
|
|
||||||
auto constant_value_node = constant_node->cast<ValueNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(constant_value_node);
|
|
||||||
if (!constant_value_node->value()->isa<tensor::Tensor>()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Expect the constant input for AllReduce to be a Tensor. Got " +
|
|
||||||
constant_value_node->value()->ToString();
|
|
||||||
}
|
|
||||||
auto constant_tensor = constant_value_node->value()->cast<tensor::TensorPtr>();
|
|
||||||
auto tensor_dtype = constant_tensor->Dtype();
|
|
||||||
auto num_of_device_node = NewValueNode(std::make_shared<tensor::Tensor>((int64_t)num_of_devices, tensor_dtype));
|
|
||||||
// Multiply nodes
|
|
||||||
auto mul_prim = prim::GetPythonOps("tensor_mul", "mindspore.ops.functional");
|
|
||||||
MS_EXCEPTION_IF_NULL(mul_prim);
|
|
||||||
mul_inputs.push_back(NewValueNode(mul_prim));
|
|
||||||
mul_inputs.push_back(constant_node);
|
|
||||||
mul_inputs.push_back(num_of_device_node);
|
|
||||||
return cur_func_graph->NewCNode(mul_inputs);
|
|
||||||
} else {
|
|
||||||
return x.GetNode(node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace irpass
|
} // namespace irpass
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -128,7 +128,6 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||||
irpass.incorporate_env_getitem_switch_,
|
irpass.incorporate_env_getitem_switch_,
|
||||||
irpass.new_env_get_item_,
|
irpass.new_env_get_item_,
|
||||||
irpass.depend_value_elim_,
|
irpass.depend_value_elim_,
|
||||||
irpass.all_reduce_const_elim_,
|
|
||||||
});
|
});
|
||||||
opt::OptPassConfig a_after_grad = opt::OptPassConfig({
|
opt::OptPassConfig a_after_grad = opt::OptPassConfig({
|
||||||
irpass.inline_without_move_,
|
irpass.inline_without_move_,
|
||||||
|
|
Loading…
Reference in New Issue