!40045 Support the type of the input of built-in function max and min is tensor.
Merge pull request !40045 from Margaret_wangrui/max_min_tensor
This commit is contained in:
commit
7143a5dbc5
|
@ -107,7 +107,7 @@ _unsupported_internal_type = (
|
|||
)
|
||||
|
||||
_hybrid_type = (
|
||||
print, len, enumerate, zip, map, filter, abs, all, any, round,
|
||||
print, len, enumerate, zip, map, filter, abs, all, any, round, max, min,
|
||||
)
|
||||
|
||||
# Unsupported python builtin type in JIT Fallback.
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
#
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -141,6 +141,8 @@ convert_object_map = {
|
|||
T.zip: C.zip_operation,
|
||||
T.enumerate: M.enumerate_,
|
||||
T.isinstance: Primitive('isinstance'),
|
||||
T.max: M.ms_max,
|
||||
T.min: M.ms_min,
|
||||
|
||||
# custom define operation
|
||||
T.iter: M.ms_iter,
|
||||
|
|
|
@ -1776,6 +1776,78 @@ def ms_round(*data):
|
|||
return constant_round(*data)
|
||||
|
||||
|
||||
def max_tensor(*data):
|
||||
"""Get the max of tensor inputs."""
|
||||
max_tensor_data = data[0]
|
||||
for input_data in data:
|
||||
max_tensor_data = P.Maximum()(max_tensor_data, input_data)
|
||||
return max_tensor_data
|
||||
|
||||
|
||||
def ms_max(*data):
|
||||
"""Implementation of `max`."""
|
||||
len_data = 0
|
||||
if isinstance(data, (dict, list, str, tuple)):
|
||||
len_data = len(data)
|
||||
else:
|
||||
const_utils.raise_type_error("max() does not support the data type.")
|
||||
if len_data <= 0:
|
||||
const_utils.raise_type_error("max() requires 1 argument at least.")
|
||||
elif len_data == 1:
|
||||
x = data[0]
|
||||
if isinstance(x, Tensor):
|
||||
return x.max()
|
||||
if isinstance(x, dict):
|
||||
return max_(x.keys())
|
||||
return max_(x)
|
||||
elif len_data >= 2:
|
||||
tensor_num = 0
|
||||
for input_data in data:
|
||||
if isinstance(input_data, Tensor):
|
||||
tensor_num = tensor_num + 1
|
||||
if tensor_num == len_data:
|
||||
return max_tensor(*data)
|
||||
if tensor_num != 0:
|
||||
const_utils.raise_type_error("max() cannot contain both tensor and non-tensor type.")
|
||||
return max_(*data)
|
||||
|
||||
|
||||
def min_tensor(*data):
|
||||
"""Get the min of tensor inputs."""
|
||||
min_tensor_data = data[0]
|
||||
for input_data in data:
|
||||
min_tensor_data = P.Minimum()(min_tensor_data, input_data)
|
||||
return min_tensor_data
|
||||
|
||||
|
||||
def ms_min(*data):
|
||||
"""Implementation of `min`."""
|
||||
len_data = 0
|
||||
if isinstance(data, (dict, list, str, tuple)):
|
||||
len_data = len(data)
|
||||
else:
|
||||
const_utils.raise_type_error("min() does not support the data type.")
|
||||
if len_data <= 0:
|
||||
const_utils.raise_type_error("min() requires 1 argument at least.")
|
||||
elif len_data == 1:
|
||||
x = data[0]
|
||||
if isinstance(x, Tensor):
|
||||
return x.min()
|
||||
if isinstance(x, dict):
|
||||
return min_(x.keys())
|
||||
return min_(x)
|
||||
elif len_data >= 2:
|
||||
tensor_num = 0
|
||||
for input_data in data:
|
||||
if isinstance(input_data, Tensor):
|
||||
tensor_num = tensor_num + 1
|
||||
if tensor_num == len_data:
|
||||
return min_tensor(*data)
|
||||
if tensor_num != 0:
|
||||
const_utils.raise_type_error("min() cannot contain both tensor and non-tensor type.")
|
||||
return min_(*data)
|
||||
|
||||
|
||||
def ms_len(data):
|
||||
"""Implementation of `len`."""
|
||||
return data.__len__()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
#
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -29,7 +29,7 @@ from operator import ( # noqa
|
|||
# support system function call
|
||||
from builtins import ( # noqa
|
||||
bool, getattr, setattr, len, iter, next, pow, range, map, zip,
|
||||
print, enumerate, isinstance, filter, abs, all, any, round,
|
||||
print, enumerate, isinstance, filter, abs, all, any, round, max, min
|
||||
)
|
||||
|
||||
# support functools
|
||||
|
@ -47,7 +47,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt',
|
|||
'matmul', 'getitem', 'setitem',
|
||||
'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip',
|
||||
'partial', 'print', 'enumerate', 'isinstance', 'filter', 'abs', 'all', 'any', 'round',
|
||||
'exp', 'log', 'sin', 'cos', 'tan']
|
||||
'exp', 'log', 'sin', 'cos', 'tan', 'max', 'min']
|
||||
|
||||
|
||||
def MakeTuple(*elts): # pragma: no cover
|
||||
|
|
|
@ -50,3 +50,31 @@ def test_for_after_for_in_if_3():
|
|||
|
||||
res = func3303()
|
||||
assert res == 64
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_for_after_for_in_if_4():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def func3304():
|
||||
x = Tensor([1])
|
||||
y = Tensor([2])
|
||||
if max(x, y) == Tensor([1]) or min(x, y) == Tensor([2]):
|
||||
return x
|
||||
|
||||
z = (Tensor(1), Tensor(2), Tensor(3))
|
||||
for i in zip(z):
|
||||
x = x * i
|
||||
return x
|
||||
|
||||
res = func3304()
|
||||
assert res == 6
|
||||
|
|
|
@ -13,9 +13,10 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test python built-in functions in graph mode"""
|
||||
import operator
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, context, nn
|
||||
from mindspore import Tensor, context, nn, ms_function
|
||||
from mindspore import dtype as mstype
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
@ -177,3 +178,63 @@ def test_fallback_round_tensor_constant():
|
|||
out = net()
|
||||
expect = Tensor(np.array([0.0, 5.0, 10.0]))
|
||||
np.testing.assert_almost_equal(out.asnumpy(), expect.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_builtin_function_max_min_with_tensor():
|
||||
"""
|
||||
Feature: Support the type of the input of built-in function max is tensor.
|
||||
Description: Support the type of the input of built-in function max is tensor.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo(x, y):
|
||||
return max(x, y), min(x, y)
|
||||
|
||||
max_out, min_out = foo(Tensor([1]), Tensor([2]))
|
||||
assert operator.eq(max_out, 2)
|
||||
assert operator.eq(min_out, 1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_builtin_function_max_min_with_multiple_tensor():
|
||||
"""
|
||||
Feature: Support the type of the input of built-in function max is tensor.
|
||||
Description: Support the type of the input of built-in function max is tensor.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo(x, y, z):
|
||||
return max(x, y, z), min(x, y, z)
|
||||
|
||||
max_out, min_out = foo(Tensor([1]), Tensor([2]), Tensor([3]))
|
||||
assert operator.eq(max_out, 3)
|
||||
assert operator.eq(min_out, 1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_builtin_function_max_min_with_tensor_list():
|
||||
"""
|
||||
Feature: Support the type of the input of built-in function min is tensor.
|
||||
Description: Support the type of the input of built-in function min is tensor.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo(x):
|
||||
return min(x), max(x)
|
||||
|
||||
min_out, max_out = foo(Tensor([1, 2, 3, 4, 5], dtype=mstype.float32))
|
||||
assert operator.eq(min_out, 1)
|
||||
assert operator.eq(max_out, 5)
|
||||
|
|
|
@ -67,27 +67,3 @@ def test_for_after_for_in_if_2():
|
|||
res_x, res_y = func3302()
|
||||
assert res_x == 4
|
||||
assert res_y == 4
|
||||
|
||||
|
||||
def test_for_after_for_in_if_4():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def func3304():
|
||||
x = Tensor([1])
|
||||
y = Tensor([2])
|
||||
if max(x, y) == Tensor([1]) or min(x, y) == Tensor([2]):
|
||||
for _ in range(5):
|
||||
raise TypeError("Not expect to enter this branch")
|
||||
|
||||
z = (Tensor(1), Tensor(2), Tensor(3))
|
||||
for i in zip(z):
|
||||
x = x * i
|
||||
return x
|
||||
|
||||
res = func3304()
|
||||
assert res == 6
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# ============================================================================
|
||||
""" test graph fallback buildin python function max and min"""
|
||||
import operator
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import ms_function, context, Tensor
|
||||
|
||||
|
@ -85,10 +84,10 @@ def test_fallback_max_with_one_input_dict():
|
|||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = max({1: 'a', 2: 'b', 3: 'c'})
|
||||
x = max({'a': 1, 'b': 2, 'c': 3})
|
||||
return x
|
||||
out = foo()
|
||||
assert out == 3
|
||||
assert out == 'c'
|
||||
|
||||
|
||||
def test_fallback_max_with_one_input_numpy_array():
|
||||
|
@ -147,16 +146,32 @@ def test_fallback_min_with_two_inputs_list():
|
|||
assert operator.eq(out, (1, 2, 3))
|
||||
|
||||
|
||||
def test_fallback_builtin_function_with_non_constance_inputs():
|
||||
def test_builtin_function_max_min_with_string():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test builtin function in graph mode with non-constant inputs.
|
||||
Expectation: Raise value error.
|
||||
Feature: Support the type of the input of built-in function min is string.
|
||||
Description: Support the type of the input of built-in function min is string.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo(x, y):
|
||||
return max(x, y)
|
||||
def foo():
|
||||
return max("1, 2, 3, 4"), min("1, 2, 3, 4")
|
||||
|
||||
with pytest.raises(ValueError) as ex:
|
||||
foo(Tensor([1]), Tensor([1]))
|
||||
assert "the inputs should be constant, but found variable" in str(ex.value)
|
||||
out_max, out_min = foo()
|
||||
assert out_max == '4'
|
||||
assert out_min == ' '
|
||||
|
||||
|
||||
def test_builtin_function_max_min_with_tuple():
|
||||
"""
|
||||
Feature: Support the type of the input of built-in function min is tuple.
|
||||
Description: Support the type of the input of built-in function min is tuple.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = [('a', 1), ('A', 1), ('a', 2)]
|
||||
return max(x), min(x)
|
||||
|
||||
out_max, out_min = foo()
|
||||
assert out_max == ('a', 2)
|
||||
assert out_min == ('A', 1)
|
||||
|
|
Loading…
Reference in New Issue