support bool str function

This commit is contained in:
liangzhibo 2022-08-26 15:17:46 +08:00
parent c443ef283a
commit 9467cdd520
18 changed files with 417 additions and 11 deletions

View File

@ -62,6 +62,7 @@
"mindspore/mindspore/python/mindspore/ops/functional.py" "unused-wildcard-import" "mindspore/mindspore/python/mindspore/ops/functional.py" "unused-wildcard-import"
"mindspore/mindspore/python/mindspore/_extends/parse/standard_method.py" "redefined-builtin" "mindspore/mindspore/python/mindspore/_extends/parse/standard_method.py" "redefined-builtin"
"mindspore/mindspore/python/mindspore/_extends/parse/standard_method.py" "protected-access" "mindspore/mindspore/python/mindspore/_extends/parse/standard_method.py" "protected-access"
"mindspore/mindspore/python/mindspore/_extends/parse/standard_method.py" "len-as-condition"
"mindspore/mindspore/python/mindspore/common/tensor.py" "redefined-builtin" "mindspore/mindspore/python/mindspore/common/tensor.py" "redefined-builtin"
"mindspore/mindspore/python/mindspore/ops/function/array_func.py" "redefined-builtin" "mindspore/mindspore/python/mindspore/ops/function/array_func.py" "redefined-builtin"
"mindspore/mindspore/python/mindspore/ops/operations/array_ops.py" "redefined-builtin" "mindspore/mindspore/python/mindspore/ops/operations/array_ops.py" "redefined-builtin"

View File

@ -100,7 +100,7 @@ _builtin_function_or_method_type = type(abs)
# Unsupported python builtin type in graph mode. # Unsupported python builtin type in graph mode.
_unsupported_python_builtin_type = ( _unsupported_python_builtin_type = (
set, dict, slice, bool, str, complex, reversed, type, set, dict, slice, complex, reversed, type,
) )
_unsupported_internal_type = ( _unsupported_internal_type = (

View File

@ -182,4 +182,6 @@ convert_class_to_function_map = {
"class 'tuple'": M.tuple_func, "class 'tuple'": M.tuple_func,
"class 'int'": M.int_func, "class 'int'": M.int_func,
"class 'float'": M.float_func, "class 'float'": M.float_func,
"class 'bool'": M.bool_func,
"class 'str'": M.str_func
} }

View File

@ -1870,25 +1870,74 @@ def ms_round(*data):
@constexpr @constexpr
def cast_to_int(data): def cast_to_str(data):
if isinstance(data, Tensor_): return str(data)
data = Tensor(data, internal=True)
return int(data)
def str_func(*data):
"""Implementation of `str`."""
data_len = len(data)
if data_len >= 2:
const_utils.raise_type_error("str() requires 0 or 1 arguments.")
if data_len == 0:
return ''
data = data[0]
if isinstance(data, (CSRTensor, COOTensor, RowTensor)):
const_utils.raise_type_error("str() does not support sparse tensor input.")
if not F.isconstant(data):
const_utils.raise_type_error("str() does not support non-constant input.")
return cast_to_str(data)
@constexpr
def cast_to_bool(data):
return bool(data)
def bool_func(*data):
"""Implementation of `bool`."""
data_len = len(data)
if data_len >= 2:
const_utils.raise_type_error("bool() requires 0 or 1 arguments.")
if data_len == 0:
return False
data = data[0]
if isinstance(data, (CSRTensor, COOTensor, RowTensor)):
const_utils.raise_type_error("bool() does not support sparse tensor input.")
if isinstance(data, (Tensor, Tensor_)):
tensor_shape = F.shape(data)
tensor_shape_len = len(tensor_shape)
if tensor_shape_len == 0 or (tensor_shape_len == 1 and tensor_shape[0] == 1):
return data != 0
const_utils.raise_value_error("The truth value of an array with several elements is ambiguous.")
if not F.isconstant(data):
return len(data) != 0
return cast_to_bool(data)
@constexpr
def cast_to_int(*data):
target = data[0]
if isinstance(target, Tensor_):
target = Tensor(target, internal=True)
if len(data) == 1:
return int(target)
return int(target, data[1])
def int_func(*data): def int_func(*data):
"""Implementation of `int`.""" """Implementation of `int`."""
data_len = len(data) data_len = len(data)
if data_len >= 2: if data_len >= 3:
const_utils.raise_type_error("int() requires 0 or 1 arguments.") const_utils.raise_type_error("int() requires 0, 1 or 2 arguments.")
if data_len == 0: if data_len == 0:
return 0 return 0
data = data[0] target = data[0]
if not F.isconstant(data): if not F.isconstant(target):
const_utils.raise_type_error("int() does not support non-constant input.") const_utils.raise_type_error("int() does not support non-constant input.")
if isinstance(data, (CSRTensor, COOTensor, RowTensor)): if isinstance(target, (CSRTensor, COOTensor, RowTensor)):
const_utils.raise_type_error("int() does not support sparse tensor input.") const_utils.raise_type_error("int() does not support sparse tensor input.")
return cast_to_int(data) return cast_to_int(*data)
@constexpr @constexpr

View File

@ -386,3 +386,22 @@ def test_builtin_function_tuple_with_non_constant_tensor():
assert len(ret) == 2 assert len(ret) == 2
assert np.all(ret[0].asnumpy() == np.array([1, 2, 3])) assert np.all(ret[0].asnumpy() == np.array([1, 2, 3]))
assert np.all(ret[1].asnumpy() == np.array([4, 5, 6])) assert np.all(ret[1].asnumpy() == np.array([4, 5, 6]))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_bool_with_input_tensor_2():
"""
Feature: JIT Fallback
Description: Test bool() in graph mode with tensor input.
Expectation: No exception.
"""
@ms_function
def foo():
x = Tensor([10])
return bool(x)
assert foo()

View File

@ -0,0 +1,187 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test graph fallback buildin python function bool"""
import pytest
import numpy as np
from mindspore import ms_function, context, Tensor
context.set_context(mode=context.GRAPH_MODE)
def test_fallback_bool_with_input_tensor():
"""
Feature: JIT Fallback
Description: Test bool() in graph mode with tensor input.
Expectation: No exception.
"""
@ms_function
def foo(x):
return bool(x)
with pytest.raises(ValueError) as ex:
foo(Tensor([1, 2, 4]))
assert "The truth value of an array with" in str(ex.value)
def test_fallback_bool_with_input_tensor_3():
"""
Feature: JIT Fallback
Description: Test bool() in graph mode with tensor input.
Expectation: No exception.
"""
@ms_function
def foo():
x = Tensor([0])
return bool(x)
assert not foo()
def test_fallback_bool_with_input_tensor_4():
"""
Feature: JIT Fallback
Description: Test bool() in graph mode with tensor input.
Expectation: No exception.
"""
@ms_function
def foo():
x = Tensor([1, 2, 3])
return bool(x)
with pytest.raises(ValueError) as ex:
foo()
assert "The truth value of an array with" in str(ex.value)
def test_fallback_bool_with_input_scalar():
"""
Feature: JIT Fallback
Description: Test bool() in graph mode with scalar input.
Expectation: No exception.
"""
@ms_function
def foo():
return bool(10.0)
assert foo()
def test_fallback_bool_with_input_list():
"""
Feature: JIT Fallback
Description: Test bool() in graph mode with list input.
Expectation: No exception.
"""
@ms_function
def foo():
x = [1, 2, 3]
return bool(x)
assert foo()
def test_fallback_bool_with_input_list_2():
"""
Feature: JIT Fallback
Description: Test bool() in graph mode with list input.
Expectation: No exception.
"""
@ms_function
def foo(a):
x = [1, 2, 3, a]
return bool(x)
assert foo(Tensor([1, 2, 3]))
def test_fallback_bool_with_input_string():
"""
Feature: JIT Fallback
Description: Test bool() in graph mode with string input.
Expectation: No exception.
"""
@ms_function
def foo():
return bool("test")
assert foo()
def test_fallback_bool_with_input_string_2():
"""
Feature: JIT Fallback
Description: Test bool() in graph mode with string input.
Expectation: No exception.
"""
@ms_function
def foo():
return bool("")
assert not foo()
def test_fallback_bool_with_input_numpy():
"""
Feature: JIT Fallback
Description: Test bool() in graph mode with numpy input.
Expectation: No exception.
"""
@ms_function
def foo():
x = np.array([1, 2, 3, 4])
return bool(x)
with pytest.raises(ValueError) as ex:
foo()
assert "The truth value of an array" in str(ex.value)
def test_fallback_bool_with_input_numpy_2():
"""
Feature: JIT Fallback
Description: Test bool() in graph mode with numpy input.
Expectation: No exception.
"""
@ms_function
def foo():
x = np.array([1,])
return bool(x)
assert foo()
def test_fallback_bool_with_no_input():
"""
Feature: JIT Fallback
Description: Test bool() in graph mode with no input.
Expectation: No exception.
"""
@ms_function
def foo():
return bool()
assert not foo()
def test_fallback_bool_with_type_input():
"""
Feature: JIT Fallback
Description: Test bool() in graph mode with type input.
Expectation: No exception.
"""
@ms_function
def foo():
return bool(int)
assert foo()

View File

@ -173,6 +173,26 @@ def test_fallback_int_with_no_input():
assert ret == 0 assert ret == 0
def test_fallback_int_with_base_input():
"""
Feature: JIT Fallback
Description: Test int() in graph mode with string input.
Expectation: No exception.
"""
@ms_function
def foo():
x1 = int('12', 16)
x2 = int('0xa', 16)
x3 = int('10', 8)
return x1, x2, x3
ret = foo()
assert len(ret) == 3
assert ret[0] == 18
assert ret[1] == 10
assert ret[2] == 8
def test_fallback_float_with_input_tensor(): def test_fallback_float_with_input_tensor():
""" """
Feature: JIT Fallback Feature: JIT Fallback

View File

@ -0,0 +1,128 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test graph fallback buildin python function str"""
import pytest
import numpy as np
from mindspore import ms_function, context, Tensor
context.set_context(mode=context.GRAPH_MODE)
def test_fallback_str_with_input_tensor():
"""
Feature: JIT Fallback
Description: Test str() in graph mode with tensor input.
Expectation: No exception.
"""
@ms_function
def foo(x):
return str(x)
with pytest.raises(TypeError) as ex:
foo(Tensor([1, 2, 4]))
assert "str() does not support non-constant input." in str(ex.value)
def test_fallback_str_with_input_tensor_2():
"""
Feature: JIT Fallback
Description: Test str() in graph mode with tensor input.
Expectation: No exception.
"""
@ms_function
def foo():
x = Tensor([10])
return str(x)
assert foo() == "Tensor(shape=[1], dtype=Int64, value=[10])"
def test_fallback_str_with_input_scalar():
"""
Feature: JIT Fallback
Description: Test str() in graph mode with scalar input.
Expectation: No exception.
"""
@ms_function
def foo():
return str(10.0)
assert foo() == "10.0"
def test_fallback_str_with_input_list():
"""
Feature: JIT Fallback
Description: Test str() in graph mode with list input.
Expectation: No exception.
"""
@ms_function
def foo():
x = [1, 2, 3]
return str(x)
assert foo() == "[1, 2, 3]"
def test_fallback_str_with_input_string():
"""
Feature: JIT Fallback
Description: Test str() in graph mode with string input.
Expectation: No exception.
"""
@ms_function
def foo():
return str("test")
assert foo() == "test"
def test_fallback_str_with_input_numpy():
"""
Feature: JIT Fallback
Description: Test str() in graph mode with numpy input.
Expectation: No exception.
"""
@ms_function
def foo():
x = np.array([1, 2, 3])
return str(x)
assert foo() == "[1 2 3]"
def test_fallback_str_with_no_input():
"""
Feature: JIT Fallback
Description: Test str() in graph mode with no input.
Expectation: No exception.
"""
@ms_function
def foo():
return str()
assert foo() == ""
def test_fallback_str_with_type_input():
"""
Feature: JIT Fallback
Description: Test str() in graph mode with type input.
Expectation: No exception.
"""
@ms_function
def foo():
return str(int)
assert foo() == "<class 'int'>"