forked from mindspore-Ecosystem/mindspore
!30431 allreduce allgather fusion
Merge pull request !30431 from jiahongQian/master
This commit is contained in:
commit
14393503b7
|
@ -39,41 +39,80 @@ void SetMirrorFusion(const CNodePtr &mirror_cnode, int64_t fusion, const std::st
|
||||||
(void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared<StringImm>(parameter_name)));
|
(void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared<StringImm>(parameter_name)));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status AllreduceFusion::SetFusionBySize(const CNodePtr &ret, int64_t threshold) {
|
Status AllCommFusion::SetFusionBySize(const CNodePtr &ret, int64_t threshold, const PrimitivePtr &primp) {
|
||||||
auto filter = [](const AnfNodePtr &node) { return !IsPrimitiveCNode(node, prim::kPrimMirror); };
|
auto filter = [primp](const AnfNodePtr &node) { return !IsPrimitiveCNode(node, primp); };
|
||||||
auto todo = DeepScopedGraphSearchWithFilter(ret, AlwaysInclude, filter);
|
auto todo = DeepScopedGraphSearchWithFilter(ret, AlwaysInclude, filter);
|
||||||
auto temp = threshold;
|
auto temp = threshold;
|
||||||
int64_t fusion = 1;
|
int64_t fusion = 1;
|
||||||
bool init = true;
|
bool init = true;
|
||||||
|
std::string parameter_name;
|
||||||
|
std::string name;
|
||||||
for (auto &node : todo) {
|
for (auto &node : todo) {
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
if (cnode->input(1)->Shape() == nullptr) continue;
|
if (cnode->input(1)->Shape() == nullptr) continue;
|
||||||
auto input_shapes = GetNodeShape(cnode->input(1));
|
auto input_shapes = GetNodeShape(cnode->input(1));
|
||||||
int64_t input_size = std::accumulate(input_shapes[0].begin(), input_shapes[0].end(), 1, std::multiplies<int64_t>());
|
int64_t input_size = std::accumulate(input_shapes[0].begin(), input_shapes[0].end(), 1, std::multiplies<int64_t>());
|
||||||
FuncGraphPtr func_graph = cnode->func_graph();
|
FuncGraphPtr func_graph = cnode->func_graph();
|
||||||
std::pair<AnfNodePtr, bool> param_node_pair = FindParameter(cnode->input(1), func_graph);
|
if (IsPrimitiveEquals(primp, prim::kPrimMirror)) {
|
||||||
if (!param_node_pair.first) {
|
name = ALL_REDUCE;
|
||||||
continue;
|
std::pair<AnfNodePtr, bool> param_node_pair = FindParameter(cnode->input(1), func_graph);
|
||||||
|
if (!param_node_pair.first) continue;
|
||||||
|
parameter_name = ParameterName(param_node_pair.first);
|
||||||
}
|
}
|
||||||
auto parameter_name = ParameterName(param_node_pair.first);
|
|
||||||
if (input_size < temp) {
|
if (IsPrimitiveEquals(primp, prim::kPrimMicroStepAllGather) || IsPrimitiveEquals(primp, prim::kPrimAllGather)) {
|
||||||
|
name = ALL_GATHER;
|
||||||
|
if (!cnode->input(0) || !cnode->input(1)) continue;
|
||||||
|
PrimitivePtr primp = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||||
|
if (!primp->HasAttr(RECOMPUTE) || GetValue<bool>(primp->GetAttr(RECOMPUTE))) continue;
|
||||||
|
std::pair<AnfNodePtr, bool> param_node_pair = FindParameterWithAllgather(cnode->input(1), func_graph, name);
|
||||||
|
if (!param_node_pair.first) continue;
|
||||||
|
parameter_name = ParameterName(param_node_pair.first);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (init || input_size < temp) {
|
||||||
temp -= input_size;
|
temp -= input_size;
|
||||||
|
init = false;
|
||||||
} else {
|
} else {
|
||||||
temp = threshold;
|
temp = threshold;
|
||||||
fusion++;
|
fusion++;
|
||||||
}
|
}
|
||||||
if (init) {
|
SetMirrorFusion(cnode, fusion, parameter_name);
|
||||||
SetMirrorFusion(cnode, 1, parameter_name);
|
|
||||||
} else {
|
|
||||||
SetMirrorFusion(cnode, fusion, parameter_name);
|
|
||||||
}
|
|
||||||
init = false;
|
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Allreduce fusion by size succeed.";
|
MS_LOG(INFO) << name << " fusion by size succeed.";
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) {
|
Status AllCommFusion::SetFusionBySizeReduceScatter(const CNodePtr &ret, int64_t threshold, const PrimitivePtr &primp) {
|
||||||
|
auto filter = [primp](const AnfNodePtr &node) { return !IsPrimitiveCNode(node, primp); };
|
||||||
|
auto todo = DeepScopedGraphSearchWithFilter(ret, AlwaysInclude, filter);
|
||||||
|
auto temp = threshold;
|
||||||
|
int64_t fusion = 1;
|
||||||
|
bool init = true;
|
||||||
|
for (auto &node : todo) {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
if (cnode->input(1) == nullptr) continue;
|
||||||
|
FuncGraphPtr func_graph = cnode->func_graph();
|
||||||
|
std::pair<AnfNodePtr, bool> param_node_pair =
|
||||||
|
FindParameterWithAllgather(cnode->input(1), func_graph, REDUCE_SCATTER);
|
||||||
|
if (!param_node_pair.first) continue;
|
||||||
|
auto parameter_name = ParameterName(param_node_pair.first);
|
||||||
|
auto input_shapes = GetNodeShape(param_node_pair.first);
|
||||||
|
int64_t input_size = std::accumulate(input_shapes[0].begin(), input_shapes[0].end(), 1, std::multiplies<int64_t>());
|
||||||
|
if (init || input_size < temp) {
|
||||||
|
temp -= input_size;
|
||||||
|
init = false;
|
||||||
|
} else {
|
||||||
|
temp = threshold;
|
||||||
|
fusion++;
|
||||||
|
}
|
||||||
|
SetMirrorFusion(cnode, fusion, parameter_name);
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Reduce_Scatter fusion by size succeed.";
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status AllCommFusion::ProcessCommOpsFusion(const CNodePtr &ret, const std::string &comm_name) {
|
||||||
if (ret == nullptr) {
|
if (ret == nullptr) {
|
||||||
MS_LOG(ERROR) << "ret is nullptr.";
|
MS_LOG(ERROR) << "ret is nullptr.";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -83,7 +122,7 @@ Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) {
|
||||||
MS_EXCEPTION_IF_NULL(root_graph_);
|
MS_EXCEPTION_IF_NULL(root_graph_);
|
||||||
auto graph_set = ForwardGraph(root_graph_);
|
auto graph_set = ForwardGraph(root_graph_);
|
||||||
if (graph_set.size() > 1) {
|
if (graph_set.size() > 1) {
|
||||||
MS_LOG(WARNING) << "AllReduce fusion don't support multiple subgraphs now.";
|
MS_LOG(WARNING) << comm_name << "fusion don't support multiple subgraphs now.";
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
auto forward_graph = *(graph_set.begin());
|
auto forward_graph = *(graph_set.begin());
|
||||||
|
@ -91,16 +130,35 @@ Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) {
|
||||||
forward_ret_ = forward_graph->get_return();
|
forward_ret_ = forward_graph->get_return();
|
||||||
MS_EXCEPTION_IF_NULL(forward_ret_);
|
MS_EXCEPTION_IF_NULL(forward_ret_);
|
||||||
if (allreduce_graph_.set_head_cnode(forward_ret_) != SUCCESS) {
|
if (allreduce_graph_.set_head_cnode(forward_ret_) != SUCCESS) {
|
||||||
MS_LOG(ERROR) << "AllreduceGraph set_head_cnode failed.";
|
MS_LOG(ERROR) << comm_name << "Graph set_head_cnode failed.";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
int64_t threshold = ParallelContext::GetInstance()->fusion_threshold_mb() * 1024 * 1024 / 4;
|
int64_t threshold = 0;
|
||||||
|
if (comm_name == ALL_REDUCE) {
|
||||||
|
threshold = ParallelContext::GetInstance()->fusion_threshold_mb();
|
||||||
|
} else if (comm_name == ALL_GATHER) {
|
||||||
|
threshold = ParallelContext::GetInstance()->allgather_fusion_threshold_mb();
|
||||||
|
} else if (comm_name == REDUCE_SCATTER) {
|
||||||
|
threshold = ParallelContext::GetInstance()->reducescatter_fusion_threshold_mb();
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << " Comm Ops must be ALL_REDUCE, ALL_GATHER or REDUCE_SCATTER, but got " << comm_name;
|
||||||
|
}
|
||||||
|
threshold *= DEFAULT_THRESHOLD_MB_TO_BYTE;
|
||||||
if (threshold <= 0) {
|
if (threshold <= 0) {
|
||||||
MS_LOG(ERROR) << "The threshold of SetFusionBySize must be larger than 0, but got " << threshold << ".";
|
MS_LOG(ERROR) << "The threshold of" << comm_name << "fusion must be larger than 0, but got " << threshold << ".";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
if (comm_name == REDUCE_SCATTER) {
|
||||||
|
(void)SetFusionBySizeReduceScatter(ret, threshold, prim::kPrimVirtualAssignAdd);
|
||||||
|
}
|
||||||
|
if (comm_name == ALL_REDUCE) {
|
||||||
|
(void)SetFusionBySize(ret, threshold, prim::kPrimMirror);
|
||||||
|
}
|
||||||
|
if (comm_name == ALL_GATHER) {
|
||||||
|
(void)SetFusionBySize(ret, threshold, prim::kPrimMicroStepAllGather);
|
||||||
|
(void)SetFusionBySize(ret, threshold, prim::kPrimAllGather);
|
||||||
|
}
|
||||||
|
|
||||||
(void)SetFusionBySize(ret, threshold);
|
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_FUSION_H_
|
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_FUSION_H_
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
#include "utils/hash_map.h"
|
#include "utils/hash_map.h"
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "frontend/parallel/allreduce_fusion/allreduce_graph.h"
|
#include "frontend/parallel/allreduce_fusion/allreduce_graph.h"
|
||||||
|
@ -34,22 +35,24 @@ constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_TIME = 0.1;
|
||||||
constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_INHERENT_TIME = 0.1;
|
constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_INHERENT_TIME = 0.1;
|
||||||
constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_BANDWIDTH = 0.1;
|
constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_BANDWIDTH = 0.1;
|
||||||
constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_COMPUTATION_TIME_PARAMETER = 0.1;
|
constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_COMPUTATION_TIME_PARAMETER = 0.1;
|
||||||
|
constexpr int64_t DEFAULT_THRESHOLD_MB_TO_BYTE = 262144;
|
||||||
|
|
||||||
const uint64_t MAX_RECURSIVE_CALL_TIMES = 100;
|
const uint64_t MAX_RECURSIVE_CALL_TIMES = 100;
|
||||||
class AllreduceFusion {
|
class AllCommFusion {
|
||||||
public:
|
public:
|
||||||
AllreduceFusion()
|
AllCommFusion()
|
||||||
: allreduce_graph_(),
|
: allreduce_graph_(),
|
||||||
ret_(nullptr),
|
ret_(nullptr),
|
||||||
forward_ret_(nullptr),
|
forward_ret_(nullptr),
|
||||||
root_graph_(nullptr),
|
root_graph_(nullptr),
|
||||||
tail_time_(0),
|
tail_time_(0),
|
||||||
computation_time_parameter_(0) {}
|
computation_time_parameter_(0) {}
|
||||||
virtual ~AllreduceFusion() = default;
|
virtual ~AllCommFusion() = default;
|
||||||
Status ProcessAllreduceFusion(const CNodePtr &ret);
|
Status ProcessCommOpsFusion(const CNodePtr &ret, const std::string &comm_name);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status SetFusionBySize(const CNodePtr &ret, int64_t threshold);
|
Status SetFusionBySize(const CNodePtr &ret, int64_t threshold, const PrimitivePtr &primp);
|
||||||
|
Status SetFusionBySizeReduceScatter(const CNodePtr &ret, int64_t threshold, const PrimitivePtr &primp);
|
||||||
AllreduceGraph allreduce_graph_;
|
AllreduceGraph allreduce_graph_;
|
||||||
CNodePtr ret_;
|
CNodePtr ret_;
|
||||||
CNodePtr forward_ret_;
|
CNodePtr forward_ret_;
|
||||||
|
|
|
@ -33,12 +33,15 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti
|
||||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||||
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
|
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
|
||||||
bool enable_all_reduce_fusion = ParallelContext::GetInstance()->enable_all_reduce_fusion();
|
bool enable_all_reduce_fusion = ParallelContext::GetInstance()->enable_all_reduce_fusion();
|
||||||
|
bool enable_all_gather_fusion = ParallelContext::GetInstance()->enable_all_gather_fusion();
|
||||||
|
bool enable_reduce_scatter_fusion = ParallelContext::GetInstance()->enable_reduce_scatter_fusion();
|
||||||
auto graph_set = ForwardGraph(root);
|
auto graph_set = ForwardGraph(root);
|
||||||
// assume no change to graph
|
// assume no change to graph
|
||||||
bool changes = false;
|
bool changes = false;
|
||||||
// control whether use model_parallel mode
|
// control whether use model_parallel mode
|
||||||
if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) ||
|
if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) ||
|
||||||
(!enable_all_reduce_fusion) || (root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY)) || graph_set.size() < 1) {
|
((!enable_all_reduce_fusion) && (!enable_all_gather_fusion) && (!enable_reduce_scatter_fusion)) ||
|
||||||
|
(root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY)) || graph_set.size() < 1) {
|
||||||
return changes;
|
return changes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,7 +53,8 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti
|
||||||
}, end_time{0};
|
}, end_time{0};
|
||||||
(void)gettimeofday(&start_time, nullptr);
|
(void)gettimeofday(&start_time, nullptr);
|
||||||
#endif
|
#endif
|
||||||
MS_LOG(INFO) << "Now entering allreduce fusion by size, and allreduce fusion before will be overlapped!";
|
MS_LOG(INFO) << "Now entering comm ops (allreduce, allgather, reducescatter) fusion by size, and fusion before will "
|
||||||
|
"be overlapped!";
|
||||||
DumpGraph(root, std::string(ALLREDUCE_FUSION_BEGIN));
|
DumpGraph(root, std::string(ALLREDUCE_FUSION_BEGIN));
|
||||||
|
|
||||||
pipeline::ResourceBasePtr res = optimizer->resource();
|
pipeline::ResourceBasePtr res = optimizer->resource();
|
||||||
|
@ -61,9 +65,15 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti
|
||||||
CNodePtr ret = root->get_return();
|
CNodePtr ret = root->get_return();
|
||||||
MS_EXCEPTION_IF_NULL(ret);
|
MS_EXCEPTION_IF_NULL(ret);
|
||||||
|
|
||||||
AllreduceFusion allreduce_fusion;
|
AllCommFusion allcomm_fusion;
|
||||||
if (allreduce_fusion.ProcessAllreduceFusion(ret) != SUCCESS) {
|
vector<std::string> comm_ops = {ALL_REDUCE, ALL_GATHER, REDUCE_SCATTER};
|
||||||
MS_LOG(EXCEPTION) << "ProcessAllreduceFusion failed";
|
vector<bool> fusionlist = {enable_all_reduce_fusion, enable_all_gather_fusion, enable_reduce_scatter_fusion};
|
||||||
|
for (size_t i = 0; i < comm_ops.size(); i++) {
|
||||||
|
if (fusionlist[i]) {
|
||||||
|
if (allcomm_fusion.ProcessCommOpsFusion(ret, comm_ops[i]) != SUCCESS) {
|
||||||
|
MS_LOG(EXCEPTION) << "Process" << comm_ops[i] << "Fusion failed";
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
DumpGraph(root, std::string(ALLREDUCE_FUSION_END));
|
DumpGraph(root, std::string(ALLREDUCE_FUSION_END));
|
||||||
|
|
|
@ -63,6 +63,8 @@ void ParallelContext::Reset() {
|
||||||
parameter_broadcast_ = false;
|
parameter_broadcast_ = false;
|
||||||
parameter_broadcast_is_set_ = false;
|
parameter_broadcast_is_set_ = false;
|
||||||
enable_all_reduce_fusion_ = true;
|
enable_all_reduce_fusion_ = true;
|
||||||
|
enable_all_gather_fusion_ = true;
|
||||||
|
enable_reduce_scatter_fusion_ = true;
|
||||||
strategy_ckpt_load_file_ = "";
|
strategy_ckpt_load_file_ = "";
|
||||||
strategy_ckpt_save_file_ = "";
|
strategy_ckpt_save_file_ = "";
|
||||||
enable_parallel_optimizer_ = false;
|
enable_parallel_optimizer_ = false;
|
||||||
|
@ -79,6 +81,8 @@ void ParallelContext::Reset() {
|
||||||
sharding_propagation_ = false;
|
sharding_propagation_ = false;
|
||||||
dataset_strategy_.clear();
|
dataset_strategy_.clear();
|
||||||
fusion_threshold_mb_ = FUSUION_THRESHOLD;
|
fusion_threshold_mb_ = FUSUION_THRESHOLD;
|
||||||
|
allgather_fusion_threshold_mb_ = FUSUION_THRESHOLD;
|
||||||
|
reducescatter_fusion_threshold_mb_ = FUSUION_THRESHOLD;
|
||||||
fusion_threshold_is_set_ = true;
|
fusion_threshold_is_set_ = true;
|
||||||
fusion_mode_ = FUSION_AUTO;
|
fusion_mode_ = FUSION_AUTO;
|
||||||
}
|
}
|
||||||
|
@ -94,6 +98,16 @@ void ParallelContext::set_fusion_threshold_mb(int64_t fusion_threshold) {
|
||||||
enable_all_reduce_fusion_ = true;
|
enable_all_reduce_fusion_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ParallelContext::set_allgather_fusion_threshold_mb(int64_t fusion_threshold) {
|
||||||
|
allgather_fusion_threshold_mb_ = fusion_threshold;
|
||||||
|
enable_all_gather_fusion_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ParallelContext::set_reducescatter_fusion_threshold_mb(int64_t fusion_threshold) {
|
||||||
|
reducescatter_fusion_threshold_mb_ = fusion_threshold;
|
||||||
|
enable_reduce_scatter_fusion_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
bool ParallelContext::set_fusion_mode(const std::string &fusion_mode) {
|
bool ParallelContext::set_fusion_mode(const std::string &fusion_mode) {
|
||||||
auto iter = std::find(FUSION_MODE_LIST.begin(), FUSION_MODE_LIST.end(), fusion_mode);
|
auto iter = std::find(FUSION_MODE_LIST.begin(), FUSION_MODE_LIST.end(), fusion_mode);
|
||||||
if (iter == FUSION_MODE_LIST.end()) {
|
if (iter == FUSION_MODE_LIST.end()) {
|
||||||
|
|
|
@ -85,6 +85,13 @@ class ParallelContext {
|
||||||
|
|
||||||
void set_fusion_threshold_mb(int64_t fusion_threshold);
|
void set_fusion_threshold_mb(int64_t fusion_threshold);
|
||||||
int64_t fusion_threshold_mb() const { return fusion_threshold_mb_; }
|
int64_t fusion_threshold_mb() const { return fusion_threshold_mb_; }
|
||||||
|
|
||||||
|
void set_allgather_fusion_threshold_mb(int64_t allgather_fusion_threshold);
|
||||||
|
int64_t allgather_fusion_threshold_mb() const { return allgather_fusion_threshold_mb_; }
|
||||||
|
|
||||||
|
void set_reducescatter_fusion_threshold_mb(int64_t rs_fusion_threshold);
|
||||||
|
int64_t reducescatter_fusion_threshold_mb() const { return reducescatter_fusion_threshold_mb_; }
|
||||||
|
|
||||||
bool set_fusion_mode(const std::string &fusion_mode);
|
bool set_fusion_mode(const std::string &fusion_mode);
|
||||||
std::string get_fusion_mode() const { return fusion_mode_; }
|
std::string get_fusion_mode() const { return fusion_mode_; }
|
||||||
|
|
||||||
|
@ -123,6 +130,14 @@ class ParallelContext {
|
||||||
enable_all_reduce_fusion_ = enable_all_reduce_fusion;
|
enable_all_reduce_fusion_ = enable_all_reduce_fusion;
|
||||||
}
|
}
|
||||||
bool enable_all_reduce_fusion() const { return enable_all_reduce_fusion_; }
|
bool enable_all_reduce_fusion() const { return enable_all_reduce_fusion_; }
|
||||||
|
void set_enable_all_gather_fusion(bool enable_all_gather_fusion) {
|
||||||
|
enable_all_gather_fusion_ = enable_all_gather_fusion;
|
||||||
|
}
|
||||||
|
bool enable_all_gather_fusion() const { return enable_all_gather_fusion_; }
|
||||||
|
void set_enable_reduce_scatter_fusion(bool enable_reduce_scatter_fusion) {
|
||||||
|
enable_reduce_scatter_fusion_ = enable_reduce_scatter_fusion;
|
||||||
|
}
|
||||||
|
bool enable_reduce_scatter_fusion() const { return enable_reduce_scatter_fusion_; }
|
||||||
|
|
||||||
void set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file);
|
void set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file);
|
||||||
std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; }
|
std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; }
|
||||||
|
@ -170,6 +185,8 @@ class ParallelContext {
|
||||||
bool loss_repeated_mean_;
|
bool loss_repeated_mean_;
|
||||||
int64_t device_num_;
|
int64_t device_num_;
|
||||||
int64_t fusion_threshold_mb_;
|
int64_t fusion_threshold_mb_;
|
||||||
|
int64_t allgather_fusion_threshold_mb_;
|
||||||
|
int64_t reducescatter_fusion_threshold_mb_; // reducescatter
|
||||||
int64_t global_rank_;
|
int64_t global_rank_;
|
||||||
int64_t grad_accumulation_step_;
|
int64_t grad_accumulation_step_;
|
||||||
std::string parallel_mode_;
|
std::string parallel_mode_;
|
||||||
|
@ -181,6 +198,8 @@ class ParallelContext {
|
||||||
bool global_rank_is_set_;
|
bool global_rank_is_set_;
|
||||||
bool parameter_broadcast_is_set_;
|
bool parameter_broadcast_is_set_;
|
||||||
bool enable_all_reduce_fusion_;
|
bool enable_all_reduce_fusion_;
|
||||||
|
bool enable_all_gather_fusion_;
|
||||||
|
bool enable_reduce_scatter_fusion_;
|
||||||
std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_indices_;
|
std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_indices_;
|
||||||
std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_;
|
std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_;
|
||||||
std::string strategy_ckpt_load_file_;
|
std::string strategy_ckpt_load_file_;
|
||||||
|
|
|
@ -564,7 +564,8 @@ static bool IsOriginWeight(const ParameterPtr ¶m) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
|
static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph,
|
||||||
|
const std::string &name = ALL_REDUCE) {
|
||||||
if (IsValueNode<RefKey>(node)) {
|
if (IsValueNode<RefKey>(node)) {
|
||||||
std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph);
|
std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph);
|
||||||
if (param_v.size() != 1) {
|
if (param_v.size() != 1) {
|
||||||
|
@ -572,7 +573,8 @@ static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &no
|
||||||
<< param_v.size();
|
<< param_v.size();
|
||||||
}
|
}
|
||||||
auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>();
|
auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>();
|
||||||
if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) {
|
if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty() &&
|
||||||
|
name == ALL_REDUCE) {
|
||||||
return std::make_pair(nullptr, true);
|
return std::make_pair(nullptr, true);
|
||||||
}
|
}
|
||||||
return std::make_pair(node, true);
|
return std::make_pair(node, true);
|
||||||
|
@ -580,9 +582,11 @@ static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &no
|
||||||
return std::make_pair(nullptr, false);
|
return std::make_pair(nullptr, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::pair<AnfNodePtr, bool> FindParameterByParameter(const AnfNodePtr &node) {
|
static std::pair<AnfNodePtr, bool> FindParameterByParameter(const AnfNodePtr &node,
|
||||||
|
const std::string &name = ALL_REDUCE) {
|
||||||
auto param_ptr = node->user_data<parallel::TensorLayout>();
|
auto param_ptr = node->user_data<parallel::TensorLayout>();
|
||||||
if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) {
|
if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty() &&
|
||||||
|
name == ALL_REDUCE) {
|
||||||
return std::make_pair(nullptr, false);
|
return std::make_pair(nullptr, false);
|
||||||
}
|
}
|
||||||
return std::make_pair(node, false);
|
return std::make_pair(node, false);
|
||||||
|
@ -637,6 +641,34 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
|
||||||
return std::make_pair(nullptr, false);
|
return std::make_pair(nullptr, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Used for allgather and reducescatter
|
||||||
|
std::pair<AnfNodePtr, bool> FindParameterWithAllgather(const AnfNodePtr &node, const FuncGraphPtr &func_graph,
|
||||||
|
const std::string &name) {
|
||||||
|
if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
|
||||||
|
return std::make_pair(nullptr, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (node->isa<Parameter>()) {
|
||||||
|
return FindParameterByParameter(node, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (node->isa<ValueNode>()) {
|
||||||
|
return FindParameterByValueNode(node, func_graph, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
CNodePtr cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
for (size_t index = 0; index < cnode->inputs().size(); ++index) {
|
||||||
|
if (index != 1) continue;
|
||||||
|
auto res = FindParameterWithAllgather(cnode->input(index), func_graph, name);
|
||||||
|
if (!res.first) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
return std::make_pair(nullptr, false);
|
||||||
|
}
|
||||||
|
|
||||||
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> AdaSumParamTensorLayout(const FuncGraphPtr &root) {
|
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> AdaSumParamTensorLayout(const FuncGraphPtr &root) {
|
||||||
MS_EXCEPTION_IF_NULL(root);
|
MS_EXCEPTION_IF_NULL(root);
|
||||||
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> adasum_param_map;
|
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> adasum_param_map;
|
||||||
|
|
|
@ -43,6 +43,8 @@ void HandleFullySplitParameters(const FuncGraphPtr &root);
|
||||||
void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root);
|
void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root);
|
||||||
void HandleAdaFactorOpt(const FuncGraphPtr &root);
|
void HandleAdaFactorOpt(const FuncGraphPtr &root);
|
||||||
std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph);
|
std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph);
|
||||||
|
std::pair<AnfNodePtr, bool> FindParameterWithAllgather(const AnfNodePtr &node, const FuncGraphPtr &func_graph,
|
||||||
|
const std::string &name);
|
||||||
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> AdaSumParamTensorLayout(const FuncGraphPtr &root);
|
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> AdaSumParamTensorLayout(const FuncGraphPtr &root);
|
||||||
bool HandleAdaSum(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
|
bool HandleAdaSum(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
|
||||||
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> *adasum_param_tensor_layout_map);
|
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> *adasum_param_tensor_layout_map);
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <queue>
|
||||||
|
|
||||||
#include "utils/hash_map.h"
|
#include "utils/hash_map.h"
|
||||||
#include "base/core_ops.h"
|
#include "base/core_ops.h"
|
||||||
|
|
|
@ -146,7 +146,15 @@ PYBIND11_MODULE(_c_expression, m) {
|
||||||
.def("set_device_num", &ParallelContext::set_device_num, "Set device num.")
|
.def("set_device_num", &ParallelContext::set_device_num, "Set device num.")
|
||||||
.def("get_device_num_is_set", &ParallelContext::device_num_is_set, "Get device num is set.")
|
.def("get_device_num_is_set", &ParallelContext::device_num_is_set, "Get device num is set.")
|
||||||
.def("set_fusion_threshold_mb", &ParallelContext::set_fusion_threshold_mb, "Set fusion threshold.")
|
.def("set_fusion_threshold_mb", &ParallelContext::set_fusion_threshold_mb, "Set fusion threshold.")
|
||||||
.def("fusion_threshold_mb", &ParallelContext::fusion_threshold_mb, "Get fusion threshold.")
|
.def("set_allgather_fusion_threshold_mb", &ParallelContext::set_allgather_fusion_threshold_mb,
|
||||||
|
"Set allgather fusion threshold.")
|
||||||
|
.def("set_reducescatter_fusion_threshold_mb", &ParallelContext::set_reducescatter_fusion_threshold_mb,
|
||||||
|
"Set reducescatter fusion threshold.")
|
||||||
|
.def("fusion_threshold_mb", &ParallelContext::fusion_threshold_mb, "Get allreduce fusion threshold.")
|
||||||
|
.def("allgather_fusion_threshold_mb", &ParallelContext::allgather_fusion_threshold_mb,
|
||||||
|
"Get allgather fusion threshold.")
|
||||||
|
.def("reducescatter_fusion_threshold_mb", &ParallelContext::reducescatter_fusion_threshold_mb,
|
||||||
|
"Get reduce_scatter fusion threshold.")
|
||||||
.def("set_fusion_mode", &ParallelContext::set_fusion_mode, "Get fusion mode.")
|
.def("set_fusion_mode", &ParallelContext::set_fusion_mode, "Get fusion mode.")
|
||||||
.def("get_fusion_mode", &ParallelContext::get_fusion_mode, "Get fusion mode.")
|
.def("get_fusion_mode", &ParallelContext::get_fusion_mode, "Get fusion mode.")
|
||||||
.def("get_global_rank", &ParallelContext::global_rank, "Get global rank.")
|
.def("get_global_rank", &ParallelContext::global_rank, "Get global rank.")
|
||||||
|
@ -178,6 +186,14 @@ PYBIND11_MODULE(_c_expression, m) {
|
||||||
"Set enable/disable all reduce fusion.")
|
"Set enable/disable all reduce fusion.")
|
||||||
.def("get_enable_all_reduce_fusion", &ParallelContext::enable_all_reduce_fusion,
|
.def("get_enable_all_reduce_fusion", &ParallelContext::enable_all_reduce_fusion,
|
||||||
"Get enable/disable all reduce fusion.")
|
"Get enable/disable all reduce fusion.")
|
||||||
|
.def("set_enable_all_gather_fusion", &ParallelContext::set_enable_all_gather_fusion,
|
||||||
|
"Set enable/disable all gather fusion.")
|
||||||
|
.def("get_enable_all_gather_fusion", &ParallelContext::enable_all_gather_fusion,
|
||||||
|
"Get enable/disable all gather fusion.")
|
||||||
|
.def("set_enable_reduce_scatter_fusion", &ParallelContext::set_enable_reduce_scatter_fusion,
|
||||||
|
"Set enable/disable reduce scatter fusion.")
|
||||||
|
.def("get_enable_reduce_scatter_fusion", &ParallelContext::enable_reduce_scatter_fusion,
|
||||||
|
"Get enable/disable reduce scatter fusion.")
|
||||||
.def("get_parameter_broadcast", &ParallelContext::parameter_broadcast, "Get parameter broadcast.")
|
.def("get_parameter_broadcast", &ParallelContext::parameter_broadcast, "Get parameter broadcast.")
|
||||||
.def("get_parameter_broadcast_is_set", &ParallelContext::parameter_broadcast_is_set,
|
.def("get_parameter_broadcast_is_set", &ParallelContext::parameter_broadcast_is_set,
|
||||||
"Get parameter broadcast is set.")
|
"Get parameter broadcast is set.")
|
||||||
|
|
|
@ -32,6 +32,8 @@ class _ParallelFusionConfig:
|
||||||
The key of the Parallel fusion method configuration.
|
The key of the Parallel fusion method configuration.
|
||||||
"""
|
"""
|
||||||
ALLREDUCE = "allreduce"
|
ALLREDUCE = "allreduce"
|
||||||
|
ALLGATHER = "allgather"
|
||||||
|
REDUCESCATTER = "reducescatter"
|
||||||
MODE = "mode"
|
MODE = "mode"
|
||||||
FUSION_CONFIG = "config"
|
FUSION_CONFIG = "config"
|
||||||
AUTO = "auto"
|
AUTO = "auto"
|
||||||
|
@ -116,8 +118,12 @@ class _AutoParallelContext:
|
||||||
for key in list(config.keys()):
|
for key in list(config.keys()):
|
||||||
if key == _ParallelFusionConfig.ALLREDUCE:
|
if key == _ParallelFusionConfig.ALLREDUCE:
|
||||||
self._set_allreduce_comm_fusion(config[key])
|
self._set_allreduce_comm_fusion(config[key])
|
||||||
|
elif key == _ParallelFusionConfig.ALLGATHER:
|
||||||
|
self._set_allgather_comm_fusion(config[key], key)
|
||||||
|
elif key == _ParallelFusionConfig.REDUCESCATTER:
|
||||||
|
self._set_allgather_comm_fusion(config[key], key)
|
||||||
else:
|
else:
|
||||||
raise KeyError("comm fusion type must be allreduce, but got {}".format(key))
|
raise KeyError("comm fusion type must be allreduce, allgather or reducescatter, but got {}".format(key))
|
||||||
|
|
||||||
def _set_allreduce_comm_fusion(self, comm_fusion):
|
def _set_allreduce_comm_fusion(self, comm_fusion):
|
||||||
"""
|
"""
|
||||||
|
@ -154,6 +160,42 @@ class _AutoParallelContext:
|
||||||
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.INDEX:
|
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.INDEX:
|
||||||
self.set_all_reduce_fusion_split_indices(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
|
self.set_all_reduce_fusion_split_indices(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
|
||||||
|
|
||||||
|
def _set_allgather_comm_fusion(self, comm_fusion, comm_type="allgather"):
|
||||||
|
"""
|
||||||
|
Set allgather and reducescatter fusion method for auto parallel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
|
||||||
|
supports four fusion methods: `auto` and `size`.
|
||||||
|
comm_type (str): The name of the communication operator, `allgather` or `reducescatter`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: When key of comm_fusion is not 'mode' or 'config'.
|
||||||
|
KeyError: When `mode` is not 'auto', 'size'.
|
||||||
|
"""
|
||||||
|
self.check_context_handle()
|
||||||
|
if comm_type == "allgather" and not self.get_enable_all_gather_fusion():
|
||||||
|
return
|
||||||
|
if comm_type == "reducescatter" and not self.get_enable_reduce_scatter_fusion():
|
||||||
|
return
|
||||||
|
if not isinstance(comm_fusion, dict):
|
||||||
|
raise TypeError("For 'comm_fusion', {} config must be dict, but got the type : {}.".format(
|
||||||
|
comm_type, type(comm_fusion)))
|
||||||
|
if _ParallelFusionConfig.MODE not in comm_fusion:
|
||||||
|
raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
|
||||||
|
if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
|
||||||
|
raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
|
||||||
|
check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE]
|
||||||
|
if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
|
||||||
|
self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
|
||||||
|
else:
|
||||||
|
raise KeyError("fusion method mode must be auto or size, but got {}".format(
|
||||||
|
comm_fusion[_ParallelFusionConfig.MODE]))
|
||||||
|
|
||||||
|
fusion_threshold = 64 if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.AUTO else \
|
||||||
|
comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]
|
||||||
|
self.set_fusion_threshold_mb(fusion_threshold, comm_type)
|
||||||
|
|
||||||
def get_comm_fusion(self):
|
def get_comm_fusion(self):
|
||||||
"""Get comm fusion config."""
|
"""Get comm fusion config."""
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
|
@ -166,12 +208,13 @@ class _AutoParallelContext:
|
||||||
_ParallelFusionConfig.FUSION_CONFIG: config}}
|
_ParallelFusionConfig.FUSION_CONFIG: config}}
|
||||||
|
|
||||||
|
|
||||||
def set_fusion_threshold_mb(self, fusion_threshold=64):
|
def set_fusion_threshold_mb(self, fusion_threshold=64, comm_type="allreduce"):
|
||||||
"""
|
"""
|
||||||
Set fusion threshold (MB) for auto parallel.
|
Set fusion threshold (MB) for auto parallel.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fusion_threshold (int): The fusion threshold (unit: MB). Default: 64.
|
fusion_threshold (int): The fusion threshold (unit: MB). Default: 64.
|
||||||
|
comm_type (str): The name of the communication operator, `allreduce`, `allgather` or `reducescatter`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the fusion threshold is not in [0, +inf].
|
ValueError: If the fusion threshold is not in [0, +inf].
|
||||||
|
@ -179,13 +222,30 @@ class _AutoParallelContext:
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
if fusion_threshold < 0:
|
if fusion_threshold < 0:
|
||||||
raise ValueError("fusion threshold must be larger than 0, but got {}".format(fusion_threshold))
|
raise ValueError("fusion threshold must be larger than 0, but got {}".format(fusion_threshold))
|
||||||
self._context_handle.set_fusion_threshold_mb(fusion_threshold)
|
|
||||||
|
if comm_type == _ParallelFusionConfig.ALLREDUCE:
|
||||||
|
self._context_handle.set_fusion_threshold_mb(fusion_threshold)
|
||||||
|
if comm_type == _ParallelFusionConfig.ALLGATHER:
|
||||||
|
self._context_handle.set_allgather_fusion_threshold_mb(fusion_threshold)
|
||||||
|
if comm_type == _ParallelFusionConfig.REDUCESCATTER:
|
||||||
|
self._context_handle.set_reducescatter_fusion_threshold_mb(fusion_threshold)
|
||||||
|
|
||||||
|
|
||||||
def fusion_threshold_mb(self):
|
def fusion_threshold_mb(self):
|
||||||
"""Get device num."""
|
"""Get all reduce threshold."""
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
return self._context_handle.fusion_threshold_mb()
|
return self._context_handle.fusion_threshold_mb()
|
||||||
|
|
||||||
|
def allgather_fusion_threshold_mb(self):
|
||||||
|
"""Get allgather threshold."""
|
||||||
|
self.check_context_handle()
|
||||||
|
return self._context_handle.allgather_fusion_threshold_mb()
|
||||||
|
|
||||||
|
def reducescatter_fusion_threshold_mb(self):
|
||||||
|
"""Get reducescatter threshold."""
|
||||||
|
self.check_context_handle()
|
||||||
|
return self._context_handle.reducescatter_fusion_threshold_mb()
|
||||||
|
|
||||||
def set_global_rank(self, global_rank):
|
def set_global_rank(self, global_rank):
|
||||||
"""
|
"""
|
||||||
Set global rank for auto parallel.
|
Set global rank for auto parallel.
|
||||||
|
@ -624,15 +684,53 @@ class _AutoParallelContext:
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
if not isinstance(enable_all_reduce_fusion, bool):
|
if not isinstance(enable_all_reduce_fusion, bool):
|
||||||
raise TypeError("For 'set_auto_parallel_context().set_enable_all_reduce_fusion', "
|
raise TypeError("For 'set_auto_parallel_context().set_enable_all_reduce_fusion', "
|
||||||
"the argument 'enable_all_reduce_fusion' must be bool, but got the type : {}."
|
"the argument 'enable_fusion' must be bool, but got the type : {}."
|
||||||
.format(type(enable_all_reduce_fusion)))
|
.format(type(enable_all_reduce_fusion)))
|
||||||
self._context_handle.set_enable_all_reduce_fusion(enable_all_reduce_fusion)
|
self._context_handle.set_enable_all_reduce_fusion(enable_all_reduce_fusion)
|
||||||
|
|
||||||
|
def set_enable_all_gather_fusion(self, enable_all_gather_fusion):
|
||||||
|
"""
|
||||||
|
Set enable/disable all gather fusion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enable_all_gather_fusion (bool): Enable/disable all gather fusion.
|
||||||
|
"""
|
||||||
|
self.check_context_handle()
|
||||||
|
if not isinstance(enable_all_gather_fusion, bool):
|
||||||
|
raise TypeError("For 'set_auto_parallel_context().set_enable_all_gather_fusion', "
|
||||||
|
"the argument 'enable_fusion' must be bool, but got the type : {}."
|
||||||
|
.format(type(enable_all_gather_fusion)))
|
||||||
|
self._context_handle.set_enable_all_gather_fusion(enable_all_gather_fusion)
|
||||||
|
|
||||||
|
def set_enable_reduce_scatter_fusion(self, enable_reduce_scatter_fusion):
|
||||||
|
"""
|
||||||
|
Set enable/disable reduce scatter fusion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enable_reduce_scatter_fusion (bool): Enable/disable reduce scatter fusion.
|
||||||
|
"""
|
||||||
|
self.check_context_handle()
|
||||||
|
if not isinstance(enable_reduce_scatter_fusion, bool):
|
||||||
|
raise TypeError("For 'set_auto_parallel_context().set_enable_reduce_scatter_fusion', "
|
||||||
|
"the argument 'enable_fusion' must be bool, but got the type : {}."
|
||||||
|
.format(type(enable_reduce_scatter_fusion)))
|
||||||
|
self._context_handle.set_enable_reduce_scatter_fusion(enable_reduce_scatter_fusion)
|
||||||
|
|
||||||
def get_enable_all_reduce_fusion(self):
|
def get_enable_all_reduce_fusion(self):
|
||||||
"""Get all reduce fusion flag."""
|
"""Get all reduce fusion flag."""
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
return self._context_handle.get_enable_all_reduce_fusion()
|
return self._context_handle.get_enable_all_reduce_fusion()
|
||||||
|
|
||||||
|
def get_enable_all_gather_fusion(self):
|
||||||
|
"""Get all gather fusion flag."""
|
||||||
|
self.check_context_handle()
|
||||||
|
return self._context_handle.get_enable_all_gather_fusion()
|
||||||
|
|
||||||
|
def get_enable_reduce_scatter_fusion(self):
|
||||||
|
"""Get reduce scatter flag."""
|
||||||
|
self.check_context_handle()
|
||||||
|
return self._context_handle.get_enable_reduce_scatter_fusion()
|
||||||
|
|
||||||
def get_device_num_is_set(self):
|
def get_device_num_is_set(self):
|
||||||
"""Get device number is set or not."""
|
"""Get device number is set or not."""
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
|
|
|
@ -0,0 +1,148 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common.parameter import Parameter
|
||||||
|
from mindspore.common.initializer import initializer
|
||||||
|
from mindspore.train.model import Model
|
||||||
|
from mindspore.nn.wrap.cell_wrapper import PipelineCell
|
||||||
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetLenet():
|
||||||
|
def __init__(self, data, label, length=3):
|
||||||
|
self.data = data
|
||||||
|
self.label = label
|
||||||
|
self.index = 1
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.index >= self.length:
|
||||||
|
raise StopIteration
|
||||||
|
self.index += 1
|
||||||
|
return self.data, self.label
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.index = 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_dataset_size():
|
||||||
|
return 32
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_repeat_count():
|
||||||
|
return 1
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_batch_size():
|
||||||
|
return 32
|
||||||
|
|
||||||
|
def create_tuple_iterator(self, num_epochs=1, do_copy=True):
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class MatMulCell(nn.Cell):
|
||||||
|
def __init__(self, strategy1, strategy2, param=None):
|
||||||
|
super().__init__()
|
||||||
|
self.param = Parameter(initializer("zeros", [64, 64]), name="param")
|
||||||
|
if param is not None:
|
||||||
|
self.param = param
|
||||||
|
self.param1 = Parameter(initializer("zeros", [64, 64]), name="param1")
|
||||||
|
self.matmul = P.MatMul().shard(strategy1)
|
||||||
|
self.matmul1 = P.MatMul().shard(strategy2)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
out = self.matmul(x, self.param)
|
||||||
|
out = self.matmul1(out, self.param1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, strategy1, strategy2, param=None):
|
||||||
|
super().__init__()
|
||||||
|
self.block = nn.CellList()
|
||||||
|
for i in range(2):
|
||||||
|
cell = MatMulCell(strategy1, strategy2, param)
|
||||||
|
cell.pipeline_stage = i
|
||||||
|
self.block.append(cell)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
for i in range(2):
|
||||||
|
x = self.block[i](x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineSplit(nn.Cell):
|
||||||
|
def __init__(self, strategy1, strategy2):
|
||||||
|
super().__init__()
|
||||||
|
self.cell = Net(strategy1, strategy2)
|
||||||
|
self.cell.block[0].matmul.add_prim_attr("parameter_start", 0)
|
||||||
|
|
||||||
|
def construct(self, x, label):
|
||||||
|
x = self.cell(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def test_fusion_size():
|
||||||
|
"""
|
||||||
|
Feature: test_fusion_auto in size mode
|
||||||
|
Description: allgather and reduce scatter fusion in size mode
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
allgather_threshold = 8
|
||||||
|
reducescatter_threshold = 16
|
||||||
|
comm_fusion_dict = {"allgather": {"mode": "size", "config": allgather_threshold},
|
||||||
|
"reducescatter": {"mode": "size", "config": reducescatter_threshold}}
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True)
|
||||||
|
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||||
|
label = Tensor(np.ones([64]), dtype=ms.float32)
|
||||||
|
strategy1 = ((4, 1), (1, 1))
|
||||||
|
strategy2 = ((2, 1), (1, 1))
|
||||||
|
net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
|
||||||
|
dataset = DatasetLenet(data, label, 3)
|
||||||
|
optimizer = nn.Lamb(net.trainable_params(), learning_rate=0.01)
|
||||||
|
model = Model(net, optimizer=optimizer)
|
||||||
|
model.train(2, dataset, dataset_sink_mode=False)
|
||||||
|
assert auto_parallel_context().allgather_fusion_threshold_mb() == allgather_threshold
|
||||||
|
assert auto_parallel_context().reducescatter_fusion_threshold_mb() == reducescatter_threshold
|
||||||
|
|
||||||
|
def test_fusion_auto():
|
||||||
|
"""
|
||||||
|
Feature: test_fusion_auto in auto mode
|
||||||
|
Description: allgather and reduce scatter fusion in auto mode
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
comm_fusion_dict = {"allgather": {"mode": "auto", "config": None},
|
||||||
|
"reducescatter": {"mode": "auto", "config": None}}
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True,
|
||||||
|
parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)
|
||||||
|
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||||
|
label = Tensor(np.ones([64]), dtype=ms.float32)
|
||||||
|
strategy1 = ((4, 1), (1, 1))
|
||||||
|
strategy2 = ((2, 1), (1, 1))
|
||||||
|
net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
|
||||||
|
dataset = DatasetLenet(data, label, 3)
|
||||||
|
optimizer = nn.Lamb(net.trainable_params(), learning_rate=0.01)
|
||||||
|
model = Model(net, optimizer=optimizer)
|
||||||
|
model.train(2, dataset, dataset_sink_mode=False)
|
||||||
|
assert auto_parallel_context().allgather_fusion_threshold_mb() == 64
|
||||||
|
assert auto_parallel_context().reducescatter_fusion_threshold_mb() == 64
|
Loading…
Reference in New Issue