diff --git a/mindspore/_extends/parse/resources.py b/mindspore/_extends/parse/resources.py index 1f6f4b91b55..5dd24ccf800 100644 --- a/mindspore/_extends/parse/resources.py +++ b/mindspore/_extends/parse/resources.py @@ -92,16 +92,16 @@ convert_object_map = { T.and_: multitype_ops.logical_and, T.or_: multitype_ops.logical_or, T.xor: NO_IMPLEMENT, - T.pos: F.scalar_uadd, + T.pos: multitype_ops.uadd, T.neg: multitype_ops.negative, T.invert: NO_IMPLEMENT, - T.not_: F.bool_not, + T.not_: multitype_ops.logical_not, T.eq: multitype_ops.equal, - T.ne: F.scalar_ne, + T.ne: multitype_ops.not_equal, T.lt: multitype_ops.less, - T.gt: F.scalar_gt, + T.gt: multitype_ops.greater, T.le: multitype_ops.less_equal, - T.ge: F.scalar_ge, + T.ge: multitype_ops.greater_equal, T.is_: F.is_, T.is_not: F.is_not, T.contains: NO_IMPLEMENT, diff --git a/mindspore/ops/composite/multitype_ops/__init__.py b/mindspore/ops/composite/multitype_ops/__init__.py index 0ab8527ab45..db28b1b5f6c 100644 --- a/mindspore/ops/composite/multitype_ops/__init__.py +++ b/mindspore/ops/composite/multitype_ops/__init__.py @@ -23,23 +23,33 @@ from .getitem_impl import getitem from .zeros_like_impl import zeros_like from .ones_like_impl import ones_like from .equal_impl import equal +from .not_equal_impl import not_equal from .less_impl import less from .less_equal_impl import less_equal +from .greater_impl import greater +from .greater_equal_impl import greater_equal from .negative_impl import negative from .logical_and_impl import logical_and from .logical_or_impl import logical_or +from .logic_not_impl import logical_not +from .uadd_impl import uadd __all__ = [ 'add', 'sub', 'mul', 'div', + 'uadd', 'zeros_like', 'ones_like', 'equal', + 'not_equal', 'less', 'less_equal', + 'greater', + 'greater_equal', 'negative', 'getitem', 'logical_and', - 'logical_or' + 'logical_or', + 'logical_not' ] diff --git a/mindspore/ops/composite/multitype_ops/equal_impl.py b/mindspore/ops/composite/multitype_ops/equal_impl.py index 9ff7e6671e5..428cdf4705d 100644 --- a/mindspore/ops/composite/multitype_ops/equal_impl.py +++ b/mindspore/ops/composite/multitype_ops/equal_impl.py @@ -190,7 +190,8 @@ def _none_equal_tuple(x, y): """ return False - +@equal.register("Tensor", "Number") +@equal.register("Number", "Tensor") @equal.register("Tensor", "Tensor") def _tensor_equal_tensor(x, y): """ diff --git a/mindspore/ops/composite/multitype_ops/greater_equal_impl.py b/mindspore/ops/composite/multitype_ops/greater_equal_impl.py new file mode 100644 index 00000000000..2073abb762f --- /dev/null +++ b/mindspore/ops/composite/multitype_ops/greater_equal_impl.py @@ -0,0 +1,53 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""greater_equal_impl""" +from mindspore.ops.composite import base +from mindspore.ops import functional as F + +# greater_equal is a metagraph object which will determine if two objects are greater_equal according to input type +# using ".register" decorator +greater_equal = base.MultitypeFuncGraph("greater_equal") + + +@greater_equal.register("Number", "Number") +def _greater_equal_scala(x, y): + """ + Determine whether x is greater equal than y + + Args: + x(Number): Number. + y(Number): Number. + + Returns: + bool, if x >= y return true, x < y return false. + """ + return F.scalar_ge(x, y) + +@greater_equal.register("Tensor", "Number") +@greater_equal.register("Number", "Tensor") +@greater_equal.register("Tensor", "Tensor") +def _greater_equal_tensor(x, y): + """ + Determine whether tensor x is greater equal than tensor y elementwise + + Args: + x(Tensor): Tensor. + y(Tensor): Tensor. + + Returns: + Tensor, return value by operator P.GreaterEqual. + """ + return F.tensor_ge(x, y) diff --git a/mindspore/ops/composite/multitype_ops/greater_impl.py b/mindspore/ops/composite/multitype_ops/greater_impl.py new file mode 100644 index 00000000000..7bbf53da492 --- /dev/null +++ b/mindspore/ops/composite/multitype_ops/greater_impl.py @@ -0,0 +1,53 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Ungreater required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""equal_impl""" +from mindspore.ops.composite import base +from mindspore.ops import functional as F + +# greater is a metafuncgraph object which will determine if two objects are greater according to input type +# using ".register" decorator +greater = base.MultitypeFuncGraph("greater") + + +@greater.register("Number", "Number") +def _greater_scala(x, y): + """ + Determine whether two numbers are greater. + + Args: + x(Number): Number. + y(Number): Number. + + Returns: + bool, if x > y return true, x <= y return false. + """ + return F.scalar_gt(x, y) + +@greater.register("Tensor", "Number") +@greater.register("Number", "Tensor") +@greater.register("Tensor", "Tensor") +def _greater_tensor(x, y): + """ + Determine whether two tensor are greater by element. + + Args: + x(Tensor): Tensor. + y(Tensor): Tensor. + + Returns: + tensor, return operation of x and y by P.Greater + """ + return F.tensor_gt(x, y) diff --git a/mindspore/ops/composite/multitype_ops/less_equal_impl.py b/mindspore/ops/composite/multitype_ops/less_equal_impl.py index f02ab61da12..dc1438da2c1 100644 --- a/mindspore/ops/composite/multitype_ops/less_equal_impl.py +++ b/mindspore/ops/composite/multitype_ops/less_equal_impl.py @@ -36,7 +36,8 @@ def _less_equal_scala(x, y): """ return F.scalar_le(x, y) - +@less_equal.register("Tensor", "Number") +@less_equal.register("Number", "Tensor") @less_equal.register("Tensor", "Tensor") def _less_equal_tensor(x, y): """ diff --git a/mindspore/ops/composite/multitype_ops/less_impl.py b/mindspore/ops/composite/multitype_ops/less_impl.py index c9c20657e58..6e50e54c826 100644 --- a/mindspore/ops/composite/multitype_ops/less_impl.py +++ b/mindspore/ops/composite/multitype_ops/less_impl.py @@ -36,7 +36,8 @@ def _less_scala(x, y): """ return F.scalar_lt(x, y) - +@less.register("Tensor", "Number") +@less.register("Number", "Tensor") @less.register("Tensor", "Tensor") def _less_tensor(x, y): """ @@ -47,6 +48,6 @@ def _less_tensor(x, y): y(Tensor): Tensor. Returns: - bool, if x and y are less elements by element return true, else return false. + Tensor, return value of x and y by operation P.Less() """ return F.tensor_lt(x, y) diff --git a/mindspore/ops/composite/multitype_ops/logic_not_impl.py b/mindspore/ops/composite/multitype_ops/logic_not_impl.py new file mode 100644 index 00000000000..35ae766433f --- /dev/null +++ b/mindspore/ops/composite/multitype_ops/logic_not_impl.py @@ -0,0 +1,48 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""logical_not_impl""" +from mindspore.ops.composite import base +from mindspore.ops import functional as F + +# logical_not is a metagraph object which will generate function according to input type +# using ".register" decorator +logical_not = base.MultitypeFuncGraph("logical_not") + + +@logical_not.register("Number") +def _logical_not_scala(x): + """ + Return logical not operation result of x + + Args: + x(Number): Number. + + Returns: + bool, Return logical not operation result of x + """ + return F.bool_not(x.__bool__()) + + +@logical_not.register("Tensor") +def _logical_not_tensor(x): + """ + Return logical not operation result of x + Args: + x(Tensor): Tensor. + Returns: + Tensor, Return logical not operation result of x + """ + return F.logical_not(x) diff --git a/mindspore/ops/composite/multitype_ops/not_equal_impl.py b/mindspore/ops/composite/multitype_ops/not_equal_impl.py new file mode 100644 index 00000000000..de099a2b8f1 --- /dev/null +++ b/mindspore/ops/composite/multitype_ops/not_equal_impl.py @@ -0,0 +1,237 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""not_equal_impl""" + +from ...composite import base +from ... import functional as F + + +not_equal = base.MultitypeFuncGraph("not_equal") +""" +not_equal is a metafuncgraph object which will determine if two objects are not_equal according to input type +using ".register" decorator +""" + + +@not_equal.register("Number", "Number") +def _not_equal_scalar(x, y): + """ + Determine if two numbers is not equal. + + Args: + x (Number): x + y (NUmber): y + + Returns: + bool, if x != y return true, x == y return false. + """ + return not F.scalar_eq(x, y) + + +@not_equal.register("String", "String") +def _not_equal_string(x, y): + """ + Determine if two strings are not equal. + + Args: + x: str + y: str + + Returns: + bool, if x != y return true, x == y return false. + """ + return not F.string_eq(x, y) + + +@not_equal.register("String", "None") +def _string_not_equal_none(x, y): + """ + Determine if string not equals none. + + Args: + x: str. + y: None. + + Returns: + bool, return True. + """ + return True + + +@not_equal.register("None", "String") +def _none_not_equal_string(x, y): + """ + Determine if string not equals none. + + Args: + x: None. + y: str. + + Returns: + bool, return True. + """ + return True + + +@not_equal.register("None", "None") +def _none_not_equal_none(x, y): + """ + Determine if none not equals none. + + Args: + x: None. + y: None. + + Returns: + bool, return False. + """ + return False + + +@not_equal.register("Number", "None") +def _scalar_not_equal_none(x, y): + """ + Determine if number not equals none. + + Args: + x: Number. + y: None. + + Returns: + bool, return True. + """ + return True + + +@not_equal.register("None", "Number") +def _none_not_equal_scalar(x, y): + """ + Determine if number not_equals none. + + Args: + x: None. + y: NUmber. + + Returns: + bool, return True. + """ + return True + + +@not_equal.register("Tuple", "Tuple") +def _euqal_tuple(x, y): + """ + Determine if two tuples are not equal by element. + + Args: + x (tuple): x + y (tuple): y + + Returns: + bool, if x and y are not equal by element return true, else return false. + """ + return not F.tuple_equal(x, y) + + +@not_equal.register("List", "List") +def _euqal_list(x, y): + """ + Determine if two lists are not equal by element. + + Args: + x (list): x + y (list): y + + Returns: + bool, if x and y are not equal by element return true, else return false. + """ + return not F.list_equal(x, y) + + +@not_equal.register("Tuple", "None") +def _tuple_euqal_none(x, y): + """ + Determine if tuple element not equals none element. + + Args: + x: Tuple. + y: None. + + Returns: + bool, return True. + """ + return True + + +@not_equal.register("None", "Tuple") +def _none_not_equal_tuple(x, y): + """ + Determine if tuple element not equals none element. + + Args: + x: None. + y: Tuple. + + Returns: + bool, return True. + """ + return True + +@not_equal.register("Tensor", "Number") +@not_equal.register("Number", "Tensor") +@not_equal.register("Tensor", "Tensor") +def _tensor_not_equal_tensor(x, y): + """ + Determine if two tensors are not_equal. + + Args: + x : Tensor. + y : Tensor. + + Returns: + bool, if x == y return true, x != y return false. + """ + return F.not_equal(x, y) + + +@not_equal.register("Tensor", "None") +def _tensor_not_equal_none(x, y): + """ + Determine if tensor not_equal none. + + Args: + x : Tensor. + y : None. + + Returns: + bool, return True. + """ + return True + + +@not_equal.register("None", "Tensor") +def _none_not_equal_tensor(x, y): + """ + Determine if tensor not equal none. + + Args: + x : None. + y : Tensor. + + Returns: + bool, return True. + """ + return True diff --git a/mindspore/ops/composite/multitype_ops/uadd_impl.py b/mindspore/ops/composite/multitype_ops/uadd_impl.py new file mode 100644 index 00000000000..163120b5410 --- /dev/null +++ b/mindspore/ops/composite/multitype_ops/uadd_impl.py @@ -0,0 +1,26 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""uadd_impl""" +from mindspore.ops.composite import base + +# uadd is a metagraph object which will return operation result regarding input +# using ".register" decorator +uadd = base.MultitypeFuncGraph("uadd") + +@uadd.register("Tensor") +@uadd.register("Number") +def _uadd_scala(x): + return x diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index b8411d42c16..4da725145f7 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -43,12 +43,15 @@ tensor_add = P.TensorAdd() neg_tensor = P.Neg() tensor_lt = P.Less() tensor_le = P.LessEqual() +tensor_gt = P.Greater() +tensor_ge = P.GreaterEqual() tensor_sub = P.Sub() tensor_mul = P.Mul() tensor_div = P.RealDiv() strided_slice = P.StridedSlice() same_type_shape = P.SameTypeShape() equal = P.Equal() +not_equal = P.NotEqual() assign_sub = P.AssignSub() assign = P.Assign() square = P.Square() @@ -97,6 +100,7 @@ bool_or = Primitive("bool_or") bool_and = Primitive("bool_and") logical_and = P.LogicalAnd() logical_or = P.LogicalOr() +logical_not = P.LogicalNot() array_to_scalar = Primitive('array_to_scalar') is_ = Primitive("is_") is_not = Primitive("is_not") diff --git a/tests/ut/cpp/python_input/gtest_input/vm/vm_test.py b/tests/ut/cpp/python_input/gtest_input/vm/vm_test.py index bdd3c900d65..947e9fa2c32 100644 --- a/tests/ut/cpp/python_input/gtest_input/vm/vm_test.py +++ b/tests/ut/cpp/python_input/gtest_input/vm/vm_test.py @@ -17,6 +17,7 @@ from mindspore.ops import Primitive scala_add = Primitive('scalar_add') scala_mul = Primitive('scalar_mul') +scalar_gt = Primitive('scalar_gt') def scalar_add(x, y): """Implement `scalar_add`.""" return scala_add(x, y) @@ -26,6 +27,6 @@ def scalar_mul(x, y): return scala_mul(x, y) def test_if(x, y): - if x > y: + if scalar_gt(x, y): return x return y diff --git a/tests/ut/python/ops/test_python_operators.py b/tests/ut/python/ops/test_python_operators.py index d6c6c037606..eb65a7f373e 100644 --- a/tests/ut/python/ops/test_python_operators.py +++ b/tests/ut/python/ops/test_python_operators.py @@ -31,8 +31,20 @@ class ComparisonOpsNet(nn.Cell): def __init__(self): super(ComparisonOpsNet, self).__init__() def construct(self, x, y): - ret = x <= y - return ret + a = x <= y + b = x <= 1.0 + c = y >= 1.0 + d = y >= x + e = x < y + f = x < 1.0 + g = 1.0 > y + h = y > x + i = y == 3.0 + j = x != 4 + k = + x + l = + 1.0 + m = k != l + return a or b or c or d or e or f or g or h or i or j or m class LogicalNumberOpsNet(nn.Cell): def __init__(self): @@ -41,7 +53,7 @@ class LogicalNumberOpsNet(nn.Cell): self.one = 0 self.zero = 0.0 def construct(self, x, y): - if self.cond and self.one or self.zero: + if self.cond and self.one or self.zero and not self.one: return x + y return x - y @@ -51,7 +63,7 @@ class LogicalTensorOpsNet(nn.Cell): super(LogicalTensorOpsNet, self).__init__() self.const_true = Tensor(True, dtype=mstype.bool_) def construct(self, x, y): - ret = x and y and (y or self.const_true) + ret = x and y and (y or self.const_true) and (not self.const_true) return ret