Support all(Tensor),any(Tensor),round(Tensor) in graph mode.

This commit is contained in:
huangbingjian 2022-07-25 14:46:41 +08:00
parent acd873e378
commit 1686142679
8 changed files with 200 additions and 76 deletions

View File

@ -218,7 +218,7 @@ BuiltInTypeMap &GetMethodMap() {
{"min", std::string("min")}, // P.reduce_min()
{"pow", std::string("pow")}, // P.Pow()
{"cosh", std::string("cosh")}, // P.Cosh()
{"round", std::string("round")}, // P.Round()
{"round", std::string("round_")}, // P.Round()
{"fill", std::string("fill")}, // P.fill()
{"fills", std::string("fills")}, // P.fills
{"ptp", std::string("ptp")}, // P.reduce_max() - P.reduce_min()

View File

@ -106,7 +106,7 @@ _unsupported_internal_type = (
)
_hybrid_type = (
print, len, enumerate, zip, map, filter, abs,
print, len, enumerate, zip, map, filter, abs, all, any, round,
)
# Unsupported python builtin type in JIT Fallback.

View File

@ -128,6 +128,9 @@ convert_object_map = {
# system function
T.abs: M.ms_abs,
T.all: M.ms_all,
T.any: M.ms_any,
T.round: M.ms_round,
T.len: M.ms_len,
T.bool_: M.bool_,
T.map: C.Map(),

View File

@ -790,7 +790,7 @@ def pow(x, y): # pylint: disable=redefined-builtin
return F.pow(x, y)
def round(x): # pylint: disable=redefined-builtin
def round_(x):
"""
Returns half to even of a tensor element-wise.
"""
@ -1660,7 +1660,7 @@ def hasnext(it):
def constant_abs(x):
"""Returns the absolute value of the constant."""
if x is None:
raise ValueError("For abs(), the parameter should be a constant or Tensor type.")
raise ValueError("For abs(), the input should be a constant or Tensor type.")
return abs(x)
@ -1671,6 +1671,50 @@ def ms_abs(x):
return constant_abs(x)
def ms_all(x):
"""Implementation of `all`."""
if isinstance(x, Tensor):
return all_(x.astype(mstype.bool_))
for element in x:
if not element:
return False
return True
def ms_any(x):
"""Implementation of `any`."""
if isinstance(x, Tensor):
return any_(x.astype(mstype.bool_))
for element in x:
if element:
return True
return False
@constexpr
def constant_round(*data):
"""Returns the rounded value of the constant."""
for x in data:
if x is None:
raise ValueError("For round(), the input should be a Tensor or 1-2 constants.")
return round(*data)
def ms_round(*data):
"""Implementation of `round`."""
len_data = len(data)
if len_data <= 0 or len_data > 2:
const_utils.raise_type_error("round() requires 1 or 2 arguments.")
if len_data == 1:
x = data[0]
if isinstance(x, Tensor):
return round_(x)
return constant_round(*data)
if isinstance(data[0], Tensor) or isinstance(data[1], Tensor):
const_utils.raise_type_error("When applying round() to tensor, only one tensor is supported as input.")
return constant_round(*data)
def ms_len(data):
"""Implementation of `len`."""
return data.__len__()

View File

@ -28,7 +28,7 @@ from operator import ( # noqa
# support system function call
from builtins import ( # noqa
bool, getattr, setattr, len, iter, next, pow, range, map, zip,
print, enumerate, isinstance, filter, abs
print, enumerate, isinstance, filter, abs, all, any, round,
)
# support functools
@ -45,7 +45,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt',
'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains',
'matmul', 'getitem', 'setitem',
'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip',
'partial', 'print', 'enumerate', 'isinstance', 'filter', 'abs',
'partial', 'print', 'enumerate', 'isinstance', 'filter', 'abs', 'all', 'any', 'round',
'exp', 'log', 'sin', 'cos', 'tan']

View File

@ -222,6 +222,10 @@ class Tensor(Tensor_):
out = tensor_operator_registry.get('__logical_not__')(self)
return out
def __round__(self):
out = tensor_operator_registry.get('round')()(self)
return out
def __bool__(self):
data = self.asnumpy()
if data.shape == ():

View File

@ -16,6 +16,7 @@
import pytest
import numpy as np
from mindspore import Tensor, context, nn
from mindspore import dtype as mstype
context.set_context(mode=context.GRAPH_MODE)
@ -28,14 +29,151 @@ context.set_context(mode=context.GRAPH_MODE)
def test_fallback_abs_tensor():
"""
Feature: JIT Fallback
Description: Test abs(Tensor) a variable tensor in construct function in graph mode
Description: Test abs(Tensor) with a variable tensor in graph mode
Expectation: No exception
"""
class TestCell(nn.Cell):
class Net(nn.Cell):
def construct(self, y):
x = Tensor([-1, 2])
return abs(x + y)
test_cell = TestCell()
assert np.all(test_cell(Tensor([-1, 2])).asnumpy() == np.array([2, 4]))
net = Net()
assert np.all(net(Tensor([-1, 2])).asnumpy() == np.array([2, 4]))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_all_tensor():
"""
Feature: JIT Fallback
Description: Test all(Tensor) with a variable tensor in graph mode
Expectation: No exception
"""
class Net(nn.Cell):
def construct(self, x, y):
return all(x), all(y)
net = Net()
x = Tensor(np.array([0, 1, 2, 3]))
y = Tensor(np.array([1, 1]))
out1, out2 = net(x, y)
assert (not out1) and out2
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_all_tensor_constant():
"""
Feature: JIT Fallback
Description: Test all(Tensor) with a constant tensor in graph mode
Expectation: No exception
"""
class Net(nn.Cell):
def construct(self):
x = Tensor(np.array([0, 1, 2, 3]))
y = Tensor(np.array([1, 1]))
return all(x), all(y)
net = Net()
out1, out2 = net()
assert (not out1) and out2
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_any_tensor():
"""
Feature: JIT Fallback
Description: Test any(Tensor) with a variable tensor in graph mode
Expectation: No exception
"""
class Net(nn.Cell):
def construct(self, x, y):
return any(x), any(y)
net = Net()
x = Tensor(np.array([0, 0]))
y = Tensor(np.array([1, 0]))
out1, out2 = net(x, y)
assert (not out1) and out2
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_any_tensor_constant():
"""
Feature: JIT Fallback
Description: Test any(Tensor) with a constant tensor in graph mode
Expectation: No exception
"""
class Net(nn.Cell):
def construct(self):
x = Tensor(np.array([0, 0]))
y = Tensor(np.array([1, 0]))
return any(x), any(y)
net = Net()
out1, out2 = net()
assert (not out1) and out2
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_round_tensor():
"""
Feature: JIT Fallback
Description: Test round(Tensor) with a variable tensor in graph mode
Expectation: No exception
"""
class Net(nn.Cell):
def construct(self, x):
return round(x)
net = Net()
x = Tensor(np.array([0.1, 4.51, 9.9]), mstype.float32)
out = net(x)
expect = Tensor(np.array([0.0, 5.0, 10.0]))
np.testing.assert_almost_equal(out.asnumpy(), expect.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_round_tensor_constant():
"""
Feature: JIT Fallback
Description: Test any(Tensor) with a constant tensor in graph mode
Expectation: No exception.
"""
class Net(nn.Cell):
def construct(self):
x = Tensor(np.array([0.1, 4.51, 9.9]), mstype.float32)
return round(x)
net = Net()
out = net()
expect = Tensor(np.array([0.0, 5.0, 10.0]))
np.testing.assert_almost_equal(out.asnumpy(), expect.asnumpy())

View File

@ -14,8 +14,7 @@
# ============================================================================
""" test graph fallback """
import numpy as np
from mindspore import ms_function, context, Tensor
from mindspore import ms_function, context
context.set_context(mode=context.GRAPH_MODE)
@ -71,38 +70,6 @@ def test_fallback_all_numpy():
assert (not x) and y
def test_fallback_all_tensor():
"""
Feature: JIT Fallback
Description: Test all(Tensor) in graph mode
Expectation: No exception
"""
@ms_function
def foo():
return all(Tensor(np.array([0, 1, 2, 3]))), all(Tensor(np.array([1, 1])))
x, y = foo()
assert (not x) and y
def test_fallback_all_tensor_construct():
"""
Feature: JIT Fallback
Description: Test all(Tensor) in graph mode
Expectation: No exception
"""
@ms_function
def foo():
x = Tensor(np.array([0, 1, 2, 3]))
y = Tensor(np.array([1, 1]))
return all(x), all(y)
x, y = foo()
assert (not x) and y
def test_fallback_any_tuple():
"""
Feature: JIT Fallback
@ -152,35 +119,3 @@ def test_fallback_any_numpy():
x, y = foo()
assert (not x) and y
def test_fallback_any_tensor():
"""
Feature: JIT Fallback
Description: Test all(Tensor) in graph mode
Expectation: No exception
"""
@ms_function
def foo():
return any(Tensor(np.array([0, 0]))), any(Tensor(np.array([1, 0])))
x, y = foo()
assert (not x) and y
def test_fallback_any_tensor_construct():
"""
Feature: JIT Fallback
Description: Test all(Tensor) in graph mode
Expectation: No exception
"""
@ms_function
def foo():
x = Tensor(np.array([0, 0, 0]))
y = Tensor(np.array([1, 0]))
return any(x), any(y)
x, y = foo()
assert (not x) and y