!17082 set default context mode to GRAPH_MODE

Merge pull request !17082 from huangbingjian/change_context_mode
This commit is contained in:
i-robot 2021-06-10 21:21:46 +08:00 committed by Gitee
commit 21513404c4
57 changed files with 106 additions and 72 deletions

View File

@ -147,7 +147,7 @@ class _Context:
def __init__(self):
self._thread_local_info = _ThreadLocalInfo()
self._context_switches = _ContextSwitchInfo(True)
self._context_switches = _ContextSwitchInfo(False)
self._context_handle = MSContext.get_instance()
def __new__(cls, *args, **kwargs):
@ -522,7 +522,7 @@ def set_context(**kwargs):
Context should be configured before running your program. If there is no configuration,
it will automatic acquisition according to device target by default. GRAPH_MODE or
PYNATIVE_MODE can be set by `mode` attribute and both modes support all backends, default
mode is PYNATIVE_MODE.
mode is GRAPH_MODE.
When the `save_graphs` attribute is set to True, attribute of `save_graphs_path` is used to set the
intermediate compilation graph storage path. By default, the graphs are saved in the current directory.
@ -532,7 +532,7 @@ def set_context(**kwargs):
Note:
Attribute name is required for setting attributes.
The mode is not recommended to be changed after net was initialized because the implementations of some
operations are different in graph mode and pynative mode. Default: PYNATIVE_MODE.
operations are different in graph mode and pynative mode. Default: GRAPH_MODE.
Some configurations are device specific, see the below table for details:
@ -555,7 +555,7 @@ def set_context(**kwargs):
=========================== =========================== =================
Args:
mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1). Default: PYNATIVE_MODE(1).
mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1). Default: GRAPH_MODE(0).
precompile_only (bool): Whether to only precompile the network. If set, the network will only be compiled and
not executed. Default: False.
device_target (str): The target device to run, support "Ascend", "GPU", and "CPU".

View File

@ -61,7 +61,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
set_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH, MAX_CALL_DEPTH_DEFAULT);
set_param<std::string>(MS_CTX_DEVICE_TARGET, target);
set_param<int>(MS_CTX_EXECUTION_MODE, kPynativeMode);
set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
set_param<bool>(MS_CTX_ENABLE_TASK_SINK, true);
set_param<bool>(MS_CTX_IR_FUSION_FLAG, true);
set_param<bool>(MS_CTX_ENABLE_HCCL, false);

View File

@ -6,6 +6,7 @@ from mindspore.common.parameter import Parameter
from mindspore.nn import Cell
import mindspore.ops.operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
@pytest.mark.level0
@ -34,7 +35,6 @@ def test_if_by_if_basic():
class Net(Cell):
def __init__(self):
super().__init__()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
self.subnet = SubNet()
self.relu = P.ReLU()
self.add = P.Add()

View File

