forked from mindspore-Ecosystem/mindspore
!26825 [Fallback] supports the use of attr/method on the interpreted nodes
Merge pull request !26825 from huangbingjian/support_numpy_method
This commit is contained in:
commit
6f95837c0e
|
@ -844,7 +844,7 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
|
|||
|
||||
// Create the apply node
|
||||
auto attr_cnode = block->func_graph()->NewCNodeInOrder({op_node, value_node, attr_node});
|
||||
if (value_node->interpret()) {
|
||||
if (value_node->interpret() || IsPrimitiveCNode(value_node, prim::kPrimPyInterpret)) {
|
||||
attr_cnode->set_interpret(true);
|
||||
}
|
||||
return attr_cnode;
|
||||
|
|
|
@ -81,3 +81,56 @@ def test_np_tensor_list():
|
|||
tensor_list = np_tensor_list()
|
||||
print("tensor_list:", tensor_list)
|
||||
assert len(tensor_list) == 3
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_count():
|
||||
"""
|
||||
Feature: Fallback feature
|
||||
Description: support attr/method of builtin type.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_count():
|
||||
x = list([1, 2, 3])
|
||||
res = x.count(1)
|
||||
return res
|
||||
assert list_count() == 1
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_append():
|
||||
"""
|
||||
Feature: Fallback feature
|
||||
Description: support attr/method of builtin type.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_append():
|
||||
x = list([1, 2, 3])
|
||||
x.append(4)
|
||||
return Tensor(x)
|
||||
assert np.all(list_append().asnumpy() == np.array([1, 2, 3, 4]))
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_list_insert():
|
||||
"""
|
||||
Feature: Fallback feature
|
||||
Description: support attr/method of builtin type.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_insert():
|
||||
x = list([1, 3, 4])
|
||||
x.insert(1, 2)
|
||||
return Tensor(x)
|
||||
assert np.all(list_insert().asnumpy() == np.array([1, 2, 3, 4]))
|
||||
|
|
|
@ -20,7 +20,11 @@ from mindspore import ms_function, context, Tensor
|
|||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_np_linspace():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -33,16 +37,14 @@ def test_np_linspace():
|
|||
b = Tensor(np.linspace(1, 1, 10))
|
||||
c = Tensor(np.linspace(10, 20, 5, endpoint=False))
|
||||
d = Tensor(np.linspace(10, 20, 5, endpoint=True))
|
||||
e = Tensor(np.linspace(1, 10, 10, retstep=True))
|
||||
f = Tensor(np.linspace(1, 10, 10).reshape([10, 1]))
|
||||
return a, b, c, d, e, f
|
||||
a, b, c, d, e, f = np_linspace()
|
||||
e = Tensor(np.linspace(1, 10, 10).reshape([10, 1]))
|
||||
return a, b, c, d, e
|
||||
a, b, c, d, e = np_linspace()
|
||||
print("a:", a)
|
||||
print("b:", b)
|
||||
print("c:", c)
|
||||
print("d:", d)
|
||||
print("e:", e)
|
||||
print("f:", f)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -130,6 +132,7 @@ def test_np_array_advanced_index_1():
|
|||
assert np.all(e.asnumpy() == np.array([[1, 2], [4, 5], [7, 8], [10, 11]]))
|
||||
|
||||
|
||||
# Not support <class 'complex'> yet.
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_np_array_advanced_index_2():
|
||||
"""
|
||||
|
@ -178,7 +181,11 @@ def test_np_array_advanced_index_3():
|
|||
print("c:", c)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_np_reshape():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -193,24 +200,11 @@ def test_np_reshape():
|
|||
assert np.all(np_reshape().asnumpy() == np.array([[0, 1, 2, 3], [4, 5, 6, 7]]))
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_np_ndarray_flat():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test numpy.flat() method in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def np_ndarray_flat():
|
||||
x = np.arange(9).reshape(3, 3)
|
||||
out = 0
|
||||
for element in x.flat:
|
||||
out += element
|
||||
return out
|
||||
assert np_ndarray_flat() == 36
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_np_ndarray_flatten():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -225,7 +219,11 @@ def test_np_ndarray_flatten():
|
|||
assert np.all(np_ndarray_flatten().asnumpy() == np.array([0, 1, 2, 3, 4, 5, 6, 7]))
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_np_ravel():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -259,7 +257,11 @@ def test_np_transpose():
|
|||
assert np.all(np_transpose().asnumpy() == np.array([0, 1, 2, 3]))
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_np_rollaxis():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -269,13 +271,19 @@ def test_np_rollaxis():
|
|||
@ms_function
|
||||
def np_rollaxis():
|
||||
x = np.arange(8).reshape(2, 2, 2)
|
||||
tensor_x = Tensor(x)
|
||||
y = np.rollaxis(x, 2, 0)
|
||||
return x[1, 1, 0], y[1, 1, 0]
|
||||
tensor_y = Tensor(y)
|
||||
return tensor_x[1, 1, 0], tensor_y[1, 1, 0]
|
||||
x, y = np_rollaxis()
|
||||
assert x == 6 and y == 5
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_np_swapaxes():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -285,8 +293,10 @@ def test_np_swapaxes():
|
|||
@ms_function
|
||||
def np_swapaxes():
|
||||
x = np.arange(8).reshape(2, 2, 2)
|
||||
tensor_x = Tensor(x)
|
||||
y = np.swapaxes(x, 2, 0)
|
||||
return x[1, 1, 0], y[1, 1, 0]
|
||||
tensor_y = Tensor(y)
|
||||
return tensor_x[1, 1, 0], tensor_y[1, 1, 0]
|
||||
x, y = np_swapaxes()
|
||||
assert x == 6 and y == 3
|
||||
|
||||
|
|
Loading…
Reference in New Issue