From f8a00d52af65f69bb6c81509963f64e3a72023d2 Mon Sep 17 00:00:00 2001 From: buxue Date: Wed, 24 Feb 2021 16:16:49 +0800 Subject: [PATCH] support invert bool tensor --- mindspore/_extends/parse/resources.py | 2 +- mindspore/common/tensor.py | 4 ++ mindspore/ops/functional.py | 1 + mindspore/ops/operations/math_ops.py | 24 ++++++- tests/ut/python/pipeline/parse/test_invert.py | 63 +++++++++++++++++++ 5 files changed, 92 insertions(+), 2 deletions(-) create mode 100644 tests/ut/python/pipeline/parse/test_invert.py diff --git a/mindspore/_extends/parse/resources.py b/mindspore/_extends/parse/resources.py index c9f0f05f0b7..53938f2f722 100644 --- a/mindspore/_extends/parse/resources.py +++ b/mindspore/_extends/parse/resources.py @@ -93,7 +93,7 @@ convert_object_map = { T.xor: NO_IMPLEMENT, T.pos: multitype_ops.uadd, T.neg: multitype_ops.negative, - T.invert: NO_IMPLEMENT, + T.invert: F.logical_not, T.not_: multitype_ops.logical_not, T.eq: multitype_ops.equal, T.ne: multitype_ops.not_equal, diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 24d0e654dec..24c0f2c6687 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -153,6 +153,10 @@ class Tensor(Tensor_): out = tensor_operator_registry.get('__neg__')(self) return out + def __invert__(self): + out = tensor_operator_registry.get('__logical_not__')(self) + return out + def __bool__(self): data = self.asnumpy() if data.shape == (): diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index b5651b2ea92..5507e0fc3de 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -223,6 +223,7 @@ tensor_operator_registry.register('__lt__', tensor_lt) tensor_operator_registry.register('__le__', tensor_le) tensor_operator_registry.register('__gt__', tensor_gt) tensor_operator_registry.register('__ge__', tensor_ge) +tensor_operator_registry.register('__logical_not__', logical_not) tensor_operator_registry.register('shape', shape) tensor_operator_registry.register('squeeze', squeeze) # support GE backend for no compare operators diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 42a6d8a9101..4bc3192771b 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -3105,9 +3105,15 @@ class LogicalNot(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_dtype_valid("x", x_dtype, [mstype.bool_], self.name) + validator.check_tensor_dtype_valid("x", x_dtype, [mstype.bool_], self.name + " or '~' operator") return mstype.tensor_type(mstype.bool_) + def infer_value(self, x): + if x is not None: + x = x.asnumpy() + return Tensor(np.logical_not(x)) + return None + class LogicalAnd(_LogicBinaryOp): """ @@ -3146,6 +3152,14 @@ class LogicalAnd(_LogicBinaryOp): def infer_dtype(self, x_dtype, y_dtype): return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.name) + def infer_value(self, x, y): + if x is not None and y is not None: + x = x.asnumpy() + y = y.asnumpy() + out = np.array(np.logical_and(x, y)) + return Tensor(out) + return None + class LogicalOr(_LogicBinaryOp): """ @@ -3184,6 +3198,14 @@ class LogicalOr(_LogicBinaryOp): def infer_dtype(self, x_dtype, y_dtype): return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.name) + def infer_value(self, x, y): + if x is not None and y is not None: + x = x.asnumpy() + y = y.asnumpy() + out = np.array(np.logical_or(x, y)) + return Tensor(out) + return None + class IsNan(PrimitiveWithInfer): """ diff --git a/tests/ut/python/pipeline/parse/test_invert.py b/tests/ut/python/pipeline/parse/test_invert.py new file mode 100644 index 00000000000..d67cf2c985e --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_invert.py @@ -0,0 +1,63 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test '~' """ +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context + + +class InvertNet(nn.Cell): + def __init__(self): + super(InvertNet, self).__init__() + self.t = Tensor(np.array([True, False, True])) + + def construct(self, x): + invert_t = ~self.t + invert_x = ~x + ret = (invert_t, invert_x) + return ret + + +def test_invert_bool_tensor(): + net = InvertNet() + input_x = Tensor(np.array([False, True, False])) + + context.set_context(mode=context.PYNATIVE_MODE) + ret = net(input_x) + assert (ret[0].asnumpy() == np.array([False, True, False])).all() + assert (ret[1].asnumpy() == np.array([True, False, True])).all() + + context.set_context(mode=context.GRAPH_MODE) + net(input_x) + + +def test_invert_int_tensor(): + net = InvertNet() + input_x = Tensor(np.array([1, 2, 3], np.int32)) + + context.set_context(mode=context.PYNATIVE_MODE) + with pytest.raises(TypeError) as err: + net(input_x) + assert "For 'LogicalNot or '~' operator', the type of `x` should be subclass of Tensor[Bool], " \ + "but got Tensor[Int32]" in str(err.value) + + context.set_context(mode=context.GRAPH_MODE) + with pytest.raises(TypeError) as err: + net(input_x) + assert "For 'LogicalNot or '~' operator', the type of `x` should be subclass of Tensor[Bool], " \ + "but got Tensor[Int32]" in str(err.value)