Refactor tensor quantization for q_* ops (#2025)

* Move QuantizationScheme to burn-tensor

* Refactor QuantizedTensorPrimitive to include the quantization strategy

* Fix QFloat tensor data display

* Refactor quantization methods to use scheme and qparams (on backend device)

* Fix clippy

* Fix fmt

* Add qtensor primitive tests
This commit is contained in:
Guillaume Lagrange 2024-07-19 10:39:50 -04:00 committed by GitHub
parent 3204cbe345
commit 0d5025edbb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
63 changed files with 830 additions and 395 deletions

View File

@ -307,7 +307,7 @@ Those operations are only available for `Float` tensors on backends that impleme
| Burn API | PyTorch Equivalent |
| ------------------------------------ | ------------------------------- |
| `tensor.quantize(strategy)` | N/A |
| `tensor.quantize(scheme, qparams)` | N/A |
| `tensor.dequantize()` | N/A |
## Activation Functions

View File

@ -44,13 +44,13 @@ tensors and can collect their statistics, such as the min and max value when usi
`MinMaxCalibration`, to compute the quantization parameters.
```rust , ignore
# use burn::quantization::{MinMaxCalibration, QuantizationScheme, QuantizationType, Quantizer};
# use burn::module::Quantizer;
# use burn::tensor::quantization::{MinMaxCalibration, QuantizationScheme, QuantizationType};
#
// Quantization config
let mut quantizer = Quantizer {
calibration: MinMaxCalibration {
scheme: QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8),
},
calibration: MinMaxCalibration {},
scheme: QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8),
};
// Quantize the weights

View File

@ -1,7 +1,8 @@
use burn_tensor::{
backend::Backend,
ops::{FloatTensor, QTensorOps, QuantizedTensor},
Device, QuantizationStrategy, Shape, TensorData,
quantization::{QuantizationParametersPrimitive, QuantizationScheme},
Device, Shape, TensorData,
};
use crate::{checkpoint::strategy::CheckpointStrategy, Autodiff};
@ -16,15 +17,13 @@ impl<B: Backend, C: CheckpointStrategy> QTensorOps<Self> for Autodiff<B, C> {
fn quantize<const D: usize>(
_tensor: FloatTensor<Self, D>,
_strategy: &QuantizationStrategy,
_scheme: &QuantizationScheme,
_qparams: QuantizationParametersPrimitive<Self>,
) -> QuantizedTensor<Self, D> {
todo!() // required for QAT
}
fn dequantize<const D: usize>(
_tensor: QuantizedTensor<Self, D>,
_strategy: &QuantizationStrategy,
) -> FloatTensor<Self, D> {
fn dequantize<const D: usize>(_tensor: QuantizedTensor<Self, D>) -> FloatTensor<Self, D> {
todo!()
}
@ -43,10 +42,7 @@ impl<B: Backend, C: CheckpointStrategy> QTensorOps<Self> for Autodiff<B, C> {
B::q_reshape(tensor, shape)
}
async fn q_into_data<const D: usize>(
tensor: QuantizedTensor<Self, D>,
strategy: QuantizationStrategy,
) -> TensorData {
B::q_into_data(tensor, strategy).await
async fn q_into_data<const D: usize>(tensor: QuantizedTensor<Self, D>) -> TensorData {
B::q_into_data(tensor).await
}
}

View File

@ -2,13 +2,14 @@ use std::marker::PhantomData;
use burn_tensor::{
backend::{Backend, DeviceId, DeviceOps, SyncType},
quantization::{QTensorPrimitive, QuantizationStrategy},
Device,
};
use candle_core::DeviceLocation;
use crate::{
element::{CandleElement, FloatCandleElement, IntCandleElement},
CandleTensor, PrecisionBridge,
CandleQTensor, CandleTensor, PrecisionBridge,
};
/// Tensor backend that uses the [candle](candle_core) crate for executing tensor operations.
@ -92,8 +93,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
type BoolTensorPrimitive<const D: usize> = CandleTensor<u8, D>;
// NOTE: candle does not implement `WithDType` for i8
type QuantizedTensorPrimitive<const D: usize> = CandleTensor<u8, D>;
type QuantizedTensorPrimitive<const D: usize> = CandleQTensor<D>;
fn ad_enabled() -> bool {
false

View File

@ -1,12 +1,13 @@
use burn_tensor::{
backend::Backend,
ops::{FloatTensor, QTensorOps, QuantizedTensor},
DType, Device, QuantizationStrategy, Shape, TensorData,
quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationStrategy},
DType, Device, Shape, TensorData,
};
use crate::{
element::{FloatCandleElement, IntCandleElement},
Candle,
Candle, CandleQTensor,
};
impl<F: FloatCandleElement, I: IntCandleElement> QTensorOps<Self> for Candle<F, I> {
@ -19,37 +20,35 @@ impl<F: FloatCandleElement, I: IntCandleElement> QTensorOps<Self> for Candle<F,
fn quantize<const D: usize>(
_tensor: FloatTensor<Self, D>,
_strategy: &QuantizationStrategy,
_scheme: &QuantizationScheme,
_qparams: QuantizationParametersPrimitive<Self>,
) -> QuantizedTensor<Self, D> {
unimplemented!()
}
fn dequantize<const D: usize>(
_tensor: QuantizedTensor<Self, D>,
_strategy: &QuantizationStrategy,
) -> FloatTensor<Self, D> {
fn dequantize<const D: usize>(_tensor: QuantizedTensor<Self, D>) -> FloatTensor<Self, D> {
unimplemented!()
}
fn q_shape<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Shape<D> {
super::base::shape(tensor)
super::base::shape(&tensor.qtensor)
}
fn q_device<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Device<Self> {
super::base::device(tensor)
super::base::device(&tensor.qtensor)
}
fn q_reshape<const D1: usize, const D2: usize>(
tensor: QuantizedTensor<Self, D1>,
shape: Shape<D2>,
) -> QuantizedTensor<Self, D2> {
super::base::reshape(tensor, shape)
CandleQTensor {
qtensor: super::base::reshape(tensor.qtensor, shape),
scheme: tensor.scheme,
}
}
async fn q_into_data<const D: usize>(
tensor: QuantizedTensor<Self, D>,
strategy: QuantizationStrategy,
) -> TensorData {
super::base::into_data(tensor)
async fn q_into_data<const D: usize>(tensor: QuantizedTensor<Self, D>) -> TensorData {
unimplemented!()
}
}

View File

@ -1,6 +1,9 @@
use std::marker::PhantomData;
use burn_tensor::{Element, Shape, TensorData};
use burn_tensor::{
quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy},
Element, Shape, TensorData,
};
use crate::{element::CandleElement, CandleDevice};
@ -45,3 +48,23 @@ impl<E: CandleElement, const D: usize> CandleTensor<E, D> {
Shape::from(x)
}
}
/// A quantized tensor for the candle backend.
#[derive(Clone, Debug)]
pub struct CandleQTensor<const D: usize> {
/// The quantized tensor.
// NOTE: candle does not implement `WithDType` for i8
pub qtensor: CandleTensor<u8, D>,
/// The quantization scheme.
pub scheme: QuantizationScheme,
}
impl<const D: usize> QTensorPrimitive for CandleQTensor<D> {
fn scheme(&self) -> &QuantizationScheme {
&self.scheme
}
fn strategy(&self) -> QuantizationStrategy {
todo!()
}
}

View File

@ -33,9 +33,6 @@ pub mod module;
/// Neural network module.
pub mod nn;
/// Quantization module.
pub mod quantization;
/// Module for the recorder.
pub mod record;

View File

@ -1,12 +1,11 @@
use super::ParamId;
use super::{ParamId, Quantizer};
use crate::{
quantization::{Calibration, Quantizer},
record::Record,
tensor::backend::{AutodiffBackend, Backend},
};
use alloc::vec::Vec;
pub use burn_derive::Module;
use burn_tensor::{Bool, Int, Tensor};
use burn_tensor::{quantization::Calibration, Bool, Int, Tensor};
/// Type alias to `Vec<B::Device>` which supports `no_std` environments, but automatically using
/// the `alloc` crate.

View File

@ -1,7 +1,9 @@
mod base;
mod display;
mod param;
mod quantize;
pub use base::*;
pub use display::*;
pub use param::*;
pub use quantize::*;

View File

@ -1,18 +1,23 @@
use burn_tensor::{backend::Backend, Tensor};
use burn_tensor::{
backend::Backend,
quantization::{Calibration, QuantizationScheme},
Tensor,
};
use crate::module::{ModuleMapper, ParamId};
use super::Calibration;
/// Describes how to quantize a module.
pub struct Quantizer<C: Calibration> {
/// The calibration method used in quantization.
pub calibration: C,
/// The quantization scheme.
pub scheme: QuantizationScheme,
}
impl<B: Backend, C: Calibration> ModuleMapper<B> for Quantizer<C> {
fn map_float<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
let strategy = self.calibration.configure(&tensor);
tensor.quantize(strategy)
let range = self.calibration.compute_range(&tensor);
let qparams = self.scheme.compute_q_params(range);
tensor.quantize(&self.scheme, qparams)
}
}

View File

@ -1,80 +0,0 @@
use burn_tensor::{
backend::Backend, AffineQuantization, ElementConversion, Quantization, QuantizationStrategy,
SymmetricQuantization, Tensor,
};
use super::{QuantizationScheme, QuantizationType};
/// Calibration method used to compute the quantization range mapping.
pub trait Calibration {
/// Configure the quantization strategy.
fn configure<B: Backend, const D: usize>(&self, tensor: &Tensor<B, D>) -> QuantizationStrategy;
}
/// Computes the quantization range mapping based on the running min and max values.
pub struct MinMaxCalibration {
/// Quantization scheme to be used.
pub scheme: QuantizationScheme,
}
impl Calibration for MinMaxCalibration {
fn configure<B: Backend, const D: usize>(&self, tensor: &Tensor<B, D>) -> QuantizationStrategy {
let min = tensor.clone().min().into_scalar().elem::<f32>();
let max = tensor.clone().max().into_scalar().elem::<f32>();
match &self.scheme {
QuantizationScheme::PerTensorAffine(dtype) => match dtype {
QuantizationType::QInt8 => {
QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::new(min, max))
}
},
QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
QuantizationType::QInt8 => QuantizationStrategy::PerTensorSymmetricInt8(
SymmetricQuantization::new(min, max),
),
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
#[test]
fn min_max_calibration_per_tensor_affine_int8() {
let device = <TestBackend as Backend>::Device::default();
let tensor = Tensor::<TestBackend, 1>::from_floats([-1.8, -1.0, 0.0, 0.5], &device);
let calibration = MinMaxCalibration {
scheme: QuantizationScheme::PerTensorAffine(QuantizationType::QInt8),
};
let strategy = calibration.configure(&tensor);
if let QuantizationStrategy::PerTensorAffineInt8(q) = strategy {
assert_eq!(q.scale, 0.009_019_608);
assert_eq!(q.offset, 72);
} else {
panic!("Wrong quantization strategy");
}
}
#[test]
fn min_max_calibration_per_tensor_symmetric_int8() {
let device = <TestBackend as Backend>::Device::default();
let tensor = Tensor::<TestBackend, 1>::from_floats([-1.8, -1.0, 0.0, 0.5], &device);
let calibration = MinMaxCalibration {
scheme: QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8),
};
let strategy = calibration.configure(&tensor);
if let QuantizationStrategy::PerTensorSymmetricInt8(q) = strategy {
assert_eq!(q.scale, 0.014_173_228);
} else {
panic!("Wrong quantization strategy");
}
}
}

View File

@ -1,7 +0,0 @@
mod calibration;
mod quantize;
mod scheme;
pub use calibration::*;
pub use quantize::*;
pub use scheme::*;

View File

@ -1,17 +0,0 @@
/// Quantization data type.
pub enum QuantizationType {
/// 8-bit signed integer.
QInt8,
}
/// Quantization scheme.
pub enum QuantizationScheme {
/// Per-tensor affine/asymmetric quantization.
PerTensorAffine(QuantizationType),
/// Per-tensor symmetric quantization.
PerTensorSymmetric(QuantizationType),
// /// Per-channel affine/asymmetric quantization.
// PerChannelAffine,
// /// Per-channel symmetric quantization.
// PerChannelSymmetric,
}

View File

@ -1,5 +1,6 @@
use crate::{
client::FusionClient, stream::Context, FusionClientLocator, FusionTensor, PrecisionBridge,
QFusionTensor,
};
use burn_tensor::{
backend::{Backend, DeviceOps, SyncType},
@ -37,7 +38,7 @@ impl<B: FusionBackend> Backend for Fusion<B> {
type BoolTensorPrimitive<const D: usize> = FusionTensor<B::FusionRuntime>;
type QuantizedTensorPrimitive<const D: usize> = FusionTensor<B::FusionRuntime>;
type QuantizedTensorPrimitive<const D: usize> = QFusionTensor<B::FusionRuntime>;
fn name() -> String {
format!("fusion<{}>", B::name())

View File

@ -1,7 +1,8 @@
use burn_tensor::{
backend::Backend,
ops::{QTensorOps, QuantizedTensor},
Device, QuantizationStrategy, Shape, TensorData,
quantization::{QuantizationParametersPrimitive, QuantizationScheme},
Device, Shape, TensorData,
};
use crate::{client::FusionClient, Fusion, FusionBackend};
@ -16,24 +17,24 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
fn quantize<const D: usize>(
_tensor: <Self as Backend>::FloatTensorPrimitive<D>,
_strategy: &QuantizationStrategy,
_scheme: &QuantizationScheme,
_qparams: QuantizationParametersPrimitive<Self>,
) -> <Self as Backend>::QuantizedTensorPrimitive<D> {
unimplemented!()
}
fn dequantize<const D: usize>(
_tensor: <Self as Backend>::QuantizedTensorPrimitive<D>,
_strategy: &QuantizationStrategy,
) -> <Self as Backend>::FloatTensorPrimitive<D> {
unimplemented!()
}
fn q_shape<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Shape<D> {
tensor.shape()
tensor.qtensor.shape()
}
fn q_device<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Device<Self> {
tensor.client.device().clone()
tensor.qtensor.client.device().clone()
}
fn q_reshape<const D1: usize, const D2: usize>(
@ -43,10 +44,7 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
unimplemented!()
}
async fn q_into_data<const D: usize>(
_tensor: QuantizedTensor<Self, D>,
_strategy: QuantizationStrategy,
) -> TensorData {
async fn q_into_data<const D: usize>(_tensor: QuantizedTensor<Self, D>) -> TensorData {
unimplemented!()
}
}

View File

@ -1,5 +1,6 @@
use crate::{client::FusionClient, stream::StreamId, Client, FusionBackend, FusionRuntime};
use burn_tensor::{
quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy},
repr::{TensorDescription, TensorId, TensorStatus},
DType, Shape, TensorData,
};
@ -157,3 +158,31 @@ impl<R: FusionRuntime> Drop for FusionTensor<R> {
}
}
}
/// A quantized tensor primitive for fusion backends.
#[derive(Debug)]
pub struct QFusionTensor<R: FusionRuntime> {
/// The quantized tensor.
pub qtensor: FusionTensor<R>,
/// The quantization scheme.
pub scheme: QuantizationScheme,
}
impl<R: FusionRuntime> QTensorPrimitive for QFusionTensor<R> {
fn scheme(&self) -> &QuantizationScheme {
&self.scheme
}
fn strategy(&self) -> QuantizationStrategy {
todo!()
}
}
impl<R: FusionRuntime> Clone for QFusionTensor<R> {
fn clone(&self) -> Self {
Self {
qtensor: self.qtensor.clone(),
scheme: self.scheme.clone(),
}
}
}

View File

@ -1,5 +1,6 @@
use crate::{
tensor::JitTensor, FloatElement, IntElement, JitAutotuneKey, JitRuntime, PrecisionBridge,
tensor::{JitTensor, QJitTensor},
FloatElement, IntElement, JitAutotuneKey, JitRuntime, PrecisionBridge,
};
use burn_compute::server::ComputeServer;
use burn_tensor::backend::{Backend, SyncType};
@ -33,8 +34,7 @@ where
type FloatTensorPrimitive<const D: usize> = JitTensor<R, Self::FloatElem, D>;
type IntTensorPrimitive<const D: usize> = JitTensor<R, Self::IntElem, D>;
type BoolTensorPrimitive<const D: usize> = JitTensor<R, u32, D>;
// TODO: implement `JitElement` / `CubeElement` for quantized type
type QuantizedTensorPrimitive<const D: usize> = JitTensor<R, u32, D>;
type QuantizedTensorPrimitive<const D: usize> = QJitTensor<R, D>;
fn name() -> String {
format!("jit<{}>", R::name())

View File

@ -1,9 +1,10 @@
use burn_tensor::{
ops::{FloatTensor, QTensorOps, QuantizedTensor},
Device, QuantizationStrategy, Shape, TensorData,
quantization::{QuantizationParametersPrimitive, QuantizationScheme},
Device, Shape, TensorData,
};
use crate::{FloatElement, IntElement, JitBackend, JitRuntime};
use crate::{tensor::QJitTensor, FloatElement, IntElement, JitBackend, JitRuntime};
impl<R, F, I> QTensorOps<Self> for JitBackend<R, F, I>
where
@ -20,37 +21,35 @@ where
fn quantize<const D: usize>(
_tensor: FloatTensor<Self, D>,
_strategy: &QuantizationStrategy,
_scheme: &QuantizationScheme,
_qparams: QuantizationParametersPrimitive<Self>,
) -> QuantizedTensor<Self, D> {
unimplemented!()
}
fn dequantize<const D: usize>(
_tensor: QuantizedTensor<Self, D>,
_strategy: &QuantizationStrategy,
) -> FloatTensor<Self, D> {
fn dequantize<const D: usize>(_tensor: QuantizedTensor<Self, D>) -> FloatTensor<Self, D> {
unimplemented!()
}
fn q_shape<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Shape<D> {
tensor.shape.clone()
tensor.qtensor.shape.clone()
}
fn q_device<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Device<Self> {
tensor.device.clone()
tensor.qtensor.device.clone()
}
fn q_reshape<const D1: usize, const D2: usize>(
tensor: QuantizedTensor<Self, D1>,
shape: Shape<D2>,
) -> QuantizedTensor<Self, D2> {
super::reshape(tensor, shape)
QJitTensor {
qtensor: super::reshape(tensor.qtensor, shape),
scheme: tensor.scheme,
}
}
async fn q_into_data<const D: usize>(
_tensor: QuantizedTensor<Self, D>,
_strategy: QuantizationStrategy,
) -> TensorData {
async fn q_into_data<const D: usize>(_tensor: QuantizedTensor<Self, D>) -> TensorData {
unimplemented!()
}
}

View File

@ -1,4 +1,7 @@
mod base;
mod layout;
mod qtensor;
pub use base::*;
pub(crate) use layout::*;
pub(crate) use qtensor::*;

View File

@ -0,0 +1,34 @@
use burn_tensor::quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy};
use crate::JitRuntime;
use super::JitTensor;
/// A quantized tensor primitive.
#[derive(Debug)]
pub struct QJitTensor<R: JitRuntime, const D: usize> {
/// The quantized tensor.
// TODO: implement `JitElement` / `CubeElement` for quantized type
pub qtensor: JitTensor<R, u32, D>,
/// The quantization scheme.
pub scheme: QuantizationScheme,
}
impl<R: JitRuntime, const D: usize> QTensorPrimitive for QJitTensor<R, D> {
fn scheme(&self) -> &QuantizationScheme {
&self.scheme
}
fn strategy(&self) -> QuantizationStrategy {
todo!()
}
}
impl<R: JitRuntime, const D: usize> Clone for QJitTensor<R, D> {
fn clone(&self) -> Self {
Self {
qtensor: self.qtensor.clone(),
scheme: self.scheme.clone(),
}
}
}

View File

@ -1,5 +1,6 @@
use crate::NdArrayTensor;
use crate::{element::FloatNdArrayElement, PrecisionBridge};
use crate::element::{FloatNdArrayElement, QuantElement};
use crate::PrecisionBridge;
use crate::{NdArrayQTensor, NdArrayTensor};
use alloc::string::String;
use burn_common::stub::Mutex;
use burn_tensor::backend::{Backend, DeviceId, DeviceOps};
@ -34,11 +35,12 @@ impl Default for NdArrayDevice {
/// This backend is compatible with CPUs and can be compiled for almost any platform, including
/// `wasm`, `arm`, and `x86`.
#[derive(Clone, Copy, Default, Debug)]
pub struct NdArray<E = f32> {
phantom: PhantomData<E>,
pub struct NdArray<E = f32, Q = i8> {
_e: PhantomData<E>,
_q: PhantomData<Q>,
}
impl<E: FloatNdArrayElement> Backend for NdArray<E> {
impl<E: FloatNdArrayElement, Q: QuantElement> Backend for NdArray<E, Q> {
type Device = NdArrayDevice;
type FullPrecisionBridge = PrecisionBridge<f32>;
@ -50,7 +52,7 @@ impl<E: FloatNdArrayElement> Backend for NdArray<E> {
type BoolTensorPrimitive<const D: usize> = NdArrayTensor<bool, D>;
type QuantizedTensorPrimitive<const D: usize> = NdArrayTensor<i8, D>;
type QuantizedTensorPrimitive<const D: usize> = NdArrayQTensor<Q, D>;
fn ad_enabled() -> bool {
false

View File

@ -1,4 +1,4 @@
use crate::{FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayTensor};
use crate::{element::QuantElement, FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayTensor};
use burn_tensor::{backend::BackendBridge, ops::FloatTensor};
use core::marker::PhantomData;
@ -8,10 +8,11 @@ pub struct PrecisionBridge<E: FloatNdArrayElement> {
_e: PhantomData<E>,
}
impl<TElem, OElem> BackendBridge<NdArray<OElem>> for PrecisionBridge<TElem>
impl<TElem, OElem, QElem> BackendBridge<NdArray<OElem, QElem>> for PrecisionBridge<TElem>
where
TElem: FloatNdArrayElement,
OElem: FloatNdArrayElement,
QElem: QuantElement,
{
type Target = NdArray<TElem>;

View File

@ -41,6 +41,11 @@ pub trait ExpElement {
fn int_abs_elem(self) -> Self;
}
/// A quantized element for the ndarray backend.
pub trait QuantElement: NdArrayElement {}
impl QuantElement for i8 {}
impl FloatNdArrayElement for f64 {}
impl FloatNdArrayElement for f32 {}

View File

@ -1,7 +1,11 @@
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArray};
use crate::{
element::{FloatNdArrayElement, QuantElement},
tensor::NdArrayTensor,
NdArray,
};
use burn_tensor::{ops::ActivationOps, ElementConversion};
impl<E: FloatNdArrayElement> ActivationOps<Self> for NdArray<E> {
impl<E: FloatNdArrayElement, Q: QuantElement> ActivationOps<Self> for NdArray<E, Q> {
fn relu<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let zero = 0.elem();
let array = tensor

View File

@ -7,7 +7,7 @@ use core::ops::Range;
use ndarray::IntoDimension;
// Current crate
use crate::element::FloatNdArrayElement;
use crate::element::{FloatNdArrayElement, QuantElement};
use crate::NdArrayDevice;
use crate::{tensor::NdArrayTensor, NdArray};
@ -16,7 +16,7 @@ use burn_tensor::{backend::Backend, Shape, TensorData};
use super::NdArrayOps;
impl<E: FloatNdArrayElement> BoolTensorOps<Self> for NdArray<E> {
impl<E: FloatNdArrayElement, Q: QuantElement> BoolTensorOps<Self> for NdArray<E, Q> {
fn bool_from_data<const D: usize>(
data: TensorData,
_device: &NdArrayDevice,

View File

@ -11,7 +11,7 @@ use ndarray::{
};
use crate::{
element::FloatNdArrayElement,
element::{FloatNdArrayElement, QuantElement},
ops::padding::{apply_padding_4d, apply_padding_5d},
sharing::UnsafeSharedRef,
tensor::NdArrayTensor,
@ -98,7 +98,7 @@ fn conv3d_mad_inner<E: FloatNdArrayElement>(
}
}
pub(crate) fn conv2d<E: FloatNdArrayElement>(
pub(crate) fn conv2d<E: FloatNdArrayElement, Q: QuantElement>(
x: NdArrayTensor<E, 4>,
weight: NdArrayTensor<E, 4>,
bias: Option<NdArrayTensor<E, 1>>,
@ -125,7 +125,7 @@ pub(crate) fn conv2d<E: FloatNdArrayElement>(
in_width,
);
let x = apply_padding_4d(x, options.padding, 0i32.elem()).array;
let x = apply_padding_4d::<E, Q>(x, options.padding, 0i32.elem()).array;
// Convert inputs from dynamic indexes to static to improve perf.
let x = x.into_dimensionality::<ndarray::Ix4>().unwrap();
@ -309,7 +309,7 @@ pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>(
NdArrayTensor::new(output.into_dyn().into_shared())
}
pub(crate) fn conv3d<E: FloatNdArrayElement>(
pub(crate) fn conv3d<E: FloatNdArrayElement, Q: QuantElement>(
x: NdArrayTensor<E, 5>,
weight: NdArrayTensor<E, 5>,
bias: Option<NdArrayTensor<E, 1>>,
@ -344,7 +344,7 @@ pub(crate) fn conv3d<E: FloatNdArrayElement>(
in_width,
);
let x = apply_padding_5d(x, options.padding, 0i32.elem()).array;
let x = apply_padding_5d::<E, Q>(x, options.padding, 0i32.elem()).array;
// Convert inputs from dynamic indexes to static to improve perf.
let x = x.into_dimensionality::<ndarray::Ix5>().unwrap();

View File

@ -12,6 +12,7 @@ use ndarray::IntoDimension;
// Current crate
use crate::element::ExpElement;
use crate::element::FloatNdArrayElement;
use crate::element::QuantElement;
use crate::{tensor::NdArrayTensor, NdArray};
use crate::{NdArrayDevice, SEED};
@ -20,7 +21,7 @@ use burn_tensor::{backend::Backend, Shape, TensorData};
use super::{NdArrayMathOps, NdArrayOps};
impl<E: FloatNdArrayElement> IntTensorOps<Self> for NdArray<E> {
impl<E: FloatNdArrayElement, Q: QuantElement> IntTensorOps<Self> for NdArray<E, Q> {
fn int_from_data<const D: usize>(
data: TensorData,
_device: &NdArrayDevice,

View File

@ -1,5 +1,7 @@
use crate::{
element::FloatNdArrayElement, ops::padding::apply_padding_4d, sharing::UnsafeSharedRef,
element::{FloatNdArrayElement, QuantElement},
ops::padding::apply_padding_4d,
sharing::UnsafeSharedRef,
tensor::NdArrayTensor,
};
@ -7,7 +9,7 @@ use burn_common::{iter_range_par, run_par};
use burn_tensor::ElementConversion;
use ndarray::Array4;
pub(crate) fn max_pool2d<E: FloatNdArrayElement>(
pub(crate) fn max_pool2d<E: FloatNdArrayElement, Q: QuantElement>(
x: NdArrayTensor<E, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
@ -28,7 +30,7 @@ pub(crate) fn max_pool2d<E: FloatNdArrayElement>(
/ stride_width)
+ 1;
let x = apply_padding_4d(x, padding, inf).array;
let x = apply_padding_4d::<E, Q>(x, padding, inf).array;
let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf);
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
@ -67,7 +69,7 @@ pub(crate) fn max_pool2d<E: FloatNdArrayElement>(
NdArrayTensor::new(output.into_dyn().into_shared())
}
pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement>(
pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement, Q: QuantElement>(
x: NdArrayTensor<E, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
@ -88,7 +90,7 @@ pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement>(
/ stride_width)
+ 1;
let x = apply_padding_4d(x, padding, inf).array;
let x = apply_padding_4d::<E, Q>(x, padding, inf).array;
let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf);
let mut indices = Array4::<i64>::zeros((batch_size, channels, out_height, out_width));

View File

@ -5,18 +5,18 @@ use super::{
interpolate::{bicubic_interpolate, bilinear_interpolate, nearest_interpolate},
maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
};
use crate::ops::interpolate::nearest_interpolate_backward;
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArray};
use crate::{element::QuantElement, ops::interpolate::nearest_interpolate_backward};
use burn_tensor::ops::*;
impl<E: FloatNdArrayElement> ModuleOps<Self> for NdArray<E> {
impl<E: FloatNdArrayElement, Q: QuantElement> ModuleOps<Self> for NdArray<E, Q> {
fn conv2d(
x: NdArrayTensor<E, 4>,
weight: NdArrayTensor<E, 4>,
bias: Option<NdArrayTensor<E, 1>>,
options: ConvOptions<2>,
) -> NdArrayTensor<E, 4> {
conv2d(x, weight, bias, options)
conv2d::<E, Q>(x, weight, bias, options)
}
fn conv_transpose2d(
@ -56,7 +56,7 @@ impl<E: FloatNdArrayElement> ModuleOps<Self> for NdArray<E> {
padding: [usize; 2],
dilation: [usize; 2],
) -> NdArrayTensor<E, 4> {
max_pool2d(x, kernel_size, stride, padding, dilation)
max_pool2d::<E, Q>(x, kernel_size, stride, padding, dilation)
}
fn max_pool2d_with_indices(
@ -65,8 +65,9 @@ impl<E: FloatNdArrayElement> ModuleOps<Self> for NdArray<E> {
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> MaxPool2dWithIndices<NdArray<E>> {
let (output, indices) = max_pool2d_with_indices(x, kernel_size, stride, padding, dilation);
) -> MaxPool2dWithIndices<NdArray<E, Q>> {
let (output, indices) =
max_pool2d_with_indices::<E, Q>(x, kernel_size, stride, padding, dilation);
MaxPool2dWithIndices::new(output, indices)
}
@ -79,7 +80,7 @@ impl<E: FloatNdArrayElement> ModuleOps<Self> for NdArray<E> {
dilation: [usize; 2],
output_grad: NdArrayTensor<E, 4>,
indices: NdArrayTensor<i64, 4>,
) -> MaxPool2dBackward<NdArray<E>> {
) -> MaxPool2dBackward<NdArray<E, Q>> {
MaxPool2dBackward::new(max_pool2d_backward(
x,
kernel_size,
@ -137,7 +138,7 @@ impl<E: FloatNdArrayElement> ModuleOps<Self> for NdArray<E> {
bias: Option<NdArrayTensor<E, 1>>,
options: ConvOptions<3>,
) -> NdArrayTensor<E, 5> {
conv3d(x, weight, bias, options)
conv3d::<E, Q>(x, weight, bias, options)
}
fn conv_transpose3d(

View File

@ -1,8 +1,12 @@
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArray};
use crate::{
element::{FloatNdArrayElement, QuantElement},
tensor::NdArrayTensor,
NdArray,
};
use burn_tensor::ops::FloatTensorOps;
use ndarray::{Array4, Array5};
pub(crate) fn apply_padding_4d<E: FloatNdArrayElement>(
pub(crate) fn apply_padding_4d<E: FloatNdArrayElement, Q: QuantElement>(
x: NdArrayTensor<E, 4>,
padding: [usize; 2],
elem: E,
@ -18,7 +22,7 @@ pub(crate) fn apply_padding_4d<E: FloatNdArrayElement>(
);
let mut x_new = NdArrayTensor::new(x_new.into_shared().into_dyn());
x_new = NdArray::float_slice_assign(
x_new = NdArray::<E, Q>::float_slice_assign(
x_new,
[
0..batch_size,
@ -32,7 +36,7 @@ pub(crate) fn apply_padding_4d<E: FloatNdArrayElement>(
x_new
}
pub(crate) fn apply_padding_5d<E: FloatNdArrayElement>(
pub(crate) fn apply_padding_5d<E: FloatNdArrayElement, Q: QuantElement>(
x: NdArrayTensor<E, 5>,
padding: [usize; 3],
elem: E,
@ -55,7 +59,7 @@ pub(crate) fn apply_padding_5d<E: FloatNdArrayElement>(
);
let mut x_new = NdArrayTensor::new(x_new.into_shared().into_dyn());
x_new = NdArray::float_slice_assign(
x_new = NdArray::<E, Q>::float_slice_assign(
x_new,
[
0..batch_size,

View File

@ -1,9 +1,16 @@
use burn_tensor::{
ops::{FloatTensor, QTensorOps, QuantizedTensor},
DType, Quantization, QuantizationStrategy, Shape, TensorData,
quantization::{
AffineQuantization, Quantization, QuantizationParametersPrimitive, QuantizationScheme,
QuantizationStrategy, QuantizationType, SymmetricQuantization,
},
DType, Shape, TensorData,
};
use crate::{element::NdArrayElement, FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayTensor};
use crate::{
element::{NdArrayElement, QuantElement},
FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor,
};
use super::NdArrayOps;
@ -13,7 +20,7 @@ fn into_data<E: NdArrayElement, const D: usize>(tensor: NdArrayTensor<E, D>) ->
TensorData::new(values, shape)
}
impl<E: FloatNdArrayElement> QTensorOps<Self> for NdArray<E> {
impl<E: FloatNdArrayElement, Q: QuantElement> QTensorOps<Self> for NdArray<E, Q> {
fn q_from_data<const D: usize>(
data: TensorData,
_device: &NdArrayDevice,
@ -22,11 +29,19 @@ impl<E: FloatNdArrayElement> QTensorOps<Self> for NdArray<E> {
DType::QFloat(strategy) => match strategy {
QuantizationStrategy::PerTensorAffineInt8(_) => {
let data = data.convert::<i8>();
NdArrayTensor::<i8, D>::from_data(data)
NdArrayQTensor {
qtensor: NdArrayTensor::<Q, D>::from_data(data),
scheme: strategy.scheme(),
strategy,
}
}
QuantizationStrategy::PerTensorSymmetricInt8(_) => {
let data = data.convert::<i8>();
NdArrayTensor::<i8, D>::from_data(data)
NdArrayQTensor {
qtensor: NdArrayTensor::<Q, D>::from_data(data),
scheme: strategy.scheme(),
strategy,
}
}
},
_ => panic!(
@ -38,18 +53,36 @@ impl<E: FloatNdArrayElement> QTensorOps<Self> for NdArray<E> {
fn quantize<const D: usize>(
tensor: FloatTensor<Self, D>,
strategy: &QuantizationStrategy,
scheme: &QuantizationScheme,
qparams: QuantizationParametersPrimitive<Self>,
) -> QuantizedTensor<Self, D> {
let data = into_data(tensor).with_quantization(*strategy);
NdArrayTensor::<i8, D>::from_data(data)
let strategy = match scheme {
QuantizationScheme::PerTensorAffine(dtype) => match dtype {
QuantizationType::QInt8 => {
QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(
into_data(qparams.scale).iter().next().unwrap(),
into_data(qparams.offset.unwrap()).iter().next().unwrap(),
))
}
},
QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
QuantizationType::QInt8 => QuantizationStrategy::PerTensorSymmetricInt8(
SymmetricQuantization::init(into_data(qparams.scale).iter().next().unwrap()),
),
},
};
let data = into_data(tensor).with_quantization(strategy);
NdArrayQTensor {
qtensor: NdArrayTensor::<Q, D>::from_data(data),
strategy,
scheme: scheme.clone(),
}
}
fn dequantize<const D: usize>(
tensor: QuantizedTensor<Self, D>,
strategy: &QuantizationStrategy,
) -> FloatTensor<Self, D> {
let data = into_data(tensor);
let values = match strategy {
fn dequantize<const D: usize>(tensor: QuantizedTensor<Self, D>) -> FloatTensor<Self, D> {
let data = into_data(tensor.qtensor);
let values = match tensor.strategy {
QuantizationStrategy::PerTensorAffineInt8(s) => s.dequantize(data.as_slice().unwrap()),
QuantizationStrategy::PerTensorSymmetricInt8(s) => {
s.dequantize(data.as_slice().unwrap())
@ -59,7 +92,7 @@ impl<E: FloatNdArrayElement> QTensorOps<Self> for NdArray<E> {
}
fn q_shape<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Shape<D> {
tensor.shape()
tensor.qtensor.shape()
}
fn q_device<const D: usize>(_tensor: &QuantizedTensor<Self, D>) -> NdArrayDevice {
@ -70,15 +103,16 @@ impl<E: FloatNdArrayElement> QTensorOps<Self> for NdArray<E> {
tensor: QuantizedTensor<Self, D1>,
shape: Shape<D2>,
) -> QuantizedTensor<Self, D2> {
NdArrayOps::reshape(tensor, shape)
NdArrayQTensor {
qtensor: NdArrayOps::reshape(tensor.qtensor, shape),
scheme: tensor.scheme,
strategy: tensor.strategy,
}
}
async fn q_into_data<const D: usize>(
tensor: QuantizedTensor<Self, D>,
strategy: QuantizationStrategy,
) -> TensorData {
let shape = tensor.shape();
let values = tensor.array.into_iter().collect();
TensorData::quantized(values, shape, strategy)
async fn q_into_data<const D: usize>(tensor: QuantizedTensor<Self, D>) -> TensorData {
let shape = tensor.qtensor.shape();
let values = tensor.qtensor.array.into_iter().collect();
TensorData::quantized(values, shape, tensor.strategy)
}
}

View File

@ -5,7 +5,7 @@ use ndarray::IntoDimension;
// Current crate
use super::{matmul::matmul, NdArrayMathOps, NdArrayOps};
use crate::element::FloatNdArrayElement;
use crate::element::{FloatNdArrayElement, QuantElement};
use crate::{tensor::NdArrayTensor, NdArray};
use crate::{NdArrayDevice, SEED};
@ -20,7 +20,7 @@ use num_traits::Float;
use libm::erf;
impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
impl<E: FloatNdArrayElement, Q: QuantElement> FloatTensorOps<Self> for NdArray<E, Q> {
fn float_from_data<const D: usize>(
data: TensorData,
_device: &NdArrayDevice,

View File

@ -1,7 +1,12 @@
use burn_tensor::{Element, Shape, TensorData};
use burn_tensor::{
quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy},
Element, Shape, TensorData,
};
use ndarray::{ArcArray, Array, Dim, IxDyn};
use crate::element::QuantElement;
/// Tensor primitive used by the [ndarray backend](crate::NdArray).
#[derive(new, Debug, Clone)]
pub struct NdArrayTensor<E, const D: usize> {
@ -111,11 +116,38 @@ where
}
}
/// A quantized tensor for the ndarray backend.
#[derive(Clone, Debug)]
pub struct NdArrayQTensor<Q: QuantElement, const D: usize> {
/// The quantized tensor.
pub qtensor: NdArrayTensor<Q, D>,
/// The quantization scheme.
pub scheme: QuantizationScheme,
/// The quantization strategy.
pub strategy: QuantizationStrategy,
}
impl<Q: QuantElement, const D: usize> QTensorPrimitive for NdArrayQTensor<Q, D> {
fn scheme(&self) -> &QuantizationScheme {
&self.scheme
}
fn strategy(&self) -> QuantizationStrategy {
self.strategy
}
}
#[cfg(test)]
mod tests {
use crate::NdArray;
use super::*;
use burn_common::rand::get_seeded_rng;
use burn_tensor::Distribution;
use burn_tensor::{
ops::QTensorOps,
quantization::{AffineQuantization, QuantizationParametersPrimitive, QuantizationType},
Distribution,
};
#[test]
fn should_support_into_and_from_data_1d() {
@ -172,4 +204,21 @@ mod tests {
assert_eq!(data_expected, data_actual);
}
#[test]
fn should_support_qtensor_strategy() {
let tensor = NdArrayTensor::<f32, 1>::from_data(TensorData::from([-1.8, -1.0, 0.0, 0.5]));
let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8);
let qparams = QuantizationParametersPrimitive {
scale: NdArrayTensor::from_data(TensorData::from([0.009_019_608])),
offset: Some(NdArrayTensor::from_data(TensorData::from([72]))),
};
let qtensor: NdArrayQTensor<i8, 1> = NdArray::quantize(tensor, &scheme, qparams);
assert_eq!(qtensor.scheme(), &scheme);
assert_eq!(
qtensor.strategy(),
QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.009_019_608, 72))
);
}
}

View File

@ -1,4 +1,4 @@
use crate::PrecisionBridge;
use crate::{PrecisionBridge, QuantElement, TchQTensor};
use super::element::TchElement;
use super::TchTensor;
@ -86,11 +86,12 @@ impl Default for LibTorchDevice {
///
/// Refer to the [tch] crate for more information.
#[derive(Clone, Copy, Default, Debug)]
pub struct LibTorch<E = f32> {
pub struct LibTorch<E = f32, Q = i8> {
_e: E,
_q: Q,
}
impl<E: TchElement> Backend for LibTorch<E> {
impl<E: TchElement, Q: QuantElement> Backend for LibTorch<E, Q> {
type Device = LibTorchDevice;
type FullPrecisionBridge = PrecisionBridge<f32>;
@ -102,7 +103,7 @@ impl<E: TchElement> Backend for LibTorch<E> {
type BoolTensorPrimitive<const D: usize> = TchTensor<bool, D>;
type QuantizedTensorPrimitive<const D: usize> = TchTensor<i8, D>;
type QuantizedTensorPrimitive<const D: usize> = TchQTensor<Q, D>;
fn seed(seed: u64) {
tch::manual_seed(seed as i64);

View File

@ -1,4 +1,4 @@
use crate::{ops::TchOps, LibTorch, TchElement, TchTensor};
use crate::{ops::TchOps, LibTorch, QuantElement, TchElement, TchTensor};
use burn_tensor::{backend::BackendBridge, ops::FloatTensor, Device};
use std::marker::PhantomData;
@ -8,10 +8,11 @@ pub struct PrecisionBridge<E: TchElement> {
_e: PhantomData<E>,
}
impl<TElem, OElem> BackendBridge<LibTorch<OElem>> for PrecisionBridge<TElem>
impl<TElem, OElem, QElem> BackendBridge<LibTorch<OElem, QElem>> for PrecisionBridge<TElem>
where
TElem: TchElement,
OElem: TchElement,
QElem: QuantElement,
{
type Target = LibTorch<TElem>;

View File

@ -12,5 +12,11 @@ impl TchElement for bf16 {}
impl TchElement for i64 {}
impl TchElement for i32 {}
impl TchElement for i16 {}
impl TchElement for i8 {}
impl TchElement for u8 {}
/// A quantized element for the tch backend.
pub trait QuantElement: TchElement {}
impl QuantElement for i8 {}

View File

@ -1,7 +1,7 @@
use crate::{element::TchElement, LibTorch, TchTensor};
use crate::{element::TchElement, LibTorch, QuantElement, TchTensor};
use burn_tensor::ops::ActivationOps;
impl<E: TchElement> ActivationOps<Self> for LibTorch<E> {
impl<E: TchElement, Q: QuantElement> ActivationOps<Self> for LibTorch<E, Q> {
fn relu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu())
}

View File

@ -1,4 +1,4 @@
use burn_tensor::{QuantizationStrategy, Shape};
use burn_tensor::{quantization::QuantizationStrategy, Shape};
use tch::Scalar;
use crate::{LibTorchDevice, TchShape, TchTensor};

View File

@ -1,9 +1,9 @@
use super::TchOps;
use crate::{element::TchElement, LibTorch, LibTorchDevice, TchTensor};
use crate::{element::TchElement, LibTorch, LibTorchDevice, QuantElement, TchTensor};
use burn_tensor::{backend::Backend, ops::BoolTensorOps, Shape, TensorData};
use std::ops::Range;
impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
impl<E: TchElement, Q: QuantElement> BoolTensorOps<Self> for LibTorch<E, Q> {
fn bool_from_data<const D: usize>(
data: TensorData,
device: &LibTorchDevice,

View File

@ -2,11 +2,11 @@ use std::ops::Range;
use burn_tensor::{backend::Backend, ops::IntTensorOps, Distribution, Shape, TensorData};
use crate::{element::TchElement, LibTorch, LibTorchDevice, TchShape, TchTensor};
use crate::{element::TchElement, LibTorch, LibTorchDevice, QuantElement, TchShape, TchTensor};
use super::TchOps;
impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
impl<E: TchElement, Q: QuantElement> IntTensorOps<Self> for LibTorch<E, Q> {
fn int_from_data<const D: usize>(
data: TensorData,
device: &LibTorchDevice,

View File

@ -1,10 +1,10 @@
use crate::{element::TchElement, LibTorch, TchTensor};
use crate::{element::TchElement, LibTorch, QuantElement, TchTensor};
use burn_tensor::ops::{
ConvOptions, ConvTransposeOptions, InterpolateMode, InterpolateOptions, MaxPool1dWithIndices,
MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
};
impl<E: TchElement> ModuleOps<Self> for LibTorch<E> {
impl<E: TchElement, Q: QuantElement> ModuleOps<Self> for LibTorch<E, Q> {
fn embedding(weights: TchTensor<E, 2>, indices: TchTensor<i64, 2>) -> TchTensor<E, 3> {
let tensor = tch::Tensor::embedding(&weights.tensor, &indices.tensor, -1, false, false);
@ -231,7 +231,7 @@ impl<E: TchElement> ModuleOps<Self> for LibTorch<E> {
stride: usize,
padding: usize,
dilation: usize,
) -> MaxPool1dWithIndices<LibTorch<E>> {
) -> MaxPool1dWithIndices<LibTorch<E, Q>> {
let (tensor, indices) = tch::Tensor::max_pool1d_with_indices(
&x.tensor,
kernel_size as i64,
@ -269,7 +269,7 @@ impl<E: TchElement> ModuleOps<Self> for LibTorch<E> {
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> MaxPool2dWithIndices<LibTorch<E>> {
) -> MaxPool2dWithIndices<LibTorch<E, Q>> {
let (tensor, indices) = tch::Tensor::max_pool2d_with_indices(
&x.tensor,
[kernel_size[0] as i64, kernel_size[1] as i64],
@ -290,7 +290,7 @@ impl<E: TchElement> ModuleOps<Self> for LibTorch<E> {
dilation: [usize; 2],
output_grad: TchTensor<E, 4>,
indices: TchTensor<i64, 4>,
) -> MaxPool2dBackward<LibTorch<E>> {
) -> MaxPool2dBackward<LibTorch<E, Q>> {
let grad = tch::Tensor::max_pool2d_with_indices_backward(
&x.tensor,
&output_grad.tensor,

View File

@ -1,13 +1,17 @@
use burn_tensor::{
ops::{FloatTensor, QTensorOps, QuantizedTensor},
DType, Quantization, QuantizationStrategy, Shape, TensorData,
quantization::{
QTensorPrimitive, Quantization, QuantizationParametersPrimitive, QuantizationScheme,
QuantizationStrategy, QuantizationType,
},
DType, Shape, TensorData,
};
use crate::{LibTorch, LibTorchDevice, TchElement, TchShape, TchTensor};
use crate::{LibTorch, LibTorchDevice, QuantElement, TchElement, TchQTensor, TchShape, TchTensor};
use super::TchOps;
impl<E: TchElement> QTensorOps<Self> for LibTorch<E> {
impl<E: TchElement, Q: QuantElement> QTensorOps<Self> for LibTorch<E, Q> {
fn q_from_data<const D: usize>(
data: TensorData,
device: &LibTorchDevice,
@ -19,25 +23,27 @@ impl<E: TchElement> QTensorOps<Self> for LibTorch<E> {
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/quantized/Quantizer.cpp#L322
// So for now we have to load the dequantized values to quantize them back since the dequantization
// methods take the values provided when quantizing.
let tensor = match data.dtype {
let (tensor, scheme) = match data.dtype {
DType::QFloat(strategy) => match strategy {
QuantizationStrategy::PerTensorAffineInt8(q) => {
let values = q.dequantize(&data.iter::<i8>().collect::<Vec<_>>());
let tensor = tch::Tensor::from_slice(&values).to(device);
TchOps::<E>::quantize::<D, i8>(
let tensor = TchOps::<E>::quantize::<D, i8>(
TchTensor::new(tensor.reshape(shape_tch.dims)),
&strategy,
)
.tensor
.tensor;
(tensor, strategy.scheme())
}
QuantizationStrategy::PerTensorSymmetricInt8(q) => {
let values = q.dequantize(&data.iter::<i8>().collect::<Vec<_>>());
let tensor = tch::Tensor::from_slice(&values).to(device);
TchOps::<E>::quantize::<D, i8>(
let tensor = TchOps::<E>::quantize::<D, i8>(
TchTensor::new(tensor.reshape(shape_tch.dims)),
&strategy,
)
.tensor
.tensor;
(tensor, strategy.scheme())
}
},
_ => panic!(
@ -45,12 +51,16 @@ impl<E: TchElement> QTensorOps<Self> for LibTorch<E> {
data.dtype
),
};
TchTensor::new(tensor)
TchQTensor {
qtensor: TchTensor::new(tensor),
scheme,
}
}
fn quantize<const D: usize>(
tensor: FloatTensor<Self, D>,
strategy: &QuantizationStrategy,
scheme: &QuantizationScheme,
qparams: QuantizationParametersPrimitive<Self>,
) -> QuantizedTensor<Self, D> {
let mut tensor = tensor;
// Quantize only works on Float Tensor
@ -58,52 +68,58 @@ impl<E: TchElement> QTensorOps<Self> for LibTorch<E> {
tensor.tensor = tensor.tensor.to_kind(tch::Kind::Float);
}
match strategy {
QuantizationStrategy::PerTensorAffineInt8(ref q) => {
TchTensor::new(tensor.tensor.quantize_per_tensor(
q.scale.into(),
q.offset.into(),
let qtensor = match scheme {
QuantizationScheme::PerTensorAffine(dtype) => match dtype {
QuantizationType::QInt8 => tensor.tensor.quantize_per_tensor_tensor_qparams(
&qparams.scale.tensor,
&qparams.offset.unwrap().tensor,
tch::Kind::QInt8,
))
),
},
QuantizationScheme::PerTensorSymmetric(_) => {
tensor.tensor.quantize_per_tensor_tensor_qparams(
&qparams.scale.tensor,
&tch::Tensor::zeros_like(&qparams.scale.tensor),
tch::Kind::QInt8,
)
}
QuantizationStrategy::PerTensorSymmetricInt8(ref q) => TchTensor::new(
tensor
.tensor
.quantize_per_tensor(q.scale.into(), 0, tch::Kind::QInt8),
),
};
TchQTensor {
qtensor: TchTensor::new(qtensor),
scheme: scheme.clone(),
}
}
fn dequantize<const D: usize>(
tensor: QuantizedTensor<Self, D>,
_strategy: &QuantizationStrategy,
) -> FloatTensor<Self, D> {
TchTensor::new(tensor.tensor.dequantize().to_kind(E::KIND))
fn dequantize<const D: usize>(tensor: QuantizedTensor<Self, D>) -> FloatTensor<Self, D> {
TchTensor::new(tensor.qtensor.tensor.dequantize().to_kind(E::KIND))
}
fn q_shape<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Shape<D> {
tensor.shape()
tensor.qtensor.shape()
}
fn q_device<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> LibTorchDevice {
tensor.tensor.device().into()
tensor.qtensor.tensor.device().into()
}
fn q_reshape<const D1: usize, const D2: usize>(
tensor: QuantizedTensor<Self, D1>,
shape: Shape<D2>,
) -> QuantizedTensor<Self, D2> {
TchOps::reshape(tensor, shape)
TchQTensor {
qtensor: TchOps::reshape(tensor.qtensor, shape),
scheme: tensor.scheme,
}
}
async fn q_into_data<const D: usize>(
tensor: QuantizedTensor<Self, D>,
strategy: QuantizationStrategy,
) -> TensorData {
async fn q_into_data<const D: usize>(tensor: QuantizedTensor<Self, D>) -> TensorData {
let shape = Self::q_shape(&tensor);
let tensor = Self::q_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
let strategy = tensor.strategy();
// To get the integer values we have to call `int_repr()`
let values: Result<Vec<i8>, tch::TchError> = tensor.tensor.int_repr().try_into();
let values: Result<Vec<i8>, tch::TchError> = tensor.qtensor.tensor.int_repr().try_into();
TensorData::quantized(values.unwrap(), shape, strategy)
}

View File

@ -1,11 +1,11 @@
use super::TchOps;
use crate::{element::TchElement, LibTorch, LibTorchDevice, TchShape, TchTensor};
use crate::{element::TchElement, LibTorch, LibTorchDevice, QuantElement, TchShape, TchTensor};
use burn_tensor::{
backend::Backend, ops::FloatTensorOps, Distribution, ElementConversion, Shape, TensorData,
};
use std::ops::Range;
impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
impl<E: TchElement, Q: QuantElement> FloatTensorOps<Self> for LibTorch<E, Q> {
fn float_from_data<const D: usize>(
data: TensorData,
device: &LibTorchDevice,

View File

@ -1,5 +1,11 @@
use crate::{element::TchElement, LibTorch, LibTorchDevice};
use burn_tensor::{ops::FloatTensorOps, Element, Shape, TensorData};
use crate::{LibTorchDevice, QuantElement};
use burn_tensor::{
quantization::{
AffineQuantization, QTensorPrimitive, QuantizationScheme, QuantizationStrategy,
QuantizationType, SymmetricQuantization,
},
Element, Shape, TensorData,
};
use libc::c_void;
use std::{marker::PhantomData, sync::Arc};
@ -139,14 +145,6 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
}
}
impl<E: TchElement, const D: usize> std::ops::Add for TchTensor<E, D> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
LibTorch::float_add(self, rhs)
}
}
impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
pub(crate) fn shape(&self) -> Shape<D> {
Shape::from(self.tensor.size())
@ -314,9 +312,51 @@ impl<E: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> T
}
}
/// A quantized tensor for the tch backend.
#[derive(Clone, Debug)]
pub struct TchQTensor<Q: QuantElement, const D: usize> {
/// The quantized tensor.
pub qtensor: TchTensor<Q, D>,
/// The quantization scheme.
pub scheme: QuantizationScheme,
}
impl<Q: QuantElement, const D: usize> QTensorPrimitive for TchQTensor<Q, D> {
fn scheme(&self) -> &QuantizationScheme {
&self.scheme
}
fn strategy(&self) -> QuantizationStrategy {
match &self.scheme {
QuantizationScheme::PerTensorAffine(dtype) => match dtype {
QuantizationType::QInt8 => {
let scale = self.qtensor.tensor.q_scale();
let offset = self.qtensor.tensor.q_zero_point();
QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(
scale as f32,
offset as i8,
))
}
},
QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
QuantizationType::QInt8 => {
let scale = self.qtensor.tensor.q_scale();
QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(
scale as f32,
))
}
},
}
}
}
#[cfg(test)]
mod tests {
use crate::LibTorch;
use super::*;
use burn_tensor::ops::QTensorOps;
use burn_tensor::quantization::QuantizationParametersPrimitive;
use burn_tensor::{Distribution, Tensor, TensorPrimitive};
use rand::prelude::StdRng;
use rand::SeedableRng;
@ -376,4 +416,27 @@ mod tests {
tensor_1.to_data().as_slice::<f32>().unwrap()
);
}
#[test]
fn should_support_qtensor_strategy() {
let tensor = TchTensor::<f32, 1>::from_data(
TensorData::from([-1.8, -1.0, 0.0, 0.5]),
tch::Device::Cpu,
);
let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8);
let qparams = QuantizationParametersPrimitive {
scale: TchTensor::from_data(TensorData::from([0.009_019_608]), tch::Device::Cpu),
offset: Some(TchTensor::from_data(
TensorData::from([72]),
tch::Device::Cpu,
)),
};
let qtensor: TchQTensor<i8, 1> = LibTorch::quantize(tensor, &scheme, qparams);
assert_eq!(qtensor.scheme(), &scheme);
assert_eq!(
qtensor.strategy(),
QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.009_019_608, 72))
);
}
}

View File

@ -18,10 +18,7 @@ impl<const D: usize, B: AutodiffBackend> Tensor<B, D> {
TensorPrimitive::Float(tensor) => B::grad(tensor, grads)
.map(TensorPrimitive::Float)
.map(Tensor::new),
TensorPrimitive::QFloat {
tensor: _,
strategy: _,
} => B::grad(&self.primitive.clone().tensor(), grads)
TensorPrimitive::QFloat(_tensor) => B::grad(&self.primitive.clone().tensor(), grads)
.map(TensorPrimitive::Float)
.map(Tensor::new),
}
@ -33,12 +30,11 @@ impl<const D: usize, B: AutodiffBackend> Tensor<B, D> {
TensorPrimitive::Float(tensor) => B::grad_remove(tensor, grads)
.map(TensorPrimitive::Float)
.map(Tensor::new),
TensorPrimitive::QFloat {
tensor: _,
strategy: _,
} => B::grad_remove(&self.primitive.clone().tensor(), grads)
.map(TensorPrimitive::Float)
.map(Tensor::new),
TensorPrimitive::QFloat(_tensor) => {
B::grad_remove(&self.primitive.clone().tensor(), grads)
.map(TensorPrimitive::Float)
.map(Tensor::new)
}
}
}
@ -49,10 +45,7 @@ impl<const D: usize, B: AutodiffBackend> Tensor<B, D> {
TensorPrimitive::Float(tensor) => {
B::grad_replace(tensor, grads, grad.primitive.tensor())
}
TensorPrimitive::QFloat {
tensor: _,
strategy: _,
} => B::grad_replace(
TensorPrimitive::QFloat(_tensor) => B::grad_replace(
&self.primitive.clone().tensor(),
grads,
grad.primitive.tensor(),
@ -89,10 +82,7 @@ impl<B: AutodiffBackend> BasicAutodiffOps<B> for Float {
) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive<D> {
match tensor {
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::inner(tensor)),
TensorPrimitive::QFloat { tensor, strategy } => TensorPrimitive::QFloat {
tensor: B::q_inner(tensor),
strategy,
},
TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_inner(tensor)),
}
}
@ -101,10 +91,7 @@ impl<B: AutodiffBackend> BasicAutodiffOps<B> for Float {
) -> <Self as TensorKind<B>>::Primitive<D> {
match inner {
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::from_inner(tensor)),
TensorPrimitive::QFloat { tensor, strategy } => TensorPrimitive::QFloat {
tensor: B::q_from_inner(tensor),
strategy,
},
TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_from_inner(tensor)),
}
}
}

View File

@ -1687,10 +1687,7 @@ impl<B: Backend> BasicOps<B> for Float {
fn shape<const D: usize>(tensor: &Self::Primitive<D>) -> Shape<D> {
match tensor {
TensorPrimitive::Float(tensor) => B::float_shape(tensor),
TensorPrimitive::QFloat {
tensor,
strategy: _,
} => B::q_shape(tensor),
TensorPrimitive::QFloat(tensor) => B::q_shape(tensor),
}
}
@ -1702,10 +1699,7 @@ impl<B: Backend> BasicOps<B> for Float {
TensorPrimitive::Float(tensor) => {
TensorPrimitive::Float(B::float_reshape(tensor, shape))
}
TensorPrimitive::QFloat { tensor, strategy } => TensorPrimitive::QFloat {
tensor: B::q_reshape(tensor, shape),
strategy,
},
TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_reshape(tensor, shape)),
}
}
@ -1744,10 +1738,7 @@ impl<B: Backend> BasicOps<B> for Float {
fn device<const D: usize>(tensor: &Self::Primitive<D>) -> <B as Backend>::Device {
match tensor {
TensorPrimitive::Float(tensor) => B::float_device(tensor),
TensorPrimitive::QFloat {
tensor,
strategy: _,
} => B::q_device(tensor),
TensorPrimitive::QFloat(tensor) => B::q_device(tensor),
}
}
@ -1761,16 +1752,13 @@ impl<B: Backend> BasicOps<B> for Float {
async fn into_data_async<const D: usize>(tensor: Self::Primitive<D>) -> TensorData {
match tensor {
TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await,
TensorPrimitive::QFloat { tensor, strategy } => B::q_into_data(tensor, strategy).await,
TensorPrimitive::QFloat(tensor) => B::q_into_data(tensor).await,
}
}
fn from_data<const D: usize>(data: TensorData, device: &B::Device) -> Self::Primitive<D> {
match data.dtype {
DType::QFloat(strategy) => TensorPrimitive::QFloat {
tensor: B::q_from_data(data, device),
strategy,
},
DType::QFloat(_strategy) => TensorPrimitive::QFloat(B::q_from_data(data, device)),
_ => TensorPrimitive::Float(B::float_from_data(data, device)),
}
}

View File

@ -1,13 +1,14 @@
use alloc::vec::Vec;
use core::convert::TryInto;
use crate::check;
use crate::check::TensorCheck;
use crate::ops::FullPrecisionBackend;
use crate::quantization::{QuantizationParameters, QuantizationScheme};
use crate::tensor::backend::Backend;
use crate::tensor::stats;
use crate::tensor::{Distribution, Shape, TensorData};
use crate::Tensor;
use crate::{check, QuantizationStrategy};
use crate::{Int, TensorPrimitive};
impl<const D: usize, B> Tensor<B, D>
@ -270,10 +271,7 @@ where
pub fn is_require_grad(&self) -> bool {
match &self.primitive {
TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor),
TensorPrimitive::QFloat {
tensor,
strategy: _,
} => B::q_is_require_grad(tensor),
TensorPrimitive::QFloat(tensor) => B::q_is_require_grad(tensor),
}
}
@ -286,10 +284,9 @@ where
TensorPrimitive::Float(tensor) => {
TensorPrimitive::Float(B::float_set_require_grad(tensor, require_grad))
}
TensorPrimitive::QFloat { tensor, strategy } => TensorPrimitive::QFloat {
tensor: B::q_set_require_grad(tensor, require_grad),
strategy,
},
TensorPrimitive::QFloat(tensor) => {
TensorPrimitive::QFloat(B::q_set_require_grad(tensor, require_grad))
}
};
Self::new(primitive)
}
@ -315,20 +312,26 @@ where
.div_scalar(n as f32 - correction_factor as f32)
}
/// Convert the tensor to a lower precision data type based on the quantization strategy.
/// Convert the tensor to a lower precision data type based on the quantization scheme.
///
/// # Arguments
///
/// * `strategy` - The quantization strategy.
/// * `scheme` - The quantization scheme.
/// * `qparams` - The pre-computed quantization parameters.
///
/// # Returns
///
/// The quantized tensor.
pub fn quantize(self, strategy: QuantizationStrategy) -> Tensor<B, D> {
Tensor::new(TensorPrimitive::QFloat {
tensor: B::quantize(self.primitive.tensor(), &strategy),
strategy,
})
pub fn quantize(
self,
scheme: &QuantizationScheme,
qparams: QuantizationParameters<B>,
) -> Tensor<B, D> {
Tensor::new(TensorPrimitive::QFloat(B::quantize(
self.primitive.tensor(),
scheme,
qparams.into(),
)))
}
/// Convert the tensor back to a higher precision data type.

View File

@ -1,4 +1,4 @@
use crate::{backend::Backend, QuantizationStrategy};
use crate::backend::Backend;
/// A type-level representation of the kind of a float tensor
#[derive(Clone, Debug)]
@ -18,19 +18,14 @@ pub enum TensorPrimitive<B: Backend, const D: usize> {
/// Float tensor primitive.
Float(B::FloatTensorPrimitive<D>),
/// Quantized float tensor primitive.
QFloat {
/// The underlying quantized tensor.
tensor: B::QuantizedTensorPrimitive<D>,
/// The tensor quantization strategy.
strategy: QuantizationStrategy,
},
QFloat(B::QuantizedTensorPrimitive<D>),
}
impl<B: Backend, const D: usize> TensorPrimitive<B, D> {
/// Returns the full tensor representation.
pub fn tensor(self) -> B::FloatTensorPrimitive<D> {
match self {
Self::QFloat { tensor, strategy } => B::dequantize(tensor, &strategy),
Self::QFloat(tensor) => B::dequantize(tensor),
Self::Float(tensor) => tensor,
}
}

View File

@ -1,8 +1,8 @@
use alloc::string::String;
pub use burn_common::sync_type::SyncType;
use crate::ops::*;
use crate::tensor::Element;
use crate::{ops::*, quantization::QTensorPrimitive};
use super::{BackendBridge, DeviceOps};
@ -87,7 +87,11 @@ pub trait Backend:
type BoolTensorPrimitive<const D: usize>: Clone + Send + 'static + core::fmt::Debug;
/// Tensor primitive to be used for all quantized operations.
type QuantizedTensorPrimitive<const D: usize>: Clone + Send + 'static + core::fmt::Debug;
type QuantizedTensorPrimitive<const D: usize>: QTensorPrimitive
+ Clone
+ Send
+ 'static
+ core::fmt::Debug;
/// If autodiff is enabled.
fn ad_enabled() -> bool {

View File

@ -7,8 +7,9 @@ use alloc::vec::Vec;
use half::{bf16, f16};
use crate::{
tensor::Shape, DType, Distribution, Element, ElementConversion, Quantization,
QuantizationStrategy,
quantization::{Quantization, QuantizationStrategy},
tensor::Shape,
DType, Distribution, Element, ElementConversion,
};
use num_traits::pow::Pow;
@ -66,10 +67,14 @@ impl TensorData {
}
}
fn try_as_slice<E: Element>(&self) -> Result<&[E], DataError> {
bytemuck::checked::try_cast_slice(&self.bytes).map_err(DataError::CastError)
}
/// Returns the immutable slice view of the tensor data.
pub fn as_slice<E: Element>(&self) -> Result<&[E], DataError> {
if E::dtype() == self.dtype {
bytemuck::checked::try_cast_slice(&self.bytes).map_err(DataError::CastError)
self.try_as_slice()
} else {
Err(DataError::TypeMismatch(format!(
"Invalid target element type (expected {:?}, got {:?})",
@ -598,10 +603,10 @@ impl core::fmt::Display for TensorData {
DType::Bool => format!("{:?}", self.as_slice::<bool>().unwrap()),
DType::QFloat(q) => match &q {
QuantizationStrategy::PerTensorAffineInt8(_) => {
format!("{:?} {q:?}", self.as_slice::<i8>().unwrap())
format!("{:?} {q:?}", self.try_as_slice::<i8>().unwrap())
}
QuantizationStrategy::PerTensorSymmetricInt8(_) => {
format!("{:?} {q:?}", self.as_slice::<i8>().unwrap())
format!("{:?} {q:?}", self.try_as_slice::<i8>().unwrap())
}
},
};

View File

@ -1,6 +1,6 @@
use core::cmp::Ordering;
use crate::{cast::ToElement, Distribution, QuantizationStrategy};
use crate::{cast::ToElement, quantization::QuantizationStrategy, Distribution};
use half::{bf16, f16};
use rand::RngCore;
use serde::{Deserialize, Serialize};

View File

@ -4,14 +4,12 @@ mod api;
mod data;
mod distribution;
mod element;
mod quantization_strategy;
mod shape;
pub use api::*;
pub use data::*;
pub use distribution::*;
pub use element::*;
pub use quantization_strategy::*;
pub use shape::*;
/// The activation module.
@ -32,6 +30,9 @@ pub mod module;
/// Operations on tensors module.
pub mod ops;
/// Tensor quantization module.
pub mod quantization;
#[cfg(feature = "experimental-named-tensor")]
mod named;
#[cfg(feature = "experimental-named-tensor")]

View File

@ -1,6 +1,10 @@
use core::future::Future;
use crate::{backend::Backend, Device, QuantizationStrategy, Shape, TensorData};
use crate::{
backend::Backend,
quantization::{QuantizationParametersPrimitive, QuantizationScheme},
Device, Shape, TensorData,
};
use super::{FloatTensor, QuantizedTensor};
@ -19,17 +23,15 @@ pub trait QTensorOps<B: Backend> {
/// The tensor with the given data.
fn q_from_data<const D: usize>(data: TensorData, device: &Device<B>) -> QuantizedTensor<B, D>;
/// Convert the tensor to a lower precision data type based on the quantization strategy.
/// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
fn quantize<const D: usize>(
tensor: FloatTensor<B, D>,
strategy: &QuantizationStrategy,
scheme: &QuantizationScheme,
qparams: QuantizationParametersPrimitive<B>,
) -> QuantizedTensor<B, D>;
/// Convert the tensor back to a higher precision data type based on the quantization strategy.
fn dequantize<const D: usize>(
tensor: QuantizedTensor<B, D>,
strategy: &QuantizationStrategy,
) -> FloatTensor<B, D>;
/// Convert the tensor back to a higher precision data type.
fn dequantize<const D: usize>(tensor: QuantizedTensor<B, D>) -> FloatTensor<B, D>;
/// Gets the shape of the tensor.
///
@ -79,7 +81,6 @@ pub trait QTensorOps<B: Backend> {
/// The data structure with the tensor's data.
fn q_into_data<const D: usize>(
tensor: QuantizedTensor<B, D>,
strategy: QuantizationStrategy,
) -> impl Future<Output = TensorData> + Send;
/// Sets the `require_grad` flag of a tensor.

View File

@ -0,0 +1,34 @@
use crate::{backend::Backend, Tensor};
/// The observed input calibration range.
#[derive(Clone, Debug)]
pub struct CalibrationRange<B: Backend> {
/// Minimum observed value.
pub min: Tensor<B, 1>,
/// Maximum observed value.
pub max: Tensor<B, 1>,
}
/// Calibration method used to compute the quantization range mapping.
pub trait Calibration {
/// Compute the input tensor range.
fn compute_range<B: Backend, const D: usize>(
&self,
tensor: &Tensor<B, D>,
) -> CalibrationRange<B>;
}
/// Computes the per-tensor quantization range mapping based on the min and max values.
pub struct MinMaxCalibration {}
impl Calibration for MinMaxCalibration {
fn compute_range<B: Backend, const D: usize>(
&self,
tensor: &Tensor<B, D>,
) -> CalibrationRange<B> {
let min = tensor.clone().min();
let max = tensor.clone().max();
CalibrationRange { min, max }
}
}

View File

@ -0,0 +1,11 @@
mod calibration;
mod parameters;
mod primitive;
mod scheme;
mod strategy;
pub use calibration::*;
pub use parameters::*;
pub use primitive::*;
pub use scheme::*;
pub use strategy::*;

View File

@ -0,0 +1,35 @@
use crate::{backend::Backend, Int, Tensor};
/// The quantization parameters.
#[derive(Clone, Debug)]
pub struct QuantizationParameters<B: Backend> {
/// The scaling factor.
pub scale: Tensor<B, 1>,
/// The zero-point offset.
pub offset: Option<Tensor<B, 1, Int>>,
}
/// The quantization parameters primitive.
///
/// # Remarks
///
/// This is a low-level struct used internally by the library to provide the quantization parameters
/// to the backends. It is not designed for direct usage by users, and not recommended to import
/// or use this struct directly.
///
/// Users should prefer the [QuantizationParameters] struct, which is designed for public use.
pub struct QuantizationParametersPrimitive<B: Backend> {
/// The scaling factor.
pub scale: B::FloatTensorPrimitive<1>,
/// The zero-point offset.
pub offset: Option<B::IntTensorPrimitive<1>>,
}
impl<B: Backend> From<QuantizationParameters<B>> for QuantizationParametersPrimitive<B> {
fn from(value: QuantizationParameters<B>) -> Self {
QuantizationParametersPrimitive {
scale: value.scale.primitive.tensor(),
offset: value.offset.map(|x| x.primitive),
}
}
}

View File

@ -0,0 +1,13 @@
use super::{QuantizationScheme, QuantizationStrategy};
/// Quantized tensor primitive.
pub trait QTensorPrimitive {
/// Returns the quantization scheme for the given tensor.
fn scheme(&self) -> &QuantizationScheme;
/// Returns the quantization strategy for the given tensor.
///
/// # Remarks
/// Retrieving the quantization strategy with its corresponding parameters might require
/// synchronization on the backend.
fn strategy(&self) -> QuantizationStrategy;
}

View File

@ -0,0 +1,71 @@
use crate::{backend::Backend, Int, Tensor};
use super::{CalibrationRange, QuantizationParameters};
/// Quantization data type.
#[derive(Clone, Debug, PartialEq)]
pub enum QuantizationType {
/// 8-bit signed integer.
QInt8,
}
/// Quantization scheme.
#[derive(Clone, Debug, PartialEq)]
pub enum QuantizationScheme {
/// Per-tensor affine/asymmetric quantization.
PerTensorAffine(QuantizationType),
/// Per-tensor symmetric quantization.
PerTensorSymmetric(QuantizationType),
// /// Per-channel affine/asymmetric quantization.
// PerChannelAffine,
// /// Per-channel symmetric quantization.
// PerChannelSymmetric,
}
/// Round the tensor to the nearest integer.
fn round<B: Backend, const D: usize>(tensor: Tensor<B, D>) -> Tensor<B, D, Int> {
tensor.add_scalar(0.5).int()
}
impl QuantizationScheme {
/// Compute the quantization parameters.
pub fn compute_q_params<B: Backend>(
&self,
range: CalibrationRange<B>,
) -> QuantizationParameters<B> {
match self {
QuantizationScheme::PerTensorAffine(dtype) => match dtype {
QuantizationType::QInt8 => {
// Quantized range `[a, b]`
let a = i8::MIN as i32;
let b = i8::MAX as i32;
// Input range `[alpha, beta]`
let input_range = range.max.clone().sub(range.min.clone());
QuantizationParameters {
scale: input_range.clone().div_scalar(b - a),
offset: Some(round(
(range.max.mul_scalar(a) - range.min.mul_scalar(b)).div(input_range),
)),
}
}
},
QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
QuantizationType::QInt8 => {
// Quantized range `[a, b]`
let b = i8::MAX as i32;
let a = -b;
// Compute scale to convert an input value in range `[-alpha, alpha]`
let values_range = range.min.abs().max_pair(range.max.abs()).mul_scalar(2);
QuantizationParameters {
scale: values_range.div_scalar(b - a),
offset: None,
}
}
},
}
}
}

View File

@ -8,6 +8,10 @@ use burn_common::{iter_par, run_par};
use num_traits::{Float, PrimInt};
use serde::{Deserialize, Serialize};
use super::{QuantizationScheme, QuantizationType};
// NOTE: QuantizationStrategy is used for TensorData (sync).
/// Quantization strategy.
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub enum QuantizationStrategy {
@ -17,6 +21,20 @@ pub enum QuantizationStrategy {
PerTensorSymmetricInt8(SymmetricQuantization<f32, i8>),
}
impl QuantizationStrategy {
/// Returns the corresponding quantization scheme.
pub fn scheme(&self) -> QuantizationScheme {
match self {
QuantizationStrategy::PerTensorAffineInt8(_) => {
QuantizationScheme::PerTensorAffine(QuantizationType::QInt8)
}
QuantizationStrategy::PerTensorSymmetricInt8(_) => {
QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8)
}
}
}
}
/// Quantization scheme to convert elements of a higher precision data type `E` to a lower precision
/// data type `Q` and vice-versa.
pub trait Quantization<E: Float, Q: PrimInt> {
@ -41,6 +59,17 @@ pub struct AffineQuantization<E: Float, Q: PrimInt, A: PrimInt> {
_a: PhantomData<A>,
}
impl<E: Float, Q: PrimInt, A: PrimInt> AffineQuantization<E, Q, A> {
/// Initialize an affine quantization scheme with the given parameters.
pub fn init(scale: E, offset: Q) -> Self {
Self {
scale,
offset,
_a: PhantomData,
}
}
}
impl<E: Float, Q: PrimInt, A: PrimInt> Quantization<E, Q> for AffineQuantization<E, Q, A> {
fn new(alpha: E, beta: E) -> Self {
// Q range `[a, b]`
@ -49,11 +78,10 @@ impl<E: Float, Q: PrimInt, A: PrimInt> Quantization<E, Q> for AffineQuantization
// Compute scale and offset to convert a floating point value in range `[alpha, beta]` to the quantized range
let range = beta - alpha;
Self {
scale: range / (b - a),
offset: Q::from(E::round(((beta * a) - (alpha * b)) / range)).unwrap(),
_a: PhantomData,
}
Self::init(
range / (b - a),
Q::from(E::round(((beta * a) - (alpha * b)) / range)).unwrap(),
)
}
fn quantize(&self, values: &[E]) -> Vec<Q> {
@ -97,6 +125,16 @@ pub struct SymmetricQuantization<E: Float, Q: PrimInt> {
_q: PhantomData<Q>,
}
impl<E: Float, Q: PrimInt> SymmetricQuantization<E, Q> {
/// Initialize a symmetric quantization scheme with the given parameters.
pub fn init(scale: E) -> Self {
Self {
scale,
_q: PhantomData,
}
}
}
impl<E: Float, Q: PrimInt> Quantization<E, Q> for SymmetricQuantization<E, Q> {
fn new(alpha: E, beta: E) -> Self {
assert!(
@ -110,10 +148,7 @@ impl<E: Float, Q: PrimInt> Quantization<E, Q> for SymmetricQuantization<E, Q> {
// Compute scale to convert a floating point value in range `[-alpha, alpha]` to the quantized range
let alpha = alpha.abs().max(beta.abs());
Self {
scale: (alpha + alpha) / (b - a),
_q: PhantomData,
}
Self::init((alpha + alpha) / (b - a))
}
fn quantize(&self, values: &[E]) -> Vec<Q> {

View File

@ -2,6 +2,7 @@ mod activation;
mod clone_invariance;
mod module;
mod ops;
mod quantization;
mod stats;
#[allow(missing_docs)]
@ -113,5 +114,9 @@ macro_rules! testgen_all {
// test padding
burn_tensor::testgen_padding!();
// test quantization
burn_tensor::testgen_calibration!();
burn_tensor::testgen_scheme!();
};
}

View File

@ -0,0 +1,26 @@
#[burn_tensor_testgen::testgen(calibration)]
mod tests {
use super::*;
use burn_tensor::{
quantization::{Calibration, MinMaxCalibration, QuantizationType},
Tensor, TensorData,
};
#[test]
fn min_max_calibration_range() {
let tensor =
Tensor::<TestBackend, 1>::from_floats([-1.8, -1.0, 0.0, 0.5], &Default::default());
let calibration = MinMaxCalibration {};
let range = calibration.compute_range(&tensor);
range
.min
.into_data()
.assert_eq(&TensorData::from([-1.8]), false);
range
.max
.into_data()
.assert_eq(&TensorData::from([0.5]), false);
}
}

View File

@ -0,0 +1,2 @@
mod calibration;
mod scheme;

View File

@ -0,0 +1,48 @@
#[burn_tensor_testgen::testgen(scheme)]
mod tests {
use super::*;
use burn_tensor::{
quantization::{CalibrationRange, QuantizationScheme, QuantizationType},
Tensor, TensorData,
};
#[test]
fn per_tensor_affine_int8() {
let device = Default::default();
let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8);
let range = CalibrationRange {
min: Tensor::<TestBackend, 1>::from_floats([-1.8], &device),
max: Tensor::<TestBackend, 1>::from_floats([0.5], &device),
};
let qparams = scheme.compute_q_params(range);
qparams
.scale
.into_data()
.assert_approx_eq(&TensorData::from([0.009_019_608]), 9);
qparams
.offset
.unwrap()
.into_data()
.assert_eq(&TensorData::from([72]), false);
}
#[test]
fn per_tensor_symmetric_int8() {
let device = Default::default();
let scheme = QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8);
let range = CalibrationRange {
min: Tensor::<TestBackend, 1>::from_floats([-1.8], &device),
max: Tensor::<TestBackend, 1>::from_floats([0.5], &device),
};
let qparams = scheme.compute_q_params(range);
qparams
.scale
.into_data()
.assert_approx_eq(&TensorData::from([0.014_173_228]), 9);
assert!(qparams.offset.is_none());
}
}