change interface

This commit is contained in:
hongxing 2020-06-18 20:57:53 +01:00
parent ea87b6c443
commit 3ad3a71fc7
2 changed files with 29 additions and 7 deletions

View File

@ -381,8 +381,8 @@ def _context():
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str,
parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str,
full_batch=bool)
auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool)
def set_auto_parallel_context(**kwargs):
"""
Set auto parallel context.
@ -414,6 +414,12 @@ def set_auto_parallel_context(**kwargs):
setting parallel strategies.
- auto_parallel: Achieving parallelism automatically.
auto_parallel_search_mode (str): There are two kinds of search modes, "recursive_programming"
and "dynamic_programming".
- recursive_programming: Recursive programming search mode.
- dynamic_programming: Dynamic programming search mode.
parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
broadcast. Default: False.

View File

@ -185,13 +185,20 @@ class _AutoParallelContext:
self.check_context_handle()
return self._context_handle.get_parallel_mode()
def set_strategy_search_mode(self, strategy_search_mode):
def set_strategy_search_mode(self, auto_parallel_search_mode):
"""
Set search mode of strategy.
Args:
auto_parallel_search_mode (str): The search mode of strategy.
"""
self.check_context_handle()
ret = self._context_handle.set_strategy_search_mode(strategy_search_mode)
ret = self._context_handle.set_strategy_search_mode(auto_parallel_search_mode)
if ret is False:
raise ValueError("Strategy search mode does not support {}".format(strategy_search_mode))
raise ValueError("Strategy search mode does not support {}".format(auto_parallel_search_mode))
def get_strategy_search_mode(self):
"""Get search mode of strategy."""
self.check_context_handle()
return self._context_handle.get_strategy_search_mode()
@ -422,6 +429,7 @@ _set_auto_parallel_context_func_map = {
"cast_before_mirror": auto_parallel_context().set_cast_before_mirror,
"loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
"parallel_mode": auto_parallel_context().set_parallel_mode,
"auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode,
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
@ -435,6 +443,7 @@ _get_auto_parallel_context_func_map = {
"cast_before_mirror": auto_parallel_context().get_cast_before_mirror,
"loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
"parallel_mode": auto_parallel_context().get_parallel_mode,
"auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode,
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
@ -442,8 +451,9 @@ _get_auto_parallel_context_func_map = {
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool,
loss_repeated_mean=bool, parallel_mode=str, parameter_broadcast=bool,
strategy_ckpt_load_file=str, strategy_ckpt_save_file=str, full_batch=bool)
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool)
def _set_auto_parallel_context(**kwargs):
"""
Set auto parallel context.
@ -471,6 +481,12 @@ def _set_auto_parallel_context(**kwargs):
setting parallel strategies.
- auto_parallel: Achieving parallelism automatically.
auto_parallel_search_mode (str): There are two kinds of search modes, "recursive_programming"
and "dynamic_programming".
- recursive_programming: Recursive programming search mode.
- dynamic_programming: Dynamic programming search mode.
parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
broadcast. Default: False.