forked from mindspore-Ecosystem/mindspore
Support all(Tensor),any(Tensor),round(Tensor) in graph mode.
This commit is contained in:
parent
acd873e378
commit
1686142679
|
@ -218,7 +218,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"min", std::string("min")}, // P.reduce_min()
|
||||
{"pow", std::string("pow")}, // P.Pow()
|
||||
{"cosh", std::string("cosh")}, // P.Cosh()
|
||||
{"round", std::string("round")}, // P.Round()
|
||||
{"round", std::string("round_")}, // P.Round()
|
||||
{"fill", std::string("fill")}, // P.fill()
|
||||
{"fills", std::string("fills")}, // P.fills
|
||||
{"ptp", std::string("ptp")}, // P.reduce_max() - P.reduce_min()
|
||||
|
|
|
@ -106,7 +106,7 @@ _unsupported_internal_type = (
|
|||
)
|
||||
|
||||
_hybrid_type = (
|
||||
print, len, enumerate, zip, map, filter, abs,
|
||||
print, len, enumerate, zip, map, filter, abs, all, any, round,
|
||||
)
|
||||
|
||||
# Unsupported python builtin type in JIT Fallback.
|
||||
|
|
|
@ -128,6 +128,9 @@ convert_object_map = {
|
|||
|
||||
# system function
|
||||
T.abs: M.ms_abs,
|
||||
T.all: M.ms_all,
|
||||
T.any: M.ms_any,
|
||||
T.round: M.ms_round,
|
||||
T.len: M.ms_len,
|
||||
T.bool_: M.bool_,
|
||||
T.map: C.Map(),
|
||||
|
|
|
@ -790,7 +790,7 @@ def pow(x, y): # pylint: disable=redefined-builtin
|
|||
return F.pow(x, y)
|
||||
|
||||
|
||||
def round(x): # pylint: disable=redefined-builtin
|
||||
def round_(x):
|
||||
"""
|
||||
Returns half to even of a tensor element-wise.
|
||||
"""
|
||||
|
@ -1660,7 +1660,7 @@ def hasnext(it):
|
|||
def constant_abs(x):
|
||||
"""Returns the absolute value of the constant."""
|
||||
if x is None:
|
||||
raise ValueError("For abs(), the parameter should be a constant or Tensor type.")
|
||||
raise ValueError("For abs(), the input should be a constant or Tensor type.")
|
||||
return abs(x)
|
||||
|
||||
|
||||
|
@ -1671,6 +1671,50 @@ def ms_abs(x):
|
|||
return constant_abs(x)
|
||||
|
||||
|
||||
def ms_all(x):
|
||||
"""Implementation of `all`."""
|
||||
if isinstance(x, Tensor):
|
||||
return all_(x.astype(mstype.bool_))
|
||||
for element in x:
|
||||
if not element:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def ms_any(x):
|
||||
"""Implementation of `any`."""
|
||||
if isinstance(x, Tensor):
|
||||
return any_(x.astype(mstype.bool_))
|
||||
for element in x:
|
||||
if element:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@constexpr
|
||||
def constant_round(*data):
|
||||
"""Returns the rounded value of the constant."""
|
||||
for x in data:
|
||||
if x is None:
|
||||
raise ValueError("For round(), the input should be a Tensor or 1-2 constants.")
|
||||
return round(*data)
|
||||
|
||||
|
||||
def ms_round(*data):
|
||||
"""Implementation of `round`."""
|
||||
len_data = len(data)
|
||||
if len_data <= 0 or len_data > 2:
|
||||
const_utils.raise_type_error("round() requires 1 or 2 arguments.")
|
||||
if len_data == 1:
|
||||
x = data[0]
|
||||
if isinstance(x, Tensor):
|
||||
return round_(x)
|
||||
return constant_round(*data)
|
||||
if isinstance(data[0], Tensor) or isinstance(data[1], Tensor):
|
||||
const_utils.raise_type_error("When applying round() to tensor, only one tensor is supported as input.")
|
||||
return constant_round(*data)
|
||||
|
||||
|
||||
def ms_len(data):
|
||||
"""Implementation of `len`."""
|
||||
return data.__len__()
|
||||
|
|
|
@ -28,7 +28,7 @@ from operator import ( # noqa
|
|||
# support system function call
|
||||
from builtins import ( # noqa
|
||||
bool, getattr, setattr, len, iter, next, pow, range, map, zip,
|
||||
print, enumerate, isinstance, filter, abs
|
||||
print, enumerate, isinstance, filter, abs, all, any, round,
|
||||
)
|
||||
|
||||
# support functools
|
||||
|
@ -45,7 +45,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt',
|
|||
'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains',
|
||||
'matmul', 'getitem', 'setitem',
|
||||
'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip',
|
||||
'partial', 'print', 'enumerate', 'isinstance', 'filter', 'abs',
|
||||
'partial', 'print', 'enumerate', 'isinstance', 'filter', 'abs', 'all', 'any', 'round',
|
||||
'exp', 'log', 'sin', 'cos', 'tan']
|
||||
|
||||
|
||||
|
|
|
@ -222,6 +222,10 @@ class Tensor(Tensor_):
|
|||
out = tensor_operator_registry.get('__logical_not__')(self)
|
||||
return out
|
||||
|
||||
def __round__(self):
|
||||
out = tensor_operator_registry.get('round')()(self)
|
||||
return out
|
||||
|
||||
def __bool__(self):
|
||||
data = self.asnumpy()
|
||||
if data.shape == ():
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, context, nn
|
||||
from mindspore import dtype as mstype
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -28,14 +29,151 @@ context.set_context(mode=context.GRAPH_MODE)
|
|||
def test_fallback_abs_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test abs(Tensor) a variable tensor in construct function in graph mode
|
||||
Description: Test abs(Tensor) with a variable tensor in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
class TestCell(nn.Cell):
|
||||
class Net(nn.Cell):
|
||||
def construct(self, y):
|
||||
x = Tensor([-1, 2])
|
||||
return abs(x + y)
|
||||
|
||||
test_cell = TestCell()
|
||||
assert np.all(test_cell(Tensor([-1, 2])).asnumpy() == np.array([2, 4]))
|
||||
net = Net()
|
||||
assert np.all(net(Tensor([-1, 2])).asnumpy() == np.array([2, 4]))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_all_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test all(Tensor) with a variable tensor in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return all(x), all(y)
|
||||
|
||||
net = Net()
|
||||
x = Tensor(np.array([0, 1, 2, 3]))
|
||||
y = Tensor(np.array([1, 1]))
|
||||
out1, out2 = net(x, y)
|
||||
assert (not out1) and out2
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_all_tensor_constant():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test all(Tensor) with a constant tensor in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self):
|
||||
x = Tensor(np.array([0, 1, 2, 3]))
|
||||
y = Tensor(np.array([1, 1]))
|
||||
return all(x), all(y)
|
||||
|
||||
net = Net()
|
||||
out1, out2 = net()
|
||||
assert (not out1) and out2
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_any_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test any(Tensor) with a variable tensor in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return any(x), any(y)
|
||||
|
||||
net = Net()
|
||||
x = Tensor(np.array([0, 0]))
|
||||
y = Tensor(np.array([1, 0]))
|
||||
out1, out2 = net(x, y)
|
||||
assert (not out1) and out2
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_any_tensor_constant():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test any(Tensor) with a constant tensor in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self):
|
||||
x = Tensor(np.array([0, 0]))
|
||||
y = Tensor(np.array([1, 0]))
|
||||
return any(x), any(y)
|
||||
|
||||
net = Net()
|
||||
out1, out2 = net()
|
||||
assert (not out1) and out2
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_round_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test round(Tensor) with a variable tensor in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
return round(x)
|
||||
|
||||
net = Net()
|
||||
x = Tensor(np.array([0.1, 4.51, 9.9]), mstype.float32)
|
||||
out = net(x)
|
||||
expect = Tensor(np.array([0.0, 5.0, 10.0]))
|
||||
np.testing.assert_almost_equal(out.asnumpy(), expect.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_round_tensor_constant():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test any(Tensor) with a constant tensor in graph mode
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def construct(self):
|
||||
x = Tensor(np.array([0.1, 4.51, 9.9]), mstype.float32)
|
||||
return round(x)
|
||||
|
||||
net = Net()
|
||||
out = net()
|
||||
expect = Tensor(np.array([0.0, 5.0, 10.0]))
|
||||
np.testing.assert_almost_equal(out.asnumpy(), expect.asnumpy())
|
||||
|
|
|
@ -14,8 +14,7 @@
|
|||
# ============================================================================
|
||||
""" test graph fallback """
|
||||
import numpy as np
|
||||
|
||||
from mindspore import ms_function, context, Tensor
|
||||
from mindspore import ms_function, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -71,38 +70,6 @@ def test_fallback_all_numpy():
|
|||
assert (not x) and y
|
||||
|
||||
|
||||
def test_fallback_all_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test all(Tensor) in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
return all(Tensor(np.array([0, 1, 2, 3]))), all(Tensor(np.array([1, 1])))
|
||||
|
||||
x, y = foo()
|
||||
assert (not x) and y
|
||||
|
||||
|
||||
def test_fallback_all_tensor_construct():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test all(Tensor) in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = Tensor(np.array([0, 1, 2, 3]))
|
||||
y = Tensor(np.array([1, 1]))
|
||||
return all(x), all(y)
|
||||
|
||||
x, y = foo()
|
||||
assert (not x) and y
|
||||
|
||||
|
||||
def test_fallback_any_tuple():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -152,35 +119,3 @@ def test_fallback_any_numpy():
|
|||
|
||||
x, y = foo()
|
||||
assert (not x) and y
|
||||
|
||||
|
||||
def test_fallback_any_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test all(Tensor) in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
return any(Tensor(np.array([0, 0]))), any(Tensor(np.array([1, 0])))
|
||||
|
||||
x, y = foo()
|
||||
assert (not x) and y
|
||||
|
||||
|
||||
def test_fallback_any_tensor_construct():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test all(Tensor) in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = Tensor(np.array([0, 0, 0]))
|
||||
y = Tensor(np.array([1, 0]))
|
||||
return any(x), any(y)
|
||||
|
||||
x, y = foo()
|
||||
assert (not x) and y
|
||||
|
|
Loading…
Reference in New Issue