forked from mindspore-Ecosystem/mindspore
fix optimizer weight shard config
This commit is contained in:
parent
88be613cdc
commit
95ac0f6d58
|
@ -70,7 +70,7 @@ void ParallelContext::Reset() {
|
|||
grad_accumulation_step_ = 1;
|
||||
communi_parallel_mode_ = ALL_GROUP_PARALLEL;
|
||||
optimizer_weight_shard_size_ = -1;
|
||||
optimizer_weight_shard_integrated_save_ = false;
|
||||
optimizer_weight_shard_aggregated_save_ = false;
|
||||
}
|
||||
|
||||
void ParallelContext::set_device_num(int64_t device_num) {
|
||||
|
@ -138,8 +138,8 @@ void ParallelContext::set_optimizer_weight_shard_size(int64_t optimizer_weight_s
|
|||
optimizer_weight_shard_size_ = optimizer_weight_shard_size;
|
||||
}
|
||||
|
||||
void ParallelContext::set_optimizer_weight_shard_integrated_save(bool optimizer_weight_shard_integrated_save) {
|
||||
optimizer_weight_shard_integrated_save_ = optimizer_weight_shard_integrated_save;
|
||||
void ParallelContext::set_optimizer_weight_shard_aggregated_save(bool optimizer_weight_shard_aggregated_save) {
|
||||
optimizer_weight_shard_aggregated_save_ = optimizer_weight_shard_aggregated_save;
|
||||
}
|
||||
|
||||
void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group) {
|
||||
|
|
|
@ -97,8 +97,8 @@ class ParallelContext {
|
|||
|
||||
void set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size);
|
||||
int64_t optimizer_weight_shard_size() const { return optimizer_weight_shard_size_; }
|
||||
void set_optimizer_weight_shard_integrated_save(bool optimizer_weight_shard_integrated_save);
|
||||
bool optimizer_weight_shard_integrated_save() const { return optimizer_weight_shard_integrated_save_; }
|
||||
void set_optimizer_weight_shard_aggregated_save(bool optimizer_weight_shard_aggregated_save);
|
||||
bool optimizer_weight_shard_aggregated_save() const { return optimizer_weight_shard_aggregated_save_; }
|
||||
|
||||
void SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group);
|
||||
const std::vector<uint32_t> GetAllReduceFusionSplitIndices(const std::string &group) const;
|
||||
|
@ -158,7 +158,7 @@ class ParallelContext {
|
|||
bool init_param_shape_;
|
||||
std::string communi_parallel_mode_;
|
||||
int64_t optimizer_weight_shard_size_;
|
||||
bool optimizer_weight_shard_integrated_save_;
|
||||
bool optimizer_weight_shard_aggregated_save_;
|
||||
};
|
||||
|
||||
} // namespace parallel
|
||||
|
|
|
@ -557,7 +557,7 @@ Status OperatorInfo::CreateGroupForOptShard(TensorLayout *const tensor_layout, s
|
|||
MS_LOG(INFO) << "Parallel optimizer: create allgather group " << allgather_group.name();
|
||||
}
|
||||
// save in tensor_layout for strategy ckpt
|
||||
auto integrated_save = ParallelContext::GetInstance()->optimizer_weight_shard_integrated_save();
|
||||
auto integrated_save = ParallelContext::GetInstance()->optimizer_weight_shard_aggregated_save();
|
||||
if (!integrated_save) {
|
||||
tensor_layout->set_opt_weight_shard_size(LongToInt(optimizer_weight_shard_size));
|
||||
int64_t opt_weight_shard_step =
|
||||
|
|
|
@ -177,9 +177,9 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
"Set opt shard group size when not fully use parallel optimizer.")
|
||||
.def("get_optimizer_weight_shard_size", &ParallelContext::optimizer_weight_shard_size,
|
||||
"Get opt shard group size when not fully use parallel optimizer.")
|
||||
.def("set_optimizer_weight_shard_integrated_save", &ParallelContext::set_optimizer_weight_shard_integrated_save,
|
||||
.def("set_optimizer_weight_shard_aggregated_save", &ParallelContext::set_optimizer_weight_shard_aggregated_save,
|
||||
"Set whether to integrated save weight shard when enable parallel optimizer.")
|
||||
.def("get_optimizer_weight_shard_integrated_save", &ParallelContext::optimizer_weight_shard_integrated_save,
|
||||
.def("get_optimizer_weight_shard_aggregated_save", &ParallelContext::optimizer_weight_shard_aggregated_save,
|
||||
"Get whether to integrated save weight shard when enable parallel optimizer.")
|
||||
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
|
||||
|
||||
|
|
|
@ -479,24 +479,24 @@ class _AutoParallelContext:
|
|||
self.check_context_handle()
|
||||
return self._context_handle.get_optimizer_weight_shard_size()
|
||||
|
||||
def set_optimizer_weight_shard_integrated_save(self, optimizer_weight_shard_integrated_save):
|
||||
def set_optimizer_weight_shard_aggregated_save(self, optimizer_weight_shard_aggregated_save):
|
||||
"""
|
||||
Set optimizer_weight_shard_integrated_save.
|
||||
Set optimizer_weight_shard_aggregated_save.
|
||||
|
||||
Args:
|
||||
optimizer_weight_shard_integrated_save (bool): Whether to integrated save weight shard when
|
||||
optimizer_weight_shard_aggregated_save (bool): Whether to integrated save weight shard when
|
||||
enable parallel optimizer.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
if not isinstance(optimizer_weight_shard_integrated_save, bool):
|
||||
raise TypeError('optimizer_weight_shard_integrated_save is invalid type')
|
||||
self._context_handle.set_optimizer_weight_shard_integrated_save(optimizer_weight_shard_integrated_save)
|
||||
if not isinstance(optimizer_weight_shard_aggregated_save, bool):
|
||||
raise TypeError('optimizer_weight_shard_aggregated_save is invalid type')
|
||||
self._context_handle.set_optimizer_weight_shard_aggregated_save(optimizer_weight_shard_aggregated_save)
|
||||
|
||||
|
||||
def get_optimizer_weight_shard_integrated_save(self):
|
||||
def get_optimizer_weight_shard_aggregated_save(self):
|
||||
"""Get optimizer_weight_shard_size."""
|
||||
self.check_context_handle()
|
||||
return self._context_handle.get_optimizer_weight_shard_integrated_save()
|
||||
return self._context_handle.get_optimizer_weight_shard_aggregated_save()
|
||||
|
||||
|
||||
def reset(self):
|
||||
|
@ -557,7 +557,7 @@ _set_auto_parallel_context_func_map = {
|
|||
"all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices,
|
||||
"communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode,
|
||||
"optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size,
|
||||
"optimizer_weight_shard_integrated_save": auto_parallel_context().set_optimizer_weight_shard_integrated_save}
|
||||
"optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save}
|
||||
|
||||
|
||||
_get_auto_parallel_context_func_map = {
|
||||
|
@ -578,7 +578,7 @@ _get_auto_parallel_context_func_map = {
|
|||
"all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices,
|
||||
"communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode,
|
||||
"optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size,
|
||||
"optimizer_weight_shard_integrated_save": auto_parallel_context().get_optimizer_weight_shard_integrated_save}
|
||||
"optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save}
|
||||
|
||||
|
||||
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
|
||||
|
@ -587,7 +587,7 @@ _get_auto_parallel_context_func_map = {
|
|||
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
||||
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
|
||||
communi_parallel_mode=str, optimizer_weight_shard_size=int,
|
||||
optimizer_weight_shard_integrated_save=bool)
|
||||
optimizer_weight_shard_aggregated_save=bool)
|
||||
|
||||
def _set_auto_parallel_context(**kwargs):
|
||||
"""
|
||||
|
@ -647,7 +647,7 @@ def _set_auto_parallel_context(**kwargs):
|
|||
optimizer_weight_shard_size (int): Set optimizer shard group size when not fully use parallel optimizer.
|
||||
It should be larger than one and less than or equal with the data parallel size.
|
||||
Default: -1, which means fully use parallel optimizer in data parallel dimension.
|
||||
optimizer_weight_shard_integrated_save (bool): Whether to integrated save weight shard when enable parallel
|
||||
optimizer_weight_shard_aggregated_save (bool): Whether to integrated save weight shard when enable parallel
|
||||
optimizer. Default: False.
|
||||
|
||||
Raises:
|
||||
|
|
|
@ -59,8 +59,8 @@ def test_set_auto_parallel_context():
|
|||
parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set()
|
||||
assert parameter_broadcast_is_set
|
||||
|
||||
auto_parallel_context().set_optimizer_weight_shard_integrated_save(True)
|
||||
integrated_save = auto_parallel_context().get_optimizer_weight_shard_integrated_save()
|
||||
auto_parallel_context().set_optimizer_weight_shard_aggregated_save(True)
|
||||
integrated_save = auto_parallel_context().get_optimizer_weight_shard_aggregated_save()
|
||||
assert integrated_save
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
@ -109,7 +109,7 @@ def test_reset_auto_parallel_context():
|
|||
parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set()
|
||||
stage = auto_parallel_context().get_pipeline_stages()
|
||||
communi_parallel_mode = context.get_auto_parallel_context("communi_parallel_mode")
|
||||
integrated_save = auto_parallel_context().get_optimizer_weight_shard_integrated_save()
|
||||
integrated_save = auto_parallel_context().get_optimizer_weight_shard_aggregated_save()
|
||||
|
||||
assert device_num == 1
|
||||
assert global_rank == 0
|
||||
|
|
Loading…
Reference in New Issue