Rewrite tensor's __bool__ for pynative mode

This commit is contained in:
simson 2020-07-21 19:54:50 +08:00
parent 28c8a5cc26
commit 5f77fbdd75
18 changed files with 54 additions and 44 deletions

View File

@ -672,7 +672,7 @@ def check_input_data(*data, data_class):
def check_output_data(data):
"""Output data check."""
if not data:
if data is None:
raise RuntimeError('Executor return data ' + str(data) + ', please check your net or input data.')

View File

@ -17,6 +17,7 @@
"""standard_method"""
from dataclasses import dataclass
from mindspore.common import dtype as mstype
from mindspore.common._register_for_tensor import tensor_operator_registry
from ...ops import functional as F
from ...ops import operations as P
from ...ops.primitive import constexpr
@ -146,7 +147,7 @@ def check_is_tensor_bool_cond(shp):
"""check if tensor is a bool condition"""
if shp in ((), (1,)):
return True
raise ValueError("tensor as bool condition, its shape should be () or (1,), but got ", shp)
raise ValueError("The truth value of an array with several elements is ambiguous.")
@constexpr
def const_tensor_to_bool(x):
@ -155,7 +156,7 @@ def const_tensor_to_bool(x):
raise ValueError("Only constant tensor bool can be converted to bool")
x = x.asnumpy()
if x.shape not in ((), (1,)):
raise ValueError("Tensor to bool should input shape () or (1), but got ", x.shape)
raise ValueError("The truth value of an array with several elements is ambiguous.")
if x.shape == ():
value = bool(x)
else:
@ -296,3 +297,5 @@ def list_append(self_, item):
def to_array(x):
"""Implementation of `to_array`."""
return x.__ms_to_array__()
tensor_operator_registry.register('__bool__', tensor_bool)

View File

@ -108,6 +108,10 @@ class Tensor(Tensor_):
out = tensor_operator_registry.get('__neg__')(self)
return out
def __bool__(self):
out = tensor_operator_registry.get('__bool__')(self)
return out
def __pos__(self):
return self

View File

@ -28,6 +28,7 @@ hastype = Primitive('hastype')
cast = P.Cast()
dtype = P.DType()
isconstant = Primitive('is_constant')
isconstant.add_prim_attr('const_value', True)
issubclass_ = P.IsSubClass()

View File

