forked from mindspore-Ecosystem/mindspore
!21776 set device_id master 0813
Merge pull request !21776 from mindspore_ding/set_device_id_master_0813
This commit is contained in:
commit
fa12d62d4d
|
@ -951,3 +951,33 @@ def args_type_check(*type_args, **type_kwargs):
|
|||
return wrapper
|
||||
|
||||
return type_check
|
||||
|
||||
|
||||
_set_record = {}
|
||||
|
||||
|
||||
def args_unreset_check(*unreset_args, **unreset_kwargs):
|
||||
"""Check the entered non repeatable setting properties."""
|
||||
|
||||
def unreset_check(func):
|
||||
sig = inspect.signature(func)
|
||||
bound_unreset = sig.bind_partial(*unreset_args, **unreset_kwargs).arguments
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
nonlocal bound_unreset
|
||||
bound_values = sig.bind(*args, **kwargs)
|
||||
argument_dict = bound_values.arguments
|
||||
if "kwargs" in bound_unreset:
|
||||
bound_unreset = bound_unreset["kwargs"]
|
||||
if "kwargs" in argument_dict:
|
||||
argument_dict = argument_dict["kwargs"]
|
||||
for name, value in argument_dict.items():
|
||||
if name in _set_record.keys():
|
||||
raise TypeError('Argument{}non resettable parameter{}.'.format(name, bound_unreset[name]))
|
||||
if name in bound_unreset:
|
||||
_set_record[name] = value
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return unreset_check
|
||||
|
|
|
@ -25,7 +25,7 @@ from types import FunctionType
|
|||
|
||||
from mindspore import log as logger
|
||||
from mindspore._c_expression import MSContext, ms_ctx_param
|
||||
from mindspore._checkparam import args_type_check, Validator
|
||||
from mindspore._checkparam import args_type_check, Validator, args_unreset_check
|
||||
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
|
||||
_reset_auto_parallel_context
|
||||
from mindspore.parallel._ps_context import _set_ps_context, _get_ps_context, _reset_ps_context
|
||||
|
@ -510,7 +510,7 @@ def _check_target_specific_cfgs(device, arg_key):
|
|||
", ignore it.")
|
||||
return False
|
||||
|
||||
|
||||
@args_unreset_check(device_id=int, variable_memory_max_size=str, max_device_memory=str)
|
||||
@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool,
|
||||
save_graphs_path=str, enable_dump=bool, auto_tune_mode=str,
|
||||
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
|
||||
|
|
|
@ -21,7 +21,7 @@ from mindspore.common.tensor import Tensor
|
|||
from mindspore import nn
|
||||
from mindspore.ops.operations import _quant_ops as Q
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU', device_id=0)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
|
|
@ -20,7 +20,7 @@ import mindspore.nn as nn
|
|||
import mindspore.context as context
|
||||
from mindspore.ops.operations import _quant_ops as Q
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU', device_id=0)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
|
|
@ -20,7 +20,7 @@ import mindspore.nn as nn
|
|||
import mindspore.context as context
|
||||
from mindspore.ops.operations import _quant_ops as Q
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU', device_id=0)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
|
|
@ -21,9 +21,8 @@ from mindspore.context import ParallelMode
|
|||
from mindspore.communication._comm_helper import GlobalComm
|
||||
from .test_auto_parallel_resnet import resnet50
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_context(device_id=0)
|
||||
context.set_context.__wrapped__(device_id=0)
|
||||
GlobalComm.CHECK_ENVS = False
|
||||
init()
|
||||
GlobalComm.CHECK_ENVS = True
|
||||
|
|
|
@ -34,7 +34,7 @@ from mindspore.context import ParallelMode
|
|||
from mindspore.communication._comm_helper import GlobalComm
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_context(device_id=0)
|
||||
context.set_context.__wrapped__(device_id=0)
|
||||
GlobalComm.CHECK_ENVS = False
|
||||
init()
|
||||
GlobalComm.CHECK_ENVS = True
|
||||
|
|
|
@ -33,7 +33,7 @@ from mindspore.context import ParallelMode
|
|||
from mindspore.communication._comm_helper import GlobalComm
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_context(device_id=0)
|
||||
context.set_context.__wrapped__(device_id=0)
|
||||
GlobalComm.CHECK_ENVS = False
|
||||
init()
|
||||
GlobalComm.CHECK_ENVS = True
|
||||
|
|
|
@ -49,9 +49,8 @@ def test_switch_mode():
|
|||
def test_set_device_id():
|
||||
""" test_set_device_id """
|
||||
with pytest.raises(TypeError):
|
||||
context.set_context(device_id=1)
|
||||
context.set_context(device_id="cpu")
|
||||
assert context.get_context("device_id") == 0
|
||||
context.set_context(device_id=1)
|
||||
assert context.get_context("device_id") == 1
|
||||
|
||||
|
||||
|
@ -115,14 +114,17 @@ def test_variable_memory_max_size():
|
|||
context.set_context(variable_memory_max_size=True)
|
||||
with pytest.raises(TypeError):
|
||||
context.set_context(variable_memory_max_size=1)
|
||||
with pytest.raises(ValueError):
|
||||
context.set_context(variable_memory_max_size="")
|
||||
with pytest.raises(ValueError):
|
||||
context.set_context(variable_memory_max_size="1G")
|
||||
with pytest.raises(ValueError):
|
||||
context.set_context(variable_memory_max_size="32GB")
|
||||
context.set_context(variable_memory_max_size="3GB")
|
||||
context.set_context.__wrapped__(variable_memory_max_size="3GB")
|
||||
|
||||
def test_max_device_memory_size():
|
||||
"""test_max_device_memory_size"""
|
||||
with pytest.raises(TypeError):
|
||||
context.set_context(max_device_memory=True)
|
||||
with pytest.raises(TypeError):
|
||||
context.set_context(max_device_memory=1)
|
||||
context.set_context(max_device_memory="3.5G")
|
||||
context.set_context.__wrapped__(max_device_memory="3GB")
|
||||
|
||||
def test_print_file_path():
|
||||
"""test_print_file_path"""
|
||||
|
@ -132,8 +134,9 @@ def test_print_file_path():
|
|||
|
||||
def test_set_context():
|
||||
""" test_set_context """
|
||||
context.set_context.__wrapped__(device_id=0)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
|
||||
device_id=0, save_graphs=True, save_graphs_path="mindspore_ir_path")
|
||||
save_graphs=True, save_graphs_path="mindspore_ir_path")
|
||||
assert context.get_context("device_id") == 0
|
||||
assert context.get_context("device_target") == "Ascend"
|
||||
assert context.get_context("save_graphs")
|
||||
|
|
|
@ -21,7 +21,6 @@ from mindspore.common.tensor import Tensor
|
|||
|
||||
def setup_module(module):
|
||||
context.set_context(mode=context.PYNATIVE_MODE, save_graphs=False, device_target="Ascend")
|
||||
context.set_context(device_id=0)
|
||||
|
||||
|
||||
c1 = Tensor([2], mstype.int32)
|
||||
|
|
Loading…
Reference in New Issue