fix copy slice dtype

This commit is contained in:
huangmengxi 2021-07-16 09:52:17 +08:00
parent ac3c031f22
commit 6d38223b6e
1 changed files with 2 additions and 2 deletions

View File

@ -680,7 +680,7 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
return data
value_shape = (dim0_size,) + const_utils.tuple_slice(data.shape, 1, None)
value = _broadcast(value_shape, value)
return copy_slice(data, value, (start,), (stop,), (step,))
return copy_slice(data, value.astype(data.dtype), (start,), (stop,), (step,))
data_shape = F.shape(data)
indices = const_utils.slice2indices(input_slice, data_shape)
if indices is False:
@ -712,7 +712,7 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
step = (1, 1)
value_shape = (dim1_stop - dim1_start,) + const_utils.tuple_slice(data.shape, 2, None)
value = _broadcast(value_shape, value)
return copy_slice(data, value, start, stop, step)
return copy_slice(data, value.astype(data.dtype), start, stop, step)
tuple_index, value, idx_advanced = remove_expanded_dims(tuple_index, F.shape(data), value)