forked from mindspore-Ecosystem/mindspore
change interface
This commit is contained in:
parent
ea87b6c443
commit
3ad3a71fc7
|
@ -381,8 +381,8 @@ def _context():
|
|||
|
||||
|
||||
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str,
|
||||
parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str,
|
||||
full_batch=bool)
|
||||
auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||
strategy_ckpt_save_file=str, full_batch=bool)
|
||||
def set_auto_parallel_context(**kwargs):
|
||||
"""
|
||||
Set auto parallel context.
|
||||
|
@ -414,6 +414,12 @@ def set_auto_parallel_context(**kwargs):
|
|||
setting parallel strategies.
|
||||
|
||||
- auto_parallel: Achieving parallelism automatically.
|
||||
auto_parallel_search_mode (str): There are two kinds of search modes, "recursive_programming"
|
||||
and "dynamic_programming".
|
||||
|
||||
- recursive_programming: Recursive programming search mode.
|
||||
|
||||
- dynamic_programming: Dynamic programming search mode.
|
||||
parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
|
||||
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
|
||||
broadcast. Default: False.
|
||||
|
|
|
@ -185,13 +185,20 @@ class _AutoParallelContext:
|
|||
self.check_context_handle()
|
||||
return self._context_handle.get_parallel_mode()
|
||||
|
||||
def set_strategy_search_mode(self, strategy_search_mode):
|
||||
def set_strategy_search_mode(self, auto_parallel_search_mode):
|
||||
"""
|
||||
Set search mode of strategy.
|
||||
|
||||
Args:
|
||||
auto_parallel_search_mode (str): The search mode of strategy.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
ret = self._context_handle.set_strategy_search_mode(strategy_search_mode)
|
||||
ret = self._context_handle.set_strategy_search_mode(auto_parallel_search_mode)
|
||||
if ret is False:
|
||||
raise ValueError("Strategy search mode does not support {}".format(strategy_search_mode))
|
||||
raise ValueError("Strategy search mode does not support {}".format(auto_parallel_search_mode))
|
||||
|
||||
def get_strategy_search_mode(self):
|
||||
"""Get search mode of strategy."""
|
||||
self.check_context_handle()
|
||||
return self._context_handle.get_strategy_search_mode()
|
||||
|
||||
|
@ -422,6 +429,7 @@ _set_auto_parallel_context_func_map = {
|
|||
"cast_before_mirror": auto_parallel_context().set_cast_before_mirror,
|
||||
"loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
|
||||
"parallel_mode": auto_parallel_context().set_parallel_mode,
|
||||
"auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode,
|
||||
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
|
||||
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
|
||||
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
|
||||
|
@ -435,6 +443,7 @@ _get_auto_parallel_context_func_map = {
|
|||
"cast_before_mirror": auto_parallel_context().get_cast_before_mirror,
|
||||
"loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
|
||||
"parallel_mode": auto_parallel_context().get_parallel_mode,
|
||||
"auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode,
|
||||
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
|
||||
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
|
||||
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
|
||||
|
@ -442,8 +451,9 @@ _get_auto_parallel_context_func_map = {
|
|||
|
||||
|
||||
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool,
|
||||
loss_repeated_mean=bool, parallel_mode=str, parameter_broadcast=bool,
|
||||
strategy_ckpt_load_file=str, strategy_ckpt_save_file=str, full_batch=bool)
|
||||
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
|
||||
parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||
strategy_ckpt_save_file=str, full_batch=bool)
|
||||
def _set_auto_parallel_context(**kwargs):
|
||||
"""
|
||||
Set auto parallel context.
|
||||
|
@ -471,6 +481,12 @@ def _set_auto_parallel_context(**kwargs):
|
|||
setting parallel strategies.
|
||||
|
||||
- auto_parallel: Achieving parallelism automatically.
|
||||
auto_parallel_search_mode (str): There are two kinds of search modes, "recursive_programming"
|
||||
and "dynamic_programming".
|
||||
|
||||
- recursive_programming: Recursive programming search mode.
|
||||
|
||||
- dynamic_programming: Dynamic programming search mode.
|
||||
parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
|
||||
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
|
||||
broadcast. Default: False.
|
||||
|
|
Loading…
Reference in New Issue