diff --git a/tests/st/probability/distribution/test_categorical_gpu.py b/tests/st/probability/distribution/test_categorical_gpu.py index c306e7f6339..0ec57bcf4b7 100644 --- a/tests/st/probability/distribution/test_categorical_gpu.py +++ b/tests/st/probability/distribution/test_categorical_gpu.py @@ -14,13 +14,13 @@ # ============================================================================ """test cases for categorical distribution""" +import pytest import numpy as np import mindspore.context as context import mindspore.nn as nn import mindspore.nn.probability.distribution as msd from mindspore import Tensor from mindspore import dtype as ms -import pytest context.set_context(mode=context.GRAPH_MODE, device_target="GPU") diff --git a/tests/st/probability/distribution/test_cauchy_pynative.py b/tests/st/probability/distribution/test_cauchy_pynative.py index 13082d7cd87..24b626c3f76 100644 --- a/tests/st/probability/distribution/test_cauchy_pynative.py +++ b/tests/st/probability/distribution/test_cauchy_pynative.py @@ -14,13 +14,13 @@ # ============================================================================ """test cases for cauchy distribution""" +import pytest import numpy as np import mindspore.context as context import mindspore.nn as nn import mindspore.nn.probability.distribution as msd from mindspore import Tensor from mindspore import dtype as ms -import pytest context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") diff --git a/tests/st/probability/distribution/test_gamma_pynative.py b/tests/st/probability/distribution/test_gamma_pynative.py index fc80249dfea..2339b649ac9 100644 --- a/tests/st/probability/distribution/test_gamma_pynative.py +++ b/tests/st/probability/distribution/test_gamma_pynative.py @@ -14,12 +14,12 @@ # ============================================================================ """test cases for gamma distribution""" +import pytest import numpy as np import mindspore.context as context import mindspore.nn as nn import mindspore.nn.probability.distribution as msd from mindspore import dtype as ms -import pytest context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")