fix copy slice dtype
This commit is contained in:
parent
ac3c031f22
commit
6d38223b6e
|
@ -680,7 +680,7 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
|
|||
return data
|
||||
value_shape = (dim0_size,) + const_utils.tuple_slice(data.shape, 1, None)
|
||||
value = _broadcast(value_shape, value)
|
||||
return copy_slice(data, value, (start,), (stop,), (step,))
|
||||
return copy_slice(data, value.astype(data.dtype), (start,), (stop,), (step,))
|
||||
data_shape = F.shape(data)
|
||||
indices = const_utils.slice2indices(input_slice, data_shape)
|
||||
if indices is False:
|
||||
|
@ -712,7 +712,7 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
|||
step = (1, 1)
|
||||
value_shape = (dim1_stop - dim1_start,) + const_utils.tuple_slice(data.shape, 2, None)
|
||||
value = _broadcast(value_shape, value)
|
||||
return copy_slice(data, value, start, stop, step)
|
||||
return copy_slice(data, value.astype(data.dtype), start, stop, step)
|
||||
|
||||
tuple_index, value, idx_advanced = remove_expanded_dims(tuple_index, F.shape(data), value)
|
||||
|
||||
|
|
Loading…
Reference in New Issue