@ -19,6 +19,7 @@ import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import context
from mindspore.explainer._utils import (
ForwardProbe,
@ -27,6 +28,7 @@ from mindspore.explainer._utils import (
retrieve_layer_by_name)
from mindspore.explainer.explanation._attribution._backprop.backprop_utils import GradNet, get_bp_weights
context.set_context(mode=context.PYNATIVE_MODE)
class CustomNet(nn.Cell):
"""Simple net for test."""

View File

@ -18,12 +18,14 @@ import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Parameter, Tensor
from mindspore import Parameter, Tensor, context
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
from mindspore.train.serialization import export
context.set_context(mode=context.PYNATIVE_MODE)
class MeanConv(nn.Cell):
def __init__(self,

View File

@ -18,10 +18,12 @@ import pytest
import numpy as onp
import mindspore.numpy as mnp
from mindspore import context
from .utils import rand_int, rand_bool, match_array, match_res, match_meta, \
match_all_arrays, run_multi_test, to_tensor
context.set_context(mode=context.PYNATIVE_MODE)
class Cases():
def __init__(self):

View File

@ -20,11 +20,14 @@ import pytest
import numpy as onp
import mindspore.numpy as mnp
from mindspore import context
from mindspore.nn import Cell
from .utils import rand_int, run_non_kw_test, check_all_results, match_array, \
rand_bool, match_res, run_multi_test, to_tensor, match_all_arrays
context.set_context(mode=context.PYNATIVE_MODE)
class Cases():
def __init__(self):

View File

@ -18,10 +18,13 @@ import pytest
import numpy as onp
import mindspore.numpy as mnp
from mindspore import context
from .utils import rand_int, rand_bool, run_binop_test, run_logical_test, match_res, \
match_all_arrays, to_tensor
context.set_context(mode=context.PYNATIVE_MODE)
class Cases():
def __init__(self):

View File

@ -18,11 +18,14 @@ import pytest
import numpy as onp
import mindspore.numpy as mnp
from mindspore import context
from mindspore.common.dtype import dtype_to_nptype
from .utils import rand_int, rand_bool, run_binop_test, run_unary_test, run_multi_test, \
run_single_test, match_res, match_array, match_meta, match_all_arrays, to_tensor
context.set_context(mode=context.PYNATIVE_MODE)
class Cases():
def __init__(self):
self.arrs = [

View File

@ -22,7 +22,7 @@ import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import ms_function
context.set_context(device_target='CPU')
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
class NetNorm(nn.Cell):
def __init__(self):

View File

@ -21,7 +21,7 @@ import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import ms_function
context.set_context(device_target='CPU')
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
class NetOneHot(nn.Cell):

View File

@ -21,7 +21,7 @@ import mindspore.nn as nn
import mindspore.context as context
from mindspore.common.api import ms_function
context.set_context(device_target="CPU")
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
class NetReduce(nn.Cell):

View File

@ -23,7 +23,7 @@ from mindspore.common.parameter import Parameter
import mindspore.nn as nn
import mindspore.context as context
context.set_context(device_target='CPU')
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
class Transpose(nn.Cell):

View File

@ -23,7 +23,7 @@ from mindspore import Tensor, ops
from mindspore.ops import operations as P
from mindspore.common.api import ms_function
context.set_context(device_target='GPU')
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
class Net(nn.Cell):

View File

@ -60,7 +60,7 @@ class AddNet(nn.Cell):
def add(nptype):
context.set_context(device_target='GPU')
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
add_net = AddNet(nptype)
output = add_net()

View File

@ -22,8 +22,6 @@ from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.ops import operations as P
context.set_context(device_target='GPU')
class Net(nn.Cell):
def __init__(self):
@ -39,6 +37,7 @@ class Net(nn.Cell):
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_net():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
x = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.float32)
y = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.float32)
z = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.float32)

View File

@ -23,7 +23,7 @@ from mindspore.common.api import ms_function
from mindspore.ops import operations as P
from mindspore.ops.operations import _quant_ops as Q
context.set_context(device_target='GPU')
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
class Net(nn.Cell):

View File

@ -22,7 +22,7 @@ from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.ops.operations import _quant_ops as Q
context.set_context(device_target='GPU')
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
class Net(nn.Cell):

View File

@ -24,7 +24,7 @@ from mindspore.common.initializer import initializer
from mindspore.common.api import ms_function
from mindspore.ops.operations import _quant_ops as Q
context.set_context(device_target='GPU')
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
class Net(nn.Cell):

View File

