forked from mindspore-Ecosystem/mindspore
!30431 allreduce allgather fusion
Merge pull request !30431 from jiahongQian/master
This commit is contained in:
commit
14393503b7
|
@ -39,41 +39,80 @@ void SetMirrorFusion(const CNodePtr &mirror_cnode, int64_t fusion, const std::st
|
|||
(void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared<StringImm>(parameter_name)));
|
||||
}
|
||||
|
||||
Status AllreduceFusion::SetFusionBySize(const CNodePtr &ret, int64_t threshold) {
|
||||
auto filter = [](const AnfNodePtr &node) { return !IsPrimitiveCNode(node, prim::kPrimMirror); };
|
||||
Status AllCommFusion::SetFusionBySize(const CNodePtr &ret, int64_t threshold, const PrimitivePtr &primp) {
|
||||
auto filter = [primp](const AnfNodePtr &node) { return !IsPrimitiveCNode(node, primp); };
|
||||
auto todo = DeepScopedGraphSearchWithFilter(ret, AlwaysInclude, filter);
|
||||
auto temp = threshold;
|
||||
int64_t fusion = 1;
|
||||
bool init = true;
|
||||
std::string parameter_name;
|
||||
std::string name;
|
||||
for (auto &node : todo) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode->input(1)->Shape() == nullptr) continue;
|
||||
auto input_shapes = GetNodeShape(cnode->input(1));
|
||||
int64_t input_size = std::accumulate(input_shapes[0].begin(), input_shapes[0].end(), 1, std::multiplies<int64_t>());
|
||||
FuncGraphPtr func_graph = cnode->func_graph();
|
||||
std::pair<AnfNodePtr, bool> param_node_pair = FindParameter(cnode->input(1), func_graph);
|
||||
if (!param_node_pair.first) {
|
||||
continue;
|
||||
if (IsPrimitiveEquals(primp, prim::kPrimMirror)) {
|
||||
name = ALL_REDUCE;
|
||||
std::pair<AnfNodePtr, bool> param_node_pair = FindParameter(cnode->input(1), func_graph);
|
||||
if (!param_node_pair.first) continue;
|
||||
parameter_name = ParameterName(param_node_pair.first);
|
||||
}
|
||||
auto parameter_name = ParameterName(param_node_pair.first);
|
||||
if (input_size < temp) {
|
||||
|
||||
if (IsPrimitiveEquals(primp, prim::kPrimMicroStepAllGather) || IsPrimitiveEquals(primp, prim::kPrimAllGather)) {
|
||||
name = ALL_GATHER;
|
||||
if (!cnode->input(0) || !cnode->input(1)) continue;
|
||||
PrimitivePtr primp = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (!primp->HasAttr(RECOMPUTE) || GetValue<bool>(primp->GetAttr(RECOMPUTE))) continue;
|
||||
std::pair<AnfNodePtr, bool> param_node_pair = FindParameterWithAllgather(cnode->input(1), func_graph, name);
|
||||
if (!param_node_pair.first) continue;
|
||||
parameter_name = ParameterName(param_node_pair.first);
|
||||
}
|
||||
|
||||
if (init || input_size < temp) {
|
||||
temp -= input_size;
|
||||
init = false;
|
||||
} else {
|
||||
temp = threshold;
|
||||
fusion++;
|
||||
}
|
||||
if (init) {
|
||||
SetMirrorFusion(cnode, 1, parameter_name);
|
||||
} else {
|
||||
SetMirrorFusion(cnode, fusion, parameter_name);
|
||||
}
|
||||
init = false;
|
||||
SetMirrorFusion(cnode, fusion, parameter_name);
|
||||
}
|
||||
MS_LOG(INFO) << "Allreduce fusion by size succeed.";
|
||||
MS_LOG(INFO) << name << " fusion by size succeed.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) {
|
||||
Status AllCommFusion::SetFusionBySizeReduceScatter(const CNodePtr &ret, int64_t threshold, const PrimitivePtr &primp) {
|
||||
auto filter = [primp](const AnfNodePtr &node) { return !IsPrimitiveCNode(node, primp); };
|
||||
auto todo = DeepScopedGraphSearchWithFilter(ret, AlwaysInclude, filter);
|
||||
auto temp = threshold;
|
||||
int64_t fusion = 1;
|
||||
bool init = true;
|
||||
for (auto &node : todo) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode->input(1) == nullptr) continue;
|
||||
FuncGraphPtr func_graph = cnode->func_graph();
|
||||
std::pair<AnfNodePtr, bool> param_node_pair =
|
||||
FindParameterWithAllgather(cnode->input(1), func_graph, REDUCE_SCATTER);
|
||||
if (!param_node_pair.first) continue;
|
||||
auto parameter_name = ParameterName(param_node_pair.first);
|
||||
auto input_shapes = GetNodeShape(param_node_pair.first);
|
||||
int64_t input_size = std::accumulate(input_shapes[0].begin(), input_shapes[0].end(), 1, std::multiplies<int64_t>());
|
||||
if (init || input_size < temp) {
|
||||
temp -= input_size;
|
||||
init = false;
|
||||
} else {
|
||||
temp = threshold;
|
||||
fusion++;
|
||||
}
|
||||
SetMirrorFusion(cnode, fusion, parameter_name);
|
||||
}
|
||||
MS_LOG(INFO) << "Reduce_Scatter fusion by size succeed.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status AllCommFusion::ProcessCommOpsFusion(const CNodePtr &ret, const std::string &comm_name) {
|
||||
if (ret == nullptr) {
|
||||
MS_LOG(ERROR) << "ret is nullptr.";
|
||||
return FAILED;
|
||||
|
@ -83,7 +122,7 @@ Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) {
|
|||
MS_EXCEPTION_IF_NULL(root_graph_);
|
||||
auto graph_set = ForwardGraph(root_graph_);
|
||||
if (graph_set.size() > 1) {
|
||||
MS_LOG(WARNING) << "AllReduce fusion don't support multiple subgraphs now.";
|
||||
MS_LOG(WARNING) << comm_name << "fusion don't support multiple subgraphs now.";
|
||||
return SUCCESS;
|
||||
}
|
||||
auto forward_graph = *(graph_set.begin());
|
||||
|
@ -91,16 +130,35 @@ Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) {
|
|||
forward_ret_ = forward_graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(forward_ret_);
|
||||
if (allreduce_graph_.set_head_cnode(forward_ret_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "AllreduceGraph set_head_cnode failed.";
|
||||
MS_LOG(ERROR) << comm_name << "Graph set_head_cnode failed.";
|
||||
return FAILED;
|
||||
}
|
||||
int64_t threshold = ParallelContext::GetInstance()->fusion_threshold_mb() * 1024 * 1024 / 4;
|
||||
int64_t threshold = 0;
|
||||
if (comm_name == ALL_REDUCE) {
|
||||
threshold = ParallelContext::GetInstance()->fusion_threshold_mb();
|
||||
} else if (comm_name == ALL_GATHER) {
|
||||
threshold = ParallelContext::GetInstance()->allgather_fusion_threshold_mb();
|
||||
} else if (comm_name == REDUCE_SCATTER) {
|
||||
threshold = ParallelContext::GetInstance()->reducescatter_fusion_threshold_mb();
|
||||
} else {
|
||||
MS_LOG(ERROR) << " Comm Ops must be ALL_REDUCE, ALL_GATHER or REDUCE_SCATTER, but got " << comm_name;
|
||||
}
|
||||
threshold *= DEFAULT_THRESHOLD_MB_TO_BYTE;
|
||||
if (threshold <= 0) {
|
||||
MS_LOG(ERROR) << "The threshold of SetFusionBySize must be larger than 0, but got " << threshold << ".";
|
||||
MS_LOG(ERROR) << "The threshold of" << comm_name << "fusion must be larger than 0, but got " << threshold << ".";
|
||||
return FAILED;
|
||||
}
|
||||
if (comm_name == REDUCE_SCATTER) {
|
||||
(void)SetFusionBySizeReduceScatter(ret, threshold, prim::kPrimVirtualAssignAdd);
|
||||
}
|
||||
if (comm_name == ALL_REDUCE) {
|
||||
(void)SetFusionBySize(ret, threshold, prim::kPrimMirror);
|
||||
}
|
||||
if (comm_name == ALL_GATHER) {
|
||||
(void)SetFusionBySize(ret, threshold, prim::kPrimMicroStepAllGather);
|
||||
(void)SetFusionBySize(ret, threshold, prim::kPrimAllGather);
|
||||
}
|
||||
|
||||
(void)SetFusionBySize(ret, threshold);
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace parallel
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_FUSION_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "utils/hash_map.h"
|
||||
#include "ir/anf.h"
|
||||
#include "frontend/parallel/allreduce_fusion/allreduce_graph.h"
|
||||
|
@ -34,22 +35,24 @@ constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_TIME = 0.1;
|
|||
constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_INHERENT_TIME = 0.1;
|
||||
constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_BANDWIDTH = 0.1;
|
||||
constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_COMPUTATION_TIME_PARAMETER = 0.1;
|
||||
constexpr int64_t DEFAULT_THRESHOLD_MB_TO_BYTE = 262144;
|
||||
|
||||
const uint64_t MAX_RECURSIVE_CALL_TIMES = 100;
|
||||
class AllreduceFusion {
|
||||
class AllCommFusion {
|
||||
public:
|
||||
AllreduceFusion()
|
||||
AllCommFusion()
|
||||
: allreduce_graph_(),
|
||||
ret_(nullptr),
|
||||
forward_ret_(nullptr),
|
||||
root_graph_(nullptr),
|
||||
tail_time_(0),
|
||||
computation_time_parameter_(0) {}
|
||||
virtual ~AllreduceFusion() = default;
|
||||
Status ProcessAllreduceFusion(const CNodePtr &ret);
|
||||
virtual ~AllCommFusion() = default;
|
||||
Status ProcessCommOpsFusion(const CNodePtr &ret, const std::string &comm_name);
|
||||
|
||||
private:
|
||||
Status SetFusionBySize(const CNodePtr &ret, int64_t threshold);
|
||||
Status SetFusionBySize(const CNodePtr &ret, int64_t threshold, const PrimitivePtr &primp);
|
||||
Status SetFusionBySizeReduceScatter(const CNodePtr &ret, int64_t threshold, const PrimitivePtr &primp);
|
||||
AllreduceGraph allreduce_graph_;
|
||||
CNodePtr ret_;
|
||||
CNodePtr forward_ret_;
|
||||
|
|
|
@ -33,12 +33,15 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti
|
|||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
|
||||
bool enable_all_reduce_fusion = ParallelContext::GetInstance()->enable_all_reduce_fusion();
|
||||
bool enable_all_gather_fusion = ParallelContext::GetInstance()->enable_all_gather_fusion();
|
||||
bool enable_reduce_scatter_fusion = ParallelContext::GetInstance()->enable_reduce_scatter_fusion();
|
||||
auto graph_set = ForwardGraph(root);
|
||||
// assume no change to graph
|
||||
bool changes = false;
|
||||
// control whether use model_parallel mode
|
||||
if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) ||
|
||||
(!enable_all_reduce_fusion) || (root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY)) || graph_set.size() < 1) {
|
||||
((!enable_all_reduce_fusion) && (!enable_all_gather_fusion) && (!enable_reduce_scatter_fusion)) ||
|
||||
(root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY)) || graph_set.size() < 1) {
|
||||
return changes;
|
||||
}
|
||||
|
||||
|
@ -50,7 +53,8 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti
|
|||
}, end_time{0};
|
||||
(void)gettimeofday(&start_time, nullptr);
|
||||
#endif
|
||||
MS_LOG(INFO) << "Now entering allreduce fusion by size, and allreduce fusion before will be overlapped!";
|
||||
MS_LOG(INFO) << "Now entering comm ops (allreduce, allgather, reducescatter) fusion by size, and fusion before will "
|
||||
"be overlapped!";
|
||||
DumpGraph(root, std::string(ALLREDUCE_FUSION_BEGIN));
|
||||
|
||||
pipeline::ResourceBasePtr res = optimizer->resource();
|
||||
|
@ -61,9 +65,15 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti
|
|||
CNodePtr ret = root->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
|
||||
AllreduceFusion allreduce_fusion;
|
||||
if (allreduce_fusion.ProcessAllreduceFusion(ret) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "ProcessAllreduceFusion failed";
|
||||
AllCommFusion allcomm_fusion;
|
||||
vector<std::string> comm_ops = {ALL_REDUCE, ALL_GATHER, REDUCE_SCATTER};
|
||||
vector<bool> fusionlist = {enable_all_reduce_fusion, enable_all_gather_fusion, enable_reduce_scatter_fusion};
|
||||
for (size_t i = 0; i < comm_ops.size(); i++) {
|
||||
if (fusionlist[i]) {
|
||||
if (allcomm_fusion.ProcessCommOpsFusion(ret, comm_ops[i]) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Process" << comm_ops[i] << "Fusion failed";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
DumpGraph(root, std::string(ALLREDUCE_FUSION_END));
|
||||
|
|
|
@ -63,6 +63,8 @@ void ParallelContext::Reset() {
|
|||
parameter_broadcast_ = false;
|
||||
parameter_broadcast_is_set_ = false;
|
||||
enable_all_reduce_fusion_ = true;
|
||||
enable_all_gather_fusion_ = true;
|
||||
enable_reduce_scatter_fusion_ = true;
|
||||
strategy_ckpt_load_file_ = "";
|
||||
strategy_ckpt_save_file_ = "";
|
||||
enable_parallel_optimizer_ = false;
|
||||
|
@ -79,6 +81,8 @@ void ParallelContext::Reset() {
|
|||
sharding_propagation_ = false;
|
||||
dataset_strategy_.clear();
|
||||
fusion_threshold_mb_ = FUSUION_THRESHOLD;
|
||||
allgather_fusion_threshold_mb_ = FUSUION_THRESHOLD;
|
||||
reducescatter_fusion_threshold_mb_ = FUSUION_THRESHOLD;
|
||||
fusion_threshold_is_set_ = true;
|
||||
fusion_mode_ = FUSION_AUTO;
|
||||
}
|
||||
|
@ -94,6 +98,16 @@ void ParallelContext::set_fusion_threshold_mb(int64_t fusion_threshold) {
|
|||
enable_all_reduce_fusion_ = true;
|
||||
}
|
||||
|
||||
void ParallelContext::set_allgather_fusion_threshold_mb(int64_t fusion_threshold) {
|
||||
allgather_fusion_threshold_mb_ = fusion_threshold;
|
||||
enable_all_gather_fusion_ = true;
|
||||
}
|
||||
|
||||
void ParallelContext::set_reducescatter_fusion_threshold_mb(int64_t fusion_threshold) {
|
||||
reducescatter_fusion_threshold_mb_ = fusion_threshold;
|
||||
enable_reduce_scatter_fusion_ = true;
|
||||
}
|
||||
|
||||
bool ParallelContext::set_fusion_mode(const std::string &fusion_mode) {
|
||||
auto iter = std::find(FUSION_MODE_LIST.begin(), FUSION_MODE_LIST.end(), fusion_mode);
|
||||
if (iter == FUSION_MODE_LIST.end()) {
|
||||
|
|
|
@ -85,6 +85,13 @@ class ParallelContext {
|
|||
|
||||
void set_fusion_threshold_mb(int64_t fusion_threshold);
|
||||
int64_t fusion_threshold_mb() const { return fusion_threshold_mb_; }
|
||||
|
||||
void set_allgather_fusion_threshold_mb(int64_t allgather_fusion_threshold);
|
||||
int64_t allgather_fusion_threshold_mb() const { return allgather_fusion_threshold_mb_; }
|
||||
|
||||
void set_reducescatter_fusion_threshold_mb(int64_t rs_fusion_threshold);
|
||||
int64_t reducescatter_fusion_threshold_mb() const { return reducescatter_fusion_threshold_mb_; }
|
||||
|
||||
bool set_fusion_mode(const std::string &fusion_mode);
|
||||
std::string get_fusion_mode() const { return fusion_mode_; }
|
||||
|
||||
|
@ -123,6 +130,14 @@ class ParallelContext {
|
|||
enable_all_reduce_fusion_ = enable_all_reduce_fusion;
|
||||
}
|
||||
bool enable_all_reduce_fusion() const { return enable_all_reduce_fusion_; }
|
||||
void set_enable_all_gather_fusion(bool enable_all_gather_fusion) {
|
||||
enable_all_gather_fusion_ = enable_all_gather_fusion;
|
||||
}
|
||||
bool enable_all_gather_fusion() const { return enable_all_gather_fusion_; }
|
||||
void set_enable_reduce_scatter_fusion(bool enable_reduce_scatter_fusion) {
|
||||
enable_reduce_scatter_fusion_ = enable_reduce_scatter_fusion;
|
||||
}
|
||||
bool enable_reduce_scatter_fusion() const { return enable_reduce_scatter_fusion_; }
|
||||
|
||||
void set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file);
|
||||
std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; }
|
||||
|
@ -170,6 +185,8 @@ class ParallelContext {
|
|||
bool loss_repeated_mean_;
|
||||
int64_t device_num_;
|
||||
int64_t fusion_threshold_mb_;
|
||||
int64_t allgather_fusion_threshold_mb_;
|
||||
int64_t reducescatter_fusion_threshold_mb_; // reducescatter
|
||||
int64_t global_rank_;
|
||||
int64_t grad_accumulation_step_;
|
||||
std::string parallel_mode_;
|
||||
|
@ -181,6 +198,8 @@ class ParallelContext {
|
|||
bool global_rank_is_set_;
|
||||
bool parameter_broadcast_is_set_;
|
||||
bool enable_all_reduce_fusion_;
|
||||
bool enable_all_gather_fusion_;
|
||||
bool enable_reduce_scatter_fusion_;
|
||||
std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_indices_;
|
||||
std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_;
|
||||
std::string strategy_ckpt_load_file_;
|
||||
|
|
|
@ -564,7 +564,8 @@ static bool IsOriginWeight(const ParameterPtr ¶m) {
|
|||
return true;
|
||||
}
|
||||
|
||||
static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
|
||||
static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph,
|
||||
const std::string &name = ALL_REDUCE) {
|
||||
if (IsValueNode<RefKey>(node)) {
|
||||
std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph);
|
||||
if (param_v.size() != 1) {
|
||||
|
@ -572,7 +573,8 @@ static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &no
|
|||
<< param_v.size();
|
||||
}
|
||||
auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>();
|
||||
if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) {
|
||||
if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty() &&
|
||||
name == ALL_REDUCE) {
|
||||
return std::make_pair(nullptr, true);
|
||||
}
|
||||
return std::make_pair(node, true);
|
||||
|
@ -580,9 +582,11 @@ static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &no
|
|||
return std::make_pair(nullptr, false);
|
||||
}
|
||||
|
||||
static std::pair<AnfNodePtr, bool> FindParameterByParameter(const AnfNodePtr &node) {
|
||||
static std::pair<AnfNodePtr, bool> FindParameterByParameter(const AnfNodePtr &node,
|
||||
const std::string &name = ALL_REDUCE) {
|
||||
auto param_ptr = node->user_data<parallel::TensorLayout>();
|
||||
if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) {
|
||||
if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty() &&
|
||||
name == ALL_REDUCE) {
|
||||
return std::make_pair(nullptr, false);
|
||||
}
|
||||
return std::make_pair(node, false);
|
||||
|
@ -637,6 +641,34 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
|
|||
return std::make_pair(nullptr, false);
|
||||
}
|
||||
|
||||
// Used for allgather and reducescatter
|
||||
std::pair<AnfNodePtr, bool> FindParameterWithAllgather(const AnfNodePtr &node, const FuncGraphPtr &func_graph,
|
||||
const std::string &name) {
|
||||
if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
|
||||
return std::make_pair(nullptr, false);
|
||||
}
|
||||
|
||||
if (node->isa<Parameter>()) {
|
||||
return FindParameterByParameter(node, name);
|
||||
}
|
||||
|
||||
if (node->isa<ValueNode>()) {
|
||||
return FindParameterByValueNode(node, func_graph, name);
|
||||
}
|
||||
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
for (size_t index = 0; index < cnode->inputs().size(); ++index) {
|
||||
if (index != 1) continue;
|
||||
auto res = FindParameterWithAllgather(cnode->input(index), func_graph, name);
|
||||
if (!res.first) {
|
||||
continue;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
return std::make_pair(nullptr, false);
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> AdaSumParamTensorLayout(const FuncGraphPtr &root) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> adasum_param_map;
|
||||
|
|
|
@ -43,6 +43,8 @@ void HandleFullySplitParameters(const FuncGraphPtr &root);
|
|||
void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root);
|
||||
void HandleAdaFactorOpt(const FuncGraphPtr &root);
|
||||
std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph);
|
||||
std::pair<AnfNodePtr, bool> FindParameterWithAllgather(const AnfNodePtr &node, const FuncGraphPtr &func_graph,
|
||||
const std::string &name);
|
||||
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> AdaSumParamTensorLayout(const FuncGraphPtr &root);
|
||||
bool HandleAdaSum(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
|
||||
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> *adasum_param_tensor_layout_map);
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <queue>
|
||||
|
||||
#include "utils/hash_map.h"
|
||||
#include "base/core_ops.h"
|
||||
|
|
|
@ -146,7 +146,15 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("set_device_num", &ParallelContext::set_device_num, "Set device num.")
|
||||
.def("get_device_num_is_set", &ParallelContext::device_num_is_set, "Get device num is set.")
|
||||
.def("set_fusion_threshold_mb", &ParallelContext::set_fusion_threshold_mb, "Set fusion threshold.")
|
||||
.def("fusion_threshold_mb", &ParallelContext::fusion_threshold_mb, "Get fusion threshold.")
|
||||
.def("set_allgather_fusion_threshold_mb", &ParallelContext::set_allgather_fusion_threshold_mb,
|
||||
"Set allgather fusion threshold.")
|
||||
.def("set_reducescatter_fusion_threshold_mb", &ParallelContext::set_reducescatter_fusion_threshold_mb,
|
||||
"Set reducescatter fusion threshold.")
|
||||
.def("fusion_threshold_mb", &ParallelContext::fusion_threshold_mb, "Get allreduce fusion threshold.")
|
||||
.def("allgather_fusion_threshold_mb", &ParallelContext::allgather_fusion_threshold_mb,
|
||||
"Get allgather fusion threshold.")
|
||||
.def("reducescatter_fusion_threshold_mb", &ParallelContext::reducescatter_fusion_threshold_mb,
|
||||
"Get reduce_scatter fusion threshold.")
|
||||
.def("set_fusion_mode", &ParallelContext::set_fusion_mode, "Get fusion mode.")
|
||||
.def("get_fusion_mode", &ParallelContext::get_fusion_mode, "Get fusion mode.")
|
||||
.def("get_global_rank", &ParallelContext::global_rank, "Get global rank.")
|
||||
|
@ -178,6 +186,14 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
"Set enable/disable all reduce fusion.")
|
||||
.def("get_enable_all_reduce_fusion", &ParallelContext::enable_all_reduce_fusion,
|
||||
"Get enable/disable all reduce fusion.")
|
||||
.def("set_enable_all_gather_fusion", &ParallelContext::set_enable_all_gather_fusion,
|
||||
"Set enable/disable all gather fusion.")
|
||||
.def("get_enable_all_gather_fusion", &ParallelContext::enable_all_gather_fusion,
|
||||
"Get enable/disable all gather fusion.")
|
||||
.def("set_enable_reduce_scatter_fusion", &ParallelContext::set_enable_reduce_scatter_fusion,
|
||||
"Set enable/disable reduce scatter fusion.")
|
||||
.def("get_enable_reduce_scatter_fusion", &ParallelContext::enable_reduce_scatter_fusion,
|
||||
"Get enable/disable reduce scatter fusion.")
|
||||
.def("get_parameter_broadcast", &ParallelContext::parameter_broadcast, "Get parameter broadcast.")
|
||||
.def("get_parameter_broadcast_is_set", &ParallelContext::parameter_broadcast_is_set,
|
||||
"Get parameter broadcast is set.")
|
||||
|
|
|
@ -32,6 +32,8 @@ class _ParallelFusionConfig:
|
|||
The key of the Parallel fusion method configuration.
|
||||
"""
|
||||
ALLREDUCE = "allreduce"
|
||||
ALLGATHER = "allgather"
|
||||
REDUCESCATTER = "reducescatter"
|
||||
MODE = "mode"
|
||||
FUSION_CONFIG = "config"
|
||||
AUTO = "auto"
|
||||
|
@ -116,8 +118,12 @@ class _AutoParallelContext:
|
|||
for key in list(config.keys()):
|
||||
if key == _ParallelFusionConfig.ALLREDUCE:
|
||||
self._set_allreduce_comm_fusion(config[key])
|
||||
elif key == _ParallelFusionConfig.ALLGATHER:
|
||||
self._set_allgather_comm_fusion(config[key], key)
|
||||
elif key == _ParallelFusionConfig.REDUCESCATTER:
|
||||
self._set_allgather_comm_fusion(config[key], key)
|
||||
else:
|
||||
raise KeyError("comm fusion type must be allreduce, but got {}".format(key))
|
||||
raise KeyError("comm fusion type must be allreduce, allgather or reducescatter, but got {}".format(key))
|
||||
|
||||
def _set_allreduce_comm_fusion(self, comm_fusion):
|
||||
"""
|
||||
|
@ -154,6 +160,42 @@ class _AutoParallelContext:
|
|||
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.INDEX:
|
||||
self.set_all_reduce_fusion_split_indices(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
|
||||
|
||||
def _set_allgather_comm_fusion(self, comm_fusion, comm_type="allgather"):
|
||||
"""
|
||||
Set allgather and reducescatter fusion method for auto parallel.
|
||||
|
||||
Args:
|
||||
comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
|
||||
supports four fusion methods: `auto` and `size`.
|
||||
comm_type (str): The name of the communication operator, `allgather` or `reducescatter`.
|
||||
|
||||
Raises:
|
||||
KeyError: When key of comm_fusion is not 'mode' or 'config'.
|
||||
KeyError: When `mode` is not 'auto', 'size'.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
if comm_type == "allgather" and not self.get_enable_all_gather_fusion():
|
||||
return
|
||||
if comm_type == "reducescatter" and not self.get_enable_reduce_scatter_fusion():
|
||||
return
|
||||
if not isinstance(comm_fusion, dict):
|
||||
raise TypeError("For 'comm_fusion', {} config must be dict, but got the type : {}.".format(
|
||||
comm_type, type(comm_fusion)))
|
||||
if _ParallelFusionConfig.MODE not in comm_fusion:
|
||||
raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
|
||||
if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
|
||||
raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
|
||||
check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE]
|
||||
if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
|
||||
self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
|
||||
else:
|
||||
raise KeyError("fusion method mode must be auto or size, but got {}".format(
|
||||
comm_fusion[_ParallelFusionConfig.MODE]))
|
||||
|
||||
fusion_threshold = 64 if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.AUTO else \
|
||||
comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]
|
||||
self.set_fusion_threshold_mb(fusion_threshold, comm_type)
|
||||
|
||||
def get_comm_fusion(self):
|
||||
"""Get comm fusion config."""
|
||||
self.check_context_handle()
|
||||
|
@ -166,12 +208,13 @@ class _AutoParallelContext:
|
|||
_ParallelFusionConfig.FUSION_CONFIG: config}}
|
||||
|
||||
|
||||
def set_fusion_threshold_mb(self, fusion_threshold=64):
|
||||
def set_fusion_threshold_mb(self, fusion_threshold=64, comm_type="allreduce"):
|
||||
"""
|
||||
Set fusion threshold (MB) for auto parallel.
|
||||
|
||||
Args:
|
||||
fusion_threshold (int): The fusion threshold (unit: MB). Default: 64.
|
||||
comm_type (str): The name of the communication operator, `allreduce`, `allgather` or `reducescatter`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the fusion threshold is not in [0, +inf].
|
||||
|
@ -179,13 +222,30 @@ class _AutoParallelContext:
|
|||
self.check_context_handle()
|
||||
if fusion_threshold < 0:
|
||||
raise ValueError("fusion threshold must be larger than 0, but got {}".format(fusion_threshold))
|
||||
self._context_handle.set_fusion_threshold_mb(fusion_threshold)
|
||||
|
||||
if comm_type == _ParallelFusionConfig.ALLREDUCE:
|
||||
self._context_handle.set_fusion_threshold_mb(fusion_threshold)
|
||||
if comm_type == _ParallelFusionConfig.ALLGATHER:
|
||||
self._context_handle.set_allgather_fusion_threshold_mb(fusion_threshold)
|
||||
if comm_type == _ParallelFusionConfig.REDUCESCATTER:
|
||||
self._context_handle.set_reducescatter_fusion_threshold_mb(fusion_threshold)
|
||||
|
||||
|
||||
def fusion_threshold_mb(self):
|
||||
"""Get device num."""
|
||||
"""Get all reduce threshold."""
|
||||
self.check_context_handle()
|
||||
return self._context_handle.fusion_threshold_mb()
|
||||
|
||||
def allgather_fusion_threshold_mb(self):
|
||||
"""Get allgather threshold."""
|
||||
self.check_context_handle()
|
||||
return self._context_handle.allgather_fusion_threshold_mb()
|
||||
|
||||
def reducescatter_fusion_threshold_mb(self):
|
||||
"""Get reducescatter threshold."""
|
||||
self.check_context_handle()
|
||||
return self._context_handle.reducescatter_fusion_threshold_mb()
|
||||
|
||||
def set_global_rank(self, global_rank):
|
||||
"""
|
||||
Set global rank for auto parallel.
|
||||
|
@ -624,15 +684,53 @@ class _AutoParallelContext:
|
|||
self.check_context_handle()
|
||||
if not isinstance(enable_all_reduce_fusion, bool):
|
||||
raise TypeError("For 'set_auto_parallel_context().set_enable_all_reduce_fusion', "
|
||||
"the argument 'enable_all_reduce_fusion' must be bool, but got the type : {}."
|
||||
"the argument 'enable_fusion' must be bool, but got the type : {}."
|
||||
.format(type(enable_all_reduce_fusion)))
|
||||
self._context_handle.set_enable_all_reduce_fusion(enable_all_reduce_fusion)
|
||||
|
||||
def set_enable_all_gather_fusion(self, enable_all_gather_fusion):
|
||||
"""
|
||||
Set enable/disable all gather fusion.
|
||||
|
||||
Args:
|
||||
enable_all_gather_fusion (bool): Enable/disable all gather fusion.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
if not isinstance(enable_all_gather_fusion, bool):
|
||||
raise TypeError("For 'set_auto_parallel_context().set_enable_all_gather_fusion', "
|
||||
"the argument 'enable_fusion' must be bool, but got the type : {}."
|
||||
.format(type(enable_all_gather_fusion)))
|
||||
self._context_handle.set_enable_all_gather_fusion(enable_all_gather_fusion)
|
||||
|
||||
def set_enable_reduce_scatter_fusion(self, enable_reduce_scatter_fusion):
|
||||
"""
|
||||
Set enable/disable reduce scatter fusion.
|
||||
|
||||
Args:
|
||||
enable_reduce_scatter_fusion (bool): Enable/disable reduce scatter fusion.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
if not isinstance(enable_reduce_scatter_fusion, bool):
|
||||
raise TypeError("For 'set_auto_parallel_context().set_enable_reduce_scatter_fusion', "
|
||||
"the argument 'enable_fusion' must be bool, but got the type : {}."
|
||||
.format(type(enable_reduce_scatter_fusion)))
|
||||
self._context_handle.set_enable_reduce_scatter_fusion(enable_reduce_scatter_fusion)
|
||||
|
||||
def get_enable_all_reduce_fusion(self):
|
||||
"""Get all reduce fusion flag."""
|
||||
self.check_context_handle()
|
||||
return self._context_handle.get_enable_all_reduce_fusion()
|
||||
|
||||
def get_enable_all_gather_fusion(self):
|
||||
"""Get all gather fusion flag."""
|
||||
self.check_context_handle()
|
||||
return self._context_handle.get_enable_all_gather_fusion()
|
||||
|
||||
def get_enable_reduce_scatter_fusion(self):
|
||||
"""Get reduce scatter flag."""
|
||||
self.check_context_handle()
|
||||
return self._context_handle.get_enable_reduce_scatter_fusion()
|
||||
|
||||
def get_device_num_is_set(self):
|
||||
"""Get device number is set or not."""
|
||||
self.check_context_handle()
|
||||
|
|
|
@ -0,0 +1,148 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.nn.wrap.cell_wrapper import PipelineCell
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
|
||||
|
||||
class DatasetLenet():
|
||||
def __init__(self, data, label, length=3):
|
||||
self.data = data
|
||||
self.label = label
|
||||
self.index = 1
|
||||
self.length = length
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.index >= self.length:
|
||||
raise StopIteration
|
||||
self.index += 1
|
||||
return self.data, self.label
|
||||
|
||||
def reset(self):
|
||||
self.index = 0
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_size():
|
||||
return 32
|
||||
|
||||
@staticmethod
|
||||
def get_repeat_count():
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
def get_batch_size():
|
||||
return 32
|
||||
|
||||
def create_tuple_iterator(self, num_epochs=1, do_copy=True):
|
||||
return self
|
||||
|
||||
|
||||
class MatMulCell(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2, param=None):
|
||||
super().__init__()
|
||||
self.param = Parameter(initializer("zeros", [64, 64]), name="param")
|
||||
if param is not None:
|
||||
self.param = param
|
||||
self.param1 = Parameter(initializer("zeros", [64, 64]), name="param1")
|
||||
self.matmul = P.MatMul().shard(strategy1)
|
||||
self.matmul1 = P.MatMul().shard(strategy2)
|
||||
|
||||
def construct(self, x):
|
||||
out = self.matmul(x, self.param)
|
||||
out = self.matmul1(out, self.param1)
|
||||
return out
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2, param=None):
|
||||
super().__init__()
|
||||
self.block = nn.CellList()
|
||||
for i in range(2):
|
||||
cell = MatMulCell(strategy1, strategy2, param)
|
||||
cell.pipeline_stage = i
|
||||
self.block.append(cell)
|
||||
|
||||
def construct(self, x):
|
||||
for i in range(2):
|
||||
x = self.block[i](x)
|
||||
return x
|
||||
|
||||
|
||||
class PipelineSplit(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.cell = Net(strategy1, strategy2)
|
||||
self.cell.block[0].matmul.add_prim_attr("parameter_start", 0)
|
||||
|
||||
def construct(self, x, label):
|
||||
x = self.cell(x)
|
||||
return x
|
||||
|
||||
def test_fusion_size():
|
||||
"""
|
||||
Feature: test_fusion_auto in size mode
|
||||
Description: allgather and reduce scatter fusion in size mode
|
||||
Expectation: success
|
||||
"""
|
||||
allgather_threshold = 8
|
||||
reducescatter_threshold = 16
|
||||
comm_fusion_dict = {"allgather": {"mode": "size", "config": allgather_threshold},
|
||||
"reducescatter": {"mode": "size", "config": reducescatter_threshold}}
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True)
|
||||
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
label = Tensor(np.ones([64]), dtype=ms.float32)
|
||||
strategy1 = ((4, 1), (1, 1))
|
||||
strategy2 = ((2, 1), (1, 1))
|
||||
net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
|
||||
dataset = DatasetLenet(data, label, 3)
|
||||
optimizer = nn.Lamb(net.trainable_params(), learning_rate=0.01)
|
||||
model = Model(net, optimizer=optimizer)
|
||||
model.train(2, dataset, dataset_sink_mode=False)
|
||||
assert auto_parallel_context().allgather_fusion_threshold_mb() == allgather_threshold
|
||||
assert auto_parallel_context().reducescatter_fusion_threshold_mb() == reducescatter_threshold
|
||||
|
||||
def test_fusion_auto():
|
||||
"""
|
||||
Feature: test_fusion_auto in auto mode
|
||||
Description: allgather and reduce scatter fusion in auto mode
|
||||
Expectation: success
|
||||
"""
|
||||
comm_fusion_dict = {"allgather": {"mode": "auto", "config": None},
|
||||
"reducescatter": {"mode": "auto", "config": None}}
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True,
|
||||
parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)
|
||||
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
label = Tensor(np.ones([64]), dtype=ms.float32)
|
||||
strategy1 = ((4, 1), (1, 1))
|
||||
strategy2 = ((2, 1), (1, 1))
|
||||
net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
|
||||
dataset = DatasetLenet(data, label, 3)
|
||||
optimizer = nn.Lamb(net.trainable_params(), learning_rate=0.01)
|
||||
model = Model(net, optimizer=optimizer)
|
||||
model.train(2, dataset, dataset_sink_mode=False)
|
||||
assert auto_parallel_context().allgather_fusion_threshold_mb() == 64
|
||||
assert auto_parallel_context().reducescatter_fusion_threshold_mb() == 64
|
Loading…
Reference in New Issue