diff --git a/crates/burn-jit/src/ops/qtensor.rs b/crates/burn-jit/src/ops/qtensor.rs index 3d9c84374..494bf4d59 100644 --- a/crates/burn-jit/src/ops/qtensor.rs +++ b/crates/burn-jit/src/ops/qtensor.rs @@ -103,6 +103,9 @@ where fn q_to_device(tensor: QuantizedTensor, device: &Device) -> QuantizedTensor { let mut tensor = tensor; tensor.qtensor = super::to_device(tensor.qtensor, device); + tensor.qparams.scale = super::to_device(tensor.qparams.scale, device); + tensor.qparams.offset = tensor.qparams.offset.map(|x| super::to_device(x, device)); + tensor }