fix optimizer weight shard config

This commit is contained in:
Ziyan 2021-06-08 10:44:17 +08:00
parent 88be613cdc
commit 95ac0f6d58
6 changed files with 24 additions and 24 deletions

View File

@ -70,7 +70,7 @@ void ParallelContext::Reset() {
grad_accumulation_step_ = 1;
communi_parallel_mode_ = ALL_GROUP_PARALLEL;
optimizer_weight_shard_size_ = -1;
optimizer_weight_shard_integrated_save_ = false;
optimizer_weight_shard_aggregated_save_ = false;
}
void ParallelContext::set_device_num(int64_t device_num) {
@ -138,8 +138,8 @@ void ParallelContext::set_optimizer_weight_shard_size(int64_t optimizer_weight_s
optimizer_weight_shard_size_ = optimizer_weight_shard_size;
}
void ParallelContext::set_optimizer_weight_shard_integrated_save(bool optimizer_weight_shard_integrated_save) {
optimizer_weight_shard_integrated_save_ = optimizer_weight_shard_integrated_save;
void ParallelContext::set_optimizer_weight_shard_aggregated_save(bool optimizer_weight_shard_aggregated_save) {
optimizer_weight_shard_aggregated_save_ = optimizer_weight_shard_aggregated_save;
}
void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group) {

View File

@ -97,8 +97,8 @@ class ParallelContext {
void set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size);
int64_t optimizer_weight_shard_size() const { return optimizer_weight_shard_size_; }
void set_optimizer_weight_shard_integrated_save(bool optimizer_weight_shard_integrated_save);
bool optimizer_weight_shard_integrated_save() const { return optimizer_weight_shard_integrated_save_; }
void set_optimizer_weight_shard_aggregated_save(bool optimizer_weight_shard_aggregated_save);
bool optimizer_weight_shard_aggregated_save() const { return optimizer_weight_shard_aggregated_save_; }
void SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group);
const std::vector<uint32_t> GetAllReduceFusionSplitIndices(const std::string &group) const;
@ -158,7 +158,7 @@ class ParallelContext {
bool init_param_shape_;
std::string communi_parallel_mode_;
int64_t optimizer_weight_shard_size_;
bool optimizer_weight_shard_integrated_save_;
bool optimizer_weight_shard_aggregated_save_;
};
} // namespace parallel

View File

