Add some fallback testcases

This commit is contained in:
Margaret_wangrui 2021-10-27 17:07:33 +08:00
parent d3ad7f8c6a
commit 7738f55f44
1 changed files with 65 additions and 0 deletions

View File

@ -29,11 +29,13 @@ context.set_context(mode=context.GRAPH_MODE)
def add_func(x, y):
return x + y
@ms_function
def do_increment(i):
add_1 = F.partial(add_func, 1)
return add_1(i)
def test_increment():
a = do_increment(9)
assert a == 10
@ -45,6 +47,7 @@ def use_monad(x, y):
res = F.depend(res, monad.U)
return res
def test_use_monad():
x = Tensor(1.0, mstype.float32)
y = Tensor(1.0, mstype.float32)
@ -62,6 +65,7 @@ class Net(nn.Cell):
print(i)
return x_len
def test_builtins_len():
net = Net()
net()
@ -75,6 +79,7 @@ def np_fallback_func():
me_x = me_x + me_x
return me_x
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_np_fallback_func():
print(np_fallback_func())
@ -88,6 +93,7 @@ def div_mod_func1():
a = divmod(x, y)
return Tensor(a)
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_div_mod_func1():
print(div_mod_func1()) # (2, 2)
@ -99,6 +105,7 @@ def div_mod_func2(x, y):
a = divmod(x, y)
return Tensor(a)
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_div_mod_func2_scalar():
"""
@ -108,6 +115,7 @@ def test_div_mod_func2_scalar():
"""
print(div_mod_func2(8, 3)) # (2, 2)
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_div_mod_func2_tensor():
"""
@ -129,6 +137,7 @@ def select_func(cond, x, y):
output = x
return output
def test_select_func():
cond = Tensor([True, False])
x = Tensor([2, 3], mstype.float32)
@ -147,6 +156,7 @@ def select_func2(cond, x, y):
output = x
return output
def test_select_func2():
cond = Tensor([True, False])
x = Tensor([2, 3], mstype.float32)
@ -160,7 +170,62 @@ def slice_func(a, b):
a[1:3, ::] = b
return a
def test_slice_func():
a = Tensor(np.arange(60).reshape(3, 4, 5), dtype=mstype.float32)
b = Tensor([1], dtype=mstype.float32)
print(slice_func(a, b))
@ms_function
def np_fallback_func_tensor_index(x):
array_x = tuple([2, 3, 4, 5])
np_x = np.array(array_x).astype(np.float32)
me_x = Tensor(np_x)
me_x = me_x + me_x
return me_x[x]
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_np_fallback_func_tensor_index():
"""
Feature: Fallback feature: support Tensor index.
Description: Fallback feature: support Tensor index.
Expectation: Fallback feature: support Tensor index.
"""
x = Tensor(1, mstype.int32)
output = np_fallback_func_tensor_index(x)
output_expect = Tensor(6, mstype.float32)
assert output == output_expect
class ControlNet(nn.Cell):
def __init__(self):
super(ControlNet, self).__init__()
def inner_function_1(self, a, b):
return a + b
def inner_function_2(self, a, b):
return a - b
def construct(self, x):
a = Tensor(np.array(4), mstype.int32)
b = Tensor(np.array(5), mstype.int32)
if a + b > x:
return self.inner_function_1(a, b)
return self.inner_function_2(a, b)
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_fallback_control_sink_tensor():
"""
Feature: Fallback feature: support define Tensor in Class construct.
Description: Fallback feature: support define Tensor in Class construct.
Expectation: Fallback feature: support define Tensor in Class construct.
"""
x = Tensor(np.array(1), mstype.int32)
net = ControlNet()
output = net(x)
output_expect = Tensor(9, mstype.int32)
assert output == output_expect