From b4bc6000dc29bf53986a26a9be83ba6aa78676ea Mon Sep 17 00:00:00 2001 From: dingpeifei Date: Fri, 13 Aug 2021 10:18:01 +0800 Subject: [PATCH] set device id master 0813 --- mindspore/_checkparam.py | 30 +++++++++++++++++++ mindspore/context.py | 4 +-- .../st/ops/gpu/test_fake_quant_perchannel.py | 2 +- .../gpu/test_fake_quant_perchannel_grad.py | 2 +- .../ops/gpu/test_fake_quant_perlayer_grad.py | 2 +- .../test_auto_parallel_resnet_predict.py | 3 +- ...to_parallel_resnet_sharding_propagation.py | 2 +- ...o_parallel_resnet_sharding_propagation2.py | 2 +- tests/ut/python/pynative_mode/test_context.py | 21 +++++++------ .../pynative_mode/test_multigraph_sink.py | 1 - 10 files changed, 50 insertions(+), 19 deletions(-) diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index 58cec1666a4..27441d6a448 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -951,3 +951,33 @@ def args_type_check(*type_args, **type_kwargs): return wrapper return type_check + + +_set_record = {} + + +def args_unreset_check(*unreset_args, **unreset_kwargs): + """Check the entered non repeatable setting properties.""" + + def unreset_check(func): + sig = inspect.signature(func) + bound_unreset = sig.bind_partial(*unreset_args, **unreset_kwargs).arguments + @wraps(func) + def wrapper(*args, **kwargs): + nonlocal bound_unreset + bound_values = sig.bind(*args, **kwargs) + argument_dict = bound_values.arguments + if "kwargs" in bound_unreset: + bound_unreset = bound_unreset["kwargs"] + if "kwargs" in argument_dict: + argument_dict = argument_dict["kwargs"] + for name, value in argument_dict.items(): + if name in _set_record.keys(): + raise TypeError('Argument{}non resettable parameter{}.'.format(name, bound_unreset[name])) + if name in bound_unreset: + _set_record[name] = value + return func(*args, **kwargs) + + return wrapper + + return unreset_check diff --git a/mindspore/context.py b/mindspore/context.py index 7b96088fb77..043057a663b 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -24,7 +24,7 @@ from collections import namedtuple from types import FunctionType from mindspore import log as logger from mindspore._c_expression import MSContext, ms_ctx_param -from mindspore._checkparam import args_type_check, Validator +from mindspore._checkparam import args_type_check, Validator, args_unreset_check from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ _reset_auto_parallel_context from mindspore.parallel._ps_context import _set_ps_context, _get_ps_context, _reset_ps_context @@ -507,7 +507,7 @@ def _check_target_specific_cfgs(device, arg_key): ", ignore it.") return False - +@args_unreset_check(device_id=int, variable_memory_max_size=str, max_device_memory=str) @args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool, save_graphs_path=str, enable_dump=bool, auto_tune_mode=str, save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, diff --git a/tests/st/ops/gpu/test_fake_quant_perchannel.py b/tests/st/ops/gpu/test_fake_quant_perchannel.py index 91185bcc14e..3d76386e4b0 100644 --- a/tests/st/ops/gpu/test_fake_quant_perchannel.py +++ b/tests/st/ops/gpu/test_fake_quant_perchannel.py @@ -21,7 +21,7 @@ from mindspore.common.tensor import Tensor from mindspore import nn from mindspore.ops.operations import _quant_ops as Q -context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU', device_id=0) +context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') class Net(nn.Cell): diff --git a/tests/st/ops/gpu/test_fake_quant_perchannel_grad.py b/tests/st/ops/gpu/test_fake_quant_perchannel_grad.py index 09f483a80c4..becdc95fae7 100644 --- a/tests/st/ops/gpu/test_fake_quant_perchannel_grad.py +++ b/tests/st/ops/gpu/test_fake_quant_perchannel_grad.py @@ -20,7 +20,7 @@ import mindspore.nn as nn import mindspore.context as context from mindspore.ops.operations import _quant_ops as Q -context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU', device_id=0) +context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') class Net(nn.Cell): diff --git a/tests/st/ops/gpu/test_fake_quant_perlayer_grad.py b/tests/st/ops/gpu/test_fake_quant_perlayer_grad.py index 9289d2f954e..6685aa3ff26 100644 --- a/tests/st/ops/gpu/test_fake_quant_perlayer_grad.py +++ b/tests/st/ops/gpu/test_fake_quant_perlayer_grad.py @@ -20,7 +20,7 @@ import mindspore.nn as nn import mindspore.context as context from mindspore.ops.operations import _quant_ops as Q -context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU', device_id=0) +context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') class Net(nn.Cell): diff --git a/tests/ut/python/parallel/test_auto_parallel_resnet_predict.py b/tests/ut/python/parallel/test_auto_parallel_resnet_predict.py index 7510b02c265..15a8d647d3f 100644 --- a/tests/ut/python/parallel/test_auto_parallel_resnet_predict.py +++ b/tests/ut/python/parallel/test_auto_parallel_resnet_predict.py @@ -21,9 +21,8 @@ from mindspore.context import ParallelMode from mindspore.communication._comm_helper import GlobalComm from .test_auto_parallel_resnet import resnet50 - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") -context.set_context(device_id=0) +context.set_context.__wrapped__(device_id=0) GlobalComm.CHECK_ENVS = False init() GlobalComm.CHECK_ENVS = True diff --git a/tests/ut/python/parallel/test_auto_parallel_resnet_sharding_propagation.py b/tests/ut/python/parallel/test_auto_parallel_resnet_sharding_propagation.py index f021485b07f..947c454c052 100644 --- a/tests/ut/python/parallel/test_auto_parallel_resnet_sharding_propagation.py +++ b/tests/ut/python/parallel/test_auto_parallel_resnet_sharding_propagation.py @@ -34,7 +34,7 @@ from mindspore.context import ParallelMode from mindspore.communication._comm_helper import GlobalComm context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") -context.set_context(device_id=0) +context.set_context.__wrapped__(device_id=0) GlobalComm.CHECK_ENVS = False init() GlobalComm.CHECK_ENVS = True diff --git a/tests/ut/python/parallel/test_auto_parallel_resnet_sharding_propagation2.py b/tests/ut/python/parallel/test_auto_parallel_resnet_sharding_propagation2.py index 977dc8d0a23..f019979abfd 100644 --- a/tests/ut/python/parallel/test_auto_parallel_resnet_sharding_propagation2.py +++ b/tests/ut/python/parallel/test_auto_parallel_resnet_sharding_propagation2.py @@ -33,7 +33,7 @@ from mindspore.context import ParallelMode from mindspore.communication._comm_helper import GlobalComm context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") -context.set_context(device_id=0) +context.set_context.__wrapped__(device_id=0) GlobalComm.CHECK_ENVS = False init() GlobalComm.CHECK_ENVS = True diff --git a/tests/ut/python/pynative_mode/test_context.py b/tests/ut/python/pynative_mode/test_context.py index 56eb983d84d..0301385f9f7 100644 --- a/tests/ut/python/pynative_mode/test_context.py +++ b/tests/ut/python/pynative_mode/test_context.py @@ -49,9 +49,8 @@ def test_switch_mode(): def test_set_device_id(): """ test_set_device_id """ with pytest.raises(TypeError): + context.set_context(device_id=1) context.set_context(device_id="cpu") - assert context.get_context("device_id") == 0 - context.set_context(device_id=1) assert context.get_context("device_id") == 1 @@ -115,14 +114,17 @@ def test_variable_memory_max_size(): context.set_context(variable_memory_max_size=True) with pytest.raises(TypeError): context.set_context(variable_memory_max_size=1) - with pytest.raises(ValueError): - context.set_context(variable_memory_max_size="") - with pytest.raises(ValueError): context.set_context(variable_memory_max_size="1G") - with pytest.raises(ValueError): - context.set_context(variable_memory_max_size="32GB") - context.set_context(variable_memory_max_size="3GB") + context.set_context.__wrapped__(variable_memory_max_size="3GB") +def test_max_device_memory_size(): + """test_max_device_memory_size""" + with pytest.raises(TypeError): + context.set_context(max_device_memory=True) + with pytest.raises(TypeError): + context.set_context(max_device_memory=1) + context.set_context(max_device_memory="3.5G") + context.set_context.__wrapped__(max_device_memory="3GB") def test_print_file_path(): """test_print_file_path""" @@ -132,8 +134,9 @@ def test_print_file_path(): def test_set_context(): """ test_set_context """ + context.set_context.__wrapped__(device_id=0) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", - device_id=0, save_graphs=True, save_graphs_path="mindspore_ir_path") + save_graphs=True, save_graphs_path="mindspore_ir_path") assert context.get_context("device_id") == 0 assert context.get_context("device_target") == "Ascend" assert context.get_context("save_graphs") diff --git a/tests/ut/python/pynative_mode/test_multigraph_sink.py b/tests/ut/python/pynative_mode/test_multigraph_sink.py index c4ef44ef5a2..cdf246e29dc 100644 --- a/tests/ut/python/pynative_mode/test_multigraph_sink.py +++ b/tests/ut/python/pynative_mode/test_multigraph_sink.py @@ -21,7 +21,6 @@ from mindspore.common.tensor import Tensor def setup_module(module): context.set_context(mode=context.PYNATIVE_MODE, save_graphs=False, device_target="Ascend") - context.set_context(device_id=0) c1 = Tensor([2], mstype.int32)