!30431 allreduce allgather fusion

Merge pull request !30431 from jiahongQian/master
This commit is contained in:
i-robot 2022-02-23 08:52:40 +00:00 committed by Gitee
commit 14393503b7
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 441 additions and 40 deletions

View File

@ -39,41 +39,80 @@ void SetMirrorFusion(const CNodePtr &mirror_cnode, int64_t fusion, const std::st
(void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared<StringImm>(parameter_name)));
}
Status AllreduceFusion::SetFusionBySize(const CNodePtr &ret, int64_t threshold) {
auto filter = [](const AnfNodePtr &node) { return !IsPrimitiveCNode(node, prim::kPrimMirror); };
Status AllCommFusion::SetFusionBySize(const CNodePtr &ret, int64_t threshold, const PrimitivePtr &primp) {
auto filter = [primp](const AnfNodePtr &node) { return !IsPrimitiveCNode(node, primp); };
auto todo = DeepScopedGraphSearchWithFilter(ret, AlwaysInclude, filter);
auto temp = threshold;
int64_t fusion = 1;
bool init = true;
std::string parameter_name;
std::string name;
for (auto &node : todo) {
auto cnode = node->cast<CNodePtr>();
if (cnode->input(1)->Shape() == nullptr) continue;
auto input_shapes = GetNodeShape(cnode->input(1));
int64_t input_size = std::accumulate(input_shapes[0].begin(), input_shapes[0].end(), 1, std::multiplies<int64_t>());
FuncGraphPtr func_graph = cnode->func_graph();
std::pair<AnfNodePtr, bool> param_node_pair = FindParameter(cnode->input(1), func_graph);
if (!param_node_pair.first) {
continue;
if (IsPrimitiveEquals(primp, prim::kPrimMirror)) {
name = ALL_REDUCE;
std::pair<AnfNodePtr, bool> param_node_pair = FindParameter(cnode->input(1), func_graph);
if (!param_node_pair.first) continue;
parameter_name = ParameterName(param_node_pair.first);
}
auto parameter_name = ParameterName(param_node_pair.first);
if (input_size < temp) {
if (IsPrimitiveEquals(primp, prim::kPrimMicroStepAllGather) || IsPrimitiveEquals(primp, prim::kPrimAllGather)) {
name = ALL_GATHER;
if (!cnode->input(0) || !cnode->input(1)) continue;
PrimitivePtr primp = GetValueNode<PrimitivePtr>(cnode->input(0));
if (!primp->HasAttr(RECOMPUTE) || GetValue<bool>(primp->GetAttr(RECOMPUTE))) continue;
std::pair<AnfNodePtr, bool> param_node_pair = FindParameterWithAllgather(cnode->input(1), func_graph, name);
if (!param_node_pair.first) continue;
parameter_name = ParameterName(param_node_pair.first);
}
if (init || input_size < temp) {
temp -= input_size;
init = false;
} else {
temp = threshold;
fusion++;
}
if (init) {
SetMirrorFusion(cnode, 1, parameter_name);
} else {
SetMirrorFusion(cnode, fusion, parameter_name);
}
init = false;
SetMirrorFusion(cnode, fusion, parameter_name);
}
MS_LOG(INFO) << "Allreduce fusion by size succeed.";
MS_LOG(INFO) << name << " fusion by size succeed.";
return SUCCESS;
}
Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) {
Status AllCommFusion::SetFusionBySizeReduceScatter(const CNodePtr &ret, int64_t threshold, const PrimitivePtr &primp) {
auto filter = [primp](const AnfNodePtr &node) { return !IsPrimitiveCNode(node, primp); };
auto todo = DeepScopedGraphSearchWithFilter(ret, AlwaysInclude, filter);
auto temp = threshold;
int64_t fusion = 1;
bool init = true;
for (auto &node : todo) {
auto cnode = node->cast<CNodePtr>();
if (cnode->input(1) == nullptr) continue;
FuncGraphPtr func_graph = cnode->func_graph();
std::pair<AnfNodePtr, bool> param_node_pair =
FindParameterWithAllgather(cnode->input(1), func_graph, REDUCE_SCATTER);
if (!param_node_pair.first) continue;
auto parameter_name = ParameterName(param_node_pair.first);
auto input_shapes = GetNodeShape(param_node_pair.first);
int64_t input_size = std::accumulate(input_shapes[0].begin(), input_shapes[0].end(), 1, std::multiplies<int64_t>());
if (init || input_size < temp) {
temp -= input_size;
init = false;
} else {
temp = threshold;
fusion++;
}
SetMirrorFusion(cnode, fusion, parameter_name);
}
MS_LOG(INFO) << "Reduce_Scatter fusion by size succeed.";
return SUCCESS;
}
Status AllCommFusion::ProcessCommOpsFusion(const CNodePtr &ret, const std::string &comm_name) {
if (ret == nullptr) {
MS_LOG(ERROR) << "ret is nullptr.";
return FAILED;
@ -83,7 +122,7 @@ Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) {
MS_EXCEPTION_IF_NULL(root_graph_);
auto graph_set = ForwardGraph(root_graph_);
if (graph_set.size() > 1) {
MS_LOG(WARNING) << "AllReduce fusion don't support multiple subgraphs now.";
MS_LOG(WARNING) << comm_name << "fusion don't support multiple subgraphs now.";
return SUCCESS;
}
auto forward_graph = *(graph_set.begin());
@ -91,16 +130,35 @@ Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) {
forward_ret_ = forward_graph->get_return();
MS_EXCEPTION_IF_NULL(forward_ret_);
if (allreduce_graph_.set_head_cnode(forward_ret_) != SUCCESS) {
MS_LOG(ERROR) << "AllreduceGraph set_head_cnode failed.";
MS_LOG(ERROR) << comm_name << "Graph set_head_cnode failed.";
return FAILED;
}
int64_t threshold = ParallelContext::GetInstance()->fusion_threshold_mb() * 1024 * 1024 / 4;
int64_t threshold = 0;
if (comm_name == ALL_REDUCE) {
threshold = ParallelContext::GetInstance()->fusion_threshold_mb();
} else if (comm_name == ALL_GATHER) {
threshold = ParallelContext::GetInstance()->allgather_fusion_threshold_mb();
} else if (comm_name == REDUCE_SCATTER) {
threshold = ParallelContext::GetInstance()->reducescatter_fusion_threshold_mb();
} else {
MS_LOG(ERROR) << " Comm Ops must be ALL_REDUCE, ALL_GATHER or REDUCE_SCATTER, but got " << comm_name;
}
threshold *= DEFAULT_THRESHOLD_MB_TO_BYTE;
if (threshold <= 0) {
MS_LOG(ERROR) << "The threshold of SetFusionBySize must be larger than 0, but got " << threshold << ".";
MS_LOG(ERROR) << "The threshold of" << comm_name << "fusion must be larger than 0, but got " << threshold << ".";
return FAILED;
}
if (comm_name == REDUCE_SCATTER) {
(void)SetFusionBySizeReduceScatter(ret, threshold, prim::kPrimVirtualAssignAdd);
}
if (comm_name == ALL_REDUCE) {
(void)SetFusionBySize(ret, threshold, prim::kPrimMirror);
}
if (comm_name == ALL_GATHER) {
(void)SetFusionBySize(ret, threshold, prim::kPrimMicroStepAllGather);
(void)SetFusionBySize(ret, threshold, prim::kPrimAllGather);
}
(void)SetFusionBySize(ret, threshold);
return SUCCESS;
}
} // namespace parallel

View File

@ -18,6 +18,7 @@
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_FUSION_H_
#include <vector>
#include <string>
#include "utils/hash_map.h"
#include "ir/anf.h"
#include "frontend/parallel/allreduce_fusion/allreduce_graph.h"
@ -34,22 +35,24 @@ constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_TIME = 0.1;
constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_INHERENT_TIME = 0.1;
constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_BANDWIDTH = 0.1;
constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_COMPUTATION_TIME_PARAMETER = 0.1;
constexpr int64_t DEFAULT_THRESHOLD_MB_TO_BYTE = 262144;
const uint64_t MAX_RECURSIVE_CALL_TIMES = 100;
class AllreduceFusion {
class AllCommFusion {
public:
AllreduceFusion()
AllCommFusion()
: allreduce_graph_(),
ret_(nullptr),
forward_ret_(nullptr),
root_graph_(nullptr),
tail_time_(0),
computation_time_parameter_(0) {}
virtual ~AllreduceFusion() = default;
Status ProcessAllreduceFusion(const CNodePtr &ret);
virtual ~AllCommFusion() = default;
Status ProcessCommOpsFusion(const CNodePtr &ret, const std::string &comm_name);
private:
Status SetFusionBySize(const CNodePtr &ret, int64_t threshold);
Status SetFusionBySize(const CNodePtr &ret, int64_t threshold, const PrimitivePtr &primp);
Status SetFusionBySizeReduceScatter(const CNodePtr &ret, int64_t threshold, const PrimitivePtr &primp);
AllreduceGraph allreduce_graph_;
CNodePtr ret_;
CNodePtr forward_ret_;

View File

@ -33,12 +33,15 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
bool enable_all_reduce_fusion = ParallelContext::GetInstance()->enable_all_reduce_fusion();
bool enable_all_gather_fusion = ParallelContext::GetInstance()->enable_all_gather_fusion();
bool enable_reduce_scatter_fusion = ParallelContext::GetInstance()->enable_reduce_scatter_fusion();
auto graph_set = ForwardGraph(root);
// assume no change to graph
bool changes = false;
// control whether use model_parallel mode
if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) ||
(!enable_all_reduce_fusion) || (root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY)) || graph_set.size() < 1) {
((!enable_all_reduce_fusion) && (!enable_all_gather_fusion) && (!enable_reduce_scatter_fusion)) ||
(root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY)) || graph_set.size() < 1) {
return changes;
}
@ -50,7 +53,8 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti
}, end_time{0};
(void)gettimeofday(&start_time, nullptr);
#endif
MS_LOG(INFO) << "Now entering allreduce fusion by size, and allreduce fusion before will be overlapped!";
MS_LOG(INFO) << "Now entering comm ops (allreduce, allgather, reducescatter) fusion by size, and fusion before will "
"be overlapped!";
DumpGraph(root, std::string(ALLREDUCE_FUSION_BEGIN));
pipeline::ResourceBasePtr res = optimizer->resource();
@ -61,9 +65,15 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti
CNodePtr ret = root->get_return();
MS_EXCEPTION_IF_NULL(ret);
AllreduceFusion allreduce_fusion;
if (allreduce_fusion.ProcessAllreduceFusion(ret) != SUCCESS) {
MS_LOG(EXCEPTION) << "ProcessAllreduceFusion failed";
AllCommFusion allcomm_fusion;
vector<std::string> comm_ops = {ALL_REDUCE, ALL_GATHER, REDUCE_SCATTER};
vector<bool> fusionlist = {enable_all_reduce_fusion, enable_all_gather_fusion, enable_reduce_scatter_fusion};
for (size_t i = 0; i < comm_ops.size(); i++) {
if (fusionlist[i]) {
if (allcomm_fusion.ProcessCommOpsFusion(ret, comm_ops[i]) != SUCCESS) {
MS_LOG(EXCEPTION) << "Process" << comm_ops[i] << "Fusion failed";
}
}
}
DumpGraph(root, std::string(ALLREDUCE_FUSION_END));

View File

@ -63,6 +63,8 @@ void ParallelContext::Reset() {
parameter_broadcast_ = false;
parameter_broadcast_is_set_ = false;
enable_all_reduce_fusion_ = true;
enable_all_gather_fusion_ = true;
enable_reduce_scatter_fusion_ = true;
strategy_ckpt_load_file_ = "";
strategy_ckpt_save_file_ = "";
enable_parallel_optimizer_ = false;
@ -79,6 +81,8 @@ void ParallelContext::Reset() {
sharding_propagation_ = false;
dataset_strategy_.clear();
fusion_threshold_mb_ = FUSUION_THRESHOLD;
allgather_fusion_threshold_mb_ = FUSUION_THRESHOLD;
reducescatter_fusion_threshold_mb_ = FUSUION_THRESHOLD;
fusion_threshold_is_set_ = true;
fusion_mode_ = FUSION_AUTO;
}
@ -94,6 +98,16 @@ void ParallelContext::set_fusion_threshold_mb(int64_t fusion_threshold) {
enable_all_reduce_fusion_ = true;
}
void ParallelContext::set_allgather_fusion_threshold_mb(int64_t fusion_threshold) {
allgather_fusion_threshold_mb_ = fusion_threshold;
enable_all_gather_fusion_ = true;
}
void ParallelContext::set_reducescatter_fusion_threshold_mb(int64_t fusion_threshold) {
reducescatter_fusion_threshold_mb_ = fusion_threshold;
enable_reduce_scatter_fusion_ = true;
}
bool ParallelContext::set_fusion_mode(const std::string &fusion_mode) {
auto iter = std::find(FUSION_MODE_LIST.begin(), FUSION_MODE_LIST.end(), fusion_mode);
if (iter == FUSION_MODE_LIST.end()) {

View File

@ -85,6 +85,13 @@ class ParallelContext {
void set_fusion_threshold_mb(int64_t fusion_threshold);
int64_t fusion_threshold_mb() const { return fusion_threshold_mb_; }
void set_allgather_fusion_threshold_mb(int64_t allgather_fusion_threshold);
int64_t allgather_fusion_threshold_mb() const { return allgather_fusion_threshold_mb_; }
void set_reducescatter_fusion_threshold_mb(int64_t rs_fusion_threshold);
int64_t reducescatter_fusion_threshold_mb() const { return reducescatter_fusion_threshold_mb_; }
bool set_fusion_mode(const std::string &fusion_mode);
std::string get_fusion_mode() const { return fusion_mode_; }
@ -123,6 +130,14 @@ class ParallelContext {
enable_all_reduce_fusion_ = enable_all_reduce_fusion;
}
bool enable_all_reduce_fusion() const { return enable_all_reduce_fusion_; }
void set_enable_all_gather_fusion(bool enable_all_gather_fusion) {
enable_all_gather_fusion_ = enable_all_gather_fusion;
}
bool enable_all_gather_fusion() const { return enable_all_gather_fusion_; }
void set_enable_reduce_scatter_fusion(bool enable_reduce_scatter_fusion) {
enable_reduce_scatter_fusion_ = enable_reduce_scatter_fusion;
}
bool enable_reduce_scatter_fusion() const { return enable_reduce_scatter_fusion_; }
void set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file);
std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; }
@ -170,6 +185,8 @@ class ParallelContext {
bool loss_repeated_mean_;
int64_t device_num_;
int64_t fusion_threshold_mb_;
int64_t allgather_fusion_threshold_mb_;
int64_t reducescatter_fusion_threshold_mb_; // reducescatter
int64_t global_rank_;
int64_t grad_accumulation_step_;
std::string parallel_mode_;
@ -181,6 +198,8 @@ class ParallelContext {
bool global_rank_is_set_;
bool parameter_broadcast_is_set_;
bool enable_all_reduce_fusion_;
bool enable_all_gather_fusion_;
bool enable_reduce_scatter_fusion_;
std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_indices_;
std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_;
std::string strategy_ckpt_load_file_;

View File

@ -564,7 +564,8 @@ static bool IsOriginWeight(const ParameterPtr &param) {
return true;
}
static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph,
const std::string &name = ALL_REDUCE) {
if (IsValueNode<RefKey>(node)) {
std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph);
if (param_v.size() != 1) {
@ -572,7 +573,8 @@ static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &no
<< param_v.size();
}
auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>();
if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) {
if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty() &&
name == ALL_REDUCE) {
return std::make_pair(nullptr, true);
}
return std::make_pair(node, true);
@ -580,9 +582,11 @@ static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &no
return std::make_pair(nullptr, false);
}
static std::pair<AnfNodePtr, bool> FindParameterByParameter(const AnfNodePtr &node) {
static std::pair<AnfNodePtr, bool> FindParameterByParameter(const AnfNodePtr &node,
const std::string &name = ALL_REDUCE) {
auto param_ptr = node->user_data<parallel::TensorLayout>();
if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) {
if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty() &&
name == ALL_REDUCE) {
return std::make_pair(nullptr, false);
}
return std::make_pair(node, false);
@ -637,6 +641,34 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
return std::make_pair(nullptr, false);
}
// Used for allgather and reducescatter
std::pair<AnfNodePtr, bool> FindParameterWithAllgather(const AnfNodePtr &node, const FuncGraphPtr &func_graph,
const std::string &name) {
if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
return std::make_pair(nullptr, false);
}
if (node->isa<Parameter>()) {
return FindParameterByParameter(node, name);
}
if (node->isa<ValueNode>()) {
return FindParameterByValueNode(node, func_graph, name);
}
CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
for (size_t index = 0; index < cnode->inputs().size(); ++index) {
if (index != 1) continue;
auto res = FindParameterWithAllgather(cnode->input(index), func_graph, name);
if (!res.first) {
continue;
}
return res;
}
return std::make_pair(nullptr, false);
}
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> AdaSumParamTensorLayout(const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(root);
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> adasum_param_map;

View File

@ -43,6 +43,8 @@ void HandleFullySplitParameters(const FuncGraphPtr &root);
void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root);
void HandleAdaFactorOpt(const FuncGraphPtr &root);
std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph);
std::pair<AnfNodePtr, bool> FindParameterWithAllgather(const AnfNodePtr &node, const FuncGraphPtr &func_graph,
const std::string &name);
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> AdaSumParamTensorLayout(const FuncGraphPtr &root);
bool HandleAdaSum(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> *adasum_param_tensor_layout_map);

View File

@ -25,6 +25,7 @@
#include <set>
#include <string>
#include <utility>
#include <queue>
#include "utils/hash_map.h"
#include "base/core_ops.h"

View File

@ -146,7 +146,15 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_device_num", &ParallelContext::set_device_num, "Set device num.")
.def("get_device_num_is_set", &ParallelContext::device_num_is_set, "Get device num is set.")
.def("set_fusion_threshold_mb", &ParallelContext::set_fusion_threshold_mb, "Set fusion threshold.")
.def("fusion_threshold_mb", &ParallelContext::fusion_threshold_mb, "Get fusion threshold.")
.def("set_allgather_fusion_threshold_mb", &ParallelContext::set_allgather_fusion_threshold_mb,
"Set allgather fusion threshold.")
.def("set_reducescatter_fusion_threshold_mb", &ParallelContext::set_reducescatter_fusion_threshold_mb,
"Set reducescatter fusion threshold.")
.def("fusion_threshold_mb", &ParallelContext::fusion_threshold_mb, "Get allreduce fusion threshold.")
.def("allgather_fusion_threshold_mb", &ParallelContext::allgather_fusion_threshold_mb,
"Get allgather fusion threshold.")
.def("reducescatter_fusion_threshold_mb", &ParallelContext::reducescatter_fusion_threshold_mb,
"Get reduce_scatter fusion threshold.")
.def("set_fusion_mode", &ParallelContext::set_fusion_mode, "Get fusion mode.")
.def("get_fusion_mode", &ParallelContext::get_fusion_mode, "Get fusion mode.")
.def("get_global_rank", &ParallelContext::global_rank, "Get global rank.")
@ -178,6 +186,14 @@ PYBIND11_MODULE(_c_expression, m) {
"Set enable/disable all reduce fusion.")
.def("get_enable_all_reduce_fusion", &ParallelContext::enable_all_reduce_fusion,
"Get enable/disable all reduce fusion.")
.def("set_enable_all_gather_fusion", &ParallelContext::set_enable_all_gather_fusion,
"Set enable/disable all gather fusion.")
.def("get_enable_all_gather_fusion", &ParallelContext::enable_all_gather_fusion,
"Get enable/disable all gather fusion.")
.def("set_enable_reduce_scatter_fusion", &ParallelContext::set_enable_reduce_scatter_fusion,
"Set enable/disable reduce scatter fusion.")
.def("get_enable_reduce_scatter_fusion", &ParallelContext::enable_reduce_scatter_fusion,
"Get enable/disable reduce scatter fusion.")
.def("get_parameter_broadcast", &ParallelContext::parameter_broadcast, "Get parameter broadcast.")
.def("get_parameter_broadcast_is_set", &ParallelContext::parameter_broadcast_is_set,
"Get parameter broadcast is set.")

View File

@ -32,6 +32,8 @@ class _ParallelFusionConfig:
The key of the Parallel fusion method configuration.
"""
ALLREDUCE = "allreduce"
ALLGATHER = "allgather"
REDUCESCATTER = "reducescatter"
MODE = "mode"
FUSION_CONFIG = "config"
AUTO = "auto"
@ -116,8 +118,12 @@ class _AutoParallelContext:
for key in list(config.keys()):
if key == _ParallelFusionConfig.ALLREDUCE:
self._set_allreduce_comm_fusion(config[key])
elif key == _ParallelFusionConfig.ALLGATHER:
self._set_allgather_comm_fusion(config[key], key)
elif key == _ParallelFusionConfig.REDUCESCATTER:
self._set_allgather_comm_fusion(config[key], key)
else:
raise KeyError("comm fusion type must be allreduce, but got {}".format(key))
raise KeyError("comm fusion type must be allreduce, allgather or reducescatter, but got {}".format(key))
def _set_allreduce_comm_fusion(self, comm_fusion):
"""
@ -154,6 +160,42 @@ class _AutoParallelContext:
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.INDEX:
self.set_all_reduce_fusion_split_indices(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
def _set_allgather_comm_fusion(self, comm_fusion, comm_type="allgather"):
"""
Set allgather and reducescatter fusion method for auto parallel.
Args:
comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
supports four fusion methods: `auto` and `size`.
comm_type (str): The name of the communication operator, `allgather` or `reducescatter`.
Raises:
KeyError: When key of comm_fusion is not 'mode' or 'config'.
KeyError: When `mode` is not 'auto', 'size'.
"""
self.check_context_handle()
if comm_type == "allgather" and not self.get_enable_all_gather_fusion():
return
if comm_type == "reducescatter" and not self.get_enable_reduce_scatter_fusion():
return
if not isinstance(comm_fusion, dict):
raise TypeError("For 'comm_fusion', {} config must be dict, but got the type : {}.".format(
comm_type, type(comm_fusion)))
if _ParallelFusionConfig.MODE not in comm_fusion:
raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE]
if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
else:
raise KeyError("fusion method mode must be auto or size, but got {}".format(
comm_fusion[_ParallelFusionConfig.MODE]))
fusion_threshold = 64 if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.AUTO else \
comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]
self.set_fusion_threshold_mb(fusion_threshold, comm_type)
def get_comm_fusion(self):
"""Get comm fusion config."""
self.check_context_handle()
@ -166,12 +208,13 @@ class _AutoParallelContext:
_ParallelFusionConfig.FUSION_CONFIG: config}}
def set_fusion_threshold_mb(self, fusion_threshold=64):
def set_fusion_threshold_mb(self, fusion_threshold=64, comm_type="allreduce"):
"""
Set fusion threshold (MB) for auto parallel.
Args:
fusion_threshold (int): The fusion threshold (unit: MB). Default: 64.
comm_type (str): The name of the communication operator, `allreduce`, `allgather` or `reducescatter`.
Raises:
ValueError: If the fusion threshold is not in [0, +inf].
@ -179,13 +222,30 @@ class _AutoParallelContext:
self.check_context_handle()
if fusion_threshold < 0:
raise ValueError("fusion threshold must be larger than 0, but got {}".format(fusion_threshold))
self._context_handle.set_fusion_threshold_mb(fusion_threshold)
if comm_type == _ParallelFusionConfig.ALLREDUCE:
self._context_handle.set_fusion_threshold_mb(fusion_threshold)
if comm_type == _ParallelFusionConfig.ALLGATHER:
self._context_handle.set_allgather_fusion_threshold_mb(fusion_threshold)
if comm_type == _ParallelFusionConfig.REDUCESCATTER:
self._context_handle.set_reducescatter_fusion_threshold_mb(fusion_threshold)
def fusion_threshold_mb(self):
"""Get device num."""
"""Get all reduce threshold."""
self.check_context_handle()
return self._context_handle.fusion_threshold_mb()
def allgather_fusion_threshold_mb(self):
"""Get allgather threshold."""
self.check_context_handle()
return self._context_handle.allgather_fusion_threshold_mb()
def reducescatter_fusion_threshold_mb(self):
"""Get reducescatter threshold."""
self.check_context_handle()
return self._context_handle.reducescatter_fusion_threshold_mb()
def set_global_rank(self, global_rank):
"""
Set global rank for auto parallel.
@ -624,15 +684,53 @@ class _AutoParallelContext:
self.check_context_handle()
if not isinstance(enable_all_reduce_fusion, bool):
raise TypeError("For 'set_auto_parallel_context().set_enable_all_reduce_fusion', "
"the argument 'enable_all_reduce_fusion' must be bool, but got the type : {}."
"the argument 'enable_fusion' must be bool, but got the type : {}."
.format(type(enable_all_reduce_fusion)))
self._context_handle.set_enable_all_reduce_fusion(enable_all_reduce_fusion)
def set_enable_all_gather_fusion(self, enable_all_gather_fusion):
"""
Set enable/disable all gather fusion.
Args:
enable_all_gather_fusion (bool): Enable/disable all gather fusion.
"""
self.check_context_handle()
if not isinstance(enable_all_gather_fusion, bool):
raise TypeError("For 'set_auto_parallel_context().set_enable_all_gather_fusion', "
"the argument 'enable_fusion' must be bool, but got the type : {}."
.format(type(enable_all_gather_fusion)))
self._context_handle.set_enable_all_gather_fusion(enable_all_gather_fusion)
def set_enable_reduce_scatter_fusion(self, enable_reduce_scatter_fusion):
"""
Set enable/disable reduce scatter fusion.
Args:
enable_reduce_scatter_fusion (bool): Enable/disable reduce scatter fusion.
"""
self.check_context_handle()
if not isinstance(enable_reduce_scatter_fusion, bool):
raise TypeError("For 'set_auto_parallel_context().set_enable_reduce_scatter_fusion', "
"the argument 'enable_fusion' must be bool, but got the type : {}."
.format(type(enable_reduce_scatter_fusion)))
self._context_handle.set_enable_reduce_scatter_fusion(enable_reduce_scatter_fusion)
def get_enable_all_reduce_fusion(self):
"""Get all reduce fusion flag."""
self.check_context_handle()
return self._context_handle.get_enable_all_reduce_fusion()
def get_enable_all_gather_fusion(self):
"""Get all gather fusion flag."""
self.check_context_handle()
return self._context_handle.get_enable_all_gather_fusion()
def get_enable_reduce_scatter_fusion(self):
"""Get reduce scatter flag."""
self.check_context_handle()
return self._context_handle.get_enable_reduce_scatter_fusion()
def get_device_num_is_set(self):
"""Get device number is set or not."""
self.check_context_handle()

