forked from mindspore-Ecosystem/mindspore
!27024 add allreduce fusion by size
Merge pull request !27024 from jiahongQian/master
This commit is contained in:
commit
2d23b698a6
|
@ -24,6 +24,7 @@
|
|||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "utils/hash_map.h"
|
||||
#include "ir/anf.h"
|
||||
|
@ -38,7 +39,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ad {
|
||||
using Registry = mindspore::HashMap<PrimitivePtr, FuncGraphPtr, PrimitiveHasher, PrimitiveTotalEqual>;
|
||||
using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr, PrimitiveHasher, PrimitiveTotalEqual>;
|
||||
class KPrim;
|
||||
extern KPrim g_k_prims;
|
||||
class DFunctor;
|
||||
|
|
|
@ -43,7 +43,7 @@ using PrimBpropOptGraphInfoPtr = std::shared_ptr<PrimBpropOptGraphInfo>;
|
|||
|
||||
using PrimBpropOptGraphLevel2InfoPtr = std::shared_ptr<PrimBpropOptGraphLevel2Info>;
|
||||
|
||||
using PrimBpropCache = mindspore::HashMap<PrimitivePtr, PrimBpropOptGraphInfoPtr, PrimitiveHasher, PrimitiveTotalEqual>;
|
||||
using PrimBpropCache = std::unordered_map<PrimitivePtr, PrimBpropOptGraphInfoPtr, PrimitiveHasher, PrimitiveTotalEqual>;
|
||||
|
||||
using TupleListKey = std::pair<PrimitivePtr, abstract::AbstractBasePtrList>;
|
||||
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
#include <memory>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
#include "utils/hash_set.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "frontend/parallel/costmodel_context.h"
|
||||
|
@ -28,375 +30,53 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
mindspore::HashSet<CNodePtr> FindCNodesWithPara(const AnfNodePtr ¶, uint64_t recursive_times = 0) {
|
||||
if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
|
||||
MS_LOG(EXCEPTION) << "FindCNodesWithPara exceeds max recursive call times! Max recursive call times is "
|
||||
<< MAX_RECURSIVE_CALL_TIMES;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(para);
|
||||
MS_EXCEPTION_IF_NULL(para->func_graph());
|
||||
FuncGraphManagerPtr manager = para->func_graph()->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto node_set = manager->node_users()[para];
|
||||
mindspore::HashSet<CNodePtr> cnode_set;
|
||||
for (auto &node_pair : node_set) {
|
||||
auto cnode = node_pair.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||
continue;
|
||||
}
|
||||
auto node_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_EXCEPTION_IF_NULL(node_prim);
|
||||
if (node_prim->name() == DEPEND && node_pair.second != 1) {
|
||||
continue;
|
||||
}
|
||||
if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
|
||||
(void)cnode_set.emplace(cnode);
|
||||
} else {
|
||||
auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1);
|
||||
for (auto &cnode_sub : cnode_set_sub) {
|
||||
(void)cnode_set.emplace(cnode_sub);
|
||||
}
|
||||
}
|
||||
}
|
||||
return cnode_set;
|
||||
}
|
||||
|
||||
Status AllreduceFusion::AddNodeToGraph() {
|
||||
const auto ¶meters = root_graph_->parameters();
|
||||
for (auto ¶meter : parameters) {
|
||||
if (!ParameterRequireGrad(parameter)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode_set = FindCNodesWithPara(parameter);
|
||||
if (cnode_set.empty()) {
|
||||
continue;
|
||||
}
|
||||
for (auto &cnode : cnode_set) {
|
||||
MS_LOG(DEBUG) << "AddNode " << cnode->DebugString();
|
||||
if (allreduce_graph_.AddNode(cnode, parameter) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "AddNode failed! cnode: " << cnode->DebugString();
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint64_t recursive_times) const {
|
||||
if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
|
||||
MS_LOG(EXCEPTION) << "FindCNode exceeds max recursive call times! Max recursive call times is "
|
||||
<< MAX_RECURSIVE_CALL_TIMES;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(from);
|
||||
mindspore::HashMap<CNodePtr, double> cnode_dist;
|
||||
if (!from->isa<CNode>()) {
|
||||
return cnode_dist;
|
||||
}
|
||||
auto cnode = from->cast<CNodePtr>();
|
||||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||
return cnode_dist;
|
||||
}
|
||||
|
||||
auto operator_info = cnode->user_data<OperatorInfo>();
|
||||
MS_LOG(DEBUG) << "cnode " << cnode->ToString() << " IsParallelCareNode: " << IsParallelCareNode(cnode)
|
||||
<< " operator_info: " << (operator_info != nullptr);
|
||||
|
||||
if (IsParallelCareNode(cnode) && (operator_info != nullptr)) {
|
||||
auto cost = operator_info->GetForwardMemoryCostFromCNode();
|
||||
MS_LOG(DEBUG) << "cnode " << cnode->DebugString() << " cost: " << cost;
|
||||
|
||||
if (allreduce_graph_.NodeInGraph(cnode)) {
|
||||
cnode_dist[cnode] = cost;
|
||||
return cnode_dist;
|
||||
} else {
|
||||
auto cnode_dist_next = FindNextCNodes(cnode, recursive_times + 1);
|
||||
for (auto &ele_next : cnode_dist_next) {
|
||||
cnode_dist[ele_next.first] = cost + ele_next.second;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto cnode_dist_next = FindNextCNodes(cnode);
|
||||
for (auto &ele : cnode_dist_next) {
|
||||
cnode_dist[ele.first] = ele.second;
|
||||
}
|
||||
}
|
||||
return cnode_dist;
|
||||
}
|
||||
|
||||
CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr &from, uint64_t recursive_times) const {
|
||||
if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
|
||||
MS_LOG(EXCEPTION) << "FindNextCNodes exceeds max recursive call times! Max recursive call times is "
|
||||
<< MAX_RECURSIVE_CALL_TIMES;
|
||||
}
|
||||
const auto &from_inputs = from->inputs();
|
||||
mindspore::HashMap<CNodePtr, double> dist_map;
|
||||
MS_LOG(DEBUG) << "from cnode " << from->DebugString() << " has " << from_inputs.size() << " inputs";
|
||||
for (auto &input_node : from_inputs) {
|
||||
auto cnode_dist = FindCNode(input_node, recursive_times + 1);
|
||||
for (auto &ele : cnode_dist) {
|
||||
(void)dist_map.emplace(ele);
|
||||
}
|
||||
}
|
||||
return dist_map;
|
||||
}
|
||||
|
||||
Status AllreduceFusion::AddEdgeToGraph() {
|
||||
mindspore::HashMap<CNodePtr, int64_t> cnode_state_map;
|
||||
const auto &cnodes = allreduce_graph_.cnode_set();
|
||||
for (auto &cnode : cnodes) {
|
||||
cnode_state_map[cnode] = 0;
|
||||
}
|
||||
const auto &head_cnode = allreduce_graph_.head_cnode();
|
||||
std::queue<CNodePtr> cnode_queue;
|
||||
cnode_queue.emplace(head_cnode);
|
||||
cnode_state_map[head_cnode] = 1;
|
||||
|
||||
while (!cnode_queue.empty()) {
|
||||
const auto cur_cnode = cnode_queue.front();
|
||||
cnode_queue.pop();
|
||||
cnode_state_map[cur_cnode] = 2;
|
||||
auto next = FindNextCNodes(cur_cnode);
|
||||
for (auto &ele : next) {
|
||||
auto &cnode = ele.first;
|
||||
auto &dist = ele.second;
|
||||
if (cnode_state_map[cnode] == 0) {
|
||||
cnode_queue.emplace(cnode);
|
||||
cnode_state_map[cnode] = 1;
|
||||
}
|
||||
if (allreduce_graph_.AddEdge(cur_cnode, cnode, dist) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "AddEdge error";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(DEBUG) << "from " << cur_cnode->DebugString() << ", to " << cnode->DebugString() << " dist " << dist;
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<CNodePtr> FindMirror(const AnfNodePtr ¶, uint64_t recursive_times = 0) {
|
||||
if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
|
||||
MS_LOG(EXCEPTION) << "FindMirror exceeds max recursive call times! Max recursive call times is "
|
||||
<< MAX_RECURSIVE_CALL_TIMES;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(para);
|
||||
MS_EXCEPTION_IF_NULL(para->func_graph());
|
||||
FuncGraphManagerPtr manager = para->func_graph()->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
AnfNodeIndexSet node_set = manager->node_users()[para];
|
||||
std::vector<CNodePtr> cnode_list;
|
||||
for (auto &node_pair : node_set) {
|
||||
auto cnode = node_pair.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||
continue;
|
||||
}
|
||||
auto node_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_EXCEPTION_IF_NULL(node_prim);
|
||||
if (node_prim->name() == CAST) {
|
||||
auto mirror_cnodes = FindMirror(node_pair.first, recursive_times + 1);
|
||||
if (mirror_cnodes.empty()) {
|
||||
MS_LOG(WARNING) << "mirror node after cast not found";
|
||||
continue;
|
||||
}
|
||||
if (mirror_cnodes.size() > 1) {
|
||||
MS_LOG(EXCEPTION) << "mirror node after cast number is not 1";
|
||||
}
|
||||
cnode_list.emplace_back(mirror_cnodes[0]);
|
||||
}
|
||||
if (node_prim->name() == MIRROR_OPERATOR) {
|
||||
cnode_list.emplace_back(cnode);
|
||||
}
|
||||
}
|
||||
return cnode_list;
|
||||
}
|
||||
|
||||
void SetMirrorFusion(const CNodePtr &mirror_cnode, int64_t fusion, const std::string ¶meter_name) {
|
||||
MS_EXCEPTION_IF_NULL(mirror_cnode);
|
||||
MS_LOG(DEBUG) << "Set Mirror " << mirror_cnode->DebugString() << " fusion " << fusion;
|
||||
auto node_prim = GetValueNode<PrimitivePtr>(mirror_cnode->input(0));
|
||||
auto old_value_ptr = node_prim->GetAttr(FUSION);
|
||||
if (old_value_ptr != nullptr) {
|
||||
if (old_value_ptr->isa<Int64Imm>()) {
|
||||
int64_t old_value = old_value_ptr->cast<Int64ImmPtr>()->value();
|
||||
if (old_value < fusion) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
(void)node_prim->AddAttr(FUSION, MakeValue(std::make_shared<Int64Imm>(fusion)));
|
||||
(void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared<StringImm>(parameter_name)));
|
||||
}
|
||||
|
||||
Status FindMirrorAndSetFusion(const AnfNodePtr ¶, int64_t fusion) {
|
||||
auto mirror_cnodes = FindMirror(para);
|
||||
if (mirror_cnodes.empty()) {
|
||||
MS_LOG(WARNING) << para->ToString() << " 0 Mirror CNode found.";
|
||||
return SUCCESS;
|
||||
}
|
||||
if (mirror_cnodes.size() > 2) {
|
||||
for (auto &mirror_cnode_1 : mirror_cnodes) {
|
||||
MS_EXCEPTION_IF_NULL(mirror_cnode_1);
|
||||
MS_LOG(INFO) << mirror_cnode_1->DebugString();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(para);
|
||||
MS_LOG(ERROR) << para->ToString() << " FindMirror is more than 2. " << mirror_cnodes.size()
|
||||
<< "Mirror CNode found.";
|
||||
return FAILED;
|
||||
}
|
||||
for (auto &mirror_cnode : mirror_cnodes) {
|
||||
auto parameter_name = ParameterName(para);
|
||||
SetMirrorFusion(mirror_cnode, fusion, parameter_name);
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status FindMirrorAndSetFusion(const std::vector<AnfNodePtr> ¶s, int64_t fusion) {
|
||||
for (auto ¶m_node : paras) {
|
||||
if (FindMirrorAndSetFusion(param_node, fusion) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "FindMirrorAndSetFusion failed";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status AllreduceFusion::SetFusion(const std::vector<double> &cost_map) {
|
||||
if (cost_map.size() < 2) {
|
||||
MS_LOG(ERROR) << "cost_map must has at least 2 items, cost_map size is " << cost_map.size();
|
||||
return FAILED;
|
||||
}
|
||||
Status AllreduceFusion::SetFusionBySize(const CNodePtr &ret, int64_t threshold) {
|
||||
auto filter = [](const AnfNodePtr &node) { return !IsPrimitiveCNode(node, prim::kPrimMirror); };
|
||||
auto todo = DeepScopedGraphSearchWithFilter(ret, AlwaysInclude, filter);
|
||||
auto temp = threshold;
|
||||
int64_t fusion = 1;
|
||||
for (auto cost_iter = cost_map.end() - 1; cost_iter != cost_map.begin(); --cost_iter) {
|
||||
auto paras = allreduce_graph_.GetParaByCost(*(cost_iter - 1), *cost_iter);
|
||||
if (FindMirrorAndSetFusion(paras, fusion) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "FindMirrorAndSetFusion failed";
|
||||
return FAILED;
|
||||
bool init = true;
|
||||
for (auto &node : todo) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode->input(1)->Shape() == nullptr) continue;
|
||||
auto input_shapes = GetNodeShape(cnode->input(1));
|
||||
int64_t input_size = std::accumulate(input_shapes[0].begin(), input_shapes[0].end(), 1, std::multiplies<int64_t>());
|
||||
FuncGraphPtr func_graph = cnode->func_graph();
|
||||
std::pair<AnfNodePtr, bool> param_node_pair = FindParameter(cnode->input(1), func_graph);
|
||||
if (!param_node_pair.first) {
|
||||
continue;
|
||||
}
|
||||
fusion++;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<double> AllreduceFusion::GenerateCostMap(int64_t fusion_times, double tail_percent) const {
|
||||
double offset = allreduce_graph_.max() * (1 - tail_percent) / (fusion_times - 1);
|
||||
MS_LOG(DEBUG) << "max = " << allreduce_graph_.max() << ", offset = " << offset;
|
||||
std::vector<double> cost_map;
|
||||
double begin = 0;
|
||||
for (auto i = 0; i < fusion_times - 1; i++) {
|
||||
cost_map.push_back(begin);
|
||||
begin += offset;
|
||||
}
|
||||
cost_map.push_back(allreduce_graph_.max() * (1 - tail_percent));
|
||||
cost_map.push_back(allreduce_graph_.max());
|
||||
MS_LOG(DEBUG) << "cost_map = " << cost_map;
|
||||
return cost_map;
|
||||
}
|
||||
|
||||
Status AllreduceFusion::SetFusionByBackwardCompTime() {
|
||||
auto fusion_times = CostModelContext::GetInstance()->costmodel_allreduce_fusion_times();
|
||||
if (fusion_times < 2) {
|
||||
MS_LOG(INFO) << "'costmodel_allreduce_fusion_times' is " << fusion_times << ". Bypass ProcessAllreduceFusion";
|
||||
return SUCCESS;
|
||||
}
|
||||
auto tail_percent = CostModelContext::GetInstance()->costmodel_allreduce_fusion_tail_percent();
|
||||
if (tail_percent < 0 || tail_percent >= 1) {
|
||||
MS_LOG(INFO) << "'costmodel_allreduce_fusion_tail_percent' is " << tail_percent
|
||||
<< ". Bypass ProcessAllreduceFusion";
|
||||
return SUCCESS;
|
||||
}
|
||||
const auto cost_map = GenerateCostMap(fusion_times, tail_percent);
|
||||
MS_LOG(DEBUG) << "AllreduceGraph GenerateCostMap succeed.";
|
||||
if (SetFusion(cost_map) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "SetFusion failed.";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(DEBUG) << "AllreduceGraph SetFusion succeed.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status AllreduceFusion::GetSetFusionByBackwardCompAndAllreduceTimeParams() {
|
||||
tail_time_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_tail_time();
|
||||
if (tail_time_ <= 0) {
|
||||
MS_LOG(INFO) << "'costmodel_allreduce_tail_time' is " << tail_time_ << ". Bypass ProcessAllreduceFusion";
|
||||
return FAILED;
|
||||
}
|
||||
allreduce_inherent_time_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_allreduce_inherent_time();
|
||||
if (allreduce_inherent_time_ <= 0) {
|
||||
MS_LOG(INFO) << "'costmodel_allreduce_fusion_allreduce_inherent_time' is " << allreduce_inherent_time_
|
||||
<< ". Bypass ProcessAllreduceFusion";
|
||||
return FAILED;
|
||||
}
|
||||
if (tail_time_ <= allreduce_inherent_time_) {
|
||||
MS_LOG(INFO) << "'costmodel_allreduce_tail_time' is " << tail_time_
|
||||
<< "'costmodel_allreduce_fusion_allreduce_inherent_time' is " << allreduce_inherent_time_
|
||||
<< ".tail_time is not more than allreduce_inherent_time. Bypass ProcessAllreduceFusion";
|
||||
return FAILED;
|
||||
}
|
||||
allreduce_bandwidth_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_allreduce_bandwidth();
|
||||
if (allreduce_bandwidth_ <= 0) {
|
||||
MS_LOG(INFO) << "'costmodel_allreduce_fusion_allreduce_bandwidth' is " << allreduce_bandwidth_
|
||||
<< ". Bypass ProcessAllreduceFusion";
|
||||
return FAILED;
|
||||
}
|
||||
computation_time_parameter_ =
|
||||
CostModelContext::GetInstance()->costmodel_allreduce_fusion_computation_time_parameter();
|
||||
if (computation_time_parameter_ <= 0) {
|
||||
MS_LOG(INFO) << "'costmodel_allreduce_fusion_computation_time_parameter' is " << computation_time_parameter_
|
||||
<< ". Bypass ProcessAllreduceFusion";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status AllreduceFusion::SetFusionByBackwardCompAndAllreduceTime() {
|
||||
if (GetSetFusionByBackwardCompAndAllreduceTimeParams() != SUCCESS) {
|
||||
MS_LOG(ERROR) << "GetSetFusionByBackwardCompAndAllreduceTimeParams failed!";
|
||||
return FAILED;
|
||||
}
|
||||
allreduce_graph_.SortArnode();
|
||||
if (allreduce_graph_.RemoveExtraParas() != SUCCESS) {
|
||||
MS_LOG(ERROR) << "RemoveExtraParas failed!";
|
||||
return FAILED;
|
||||
}
|
||||
double para_size = (tail_time_ - allreduce_inherent_time_) / allreduce_bandwidth_;
|
||||
double to_cost = allreduce_graph_.max();
|
||||
int64_t fusion = 1;
|
||||
while (to_cost != 0) {
|
||||
MS_LOG(INFO) << "to_cost: " << to_cost << " para_size: " << para_size;
|
||||
auto node_cost_pair = allreduce_graph_.GetParaByParaSize(to_cost, para_size);
|
||||
MS_LOG(INFO) << "para size: " << node_cost_pair.first.size() << " from_cost: " << node_cost_pair.second;
|
||||
auto paras = node_cost_pair.first;
|
||||
if (FindMirrorAndSetFusion(paras, fusion) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "FindMirrorAndSetFusion failed";
|
||||
return FAILED;
|
||||
auto parameter_name = ParameterName(param_node_pair.first);
|
||||
if (input_size < temp) {
|
||||
temp -= input_size;
|
||||
} else {
|
||||
temp = threshold;
|
||||
fusion++;
|
||||
}
|
||||
fusion++;
|
||||
para_size = ((to_cost - node_cost_pair.second) * computation_time_parameter_ - allreduce_inherent_time_) /
|
||||
allreduce_bandwidth_;
|
||||
to_cost = node_cost_pair.second;
|
||||
if (init) {
|
||||
SetMirrorFusion(cnode, 1, parameter_name);
|
||||
} else {
|
||||
SetMirrorFusion(cnode, fusion, parameter_name);
|
||||
}
|
||||
init = false;
|
||||
}
|
||||
MS_LOG(DEBUG) << "AllreduceGraph SetFusionByBackwardCompAndAllreduceTime succeed.";
|
||||
MS_LOG(INFO) << "Allreduce fusion by size succeed.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status AllreduceFusion::SetFusionByAlgorithm(int64_t algorithm) {
|
||||
if (algorithm == 1) {
|
||||
return SetFusionByBackwardCompTime();
|
||||
}
|
||||
return SetFusionByBackwardCompAndAllreduceTime();
|
||||
}
|
||||
|
||||
Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) {
|
||||
if (ret == nullptr) {
|
||||
MS_LOG(ERROR) << "ret is nullptr.";
|
||||
return FAILED;
|
||||
}
|
||||
auto algorithm = CostModelContext::GetInstance()->costmodel_allreduce_fusion_algorithm();
|
||||
if (algorithm < 1 || algorithm > 2) {
|
||||
MS_LOG(INFO) << "'costmodel_allreduce_fusion_algorithm' is " << algorithm << ". Bypass ProcessAllreduceFusion";
|
||||
return SUCCESS;
|
||||
}
|
||||
ret_ = ret;
|
||||
root_graph_ = ret_->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(root_graph_);
|
||||
|
@ -409,27 +89,20 @@ Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) {
|
|||
MS_EXCEPTION_IF_NULL(forward_graph);
|
||||
forward_ret_ = forward_graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(forward_ret_);
|
||||
|
||||
if (allreduce_graph_.set_head_cnode(forward_ret_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "AllreduceGraph set_head_cnode failed.";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(DEBUG) << "AllreduceGraph set_head_cnode succeed.";
|
||||
if (AddNodeToGraph() != SUCCESS) {
|
||||
MS_LOG(ERROR) << "AddNodeToGraph failed.";
|
||||
auto threshold = ParallelContext::GetInstance()->fusion_threshold_mb() * 1024 * 1024 / 4.0;
|
||||
if (threshold > 0) {
|
||||
if (SetFusionBySize(ret, threshold) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "SetFusionBySize failed.";
|
||||
return FAILED;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "The threshold of SetFusionBySize must be larger than 0, but got " << threshold << ".";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(DEBUG) << "AllreduceGraph AddNodeToGraph succeed.";
|
||||
if (AddEdgeToGraph() != SUCCESS) {
|
||||
MS_LOG(ERROR) << "AddNodeToGraph failed.";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(DEBUG) << "AllreduceGraph AddEdgeToGraph succeed.";
|
||||
if (SetFusionByAlgorithm(algorithm) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "SetFusionByAlgorithm failed.";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(DEBUG) << "AllreduceGraph SetFusionByAlgorithm succeed.";
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace parallel
|
||||
|
|
|
@ -27,8 +27,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
using CNodeCostMap = mindspore::HashMap<CNodePtr, double>;
|
||||
|
||||
constexpr int64_t DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM = 0;
|
||||
constexpr int64_t DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES = 0;
|
||||
constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT = 0.1;
|
||||
|
@ -46,31 +44,17 @@ class AllreduceFusion {
|
|||
forward_ret_(nullptr),
|
||||
root_graph_(nullptr),
|
||||
tail_time_(0),
|
||||
allreduce_inherent_time_(0),
|
||||
allreduce_bandwidth_(0),
|
||||
computation_time_parameter_(0) {}
|
||||
virtual ~AllreduceFusion() = default;
|
||||
Status ProcessAllreduceFusion(const CNodePtr &ret);
|
||||
|
||||
private:
|
||||
Status AddNodeToGraph();
|
||||
CNodeCostMap FindCNode(const AnfNodePtr &from, uint64_t recursive_times = 0) const;
|
||||
CNodeCostMap FindNextCNodes(const CNodePtr &from, uint64_t recursive_times = 0) const;
|
||||
Status AddEdgeToGraph();
|
||||
std::vector<double> GenerateCostMap(int64_t fusion_times, double tail_percent) const;
|
||||
Status SetFusion(const std::vector<double> &cost_map);
|
||||
Status SetFusionByAlgorithm(int64_t algorithm);
|
||||
Status SetFusionByBackwardCompTime();
|
||||
Status SetFusionByBackwardCompAndAllreduceTime();
|
||||
Status GetSetFusionByBackwardCompAndAllreduceTimeParams();
|
||||
|
||||
Status SetFusionBySize(const CNodePtr &ret, int64_t threshold);
|
||||
AllreduceGraph allreduce_graph_;
|
||||
CNodePtr ret_;
|
||||
CNodePtr forward_ret_;
|
||||
FuncGraphPtr root_graph_;
|
||||
double tail_time_;
|
||||
double allreduce_inherent_time_;
|
||||
double allreduce_bandwidth_;
|
||||
double computation_time_parameter_;
|
||||
};
|
||||
} // namespace parallel
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "frontend/parallel/context.h"
|
||||
#include "frontend/parallel/graph_util/graph_info.h"
|
||||
#include "frontend/parallel/status.h"
|
||||
#include "frontend/parallel/step_parallel.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -32,13 +33,15 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti
|
|||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
|
||||
bool enable_all_reduce_fusion = ParallelContext::GetInstance()->enable_all_reduce_fusion();
|
||||
auto graph_set = ForwardGraph(root);
|
||||
// assume no change to graph
|
||||
bool changes = false;
|
||||
// control whether use model_parallel mode
|
||||
if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) ||
|
||||
(!enable_all_reduce_fusion) || (root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY))) {
|
||||
(!enable_all_reduce_fusion) || (root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY)) || graph_set.size() < 1) {
|
||||
return changes;
|
||||
}
|
||||
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
auto start_time = std::chrono::steady_clock::now();
|
||||
#else
|
||||
|
@ -47,7 +50,7 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti
|
|||
}, end_time{0};
|
||||
(void)gettimeofday(&start_time, nullptr);
|
||||
#endif
|
||||
MS_LOG(INFO) << "Now entering allreduce fusion";
|
||||
MS_LOG(INFO) << "Now entering allreduce fusion by size, and allreduce fusion before will be overlapped!";
|
||||
DumpGraph(root, std::string(ALLREDUCE_FUSION_BEGIN));
|
||||
|
||||
pipeline::ResourceBasePtr res = optimizer->resource();
|
||||
|
|
|
@ -36,6 +36,8 @@ std::vector<std::string> STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECUR
|
|||
std::vector<std::string> COMMUNI_PARALLEL_MODE_LIST = {ALL_GROUP_PARALLEL, SAME_SERVER_GROUP_PARALLEL,
|
||||
NO_GROUP_PARALLEL};
|
||||
|
||||
std::vector<std::string> FUSION_MODE_LIST = {FUSION_AUTO, FUSION_SIZE, FUSION_INDEX};
|
||||
|
||||
std::shared_ptr<ParallelContext> ParallelContext::inst_context_ = nullptr;
|
||||
|
||||
std::shared_ptr<ParallelContext> ParallelContext::GetInstance() {
|
||||
|
@ -60,7 +62,7 @@ void ParallelContext::Reset() {
|
|||
parallel_mode_ = STAND_ALONE;
|
||||
parameter_broadcast_ = false;
|
||||
parameter_broadcast_is_set_ = false;
|
||||
enable_all_reduce_fusion_ = false;
|
||||
enable_all_reduce_fusion_ = true;
|
||||
strategy_ckpt_load_file_ = "";
|
||||
strategy_ckpt_save_file_ = "";
|
||||
enable_parallel_optimizer_ = false;
|
||||
|
@ -76,6 +78,9 @@ void ParallelContext::Reset() {
|
|||
grad_accumulation_shard_ = true;
|
||||
sharding_propagation_ = false;
|
||||
dataset_strategy_.clear();
|
||||
fusion_threshold_mb_ = FUSUION_THRESHOLD;
|
||||
fusion_threshold_is_set_ = true;
|
||||
fusion_mode_ = FUSION_AUTO;
|
||||
}
|
||||
|
||||
void ParallelContext::set_device_num(int64_t device_num) {
|
||||
|
@ -83,6 +88,22 @@ void ParallelContext::set_device_num(int64_t device_num) {
|
|||
device_num_is_set_ = true;
|
||||
}
|
||||
|
||||
void ParallelContext::set_fusion_threshold_mb(int64_t fusion_threshold) {
|
||||
fusion_threshold_mb_ = fusion_threshold;
|
||||
fusion_threshold_is_set_ = true;
|
||||
enable_all_reduce_fusion_ = true;
|
||||
}
|
||||
|
||||
bool ParallelContext::set_fusion_mode(const std::string &fusion_mode) {
|
||||
auto iter = std::find(FUSION_MODE_LIST.begin(), FUSION_MODE_LIST.end(), fusion_mode);
|
||||
if (iter == FUSION_MODE_LIST.end()) {
|
||||
MS_LOG(INFO) << "Invalid fusion mode:" << fusion_mode;
|
||||
return false;
|
||||
}
|
||||
fusion_mode_ = fusion_mode;
|
||||
return true;
|
||||
}
|
||||
|
||||
void ParallelContext::set_global_rank(int64_t global_rank) {
|
||||
global_rank_ = global_rank;
|
||||
global_rank_is_set_ = true;
|
||||
|
|
|
@ -52,6 +52,11 @@ constexpr char SAME_SERVER_GROUP_PARALLEL[] = "same_server_group_parallel";
|
|||
constexpr char NO_GROUP_PARALLEL[] = "no_group_parallel";
|
||||
|
||||
constexpr char IS_FIRST_ITERATION[] = "is_first_iteration";
|
||||
|
||||
constexpr char FUSION_AUTO[] = "auto";
|
||||
constexpr char FUSION_SIZE[] = "size";
|
||||
constexpr char FUSION_INDEX[] = "index";
|
||||
constexpr int64_t FUSUION_THRESHOLD = 64;
|
||||
class ParallelContext {
|
||||
public:
|
||||
~ParallelContext() = default;
|
||||
|
@ -78,6 +83,11 @@ class ParallelContext {
|
|||
void set_device_num(int64_t device_num);
|
||||
int64_t device_num() const { return device_num_; }
|
||||
|
||||
void set_fusion_threshold_mb(int64_t fusion_threshold);
|
||||
int64_t fusion_threshold_mb() const { return fusion_threshold_mb_; }
|
||||
bool set_fusion_mode(const std::string &fusion_mode);
|
||||
std::string get_fusion_mode() const { return fusion_mode_; }
|
||||
|
||||
void set_pipeline_stage_split_num(const int64_t stages);
|
||||
int64_t pipeline_stage_split_num() const { return pipeline_stage_split_num_; }
|
||||
|
||||
|
@ -159,6 +169,7 @@ class ParallelContext {
|
|||
bool gradient_fp32_sync_;
|
||||
bool loss_repeated_mean_;
|
||||
int64_t device_num_;
|
||||
int64_t fusion_threshold_mb_;
|
||||
int64_t global_rank_;
|
||||
int64_t grad_accumulation_step_;
|
||||
std::string parallel_mode_;
|
||||
|
@ -166,6 +177,7 @@ class ParallelContext {
|
|||
int64_t pipeline_stage_split_num_;
|
||||
bool parameter_broadcast_;
|
||||
bool device_num_is_set_;
|
||||
bool fusion_threshold_is_set_;
|
||||
bool global_rank_is_set_;
|
||||
bool parameter_broadcast_is_set_;
|
||||
bool enable_all_reduce_fusion_;
|
||||
|
@ -186,6 +198,7 @@ class ParallelContext {
|
|||
bool dataset_repeat_dim_right_ = false;
|
||||
bool hccl_test_available_ = false;
|
||||
bool sharding_propagation_;
|
||||
std::string fusion_mode_;
|
||||
};
|
||||
|
||||
} // namespace parallel
|
||||
|
|
|
@ -148,6 +148,10 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("set_hccl_test_avaible", &ParallelContext::set_hccl_test_available, "Set hccl test available.")
|
||||
.def("set_device_num", &ParallelContext::set_device_num, "Set device num.")
|
||||
.def("get_device_num_is_set", &ParallelContext::device_num_is_set, "Get device num is set.")
|
||||
.def("set_fusion_threshold_mb", &ParallelContext::set_fusion_threshold_mb, "Set fusion threshold.")
|
||||
.def("fusion_threshold_mb", &ParallelContext::fusion_threshold_mb, "Get fusion threshold.")
|
||||
.def("set_fusion_mode", &ParallelContext::set_fusion_mode, "Get fusion mode.")
|
||||
.def("get_fusion_mode", &ParallelContext::get_fusion_mode, "Get fusion mode.")
|
||||
.def("get_global_rank", &ParallelContext::global_rank, "Get global rank.")
|
||||
.def("set_global_rank", &ParallelContext::set_global_rank, "Set global rank.")
|
||||
.def("get_grad_accumulation_shard", &ParallelContext::grad_accumulation_shard, "Get grad_accumulation_shard.")
|
||||
|
|
|
@ -373,7 +373,7 @@ def _context():
|
|||
auto_parallel_search_mode=str, search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
||||
all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int,
|
||||
parallel_optimizer_config=dict)
|
||||
parallel_optimizer_config=dict, comm_fusion=dict)
|
||||
def set_auto_parallel_context(**kwargs):
|
||||
r"""
|
||||
Set auto parallel context, which is valid only for Ascend and GPU target.
|
||||
|
@ -402,6 +402,7 @@ def set_auto_parallel_context(**kwargs):
|
|||
parallel_optimizer_config pipeline_stages
|
||||
\ grad_accumulation_step
|
||||
\ auto_parallel_search_mode
|
||||
\ comm_fusion
|
||||
=========================== ===========================
|
||||
|
||||
Args:
|
||||
|
@ -476,6 +477,16 @@ def set_auto_parallel_context(**kwargs):
|
|||
with larger batch size. This configure is effective only
|
||||
when the model runs on pipeline training or gradient
|
||||
accumulation with data parallel.
|
||||
comm_fusion (dict): A dict contains the types and configurations for setting the communication fusion. each
|
||||
communication fusion config has two keys: "mode" and "config".
|
||||
It supports following communication fusion types and configurations:
|
||||
|
||||
- allreduce: if communication fusion type is `allreduce`. The `mode` contains: `auto`, `size`
|
||||
and `index`. In `auto` mode, allreduce fusion is configured by gradients size, and the default
|
||||
fusion threshold is `64` MB. In 'size' mode, allreduce fusion is configured by gradients size
|
||||
manually, and the fusion threshold must be larger than `0` MB. In `index` mode, it is same as
|
||||
`all_reduce_fusion_config`.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not attribute in auto parallel context.
|
||||
|
@ -498,6 +509,8 @@ def set_auto_parallel_context(**kwargs):
|
|||
>>> context.set_auto_parallel_context(pipeline_stages=2)
|
||||
>>> parallel_config = {"gradient_accumulation_shard": True}
|
||||
>>> context.set_auto_parallel_context(parallel_optimizer_config=parallel_config, enable_parallel_optimizer=True)
|
||||
>>> comm_fusion_config = {"allreduce": {"mode": "size", "config": 32}}
|
||||
>>> context.set_auto_parallel_context(comm_fusion=comm_fusion_config)
|
||||
"""
|
||||
_set_auto_parallel_context(**kwargs)
|
||||
|
||||
|
@ -540,6 +553,7 @@ def reset_auto_parallel_context():
|
|||
- full_batch: False.
|
||||
- enable_parallel_optimizer: False.
|
||||
- pipeline_stages: 1.
|
||||
- fusion_threshold: 64.
|
||||
"""
|
||||
_reset_auto_parallel_context()
|
||||
|
||||
|
|
|
@ -26,6 +26,16 @@ _MAX_GROUP_NAME_LEN = 127
|
|||
_DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1"
|
||||
_DEFAULT_NCCL_FUSION_GROUP_NAME = "nccl_world_groupsum1"
|
||||
|
||||
class _ParallelFusionConfig:
|
||||
"""
|
||||
The key of the Parallel fusion method configuration.
|
||||
"""
|
||||
ALLREDUCE = "allreduce"
|
||||
MODE = "mode"
|
||||
FUSION_CONFIG = "config"
|
||||
AUTO = "auto"
|
||||
INDEX = "index"
|
||||
SIZE = "size"
|
||||
|
||||
class _ParallelOptimizerConfig:
|
||||
"""
|
||||
|
@ -89,6 +99,86 @@ class _AutoParallelContext:
|
|||
self.check_context_handle()
|
||||
return self._context_handle.get_device_num()
|
||||
|
||||
def set_comm_fusion(self, config):
|
||||
"""
|
||||
Set fusion method for auto parallel.
|
||||
|
||||
Args:
|
||||
config (dict): A dict contains the methods and values for setting the communication fusion. Currently it
|
||||
supports: `allreduce`.
|
||||
|
||||
Raises:
|
||||
KeyError: When key of comm_fusion is not 'allreduce'.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
for key in list(config.keys()):
|
||||
if key == _ParallelFusionConfig.ALLREDUCE:
|
||||
self._set_allreduce_comm_fusion(config[key])
|
||||
else:
|
||||
raise KeyError("comm fusion type must be allreduce, but got {}".format(key))
|
||||
|
||||
def _set_allreduce_comm_fusion(self, comm_fusion):
|
||||
"""
|
||||
Set fusion method for auto parallel.
|
||||
|
||||
Args:
|
||||
comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
|
||||
supports four fusion methods: `auto`, `size` and `index`.
|
||||
|
||||
Raises:
|
||||
KeyError: When key of comm_fusion is not 'mode' or 'config'.
|
||||
KeyError: When `mode` is not 'auto', 'size' or 'index'.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
if _ParallelFusionConfig.MODE not in comm_fusion:
|
||||
raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
|
||||
if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
|
||||
raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
|
||||
check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.INDEX, _ParallelFusionConfig.SIZE]
|
||||
if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
|
||||
self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
|
||||
else:
|
||||
raise KeyError("fusion method mode must be auto, index or size, but got {}".format(
|
||||
comm_fusion[_ParallelFusionConfig.MODE]))
|
||||
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.AUTO:
|
||||
self.set_fusion_threshold_mb(fusion_threshold=64)
|
||||
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.SIZE:
|
||||
self.set_fusion_threshold_mb(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
|
||||
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.INDEX:
|
||||
self.set_all_reduce_fusion_split_indices(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
|
||||
|
||||
def get_comm_fusion(self):
|
||||
"""Get comm fusion config."""
|
||||
self.check_context_handle()
|
||||
mode = self._context_handle.get_fusion_mode()
|
||||
if mode in (_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE):
|
||||
config = self.fusion_threshold_mb()
|
||||
if mode == _ParallelFusionConfig.INDEX:
|
||||
config = self.get_all_reduce_fusion_split_indices()
|
||||
return {_ParallelFusionConfig.ALLREDUCE: {_ParallelFusionConfig.MODE: mode,
|
||||
_ParallelFusionConfig.FUSION_CONFIG: config}}
|
||||
|
||||
|
||||
def set_fusion_threshold_mb(self, fusion_threshold=64):
|
||||
"""
|
||||
Set fusion threshold (MB) for auto parallel.
|
||||
|
||||
Args:
|
||||
fusion_threshold (int): The fusion threshold (unit: MB). Default: 64.
|
||||
|
||||
Raises:
|
||||
ValueError: If the fusion threshold is not in [0, +inf].
|
||||
"""
|
||||
self.check_context_handle()
|
||||
if fusion_threshold < 0:
|
||||
raise ValueError("fusion threshold must be larger than 0, but got {}".format(fusion_threshold))
|
||||
self._context_handle.set_fusion_threshold_mb(fusion_threshold)
|
||||
|
||||
def fusion_threshold_mb(self):
|
||||
"""Get device num."""
|
||||
self.check_context_handle()
|
||||
return self._context_handle.fusion_threshold_mb()
|
||||
|
||||
def set_global_rank(self, global_rank):
|
||||
"""
|
||||
Set global rank for auto parallel.
|
||||
|
@ -741,7 +831,8 @@ _set_auto_parallel_context_func_map = {
|
|||
"optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size,
|
||||
"optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save,
|
||||
"sharding_propagation": auto_parallel_context().set_sharding_propagation,
|
||||
"enable_alltoall": auto_parallel_context().set_enable_alltoall}
|
||||
"enable_alltoall": auto_parallel_context().set_enable_alltoall,
|
||||
"comm_fusion": auto_parallel_context().set_comm_fusion}
|
||||
|
||||
|
||||
_get_auto_parallel_context_func_map = {
|
||||
|
@ -766,7 +857,8 @@ _get_auto_parallel_context_func_map = {
|
|||
"optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size,
|
||||
"optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save,
|
||||
"sharding_propagation": auto_parallel_context().get_sharding_propagation,
|
||||
"enable_alltoall": auto_parallel_context().get_enable_alltoall}
|
||||
"enable_alltoall": auto_parallel_context().get_enable_alltoall,
|
||||
"comm_fusion": auto_parallel_context().get_comm_fusion}
|
||||
|
||||
|
||||
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
|
||||
|
@ -775,7 +867,7 @@ _get_auto_parallel_context_func_map = {
|
|||
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
||||
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
|
||||
communi_parallel_mode=str, optimizer_weight_shard_size=int, sharding_propagation=bool,
|
||||
optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool)
|
||||
optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool, comm_fusion=dict)
|
||||
|
||||
def _set_auto_parallel_context(**kwargs):
|
||||
"""
|
||||
|
@ -848,6 +940,16 @@ def _set_auto_parallel_context(**kwargs):
|
|||
search the desired strategies. Default: False.
|
||||
enable_alltoall (bool): Set the value of enabling AllToAll. If False, AllGather and Split are used to
|
||||
circumvent AllToAll. Default: False.
|
||||
comm_fusion (dict): A dict contains the types and configurations for setting the communication fusion. each
|
||||
communication fusion config has two keys: "mode" and "config".
|
||||
It supports following communication fusion types and configurations:
|
||||
|
||||
- allreduce: if communication fusion type is `allreduce`. The `mode` contains: `auto`, `size`
|
||||
and `index`. In `auto` mode, allreduce fusion is configured by gradients size, and the default
|
||||
fusion threshold is `64` MB. In 'size' mode, allreduce fusion is configured by gradients size
|
||||
manually, and the fusion threshold must be larger than `0` MB. In `index` mode, it is same as
|
||||
`all_reduce_fusion_config`.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not attribute in auto parallel context.
|
||||
|
@ -896,5 +998,6 @@ def _reset_auto_parallel_context():
|
|||
- sharding_propagation: False
|
||||
- pipeline_stages: 0
|
||||
- gradient_accumulation_shard: True
|
||||
- fusion_threshold: 64
|
||||
"""
|
||||
auto_parallel_context().reset()
|
||||
|
|
|
@ -13,20 +13,46 @@
|
|||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.nn.optim import Lamb
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.parallel import _cost_model_context as cost_model_context
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from mindspore.train import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from tests.dataset_mock import MindData
|
||||
context.set_context(mode=context.PYNATIVE_MODE, save_graphs=True)
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""Net definition"""
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.fc1 = nn.Dense(128, 768, activation='relu')
|
||||
self.fc2 = nn.Dense(128, 768, activation='relu')
|
||||
self.fc3 = nn.Dense(128, 768, activation='relu')
|
||||
self.fc4 = nn.Dense(768, 768, activation='relu')
|
||||
self.relu4 = nn.ReLU()
|
||||
self.relu5 = nn.ReLU()
|
||||
self.transpose = P.Transpose()
|
||||
self.matmul1 = P.MatMul()
|
||||
self.matmul2 = P.MatMul()
|
||||
|
||||
def construct(self, x):
|
||||
q = self.fc1(x)
|
||||
k = self.fc2(x)
|
||||
v = self.fc3(x)
|
||||
k = self.transpose(k, (1, 0))
|
||||
c = self.relu4(self.matmul1(q, k))
|
||||
s = self.relu5(self.matmul2(c, v))
|
||||
s = self.fc4(s)
|
||||
return s
|
||||
|
||||
class Dataset(MindData):
|
||||
def __init__(self, predict, label, length=3):
|
||||
|
@ -107,7 +133,6 @@ def train_common(net):
|
|||
momentum = 0.9
|
||||
epoch_size = 2
|
||||
device_num = 4
|
||||
auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True)
|
||||
context.set_auto_parallel_context(device_num=device_num, parameter_broadcast=False)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -121,196 +146,78 @@ def train_common(net):
|
|||
|
||||
model.train(epoch_size, dataset, dataset_sink_mode=False)
|
||||
allreduce_fusion_dict = _cell_graph_executor._get_allreduce_fusion(model._train_network)
|
||||
|
||||
print(allreduce_fusion_dict)
|
||||
return allreduce_fusion_dict
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="depreciated feature")
|
||||
def test_allreduce_fusion_parameters():
|
||||
cost_model_context.reset_cost_model_context()
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2)
|
||||
algorithm = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_algorithm')
|
||||
assert algorithm == 2
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
|
||||
algorithm = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_algorithm')
|
||||
assert algorithm == 1
|
||||
cost_model_context.reset_cost_model_context()
|
||||
algorithm = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_algorithm')
|
||||
assert algorithm == 0
|
||||
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
|
||||
fusion_times = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_times')
|
||||
assert fusion_times == 2
|
||||
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.2)
|
||||
tail_percent = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_percent')
|
||||
assert tail_percent == 0.2
|
||||
cost_model_context.reset_cost_model_context()
|
||||
tail_percent = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_percent')
|
||||
assert tail_percent == 0.1
|
||||
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.2)
|
||||
tail_time = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_time')
|
||||
assert tail_time == 0.2
|
||||
cost_model_context.reset_cost_model_context()
|
||||
tail_time = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_time')
|
||||
assert tail_time == 0.1
|
||||
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_inherent_time=0.2)
|
||||
allreduce_inherent_time = cost_model_context.get_cost_model_context(
|
||||
'costmodel_allreduce_fusion_allreduce_inherent_time')
|
||||
assert allreduce_inherent_time == 0.2
|
||||
cost_model_context.reset_cost_model_context()
|
||||
allreduce_inherent_time = cost_model_context.get_cost_model_context(
|
||||
'costmodel_allreduce_fusion_allreduce_inherent_time')
|
||||
assert allreduce_inherent_time == 0.1
|
||||
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_bandwidth=0.2)
|
||||
allreduce_bandwidth = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_allreduce_bandwidth')
|
||||
assert allreduce_bandwidth == 0.2
|
||||
cost_model_context.reset_cost_model_context()
|
||||
allreduce_bandwidth = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_allreduce_bandwidth')
|
||||
assert allreduce_bandwidth == 0.1
|
||||
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_computation_time_parameter=0.2)
|
||||
computation_time_parameter = cost_model_context.get_cost_model_context(
|
||||
'costmodel_allreduce_fusion_computation_time_parameter')
|
||||
assert computation_time_parameter == 0.2
|
||||
cost_model_context.reset_cost_model_context()
|
||||
computation_time_parameter = cost_model_context.get_cost_model_context(
|
||||
'costmodel_allreduce_fusion_computation_time_parameter')
|
||||
assert computation_time_parameter == 0.1
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="depreciated feature")
|
||||
def test_allreduce_fusion1():
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
def test_allreduce_fusion_auto():
|
||||
"""
|
||||
Feature: test_allreduce_fusion in auto mode
|
||||
Description: allreduce fusion in auto mode
|
||||
Expectation: success
|
||||
"""
|
||||
comm_fusion_dict = {"allreduce": {"mode": "auto", "config": None}}
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, comm_fusion=comm_fusion_dict)
|
||||
net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
|
||||
allreduce_fusion_dict = train_common(net)
|
||||
expect_dict = {'backbone2.fc8.weight': 2,
|
||||
'backbone2.fc7.weight': 2,
|
||||
'backbone2.fc6.weight': 2,
|
||||
'backbone1.fc4.weight': 2,
|
||||
'backbone1.fc3.weight': 2,
|
||||
'backbone1.fc2.weight': 2,
|
||||
'backbone2.fc5.weight': 1,
|
||||
'backbone2.fc4.weight': 1,
|
||||
'backbone2.fc3.weight': 1,
|
||||
'backbone2.fc2.weight': 1,
|
||||
'backbone2.fc1.weight': 1,
|
||||
'backbone1.fc1.weight': 1}
|
||||
assert allreduce_fusion_dict == expect_dict
|
||||
cost_model_context.reset_cost_model_context()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="depreciated feature")
|
||||
# reset_cost_model_context is called, the default value of costmodel_allreduce_fusion_times is 0, step_allreduce_fusion
|
||||
# is bypassed.
|
||||
def test_allreduce_fusion2():
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5)
|
||||
cost_model_context.reset_cost_model_context()
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
|
||||
allreduce_fusion_dict = train_common(net)
|
||||
expect_dict = {}
|
||||
assert allreduce_fusion_dict == expect_dict
|
||||
cost_model_context.reset_cost_model_context()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="depreciated feature")
|
||||
def test_allreduce_fusion3():
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=3)
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.3333333)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
net = SimpleDMLNet(DenseNet1(has_bias=True, activation='relu'), DenseNet2(has_bias=False, activation='relu'))
|
||||
allreduce_fusion_dict = train_common(net)
|
||||
expect_dict = {'backbone2.fc8.weight': 3,
|
||||
'backbone2.fc7.weight': 3,
|
||||
'backbone2.fc6.weight': 2,
|
||||
'backbone2.fc5.weight': 2,
|
||||
'backbone2.fc4.weight': 2,
|
||||
'backbone2.fc3.weight': 1,
|
||||
'backbone2.fc2.weight': 1,
|
||||
'backbone2.fc1.weight': 1,
|
||||
'backbone1.fc4.bias': 3,
|
||||
'backbone1.fc4.weight': 3,
|
||||
'backbone1.fc3.bias': 3,
|
||||
'backbone1.fc3.weight': 2,
|
||||
'backbone1.fc2.bias': 2,
|
||||
'backbone1.fc2.weight': 2,
|
||||
'backbone1.fc1.bias': 2,
|
||||
'backbone1.fc1.weight': 2}
|
||||
assert allreduce_fusion_dict == expect_dict
|
||||
cost_model_context.reset_cost_model_context()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="depreciated feature")
|
||||
def test_allreduce_fusion4():
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
net = SimpleDMLNet(DenseNet2(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
|
||||
allreduce_fusion_dict = train_common(net)
|
||||
expect_dict = {'backbone2.fc8.weight': 2,
|
||||
'backbone2.fc7.weight': 2,
|
||||
'backbone2.fc6.weight': 2,
|
||||
'backbone1.fc8.weight': 2,
|
||||
'backbone1.fc7.weight': 2,
|
||||
'backbone1.fc6.weight': 2,
|
||||
'backbone2.fc5.weight': 1,
|
||||
'backbone2.fc4.weight': 1,
|
||||
'backbone2.fc3.weight': 1,
|
||||
'backbone2.fc2.weight': 1,
|
||||
'backbone2.fc1.weight': 1,
|
||||
'backbone1.fc5.weight': 1,
|
||||
expect_dict = {'backbone2.fc8.weight': 1,
|
||||
'backbone2.fc7.weight': 1,
|
||||
'backbone2.fc6.weight': 1,
|
||||
'backbone1.fc4.weight': 1,
|
||||
'backbone1.fc3.weight': 1,
|
||||
'backbone1.fc2.weight': 1,
|
||||
'backbone1.fc1.weight': 1}
|
||||
|
||||
assert allreduce_fusion_dict == expect_dict
|
||||
cost_model_context.reset_cost_model_context()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="depreciated feature")
|
||||
def test_allreduce_fusion5():
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2)
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.1)
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_inherent_time=0.05)
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_bandwidth=0.000001)
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_computation_time_parameter=0.0000015)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
net = SimpleDMLNet(DenseNet2(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
|
||||
allreduce_fusion_dict = train_common(net)
|
||||
|
||||
expect_dict = {'backbone2.fc8.weight': 3,
|
||||
'backbone2.fc7.weight': 3,
|
||||
'backbone2.fc6.weight': 3,
|
||||
'backbone2.fc5.weight': 3,
|
||||
'backbone2.fc4.weight': 2,
|
||||
'backbone2.fc3.weight': 2,
|
||||
'backbone2.fc5.weight': 1,
|
||||
'backbone2.fc4.weight': 1,
|
||||
'backbone2.fc3.weight': 1,
|
||||
'backbone2.fc2.weight': 1,
|
||||
'backbone2.fc1.weight': 1,
|
||||
'backbone1.fc8.weight': 3,
|
||||
'backbone1.fc7.weight': 3,
|
||||
'backbone1.fc6.weight': 3,
|
||||
'backbone1.fc5.weight': 3,
|
||||
'backbone1.fc4.weight': 2,
|
||||
'backbone1.fc3.weight': 2,
|
||||
'backbone1.fc2.weight': 1,
|
||||
'backbone1.fc1.weight': 1,}
|
||||
|
||||
'backbone1.fc1.weight': 1}
|
||||
assert allreduce_fusion_dict == expect_dict
|
||||
cost_model_context.reset_cost_model_context()
|
||||
|
||||
def test_allreduce_fusion_size():
|
||||
"""
|
||||
Feature: test_allreduce_fusion in size mode
|
||||
Description: allreduce fusion in size mode
|
||||
Expectation: success
|
||||
"""
|
||||
comm_fusion_dict = {"allreduce": {"mode": "size", "config": 32}}
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, comm_fusion=comm_fusion_dict)
|
||||
net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
|
||||
allreduce_fusion_dict = train_common(net)
|
||||
expect_dict = {'backbone2.fc8.weight': 1,
|
||||
'backbone2.fc7.weight': 1,
|
||||
'backbone2.fc6.weight': 1,
|
||||
'backbone1.fc4.weight': 1,
|
||||
'backbone1.fc3.weight': 1,
|
||||
'backbone1.fc2.weight': 1,
|
||||
'backbone2.fc5.weight': 1,
|
||||
'backbone2.fc4.weight': 1,
|
||||
'backbone2.fc3.weight': 1,
|
||||
'backbone2.fc2.weight': 1,
|
||||
'backbone2.fc1.weight': 1,
|
||||
'backbone1.fc1.weight': 1}
|
||||
assert allreduce_fusion_dict == expect_dict
|
||||
cost_model_context.reset_cost_model_context()
|
||||
comm_fusion = auto_parallel_context().get_comm_fusion()
|
||||
assert comm_fusion_dict == comm_fusion
|
||||
|
||||
def test_lamb_split_fusion_in_index():
|
||||
"""
|
||||
Feature: test_allreduce_fusion in index mode
|
||||
Description: allreduce fusion in index mode
|
||||
Expectation: success
|
||||
"""
|
||||
comm_fusion_dict = {"allreduce": {"mode": "index", "config": [2, 4, 6, 8]}}
|
||||
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True,
|
||||
comm_fusion=comm_fusion_dict)
|
||||
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
||||
label = Tensor(np.zeros([32, 768]).astype(np.float32))
|
||||
net = Net()
|
||||
net.set_train()
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
optimizer = Lamb(net.trainable_params(), learning_rate=0.1)
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||
_cell_graph_executor.compile(train_network, inputs, label)
|
||||
context.reset_auto_parallel_context()
|
||||
|
|
Loading…
Reference in New Issue