!1530 support tensor set item the number value type is similar as tensor dtype

Merge pull request !1530 from zhangbuxue/support_tensor_getitem_number_value_type_similar_as_tensor_dtype
This commit is contained in:
mindspore-ci-bot 2020-05-27 21:29:24 +08:00 committed by Gitee
commit 77547cdfc0
2 changed files with 7 additions and 7 deletions

View File

@ -276,7 +276,7 @@ def check_value_elements(data_dtype, types):
else: else:
raise TypeError(f"For '{TENSOR_SETITEM}', the data type of {i}th tensor '{ele_dtype}' " raise TypeError(f"For '{TENSOR_SETITEM}', the data type of {i}th tensor '{ele_dtype}' "
f"in value tuple is not consistent with assigned tensor data type '{data_dtype}'.") f"in value tuple is not consistent with assigned tensor data type '{data_dtype}'.")
elif mstype.issubclass_(ele, data_dtype): elif mstype.dtype_to_pytype(ele) == mstype.dtype_to_pytype(data_dtype):
scalars_number += 1 scalars_number += 1
else: else:
raise TypeError(f"For '{TENSOR_SETITEM}', the {i}th element type '{ele}' in " raise TypeError(f"For '{TENSOR_SETITEM}', the {i}th element type '{ele}' in "

View File

@ -278,8 +278,8 @@ class TensorSetItemByMixedTensors_1(Cell):
class TensorSetItemByMixedTensors_2(Cell): class TensorSetItemByMixedTensors_2(Cell):
def __init__(self, value): def __init__(self, value):
super(TensorSetItemByMixedTensors_2, self).__init__() super(TensorSetItemByMixedTensors_2, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32)) self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float16))
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float16),
name="x") name="x")
self.value = value self.value = value
@ -911,7 +911,7 @@ test_cases = [
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}), }),
('TensorSetItemByMixedTensorsWithTensor_2', { ('TensorSetItemByMixedTensorsWithTensor_2', {
'block': TensorSetItemByMixedTensors_2(value=Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32))), 'block': TensorSetItemByMixedTensors_2(value=Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float16))),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
@ -923,9 +923,9 @@ test_cases = [
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}), }),
('TensorGetItemByMixedTensorsWithTupleOfTensor_2', { ('TensorGetItemByMixedTensorsWithTupleOfTensor_2', {
'block': TensorSetItemByMixedTensors_2(value=(Tensor(np.ones((4, 5), np.float32)), 'block': TensorSetItemByMixedTensors_2(value=(Tensor(np.ones((4, 5), np.float16)),
Tensor(np.zeros((4, 5), np.float32)), Tensor(np.zeros((4, 5), np.float16)),
Tensor(np.ones((4, 5), np.float32)))), Tensor(np.ones((4, 5), np.float16)))),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],