forked from mindspore-Ecosystem/mindspore
Restrict tensor getitem or setitem not support mixed tensor.
This commit is contained in:
parent
b06c802807
commit
850171a34b
|
@ -254,7 +254,7 @@ def tuple_element_is_int(indexs):
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def tuple_elements_type(types):
|
def tuple_index_elements_type(types, op_name):
|
||||||
"""Judges the type of all elements of the tuple."""
|
"""Judges the type of all elements of the tuple."""
|
||||||
tensors_number = 0
|
tensors_number = 0
|
||||||
for ele in types:
|
for ele in types:
|
||||||
|
@ -264,7 +264,7 @@ def tuple_elements_type(types):
|
||||||
return ALL_TENSOR
|
return ALL_TENSOR
|
||||||
if tensors_number == 0:
|
if tensors_number == 0:
|
||||||
return NO_TENSOR
|
return NO_TENSOR
|
||||||
return CONTAIN_TENSOR
|
raise IndexError(f"For '{op_name}', the index does not support mixed tensor.")
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
|
|
|
@ -247,7 +247,7 @@ def _tensor_getitem_by_tuple(data, tuple_index):
|
||||||
Tensor, element type is same as the element type of data.
|
Tensor, element type is same as the element type of data.
|
||||||
"""
|
"""
|
||||||
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
|
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
|
||||||
index_elements_type = multi_utils.tuple_elements_type(index_types)
|
index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_GETITEM)
|
||||||
result = None
|
result = None
|
||||||
if index_elements_type == multi_utils.NO_TENSOR:
|
if index_elements_type == multi_utils.NO_TENSOR:
|
||||||
result = _tensor_slice(data, tuple_index)
|
result = _tensor_slice(data, tuple_index)
|
||||||
|
|
|
@ -191,7 +191,7 @@ def _tensor_setitem_by_tuple_with_number(data, tuple_index, value):
|
||||||
Tensor, element type and shape is same as data.
|
Tensor, element type and shape is same as data.
|
||||||
"""
|
"""
|
||||||
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
|
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
|
||||||
index_elements_type = multi_utils.tuple_elements_type(index_types)
|
index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM)
|
||||||
result = None
|
result = None
|
||||||
if index_elements_type == multi_utils.NO_TENSOR:
|
if index_elements_type == multi_utils.NO_TENSOR:
|
||||||
result = _tensor_assgin_number(data, tuple_index, value)
|
result = _tensor_assgin_number(data, tuple_index, value)
|
||||||
|
@ -222,7 +222,7 @@ def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
||||||
Tensor, element type and shape is same as data.
|
Tensor, element type and shape is same as data.
|
||||||
"""
|
"""
|
||||||
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
|
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
|
||||||
index_elements_type = multi_utils.tuple_elements_type(index_types)
|
index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM)
|
||||||
result = None
|
result = None
|
||||||
if index_elements_type == multi_utils.NO_TENSOR:
|
if index_elements_type == multi_utils.NO_TENSOR:
|
||||||
result = _tensor_assgin_tensor(data, tuple_index, value)
|
result = _tensor_assgin_tensor(data, tuple_index, value)
|
||||||
|
@ -254,7 +254,7 @@ def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
|
||||||
Tensor, element type and shape is same as data.
|
Tensor, element type and shape is same as data.
|
||||||
"""
|
"""
|
||||||
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
|
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
|
||||||
index_elements_type = multi_utils.tuple_elements_type(index_types)
|
index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM)
|
||||||
result = None
|
result = None
|
||||||
if index_elements_type == multi_utils.ALL_TENSOR:
|
if index_elements_type == multi_utils.ALL_TENSOR:
|
||||||
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM)
|
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM)
|
||||||
|
|
|
@ -146,9 +146,9 @@ class TensorAssignWithSlice(Cell):
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
|
||||||
class TensorIndexByOneTensor(Cell):
|
class TensorGetItemByOneTensor(Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(TensorIndexByOneTensor, self).__init__()
|
super(TensorGetItemByOneTensor, self).__init__()
|
||||||
self.const = Tensor(np.ones((5, 4, 7, 8)), mstype.int32)
|
self.const = Tensor(np.ones((5, 4, 7, 8)), mstype.int32)
|
||||||
|
|
||||||
def construct(self, x, index):
|
def construct(self, x, index):
|
||||||
|
@ -156,9 +156,9 @@ class TensorIndexByOneTensor(Cell):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class TensorIndexByTwoTensors(Cell):
|
class TensorGetItemByTwoTensors(Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(TensorIndexByTwoTensors, self).__init__()
|
super(TensorGetItemByTwoTensors, self).__init__()
|
||||||
self.const = Tensor(np.ones((3, 4, 5, 8)), mstype.int32)
|
self.const = Tensor(np.ones((3, 4, 5, 8)), mstype.int32)
|
||||||
|
|
||||||
def construct(self, x, index_0, index_1):
|
def construct(self, x, index_0, index_1):
|
||||||
|
@ -166,9 +166,9 @@ class TensorIndexByTwoTensors(Cell):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class TensorIndexByThreeTensors(Cell):
|
class TensorGetItemByThreeTensors(Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(TensorIndexByThreeTensors, self).__init__()
|
super(TensorGetItemByThreeTensors, self).__init__()
|
||||||
self.const = Tensor(np.ones((5, 3, 4, 5)), mstype.int32)
|
self.const = Tensor(np.ones((5, 3, 4, 5)), mstype.int32)
|
||||||
|
|
||||||
def construct(self, x, index_0, index_1, index_2):
|
def construct(self, x, index_0, index_1, index_2):
|
||||||
|
@ -176,6 +176,15 @@ class TensorIndexByThreeTensors(Cell):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class TensorGetItemByMixedTensors(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(TensorGetItemByMixedTensors, self).__init__()
|
||||||
|
|
||||||
|
def construct(self, x, index_0, index_1):
|
||||||
|
ret = x[index_0, index_1, 0:6]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class TensorSetItemByOneTensorWithNumber(Cell):
|
class TensorSetItemByOneTensorWithNumber(Cell):
|
||||||
def __init__(self, value):
|
def __init__(self, value):
|
||||||
super(TensorSetItemByOneTensorWithNumber, self).__init__()
|
super(TensorSetItemByOneTensorWithNumber, self).__init__()
|
||||||
|
@ -300,6 +309,19 @@ class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class TensorSetItemByMixedTensors(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(TensorSetItemByMixedTensors, self).__init__()
|
||||||
|
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||||
|
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||||
|
self.value = 99.0
|
||||||
|
|
||||||
|
def construct(self, index_0, index_1):
|
||||||
|
self.param[index_0, index_1, 0:6] = self.value
|
||||||
|
ret = self.param + self.const
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_assign():
|
def test_tensor_assign():
|
||||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||||
net = TensorAssignWithSlice()
|
net = TensorAssignWithSlice()
|
||||||
|
@ -596,19 +618,19 @@ test_cases = [
|
||||||
'block': NetWorkSliceEllipsis(),
|
'block': NetWorkSliceEllipsis(),
|
||||||
'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))],
|
'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))],
|
||||||
}),
|
}),
|
||||||
('TensorIndexByOneTensor', {
|
('TensorGetItemByOneTensor', {
|
||||||
'block': TensorIndexByOneTensor(),
|
'block': TensorGetItemByOneTensor(),
|
||||||
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||||
Tensor(np.random.randint(6, size=(5, 4)), mstype.int32)],
|
Tensor(np.random.randint(6, size=(5, 4)), mstype.int32)],
|
||||||
}),
|
}),
|
||||||
('TensorIndexByTwoTensors', {
|
('TensorGetItemByTwoTensors', {
|
||||||
'block': TensorIndexByTwoTensors(),
|
'block': TensorGetItemByTwoTensors(),
|
||||||
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||||
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)],
|
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)],
|
||||||
}),
|
}),
|
||||||
('TensorIndexByThreeTensors', {
|
('TensorGetItemByThreeTensors', {
|
||||||
'block': TensorIndexByThreeTensors(),
|
'block': TensorGetItemByThreeTensors(),
|
||||||
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||||
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||||
|
@ -665,37 +687,43 @@ test_cases = [
|
||||||
]
|
]
|
||||||
|
|
||||||
raise_error_set = [
|
raise_error_set = [
|
||||||
('TensorIndexByOneTensorDtypeError', {
|
('TensorGetItemByOneTensorDtypeError', {
|
||||||
'block': (TensorIndexByOneTensor(), {'exception': TypeError}),
|
'block': (TensorGetItemByOneTensor(), {'exception': TypeError}),
|
||||||
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||||
Tensor(np.random.randint(6, size=(5, 4)), mstype.int8)],
|
Tensor(np.random.randint(6, size=(5, 4)), mstype.int8)],
|
||||||
}),
|
}),
|
||||||
('TensorIndexByTwoTensorsShapeError', {
|
('TensorGetItemByTwoTensorsShapeError', {
|
||||||
'block': (TensorIndexByTwoTensors(), {'exception': ValueError}),
|
'block': (TensorGetItemByTwoTensors(), {'exception': ValueError}),
|
||||||
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||||
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
Tensor(np.random.randint(7, size=(2, 3, 5)), mstype.int32)],
|
Tensor(np.random.randint(7, size=(2, 3, 5)), mstype.int32)],
|
||||||
}),
|
}),
|
||||||
('TensorIndexByTwoTensorsDtypeError', {
|
('TensorGetItemByTwoTensorsDtypeError', {
|
||||||
'block': (TensorIndexByTwoTensors(), {'exception': TypeError}),
|
'block': (TensorGetItemByTwoTensors(), {'exception': TypeError}),
|
||||||
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||||
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
Tensor(np.random.randint(7, size=(4, 5)), mstype.float32)],
|
Tensor(np.random.randint(7, size=(4, 5)), mstype.float32)],
|
||||||
}),
|
}),
|
||||||
('TensorIndexByThreeTensorsShapeError', {
|
('TensorGetItemByThreeTensorsShapeError', {
|
||||||
'block': (TensorIndexByThreeTensors(), {'exception': ValueError}),
|
'block': (TensorGetItemByThreeTensors(), {'exception': ValueError}),
|
||||||
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||||
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32),
|
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32),
|
||||||
Tensor(np.random.randint(8, size=(5, 2, 4, 5)), mstype.int32)],
|
Tensor(np.random.randint(8, size=(5, 2, 4, 5)), mstype.int32)],
|
||||||
}),
|
}),
|
||||||
('TensorIndexByThreeTensorsDtypeError', {
|
('TensorGetItemByThreeTensorsDtypeError', {
|
||||||
'block': (TensorIndexByThreeTensors(), {'exception': TypeError}),
|
'block': (TensorGetItemByThreeTensors(), {'exception': TypeError}),
|
||||||
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||||
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64),
|
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64),
|
||||||
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
|
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
|
||||||
}),
|
}),
|
||||||
|
('TensorGetItemByMixedTensors', {
|
||||||
|
'block': (TensorGetItemByMixedTensors(), {'exception': IndexError}),
|
||||||
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64)],
|
||||||
|
}),
|
||||||
('TensorSetItemByOneTensorWithNumberTypeError', {
|
('TensorSetItemByOneTensorWithNumberTypeError', {
|
||||||
'block': (TensorSetItemByOneTensorWithNumber(value=0), {'exception': TypeError}),
|
'block': (TensorSetItemByOneTensorWithNumber(value=0), {'exception': TypeError}),
|
||||||
'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)],
|
'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)],
|
||||||
|
@ -781,6 +809,11 @@ raise_error_set = [
|
||||||
Tensor(np.zeros((4, 5)), mstype.float32),
|
Tensor(np.zeros((4, 5)), mstype.float32),
|
||||||
Tensor(np.ones((4, 5)), mstype.int32),
|
Tensor(np.ones((4, 5)), mstype.int32),
|
||||||
Tensor(np.ones((4, 5)) * 2, mstype.int32)],
|
Tensor(np.ones((4, 5)) * 2, mstype.int32)],
|
||||||
|
}),
|
||||||
|
('TensorSetItemByMixedTensors', {
|
||||||
|
'block': (TensorSetItemByMixedTensors(), {'exception': IndexError}),
|
||||||
|
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)],
|
||||||
})
|
})
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue