From 3ad3a71fc71af580e74cd4543d540005ed53cff7 Mon Sep 17 00:00:00 2001 From: hongxing Date: Thu, 18 Jun 2020 20:57:53 +0100 Subject: [PATCH] change interface --- mindspore/context.py | 10 ++++++-- mindspore/parallel/_auto_parallel_context.py | 26 ++++++++++++++++---- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/mindspore/context.py b/mindspore/context.py index f53f58e3c4a..075ae250ab0 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -381,8 +381,8 @@ def _context(): @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str, - parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str, - full_batch=bool) + auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, + strategy_ckpt_save_file=str, full_batch=bool) def set_auto_parallel_context(**kwargs): """ Set auto parallel context. @@ -414,6 +414,12 @@ def set_auto_parallel_context(**kwargs): setting parallel strategies. - auto_parallel: Achieving parallelism automatically. + auto_parallel_search_mode (str): There are two kinds of search modes, "recursive_programming" + and "dynamic_programming". + + - recursive_programming: Recursive programming search mode. + + - dynamic_programming: Dynamic programming search mode. parameter_broadcast (bool): Indicating whether to broadcast parameters before training. "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter broadcast. Default: False. diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index 3aee81c4d53..56152a380a6 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -185,13 +185,20 @@ class _AutoParallelContext: self.check_context_handle() return self._context_handle.get_parallel_mode() - def set_strategy_search_mode(self, strategy_search_mode): + def set_strategy_search_mode(self, auto_parallel_search_mode): + """ + Set search mode of strategy. + + Args: + auto_parallel_search_mode (str): The search mode of strategy. + """ self.check_context_handle() - ret = self._context_handle.set_strategy_search_mode(strategy_search_mode) + ret = self._context_handle.set_strategy_search_mode(auto_parallel_search_mode) if ret is False: - raise ValueError("Strategy search mode does not support {}".format(strategy_search_mode)) + raise ValueError("Strategy search mode does not support {}".format(auto_parallel_search_mode)) def get_strategy_search_mode(self): + """Get search mode of strategy.""" self.check_context_handle() return self._context_handle.get_strategy_search_mode() @@ -422,6 +429,7 @@ _set_auto_parallel_context_func_map = { "cast_before_mirror": auto_parallel_context().set_cast_before_mirror, "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean, "parallel_mode": auto_parallel_context().set_parallel_mode, + "auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode, "parameter_broadcast": auto_parallel_context().set_parameter_broadcast, "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file, "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file, @@ -435,6 +443,7 @@ _get_auto_parallel_context_func_map = { "cast_before_mirror": auto_parallel_context().get_cast_before_mirror, "loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean, "parallel_mode": auto_parallel_context().get_parallel_mode, + "auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode, "parameter_broadcast": auto_parallel_context().get_parameter_broadcast, "strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file, "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file, @@ -442,8 +451,9 @@ _get_auto_parallel_context_func_map = { @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, - loss_repeated_mean=bool, parallel_mode=str, parameter_broadcast=bool, - strategy_ckpt_load_file=str, strategy_ckpt_save_file=str, full_batch=bool) + loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, + parameter_broadcast=bool, strategy_ckpt_load_file=str, + strategy_ckpt_save_file=str, full_batch=bool) def _set_auto_parallel_context(**kwargs): """ Set auto parallel context. @@ -471,6 +481,12 @@ def _set_auto_parallel_context(**kwargs): setting parallel strategies. - auto_parallel: Achieving parallelism automatically. + auto_parallel_search_mode (str): There are two kinds of search modes, "recursive_programming" + and "dynamic_programming". + + - recursive_programming: Recursive programming search mode. + + - dynamic_programming: Dynamic programming search mode. parameter_broadcast (bool): Indicating whether to broadcast parameters before training. "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter broadcast. Default: False.