diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h index 0355d46f487..ee842f4c5d9 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h @@ -24,6 +24,7 @@ #include #include #include +#include #include "utils/hash_map.h" #include "ir/anf.h" @@ -38,7 +39,7 @@ namespace mindspore { namespace ad { -using Registry = mindspore::HashMap; +using Registry = std::unordered_map; class KPrim; extern KPrim g_k_prims; class DFunctor; diff --git a/mindspore/ccsrc/frontend/optimizer/ad/prim_bprop_optimizer.h b/mindspore/ccsrc/frontend/optimizer/ad/prim_bprop_optimizer.h index b1144625147..2442f16ff9a 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/prim_bprop_optimizer.h +++ b/mindspore/ccsrc/frontend/optimizer/ad/prim_bprop_optimizer.h @@ -43,7 +43,7 @@ using PrimBpropOptGraphInfoPtr = std::shared_ptr; using PrimBpropOptGraphLevel2InfoPtr = std::shared_ptr; -using PrimBpropCache = mindspore::HashMap; +using PrimBpropCache = std::unordered_map; using TupleListKey = std::pair; diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc index cc79cd5a681..a77758f4375 100644 --- a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc @@ -18,6 +18,8 @@ #include #include #include +#include +#include #include "utils/hash_set.h" #include "ir/func_graph.h" #include "frontend/parallel/costmodel_context.h" @@ -28,375 +30,53 @@ namespace mindspore { namespace parallel { -mindspore::HashSet FindCNodesWithPara(const AnfNodePtr ¶, uint64_t recursive_times = 0) { - if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { - MS_LOG(EXCEPTION) << "FindCNodesWithPara exceeds max recursive call times! Max recursive call times is " - << MAX_RECURSIVE_CALL_TIMES; - } - MS_EXCEPTION_IF_NULL(para); - MS_EXCEPTION_IF_NULL(para->func_graph()); - FuncGraphManagerPtr manager = para->func_graph()->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto node_set = manager->node_users()[para]; - mindspore::HashSet cnode_set; - for (auto &node_pair : node_set) { - auto cnode = node_pair.first->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!IsValueNode(cnode->input(0))) { - continue; - } - auto node_prim = GetValueNode(cnode->input(0)); - MS_EXCEPTION_IF_NULL(node_prim); - if (node_prim->name() == DEPEND && node_pair.second != 1) { - continue; - } - if (IsParallelCareNode(cnode) && cnode->has_user_data()) { - (void)cnode_set.emplace(cnode); - } else { - auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1); - for (auto &cnode_sub : cnode_set_sub) { - (void)cnode_set.emplace(cnode_sub); - } - } - } - return cnode_set; -} - -Status AllreduceFusion::AddNodeToGraph() { - const auto ¶meters = root_graph_->parameters(); - for (auto ¶meter : parameters) { - if (!ParameterRequireGrad(parameter)) { - continue; - } - auto cnode_set = FindCNodesWithPara(parameter); - if (cnode_set.empty()) { - continue; - } - for (auto &cnode : cnode_set) { - MS_LOG(DEBUG) << "AddNode " << cnode->DebugString(); - if (allreduce_graph_.AddNode(cnode, parameter) != SUCCESS) { - MS_LOG(ERROR) << "AddNode failed! cnode: " << cnode->DebugString(); - return FAILED; - } - } - } - return SUCCESS; -} - -CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint64_t recursive_times) const { - if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { - MS_LOG(EXCEPTION) << "FindCNode exceeds max recursive call times! Max recursive call times is " - << MAX_RECURSIVE_CALL_TIMES; - } - MS_EXCEPTION_IF_NULL(from); - mindspore::HashMap cnode_dist; - if (!from->isa()) { - return cnode_dist; - } - auto cnode = from->cast(); - if (!IsValueNode(cnode->input(0))) { - return cnode_dist; - } - - auto operator_info = cnode->user_data(); - MS_LOG(DEBUG) << "cnode " << cnode->ToString() << " IsParallelCareNode: " << IsParallelCareNode(cnode) - << " operator_info: " << (operator_info != nullptr); - - if (IsParallelCareNode(cnode) && (operator_info != nullptr)) { - auto cost = operator_info->GetForwardMemoryCostFromCNode(); - MS_LOG(DEBUG) << "cnode " << cnode->DebugString() << " cost: " << cost; - - if (allreduce_graph_.NodeInGraph(cnode)) { - cnode_dist[cnode] = cost; - return cnode_dist; - } else { - auto cnode_dist_next = FindNextCNodes(cnode, recursive_times + 1); - for (auto &ele_next : cnode_dist_next) { - cnode_dist[ele_next.first] = cost + ele_next.second; - } - } - } else { - auto cnode_dist_next = FindNextCNodes(cnode); - for (auto &ele : cnode_dist_next) { - cnode_dist[ele.first] = ele.second; - } - } - return cnode_dist; -} - -CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr &from, uint64_t recursive_times) const { - if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { - MS_LOG(EXCEPTION) << "FindNextCNodes exceeds max recursive call times! Max recursive call times is " - << MAX_RECURSIVE_CALL_TIMES; - } - const auto &from_inputs = from->inputs(); - mindspore::HashMap dist_map; - MS_LOG(DEBUG) << "from cnode " << from->DebugString() << " has " << from_inputs.size() << " inputs"; - for (auto &input_node : from_inputs) { - auto cnode_dist = FindCNode(input_node, recursive_times + 1); - for (auto &ele : cnode_dist) { - (void)dist_map.emplace(ele); - } - } - return dist_map; -} - -Status AllreduceFusion::AddEdgeToGraph() { - mindspore::HashMap cnode_state_map; - const auto &cnodes = allreduce_graph_.cnode_set(); - for (auto &cnode : cnodes) { - cnode_state_map[cnode] = 0; - } - const auto &head_cnode = allreduce_graph_.head_cnode(); - std::queue cnode_queue; - cnode_queue.emplace(head_cnode); - cnode_state_map[head_cnode] = 1; - - while (!cnode_queue.empty()) { - const auto cur_cnode = cnode_queue.front(); - cnode_queue.pop(); - cnode_state_map[cur_cnode] = 2; - auto next = FindNextCNodes(cur_cnode); - for (auto &ele : next) { - auto &cnode = ele.first; - auto &dist = ele.second; - if (cnode_state_map[cnode] == 0) { - cnode_queue.emplace(cnode); - cnode_state_map[cnode] = 1; - } - if (allreduce_graph_.AddEdge(cur_cnode, cnode, dist) != SUCCESS) { - MS_LOG(ERROR) << "AddEdge error"; - return FAILED; - } - MS_LOG(DEBUG) << "from " << cur_cnode->DebugString() << ", to " << cnode->DebugString() << " dist " << dist; - } - } - return SUCCESS; -} - -std::vector FindMirror(const AnfNodePtr ¶, uint64_t recursive_times = 0) { - if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { - MS_LOG(EXCEPTION) << "FindMirror exceeds max recursive call times! Max recursive call times is " - << MAX_RECURSIVE_CALL_TIMES; - } - MS_EXCEPTION_IF_NULL(para); - MS_EXCEPTION_IF_NULL(para->func_graph()); - FuncGraphManagerPtr manager = para->func_graph()->manager(); - MS_EXCEPTION_IF_NULL(manager); - AnfNodeIndexSet node_set = manager->node_users()[para]; - std::vector cnode_list; - for (auto &node_pair : node_set) { - auto cnode = node_pair.first->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!IsValueNode(cnode->input(0))) { - continue; - } - auto node_prim = GetValueNode(cnode->input(0)); - MS_EXCEPTION_IF_NULL(node_prim); - if (node_prim->name() == CAST) { - auto mirror_cnodes = FindMirror(node_pair.first, recursive_times + 1); - if (mirror_cnodes.empty()) { - MS_LOG(WARNING) << "mirror node after cast not found"; - continue; - } - if (mirror_cnodes.size() > 1) { - MS_LOG(EXCEPTION) << "mirror node after cast number is not 1"; - } - cnode_list.emplace_back(mirror_cnodes[0]); - } - if (node_prim->name() == MIRROR_OPERATOR) { - cnode_list.emplace_back(cnode); - } - } - return cnode_list; -} - void SetMirrorFusion(const CNodePtr &mirror_cnode, int64_t fusion, const std::string ¶meter_name) { MS_EXCEPTION_IF_NULL(mirror_cnode); MS_LOG(DEBUG) << "Set Mirror " << mirror_cnode->DebugString() << " fusion " << fusion; auto node_prim = GetValueNode(mirror_cnode->input(0)); - auto old_value_ptr = node_prim->GetAttr(FUSION); - if (old_value_ptr != nullptr) { - if (old_value_ptr->isa()) { - int64_t old_value = old_value_ptr->cast()->value(); - if (old_value < fusion) { - return; - } - } - } (void)node_prim->AddAttr(FUSION, MakeValue(std::make_shared(fusion))); (void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared(parameter_name))); } -Status FindMirrorAndSetFusion(const AnfNodePtr ¶, int64_t fusion) { - auto mirror_cnodes = FindMirror(para); - if (mirror_cnodes.empty()) { - MS_LOG(WARNING) << para->ToString() << " 0 Mirror CNode found."; - return SUCCESS; - } - if (mirror_cnodes.size() > 2) { - for (auto &mirror_cnode_1 : mirror_cnodes) { - MS_EXCEPTION_IF_NULL(mirror_cnode_1); - MS_LOG(INFO) << mirror_cnode_1->DebugString(); - } - MS_EXCEPTION_IF_NULL(para); - MS_LOG(ERROR) << para->ToString() << " FindMirror is more than 2. " << mirror_cnodes.size() - << "Mirror CNode found."; - return FAILED; - } - for (auto &mirror_cnode : mirror_cnodes) { - auto parameter_name = ParameterName(para); - SetMirrorFusion(mirror_cnode, fusion, parameter_name); - } - return SUCCESS; -} - -Status FindMirrorAndSetFusion(const std::vector ¶s, int64_t fusion) { - for (auto ¶m_node : paras) { - if (FindMirrorAndSetFusion(param_node, fusion) != SUCCESS) { - MS_LOG(ERROR) << "FindMirrorAndSetFusion failed"; - return FAILED; - } - } - return SUCCESS; -} - -Status AllreduceFusion::SetFusion(const std::vector &cost_map) { - if (cost_map.size() < 2) { - MS_LOG(ERROR) << "cost_map must has at least 2 items, cost_map size is " << cost_map.size(); - return FAILED; - } +Status AllreduceFusion::SetFusionBySize(const CNodePtr &ret, int64_t threshold) { + auto filter = [](const AnfNodePtr &node) { return !IsPrimitiveCNode(node, prim::kPrimMirror); }; + auto todo = DeepScopedGraphSearchWithFilter(ret, AlwaysInclude, filter); + auto temp = threshold; int64_t fusion = 1; - for (auto cost_iter = cost_map.end() - 1; cost_iter != cost_map.begin(); --cost_iter) { - auto paras = allreduce_graph_.GetParaByCost(*(cost_iter - 1), *cost_iter); - if (FindMirrorAndSetFusion(paras, fusion) != SUCCESS) { - MS_LOG(ERROR) << "FindMirrorAndSetFusion failed"; - return FAILED; + bool init = true; + for (auto &node : todo) { + auto cnode = node->cast(); + if (cnode->input(1)->Shape() == nullptr) continue; + auto input_shapes = GetNodeShape(cnode->input(1)); + int64_t input_size = std::accumulate(input_shapes[0].begin(), input_shapes[0].end(), 1, std::multiplies()); + FuncGraphPtr func_graph = cnode->func_graph(); + std::pair param_node_pair = FindParameter(cnode->input(1), func_graph); + if (!param_node_pair.first) { + continue; } - fusion++; - } - return SUCCESS; -} - -std::vector AllreduceFusion::GenerateCostMap(int64_t fusion_times, double tail_percent) const { - double offset = allreduce_graph_.max() * (1 - tail_percent) / (fusion_times - 1); - MS_LOG(DEBUG) << "max = " << allreduce_graph_.max() << ", offset = " << offset; - std::vector cost_map; - double begin = 0; - for (auto i = 0; i < fusion_times - 1; i++) { - cost_map.push_back(begin); - begin += offset; - } - cost_map.push_back(allreduce_graph_.max() * (1 - tail_percent)); - cost_map.push_back(allreduce_graph_.max()); - MS_LOG(DEBUG) << "cost_map = " << cost_map; - return cost_map; -} - -Status AllreduceFusion::SetFusionByBackwardCompTime() { - auto fusion_times = CostModelContext::GetInstance()->costmodel_allreduce_fusion_times(); - if (fusion_times < 2) { - MS_LOG(INFO) << "'costmodel_allreduce_fusion_times' is " << fusion_times << ". Bypass ProcessAllreduceFusion"; - return SUCCESS; - } - auto tail_percent = CostModelContext::GetInstance()->costmodel_allreduce_fusion_tail_percent(); - if (tail_percent < 0 || tail_percent >= 1) { - MS_LOG(INFO) << "'costmodel_allreduce_fusion_tail_percent' is " << tail_percent - << ". Bypass ProcessAllreduceFusion"; - return SUCCESS; - } - const auto cost_map = GenerateCostMap(fusion_times, tail_percent); - MS_LOG(DEBUG) << "AllreduceGraph GenerateCostMap succeed."; - if (SetFusion(cost_map) != SUCCESS) { - MS_LOG(ERROR) << "SetFusion failed."; - return FAILED; - } - MS_LOG(DEBUG) << "AllreduceGraph SetFusion succeed."; - return SUCCESS; -} - -Status AllreduceFusion::GetSetFusionByBackwardCompAndAllreduceTimeParams() { - tail_time_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_tail_time(); - if (tail_time_ <= 0) { - MS_LOG(INFO) << "'costmodel_allreduce_tail_time' is " << tail_time_ << ". Bypass ProcessAllreduceFusion"; - return FAILED; - } - allreduce_inherent_time_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_allreduce_inherent_time(); - if (allreduce_inherent_time_ <= 0) { - MS_LOG(INFO) << "'costmodel_allreduce_fusion_allreduce_inherent_time' is " << allreduce_inherent_time_ - << ". Bypass ProcessAllreduceFusion"; - return FAILED; - } - if (tail_time_ <= allreduce_inherent_time_) { - MS_LOG(INFO) << "'costmodel_allreduce_tail_time' is " << tail_time_ - << "'costmodel_allreduce_fusion_allreduce_inherent_time' is " << allreduce_inherent_time_ - << ".tail_time is not more than allreduce_inherent_time. Bypass ProcessAllreduceFusion"; - return FAILED; - } - allreduce_bandwidth_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_allreduce_bandwidth(); - if (allreduce_bandwidth_ <= 0) { - MS_LOG(INFO) << "'costmodel_allreduce_fusion_allreduce_bandwidth' is " << allreduce_bandwidth_ - << ". Bypass ProcessAllreduceFusion"; - return FAILED; - } - computation_time_parameter_ = - CostModelContext::GetInstance()->costmodel_allreduce_fusion_computation_time_parameter(); - if (computation_time_parameter_ <= 0) { - MS_LOG(INFO) << "'costmodel_allreduce_fusion_computation_time_parameter' is " << computation_time_parameter_ - << ". Bypass ProcessAllreduceFusion"; - return FAILED; - } - return SUCCESS; -} - -Status AllreduceFusion::SetFusionByBackwardCompAndAllreduceTime() { - if (GetSetFusionByBackwardCompAndAllreduceTimeParams() != SUCCESS) { - MS_LOG(ERROR) << "GetSetFusionByBackwardCompAndAllreduceTimeParams failed!"; - return FAILED; - } - allreduce_graph_.SortArnode(); - if (allreduce_graph_.RemoveExtraParas() != SUCCESS) { - MS_LOG(ERROR) << "RemoveExtraParas failed!"; - return FAILED; - } - double para_size = (tail_time_ - allreduce_inherent_time_) / allreduce_bandwidth_; - double to_cost = allreduce_graph_.max(); - int64_t fusion = 1; - while (to_cost != 0) { - MS_LOG(INFO) << "to_cost: " << to_cost << " para_size: " << para_size; - auto node_cost_pair = allreduce_graph_.GetParaByParaSize(to_cost, para_size); - MS_LOG(INFO) << "para size: " << node_cost_pair.first.size() << " from_cost: " << node_cost_pair.second; - auto paras = node_cost_pair.first; - if (FindMirrorAndSetFusion(paras, fusion) != SUCCESS) { - MS_LOG(ERROR) << "FindMirrorAndSetFusion failed"; - return FAILED; + auto parameter_name = ParameterName(param_node_pair.first); + if (input_size < temp) { + temp -= input_size; + } else { + temp = threshold; + fusion++; } - fusion++; - para_size = ((to_cost - node_cost_pair.second) * computation_time_parameter_ - allreduce_inherent_time_) / - allreduce_bandwidth_; - to_cost = node_cost_pair.second; + if (init) { + SetMirrorFusion(cnode, 1, parameter_name); + } else { + SetMirrorFusion(cnode, fusion, parameter_name); + } + init = false; } - MS_LOG(DEBUG) << "AllreduceGraph SetFusionByBackwardCompAndAllreduceTime succeed."; + MS_LOG(INFO) << "Allreduce fusion by size succeed."; return SUCCESS; } -Status AllreduceFusion::SetFusionByAlgorithm(int64_t algorithm) { - if (algorithm == 1) { - return SetFusionByBackwardCompTime(); - } - return SetFusionByBackwardCompAndAllreduceTime(); -} - Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) { if (ret == nullptr) { MS_LOG(ERROR) << "ret is nullptr."; return FAILED; } - auto algorithm = CostModelContext::GetInstance()->costmodel_allreduce_fusion_algorithm(); - if (algorithm < 1 || algorithm > 2) { - MS_LOG(INFO) << "'costmodel_allreduce_fusion_algorithm' is " << algorithm << ". Bypass ProcessAllreduceFusion"; - return SUCCESS; - } ret_ = ret; root_graph_ = ret_->func_graph(); MS_EXCEPTION_IF_NULL(root_graph_); @@ -409,27 +89,20 @@ Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) { MS_EXCEPTION_IF_NULL(forward_graph); forward_ret_ = forward_graph->get_return(); MS_EXCEPTION_IF_NULL(forward_ret_); - if (allreduce_graph_.set_head_cnode(forward_ret_) != SUCCESS) { MS_LOG(ERROR) << "AllreduceGraph set_head_cnode failed."; return FAILED; } - MS_LOG(DEBUG) << "AllreduceGraph set_head_cnode succeed."; - if (AddNodeToGraph() != SUCCESS) { - MS_LOG(ERROR) << "AddNodeToGraph failed."; + auto threshold = ParallelContext::GetInstance()->fusion_threshold_mb() * 1024 * 1024 / 4.0; + if (threshold > 0) { + if (SetFusionBySize(ret, threshold) != SUCCESS) { + MS_LOG(ERROR) << "SetFusionBySize failed."; + return FAILED; + } + } else { + MS_LOG(ERROR) << "The threshold of SetFusionBySize must be larger than 0, but got " << threshold << "."; return FAILED; } - MS_LOG(DEBUG) << "AllreduceGraph AddNodeToGraph succeed."; - if (AddEdgeToGraph() != SUCCESS) { - MS_LOG(ERROR) << "AddNodeToGraph failed."; - return FAILED; - } - MS_LOG(DEBUG) << "AllreduceGraph AddEdgeToGraph succeed."; - if (SetFusionByAlgorithm(algorithm) != SUCCESS) { - MS_LOG(ERROR) << "SetFusionByAlgorithm failed."; - return FAILED; - } - MS_LOG(DEBUG) << "AllreduceGraph SetFusionByAlgorithm succeed."; return SUCCESS; } } // namespace parallel diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.h b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.h index c5a0196af00..f4cedb8807e 100644 --- a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.h +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.h @@ -27,8 +27,6 @@ namespace mindspore { namespace parallel { -using CNodeCostMap = mindspore::HashMap; - constexpr int64_t DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM = 0; constexpr int64_t DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES = 0; constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT = 0.1; @@ -46,31 +44,17 @@ class AllreduceFusion { forward_ret_(nullptr), root_graph_(nullptr), tail_time_(0), - allreduce_inherent_time_(0), - allreduce_bandwidth_(0), computation_time_parameter_(0) {} virtual ~AllreduceFusion() = default; Status ProcessAllreduceFusion(const CNodePtr &ret); private: - Status AddNodeToGraph(); - CNodeCostMap FindCNode(const AnfNodePtr &from, uint64_t recursive_times = 0) const; - CNodeCostMap FindNextCNodes(const CNodePtr &from, uint64_t recursive_times = 0) const; - Status AddEdgeToGraph(); - std::vector GenerateCostMap(int64_t fusion_times, double tail_percent) const; - Status SetFusion(const std::vector &cost_map); - Status SetFusionByAlgorithm(int64_t algorithm); - Status SetFusionByBackwardCompTime(); - Status SetFusionByBackwardCompAndAllreduceTime(); - Status GetSetFusionByBackwardCompAndAllreduceTimeParams(); - + Status SetFusionBySize(const CNodePtr &ret, int64_t threshold); AllreduceGraph allreduce_graph_; CNodePtr ret_; CNodePtr forward_ret_; FuncGraphPtr root_graph_; double tail_time_; - double allreduce_inherent_time_; - double allreduce_bandwidth_; double computation_time_parameter_; }; } // namespace parallel diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/step_allreduce_fusion.cc b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/step_allreduce_fusion.cc index 1c15d8aaa45..16902b5c0a8 100644 --- a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/step_allreduce_fusion.cc +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/step_allreduce_fusion.cc @@ -22,6 +22,7 @@ #include "frontend/parallel/context.h" #include "frontend/parallel/graph_util/graph_info.h" #include "frontend/parallel/status.h" +#include "frontend/parallel/step_parallel.h" #include "utils/log_adapter.h" namespace mindspore { @@ -32,13 +33,15 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); bool enable_all_reduce_fusion = ParallelContext::GetInstance()->enable_all_reduce_fusion(); + auto graph_set = ForwardGraph(root); // assume no change to graph bool changes = false; // control whether use model_parallel mode if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) || - (!enable_all_reduce_fusion) || (root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY))) { + (!enable_all_reduce_fusion) || (root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY)) || graph_set.size() < 1) { return changes; } + #if defined(_WIN32) || defined(_WIN64) auto start_time = std::chrono::steady_clock::now(); #else @@ -47,7 +50,7 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti }, end_time{0}; (void)gettimeofday(&start_time, nullptr); #endif - MS_LOG(INFO) << "Now entering allreduce fusion"; + MS_LOG(INFO) << "Now entering allreduce fusion by size, and allreduce fusion before will be overlapped!"; DumpGraph(root, std::string(ALLREDUCE_FUSION_BEGIN)); pipeline::ResourceBasePtr res = optimizer->resource(); diff --git a/mindspore/ccsrc/frontend/parallel/context.cc b/mindspore/ccsrc/frontend/parallel/context.cc index ee39bb0d6e6..c883f3862d5 100644 --- a/mindspore/ccsrc/frontend/parallel/context.cc +++ b/mindspore/ccsrc/frontend/parallel/context.cc @@ -36,6 +36,8 @@ std::vector STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECUR std::vector COMMUNI_PARALLEL_MODE_LIST = {ALL_GROUP_PARALLEL, SAME_SERVER_GROUP_PARALLEL, NO_GROUP_PARALLEL}; +std::vector FUSION_MODE_LIST = {FUSION_AUTO, FUSION_SIZE, FUSION_INDEX}; + std::shared_ptr ParallelContext::inst_context_ = nullptr; std::shared_ptr ParallelContext::GetInstance() { @@ -60,7 +62,7 @@ void ParallelContext::Reset() { parallel_mode_ = STAND_ALONE; parameter_broadcast_ = false; parameter_broadcast_is_set_ = false; - enable_all_reduce_fusion_ = false; + enable_all_reduce_fusion_ = true; strategy_ckpt_load_file_ = ""; strategy_ckpt_save_file_ = ""; enable_parallel_optimizer_ = false; @@ -76,6 +78,9 @@ void ParallelContext::Reset() { grad_accumulation_shard_ = true; sharding_propagation_ = false; dataset_strategy_.clear(); + fusion_threshold_mb_ = FUSUION_THRESHOLD; + fusion_threshold_is_set_ = true; + fusion_mode_ = FUSION_AUTO; } void ParallelContext::set_device_num(int64_t device_num) { @@ -83,6 +88,22 @@ void ParallelContext::set_device_num(int64_t device_num) { device_num_is_set_ = true; } +void ParallelContext::set_fusion_threshold_mb(int64_t fusion_threshold) { + fusion_threshold_mb_ = fusion_threshold; + fusion_threshold_is_set_ = true; + enable_all_reduce_fusion_ = true; +} + +bool ParallelContext::set_fusion_mode(const std::string &fusion_mode) { + auto iter = std::find(FUSION_MODE_LIST.begin(), FUSION_MODE_LIST.end(), fusion_mode); + if (iter == FUSION_MODE_LIST.end()) { + MS_LOG(INFO) << "Invalid fusion mode:" << fusion_mode; + return false; + } + fusion_mode_ = fusion_mode; + return true; +} + void ParallelContext::set_global_rank(int64_t global_rank) { global_rank_ = global_rank; global_rank_is_set_ = true; diff --git a/mindspore/ccsrc/frontend/parallel/context.h b/mindspore/ccsrc/frontend/parallel/context.h index 327b1659d01..950beddc464 100644 --- a/mindspore/ccsrc/frontend/parallel/context.h +++ b/mindspore/ccsrc/frontend/parallel/context.h @@ -52,6 +52,11 @@ constexpr char SAME_SERVER_GROUP_PARALLEL[] = "same_server_group_parallel"; constexpr char NO_GROUP_PARALLEL[] = "no_group_parallel"; constexpr char IS_FIRST_ITERATION[] = "is_first_iteration"; + +constexpr char FUSION_AUTO[] = "auto"; +constexpr char FUSION_SIZE[] = "size"; +constexpr char FUSION_INDEX[] = "index"; +constexpr int64_t FUSUION_THRESHOLD = 64; class ParallelContext { public: ~ParallelContext() = default; @@ -78,6 +83,11 @@ class ParallelContext { void set_device_num(int64_t device_num); int64_t device_num() const { return device_num_; } + void set_fusion_threshold_mb(int64_t fusion_threshold); + int64_t fusion_threshold_mb() const { return fusion_threshold_mb_; } + bool set_fusion_mode(const std::string &fusion_mode); + std::string get_fusion_mode() const { return fusion_mode_; } + void set_pipeline_stage_split_num(const int64_t stages); int64_t pipeline_stage_split_num() const { return pipeline_stage_split_num_; } @@ -159,6 +169,7 @@ class ParallelContext { bool gradient_fp32_sync_; bool loss_repeated_mean_; int64_t device_num_; + int64_t fusion_threshold_mb_; int64_t global_rank_; int64_t grad_accumulation_step_; std::string parallel_mode_; @@ -166,6 +177,7 @@ class ParallelContext { int64_t pipeline_stage_split_num_; bool parameter_broadcast_; bool device_num_is_set_; + bool fusion_threshold_is_set_; bool global_rank_is_set_; bool parameter_broadcast_is_set_; bool enable_all_reduce_fusion_; @@ -186,6 +198,7 @@ class ParallelContext { bool dataset_repeat_dim_right_ = false; bool hccl_test_available_ = false; bool sharding_propagation_; + std::string fusion_mode_; }; } // namespace parallel diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index 086f1470450..33f1ab2c682 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -148,6 +148,10 @@ PYBIND11_MODULE(_c_expression, m) { .def("set_hccl_test_avaible", &ParallelContext::set_hccl_test_available, "Set hccl test available.") .def("set_device_num", &ParallelContext::set_device_num, "Set device num.") .def("get_device_num_is_set", &ParallelContext::device_num_is_set, "Get device num is set.") + .def("set_fusion_threshold_mb", &ParallelContext::set_fusion_threshold_mb, "Set fusion threshold.") + .def("fusion_threshold_mb", &ParallelContext::fusion_threshold_mb, "Get fusion threshold.") + .def("set_fusion_mode", &ParallelContext::set_fusion_mode, "Get fusion mode.") + .def("get_fusion_mode", &ParallelContext::get_fusion_mode, "Get fusion mode.") .def("get_global_rank", &ParallelContext::global_rank, "Get global rank.") .def("set_global_rank", &ParallelContext::set_global_rank, "Set global rank.") .def("get_grad_accumulation_shard", &ParallelContext::grad_accumulation_shard, "Get grad_accumulation_shard.") diff --git a/mindspore/context.py b/mindspore/context.py index 75996ca9cb5..f1fa18b8b32 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -373,7 +373,7 @@ def _context(): auto_parallel_search_mode=str, search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int, - parallel_optimizer_config=dict) + parallel_optimizer_config=dict, comm_fusion=dict) def set_auto_parallel_context(**kwargs): r""" Set auto parallel context, which is valid only for Ascend and GPU target. @@ -402,6 +402,7 @@ def set_auto_parallel_context(**kwargs): parallel_optimizer_config pipeline_stages \ grad_accumulation_step \ auto_parallel_search_mode + \ comm_fusion =========================== =========================== Args: @@ -476,6 +477,16 @@ def set_auto_parallel_context(**kwargs): with larger batch size. This configure is effective only when the model runs on pipeline training or gradient accumulation with data parallel. + comm_fusion (dict): A dict contains the types and configurations for setting the communication fusion. each + communication fusion config has two keys: "mode" and "config". + It supports following communication fusion types and configurations: + + - allreduce: if communication fusion type is `allreduce`. The `mode` contains: `auto`, `size` + and `index`. In `auto` mode, allreduce fusion is configured by gradients size, and the default + fusion threshold is `64` MB. In 'size' mode, allreduce fusion is configured by gradients size + manually, and the fusion threshold must be larger than `0` MB. In `index` mode, it is same as + `all_reduce_fusion_config`. + Raises: ValueError: If input key is not attribute in auto parallel context. @@ -498,6 +509,8 @@ def set_auto_parallel_context(**kwargs): >>> context.set_auto_parallel_context(pipeline_stages=2) >>> parallel_config = {"gradient_accumulation_shard": True} >>> context.set_auto_parallel_context(parallel_optimizer_config=parallel_config, enable_parallel_optimizer=True) + >>> comm_fusion_config = {"allreduce": {"mode": "size", "config": 32}} + >>> context.set_auto_parallel_context(comm_fusion=comm_fusion_config) """ _set_auto_parallel_context(**kwargs) @@ -540,6 +553,7 @@ def reset_auto_parallel_context(): - full_batch: False. - enable_parallel_optimizer: False. - pipeline_stages: 1. + - fusion_threshold: 64. """ _reset_auto_parallel_context() diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index c178bb8bd6a..e855f49f401 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -26,6 +26,16 @@ _MAX_GROUP_NAME_LEN = 127 _DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1" _DEFAULT_NCCL_FUSION_GROUP_NAME = "nccl_world_groupsum1" +class _ParallelFusionConfig: + """ + The key of the Parallel fusion method configuration. + """ + ALLREDUCE = "allreduce" + MODE = "mode" + FUSION_CONFIG = "config" + AUTO = "auto" + INDEX = "index" + SIZE = "size" class _ParallelOptimizerConfig: """ @@ -89,6 +99,86 @@ class _AutoParallelContext: self.check_context_handle() return self._context_handle.get_device_num() + def set_comm_fusion(self, config): + """ + Set fusion method for auto parallel. + + Args: + config (dict): A dict contains the methods and values for setting the communication fusion. Currently it + supports: `allreduce`. + + Raises: + KeyError: When key of comm_fusion is not 'allreduce'. + """ + self.check_context_handle() + for key in list(config.keys()): + if key == _ParallelFusionConfig.ALLREDUCE: + self._set_allreduce_comm_fusion(config[key]) + else: + raise KeyError("comm fusion type must be allreduce, but got {}".format(key)) + + def _set_allreduce_comm_fusion(self, comm_fusion): + """ + Set fusion method for auto parallel. + + Args: + comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it + supports four fusion methods: `auto`, `size` and `index`. + + Raises: + KeyError: When key of comm_fusion is not 'mode' or 'config'. + KeyError: When `mode` is not 'auto', 'size' or 'index'. + """ + self.check_context_handle() + if _ParallelFusionConfig.MODE not in comm_fusion: + raise KeyError("For 'comm_fusion', the key 'mode' should be contained.") + if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion: + raise KeyError("For 'comm_fusion', the key 'config' should be contained.") + check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.INDEX, _ParallelFusionConfig.SIZE] + if comm_fusion[_ParallelFusionConfig.MODE] in check_mode: + self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE]) + else: + raise KeyError("fusion method mode must be auto, index or size, but got {}".format( + comm_fusion[_ParallelFusionConfig.MODE])) + if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.AUTO: + self.set_fusion_threshold_mb(fusion_threshold=64) + if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.SIZE: + self.set_fusion_threshold_mb(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]) + if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.INDEX: + self.set_all_reduce_fusion_split_indices(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]) + + def get_comm_fusion(self): + """Get comm fusion config.""" + self.check_context_handle() + mode = self._context_handle.get_fusion_mode() + if mode in (_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE): + config = self.fusion_threshold_mb() + if mode == _ParallelFusionConfig.INDEX: + config = self.get_all_reduce_fusion_split_indices() + return {_ParallelFusionConfig.ALLREDUCE: {_ParallelFusionConfig.MODE: mode, + _ParallelFusionConfig.FUSION_CONFIG: config}} + + + def set_fusion_threshold_mb(self, fusion_threshold=64): + """ + Set fusion threshold (MB) for auto parallel. + + Args: + fusion_threshold (int): The fusion threshold (unit: MB). Default: 64. + + Raises: + ValueError: If the fusion threshold is not in [0, +inf]. + """ + self.check_context_handle() + if fusion_threshold < 0: + raise ValueError("fusion threshold must be larger than 0, but got {}".format(fusion_threshold)) + self._context_handle.set_fusion_threshold_mb(fusion_threshold) + + def fusion_threshold_mb(self): + """Get device num.""" + self.check_context_handle() + return self._context_handle.fusion_threshold_mb() + def set_global_rank(self, global_rank): """ Set global rank for auto parallel. @@ -741,7 +831,8 @@ _set_auto_parallel_context_func_map = { "optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size, "optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save, "sharding_propagation": auto_parallel_context().set_sharding_propagation, - "enable_alltoall": auto_parallel_context().set_enable_alltoall} + "enable_alltoall": auto_parallel_context().set_enable_alltoall, + "comm_fusion": auto_parallel_context().set_comm_fusion} _get_auto_parallel_context_func_map = { @@ -766,7 +857,8 @@ _get_auto_parallel_context_func_map = { "optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size, "optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save, "sharding_propagation": auto_parallel_context().get_sharding_propagation, - "enable_alltoall": auto_parallel_context().get_enable_alltoall} + "enable_alltoall": auto_parallel_context().get_enable_alltoall, + "comm_fusion": auto_parallel_context().get_comm_fusion} @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, @@ -775,7 +867,7 @@ _get_auto_parallel_context_func_map = { strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str, communi_parallel_mode=str, optimizer_weight_shard_size=int, sharding_propagation=bool, - optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool) + optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool, comm_fusion=dict) def _set_auto_parallel_context(**kwargs): """ @@ -848,6 +940,16 @@ def _set_auto_parallel_context(**kwargs): search the desired strategies. Default: False. enable_alltoall (bool): Set the value of enabling AllToAll. If False, AllGather and Split are used to circumvent AllToAll. Default: False. + comm_fusion (dict): A dict contains the types and configurations for setting the communication fusion. each + communication fusion config has two keys: "mode" and "config". + It supports following communication fusion types and configurations: + + - allreduce: if communication fusion type is `allreduce`. The `mode` contains: `auto`, `size` + and `index`. In `auto` mode, allreduce fusion is configured by gradients size, and the default + fusion threshold is `64` MB. In 'size' mode, allreduce fusion is configured by gradients size + manually, and the fusion threshold must be larger than `0` MB. In `index` mode, it is same as + `all_reduce_fusion_config`. + Raises: ValueError: If input key is not attribute in auto parallel context. @@ -896,5 +998,6 @@ def _reset_auto_parallel_context(): - sharding_propagation: False - pipeline_stages: 0 - gradient_accumulation_shard: True + - fusion_threshold: 64 """ auto_parallel_context().reset() diff --git a/tests/ut/python/parallel/test_allreduce_fusion.py b/tests/ut/python/parallel/test_allreduce_fusion.py index 717fc64be40..ba9bb070e41 100644 --- a/tests/ut/python/parallel/test_allreduce_fusion.py +++ b/tests/ut/python/parallel/test_allreduce_fusion.py @@ -13,20 +13,46 @@ # limitations under the License. import numpy as np -import pytest import mindspore as ms import mindspore.nn as nn from mindspore import Tensor, context from mindspore.common.api import _cell_graph_executor +from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits +from mindspore.nn.optim import Lamb from mindspore.nn.optim.momentum import Momentum +from mindspore.ops import operations as P from mindspore.parallel import _cost_model_context as cost_model_context from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.train import Model from mindspore.context import ParallelMode from tests.dataset_mock import MindData +context.set_context(mode=context.PYNATIVE_MODE, save_graphs=True) +class Net(nn.Cell): + """Net definition""" + def __init__(self): + super(Net, self).__init__() + self.fc1 = nn.Dense(128, 768, activation='relu') + self.fc2 = nn.Dense(128, 768, activation='relu') + self.fc3 = nn.Dense(128, 768, activation='relu') + self.fc4 = nn.Dense(768, 768, activation='relu') + self.relu4 = nn.ReLU() + self.relu5 = nn.ReLU() + self.transpose = P.Transpose() + self.matmul1 = P.MatMul() + self.matmul2 = P.MatMul() + + def construct(self, x): + q = self.fc1(x) + k = self.fc2(x) + v = self.fc3(x) + k = self.transpose(k, (1, 0)) + c = self.relu4(self.matmul1(q, k)) + s = self.relu5(self.matmul2(c, v)) + s = self.fc4(s) + return s class Dataset(MindData): def __init__(self, predict, label, length=3): @@ -107,7 +133,6 @@ def train_common(net): momentum = 0.9 epoch_size = 2 device_num = 4 - auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True) context.set_auto_parallel_context(device_num=device_num, parameter_broadcast=False) context.set_context(mode=context.GRAPH_MODE) @@ -121,196 +146,78 @@ def train_common(net): model.train(epoch_size, dataset, dataset_sink_mode=False) allreduce_fusion_dict = _cell_graph_executor._get_allreduce_fusion(model._train_network) - print(allreduce_fusion_dict) return allreduce_fusion_dict - -@pytest.mark.skip(reason="depreciated feature") -def test_allreduce_fusion_parameters(): - cost_model_context.reset_cost_model_context() - cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2) - algorithm = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_algorithm') - assert algorithm == 2 - cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) - algorithm = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_algorithm') - assert algorithm == 1 - cost_model_context.reset_cost_model_context() - algorithm = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_algorithm') - assert algorithm == 0 - - cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) - fusion_times = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_times') - assert fusion_times == 2 - - cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.2) - tail_percent = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_percent') - assert tail_percent == 0.2 - cost_model_context.reset_cost_model_context() - tail_percent = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_percent') - assert tail_percent == 0.1 - - cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.2) - tail_time = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_time') - assert tail_time == 0.2 - cost_model_context.reset_cost_model_context() - tail_time = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_time') - assert tail_time == 0.1 - - cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_inherent_time=0.2) - allreduce_inherent_time = cost_model_context.get_cost_model_context( - 'costmodel_allreduce_fusion_allreduce_inherent_time') - assert allreduce_inherent_time == 0.2 - cost_model_context.reset_cost_model_context() - allreduce_inherent_time = cost_model_context.get_cost_model_context( - 'costmodel_allreduce_fusion_allreduce_inherent_time') - assert allreduce_inherent_time == 0.1 - - cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_bandwidth=0.2) - allreduce_bandwidth = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_allreduce_bandwidth') - assert allreduce_bandwidth == 0.2 - cost_model_context.reset_cost_model_context() - allreduce_bandwidth = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_allreduce_bandwidth') - assert allreduce_bandwidth == 0.1 - - cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_computation_time_parameter=0.2) - computation_time_parameter = cost_model_context.get_cost_model_context( - 'costmodel_allreduce_fusion_computation_time_parameter') - assert computation_time_parameter == 0.2 - cost_model_context.reset_cost_model_context() - computation_time_parameter = cost_model_context.get_cost_model_context( - 'costmodel_allreduce_fusion_computation_time_parameter') - assert computation_time_parameter == 0.1 - - -@pytest.mark.skip(reason="depreciated feature") -def test_allreduce_fusion1(): - cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) - cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) - cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5) - context.reset_auto_parallel_context() - context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) +def test_allreduce_fusion_auto(): + """ + Feature: test_allreduce_fusion in auto mode + Description: allreduce fusion in auto mode + Expectation: success + """ + comm_fusion_dict = {"allreduce": {"mode": "auto", "config": None}} + context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, comm_fusion=comm_fusion_dict) net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None)) allreduce_fusion_dict = train_common(net) - expect_dict = {'backbone2.fc8.weight': 2, - 'backbone2.fc7.weight': 2, - 'backbone2.fc6.weight': 2, - 'backbone1.fc4.weight': 2, - 'backbone1.fc3.weight': 2, - 'backbone1.fc2.weight': 2, - 'backbone2.fc5.weight': 1, - 'backbone2.fc4.weight': 1, - 'backbone2.fc3.weight': 1, - 'backbone2.fc2.weight': 1, - 'backbone2.fc1.weight': 1, - 'backbone1.fc1.weight': 1} - assert allreduce_fusion_dict == expect_dict - cost_model_context.reset_cost_model_context() - - -@pytest.mark.skip(reason="depreciated feature") -# reset_cost_model_context is called, the default value of costmodel_allreduce_fusion_times is 0, step_allreduce_fusion -# is bypassed. -def test_allreduce_fusion2(): - cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) - cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5) - cost_model_context.reset_cost_model_context() - context.reset_auto_parallel_context() - context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) - net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None)) - allreduce_fusion_dict = train_common(net) - expect_dict = {} - assert allreduce_fusion_dict == expect_dict - cost_model_context.reset_cost_model_context() - - -@pytest.mark.skip(reason="depreciated feature") -def test_allreduce_fusion3(): - cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) - cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=3) - cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.3333333) - context.reset_auto_parallel_context() - context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) - net = SimpleDMLNet(DenseNet1(has_bias=True, activation='relu'), DenseNet2(has_bias=False, activation='relu')) - allreduce_fusion_dict = train_common(net) - expect_dict = {'backbone2.fc8.weight': 3, - 'backbone2.fc7.weight': 3, - 'backbone2.fc6.weight': 2, - 'backbone2.fc5.weight': 2, - 'backbone2.fc4.weight': 2, - 'backbone2.fc3.weight': 1, - 'backbone2.fc2.weight': 1, - 'backbone2.fc1.weight': 1, - 'backbone1.fc4.bias': 3, - 'backbone1.fc4.weight': 3, - 'backbone1.fc3.bias': 3, - 'backbone1.fc3.weight': 2, - 'backbone1.fc2.bias': 2, - 'backbone1.fc2.weight': 2, - 'backbone1.fc1.bias': 2, - 'backbone1.fc1.weight': 2} - assert allreduce_fusion_dict == expect_dict - cost_model_context.reset_cost_model_context() - - -@pytest.mark.skip(reason="depreciated feature") -def test_allreduce_fusion4(): - cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) - cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) - cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5) - context.reset_auto_parallel_context() - context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) - net = SimpleDMLNet(DenseNet2(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None)) - allreduce_fusion_dict = train_common(net) - expect_dict = {'backbone2.fc8.weight': 2, - 'backbone2.fc7.weight': 2, - 'backbone2.fc6.weight': 2, - 'backbone1.fc8.weight': 2, - 'backbone1.fc7.weight': 2, - 'backbone1.fc6.weight': 2, - 'backbone2.fc5.weight': 1, - 'backbone2.fc4.weight': 1, - 'backbone2.fc3.weight': 1, - 'backbone2.fc2.weight': 1, - 'backbone2.fc1.weight': 1, - 'backbone1.fc5.weight': 1, + expect_dict = {'backbone2.fc8.weight': 1, + 'backbone2.fc7.weight': 1, + 'backbone2.fc6.weight': 1, 'backbone1.fc4.weight': 1, 'backbone1.fc3.weight': 1, 'backbone1.fc2.weight': 1, - 'backbone1.fc1.weight': 1} - - assert allreduce_fusion_dict == expect_dict - cost_model_context.reset_cost_model_context() - - -@pytest.mark.skip(reason="depreciated feature") -def test_allreduce_fusion5(): - cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2) - cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.1) - cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_inherent_time=0.05) - cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_bandwidth=0.000001) - cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_computation_time_parameter=0.0000015) - context.reset_auto_parallel_context() - context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) - net = SimpleDMLNet(DenseNet2(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None)) - allreduce_fusion_dict = train_common(net) - - expect_dict = {'backbone2.fc8.weight': 3, - 'backbone2.fc7.weight': 3, - 'backbone2.fc6.weight': 3, - 'backbone2.fc5.weight': 3, - 'backbone2.fc4.weight': 2, - 'backbone2.fc3.weight': 2, + 'backbone2.fc5.weight': 1, + 'backbone2.fc4.weight': 1, + 'backbone2.fc3.weight': 1, 'backbone2.fc2.weight': 1, 'backbone2.fc1.weight': 1, - 'backbone1.fc8.weight': 3, - 'backbone1.fc7.weight': 3, - 'backbone1.fc6.weight': 3, - 'backbone1.fc5.weight': 3, - 'backbone1.fc4.weight': 2, - 'backbone1.fc3.weight': 2, - 'backbone1.fc2.weight': 1, - 'backbone1.fc1.weight': 1,} - + 'backbone1.fc1.weight': 1} assert allreduce_fusion_dict == expect_dict cost_model_context.reset_cost_model_context() + +def test_allreduce_fusion_size(): + """ + Feature: test_allreduce_fusion in size mode + Description: allreduce fusion in size mode + Expectation: success + """ + comm_fusion_dict = {"allreduce": {"mode": "size", "config": 32}} + context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, comm_fusion=comm_fusion_dict) + net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None)) + allreduce_fusion_dict = train_common(net) + expect_dict = {'backbone2.fc8.weight': 1, + 'backbone2.fc7.weight': 1, + 'backbone2.fc6.weight': 1, + 'backbone1.fc4.weight': 1, + 'backbone1.fc3.weight': 1, + 'backbone1.fc2.weight': 1, + 'backbone2.fc5.weight': 1, + 'backbone2.fc4.weight': 1, + 'backbone2.fc3.weight': 1, + 'backbone2.fc2.weight': 1, + 'backbone2.fc1.weight': 1, + 'backbone1.fc1.weight': 1} + assert allreduce_fusion_dict == expect_dict + cost_model_context.reset_cost_model_context() + comm_fusion = auto_parallel_context().get_comm_fusion() + assert comm_fusion_dict == comm_fusion + +def test_lamb_split_fusion_in_index(): + """ + Feature: test_allreduce_fusion in index mode + Description: allreduce fusion in index mode + Expectation: success + """ + comm_fusion_dict = {"allreduce": {"mode": "index", "config": [2, 4, 6, 8]}} + context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True, + comm_fusion=comm_fusion_dict) + inputs = Tensor(np.ones([32, 128]).astype(np.float32)) + label = Tensor(np.zeros([32, 768]).astype(np.float32)) + net = Net() + net.set_train() + loss = nn.SoftmaxCrossEntropyWithLogits() + optimizer = Lamb(net.trainable_params(), learning_rate=0.1) + + net_with_loss = WithLossCell(net, loss) + train_network = TrainOneStepCell(net_with_loss, optimizer) + _cell_graph_executor.compile(train_network, inputs, label) + context.reset_auto_parallel_context()