mirror of https://github.com/tracel-ai/burn.git
Add fusion quantization ops (#2301)
* Add fusion quantize/dequantize and from/into data * Add quantization tests for fusion * Fix jit q_to_device * Add q_to_device * Add comment * Add note on handles/streams order * Remove unused field * Fix clippy * Add QuantizedEncoding associated type
This commit is contained in:
parent
ce2d8e0465
commit
c7d2be06ed
|
@ -36,6 +36,7 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
|
|||
type BoolTensorPrimitive = B::BoolTensorPrimitive;
|
||||
|
||||
type QuantizedTensorPrimitive = B::QuantizedTensorPrimitive;
|
||||
type QuantizedEncoding = B::QuantizedEncoding;
|
||||
|
||||
fn ad_enabled() -> bool {
|
||||
true
|
||||
|
|
|
@ -172,6 +172,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
|
|||
type BoolTensorPrimitive = CandleTensor<u8>;
|
||||
|
||||
type QuantizedTensorPrimitive = CandleQTensor;
|
||||
type QuantizedEncoding = u8;
|
||||
|
||||
fn ad_enabled() -> bool {
|
||||
false
|
||||
|
|
|
@ -40,6 +40,8 @@ impl<B: FusionBackend> Backend for Fusion<B> {
|
|||
|
||||
type QuantizedTensorPrimitive = QFusionTensor<B::FusionRuntime>;
|
||||
|
||||
type QuantizedEncoding = B::QuantizedEncoding;
|
||||
|
||||
fn name() -> String {
|
||||
format!("fusion<{}>", B::name())
|
||||
}
|
||||
|
|
|
@ -2,10 +2,10 @@ use std::future::Future;
|
|||
|
||||
use crate::{
|
||||
stream::{execution::Operation, StreamId},
|
||||
FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionTensor,
|
||||
FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionTensor, QFusionTensor,
|
||||
};
|
||||
use burn_tensor::{
|
||||
repr::{OperationDescription, TensorDescription, TensorId},
|
||||
repr::{OperationDescription, QuantizedTensorDescription, TensorDescription, TensorId},
|
||||
DType, TensorData,
|
||||
};
|
||||
|
||||
|
@ -56,6 +56,14 @@ where
|
|||
tensor: TensorDescription,
|
||||
stream: StreamId,
|
||||
) -> impl Future<Output = TensorData> + Send
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>;
|
||||
/// Read the values contained by a quantized tensor.
|
||||
fn read_tensor_quantized<B>(
|
||||
&self,
|
||||
tensor: QuantizedTensorDescription,
|
||||
streams: Vec<StreamId>,
|
||||
) -> impl Future<Output = TensorData> + Send
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>;
|
||||
/// Change the client of the given float tensor.
|
||||
|
@ -83,6 +91,15 @@ where
|
|||
client: Self,
|
||||
stream: StreamId,
|
||||
) -> FusionTensor<R>
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>;
|
||||
/// Change the client of the given quantized tensor.
|
||||
fn change_client_quantized<B>(
|
||||
&self,
|
||||
tensor: QuantizedTensorDescription,
|
||||
client: Self,
|
||||
streams: Vec<StreamId>,
|
||||
) -> QFusionTensor<R>
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>;
|
||||
/// Drop the tensor with the given [tensor id](TensorId).
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
use super::FusionClient;
|
||||
use crate::{
|
||||
stream::{execution::Operation, StreamId},
|
||||
FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionServer, FusionTensor,
|
||||
FusionBackend, FusionDevice, FusionHandle, FusionQuantizationParameters, FusionRuntime,
|
||||
FusionServer, FusionTensor, QFusionTensor,
|
||||
};
|
||||
use burn_tensor::{
|
||||
repr::{OperationDescription, TensorDescription, TensorId},
|
||||
repr::{OperationDescription, QuantizedTensorDescription, TensorDescription, TensorId},
|
||||
DType,
|
||||
};
|
||||
use spin::Mutex;
|
||||
|
@ -111,6 +112,20 @@ where
|
|||
self.server.lock().read_bool::<B>(tensor, stream).await
|
||||
}
|
||||
|
||||
async fn read_tensor_quantized<B>(
|
||||
&self,
|
||||
tensor: QuantizedTensorDescription,
|
||||
streams: Vec<StreamId>,
|
||||
) -> burn_tensor::TensorData
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
self.server
|
||||
.lock()
|
||||
.read_quantized::<B>(tensor, streams)
|
||||
.await
|
||||
}
|
||||
|
||||
fn change_client_float<B>(
|
||||
&self,
|
||||
tensor: TensorDescription,
|
||||
|
@ -175,6 +190,59 @@ where
|
|||
FusionTensor::new(id, tensor.shape, tensor.dtype, client, StreamId::current())
|
||||
}
|
||||
|
||||
fn change_client_quantized<B>(
|
||||
&self,
|
||||
tensor: QuantizedTensorDescription,
|
||||
client: Self,
|
||||
streams: Vec<StreamId>,
|
||||
) -> QFusionTensor<R>
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
let mut server_other = client.server.lock();
|
||||
let mut server_current = self.server.lock();
|
||||
for stream in streams {
|
||||
server_current.drain_stream(stream);
|
||||
}
|
||||
|
||||
let mut ids =
|
||||
server_current.change_server_quantized::<B>(&tensor, &client.device, &mut server_other);
|
||||
|
||||
core::mem::drop(server_other);
|
||||
core::mem::drop(server_current);
|
||||
|
||||
// NOTE: the expected order is known [qtensor, scale, <offset>]
|
||||
let offset = tensor.qparams.offset.map(|desc| {
|
||||
FusionTensor::new(
|
||||
ids.pop().unwrap(),
|
||||
desc.shape,
|
||||
desc.dtype,
|
||||
client.clone(),
|
||||
StreamId::current(),
|
||||
)
|
||||
});
|
||||
let scale = FusionTensor::new(
|
||||
ids.pop().unwrap(),
|
||||
tensor.qparams.scale.shape,
|
||||
tensor.qparams.scale.dtype,
|
||||
client.clone(),
|
||||
StreamId::current(),
|
||||
);
|
||||
let qtensor = FusionTensor::new(
|
||||
ids.pop().unwrap(),
|
||||
tensor.tensor.shape,
|
||||
tensor.tensor.dtype,
|
||||
client,
|
||||
StreamId::current(),
|
||||
);
|
||||
|
||||
QFusionTensor {
|
||||
qtensor,
|
||||
scheme: tensor.scheme,
|
||||
qparams: FusionQuantizationParameters { scale, offset },
|
||||
}
|
||||
}
|
||||
|
||||
fn register_orphan(&self, id: &TensorId) {
|
||||
self.server.lock().drop_tensor_handle(*id);
|
||||
}
|
||||
|
|
|
@ -1,35 +1,214 @@
|
|||
use std::ops::Range;
|
||||
use std::{marker::PhantomData, ops::Range};
|
||||
|
||||
use burn_tensor::{
|
||||
ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
|
||||
quantization::{QuantizationParametersPrimitive, QuantizationScheme},
|
||||
Device, Shape, TensorData,
|
||||
ops::{FloatElem, FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
|
||||
quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationStrategy},
|
||||
repr::{
|
||||
DequantizeOperationDescription, FloatOperationDescription, HandleContainer,
|
||||
OperationDescription, QuantizationParametersDescription, QuantizeOperationDescription,
|
||||
},
|
||||
DType, Device, Element, Shape, TensorData,
|
||||
};
|
||||
|
||||
use crate::{client::FusionClient, Fusion, FusionBackend};
|
||||
use crate::{
|
||||
client::FusionClient,
|
||||
get_client,
|
||||
stream::{execution::Operation, StreamId},
|
||||
Fusion, FusionBackend, FusionQuantizationParameters, QFusionTensor,
|
||||
};
|
||||
|
||||
impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
|
||||
fn q_from_data(_data: TensorData, _device: &Device<Self>) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
|
||||
match data.dtype {
|
||||
DType::QFloat(strategy) => {
|
||||
let client = get_client::<B>(device);
|
||||
let tensor = B::q_from_data(data, device);
|
||||
let shape = B::q_shape(&tensor);
|
||||
|
||||
let mut handles = B::quantized_tensor_handle(tensor);
|
||||
let qparams = match strategy {
|
||||
QuantizationStrategy::PerTensorAffineInt8(_) => {
|
||||
let num_handles = handles.len();
|
||||
assert_eq!(
|
||||
num_handles, 3,
|
||||
"Expected 3 handles for quantized tensor, got {num_handles}"
|
||||
);
|
||||
let offset = handles.pop().unwrap();
|
||||
let scale = handles.pop().unwrap();
|
||||
FusionQuantizationParameters {
|
||||
scale: client.register_tensor(
|
||||
scale,
|
||||
vec![1],
|
||||
StreamId::current(),
|
||||
B::FloatElem::dtype(),
|
||||
),
|
||||
offset: Some(client.register_tensor(
|
||||
offset,
|
||||
vec![1],
|
||||
StreamId::current(),
|
||||
B::IntElem::dtype(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
QuantizationStrategy::PerTensorSymmetricInt8(_) => {
|
||||
let num_handles = handles.len();
|
||||
assert_eq!(
|
||||
num_handles, 2,
|
||||
"Expected 2 handles for quantized tensor, got {num_handles}"
|
||||
);
|
||||
let scale = handles.pop().unwrap();
|
||||
FusionQuantizationParameters {
|
||||
scale: client.register_tensor(
|
||||
scale,
|
||||
vec![1],
|
||||
StreamId::current(),
|
||||
B::FloatElem::dtype(),
|
||||
),
|
||||
offset: None,
|
||||
}
|
||||
}
|
||||
};
|
||||
let qtensor = client.register_tensor(
|
||||
handles.pop().unwrap(),
|
||||
shape.dims,
|
||||
StreamId::current(),
|
||||
B::QuantizedEncoding::dtype(),
|
||||
);
|
||||
QFusionTensor {
|
||||
qtensor,
|
||||
qparams,
|
||||
scheme: strategy.scheme(),
|
||||
}
|
||||
}
|
||||
_ => panic!(
|
||||
"Invalid dtype (expected DType::QFloat, got {:?})",
|
||||
data.dtype
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn quantize(
|
||||
_tensor: FloatTensor<Self>,
|
||||
_scheme: &QuantizationScheme,
|
||||
_qparams: QuantizationParametersPrimitive<Self>,
|
||||
tensor: FloatTensor<Self>,
|
||||
scheme: &QuantizationScheme,
|
||||
qparams: QuantizationParametersPrimitive<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
#[derive(new)]
|
||||
struct QuantizeOp<B: FusionBackend> {
|
||||
desc: QuantizeOperationDescription,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for QuantizeOp<B> {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
|
||||
let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
|
||||
let scale = handles.get_float_tensor::<B>(&self.desc.qparams.scale);
|
||||
let offset = self
|
||||
.desc
|
||||
.qparams
|
||||
.offset
|
||||
.as_ref()
|
||||
.map(|x| handles.get_int_tensor::<B>(x));
|
||||
|
||||
let qparams = QuantizationParametersPrimitive { scale, offset };
|
||||
let output = B::quantize(tensor, &self.desc.scheme, qparams);
|
||||
if let Some(offset) = &self.desc.qparams.offset {
|
||||
handles.register_quantized_tensor::<B>(
|
||||
&[&self.desc.out.id, &self.desc.qparams.scale.id, &offset.id],
|
||||
output,
|
||||
);
|
||||
} else {
|
||||
handles.register_quantized_tensor::<B>(
|
||||
&[&self.desc.out.id, &self.desc.qparams.scale.id],
|
||||
output,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let shape: Vec<usize> = tensor.shape.clone();
|
||||
let out = tensor
|
||||
.client
|
||||
.tensor_uninitialized(shape, B::QuantizedEncoding::dtype());
|
||||
|
||||
let streams = if let Some(offset) = &qparams.offset {
|
||||
vec![tensor.stream, qparams.scale.stream, offset.stream]
|
||||
} else {
|
||||
vec![tensor.stream, qparams.scale.stream]
|
||||
};
|
||||
|
||||
let desc = QuantizeOperationDescription {
|
||||
tensor: tensor.into_description(),
|
||||
qparams: QuantizationParametersDescription {
|
||||
scale: qparams.scale.clone().into_description(),
|
||||
offset: qparams.offset.clone().map(|x| x.into_description()),
|
||||
},
|
||||
scheme: scheme.clone(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
out.client.register(
|
||||
streams,
|
||||
OperationDescription::Float(
|
||||
FloatElem::<Self>::dtype(),
|
||||
FloatOperationDescription::Quantize(desc.clone()),
|
||||
),
|
||||
QuantizeOp::<B>::new(desc),
|
||||
);
|
||||
|
||||
QFusionTensor {
|
||||
qtensor: out,
|
||||
scheme: scheme.clone(),
|
||||
qparams: qparams.into(),
|
||||
}
|
||||
}
|
||||
|
||||
fn quantize_dynamic(
|
||||
_tensor: FloatTensor<Self>,
|
||||
_scheme: &QuantizationScheme,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
|
||||
#[derive(new)]
|
||||
struct DequantizeOp<B: FusionBackend> {
|
||||
desc: DequantizeOperationDescription,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
fn dequantize(_tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
|
||||
unimplemented!()
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for DequantizeOp<B> {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
|
||||
let tensor = handles.get_quantized_tensor::<B>(&self.desc.qtensor);
|
||||
|
||||
let output = B::dequantize(tensor);
|
||||
handles.register_float_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let shape: Vec<usize> = tensor.qtensor.shape.clone();
|
||||
let out = tensor
|
||||
.qtensor
|
||||
.client
|
||||
.tensor_uninitialized(shape, B::FloatElem::dtype());
|
||||
|
||||
let streams = if let Some(offset) = &tensor.qparams.offset {
|
||||
vec![
|
||||
tensor.qtensor.stream,
|
||||
tensor.qparams.scale.stream,
|
||||
offset.stream,
|
||||
]
|
||||
} else {
|
||||
vec![tensor.qtensor.stream, tensor.qparams.scale.stream]
|
||||
};
|
||||
|
||||
let desc = DequantizeOperationDescription {
|
||||
qtensor: tensor.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
out.client.register(
|
||||
streams,
|
||||
OperationDescription::Float(
|
||||
FloatElem::<Self>::dtype(),
|
||||
FloatOperationDescription::Dequantize(desc.clone()),
|
||||
),
|
||||
DequantizeOp::<B>::new(desc),
|
||||
);
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn q_shape(tensor: &QuantizedTensor<Self>) -> Shape {
|
||||
|
@ -40,19 +219,38 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
|
|||
tensor.qtensor.client.device().clone()
|
||||
}
|
||||
|
||||
fn q_to_device(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_device: &Device<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
fn q_to_device(tensor: QuantizedTensor<Self>, device: &Device<Self>) -> QuantizedTensor<Self> {
|
||||
// Quantization parameters are on the same device as the qtensor
|
||||
let device_original: &B::Device = tensor.qtensor.client.device();
|
||||
let device_target: B::Device = device.clone();
|
||||
|
||||
if device_original == &device_target {
|
||||
return tensor;
|
||||
}
|
||||
println!("q_to_device {:?} {:?}", device_original, device_target);
|
||||
|
||||
let client_target = get_client::<B>(&device_target);
|
||||
let client_original = tensor.qtensor.client.clone();
|
||||
|
||||
let ids = if let Some(offset) = &tensor.qparams.offset {
|
||||
vec![
|
||||
tensor.qtensor.stream,
|
||||
tensor.qparams.scale.stream,
|
||||
offset.stream,
|
||||
]
|
||||
} else {
|
||||
vec![tensor.qtensor.stream, tensor.qparams.scale.stream]
|
||||
};
|
||||
|
||||
client_original.change_client_quantized::<B>(tensor.into_description(), client_target, ids)
|
||||
}
|
||||
|
||||
fn q_reshape(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn q_into_data(_tensor: QuantizedTensor<Self>) -> TensorData {
|
||||
unimplemented!()
|
||||
async fn q_into_data(tensor: QuantizedTensor<Self>) -> TensorData {
|
||||
tensor.into_data::<B>().await
|
||||
}
|
||||
|
||||
fn q_swap_dims(
|
||||
|
|
|
@ -2,7 +2,9 @@ use crate::{
|
|||
stream::{execution::Operation, MultiStream, StreamId},
|
||||
FusionBackend, FusionRuntime,
|
||||
};
|
||||
use burn_tensor::repr::{HandleContainer, OperationDescription, TensorDescription, TensorId};
|
||||
use burn_tensor::repr::{
|
||||
HandleContainer, OperationDescription, QuantizedTensorDescription, TensorDescription, TensorId,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct FusionServer<R: FusionRuntime> {
|
||||
|
@ -87,6 +89,24 @@ where
|
|||
B::bool_into_data(tensor).await
|
||||
}
|
||||
|
||||
pub async fn read_quantized<B>(
|
||||
&mut self,
|
||||
tensor: QuantizedTensorDescription,
|
||||
ids: Vec<StreamId>,
|
||||
) -> burn_tensor::TensorData
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
// Make sure all registered operations are executed.
|
||||
// The underlying backend can still be async.
|
||||
for id in ids {
|
||||
self.drain_stream(id);
|
||||
}
|
||||
|
||||
let tensor = self.handles.get_quantized_tensor::<B>(&tensor);
|
||||
B::q_into_data(tensor).await
|
||||
}
|
||||
|
||||
pub fn change_server_float<B>(
|
||||
&mut self,
|
||||
tensor: &TensorDescription,
|
||||
|
@ -147,6 +167,39 @@ where
|
|||
id
|
||||
}
|
||||
|
||||
pub fn change_server_quantized<B>(
|
||||
&mut self,
|
||||
desc: &QuantizedTensorDescription,
|
||||
device: &R::FusionDevice,
|
||||
server_device: &mut Self,
|
||||
) -> Vec<Arc<TensorId>>
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
let tensor = self.handles.get_quantized_tensor::<B>(desc);
|
||||
let tensor = B::q_to_device(tensor, device);
|
||||
if desc.qparams.offset.is_some() {
|
||||
let tensor_id = server_device.create_empty_handle();
|
||||
let scale_id = server_device.create_empty_handle();
|
||||
let offset_id = server_device.create_empty_handle();
|
||||
|
||||
server_device
|
||||
.handles
|
||||
.register_quantized_tensor::<B>(&[&tensor_id, &scale_id, &offset_id], tensor);
|
||||
|
||||
vec![tensor_id, scale_id, offset_id]
|
||||
} else {
|
||||
let tensor_id = server_device.create_empty_handle();
|
||||
let scale_id = server_device.create_empty_handle();
|
||||
|
||||
server_device
|
||||
.handles
|
||||
.register_quantized_tensor::<B>(&[&tensor_id, &scale_id], tensor);
|
||||
|
||||
vec![tensor_id, scale_id]
|
||||
}
|
||||
}
|
||||
|
||||
pub fn drop_tensor_handle(&mut self, id: TensorId) {
|
||||
self.handles.handles_orphan.push(id);
|
||||
}
|
||||
|
|
|
@ -487,6 +487,39 @@ impl RelativeOpsScalar<f32> for FloatOperationDescription {
|
|||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
FloatOperationDescription::Quantize(desc) => {
|
||||
FloatOperationDescription::Quantize(QuantizeOperationDescription {
|
||||
tensor: desc.tensor.to_relative(converter),
|
||||
qparams: QuantizationParametersDescription {
|
||||
scale: desc.qparams.scale.to_relative(converter),
|
||||
offset: desc
|
||||
.qparams
|
||||
.offset
|
||||
.as_ref()
|
||||
.map(|x| x.to_relative(converter)),
|
||||
},
|
||||
scheme: desc.scheme.clone(),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
FloatOperationDescription::Dequantize(desc) => {
|
||||
FloatOperationDescription::Dequantize(DequantizeOperationDescription {
|
||||
qtensor: QuantizedTensorDescription {
|
||||
tensor: desc.qtensor.tensor.to_relative(converter),
|
||||
qparams: QuantizationParametersDescription {
|
||||
scale: desc.qtensor.qparams.scale.to_relative(converter),
|
||||
offset: desc
|
||||
.qtensor
|
||||
.qparams
|
||||
.offset
|
||||
.as_ref()
|
||||
.map(|x| x.to_relative(converter)),
|
||||
},
|
||||
scheme: desc.qtensor.scheme.clone(),
|
||||
},
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
use crate::{client::FusionClient, stream::StreamId, Client, FusionBackend, FusionRuntime};
|
||||
use crate::{client::FusionClient, stream::StreamId, Client, Fusion, FusionBackend, FusionRuntime};
|
||||
use burn_tensor::{
|
||||
quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy},
|
||||
repr::{TensorDescription, TensorId, TensorStatus},
|
||||
quantization::{
|
||||
QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme, QuantizationStrategy,
|
||||
},
|
||||
repr::{
|
||||
QuantizationParametersDescription, QuantizedTensorDescription, TensorDescription, TensorId,
|
||||
TensorStatus,
|
||||
},
|
||||
DType, Shape, TensorData,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
@ -166,6 +171,8 @@ pub struct QFusionTensor<R: FusionRuntime> {
|
|||
pub qtensor: FusionTensor<R>,
|
||||
/// The quantization scheme.
|
||||
pub scheme: QuantizationScheme,
|
||||
/// The quantization parameters.
|
||||
pub qparams: FusionQuantizationParameters<R>,
|
||||
}
|
||||
|
||||
impl<R: FusionRuntime> QTensorPrimitive for QFusionTensor<R> {
|
||||
|
@ -174,6 +181,7 @@ impl<R: FusionRuntime> QTensorPrimitive for QFusionTensor<R> {
|
|||
}
|
||||
|
||||
fn strategy(&self) -> QuantizationStrategy {
|
||||
// TODO
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
@ -183,6 +191,72 @@ impl<R: FusionRuntime> Clone for QFusionTensor<R> {
|
|||
Self {
|
||||
qtensor: self.qtensor.clone(),
|
||||
scheme: self.scheme.clone(),
|
||||
qparams: self.qparams.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: FusionRuntime> QFusionTensor<R> {
|
||||
pub(crate) async fn into_data<B>(self) -> TensorData
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
let streams = if let Some(offset) = &self.qparams.offset {
|
||||
vec![
|
||||
self.qtensor.stream,
|
||||
self.qparams.scale.stream,
|
||||
offset.stream,
|
||||
]
|
||||
} else {
|
||||
vec![self.qtensor.stream, self.qparams.scale.stream]
|
||||
};
|
||||
|
||||
// Quantized tensor and qparams tensors client are the same
|
||||
self.qtensor
|
||||
.client
|
||||
.clone()
|
||||
.read_tensor_quantized::<B>(self.into_description(), streams)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Description to be used when using an initialized tensor used as input.
|
||||
pub(crate) fn into_description(self) -> QuantizedTensorDescription {
|
||||
QuantizedTensorDescription {
|
||||
tensor: self.qtensor.into_description(),
|
||||
qparams: QuantizationParametersDescription {
|
||||
scale: self.qparams.scale.into_description(),
|
||||
offset: self.qparams.offset.map(|x| x.into_description()),
|
||||
},
|
||||
scheme: self.scheme,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The quantization parameters.
|
||||
#[derive(Debug)]
|
||||
pub struct FusionQuantizationParameters<R: FusionRuntime> {
|
||||
/// The scaling factor.
|
||||
pub scale: FusionTensor<R>,
|
||||
/// The zero-point offset.
|
||||
pub offset: Option<FusionTensor<R>>,
|
||||
}
|
||||
|
||||
impl<R: FusionRuntime> Clone for FusionQuantizationParameters<R> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
scale: self.scale.clone(),
|
||||
offset: self.offset.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> From<QuantizationParametersPrimitive<Fusion<B>>>
|
||||
for FusionQuantizationParameters<B::FusionRuntime>
|
||||
{
|
||||
fn from(value: QuantizationParametersPrimitive<Fusion<B>>) -> Self {
|
||||
FusionQuantizationParameters {
|
||||
scale: value.scale,
|
||||
offset: value.offset,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,6 +35,7 @@ where
|
|||
type IntTensorPrimitive = JitTensor<R, Self::IntElem>;
|
||||
type BoolTensorPrimitive = JitTensor<R, u32>;
|
||||
type QuantizedTensorPrimitive = QJitTensor<R, Self::FloatElem, Self::IntElem>;
|
||||
type QuantizedEncoding = u32;
|
||||
|
||||
fn name() -> String {
|
||||
format!("jit<{}>", R::name())
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
use super::{ElementWise, ElementWiseState};
|
||||
use crate::tensor::is_contiguous;
|
||||
use crate::tensor::{is_contiguous, JitQuantizationParameters, QJitTensor};
|
||||
use crate::{
|
||||
element::JitElement, fusion::ElementWiseBuilder, kernel, tensor::JitTensor, FloatElement,
|
||||
IntElement, JitBackend, JitRuntime,
|
||||
};
|
||||
use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime};
|
||||
use burn_tensor::quantization::QuantizationScheme;
|
||||
use burn_tensor::repr::TensorHandle;
|
||||
use burn_tensor::{repr::ReprBackend, Shape};
|
||||
use core::marker::PhantomData;
|
||||
use cubecl::client::ComputeClient;
|
||||
|
@ -63,16 +65,55 @@ where
|
|||
impl<R: JitRuntime, F: FloatElement, I: IntElement> ReprBackend for JitBackend<R, F, I> {
|
||||
type Handle = JitFusionHandle<R>;
|
||||
|
||||
fn float_tensor(handle: Self::Handle, shape: Shape) -> burn_tensor::ops::FloatTensor<Self> {
|
||||
handle.into_tensor(shape)
|
||||
fn float_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::FloatTensor<Self> {
|
||||
handle.handle.into_tensor(handle.shape)
|
||||
}
|
||||
|
||||
fn int_tensor(handle: Self::Handle, shape: Shape) -> burn_tensor::ops::IntTensor<Self> {
|
||||
handle.into_tensor(shape)
|
||||
fn int_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::IntTensor<Self> {
|
||||
handle.handle.into_tensor(handle.shape)
|
||||
}
|
||||
|
||||
fn bool_tensor(handle: Self::Handle, shape: Shape) -> burn_tensor::ops::BoolTensor<Self> {
|
||||
handle.into_tensor(shape)
|
||||
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::BoolTensor<Self> {
|
||||
handle.handle.into_tensor(handle.shape)
|
||||
}
|
||||
|
||||
fn quantized_tensor(
|
||||
handles: Vec<TensorHandle<Self::Handle>>,
|
||||
scheme: QuantizationScheme,
|
||||
) -> burn_tensor::ops::QuantizedTensor<Self> {
|
||||
match handles.len() {
|
||||
// NOTE: the order of the handles is known [qtensor, scale, <offset>]
|
||||
3 => {
|
||||
let mut handles = handles;
|
||||
let offset = handles.pop().unwrap();
|
||||
let scale = handles.pop().unwrap();
|
||||
let qtensor = handles.pop().unwrap();
|
||||
QJitTensor {
|
||||
qtensor: qtensor.handle.into_tensor(qtensor.shape),
|
||||
scheme,
|
||||
qparams: JitQuantizationParameters {
|
||||
scale: scale.handle.into_tensor(scale.shape),
|
||||
offset: Some(offset.handle.into_tensor(offset.shape)),
|
||||
},
|
||||
}
|
||||
}
|
||||
2 => {
|
||||
let mut handles = handles;
|
||||
let scale = handles.pop().unwrap();
|
||||
let qtensor = handles.pop().unwrap();
|
||||
QJitTensor {
|
||||
qtensor: qtensor.handle.into_tensor(qtensor.shape),
|
||||
scheme,
|
||||
qparams: JitQuantizationParameters {
|
||||
scale: scale.handle.into_tensor(scale.shape),
|
||||
offset: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
panic!("Expected handles for the quantized tensor and its quantization parameters.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn float_tensor_handle(tensor: burn_tensor::ops::FloatTensor<Self>) -> Self::Handle {
|
||||
|
@ -86,6 +127,19 @@ impl<R: JitRuntime, F: FloatElement, I: IntElement> ReprBackend for JitBackend<R
|
|||
fn bool_tensor_handle(tensor: burn_tensor::ops::BoolTensor<Self>) -> Self::Handle {
|
||||
tensor.into()
|
||||
}
|
||||
|
||||
fn quantized_tensor_handle(
|
||||
tensor: burn_tensor::ops::QuantizedTensor<Self>,
|
||||
) -> Vec<Self::Handle> {
|
||||
let qtensor: JitFusionHandle<R> = tensor.qtensor.into();
|
||||
let scale: JitFusionHandle<R> = tensor.qparams.scale.into();
|
||||
if let Some(offset) = tensor.qparams.offset {
|
||||
let offset: JitFusionHandle<R> = offset.into();
|
||||
vec![qtensor, scale, offset]
|
||||
} else {
|
||||
vec![qtensor, scale]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: JitRuntime> FusionRuntime for FusionJitRuntime<R> {
|
||||
|
|
|
@ -103,6 +103,9 @@ where
|
|||
fn q_to_device(tensor: QuantizedTensor<Self>, device: &Device<Self>) -> QuantizedTensor<Self> {
|
||||
let mut tensor = tensor;
|
||||
tensor.qtensor = super::to_device(tensor.qtensor, device);
|
||||
tensor.qparams.scale = super::to_device(tensor.qparams.scale, device);
|
||||
tensor.qparams.offset = tensor.qparams.offset.map(|x| super::to_device(x, device));
|
||||
|
||||
tensor
|
||||
}
|
||||
|
||||
|
|
|
@ -133,5 +133,13 @@ macro_rules! testgen_jit_fusion {
|
|||
|
||||
burn_tensor::testgen_all!();
|
||||
burn_autodiff::testgen_all!();
|
||||
|
||||
// Not all ops are implemented for quantization yet, notably missing:
|
||||
// `q_swap_dims`, `q_permute`, `q_flip`, `q_gather`, `q_select`, `q_slice`, `q_expand`
|
||||
// burn_tensor::testgen_quantization!();
|
||||
// test quantization
|
||||
burn_tensor::testgen_calibration!();
|
||||
burn_tensor::testgen_scheme!();
|
||||
burn_tensor::testgen_quantize!();
|
||||
};
|
||||
}
|
||||
|
|
|
@ -53,6 +53,7 @@ impl<E: FloatNdArrayElement, Q: QuantElement> Backend for NdArray<E, Q> {
|
|||
type BoolTensorPrimitive = NdArrayTensor<bool>;
|
||||
|
||||
type QuantizedTensorPrimitive = NdArrayQTensor<Q>;
|
||||
type QuantizedEncoding = Q;
|
||||
|
||||
fn ad_enabled() -> bool {
|
||||
false
|
||||
|
|
|
@ -104,6 +104,7 @@ impl<E: TchElement, Q: QuantElement> Backend for LibTorch<E, Q> {
|
|||
type BoolTensorPrimitive = TchTensor<bool>;
|
||||
|
||||
type QuantizedTensorPrimitive = TchQTensor<Q>;
|
||||
type QuantizedEncoding = Q;
|
||||
|
||||
fn seed(seed: u64) {
|
||||
tch::manual_seed(seed as i64);
|
||||
|
|
|
@ -1,8 +1,18 @@
|
|||
use crate::{
|
||||
backend::Backend,
|
||||
ops::{BoolTensor, FloatTensor, IntTensor},
|
||||
ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
|
||||
quantization::QuantizationScheme,
|
||||
Shape,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
/// A tensor representation containing a reference to a tensor resource with a given shape.
|
||||
pub struct TensorHandle<H> {
|
||||
/// The type that can be used to point to a tensor of any kind.
|
||||
pub handle: H,
|
||||
/// The shape associated to the tensor.
|
||||
pub shape: Shape,
|
||||
}
|
||||
|
||||
/// Backend extension trait that allows an existing [backend](Backend) to use the Burn tensor representation
|
||||
/// for compilation purpose or other...
|
||||
|
@ -11,11 +21,16 @@ pub trait ReprBackend: Backend {
|
|||
type Handle: Sync + Send + Clone;
|
||||
|
||||
/// Convert a [handle](ReprBackend::Handle) to a [float tensor](Backend::FloatTensorPrimitive).
|
||||
fn float_tensor(handle: Self::Handle, shape: Shape) -> FloatTensor<Self>;
|
||||
fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self>;
|
||||
/// Convert a [handle](ReprBackend::Handle) to an [int tensor](Backend::IntTensorPrimitive).
|
||||
fn int_tensor(handle: Self::Handle, shape: Shape) -> IntTensor<Self>;
|
||||
fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self>;
|
||||
/// Convert a [handle](ReprBackend::Handle) to a [bool tensor](Backend::BoolTensorPrimitive).
|
||||
fn bool_tensor(handle: Self::Handle, shape: Shape) -> BoolTensor<Self>;
|
||||
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self>;
|
||||
/// Convert a [handle](ReprBackend::Handle) to a [quantized tensor](Backend::QuantizedTensorPrimitive).
|
||||
fn quantized_tensor(
|
||||
handles: Vec<TensorHandle<Self::Handle>>,
|
||||
scheme: QuantizationScheme,
|
||||
) -> QuantizedTensor<Self>;
|
||||
|
||||
/// Convert a [float tensor](Backend::FloatTensorPrimitive) to a [handle](ReprBackend::Handle).
|
||||
fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle;
|
||||
|
@ -23,4 +38,7 @@ pub trait ReprBackend: Backend {
|
|||
fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle;
|
||||
/// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](ReprBackend::Handle).
|
||||
fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle;
|
||||
/// Convert a [quantized tensor](Backend::QuantizedTensorPrimitive) to a [handle](ReprBackend::Handle).
|
||||
/// A quantized tensor has multiple handles for the tensor itself and the quantization parameters.
|
||||
fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Vec<Self::Handle>;
|
||||
}
|
||||
|
|
|
@ -7,6 +7,8 @@ use crate::{
|
|||
};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use super::{QuantizedTensorDescription, TensorHandle};
|
||||
|
||||
/// Keep all [tensor handles](ReprBackend::Handle) in one place and ensure that all resources
|
||||
/// are used optimally.
|
||||
#[derive(Default)]
|
||||
|
@ -66,16 +68,21 @@ impl<H: Clone> HandleContainer<H> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Get the tensor handle for the given [tensor description](TensorDescription).
|
||||
fn get_tensor_handle(&mut self, tensor: &TensorDescription) -> TensorHandle<H> {
|
||||
TensorHandle {
|
||||
handle: self.get_handle(&tensor.id, &tensor.status),
|
||||
shape: Shape::from(&tensor.shape),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the [float tensor](crate::backend::Backend::FloatTensorPrimitive) corresponding to the
|
||||
/// given [tensor description](TensorDescription).
|
||||
pub fn get_float_tensor<B>(&mut self, tensor: &TensorDescription) -> B::FloatTensorPrimitive
|
||||
where
|
||||
B: ReprBackend<Handle = H>,
|
||||
{
|
||||
B::float_tensor(
|
||||
self.get_handle(&tensor.id, &tensor.status),
|
||||
Shape::from(&tensor.shape),
|
||||
)
|
||||
B::float_tensor(self.get_tensor_handle(tensor))
|
||||
}
|
||||
|
||||
/// Get the [int tensor](crate::backend::Backend::IntTensorPrimitive) corresponding to the
|
||||
|
@ -84,10 +91,7 @@ impl<H: Clone> HandleContainer<H> {
|
|||
where
|
||||
B: ReprBackend<Handle = H>,
|
||||
{
|
||||
B::int_tensor(
|
||||
self.get_handle(&tensor.id, &tensor.status),
|
||||
Shape::from(&tensor.shape),
|
||||
)
|
||||
B::int_tensor(self.get_tensor_handle(tensor))
|
||||
}
|
||||
|
||||
/// Get the [bool tensor](crate::backend::Backend::BoolTensorPrimitive) corresponding to the
|
||||
|
@ -96,10 +100,26 @@ impl<H: Clone> HandleContainer<H> {
|
|||
where
|
||||
B: ReprBackend<Handle = H>,
|
||||
{
|
||||
B::bool_tensor(
|
||||
self.get_handle(&tensor.id, &tensor.status),
|
||||
Shape::from(&tensor.shape),
|
||||
)
|
||||
B::bool_tensor(self.get_tensor_handle(tensor))
|
||||
}
|
||||
|
||||
/// Get the [quantized tensor](crate::backend::Backend::QuantizedTensorPrimitive) corresponding to the
|
||||
/// given [tensor description](TensorDescription).
|
||||
pub fn get_quantized_tensor<B>(
|
||||
&mut self,
|
||||
tensor: &QuantizedTensorDescription,
|
||||
) -> B::QuantizedTensorPrimitive
|
||||
where
|
||||
B: ReprBackend<Handle = H>,
|
||||
{
|
||||
let qtensor = self.get_tensor_handle(&tensor.tensor);
|
||||
let scale = self.get_tensor_handle(&tensor.qparams.scale);
|
||||
let handles = if let Some(offset) = &tensor.qparams.offset {
|
||||
vec![qtensor, scale, self.get_tensor_handle(offset)]
|
||||
} else {
|
||||
vec![qtensor, scale]
|
||||
};
|
||||
B::quantized_tensor(handles, tensor.scheme.clone())
|
||||
}
|
||||
|
||||
/// Register a new [float tensor](crate::backend::Backend::FloatTensorPrimitive) with the corresponding [tensor id](TensorId).
|
||||
|
@ -111,6 +131,26 @@ impl<H: Clone> HandleContainer<H> {
|
|||
self.handles.insert(*id, Handle::Existing(handle));
|
||||
}
|
||||
|
||||
/// Register a new [quantized tensor](crate::backend::Backend::QuantizedTensorPrimitive) with the corresponding [tensor ids](TensorId).
|
||||
pub fn register_quantized_tensor<B>(
|
||||
&mut self,
|
||||
ids: &[&TensorId],
|
||||
tensor: B::QuantizedTensorPrimitive,
|
||||
) where
|
||||
B: ReprBackend<Handle = H>,
|
||||
{
|
||||
let handles = B::quantized_tensor_handle(tensor);
|
||||
assert_eq!(
|
||||
ids.len(),
|
||||
handles.len(),
|
||||
"Number of tensor ids and handles must match"
|
||||
);
|
||||
|
||||
for (handle, id) in handles.into_iter().zip(ids) {
|
||||
self.handles.insert(**id, Handle::Existing(handle));
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a new [int tensor](crate::backend::Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId).
|
||||
pub fn register_int_tensor<B>(&mut self, id: &TensorId, tensor: B::IntTensorPrimitive)
|
||||
where
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
mod backend;
|
||||
mod handle;
|
||||
mod operation;
|
||||
mod quantization;
|
||||
mod tensor;
|
||||
|
||||
pub use backend::*;
|
||||
pub use handle::*;
|
||||
pub use operation::*;
|
||||
pub use quantization::*;
|
||||
pub use tensor::*;
|
||||
|
|
|
@ -5,10 +5,13 @@ use crate::{
|
|||
ops::{
|
||||
ConvOptions, ConvTransposeOptions, DeformConvOptions, InterpolateMode, InterpolateOptions,
|
||||
},
|
||||
quantization::QuantizationScheme,
|
||||
repr::tensor::TensorDescription,
|
||||
DType, Distribution, Element,
|
||||
};
|
||||
|
||||
use super::{QuantizationParametersDescription, QuantizedTensorDescription};
|
||||
|
||||
/// Describe all tensor operations possible.
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
|
||||
pub enum OperationDescription {
|
||||
|
@ -61,6 +64,10 @@ pub enum FloatOperationDescription {
|
|||
Random(RandomOperationDescription),
|
||||
/// Operation corresponding to [recip](crate::ops::FloatTensorOps::float_recip).
|
||||
Recip(UnaryOperationDescription),
|
||||
/// Operation corresponding to [quantize](crate::ops::QTensorOps::quantize).
|
||||
Quantize(QuantizeOperationDescription),
|
||||
/// Operation corresponding to [dequantize](crate::ops::QTensorOps::dequantize).
|
||||
Dequantize(DequantizeOperationDescription),
|
||||
}
|
||||
|
||||
/// Operation description specific to module.
|
||||
|
@ -830,6 +837,22 @@ pub struct ConvTranspose3dOptionsDescription {
|
|||
pub groups: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct QuantizeOperationDescription {
|
||||
pub tensor: TensorDescription,
|
||||
pub qparams: QuantizationParametersDescription,
|
||||
pub scheme: QuantizationScheme,
|
||||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct DequantizeOperationDescription {
|
||||
pub qtensor: QuantizedTensorDescription,
|
||||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
impl From<ConvOptions<1>> for Conv1dOptionsDescription {
|
||||
fn from(value: ConvOptions<1>) -> Self {
|
||||
Self {
|
||||
|
@ -1421,6 +1444,25 @@ impl FloatOperationDescription {
|
|||
FloatOperationDescription::Sin(desc) => vec![&desc.input, &desc.out],
|
||||
FloatOperationDescription::Tanh(desc) => vec![&desc.input, &desc.out],
|
||||
FloatOperationDescription::IntoInt(desc) => vec![&desc.input, &desc.out],
|
||||
FloatOperationDescription::Quantize(desc) => {
|
||||
if let Some(offset) = &desc.qparams.offset {
|
||||
vec![&desc.tensor, &desc.qparams.scale, &offset, &desc.out]
|
||||
} else {
|
||||
vec![&desc.tensor, &desc.qparams.scale, &desc.out]
|
||||
}
|
||||
}
|
||||
FloatOperationDescription::Dequantize(desc) => {
|
||||
if let Some(offset) = &desc.qtensor.qparams.offset {
|
||||
vec![
|
||||
&desc.qtensor.tensor,
|
||||
&desc.qtensor.qparams.scale,
|
||||
&offset,
|
||||
&desc.out,
|
||||
]
|
||||
} else {
|
||||
vec![&desc.qtensor.tensor, &desc.qtensor.qparams.scale, &desc.out]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::quantization::QuantizationScheme;
|
||||
|
||||
use super::TensorDescription;
|
||||
|
||||
/// A quantized tensor description represents a snapshot of a quantized tensor when it was used.
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct QuantizedTensorDescription {
|
||||
/// The quantized tensor.
|
||||
pub tensor: TensorDescription,
|
||||
/// The quantization parameters.
|
||||
pub qparams: QuantizationParametersDescription,
|
||||
/// The quantization scheme
|
||||
pub scheme: QuantizationScheme,
|
||||
}
|
||||
|
||||
/// Quantization parameters description.
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct QuantizationParametersDescription {
|
||||
/// The scaling factor.
|
||||
pub scale: TensorDescription,
|
||||
/// The zero-point offset.
|
||||
pub offset: Option<TensorDescription>,
|
||||
}
|
|
@ -88,6 +88,8 @@ pub trait Backend:
|
|||
|
||||
/// Tensor primitive to be used for all quantized operations.
|
||||
type QuantizedTensorPrimitive: QTensorPrimitive + Clone + Send + 'static + core::fmt::Debug;
|
||||
/// Quantized tensor encoding type.
|
||||
type QuantizedEncoding: Element;
|
||||
|
||||
/// If autodiff is enabled.
|
||||
fn ad_enabled() -> bool {
|
||||
|
|
|
@ -1,16 +1,18 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{backend::Backend, Tensor, TensorPrimitive};
|
||||
|
||||
use super::{CalibrationRange, QuantizationParameters, QuantizationParametersPrimitive};
|
||||
|
||||
/// Quantization data type.
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum QuantizationType {
|
||||
/// 8-bit signed integer.
|
||||
QInt8,
|
||||
}
|
||||
|
||||
/// Quantization scheme.
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum QuantizationScheme {
|
||||
/// Per-tensor affine/asymmetric quantization.
|
||||
PerTensorAffine(QuantizationType),
|
||||
|
|
Loading…
Reference in New Issue