!20391 fix setitem dtype using copy slice

Merge pull request !20391 from huangmengxi/setitem_fix
This commit is contained in:
i-robot 2021-07-20 13:36:10 +00:00 committed by Gitee
commit b5dc887fcd
1 changed files with 2 additions and 2 deletions

View File

@ -680,7 +680,7 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
return data
value_shape = (dim0_size,) + const_utils.tuple_slice(data.shape, 1, None)
value = _broadcast(value_shape, value)
return copy_slice(data, value, (start,), (stop,), (step,))
return copy_slice(data, value.astype(data.dtype), (start,), (stop,), (step,))
data_shape = F.shape(data)
indices = const_utils.slice2indices(input_slice, data_shape)
if indices is False:
@ -712,7 +712,7 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
step = (1, 1)
value_shape = (dim1_stop - dim1_start,) + const_utils.tuple_slice(data.shape, 2, None)
value = _broadcast(value_shape, value)
return copy_slice(data, value, start, stop, step)
return copy_slice(data, value.astype(data.dtype), start, stop, step)
tuple_index, value, idx_advanced = remove_expanded_dims(tuple_index, F.shape(data), value)