!40045 Support the type of the input of built-in function max and min is tensor.

Merge pull request !40045 from Margaret_wangrui/max_min_tensor
This commit is contained in:
i-robot 2022-08-11 03:14:50 +00:00 committed by Gitee
commit 7143a5dbc5
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 196 additions and 42 deletions

View File

@ -107,7 +107,7 @@ _unsupported_internal_type = (
)
_hybrid_type = (
print, len, enumerate, zip, map, filter, abs, all, any, round,
print, len, enumerate, zip, map, filter, abs, all, any, round, max, min,
)
# Unsupported python builtin type in JIT Fallback.

View File

@ -1,6 +1,6 @@
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
#
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -141,6 +141,8 @@ convert_object_map = {
T.zip: C.zip_operation,
T.enumerate: M.enumerate_,
T.isinstance: Primitive('isinstance'),
T.max: M.ms_max,
T.min: M.ms_min,
# custom define operation
T.iter: M.ms_iter,

View File

@ -1776,6 +1776,78 @@ def ms_round(*data):
return constant_round(*data)
def max_tensor(*data):
"""Get the max of tensor inputs."""
max_tensor_data = data[0]
for input_data in data:
max_tensor_data = P.Maximum()(max_tensor_data, input_data)
return max_tensor_data
def ms_max(*data):
"""Implementation of `max`."""
len_data = 0
if isinstance(data, (dict, list, str, tuple)):
len_data = len(data)
else:
const_utils.raise_type_error("max() does not support the data type.")
if len_data <= 0:
const_utils.raise_type_error("max() requires 1 argument at least.")
elif len_data == 1:
x = data[0]
if isinstance(x, Tensor):
return x.max()
if isinstance(x, dict):
return max_(x.keys())
return max_(x)
elif len_data >= 2:
tensor_num = 0
for input_data in data:
if isinstance(input_data, Tensor):
tensor_num = tensor_num + 1
if tensor_num == len_data:
return max_tensor(*data)
if tensor_num != 0:
const_utils.raise_type_error("max() cannot contain both tensor and non-tensor type.")
return max_(*data)
def min_tensor(*data):
"""Get the min of tensor inputs."""
min_tensor_data = data[0]
for input_data in data:
min_tensor_data = P.Minimum()(min_tensor_data, input_data)
return min_tensor_data
def ms_min(*data):
"""Implementation of `min`."""
len_data = 0
if isinstance(data, (dict, list, str, tuple)):
len_data = len(data)
else:
const_utils.raise_type_error("min() does not support the data type.")
if len_data <= 0:
const_utils.raise_type_error("min() requires 1 argument at least.")
elif len_data == 1:
x = data[0]
if isinstance(x, Tensor):
return x.min()
if isinstance(x, dict):
return min_(x.keys())
return min_(x)
elif len_data >= 2:
tensor_num = 0
for input_data in data:
if isinstance(input_data, Tensor):
tensor_num = tensor_num + 1
if tensor_num == len_data:
return min_tensor(*data)
if tensor_num != 0:
const_utils.raise_type_error("min() cannot contain both tensor and non-tensor type.")
return min_(*data)
def ms_len(data):
"""Implementation of `len`."""
return data.__len__()

View File

@ -1,6 +1,6 @@
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
#
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -29,7 +29,7 @@ from operator import ( # noqa
# support system function call
from builtins import ( # noqa
bool, getattr, setattr, len, iter, next, pow, range, map, zip,
print, enumerate, isinstance, filter, abs, all, any, round,
print, enumerate, isinstance, filter, abs, all, any, round, max, min
)
# support functools
@ -47,7 +47,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt',
'matmul', 'getitem', 'setitem',
'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip',
'partial', 'print', 'enumerate', 'isinstance', 'filter', 'abs', 'all', 'any', 'round',
'exp', 'log', 'sin', 'cos', 'tan']
'exp', 'log', 'sin', 'cos', 'tan', 'max', 'min']
def MakeTuple(*elts): # pragma: no cover

View File