@ -557,7 +557,7 @@ Status OperatorInfo::CreateGroupForOptShard(TensorLayout *const tensor_layout, s
MS_LOG(INFO) << "Parallel optimizer: create allgather group " << allgather_group.name();
}
// save in tensor_layout for strategy ckpt
auto integrated_save = ParallelContext::GetInstance()->optimizer_weight_shard_integrated_save();
auto integrated_save = ParallelContext::GetInstance()->optimizer_weight_shard_aggregated_save();
if (!integrated_save) {
tensor_layout->set_opt_weight_shard_size(LongToInt(optimizer_weight_shard_size));
int64_t opt_weight_shard_step =

View File

@ -177,9 +177,9 @@ PYBIND11_MODULE(_c_expression, m) {
"Set opt shard group size when not fully use parallel optimizer.")
.def("get_optimizer_weight_shard_size", &ParallelContext::optimizer_weight_shard_size,
"Get opt shard group size when not fully use parallel optimizer.")
.def("set_optimizer_weight_shard_integrated_save", &ParallelContext::set_optimizer_weight_shard_integrated_save,
.def("set_optimizer_weight_shard_aggregated_save", &ParallelContext::set_optimizer_weight_shard_aggregated_save,
"Set whether to integrated save weight shard when enable parallel optimizer.")
.def("get_optimizer_weight_shard_integrated_save", &ParallelContext::optimizer_weight_shard_integrated_save,
.def("get_optimizer_weight_shard_aggregated_save", &ParallelContext::optimizer_weight_shard_aggregated_save,
"Get whether to integrated save weight shard when enable parallel optimizer.")
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");

View File

@ -479,24 +479,24 @@ class _AutoParallelContext:
self.check_context_handle()
return self._context_handle.get_optimizer_weight_shard_size()
def set_optimizer_weight_shard_integrated_save(self, optimizer_weight_shard_integrated_save):
def set_optimizer_weight_shard_aggregated_save(self, optimizer_weight_shard_aggregated_save):
"""
Set optimizer_weight_shard_integrated_save.
Set optimizer_weight_shard_aggregated_save.
Args:
optimizer_weight_shard_integrated_save (bool): Whether to integrated save weight shard when
optimizer_weight_shard_aggregated_save (bool): Whether to integrated save weight shard when
enable parallel optimizer.
"""
self.check_context_handle()
if not isinstance(optimizer_weight_shard_integrated_save, bool):
raise TypeError('optimizer_weight_shard_integrated_save is invalid type')
self._context_handle.set_optimizer_weight_shard_integrated_save(optimizer_weight_shard_integrated_save)
if not isinstance(optimizer_weight_shard_aggregated_save, bool):
raise TypeError('optimizer_weight_shard_aggregated_save is invalid type')
self._context_handle.set_optimizer_weight_shard_aggregated_save(optimizer_weight_shard_aggregated_save)
def get_optimizer_weight_shard_integrated_save(self):
def get_optimizer_weight_shard_aggregated_save(self):
"""Get optimizer_weight_shard_size."""
self.check_context_handle()
return self._context_handle.get_optimizer_weight_shard_integrated_save()
return self._context_handle.get_optimizer_weight_shard_aggregated_save()
def reset(self):
@ -557,7 +557,7 @@ _set_auto_parallel_context_func_map = {
"all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices,
"communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode,
"optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size,
"optimizer_weight_shard_integrated_save": auto_parallel_context().set_optimizer_weight_shard_integrated_save}
"optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save}
_get_auto_parallel_context_func_map = {
@ -578,7 +578,7 @@ _get_auto_parallel_context_func_map = {
"all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices,
"communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode,
"optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size,
"optimizer_weight_shard_integrated_save": auto_parallel_context().get_optimizer_weight_shard_integrated_save}
"optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save}
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
@ -587,7 +587,7 @@ _get_auto_parallel_context_func_map = {
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
communi_parallel_mode=str, optimizer_weight_shard_size=int,
optimizer_weight_shard_integrated_save=bool)
optimizer_weight_shard_aggregated_save=bool)
def _set_auto_parallel_context(**kwargs):
"""
@ -647,7 +647,7 @@ def _set_auto_parallel_context(**kwargs):
optimizer_weight_shard_size (int): Set optimizer shard group size when not fully use parallel optimizer.
It should be larger than one and less than or equal with the data parallel size.
Default: -1, which means fully use parallel optimizer in data parallel dimension.
optimizer_weight_shard_integrated_save (bool): Whether to integrated save weight shard when enable parallel
optimizer_weight_shard_aggregated_save (bool): Whether to integrated save weight shard when enable parallel
optimizer. Default: False.
Raises:

View File

@ -59,8 +59,8 @@ def test_set_auto_parallel_context():
parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set()
assert parameter_broadcast_is_set
auto_parallel_context().set_optimizer_weight_shard_integrated_save(True)
integrated_save = auto_parallel_context().get_optimizer_weight_shard_integrated_save()
auto_parallel_context().set_optimizer_weight_shard_aggregated_save(True)
integrated_save = auto_parallel_context().get_optimizer_weight_shard_aggregated_save()
assert integrated_save
with pytest.raises(ValueError):
@ -109,7 +109,7 @@ def test_reset_auto_parallel_context():
parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set()
stage = auto_parallel_context().get_pipeline_stages()
communi_parallel_mode = context.get_auto_parallel_context("communi_parallel_mode")
integrated_save = auto_parallel_context().get_optimizer_weight_shard_integrated_save()
integrated_save = auto_parallel_context().get_optimizer_weight_shard_aggregated_save()
assert device_num == 1
assert global_rank == 0