forked from mindspore-Ecosystem/mindspore
Add some fallback testcases
This commit is contained in:
parent
d3ad7f8c6a
commit
7738f55f44
|
@ -29,11 +29,13 @@ context.set_context(mode=context.GRAPH_MODE)
|
|||
def add_func(x, y):
|
||||
return x + y
|
||||
|
||||
|
||||
@ms_function
|
||||
def do_increment(i):
|
||||
add_1 = F.partial(add_func, 1)
|
||||
return add_1(i)
|
||||
|
||||
|
||||
def test_increment():
|
||||
a = do_increment(9)
|
||||
assert a == 10
|
||||
|
@ -45,6 +47,7 @@ def use_monad(x, y):
|
|||
res = F.depend(res, monad.U)
|
||||
return res
|
||||
|
||||
|
||||
def test_use_monad():
|
||||
x = Tensor(1.0, mstype.float32)
|
||||
y = Tensor(1.0, mstype.float32)
|
||||
|
@ -62,6 +65,7 @@ class Net(nn.Cell):
|
|||
print(i)
|
||||
return x_len
|
||||
|
||||
|
||||
def test_builtins_len():
|
||||
net = Net()
|
||||
net()
|
||||
|
@ -75,6 +79,7 @@ def np_fallback_func():
|
|||
me_x = me_x + me_x
|
||||
return me_x
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_np_fallback_func():
|
||||
print(np_fallback_func())
|
||||
|
@ -88,6 +93,7 @@ def div_mod_func1():
|
|||
a = divmod(x, y)
|
||||
return Tensor(a)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_div_mod_func1():
|
||||
print(div_mod_func1()) # (2, 2)
|
||||
|
@ -99,6 +105,7 @@ def div_mod_func2(x, y):
|
|||
a = divmod(x, y)
|
||||
return Tensor(a)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_div_mod_func2_scalar():
|
||||
"""
|
||||
|
@ -108,6 +115,7 @@ def test_div_mod_func2_scalar():
|
|||
"""
|
||||
print(div_mod_func2(8, 3)) # (2, 2)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_div_mod_func2_tensor():
|
||||
"""
|
||||
|
@ -129,6 +137,7 @@ def select_func(cond, x, y):
|
|||
output = x
|
||||
return output
|
||||
|
||||
|
||||
def test_select_func():
|
||||
cond = Tensor([True, False])
|
||||
x = Tensor([2, 3], mstype.float32)
|
||||
|
@ -147,6 +156,7 @@ def select_func2(cond, x, y):
|
|||
output = x
|
||||
return output
|
||||
|
||||
|
||||
def test_select_func2():
|
||||
cond = Tensor([True, False])
|
||||
x = Tensor([2, 3], mstype.float32)
|
||||
|
@ -160,7 +170,62 @@ def slice_func(a, b):
|
|||
a[1:3, ::] = b
|
||||
return a
|
||||
|
||||
|
||||
def test_slice_func():
|
||||
a = Tensor(np.arange(60).reshape(3, 4, 5), dtype=mstype.float32)
|
||||
b = Tensor([1], dtype=mstype.float32)
|
||||
print(slice_func(a, b))
|
||||
|
||||
|
||||
@ms_function
|
||||
def np_fallback_func_tensor_index(x):
|
||||
array_x = tuple([2, 3, 4, 5])
|
||||
np_x = np.array(array_x).astype(np.float32)
|
||||
me_x = Tensor(np_x)
|
||||
me_x = me_x + me_x
|
||||
return me_x[x]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_np_fallback_func_tensor_index():
|
||||
"""
|
||||
Feature: Fallback feature: support Tensor index.
|
||||
Description: Fallback feature: support Tensor index.
|
||||
Expectation: Fallback feature: support Tensor index.
|
||||
"""
|
||||
x = Tensor(1, mstype.int32)
|
||||
output = np_fallback_func_tensor_index(x)
|
||||
output_expect = Tensor(6, mstype.float32)
|
||||
assert output == output_expect
|
||||
|
||||
|
||||
class ControlNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ControlNet, self).__init__()
|
||||
|
||||
def inner_function_1(self, a, b):
|
||||
return a + b
|
||||
|
||||
def inner_function_2(self, a, b):
|
||||
return a - b
|
||||
|
||||
def construct(self, x):
|
||||
a = Tensor(np.array(4), mstype.int32)
|
||||
b = Tensor(np.array(5), mstype.int32)
|
||||
if a + b > x:
|
||||
return self.inner_function_1(a, b)
|
||||
return self.inner_function_2(a, b)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_fallback_control_sink_tensor():
|
||||
"""
|
||||
Feature: Fallback feature: support define Tensor in Class construct.
|
||||
Description: Fallback feature: support define Tensor in Class construct.
|
||||
Expectation: Fallback feature: support define Tensor in Class construct.
|
||||
"""
|
||||
x = Tensor(np.array(1), mstype.int32)
|
||||
net = ControlNet()
|
||||
output = net(x)
|
||||
output_expect = Tensor(9, mstype.int32)
|
||||
assert output == output_expect
|
||||
|
|
Loading…
Reference in New Issue