@ -50,3 +50,31 @@ def test_for_after_for_in_if_3():
res = func3303()
assert res == 64
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_for_after_for_in_if_4():
"""
Feature: JIT Fallback
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
def func3304():
x = Tensor([1])
y = Tensor([2])
if max(x, y) == Tensor([1]) or min(x, y) == Tensor([2]):
return x
z = (Tensor(1), Tensor(2), Tensor(3))
for i in zip(z):
x = x * i
return x
res = func3304()
assert res == 6

View File

@ -13,9 +13,10 @@
# limitations under the License.
# ============================================================================
"""test python built-in functions in graph mode"""
import operator
import pytest
import numpy as np
from mindspore import Tensor, context, nn
from mindspore import Tensor, context, nn, ms_function
from mindspore import dtype as mstype
context.set_context(mode=context.GRAPH_MODE)
@ -177,3 +178,63 @@ def test_fallback_round_tensor_constant():
out = net()
expect = Tensor(np.array([0.0, 5.0, 10.0]))
np.testing.assert_almost_equal(out.asnumpy(), expect.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_builtin_function_max_min_with_tensor():
"""
Feature: Support the type of the input of built-in function max is tensor.
Description: Support the type of the input of built-in function max is tensor.
Expectation: No exception.
"""
@ms_function
def foo(x, y):
return max(x, y), min(x, y)
max_out, min_out = foo(Tensor([1]), Tensor([2]))
assert operator.eq(max_out, 2)
assert operator.eq(min_out, 1)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_builtin_function_max_min_with_multiple_tensor():
"""
Feature: Support the type of the input of built-in function max is tensor.
Description: Support the type of the input of built-in function max is tensor.
Expectation: No exception.
"""
@ms_function
def foo(x, y, z):
return max(x, y, z), min(x, y, z)
max_out, min_out = foo(Tensor([1]), Tensor([2]), Tensor([3]))
assert operator.eq(max_out, 3)
assert operator.eq(min_out, 1)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_builtin_function_max_min_with_tensor_list():
"""
Feature: Support the type of the input of built-in function min is tensor.
Description: Support the type of the input of built-in function min is tensor.
Expectation: No exception.
"""
@ms_function
def foo(x):
return min(x), max(x)
min_out, max_out = foo(Tensor([1, 2, 3, 4, 5], dtype=mstype.float32))
assert operator.eq(min_out, 1)
assert operator.eq(max_out, 5)

View File

@ -67,27 +67,3 @@ def test_for_after_for_in_if_2():
res_x, res_y = func3302()
assert res_x == 4
assert res_y == 4
def test_for_after_for_in_if_4():
"""
Feature: JIT Fallback
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
def func3304():
x = Tensor([1])
y = Tensor([2])
if max(x, y) == Tensor([1]) or min(x, y) == Tensor([2]):
for _ in range(5):
raise TypeError("Not expect to enter this branch")
z = (Tensor(1), Tensor(2), Tensor(3))
for i in zip(z):
x = x * i
return x
res = func3304()
assert res == 6

View File

@ -14,7 +14,6 @@
# ============================================================================
""" test graph fallback buildin python function max and min"""
import operator
import pytest
import numpy as np
from mindspore import ms_function, context, Tensor
@ -85,10 +84,10 @@ def test_fallback_max_with_one_input_dict():
"""
@ms_function
def foo():
x = max({1: 'a', 2: 'b', 3: 'c'})
x = max({'a': 1, 'b': 2, 'c': 3})
return x
out = foo()
assert out == 3
assert out == 'c'
def test_fallback_max_with_one_input_numpy_array():
@ -147,16 +146,32 @@ def test_fallback_min_with_two_inputs_list():
assert operator.eq(out, (1, 2, 3))
def test_fallback_builtin_function_with_non_constance_inputs():
def test_builtin_function_max_min_with_string():
"""
Feature: JIT Fallback
Description: Test builtin function in graph mode with non-constant inputs.
Expectation: Raise value error.
Feature: Support the type of the input of built-in function min is string.
Description: Support the type of the input of built-in function min is string.
Expectation: No exception.
"""
@ms_function
def foo(x, y):
return max(x, y)
def foo():
return max("1, 2, 3, 4"), min("1, 2, 3, 4")
with pytest.raises(ValueError) as ex:
foo(Tensor([1]), Tensor([1]))
assert "the inputs should be constant, but found variable" in str(ex.value)
out_max, out_min = foo()
assert out_max == '4'
assert out_min == ' '
def test_builtin_function_max_min_with_tuple():
"""
Feature: Support the type of the input of built-in function min is tuple.
Description: Support the type of the input of built-in function min is tuple.
Expectation: No exception.
"""
@ms_function
def foo():
x = [('a', 1), ('A', 1), ('a', 2)]
return max(x), min(x)
out_max, out_min = foo()
assert out_max == ('a', 2)
assert out_min == ('A', 1)