View File

@ -0,0 +1,148 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import context
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.train.model import Model
from mindspore.nn.wrap.cell_wrapper import PipelineCell
from mindspore.parallel._auto_parallel_context import auto_parallel_context
class DatasetLenet():
def __init__(self, data, label, length=3):
self.data = data
self.label = label
self.index = 1
self.length = length
def __iter__(self):
return self
def __next__(self):
if self.index >= self.length:
raise StopIteration
self.index += 1
return self.data, self.label
def reset(self):
self.index = 0
@staticmethod
def get_dataset_size():
return 32
@staticmethod
def get_repeat_count():
return 1
@staticmethod
def get_batch_size():
return 32
def create_tuple_iterator(self, num_epochs=1, do_copy=True):
return self
class MatMulCell(nn.Cell):
def __init__(self, strategy1, strategy2, param=None):
super().__init__()
self.param = Parameter(initializer("zeros", [64, 64]), name="param")
if param is not None:
self.param = param
self.param1 = Parameter(initializer("zeros", [64, 64]), name="param1")
self.matmul = P.MatMul().shard(strategy1)
self.matmul1 = P.MatMul().shard(strategy2)
def construct(self, x):
out = self.matmul(x, self.param)
out = self.matmul1(out, self.param1)
return out
class Net(nn.Cell):
def __init__(self, strategy1, strategy2, param=None):
super().__init__()
self.block = nn.CellList()
for i in range(2):
cell = MatMulCell(strategy1, strategy2, param)
cell.pipeline_stage = i
self.block.append(cell)
def construct(self, x):
for i in range(2):
x = self.block[i](x)
return x
class PipelineSplit(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.cell = Net(strategy1, strategy2)
self.cell.block[0].matmul.add_prim_attr("parameter_start", 0)
def construct(self, x, label):
x = self.cell(x)
return x
def test_fusion_size():
"""
Feature: test_fusion_auto in size mode
Description: allgather and reduce scatter fusion in size mode
Expectation: success
"""
allgather_threshold = 8
reducescatter_threshold = 16
comm_fusion_dict = {"allgather": {"mode": "size", "config": allgather_threshold},
"reducescatter": {"mode": "size", "config": reducescatter_threshold}}
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True)
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(net.trainable_params(), learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
assert auto_parallel_context().allgather_fusion_threshold_mb() == allgather_threshold
assert auto_parallel_context().reducescatter_fusion_threshold_mb() == reducescatter_threshold
def test_fusion_auto():
"""
Feature: test_fusion_auto in auto mode
Description: allgather and reduce scatter fusion in auto mode
Expectation: success
"""
comm_fusion_dict = {"allgather": {"mode": "auto", "config": None},
"reducescatter": {"mode": "auto", "config": None}}
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True,
parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(net.trainable_params(), learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
assert auto_parallel_context().allgather_fusion_threshold_mb() == 64
assert auto_parallel_context().reducescatter_fusion_threshold_mb() == 64