forked from mindspore-Ecosystem/mindspore
add comparison ops
fix pylint use scalar_lt primitive directly fix review
This commit is contained in:
parent
7329a1ef87
commit
d12a720fc5
|
@ -92,16 +92,16 @@ convert_object_map = {
|
|||
T.and_: multitype_ops.logical_and,
|
||||
T.or_: multitype_ops.logical_or,
|
||||
T.xor: NO_IMPLEMENT,
|
||||
T.pos: F.scalar_uadd,
|
||||
T.pos: multitype_ops.uadd,
|
||||
T.neg: multitype_ops.negative,
|
||||
T.invert: NO_IMPLEMENT,
|
||||
T.not_: F.bool_not,
|
||||
T.not_: multitype_ops.logical_not,
|
||||
T.eq: multitype_ops.equal,
|
||||
T.ne: F.scalar_ne,
|
||||
T.ne: multitype_ops.not_equal,
|
||||
T.lt: multitype_ops.less,
|
||||
T.gt: F.scalar_gt,
|
||||
T.gt: multitype_ops.greater,
|
||||
T.le: multitype_ops.less_equal,
|
||||
T.ge: F.scalar_ge,
|
||||
T.ge: multitype_ops.greater_equal,
|
||||
T.is_: F.is_,
|
||||
T.is_not: F.is_not,
|
||||
T.contains: NO_IMPLEMENT,
|
||||
|
|
|
@ -23,23 +23,33 @@ from .getitem_impl import getitem
|
|||
from .zeros_like_impl import zeros_like
|
||||
from .ones_like_impl import ones_like
|
||||
from .equal_impl import equal
|
||||
from .not_equal_impl import not_equal
|
||||
from .less_impl import less
|
||||
from .less_equal_impl import less_equal
|
||||
from .greater_impl import greater
|
||||
from .greater_equal_impl import greater_equal
|
||||
from .negative_impl import negative
|
||||
from .logical_and_impl import logical_and
|
||||
from .logical_or_impl import logical_or
|
||||
from .logic_not_impl import logical_not
|
||||
from .uadd_impl import uadd
|
||||
__all__ = [
|
||||
'add',
|
||||
'sub',
|
||||
'mul',
|
||||
'div',
|
||||
'uadd',
|
||||
'zeros_like',
|
||||
'ones_like',
|
||||
'equal',
|
||||
'not_equal',
|
||||
'less',
|
||||
'less_equal',
|
||||
'greater',
|
||||
'greater_equal',
|
||||
'negative',
|
||||
'getitem',
|
||||
'logical_and',
|
||||
'logical_or'
|
||||
'logical_or',
|
||||
'logical_not'
|
||||
]
|
||||
|
|
|
@ -190,7 +190,8 @@ def _none_equal_tuple(x, y):
|
|||
"""
|
||||
return False
|
||||
|
||||
|
||||
@equal.register("Tensor", "Number")
|
||||
@equal.register("Number", "Tensor")
|
||||
@equal.register("Tensor", "Tensor")
|
||||
def _tensor_equal_tensor(x, y):
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""greater_equal_impl"""
|
||||
from mindspore.ops.composite import base
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
# greater_equal is a metagraph object which will determine if two objects are greater_equal according to input type
|
||||
# using ".register" decorator
|
||||
greater_equal = base.MultitypeFuncGraph("greater_equal")
|
||||
|
||||
|
||||
@greater_equal.register("Number", "Number")
|
||||
def _greater_equal_scala(x, y):
|
||||
"""
|
||||
Determine whether x is greater equal than y
|
||||
|
||||
Args:
|
||||
x(Number): Number.
|
||||
y(Number): Number.
|
||||
|
||||
Returns:
|
||||
bool, if x >= y return true, x < y return false.
|
||||
"""
|
||||
return F.scalar_ge(x, y)
|
||||
|
||||
@greater_equal.register("Tensor", "Number")
|
||||
@greater_equal.register("Number", "Tensor")
|
||||
@greater_equal.register("Tensor", "Tensor")
|
||||
def _greater_equal_tensor(x, y):
|
||||
"""
|
||||
Determine whether tensor x is greater equal than tensor y elementwise
|
||||
|
||||
Args:
|
||||
x(Tensor): Tensor.
|
||||
y(Tensor): Tensor.
|
||||
|
||||
Returns:
|
||||
Tensor, return value by operator P.GreaterEqual.
|
||||
"""
|
||||
return F.tensor_ge(x, y)
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Ungreater required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""equal_impl"""
|
||||
from mindspore.ops.composite import base
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
# greater is a metafuncgraph object which will determine if two objects are greater according to input type
|
||||
# using ".register" decorator
|
||||
greater = base.MultitypeFuncGraph("greater")
|
||||
|
||||
|
||||
@greater.register("Number", "Number")
|
||||
def _greater_scala(x, y):
|
||||
"""
|
||||
Determine whether two numbers are greater.
|
||||
|
||||
Args:
|
||||
x(Number): Number.
|
||||
y(Number): Number.
|
||||
|
||||
Returns:
|
||||
bool, if x > y return true, x <= y return false.
|
||||
"""
|
||||
return F.scalar_gt(x, y)
|
||||
|
||||
@greater.register("Tensor", "Number")
|
||||
@greater.register("Number", "Tensor")
|
||||
@greater.register("Tensor", "Tensor")
|
||||
def _greater_tensor(x, y):
|
||||
"""
|
||||
Determine whether two tensor are greater by element.
|
||||
|
||||
Args:
|
||||
x(Tensor): Tensor.
|
||||
y(Tensor): Tensor.
|
||||
|
||||
Returns:
|
||||
tensor, return operation of x and y by P.Greater
|
||||
"""
|
||||
return F.tensor_gt(x, y)
|
|
@ -36,7 +36,8 @@ def _less_equal_scala(x, y):
|
|||
"""
|
||||
return F.scalar_le(x, y)
|
||||
|
||||
|
||||
@less_equal.register("Tensor", "Number")
|
||||
@less_equal.register("Number", "Tensor")
|
||||
@less_equal.register("Tensor", "Tensor")
|
||||
def _less_equal_tensor(x, y):
|
||||
"""
|
||||
|
|
|
@ -36,7 +36,8 @@ def _less_scala(x, y):
|
|||
"""
|
||||
return F.scalar_lt(x, y)
|
||||
|
||||
|
||||
@less.register("Tensor", "Number")
|
||||
@less.register("Number", "Tensor")
|
||||
@less.register("Tensor", "Tensor")
|
||||
def _less_tensor(x, y):
|
||||
"""
|
||||
|
@ -47,6 +48,6 @@ def _less_tensor(x, y):
|
|||
y(Tensor): Tensor.
|
||||
|
||||
Returns:
|
||||
bool, if x and y are less elements by element return true, else return false.
|
||||
Tensor, return value of x and y by operation P.Less()
|
||||
"""
|
||||
return F.tensor_lt(x, y)
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""logical_not_impl"""
|
||||
from mindspore.ops.composite import base
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
# logical_not is a metagraph object which will generate function according to input type
|
||||
# using ".register" decorator
|
||||
logical_not = base.MultitypeFuncGraph("logical_not")
|
||||
|
||||
|
||||
@logical_not.register("Number")
|
||||
def _logical_not_scala(x):
|
||||
"""
|
||||
Return logical not operation result of x
|
||||
|
||||
Args:
|
||||
x(Number): Number.
|
||||
|
||||
Returns:
|
||||
bool, Return logical not operation result of x
|
||||
"""
|
||||
return F.bool_not(x.__bool__())
|
||||
|
||||
|
||||
@logical_not.register("Tensor")
|
||||
def _logical_not_tensor(x):
|
||||
"""
|
||||
Return logical not operation result of x
|
||||
Args:
|
||||
x(Tensor): Tensor.
|
||||
Returns:
|
||||
Tensor, Return logical not operation result of x
|
||||
"""
|
||||
return F.logical_not(x)
|
|
@ -0,0 +1,237 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""not_equal_impl"""
|
||||
|
||||
from ...composite import base
|
||||
from ... import functional as F
|
||||
|
||||
|
||||
not_equal = base.MultitypeFuncGraph("not_equal")
|
||||
"""
|
||||
not_equal is a metafuncgraph object which will determine if two objects are not_equal according to input type
|
||||
using ".register" decorator
|
||||
"""
|
||||
|
||||
|
||||
@not_equal.register("Number", "Number")
|
||||
def _not_equal_scalar(x, y):
|
||||
"""
|
||||
Determine if two numbers is not equal.
|
||||
|
||||
Args:
|
||||
x (Number): x
|
||||
y (NUmber): y
|
||||
|
||||
Returns:
|
||||
bool, if x != y return true, x == y return false.
|
||||
"""
|
||||
return not F.scalar_eq(x, y)
|
||||
|
||||
|
||||
@not_equal.register("String", "String")
|
||||
def _not_equal_string(x, y):
|
||||
"""
|
||||
Determine if two strings are not equal.
|
||||
|
||||
Args:
|
||||
x: str
|
||||
y: str
|
||||
|
||||
Returns:
|
||||
bool, if x != y return true, x == y return false.
|
||||
"""
|
||||
return not F.string_eq(x, y)
|
||||
|
||||
|
||||
@not_equal.register("String", "None")
|
||||
def _string_not_equal_none(x, y):
|
||||
"""
|
||||
Determine if string not equals none.
|
||||
|
||||
Args:
|
||||
x: str.
|
||||
y: None.
|
||||
|
||||
Returns:
|
||||
bool, return True.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
@not_equal.register("None", "String")
|
||||
def _none_not_equal_string(x, y):
|
||||
"""
|
||||
Determine if string not equals none.
|
||||
|
||||
Args:
|
||||
x: None.
|
||||
y: str.
|
||||
|
||||
Returns:
|
||||
bool, return True.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
@not_equal.register("None", "None")
|
||||
def _none_not_equal_none(x, y):
|
||||
"""
|
||||
Determine if none not equals none.
|
||||
|
||||
Args:
|
||||
x: None.
|
||||
y: None.
|
||||
|
||||
Returns:
|
||||
bool, return False.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
@not_equal.register("Number", "None")
|
||||
def _scalar_not_equal_none(x, y):
|
||||
"""
|
||||
Determine if number not equals none.
|
||||
|
||||
Args:
|
||||
x: Number.
|
||||
y: None.
|
||||
|
||||
Returns:
|
||||
bool, return True.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
@not_equal.register("None", "Number")
|
||||
def _none_not_equal_scalar(x, y):
|
||||
"""
|
||||
Determine if number not_equals none.
|
||||
|
||||
Args:
|
||||
x: None.
|
||||
y: NUmber.
|
||||
|
||||
Returns:
|
||||
bool, return True.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
@not_equal.register("Tuple", "Tuple")
|
||||
def _euqal_tuple(x, y):
|
||||
"""
|
||||
Determine if two tuples are not equal by element.
|
||||
|
||||
Args:
|
||||
x (tuple): x
|
||||
y (tuple): y
|
||||
|
||||
Returns:
|
||||
bool, if x and y are not equal by element return true, else return false.
|
||||
"""
|
||||
return not F.tuple_equal(x, y)
|
||||
|
||||
|
||||
@not_equal.register("List", "List")
|
||||
def _euqal_list(x, y):
|
||||
"""
|
||||
Determine if two lists are not equal by element.
|
||||
|
||||
Args:
|
||||
x (list): x
|
||||
y (list): y
|
||||
|
||||
Returns:
|
||||
bool, if x and y are not equal by element return true, else return false.
|
||||
"""
|
||||
return not F.list_equal(x, y)
|
||||
|
||||
|
||||
@not_equal.register("Tuple", "None")
|
||||
def _tuple_euqal_none(x, y):
|
||||
"""
|
||||
Determine if tuple element not equals none element.
|
||||
|
||||
Args:
|
||||
x: Tuple.
|
||||
y: None.
|
||||
|
||||
Returns:
|
||||
bool, return True.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
@not_equal.register("None", "Tuple")
|
||||
def _none_not_equal_tuple(x, y):
|
||||
"""
|
||||
Determine if tuple element not equals none element.
|
||||
|
||||
Args:
|
||||
x: None.
|
||||
y: Tuple.
|
||||
|
||||
Returns:
|
||||
bool, return True.
|
||||
"""
|
||||
return True
|
||||
|
||||
@not_equal.register("Tensor", "Number")
|
||||
@not_equal.register("Number", "Tensor")
|
||||
@not_equal.register("Tensor", "Tensor")
|
||||
def _tensor_not_equal_tensor(x, y):
|
||||
"""
|
||||
Determine if two tensors are not_equal.
|
||||
|
||||
Args:
|
||||
x : Tensor.
|
||||
y : Tensor.
|
||||
|
||||
Returns:
|
||||
bool, if x == y return true, x != y return false.
|
||||
"""
|
||||
return F.not_equal(x, y)
|
||||
|
||||
|
||||
@not_equal.register("Tensor", "None")
|
||||
def _tensor_not_equal_none(x, y):
|
||||
"""
|
||||
Determine if tensor not_equal none.
|
||||
|
||||
Args:
|
||||
x : Tensor.
|
||||
y : None.
|
||||
|
||||
Returns:
|
||||
bool, return True.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
@not_equal.register("None", "Tensor")
|
||||
def _none_not_equal_tensor(x, y):
|
||||
"""
|
||||
Determine if tensor not equal none.
|
||||
|
||||
Args:
|
||||
x : None.
|
||||
y : Tensor.
|
||||
|
||||
Returns:
|
||||
bool, return True.
|
||||
"""
|
||||
return True
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""uadd_impl"""
|
||||
from mindspore.ops.composite import base
|
||||
|
||||
# uadd is a metagraph object which will return operation result regarding input
|
||||
# using ".register" decorator
|
||||
uadd = base.MultitypeFuncGraph("uadd")
|
||||
|
||||
@uadd.register("Tensor")
|
||||
@uadd.register("Number")
|
||||
def _uadd_scala(x):
|
||||
return x
|
|
@ -43,12 +43,15 @@ tensor_add = P.TensorAdd()
|
|||
neg_tensor = P.Neg()
|
||||
tensor_lt = P.Less()
|
||||
tensor_le = P.LessEqual()
|
||||
tensor_gt = P.Greater()
|
||||
tensor_ge = P.GreaterEqual()
|
||||
tensor_sub = P.Sub()
|
||||
tensor_mul = P.Mul()
|
||||
tensor_div = P.RealDiv()
|
||||
strided_slice = P.StridedSlice()
|
||||
same_type_shape = P.SameTypeShape()
|
||||
equal = P.Equal()
|
||||
not_equal = P.NotEqual()
|
||||
assign_sub = P.AssignSub()
|
||||
assign = P.Assign()
|
||||
square = P.Square()
|
||||
|
@ -97,6 +100,7 @@ bool_or = Primitive("bool_or")
|
|||
bool_and = Primitive("bool_and")
|
||||
logical_and = P.LogicalAnd()
|
||||
logical_or = P.LogicalOr()
|
||||
logical_not = P.LogicalNot()
|
||||
array_to_scalar = Primitive('array_to_scalar')
|
||||
is_ = Primitive("is_")
|
||||
is_not = Primitive("is_not")
|
||||
|
|
|
@ -17,6 +17,7 @@ from mindspore.ops import Primitive
|
|||
|
||||
scala_add = Primitive('scalar_add')
|
||||
scala_mul = Primitive('scalar_mul')
|
||||
scalar_gt = Primitive('scalar_gt')
|
||||
def scalar_add(x, y):
|
||||
"""Implement `scalar_add`."""
|
||||
return scala_add(x, y)
|
||||
|
@ -26,6 +27,6 @@ def scalar_mul(x, y):
|
|||
return scala_mul(x, y)
|
||||
|
||||
def test_if(x, y):
|
||||
if x > y:
|
||||
if scalar_gt(x, y):
|
||||
return x
|
||||
return y
|
||||
|
|
|
@ -31,8 +31,20 @@ class ComparisonOpsNet(nn.Cell):
|
|||
def __init__(self):
|
||||
super(ComparisonOpsNet, self).__init__()
|
||||
def construct(self, x, y):
|
||||
ret = x <= y
|
||||
return ret
|
||||
a = x <= y
|
||||
b = x <= 1.0
|
||||
c = y >= 1.0
|
||||
d = y >= x
|
||||
e = x < y
|
||||
f = x < 1.0
|
||||
g = 1.0 > y
|
||||
h = y > x
|
||||
i = y == 3.0
|
||||
j = x != 4
|
||||
k = + x
|
||||
l = + 1.0
|
||||
m = k != l
|
||||
return a or b or c or d or e or f or g or h or i or j or m
|
||||
|
||||
class LogicalNumberOpsNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -41,7 +53,7 @@ class LogicalNumberOpsNet(nn.Cell):
|
|||
self.one = 0
|
||||
self.zero = 0.0
|
||||
def construct(self, x, y):
|
||||
if self.cond and self.one or self.zero:
|
||||
if self.cond and self.one or self.zero and not self.one:
|
||||
return x + y
|
||||
return x - y
|
||||
|
||||
|
@ -51,7 +63,7 @@ class LogicalTensorOpsNet(nn.Cell):
|
|||
super(LogicalTensorOpsNet, self).__init__()
|
||||
self.const_true = Tensor(True, dtype=mstype.bool_)
|
||||
def construct(self, x, y):
|
||||
ret = x and y and (y or self.const_true)
|
||||
ret = x and y and (y or self.const_true) and (not self.const_true)
|
||||
return ret
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue