forked from mindspore-Ecosystem/mindspore
!26807 [Auto parallel] [Sharding propagation] Interface change of sharding propagation
Merge pull request !26807 from Xiaoda/113-auto-parallel-search-mode-changes-to-search-mode
This commit is contained in:
commit
9f8ec2c5ab
|
@ -31,7 +31,7 @@ std::map<std::string, Shape> param_shapes;
|
|||
|
||||
std::vector<std::string> PARALLEL_MODE_LIST = {STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL,
|
||||
AUTO_PARALLEL};
|
||||
std::vector<std::string> STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECURSIVE_PROGRAMMING};
|
||||
std::vector<std::string> STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECURSIVE_PROGRAMMING, SHARDING_PROPAGATION};
|
||||
|
||||
std::vector<std::string> COMMUNI_PARALLEL_MODE_LIST = {ALL_GROUP_PARALLEL, SAME_SERVER_GROUP_PARALLEL,
|
||||
NO_GROUP_PARALLEL};
|
||||
|
@ -72,7 +72,6 @@ void ParallelContext::Reset() {
|
|||
communi_parallel_mode_ = ALL_GROUP_PARALLEL;
|
||||
optimizer_weight_shard_size_ = -1;
|
||||
optimizer_weight_shard_aggregated_save_ = false;
|
||||
sharding_propagation_ = false;
|
||||
enable_all2all_ = false;
|
||||
grad_accumulation_shard_ = true;
|
||||
dataset_strategy_.clear();
|
||||
|
@ -276,8 +275,6 @@ void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func
|
|||
MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
|
||||
}
|
||||
|
||||
void ParallelContext::set_sharding_propagation(const bool stra_pto) { sharding_propagation_ = stra_pto; }
|
||||
|
||||
void ParallelContext::set_enable_all2all(const bool enable) { enable_all2all_ = enable; }
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -42,6 +42,7 @@ constexpr char SEMI_AUTO_PARALLEL[] = "semi_auto_parallel";
|
|||
|
||||
constexpr char DYNAMIC_PROGRAMMING[] = "dynamic_programming";
|
||||
constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming";
|
||||
constexpr char SHARDING_PROPAGATION[] = "sharding_propagation";
|
||||
|
||||
constexpr char TRAINING[] = "training";
|
||||
constexpr char ACCUMULATION[] = "accumulation";
|
||||
|
@ -134,8 +135,6 @@ class ParallelContext {
|
|||
|
||||
bool set_communi_parallel_mode(const std::string &communi_parallel_mode);
|
||||
std::string communi_parallel_mode() const { return communi_parallel_mode_; }
|
||||
void set_sharding_propagation(const bool);
|
||||
bool sharding_propagation() const { return sharding_propagation_; }
|
||||
void set_enable_all2all(const bool);
|
||||
bool enable_all2all() const { return enable_all2all_; }
|
||||
void set_dataset_repeat_dim_right(const bool dataset_repeat_dim_right) {
|
||||
|
@ -179,9 +178,6 @@ class ParallelContext {
|
|||
int64_t optimizer_weight_shard_size_;
|
||||
bool optimizer_weight_shard_aggregated_save_;
|
||||
bool grad_accumulation_shard_;
|
||||
// In AUTO_PARALLEL mode, 'sharding_propagation_' = True indicates that sharding-configured operators
|
||||
// will propagate the sharding strategies to other operators with minimum redistribution cost.
|
||||
bool sharding_propagation_;
|
||||
// Enable AllToAll or not. If false, use AllGather and Split.
|
||||
bool enable_all2all_;
|
||||
std::vector<std::vector<int64_t>> dataset_strategy_;
|
||||
|
|
|
@ -71,13 +71,8 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
|
|||
return changes;
|
||||
}
|
||||
|
||||
// check whether strategy_search_mode is valid
|
||||
std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode();
|
||||
if ((strategy_search_mode != DYNAMIC_PROGRAMMING) && (strategy_search_mode != RECURSIVE_PROGRAMMING)) {
|
||||
// Setting searching mode: dynamic programming as default.
|
||||
strategy_search_mode = DYNAMIC_PROGRAMMING;
|
||||
MS_LOG(INFO) << "Non-idicated strategy searching mode, using DP searching mode as default";
|
||||
}
|
||||
MS_LOG(INFO) << "search_mode: " << strategy_search_mode;
|
||||
|
||||
struct timeval start_time {
|
||||
0
|
||||
|
@ -108,7 +103,7 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
|
|||
}
|
||||
|
||||
// search parallelization strategy
|
||||
if (strategy_search_mode == DYNAMIC_PROGRAMMING) {
|
||||
if ((strategy_search_mode == DYNAMIC_PROGRAMMING) || (strategy_search_mode == SHARDING_PROPAGATION)) {
|
||||
if (ParallelStrategySearch(all_nodes, root) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using DP searching mode";
|
||||
}
|
||||
|
@ -374,7 +369,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
if (ParallelContext::GetInstance()->sharding_propagation() &&
|
||||
if ((ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION) &&
|
||||
(operator_info->name().find(VIRTUAL_DATA_SET_INFO) != std::string::npos)) {
|
||||
const auto &swc_vec = operator_info->GetStrategyCost();
|
||||
if (swc_vec.empty()) {
|
||||
|
@ -644,7 +639,7 @@ void CreateEdgeBetweenTwoOps(const OperatorInfoPtr &prev_op_info, const Operator
|
|||
node_op_info->AddPrevEdge(edge_ptr);
|
||||
prev_op_info->AddSuccEdge(edge_ptr);
|
||||
entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr);
|
||||
if (ParallelContext::GetInstance()->sharding_propagation() && (prev_prim->name() == CAST) &&
|
||||
if ((ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION) && (prev_prim->name() == CAST) &&
|
||||
(configured_stra_ops_.find(node_op_info) != configured_stra_ops_.end())) {
|
||||
const auto next_op_stra = configured_stra_ops_[node_op_info];
|
||||
const auto cast_stra = edge_ptr->GetPrevOpStrategyByNextOpStrategyWithMiniComm(next_op_stra);
|
||||
|
@ -995,7 +990,7 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
|
|||
}
|
||||
|
||||
// Step 4: run the strategy searching algorithm
|
||||
if (ParallelContext::GetInstance()->sharding_propagation()) {
|
||||
if ((ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION)) {
|
||||
entire_costgraph->StrategyPropagate(configured_stra_ops_);
|
||||
configured_stra_ops_.clear();
|
||||
} else if (GetStrategy(entire_costgraph) != SUCCESS) {
|
||||
|
|
|
@ -206,9 +206,6 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
"Set whether to integrated save weight shard when enable parallel optimizer.")
|
||||
.def("get_optimizer_weight_shard_aggregated_save", &ParallelContext::optimizer_weight_shard_aggregated_save,
|
||||
"Get whether to integrated save weight shard when enable parallel optimizer.")
|
||||
.def("set_sharding_propagation", &ParallelContext::set_sharding_propagation,
|
||||
"Set sharding strategy propagation value.")
|
||||
.def("get_sharding_propagation", &ParallelContext::sharding_propagation, "Get sharding strategy propagation value.")
|
||||
.def("set_enable_alltoall", &ParallelContext::set_enable_all2all, "Set the enabling AllToAll value.")
|
||||
.def("get_enable_alltoall", &ParallelContext::enable_all2all, "Get the enabling AllToAll value.")
|
||||
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
|
||||
|
|
|
@ -370,7 +370,7 @@ def _context():
|
|||
|
||||
|
||||
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str,
|
||||
auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||
search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
||||
all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int,
|
||||
parallel_optimizer_config=dict)
|
||||
|
@ -395,7 +395,7 @@ def set_auto_parallel_context(**kwargs):
|
|||
=========================== ===========================
|
||||
device_num gradient_fp32_sync
|
||||
global_rank loss_repeated_mean
|
||||
gradients_mean auto_parallel_search_mode
|
||||
gradients_mean search_mode
|
||||
parallel_mode strategy_ckpt_load_file
|
||||
all_reduce_fusion_config strategy_ckpt_save_file
|
||||
enable_parallel_optimizer dataset_strategy
|
||||
|
@ -423,12 +423,14 @@ def set_auto_parallel_context(**kwargs):
|
|||
- semi_auto_parallel: Achieves data and model parallelism by setting parallel strategies.
|
||||
|
||||
- auto_parallel: Achieving parallelism automatically.
|
||||
auto_parallel_search_mode (str): There are two kinds of shard strategy search modes, "recursive_programming"
|
||||
and "dynamic_programming". Default: "dynamic_programming".
|
||||
search_mode (str): There are three kinds of shard strategy search modes: "recursive_programming",
|
||||
"dynamic_programming" and "sharding_propagation". Default: "dynamic_programming".
|
||||
|
||||
- recursive_programming: Recursive programming search mode.
|
||||
|
||||
- dynamic_programming: Dynamic programming search mode.
|
||||
|
||||
- sharding_propagation: Propagate shardings from configured ops to non-configured ops.
|
||||
parameter_broadcast (bool): Whether to broadcast parameters before training. Before training, in order to have
|
||||
the same network initialization parameter values for all devices, broadcast the parameters
|
||||
on device 0 to other devices. Parameter broadcasting in different parallel modes is different,
|
||||
|
@ -482,7 +484,7 @@ def set_auto_parallel_context(**kwargs):
|
|||
>>> context.set_auto_parallel_context(gradients_mean=True)
|
||||
>>> context.set_auto_parallel_context(gradient_fp32_sync=False)
|
||||
>>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
>>> context.set_auto_parallel_context(auto_parallel_search_mode="dynamic_programming")
|
||||
>>> context.set_auto_parallel_context(search_mode="dynamic_programming")
|
||||
>>> context.set_auto_parallel_context(parameter_broadcast=False)
|
||||
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
|
||||
>>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
|
||||
|
@ -526,7 +528,7 @@ def reset_auto_parallel_context():
|
|||
- gradients_mean: False.
|
||||
- gradient_fp32_sync: True.
|
||||
- parallel_mode: 'stand_alone'.
|
||||
- auto_parallel_search_mode: 'dynamic_programming'.
|
||||
- search_mode: 'dynamic_programming'.
|
||||
- parameter_broadcast: False.
|
||||
- strategy_ckpt_load_file: ''.
|
||||
- strategy_ckpt_save_file: ''.
|
||||
|
|
|
@ -214,19 +214,19 @@ class _AutoParallelContext:
|
|||
return context.ParallelMode.STAND_ALONE
|
||||
return self._context_handle.get_parallel_mode()
|
||||
|
||||
def set_strategy_search_mode(self, auto_parallel_search_mode):
|
||||
def set_strategy_search_mode(self, search_mode):
|
||||
"""
|
||||
Set search mode of strategy.
|
||||
|
||||
Args:
|
||||
auto_parallel_search_mode (str): The search mode of strategy.
|
||||
search_mode (str): The search mode of strategy.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
ret = self._context_handle.set_strategy_search_mode(auto_parallel_search_mode)
|
||||
ret = self._context_handle.set_strategy_search_mode(search_mode)
|
||||
if ret is False:
|
||||
raise ValueError("The context configuration parameter 'auto_parallel_search_mode' only support "
|
||||
raise ValueError("The context configuration parameter 'search_mode' only support "
|
||||
"'recursive_programming' and 'dynamic_programming', but got the value : {}."
|
||||
.format(auto_parallel_search_mode))
|
||||
.format(search_mode))
|
||||
|
||||
def get_strategy_search_mode(self):
|
||||
"""Get search mode of strategy."""
|
||||
|
@ -542,27 +542,6 @@ class _AutoParallelContext:
|
|||
self.check_context_handle()
|
||||
return self._context_handle.get_grad_accumulation_shard()
|
||||
|
||||
def set_sharding_propagation(self, sharding_propagation):
|
||||
"""
|
||||
Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True, the strategy-configured operators
|
||||
will propagate the strategies to other operators with minimum redistribution cost; otherwise, the algorithm
|
||||
will search the desired strategies.
|
||||
Default: False.
|
||||
|
||||
Args:
|
||||
sharding_propagation (bool): Enable/disable strategy propagation.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
if not isinstance(sharding_propagation, bool):
|
||||
raise TypeError("The type of parameter 'sharding_propagation' must be bool, "
|
||||
"but got the type : {}.".format(type(sharding_propagation)))
|
||||
self._context_handle.set_sharding_propagation(sharding_propagation)
|
||||
|
||||
def get_sharding_propagation(self):
|
||||
"""Get the value of sharding strategy propagation."""
|
||||
self.check_context_handle()
|
||||
return self._context_handle.get_sharding_propagation()
|
||||
|
||||
def set_enable_alltoall(self, enable_a2a):
|
||||
"""
|
||||
Set the value of enabling AllToAll. If False, AllGather and Split are used to circumvent AllToAll.
|
||||
|
@ -697,7 +676,7 @@ _set_auto_parallel_context_func_map = {
|
|||
"loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
|
||||
"pipeline_stages": auto_parallel_context().set_pipeline_stages,
|
||||
"parallel_mode": auto_parallel_context().set_parallel_mode,
|
||||
"auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode,
|
||||
"search_mode": auto_parallel_context().set_strategy_search_mode,
|
||||
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
|
||||
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
|
||||
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
|
||||
|
@ -711,7 +690,6 @@ _set_auto_parallel_context_func_map = {
|
|||
"communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode,
|
||||
"optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size,
|
||||
"optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save,
|
||||
"sharding_propagation": auto_parallel_context().set_sharding_propagation,
|
||||
"enable_alltoall": auto_parallel_context().set_enable_alltoall}
|
||||
|
||||
|
||||
|
@ -723,7 +701,7 @@ _get_auto_parallel_context_func_map = {
|
|||
"loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
|
||||
"pipeline_stages": auto_parallel_context().get_pipeline_stages,
|
||||
"parallel_mode": auto_parallel_context().get_parallel_mode,
|
||||
"auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode,
|
||||
"search_mode": auto_parallel_context().get_strategy_search_mode,
|
||||
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
|
||||
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
|
||||
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
|
||||
|
@ -735,18 +713,16 @@ _get_auto_parallel_context_func_map = {
|
|||
"communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode,
|
||||
"optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size,
|
||||
"optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save,
|
||||
"sharding_propagation": auto_parallel_context().get_sharding_propagation,
|
||||
"enable_alltoall": auto_parallel_context().get_enable_alltoall}
|
||||
|
||||
|
||||
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
|
||||
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
|
||||
loss_repeated_mean=bool, parallel_mode=str, search_mode=str,
|
||||
parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
||||
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
|
||||
communi_parallel_mode=str, optimizer_weight_shard_size=int,
|
||||
optimizer_weight_shard_aggregated_save=bool,
|
||||
sharding_propagation=bool, enable_alltoall=bool)
|
||||
optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool)
|
||||
|
||||
def _set_auto_parallel_context(**kwargs):
|
||||
"""
|
||||
|
@ -776,12 +752,14 @@ def _set_auto_parallel_context(**kwargs):
|
|||
setting parallel strategies.
|
||||
|
||||
- auto_parallel: Achieving parallelism automatically.
|
||||
auto_parallel_search_mode (str): There are two kinds of search modes, "recursive_programming"
|
||||
and "dynamic_programming". Default: "dynamic_programming".
|
||||
search_mode (str): There are two kinds of search modes: "recursive_programming", "dynamic_programming"
|
||||
and "sharding_propagation". Default: "dynamic_programming".
|
||||
|
||||
- recursive_programming: Recursive programming search mode.
|
||||
|
||||
- dynamic_programming: Dynamic programming search mode.
|
||||
|
||||
- sharding_propagation: Propagate shardings from configured ops to non-configured ops.
|
||||
parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
|
||||
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
|
||||
broadcast. Default: False.
|
||||
|
@ -858,7 +836,7 @@ def _reset_auto_parallel_context():
|
|||
- strategy_ckpt_load_file: ""
|
||||
- strategy_ckpt_save_file: ""
|
||||
- enable_parallel_optimizer: False
|
||||
- auto_parallel_search_mode: dynamic_programming
|
||||
- search_mode: dynamic_programming
|
||||
- pipeline_stages: 0
|
||||
- gradient_accumulation_shard: True
|
||||
"""
|
||||
|
|
|
@ -296,13 +296,13 @@ class ParallelStrategySearchFactory:
|
|||
newest_ckpt_file = find_newest_ckpt_file(ckpt_path)
|
||||
return load_checkpoint(newest_ckpt_file)
|
||||
|
||||
def mindspore_auto_parallel_impl(self, dataset, epoch, device_num, auto_parallel_search_mode="dynamic_programming"):
|
||||
def mindspore_auto_parallel_impl(self, dataset, epoch, device_num, search_mode="dynamic_programming"):
|
||||
parallel_mode_net = self.parallel_mode_net
|
||||
set_algo_parameters(fully_use_devices=False)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL,
|
||||
device_num=device_num,
|
||||
auto_parallel_search_mode=auto_parallel_search_mode)
|
||||
search_mode=search_mode)
|
||||
self.parallel_ckpt = self._model_train_and_save_ckpt(net=parallel_mode_net,
|
||||
dataset=dataset, epoch=epoch)
|
||||
context.reset_auto_parallel_context()
|
||||
|
@ -379,5 +379,5 @@ def test_auto_parallel_recursive_strategy_search_axis_1_basic():
|
|||
image_size=(3, 224, 224), use_parallel=True,
|
||||
num_classes=12)
|
||||
fact.mindspore_auto_parallel_impl(dataset=parallel_dataset,
|
||||
epoch=2, device_num=8, auto_parallel_search_mode="recursive_programming")
|
||||
epoch=2, device_num=8, search_mode="recursive_programming")
|
||||
fact.checkpoint_cmp(inputs_np=inputs_np)
|
||||
|
|
|
@ -290,7 +290,7 @@ class DatasetLenet():
|
|||
def test_train_64k_8p(batch_size=32, num_classes=65536): # 1048576 #131072 #32768 #8192
|
||||
dev_num = 8
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num,
|
||||
sharding_propagation=True)
|
||||
search_mode="sharding_propagation")
|
||||
cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||||
resset_op_id()
|
||||
|
|
|
@ -289,7 +289,7 @@ class DatasetLenet():
|
|||
def test_train_32k_8p(batch_size=32, num_classes=32768):
|
||||
dev_num = 8
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num,
|
||||
sharding_propagation=True)
|
||||
search_mode="sharding_propagation")
|
||||
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||||
resset_op_id()
|
||||
np.random.seed(6)
|
||||
|
|
|
@ -51,7 +51,7 @@ def compile_net(net):
|
|||
|
||||
def test_auto_parallel_activation1():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0,
|
||||
sharding_propagation=True)
|
||||
search_mode="sharding_propagation")
|
||||
strategy1 = ((4, 4), (4, 4))
|
||||
strategy2 = None
|
||||
net = Net(_w1, strategy1, strategy2)
|
||||
|
@ -60,7 +60,7 @@ def test_auto_parallel_activation1():
|
|||
|
||||
def test_auto_parallel_activation2():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0,
|
||||
sharding_propagation=True)
|
||||
search_mode="sharding_propagation")
|
||||
strategy1 = None
|
||||
strategy2 = ((4, 4),)
|
||||
net = Net(_w1, strategy1, strategy2)
|
||||
|
@ -68,7 +68,7 @@ def test_auto_parallel_activation2():
|
|||
|
||||
def auto_parallel_activation3():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0,
|
||||
sharding_propagation=True)
|
||||
search_mode="sharding_propagation")
|
||||
strategy1 = ((4, 4), (4, 4))
|
||||
strategy2 = None
|
||||
strategy3 = ((4, 4),)
|
||||
|
@ -77,7 +77,7 @@ def auto_parallel_activation3():
|
|||
|
||||
def test_auto_parallel_activation4():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0,
|
||||
sharding_propagation=True)
|
||||
search_mode="sharding_propagation")
|
||||
strategy1 = ((4, 4), (4, 4))
|
||||
strategy2 = None
|
||||
strategy3 = ((8, 2),)
|
||||
|
|
|
@ -52,7 +52,7 @@ def compile_net(net):
|
|||
|
||||
def test_auto_parallel_activation4():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0,
|
||||
sharding_propagation=True)
|
||||
search_mode="sharding_propagation")
|
||||
strategy1 = ((4, 4), (4, 4))
|
||||
strategy2 = None
|
||||
strategy3 = ((8, 2),)
|
||||
|
|
|
@ -53,7 +53,7 @@ def compile_net(net):
|
|||
|
||||
def test_auto_parallel_activation1():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0,
|
||||
sharding_propagation=True)
|
||||
search_mode="sharding_propagation")
|
||||
strategy1 = None
|
||||
strategy2 = ((8, 1),)
|
||||
strategy3 = ((1, 8), (1, 1))
|
||||
|
@ -62,7 +62,7 @@ def test_auto_parallel_activation1():
|
|||
|
||||
def test_auto_parallel_activation2():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0,
|
||||
sharding_propagation=True)
|
||||
search_mode="sharding_propagation")
|
||||
strategy1 = ((1, 8),)
|
||||
strategy2 = ((1, 1),)
|
||||
strategy3 = ((1, 8), (1, 1))
|
||||
|
|
|
@ -68,12 +68,6 @@ def test_two_matmul():
|
|||
|
||||
size = 16
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
strategy_pro = context.get_auto_parallel_context("sharding_propagation")
|
||||
assert not strategy_pro
|
||||
context.set_auto_parallel_context(sharding_propagation=True)
|
||||
strategy_pro = context.get_auto_parallel_context("sharding_propagation")
|
||||
assert strategy_pro
|
||||
context.set_auto_parallel_context(sharding_propagation=False)
|
||||
cost_model_context.set_cost_model_context(device_memory_capacity=32.0 * 1024.0 * 1024.0 * 1024.0,
|
||||
costmodel_alpha=1.0,
|
||||
costmodel_beta=60.0,
|
||||
|
|
|
@ -69,7 +69,7 @@ class GradWrapTwoInput(nn.Cell):
|
|||
|
||||
def compile_graph(net, device_num, x):
|
||||
context.set_auto_parallel_context(device_num=device_num, global_rank=0, parallel_mode="auto_parallel",
|
||||
sharding_propagation=True)
|
||||
search_mode="sharding_propagation")
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
_cell_graph_executor.compile(net, x)
|
||||
|
@ -77,7 +77,7 @@ def compile_graph(net, device_num, x):
|
|||
|
||||
def compile_graph_two_input(net, device_num, x, y):
|
||||
context.set_auto_parallel_context(device_num=device_num, global_rank=0, parallel_mode="auto_parallel",
|
||||
sharding_propagation=True)
|
||||
search_mode="sharding_propagation")
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
_cell_graph_executor.compile(net, x, y)
|
||||
|
|
Loading…
Reference in New Issue