!32437 bug fix for pynative auto parallel

Merge pull request !32437 from wangjun/pynative_autoparallel_fix
This commit is contained in:
i-robot 2022-04-02 06:43:39 +00:00 committed by Gitee
commit 6124c82785
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 7 additions and 2 deletions

View File

@ -359,8 +359,9 @@ class _AutoParallelContext:
if run_mode == context.PYNATIVE_MODE and parallel_mode not in (
context.ParallelMode.DATA_PARALLEL, context.ParallelMode.STAND_ALONE,
context.ParallelMode.AUTO_PARALLEL):
raise ValueError(f"Pynative Only support STAND_ALONE, DATA_PARALLEL and AUTO_PARALLEL under shard function"
f"for ParallelMode, "
raise ValueError(f"Pynative only supports STAND_ALONE, DATA_PARALLEL and AUTO_PARALLEL using"
f" sharding_propagation under shard function"
f" for ParallelMode, "
f"but got {parallel_mode.upper()}.")
ret = self._context_handle.set_parallel_mode(parallel_mode)
if ret is False:
@ -383,6 +384,10 @@ class _AutoParallelContext:
search_mode (str): The search mode of strategy.
"""
self.check_context_handle()
run_mode = context.get_context("mode")
if run_mode == context.PYNATIVE_MODE and search_mode != "sharding_propagation":
raise ValueError(f"PyNative only supports AUTO_PARALLEL using sharding_propagation"
f" but got search_mode of {search_mode}.")
ret = self._context_handle.set_strategy_search_mode(search_mode)
if ret is False:
raise ValueError("The context configuration parameter 'auto_parallel_search_mode' only support "