!22359 add st for pynative synchronize

Merge pull request !22359 from chujinjin/add_st_for_pynative_synchronize
This commit is contained in:
i-robot 2021-08-26 03:35:11 +00:00 committed by Gitee
commit 785e5fe6fd
2 changed files with 56 additions and 3 deletions

View File

@ -516,7 +516,7 @@ def _check_target_specific_cfgs(device, arg_key):
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool, enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str, enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str,
enable_sparse=bool, max_call_depth=int, env_config_path=str, graph_kernel_flags=str, enable_sparse=bool, max_call_depth=int, env_config_path=str, graph_kernel_flags=str,
save_compile_cache=bool, load_compile_cache=bool, grad_for_scalar=bool) save_compile_cache=bool, load_compile_cache=bool, grad_for_scalar=bool, pynative_synchronize=bool)
def set_context(**kwargs): def set_context(**kwargs):
""" """
Set context for running environment. Set context for running environment.
@ -544,14 +544,14 @@ def set_context(**kwargs):
check_bprop print_file_path max_device_memory check_bprop print_file_path max_device_memory
device_id enable_dump enable_graph_kernel device_id enable_dump enable_graph_kernel
device_target save_dump_path graph_kernel_flags device_target save_dump_path graph_kernel_flags
enable_sparse enable_graph_kernel enable_sparse enable_graph_kernel pynative_synchronize
max_call_depth enable_reduce_precision max_call_depth enable_reduce_precision
mode enable_profiling mode enable_profiling
reserve_class_name_in_scope profiling_options reserve_class_name_in_scope profiling_options
save_graphs variable_memory_max_size save_graphs variable_memory_max_size
save_graphs_path auto_tune_mode save_graphs_path auto_tune_mode
env_config_path graph_kernel_flags env_config_path graph_kernel_flags
grad_for_scalar grad_for_scalar pynative_synchronize
save_compile_cache save_compile_cache
load_compile_cache load_compile_cache
=========================== =========================== ================= =========================== =========================== =================
@ -663,6 +663,8 @@ def set_context(**kwargs):
you should make sure the network has not been changed since the last execution. By now, we have you should make sure the network has not been changed since the last execution. By now, we have
not support automatically checking the changes yet. Default: False. not support automatically checking the changes yet. Default: False.
This is an experimental prototype that is subject to change and/or deletion. This is an experimental prototype that is subject to change and/or deletion.
pynative_synchronize (bool): Whether to enable asynchronous execution of the device in Pynative mode.
Default: False.
Raises: Raises:
ValueError: If input key is not an attribute in context. ValueError: If input key is not an attribute in context.

View File

@ -0,0 +1,51 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.get_next = P.GetNext([mstype.float32], [(1, 1)], 1, "test")
def construct(self, x1,):
x = self.get_next()
x = x + x1
return x
def test_pynative_synchronize_true():
context.set_context(pynative_synchronize=True)
with pytest.raises(RuntimeError) as execinfo:
x1 = np.random.randn(1, 1).astype(np.float32)
net = Net()
output = net(Tensor(x1))
print(output.asnumpy())
assert "GetNext" in str(execinfo.value)
def test_pynative_synchronize_false():
context.set_context(pynative_synchronize=False)
with pytest.raises(RuntimeError) as execinfo:
x1 = np.random.randn(1, 1).astype(np.float32)
net = Net()
output = net(Tensor(x1))
print(output.asnumpy())
assert "Sync stream error" in str(execinfo.value)