diff --git a/tests/st/ops/custom_ops_tbe/cus_square.py b/tests/st/ops/custom_ops_tbe/cus_square.py index 6a9e769f51f..d006f75b4cb 100644 --- a/tests/st/ops/custom_ops_tbe/cus_square.py +++ b/tests/st/ops/custom_ops_tbe/cus_square.py @@ -24,7 +24,7 @@ class CusSquare(PrimitiveWithInfer): def __init__(self): """init CusSquare""" self.init_prim_io_names(inputs=['x'], outputs=['y']) - from .square_impl import CusSquare + from square_impl import CusSquare def vm_impl(self, x): x = x.asnumpy() diff --git a/tests/st/ops/custom_ops_tbe/test_square.py b/tests/st/ops/custom_ops_tbe/test_square.py index c67edae3077..d8439000f85 100644 --- a/tests/st/ops/custom_ops_tbe/test_square.py +++ b/tests/st/ops/custom_ops_tbe/test_square.py @@ -16,7 +16,7 @@ import numpy as np import mindspore.nn as nn import mindspore.context as context from mindspore import Tensor -from .cus_square import CusSquare +from cus_square import CusSquare import pytest context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") @@ -32,6 +32,7 @@ class Net(nn.Cell): @pytest.mark.level0 @pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training @pytest.mark.env_onecard def test_net(): x = np.array([1.0, 4.0, 9.0]).astype(np.float32)