forked from mindspore-Ecosystem/mindspore
!31000 add fallback example of function
Merge pull request !31000 from lianliguang/add-fallback-builtin-function
This commit is contained in:
commit
b32daf3cb4
|
@ -14,17 +14,18 @@
|
|||
# ============================================================================
|
||||
"""Tensor implementation."""
|
||||
import numbers
|
||||
|
||||
import numpy as np
|
||||
from mindspore.communication.management import get_rank, get_group_size
|
||||
|
||||
from mindspore import log as logger
|
||||
from mindspore.communication.management import get_rank, get_group_size
|
||||
from . import dtype as mstype
|
||||
from ._register_for_tensor import tensor_operator_registry
|
||||
from .._c_expression import Tensor as Tensor_
|
||||
from .._c_expression import CSRTensor as CSRTensor_
|
||||
from .._c_expression import COOTensor as COOTensor_
|
||||
from .._checkparam import Validator as validator
|
||||
from .._c_expression import CSRTensor as CSRTensor_
|
||||
from .._c_expression import Tensor as Tensor_
|
||||
from .._checkparam import Rel
|
||||
from .._checkparam import Validator as validator
|
||||
|
||||
__all__ = ['Tensor', 'RowTensor', 'SparseTensor', 'COOTensor', 'CSRTensor']
|
||||
np_types = (np.int8, np.int16, np.int32, np.int64,
|
||||
|
@ -221,6 +222,21 @@ class Tensor(Tensor_):
|
|||
return bool(data[0])
|
||||
raise ValueError("The truth value of an array with several elements is ambiguous.")
|
||||
|
||||
def _convert_scalar_(self, data, func, message):
|
||||
if data.shape == ():
|
||||
return func(data)
|
||||
if data.shape == (1,):
|
||||
return func(data[0])
|
||||
raise ValueError(message)
|
||||
|
||||
def __int__(self):
|
||||
data = self.asnumpy()
|
||||
return self._convert_scalar_(data, int, "only one element tensors can be converted to Python scalars")
|
||||
|
||||
def __float__(self):
|
||||
data = self.asnumpy()
|
||||
return self._convert_scalar_(data, float, "only one element tensors can be converted to Python scalars")
|
||||
|
||||
def __index__(self):
|
||||
data = self.asnumpy()
|
||||
if not (data.dtype == "int8"
|
||||
|
@ -229,15 +245,15 @@ class Tensor(Tensor_):
|
|||
or data.dtype == "int64"
|
||||
or data.dtype == "bool"):
|
||||
raise ValueError("Only integer tensors of a single element can be converted to an index.")
|
||||
if data.shape == ():
|
||||
return int(data)
|
||||
if data.shape == (1,):
|
||||
return int(data[0])
|
||||
raise ValueError("Only integer tensors of a single element can be converted to an index.")
|
||||
return self._convert_scalar_(data, int,
|
||||
"Only integer tensors of a single element can be converted to an index.")
|
||||
|
||||
def __pos__(self):
|
||||
return self
|
||||
|
||||
def __abs__(self):
|
||||
return Tensor(abs(self.asnumpy()))
|
||||
|
||||
def __add__(self, other):
|
||||
return tensor_operator_registry.get('__add__')(self, other)
|
||||
|
||||
|
@ -825,7 +841,7 @@ class Tensor(Tensor_):
|
|||
if order == 'C':
|
||||
return reshape_op(self, (-1,))
|
||||
|
||||
perm = tuple(range(self.ndim-1, -1, -1))
|
||||
perm = tuple(range(self.ndim - 1, -1, -1))
|
||||
return reshape_op(trans_op(self, perm), (-1,))
|
||||
|
||||
def narrow(self, axis, start, length):
|
||||
|
@ -902,11 +918,11 @@ class Tensor(Tensor_):
|
|||
|
||||
perm = tuple(range(0, self.ndim))
|
||||
if axis2 + 1 < self.ndim:
|
||||
new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \
|
||||
perm[axis1+1:axis2] + perm[axis1:axis1+1] + perm[axis2+1:]
|
||||
new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \
|
||||
perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] + perm[axis2 + 1:]
|
||||
else:
|
||||
new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \
|
||||
perm[axis1+1:axis2] + perm[axis1:axis1+1]
|
||||
new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \
|
||||
perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1]
|
||||
|
||||
return tensor_operator_registry.get('transpose')()(self, new_perm)
|
||||
|
||||
|
@ -1723,11 +1739,11 @@ class Tensor(Tensor_):
|
|||
e = e.astype(mstype.float32)
|
||||
if offset > 0:
|
||||
e_left = tensor_operator_registry.get('fill')(dtype, (n, offset), 0)
|
||||
e_right = e[..., 0:m-offset:1]
|
||||
e_right = e[..., 0:m - offset:1]
|
||||
e = tensor_operator_registry.get('concatenate')(1)((e_left, e_right)).astype(dtype)
|
||||
elif offset < 0:
|
||||
e_upper = tensor_operator_registry.get('fill')(dtype, (-offset, m), 0)
|
||||
e_lower = e[0:n+offset:1, ...]
|
||||
e_lower = e[0:n + offset:1, ...]
|
||||
e = tensor_operator_registry.get('concatenate')(0)((e_upper, e_lower)).astype(dtype)
|
||||
e = tensor_operator_registry.get('broadcast_to')(shape)(e)
|
||||
|
||||
|
@ -1735,7 +1751,7 @@ class Tensor(Tensor_):
|
|||
res = tensor_operator_registry.get('reduce_sum')(prod.astype(mstype.float32), -1)
|
||||
|
||||
begin = ()
|
||||
for i in range(ndim-2):
|
||||
for i in range(ndim - 2):
|
||||
begin += (0,)
|
||||
last_dim_begin = max(0, -offset)
|
||||
begin += (last_dim_begin,)
|
||||
|
@ -1986,7 +2002,7 @@ class Tensor(Tensor_):
|
|||
|
||||
sort_range = tuple(range(validator.get_log2_size(tensor_operator_registry.get('shape_mul')(a.shape) + 1)))
|
||||
for _ in sort_range:
|
||||
mid = (i - -j)//2
|
||||
mid = (i - -j) // 2
|
||||
mask = less_op(v, tensor_operator_registry.get('gather_nd')(a, mid.reshape(mid.shape + (1,))))
|
||||
i = tensor_operator_registry.get('select')(mask, i, mid)
|
||||
j = tensor_operator_registry.get('select')(mask, mid, j)
|
||||
|
|
|
@ -134,6 +134,21 @@ def _dict_setitem_with_number(data, key, value):
|
|||
return F.dict_setitem(data, key, value)
|
||||
|
||||
|
||||
@setitem.register("Dictionary", "String", "List")
|
||||
def _dict_setitem_with_list(data, key, value):
|
||||
"""
|
||||
Assigns value to dictionary.
|
||||
|
||||
Inputs:
|
||||
data (dict): Data of type dict.
|
||||
key (str): Key of the data.
|
||||
value (List): Value given.
|
||||
|
||||
Outputs:
|
||||
dict, type is as same as the element type of data.
|
||||
"""
|
||||
return F.dict_setitem(data, key, value)
|
||||
|
||||
@setitem.register("Dictionary", "String", "Tuple")
|
||||
def _dict_setitem_with_tuple(data, key, value):
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,150 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback """
|
||||
import math
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import ms_function, context, Tensor, nn
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_fallback_abs_integer():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test abs(int) in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = -1
|
||||
return abs(x)
|
||||
|
||||
assert foo() == 1
|
||||
|
||||
|
||||
def test_fallback_abs_float():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test abs(float) in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = -1.0
|
||||
return abs(x)
|
||||
|
||||
assert math.isclose(foo(), 1.0, abs_tol=1e-5)
|
||||
|
||||
|
||||
def test_fallback_abs_complex():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test abs(complex) in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = complex(-1, 2)
|
||||
return abs(x)
|
||||
|
||||
assert math.isclose(foo(), abs(-1 + 2j), abs_tol=1e-5)
|
||||
|
||||
|
||||
def test_fallback_abs_numpy():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test abs(np.array) in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = abs(np.array([1, -2, 3]))
|
||||
return Tensor(x)
|
||||
|
||||
assert np.all(foo().asnumpy() == abs(np.array([-1, 2, -3])))
|
||||
|
||||
|
||||
@pytest.mark.skip("Not Supported yet need to convert C++ Tensor To python")
|
||||
def test_fallback_abs_cell_construct_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test abs(Tensor) the tensor is construct in construct function in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
class TestCell(nn.Cell):
|
||||
def construct(self):
|
||||
x = Tensor([-1, 2])
|
||||
return abs(x)
|
||||
|
||||
test_cell = TestCell()
|
||||
assert np.all(test_cell().asnumpy() == np.array([1, 2]))
|
||||
|
||||
|
||||
@pytest.mark.skip("Not Supported yet not support variable")
|
||||
def test_fallback_abs_cell_variable_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test abs(Tensor) a variable tensor in construct function in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
class TestCell(nn.Cell):
|
||||
def construct(self, y):
|
||||
x = Tensor([-1, 2])
|
||||
return abs(x + y)
|
||||
|
||||
test_cell = TestCell()
|
||||
assert np.all(test_cell(Tensor([-1, 2])).asnumpy() == np.array([2, 4]))
|
||||
|
||||
|
||||
def test_fallback_abs_cell_init_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test abs(Tensor) the tensor is construct in construct function in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
class TestCell(nn.Cell):
|
||||
def __init__(self):
|
||||
super(TestCell, self).__init__()
|
||||
self.x = Tensor([-1, 2])
|
||||
|
||||
def construct(self):
|
||||
return abs(self.x)
|
||||
|
||||
test_cell = TestCell()
|
||||
assert np.allclose(test_cell().asnumpy(), np.array([1, 2]))
|
||||
|
||||
|
||||
def test_fallback_abs_ms_function_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test abs(Tensor) the tensor is construct in ms_function
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = abs(Tensor(np.array([1, -2, 3])))
|
||||
return x
|
||||
|
||||
assert np.allclose(foo().asnumpy(), abs(np.array([-1, 2, -3])))
|
|
@ -0,0 +1,189 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback """
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import ms_function, context, Tensor
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_fallback_all_tuple():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test all(Tuple) in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = (0, 1, 2, 3)
|
||||
y = (1, 1)
|
||||
return all(x), all(y)
|
||||
|
||||
x, y = foo()
|
||||
assert (not x) and y
|
||||
|
||||
|
||||
def test_fallback_all_list():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test all(List) in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = [0, 1, 2, 3]
|
||||
y = [1, 1]
|
||||
return all(x), all(y)
|
||||
|
||||
x, y = foo()
|
||||
assert (not x) and y
|
||||
|
||||
|
||||
def test_fallback_all_numpy():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test all(numpy.array) in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = np.array([0, 1, 2, 3])
|
||||
y = np.array([1, 1])
|
||||
return all(x), all(y)
|
||||
|
||||
x, y = foo()
|
||||
assert (not x) and y
|
||||
|
||||
|
||||
def test_fallback_all_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test all(Tensor) in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
return all(Tensor(np.array([0, 1, 2, 3]))), all(Tensor(np.array([1, 1])))
|
||||
|
||||
x, y = foo()
|
||||
assert (not x) and y
|
||||
|
||||
|
||||
@pytest.mark.skip("Not support yet should convert C++ Tensor to python")
|
||||
def test_fallback_all_tensor_construct():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test all(numpy.array) in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = Tensor(np.array([0, 1, 2, 3]))
|
||||
y = Tensor(np.array([1, 1]))
|
||||
return all(x), all(y)
|
||||
|
||||
x, y = foo()
|
||||
assert (not x) and not y
|
||||
|
||||
|
||||
def test_fallback_any_tuple():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test any(Tuple) in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = (0, 0, 0, 0)
|
||||
y = (1, 0)
|
||||
return any(x), any(y)
|
||||
|
||||
x, y = foo()
|
||||
assert (not x) and y
|
||||
|
||||
|
||||
def test_fallback_any_list():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test any(List) in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = [0, 0, 0, 0]
|
||||
y = [1, 0]
|
||||
return any(x), any(y)
|
||||
|
||||
x, y = foo()
|
||||
assert (not x) and y
|
||||
|
||||
|
||||
def test_fallback_any_numpy():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test any(numpy.array) in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = np.array([0, 0, 0])
|
||||
y = np.array([1, 0])
|
||||
return any(x), any(y)
|
||||
|
||||
x, y = foo()
|
||||
assert (not x) and y
|
||||
|
||||
|
||||
def test_fallback_any_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test all(Tensor) in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
return any(Tensor(np.array([0, 0]))), any(Tensor(np.array([1, 0])))
|
||||
|
||||
x, y = foo()
|
||||
assert (not x) and y
|
||||
|
||||
|
||||
@pytest.mark.skip("Not support yet should convert C++ Tensor to python")
|
||||
def test_fallback_any_tensor_construct():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test all(Tensor) in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = Tensor(np.array([0, 0, 0]))
|
||||
y = Tensor(np.array([1, 0]))
|
||||
return any(x), any(y)
|
||||
|
||||
x, y = foo()
|
||||
assert (not x) and not y
|
|
@ -0,0 +1,239 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback """
|
||||
import math
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import ms_function, Tensor
|
||||
|
||||
|
||||
def test_fallback_bool_int():
|
||||
"""
|
||||
Feature : JIT Fallback
|
||||
Description: Test bool(int) in graph mode.
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = bool(int)
|
||||
return x
|
||||
|
||||
assert foo()
|
||||
|
||||
|
||||
def test_fallback_bool_empty():
|
||||
"""
|
||||
Feature : JIT Fallback
|
||||
Description: Test bool() in graph mode.
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = bool()
|
||||
return x
|
||||
|
||||
assert not foo()
|
||||
|
||||
|
||||
def test_fallback_bool_seq():
|
||||
"""
|
||||
Feature : JIT Fallback
|
||||
Description: Test bool(sequence) in graph mode.
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x1 = bool([1, 2, 3, 4])
|
||||
y1 = bool((1, 2))
|
||||
x2 = bool([])
|
||||
y2 = bool(tuple())
|
||||
return x1, y1, x2, y2
|
||||
x1, y1, x2, y2 = foo()
|
||||
assert x1 and y1 and not x2 and not y2
|
||||
|
||||
|
||||
def test_fallback_bool_str():
|
||||
"""
|
||||
Feature : JIT Fallback
|
||||
Description: Test bool(str) in graph mode.
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = bool("")
|
||||
y = bool("123")
|
||||
return x, y
|
||||
|
||||
x, y = foo()
|
||||
assert not x and y
|
||||
|
||||
|
||||
def test_fallback_bool_None_and_complex():
|
||||
"""
|
||||
Feature : JIT Fallback
|
||||
Description: Test bool(None) and bool(complex) in graph mode.
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x1 = bool(None)
|
||||
x2 = bool(complex(0, 0))
|
||||
x3 = bool(complex(1, 0))
|
||||
x4 = bool(complex(0, 1))
|
||||
return x1, x2, x3, x4
|
||||
|
||||
x1, x2, x3, x4 = foo()
|
||||
assert (not x1) and (not x2) and x3 and x4
|
||||
|
||||
|
||||
def test_fallback_bool_tensor():
|
||||
"""
|
||||
Feature : JIT Fallback
|
||||
Description: Test bool(Tensor) in graph mode.
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = bool(Tensor([1]))
|
||||
y = bool(Tensor([0]))
|
||||
return x, y
|
||||
|
||||
x, y = foo()
|
||||
assert x and not y
|
||||
|
||||
|
||||
@pytest.mark.skip("Not support yet should convert C++ Tensor to Python")
|
||||
def test_fallback_bool_tensor_construct():
|
||||
"""
|
||||
Feature : JIT Fallback
|
||||
Description: Test bool(Tensor) in graph mode.
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = Tensor([1])
|
||||
y = Tensor([0])
|
||||
x = bool(x)
|
||||
y = bool(y)
|
||||
return x, y
|
||||
x, y = foo()
|
||||
assert x and not y
|
||||
|
||||
|
||||
def test_fallback_float():
|
||||
"""
|
||||
Feature : JIT Fallback
|
||||
Description: Test float(int) in graph mode.
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = float(5)
|
||||
return x
|
||||
|
||||
assert math.isclose(foo(), 5.0, abs_tol=1e-5)
|
||||
|
||||
|
||||
def test_fallback_float_empty():
|
||||
"""
|
||||
Feature : JIT Fallback
|
||||
Description: Test float() in graph mode.
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = float()
|
||||
return x
|
||||
|
||||
assert math.isclose(foo(), 0.0, abs_tol=1e-5)
|
||||
|
||||
|
||||
def test_fallback_float_str():
|
||||
"""
|
||||
Feature : JIT Fallback
|
||||
Description: Test float(str) in graph mode.
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x1 = float("12.3")
|
||||
x2 = float("-12.3")
|
||||
x3 = float("1e-003")
|
||||
x4 = float("-1234\n")
|
||||
x5 = float("-Infinity")
|
||||
return x1, x2, x3, x4, x5
|
||||
|
||||
x1, x2, x3, x4, x5 = foo()
|
||||
assert math.isclose(x1, 12.3, abs_tol=1e-5) \
|
||||
and math.isclose(x2, -12.3, abs_tol=1e-5) \
|
||||
and math.isclose(x3, 1e-003, abs_tol=1e-5) \
|
||||
and math.isclose(x4, -1234, abs_tol=1e-5) \
|
||||
and x5 == float("-Infinity")
|
||||
|
||||
|
||||
def test_fallback_float_tensor():
|
||||
"""
|
||||
Feature : JIT Fallback
|
||||
Description: Test float(Tensor) in graph mode.
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = float(Tensor([1.5]))
|
||||
return x
|
||||
|
||||
assert math.isclose(foo(), 1.5, abs_tol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.skip("Not supported need to convert C++ Tensor to py")
|
||||
def test_fallback_float_tensor_construct():
|
||||
"""
|
||||
Feature : JIT Fallback
|
||||
Description: Test float(Tensor) in graph mode.
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = Tensor([1.5])
|
||||
x = float(x)
|
||||
return x
|
||||
|
||||
assert math.isclose(foo(), 1.5, abs_tol=1e-5)
|
||||
|
||||
|
||||
def test_fallback_float_numpy():
|
||||
"""
|
||||
Feature : JIT Fallback
|
||||
Description: Test float(np.array) in graph mode.
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = float(np.array([1.5]))
|
||||
return x
|
||||
|
||||
assert math.isclose(foo(), 1.5, abs_tol=1e-5)
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback """
|
||||
|
||||
from mindspore import ms_function
|
||||
|
||||
|
||||
def test_fallback_dict_empty():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test dict() in graoh mode.
|
||||
Expectation:No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
dict_x = dict()
|
||||
dict_x['a'] = [1, 2, 3]
|
||||
return dict_x["a"]
|
||||
|
||||
assert foo() == (1, 2, 3)
|
||||
|
||||
|
||||
def test_fallback_dict_zip_iter_assign():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test dict() in graoh mode.
|
||||
Expectation:No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
dict_x1 = dict(zip(['one', 'two', 'three'], [1, 2, 3]))
|
||||
dict_x2 = dict([("one", 1), ("two", 2)])
|
||||
dict_x3 = dict(one=1, two=2)
|
||||
dict_x4 = dict({'one': 1, 'two': 2})
|
||||
return dict_x1["one"], dict_x2["one"], dict_x3["one"], dict_x4["one"]
|
||||
|
||||
x1, x2, x3, x4 = foo()
|
||||
assert x1 == 1 and x2 == 1 and x3 == 1 and x4 == 1
|
|
@ -13,53 +13,11 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback """
|
||||
import math
|
||||
import numpy as np
|
||||
from mindspore import ms_function, context, Tensor
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_fallback_abs():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test abs() in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = -1
|
||||
return abs(x)
|
||||
assert foo() == 1
|
||||
|
||||
|
||||
def test_fallback_all():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test all() in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = (0, 1, 2, 3)
|
||||
return all(x)
|
||||
assert not foo()
|
||||
|
||||
|
||||
def test_fallback_any():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test any() in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = (0, 1, 0, 0)
|
||||
return any(x)
|
||||
out = foo()
|
||||
assert out
|
||||
|
||||
|
||||
def test_fallback_bin():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -72,20 +30,6 @@ def test_fallback_bin():
|
|||
return x
|
||||
assert foo() == '0b11'
|
||||
|
||||
|
||||
def test_fallback_bool():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test bool() in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = bool(1)
|
||||
return x
|
||||
assert foo()
|
||||
|
||||
|
||||
def test_fallback_chr():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -115,19 +59,6 @@ def test_fallback_complex():
|
|||
assert np.all(res.asnumpy() == expect_res)
|
||||
|
||||
|
||||
def test_fallback_dict():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test dict() in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
dict_x = dict(a=1, b=2, c=3)
|
||||
return dict_x
|
||||
print(foo())
|
||||
|
||||
|
||||
def test_fallback_divmod():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -141,21 +72,6 @@ def test_fallback_divmod():
|
|||
assert foo() == (3, 1)
|
||||
|
||||
|
||||
def test_fallback_float():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test float() in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = float(1)
|
||||
return x
|
||||
|
||||
out = foo()
|
||||
assert math.isclose(out, 1, abs_tol=1e-5)
|
||||
|
||||
|
||||
def test_fallback_hash():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -181,20 +97,6 @@ def test_fallback_hex():
|
|||
return x
|
||||
assert foo() == '0xff'
|
||||
|
||||
|
||||
def test_fallback_int():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test int() in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = int(5.0)
|
||||
return x
|
||||
assert foo() == 5
|
||||
|
||||
|
||||
def test_fallback_oct():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
|
Loading…
Reference in New Issue