forked from mindspore-Ecosystem/mindspore
Support determine tensor in(not in) a list(tuple)
This commit is contained in:
parent
b4ce0aa933
commit
a9e781921a
|
@ -577,3 +577,12 @@ def tensor_setitem_by_ellipsis_with_tensor(data, index, value):
|
|||
param2 = F.cast(value, data_dtype)
|
||||
result = F.tensor_mul(param1, param2)
|
||||
return result
|
||||
|
||||
|
||||
def tensor_in_sequence(x, y):
|
||||
"""Assigns whether a sequence contains the given tensor"""
|
||||
for i in y:
|
||||
if isinstance(i, mstype.tensor) and x.shape == i.shape and x.dtype == i.dtype:
|
||||
if F.equal(x, i).all():
|
||||
return const_utils.scalar_to_tensor(True)
|
||||
return const_utils.scalar_to_tensor(False)
|
||||
|
|
|
@ -39,14 +39,17 @@ TENSOR_GETITEM = "tensor getitem"
|
|||
SET_ITEM_BY_ONE_TENSOR = 0
|
||||
SET_ITEM_BY_TUPLE_OF_TENSOR = 1
|
||||
|
||||
|
||||
@constexpr
|
||||
def raise_value_error(msg):
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
@constexpr
|
||||
def raise_index_error(msg):
|
||||
raise IndexError(msg)
|
||||
|
||||
|
||||
@constexpr
|
||||
def raise_type_error(msg):
|
||||
raise TypeError(msg)
|
||||
|
@ -704,7 +707,7 @@ def get_stride_info_from_slice(data_shape, slice_index):
|
|||
def get_stride_info_from_integer(data_shape, number):
|
||||
"""Get stride info from a integer"""
|
||||
begin_strides = [number]
|
||||
end_strides = [number+1]
|
||||
end_strides = [number + 1]
|
||||
step_strides = [1]
|
||||
for end in data_shape[1:]:
|
||||
begin_strides.append(0)
|
||||
|
@ -720,7 +723,7 @@ def get_slice_stride(dim_size, index_slice):
|
|||
stop_default = dim_size
|
||||
if step < 0:
|
||||
start_default = -1
|
||||
stop_default = -(dim_size+1)
|
||||
stop_default = -(dim_size + 1)
|
||||
start = start_default if index_slice.start is None else index_slice.start
|
||||
stop = stop_default if index_slice.stop is None else index_slice.stop
|
||||
return start, stop, step
|
||||
|
@ -775,3 +778,9 @@ def mstype_eq(x, y):
|
|||
if x == y:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@constexpr
|
||||
def scalar_to_tensor(x):
|
||||
"""Convert a scalar to a tensor"""
|
||||
return Tensor(x)
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
"""Implementation for internal polymorphism `in` operations."""
|
||||
|
||||
from . import _constexpr_utils as const_utils
|
||||
from . import _compile_utils as compile_utils
|
||||
from ... import functional as F
|
||||
from ...composite import base
|
||||
|
||||
|
@ -99,3 +100,33 @@ def _str_in_dict(x, y):
|
|||
bool, if x in y return true, x not in y return false.
|
||||
"""
|
||||
return F.in_dict(x, y)
|
||||
|
||||
|
||||
@in_.register("Tensor", "List")
|
||||
def _tensor_in_list(x, y):
|
||||
"""
|
||||
Determine if a tensor in a list.
|
||||
|
||||
Args:
|
||||
x: Tensor
|
||||
y: List
|
||||
|
||||
Returns:
|
||||
bool, if x in y return true, x not in y return false.
|
||||
"""
|
||||
return compile_utils.tensor_in_sequence(x, y)
|
||||
|
||||
|
||||
@in_.register("Tensor", "Tuple")
|
||||
def _tensor_in_tuple(x, y):
|
||||
"""
|
||||
Determine if a tensor in a tuple.
|
||||
|
||||
Args:
|
||||
x: Tensor
|
||||
y: Tuple
|
||||
|
||||
Returns:
|
||||
bool, if x in y return true, x not in y return false.
|
||||
"""
|
||||
return compile_utils.tensor_in_sequence(x, y)
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
"""Implementation for internal polymorphism `not in` operations."""
|
||||
|
||||
from . import _constexpr_utils as const_utils
|
||||
from . import _compile_utils as compile_utils
|
||||
from ... import functional as F
|
||||
from ...composite import base
|
||||
|
||||
|
@ -99,3 +100,33 @@ def _str_not_in_dict(x, y):
|
|||
bool, if x not in y return true, x in y return false.
|
||||
"""
|
||||
return F.not_in_dict(x, y)
|
||||
|
||||
|
||||
@not_in_.register("Tensor", "List")
|
||||
def _tensor_not_in_list(x, y):
|
||||
"""
|
||||
Determine if a tensor not in a list.
|
||||
|
||||
Args:
|
||||
x: Tensor
|
||||
y: List
|
||||
|
||||
Returns:
|
||||
bool, if x not in y return true, x in y return false.
|
||||
"""
|
||||
return not compile_utils.tensor_in_sequence(x, y)
|
||||
|
||||
|
||||
@not_in_.register("Tensor", "Tuple")
|
||||
def _tensor_not_in_tuple(x, y):
|
||||
"""
|
||||
Determine if a tensor not in a tuple.
|
||||
|
||||
Args:
|
||||
x: Tensor
|
||||
y: Tuple
|
||||
|
||||
Returns:
|
||||
bool, if x not in y return true, x in y return false.
|
||||
"""
|
||||
return not compile_utils.tensor_in_sequence(x, y)
|
||||
|
|
|
@ -30,7 +30,6 @@ from tests.mindspore_test_framework.pipeline.forward.compile_forward \
|
|||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
|
||||
|
@ -258,6 +257,34 @@ class AxisListDefaultNet(nn.Cell):
|
|||
return self.reduce_sum(x)
|
||||
|
||||
|
||||
class TensorInList(nn.Cell):
|
||||
def __init__(self):
|
||||
super(TensorInList, self).__init__()
|
||||
self.t1 = Tensor(1, mstype.float32)
|
||||
self.t2 = Tensor(2, mstype.float32)
|
||||
|
||||
def construct(self, x):
|
||||
ret = x
|
||||
list_ = [1, [2, 3], "str", self.t1, self.t2, x]
|
||||
if x in list_:
|
||||
ret = x + x
|
||||
return ret
|
||||
|
||||
|
||||
class TensorNotInList(nn.Cell):
|
||||
def __init__(self):
|
||||
super(TensorNotInList, self).__init__()
|
||||
self.t1 = Tensor(1, mstype.float32)
|
||||
self.t2 = Tensor(2, mstype.float32)
|
||||
|
||||
def construct(self, x):
|
||||
ret = x
|
||||
list_ = [self.t2, x]
|
||||
if self.t1 not in list_:
|
||||
ret = x + x
|
||||
return ret
|
||||
|
||||
|
||||
test_case_ops = [
|
||||
('ListOperate', {
|
||||
'block': ListOperate(),
|
||||
|
@ -275,6 +302,12 @@ test_case_ops = [
|
|||
('InList', {
|
||||
'block': InListNet(),
|
||||
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
|
||||
('TensorInList', {
|
||||
'block': TensorInList(),
|
||||
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
|
||||
('TensorNotInList', {
|
||||
'block': TensorNotInList(),
|
||||
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
|
||||
]
|
||||
|
||||
test_case_lists = [test_case_ops]
|
||||
|
|
|
@ -53,7 +53,7 @@ class NestTupleGraphNet(nn.Cell):
|
|||
|
||||
|
||||
class InTupleNet(nn.Cell):
|
||||
def __init__(self,):
|
||||
def __init__(self):
|
||||
super(InTupleNet, self).__init__()
|
||||
self.tuple_ = (1, 2, 3, 4, 5, "ok")
|
||||
|
||||
|
@ -66,6 +66,34 @@ class InTupleNet(nn.Cell):
|
|||
return ret
|
||||
|
||||
|
||||
class TensorInTuple(nn.Cell):
|
||||
def __init__(self):
|
||||
super(TensorInTuple, self).__init__()
|
||||
self.t1 = Tensor(1, mstype.float32)
|
||||
self.t2 = Tensor(2, mstype.float32)
|
||||
self.tuple_ = (self.t1, self.t2)
|
||||
|
||||
def construct(self, x):
|
||||
ret = x
|
||||
if self.t1 in self.tuple_:
|
||||
ret = x + x
|
||||
return ret
|
||||
|
||||
|
||||
class TensorNotInTuple(nn.Cell):
|
||||
def __init__(self):
|
||||
super(TensorNotInTuple, self).__init__()
|
||||
self.t1 = Tensor(1, mstype.float32)
|
||||
self.t2 = Tensor(2, mstype.float32)
|
||||
self.tuple_ = (self.t1, self.t2)
|
||||
|
||||
def construct(self, x):
|
||||
ret = x
|
||||
if self.t1 not in self.tuple_:
|
||||
ret = x + x
|
||||
return ret
|
||||
|
||||
|
||||
test_case_ops = [
|
||||
('TupleGraph', {
|
||||
'block': TupleGraphNet(),
|
||||
|
@ -75,7 +103,13 @@ test_case_ops = [
|
|||
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}),
|
||||
('InTuple', {
|
||||
'block': InTupleNet(),
|
||||
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]})
|
||||
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}),
|
||||
('TensorInTuple', {
|
||||
'block': TensorInTuple(),
|
||||
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}),
|
||||
('TensorNotInTuple', {
|
||||
'block': TensorNotInTuple(),
|
||||
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}),
|
||||
]
|
||||
|
||||
test_case_lists = [test_case_ops]
|
||||
|
|
Loading…
Reference in New Issue