!27024 add allreduce fusion by size

Merge pull request !27024 from jiahongQian/master
This commit is contained in:
i-robot 2021-12-06 05:32:12 +00:00 committed by Gitee
commit 2d23b698a6
11 changed files with 296 additions and 573 deletions

View File

@ -24,6 +24,7 @@
#include <vector>
#include <iostream>
#include <utility>
#include <unordered_map>
#include "utils/hash_map.h"
#include "ir/anf.h"
@ -38,7 +39,7 @@
namespace mindspore {
namespace ad {
using Registry = mindspore::HashMap<PrimitivePtr, FuncGraphPtr, PrimitiveHasher, PrimitiveTotalEqual>;
using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr, PrimitiveHasher, PrimitiveTotalEqual>;
class KPrim;
extern KPrim g_k_prims;
class DFunctor;

View File

@ -43,7 +43,7 @@ using PrimBpropOptGraphInfoPtr = std::shared_ptr<PrimBpropOptGraphInfo>;
using PrimBpropOptGraphLevel2InfoPtr = std::shared_ptr<PrimBpropOptGraphLevel2Info>;
using PrimBpropCache = mindspore::HashMap<PrimitivePtr, PrimBpropOptGraphInfoPtr, PrimitiveHasher, PrimitiveTotalEqual>;
using PrimBpropCache = std::unordered_map<PrimitivePtr, PrimBpropOptGraphInfoPtr, PrimitiveHasher, PrimitiveTotalEqual>;
using TupleListKey = std::pair<PrimitivePtr, abstract::AbstractBasePtrList>;

View File

@ -18,6 +18,8 @@
#include <memory>
#include <queue>
#include <string>
#include <functional>
#include <utility>
#include "utils/hash_set.h"
#include "ir/func_graph.h"
#include "frontend/parallel/costmodel_context.h"
@ -28,375 +30,53 @@
namespace mindspore {
namespace parallel {
mindspore::HashSet<CNodePtr> FindCNodesWithPara(const AnfNodePtr &para, uint64_t recursive_times = 0) {
if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
MS_LOG(EXCEPTION) << "FindCNodesWithPara exceeds max recursive call times! Max recursive call times is "
<< MAX_RECURSIVE_CALL_TIMES;
}
MS_EXCEPTION_IF_NULL(para);
MS_EXCEPTION_IF_NULL(para->func_graph());
FuncGraphManagerPtr manager = para->func_graph()->manager();
MS_EXCEPTION_IF_NULL(manager);
auto node_set = manager->node_users()[para];
mindspore::HashSet<CNodePtr> cnode_set;
for (auto &node_pair : node_set) {
auto cnode = node_pair.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!IsValueNode<Primitive>(cnode->input(0))) {
continue;
}
auto node_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(node_prim);
if (node_prim->name() == DEPEND && node_pair.second != 1) {
continue;
}
if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
(void)cnode_set.emplace(cnode);
} else {
auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1);
for (auto &cnode_sub : cnode_set_sub) {
(void)cnode_set.emplace(cnode_sub);
}
}
}
return cnode_set;
}
Status AllreduceFusion::AddNodeToGraph() {
const auto &parameters = root_graph_->parameters();
for (auto &parameter : parameters) {
if (!ParameterRequireGrad(parameter)) {
continue;
}
auto cnode_set = FindCNodesWithPara(parameter);
if (cnode_set.empty()) {
continue;
}
for (auto &cnode : cnode_set) {
MS_LOG(DEBUG) << "AddNode " << cnode->DebugString();
if (allreduce_graph_.AddNode(cnode, parameter) != SUCCESS) {
MS_LOG(ERROR) << "AddNode failed! cnode: " << cnode->DebugString();
return FAILED;
}
}
}
return SUCCESS;
}
CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint64_t recursive_times) const {
if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
MS_LOG(EXCEPTION) << "FindCNode exceeds max recursive call times! Max recursive call times is "
<< MAX_RECURSIVE_CALL_TIMES;
}
MS_EXCEPTION_IF_NULL(from);
mindspore::HashMap<CNodePtr, double> cnode_dist;
if (!from->isa<CNode>()) {
return cnode_dist;
}
auto cnode = from->cast<CNodePtr>();
if (!IsValueNode<Primitive>(cnode->input(0))) {
return cnode_dist;
}
auto operator_info = cnode->user_data<OperatorInfo>();
MS_LOG(DEBUG) << "cnode " << cnode->ToString() << " IsParallelCareNode: " << IsParallelCareNode(cnode)
<< " operator_info: " << (operator_info != nullptr);
if (IsParallelCareNode(cnode) && (operator_info != nullptr)) {
auto cost = operator_info->GetForwardMemoryCostFromCNode();
MS_LOG(DEBUG) << "cnode " << cnode->DebugString() << " cost: " << cost;
if (allreduce_graph_.NodeInGraph(cnode)) {
cnode_dist[cnode] = cost;
return cnode_dist;
} else {
auto cnode_dist_next = FindNextCNodes(cnode, recursive_times + 1);
for (auto &ele_next : cnode_dist_next) {
cnode_dist[ele_next.first] = cost + ele_next.second;
}
}
} else {
auto cnode_dist_next = FindNextCNodes(cnode);
for (auto &ele : cnode_dist_next) {
cnode_dist[ele.first] = ele.second;
}
}
return cnode_dist;
}
CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr &from, uint64_t recursive_times) const {
if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
MS_LOG(EXCEPTION) << "FindNextCNodes exceeds max recursive call times! Max recursive call times is "
<< MAX_RECURSIVE_CALL_TIMES;
}
const auto &from_inputs = from->inputs();
mindspore::HashMap<CNodePtr, double> dist_map;
MS_LOG(DEBUG) << "from cnode " << from->DebugString() << " has " << from_inputs.size() << " inputs";
for (auto &input_node : from_inputs) {
auto cnode_dist = FindCNode(input_node, recursive_times + 1);
for (auto &ele : cnode_dist) {
(void)dist_map.emplace(ele);
}
}
return dist_map;
}
Status AllreduceFusion::AddEdgeToGraph() {
mindspore::HashMap<CNodePtr, int64_t> cnode_state_map;
const auto &cnodes = allreduce_graph_.cnode_set();
for (auto &cnode : cnodes) {
cnode_state_map[cnode] = 0;
}
const auto &head_cnode = allreduce_graph_.head_cnode();
std::queue<CNodePtr> cnode_queue;
cnode_queue.emplace(head_cnode);
cnode_state_map[head_cnode] = 1;
while (!cnode_queue.empty()) {
const auto cur_cnode = cnode_queue.front();
cnode_queue.pop();
cnode_state_map[cur_cnode] = 2;
auto next = FindNextCNodes(cur_cnode);
for (auto &ele : next) {
auto &cnode = ele.first;
auto &dist = ele.second;
if (cnode_state_map[cnode] == 0) {
cnode_queue.emplace(cnode);
cnode_state_map[cnode] = 1;
}
if (allreduce_graph_.AddEdge(cur_cnode, cnode, dist) != SUCCESS) {
MS_LOG(ERROR) << "AddEdge error";
return FAILED;
}
MS_LOG(DEBUG) << "from " << cur_cnode->DebugString() << ", to " << cnode->DebugString() << " dist " << dist;
}
}
return SUCCESS;
}
std::vector<CNodePtr> FindMirror(const AnfNodePtr &para, uint64_t recursive_times = 0) {
if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
MS_LOG(EXCEPTION) << "FindMirror exceeds max recursive call times! Max recursive call times is "
<< MAX_RECURSIVE_CALL_TIMES;
}
MS_EXCEPTION_IF_NULL(para);
MS_EXCEPTION_IF_NULL(para->func_graph());
FuncGraphManagerPtr manager = para->func_graph()->manager();
MS_EXCEPTION_IF_NULL(manager);
AnfNodeIndexSet node_set = manager->node_users()[para];
std::vector<CNodePtr> cnode_list;
for (auto &node_pair : node_set) {
auto cnode = node_pair.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!IsValueNode<Primitive>(cnode->input(0))) {
continue;
}
auto node_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(node_prim);
if (node_prim->name() == CAST) {
auto mirror_cnodes = FindMirror(node_pair.first, recursive_times + 1);
if (mirror_cnodes.empty()) {
MS_LOG(WARNING) << "mirror node after cast not found";
continue;
}
if (mirror_cnodes.size() > 1) {
MS_LOG(EXCEPTION) << "mirror node after cast number is not 1";
}
cnode_list.emplace_back(mirror_cnodes[0]);
}
if (node_prim->name() == MIRROR_OPERATOR) {
cnode_list.emplace_back(cnode);
}
}
return cnode_list;
}
void SetMirrorFusion(const CNodePtr &mirror_cnode, int64_t fusion, const std::string &parameter_name) {
MS_EXCEPTION_IF_NULL(mirror_cnode);
MS_LOG(DEBUG) << "Set Mirror " << mirror_cnode->DebugString() << " fusion " << fusion;
auto node_prim = GetValueNode<PrimitivePtr>(mirror_cnode->input(0));
auto old_value_ptr = node_prim->GetAttr(FUSION);
if (old_value_ptr != nullptr) {
if (old_value_ptr->isa<Int64Imm>()) {
int64_t old_value = old_value_ptr->cast<Int64ImmPtr>()->value();
if (old_value < fusion) {
return;
}
}
}
(void)node_prim->AddAttr(FUSION, MakeValue(std::make_shared<Int64Imm>(fusion)));
(void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared<StringImm>(parameter_name)));
}
Status FindMirrorAndSetFusion(const AnfNodePtr &para, int64_t fusion) {
auto mirror_cnodes = FindMirror(para);
if (mirror_cnodes.empty()) {
MS_LOG(WARNING) << para->ToString() << " 0 Mirror CNode found.";
return SUCCESS;
}
if (mirror_cnodes.size() > 2) {
for (auto &mirror_cnode_1 : mirror_cnodes) {
MS_EXCEPTION_IF_NULL(mirror_cnode_1);
MS_LOG(INFO) << mirror_cnode_1->DebugString();
}
MS_EXCEPTION_IF_NULL(para);
MS_LOG(ERROR) << para->ToString() << " FindMirror is more than 2. " << mirror_cnodes.size()
<< "Mirror CNode found.";
return FAILED;
}
for (auto &mirror_cnode : mirror_cnodes) {
auto parameter_name = ParameterName(para);
SetMirrorFusion(mirror_cnode, fusion, parameter_name);
}
return SUCCESS;
}
Status FindMirrorAndSetFusion(const std::vector<AnfNodePtr> &paras, int64_t fusion) {
for (auto &param_node : paras) {
if (FindMirrorAndSetFusion(param_node, fusion) != SUCCESS) {
MS_LOG(ERROR) << "FindMirrorAndSetFusion failed";
return FAILED;
}
}
return SUCCESS;
}
Status AllreduceFusion::SetFusion(const std::vector<double> &cost_map) {
if (cost_map.size() < 2) {
MS_LOG(ERROR) << "cost_map must has at least 2 items, cost_map size is " << cost_map.size();
return FAILED;
}
Status AllreduceFusion::SetFusionBySize(const CNodePtr &ret, int64_t threshold) {
auto filter = [](const AnfNodePtr &node) { return !IsPrimitiveCNode(node, prim::kPrimMirror); };
auto todo = DeepScopedGraphSearchWithFilter(ret, AlwaysInclude, filter);
auto temp = threshold;
int64_t fusion = 1;
for (auto cost_iter = cost_map.end() - 1; cost_iter != cost_map.begin(); --cost_iter) {
auto paras = allreduce_graph_.GetParaByCost(*(cost_iter - 1), *cost_iter);
if (FindMirrorAndSetFusion(paras, fusion) != SUCCESS) {
MS_LOG(ERROR) << "FindMirrorAndSetFusion failed";
return FAILED;
bool init = true;
for (auto &node : todo) {
auto cnode = node->cast<CNodePtr>();
if (cnode->input(1)->Shape() == nullptr) continue;
auto input_shapes = GetNodeShape(cnode->input(1));
int64_t input_size = std::accumulate(input_shapes[0].begin(), input_shapes[0].end(), 1, std::multiplies<int64_t>());
FuncGraphPtr func_graph = cnode->func_graph();
std::pair<AnfNodePtr, bool> param_node_pair = FindParameter(cnode->input(1), func_graph);
if (!param_node_pair.first) {
continue;
}
fusion++;
}
return SUCCESS;
}
std::vector<double> AllreduceFusion::GenerateCostMap(int64_t fusion_times, double tail_percent) const {
double offset = allreduce_graph_.max() * (1 - tail_percent) / (fusion_times - 1);
MS_LOG(DEBUG) << "max = " << allreduce_graph_.max() << ", offset = " << offset;
std::vector<double> cost_map;
double begin = 0;
for (auto i = 0; i < fusion_times - 1; i++) {
cost_map.push_back(begin);
begin += offset;
}
cost_map.push_back(allreduce_graph_.max() * (1 - tail_percent));
cost_map.push_back(allreduce_graph_.max());
MS_LOG(DEBUG) << "cost_map = " << cost_map;
return cost_map;
}
Status AllreduceFusion::SetFusionByBackwardCompTime() {
auto fusion_times = CostModelContext::GetInstance()->costmodel_allreduce_fusion_times();
if (fusion_times < 2) {
MS_LOG(INFO) << "'costmodel_allreduce_fusion_times' is " << fusion_times << ". Bypass ProcessAllreduceFusion";
return SUCCESS;
}
auto tail_percent = CostModelContext::GetInstance()->costmodel_allreduce_fusion_tail_percent();
if (tail_percent < 0 || tail_percent >= 1) {
MS_LOG(INFO) << "'costmodel_allreduce_fusion_tail_percent' is " << tail_percent
<< ". Bypass ProcessAllreduceFusion";
return SUCCESS;
}
const auto cost_map = GenerateCostMap(fusion_times, tail_percent);
MS_LOG(DEBUG) << "AllreduceGraph GenerateCostMap succeed.";
if (SetFusion(cost_map) != SUCCESS) {
MS_LOG(ERROR) << "SetFusion failed.";
return FAILED;
}
MS_LOG(DEBUG) << "AllreduceGraph SetFusion succeed.";
return SUCCESS;
}
Status AllreduceFusion::GetSetFusionByBackwardCompAndAllreduceTimeParams() {
tail_time_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_tail_time();
if (tail_time_ <= 0) {
MS_LOG(INFO) << "'costmodel_allreduce_tail_time' is " << tail_time_ << ". Bypass ProcessAllreduceFusion";
return FAILED;
}
allreduce_inherent_time_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_allreduce_inherent_time();
if (allreduce_inherent_time_ <= 0) {
MS_LOG(INFO) << "'costmodel_allreduce_fusion_allreduce_inherent_time' is " << allreduce_inherent_time_
<< ". Bypass ProcessAllreduceFusion";
return FAILED;
}
if (tail_time_ <= allreduce_inherent_time_) {
MS_LOG(INFO) << "'costmodel_allreduce_tail_time' is " << tail_time_
<< "'costmodel_allreduce_fusion_allreduce_inherent_time' is " << allreduce_inherent_time_
<< ".tail_time is not more than allreduce_inherent_time. Bypass ProcessAllreduceFusion";
return FAILED;
}
allreduce_bandwidth_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_allreduce_bandwidth();
if (allreduce_bandwidth_ <= 0) {
MS_LOG(INFO) << "'costmodel_allreduce_fusion_allreduce_bandwidth' is " << allreduce_bandwidth_
<< ". Bypass ProcessAllreduceFusion";
return FAILED;
}
computation_time_parameter_ =
CostModelContext::GetInstance()->costmodel_allreduce_fusion_computation_time_parameter();
if (computation_time_parameter_ <= 0) {
MS_LOG(INFO) << "'costmodel_allreduce_fusion_computation_time_parameter' is " << computation_time_parameter_
<< ". Bypass ProcessAllreduceFusion";
return FAILED;
}
return SUCCESS;
}
Status AllreduceFusion::SetFusionByBackwardCompAndAllreduceTime() {
if (GetSetFusionByBackwardCompAndAllreduceTimeParams() != SUCCESS) {
MS_LOG(ERROR) << "GetSetFusionByBackwardCompAndAllreduceTimeParams failed!";
return FAILED;
}
allreduce_graph_.SortArnode();
if (allreduce_graph_.RemoveExtraParas() != SUCCESS) {
MS_LOG(ERROR) << "RemoveExtraParas failed!";
return FAILED;
}
double para_size = (tail_time_ - allreduce_inherent_time_) / allreduce_bandwidth_;
double to_cost = allreduce_graph_.max();
int64_t fusion = 1;
while (to_cost != 0) {
MS_LOG(INFO) << "to_cost: " << to_cost << " para_size: " << para_size;
auto node_cost_pair = allreduce_graph_.GetParaByParaSize(to_cost, para_size);
MS_LOG(INFO) << "para size: " << node_cost_pair.first.size() << " from_cost: " << node_cost_pair.second;
auto paras = node_cost_pair.first;
if (FindMirrorAndSetFusion(paras, fusion) != SUCCESS) {
MS_LOG(ERROR) << "FindMirrorAndSetFusion failed";
return FAILED;
auto parameter_name = ParameterName(param_node_pair.first);
if (input_size < temp) {
temp -= input_size;
} else {
temp = threshold;
fusion++;
}
fusion++;
para_size = ((to_cost - node_cost_pair.second) * computation_time_parameter_ - allreduce_inherent_time_) /
allreduce_bandwidth_;
to_cost = node_cost_pair.second;
if (init) {
SetMirrorFusion(cnode, 1, parameter_name);
} else {
SetMirrorFusion(cnode, fusion, parameter_name);
}
init = false;
}
MS_LOG(DEBUG) << "AllreduceGraph SetFusionByBackwardCompAndAllreduceTime succeed.";
MS_LOG(INFO) << "Allreduce fusion by size succeed.";
return SUCCESS;
}
Status AllreduceFusion::SetFusionByAlgorithm(int64_t algorithm) {
if (algorithm == 1) {
return SetFusionByBackwardCompTime();
}
return SetFusionByBackwardCompAndAllreduceTime();
}
Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) {
if (ret == nullptr) {
MS_LOG(ERROR) << "ret is nullptr.";
return FAILED;
}
auto algorithm = CostModelContext::GetInstance()->costmodel_allreduce_fusion_algorithm();
if (algorithm < 1 || algorithm > 2) {
MS_LOG(INFO) << "'costmodel_allreduce_fusion_algorithm' is " << algorithm << ". Bypass ProcessAllreduceFusion";
return SUCCESS;
}
ret_ = ret;
root_graph_ = ret_->func_graph();
MS_EXCEPTION_IF_NULL(root_graph_);
@ -409,27 +89,20 @@ Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) {
MS_EXCEPTION_IF_NULL(forward_graph);
forward_ret_ = forward_graph->get_return();
MS_EXCEPTION_IF_NULL(forward_ret_);
if (allreduce_graph_.set_head_cnode(forward_ret_) != SUCCESS) {
MS_LOG(ERROR) << "AllreduceGraph set_head_cnode failed.";
return FAILED;
}
MS_LOG(DEBUG) << "AllreduceGraph set_head_cnode succeed.";
if (AddNodeToGraph() != SUCCESS) {
MS_LOG(ERROR) << "AddNodeToGraph failed.";
auto threshold = ParallelContext::GetInstance()->fusion_threshold_mb() * 1024 * 1024 / 4.0;
if (threshold > 0) {
if (SetFusionBySize(ret, threshold) != SUCCESS) {
MS_LOG(ERROR) << "SetFusionBySize failed.";
return FAILED;
}
} else {
MS_LOG(ERROR) << "The threshold of SetFusionBySize must be larger than 0, but got " << threshold << ".";
return FAILED;
}
MS_LOG(DEBUG) << "AllreduceGraph AddNodeToGraph succeed.";
if (AddEdgeToGraph() != SUCCESS) {
MS_LOG(ERROR) << "AddNodeToGraph failed.";
return FAILED;
}
MS_LOG(DEBUG) << "AllreduceGraph AddEdgeToGraph succeed.";
if (SetFusionByAlgorithm(algorithm) != SUCCESS) {
MS_LOG(ERROR) << "SetFusionByAlgorithm failed.";
return FAILED;
}
MS_LOG(DEBUG) << "AllreduceGraph SetFusionByAlgorithm succeed.";
return SUCCESS;
}
} // namespace parallel

View File

@ -27,8 +27,6 @@
namespace mindspore {
namespace parallel {
using CNodeCostMap = mindspore::HashMap<CNodePtr, double>;
constexpr int64_t DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM = 0;
constexpr int64_t DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES = 0;
constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT = 0.1;
@ -46,31 +44,17 @@ class AllreduceFusion {
forward_ret_(nullptr),
root_graph_(nullptr),
tail_time_(0),
allreduce_inherent_time_(0),
allreduce_bandwidth_(0),
computation_time_parameter_(0) {}
virtual ~AllreduceFusion() = default;
Status ProcessAllreduceFusion(const CNodePtr &ret);
private:
Status AddNodeToGraph();
CNodeCostMap FindCNode(const AnfNodePtr &from, uint64_t recursive_times = 0) const;
CNodeCostMap FindNextCNodes(const CNodePtr &from, uint64_t recursive_times = 0) const;
Status AddEdgeToGraph();
std::vector<double> GenerateCostMap(int64_t fusion_times, double tail_percent) const;
Status SetFusion(const std::vector<double> &cost_map);
Status SetFusionByAlgorithm(int64_t algorithm);
Status SetFusionByBackwardCompTime();
Status SetFusionByBackwardCompAndAllreduceTime();
Status GetSetFusionByBackwardCompAndAllreduceTimeParams();
Status SetFusionBySize(const CNodePtr &ret, int64_t threshold);
AllreduceGraph allreduce_graph_;
CNodePtr ret_;
CNodePtr forward_ret_;
FuncGraphPtr root_graph_;
double tail_time_;
double allreduce_inherent_time_;
double allreduce_bandwidth_;
double computation_time_parameter_;
};
} // namespace parallel

View File

@ -22,6 +22,7 @@
#include "frontend/parallel/context.h"
#include "frontend/parallel/graph_util/graph_info.h"
#include "frontend/parallel/status.h"
#include "frontend/parallel/step_parallel.h"
#include "utils/log_adapter.h"
namespace mindspore {
@ -32,13 +33,15 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
bool enable_all_reduce_fusion = ParallelContext::GetInstance()->enable_all_reduce_fusion();
auto graph_set = ForwardGraph(root);
// assume no change to graph
bool changes = false;
// control whether use model_parallel mode
if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) ||
(!enable_all_reduce_fusion) || (root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY))) {
(!enable_all_reduce_fusion) || (root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY)) || graph_set.size() < 1) {
return changes;
}
#if defined(_WIN32) || defined(_WIN64)
auto start_time = std::chrono::steady_clock::now();
#else
@ -47,7 +50,7 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti
}, end_time{0};
(void)gettimeofday(&start_time, nullptr);
#endif
MS_LOG(INFO) << "Now entering allreduce fusion";
MS_LOG(INFO) << "Now entering allreduce fusion by size, and allreduce fusion before will be overlapped!";
DumpGraph(root, std::string(ALLREDUCE_FUSION_BEGIN));
pipeline::ResourceBasePtr res = optimizer->resource();