@ -41,7 +41,7 @@ class ConcatV32(nn.Cell):
def axis32(nptype):
context.set_context(device_target='GPU')
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
cat = ConcatV32(nptype)
output = cat()
@ -98,7 +98,7 @@ class ConcatV43(nn.Cell):
def axis43(nptype):
context.set_context(device_target='GPU')
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
cat = ConcatV43(nptype)
output = cat()
@ -159,6 +159,8 @@ class ConcatV21(nn.Cell):
def axis21(nptype):
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
cat = ConcatV21(nptype)
output = cat()
expect = np.array([[0., 1., 0., 1., 2.],
@ -206,8 +208,9 @@ class Concat3INet(nn.Cell):
def concat_3i(nptype):
cat = Concat3INet()
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
cat = Concat3INet()
x1_np = np.random.randn(32, 4, 224, 224).astype(nptype)
x2_np = np.random.randn(32, 8, 224, 224).astype(nptype)
x3_np = np.random.randn(32, 10, 224, 224).astype(nptype)
@ -250,6 +253,7 @@ def test_concat_3i_uint8():
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_concat_3i_bool():
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
cat = Concat3INet()
x1_np = np.random.choice([True, False], (32, 4, 224, 224)).astype(np.bool)
@ -275,8 +279,9 @@ class Concat4INet(nn.Cell):
def concat_4i(nptype):
cat = Concat4INet()
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
cat = Concat4INet()
x1_np = np.random.randn(32, 4, 224, 224).astype(nptype)
x2_np = np.random.randn(32, 8, 224, 224).astype(nptype)
x3_np = np.random.randn(32, 10, 224, 224).astype(nptype)
@ -321,8 +326,9 @@ def test_concat_4i_uint8():
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_concat_4i_bool():
cat = Concat4INet()
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
cat = Concat4INet()
x1_np = np.random.choice([True, False], (32, 4, 224, 224)).astype(np.bool)
x2_np = np.random.choice([True, False], (32, 8, 224, 224)).astype(np.bool)
x3_np = np.random.choice([True, False], (32, 10, 224, 224)).astype(np.bool)

View File

@ -23,7 +23,7 @@ from mindspore.common.api import ms_function
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
context.set_context(device_target='GPU')
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
class Conv2dFilter(nn.Cell):

View File

@ -22,7 +22,7 @@ from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.ops import operations as P
context.set_context(device_target='GPU')
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
class Conv2dInput(nn.Cell):

View File

@ -22,7 +22,7 @@ from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.ops.operations import _quant_ops as Q
context.set_context(device_target='GPU')
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
class Net(nn.Cell):

View File

@ -22,7 +22,7 @@ from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.ops.operations import _quant_ops as Q
context.set_context(device_target='GPU')
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
class Net(nn.Cell):

View File

@ -65,8 +65,8 @@ def test_inplace_fusion1():
x2 = Tensor(x2_np.astype(np.float32))
x3 = Tensor(x3_np.astype(np.float32))
net = Conv2dBpropInputInplace(w1, w2)
context.set_context(device_target='GPU', mode=context.GRAPH_MODE)
net = Conv2dBpropInputInplace(w1, w2)
fusion_output = net(x1, x2, x3)
context.set_context(device_target='GPU', mode=context.PYNATIVE_MODE)

View File

@ -23,7 +23,7 @@ from mindspore.common.api import ms_function
from mindspore.ops import operations as P
def cum_prod(nptype):
context.set_context(device_target='GPU')
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
x0 = np.random.rand(2, 3, 4, 4).astype(nptype)
axis0 = 3

View File

@ -23,7 +23,7 @@ from mindspore.common.api import ms_function
from mindspore.ops import operations as P
def cum_sum(nptype):
context.set_context(device_target='GPU')
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
x0 = np.random.rand(2, 3, 4, 4).astype(nptype)
axis0 = 3

View File

@ -34,6 +34,7 @@ class Net(nn.Cell):
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_dropout():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
x_shape = [32, 16, 2, 5]
x = np.ones(x_shape).astype(np.float32)
keep_prob = 0.4

View File

@ -21,7 +21,7 @@ from mindspore.common.tensor import Tensor
from mindspore import nn
from mindspore.ops.operations import _quant_ops as Q
context.set_context(device_target='GPU', device_id=0)
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU', device_id=0)
class Net(nn.Cell):
@ -240,7 +240,7 @@ def test_fake_quant_perchannel10():
x = np.array([-0.1, 0.0, 0.1, 0.25, 0.5, 0.75,
1.0, 1.25, 1.5, 1.75, 2.0, 2.25,
63.0, 63.25, 63.5, 63.7, 63.75, 63.8,
63.9, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32)
63.9, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape((1, 4, 2, 3)).astype(np.float32)
expect = np.array([0.0, 0.0, 0.0, 0.25, 0.5, 0.75,
1.0, 1.25, 1.5, 1.75, 2.0, 2.25,
63.0, 63.25, 63.5, 63.75, 63.75, 63.75,
@ -269,7 +269,7 @@ def test_fake_quant_perchannel11():
x = np.array([-0.1, 0.0, 0.1, 0.25, 0.5, 0.75,
1.0, 1.25, 1.5, 1.75, 2.0, 2.25,
63.0, 63.25, 63.3, 63.4, 63.5, 63.6,
63.7, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32)
63.7, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape((1, 4, 2, 3)).astype(np.float32)
expect = np.array([0.0, 0.0, 0.0, 0.25, 0.5, 0.75,
1.0, 1.25, 1.5, 1.75, 2.0, 2.25,
63.0, 63.25, 63.25, 63.5, 63.5, 63.5,
@ -296,7 +296,7 @@ def test_fake_quant_perchannel12():
x = np.array([-0.3, -0.25, -0.2, 0.0, 0.25, 0.5,
0.75, 1.0, 1.25, 1.5, 1.75, 2.0,
63.0, 63.25, 63.4, 63.5, 63.6, 63.7,
100.0, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32)
100.0, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape((1, 4, 2, 3)).astype(np.float32)
expect = np.array([-0.25, -0.25, -0.25, 0.0, 0.25, 0.5,
0.75, 1.0, 1.25, 1.5, 1.75, 2.0,
63.0, 63.25, 63.5, 63.5, 63.5, 63.5,
@ -325,7 +325,7 @@ def test_fake_quant_perchannel13():
x = np.array([-0.3, -0.25, -0.2, 0.0, 0.25, 0.5,
0.75, 1.0, 1.25, 1.5, 1.75, 2.0,
63.0, 63.2, 63.25, 63.3, 63.4, 63.5,
100.0, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32)
100.0, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape((1, 4, 2, 3)).astype(np.float32)
expect = np.array([-0.25, -0.25, -0.25, 0.0, 0.25, 0.5,
0.75, 1.0, 1.25, 1.5, 1.75, 2.0,
63.0, 63.25, 63.25, 63.25, 63.25, 63.25,
@ -526,7 +526,7 @@ def test_fake_quant_perchannel22():
x = np.array([-0.1, 0.0, 0.1, 0.5, 1.0, 1.5,
1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
6.0, 6.5, 7.0, 7.4, 7.5, 7.7,
7.8, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32)
7.8, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape((1, 4, 2, 3)).astype(np.float32)
expect = np.array([0.0, 0.0, 0.0, 0.5, 1.0, 1.5,
1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
6.0, 6.5, 7.0, 7.5, 7.5, 7.5,
@ -553,7 +553,7 @@ def test_fake_quant_perchannel23():
x = np.array([-0.1, 0.0, 0.1, 0.5, 1.0, 1.5,
1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
6.0, 6.5, 6.8, 6.9, 7.0, 7.1,
7.2, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32)
7.2, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape((1, 4, 2, 3)).astype(np.float32)
expect = np.array([0.0, 0.0, 0.0, 0.5, 1.0, 1.5,
1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
6.0, 6.5, 7.0, 7.0, 7.0, 7.0,
@ -580,7 +580,7 @@ def test_fake_quant_perchannel24():
x = np.array([-0.6, -0.5, -0.4, 0.0, 0.5, 1.0,
1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
6.0, 6.5, 6.9, 7.0, 7.1, 7.7,
100.0, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32)
100.0, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape((1, 4, 2, 3)).astype(np.float32)
expect = np.array([-0.5, -0.5, -0.5, 0.0, 0.5, 1.0,
1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
6.0, 6.5, 7.0, 7.0, 7.0, 7.0,
@ -607,7 +607,7 @@ def test_fake_quant_perchannel25():
x = np.array([-0.6, -0.5, -0.4, 0.0, 0.5, 1.0,
1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
5.5, 6.0, 6.4, 6.5, 6.6, 6.7,
100.0, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32)
100.0, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape((1, 4, 2, 3)).astype(np.float32)
expect = np.array([-0.5, -0.5, -0.5, 0.0, 0.5, 1.0,
1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
5.5, 6.0, 6.5, 6.5, 6.5, 6.5,

View File

@ -20,7 +20,7 @@ import mindspore.nn as nn
import mindspore.context as context
from mindspore.ops.operations import _quant_ops as Q
context.set_context(device_target='GPU', device_id=0)
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU', device_id=0)
class Net(nn.Cell):
@ -256,7 +256,7 @@ def test_fake_quant_grad10():
x = np.array([-0.1, 0.0, 63.75, 63.8, -0.1, 0.0,
63.75, 63.8, -0.1, 0.0, 63.75, 63.8,
-0.1, 0.0, 63.75, 63.8, -0.1, 0.0,
63.75, 63.8, -0.1, 0.0, 63.75, 63.8]).reshape(4, 3, 2, 1).astype(np.float32)
63.75, 63.8, -0.1, 0.0, 63.75, 63.8]).reshape((4, 3, 2, 1)).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32)
max_val = np.array([63.65, 63.65, 63.65, 63.65]).astype(np.float32)
dout = read_dout.flatten()
@ -286,7 +286,7 @@ def test_fake_quant_grad11():
# WithVarsPerChannelDim4GradientNudgedDown_NarrowRange
read_dout = np.random.uniform(-1, 1, size=[4, 3, 2, 1]).astype('float32')
x = np.array([-0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5,
63.6, -0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5, 63.6]).reshape(4, 3, 2, 1).astype(np.float32)
63.6, -0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5, 63.6]).reshape((4, 3, 2, 1)).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32)
max_val = np.array([63.4, 63.4, 63.4, 63.4]).astype(np.float32)
dout = read_dout.flatten()
@ -318,7 +318,7 @@ def test_fake_quant_grad12():
x = np.array([-0.3, -0.25, 63.5, 63.6, -0.3, -0.25,
63.5, 63.6, -0.3, -0.25, 63.5, 63.6,
-0.3, -0.25, 63.5, 63.6, -0.3, -0.25,
63.5, 63.6, -0.3, -0.25, 63.5, 63.6]).reshape(4, 3, 2, 1).astype(np.float32)
63.5, 63.6, -0.3, -0.25, 63.5, 63.6]).reshape((4, 3, 2, 1)).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32)
max_val = np.array([63.625, 63.625, 63.625, 63.625]).astype(np.float32)
dout = read_dout.flatten()
@ -350,7 +350,7 @@ def test_fake_quant_grad13():
x = np.array([-0.3, -0.25, 63.25, 63.3, -0.3, -0.25,
63.25, 63.3, -0.3, -0.25, 63.25, 63.3,
-0.3, -0.25, 63.25, 63.3, -0.3, -0.25,
63.25, 63.3, -0.3, -0.25, 63.25, 63.3]).reshape(4, 3, 2, 1).astype(np.float32)
63.25, 63.3, -0.3, -0.25, 63.25, 63.3]).reshape((4, 3, 2, 1)).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32)
max_val = np.array([63.375, 63.375, 63.375, 63.375]).astype(np.float32)
dout = read_dout.flatten()

View File

@ -21,7 +21,7 @@ from mindspore.common.tensor import Tensor
import mindspore.nn as nn
from mindspore.ops.operations import _quant_ops as Q
context.set_context(device_target='GPU', device_id=0)
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU', device_id=0)
class Net(nn.Cell):

View File

@ -20,7 +20,7 @@ import mindspore.nn as nn
import mindspore.context as context
from mindspore.ops.operations import _quant_ops as Q
context.set_context(device_target='GPU', device_id=0)
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU', device_id=0)
class Net(nn.Cell):

View File

@ -103,6 +103,7 @@ def test_in_top_k_float32():
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_in_top_k_invalid_input():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
# k must be > 0
with pytest.raises(ValueError):
in_top_k_net = InTopKNet(0)

View File

@ -224,6 +224,7 @@ def test_index_add_int16():
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_index_add_invalid_inputs():
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
x = np.arange(2 * 3 * 4).reshape(2, 3, 4).astype(np.uint8)
y = np.ones((2, 2, 4), dtype=np.uint8)
with pytest.raises(TypeError):

View File

@ -82,6 +82,8 @@ class CustomLoss(Loss):
return self.get_loss(x, weights=2.0)
def custom_loss(nptype):
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
loss = L1Loss()
input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(nptype))
target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(nptype))

View File

@ -25,7 +25,7 @@ from mindspore.common.tensor import Tensor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
context.set_context(device_target='GPU')
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
class LstmNet(nn.Cell):

View File

@ -21,7 +21,7 @@ import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import ms_function
context.set_context(device_target='GPU')
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
class NetOneHot(nn.Cell):

View File

@ -21,6 +21,9 @@ import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class RangeNet(nn.Cell):
def __init__(self, maxlen=50):
super(RangeNet, self).__init__()
@ -34,8 +37,6 @@ class RangeNet(nn.Cell):
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_range_precision_end_equals_last_element():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
range_net = RangeNet(100)
ms_out = range_net(Tensor(1000.04, mstype.float32),
Tensor(1001.04, mstype.float32),
@ -68,8 +69,6 @@ def test_range_precision_end_equals_last_element():
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_range_int():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
range_net = RangeNet()
ms_out = range_net(Tensor(2, mstype.int32), Tensor(5, mstype.int32), Tensor(1, mstype.int32)).asnumpy()
np_expected = np.array([2, 3, 4])
@ -94,8 +93,6 @@ def test_range_int():
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_range_float():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
range_net = RangeNet()
ms_out = range_net(Tensor(2.3, mstype.float32), Tensor(5.5, mstype.float32), Tensor(1.2, mstype.float32)).asnumpy()
np_expected = np.array([2.3, 3.5, 4.7])

View File

@ -39,8 +39,6 @@ x3 = np.array([[True, True], [True, False], [False, False]])
axis3 = 1
keep_dims3 = False
context.set_context(device_target='GPU')
class ReduceAll(nn.Cell):
def __init__(self):
@ -75,6 +73,7 @@ class ReduceAll(nn.Cell):
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ReduceAll():
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
reduce_all = ReduceAll()
output = reduce_all()

View File

@ -39,8 +39,6 @@ x3 = np.array([[True, True], [True, False], [False, False]])
axis3 = 1
keep_dims3 = False
context.set_context(device_target='GPU')
class ReduceAny(nn.Cell):
def __init__(self):
@ -75,6 +73,7 @@ class ReduceAny(nn.Cell):
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ReduceAny():
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
reduce_any = ReduceAny()
output = reduce_any()

View File

@ -63,8 +63,6 @@ axis8 = ()
np_axis8 = None
keep_dims8 = True
context.set_context(device_target='GPU')
class ReduceMax(nn.Cell):
def __init__(self):
@ -123,6 +121,7 @@ class ReduceMax(nn.Cell):
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ReduceMax():
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
reduce_max = ReduceMax()
output = reduce_max()

View File

@ -84,8 +84,6 @@ axis14 = ()
np_axis14 = None
keep_dims14 = True
context.set_context(device_target='GPU')
class ReduceMean(nn.Cell):
def __init__(self):
@ -174,6 +172,7 @@ class ReduceMean(nn.Cell):
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ReduceMean():
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
reduce_mean = ReduceMean()
output = reduce_mean()

View File

@ -63,8 +63,6 @@ axis8 = ()
np_axis8 = None
keep_dims8 = True
context.set_context(device_target='GPU')
class ReduceMin(nn.Cell):
def __init__(self):
@ -123,6 +121,7 @@ class ReduceMin(nn.Cell):
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ReduceMin():
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
reduce_min = ReduceMin()
output = reduce_min()

View File

@ -86,8 +86,6 @@ axis14 = ()
np_axis14 = None
keep_dims14 = True
context.set_context(device_target='GPU')
class ReduceSum(nn.Cell):
def __init__(self):
@ -176,6 +174,7 @@ class ReduceSum(nn.Cell):
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ReduceSum():
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
reduce_sum = ReduceSum()
output = reduce_sum()

View File

@ -101,6 +101,7 @@ def test_reverse_v2_int64():
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_reverse_v2_invalid_axis():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x = Tensor(np.arange(60).reshape(1, 2, 3, 2, 5).astype(np.int32))
with pytest.raises(ValueError) as info:

View File

@ -135,6 +135,7 @@ def test_sampled_softmax_loss_none_sampler():
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sampledsoftmaxloss_reduction_invalid():
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
# Check 'reduction'
with pytest.raises(ValueError):
nn.SampledSoftmaxLoss(num_sampled=4, num_classes=7, reduction="")

View File

@ -62,6 +62,7 @@ def test_slice_4d():
x_np = np.random.randn(32, 24, 224, 224).astype(np.float32)
output_np = x_np[:, 11:18, :, :]
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x_ms = Tensor(x_np)
net = SliceNet()
output_ms = net(x_ms)

View File

@ -22,7 +22,7 @@ from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.ops.operations import _grad_ops as G
context.set_context(device_target='GPU')
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
class SliceGrad(nn.Cell):

View File

@ -85,6 +85,7 @@ class Transpose_dynamic2(nn.Cell):
return (out_1, out_2)
def transpose1(nptype):
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
transpose = Transpose(nptype)
output = transpose()
expect0 = np.array([[[0, 6, 12, 18, 24],

View File

@ -23,7 +23,6 @@ from mindspore.common import dtype as mstype
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops import operations as P
context.set_context(device_target='GPU')
class UnsortedSegmentSumNet(nn.Cell):
def __init__(self, num_segments):
@ -39,6 +38,7 @@ class UnsortedSegmentSumNet(nn.Cell):
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_1D():
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
input_x = Tensor([1, 2, 3, 4], mstype.float32)
segment_ids = Tensor([0, 0, 1, 2], mstype.int32)
num_segments = 4
@ -53,6 +53,7 @@ def test_1D():
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_2D():
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
input_x = Tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]], mstype.float32)
@ -72,6 +73,7 @@ def test_2D():
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_3D():
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
input_x = Tensor(np.arange(4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3))
segment_ids = Tensor([2, 1, 1, -1], mstype.int32)
num_segments = 5

View File

@ -157,6 +157,7 @@ def test_zeros_like_dynamic_float64():
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_zeros_like_dynamic_multiple_inputs():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = ZerosLikeDynamicNet()
x = Tensor(np.arange(4).reshape(4).astype(np.float32))

View File

@ -191,6 +191,7 @@ class ArgMaxWithValueFactory(OpsFactory):
return input_grad.asnumpy()
def forward_cmp(self):
context.set_context(mode=context.PYNATIVE_MODE, device_target=context.get_context('device_target'))
out_numpy = self.forward_numpy_impl()
out_mindspore = self.forward_mindspore_impl()
allclose_nparray(out_numpy[0], out_mindspore[0], self.loss, self.loss)

View File

@ -18,7 +18,10 @@ import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context, Tensor
context.set_context(mode=context.PYNATIVE_MODE)
weight = Tensor(np.ones([2, 2]))
conv2 = nn.Conv2d(3, 64, (3, 3), stride=2, padding=0)

View File

@ -34,6 +34,8 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
from ....mindspore_test_framework.pipeline.forward.verify_exception \
import pipeline_for_verify_exception_for_case_by_case_config
context.set_context(mode=context.PYNATIVE_MODE)
def test_expand_dims():
input_tensor = Tensor(np.array([[2, 2], [2, 2]]))

View File

@ -16,13 +16,15 @@
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context, Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception
context.set_context(mode=context.PYNATIVE_MODE)
class ExpandDimsNet(nn.Cell):
def __init__(self, axis):

View File

@ -27,7 +27,7 @@ from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
context.set_context(mode=context.PYNATIVE_MODE)
grad_by_list_with_sens = C.GradOperation(get_by_list=True, sens_param=True)

View File

@ -18,7 +18,10 @@ import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context, Tensor
context.set_context(mode=context.PYNATIVE_MODE)
weight = Tensor(np.ones([2, 2]))
conv2 = nn.Conv2d(3, 64, (3, 3), stride=2, padding=0)