From 5937d0335541640b07bf84633cbc96741c7df376 Mon Sep 17 00:00:00 2001 From: chujinjin Date: Wed, 25 Aug 2021 15:40:33 +0800 Subject: [PATCH] add st for pynative synchronize --- mindspore/context.py | 8 +-- .../st/pynative/test_pynative_sync_control.py | 51 +++++++++++++++++++ 2 files changed, 56 insertions(+), 3 deletions(-) create mode 100644 tests/st/pynative/test_pynative_sync_control.py diff --git a/mindspore/context.py b/mindspore/context.py index 1470b3519a2..97f09279b88 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -516,7 +516,7 @@ def _check_target_specific_cfgs(device, arg_key): enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool, enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str, enable_sparse=bool, max_call_depth=int, env_config_path=str, graph_kernel_flags=str, - save_compile_cache=bool, load_compile_cache=bool, grad_for_scalar=bool) + save_compile_cache=bool, load_compile_cache=bool, grad_for_scalar=bool, pynative_synchronize=bool) def set_context(**kwargs): """ Set context for running environment. @@ -544,14 +544,14 @@ def set_context(**kwargs): check_bprop print_file_path max_device_memory device_id enable_dump enable_graph_kernel device_target save_dump_path graph_kernel_flags - enable_sparse enable_graph_kernel + enable_sparse enable_graph_kernel pynative_synchronize max_call_depth enable_reduce_precision mode enable_profiling reserve_class_name_in_scope profiling_options save_graphs variable_memory_max_size save_graphs_path auto_tune_mode env_config_path graph_kernel_flags - grad_for_scalar + grad_for_scalar pynative_synchronize save_compile_cache load_compile_cache =========================== =========================== ================= @@ -663,6 +663,8 @@ def set_context(**kwargs): you should make sure the network has not been changed since the last execution. By now, we have not support automatically checking the changes yet. Default: False. This is an experimental prototype that is subject to change and/or deletion. + pynative_synchronize (bool): Whether to enable asynchronous execution of the device in Pynative mode. + Default: False. Raises: ValueError: If input key is not an attribute in context. diff --git a/tests/st/pynative/test_pynative_sync_control.py b/tests/st/pynative/test_pynative_sync_control.py new file mode 100644 index 00000000000..a2fb44ff58d --- /dev/null +++ b/tests/st/pynative/test_pynative_sync_control.py @@ -0,0 +1,51 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.ops import operations as P + +context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.get_next = P.GetNext([mstype.float32], [(1, 1)], 1, "test") + + def construct(self, x1,): + x = self.get_next() + x = x + x1 + return x + +def test_pynative_synchronize_true(): + context.set_context(pynative_synchronize=True) + with pytest.raises(RuntimeError) as execinfo: + x1 = np.random.randn(1, 1).astype(np.float32) + net = Net() + output = net(Tensor(x1)) + print(output.asnumpy()) + assert "GetNext" in str(execinfo.value) + +def test_pynative_synchronize_false(): + context.set_context(pynative_synchronize=False) + with pytest.raises(RuntimeError) as execinfo: + x1 = np.random.randn(1, 1).astype(np.float32) + net = Net() + output = net(Tensor(x1)) + print(output.asnumpy()) + assert "Sync stream error" in str(execinfo.value)