!26807 [Auto parallel] [Sharding propagation] Interface change of sharding propagation

Merge pull request !26807 from Xiaoda/113-auto-parallel-search-mode-changes-to-search-mode
This commit is contained in:
i-robot 2021-11-26 01:48:58 +00:00 committed by Gitee
commit 9f8ec2c5ab
14 changed files with 43 additions and 84 deletions

View File

@ -31,7 +31,7 @@ std::map<std::string, Shape> param_shapes;
std::vector<std::string> PARALLEL_MODE_LIST = {STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL,
AUTO_PARALLEL};
std::vector<std::string> STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECURSIVE_PROGRAMMING};
std::vector<std::string> STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECURSIVE_PROGRAMMING, SHARDING_PROPAGATION};
std::vector<std::string> COMMUNI_PARALLEL_MODE_LIST = {ALL_GROUP_PARALLEL, SAME_SERVER_GROUP_PARALLEL,
NO_GROUP_PARALLEL};
@ -72,7 +72,6 @@ void ParallelContext::Reset() {
communi_parallel_mode_ = ALL_GROUP_PARALLEL;
optimizer_weight_shard_size_ = -1;
optimizer_weight_shard_aggregated_save_ = false;
sharding_propagation_ = false;
enable_all2all_ = false;
grad_accumulation_shard_ = true;
dataset_strategy_.clear();
@ -276,8 +275,6 @@ void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func
MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
}
void ParallelContext::set_sharding_propagation(const bool stra_pto) { sharding_propagation_ = stra_pto; }
void ParallelContext::set_enable_all2all(const bool enable) { enable_all2all_ = enable; }
} // namespace parallel
} // namespace mindspore

View File

