set device id master 0813

This commit is contained in:
dingpeifei 2021-08-13 10:18:01 +08:00 committed by d00455729
parent d1555e3056
commit b4bc6000dc
10 changed files with 50 additions and 19 deletions

View File

@ -951,3 +951,33 @@ def args_type_check(*type_args, **type_kwargs):
return wrapper
return type_check
_set_record = {}
def args_unreset_check(*unreset_args, **unreset_kwargs):
"""Check the entered non repeatable setting properties."""
def unreset_check(func):
sig = inspect.signature(func)
bound_unreset = sig.bind_partial(*unreset_args, **unreset_kwargs).arguments
@wraps(func)
def wrapper(*args, **kwargs):
nonlocal bound_unreset
bound_values = sig.bind(*args, **kwargs)
argument_dict = bound_values.arguments
if "kwargs" in bound_unreset:
bound_unreset = bound_unreset["kwargs"]
if "kwargs" in argument_dict:
argument_dict = argument_dict["kwargs"]
for name, value in argument_dict.items():
if name in _set_record.keys():
raise TypeError('Argument{}non resettable parameter{}.'.format(name, bound_unreset[name]))
if name in bound_unreset:
_set_record[name] = value
return func(*args, **kwargs)
return wrapper
return unreset_check

View File

@ -24,7 +24,7 @@ from collections import namedtuple
from types import FunctionType
from mindspore import log as logger
from mindspore._c_expression import MSContext, ms_ctx_param
from mindspore._checkparam import args_type_check, Validator
from mindspore._checkparam import args_type_check, Validator, args_unreset_check
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
_reset_auto_parallel_context
from mindspore.parallel._ps_context import _set_ps_context, _get_ps_context, _reset_ps_context
@ -507,7 +507,7 @@ def _check_target_specific_cfgs(device, arg_key):
", ignore it.")
return False
@args_unreset_check(device_id=int, variable_memory_max_size=str, max_device_memory=str)
@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool,
save_graphs_path=str, enable_dump=bool, auto_tune_mode=str,
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,

View File

@ -21,7 +21,7 @@ from mindspore.common.tensor import Tensor
from mindspore import nn
from mindspore.ops.operations import _quant_ops as Q
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU', device_id=0)
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
class Net(nn.Cell):

View File

@ -20,7 +20,7 @@ import mindspore.nn as nn
import mindspore.context as context
from mindspore.ops.operations import _quant_ops as Q
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU', device_id=0)
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
class Net(nn.Cell):

View File

@ -20,7 +20,7 @@ import mindspore.nn as nn
import mindspore.context as context
from mindspore.ops.operations import _quant_ops as Q
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU', device_id=0)
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
class Net(nn.Cell):

View File

@ -21,9 +21,8 @@ from mindspore.context import ParallelMode
from mindspore.communication._comm_helper import GlobalComm
from .test_auto_parallel_resnet import resnet50
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=0)
context.set_context.__wrapped__(device_id=0)
GlobalComm.CHECK_ENVS = False
init()
GlobalComm.CHECK_ENVS = True

View File

@ -34,7 +34,7 @@ from mindspore.context import ParallelMode
from mindspore.communication._comm_helper import GlobalComm
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=0)
context.set_context.__wrapped__(device_id=0)
GlobalComm.CHECK_ENVS = False
init()
GlobalComm.CHECK_ENVS = True

View File

@ -33,7 +33,7 @@ from mindspore.context import ParallelMode
from mindspore.communication._comm_helper import GlobalComm
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=0)
context.set_context.__wrapped__(device_id=0)
GlobalComm.CHECK_ENVS = False
init()
GlobalComm.CHECK_ENVS = True

View File

@ -49,9 +49,8 @@ def test_switch_mode():
def test_set_device_id():
""" test_set_device_id """
with pytest.raises(TypeError):
context.set_context(device_id=1)
context.set_context(device_id="cpu")
assert context.get_context("device_id") == 0
context.set_context(device_id=1)
assert context.get_context("device_id") == 1
@ -115,14 +114,17 @@ def test_variable_memory_max_size():
context.set_context(variable_memory_max_size=True)
with pytest.raises(TypeError):
context.set_context(variable_memory_max_size=1)
with pytest.raises(ValueError):
context.set_context(variable_memory_max_size="")
with pytest.raises(ValueError):
context.set_context(variable_memory_max_size="1G")
with pytest.raises(ValueError):
context.set_context(variable_memory_max_size="32GB")
context.set_context(variable_memory_max_size="3GB")
context.set_context.__wrapped__(variable_memory_max_size="3GB")
def test_max_device_memory_size():
"""test_max_device_memory_size"""
with pytest.raises(TypeError):
context.set_context(max_device_memory=True)
with pytest.raises(TypeError):
context.set_context(max_device_memory=1)
context.set_context(max_device_memory="3.5G")
context.set_context.__wrapped__(max_device_memory="3GB")
def test_print_file_path():
"""test_print_file_path"""
@ -132,8 +134,9 @@ def test_print_file_path():
def test_set_context():
""" test_set_context """
context.set_context.__wrapped__(device_id=0)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
device_id=0, save_graphs=True, save_graphs_path="mindspore_ir_path")
save_graphs=True, save_graphs_path="mindspore_ir_path")
assert context.get_context("device_id") == 0
assert context.get_context("device_target") == "Ascend"
assert context.get_context("save_graphs")

View File

@ -21,7 +21,6 @@ from mindspore.common.tensor import Tensor
def setup_module(module):
context.set_context(mode=context.PYNATIVE_MODE, save_graphs=False, device_target="Ascend")
context.set_context(device_id=0)
c1 = Tensor([2], mstype.int32)