mirror of https://github.com/tracel-ai/burn.git
Refactor tensor quantization for q_* ops (#2025)
* Move QuantizationScheme to burn-tensor * Refactor QuantizedTensorPrimitive to include the quantization strategy * Fix QFloat tensor data display * Refactor quantization methods to use scheme and qparams (on backend device) * Fix clippy * Fix fmt * Add qtensor primitive tests
This commit is contained in:
parent
3204cbe345
commit
0d5025edbb
|
@ -307,7 +307,7 @@ Those operations are only available for `Float` tensors on backends that impleme
|
|||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| ------------------------------------ | ------------------------------- |
|
||||
| `tensor.quantize(strategy)` | N/A |
|
||||
| `tensor.quantize(scheme, qparams)` | N/A |
|
||||
| `tensor.dequantize()` | N/A |
|
||||
|
||||
## Activation Functions
|
||||
|
|
|
@ -44,13 +44,13 @@ tensors and can collect their statistics, such as the min and max value when usi
|
|||
`MinMaxCalibration`, to compute the quantization parameters.
|
||||
|
||||
```rust , ignore
|
||||
# use burn::quantization::{MinMaxCalibration, QuantizationScheme, QuantizationType, Quantizer};
|
||||
# use burn::module::Quantizer;
|
||||
# use burn::tensor::quantization::{MinMaxCalibration, QuantizationScheme, QuantizationType};
|
||||
#
|
||||
// Quantization config
|
||||
let mut quantizer = Quantizer {
|
||||
calibration: MinMaxCalibration {
|
||||
scheme: QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8),
|
||||
},
|
||||
calibration: MinMaxCalibration {},
|
||||
scheme: QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8),
|
||||
};
|
||||
|
||||
// Quantize the weights
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use burn_tensor::{
|
||||
backend::Backend,
|
||||
ops::{FloatTensor, QTensorOps, QuantizedTensor},
|
||||
Device, QuantizationStrategy, Shape, TensorData,
|
||||
quantization::{QuantizationParametersPrimitive, QuantizationScheme},
|
||||
Device, Shape, TensorData,
|
||||
};
|
||||
|
||||
use crate::{checkpoint::strategy::CheckpointStrategy, Autodiff};
|
||||
|
@ -16,15 +17,13 @@ impl<B: Backend, C: CheckpointStrategy> QTensorOps<Self> for Autodiff<B, C> {
|
|||
|
||||
fn quantize<const D: usize>(
|
||||
_tensor: FloatTensor<Self, D>,
|
||||
_strategy: &QuantizationStrategy,
|
||||
_scheme: &QuantizationScheme,
|
||||
_qparams: QuantizationParametersPrimitive<Self>,
|
||||
) -> QuantizedTensor<Self, D> {
|
||||
todo!() // required for QAT
|
||||
}
|
||||
|
||||
fn dequantize<const D: usize>(
|
||||
_tensor: QuantizedTensor<Self, D>,
|
||||
_strategy: &QuantizationStrategy,
|
||||
) -> FloatTensor<Self, D> {
|
||||
fn dequantize<const D: usize>(_tensor: QuantizedTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
|
@ -43,10 +42,7 @@ impl<B: Backend, C: CheckpointStrategy> QTensorOps<Self> for Autodiff<B, C> {
|
|||
B::q_reshape(tensor, shape)
|
||||
}
|
||||
|
||||
async fn q_into_data<const D: usize>(
|
||||
tensor: QuantizedTensor<Self, D>,
|
||||
strategy: QuantizationStrategy,
|
||||
) -> TensorData {
|
||||
B::q_into_data(tensor, strategy).await
|
||||
async fn q_into_data<const D: usize>(tensor: QuantizedTensor<Self, D>) -> TensorData {
|
||||
B::q_into_data(tensor).await
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,13 +2,14 @@ use std::marker::PhantomData;
|
|||
|
||||
use burn_tensor::{
|
||||
backend::{Backend, DeviceId, DeviceOps, SyncType},
|
||||
quantization::{QTensorPrimitive, QuantizationStrategy},
|
||||
Device,
|
||||
};
|
||||
use candle_core::DeviceLocation;
|
||||
|
||||
use crate::{
|
||||
element::{CandleElement, FloatCandleElement, IntCandleElement},
|
||||
CandleTensor, PrecisionBridge,
|
||||
CandleQTensor, CandleTensor, PrecisionBridge,
|
||||
};
|
||||
|
||||
/// Tensor backend that uses the [candle](candle_core) crate for executing tensor operations.
|
||||
|
@ -92,8 +93,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
|
|||
|
||||
type BoolTensorPrimitive<const D: usize> = CandleTensor<u8, D>;
|
||||
|
||||
// NOTE: candle does not implement `WithDType` for i8
|
||||
type QuantizedTensorPrimitive<const D: usize> = CandleTensor<u8, D>;
|
||||
type QuantizedTensorPrimitive<const D: usize> = CandleQTensor<D>;
|
||||
|
||||
fn ad_enabled() -> bool {
|
||||
false
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
use burn_tensor::{
|
||||
backend::Backend,
|
||||
ops::{FloatTensor, QTensorOps, QuantizedTensor},
|
||||
DType, Device, QuantizationStrategy, Shape, TensorData,
|
||||
quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationStrategy},
|
||||
DType, Device, Shape, TensorData,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
element::{FloatCandleElement, IntCandleElement},
|
||||
Candle,
|
||||
Candle, CandleQTensor,
|
||||
};
|
||||
|
||||
impl<F: FloatCandleElement, I: IntCandleElement> QTensorOps<Self> for Candle<F, I> {
|
||||
|
@ -19,37 +20,35 @@ impl<F: FloatCandleElement, I: IntCandleElement> QTensorOps<Self> for Candle<F,
|
|||
|
||||
fn quantize<const D: usize>(
|
||||
_tensor: FloatTensor<Self, D>,
|
||||
_strategy: &QuantizationStrategy,
|
||||
_scheme: &QuantizationScheme,
|
||||
_qparams: QuantizationParametersPrimitive<Self>,
|
||||
) -> QuantizedTensor<Self, D> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn dequantize<const D: usize>(
|
||||
_tensor: QuantizedTensor<Self, D>,
|
||||
_strategy: &QuantizationStrategy,
|
||||
) -> FloatTensor<Self, D> {
|
||||
fn dequantize<const D: usize>(_tensor: QuantizedTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_shape<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Shape<D> {
|
||||
super::base::shape(tensor)
|
||||
super::base::shape(&tensor.qtensor)
|
||||
}
|
||||
|
||||
fn q_device<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Device<Self> {
|
||||
super::base::device(tensor)
|
||||
super::base::device(&tensor.qtensor)
|
||||
}
|
||||
|
||||
fn q_reshape<const D1: usize, const D2: usize>(
|
||||
tensor: QuantizedTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> QuantizedTensor<Self, D2> {
|
||||
super::base::reshape(tensor, shape)
|
||||
CandleQTensor {
|
||||
qtensor: super::base::reshape(tensor.qtensor, shape),
|
||||
scheme: tensor.scheme,
|
||||
}
|
||||
}
|
||||
|
||||
async fn q_into_data<const D: usize>(
|
||||
tensor: QuantizedTensor<Self, D>,
|
||||
strategy: QuantizationStrategy,
|
||||
) -> TensorData {
|
||||
super::base::into_data(tensor)
|
||||
async fn q_into_data<const D: usize>(tensor: QuantizedTensor<Self, D>) -> TensorData {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use burn_tensor::{Element, Shape, TensorData};
|
||||
use burn_tensor::{
|
||||
quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy},
|
||||
Element, Shape, TensorData,
|
||||
};
|
||||
|
||||
use crate::{element::CandleElement, CandleDevice};
|
||||
|
||||
|
@ -45,3 +48,23 @@ impl<E: CandleElement, const D: usize> CandleTensor<E, D> {
|
|||
Shape::from(x)
|
||||
}
|
||||
}
|
||||
|
||||
/// A quantized tensor for the candle backend.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CandleQTensor<const D: usize> {
|
||||
/// The quantized tensor.
|
||||
// NOTE: candle does not implement `WithDType` for i8
|
||||
pub qtensor: CandleTensor<u8, D>,
|
||||
/// The quantization scheme.
|
||||
pub scheme: QuantizationScheme,
|
||||
}
|
||||
|
||||
impl<const D: usize> QTensorPrimitive for CandleQTensor<D> {
|
||||
fn scheme(&self) -> &QuantizationScheme {
|
||||
&self.scheme
|
||||
}
|
||||
|
||||
fn strategy(&self) -> QuantizationStrategy {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,9 +33,6 @@ pub mod module;
|
|||
/// Neural network module.
|
||||
pub mod nn;
|
||||
|
||||
/// Quantization module.
|
||||
pub mod quantization;
|
||||
|
||||
/// Module for the recorder.
|
||||
pub mod record;
|
||||
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
use super::ParamId;
|
||||
use super::{ParamId, Quantizer};
|
||||
use crate::{
|
||||
quantization::{Calibration, Quantizer},
|
||||
record::Record,
|
||||
tensor::backend::{AutodiffBackend, Backend},
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
pub use burn_derive::Module;
|
||||
use burn_tensor::{Bool, Int, Tensor};
|
||||
use burn_tensor::{quantization::Calibration, Bool, Int, Tensor};
|
||||
|
||||
/// Type alias to `Vec<B::Device>` which supports `no_std` environments, but automatically using
|
||||
/// the `alloc` crate.
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
mod base;
|
||||
mod display;
|
||||
mod param;
|
||||
mod quantize;
|
||||
|
||||
pub use base::*;
|
||||
pub use display::*;
|
||||
pub use param::*;
|
||||
pub use quantize::*;
|
||||
|
|
|
@ -1,18 +1,23 @@
|
|||
use burn_tensor::{backend::Backend, Tensor};
|
||||
use burn_tensor::{
|
||||
backend::Backend,
|
||||
quantization::{Calibration, QuantizationScheme},
|
||||
Tensor,
|
||||
};
|
||||
|
||||
use crate::module::{ModuleMapper, ParamId};
|
||||
|
||||
use super::Calibration;
|
||||
|
||||
/// Describes how to quantize a module.
|
||||
pub struct Quantizer<C: Calibration> {
|
||||
/// The calibration method used in quantization.
|
||||
pub calibration: C,
|
||||
/// The quantization scheme.
|
||||
pub scheme: QuantizationScheme,
|
||||
}
|
||||
|
||||
impl<B: Backend, C: Calibration> ModuleMapper<B> for Quantizer<C> {
|
||||
fn map_float<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let strategy = self.calibration.configure(&tensor);
|
||||
tensor.quantize(strategy)
|
||||
let range = self.calibration.compute_range(&tensor);
|
||||
let qparams = self.scheme.compute_q_params(range);
|
||||
tensor.quantize(&self.scheme, qparams)
|
||||
}
|
||||
}
|
|
@ -1,80 +0,0 @@
|
|||
use burn_tensor::{
|
||||
backend::Backend, AffineQuantization, ElementConversion, Quantization, QuantizationStrategy,
|
||||
SymmetricQuantization, Tensor,
|
||||
};
|
||||
|
||||
use super::{QuantizationScheme, QuantizationType};
|
||||
|
||||
/// Calibration method used to compute the quantization range mapping.
|
||||
pub trait Calibration {
|
||||
/// Configure the quantization strategy.
|
||||
fn configure<B: Backend, const D: usize>(&self, tensor: &Tensor<B, D>) -> QuantizationStrategy;
|
||||
}
|
||||
|
||||
/// Computes the quantization range mapping based on the running min and max values.
|
||||
pub struct MinMaxCalibration {
|
||||
/// Quantization scheme to be used.
|
||||
pub scheme: QuantizationScheme,
|
||||
}
|
||||
|
||||
impl Calibration for MinMaxCalibration {
|
||||
fn configure<B: Backend, const D: usize>(&self, tensor: &Tensor<B, D>) -> QuantizationStrategy {
|
||||
let min = tensor.clone().min().into_scalar().elem::<f32>();
|
||||
let max = tensor.clone().max().into_scalar().elem::<f32>();
|
||||
|
||||
match &self.scheme {
|
||||
QuantizationScheme::PerTensorAffine(dtype) => match dtype {
|
||||
QuantizationType::QInt8 => {
|
||||
QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::new(min, max))
|
||||
}
|
||||
},
|
||||
QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
|
||||
QuantizationType::QInt8 => QuantizationStrategy::PerTensorSymmetricInt8(
|
||||
SymmetricQuantization::new(min, max),
|
||||
),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
|
||||
#[test]
|
||||
fn min_max_calibration_per_tensor_affine_int8() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let tensor = Tensor::<TestBackend, 1>::from_floats([-1.8, -1.0, 0.0, 0.5], &device);
|
||||
let calibration = MinMaxCalibration {
|
||||
scheme: QuantizationScheme::PerTensorAffine(QuantizationType::QInt8),
|
||||
};
|
||||
|
||||
let strategy = calibration.configure(&tensor);
|
||||
|
||||
if let QuantizationStrategy::PerTensorAffineInt8(q) = strategy {
|
||||
assert_eq!(q.scale, 0.009_019_608);
|
||||
assert_eq!(q.offset, 72);
|
||||
} else {
|
||||
panic!("Wrong quantization strategy");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn min_max_calibration_per_tensor_symmetric_int8() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let tensor = Tensor::<TestBackend, 1>::from_floats([-1.8, -1.0, 0.0, 0.5], &device);
|
||||
let calibration = MinMaxCalibration {
|
||||
scheme: QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8),
|
||||
};
|
||||
|
||||
let strategy = calibration.configure(&tensor);
|
||||
|
||||
if let QuantizationStrategy::PerTensorSymmetricInt8(q) = strategy {
|
||||
assert_eq!(q.scale, 0.014_173_228);
|
||||
} else {
|
||||
panic!("Wrong quantization strategy");
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,7 +0,0 @@
|
|||
mod calibration;
|
||||
mod quantize;
|
||||
mod scheme;
|
||||
|
||||
pub use calibration::*;
|
||||
pub use quantize::*;
|
||||
pub use scheme::*;
|
|
@ -1,17 +0,0 @@
|
|||
/// Quantization data type.
|
||||
pub enum QuantizationType {
|
||||
/// 8-bit signed integer.
|
||||
QInt8,
|
||||
}
|
||||
|
||||
/// Quantization scheme.
|
||||
pub enum QuantizationScheme {
|
||||
/// Per-tensor affine/asymmetric quantization.
|
||||
PerTensorAffine(QuantizationType),
|
||||
/// Per-tensor symmetric quantization.
|
||||
PerTensorSymmetric(QuantizationType),
|
||||
// /// Per-channel affine/asymmetric quantization.
|
||||
// PerChannelAffine,
|
||||
// /// Per-channel symmetric quantization.
|
||||
// PerChannelSymmetric,
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
use crate::{
|
||||
client::FusionClient, stream::Context, FusionClientLocator, FusionTensor, PrecisionBridge,
|
||||
QFusionTensor,
|
||||
};
|
||||
use burn_tensor::{
|
||||
backend::{Backend, DeviceOps, SyncType},
|
||||
|
@ -37,7 +38,7 @@ impl<B: FusionBackend> Backend for Fusion<B> {
|
|||
|
||||
type BoolTensorPrimitive<const D: usize> = FusionTensor<B::FusionRuntime>;
|
||||
|
||||
type QuantizedTensorPrimitive<const D: usize> = FusionTensor<B::FusionRuntime>;
|
||||
type QuantizedTensorPrimitive<const D: usize> = QFusionTensor<B::FusionRuntime>;
|
||||
|
||||
fn name() -> String {
|
||||
format!("fusion<{}>", B::name())
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use burn_tensor::{
|
||||
backend::Backend,
|
||||
ops::{QTensorOps, QuantizedTensor},
|
||||
Device, QuantizationStrategy, Shape, TensorData,
|
||||
quantization::{QuantizationParametersPrimitive, QuantizationScheme},
|
||||
Device, Shape, TensorData,
|
||||
};
|
||||
|
||||
use crate::{client::FusionClient, Fusion, FusionBackend};
|
||||
|
@ -16,24 +17,24 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
|
|||
|
||||
fn quantize<const D: usize>(
|
||||
_tensor: <Self as Backend>::FloatTensorPrimitive<D>,
|
||||
_strategy: &QuantizationStrategy,
|
||||
_scheme: &QuantizationScheme,
|
||||
_qparams: QuantizationParametersPrimitive<Self>,
|
||||
) -> <Self as Backend>::QuantizedTensorPrimitive<D> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn dequantize<const D: usize>(
|
||||
_tensor: <Self as Backend>::QuantizedTensorPrimitive<D>,
|
||||
_strategy: &QuantizationStrategy,
|
||||
) -> <Self as Backend>::FloatTensorPrimitive<D> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_shape<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Shape<D> {
|
||||
tensor.shape()
|
||||
tensor.qtensor.shape()
|
||||
}
|
||||
|
||||
fn q_device<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Device<Self> {
|
||||
tensor.client.device().clone()
|
||||
tensor.qtensor.client.device().clone()
|
||||
}
|
||||
|
||||
fn q_reshape<const D1: usize, const D2: usize>(
|
||||
|
@ -43,10 +44,7 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
|
|||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn q_into_data<const D: usize>(
|
||||
_tensor: QuantizedTensor<Self, D>,
|
||||
_strategy: QuantizationStrategy,
|
||||
) -> TensorData {
|
||||
async fn q_into_data<const D: usize>(_tensor: QuantizedTensor<Self, D>) -> TensorData {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use crate::{client::FusionClient, stream::StreamId, Client, FusionBackend, FusionRuntime};
|
||||
use burn_tensor::{
|
||||
quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy},
|
||||
repr::{TensorDescription, TensorId, TensorStatus},
|
||||
DType, Shape, TensorData,
|
||||
};
|
||||
|
@ -157,3 +158,31 @@ impl<R: FusionRuntime> Drop for FusionTensor<R> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A quantized tensor primitive for fusion backends.
|
||||
#[derive(Debug)]
|
||||
pub struct QFusionTensor<R: FusionRuntime> {
|
||||
/// The quantized tensor.
|
||||
pub qtensor: FusionTensor<R>,
|
||||
/// The quantization scheme.
|
||||
pub scheme: QuantizationScheme,
|
||||
}
|
||||
|
||||
impl<R: FusionRuntime> QTensorPrimitive for QFusionTensor<R> {
|
||||
fn scheme(&self) -> &QuantizationScheme {
|
||||
&self.scheme
|
||||
}
|
||||
|
||||
fn strategy(&self) -> QuantizationStrategy {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: FusionRuntime> Clone for QFusionTensor<R> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
qtensor: self.qtensor.clone(),
|
||||
scheme: self.scheme.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use crate::{
|
||||
tensor::JitTensor, FloatElement, IntElement, JitAutotuneKey, JitRuntime, PrecisionBridge,
|
||||
tensor::{JitTensor, QJitTensor},
|
||||
FloatElement, IntElement, JitAutotuneKey, JitRuntime, PrecisionBridge,
|
||||
};
|
||||
use burn_compute::server::ComputeServer;
|
||||
use burn_tensor::backend::{Backend, SyncType};
|
||||
|
@ -33,8 +34,7 @@ where
|
|||
type FloatTensorPrimitive<const D: usize> = JitTensor<R, Self::FloatElem, D>;
|
||||
type IntTensorPrimitive<const D: usize> = JitTensor<R, Self::IntElem, D>;
|
||||
type BoolTensorPrimitive<const D: usize> = JitTensor<R, u32, D>;
|
||||
// TODO: implement `JitElement` / `CubeElement` for quantized type
|
||||
type QuantizedTensorPrimitive<const D: usize> = JitTensor<R, u32, D>;
|
||||
type QuantizedTensorPrimitive<const D: usize> = QJitTensor<R, D>;
|
||||
|
||||
fn name() -> String {
|
||||
format!("jit<{}>", R::name())
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
use burn_tensor::{
|
||||
ops::{FloatTensor, QTensorOps, QuantizedTensor},
|
||||
Device, QuantizationStrategy, Shape, TensorData,
|
||||
quantization::{QuantizationParametersPrimitive, QuantizationScheme},
|
||||
Device, Shape, TensorData,
|
||||
};
|
||||
|
||||
use crate::{FloatElement, IntElement, JitBackend, JitRuntime};
|
||||
use crate::{tensor::QJitTensor, FloatElement, IntElement, JitBackend, JitRuntime};
|
||||
|
||||
impl<R, F, I> QTensorOps<Self> for JitBackend<R, F, I>
|
||||
where
|
||||
|
@ -20,37 +21,35 @@ where
|
|||
|
||||
fn quantize<const D: usize>(
|
||||
_tensor: FloatTensor<Self, D>,
|
||||
_strategy: &QuantizationStrategy,
|
||||
_scheme: &QuantizationScheme,
|
||||
_qparams: QuantizationParametersPrimitive<Self>,
|
||||
) -> QuantizedTensor<Self, D> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn dequantize<const D: usize>(
|
||||
_tensor: QuantizedTensor<Self, D>,
|
||||
_strategy: &QuantizationStrategy,
|
||||
) -> FloatTensor<Self, D> {
|
||||
fn dequantize<const D: usize>(_tensor: QuantizedTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_shape<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Shape<D> {
|
||||
tensor.shape.clone()
|
||||
tensor.qtensor.shape.clone()
|
||||
}
|
||||
|
||||
fn q_device<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Device<Self> {
|
||||
tensor.device.clone()
|
||||
tensor.qtensor.device.clone()
|
||||
}
|
||||
|
||||
fn q_reshape<const D1: usize, const D2: usize>(
|
||||
tensor: QuantizedTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> QuantizedTensor<Self, D2> {
|
||||
super::reshape(tensor, shape)
|
||||
QJitTensor {
|
||||
qtensor: super::reshape(tensor.qtensor, shape),
|
||||
scheme: tensor.scheme,
|
||||
}
|
||||
}
|
||||
|
||||
async fn q_into_data<const D: usize>(
|
||||
_tensor: QuantizedTensor<Self, D>,
|
||||
_strategy: QuantizationStrategy,
|
||||
) -> TensorData {
|
||||
async fn q_into_data<const D: usize>(_tensor: QuantizedTensor<Self, D>) -> TensorData {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
mod base;
|
||||
mod layout;
|
||||
mod qtensor;
|
||||
|
||||
pub use base::*;
|
||||
pub(crate) use layout::*;
|
||||
pub(crate) use qtensor::*;
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
use burn_tensor::quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy};
|
||||
|
||||
use crate::JitRuntime;
|
||||
|
||||
use super::JitTensor;
|
||||
|
||||
/// A quantized tensor primitive.
|
||||
#[derive(Debug)]
|
||||
pub struct QJitTensor<R: JitRuntime, const D: usize> {
|
||||
/// The quantized tensor.
|
||||
// TODO: implement `JitElement` / `CubeElement` for quantized type
|
||||
pub qtensor: JitTensor<R, u32, D>,
|
||||
/// The quantization scheme.
|
||||
pub scheme: QuantizationScheme,
|
||||
}
|
||||
|
||||
impl<R: JitRuntime, const D: usize> QTensorPrimitive for QJitTensor<R, D> {
|
||||
fn scheme(&self) -> &QuantizationScheme {
|
||||
&self.scheme
|
||||
}
|
||||
|
||||
fn strategy(&self) -> QuantizationStrategy {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: JitRuntime, const D: usize> Clone for QJitTensor<R, D> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
qtensor: self.qtensor.clone(),
|
||||
scheme: self.scheme.clone(),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
use crate::NdArrayTensor;
|
||||
use crate::{element::FloatNdArrayElement, PrecisionBridge};
|
||||
use crate::element::{FloatNdArrayElement, QuantElement};
|
||||
use crate::PrecisionBridge;
|
||||
use crate::{NdArrayQTensor, NdArrayTensor};
|
||||
use alloc::string::String;
|
||||
use burn_common::stub::Mutex;
|
||||
use burn_tensor::backend::{Backend, DeviceId, DeviceOps};
|
||||
|
@ -34,11 +35,12 @@ impl Default for NdArrayDevice {
|
|||
/// This backend is compatible with CPUs and can be compiled for almost any platform, including
|
||||
/// `wasm`, `arm`, and `x86`.
|
||||
#[derive(Clone, Copy, Default, Debug)]
|
||||
pub struct NdArray<E = f32> {
|
||||
phantom: PhantomData<E>,
|
||||
pub struct NdArray<E = f32, Q = i8> {
|
||||
_e: PhantomData<E>,
|
||||
_q: PhantomData<Q>,
|
||||
}
|
||||
|
||||
impl<E: FloatNdArrayElement> Backend for NdArray<E> {
|
||||
impl<E: FloatNdArrayElement, Q: QuantElement> Backend for NdArray<E, Q> {
|
||||
type Device = NdArrayDevice;
|
||||
type FullPrecisionBridge = PrecisionBridge<f32>;
|
||||
|
||||
|
@ -50,7 +52,7 @@ impl<E: FloatNdArrayElement> Backend for NdArray<E> {
|
|||
|
||||
type BoolTensorPrimitive<const D: usize> = NdArrayTensor<bool, D>;
|
||||
|
||||
type QuantizedTensorPrimitive<const D: usize> = NdArrayTensor<i8, D>;
|
||||
type QuantizedTensorPrimitive<const D: usize> = NdArrayQTensor<Q, D>;
|
||||
|
||||
fn ad_enabled() -> bool {
|
||||
false
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayTensor};
|
||||
use crate::{element::QuantElement, FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayTensor};
|
||||
use burn_tensor::{backend::BackendBridge, ops::FloatTensor};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
|
@ -8,10 +8,11 @@ pub struct PrecisionBridge<E: FloatNdArrayElement> {
|
|||
_e: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<TElem, OElem> BackendBridge<NdArray<OElem>> for PrecisionBridge<TElem>
|
||||
impl<TElem, OElem, QElem> BackendBridge<NdArray<OElem, QElem>> for PrecisionBridge<TElem>
|
||||
where
|
||||
TElem: FloatNdArrayElement,
|
||||
OElem: FloatNdArrayElement,
|
||||
QElem: QuantElement,
|
||||
{
|
||||
type Target = NdArray<TElem>;
|
||||
|
||||
|
|
|
@ -41,6 +41,11 @@ pub trait ExpElement {
|
|||
fn int_abs_elem(self) -> Self;
|
||||
}
|
||||
|
||||
/// A quantized element for the ndarray backend.
|
||||
pub trait QuantElement: NdArrayElement {}
|
||||
|
||||
impl QuantElement for i8 {}
|
||||
|
||||
impl FloatNdArrayElement for f64 {}
|
||||
impl FloatNdArrayElement for f32 {}
|
||||
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArray};
|
||||
use crate::{
|
||||
element::{FloatNdArrayElement, QuantElement},
|
||||
tensor::NdArrayTensor,
|
||||
NdArray,
|
||||
};
|
||||
use burn_tensor::{ops::ActivationOps, ElementConversion};
|
||||
|
||||
impl<E: FloatNdArrayElement> ActivationOps<Self> for NdArray<E> {
|
||||
impl<E: FloatNdArrayElement, Q: QuantElement> ActivationOps<Self> for NdArray<E, Q> {
|
||||
fn relu<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
||||
let zero = 0.elem();
|
||||
let array = tensor
|
||||
|
|
|
@ -7,7 +7,7 @@ use core::ops::Range;
|
|||
use ndarray::IntoDimension;
|
||||
|
||||
// Current crate
|
||||
use crate::element::FloatNdArrayElement;
|
||||
use crate::element::{FloatNdArrayElement, QuantElement};
|
||||
use crate::NdArrayDevice;
|
||||
use crate::{tensor::NdArrayTensor, NdArray};
|
||||
|
||||
|
@ -16,7 +16,7 @@ use burn_tensor::{backend::Backend, Shape, TensorData};
|
|||
|
||||
use super::NdArrayOps;
|
||||
|
||||
impl<E: FloatNdArrayElement> BoolTensorOps<Self> for NdArray<E> {
|
||||
impl<E: FloatNdArrayElement, Q: QuantElement> BoolTensorOps<Self> for NdArray<E, Q> {
|
||||
fn bool_from_data<const D: usize>(
|
||||
data: TensorData,
|
||||
_device: &NdArrayDevice,
|
||||
|
|
|
@ -11,7 +11,7 @@ use ndarray::{
|
|||
};
|
||||
|
||||
use crate::{
|
||||
element::FloatNdArrayElement,
|
||||
element::{FloatNdArrayElement, QuantElement},
|
||||
ops::padding::{apply_padding_4d, apply_padding_5d},
|
||||
sharing::UnsafeSharedRef,
|
||||
tensor::NdArrayTensor,
|
||||
|
@ -98,7 +98,7 @@ fn conv3d_mad_inner<E: FloatNdArrayElement>(
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn conv2d<E: FloatNdArrayElement>(
|
||||
pub(crate) fn conv2d<E: FloatNdArrayElement, Q: QuantElement>(
|
||||
x: NdArrayTensor<E, 4>,
|
||||
weight: NdArrayTensor<E, 4>,
|
||||
bias: Option<NdArrayTensor<E, 1>>,
|
||||
|
@ -125,7 +125,7 @@ pub(crate) fn conv2d<E: FloatNdArrayElement>(
|
|||
in_width,
|
||||
);
|
||||
|
||||
let x = apply_padding_4d(x, options.padding, 0i32.elem()).array;
|
||||
let x = apply_padding_4d::<E, Q>(x, options.padding, 0i32.elem()).array;
|
||||
|
||||
// Convert inputs from dynamic indexes to static to improve perf.
|
||||
let x = x.into_dimensionality::<ndarray::Ix4>().unwrap();
|
||||
|
@ -309,7 +309,7 @@ pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>(
|
|||
NdArrayTensor::new(output.into_dyn().into_shared())
|
||||
}
|
||||
|
||||
pub(crate) fn conv3d<E: FloatNdArrayElement>(
|
||||
pub(crate) fn conv3d<E: FloatNdArrayElement, Q: QuantElement>(
|
||||
x: NdArrayTensor<E, 5>,
|
||||
weight: NdArrayTensor<E, 5>,
|
||||
bias: Option<NdArrayTensor<E, 1>>,
|
||||
|
@ -344,7 +344,7 @@ pub(crate) fn conv3d<E: FloatNdArrayElement>(
|
|||
in_width,
|
||||
);
|
||||
|
||||
let x = apply_padding_5d(x, options.padding, 0i32.elem()).array;
|
||||
let x = apply_padding_5d::<E, Q>(x, options.padding, 0i32.elem()).array;
|
||||
|
||||
// Convert inputs from dynamic indexes to static to improve perf.
|
||||
let x = x.into_dimensionality::<ndarray::Ix5>().unwrap();
|
||||
|
|
|
@ -12,6 +12,7 @@ use ndarray::IntoDimension;
|
|||
// Current crate
|
||||
use crate::element::ExpElement;
|
||||
use crate::element::FloatNdArrayElement;
|
||||
use crate::element::QuantElement;
|
||||
use crate::{tensor::NdArrayTensor, NdArray};
|
||||
use crate::{NdArrayDevice, SEED};
|
||||
|
||||
|
@ -20,7 +21,7 @@ use burn_tensor::{backend::Backend, Shape, TensorData};
|
|||
|
||||
use super::{NdArrayMathOps, NdArrayOps};
|
||||
|
||||
impl<E: FloatNdArrayElement> IntTensorOps<Self> for NdArray<E> {
|
||||
impl<E: FloatNdArrayElement, Q: QuantElement> IntTensorOps<Self> for NdArray<E, Q> {
|
||||
fn int_from_data<const D: usize>(
|
||||
data: TensorData,
|
||||
_device: &NdArrayDevice,
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
use crate::{
|
||||
element::FloatNdArrayElement, ops::padding::apply_padding_4d, sharing::UnsafeSharedRef,
|
||||
element::{FloatNdArrayElement, QuantElement},
|
||||
ops::padding::apply_padding_4d,
|
||||
sharing::UnsafeSharedRef,
|
||||
tensor::NdArrayTensor,
|
||||
};
|
||||
|
||||
|
@ -7,7 +9,7 @@ use burn_common::{iter_range_par, run_par};
|
|||
use burn_tensor::ElementConversion;
|
||||
use ndarray::Array4;
|
||||
|
||||
pub(crate) fn max_pool2d<E: FloatNdArrayElement>(
|
||||
pub(crate) fn max_pool2d<E: FloatNdArrayElement, Q: QuantElement>(
|
||||
x: NdArrayTensor<E, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
|
@ -28,7 +30,7 @@ pub(crate) fn max_pool2d<E: FloatNdArrayElement>(
|
|||
/ stride_width)
|
||||
+ 1;
|
||||
|
||||
let x = apply_padding_4d(x, padding, inf).array;
|
||||
let x = apply_padding_4d::<E, Q>(x, padding, inf).array;
|
||||
|
||||
let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf);
|
||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
|
@ -67,7 +69,7 @@ pub(crate) fn max_pool2d<E: FloatNdArrayElement>(
|
|||
NdArrayTensor::new(output.into_dyn().into_shared())
|
||||
}
|
||||
|
||||
pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement>(
|
||||
pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement, Q: QuantElement>(
|
||||
x: NdArrayTensor<E, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
|
@ -88,7 +90,7 @@ pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement>(
|
|||
/ stride_width)
|
||||
+ 1;
|
||||
|
||||
let x = apply_padding_4d(x, padding, inf).array;
|
||||
let x = apply_padding_4d::<E, Q>(x, padding, inf).array;
|
||||
|
||||
let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf);
|
||||
let mut indices = Array4::<i64>::zeros((batch_size, channels, out_height, out_width));
|
||||
|
|
|
@ -5,18 +5,18 @@ use super::{
|
|||
interpolate::{bicubic_interpolate, bilinear_interpolate, nearest_interpolate},
|
||||
maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
|
||||
};
|
||||
use crate::ops::interpolate::nearest_interpolate_backward;
|
||||
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArray};
|
||||
use crate::{element::QuantElement, ops::interpolate::nearest_interpolate_backward};
|
||||
use burn_tensor::ops::*;
|
||||
|
||||
impl<E: FloatNdArrayElement> ModuleOps<Self> for NdArray<E> {
|
||||
impl<E: FloatNdArrayElement, Q: QuantElement> ModuleOps<Self> for NdArray<E, Q> {
|
||||
fn conv2d(
|
||||
x: NdArrayTensor<E, 4>,
|
||||
weight: NdArrayTensor<E, 4>,
|
||||
bias: Option<NdArrayTensor<E, 1>>,
|
||||
options: ConvOptions<2>,
|
||||
) -> NdArrayTensor<E, 4> {
|
||||
conv2d(x, weight, bias, options)
|
||||
conv2d::<E, Q>(x, weight, bias, options)
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
|
@ -56,7 +56,7 @@ impl<E: FloatNdArrayElement> ModuleOps<Self> for NdArray<E> {
|
|||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> NdArrayTensor<E, 4> {
|
||||
max_pool2d(x, kernel_size, stride, padding, dilation)
|
||||
max_pool2d::<E, Q>(x, kernel_size, stride, padding, dilation)
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indices(
|
||||
|
@ -65,8 +65,9 @@ impl<E: FloatNdArrayElement> ModuleOps<Self> for NdArray<E> {
|
|||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> MaxPool2dWithIndices<NdArray<E>> {
|
||||
let (output, indices) = max_pool2d_with_indices(x, kernel_size, stride, padding, dilation);
|
||||
) -> MaxPool2dWithIndices<NdArray<E, Q>> {
|
||||
let (output, indices) =
|
||||
max_pool2d_with_indices::<E, Q>(x, kernel_size, stride, padding, dilation);
|
||||
|
||||
MaxPool2dWithIndices::new(output, indices)
|
||||
}
|
||||
|
@ -79,7 +80,7 @@ impl<E: FloatNdArrayElement> ModuleOps<Self> for NdArray<E> {
|
|||
dilation: [usize; 2],
|
||||
output_grad: NdArrayTensor<E, 4>,
|
||||
indices: NdArrayTensor<i64, 4>,
|
||||
) -> MaxPool2dBackward<NdArray<E>> {
|
||||
) -> MaxPool2dBackward<NdArray<E, Q>> {
|
||||
MaxPool2dBackward::new(max_pool2d_backward(
|
||||
x,
|
||||
kernel_size,
|
||||
|
@ -137,7 +138,7 @@ impl<E: FloatNdArrayElement> ModuleOps<Self> for NdArray<E> {
|
|||
bias: Option<NdArrayTensor<E, 1>>,
|
||||
options: ConvOptions<3>,
|
||||
) -> NdArrayTensor<E, 5> {
|
||||
conv3d(x, weight, bias, options)
|
||||
conv3d::<E, Q>(x, weight, bias, options)
|
||||
}
|
||||
|
||||
fn conv_transpose3d(
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArray};
|
||||
use crate::{
|
||||
element::{FloatNdArrayElement, QuantElement},
|
||||
tensor::NdArrayTensor,
|
||||
NdArray,
|
||||
};
|
||||
use burn_tensor::ops::FloatTensorOps;
|
||||
use ndarray::{Array4, Array5};
|
||||
|
||||
pub(crate) fn apply_padding_4d<E: FloatNdArrayElement>(
|
||||
pub(crate) fn apply_padding_4d<E: FloatNdArrayElement, Q: QuantElement>(
|
||||
x: NdArrayTensor<E, 4>,
|
||||
padding: [usize; 2],
|
||||
elem: E,
|
||||
|
@ -18,7 +22,7 @@ pub(crate) fn apply_padding_4d<E: FloatNdArrayElement>(
|
|||
);
|
||||
let mut x_new = NdArrayTensor::new(x_new.into_shared().into_dyn());
|
||||
|
||||
x_new = NdArray::float_slice_assign(
|
||||
x_new = NdArray::<E, Q>::float_slice_assign(
|
||||
x_new,
|
||||
[
|
||||
0..batch_size,
|
||||
|
@ -32,7 +36,7 @@ pub(crate) fn apply_padding_4d<E: FloatNdArrayElement>(
|
|||
x_new
|
||||
}
|
||||
|
||||
pub(crate) fn apply_padding_5d<E: FloatNdArrayElement>(
|
||||
pub(crate) fn apply_padding_5d<E: FloatNdArrayElement, Q: QuantElement>(
|
||||
x: NdArrayTensor<E, 5>,
|
||||
padding: [usize; 3],
|
||||
elem: E,
|
||||
|
@ -55,7 +59,7 @@ pub(crate) fn apply_padding_5d<E: FloatNdArrayElement>(
|
|||
);
|
||||
let mut x_new = NdArrayTensor::new(x_new.into_shared().into_dyn());
|
||||
|
||||
x_new = NdArray::float_slice_assign(
|
||||
x_new = NdArray::<E, Q>::float_slice_assign(
|
||||
x_new,
|
||||
[
|
||||
0..batch_size,
|
||||
|
|
|
@ -1,9 +1,16 @@
|
|||
use burn_tensor::{
|
||||
ops::{FloatTensor, QTensorOps, QuantizedTensor},
|
||||
DType, Quantization, QuantizationStrategy, Shape, TensorData,
|
||||
quantization::{
|
||||
AffineQuantization, Quantization, QuantizationParametersPrimitive, QuantizationScheme,
|
||||
QuantizationStrategy, QuantizationType, SymmetricQuantization,
|
||||
},
|
||||
DType, Shape, TensorData,
|
||||
};
|
||||
|
||||
use crate::{element::NdArrayElement, FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayTensor};
|
||||
use crate::{
|
||||
element::{NdArrayElement, QuantElement},
|
||||
FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor,
|
||||
};
|
||||
|
||||
use super::NdArrayOps;
|
||||
|
||||
|
@ -13,7 +20,7 @@ fn into_data<E: NdArrayElement, const D: usize>(tensor: NdArrayTensor<E, D>) ->
|
|||
TensorData::new(values, shape)
|
||||
}
|
||||
|
||||
impl<E: FloatNdArrayElement> QTensorOps<Self> for NdArray<E> {
|
||||
impl<E: FloatNdArrayElement, Q: QuantElement> QTensorOps<Self> for NdArray<E, Q> {
|
||||
fn q_from_data<const D: usize>(
|
||||
data: TensorData,
|
||||
_device: &NdArrayDevice,
|
||||
|
@ -22,11 +29,19 @@ impl<E: FloatNdArrayElement> QTensorOps<Self> for NdArray<E> {
|
|||
DType::QFloat(strategy) => match strategy {
|
||||
QuantizationStrategy::PerTensorAffineInt8(_) => {
|
||||
let data = data.convert::<i8>();
|
||||
NdArrayTensor::<i8, D>::from_data(data)
|
||||
NdArrayQTensor {
|
||||
qtensor: NdArrayTensor::<Q, D>::from_data(data),
|
||||
scheme: strategy.scheme(),
|
||||
strategy,
|
||||
}
|
||||
}
|
||||
QuantizationStrategy::PerTensorSymmetricInt8(_) => {
|
||||
let data = data.convert::<i8>();
|
||||
NdArrayTensor::<i8, D>::from_data(data)
|
||||
NdArrayQTensor {
|
||||
qtensor: NdArrayTensor::<Q, D>::from_data(data),
|
||||
scheme: strategy.scheme(),
|
||||
strategy,
|
||||
}
|
||||
}
|
||||
},
|
||||
_ => panic!(
|
||||
|
@ -38,18 +53,36 @@ impl<E: FloatNdArrayElement> QTensorOps<Self> for NdArray<E> {
|
|||
|
||||
fn quantize<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
strategy: &QuantizationStrategy,
|
||||
scheme: &QuantizationScheme,
|
||||
qparams: QuantizationParametersPrimitive<Self>,
|
||||
) -> QuantizedTensor<Self, D> {
|
||||
let data = into_data(tensor).with_quantization(*strategy);
|
||||
NdArrayTensor::<i8, D>::from_data(data)
|
||||
let strategy = match scheme {
|
||||
QuantizationScheme::PerTensorAffine(dtype) => match dtype {
|
||||
QuantizationType::QInt8 => {
|
||||
QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(
|
||||
into_data(qparams.scale).iter().next().unwrap(),
|
||||
into_data(qparams.offset.unwrap()).iter().next().unwrap(),
|
||||
))
|
||||
}
|
||||
},
|
||||
QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
|
||||
QuantizationType::QInt8 => QuantizationStrategy::PerTensorSymmetricInt8(
|
||||
SymmetricQuantization::init(into_data(qparams.scale).iter().next().unwrap()),
|
||||
),
|
||||
},
|
||||
};
|
||||
|
||||
let data = into_data(tensor).with_quantization(strategy);
|
||||
NdArrayQTensor {
|
||||
qtensor: NdArrayTensor::<Q, D>::from_data(data),
|
||||
strategy,
|
||||
scheme: scheme.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn dequantize<const D: usize>(
|
||||
tensor: QuantizedTensor<Self, D>,
|
||||
strategy: &QuantizationStrategy,
|
||||
) -> FloatTensor<Self, D> {
|
||||
let data = into_data(tensor);
|
||||
let values = match strategy {
|
||||
fn dequantize<const D: usize>(tensor: QuantizedTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
let data = into_data(tensor.qtensor);
|
||||
let values = match tensor.strategy {
|
||||
QuantizationStrategy::PerTensorAffineInt8(s) => s.dequantize(data.as_slice().unwrap()),
|
||||
QuantizationStrategy::PerTensorSymmetricInt8(s) => {
|
||||
s.dequantize(data.as_slice().unwrap())
|
||||
|
@ -59,7 +92,7 @@ impl<E: FloatNdArrayElement> QTensorOps<Self> for NdArray<E> {
|
|||
}
|
||||
|
||||
fn q_shape<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Shape<D> {
|
||||
tensor.shape()
|
||||
tensor.qtensor.shape()
|
||||
}
|
||||
|
||||
fn q_device<const D: usize>(_tensor: &QuantizedTensor<Self, D>) -> NdArrayDevice {
|
||||
|
@ -70,15 +103,16 @@ impl<E: FloatNdArrayElement> QTensorOps<Self> for NdArray<E> {
|
|||
tensor: QuantizedTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> QuantizedTensor<Self, D2> {
|
||||
NdArrayOps::reshape(tensor, shape)
|
||||
NdArrayQTensor {
|
||||
qtensor: NdArrayOps::reshape(tensor.qtensor, shape),
|
||||
scheme: tensor.scheme,
|
||||
strategy: tensor.strategy,
|
||||
}
|
||||
}
|
||||
|
||||
async fn q_into_data<const D: usize>(
|
||||
tensor: QuantizedTensor<Self, D>,
|
||||
strategy: QuantizationStrategy,
|
||||
) -> TensorData {
|
||||
let shape = tensor.shape();
|
||||
let values = tensor.array.into_iter().collect();
|
||||
TensorData::quantized(values, shape, strategy)
|
||||
async fn q_into_data<const D: usize>(tensor: QuantizedTensor<Self, D>) -> TensorData {
|
||||
let shape = tensor.qtensor.shape();
|
||||
let values = tensor.qtensor.array.into_iter().collect();
|
||||
TensorData::quantized(values, shape, tensor.strategy)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ use ndarray::IntoDimension;
|
|||
|
||||
// Current crate
|
||||
use super::{matmul::matmul, NdArrayMathOps, NdArrayOps};
|
||||
use crate::element::FloatNdArrayElement;
|
||||
use crate::element::{FloatNdArrayElement, QuantElement};
|
||||
use crate::{tensor::NdArrayTensor, NdArray};
|
||||
use crate::{NdArrayDevice, SEED};
|
||||
|
||||
|
@ -20,7 +20,7 @@ use num_traits::Float;
|
|||
|
||||
use libm::erf;
|
||||
|
||||
impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
|
||||
impl<E: FloatNdArrayElement, Q: QuantElement> FloatTensorOps<Self> for NdArray<E, Q> {
|
||||
fn float_from_data<const D: usize>(
|
||||
data: TensorData,
|
||||
_device: &NdArrayDevice,
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
use burn_tensor::{Element, Shape, TensorData};
|
||||
use burn_tensor::{
|
||||
quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy},
|
||||
Element, Shape, TensorData,
|
||||
};
|
||||
|
||||
use ndarray::{ArcArray, Array, Dim, IxDyn};
|
||||
|
||||
use crate::element::QuantElement;
|
||||
|
||||
/// Tensor primitive used by the [ndarray backend](crate::NdArray).
|
||||
#[derive(new, Debug, Clone)]
|
||||
pub struct NdArrayTensor<E, const D: usize> {
|
||||
|
@ -111,11 +116,38 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// A quantized tensor for the ndarray backend.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct NdArrayQTensor<Q: QuantElement, const D: usize> {
|
||||
/// The quantized tensor.
|
||||
pub qtensor: NdArrayTensor<Q, D>,
|
||||
/// The quantization scheme.
|
||||
pub scheme: QuantizationScheme,
|
||||
/// The quantization strategy.
|
||||
pub strategy: QuantizationStrategy,
|
||||
}
|
||||
|
||||
impl<Q: QuantElement, const D: usize> QTensorPrimitive for NdArrayQTensor<Q, D> {
|
||||
fn scheme(&self) -> &QuantizationScheme {
|
||||
&self.scheme
|
||||
}
|
||||
|
||||
fn strategy(&self) -> QuantizationStrategy {
|
||||
self.strategy
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::NdArray;
|
||||
|
||||
use super::*;
|
||||
use burn_common::rand::get_seeded_rng;
|
||||
use burn_tensor::Distribution;
|
||||
use burn_tensor::{
|
||||
ops::QTensorOps,
|
||||
quantization::{AffineQuantization, QuantizationParametersPrimitive, QuantizationType},
|
||||
Distribution,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn should_support_into_and_from_data_1d() {
|
||||
|
@ -172,4 +204,21 @@ mod tests {
|
|||
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_qtensor_strategy() {
|
||||
let tensor = NdArrayTensor::<f32, 1>::from_data(TensorData::from([-1.8, -1.0, 0.0, 0.5]));
|
||||
let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8);
|
||||
let qparams = QuantizationParametersPrimitive {
|
||||
scale: NdArrayTensor::from_data(TensorData::from([0.009_019_608])),
|
||||
offset: Some(NdArrayTensor::from_data(TensorData::from([72]))),
|
||||
};
|
||||
let qtensor: NdArrayQTensor<i8, 1> = NdArray::quantize(tensor, &scheme, qparams);
|
||||
|
||||
assert_eq!(qtensor.scheme(), &scheme);
|
||||
assert_eq!(
|
||||
qtensor.strategy(),
|
||||
QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.009_019_608, 72))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::PrecisionBridge;
|
||||
use crate::{PrecisionBridge, QuantElement, TchQTensor};
|
||||
|
||||
use super::element::TchElement;
|
||||
use super::TchTensor;
|
||||
|
@ -86,11 +86,12 @@ impl Default for LibTorchDevice {
|
|||
///
|
||||
/// Refer to the [tch] crate for more information.
|
||||
#[derive(Clone, Copy, Default, Debug)]
|
||||
pub struct LibTorch<E = f32> {
|
||||
pub struct LibTorch<E = f32, Q = i8> {
|
||||
_e: E,
|
||||
_q: Q,
|
||||
}
|
||||
|
||||
impl<E: TchElement> Backend for LibTorch<E> {
|
||||
impl<E: TchElement, Q: QuantElement> Backend for LibTorch<E, Q> {
|
||||
type Device = LibTorchDevice;
|
||||
type FullPrecisionBridge = PrecisionBridge<f32>;
|
||||
|
||||
|
@ -102,7 +103,7 @@ impl<E: TchElement> Backend for LibTorch<E> {
|
|||
|
||||
type BoolTensorPrimitive<const D: usize> = TchTensor<bool, D>;
|
||||
|
||||
type QuantizedTensorPrimitive<const D: usize> = TchTensor<i8, D>;
|
||||
type QuantizedTensorPrimitive<const D: usize> = TchQTensor<Q, D>;
|
||||
|
||||
fn seed(seed: u64) {
|
||||
tch::manual_seed(seed as i64);
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{ops::TchOps, LibTorch, TchElement, TchTensor};
|
||||
use crate::{ops::TchOps, LibTorch, QuantElement, TchElement, TchTensor};
|
||||
use burn_tensor::{backend::BackendBridge, ops::FloatTensor, Device};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
|
@ -8,10 +8,11 @@ pub struct PrecisionBridge<E: TchElement> {
|
|||
_e: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<TElem, OElem> BackendBridge<LibTorch<OElem>> for PrecisionBridge<TElem>
|
||||
impl<TElem, OElem, QElem> BackendBridge<LibTorch<OElem, QElem>> for PrecisionBridge<TElem>
|
||||
where
|
||||
TElem: TchElement,
|
||||
OElem: TchElement,
|
||||
QElem: QuantElement,
|
||||
{
|
||||
type Target = LibTorch<TElem>;
|
||||
|
||||
|
|
|
@ -12,5 +12,11 @@ impl TchElement for bf16 {}
|
|||
impl TchElement for i64 {}
|
||||
impl TchElement for i32 {}
|
||||
impl TchElement for i16 {}
|
||||
impl TchElement for i8 {}
|
||||
|
||||
impl TchElement for u8 {}
|
||||
|
||||
/// A quantized element for the tch backend.
|
||||
pub trait QuantElement: TchElement {}
|
||||
|
||||
impl QuantElement for i8 {}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{element::TchElement, LibTorch, TchTensor};
|
||||
use crate::{element::TchElement, LibTorch, QuantElement, TchTensor};
|
||||
use burn_tensor::ops::ActivationOps;
|
||||
|
||||
impl<E: TchElement> ActivationOps<Self> for LibTorch<E> {
|
||||
impl<E: TchElement, Q: QuantElement> ActivationOps<Self> for LibTorch<E, Q> {
|
||||
fn relu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
|
||||
tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu())
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use burn_tensor::{QuantizationStrategy, Shape};
|
||||
use burn_tensor::{quantization::QuantizationStrategy, Shape};
|
||||
use tch::Scalar;
|
||||
|
||||
use crate::{LibTorchDevice, TchShape, TchTensor};
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
use super::TchOps;
|
||||
use crate::{element::TchElement, LibTorch, LibTorchDevice, TchTensor};
|
||||
use crate::{element::TchElement, LibTorch, LibTorchDevice, QuantElement, TchTensor};
|
||||
use burn_tensor::{backend::Backend, ops::BoolTensorOps, Shape, TensorData};
|
||||
use std::ops::Range;
|
||||
|
||||
impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
|
||||
impl<E: TchElement, Q: QuantElement> BoolTensorOps<Self> for LibTorch<E, Q> {
|
||||
fn bool_from_data<const D: usize>(
|
||||
data: TensorData,
|
||||
device: &LibTorchDevice,
|
||||
|
|
|
@ -2,11 +2,11 @@ use std::ops::Range;
|
|||
|
||||
use burn_tensor::{backend::Backend, ops::IntTensorOps, Distribution, Shape, TensorData};
|
||||
|
||||
use crate::{element::TchElement, LibTorch, LibTorchDevice, TchShape, TchTensor};
|
||||
use crate::{element::TchElement, LibTorch, LibTorchDevice, QuantElement, TchShape, TchTensor};
|
||||
|
||||
use super::TchOps;
|
||||
|
||||
impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
|
||||
impl<E: TchElement, Q: QuantElement> IntTensorOps<Self> for LibTorch<E, Q> {
|
||||
fn int_from_data<const D: usize>(
|
||||
data: TensorData,
|
||||
device: &LibTorchDevice,
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
use crate::{element::TchElement, LibTorch, TchTensor};
|
||||
use crate::{element::TchElement, LibTorch, QuantElement, TchTensor};
|
||||
use burn_tensor::ops::{
|
||||
ConvOptions, ConvTransposeOptions, InterpolateMode, InterpolateOptions, MaxPool1dWithIndices,
|
||||
MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
|
||||
};
|
||||
|
||||
impl<E: TchElement> ModuleOps<Self> for LibTorch<E> {
|
||||
impl<E: TchElement, Q: QuantElement> ModuleOps<Self> for LibTorch<E, Q> {
|
||||
fn embedding(weights: TchTensor<E, 2>, indices: TchTensor<i64, 2>) -> TchTensor<E, 3> {
|
||||
let tensor = tch::Tensor::embedding(&weights.tensor, &indices.tensor, -1, false, false);
|
||||
|
||||
|
@ -231,7 +231,7 @@ impl<E: TchElement> ModuleOps<Self> for LibTorch<E> {
|
|||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
) -> MaxPool1dWithIndices<LibTorch<E>> {
|
||||
) -> MaxPool1dWithIndices<LibTorch<E, Q>> {
|
||||
let (tensor, indices) = tch::Tensor::max_pool1d_with_indices(
|
||||
&x.tensor,
|
||||
kernel_size as i64,
|
||||
|
@ -269,7 +269,7 @@ impl<E: TchElement> ModuleOps<Self> for LibTorch<E> {
|
|||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> MaxPool2dWithIndices<LibTorch<E>> {
|
||||
) -> MaxPool2dWithIndices<LibTorch<E, Q>> {
|
||||
let (tensor, indices) = tch::Tensor::max_pool2d_with_indices(
|
||||
&x.tensor,
|
||||
[kernel_size[0] as i64, kernel_size[1] as i64],
|
||||
|
@ -290,7 +290,7 @@ impl<E: TchElement> ModuleOps<Self> for LibTorch<E> {
|
|||
dilation: [usize; 2],
|
||||
output_grad: TchTensor<E, 4>,
|
||||
indices: TchTensor<i64, 4>,
|
||||
) -> MaxPool2dBackward<LibTorch<E>> {
|
||||
) -> MaxPool2dBackward<LibTorch<E, Q>> {
|
||||
let grad = tch::Tensor::max_pool2d_with_indices_backward(
|
||||
&x.tensor,
|
||||
&output_grad.tensor,
|
||||
|
|
|
@ -1,13 +1,17 @@
|
|||
use burn_tensor::{
|
||||
ops::{FloatTensor, QTensorOps, QuantizedTensor},
|
||||
DType, Quantization, QuantizationStrategy, Shape, TensorData,
|
||||
quantization::{
|
||||
QTensorPrimitive, Quantization, QuantizationParametersPrimitive, QuantizationScheme,
|
||||
QuantizationStrategy, QuantizationType,
|
||||
},
|
||||
DType, Shape, TensorData,
|
||||
};
|
||||
|
||||
use crate::{LibTorch, LibTorchDevice, TchElement, TchShape, TchTensor};
|
||||
use crate::{LibTorch, LibTorchDevice, QuantElement, TchElement, TchQTensor, TchShape, TchTensor};
|
||||
|
||||
use super::TchOps;
|
||||
|
||||
impl<E: TchElement> QTensorOps<Self> for LibTorch<E> {
|
||||
impl<E: TchElement, Q: QuantElement> QTensorOps<Self> for LibTorch<E, Q> {
|
||||
fn q_from_data<const D: usize>(
|
||||
data: TensorData,
|
||||
device: &LibTorchDevice,
|
||||
|
@ -19,25 +23,27 @@ impl<E: TchElement> QTensorOps<Self> for LibTorch<E> {
|
|||
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/quantized/Quantizer.cpp#L322
|
||||
// So for now we have to load the dequantized values to quantize them back since the dequantization
|
||||
// methods take the values provided when quantizing.
|
||||
let tensor = match data.dtype {
|
||||
let (tensor, scheme) = match data.dtype {
|
||||
DType::QFloat(strategy) => match strategy {
|
||||
QuantizationStrategy::PerTensorAffineInt8(q) => {
|
||||
let values = q.dequantize(&data.iter::<i8>().collect::<Vec<_>>());
|
||||
let tensor = tch::Tensor::from_slice(&values).to(device);
|
||||
TchOps::<E>::quantize::<D, i8>(
|
||||
let tensor = TchOps::<E>::quantize::<D, i8>(
|
||||
TchTensor::new(tensor.reshape(shape_tch.dims)),
|
||||
&strategy,
|
||||
)
|
||||
.tensor
|
||||
.tensor;
|
||||
(tensor, strategy.scheme())
|
||||
}
|
||||
QuantizationStrategy::PerTensorSymmetricInt8(q) => {
|
||||
let values = q.dequantize(&data.iter::<i8>().collect::<Vec<_>>());
|
||||
let tensor = tch::Tensor::from_slice(&values).to(device);
|
||||
TchOps::<E>::quantize::<D, i8>(
|
||||
let tensor = TchOps::<E>::quantize::<D, i8>(
|
||||
TchTensor::new(tensor.reshape(shape_tch.dims)),
|
||||
&strategy,
|
||||
)
|
||||
.tensor
|
||||
.tensor;
|
||||
(tensor, strategy.scheme())
|
||||
}
|
||||
},
|
||||
_ => panic!(
|
||||
|
@ -45,12 +51,16 @@ impl<E: TchElement> QTensorOps<Self> for LibTorch<E> {
|
|||
data.dtype
|
||||
),
|
||||
};
|
||||
TchTensor::new(tensor)
|
||||
TchQTensor {
|
||||
qtensor: TchTensor::new(tensor),
|
||||
scheme,
|
||||
}
|
||||
}
|
||||
|
||||
fn quantize<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
strategy: &QuantizationStrategy,
|
||||
scheme: &QuantizationScheme,
|
||||
qparams: QuantizationParametersPrimitive<Self>,
|
||||
) -> QuantizedTensor<Self, D> {
|
||||
let mut tensor = tensor;
|
||||
// Quantize only works on Float Tensor
|
||||
|
@ -58,52 +68,58 @@ impl<E: TchElement> QTensorOps<Self> for LibTorch<E> {
|
|||
tensor.tensor = tensor.tensor.to_kind(tch::Kind::Float);
|
||||
}
|
||||
|
||||
match strategy {
|
||||
QuantizationStrategy::PerTensorAffineInt8(ref q) => {
|
||||
TchTensor::new(tensor.tensor.quantize_per_tensor(
|
||||
q.scale.into(),
|
||||
q.offset.into(),
|
||||
let qtensor = match scheme {
|
||||
QuantizationScheme::PerTensorAffine(dtype) => match dtype {
|
||||
QuantizationType::QInt8 => tensor.tensor.quantize_per_tensor_tensor_qparams(
|
||||
&qparams.scale.tensor,
|
||||
&qparams.offset.unwrap().tensor,
|
||||
tch::Kind::QInt8,
|
||||
))
|
||||
),
|
||||
},
|
||||
QuantizationScheme::PerTensorSymmetric(_) => {
|
||||
tensor.tensor.quantize_per_tensor_tensor_qparams(
|
||||
&qparams.scale.tensor,
|
||||
&tch::Tensor::zeros_like(&qparams.scale.tensor),
|
||||
tch::Kind::QInt8,
|
||||
)
|
||||
}
|
||||
QuantizationStrategy::PerTensorSymmetricInt8(ref q) => TchTensor::new(
|
||||
tensor
|
||||
.tensor
|
||||
.quantize_per_tensor(q.scale.into(), 0, tch::Kind::QInt8),
|
||||
),
|
||||
};
|
||||
|
||||
TchQTensor {
|
||||
qtensor: TchTensor::new(qtensor),
|
||||
scheme: scheme.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn dequantize<const D: usize>(
|
||||
tensor: QuantizedTensor<Self, D>,
|
||||
_strategy: &QuantizationStrategy,
|
||||
) -> FloatTensor<Self, D> {
|
||||
TchTensor::new(tensor.tensor.dequantize().to_kind(E::KIND))
|
||||
fn dequantize<const D: usize>(tensor: QuantizedTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
TchTensor::new(tensor.qtensor.tensor.dequantize().to_kind(E::KIND))
|
||||
}
|
||||
|
||||
fn q_shape<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Shape<D> {
|
||||
tensor.shape()
|
||||
tensor.qtensor.shape()
|
||||
}
|
||||
|
||||
fn q_device<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> LibTorchDevice {
|
||||
tensor.tensor.device().into()
|
||||
tensor.qtensor.tensor.device().into()
|
||||
}
|
||||
|
||||
fn q_reshape<const D1: usize, const D2: usize>(
|
||||
tensor: QuantizedTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> QuantizedTensor<Self, D2> {
|
||||
TchOps::reshape(tensor, shape)
|
||||
TchQTensor {
|
||||
qtensor: TchOps::reshape(tensor.qtensor, shape),
|
||||
scheme: tensor.scheme,
|
||||
}
|
||||
}
|
||||
|
||||
async fn q_into_data<const D: usize>(
|
||||
tensor: QuantizedTensor<Self, D>,
|
||||
strategy: QuantizationStrategy,
|
||||
) -> TensorData {
|
||||
async fn q_into_data<const D: usize>(tensor: QuantizedTensor<Self, D>) -> TensorData {
|
||||
let shape = Self::q_shape(&tensor);
|
||||
let tensor = Self::q_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
|
||||
let strategy = tensor.strategy();
|
||||
|
||||
// To get the integer values we have to call `int_repr()`
|
||||
let values: Result<Vec<i8>, tch::TchError> = tensor.tensor.int_repr().try_into();
|
||||
let values: Result<Vec<i8>, tch::TchError> = tensor.qtensor.tensor.int_repr().try_into();
|
||||
|
||||
TensorData::quantized(values.unwrap(), shape, strategy)
|
||||
}
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
use super::TchOps;
|
||||
use crate::{element::TchElement, LibTorch, LibTorchDevice, TchShape, TchTensor};
|
||||
use crate::{element::TchElement, LibTorch, LibTorchDevice, QuantElement, TchShape, TchTensor};
|
||||
use burn_tensor::{
|
||||
backend::Backend, ops::FloatTensorOps, Distribution, ElementConversion, Shape, TensorData,
|
||||
};
|
||||
use std::ops::Range;
|
||||
|
||||
impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
|
||||
impl<E: TchElement, Q: QuantElement> FloatTensorOps<Self> for LibTorch<E, Q> {
|
||||
fn float_from_data<const D: usize>(
|
||||
data: TensorData,
|
||||
device: &LibTorchDevice,
|
||||
|
|
|
@ -1,5 +1,11 @@
|
|||
use crate::{element::TchElement, LibTorch, LibTorchDevice};
|
||||
use burn_tensor::{ops::FloatTensorOps, Element, Shape, TensorData};
|
||||
use crate::{LibTorchDevice, QuantElement};
|
||||
use burn_tensor::{
|
||||
quantization::{
|
||||
AffineQuantization, QTensorPrimitive, QuantizationScheme, QuantizationStrategy,
|
||||
QuantizationType, SymmetricQuantization,
|
||||
},
|
||||
Element, Shape, TensorData,
|
||||
};
|
||||
use libc::c_void;
|
||||
use std::{marker::PhantomData, sync::Arc};
|
||||
|
||||
|
@ -139,14 +145,6 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<E: TchElement, const D: usize> std::ops::Add for TchTensor<E, D> {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
LibTorch::float_add(self, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
|
||||
pub(crate) fn shape(&self) -> Shape<D> {
|
||||
Shape::from(self.tensor.size())
|
||||
|
@ -314,9 +312,51 @@ impl<E: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> T
|
|||
}
|
||||
}
|
||||
|
||||
/// A quantized tensor for the tch backend.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct TchQTensor<Q: QuantElement, const D: usize> {
|
||||
/// The quantized tensor.
|
||||
pub qtensor: TchTensor<Q, D>,
|
||||
/// The quantization scheme.
|
||||
pub scheme: QuantizationScheme,
|
||||
}
|
||||
|
||||
impl<Q: QuantElement, const D: usize> QTensorPrimitive for TchQTensor<Q, D> {
|
||||
fn scheme(&self) -> &QuantizationScheme {
|
||||
&self.scheme
|
||||
}
|
||||
|
||||
fn strategy(&self) -> QuantizationStrategy {
|
||||
match &self.scheme {
|
||||
QuantizationScheme::PerTensorAffine(dtype) => match dtype {
|
||||
QuantizationType::QInt8 => {
|
||||
let scale = self.qtensor.tensor.q_scale();
|
||||
let offset = self.qtensor.tensor.q_zero_point();
|
||||
QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(
|
||||
scale as f32,
|
||||
offset as i8,
|
||||
))
|
||||
}
|
||||
},
|
||||
QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
|
||||
QuantizationType::QInt8 => {
|
||||
let scale = self.qtensor.tensor.q_scale();
|
||||
QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(
|
||||
scale as f32,
|
||||
))
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::LibTorch;
|
||||
|
||||
use super::*;
|
||||
use burn_tensor::ops::QTensorOps;
|
||||
use burn_tensor::quantization::QuantizationParametersPrimitive;
|
||||
use burn_tensor::{Distribution, Tensor, TensorPrimitive};
|
||||
use rand::prelude::StdRng;
|
||||
use rand::SeedableRng;
|
||||
|
@ -376,4 +416,27 @@ mod tests {
|
|||
tensor_1.to_data().as_slice::<f32>().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_qtensor_strategy() {
|
||||
let tensor = TchTensor::<f32, 1>::from_data(
|
||||
TensorData::from([-1.8, -1.0, 0.0, 0.5]),
|
||||
tch::Device::Cpu,
|
||||
);
|
||||
let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8);
|
||||
let qparams = QuantizationParametersPrimitive {
|
||||
scale: TchTensor::from_data(TensorData::from([0.009_019_608]), tch::Device::Cpu),
|
||||
offset: Some(TchTensor::from_data(
|
||||
TensorData::from([72]),
|
||||
tch::Device::Cpu,
|
||||
)),
|
||||
};
|
||||
let qtensor: TchQTensor<i8, 1> = LibTorch::quantize(tensor, &scheme, qparams);
|
||||
|
||||
assert_eq!(qtensor.scheme(), &scheme);
|
||||
assert_eq!(
|
||||
qtensor.strategy(),
|
||||
QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.009_019_608, 72))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,10 +18,7 @@ impl<const D: usize, B: AutodiffBackend> Tensor<B, D> {
|
|||
TensorPrimitive::Float(tensor) => B::grad(tensor, grads)
|
||||
.map(TensorPrimitive::Float)
|
||||
.map(Tensor::new),
|
||||
TensorPrimitive::QFloat {
|
||||
tensor: _,
|
||||
strategy: _,
|
||||
} => B::grad(&self.primitive.clone().tensor(), grads)
|
||||
TensorPrimitive::QFloat(_tensor) => B::grad(&self.primitive.clone().tensor(), grads)
|
||||
.map(TensorPrimitive::Float)
|
||||
.map(Tensor::new),
|
||||
}
|
||||
|
@ -33,12 +30,11 @@ impl<const D: usize, B: AutodiffBackend> Tensor<B, D> {
|
|||
TensorPrimitive::Float(tensor) => B::grad_remove(tensor, grads)
|
||||
.map(TensorPrimitive::Float)
|
||||
.map(Tensor::new),
|
||||
TensorPrimitive::QFloat {
|
||||
tensor: _,
|
||||
strategy: _,
|
||||
} => B::grad_remove(&self.primitive.clone().tensor(), grads)
|
||||
.map(TensorPrimitive::Float)
|
||||
.map(Tensor::new),
|
||||
TensorPrimitive::QFloat(_tensor) => {
|
||||
B::grad_remove(&self.primitive.clone().tensor(), grads)
|
||||
.map(TensorPrimitive::Float)
|
||||
.map(Tensor::new)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -49,10 +45,7 @@ impl<const D: usize, B: AutodiffBackend> Tensor<B, D> {
|
|||
TensorPrimitive::Float(tensor) => {
|
||||
B::grad_replace(tensor, grads, grad.primitive.tensor())
|
||||
}
|
||||
TensorPrimitive::QFloat {
|
||||
tensor: _,
|
||||
strategy: _,
|
||||
} => B::grad_replace(
|
||||
TensorPrimitive::QFloat(_tensor) => B::grad_replace(
|
||||
&self.primitive.clone().tensor(),
|
||||
grads,
|
||||
grad.primitive.tensor(),
|
||||
|
@ -89,10 +82,7 @@ impl<B: AutodiffBackend> BasicAutodiffOps<B> for Float {
|
|||
) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive<D> {
|
||||
match tensor {
|
||||
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::inner(tensor)),
|
||||
TensorPrimitive::QFloat { tensor, strategy } => TensorPrimitive::QFloat {
|
||||
tensor: B::q_inner(tensor),
|
||||
strategy,
|
||||
},
|
||||
TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_inner(tensor)),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -101,10 +91,7 @@ impl<B: AutodiffBackend> BasicAutodiffOps<B> for Float {
|
|||
) -> <Self as TensorKind<B>>::Primitive<D> {
|
||||
match inner {
|
||||
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::from_inner(tensor)),
|
||||
TensorPrimitive::QFloat { tensor, strategy } => TensorPrimitive::QFloat {
|
||||
tensor: B::q_from_inner(tensor),
|
||||
strategy,
|
||||
},
|
||||
TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_from_inner(tensor)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1687,10 +1687,7 @@ impl<B: Backend> BasicOps<B> for Float {
|
|||
fn shape<const D: usize>(tensor: &Self::Primitive<D>) -> Shape<D> {
|
||||
match tensor {
|
||||
TensorPrimitive::Float(tensor) => B::float_shape(tensor),
|
||||
TensorPrimitive::QFloat {
|
||||
tensor,
|
||||
strategy: _,
|
||||
} => B::q_shape(tensor),
|
||||
TensorPrimitive::QFloat(tensor) => B::q_shape(tensor),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1702,10 +1699,7 @@ impl<B: Backend> BasicOps<B> for Float {
|
|||
TensorPrimitive::Float(tensor) => {
|
||||
TensorPrimitive::Float(B::float_reshape(tensor, shape))
|
||||
}
|
||||
TensorPrimitive::QFloat { tensor, strategy } => TensorPrimitive::QFloat {
|
||||
tensor: B::q_reshape(tensor, shape),
|
||||
strategy,
|
||||
},
|
||||
TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_reshape(tensor, shape)),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1744,10 +1738,7 @@ impl<B: Backend> BasicOps<B> for Float {
|
|||
fn device<const D: usize>(tensor: &Self::Primitive<D>) -> <B as Backend>::Device {
|
||||
match tensor {
|
||||
TensorPrimitive::Float(tensor) => B::float_device(tensor),
|
||||
TensorPrimitive::QFloat {
|
||||
tensor,
|
||||
strategy: _,
|
||||
} => B::q_device(tensor),
|
||||
TensorPrimitive::QFloat(tensor) => B::q_device(tensor),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1761,16 +1752,13 @@ impl<B: Backend> BasicOps<B> for Float {
|
|||
async fn into_data_async<const D: usize>(tensor: Self::Primitive<D>) -> TensorData {
|
||||
match tensor {
|
||||
TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await,
|
||||
TensorPrimitive::QFloat { tensor, strategy } => B::q_into_data(tensor, strategy).await,
|
||||
TensorPrimitive::QFloat(tensor) => B::q_into_data(tensor).await,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_data<const D: usize>(data: TensorData, device: &B::Device) -> Self::Primitive<D> {
|
||||
match data.dtype {
|
||||
DType::QFloat(strategy) => TensorPrimitive::QFloat {
|
||||
tensor: B::q_from_data(data, device),
|
||||
strategy,
|
||||
},
|
||||
DType::QFloat(_strategy) => TensorPrimitive::QFloat(B::q_from_data(data, device)),
|
||||
_ => TensorPrimitive::Float(B::float_from_data(data, device)),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
use alloc::vec::Vec;
|
||||
use core::convert::TryInto;
|
||||
|
||||
use crate::check;
|
||||
use crate::check::TensorCheck;
|
||||
use crate::ops::FullPrecisionBackend;
|
||||
use crate::quantization::{QuantizationParameters, QuantizationScheme};
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::stats;
|
||||
use crate::tensor::{Distribution, Shape, TensorData};
|
||||
use crate::Tensor;
|
||||
use crate::{check, QuantizationStrategy};
|
||||
use crate::{Int, TensorPrimitive};
|
||||
|
||||
impl<const D: usize, B> Tensor<B, D>
|
||||
|
@ -270,10 +271,7 @@ where
|
|||
pub fn is_require_grad(&self) -> bool {
|
||||
match &self.primitive {
|
||||
TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor),
|
||||
TensorPrimitive::QFloat {
|
||||
tensor,
|
||||
strategy: _,
|
||||
} => B::q_is_require_grad(tensor),
|
||||
TensorPrimitive::QFloat(tensor) => B::q_is_require_grad(tensor),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -286,10 +284,9 @@ where
|
|||
TensorPrimitive::Float(tensor) => {
|
||||
TensorPrimitive::Float(B::float_set_require_grad(tensor, require_grad))
|
||||
}
|
||||
TensorPrimitive::QFloat { tensor, strategy } => TensorPrimitive::QFloat {
|
||||
tensor: B::q_set_require_grad(tensor, require_grad),
|
||||
strategy,
|
||||
},
|
||||
TensorPrimitive::QFloat(tensor) => {
|
||||
TensorPrimitive::QFloat(B::q_set_require_grad(tensor, require_grad))
|
||||
}
|
||||
};
|
||||
Self::new(primitive)
|
||||
}
|
||||
|
@ -315,20 +312,26 @@ where
|
|||
.div_scalar(n as f32 - correction_factor as f32)
|
||||
}
|
||||
|
||||
/// Convert the tensor to a lower precision data type based on the quantization strategy.
|
||||
/// Convert the tensor to a lower precision data type based on the quantization scheme.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `strategy` - The quantization strategy.
|
||||
/// * `scheme` - The quantization scheme.
|
||||
/// * `qparams` - The pre-computed quantization parameters.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The quantized tensor.
|
||||
pub fn quantize(self, strategy: QuantizationStrategy) -> Tensor<B, D> {
|
||||
Tensor::new(TensorPrimitive::QFloat {
|
||||
tensor: B::quantize(self.primitive.tensor(), &strategy),
|
||||
strategy,
|
||||
})
|
||||
pub fn quantize(
|
||||
self,
|
||||
scheme: &QuantizationScheme,
|
||||
qparams: QuantizationParameters<B>,
|
||||
) -> Tensor<B, D> {
|
||||
Tensor::new(TensorPrimitive::QFloat(B::quantize(
|
||||
self.primitive.tensor(),
|
||||
scheme,
|
||||
qparams.into(),
|
||||
)))
|
||||
}
|
||||
|
||||
/// Convert the tensor back to a higher precision data type.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{backend::Backend, QuantizationStrategy};
|
||||
use crate::backend::Backend;
|
||||
|
||||
/// A type-level representation of the kind of a float tensor
|
||||
#[derive(Clone, Debug)]
|
||||
|
@ -18,19 +18,14 @@ pub enum TensorPrimitive<B: Backend, const D: usize> {
|
|||
/// Float tensor primitive.
|
||||
Float(B::FloatTensorPrimitive<D>),
|
||||
/// Quantized float tensor primitive.
|
||||
QFloat {
|
||||
/// The underlying quantized tensor.
|
||||
tensor: B::QuantizedTensorPrimitive<D>,
|
||||
/// The tensor quantization strategy.
|
||||
strategy: QuantizationStrategy,
|
||||
},
|
||||
QFloat(B::QuantizedTensorPrimitive<D>),
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> TensorPrimitive<B, D> {
|
||||
/// Returns the full tensor representation.
|
||||
pub fn tensor(self) -> B::FloatTensorPrimitive<D> {
|
||||
match self {
|
||||
Self::QFloat { tensor, strategy } => B::dequantize(tensor, &strategy),
|
||||
Self::QFloat(tensor) => B::dequantize(tensor),
|
||||
Self::Float(tensor) => tensor,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use alloc::string::String;
|
||||
pub use burn_common::sync_type::SyncType;
|
||||
|
||||
use crate::ops::*;
|
||||
use crate::tensor::Element;
|
||||
use crate::{ops::*, quantization::QTensorPrimitive};
|
||||
|
||||
use super::{BackendBridge, DeviceOps};
|
||||
|
||||
|
@ -87,7 +87,11 @@ pub trait Backend:
|
|||
type BoolTensorPrimitive<const D: usize>: Clone + Send + 'static + core::fmt::Debug;
|
||||
|
||||
/// Tensor primitive to be used for all quantized operations.
|
||||
type QuantizedTensorPrimitive<const D: usize>: Clone + Send + 'static + core::fmt::Debug;
|
||||
type QuantizedTensorPrimitive<const D: usize>: QTensorPrimitive
|
||||
+ Clone
|
||||
+ Send
|
||||
+ 'static
|
||||
+ core::fmt::Debug;
|
||||
|
||||
/// If autodiff is enabled.
|
||||
fn ad_enabled() -> bool {
|
||||
|
|
|
@ -7,8 +7,9 @@ use alloc::vec::Vec;
|
|||
use half::{bf16, f16};
|
||||
|
||||
use crate::{
|
||||
tensor::Shape, DType, Distribution, Element, ElementConversion, Quantization,
|
||||
QuantizationStrategy,
|
||||
quantization::{Quantization, QuantizationStrategy},
|
||||
tensor::Shape,
|
||||
DType, Distribution, Element, ElementConversion,
|
||||
};
|
||||
|
||||
use num_traits::pow::Pow;
|
||||
|
@ -66,10 +67,14 @@ impl TensorData {
|
|||
}
|
||||
}
|
||||
|
||||
fn try_as_slice<E: Element>(&self) -> Result<&[E], DataError> {
|
||||
bytemuck::checked::try_cast_slice(&self.bytes).map_err(DataError::CastError)
|
||||
}
|
||||
|
||||
/// Returns the immutable slice view of the tensor data.
|
||||
pub fn as_slice<E: Element>(&self) -> Result<&[E], DataError> {
|
||||
if E::dtype() == self.dtype {
|
||||
bytemuck::checked::try_cast_slice(&self.bytes).map_err(DataError::CastError)
|
||||
self.try_as_slice()
|
||||
} else {
|
||||
Err(DataError::TypeMismatch(format!(
|
||||
"Invalid target element type (expected {:?}, got {:?})",
|
||||
|
@ -598,10 +603,10 @@ impl core::fmt::Display for TensorData {
|
|||
DType::Bool => format!("{:?}", self.as_slice::<bool>().unwrap()),
|
||||
DType::QFloat(q) => match &q {
|
||||
QuantizationStrategy::PerTensorAffineInt8(_) => {
|
||||
format!("{:?} {q:?}", self.as_slice::<i8>().unwrap())
|
||||
format!("{:?} {q:?}", self.try_as_slice::<i8>().unwrap())
|
||||
}
|
||||
QuantizationStrategy::PerTensorSymmetricInt8(_) => {
|
||||
format!("{:?} {q:?}", self.as_slice::<i8>().unwrap())
|
||||
format!("{:?} {q:?}", self.try_as_slice::<i8>().unwrap())
|
||||
}
|
||||
},
|
||||
};
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use core::cmp::Ordering;
|
||||
|
||||
use crate::{cast::ToElement, Distribution, QuantizationStrategy};
|
||||
use crate::{cast::ToElement, quantization::QuantizationStrategy, Distribution};
|
||||
use half::{bf16, f16};
|
||||
use rand::RngCore;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
|
|
@ -4,14 +4,12 @@ mod api;
|
|||
mod data;
|
||||
mod distribution;
|
||||
mod element;
|
||||
mod quantization_strategy;
|
||||
mod shape;
|
||||
|
||||
pub use api::*;
|
||||
pub use data::*;
|
||||
pub use distribution::*;
|
||||
pub use element::*;
|
||||
pub use quantization_strategy::*;
|
||||
pub use shape::*;
|
||||
|
||||
/// The activation module.
|
||||
|
@ -32,6 +30,9 @@ pub mod module;
|
|||
/// Operations on tensors module.
|
||||
pub mod ops;
|
||||
|
||||
/// Tensor quantization module.
|
||||
pub mod quantization;
|
||||
|
||||
#[cfg(feature = "experimental-named-tensor")]
|
||||
mod named;
|
||||
#[cfg(feature = "experimental-named-tensor")]
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
use core::future::Future;
|
||||
|
||||
use crate::{backend::Backend, Device, QuantizationStrategy, Shape, TensorData};
|
||||
use crate::{
|
||||
backend::Backend,
|
||||
quantization::{QuantizationParametersPrimitive, QuantizationScheme},
|
||||
Device, Shape, TensorData,
|
||||
};
|
||||
|
||||
use super::{FloatTensor, QuantizedTensor};
|
||||
|
||||
|
@ -19,17 +23,15 @@ pub trait QTensorOps<B: Backend> {
|
|||
/// The tensor with the given data.
|
||||
fn q_from_data<const D: usize>(data: TensorData, device: &Device<B>) -> QuantizedTensor<B, D>;
|
||||
|
||||
/// Convert the tensor to a lower precision data type based on the quantization strategy.
|
||||
/// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
|
||||
fn quantize<const D: usize>(
|
||||
tensor: FloatTensor<B, D>,
|
||||
strategy: &QuantizationStrategy,
|
||||
scheme: &QuantizationScheme,
|
||||
qparams: QuantizationParametersPrimitive<B>,
|
||||
) -> QuantizedTensor<B, D>;
|
||||
|
||||
/// Convert the tensor back to a higher precision data type based on the quantization strategy.
|
||||
fn dequantize<const D: usize>(
|
||||
tensor: QuantizedTensor<B, D>,
|
||||
strategy: &QuantizationStrategy,
|
||||
) -> FloatTensor<B, D>;
|
||||
/// Convert the tensor back to a higher precision data type.
|
||||
fn dequantize<const D: usize>(tensor: QuantizedTensor<B, D>) -> FloatTensor<B, D>;
|
||||
|
||||
/// Gets the shape of the tensor.
|
||||
///
|
||||
|
@ -79,7 +81,6 @@ pub trait QTensorOps<B: Backend> {
|
|||
/// The data structure with the tensor's data.
|
||||
fn q_into_data<const D: usize>(
|
||||
tensor: QuantizedTensor<B, D>,
|
||||
strategy: QuantizationStrategy,
|
||||
) -> impl Future<Output = TensorData> + Send;
|
||||
|
||||
/// Sets the `require_grad` flag of a tensor.
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
use crate::{backend::Backend, Tensor};
|
||||
|
||||
/// The observed input calibration range.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CalibrationRange<B: Backend> {
|
||||
/// Minimum observed value.
|
||||
pub min: Tensor<B, 1>,
|
||||
/// Maximum observed value.
|
||||
pub max: Tensor<B, 1>,
|
||||
}
|
||||
|
||||
/// Calibration method used to compute the quantization range mapping.
|
||||
pub trait Calibration {
|
||||
/// Compute the input tensor range.
|
||||
fn compute_range<B: Backend, const D: usize>(
|
||||
&self,
|
||||
tensor: &Tensor<B, D>,
|
||||
) -> CalibrationRange<B>;
|
||||
}
|
||||
|
||||
/// Computes the per-tensor quantization range mapping based on the min and max values.
|
||||
pub struct MinMaxCalibration {}
|
||||
|
||||
impl Calibration for MinMaxCalibration {
|
||||
fn compute_range<B: Backend, const D: usize>(
|
||||
&self,
|
||||
tensor: &Tensor<B, D>,
|
||||
) -> CalibrationRange<B> {
|
||||
let min = tensor.clone().min();
|
||||
let max = tensor.clone().max();
|
||||
|
||||
CalibrationRange { min, max }
|
||||
}
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
mod calibration;
|
||||
mod parameters;
|
||||
mod primitive;
|
||||
mod scheme;
|
||||
mod strategy;
|
||||
|
||||
pub use calibration::*;
|
||||
pub use parameters::*;
|
||||
pub use primitive::*;
|
||||
pub use scheme::*;
|
||||
pub use strategy::*;
|
|
@ -0,0 +1,35 @@
|
|||
use crate::{backend::Backend, Int, Tensor};
|
||||
|
||||
/// The quantization parameters.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct QuantizationParameters<B: Backend> {
|
||||
/// The scaling factor.
|
||||
pub scale: Tensor<B, 1>,
|
||||
/// The zero-point offset.
|
||||
pub offset: Option<Tensor<B, 1, Int>>,
|
||||
}
|
||||
|
||||
/// The quantization parameters primitive.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level struct used internally by the library to provide the quantization parameters
|
||||
/// to the backends. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this struct directly.
|
||||
///
|
||||
/// Users should prefer the [QuantizationParameters] struct, which is designed for public use.
|
||||
pub struct QuantizationParametersPrimitive<B: Backend> {
|
||||
/// The scaling factor.
|
||||
pub scale: B::FloatTensorPrimitive<1>,
|
||||
/// The zero-point offset.
|
||||
pub offset: Option<B::IntTensorPrimitive<1>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> From<QuantizationParameters<B>> for QuantizationParametersPrimitive<B> {
|
||||
fn from(value: QuantizationParameters<B>) -> Self {
|
||||
QuantizationParametersPrimitive {
|
||||
scale: value.scale.primitive.tensor(),
|
||||
offset: value.offset.map(|x| x.primitive),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
use super::{QuantizationScheme, QuantizationStrategy};
|
||||
|
||||
/// Quantized tensor primitive.
|
||||
pub trait QTensorPrimitive {
|
||||
/// Returns the quantization scheme for the given tensor.
|
||||
fn scheme(&self) -> &QuantizationScheme;
|
||||
/// Returns the quantization strategy for the given tensor.
|
||||
///
|
||||
/// # Remarks
|
||||
/// Retrieving the quantization strategy with its corresponding parameters might require
|
||||
/// synchronization on the backend.
|
||||
fn strategy(&self) -> QuantizationStrategy;
|
||||
}
|
|
@ -0,0 +1,71 @@
|
|||
use crate::{backend::Backend, Int, Tensor};
|
||||
|
||||
use super::{CalibrationRange, QuantizationParameters};
|
||||
|
||||
/// Quantization data type.
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum QuantizationType {
|
||||
/// 8-bit signed integer.
|
||||
QInt8,
|
||||
}
|
||||
|
||||
/// Quantization scheme.
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum QuantizationScheme {
|
||||
/// Per-tensor affine/asymmetric quantization.
|
||||
PerTensorAffine(QuantizationType),
|
||||
/// Per-tensor symmetric quantization.
|
||||
PerTensorSymmetric(QuantizationType),
|
||||
// /// Per-channel affine/asymmetric quantization.
|
||||
// PerChannelAffine,
|
||||
// /// Per-channel symmetric quantization.
|
||||
// PerChannelSymmetric,
|
||||
}
|
||||
|
||||
/// Round the tensor to the nearest integer.
|
||||
fn round<B: Backend, const D: usize>(tensor: Tensor<B, D>) -> Tensor<B, D, Int> {
|
||||
tensor.add_scalar(0.5).int()
|
||||
}
|
||||
|
||||
impl QuantizationScheme {
|
||||
/// Compute the quantization parameters.
|
||||
pub fn compute_q_params<B: Backend>(
|
||||
&self,
|
||||
range: CalibrationRange<B>,
|
||||
) -> QuantizationParameters<B> {
|
||||
match self {
|
||||
QuantizationScheme::PerTensorAffine(dtype) => match dtype {
|
||||
QuantizationType::QInt8 => {
|
||||
// Quantized range `[a, b]`
|
||||
let a = i8::MIN as i32;
|
||||
let b = i8::MAX as i32;
|
||||
|
||||
// Input range `[alpha, beta]`
|
||||
let input_range = range.max.clone().sub(range.min.clone());
|
||||
|
||||
QuantizationParameters {
|
||||
scale: input_range.clone().div_scalar(b - a),
|
||||
offset: Some(round(
|
||||
(range.max.mul_scalar(a) - range.min.mul_scalar(b)).div(input_range),
|
||||
)),
|
||||
}
|
||||
}
|
||||
},
|
||||
QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
|
||||
QuantizationType::QInt8 => {
|
||||
// Quantized range `[a, b]`
|
||||
let b = i8::MAX as i32;
|
||||
let a = -b;
|
||||
|
||||
// Compute scale to convert an input value in range `[-alpha, alpha]`
|
||||
let values_range = range.min.abs().max_pair(range.max.abs()).mul_scalar(2);
|
||||
|
||||
QuantizationParameters {
|
||||
scale: values_range.div_scalar(b - a),
|
||||
offset: None,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
|
@ -8,6 +8,10 @@ use burn_common::{iter_par, run_par};
|
|||
use num_traits::{Float, PrimInt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{QuantizationScheme, QuantizationType};
|
||||
|
||||
// NOTE: QuantizationStrategy is used for TensorData (sync).
|
||||
|
||||
/// Quantization strategy.
|
||||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum QuantizationStrategy {
|
||||
|
@ -17,6 +21,20 @@ pub enum QuantizationStrategy {
|
|||
PerTensorSymmetricInt8(SymmetricQuantization<f32, i8>),
|
||||
}
|
||||
|
||||
impl QuantizationStrategy {
|
||||
/// Returns the corresponding quantization scheme.
|
||||
pub fn scheme(&self) -> QuantizationScheme {
|
||||
match self {
|
||||
QuantizationStrategy::PerTensorAffineInt8(_) => {
|
||||
QuantizationScheme::PerTensorAffine(QuantizationType::QInt8)
|
||||
}
|
||||
QuantizationStrategy::PerTensorSymmetricInt8(_) => {
|
||||
QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Quantization scheme to convert elements of a higher precision data type `E` to a lower precision
|
||||
/// data type `Q` and vice-versa.
|
||||
pub trait Quantization<E: Float, Q: PrimInt> {
|
||||
|
@ -41,6 +59,17 @@ pub struct AffineQuantization<E: Float, Q: PrimInt, A: PrimInt> {
|
|||
_a: PhantomData<A>,
|
||||
}
|
||||
|
||||
impl<E: Float, Q: PrimInt, A: PrimInt> AffineQuantization<E, Q, A> {
|
||||
/// Initialize an affine quantization scheme with the given parameters.
|
||||
pub fn init(scale: E, offset: Q) -> Self {
|
||||
Self {
|
||||
scale,
|
||||
offset,
|
||||
_a: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Float, Q: PrimInt, A: PrimInt> Quantization<E, Q> for AffineQuantization<E, Q, A> {
|
||||
fn new(alpha: E, beta: E) -> Self {
|
||||
// Q range `[a, b]`
|
||||
|
@ -49,11 +78,10 @@ impl<E: Float, Q: PrimInt, A: PrimInt> Quantization<E, Q> for AffineQuantization
|
|||
|
||||
// Compute scale and offset to convert a floating point value in range `[alpha, beta]` to the quantized range
|
||||
let range = beta - alpha;
|
||||
Self {
|
||||
scale: range / (b - a),
|
||||
offset: Q::from(E::round(((beta * a) - (alpha * b)) / range)).unwrap(),
|
||||
_a: PhantomData,
|
||||
}
|
||||
Self::init(
|
||||
range / (b - a),
|
||||
Q::from(E::round(((beta * a) - (alpha * b)) / range)).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn quantize(&self, values: &[E]) -> Vec<Q> {
|
||||
|
@ -97,6 +125,16 @@ pub struct SymmetricQuantization<E: Float, Q: PrimInt> {
|
|||
_q: PhantomData<Q>,
|
||||
}
|
||||
|
||||
impl<E: Float, Q: PrimInt> SymmetricQuantization<E, Q> {
|
||||
/// Initialize a symmetric quantization scheme with the given parameters.
|
||||
pub fn init(scale: E) -> Self {
|
||||
Self {
|
||||
scale,
|
||||
_q: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Float, Q: PrimInt> Quantization<E, Q> for SymmetricQuantization<E, Q> {
|
||||
fn new(alpha: E, beta: E) -> Self {
|
||||
assert!(
|
||||
|
@ -110,10 +148,7 @@ impl<E: Float, Q: PrimInt> Quantization<E, Q> for SymmetricQuantization<E, Q> {
|
|||
|
||||
// Compute scale to convert a floating point value in range `[-alpha, alpha]` to the quantized range
|
||||
let alpha = alpha.abs().max(beta.abs());
|
||||
Self {
|
||||
scale: (alpha + alpha) / (b - a),
|
||||
_q: PhantomData,
|
||||
}
|
||||
Self::init((alpha + alpha) / (b - a))
|
||||
}
|
||||
|
||||
fn quantize(&self, values: &[E]) -> Vec<Q> {
|
|
@ -2,6 +2,7 @@ mod activation;
|
|||
mod clone_invariance;
|
||||
mod module;
|
||||
mod ops;
|
||||
mod quantization;
|
||||
mod stats;
|
||||
|
||||
#[allow(missing_docs)]
|
||||
|
@ -113,5 +114,9 @@ macro_rules! testgen_all {
|
|||
|
||||
// test padding
|
||||
burn_tensor::testgen_padding!();
|
||||
|
||||
// test quantization
|
||||
burn_tensor::testgen_calibration!();
|
||||
burn_tensor::testgen_scheme!();
|
||||
};
|
||||
}
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
#[burn_tensor_testgen::testgen(calibration)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{
|
||||
quantization::{Calibration, MinMaxCalibration, QuantizationType},
|
||||
Tensor, TensorData,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn min_max_calibration_range() {
|
||||
let tensor =
|
||||
Tensor::<TestBackend, 1>::from_floats([-1.8, -1.0, 0.0, 0.5], &Default::default());
|
||||
let calibration = MinMaxCalibration {};
|
||||
|
||||
let range = calibration.compute_range(&tensor);
|
||||
|
||||
range
|
||||
.min
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([-1.8]), false);
|
||||
range
|
||||
.max
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([0.5]), false);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,2 @@
|
|||
mod calibration;
|
||||
mod scheme;
|
|
@ -0,0 +1,48 @@
|
|||
#[burn_tensor_testgen::testgen(scheme)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{
|
||||
quantization::{CalibrationRange, QuantizationScheme, QuantizationType},
|
||||
Tensor, TensorData,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn per_tensor_affine_int8() {
|
||||
let device = Default::default();
|
||||
let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8);
|
||||
let range = CalibrationRange {
|
||||
min: Tensor::<TestBackend, 1>::from_floats([-1.8], &device),
|
||||
max: Tensor::<TestBackend, 1>::from_floats([0.5], &device),
|
||||
};
|
||||
|
||||
let qparams = scheme.compute_q_params(range);
|
||||
|
||||
qparams
|
||||
.scale
|
||||
.into_data()
|
||||
.assert_approx_eq(&TensorData::from([0.009_019_608]), 9);
|
||||
qparams
|
||||
.offset
|
||||
.unwrap()
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([72]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn per_tensor_symmetric_int8() {
|
||||
let device = Default::default();
|
||||
let scheme = QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8);
|
||||
let range = CalibrationRange {
|
||||
min: Tensor::<TestBackend, 1>::from_floats([-1.8], &device),
|
||||
max: Tensor::<TestBackend, 1>::from_floats([0.5], &device),
|
||||
};
|
||||
|
||||
let qparams = scheme.compute_q_params(range);
|
||||
|
||||
qparams
|
||||
.scale
|
||||
.into_data()
|
||||
.assert_approx_eq(&TensorData::from([0.014_173_228]), 9);
|
||||
assert!(qparams.offset.is_none());
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue