forked from mindspore-Ecosystem/mindspore
support bool str function
This commit is contained in:
parent
c443ef283a
commit
9467cdd520
|
@ -62,6 +62,7 @@
|
|||
"mindspore/mindspore/python/mindspore/ops/functional.py" "unused-wildcard-import"
|
||||
"mindspore/mindspore/python/mindspore/_extends/parse/standard_method.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/_extends/parse/standard_method.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/_extends/parse/standard_method.py" "len-as-condition"
|
||||
"mindspore/mindspore/python/mindspore/common/tensor.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/ops/function/array_func.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/ops/operations/array_ops.py" "redefined-builtin"
|
||||
|
|
|
@ -100,7 +100,7 @@ _builtin_function_or_method_type = type(abs)
|
|||
|
||||
# Unsupported python builtin type in graph mode.
|
||||
_unsupported_python_builtin_type = (
|
||||
set, dict, slice, bool, str, complex, reversed, type,
|
||||
set, dict, slice, complex, reversed, type,
|
||||
)
|
||||
|
||||
_unsupported_internal_type = (
|
||||
|
|
|
@ -182,4 +182,6 @@ convert_class_to_function_map = {
|
|||
"class 'tuple'": M.tuple_func,
|
||||
"class 'int'": M.int_func,
|
||||
"class 'float'": M.float_func,
|
||||
"class 'bool'": M.bool_func,
|
||||
"class 'str'": M.str_func
|
||||
}
|
||||
|
|
|
@ -1870,25 +1870,74 @@ def ms_round(*data):
|
|||
|
||||
|
||||
@constexpr
|
||||
def cast_to_int(data):
|
||||
if isinstance(data, Tensor_):
|
||||
data = Tensor(data, internal=True)
|
||||
return int(data)
|
||||
def cast_to_str(data):
|
||||
return str(data)
|
||||
|
||||
|
||||
def str_func(*data):
|
||||
"""Implementation of `str`."""
|
||||
data_len = len(data)
|
||||
if data_len >= 2:
|
||||
const_utils.raise_type_error("str() requires 0 or 1 arguments.")
|
||||
if data_len == 0:
|
||||
return ''
|
||||
data = data[0]
|
||||
if isinstance(data, (CSRTensor, COOTensor, RowTensor)):
|
||||
const_utils.raise_type_error("str() does not support sparse tensor input.")
|
||||
if not F.isconstant(data):
|
||||
const_utils.raise_type_error("str() does not support non-constant input.")
|
||||
return cast_to_str(data)
|
||||
|
||||
|
||||
@constexpr
|
||||
def cast_to_bool(data):
|
||||
return bool(data)
|
||||
|
||||
|
||||
def bool_func(*data):
|
||||
"""Implementation of `bool`."""
|
||||
data_len = len(data)
|
||||
if data_len >= 2:
|
||||
const_utils.raise_type_error("bool() requires 0 or 1 arguments.")
|
||||
if data_len == 0:
|
||||
return False
|
||||
data = data[0]
|
||||
if isinstance(data, (CSRTensor, COOTensor, RowTensor)):
|
||||
const_utils.raise_type_error("bool() does not support sparse tensor input.")
|
||||
if isinstance(data, (Tensor, Tensor_)):
|
||||
tensor_shape = F.shape(data)
|
||||
tensor_shape_len = len(tensor_shape)
|
||||
if tensor_shape_len == 0 or (tensor_shape_len == 1 and tensor_shape[0] == 1):
|
||||
return data != 0
|
||||
const_utils.raise_value_error("The truth value of an array with several elements is ambiguous.")
|
||||
if not F.isconstant(data):
|
||||
return len(data) != 0
|
||||
return cast_to_bool(data)
|
||||
|
||||
|
||||
@constexpr
|
||||
def cast_to_int(*data):
|
||||
target = data[0]
|
||||
if isinstance(target, Tensor_):
|
||||
target = Tensor(target, internal=True)
|
||||
if len(data) == 1:
|
||||
return int(target)
|
||||
return int(target, data[1])
|
||||
|
||||
|
||||
def int_func(*data):
|
||||
"""Implementation of `int`."""
|
||||
data_len = len(data)
|
||||
if data_len >= 2:
|
||||
const_utils.raise_type_error("int() requires 0 or 1 arguments.")
|
||||
if data_len >= 3:
|
||||
const_utils.raise_type_error("int() requires 0, 1 or 2 arguments.")
|
||||
if data_len == 0:
|
||||
return 0
|
||||
data = data[0]
|
||||
if not F.isconstant(data):
|
||||
target = data[0]
|
||||
if not F.isconstant(target):
|
||||
const_utils.raise_type_error("int() does not support non-constant input.")
|
||||
if isinstance(data, (CSRTensor, COOTensor, RowTensor)):
|
||||
if isinstance(target, (CSRTensor, COOTensor, RowTensor)):
|
||||
const_utils.raise_type_error("int() does not support sparse tensor input.")
|
||||
return cast_to_int(data)
|
||||
return cast_to_int(*data)
|
||||
|
||||
|
||||
@constexpr
|
||||
|
|
|
@ -386,3 +386,22 @@ def test_builtin_function_tuple_with_non_constant_tensor():
|
|||
assert len(ret) == 2
|
||||
assert np.all(ret[0].asnumpy() == np.array([1, 2, 3]))
|
||||
assert np.all(ret[1].asnumpy() == np.array([4, 5, 6]))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_bool_with_input_tensor_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test bool() in graph mode with tensor input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = Tensor([10])
|
||||
return bool(x)
|
||||
|
||||
assert foo()
|
|
@ -0,0 +1,187 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback buildin python function bool"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import ms_function, context, Tensor
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_fallback_bool_with_input_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test bool() in graph mode with tensor input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo(x):
|
||||
return bool(x)
|
||||
with pytest.raises(ValueError) as ex:
|
||||
foo(Tensor([1, 2, 4]))
|
||||
assert "The truth value of an array with" in str(ex.value)
|
||||
|
||||
|
||||
def test_fallback_bool_with_input_tensor_3():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test bool() in graph mode with tensor input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = Tensor([0])
|
||||
return bool(x)
|
||||
|
||||
assert not foo()
|
||||
|
||||
|
||||
def test_fallback_bool_with_input_tensor_4():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test bool() in graph mode with tensor input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = Tensor([1, 2, 3])
|
||||
return bool(x)
|
||||
|
||||
with pytest.raises(ValueError) as ex:
|
||||
foo()
|
||||
assert "The truth value of an array with" in str(ex.value)
|
||||
|
||||
|
||||
def test_fallback_bool_with_input_scalar():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test bool() in graph mode with scalar input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
return bool(10.0)
|
||||
|
||||
assert foo()
|
||||
|
||||
|
||||
def test_fallback_bool_with_input_list():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test bool() in graph mode with list input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = [1, 2, 3]
|
||||
return bool(x)
|
||||
|
||||
assert foo()
|
||||
|
||||
|
||||
def test_fallback_bool_with_input_list_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test bool() in graph mode with list input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo(a):
|
||||
x = [1, 2, 3, a]
|
||||
return bool(x)
|
||||
|
||||
assert foo(Tensor([1, 2, 3]))
|
||||
|
||||
|
||||
def test_fallback_bool_with_input_string():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test bool() in graph mode with string input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
return bool("test")
|
||||
|
||||
assert foo()
|
||||
|
||||
|
||||
def test_fallback_bool_with_input_string_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test bool() in graph mode with string input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
return bool("")
|
||||
|
||||
assert not foo()
|
||||
|
||||
|
||||
def test_fallback_bool_with_input_numpy():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test bool() in graph mode with numpy input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = np.array([1, 2, 3, 4])
|
||||
return bool(x)
|
||||
|
||||
with pytest.raises(ValueError) as ex:
|
||||
foo()
|
||||
assert "The truth value of an array" in str(ex.value)
|
||||
|
||||
|
||||
def test_fallback_bool_with_input_numpy_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test bool() in graph mode with numpy input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = np.array([1,])
|
||||
return bool(x)
|
||||
|
||||
assert foo()
|
||||
|
||||
|
||||
def test_fallback_bool_with_no_input():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test bool() in graph mode with no input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
return bool()
|
||||
|
||||
assert not foo()
|
||||
|
||||
|
||||
def test_fallback_bool_with_type_input():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test bool() in graph mode with type input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
return bool(int)
|
||||
|
||||
assert foo()
|
|
@ -173,6 +173,26 @@ def test_fallback_int_with_no_input():
|
|||
assert ret == 0
|
||||
|
||||
|
||||
def test_fallback_int_with_base_input():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test int() in graph mode with string input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x1 = int('12', 16)
|
||||
x2 = int('0xa', 16)
|
||||
x3 = int('10', 8)
|
||||
return x1, x2, x3
|
||||
|
||||
ret = foo()
|
||||
assert len(ret) == 3
|
||||
assert ret[0] == 18
|
||||
assert ret[1] == 10
|
||||
assert ret[2] == 8
|
||||
|
||||
|
||||
def test_fallback_float_with_input_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback buildin python function str"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import ms_function, context, Tensor
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_fallback_str_with_input_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str() in graph mode with tensor input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo(x):
|
||||
return str(x)
|
||||
with pytest.raises(TypeError) as ex:
|
||||
foo(Tensor([1, 2, 4]))
|
||||
assert "str() does not support non-constant input." in str(ex.value)
|
||||
|
||||
|
||||
def test_fallback_str_with_input_tensor_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str() in graph mode with tensor input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = Tensor([10])
|
||||
return str(x)
|
||||
|
||||
assert foo() == "Tensor(shape=[1], dtype=Int64, value=[10])"
|
||||
|
||||
|
||||
def test_fallback_str_with_input_scalar():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str() in graph mode with scalar input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
return str(10.0)
|
||||
|
||||
assert foo() == "10.0"
|
||||
|
||||
|
||||
def test_fallback_str_with_input_list():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str() in graph mode with list input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = [1, 2, 3]
|
||||
return str(x)
|
||||
|
||||
assert foo() == "[1, 2, 3]"
|
||||
|
||||
|
||||
def test_fallback_str_with_input_string():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str() in graph mode with string input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
return str("test")
|
||||
|
||||
assert foo() == "test"
|
||||
|
||||
|
||||
def test_fallback_str_with_input_numpy():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str() in graph mode with numpy input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = np.array([1, 2, 3])
|
||||
return str(x)
|
||||
|
||||
assert foo() == "[1 2 3]"
|
||||
|
||||
|
||||
def test_fallback_str_with_no_input():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str() in graph mode with no input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
return str()
|
||||
|
||||
assert foo() == ""
|
||||
|
||||
|
||||
def test_fallback_str_with_type_input():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str() in graph mode with type input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
return str(int)
|
||||
|
||||
assert foo() == "<class 'int'>"
|
Loading…
Reference in New Issue