From c7d2be06ed924372cbab5f943f29980b4bcf822a Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 26 Sep 2024 11:22:26 -0400 Subject: [PATCH] Add fusion quantization ops (#2301) * Add fusion quantize/dequantize and from/into data * Add quantization tests for fusion * Fix jit q_to_device * Add q_to_device * Add comment * Add note on handles/streams order * Remove unused field * Fix clippy * Add QuantizedEncoding associated type --- crates/burn-autodiff/src/backend.rs | 1 + crates/burn-candle/src/backend.rs | 1 + crates/burn-fusion/src/backend.rs | 2 + crates/burn-fusion/src/client/base.rs | 21 +- crates/burn-fusion/src/client/mutex.rs | 72 ++++- crates/burn-fusion/src/ops/qtensor.rs | 250 ++++++++++++++++-- crates/burn-fusion/src/server.rs | 55 +++- crates/burn-fusion/src/stream/context.rs | 33 +++ crates/burn-fusion/src/tensor.rs | 80 +++++- crates/burn-jit/src/backend.rs | 1 + crates/burn-jit/src/fusion/base.rs | 68 ++++- crates/burn-jit/src/ops/qtensor.rs | 3 + crates/burn-jit/src/tests/mod.rs | 8 + crates/burn-ndarray/src/backend.rs | 1 + crates/burn-tch/src/backend.rs | 1 + crates/burn-tensor/src/repr/backend.rs | 26 +- crates/burn-tensor/src/repr/handle.rs | 64 ++++- crates/burn-tensor/src/repr/mod.rs | 2 + crates/burn-tensor/src/repr/operation.rs | 42 +++ crates/burn-tensor/src/repr/quantization.rs | 25 ++ crates/burn-tensor/src/tensor/backend/base.rs | 2 + .../src/tensor/quantization/scheme.rs | 6 +- 22 files changed, 705 insertions(+), 59 deletions(-) create mode 100644 crates/burn-tensor/src/repr/quantization.rs diff --git a/crates/burn-autodiff/src/backend.rs b/crates/burn-autodiff/src/backend.rs index 1508022f5..5918f2d24 100644 --- a/crates/burn-autodiff/src/backend.rs +++ b/crates/burn-autodiff/src/backend.rs @@ -36,6 +36,7 @@ impl Backend for Autodiff { type BoolTensorPrimitive = B::BoolTensorPrimitive; type QuantizedTensorPrimitive = B::QuantizedTensorPrimitive; + type QuantizedEncoding = B::QuantizedEncoding; fn ad_enabled() -> bool { true diff --git a/crates/burn-candle/src/backend.rs b/crates/burn-candle/src/backend.rs index fe42039dd..88694fe1a 100644 --- a/crates/burn-candle/src/backend.rs +++ b/crates/burn-candle/src/backend.rs @@ -172,6 +172,7 @@ impl Backend for Candle { type BoolTensorPrimitive = CandleTensor; type QuantizedTensorPrimitive = CandleQTensor; + type QuantizedEncoding = u8; fn ad_enabled() -> bool { false diff --git a/crates/burn-fusion/src/backend.rs b/crates/burn-fusion/src/backend.rs index 77965cb26..89dd28f7b 100644 --- a/crates/burn-fusion/src/backend.rs +++ b/crates/burn-fusion/src/backend.rs @@ -40,6 +40,8 @@ impl Backend for Fusion { type QuantizedTensorPrimitive = QFusionTensor; + type QuantizedEncoding = B::QuantizedEncoding; + fn name() -> String { format!("fusion<{}>", B::name()) } diff --git a/crates/burn-fusion/src/client/base.rs b/crates/burn-fusion/src/client/base.rs index 87358a30f..98c35c51a 100644 --- a/crates/burn-fusion/src/client/base.rs +++ b/crates/burn-fusion/src/client/base.rs @@ -2,10 +2,10 @@ use std::future::Future; use crate::{ stream::{execution::Operation, StreamId}, - FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionTensor, + FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionTensor, QFusionTensor, }; use burn_tensor::{ - repr::{OperationDescription, TensorDescription, TensorId}, + repr::{OperationDescription, QuantizedTensorDescription, TensorDescription, TensorId}, DType, TensorData, }; @@ -56,6 +56,14 @@ where tensor: TensorDescription, stream: StreamId, ) -> impl Future + Send + where + B: FusionBackend; + /// Read the values contained by a quantized tensor. + fn read_tensor_quantized( + &self, + tensor: QuantizedTensorDescription, + streams: Vec, + ) -> impl Future + Send where B: FusionBackend; /// Change the client of the given float tensor. @@ -83,6 +91,15 @@ where client: Self, stream: StreamId, ) -> FusionTensor + where + B: FusionBackend; + /// Change the client of the given quantized tensor. + fn change_client_quantized( + &self, + tensor: QuantizedTensorDescription, + client: Self, + streams: Vec, + ) -> QFusionTensor where B: FusionBackend; /// Drop the tensor with the given [tensor id](TensorId). diff --git a/crates/burn-fusion/src/client/mutex.rs b/crates/burn-fusion/src/client/mutex.rs index 7d1921cce..bcc1a87e9 100644 --- a/crates/burn-fusion/src/client/mutex.rs +++ b/crates/burn-fusion/src/client/mutex.rs @@ -1,10 +1,11 @@ use super::FusionClient; use crate::{ stream::{execution::Operation, StreamId}, - FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionServer, FusionTensor, + FusionBackend, FusionDevice, FusionHandle, FusionQuantizationParameters, FusionRuntime, + FusionServer, FusionTensor, QFusionTensor, }; use burn_tensor::{ - repr::{OperationDescription, TensorDescription, TensorId}, + repr::{OperationDescription, QuantizedTensorDescription, TensorDescription, TensorId}, DType, }; use spin::Mutex; @@ -111,6 +112,20 @@ where self.server.lock().read_bool::(tensor, stream).await } + async fn read_tensor_quantized( + &self, + tensor: QuantizedTensorDescription, + streams: Vec, + ) -> burn_tensor::TensorData + where + B: FusionBackend, + { + self.server + .lock() + .read_quantized::(tensor, streams) + .await + } + fn change_client_float( &self, tensor: TensorDescription, @@ -175,6 +190,59 @@ where FusionTensor::new(id, tensor.shape, tensor.dtype, client, StreamId::current()) } + fn change_client_quantized( + &self, + tensor: QuantizedTensorDescription, + client: Self, + streams: Vec, + ) -> QFusionTensor + where + B: FusionBackend, + { + let mut server_other = client.server.lock(); + let mut server_current = self.server.lock(); + for stream in streams { + server_current.drain_stream(stream); + } + + let mut ids = + server_current.change_server_quantized::(&tensor, &client.device, &mut server_other); + + core::mem::drop(server_other); + core::mem::drop(server_current); + + // NOTE: the expected order is known [qtensor, scale, ] + let offset = tensor.qparams.offset.map(|desc| { + FusionTensor::new( + ids.pop().unwrap(), + desc.shape, + desc.dtype, + client.clone(), + StreamId::current(), + ) + }); + let scale = FusionTensor::new( + ids.pop().unwrap(), + tensor.qparams.scale.shape, + tensor.qparams.scale.dtype, + client.clone(), + StreamId::current(), + ); + let qtensor = FusionTensor::new( + ids.pop().unwrap(), + tensor.tensor.shape, + tensor.tensor.dtype, + client, + StreamId::current(), + ); + + QFusionTensor { + qtensor, + scheme: tensor.scheme, + qparams: FusionQuantizationParameters { scale, offset }, + } + } + fn register_orphan(&self, id: &TensorId) { self.server.lock().drop_tensor_handle(*id); } diff --git a/crates/burn-fusion/src/ops/qtensor.rs b/crates/burn-fusion/src/ops/qtensor.rs index e6479f5d3..a6a51ce3a 100644 --- a/crates/burn-fusion/src/ops/qtensor.rs +++ b/crates/burn-fusion/src/ops/qtensor.rs @@ -1,35 +1,214 @@ -use std::ops::Range; +use std::{marker::PhantomData, ops::Range}; use burn_tensor::{ - ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, - quantization::{QuantizationParametersPrimitive, QuantizationScheme}, - Device, Shape, TensorData, + ops::{FloatElem, FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, + quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationStrategy}, + repr::{ + DequantizeOperationDescription, FloatOperationDescription, HandleContainer, + OperationDescription, QuantizationParametersDescription, QuantizeOperationDescription, + }, + DType, Device, Element, Shape, TensorData, }; -use crate::{client::FusionClient, Fusion, FusionBackend}; +use crate::{ + client::FusionClient, + get_client, + stream::{execution::Operation, StreamId}, + Fusion, FusionBackend, FusionQuantizationParameters, QFusionTensor, +}; impl QTensorOps for Fusion { - fn q_from_data(_data: TensorData, _device: &Device) -> QuantizedTensor { - unimplemented!() + fn q_from_data(data: TensorData, device: &Device) -> QuantizedTensor { + match data.dtype { + DType::QFloat(strategy) => { + let client = get_client::(device); + let tensor = B::q_from_data(data, device); + let shape = B::q_shape(&tensor); + + let mut handles = B::quantized_tensor_handle(tensor); + let qparams = match strategy { + QuantizationStrategy::PerTensorAffineInt8(_) => { + let num_handles = handles.len(); + assert_eq!( + num_handles, 3, + "Expected 3 handles for quantized tensor, got {num_handles}" + ); + let offset = handles.pop().unwrap(); + let scale = handles.pop().unwrap(); + FusionQuantizationParameters { + scale: client.register_tensor( + scale, + vec![1], + StreamId::current(), + B::FloatElem::dtype(), + ), + offset: Some(client.register_tensor( + offset, + vec![1], + StreamId::current(), + B::IntElem::dtype(), + )), + } + } + QuantizationStrategy::PerTensorSymmetricInt8(_) => { + let num_handles = handles.len(); + assert_eq!( + num_handles, 2, + "Expected 2 handles for quantized tensor, got {num_handles}" + ); + let scale = handles.pop().unwrap(); + FusionQuantizationParameters { + scale: client.register_tensor( + scale, + vec![1], + StreamId::current(), + B::FloatElem::dtype(), + ), + offset: None, + } + } + }; + let qtensor = client.register_tensor( + handles.pop().unwrap(), + shape.dims, + StreamId::current(), + B::QuantizedEncoding::dtype(), + ); + QFusionTensor { + qtensor, + qparams, + scheme: strategy.scheme(), + } + } + _ => panic!( + "Invalid dtype (expected DType::QFloat, got {:?})", + data.dtype + ), + } } fn quantize( - _tensor: FloatTensor, - _scheme: &QuantizationScheme, - _qparams: QuantizationParametersPrimitive, + tensor: FloatTensor, + scheme: &QuantizationScheme, + qparams: QuantizationParametersPrimitive, ) -> QuantizedTensor { - unimplemented!() + #[derive(new)] + struct QuantizeOp { + desc: QuantizeOperationDescription, + _b: PhantomData, + } + + impl Operation for QuantizeOp { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_float_tensor::(&self.desc.tensor); + let scale = handles.get_float_tensor::(&self.desc.qparams.scale); + let offset = self + .desc + .qparams + .offset + .as_ref() + .map(|x| handles.get_int_tensor::(x)); + + let qparams = QuantizationParametersPrimitive { scale, offset }; + let output = B::quantize(tensor, &self.desc.scheme, qparams); + if let Some(offset) = &self.desc.qparams.offset { + handles.register_quantized_tensor::( + &[&self.desc.out.id, &self.desc.qparams.scale.id, &offset.id], + output, + ); + } else { + handles.register_quantized_tensor::( + &[&self.desc.out.id, &self.desc.qparams.scale.id], + output, + ); + } + } + } + + let shape: Vec = tensor.shape.clone(); + let out = tensor + .client + .tensor_uninitialized(shape, B::QuantizedEncoding::dtype()); + + let streams = if let Some(offset) = &qparams.offset { + vec![tensor.stream, qparams.scale.stream, offset.stream] + } else { + vec![tensor.stream, qparams.scale.stream] + }; + + let desc = QuantizeOperationDescription { + tensor: tensor.into_description(), + qparams: QuantizationParametersDescription { + scale: qparams.scale.clone().into_description(), + offset: qparams.offset.clone().map(|x| x.into_description()), + }, + scheme: scheme.clone(), + out: out.to_description_out(), + }; + + out.client.register( + streams, + OperationDescription::Float( + FloatElem::::dtype(), + FloatOperationDescription::Quantize(desc.clone()), + ), + QuantizeOp::::new(desc), + ); + + QFusionTensor { + qtensor: out, + scheme: scheme.clone(), + qparams: qparams.into(), + } } - fn quantize_dynamic( - _tensor: FloatTensor, - _scheme: &QuantizationScheme, - ) -> QuantizedTensor { - unimplemented!() - } + fn dequantize(tensor: QuantizedTensor) -> FloatTensor { + #[derive(new)] + struct DequantizeOp { + desc: DequantizeOperationDescription, + _b: PhantomData, + } - fn dequantize(_tensor: QuantizedTensor) -> FloatTensor { - unimplemented!() + impl Operation for DequantizeOp { + fn execute(self: Box, handles: &mut HandleContainer) { + let tensor = handles.get_quantized_tensor::(&self.desc.qtensor); + + let output = B::dequantize(tensor); + handles.register_float_tensor::(&self.desc.out.id, output); + } + } + + let shape: Vec = tensor.qtensor.shape.clone(); + let out = tensor + .qtensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); + + let streams = if let Some(offset) = &tensor.qparams.offset { + vec![ + tensor.qtensor.stream, + tensor.qparams.scale.stream, + offset.stream, + ] + } else { + vec![tensor.qtensor.stream, tensor.qparams.scale.stream] + }; + + let desc = DequantizeOperationDescription { + qtensor: tensor.into_description(), + out: out.to_description_out(), + }; + + out.client.register( + streams, + OperationDescription::Float( + FloatElem::::dtype(), + FloatOperationDescription::Dequantize(desc.clone()), + ), + DequantizeOp::::new(desc), + ); + + out } fn q_shape(tensor: &QuantizedTensor) -> Shape { @@ -40,19 +219,38 @@ impl QTensorOps for Fusion { tensor.qtensor.client.device().clone() } - fn q_to_device( - _tensor: QuantizedTensor, - _device: &Device, - ) -> QuantizedTensor { - unimplemented!() + fn q_to_device(tensor: QuantizedTensor, device: &Device) -> QuantizedTensor { + // Quantization parameters are on the same device as the qtensor + let device_original: &B::Device = tensor.qtensor.client.device(); + let device_target: B::Device = device.clone(); + + if device_original == &device_target { + return tensor; + } + println!("q_to_device {:?} {:?}", device_original, device_target); + + let client_target = get_client::(&device_target); + let client_original = tensor.qtensor.client.clone(); + + let ids = if let Some(offset) = &tensor.qparams.offset { + vec![ + tensor.qtensor.stream, + tensor.qparams.scale.stream, + offset.stream, + ] + } else { + vec![tensor.qtensor.stream, tensor.qparams.scale.stream] + }; + + client_original.change_client_quantized::(tensor.into_description(), client_target, ids) } fn q_reshape(_tensor: QuantizedTensor, _shape: Shape) -> QuantizedTensor { unimplemented!() } - async fn q_into_data(_tensor: QuantizedTensor) -> TensorData { - unimplemented!() + async fn q_into_data(tensor: QuantizedTensor) -> TensorData { + tensor.into_data::().await } fn q_swap_dims( diff --git a/crates/burn-fusion/src/server.rs b/crates/burn-fusion/src/server.rs index 49507e878..58cb7f960 100644 --- a/crates/burn-fusion/src/server.rs +++ b/crates/burn-fusion/src/server.rs @@ -2,7 +2,9 @@ use crate::{ stream::{execution::Operation, MultiStream, StreamId}, FusionBackend, FusionRuntime, }; -use burn_tensor::repr::{HandleContainer, OperationDescription, TensorDescription, TensorId}; +use burn_tensor::repr::{ + HandleContainer, OperationDescription, QuantizedTensorDescription, TensorDescription, TensorId, +}; use std::sync::Arc; pub struct FusionServer { @@ -87,6 +89,24 @@ where B::bool_into_data(tensor).await } + pub async fn read_quantized( + &mut self, + tensor: QuantizedTensorDescription, + ids: Vec, + ) -> burn_tensor::TensorData + where + B: FusionBackend, + { + // Make sure all registered operations are executed. + // The underlying backend can still be async. + for id in ids { + self.drain_stream(id); + } + + let tensor = self.handles.get_quantized_tensor::(&tensor); + B::q_into_data(tensor).await + } + pub fn change_server_float( &mut self, tensor: &TensorDescription, @@ -147,6 +167,39 @@ where id } + pub fn change_server_quantized( + &mut self, + desc: &QuantizedTensorDescription, + device: &R::FusionDevice, + server_device: &mut Self, + ) -> Vec> + where + B: FusionBackend, + { + let tensor = self.handles.get_quantized_tensor::(desc); + let tensor = B::q_to_device(tensor, device); + if desc.qparams.offset.is_some() { + let tensor_id = server_device.create_empty_handle(); + let scale_id = server_device.create_empty_handle(); + let offset_id = server_device.create_empty_handle(); + + server_device + .handles + .register_quantized_tensor::(&[&tensor_id, &scale_id, &offset_id], tensor); + + vec![tensor_id, scale_id, offset_id] + } else { + let tensor_id = server_device.create_empty_handle(); + let scale_id = server_device.create_empty_handle(); + + server_device + .handles + .register_quantized_tensor::(&[&tensor_id, &scale_id], tensor); + + vec![tensor_id, scale_id] + } + } + pub fn drop_tensor_handle(&mut self, id: TensorId) { self.handles.handles_orphan.push(id); } diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index 6eee274bd..571a1254b 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -487,6 +487,39 @@ impl RelativeOpsScalar for FloatOperationDescription { out: desc.out.to_relative(converter), }) } + FloatOperationDescription::Quantize(desc) => { + FloatOperationDescription::Quantize(QuantizeOperationDescription { + tensor: desc.tensor.to_relative(converter), + qparams: QuantizationParametersDescription { + scale: desc.qparams.scale.to_relative(converter), + offset: desc + .qparams + .offset + .as_ref() + .map(|x| x.to_relative(converter)), + }, + scheme: desc.scheme.clone(), + out: desc.out.to_relative(converter), + }) + } + FloatOperationDescription::Dequantize(desc) => { + FloatOperationDescription::Dequantize(DequantizeOperationDescription { + qtensor: QuantizedTensorDescription { + tensor: desc.qtensor.tensor.to_relative(converter), + qparams: QuantizationParametersDescription { + scale: desc.qtensor.qparams.scale.to_relative(converter), + offset: desc + .qtensor + .qparams + .offset + .as_ref() + .map(|x| x.to_relative(converter)), + }, + scheme: desc.qtensor.scheme.clone(), + }, + out: desc.out.to_relative(converter), + }) + } } } } diff --git a/crates/burn-fusion/src/tensor.rs b/crates/burn-fusion/src/tensor.rs index 3239662e1..15edf0da2 100644 --- a/crates/burn-fusion/src/tensor.rs +++ b/crates/burn-fusion/src/tensor.rs @@ -1,7 +1,12 @@ -use crate::{client::FusionClient, stream::StreamId, Client, FusionBackend, FusionRuntime}; +use crate::{client::FusionClient, stream::StreamId, Client, Fusion, FusionBackend, FusionRuntime}; use burn_tensor::{ - quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy}, - repr::{TensorDescription, TensorId, TensorStatus}, + quantization::{ + QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme, QuantizationStrategy, + }, + repr::{ + QuantizationParametersDescription, QuantizedTensorDescription, TensorDescription, TensorId, + TensorStatus, + }, DType, Shape, TensorData, }; use std::sync::Arc; @@ -166,6 +171,8 @@ pub struct QFusionTensor { pub qtensor: FusionTensor, /// The quantization scheme. pub scheme: QuantizationScheme, + /// The quantization parameters. + pub qparams: FusionQuantizationParameters, } impl QTensorPrimitive for QFusionTensor { @@ -174,6 +181,7 @@ impl QTensorPrimitive for QFusionTensor { } fn strategy(&self) -> QuantizationStrategy { + // TODO todo!() } } @@ -183,6 +191,72 @@ impl Clone for QFusionTensor { Self { qtensor: self.qtensor.clone(), scheme: self.scheme.clone(), + qparams: self.qparams.clone(), + } + } +} + +impl QFusionTensor { + pub(crate) async fn into_data(self) -> TensorData + where + B: FusionBackend, + { + let streams = if let Some(offset) = &self.qparams.offset { + vec![ + self.qtensor.stream, + self.qparams.scale.stream, + offset.stream, + ] + } else { + vec![self.qtensor.stream, self.qparams.scale.stream] + }; + + // Quantized tensor and qparams tensors client are the same + self.qtensor + .client + .clone() + .read_tensor_quantized::(self.into_description(), streams) + .await + } + + /// Description to be used when using an initialized tensor used as input. + pub(crate) fn into_description(self) -> QuantizedTensorDescription { + QuantizedTensorDescription { + tensor: self.qtensor.into_description(), + qparams: QuantizationParametersDescription { + scale: self.qparams.scale.into_description(), + offset: self.qparams.offset.map(|x| x.into_description()), + }, + scheme: self.scheme, + } + } +} + +/// The quantization parameters. +#[derive(Debug)] +pub struct FusionQuantizationParameters { + /// The scaling factor. + pub scale: FusionTensor, + /// The zero-point offset. + pub offset: Option>, +} + +impl Clone for FusionQuantizationParameters { + fn clone(&self) -> Self { + Self { + scale: self.scale.clone(), + offset: self.offset.clone(), + } + } +} + +impl From>> + for FusionQuantizationParameters +{ + fn from(value: QuantizationParametersPrimitive>) -> Self { + FusionQuantizationParameters { + scale: value.scale, + offset: value.offset, } } } diff --git a/crates/burn-jit/src/backend.rs b/crates/burn-jit/src/backend.rs index 14b7edaab..2e4271d6c 100644 --- a/crates/burn-jit/src/backend.rs +++ b/crates/burn-jit/src/backend.rs @@ -35,6 +35,7 @@ where type IntTensorPrimitive = JitTensor; type BoolTensorPrimitive = JitTensor; type QuantizedTensorPrimitive = QJitTensor; + type QuantizedEncoding = u32; fn name() -> String { format!("jit<{}>", R::name()) diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index 4fa5c3a68..e0dcde3bc 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -1,10 +1,12 @@ use super::{ElementWise, ElementWiseState}; -use crate::tensor::is_contiguous; +use crate::tensor::{is_contiguous, JitQuantizationParameters, QJitTensor}; use crate::{ element::JitElement, fusion::ElementWiseBuilder, kernel, tensor::JitTensor, FloatElement, IntElement, JitBackend, JitRuntime, }; use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime}; +use burn_tensor::quantization::QuantizationScheme; +use burn_tensor::repr::TensorHandle; use burn_tensor::{repr::ReprBackend, Shape}; use core::marker::PhantomData; use cubecl::client::ComputeClient; @@ -63,16 +65,55 @@ where impl ReprBackend for JitBackend { type Handle = JitFusionHandle; - fn float_tensor(handle: Self::Handle, shape: Shape) -> burn_tensor::ops::FloatTensor { - handle.into_tensor(shape) + fn float_tensor(handle: TensorHandle) -> burn_tensor::ops::FloatTensor { + handle.handle.into_tensor(handle.shape) } - fn int_tensor(handle: Self::Handle, shape: Shape) -> burn_tensor::ops::IntTensor { - handle.into_tensor(shape) + fn int_tensor(handle: TensorHandle) -> burn_tensor::ops::IntTensor { + handle.handle.into_tensor(handle.shape) } - fn bool_tensor(handle: Self::Handle, shape: Shape) -> burn_tensor::ops::BoolTensor { - handle.into_tensor(shape) + fn bool_tensor(handle: TensorHandle) -> burn_tensor::ops::BoolTensor { + handle.handle.into_tensor(handle.shape) + } + + fn quantized_tensor( + handles: Vec>, + scheme: QuantizationScheme, + ) -> burn_tensor::ops::QuantizedTensor { + match handles.len() { + // NOTE: the order of the handles is known [qtensor, scale, ] + 3 => { + let mut handles = handles; + let offset = handles.pop().unwrap(); + let scale = handles.pop().unwrap(); + let qtensor = handles.pop().unwrap(); + QJitTensor { + qtensor: qtensor.handle.into_tensor(qtensor.shape), + scheme, + qparams: JitQuantizationParameters { + scale: scale.handle.into_tensor(scale.shape), + offset: Some(offset.handle.into_tensor(offset.shape)), + }, + } + } + 2 => { + let mut handles = handles; + let scale = handles.pop().unwrap(); + let qtensor = handles.pop().unwrap(); + QJitTensor { + qtensor: qtensor.handle.into_tensor(qtensor.shape), + scheme, + qparams: JitQuantizationParameters { + scale: scale.handle.into_tensor(scale.shape), + offset: None, + }, + } + } + _ => { + panic!("Expected handles for the quantized tensor and its quantization parameters.") + } + } } fn float_tensor_handle(tensor: burn_tensor::ops::FloatTensor) -> Self::Handle { @@ -86,6 +127,19 @@ impl ReprBackend for JitBackend) -> Self::Handle { tensor.into() } + + fn quantized_tensor_handle( + tensor: burn_tensor::ops::QuantizedTensor, + ) -> Vec { + let qtensor: JitFusionHandle = tensor.qtensor.into(); + let scale: JitFusionHandle = tensor.qparams.scale.into(); + if let Some(offset) = tensor.qparams.offset { + let offset: JitFusionHandle = offset.into(); + vec![qtensor, scale, offset] + } else { + vec![qtensor, scale] + } + } } impl FusionRuntime for FusionJitRuntime { diff --git a/crates/burn-jit/src/ops/qtensor.rs b/crates/burn-jit/src/ops/qtensor.rs index 3d9c84374..494bf4d59 100644 --- a/crates/burn-jit/src/ops/qtensor.rs +++ b/crates/burn-jit/src/ops/qtensor.rs @@ -103,6 +103,9 @@ where fn q_to_device(tensor: QuantizedTensor, device: &Device) -> QuantizedTensor { let mut tensor = tensor; tensor.qtensor = super::to_device(tensor.qtensor, device); + tensor.qparams.scale = super::to_device(tensor.qparams.scale, device); + tensor.qparams.offset = tensor.qparams.offset.map(|x| super::to_device(x, device)); + tensor } diff --git a/crates/burn-jit/src/tests/mod.rs b/crates/burn-jit/src/tests/mod.rs index c805d9084..8730f7098 100644 --- a/crates/burn-jit/src/tests/mod.rs +++ b/crates/burn-jit/src/tests/mod.rs @@ -133,5 +133,13 @@ macro_rules! testgen_jit_fusion { burn_tensor::testgen_all!(); burn_autodiff::testgen_all!(); + + // Not all ops are implemented for quantization yet, notably missing: + // `q_swap_dims`, `q_permute`, `q_flip`, `q_gather`, `q_select`, `q_slice`, `q_expand` + // burn_tensor::testgen_quantization!(); + // test quantization + burn_tensor::testgen_calibration!(); + burn_tensor::testgen_scheme!(); + burn_tensor::testgen_quantize!(); }; } diff --git a/crates/burn-ndarray/src/backend.rs b/crates/burn-ndarray/src/backend.rs index b166cf71f..671b999ed 100644 --- a/crates/burn-ndarray/src/backend.rs +++ b/crates/burn-ndarray/src/backend.rs @@ -53,6 +53,7 @@ impl Backend for NdArray { type BoolTensorPrimitive = NdArrayTensor; type QuantizedTensorPrimitive = NdArrayQTensor; + type QuantizedEncoding = Q; fn ad_enabled() -> bool { false diff --git a/crates/burn-tch/src/backend.rs b/crates/burn-tch/src/backend.rs index d5ab268e0..2d9864cdd 100644 --- a/crates/burn-tch/src/backend.rs +++ b/crates/burn-tch/src/backend.rs @@ -104,6 +104,7 @@ impl Backend for LibTorch { type BoolTensorPrimitive = TchTensor; type QuantizedTensorPrimitive = TchQTensor; + type QuantizedEncoding = Q; fn seed(seed: u64) { tch::manual_seed(seed as i64); diff --git a/crates/burn-tensor/src/repr/backend.rs b/crates/burn-tensor/src/repr/backend.rs index 61b31f089..9853a2eaf 100644 --- a/crates/burn-tensor/src/repr/backend.rs +++ b/crates/burn-tensor/src/repr/backend.rs @@ -1,8 +1,18 @@ use crate::{ backend::Backend, - ops::{BoolTensor, FloatTensor, IntTensor}, + ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}, + quantization::QuantizationScheme, Shape, }; +use alloc::vec::Vec; + +/// A tensor representation containing a reference to a tensor resource with a given shape. +pub struct TensorHandle { + /// The type that can be used to point to a tensor of any kind. + pub handle: H, + /// The shape associated to the tensor. + pub shape: Shape, +} /// Backend extension trait that allows an existing [backend](Backend) to use the Burn tensor representation /// for compilation purpose or other... @@ -11,11 +21,16 @@ pub trait ReprBackend: Backend { type Handle: Sync + Send + Clone; /// Convert a [handle](ReprBackend::Handle) to a [float tensor](Backend::FloatTensorPrimitive). - fn float_tensor(handle: Self::Handle, shape: Shape) -> FloatTensor; + fn float_tensor(handle: TensorHandle) -> FloatTensor; /// Convert a [handle](ReprBackend::Handle) to an [int tensor](Backend::IntTensorPrimitive). - fn int_tensor(handle: Self::Handle, shape: Shape) -> IntTensor; + fn int_tensor(handle: TensorHandle) -> IntTensor; /// Convert a [handle](ReprBackend::Handle) to a [bool tensor](Backend::BoolTensorPrimitive). - fn bool_tensor(handle: Self::Handle, shape: Shape) -> BoolTensor; + fn bool_tensor(handle: TensorHandle) -> BoolTensor; + /// Convert a [handle](ReprBackend::Handle) to a [quantized tensor](Backend::QuantizedTensorPrimitive). + fn quantized_tensor( + handles: Vec>, + scheme: QuantizationScheme, + ) -> QuantizedTensor; /// Convert a [float tensor](Backend::FloatTensorPrimitive) to a [handle](ReprBackend::Handle). fn float_tensor_handle(tensor: FloatTensor) -> Self::Handle; @@ -23,4 +38,7 @@ pub trait ReprBackend: Backend { fn int_tensor_handle(tensor: IntTensor) -> Self::Handle; /// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](ReprBackend::Handle). fn bool_tensor_handle(tensor: BoolTensor) -> Self::Handle; + /// Convert a [quantized tensor](Backend::QuantizedTensorPrimitive) to a [handle](ReprBackend::Handle). + /// A quantized tensor has multiple handles for the tensor itself and the quantization parameters. + fn quantized_tensor_handle(tensor: QuantizedTensor) -> Vec; } diff --git a/crates/burn-tensor/src/repr/handle.rs b/crates/burn-tensor/src/repr/handle.rs index 868c0cf5e..88b4b8934 100644 --- a/crates/burn-tensor/src/repr/handle.rs +++ b/crates/burn-tensor/src/repr/handle.rs @@ -7,6 +7,8 @@ use crate::{ }; use std::{collections::HashMap, sync::Arc}; +use super::{QuantizedTensorDescription, TensorHandle}; + /// Keep all [tensor handles](ReprBackend::Handle) in one place and ensure that all resources /// are used optimally. #[derive(Default)] @@ -66,16 +68,21 @@ impl HandleContainer { } } + /// Get the tensor handle for the given [tensor description](TensorDescription). + fn get_tensor_handle(&mut self, tensor: &TensorDescription) -> TensorHandle { + TensorHandle { + handle: self.get_handle(&tensor.id, &tensor.status), + shape: Shape::from(&tensor.shape), + } + } + /// Get the [float tensor](crate::backend::Backend::FloatTensorPrimitive) corresponding to the /// given [tensor description](TensorDescription). pub fn get_float_tensor(&mut self, tensor: &TensorDescription) -> B::FloatTensorPrimitive where B: ReprBackend, { - B::float_tensor( - self.get_handle(&tensor.id, &tensor.status), - Shape::from(&tensor.shape), - ) + B::float_tensor(self.get_tensor_handle(tensor)) } /// Get the [int tensor](crate::backend::Backend::IntTensorPrimitive) corresponding to the @@ -84,10 +91,7 @@ impl HandleContainer { where B: ReprBackend, { - B::int_tensor( - self.get_handle(&tensor.id, &tensor.status), - Shape::from(&tensor.shape), - ) + B::int_tensor(self.get_tensor_handle(tensor)) } /// Get the [bool tensor](crate::backend::Backend::BoolTensorPrimitive) corresponding to the @@ -96,10 +100,26 @@ impl HandleContainer { where B: ReprBackend, { - B::bool_tensor( - self.get_handle(&tensor.id, &tensor.status), - Shape::from(&tensor.shape), - ) + B::bool_tensor(self.get_tensor_handle(tensor)) + } + + /// Get the [quantized tensor](crate::backend::Backend::QuantizedTensorPrimitive) corresponding to the + /// given [tensor description](TensorDescription). + pub fn get_quantized_tensor( + &mut self, + tensor: &QuantizedTensorDescription, + ) -> B::QuantizedTensorPrimitive + where + B: ReprBackend, + { + let qtensor = self.get_tensor_handle(&tensor.tensor); + let scale = self.get_tensor_handle(&tensor.qparams.scale); + let handles = if let Some(offset) = &tensor.qparams.offset { + vec![qtensor, scale, self.get_tensor_handle(offset)] + } else { + vec![qtensor, scale] + }; + B::quantized_tensor(handles, tensor.scheme.clone()) } /// Register a new [float tensor](crate::backend::Backend::FloatTensorPrimitive) with the corresponding [tensor id](TensorId). @@ -111,6 +131,26 @@ impl HandleContainer { self.handles.insert(*id, Handle::Existing(handle)); } + /// Register a new [quantized tensor](crate::backend::Backend::QuantizedTensorPrimitive) with the corresponding [tensor ids](TensorId). + pub fn register_quantized_tensor( + &mut self, + ids: &[&TensorId], + tensor: B::QuantizedTensorPrimitive, + ) where + B: ReprBackend, + { + let handles = B::quantized_tensor_handle(tensor); + assert_eq!( + ids.len(), + handles.len(), + "Number of tensor ids and handles must match" + ); + + for (handle, id) in handles.into_iter().zip(ids) { + self.handles.insert(**id, Handle::Existing(handle)); + } + } + /// Register a new [int tensor](crate::backend::Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId). pub fn register_int_tensor(&mut self, id: &TensorId, tensor: B::IntTensorPrimitive) where diff --git a/crates/burn-tensor/src/repr/mod.rs b/crates/burn-tensor/src/repr/mod.rs index b98e43ba3..e26565b35 100644 --- a/crates/burn-tensor/src/repr/mod.rs +++ b/crates/burn-tensor/src/repr/mod.rs @@ -1,9 +1,11 @@ mod backend; mod handle; mod operation; +mod quantization; mod tensor; pub use backend::*; pub use handle::*; pub use operation::*; +pub use quantization::*; pub use tensor::*; diff --git a/crates/burn-tensor/src/repr/operation.rs b/crates/burn-tensor/src/repr/operation.rs index 634867875..4e5148bdd 100644 --- a/crates/burn-tensor/src/repr/operation.rs +++ b/crates/burn-tensor/src/repr/operation.rs @@ -5,10 +5,13 @@ use crate::{ ops::{ ConvOptions, ConvTransposeOptions, DeformConvOptions, InterpolateMode, InterpolateOptions, }, + quantization::QuantizationScheme, repr::tensor::TensorDescription, DType, Distribution, Element, }; +use super::{QuantizationParametersDescription, QuantizedTensorDescription}; + /// Describe all tensor operations possible. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub enum OperationDescription { @@ -61,6 +64,10 @@ pub enum FloatOperationDescription { Random(RandomOperationDescription), /// Operation corresponding to [recip](crate::ops::FloatTensorOps::float_recip). Recip(UnaryOperationDescription), + /// Operation corresponding to [quantize](crate::ops::QTensorOps::quantize). + Quantize(QuantizeOperationDescription), + /// Operation corresponding to [dequantize](crate::ops::QTensorOps::dequantize). + Dequantize(DequantizeOperationDescription), } /// Operation description specific to module. @@ -830,6 +837,22 @@ pub struct ConvTranspose3dOptionsDescription { pub groups: usize, } +#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] +#[allow(missing_docs)] +pub struct QuantizeOperationDescription { + pub tensor: TensorDescription, + pub qparams: QuantizationParametersDescription, + pub scheme: QuantizationScheme, + pub out: TensorDescription, +} + +#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] +#[allow(missing_docs)] +pub struct DequantizeOperationDescription { + pub qtensor: QuantizedTensorDescription, + pub out: TensorDescription, +} + impl From> for Conv1dOptionsDescription { fn from(value: ConvOptions<1>) -> Self { Self { @@ -1421,6 +1444,25 @@ impl FloatOperationDescription { FloatOperationDescription::Sin(desc) => vec![&desc.input, &desc.out], FloatOperationDescription::Tanh(desc) => vec![&desc.input, &desc.out], FloatOperationDescription::IntoInt(desc) => vec![&desc.input, &desc.out], + FloatOperationDescription::Quantize(desc) => { + if let Some(offset) = &desc.qparams.offset { + vec![&desc.tensor, &desc.qparams.scale, &offset, &desc.out] + } else { + vec![&desc.tensor, &desc.qparams.scale, &desc.out] + } + } + FloatOperationDescription::Dequantize(desc) => { + if let Some(offset) = &desc.qtensor.qparams.offset { + vec![ + &desc.qtensor.tensor, + &desc.qtensor.qparams.scale, + &offset, + &desc.out, + ] + } else { + vec![&desc.qtensor.tensor, &desc.qtensor.qparams.scale, &desc.out] + } + } } } } diff --git a/crates/burn-tensor/src/repr/quantization.rs b/crates/burn-tensor/src/repr/quantization.rs new file mode 100644 index 000000000..8a38a35a4 --- /dev/null +++ b/crates/burn-tensor/src/repr/quantization.rs @@ -0,0 +1,25 @@ +use serde::{Deserialize, Serialize}; + +use crate::quantization::QuantizationScheme; + +use super::TensorDescription; + +/// A quantized tensor description represents a snapshot of a quantized tensor when it was used. +#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] +pub struct QuantizedTensorDescription { + /// The quantized tensor. + pub tensor: TensorDescription, + /// The quantization parameters. + pub qparams: QuantizationParametersDescription, + /// The quantization scheme + pub scheme: QuantizationScheme, +} + +/// Quantization parameters description. +#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] +pub struct QuantizationParametersDescription { + /// The scaling factor. + pub scale: TensorDescription, + /// The zero-point offset. + pub offset: Option, +} diff --git a/crates/burn-tensor/src/tensor/backend/base.rs b/crates/burn-tensor/src/tensor/backend/base.rs index d7018484e..df78fbf5e 100644 --- a/crates/burn-tensor/src/tensor/backend/base.rs +++ b/crates/burn-tensor/src/tensor/backend/base.rs @@ -88,6 +88,8 @@ pub trait Backend: /// Tensor primitive to be used for all quantized operations. type QuantizedTensorPrimitive: QTensorPrimitive + Clone + Send + 'static + core::fmt::Debug; + /// Quantized tensor encoding type. + type QuantizedEncoding: Element; /// If autodiff is enabled. fn ad_enabled() -> bool { diff --git a/crates/burn-tensor/src/tensor/quantization/scheme.rs b/crates/burn-tensor/src/tensor/quantization/scheme.rs index 2526fe1f1..d91b23162 100644 --- a/crates/burn-tensor/src/tensor/quantization/scheme.rs +++ b/crates/burn-tensor/src/tensor/quantization/scheme.rs @@ -1,16 +1,18 @@ +use serde::{Deserialize, Serialize}; + use crate::{backend::Backend, Tensor, TensorPrimitive}; use super::{CalibrationRange, QuantizationParameters, QuantizationParametersPrimitive}; /// Quantization data type. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] pub enum QuantizationType { /// 8-bit signed integer. QInt8, } /// Quantization scheme. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] pub enum QuantizationScheme { /// Per-tensor affine/asymmetric quantization. PerTensorAffine(QuantizationType),