!26825 [Fallback] supports the use of attr/method on the interpreted nodes

Merge pull request !26825 from huangbingjian/support_numpy_method
This commit is contained in:
i-robot 2021-11-27 03:07:24 +00:00 committed by Gitee
commit 6f95837c0e
3 changed files with 94 additions and 31 deletions

View File

@ -844,7 +844,7 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
// Create the apply node
auto attr_cnode = block->func_graph()->NewCNodeInOrder({op_node, value_node, attr_node});
if (value_node->interpret()) {
if (value_node->interpret() || IsPrimitiveCNode(value_node, prim::kPrimPyInterpret)) {
attr_cnode->set_interpret(true);
}
return attr_cnode;

View File

@ -81,3 +81,56 @@ def test_np_tensor_list():
tensor_list = np_tensor_list()
print("tensor_list:", tensor_list)
assert len(tensor_list) == 3
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_count():
"""
Feature: Fallback feature
Description: support attr/method of builtin type.
Expectation: No exception.
"""
@ms_function
def list_count():
x = list([1, 2, 3])
res = x.count(1)
return res
assert list_count() == 1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_append():
"""
Feature: Fallback feature
Description: support attr/method of builtin type.
Expectation: No exception.
"""
@ms_function
def list_append():
x = list([1, 2, 3])
x.append(4)
return Tensor(x)
assert np.all(list_append().asnumpy() == np.array([1, 2, 3, 4]))
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_list_insert():
"""
Feature: Fallback feature
Description: support attr/method of builtin type.
Expectation: No exception.
"""
@ms_function
def list_insert():
x = list([1, 3, 4])
x.insert(1, 2)
return Tensor(x)
assert np.all(list_insert().asnumpy() == np.array([1, 2, 3, 4]))

View File

@ -20,7 +20,11 @@ from mindspore import ms_function, context, Tensor
context.set_context(mode=context.GRAPH_MODE)
@pytest.mark.skip(reason='Not support graph fallback feature yet')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_np_linspace():
"""
Feature: JIT Fallback
@ -33,16 +37,14 @@ def test_np_linspace():
b = Tensor(np.linspace(1, 1, 10))
c = Tensor(np.linspace(10, 20, 5, endpoint=False))
d = Tensor(np.linspace(10, 20, 5, endpoint=True))
e = Tensor(np.linspace(1, 10, 10, retstep=True))
f = Tensor(np.linspace(1, 10, 10).reshape([10, 1]))
return a, b, c, d, e, f
a, b, c, d, e, f = np_linspace()
e = Tensor(np.linspace(1, 10, 10).reshape([10, 1]))
return a, b, c, d, e
a, b, c, d, e = np_linspace()
print("a:", a)
print("b:", b)
print("c:", c)
print("d:", d)
print("e:", e)
print("f:", f)
@pytest.mark.level0
@ -130,6 +132,7 @@ def test_np_array_advanced_index_1():
assert np.all(e.asnumpy() == np.array([[1, 2], [4, 5], [7, 8], [10, 11]]))
# Not support <class 'complex'> yet.
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_np_array_advanced_index_2():
"""
@ -178,7 +181,11 @@ def test_np_array_advanced_index_3():
print("c:", c)
@pytest.mark.skip(reason='Not support graph fallback feature yet')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_np_reshape():
"""
Feature: JIT Fallback
@ -193,24 +200,11 @@ def test_np_reshape():
assert np.all(np_reshape().asnumpy() == np.array([[0, 1, 2, 3], [4, 5, 6, 7]]))
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_np_ndarray_flat():
"""
Feature: JIT Fallback
Description: Test numpy.flat() method in graph mode.
Expectation: No exception.
"""
@ms_function
def np_ndarray_flat():
x = np.arange(9).reshape(3, 3)
out = 0
for element in x.flat:
out += element
return out
assert np_ndarray_flat() == 36
@pytest.mark.skip(reason='Not support graph fallback feature yet')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_np_ndarray_flatten():
"""
Feature: JIT Fallback
@ -225,7 +219,11 @@ def test_np_ndarray_flatten():
assert np.all(np_ndarray_flatten().asnumpy() == np.array([0, 1, 2, 3, 4, 5, 6, 7]))
@pytest.mark.skip(reason='Not support graph fallback feature yet')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_np_ravel():
"""
Feature: JIT Fallback
@ -259,7 +257,11 @@ def test_np_transpose():
assert np.all(np_transpose().asnumpy() == np.array([0, 1, 2, 3]))
@pytest.mark.skip(reason='Not support graph fallback feature yet')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_np_rollaxis():
"""
Feature: JIT Fallback
@ -269,13 +271,19 @@ def test_np_rollaxis():
@ms_function
def np_rollaxis():
x = np.arange(8).reshape(2, 2, 2)
tensor_x = Tensor(x)
y = np.rollaxis(x, 2, 0)
return x[1, 1, 0], y[1, 1, 0]
tensor_y = Tensor(y)
return tensor_x[1, 1, 0], tensor_y[1, 1, 0]
x, y = np_rollaxis()
assert x == 6 and y == 5
@pytest.mark.skip(reason='Not support graph fallback feature yet')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_np_swapaxes():
"""
Feature: JIT Fallback
@ -285,8 +293,10 @@ def test_np_swapaxes():
@ms_function
def np_swapaxes():
x = np.arange(8).reshape(2, 2, 2)
tensor_x = Tensor(x)
y = np.swapaxes(x, 2, 0)
return x[1, 1, 0], y[1, 1, 0]
tensor_y = Tensor(y)
return tensor_x[1, 1, 0], tensor_y[1, 1, 0]
x, y = np_swapaxes()
assert x == 6 and y == 3