@ -37,7 +37,7 @@ class Bprop(Cell):
self.grad = grad_op
self.sens = sens
self.with_sens = False
if sens:
if sens is not None:
self.with_sens = True
def construct(self, *inputs):
@ -71,10 +71,10 @@ def bprop(func, *inputs, grads_wrt_outputs=None, wrt: list = None, params: list
func.set_train()
with_sens_param = False
if grads_wrt_outputs:
if grads_wrt_outputs is not None:
with_sens_param = True
if not wrt:
if wrt is None:
wrt = []
wrt_inputs = False
if 'inputs' in wrt:

View File

@ -63,7 +63,7 @@ def fill_block_config(ret, block_config, tid, group, desc_inputs, desc_bprop, ex
sampling_times, reduce_output, init_param_with, \
split_outputs, exception, error_keywords = get_function_config(block_config[-1])
if block:
if block is not None:
func_list.append({
keyword.id: tid,
keyword.group: group,

View File

@ -22,7 +22,7 @@ from mindspore.common.tensor import Tensor
def setup_module():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
c1 = Tensor([2], mstype.int32)

View File

@ -48,7 +48,7 @@ def test_list_equal():
ret = net(x, y)
print(ret.asnumpy())
assert ret == x
assert np.all(ret.asnumpy() == x.asnumpy())
assert ret.dtype == mstype.int32
assert ret.shape == (6, 8, 10)
@ -70,7 +70,7 @@ def test_list_not_equal():
y = Tensor(np.zeros([3, 4, 5], np.int32))
z = [1, 2, 3]
net = Net(z)
assert net(x, y) == y
assert np.all(net(x, y).asnumpy() == y.asnumpy())
def test_list_expansion():
@ -91,7 +91,7 @@ def test_list_expansion():
y = Tensor(np.zeros([3, 4, 5], np.int32))
z = [1, 2, 3]
net = Net(z)
assert net(x, y) == x
assert np.all(net(x, y).asnumpy() == x.asnumpy())
def test_list_append():
@ -114,7 +114,7 @@ def test_list_append():
y = Tensor(np.zeros([3, 4, 5], np.int32))
z = [1, 2, 3]
net = Net(z)
assert net(x, y) == y
assert np.all(net(x, y).asnumpy() == y.asnumpy())
def test_class_member_list_append():

View File

@ -115,8 +115,7 @@ def test_if_none():
y = Tensor(np.zeros([3, 4, 5], np.int32))
z = None
net = Net(z)
assert net(x, y) == y
assert np.all(net(x, y).asnumpy() == y.asnumpy())
def test_if_str_is_not_none_right():
class Net(nn.Cell):
@ -136,7 +135,7 @@ def test_if_str_is_not_none_right():
y = Tensor(np.zeros([3, 4, 5], np.int32))
z = "ok"
net = Net(z)
assert net(x, y) == y
assert np.all(net(x, y).asnumpy() == y.asnumpy())
def test_if_str_is_not_none_left():
@ -157,7 +156,7 @@ def test_if_str_is_not_none_left():
y = Tensor(np.zeros([3, 4, 5], np.int32))
z = "ok"
net = Net(z)
assert net(x, y) == y
assert np.all(net(x, y).asnumpy() == y.asnumpy())
def test_if_none_equal_none():
@ -178,7 +177,7 @@ def test_if_none_equal_none():
y = Tensor(np.zeros([3, 4, 5], np.int32))
z = None
net = Net(z)
assert net(x, y) == x
assert np.all(net(x, y).asnumpy() == x.asnumpy())
def test_if_str_is_null():
@ -199,7 +198,7 @@ def test_if_str_is_null():
y = Tensor(np.zeros([3, 4, 5], np.int32))
z = ""
net = Net(z)
assert net(x, y) == y
assert np.all(net(x, y).asnumpy() == y.asnumpy())
def test_if_str_is_true():
@ -220,7 +219,7 @@ def test_if_str_is_true():
y = Tensor(np.zeros([3, 4, 5], np.int32))
z = "ok"
net = Net(z)
assert net(x, y) == x
assert np.all(net(x, y).asnumpy() == x.asnumpy())
def test_if_str_equal():
@ -241,7 +240,7 @@ def test_if_str_equal():
y = Tensor(np.zeros([3, 4, 5], np.int32))
z = "ok"
net = Net(z)
assert net(x, y) == x
assert np.all(net(x, y).asnumpy() == x.asnumpy())
def test_if_tuple_is_null():
@ -262,7 +261,7 @@ def test_if_tuple_is_null():
y = Tensor(np.zeros([3, 4, 5], np.int32))
z = ()
net = Net(z)
assert net(x, y) == y
assert np.all(net(x, y).asnumpy() == y.asnumpy())
def test_if_tuple_is_not_null():
@ -283,7 +282,7 @@ def test_if_tuple_is_not_null():
y = Tensor(np.zeros([3, 4, 5], np.int32))
z = (1, 2, 3)
net = Net(z)
assert net(x, y) == x
assert np.all(net(x, y).asnumpy() == x.asnumpy())
def test_if_dict_is_null():
@ -304,7 +303,7 @@ def test_if_dict_is_null():
y = Tensor(np.zeros([3, 4, 5], np.int32))
z = {}
net = Net(z)
assert net(x, y) == y
assert np.all(net(x, y).asnumpy() == y.asnumpy())
def test_if_dict_is_not_null():
@ -325,7 +324,7 @@ def test_if_dict_is_not_null():
y = Tensor(np.zeros([3, 4, 5], np.int32))
z = {"one": 1, "two": 2}
net = Net(z)
assert net(x, y) == x
assert np.all(net(x, y).asnumpy() == x.asnumpy())
def test_if_else_assign():
@ -355,7 +354,7 @@ def test_if_else_assign():
y = Tensor(np.zeros([3, 4, 5], np.int32))
z = [1, 2]
net = Net(z)
assert net(x, y) == x
assert np.all(net(x, y).asnumpy() == x.asnumpy())
def test_if_compile_true():

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
from mindspore import Tensor
from mindspore.train._utils import _to_full_shapes, _to_full_tensor
@ -33,7 +35,7 @@ def test_to_full_tensor_1():
expect = ([[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [1, 2, 3], [4, 5, 6], [0, 0, 0], [0, 0, 0]])
expect_tensor = Tensor(expect, dtype=ms.float32)
assert full_tensor[0] == expect_tensor
assert np.all(full_tensor[0].asnumpy() == expect_tensor.asnumpy())
def test_to_full_tensor_2():
@ -50,7 +52,8 @@ def test_to_full_tensor_2():
expect_tensor1 = Tensor(expect1, dtype=ms.int32)
expect_tensors = (expect_tensor0, expect_tensor1)
assert full_tensor == expect_tensors
assert np.all(full_tensor[0].asnumpy() == expect_tensors[0].asnumpy())
assert np.all(full_tensor[1].asnumpy() == expect_tensors[1].asnumpy())
def test_to_full_tensor_sens_2():
@ -68,4 +71,6 @@ def test_to_full_tensor_sens_2():
expect_tensor_sens = Tensor(0.1, dtype=ms.float32)
expect_tensors = (expect_tensor0, expect_tensor1, expect_tensor_sens)
assert full_tensor == expect_tensors
assert np.all(full_tensor[0].asnumpy() == expect_tensors[0].asnumpy())
assert np.all(full_tensor[1].asnumpy() == expect_tensors[1].asnumpy())
assert np.all(full_tensor[2].asnumpy() == expect_tensors[2].asnumpy())

View File

@ -47,7 +47,7 @@ def test_parser_three_default_mixed_args_subnet():
tensor1 = Tensor(np.full((2, 3), 2).astype(np.float32))
tensor2 = Tensor(np.full((3, 2), 4).astype(np.float32))
net = NetOut()
assert net(tensor1, tensor2) == tensor1
assert np.all(net(tensor1, tensor2).asnumpy() == tensor1.asnumpy())
# pylint: disable=keyword-arg-before-vararg

View File

@ -53,4 +53,7 @@ def test_hypermap_specialize_param():
expected_ret = (Tensor(np.full(1, 5).astype(np.int32)), Tensor(np.full(2, 5).astype(np.int32)))
ret = hypermap_specialize_param()
assert ret == (expected_ret, list(expected_ret))
assert ret[0][0].asnumpy() == expected_ret[0].asnumpy()
assert np.all(ret[0][1].asnumpy() == expected_ret[1].asnumpy())
assert ret[1][0].asnumpy() == list(expected_ret[0].asnumpy())
assert np.all(ret[1][1].asnumpy() == list(expected_ret[1].asnumpy()))

View File

@ -66,5 +66,4 @@ def test_assign_in_while():
input_shape = (1024, 512)
z = Tensor(np.random.randn(*input_shape).astype(np.float32))
net = Net(input_shape)
ret = net(x, y, z)
assert ret == z
net(x, y, z)

View File

@ -39,5 +39,5 @@ def test_tensor_orign_ops():
assert np.all(z.asnumpy() - (x.asnumpy() + y.asnumpy()) < 0.0001)
z = x * y
assert np.all(z.asnumpy() - (x.asnumpy() * y.asnumpy()) < 0.0001)
assert x == y
assert np.all(x.asnumpy() == y.asnumpy())
assert x != 'zero'

View File

@ -57,7 +57,7 @@ def test_multitype_tuple():
params1 = Parameter(tensor1, name="params1")
tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
output = op_add((params1, tensor2))
assert output == Tensor(np.array([[2.4, 4.2], [4.4, 6.4]]).astype('float32'))
assert np.all(output.asnumpy() == np.array([[2.4, 4.2], [4.4, 6.4]]).astype('float32'))
def test_multitype_scalar():

View File

@ -380,7 +380,7 @@ def test_while_net():
x = Tensor(np.ones([1, 16, 12, 12]).astype(np.float32))
z = Tensor(np.ones([1, 16, 16, 16]).astype(np.float32))
res = t1_while(x, y, z)
assert res == Tensor(np.ones([1, 16, 12, 12]).astype(np.float32) * 2306.0)
assert np.all(res.asnumpy() == np.ones([1, 16, 12, 12]).astype(np.float32) * 2306.0)
@ms_function
@ -403,7 +403,7 @@ def test_if_while():
x = Tensor(np.random.randn(1, 16, 12, 12).astype(np.float32))
z = Tensor(np.random.randn(1, 16, 16, 16).astype(np.float32))
res = if_while(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)), x, z)
assert res == Tensor(np.ones([64, 10]).astype(np.float32) * 4.0)
assert np.all(res.asnumpy() == np.ones([64, 10]).astype(np.float32) * 4.0)
def _while(x):
@ -550,7 +550,7 @@ def test_zeros():
""" test_zeros """
x = Tensor(np.ones([2, 3]).astype(np.int32))
res = zero_like_tensor(x)
assert res == Tensor(np.zeros([2, 3]).astype(np.int32))
assert np.all(res.asnumpy() == np.zeros([2, 3]).astype(np.int32))
@ms_function
@ -811,7 +811,7 @@ def test_while_sp():
z = Tensor(np.ones([1, 3]).astype(np.float32))
x = Tensor(np.ones([1, 3]).astype(np.float32) * 2.0)
res = while_sp(x, y, z)
assert res == Tensor(np.ones([1, 3]).astype(np.float32) * 1024.0)
assert np.all(res.asnumpy() == np.ones([1, 3]).astype(np.float32) * 1024.0)
def grad_refactor_simple_1(x, y):
@ -1030,7 +1030,7 @@ def test_grad_if_defer_inline():
network.add_flags(defer_inline=False)
inp = Tensor(np.ones([128, 96]).astype(np.float32))
grads = C.grad_all(network)(inp)
assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),)
assert np.all(grads[0].asnumpy() == np.full([128, 96], 0.6, dtype=np.float32))
def test_dict_const():

View File

@ -256,7 +256,7 @@ def test_stop_gradient_4():
def stop_test(x):
return stop_gradient(x)
assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (0,)
assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (1,)
def test_stop_gradient_5():

View File

@ -294,10 +294,7 @@ class TestSummaryCollector:
summary_collector = SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir)))
assert summary_collector._is_parse_loss_success
assert summary_collector._get_loss(cb_params) == expected_loss
if expected_loss is None:
assert not summary_collector._is_parse_loss_success
def test_get_optimizer_from_cb_params_success(self):
"""Test get optimizer success from cb params."""
@ -381,7 +378,6 @@ class TestSummaryCollector:
result = get_value()
assert PluginEnum.HISTOGRAM.value == result[0][0]
assert expected_names == [data[1] for data in result]
assert expected_values == [data[2] for data in result]
@pytest.mark.parametrize("specified_data, action, expected_result", [
(None, True, SummaryCollector._DEFAULT_SPECIFIED_DATA),