forked from mindspore-Ecosystem/mindspore
add comparison ops
fix pylint use scalar_lt primitive directly fix review
This commit is contained in:
parent
7329a1ef87
commit
d12a720fc5
|
@ -92,16 +92,16 @@ convert_object_map = {
|
||||||
T.and_: multitype_ops.logical_and,
|
T.and_: multitype_ops.logical_and,
|
||||||
T.or_: multitype_ops.logical_or,
|
T.or_: multitype_ops.logical_or,
|
||||||
T.xor: NO_IMPLEMENT,
|
T.xor: NO_IMPLEMENT,
|
||||||
T.pos: F.scalar_uadd,
|
T.pos: multitype_ops.uadd,
|
||||||
T.neg: multitype_ops.negative,
|
T.neg: multitype_ops.negative,
|
||||||
T.invert: NO_IMPLEMENT,
|
T.invert: NO_IMPLEMENT,
|
||||||
T.not_: F.bool_not,
|
T.not_: multitype_ops.logical_not,
|
||||||
T.eq: multitype_ops.equal,
|
T.eq: multitype_ops.equal,
|
||||||
T.ne: F.scalar_ne,
|
T.ne: multitype_ops.not_equal,
|
||||||
T.lt: multitype_ops.less,
|
T.lt: multitype_ops.less,
|
||||||
T.gt: F.scalar_gt,
|
T.gt: multitype_ops.greater,
|
||||||
T.le: multitype_ops.less_equal,
|
T.le: multitype_ops.less_equal,
|
||||||
T.ge: F.scalar_ge,
|
T.ge: multitype_ops.greater_equal,
|
||||||
T.is_: F.is_,
|
T.is_: F.is_,
|
||||||
T.is_not: F.is_not,
|
T.is_not: F.is_not,
|
||||||
T.contains: NO_IMPLEMENT,
|
T.contains: NO_IMPLEMENT,
|
||||||
|
|
|
@ -23,23 +23,33 @@ from .getitem_impl import getitem
|
||||||
from .zeros_like_impl import zeros_like
|
from .zeros_like_impl import zeros_like
|
||||||
from .ones_like_impl import ones_like
|
from .ones_like_impl import ones_like
|
||||||
from .equal_impl import equal
|
from .equal_impl import equal
|
||||||
|
from .not_equal_impl import not_equal
|
||||||
from .less_impl import less
|
from .less_impl import less
|
||||||
from .less_equal_impl import less_equal
|
from .less_equal_impl import less_equal
|
||||||
|
from .greater_impl import greater
|
||||||
|
from .greater_equal_impl import greater_equal
|
||||||
from .negative_impl import negative
|
from .negative_impl import negative
|
||||||
from .logical_and_impl import logical_and
|
from .logical_and_impl import logical_and
|
||||||
from .logical_or_impl import logical_or
|
from .logical_or_impl import logical_or
|
||||||
|
from .logic_not_impl import logical_not
|
||||||
|
from .uadd_impl import uadd
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'add',
|
'add',
|
||||||
'sub',
|
'sub',
|
||||||
'mul',
|
'mul',
|
||||||
'div',
|
'div',
|
||||||
|
'uadd',
|
||||||
'zeros_like',
|
'zeros_like',
|
||||||
'ones_like',
|
'ones_like',
|
||||||
'equal',
|
'equal',
|
||||||
|
'not_equal',
|
||||||
'less',
|
'less',
|
||||||
'less_equal',
|
'less_equal',
|
||||||
|
'greater',
|
||||||
|
'greater_equal',
|
||||||
'negative',
|
'negative',
|
||||||
'getitem',
|
'getitem',
|
||||||
'logical_and',
|
'logical_and',
|
||||||
'logical_or'
|
'logical_or',
|
||||||
|
'logical_not'
|
||||||
]
|
]
|
||||||
|
|
|
@ -190,7 +190,8 @@ def _none_equal_tuple(x, y):
|
||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@equal.register("Tensor", "Number")
|
||||||
|
@equal.register("Number", "Tensor")
|
||||||
@equal.register("Tensor", "Tensor")
|
@equal.register("Tensor", "Tensor")
|
||||||
def _tensor_equal_tensor(x, y):
|
def _tensor_equal_tensor(x, y):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""greater_equal_impl"""
|
||||||
|
from mindspore.ops.composite import base
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
|
||||||
|
# greater_equal is a metagraph object which will determine if two objects are greater_equal according to input type
|
||||||
|
# using ".register" decorator
|
||||||
|
greater_equal = base.MultitypeFuncGraph("greater_equal")
|
||||||
|
|
||||||
|
|
||||||
|
@greater_equal.register("Number", "Number")
|
||||||
|
def _greater_equal_scala(x, y):
|
||||||
|
"""
|
||||||
|
Determine whether x is greater equal than y
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x(Number): Number.
|
||||||
|
y(Number): Number.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool, if x >= y return true, x < y return false.
|
||||||
|
"""
|
||||||
|
return F.scalar_ge(x, y)
|
||||||
|
|
||||||
|
@greater_equal.register("Tensor", "Number")
|
||||||
|
@greater_equal.register("Number", "Tensor")
|
||||||
|
@greater_equal.register("Tensor", "Tensor")
|
||||||
|
def _greater_equal_tensor(x, y):
|
||||||
|
"""
|
||||||
|
Determine whether tensor x is greater equal than tensor y elementwise
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x(Tensor): Tensor.
|
||||||
|
y(Tensor): Tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, return value by operator P.GreaterEqual.
|
||||||
|
"""
|
||||||
|
return F.tensor_ge(x, y)
|
|
@ -0,0 +1,53 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Ungreater required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""equal_impl"""
|
||||||
|
from mindspore.ops.composite import base
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
|
||||||
|
# greater is a metafuncgraph object which will determine if two objects are greater according to input type
|
||||||
|
# using ".register" decorator
|
||||||
|
greater = base.MultitypeFuncGraph("greater")
|
||||||
|
|
||||||
|
|
||||||
|
@greater.register("Number", "Number")
|
||||||
|
def _greater_scala(x, y):
|
||||||
|
"""
|
||||||
|
Determine whether two numbers are greater.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x(Number): Number.
|
||||||
|
y(Number): Number.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool, if x > y return true, x <= y return false.
|
||||||
|
"""
|
||||||
|
return F.scalar_gt(x, y)
|
||||||
|
|
||||||
|
@greater.register("Tensor", "Number")
|
||||||
|
@greater.register("Number", "Tensor")
|
||||||
|
@greater.register("Tensor", "Tensor")
|
||||||
|
def _greater_tensor(x, y):
|
||||||
|
"""
|
||||||
|
Determine whether two tensor are greater by element.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x(Tensor): Tensor.
|
||||||
|
y(Tensor): Tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensor, return operation of x and y by P.Greater
|
||||||
|
"""
|
||||||
|
return F.tensor_gt(x, y)
|
|
@ -36,7 +36,8 @@ def _less_equal_scala(x, y):
|
||||||
"""
|
"""
|
||||||
return F.scalar_le(x, y)
|
return F.scalar_le(x, y)
|
||||||
|
|
||||||
|
@less_equal.register("Tensor", "Number")
|
||||||
|
@less_equal.register("Number", "Tensor")
|
||||||
@less_equal.register("Tensor", "Tensor")
|
@less_equal.register("Tensor", "Tensor")
|
||||||
def _less_equal_tensor(x, y):
|
def _less_equal_tensor(x, y):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -36,7 +36,8 @@ def _less_scala(x, y):
|
||||||
"""
|
"""
|
||||||
return F.scalar_lt(x, y)
|
return F.scalar_lt(x, y)
|
||||||
|
|
||||||
|
@less.register("Tensor", "Number")
|
||||||
|
@less.register("Number", "Tensor")
|
||||||
@less.register("Tensor", "Tensor")
|
@less.register("Tensor", "Tensor")
|
||||||
def _less_tensor(x, y):
|
def _less_tensor(x, y):
|
||||||
"""
|
"""
|
||||||
|
@ -47,6 +48,6 @@ def _less_tensor(x, y):
|
||||||
y(Tensor): Tensor.
|
y(Tensor): Tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool, if x and y are less elements by element return true, else return false.
|
Tensor, return value of x and y by operation P.Less()
|
||||||
"""
|
"""
|
||||||
return F.tensor_lt(x, y)
|
return F.tensor_lt(x, y)
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""logical_not_impl"""
|
||||||
|
from mindspore.ops.composite import base
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
|
||||||
|
# logical_not is a metagraph object which will generate function according to input type
|
||||||
|
# using ".register" decorator
|
||||||
|
logical_not = base.MultitypeFuncGraph("logical_not")
|
||||||
|
|
||||||
|
|
||||||
|
@logical_not.register("Number")
|
||||||
|
def _logical_not_scala(x):
|
||||||
|
"""
|
||||||
|
Return logical not operation result of x
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x(Number): Number.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool, Return logical not operation result of x
|
||||||
|
"""
|
||||||
|
return F.bool_not(x.__bool__())
|
||||||
|
|
||||||
|
|
||||||
|
@logical_not.register("Tensor")
|
||||||
|
def _logical_not_tensor(x):
|
||||||
|
"""
|
||||||
|
Return logical not operation result of x
|
||||||
|
Args:
|
||||||
|
x(Tensor): Tensor.
|
||||||
|
Returns:
|
||||||
|
Tensor, Return logical not operation result of x
|
||||||
|
"""
|
||||||
|
return F.logical_not(x)
|
|
@ -0,0 +1,237 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""not_equal_impl"""
|
||||||
|
|
||||||
|
from ...composite import base
|
||||||
|
from ... import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
not_equal = base.MultitypeFuncGraph("not_equal")
|
||||||
|
"""
|
||||||
|
not_equal is a metafuncgraph object which will determine if two objects are not_equal according to input type
|
||||||
|
using ".register" decorator
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@not_equal.register("Number", "Number")
|
||||||
|
def _not_equal_scalar(x, y):
|
||||||
|
"""
|
||||||
|
Determine if two numbers is not equal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Number): x
|
||||||
|
y (NUmber): y
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool, if x != y return true, x == y return false.
|
||||||
|
"""
|
||||||
|
return not F.scalar_eq(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
@not_equal.register("String", "String")
|
||||||
|
def _not_equal_string(x, y):
|
||||||
|
"""
|
||||||
|
Determine if two strings are not equal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: str
|
||||||
|
y: str
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool, if x != y return true, x == y return false.
|
||||||
|
"""
|
||||||
|
return not F.string_eq(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
@not_equal.register("String", "None")
|
||||||
|
def _string_not_equal_none(x, y):
|
||||||
|
"""
|
||||||
|
Determine if string not equals none.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: str.
|
||||||
|
y: None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool, return True.
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@not_equal.register("None", "String")
|
||||||
|
def _none_not_equal_string(x, y):
|
||||||
|
"""
|
||||||
|
Determine if string not equals none.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: None.
|
||||||
|
y: str.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool, return True.
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@not_equal.register("None", "None")
|
||||||
|
def _none_not_equal_none(x, y):
|
||||||
|
"""
|
||||||
|
Determine if none not equals none.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: None.
|
||||||
|
y: None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool, return False.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@not_equal.register("Number", "None")
|
||||||
|
def _scalar_not_equal_none(x, y):
|
||||||
|
"""
|
||||||
|
Determine if number not equals none.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Number.
|
||||||
|
y: None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool, return True.
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@not_equal.register("None", "Number")
|
||||||
|
def _none_not_equal_scalar(x, y):
|
||||||
|
"""
|
||||||
|
Determine if number not_equals none.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: None.
|
||||||
|
y: NUmber.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool, return True.
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@not_equal.register("Tuple", "Tuple")
|
||||||
|
def _euqal_tuple(x, y):
|
||||||
|
"""
|
||||||
|
Determine if two tuples are not equal by element.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (tuple): x
|
||||||
|
y (tuple): y
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool, if x and y are not equal by element return true, else return false.
|
||||||
|
"""
|
||||||
|
return not F.tuple_equal(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
@not_equal.register("List", "List")
|
||||||
|
def _euqal_list(x, y):
|
||||||
|
"""
|
||||||
|
Determine if two lists are not equal by element.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (list): x
|
||||||
|
y (list): y
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool, if x and y are not equal by element return true, else return false.
|
||||||
|
"""
|
||||||
|
return not F.list_equal(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
@not_equal.register("Tuple", "None")
|
||||||
|
def _tuple_euqal_none(x, y):
|
||||||
|
"""
|
||||||
|
Determine if tuple element not equals none element.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Tuple.
|
||||||
|
y: None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool, return True.
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@not_equal.register("None", "Tuple")
|
||||||
|
def _none_not_equal_tuple(x, y):
|
||||||
|
"""
|
||||||
|
Determine if tuple element not equals none element.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: None.
|
||||||
|
y: Tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool, return True.
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
@not_equal.register("Tensor", "Number")
|
||||||
|
@not_equal.register("Number", "Tensor")
|
||||||
|
@not_equal.register("Tensor", "Tensor")
|
||||||
|
def _tensor_not_equal_tensor(x, y):
|
||||||
|
"""
|
||||||
|
Determine if two tensors are not_equal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x : Tensor.
|
||||||
|
y : Tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool, if x == y return true, x != y return false.
|
||||||
|
"""
|
||||||
|
return F.not_equal(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
@not_equal.register("Tensor", "None")
|
||||||
|
def _tensor_not_equal_none(x, y):
|
||||||
|
"""
|
||||||
|
Determine if tensor not_equal none.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x : Tensor.
|
||||||
|
y : None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool, return True.
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@not_equal.register("None", "Tensor")
|
||||||
|
def _none_not_equal_tensor(x, y):
|
||||||
|
"""
|
||||||
|
Determine if tensor not equal none.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x : None.
|
||||||
|
y : Tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool, return True.
|
||||||
|
"""
|
||||||
|
return True
|
|
@ -0,0 +1,26 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""uadd_impl"""
|
||||||
|
from mindspore.ops.composite import base
|
||||||
|
|
||||||
|
# uadd is a metagraph object which will return operation result regarding input
|
||||||
|
# using ".register" decorator
|
||||||
|
uadd = base.MultitypeFuncGraph("uadd")
|
||||||
|
|
||||||
|
@uadd.register("Tensor")
|
||||||
|
@uadd.register("Number")
|
||||||
|
def _uadd_scala(x):
|
||||||
|
return x
|
|
@ -43,12 +43,15 @@ tensor_add = P.TensorAdd()
|
||||||
neg_tensor = P.Neg()
|
neg_tensor = P.Neg()
|
||||||
tensor_lt = P.Less()
|
tensor_lt = P.Less()
|
||||||
tensor_le = P.LessEqual()
|
tensor_le = P.LessEqual()
|
||||||
|
tensor_gt = P.Greater()
|
||||||
|
tensor_ge = P.GreaterEqual()
|
||||||
tensor_sub = P.Sub()
|
tensor_sub = P.Sub()
|
||||||
tensor_mul = P.Mul()
|
tensor_mul = P.Mul()
|
||||||
tensor_div = P.RealDiv()
|
tensor_div = P.RealDiv()
|
||||||
strided_slice = P.StridedSlice()
|
strided_slice = P.StridedSlice()
|
||||||
same_type_shape = P.SameTypeShape()
|
same_type_shape = P.SameTypeShape()
|
||||||
equal = P.Equal()
|
equal = P.Equal()
|
||||||
|
not_equal = P.NotEqual()
|
||||||
assign_sub = P.AssignSub()
|
assign_sub = P.AssignSub()
|
||||||
assign = P.Assign()
|
assign = P.Assign()
|
||||||
square = P.Square()
|
square = P.Square()
|
||||||
|
@ -97,6 +100,7 @@ bool_or = Primitive("bool_or")
|
||||||
bool_and = Primitive("bool_and")
|
bool_and = Primitive("bool_and")
|
||||||
logical_and = P.LogicalAnd()
|
logical_and = P.LogicalAnd()
|
||||||
logical_or = P.LogicalOr()
|
logical_or = P.LogicalOr()
|
||||||
|
logical_not = P.LogicalNot()
|
||||||
array_to_scalar = Primitive('array_to_scalar')
|
array_to_scalar = Primitive('array_to_scalar')
|
||||||
is_ = Primitive("is_")
|
is_ = Primitive("is_")
|
||||||
is_not = Primitive("is_not")
|
is_not = Primitive("is_not")
|
||||||
|
|
|
@ -17,6 +17,7 @@ from mindspore.ops import Primitive
|
||||||
|
|
||||||
scala_add = Primitive('scalar_add')
|
scala_add = Primitive('scalar_add')
|
||||||
scala_mul = Primitive('scalar_mul')
|
scala_mul = Primitive('scalar_mul')
|
||||||
|
scalar_gt = Primitive('scalar_gt')
|
||||||
def scalar_add(x, y):
|
def scalar_add(x, y):
|
||||||
"""Implement `scalar_add`."""
|
"""Implement `scalar_add`."""
|
||||||
return scala_add(x, y)
|
return scala_add(x, y)
|
||||||
|
@ -26,6 +27,6 @@ def scalar_mul(x, y):
|
||||||
return scala_mul(x, y)
|
return scala_mul(x, y)
|
||||||
|
|
||||||
def test_if(x, y):
|
def test_if(x, y):
|
||||||
if x > y:
|
if scalar_gt(x, y):
|
||||||
return x
|
return x
|
||||||
return y
|
return y
|
||||||
|
|
|
@ -31,8 +31,20 @@ class ComparisonOpsNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(ComparisonOpsNet, self).__init__()
|
super(ComparisonOpsNet, self).__init__()
|
||||||
def construct(self, x, y):
|
def construct(self, x, y):
|
||||||
ret = x <= y
|
a = x <= y
|
||||||
return ret
|
b = x <= 1.0
|
||||||
|
c = y >= 1.0
|
||||||
|
d = y >= x
|
||||||
|
e = x < y
|
||||||
|
f = x < 1.0
|
||||||
|
g = 1.0 > y
|
||||||
|
h = y > x
|
||||||
|
i = y == 3.0
|
||||||
|
j = x != 4
|
||||||
|
k = + x
|
||||||
|
l = + 1.0
|
||||||
|
m = k != l
|
||||||
|
return a or b or c or d or e or f or g or h or i or j or m
|
||||||
|
|
||||||
class LogicalNumberOpsNet(nn.Cell):
|
class LogicalNumberOpsNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -41,7 +53,7 @@ class LogicalNumberOpsNet(nn.Cell):
|
||||||
self.one = 0
|
self.one = 0
|
||||||
self.zero = 0.0
|
self.zero = 0.0
|
||||||
def construct(self, x, y):
|
def construct(self, x, y):
|
||||||
if self.cond and self.one or self.zero:
|
if self.cond and self.one or self.zero and not self.one:
|
||||||
return x + y
|
return x + y
|
||||||
return x - y
|
return x - y
|
||||||
|
|
||||||
|
@ -51,7 +63,7 @@ class LogicalTensorOpsNet(nn.Cell):
|
||||||
super(LogicalTensorOpsNet, self).__init__()
|
super(LogicalTensorOpsNet, self).__init__()
|
||||||
self.const_true = Tensor(True, dtype=mstype.bool_)
|
self.const_true = Tensor(True, dtype=mstype.bool_)
|
||||||
def construct(self, x, y):
|
def construct(self, x, y):
|
||||||
ret = x and y and (y or self.const_true)
|
ret = x and y and (y or self.const_true) and (not self.const_true)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue