Add fusion quantization ops (#2301)

* Add fusion quantize/dequantize and from/into data

* Add quantization tests for fusion

* Fix jit q_to_device

* Add q_to_device

* Add comment

* Add note on handles/streams order

* Remove unused field

* Fix clippy

* Add QuantizedEncoding associated type
This commit is contained in:
Guillaume Lagrange 2024-09-26 11:22:26 -04:00 committed by GitHub
parent ce2d8e0465
commit c7d2be06ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 705 additions and 59 deletions

View File

@ -36,6 +36,7 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
type BoolTensorPrimitive = B::BoolTensorPrimitive;
type QuantizedTensorPrimitive = B::QuantizedTensorPrimitive;
type QuantizedEncoding = B::QuantizedEncoding;
fn ad_enabled() -> bool {
true

View File

@ -172,6 +172,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
type BoolTensorPrimitive = CandleTensor<u8>;
type QuantizedTensorPrimitive = CandleQTensor;
type QuantizedEncoding = u8;
fn ad_enabled() -> bool {
false

View File

@ -40,6 +40,8 @@ impl<B: FusionBackend> Backend for Fusion<B> {
type QuantizedTensorPrimitive = QFusionTensor<B::FusionRuntime>;
type QuantizedEncoding = B::QuantizedEncoding;
fn name() -> String {
format!("fusion<{}>", B::name())
}

View File

@ -2,10 +2,10 @@ use std::future::Future;
use crate::{
stream::{execution::Operation, StreamId},
FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionTensor,
FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionTensor, QFusionTensor,
};
use burn_tensor::{
repr::{OperationDescription, TensorDescription, TensorId},
repr::{OperationDescription, QuantizedTensorDescription, TensorDescription, TensorId},
DType, TensorData,
};
@ -56,6 +56,14 @@ where
tensor: TensorDescription,
stream: StreamId,
) -> impl Future<Output = TensorData> + Send
where
B: FusionBackend<FusionRuntime = R>;
/// Read the values contained by a quantized tensor.
fn read_tensor_quantized<B>(
&self,
tensor: QuantizedTensorDescription,
streams: Vec<StreamId>,
) -> impl Future<Output = TensorData> + Send
where
B: FusionBackend<FusionRuntime = R>;
/// Change the client of the given float tensor.
@ -83,6 +91,15 @@ where
client: Self,
stream: StreamId,
) -> FusionTensor<R>
where
B: FusionBackend<FusionRuntime = R>;
/// Change the client of the given quantized tensor.
fn change_client_quantized<B>(
&self,
tensor: QuantizedTensorDescription,
client: Self,
streams: Vec<StreamId>,
) -> QFusionTensor<R>
where
B: FusionBackend<FusionRuntime = R>;
/// Drop the tensor with the given [tensor id](TensorId).

View File

@ -1,10 +1,11 @@
use super::FusionClient;
use crate::{
stream::{execution::Operation, StreamId},
FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionServer, FusionTensor,
FusionBackend, FusionDevice, FusionHandle, FusionQuantizationParameters, FusionRuntime,
FusionServer, FusionTensor, QFusionTensor,
};
use burn_tensor::{
repr::{OperationDescription, TensorDescription, TensorId},
repr::{OperationDescription, QuantizedTensorDescription, TensorDescription, TensorId},
DType,
};
use spin::Mutex;
@ -111,6 +112,20 @@ where
self.server.lock().read_bool::<B>(tensor, stream).await
}
async fn read_tensor_quantized<B>(
&self,
tensor: QuantizedTensorDescription,
streams: Vec<StreamId>,
) -> burn_tensor::TensorData
where
B: FusionBackend<FusionRuntime = R>,
{
self.server
.lock()
.read_quantized::<B>(tensor, streams)
.await
}
fn change_client_float<B>(
&self,
tensor: TensorDescription,
@ -175,6 +190,59 @@ where
FusionTensor::new(id, tensor.shape, tensor.dtype, client, StreamId::current())
}
fn change_client_quantized<B>(
&self,
tensor: QuantizedTensorDescription,
client: Self,
streams: Vec<StreamId>,
) -> QFusionTensor<R>
where
B: FusionBackend<FusionRuntime = R>,
{
let mut server_other = client.server.lock();
let mut server_current = self.server.lock();
for stream in streams {
server_current.drain_stream(stream);
}
let mut ids =
server_current.change_server_quantized::<B>(&tensor, &client.device, &mut server_other);
core::mem::drop(server_other);
core::mem::drop(server_current);
// NOTE: the expected order is known [qtensor, scale, <offset>]
let offset = tensor.qparams.offset.map(|desc| {
FusionTensor::new(
ids.pop().unwrap(),
desc.shape,
desc.dtype,
client.clone(),
StreamId::current(),
)
});
let scale = FusionTensor::new(
ids.pop().unwrap(),
tensor.qparams.scale.shape,
tensor.qparams.scale.dtype,
client.clone(),
StreamId::current(),
);
let qtensor = FusionTensor::new(
ids.pop().unwrap(),
tensor.tensor.shape,
tensor.tensor.dtype,
client,
StreamId::current(),
);
QFusionTensor {
qtensor,
scheme: tensor.scheme,
qparams: FusionQuantizationParameters { scale, offset },
}
}
fn register_orphan(&self, id: &TensorId) {
self.server.lock().drop_tensor_handle(*id);
}

View File

@ -1,35 +1,214 @@
use std::ops::Range;
use std::{marker::PhantomData, ops::Range};
use burn_tensor::{
ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
quantization::{QuantizationParametersPrimitive, QuantizationScheme},
Device, Shape, TensorData,
ops::{FloatElem, FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationStrategy},
repr::{
DequantizeOperationDescription, FloatOperationDescription, HandleContainer,
OperationDescription, QuantizationParametersDescription, QuantizeOperationDescription,
},
DType, Device, Element, Shape, TensorData,
};
use crate::{client::FusionClient, Fusion, FusionBackend};
use crate::{
client::FusionClient,
get_client,
stream::{execution::Operation, StreamId},
Fusion, FusionBackend, FusionQuantizationParameters, QFusionTensor,
};
impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
fn q_from_data(_data: TensorData, _device: &Device<Self>) -> QuantizedTensor<Self> {
unimplemented!()
fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
match data.dtype {
DType::QFloat(strategy) => {
let client = get_client::<B>(device);
let tensor = B::q_from_data(data, device);
let shape = B::q_shape(&tensor);
let mut handles = B::quantized_tensor_handle(tensor);
let qparams = match strategy {
QuantizationStrategy::PerTensorAffineInt8(_) => {
let num_handles = handles.len();
assert_eq!(
num_handles, 3,
"Expected 3 handles for quantized tensor, got {num_handles}"
);
let offset = handles.pop().unwrap();
let scale = handles.pop().unwrap();
FusionQuantizationParameters {
scale: client.register_tensor(
scale,
vec![1],
StreamId::current(),
B::FloatElem::dtype(),
),
offset: Some(client.register_tensor(
offset,
vec![1],
StreamId::current(),
B::IntElem::dtype(),
)),
}
}
QuantizationStrategy::PerTensorSymmetricInt8(_) => {
let num_handles = handles.len();
assert_eq!(
num_handles, 2,
"Expected 2 handles for quantized tensor, got {num_handles}"
);
let scale = handles.pop().unwrap();
FusionQuantizationParameters {
scale: client.register_tensor(
scale,
vec![1],
StreamId::current(),
B::FloatElem::dtype(),
),
offset: None,
}
}
};
let qtensor = client.register_tensor(
handles.pop().unwrap(),
shape.dims,
StreamId::current(),
B::QuantizedEncoding::dtype(),
);
QFusionTensor {
qtensor,
qparams,
scheme: strategy.scheme(),
}
}
_ => panic!(
"Invalid dtype (expected DType::QFloat, got {:?})",
data.dtype
),
}
}
fn quantize(
_tensor: FloatTensor<Self>,
_scheme: &QuantizationScheme,
_qparams: QuantizationParametersPrimitive<Self>,
tensor: FloatTensor<Self>,
scheme: &QuantizationScheme,
qparams: QuantizationParametersPrimitive<Self>,
) -> QuantizedTensor<Self> {
unimplemented!()
#[derive(new)]
struct QuantizeOp<B: FusionBackend> {
desc: QuantizeOperationDescription,
_b: PhantomData<B>,
}
fn quantize_dynamic(
_tensor: FloatTensor<Self>,
_scheme: &QuantizationScheme,
) -> QuantizedTensor<Self> {
unimplemented!()
impl<B: FusionBackend> Operation<B::FusionRuntime> for QuantizeOp<B> {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
let scale = handles.get_float_tensor::<B>(&self.desc.qparams.scale);
let offset = self
.desc
.qparams
.offset
.as_ref()
.map(|x| handles.get_int_tensor::<B>(x));
let qparams = QuantizationParametersPrimitive { scale, offset };
let output = B::quantize(tensor, &self.desc.scheme, qparams);
if let Some(offset) = &self.desc.qparams.offset {
handles.register_quantized_tensor::<B>(
&[&self.desc.out.id, &self.desc.qparams.scale.id, &offset.id],
output,
);
} else {
handles.register_quantized_tensor::<B>(
&[&self.desc.out.id, &self.desc.qparams.scale.id],
output,
);
}
}
}
fn dequantize(_tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
unimplemented!()
let shape: Vec<usize> = tensor.shape.clone();
let out = tensor
.client
.tensor_uninitialized(shape, B::QuantizedEncoding::dtype());
let streams = if let Some(offset) = &qparams.offset {
vec![tensor.stream, qparams.scale.stream, offset.stream]
} else {
vec![tensor.stream, qparams.scale.stream]
};
let desc = QuantizeOperationDescription {
tensor: tensor.into_description(),
qparams: QuantizationParametersDescription {
scale: qparams.scale.clone().into_description(),
offset: qparams.offset.clone().map(|x| x.into_description()),
},
scheme: scheme.clone(),
out: out.to_description_out(),
};
out.client.register(
streams,
OperationDescription::Float(
FloatElem::<Self>::dtype(),
FloatOperationDescription::Quantize(desc.clone()),
),
QuantizeOp::<B>::new(desc),
);
QFusionTensor {
qtensor: out,
scheme: scheme.clone(),
qparams: qparams.into(),
}
}
fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
#[derive(new)]
struct DequantizeOp<B: FusionBackend> {
desc: DequantizeOperationDescription,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for DequantizeOp<B> {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
let tensor = handles.get_quantized_tensor::<B>(&self.desc.qtensor);
let output = B::dequantize(tensor);
handles.register_float_tensor::<B>(&self.desc.out.id, output);
}
}
let shape: Vec<usize> = tensor.qtensor.shape.clone();
let out = tensor
.qtensor
.client
.tensor_uninitialized(shape, B::FloatElem::dtype());
let streams = if let Some(offset) = &tensor.qparams.offset {
vec![
tensor.qtensor.stream,
tensor.qparams.scale.stream,
offset.stream,
]
} else {
vec![tensor.qtensor.stream, tensor.qparams.scale.stream]
};
let desc = DequantizeOperationDescription {
qtensor: tensor.into_description(),
out: out.to_description_out(),
};
out.client.register(
streams,
OperationDescription::Float(
FloatElem::<Self>::dtype(),
FloatOperationDescription::Dequantize(desc.clone()),
),
DequantizeOp::<B>::new(desc),
);
out
}
fn q_shape(tensor: &QuantizedTensor<Self>) -> Shape {
@ -40,19 +219,38 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
tensor.qtensor.client.device().clone()
}
fn q_to_device(
_tensor: QuantizedTensor<Self>,
_device: &Device<Self>,
) -> QuantizedTensor<Self> {
unimplemented!()
fn q_to_device(tensor: QuantizedTensor<Self>, device: &Device<Self>) -> QuantizedTensor<Self> {
// Quantization parameters are on the same device as the qtensor
let device_original: &B::Device = tensor.qtensor.client.device();
let device_target: B::Device = device.clone();
if device_original == &device_target {
return tensor;
}
println!("q_to_device {:?} {:?}", device_original, device_target);
let client_target = get_client::<B>(&device_target);
let client_original = tensor.qtensor.client.clone();
let ids = if let Some(offset) = &tensor.qparams.offset {
vec![
tensor.qtensor.stream,
tensor.qparams.scale.stream,
offset.stream,
]
} else {
vec![tensor.qtensor.stream, tensor.qparams.scale.stream]
};
client_original.change_client_quantized::<B>(tensor.into_description(), client_target, ids)
}
fn q_reshape(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
unimplemented!()
}
async fn q_into_data(_tensor: QuantizedTensor<Self>) -> TensorData {
unimplemented!()
async fn q_into_data(tensor: QuantizedTensor<Self>) -> TensorData {
tensor.into_data::<B>().await
}
fn q_swap_dims(

View File

@ -2,7 +2,9 @@ use crate::{
stream::{execution::Operation, MultiStream, StreamId},
FusionBackend, FusionRuntime,
};
use burn_tensor::repr::{HandleContainer, OperationDescription, TensorDescription, TensorId};
use burn_tensor::repr::{
HandleContainer, OperationDescription, QuantizedTensorDescription, TensorDescription, TensorId,
};
use std::sync::Arc;
pub struct FusionServer<R: FusionRuntime> {
@ -87,6 +89,24 @@ where
B::bool_into_data(tensor).await
}
pub async fn read_quantized<B>(
&mut self,
tensor: QuantizedTensorDescription,
ids: Vec<StreamId>,
) -> burn_tensor::TensorData
where
B: FusionBackend<FusionRuntime = R>,
{
// Make sure all registered operations are executed.
// The underlying backend can still be async.
for id in ids {
self.drain_stream(id);
}
let tensor = self.handles.get_quantized_tensor::<B>(&tensor);
B::q_into_data(tensor).await
}
pub fn change_server_float<B>(
&mut self,
tensor: &TensorDescription,
@ -147,6 +167,39 @@ where
id
}
pub fn change_server_quantized<B>(
&mut self,
desc: &QuantizedTensorDescription,
device: &R::FusionDevice,
server_device: &mut Self,
) -> Vec<Arc<TensorId>>
where
B: FusionBackend<FusionRuntime = R>,
{
let tensor = self.handles.get_quantized_tensor::<B>(desc);
let tensor = B::q_to_device(tensor, device);
if desc.qparams.offset.is_some() {
let tensor_id = server_device.create_empty_handle();
let scale_id = server_device.create_empty_handle();
let offset_id = server_device.create_empty_handle();
server_device
.handles
.register_quantized_tensor::<B>(&[&tensor_id, &scale_id, &offset_id], tensor);
vec![tensor_id, scale_id, offset_id]
} else {
let tensor_id = server_device.create_empty_handle();
let scale_id = server_device.create_empty_handle();
server_device
.handles
.register_quantized_tensor::<B>(&[&tensor_id, &scale_id], tensor);
vec![tensor_id, scale_id]
}
}
pub fn drop_tensor_handle(&mut self, id: TensorId) {
self.handles.handles_orphan.push(id);
}

View File

@ -487,6 +487,39 @@ impl RelativeOpsScalar<f32> for FloatOperationDescription {
out: desc.out.to_relative(converter),
})
}
FloatOperationDescription::Quantize(desc) => {
FloatOperationDescription::Quantize(QuantizeOperationDescription {
tensor: desc.tensor.to_relative(converter),
qparams: QuantizationParametersDescription {
scale: desc.qparams.scale.to_relative(converter),
offset: desc
.qparams
.offset
.as_ref()
.map(|x| x.to_relative(converter)),
},
scheme: desc.scheme.clone(),
out: desc.out.to_relative(converter),
})
}
FloatOperationDescription::Dequantize(desc) => {
FloatOperationDescription::Dequantize(DequantizeOperationDescription {
qtensor: QuantizedTensorDescription {
tensor: desc.qtensor.tensor.to_relative(converter),
qparams: QuantizationParametersDescription {
scale: desc.qtensor.qparams.scale.to_relative(converter),
offset: desc
.qtensor
.qparams
.offset
.as_ref()
.map(|x| x.to_relative(converter)),
},
scheme: desc.qtensor.scheme.clone(),
},
out: desc.out.to_relative(converter),
})
}
}
}
}

View File

@ -1,7 +1,12 @@
use crate::{client::FusionClient, stream::StreamId, Client, FusionBackend, FusionRuntime};
use crate::{client::FusionClient, stream::StreamId, Client, Fusion, FusionBackend, FusionRuntime};
use burn_tensor::{
quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy},
repr::{TensorDescription, TensorId, TensorStatus},
quantization::{
QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme, QuantizationStrategy,
},
repr::{
QuantizationParametersDescription, QuantizedTensorDescription, TensorDescription, TensorId,
TensorStatus,
},
DType, Shape, TensorData,
};
use std::sync::Arc;
@ -166,6 +171,8 @@ pub struct QFusionTensor<R: FusionRuntime> {
pub qtensor: FusionTensor<R>,
/// The quantization scheme.
pub scheme: QuantizationScheme,
/// The quantization parameters.
pub qparams: FusionQuantizationParameters<R>,
}
impl<R: FusionRuntime> QTensorPrimitive for QFusionTensor<R> {
@ -174,6 +181,7 @@ impl<R: FusionRuntime> QTensorPrimitive for QFusionTensor<R> {
}
fn strategy(&self) -> QuantizationStrategy {
// TODO
todo!()
}
}
@ -183,6 +191,72 @@ impl<R: FusionRuntime> Clone for QFusionTensor<R> {
Self {
qtensor: self.qtensor.clone(),
scheme: self.scheme.clone(),
qparams: self.qparams.clone(),
}
}
}
impl<R: FusionRuntime> QFusionTensor<R> {
pub(crate) async fn into_data<B>(self) -> TensorData
where
B: FusionBackend<FusionRuntime = R>,
{
let streams = if let Some(offset) = &self.qparams.offset {
vec![
self.qtensor.stream,
self.qparams.scale.stream,
offset.stream,
]
} else {
vec![self.qtensor.stream, self.qparams.scale.stream]
};
// Quantized tensor and qparams tensors client are the same
self.qtensor
.client
.clone()
.read_tensor_quantized::<B>(self.into_description(), streams)
.await
}
/// Description to be used when using an initialized tensor used as input.
pub(crate) fn into_description(self) -> QuantizedTensorDescription {
QuantizedTensorDescription {
tensor: self.qtensor.into_description(),
qparams: QuantizationParametersDescription {
scale: self.qparams.scale.into_description(),
offset: self.qparams.offset.map(|x| x.into_description()),
},
scheme: self.scheme,
}
}
}
/// The quantization parameters.
#[derive(Debug)]
pub struct FusionQuantizationParameters<R: FusionRuntime> {
/// The scaling factor.
pub scale: FusionTensor<R>,
/// The zero-point offset.
pub offset: Option<FusionTensor<R>>,
}
impl<R: FusionRuntime> Clone for FusionQuantizationParameters<R> {
fn clone(&self) -> Self {
Self {
scale: self.scale.clone(),
offset: self.offset.clone(),
}
}
}
impl<B: FusionBackend> From<QuantizationParametersPrimitive<Fusion<B>>>
for FusionQuantizationParameters<B::FusionRuntime>
{
fn from(value: QuantizationParametersPrimitive<Fusion<B>>) -> Self {
FusionQuantizationParameters {
scale: value.scale,
offset: value.offset,
}
}
}

View File

@ -35,6 +35,7 @@ where
type IntTensorPrimitive = JitTensor<R, Self::IntElem>;
type BoolTensorPrimitive = JitTensor<R, u32>;
type QuantizedTensorPrimitive = QJitTensor<R, Self::FloatElem, Self::IntElem>;
type QuantizedEncoding = u32;
fn name() -> String {
format!("jit<{}>", R::name())

View File

@ -1,10 +1,12 @@
use super::{ElementWise, ElementWiseState};
use crate::tensor::is_contiguous;
use crate::tensor::{is_contiguous, JitQuantizationParameters, QJitTensor};
use crate::{
element::JitElement, fusion::ElementWiseBuilder, kernel, tensor::JitTensor, FloatElement,
IntElement, JitBackend, JitRuntime,
};
use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime};
use burn_tensor::quantization::QuantizationScheme;
use burn_tensor::repr::TensorHandle;
use burn_tensor::{repr::ReprBackend, Shape};
use core::marker::PhantomData;
use cubecl::client::ComputeClient;
@ -63,16 +65,55 @@ where
impl<R: JitRuntime, F: FloatElement, I: IntElement> ReprBackend for JitBackend<R, F, I> {
type Handle = JitFusionHandle<R>;
fn float_tensor(handle: Self::Handle, shape: Shape) -> burn_tensor::ops::FloatTensor<Self> {
handle.into_tensor(shape)
fn float_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::FloatTensor<Self> {
handle.handle.into_tensor(handle.shape)
}
fn int_tensor(handle: Self::Handle, shape: Shape) -> burn_tensor::ops::IntTensor<Self> {
handle.into_tensor(shape)
fn int_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::IntTensor<Self> {
handle.handle.into_tensor(handle.shape)
}
fn bool_tensor(handle: Self::Handle, shape: Shape) -> burn_tensor::ops::BoolTensor<Self> {
handle.into_tensor(shape)
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::BoolTensor<Self> {
handle.handle.into_tensor(handle.shape)
}
fn quantized_tensor(
handles: Vec<TensorHandle<Self::Handle>>,
scheme: QuantizationScheme,
) -> burn_tensor::ops::QuantizedTensor<Self> {
match handles.len() {
// NOTE: the order of the handles is known [qtensor, scale, <offset>]
3 => {
let mut handles = handles;
let offset = handles.pop().unwrap();
let scale = handles.pop().unwrap();
let qtensor = handles.pop().unwrap();
QJitTensor {
qtensor: qtensor.handle.into_tensor(qtensor.shape),
scheme,
qparams: JitQuantizationParameters {
scale: scale.handle.into_tensor(scale.shape),
offset: Some(offset.handle.into_tensor(offset.shape)),
},
}
}
2 => {
let mut handles = handles;
let scale = handles.pop().unwrap();
let qtensor = handles.pop().unwrap();
QJitTensor {
qtensor: qtensor.handle.into_tensor(qtensor.shape),
scheme,
qparams: JitQuantizationParameters {
scale: scale.handle.into_tensor(scale.shape),
offset: None,
},
}
}
_ => {
panic!("Expected handles for the quantized tensor and its quantization parameters.")
}
}
}
fn float_tensor_handle(tensor: burn_tensor::ops::FloatTensor<Self>) -> Self::Handle {
@ -86,6 +127,19 @@ impl<R: JitRuntime, F: FloatElement, I: IntElement> ReprBackend for JitBackend<R
fn bool_tensor_handle(tensor: burn_tensor::ops::BoolTensor<Self>) -> Self::Handle {
tensor.into()
}
fn quantized_tensor_handle(
tensor: burn_tensor::ops::QuantizedTensor<Self>,
) -> Vec<Self::Handle> {
let qtensor: JitFusionHandle<R> = tensor.qtensor.into();
let scale: JitFusionHandle<R> = tensor.qparams.scale.into();
if let Some(offset) = tensor.qparams.offset {
let offset: JitFusionHandle<R> = offset.into();
vec![qtensor, scale, offset]
} else {
vec![qtensor, scale]
}
}
}
impl<R: JitRuntime> FusionRuntime for FusionJitRuntime<R> {

View File

@ -103,6 +103,9 @@ where
fn q_to_device(tensor: QuantizedTensor<Self>, device: &Device<Self>) -> QuantizedTensor<Self> {
let mut tensor = tensor;
tensor.qtensor = super::to_device(tensor.qtensor, device);
tensor.qparams.scale = super::to_device(tensor.qparams.scale, device);
tensor.qparams.offset = tensor.qparams.offset.map(|x| super::to_device(x, device));
tensor
}

View File

@ -133,5 +133,13 @@ macro_rules! testgen_jit_fusion {
burn_tensor::testgen_all!();
burn_autodiff::testgen_all!();
// Not all ops are implemented for quantization yet, notably missing:
// `q_swap_dims`, `q_permute`, `q_flip`, `q_gather`, `q_select`, `q_slice`, `q_expand`
// burn_tensor::testgen_quantization!();
// test quantization
burn_tensor::testgen_calibration!();
burn_tensor::testgen_scheme!();
burn_tensor::testgen_quantize!();
};
}

View File

@ -53,6 +53,7 @@ impl<E: FloatNdArrayElement, Q: QuantElement> Backend for NdArray<E, Q> {
type BoolTensorPrimitive = NdArrayTensor<bool>;
type QuantizedTensorPrimitive = NdArrayQTensor<Q>;
type QuantizedEncoding = Q;
fn ad_enabled() -> bool {
false

View File

@ -104,6 +104,7 @@ impl<E: TchElement, Q: QuantElement> Backend for LibTorch<E, Q> {
type BoolTensorPrimitive = TchTensor<bool>;
type QuantizedTensorPrimitive = TchQTensor<Q>;
type QuantizedEncoding = Q;
fn seed(seed: u64) {
tch::manual_seed(seed as i64);

View File

@ -1,8 +1,18 @@
use crate::{
backend::Backend,
ops::{BoolTensor, FloatTensor, IntTensor},
ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
quantization::QuantizationScheme,
Shape,
};
use alloc::vec::Vec;
/// A tensor representation containing a reference to a tensor resource with a given shape.
pub struct TensorHandle<H> {
/// The type that can be used to point to a tensor of any kind.
pub handle: H,
/// The shape associated to the tensor.
pub shape: Shape,
}
/// Backend extension trait that allows an existing [backend](Backend) to use the Burn tensor representation
/// for compilation purpose or other...
@ -11,11 +21,16 @@ pub trait ReprBackend: Backend {
type Handle: Sync + Send + Clone;
/// Convert a [handle](ReprBackend::Handle) to a [float tensor](Backend::FloatTensorPrimitive).
fn float_tensor(handle: Self::Handle, shape: Shape) -> FloatTensor<Self>;
fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self>;
/// Convert a [handle](ReprBackend::Handle) to an [int tensor](Backend::IntTensorPrimitive).
fn int_tensor(handle: Self::Handle, shape: Shape) -> IntTensor<Self>;
fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self>;
/// Convert a [handle](ReprBackend::Handle) to a [bool tensor](Backend::BoolTensorPrimitive).
fn bool_tensor(handle: Self::Handle, shape: Shape) -> BoolTensor<Self>;
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self>;
/// Convert a [handle](ReprBackend::Handle) to a [quantized tensor](Backend::QuantizedTensorPrimitive).
fn quantized_tensor(
handles: Vec<TensorHandle<Self::Handle>>,
scheme: QuantizationScheme,
) -> QuantizedTensor<Self>;
/// Convert a [float tensor](Backend::FloatTensorPrimitive) to a [handle](ReprBackend::Handle).
fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle;
@ -23,4 +38,7 @@ pub trait ReprBackend: Backend {
fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle;
/// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](ReprBackend::Handle).
fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle;
/// Convert a [quantized tensor](Backend::QuantizedTensorPrimitive) to a [handle](ReprBackend::Handle).
/// A quantized tensor has multiple handles for the tensor itself and the quantization parameters.
fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Vec<Self::Handle>;
}

View File

@ -7,6 +7,8 @@ use crate::{
};
use std::{collections::HashMap, sync::Arc};
use super::{QuantizedTensorDescription, TensorHandle};
/// Keep all [tensor handles](ReprBackend::Handle) in one place and ensure that all resources
/// are used optimally.
#[derive(Default)]
@ -66,16 +68,21 @@ impl<H: Clone> HandleContainer<H> {
}
}
/// Get the tensor handle for the given [tensor description](TensorDescription).
fn get_tensor_handle(&mut self, tensor: &TensorDescription) -> TensorHandle<H> {
TensorHandle {
handle: self.get_handle(&tensor.id, &tensor.status),
shape: Shape::from(&tensor.shape),
}
}
/// Get the [float tensor](crate::backend::Backend::FloatTensorPrimitive) corresponding to the
/// given [tensor description](TensorDescription).
pub fn get_float_tensor<B>(&mut self, tensor: &TensorDescription) -> B::FloatTensorPrimitive
where
B: ReprBackend<Handle = H>,
{
B::float_tensor(
self.get_handle(&tensor.id, &tensor.status),
Shape::from(&tensor.shape),
)
B::float_tensor(self.get_tensor_handle(tensor))
}
/// Get the [int tensor](crate::backend::Backend::IntTensorPrimitive) corresponding to the
@ -84,10 +91,7 @@ impl<H: Clone> HandleContainer<H> {
where
B: ReprBackend<Handle = H>,
{
B::int_tensor(
self.get_handle(&tensor.id, &tensor.status),
Shape::from(&tensor.shape),
)
B::int_tensor(self.get_tensor_handle(tensor))
}
/// Get the [bool tensor](crate::backend::Backend::BoolTensorPrimitive) corresponding to the
@ -96,10 +100,26 @@ impl<H: Clone> HandleContainer<H> {
where
B: ReprBackend<Handle = H>,
{
B::bool_tensor(
self.get_handle(&tensor.id, &tensor.status),
Shape::from(&tensor.shape),
)
B::bool_tensor(self.get_tensor_handle(tensor))
}
/// Get the [quantized tensor](crate::backend::Backend::QuantizedTensorPrimitive) corresponding to the
/// given [tensor description](TensorDescription).
pub fn get_quantized_tensor<B>(
&mut self,
tensor: &QuantizedTensorDescription,
) -> B::QuantizedTensorPrimitive
where
B: ReprBackend<Handle = H>,
{
let qtensor = self.get_tensor_handle(&tensor.tensor);
let scale = self.get_tensor_handle(&tensor.qparams.scale);
let handles = if let Some(offset) = &tensor.qparams.offset {
vec![qtensor, scale, self.get_tensor_handle(offset)]
} else {
vec![qtensor, scale]
};
B::quantized_tensor(handles, tensor.scheme.clone())
}
/// Register a new [float tensor](crate::backend::Backend::FloatTensorPrimitive) with the corresponding [tensor id](TensorId).
@ -111,6 +131,26 @@ impl<H: Clone> HandleContainer<H> {
self.handles.insert(*id, Handle::Existing(handle));
}
/// Register a new [quantized tensor](crate::backend::Backend::QuantizedTensorPrimitive) with the corresponding [tensor ids](TensorId).
pub fn register_quantized_tensor<B>(
&mut self,
ids: &[&TensorId],
tensor: B::QuantizedTensorPrimitive,
) where
B: ReprBackend<Handle = H>,
{
let handles = B::quantized_tensor_handle(tensor);
assert_eq!(
ids.len(),
handles.len(),
"Number of tensor ids and handles must match"
);
for (handle, id) in handles.into_iter().zip(ids) {
self.handles.insert(**id, Handle::Existing(handle));
}
}
/// Register a new [int tensor](crate::backend::Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId).
pub fn register_int_tensor<B>(&mut self, id: &TensorId, tensor: B::IntTensorPrimitive)
where

View File

@ -1,9 +1,11 @@
mod backend;
mod handle;
mod operation;
mod quantization;
mod tensor;
pub use backend::*;
pub use handle::*;
pub use operation::*;
pub use quantization::*;
pub use tensor::*;

View File

@ -5,10 +5,13 @@ use crate::{
ops::{
ConvOptions, ConvTransposeOptions, DeformConvOptions, InterpolateMode, InterpolateOptions,
},
quantization::QuantizationScheme,
repr::tensor::TensorDescription,
DType, Distribution, Element,
};
use super::{QuantizationParametersDescription, QuantizedTensorDescription};
/// Describe all tensor operations possible.
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub enum OperationDescription {
@ -61,6 +64,10 @@ pub enum FloatOperationDescription {
Random(RandomOperationDescription),
/// Operation corresponding to [recip](crate::ops::FloatTensorOps::float_recip).
Recip(UnaryOperationDescription),
/// Operation corresponding to [quantize](crate::ops::QTensorOps::quantize).
Quantize(QuantizeOperationDescription),
/// Operation corresponding to [dequantize](crate::ops::QTensorOps::dequantize).
Dequantize(DequantizeOperationDescription),
}
/// Operation description specific to module.
@ -830,6 +837,22 @@ pub struct ConvTranspose3dOptionsDescription {
pub groups: usize,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct QuantizeOperationDescription {
pub tensor: TensorDescription,
pub qparams: QuantizationParametersDescription,
pub scheme: QuantizationScheme,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct DequantizeOperationDescription {
pub qtensor: QuantizedTensorDescription,
pub out: TensorDescription,
}
impl From<ConvOptions<1>> for Conv1dOptionsDescription {
fn from(value: ConvOptions<1>) -> Self {
Self {
@ -1421,6 +1444,25 @@ impl FloatOperationDescription {
FloatOperationDescription::Sin(desc) => vec![&desc.input, &desc.out],
FloatOperationDescription::Tanh(desc) => vec![&desc.input, &desc.out],
FloatOperationDescription::IntoInt(desc) => vec![&desc.input, &desc.out],
FloatOperationDescription::Quantize(desc) => {
if let Some(offset) = &desc.qparams.offset {
vec![&desc.tensor, &desc.qparams.scale, &offset, &desc.out]
} else {
vec![&desc.tensor, &desc.qparams.scale, &desc.out]
}
}
FloatOperationDescription::Dequantize(desc) => {
if let Some(offset) = &desc.qtensor.qparams.offset {
vec![
&desc.qtensor.tensor,
&desc.qtensor.qparams.scale,
&offset,
&desc.out,
]
} else {
vec![&desc.qtensor.tensor, &desc.qtensor.qparams.scale, &desc.out]
}
}
}
}
}

View File

@ -0,0 +1,25 @@
use serde::{Deserialize, Serialize};
use crate::quantization::QuantizationScheme;
use super::TensorDescription;
/// A quantized tensor description represents a snapshot of a quantized tensor when it was used.
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct QuantizedTensorDescription {
/// The quantized tensor.
pub tensor: TensorDescription,
/// The quantization parameters.
pub qparams: QuantizationParametersDescription,
/// The quantization scheme
pub scheme: QuantizationScheme,
}
/// Quantization parameters description.
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct QuantizationParametersDescription {
/// The scaling factor.
pub scale: TensorDescription,
/// The zero-point offset.
pub offset: Option<TensorDescription>,
}

View File

@ -88,6 +88,8 @@ pub trait Backend:
/// Tensor primitive to be used for all quantized operations.
type QuantizedTensorPrimitive: QTensorPrimitive + Clone + Send + 'static + core::fmt::Debug;
/// Quantized tensor encoding type.
type QuantizedEncoding: Element;
/// If autodiff is enabled.
fn ad_enabled() -> bool {

View File

@ -1,16 +1,18 @@
use serde::{Deserialize, Serialize};
use crate::{backend::Backend, Tensor, TensorPrimitive};
use super::{CalibrationRange, QuantizationParameters, QuantizationParametersPrimitive};
/// Quantization data type.
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub enum QuantizationType {
/// 8-bit signed integer.
QInt8,
}
/// Quantization scheme.
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub enum QuantizationScheme {
/// Per-tensor affine/asymmetric quantization.
PerTensorAffine(QuantizationType),