[ME][Fallback] Add tensor testcase

This commit is contained in:
Margaret_wangrui 2021-12-10 12:00:42 +08:00
parent 31b17f273d
commit 2779760409
3 changed files with 97 additions and 50 deletions

View File

@ -902,7 +902,9 @@ AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object
MS_LOG(EXCEPTION) << "Comparators can't be empty.";
}
AnfNodePtr left_node = ParseExprNode(block, left);
left_node = HandleInterpret(block, left_node, left);
AnfNodePtr right_node = ParseExprNode(block, comparators[0]);
right_node = HandleInterpret(block, right_node, comparators[0]);
MS_EXCEPTION_IF_NULL(block);
AnfNodePtr op_node = block->MakeResolveAstOp(ops[0]);
@ -1108,6 +1110,7 @@ AnfNodePtr Parser::ParseUnaryOp(const FunctionBlockPtr &block, const py::object
py::object operand = python_adapter::GetPyObjAttr(node, "operand");
AnfNodePtr operand_node = ParseExprNode(block, operand);
operand_node = HandleInterpret(block, operand_node, operand);
return block->func_graph()->NewCNodeInOrder({op_node, operand_node});
}

View File

@ -266,3 +266,97 @@ def test_binop_new_tensor():
net = BinOpNet()
print(net())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_tensor_compare():
"""
Feature: Fallback feature
Description: support compare op's interpreted nodes.
Expectation: No exception.
"""
class CompareNet(nn.Cell):
def __init__(self):
super(CompareNet, self).__init__()
def construct(self):
np_array_1 = np.array(1)
np_array_2 = np.array(2)
res = Tensor(np_array_1) < Tensor(np_array_2)
return res
compare_net = CompareNet()
print(compare_net())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_tensor_not():
"""
Feature: Fallback feature
Description: support bool op's interpreted nodes.
Expectation: No exception.
"""
class NotNet(nn.Cell):
def __init__(self):
super(NotNet, self).__init__()
def construct(self):
np_array_1 = np.array(True, dtype=np.bool_)
res = not Tensor(np_array_1)
return res
net = NotNet()
res = net()
print("res:", res)
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_fallback_tensor_and():
"""
Feature: Fallback feature
Description: support bool op's interpreted nodes.
Expectation: No exception.
"""
class AndNet(nn.Cell):
def __init__(self):
super(AndNet, self).__init__()
def construct(self):
np_array_1 = np.array(True, dtype=np.bool_)
np_array_2 = np.array(False, dtype=np.bool_)
res = Tensor(np_array_1) and Tensor(np_array_2)
return res
net = AndNet()
res = net()
print("res:", res)
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_fallback_tensor_or():
"""
Feature: Fallback feature
Description: support bool op's interpreted nodes.
Expectation: No exception.
"""
class OrNet(nn.Cell):
def __init__(self):
super(OrNet, self).__init__()
def construct(self):
np_array_1 = np.array(True, dtype=np.bool_)
np_array_2 = np.array(False, dtype=np.bool_)
res = Tensor(np_array_1) or Tensor(np_array_2)
return res
net = OrNet()
res = net()
print("res:", res)

View File

@ -65,7 +65,6 @@ def test_np_array_3():
assert np.all(res.asnumpy() == expect_res.asnumpy())
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_np_array_4():
"""
Feature: JIT Fallback
@ -108,37 +107,6 @@ def test_np_dtype_2():
assert np.all(res.asnumpy() == Tensor(np.array([1, 2, 3], dtype=np.int32)).asnumpy())
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_np_dtype_3():
"""
Feature: JIT Fallback
Description: Test numpy with dtype in graph mode.
Expectation: No exception.
"""
@ms_function
def np_dtype_3():
t = np.dtype([('age', np.int8)])
return Tensor(np.array([1, 2, 3], dtype=t))
res = np_dtype_3()
print("res:", res)
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_np_dtype_4():
"""
Feature: JIT Fallback
Description: Test numpy with dtype in graph mode.
Expectation: No exception.
"""
@ms_function
def np_dtype_4():
student = np.dtype([('name', 'S20'), ('age', 'i1'), ('marks', 'f4')])
a = np.array([('abc', 21, 50), ('xyz', 18, 75)], dtype=student)
return Tensor(a)
res = np_dtype_4()
print("res:", res)
def test_np_array_ndim():
"""
Feature: JIT Fallback
@ -153,7 +121,6 @@ def test_np_array_ndim():
assert res == 1
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_np_array_reshape_1():
"""
Feature: JIT Fallback
@ -185,7 +152,6 @@ def test_np_array_reshape_2():
print("res:", res)
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_np_array_itemsize():
"""
Feature: JIT Fallback
@ -265,22 +231,6 @@ def test_np_asarray_tuple():
assert np.all(res.asnumpy() == except_res.asnumpy())
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_np_asarray_tuple_list():
"""
Feature: JIT Fallback
Description: Test numpy with tuple list to array in graph mode.
Expectation: No exception.
"""
@ms_function
def np_asarray_tuple_list():
x = [(1, 2, 3), (4, 5)]
y = np.asarray(x)
return Tensor(y)
res = np_asarray_tuple_list()
print("res:", res)
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_np_frombuffer():
"""