@ -42,6 +42,7 @@ constexpr char SEMI_AUTO_PARALLEL[] = "semi_auto_parallel";
constexpr char DYNAMIC_PROGRAMMING[] = "dynamic_programming";
constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming";
constexpr char SHARDING_PROPAGATION[] = "sharding_propagation";
constexpr char TRAINING[] = "training";
constexpr char ACCUMULATION[] = "accumulation";
@ -134,8 +135,6 @@ class ParallelContext {
bool set_communi_parallel_mode(const std::string &communi_parallel_mode);
std::string communi_parallel_mode() const { return communi_parallel_mode_; }
void set_sharding_propagation(const bool);
bool sharding_propagation() const { return sharding_propagation_; }
void set_enable_all2all(const bool);
bool enable_all2all() const { return enable_all2all_; }
void set_dataset_repeat_dim_right(const bool dataset_repeat_dim_right) {
@ -179,9 +178,6 @@ class ParallelContext {
int64_t optimizer_weight_shard_size_;
bool optimizer_weight_shard_aggregated_save_;
bool grad_accumulation_shard_;
// In AUTO_PARALLEL mode, 'sharding_propagation_' = True indicates that sharding-configured operators
// will propagate the sharding strategies to other operators with minimum redistribution cost.
bool sharding_propagation_;
// Enable AllToAll or not. If false, use AllGather and Split.
bool enable_all2all_;
std::vector<std::vector<int64_t>> dataset_strategy_;

View File

@ -71,13 +71,8 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
return changes;
}
// check whether strategy_search_mode is valid
std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode();
if ((strategy_search_mode != DYNAMIC_PROGRAMMING) && (strategy_search_mode != RECURSIVE_PROGRAMMING)) {
// Setting searching mode: dynamic programming as default.
strategy_search_mode = DYNAMIC_PROGRAMMING;
MS_LOG(INFO) << "Non-idicated strategy searching mode, using DP searching mode as default";
}
MS_LOG(INFO) << "search_mode: " << strategy_search_mode;
struct timeval start_time {
0
@ -108,7 +103,7 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
}
// search parallelization strategy
if (strategy_search_mode == DYNAMIC_PROGRAMMING) {
if ((strategy_search_mode == DYNAMIC_PROGRAMMING) || (strategy_search_mode == SHARDING_PROPAGATION)) {
if (ParallelStrategySearch(all_nodes, root) != SUCCESS) {
MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using DP searching mode";
}
@ -374,7 +369,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
return nullptr;
}
if (ParallelContext::GetInstance()->sharding_propagation() &&
if ((ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION) &&
(operator_info->name().find(VIRTUAL_DATA_SET_INFO) != std::string::npos)) {
const auto &swc_vec = operator_info->GetStrategyCost();
if (swc_vec.empty()) {
@ -644,7 +639,7 @@ void CreateEdgeBetweenTwoOps(const OperatorInfoPtr &prev_op_info, const Operator
node_op_info->AddPrevEdge(edge_ptr);
prev_op_info->AddSuccEdge(edge_ptr);
entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr);
if (ParallelContext::GetInstance()->sharding_propagation() && (prev_prim->name() == CAST) &&
if ((ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION) && (prev_prim->name() == CAST) &&
(configured_stra_ops_.find(node_op_info) != configured_stra_ops_.end())) {
const auto next_op_stra = configured_stra_ops_[node_op_info];
const auto cast_stra = edge_ptr->GetPrevOpStrategyByNextOpStrategyWithMiniComm(next_op_stra);
@ -995,7 +990,7 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
}
// Step 4: run the strategy searching algorithm
if (ParallelContext::GetInstance()->sharding_propagation()) {
if ((ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION)) {
entire_costgraph->StrategyPropagate(configured_stra_ops_);
configured_stra_ops_.clear();
} else if (GetStrategy(entire_costgraph) != SUCCESS) {

View File

@ -206,9 +206,6 @@ PYBIND11_MODULE(_c_expression, m) {
"Set whether to integrated save weight shard when enable parallel optimizer.")
.def("get_optimizer_weight_shard_aggregated_save", &ParallelContext::optimizer_weight_shard_aggregated_save,
"Get whether to integrated save weight shard when enable parallel optimizer.")
.def("set_sharding_propagation", &ParallelContext::set_sharding_propagation,
"Set sharding strategy propagation value.")
.def("get_sharding_propagation", &ParallelContext::sharding_propagation, "Get sharding strategy propagation value.")
.def("set_enable_alltoall", &ParallelContext::set_enable_all2all, "Set the enabling AllToAll value.")
.def("get_enable_alltoall", &ParallelContext::enable_all2all, "Get the enabling AllToAll value.")
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");

View File

@ -370,7 +370,7 @@ def _context():
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str,
auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int,
parallel_optimizer_config=dict)
@ -395,7 +395,7 @@ def set_auto_parallel_context(**kwargs):
=========================== ===========================
device_num gradient_fp32_sync
global_rank loss_repeated_mean
gradients_mean auto_parallel_search_mode
gradients_mean search_mode
parallel_mode strategy_ckpt_load_file
all_reduce_fusion_config strategy_ckpt_save_file
enable_parallel_optimizer dataset_strategy
@ -423,12 +423,14 @@ def set_auto_parallel_context(**kwargs):
- semi_auto_parallel: Achieves data and model parallelism by setting parallel strategies.
- auto_parallel: Achieving parallelism automatically.
auto_parallel_search_mode (str): There are two kinds of shard strategy search modes, "recursive_programming"
and "dynamic_programming". Default: "dynamic_programming".
search_mode (str): There are three kinds of shard strategy search modes: "recursive_programming",
"dynamic_programming" and "sharding_propagation". Default: "dynamic_programming".
- recursive_programming: Recursive programming search mode.
- dynamic_programming: Dynamic programming search mode.
- sharding_propagation: Propagate shardings from configured ops to non-configured ops.
parameter_broadcast (bool): Whether to broadcast parameters before training. Before training, in order to have
the same network initialization parameter values for all devices, broadcast the parameters
on device 0 to other devices. Parameter broadcasting in different parallel modes is different,
@ -482,7 +484,7 @@ def set_auto_parallel_context(**kwargs):
>>> context.set_auto_parallel_context(gradients_mean=True)
>>> context.set_auto_parallel_context(gradient_fp32_sync=False)
>>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
>>> context.set_auto_parallel_context(auto_parallel_search_mode="dynamic_programming")
>>> context.set_auto_parallel_context(search_mode="dynamic_programming")
>>> context.set_auto_parallel_context(parameter_broadcast=False)
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
>>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
@ -526,7 +528,7 @@ def reset_auto_parallel_context():
- gradients_mean: False.
- gradient_fp32_sync: True.
- parallel_mode: 'stand_alone'.
- auto_parallel_search_mode: 'dynamic_programming'.
- search_mode: 'dynamic_programming'.
- parameter_broadcast: False.
- strategy_ckpt_load_file: ''.
- strategy_ckpt_save_file: ''.

View File

@ -214,19 +214,19 @@ class _AutoParallelContext:
return context.ParallelMode.STAND_ALONE
return self._context_handle.get_parallel_mode()
def set_strategy_search_mode(self, auto_parallel_search_mode):
def set_strategy_search_mode(self, search_mode):
"""
Set search mode of strategy.
Args:
auto_parallel_search_mode (str): The search mode of strategy.
search_mode (str): The search mode of strategy.
"""
self.check_context_handle()
ret = self._context_handle.set_strategy_search_mode(auto_parallel_search_mode)
ret = self._context_handle.set_strategy_search_mode(search_mode)
if ret is False:
raise ValueError("The context configuration parameter 'auto_parallel_search_mode' only support "
raise ValueError("The context configuration parameter 'search_mode' only support "
"'recursive_programming' and 'dynamic_programming', but got the value : {}."
.format(auto_parallel_search_mode))
.format(search_mode))
def get_strategy_search_mode(self):
"""Get search mode of strategy."""
@ -542,27 +542,6 @@ class _AutoParallelContext:
self.check_context_handle()
return self._context_handle.get_grad_accumulation_shard()
def set_sharding_propagation(self, sharding_propagation):
"""
Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True, the strategy-configured operators
will propagate the strategies to other operators with minimum redistribution cost; otherwise, the algorithm
will search the desired strategies.
Default: False.
Args:
sharding_propagation (bool): Enable/disable strategy propagation.
"""
self.check_context_handle()
if not isinstance(sharding_propagation, bool):
raise TypeError("The type of parameter 'sharding_propagation' must be bool, "
"but got the type : {}.".format(type(sharding_propagation)))
self._context_handle.set_sharding_propagation(sharding_propagation)
def get_sharding_propagation(self):
"""Get the value of sharding strategy propagation."""
self.check_context_handle()
return self._context_handle.get_sharding_propagation()
def set_enable_alltoall(self, enable_a2a):
"""
Set the value of enabling AllToAll. If False, AllGather and Split are used to circumvent AllToAll.
@ -697,7 +676,7 @@ _set_auto_parallel_context_func_map = {
"loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
"pipeline_stages": auto_parallel_context().set_pipeline_stages,
"parallel_mode": auto_parallel_context().set_parallel_mode,
"auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode,
"search_mode": auto_parallel_context().set_strategy_search_mode,
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
@ -711,7 +690,6 @@ _set_auto_parallel_context_func_map = {
"communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode,
"optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size,
"optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save,
"sharding_propagation": auto_parallel_context().set_sharding_propagation,
"enable_alltoall": auto_parallel_context().set_enable_alltoall}
@ -723,7 +701,7 @@ _get_auto_parallel_context_func_map = {
"loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
"pipeline_stages": auto_parallel_context().get_pipeline_stages,
"parallel_mode": auto_parallel_context().get_parallel_mode,
"auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode,
"search_mode": auto_parallel_context().get_strategy_search_mode,
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
@ -735,18 +713,16 @@ _get_auto_parallel_context_func_map = {
"communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode,
"optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size,
"optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save,
"sharding_propagation": auto_parallel_context().get_sharding_propagation,
"enable_alltoall": auto_parallel_context().get_enable_alltoall}
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
loss_repeated_mean=bool, parallel_mode=str, search_mode=str,
parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
communi_parallel_mode=str, optimizer_weight_shard_size=int,
optimizer_weight_shard_aggregated_save=bool,
sharding_propagation=bool, enable_alltoall=bool)
optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool)
def _set_auto_parallel_context(**kwargs):
"""
@ -776,12 +752,14 @@ def _set_auto_parallel_context(**kwargs):
setting parallel strategies.
- auto_parallel: Achieving parallelism automatically.
auto_parallel_search_mode (str): There are two kinds of search modes, "recursive_programming"
and "dynamic_programming". Default: "dynamic_programming".
search_mode (str): There are two kinds of search modes: "recursive_programming", "dynamic_programming"
and "sharding_propagation". Default: "dynamic_programming".
- recursive_programming: Recursive programming search mode.
- dynamic_programming: Dynamic programming search mode.
- sharding_propagation: Propagate shardings from configured ops to non-configured ops.
parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
broadcast. Default: False.
@ -858,7 +836,7 @@ def _reset_auto_parallel_context():
- strategy_ckpt_load_file: ""
- strategy_ckpt_save_file: ""
- enable_parallel_optimizer: False
- auto_parallel_search_mode: dynamic_programming
- search_mode: dynamic_programming
- pipeline_stages: 0
- gradient_accumulation_shard: True
"""

View File

@ -296,13 +296,13 @@ class ParallelStrategySearchFactory:
newest_ckpt_file = find_newest_ckpt_file(ckpt_path)
return load_checkpoint(newest_ckpt_file)
def mindspore_auto_parallel_impl(self, dataset, epoch, device_num, auto_parallel_search_mode="dynamic_programming"):
def mindspore_auto_parallel_impl(self, dataset, epoch, device_num, search_mode="dynamic_programming"):
parallel_mode_net = self.parallel_mode_net
set_algo_parameters(fully_use_devices=False)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL,
device_num=device_num,
auto_parallel_search_mode=auto_parallel_search_mode)
search_mode=search_mode)
self.parallel_ckpt = self._model_train_and_save_ckpt(net=parallel_mode_net,
dataset=dataset, epoch=epoch)
context.reset_auto_parallel_context()
@ -379,5 +379,5 @@ def test_auto_parallel_recursive_strategy_search_axis_1_basic():
image_size=(3, 224, 224), use_parallel=True,
num_classes=12)
fact.mindspore_auto_parallel_impl(dataset=parallel_dataset,
epoch=2, device_num=8, auto_parallel_search_mode="recursive_programming")
epoch=2, device_num=8, search_mode="recursive_programming")
fact.checkpoint_cmp(inputs_np=inputs_np)

View File

@ -290,7 +290,7 @@ class DatasetLenet():
def test_train_64k_8p(batch_size=32, num_classes=65536): # 1048576 #131072 #32768 #8192
dev_num = 8
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num,
sharding_propagation=True)
search_mode="sharding_propagation")
cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0)
set_algo_parameters(elementwise_op_strategy_follow=True)
resset_op_id()

View File

@ -289,7 +289,7 @@ class DatasetLenet():
def test_train_32k_8p(batch_size=32, num_classes=32768):
dev_num = 8
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num,
sharding_propagation=True)
search_mode="sharding_propagation")
set_algo_parameters(elementwise_op_strategy_follow=True)
resset_op_id()
np.random.seed(6)

View File

@ -51,7 +51,7 @@ def compile_net(net):
def test_auto_parallel_activation1():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0,
sharding_propagation=True)
search_mode="sharding_propagation")
strategy1 = ((4, 4), (4, 4))
strategy2 = None
net = Net(_w1, strategy1, strategy2)
@ -60,7 +60,7 @@ def test_auto_parallel_activation1():
def test_auto_parallel_activation2():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0,
sharding_propagation=True)
search_mode="sharding_propagation")
strategy1 = None
strategy2 = ((4, 4),)
net = Net(_w1, strategy1, strategy2)
@ -68,7 +68,7 @@ def test_auto_parallel_activation2():
def auto_parallel_activation3():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0,
sharding_propagation=True)
search_mode="sharding_propagation")
strategy1 = ((4, 4), (4, 4))
strategy2 = None
strategy3 = ((4, 4),)
@ -77,7 +77,7 @@ def auto_parallel_activation3():
def test_auto_parallel_activation4():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0,
sharding_propagation=True)
search_mode="sharding_propagation")
strategy1 = ((4, 4), (4, 4))
strategy2 = None
strategy3 = ((8, 2),)

View File

@ -52,7 +52,7 @@ def compile_net(net):
def test_auto_parallel_activation4():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0,
sharding_propagation=True)
search_mode="sharding_propagation")
strategy1 = ((4, 4), (4, 4))
strategy2 = None
strategy3 = ((8, 2),)

View File

@ -53,7 +53,7 @@ def compile_net(net):
def test_auto_parallel_activation1():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0,
sharding_propagation=True)
search_mode="sharding_propagation")
strategy1 = None
strategy2 = ((8, 1),)
strategy3 = ((1, 8), (1, 1))
@ -62,7 +62,7 @@ def test_auto_parallel_activation1():
def test_auto_parallel_activation2():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0,
sharding_propagation=True)
search_mode="sharding_propagation")
strategy1 = ((1, 8),)
strategy2 = ((1, 1),)
strategy3 = ((1, 8), (1, 1))

View File

@ -68,12 +68,6 @@ def test_two_matmul():
size = 16
context.set_auto_parallel_context(device_num=size, global_rank=0)
strategy_pro = context.get_auto_parallel_context("sharding_propagation")
assert not strategy_pro
context.set_auto_parallel_context(sharding_propagation=True)
strategy_pro = context.get_auto_parallel_context("sharding_propagation")
assert strategy_pro
context.set_auto_parallel_context(sharding_propagation=False)
cost_model_context.set_cost_model_context(device_memory_capacity=32.0 * 1024.0 * 1024.0 * 1024.0,
costmodel_alpha=1.0,
costmodel_beta=60.0,

View File

@ -69,7 +69,7 @@ class GradWrapTwoInput(nn.Cell):
def compile_graph(net, device_num, x):
context.set_auto_parallel_context(device_num=device_num, global_rank=0, parallel_mode="auto_parallel",
sharding_propagation=True)
search_mode="sharding_propagation")
net.set_auto_parallel()
net.set_train()
_cell_graph_executor.compile(net, x)
@ -77,7 +77,7 @@ def compile_graph(net, device_num, x):
def compile_graph_two_input(net, device_num, x, y):
context.set_auto_parallel_context(device_num=device_num, global_rank=0, parallel_mode="auto_parallel",
sharding_propagation=True)
search_mode="sharding_propagation")
net.set_auto_parallel()
net.set_train()
_cell_graph_executor.compile(net, x, y)