View File

@ -36,6 +36,8 @@ std::vector<std::string> STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECUR
std::vector<std::string> COMMUNI_PARALLEL_MODE_LIST = {ALL_GROUP_PARALLEL, SAME_SERVER_GROUP_PARALLEL,
NO_GROUP_PARALLEL};
std::vector<std::string> FUSION_MODE_LIST = {FUSION_AUTO, FUSION_SIZE, FUSION_INDEX};
std::shared_ptr<ParallelContext> ParallelContext::inst_context_ = nullptr;
std::shared_ptr<ParallelContext> ParallelContext::GetInstance() {
@ -60,7 +62,7 @@ void ParallelContext::Reset() {
parallel_mode_ = STAND_ALONE;
parameter_broadcast_ = false;
parameter_broadcast_is_set_ = false;
enable_all_reduce_fusion_ = false;
enable_all_reduce_fusion_ = true;
strategy_ckpt_load_file_ = "";
strategy_ckpt_save_file_ = "";
enable_parallel_optimizer_ = false;
@ -76,6 +78,9 @@ void ParallelContext::Reset() {
grad_accumulation_shard_ = true;
sharding_propagation_ = false;
dataset_strategy_.clear();
fusion_threshold_mb_ = FUSUION_THRESHOLD;
fusion_threshold_is_set_ = true;
fusion_mode_ = FUSION_AUTO;
}
void ParallelContext::set_device_num(int64_t device_num) {
@ -83,6 +88,22 @@ void ParallelContext::set_device_num(int64_t device_num) {
device_num_is_set_ = true;
}
void ParallelContext::set_fusion_threshold_mb(int64_t fusion_threshold) {
fusion_threshold_mb_ = fusion_threshold;
fusion_threshold_is_set_ = true;
enable_all_reduce_fusion_ = true;
}
bool ParallelContext::set_fusion_mode(const std::string &fusion_mode) {
auto iter = std::find(FUSION_MODE_LIST.begin(), FUSION_MODE_LIST.end(), fusion_mode);
if (iter == FUSION_MODE_LIST.end()) {
MS_LOG(INFO) << "Invalid fusion mode:" << fusion_mode;
return false;
}
fusion_mode_ = fusion_mode;
return true;
}
void ParallelContext::set_global_rank(int64_t global_rank) {
global_rank_ = global_rank;
global_rank_is_set_ = true;

View File

@ -52,6 +52,11 @@ constexpr char SAME_SERVER_GROUP_PARALLEL[] = "same_server_group_parallel";
constexpr char NO_GROUP_PARALLEL[] = "no_group_parallel";
constexpr char IS_FIRST_ITERATION[] = "is_first_iteration";
constexpr char FUSION_AUTO[] = "auto";
constexpr char FUSION_SIZE[] = "size";
constexpr char FUSION_INDEX[] = "index";
constexpr int64_t FUSUION_THRESHOLD = 64;
class ParallelContext {
public:
~ParallelContext() = default;
@ -78,6 +83,11 @@ class ParallelContext {
void set_device_num(int64_t device_num);
int64_t device_num() const { return device_num_; }
void set_fusion_threshold_mb(int64_t fusion_threshold);
int64_t fusion_threshold_mb() const { return fusion_threshold_mb_; }
bool set_fusion_mode(const std::string &fusion_mode);
std::string get_fusion_mode() const { return fusion_mode_; }
void set_pipeline_stage_split_num(const int64_t stages);
int64_t pipeline_stage_split_num() const { return pipeline_stage_split_num_; }
@ -159,6 +169,7 @@ class ParallelContext {
bool gradient_fp32_sync_;
bool loss_repeated_mean_;
int64_t device_num_;
int64_t fusion_threshold_mb_;
int64_t global_rank_;
int64_t grad_accumulation_step_;
std::string parallel_mode_;
@ -166,6 +177,7 @@ class ParallelContext {
int64_t pipeline_stage_split_num_;
bool parameter_broadcast_;
bool device_num_is_set_;
bool fusion_threshold_is_set_;
bool global_rank_is_set_;
bool parameter_broadcast_is_set_;
bool enable_all_reduce_fusion_;
@ -186,6 +198,7 @@ class ParallelContext {
bool dataset_repeat_dim_right_ = false;
bool hccl_test_available_ = false;
bool sharding_propagation_;
std::string fusion_mode_;
};
} // namespace parallel

View File

@ -148,6 +148,10 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_hccl_test_avaible", &ParallelContext::set_hccl_test_available, "Set hccl test available.")
.def("set_device_num", &ParallelContext::set_device_num, "Set device num.")
.def("get_device_num_is_set", &ParallelContext::device_num_is_set, "Get device num is set.")
.def("set_fusion_threshold_mb", &ParallelContext::set_fusion_threshold_mb, "Set fusion threshold.")
.def("fusion_threshold_mb", &ParallelContext::fusion_threshold_mb, "Get fusion threshold.")
.def("set_fusion_mode", &ParallelContext::set_fusion_mode, "Get fusion mode.")
.def("get_fusion_mode", &ParallelContext::get_fusion_mode, "Get fusion mode.")
.def("get_global_rank", &ParallelContext::global_rank, "Get global rank.")
.def("set_global_rank", &ParallelContext::set_global_rank, "Set global rank.")
.def("get_grad_accumulation_shard", &ParallelContext::grad_accumulation_shard, "Get grad_accumulation_shard.")

View File

@ -373,7 +373,7 @@ def _context():
auto_parallel_search_mode=str, search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int,
parallel_optimizer_config=dict)
parallel_optimizer_config=dict, comm_fusion=dict)
def set_auto_parallel_context(**kwargs):
r"""
Set auto parallel context, which is valid only for Ascend and GPU target.
@ -402,6 +402,7 @@ def set_auto_parallel_context(**kwargs):
parallel_optimizer_config pipeline_stages
\ grad_accumulation_step
\ auto_parallel_search_mode
\ comm_fusion
=========================== ===========================
Args:
@ -476,6 +477,16 @@ def set_auto_parallel_context(**kwargs):
with larger batch size. This configure is effective only
when the model runs on pipeline training or gradient
accumulation with data parallel.
comm_fusion (dict): A dict contains the types and configurations for setting the communication fusion. each
communication fusion config has two keys: "mode" and "config".
It supports following communication fusion types and configurations:
- allreduce: if communication fusion type is `allreduce`. The `mode` contains: `auto`, `size`
and `index`. In `auto` mode, allreduce fusion is configured by gradients size, and the default
fusion threshold is `64` MB. In 'size' mode, allreduce fusion is configured by gradients size
manually, and the fusion threshold must be larger than `0` MB. In `index` mode, it is same as
`all_reduce_fusion_config`.
Raises:
ValueError: If input key is not attribute in auto parallel context.
@ -498,6 +509,8 @@ def set_auto_parallel_context(**kwargs):
>>> context.set_auto_parallel_context(pipeline_stages=2)
>>> parallel_config = {"gradient_accumulation_shard": True}
>>> context.set_auto_parallel_context(parallel_optimizer_config=parallel_config, enable_parallel_optimizer=True)
>>> comm_fusion_config = {"allreduce": {"mode": "size", "config": 32}}
>>> context.set_auto_parallel_context(comm_fusion=comm_fusion_config)
"""
_set_auto_parallel_context(**kwargs)
@ -540,6 +553,7 @@ def reset_auto_parallel_context():
- full_batch: False.
- enable_parallel_optimizer: False.
- pipeline_stages: 1.
- fusion_threshold: 64.
"""
_reset_auto_parallel_context()

View File

@ -26,6 +26,16 @@ _MAX_GROUP_NAME_LEN = 127
_DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1"
_DEFAULT_NCCL_FUSION_GROUP_NAME = "nccl_world_groupsum1"
class _ParallelFusionConfig:
"""
The key of the Parallel fusion method configuration.
"""
ALLREDUCE = "allreduce"
MODE = "mode"
FUSION_CONFIG = "config"
AUTO = "auto"
INDEX = "index"
SIZE = "size"
class _ParallelOptimizerConfig:
"""
@ -89,6 +99,86 @@ class _AutoParallelContext:
self.check_context_handle()
return self._context_handle.get_device_num()
def set_comm_fusion(self, config):
"""
Set fusion method for auto parallel.
Args:
config (dict): A dict contains the methods and values for setting the communication fusion. Currently it
supports: `allreduce`.
Raises:
KeyError: When key of comm_fusion is not 'allreduce'.
"""
self.check_context_handle()
for key in list(config.keys()):
if key == _ParallelFusionConfig.ALLREDUCE:
self._set_allreduce_comm_fusion(config[key])
else:
raise KeyError("comm fusion type must be allreduce, but got {}".format(key))
def _set_allreduce_comm_fusion(self, comm_fusion):
"""
Set fusion method for auto parallel.
Args:
comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
supports four fusion methods: `auto`, `size` and `index`.
Raises:
KeyError: When key of comm_fusion is not 'mode' or 'config'.
KeyError: When `mode` is not 'auto', 'size' or 'index'.
"""
self.check_context_handle()
if _ParallelFusionConfig.MODE not in comm_fusion:
raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.INDEX, _ParallelFusionConfig.SIZE]
if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
else:
raise KeyError("fusion method mode must be auto, index or size, but got {}".format(
comm_fusion[_ParallelFusionConfig.MODE]))
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.AUTO:
self.set_fusion_threshold_mb(fusion_threshold=64)
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.SIZE:
self.set_fusion_threshold_mb(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.INDEX:
self.set_all_reduce_fusion_split_indices(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
def get_comm_fusion(self):
"""Get comm fusion config."""
self.check_context_handle()
mode = self._context_handle.get_fusion_mode()
if mode in (_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE):
config = self.fusion_threshold_mb()
if mode == _ParallelFusionConfig.INDEX:
config = self.get_all_reduce_fusion_split_indices()
return {_ParallelFusionConfig.ALLREDUCE: {_ParallelFusionConfig.MODE: mode,
_ParallelFusionConfig.FUSION_CONFIG: config}}
def set_fusion_threshold_mb(self, fusion_threshold=64):
"""
Set fusion threshold (MB) for auto parallel.
Args:
fusion_threshold (int): The fusion threshold (unit: MB). Default: 64.
Raises:
ValueError: If the fusion threshold is not in [0, +inf].
"""
self.check_context_handle()
if fusion_threshold < 0:
raise ValueError("fusion threshold must be larger than 0, but got {}".format(fusion_threshold))
self._context_handle.set_fusion_threshold_mb(fusion_threshold)
def fusion_threshold_mb(self):
"""Get device num."""
self.check_context_handle()
return self._context_handle.fusion_threshold_mb()
def set_global_rank(self, global_rank):
"""
Set global rank for auto parallel.
@ -741,7 +831,8 @@ _set_auto_parallel_context_func_map = {
"optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size,
"optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save,
"sharding_propagation": auto_parallel_context().set_sharding_propagation,
"enable_alltoall": auto_parallel_context().set_enable_alltoall}
"enable_alltoall": auto_parallel_context().set_enable_alltoall,
"comm_fusion": auto_parallel_context().set_comm_fusion}
_get_auto_parallel_context_func_map = {
@ -766,7 +857,8 @@ _get_auto_parallel_context_func_map = {
"optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size,
"optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save,
"sharding_propagation": auto_parallel_context().get_sharding_propagation,
"enable_alltoall": auto_parallel_context().get_enable_alltoall}
"enable_alltoall": auto_parallel_context().get_enable_alltoall,
"comm_fusion": auto_parallel_context().get_comm_fusion}
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
@ -775,7 +867,7 @@ _get_auto_parallel_context_func_map = {
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
communi_parallel_mode=str, optimizer_weight_shard_size=int, sharding_propagation=bool,
optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool)
optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool, comm_fusion=dict)
def _set_auto_parallel_context(**kwargs):
"""
@ -848,6 +940,16 @@ def _set_auto_parallel_context(**kwargs):
search the desired strategies. Default: False.
enable_alltoall (bool): Set the value of enabling AllToAll. If False, AllGather and Split are used to
circumvent AllToAll. Default: False.
comm_fusion (dict): A dict contains the types and configurations for setting the communication fusion. each
communication fusion config has two keys: "mode" and "config".
It supports following communication fusion types and configurations:
- allreduce: if communication fusion type is `allreduce`. The `mode` contains: `auto`, `size`
and `index`. In `auto` mode, allreduce fusion is configured by gradients size, and the default
fusion threshold is `64` MB. In 'size' mode, allreduce fusion is configured by gradients size
manually, and the fusion threshold must be larger than `0` MB. In `index` mode, it is same as
`all_reduce_fusion_config`.
Raises:
ValueError: If input key is not attribute in auto parallel context.
@ -896,5 +998,6 @@ def _reset_auto_parallel_context():
- sharding_propagation: False
- pipeline_stages: 0
- gradient_accumulation_shard: True
- fusion_threshold: 64
"""
auto_parallel_context().reset()

View File

@ -13,20 +13,46 @@
# limitations under the License.
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore.common.api import _cell_graph_executor
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim import Lamb
from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P
from mindspore.parallel import _cost_model_context as cost_model_context
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData
context.set_context(mode=context.PYNATIVE_MODE, save_graphs=True)
class Net(nn.Cell):
"""Net definition"""
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Dense(128, 768, activation='relu')
self.fc2 = nn.Dense(128, 768, activation='relu')
self.fc3 = nn.Dense(128, 768, activation='relu')
self.fc4 = nn.Dense(768, 768, activation='relu')
self.relu4 = nn.ReLU()
self.relu5 = nn.ReLU()
self.transpose = P.Transpose()
self.matmul1 = P.MatMul()
self.matmul2 = P.MatMul()
def construct(self, x):
q = self.fc1(x)
k = self.fc2(x)
v = self.fc3(x)
k = self.transpose(k, (1, 0))
c = self.relu4(self.matmul1(q, k))
s = self.relu5(self.matmul2(c, v))
s = self.fc4(s)
return s
class Dataset(MindData):
def __init__(self, predict, label, length=3):
@ -107,7 +133,6 @@ def train_common(net):
momentum = 0.9
epoch_size = 2
device_num = 4
auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True)
context.set_auto_parallel_context(device_num=device_num, parameter_broadcast=False)
context.set_context(mode=context.GRAPH_MODE)
@ -121,196 +146,78 @@ def train_common(net):
model.train(epoch_size, dataset, dataset_sink_mode=False)
allreduce_fusion_dict = _cell_graph_executor._get_allreduce_fusion(model._train_network)
print(allreduce_fusion_dict)
return allreduce_fusion_dict
@pytest.mark.skip(reason="depreciated feature")
def test_allreduce_fusion_parameters():
cost_model_context.reset_cost_model_context()
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2)
algorithm = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_algorithm')
assert algorithm == 2
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
algorithm = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_algorithm')
assert algorithm == 1
cost_model_context.reset_cost_model_context()
algorithm = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_algorithm')
assert algorithm == 0
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
fusion_times = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_times')
assert fusion_times == 2
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.2)
tail_percent = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_percent')
assert tail_percent == 0.2
cost_model_context.reset_cost_model_context()
tail_percent = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_percent')
assert tail_percent == 0.1
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.2)
tail_time = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_time')
assert tail_time == 0.2
cost_model_context.reset_cost_model_context()
tail_time = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_time')
assert tail_time == 0.1
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_inherent_time=0.2)
allreduce_inherent_time = cost_model_context.get_cost_model_context(
'costmodel_allreduce_fusion_allreduce_inherent_time')
assert allreduce_inherent_time == 0.2
cost_model_context.reset_cost_model_context()
allreduce_inherent_time = cost_model_context.get_cost_model_context(
'costmodel_allreduce_fusion_allreduce_inherent_time')
assert allreduce_inherent_time == 0.1
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_bandwidth=0.2)
allreduce_bandwidth = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_allreduce_bandwidth')
assert allreduce_bandwidth == 0.2
cost_model_context.reset_cost_model_context()
allreduce_bandwidth = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_allreduce_bandwidth')
assert allreduce_bandwidth == 0.1
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_computation_time_parameter=0.2)
computation_time_parameter = cost_model_context.get_cost_model_context(
'costmodel_allreduce_fusion_computation_time_parameter')
assert computation_time_parameter == 0.2
cost_model_context.reset_cost_model_context()
computation_time_parameter = cost_model_context.get_cost_model_context(
'costmodel_allreduce_fusion_computation_time_parameter')
assert computation_time_parameter == 0.1
@pytest.mark.skip(reason="depreciated feature")
def test_allreduce_fusion1():
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
def test_allreduce_fusion_auto():
"""
Feature: test_allreduce_fusion in auto mode
Description: allreduce fusion in auto mode
Expectation: success
"""
comm_fusion_dict = {"allreduce": {"mode": "auto", "config": None}}
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, comm_fusion=comm_fusion_dict)
net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
allreduce_fusion_dict = train_common(net)
expect_dict = {'backbone2.fc8.weight': 2,
'backbone2.fc7.weight': 2,
'backbone2.fc6.weight': 2,
'backbone1.fc4.weight': 2,
'backbone1.fc3.weight': 2,
'backbone1.fc2.weight': 2,
'backbone2.fc5.weight': 1,
'backbone2.fc4.weight': 1,
'backbone2.fc3.weight': 1,
'backbone2.fc2.weight': 1,
'backbone2.fc1.weight': 1,
'backbone1.fc1.weight': 1}
assert allreduce_fusion_dict == expect_dict
cost_model_context.reset_cost_model_context()
@pytest.mark.skip(reason="depreciated feature")
# reset_cost_model_context is called, the default value of costmodel_allreduce_fusion_times is 0, step_allreduce_fusion
# is bypassed.
def test_allreduce_fusion2():
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5)
cost_model_context.reset_cost_model_context()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
allreduce_fusion_dict = train_common(net)
expect_dict = {}
assert allreduce_fusion_dict == expect_dict
cost_model_context.reset_cost_model_context()
@pytest.mark.skip(reason="depreciated feature")
def test_allreduce_fusion3():
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=3)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.3333333)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
net = SimpleDMLNet(DenseNet1(has_bias=True, activation='relu'), DenseNet2(has_bias=False, activation='relu'))
allreduce_fusion_dict = train_common(net)
expect_dict = {'backbone2.fc8.weight': 3,
'backbone2.fc7.weight': 3,
'backbone2.fc6.weight': 2,
'backbone2.fc5.weight': 2,
'backbone2.fc4.weight': 2,
'backbone2.fc3.weight': 1,
'backbone2.fc2.weight': 1,
'backbone2.fc1.weight': 1,
'backbone1.fc4.bias': 3,
'backbone1.fc4.weight': 3,
'backbone1.fc3.bias': 3,
'backbone1.fc3.weight': 2,
'backbone1.fc2.bias': 2,
'backbone1.fc2.weight': 2,
'backbone1.fc1.bias': 2,
'backbone1.fc1.weight': 2}
assert allreduce_fusion_dict == expect_dict
cost_model_context.reset_cost_model_context()
@pytest.mark.skip(reason="depreciated feature")
def test_allreduce_fusion4():
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
net = SimpleDMLNet(DenseNet2(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
allreduce_fusion_dict = train_common(net)
expect_dict = {'backbone2.fc8.weight': 2,
'backbone2.fc7.weight': 2,
'backbone2.fc6.weight': 2,
'backbone1.fc8.weight': 2,
'backbone1.fc7.weight': 2,
'backbone1.fc6.weight': 2,
'backbone2.fc5.weight': 1,
'backbone2.fc4.weight': 1,
'backbone2.fc3.weight': 1,
'backbone2.fc2.weight': 1,
'backbone2.fc1.weight': 1,
'backbone1.fc5.weight': 1,
expect_dict = {'backbone2.fc8.weight': 1,
'backbone2.fc7.weight': 1,
'backbone2.fc6.weight': 1,
'backbone1.fc4.weight': 1,
'backbone1.fc3.weight': 1,
'backbone1.fc2.weight': 1,
'backbone1.fc1.weight': 1}
assert allreduce_fusion_dict == expect_dict
cost_model_context.reset_cost_model_context()
@pytest.mark.skip(reason="depreciated feature")
def test_allreduce_fusion5():
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.1)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_inherent_time=0.05)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_bandwidth=0.000001)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_computation_time_parameter=0.0000015)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
net = SimpleDMLNet(DenseNet2(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
allreduce_fusion_dict = train_common(net)
expect_dict = {'backbone2.fc8.weight': 3,
'backbone2.fc7.weight': 3,
'backbone2.fc6.weight': 3,
'backbone2.fc5.weight': 3,
'backbone2.fc4.weight': 2,
'backbone2.fc3.weight': 2,
'backbone2.fc5.weight': 1,
'backbone2.fc4.weight': 1,
'backbone2.fc3.weight': 1,
'backbone2.fc2.weight': 1,
'backbone2.fc1.weight': 1,
'backbone1.fc8.weight': 3,
'backbone1.fc7.weight': 3,
'backbone1.fc6.weight': 3,
'backbone1.fc5.weight': 3,
'backbone1.fc4.weight': 2,
'backbone1.fc3.weight': 2,
'backbone1.fc2.weight': 1,
'backbone1.fc1.weight': 1,}
'backbone1.fc1.weight': 1}
assert allreduce_fusion_dict == expect_dict
cost_model_context.reset_cost_model_context()
def test_allreduce_fusion_size():
"""
Feature: test_allreduce_fusion in size mode
Description: allreduce fusion in size mode
Expectation: success
"""
comm_fusion_dict = {"allreduce": {"mode": "size", "config": 32}}
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, comm_fusion=comm_fusion_dict)
net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
allreduce_fusion_dict = train_common(net)
expect_dict = {'backbone2.fc8.weight': 1,
'backbone2.fc7.weight': 1,
'backbone2.fc6.weight': 1,
'backbone1.fc4.weight': 1,
'backbone1.fc3.weight': 1,
'backbone1.fc2.weight': 1,
'backbone2.fc5.weight': 1,
'backbone2.fc4.weight': 1,
'backbone2.fc3.weight': 1,
'backbone2.fc2.weight': 1,
'backbone2.fc1.weight': 1,
'backbone1.fc1.weight': 1}
assert allreduce_fusion_dict == expect_dict
cost_model_context.reset_cost_model_context()
comm_fusion = auto_parallel_context().get_comm_fusion()
assert comm_fusion_dict == comm_fusion
def test_lamb_split_fusion_in_index():
"""
Feature: test_allreduce_fusion in index mode
Description: allreduce fusion in index mode
Expectation: success
"""
comm_fusion_dict = {"allreduce": {"mode": "index", "config": [2, 4, 6, 8]}}
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True,
comm_fusion=comm_fusion_dict)
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
label = Tensor(np.zeros([32, 768]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Lamb(net.trainable_params(), learning_rate=0.1)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_cell_graph_executor.compile(train_network, inputs, label)
context.reset_auto_parallel_context()