forked from mindspore-Ecosystem/mindspore
!31177 Add ut for fallback python buildin function with tensor input
Merge pull request !31177 from LiangZhibo/syntax
This commit is contained in:
commit
603ff25afb
|
@ -98,6 +98,52 @@ def test_fallback_list_with_input_number():
|
|||
assert "object is not iterable" in str(ex.value)
|
||||
|
||||
|
||||
def test_fallback_list_with_input_constant_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test list() in graph mode with constant tensor.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = list(Tensor([1, 2, 3]))
|
||||
x.append(Tensor([4]))
|
||||
return x
|
||||
out = foo()
|
||||
assert isinstance(out, tuple)
|
||||
assert len(out) == 4
|
||||
assert isinstance(out[0], Tensor)
|
||||
assert out[0].asnumpy() == 1
|
||||
assert isinstance(out[1], Tensor)
|
||||
assert out[1].asnumpy() == 2
|
||||
assert isinstance(out[2], Tensor)
|
||||
assert out[2].asnumpy() == 3
|
||||
assert isinstance(out[3], Tensor)
|
||||
assert out[3].asnumpy() == 4
|
||||
|
||||
|
||||
def test_fallback_list_with_input_constant_tensor_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test list() in graph mode with constant tensor.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = list(Tensor([[1, 2], [3, 4]]))
|
||||
x.append(Tensor([5, 6]))
|
||||
return x
|
||||
out = foo()
|
||||
assert isinstance(out, tuple)
|
||||
assert len(out) == 3
|
||||
assert isinstance(out[0], Tensor)
|
||||
assert np.allclose(out[0].asnumpy(), np.array([1, 2]))
|
||||
assert isinstance(out[1], Tensor)
|
||||
assert np.allclose(out[1].asnumpy(), np.array([3, 4]))
|
||||
assert isinstance(out[2], Tensor)
|
||||
assert np.allclose(out[2].asnumpy(), np.array([5, 6]))
|
||||
|
||||
|
||||
def test_fallback_tuple_with_input_list():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -157,6 +203,46 @@ def test_fallback_tuple_with_input_numpy_array():
|
|||
assert np.allclose(np.array([1, 2, 3]), out.asnumpy())
|
||||
|
||||
|
||||
def test_fallback_tuple_with_input_constant_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test tuple() in graph mode with constant tensor.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = tuple(Tensor([1, 2, 3]))
|
||||
return x
|
||||
out = foo()
|
||||
assert isinstance(out, tuple)
|
||||
assert len(out) == 3
|
||||
assert isinstance(out[0], Tensor)
|
||||
assert out[0].asnumpy() == 1
|
||||
assert isinstance(out[1], Tensor)
|
||||
assert out[1].asnumpy() == 2
|
||||
assert isinstance(out[2], Tensor)
|
||||
assert out[2].asnumpy() == 3
|
||||
|
||||
|
||||
def test_fallback_tuple_with_input_constant_tensor_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test tuple() in graph mode with constant tensor.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = list(Tensor([[1, 2], [3, 4]]))
|
||||
return x
|
||||
out = foo()
|
||||
assert isinstance(out, tuple)
|
||||
assert len(out) == 2
|
||||
assert isinstance(out[0], Tensor)
|
||||
assert np.allclose(out[0].asnumpy(), np.array([1, 2]))
|
||||
assert isinstance(out[1], Tensor)
|
||||
assert np.allclose(out[1].asnumpy(), np.array([3, 4]))
|
||||
|
||||
|
||||
def test_fallback_tuple_with_input_number():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
|
|
@ -103,6 +103,20 @@ def test_fallback_max_with_one_input_numpy_array():
|
|||
assert out == 3
|
||||
|
||||
|
||||
def test_fallback_max_with_one_input_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test max() in graph mode with one input tensor.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = max(Tensor([1, 2, 3]))
|
||||
return x
|
||||
out = foo()
|
||||
assert out == 3
|
||||
|
||||
|
||||
def test_fallback_max_with_two_inputs_list():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
|
|
@ -12,13 +12,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback buildin python function round"""
|
||||
""" test graph fallback buildin python function sum"""
|
||||
import pytest
|
||||
from mindspore import ms_function, context
|
||||
import numpy as np
|
||||
from mindspore import ms_function, context, Tensor
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
def test_fallback_round_with_x_list_n_default():
|
||||
def test_fallback_sum_with_x_list_n_default():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test sum() in graph mode with input x list and input n default.
|
||||
|
@ -32,7 +33,7 @@ def test_fallback_round_with_x_list_n_default():
|
|||
assert out == 6
|
||||
|
||||
|
||||
def test_fallback_round_with_x_tuple_n_default():
|
||||
def test_fallback_sum_with_x_tuple_n_default():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test sum() in graph mode with input x tuple and input n default.
|
||||
|
@ -46,7 +47,7 @@ def test_fallback_round_with_x_tuple_n_default():
|
|||
assert out == 6
|
||||
|
||||
|
||||
def test_fallback_round_with_x_dict_n_default():
|
||||
def test_fallback_sum_with_x_dict_n_default():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test sum() in graph mode with input x dict and input n default.
|
||||
|
@ -60,7 +61,63 @@ def test_fallback_round_with_x_dict_n_default():
|
|||
assert out == 6
|
||||
|
||||
|
||||
def test_fallback_round_with_x_list_n_not_default():
|
||||
def test_fallback_sum_with_x_numpy_array_n_default():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test sum() in graph mode with input x numpy array and input n default.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = sum(np.array([1, 2, 3]))
|
||||
return Tensor(x)
|
||||
out = foo()
|
||||
assert out.asnumpy() == 6
|
||||
|
||||
|
||||
def test_fallback_sum_with_x_tensor_n_default():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test sum() in graph mode with input x tensor and input n default.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = sum(Tensor([1, 2, 3]))
|
||||
return x
|
||||
out = foo()
|
||||
assert out.asnumpy() == 6
|
||||
|
||||
|
||||
def test_fallback_sum_with_x_tensor_n_default_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test sum() in graph mode with input x tensor and input n default.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = sum(Tensor([[1, 1], [2, 2]]))
|
||||
return x
|
||||
out = foo()
|
||||
assert np.allclose(out.asnumpy(), np.array([3, 3]))
|
||||
|
||||
|
||||
def test_fallback_sum_with_x_numpy_array_n_default_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test sum() in graph mode with input x numpy array and input n default.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = sum(np.array([[1, 1], [2, 2]]))
|
||||
return Tensor(x)
|
||||
out = foo()
|
||||
assert np.allclose(out.asnumpy(), np.array([3, 3]))
|
||||
|
||||
|
||||
def test_fallback_sum_with_x_list_n_not_default():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test sum() in graph mode with input x list and input n not default.
|
||||
|
@ -74,7 +131,21 @@ def test_fallback_round_with_x_list_n_not_default():
|
|||
assert out == 16
|
||||
|
||||
|
||||
def test_fallback_round_with_x_tuple_n_not_default():
|
||||
def test_fallback_sum_with_x_tensor_n_not_default():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test sum() in graph mode with input x tensor and input n not default.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = sum(Tensor([1, 2, 3]), 10)
|
||||
return x
|
||||
out = foo()
|
||||
assert out == 16
|
||||
|
||||
|
||||
def test_fallback_sum_with_x_tuple_n_not_default():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test sum() in graph mode with input x tuple and input n not default.
|
||||
|
@ -88,7 +159,7 @@ def test_fallback_round_with_x_tuple_n_not_default():
|
|||
assert out == 16
|
||||
|
||||
|
||||
def test_fallback_round_with_x_dict_n_not_default():
|
||||
def test_fallback_sum_with_x_dict_n_not_default():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test sum() in graph mode with input x dict and input n not default.
|
||||
|
@ -102,7 +173,21 @@ def test_fallback_round_with_x_dict_n_not_default():
|
|||
assert out == 16
|
||||
|
||||
|
||||
def test_fallback_round_with_x_not_iterable():
|
||||
def test_fallback_sum_with_x_numpy_array_n_not_default():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test sum() in graph mode with input x numpy array and input n default.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = sum(np.array([[1, 1], [2, 2]]), 5)
|
||||
return Tensor(x)
|
||||
out = foo()
|
||||
assert np.allclose(out.asnumpy(), np.array([8, 8]))
|
||||
|
||||
|
||||
def test_fallback_sum_with_x_not_iterable():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test sum() in graph mode with input x not iterable.
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
""" test graph fallback buildin python function type"""
|
||||
import numpy as np
|
||||
from mindspore import ms_function, context
|
||||
from mindspore import ms_function, context, Tensor
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -100,3 +100,17 @@ def test_fallback_type_with_input_numpy_array():
|
|||
return x
|
||||
out = foo()
|
||||
assert str(out) == "<class 'numpy.ndarray'>"
|
||||
|
||||
|
||||
def test_fallback_type_with_input_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test type() in graph mode with tensor input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = type(Tensor([1, 2, 3]))
|
||||
return x
|
||||
out = foo()
|
||||
assert str(out) == "<class 'mindspore.common.tensor.Tensor'>"
|
||||
|
|
Loading…
Reference in New Issue