Add more quantization support for burn-jit (#2275)

* Add cubecl quantization kernels and QTensorOps for burn-jit

* Fix typo

* Fix output vec factor

* Fix output dtype size_of

* Remove unused code in dequantize test

* Fix dequantize vectorization

* Handle tensors when number of elems is not a multiple of 4

* Support quantize for tensors with less than 4 elems (no vectorization)

* Fix equal 0 test

* Add quantize/dequantize tests

* Add q_to_device

* Refactor kernels for latest cubecl

* intermediate i32 cast

* Fix size_of output type

* Use strict=false to ignore floating point precision issues with qparams equality

* Only check that lhs & rhs strategies match (but not strict on qparams values)

* Use assert_approx_eq on dequant values

* Reduce precision for flaky test

* Remove todo comment

* Add comment for cast to unsigned

* More comment

---------

Co-authored-by: louisfd <louisfd94@gmail.com>
This commit is contained in:
Guillaume Lagrange 2024-09-17 10:08:20 -04:00 committed by GitHub
parent 834005eadb
commit aa79e36a8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 744 additions and 50 deletions

View File

@ -34,7 +34,8 @@ where
type FloatTensorPrimitive<const D: usize> = JitTensor<R, Self::FloatElem, D>; type FloatTensorPrimitive<const D: usize> = JitTensor<R, Self::FloatElem, D>;
type IntTensorPrimitive<const D: usize> = JitTensor<R, Self::IntElem, D>; type IntTensorPrimitive<const D: usize> = JitTensor<R, Self::IntElem, D>;
type BoolTensorPrimitive<const D: usize> = JitTensor<R, u32, D>; type BoolTensorPrimitive<const D: usize> = JitTensor<R, u32, D>;
type QuantizedTensorPrimitive<const D: usize> = QJitTensor<R, D>; type QuantizedTensorPrimitive<const D: usize> =
QJitTensor<R, Self::FloatElem, Self::IntElem, D>;
fn name() -> String { fn name() -> String {
format!("jit<{}>", R::name()) format!("jit<{}>", R::name())

View File

@ -25,6 +25,8 @@ pub mod matmul;
pub mod pool; pub mod pool;
/// Pseudo-random number generator kernels /// Pseudo-random number generator kernels
pub mod prng; pub mod prng;
/// Quantization operations
pub mod quantization;
/// Reduction algorithms /// Reduction algorithms
pub mod reduce; pub mod reduce;

View File

@ -0,0 +1,211 @@
use crate::tensor::{JitTensor, QJitTensor};
use crate::FloatElement;
use crate::{IntElement, JitElement, JitRuntime};
use burn_tensor::quantization::{QuantizationScheme, QuantizationType};
use cubecl::calculate_cube_count_elemwise;
use cubecl::prelude::*;
#[cube]
pub(crate) fn dequantize_affine_int8<F: Float>(value: i32, scale: F, offset: i32) -> F {
// x = scale * (x_q - offset)
scale * (F::cast_from(value) - F::cast_from(offset))
}
#[cube]
pub(crate) fn extract_i8(value: u32, offset: u32) -> i32 {
// Extract 8-bit segment
let value = (value >> offset) & 0xFF;
// Check if the value is negative by inspecting the MSB and subtract 256 if it is
// Subtract 0 or 256 to circumvent unsupported conditional assignment (let x = if {} else {};)
let sub = i32::cast_from(value & 0x80 != 0) * 256;
i32::cast_from(value) - sub
}
#[cube(launch_unchecked)]
pub(crate) fn dequantize_per_tensor_affine_int8_kernel(
input: &Tensor<u32>,
scale: &Tensor<f32>,
offset: &Tensor<i32>,
output: &mut Tensor<f32>,
#[comptime] vectorized: bool,
) {
if ABSOLUTE_POS >= output.len() {
return;
}
let scale = scale[0];
let offset = offset[0];
let num_packed = 4;
let value = input[ABSOLUTE_POS];
let output_pos = ABSOLUTE_POS * num_packed;
if vectorized {
let vectorization_factor = vectorization_of(input);
#[unroll]
for i in 0..vectorization_factor {
// Extract each 8-bit segment
let v1 = extract_i8(value[i], 24);
let v2 = extract_i8(value[i], 16);
let v3 = extract_i8(value[i], 8);
let v4 = extract_i8(value[i], 0);
output[output_pos * vectorization_factor + i * num_packed] =
dequantize_affine_int8::<f32>(v1, scale, offset);
output[output_pos * vectorization_factor + i * num_packed + 1] =
dequantize_affine_int8::<f32>(v2, scale, offset);
output[output_pos * vectorization_factor + i * num_packed + 2] =
dequantize_affine_int8::<f32>(v3, scale, offset);
output[output_pos * vectorization_factor + i * num_packed + 3] =
dequantize_affine_int8::<f32>(v4, scale, offset);
}
} else {
// Extract each 8-bit segment
let v1 = extract_i8(value, 24);
let v2 = extract_i8(value, 16);
let v3 = extract_i8(value, 8);
let v4 = extract_i8(value, 0);
output[output_pos] = dequantize_affine_int8::<f32>(v1, scale, offset);
output[output_pos + 1] = dequantize_affine_int8::<f32>(v2, scale, offset);
output[output_pos + 2] = dequantize_affine_int8::<f32>(v3, scale, offset);
output[output_pos + 3] = dequantize_affine_int8::<f32>(v4, scale, offset);
}
}
#[cube]
pub(crate) fn dequantize_symmetric_int8<F: Float>(value: i32, scale: F) -> F {
// x = scale * x_q
scale * F::cast_from(value)
}
// Would have wrapped symmetric with the same affine kernel but cube doesn't support Option<Tensor> for offset.
#[cube(launch_unchecked)]
pub(crate) fn dequantize_per_tensor_symmetric_int8_kernel(
input: &Tensor<u32>,
scale: &Tensor<f32>,
output: &mut Tensor<f32>,
#[comptime] vectorized: bool,
) {
if ABSOLUTE_POS >= output.len() {
return;
}
let scale = scale[0];
let num_packed = 4;
let value = input[ABSOLUTE_POS];
let output_pos = ABSOLUTE_POS * num_packed;
if vectorized {
let vectorization_factor = vectorization_of(input);
#[unroll]
for i in 0..vectorization_factor {
for j in 0..num_packed {
let output_idx = output_pos * vectorization_factor + i * num_packed + j;
if output_idx >= output.len() {
return; // value not quantized (padding)
}
// Extract each 8-bit segment
let v = extract_i8(value[i], (3 - j) * 8);
output[output_idx] = dequantize_symmetric_int8::<f32>(v, scale);
}
}
} else {
// Extract each 8-bit segment
for j in 0..num_packed {
let output_idx = output_pos + j;
if output_idx >= output.len() {
return; // value not quantized (padding)
}
// Extract each 8-bit segment
let v = extract_i8(value, (3 - j) * 8);
output[output_pos + j] = dequantize_symmetric_int8::<f32>(v, scale);
}
}
}
pub(crate) fn dequantize_per_tensor<R, F, I, const D: usize>(
tensor: JitTensor<R, u32, D>,
scale: JitTensor<R, F, 1>,
offset: Option<JitTensor<R, I, 1>>,
) -> JitTensor<R, F, D>
where
R: JitRuntime,
F: JitElement,
I: IntElement,
{
// The actual number of elements is 1/4 (four int8 values packed in a single u32)
// so we choose a vectorization factor to match a valid input binding size.
let num_out_elems = tensor.shape.num_elements();
let num_elems = usize::div_ceil(num_out_elems, 4);
let vectorization_factor = [4u8, 2, 1]
.iter()
.filter_map(|&v| {
if num_elems >= v as usize {
Some(v)
} else {
None
}
})
.next()
.unwrap();
let cube_dim = CubeDim::default();
let cube_count =
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
let shape_output = tensor.shape.clone();
let client = tensor.client.clone();
let handle = client.empty(num_out_elems * core::mem::size_of::<F>());
let output =
JitTensor::new_contiguous(client.clone(), tensor.device.clone(), shape_output, handle);
let dummy_array = [1; D];
if let Some(offset) = offset {
unsafe {
dequantize_per_tensor_affine_int8_kernel::launch_unchecked::<R>(
&client,
cube_count,
cube_dim,
tensor.as_tensor_arg(vectorization_factor),
// Ignore shape and stride
TensorArg::from_raw_parts(&scale.handle, &dummy_array, &dummy_array, 1),
TensorArg::from_raw_parts(&offset.handle, &dummy_array, &dummy_array, 1),
output.as_tensor_arg(1),
vectorization_factor > 1,
)
};
} else {
unsafe {
dequantize_per_tensor_symmetric_int8_kernel::launch_unchecked::<R>(
&client,
cube_count,
cube_dim,
tensor.as_tensor_arg(vectorization_factor),
// Ignore shape and stride
TensorArg::from_raw_parts(&scale.handle, &dummy_array, &dummy_array, 1),
output.as_tensor_arg(1),
vectorization_factor > 1,
)
};
}
output
}
/// Convert the tensor back to a higher precision data type.
pub fn dequantize<R, F, I, const D: usize>(tensor: QJitTensor<R, F, I, D>) -> JitTensor<R, F, D>
where
R: JitRuntime,
F: FloatElement,
I: IntElement,
{
match tensor.scheme {
QuantizationScheme::PerTensorAffine(dtype)
| QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
QuantizationType::QInt8 => {
dequantize_per_tensor(tensor.qtensor, tensor.qparams.scale, tensor.qparams.offset)
}
},
}
}

View File

@ -0,0 +1,5 @@
mod dequantize;
mod quantize;
pub use dequantize::*;
pub use quantize::*;

View File

@ -0,0 +1,219 @@
use crate::tensor::{JitQuantizationParameters, JitTensor, QJitTensor};
use crate::FloatElement;
use crate::{IntElement, JitElement, JitRuntime};
use burn_tensor::quantization::{QuantizationScheme, QuantizationType};
use cubecl::calculate_cube_count_elemwise;
use cubecl::prelude::*;
#[cube]
pub(crate) fn quantize_affine_int8<F: Float>(
value: F,
scale: F,
offset: i32,
range_min: F,
range_max: F,
) -> u32 {
let offset = F::cast_from(offset);
// x_q = clamp(round(x / scale + offset), a, b)
// NOTE: we add 256 before casting to unsigned to correctly represent negative values
u32::cast_from(
i32::cast_from(F::clamp(
F::round((value / scale) + offset),
range_min,
range_max,
)) + 256,
)
}
#[cube(launch_unchecked)]
pub(crate) fn quantize_per_tensor_affine_int8_kernel(
input: &Tensor<f32>,
scale: &Tensor<f32>,
offset: &Tensor<i32>,
range_min: f32,
range_max: f32,
output: &mut Tensor<u32>,
#[comptime] vectorized: bool,
) {
if ABSOLUTE_POS >= output.len() {
return;
}
let scale = scale[0];
let offset = offset[0];
let num_packed = 4;
let mut v_packed = 0;
if vectorized {
// Assuming a vectorization factor of 4 (equal to the number of values packed)
let value = input[ABSOLUTE_POS];
let vectorization_factor = vectorization_of(input);
#[unroll]
for i in 0..vectorization_factor {
let v = quantize_affine_int8::<f32>(value[i], scale, offset, range_min, range_max);
// Shift and combine into u32
v_packed |= (v & 0xFF) << (8 * (num_packed - i - 1));
}
} else {
for i in 0..num_packed {
let v = quantize_affine_int8::<f32>(
input[ABSOLUTE_POS + i],
scale,
offset,
range_min,
range_max,
);
// Shift and combine into u32
v_packed |= (v & 0xFF) << (8 * (num_packed - i - 1));
}
}
output[ABSOLUTE_POS] = v_packed;
}
#[cube]
pub(crate) fn quantize_symmetric_int8<F: Float>(
value: F,
scale: F,
range_min: F,
range_max: F,
) -> u32 {
// x_q = clamp(round(x / scale), a, b)
// NOTE: we add 256 before casting to unsigned to correctly represent negative values
u32::cast_from(i32::cast_from(F::clamp(F::round(value / scale), range_min, range_max)) + 256)
}
// Would have wrapped symmetric with the same affine kernel but cube doesn't support Option<Tensor> for offset.
#[cube(launch_unchecked)]
pub(crate) fn quantize_per_tensor_symmetric_int8_kernel(
input: &Tensor<f32>,
scale: &Tensor<f32>,
range_min: f32,
range_max: f32,
output: &mut Tensor<u32>,
#[comptime] vectorized: bool,
) {
if ABSOLUTE_POS >= output.len() {
return;
}
let scale = scale[0];
let num_packed = 4;
let mut v_packed = 0;
if vectorized {
// Assuming a vectorization factor of 4 (equal to the number of values packed)
let value = input[ABSOLUTE_POS];
let vectorization_factor = vectorization_of(input);
#[unroll]
for i in 0..vectorization_factor {
let v = quantize_symmetric_int8::<f32>(value[i], scale, range_min, range_max);
// Shift and combine into u32
v_packed |= (v & 0xFF) << (8 * (num_packed - i - 1));
}
} else {
for i in 0..num_packed {
let v = quantize_symmetric_int8::<f32>(
input[ABSOLUTE_POS + i],
scale,
range_min,
range_max,
);
// Shift and combine into u32
v_packed |= (v & 0xFF) << (8 * (num_packed - i - 1));
}
}
output[ABSOLUTE_POS] = v_packed;
}
pub(crate) fn quantize_per_tensor<R, F, I, const D: usize>(
tensor: JitTensor<R, F, D>,
scale: JitTensor<R, F, 1>,
offset: Option<JitTensor<R, I, 1>>,
) -> JitTensor<R, u32, D>
where
R: JitRuntime,
F: JitElement,
I: IntElement,
{
let num_elems = tensor.shape.num_elements();
let shape_output = tensor.shape.clone();
let client = tensor.client.clone();
// Output tensor contains 4x less elements (four int8 values packed in a single u32)
let handle = client.empty(usize::div_ceil(num_elems, 4) * core::mem::size_of::<u32>());
let output =
JitTensor::new_contiguous(client.clone(), tensor.device.clone(), shape_output, handle);
// Force vectorization to process 4 quantized values packed for 1 output value
let vectorization_factor: u8 = if num_elems < 4 { 1 } else { 4 };
let cube_dim = CubeDim::default();
let cube_count =
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
let dummy_array = [1; D];
if let Some(offset) = offset {
unsafe {
quantize_per_tensor_affine_int8_kernel::launch_unchecked::<R>(
&client,
cube_count,
cube_dim,
tensor.as_tensor_arg(vectorization_factor),
// Ignore shape and stride
TensorArg::from_raw_parts(&scale.handle, &dummy_array, &dummy_array, 1),
TensorArg::from_raw_parts(&offset.handle, &dummy_array, &dummy_array, 1),
ScalarArg::new(i8::MIN as f32),
ScalarArg::new(i8::MAX as f32),
output.as_tensor_arg(1),
vectorization_factor > 1,
)
};
} else {
unsafe {
quantize_per_tensor_symmetric_int8_kernel::launch_unchecked::<R>(
&client,
cube_count,
cube_dim,
tensor.as_tensor_arg(vectorization_factor),
// Ignore shape and stride
TensorArg::from_raw_parts(&scale.handle, &dummy_array, &dummy_array, 1),
ScalarArg::new(-i8::MAX as f32),
ScalarArg::new(i8::MAX as f32),
output.as_tensor_arg(1),
vectorization_factor > 1,
)
};
}
output
}
/// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
pub fn quantize<R, F, I, const D: usize>(
tensor: JitTensor<R, F, D>,
scheme: &QuantizationScheme,
qparams: JitQuantizationParameters<R, F, I>,
) -> QJitTensor<R, F, I, D>
where
R: JitRuntime,
F: FloatElement,
I: IntElement,
{
let qtensor = match scheme {
QuantizationScheme::PerTensorAffine(dtype)
| QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
QuantizationType::QInt8 => {
quantize_per_tensor(tensor, qparams.scale.clone(), qparams.offset.clone())
}
},
};
QJitTensor {
qtensor,
scheme: scheme.clone(),
qparams,
}
}

View File

@ -1,12 +1,47 @@
use std::ops::Range; use std::ops::Range;
use alloc::vec::Vec;
use burn_tensor::{ use burn_tensor::{
ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
quantization::{QuantizationParametersPrimitive, QuantizationScheme}, quantization::{
Device, Shape, TensorData, QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme,
QuantizationStrategy, QuantizationType,
},
DType, Device, ElementConversion, Shape, TensorData,
}; };
use crate::{tensor::QJitTensor, FloatElement, IntElement, JitBackend, JitRuntime}; use crate::{
kernel,
tensor::{JitQuantizationParameters, JitTensor, QJitTensor},
FloatElement, IntElement, JitBackend, JitRuntime,
};
use cubecl::CubeElement;
fn pack_i8s_to_u32s(data: &TensorData) -> Vec<u32> {
// Shift and combine groups of four 8-bit values into a u32.
// Same as doing this:
// let result = (a_u8 & 0xFF) << 24 | (b_u8 & 0xFF) << 16 | (c_u8 & 0xFF) << 8 | (d_u8 & 0xFF);
data.as_bytes()
.chunks(4)
.map(|x| {
x.iter().enumerate().fold(0u32, |acc, (i, x)| {
acc | (*x as i8 as u32 & 0xFF) << ((3 - i) * 8)
})
})
.collect()
}
/// Create a quantized tensor with packed values (u32).
fn packed_tensor<R: JitRuntime, S: Into<Shape<D>>, const D: usize>(
data: Vec<u32>,
shape: S,
device: &R::Device,
) -> JitTensor<R, u32, D> {
let client = R::client(device);
let buffer = client.create(u32::as_bytes(&data));
JitTensor::new_contiguous(client, device.clone(), shape.into(), buffer)
}
impl<R, F, I> QTensorOps<Self> for JitBackend<R, F, I> impl<R, F, I> QTensorOps<Self> for JitBackend<R, F, I>
where where
@ -15,22 +50,49 @@ where
I: IntElement, I: IntElement,
{ {
fn q_from_data<const D: usize>( fn q_from_data<const D: usize>(
_data: TensorData, data: TensorData,
_device: &Device<Self>, device: &Device<Self>,
) -> QuantizedTensor<Self, D> { ) -> QuantizedTensor<Self, D> {
todo!() match data.dtype {
DType::QFloat(strategy) => match strategy {
QuantizationStrategy::PerTensorAffineInt8(q) => {
// Convert quantized values to packed u32s
QJitTensor {
qtensor: packed_tensor(pack_i8s_to_u32s(&data), data.shape, device),
scheme: strategy.scheme(),
qparams: JitQuantizationParameters::new(
q.scale.elem(),
Some(q.offset.elem()),
device,
),
}
}
QuantizationStrategy::PerTensorSymmetricInt8(q) => {
// Convert quantized values to packed u32s
QJitTensor {
qtensor: packed_tensor(pack_i8s_to_u32s(&data), data.shape, device),
scheme: strategy.scheme(),
qparams: JitQuantizationParameters::new(q.scale.elem(), None, device),
}
}
},
_ => panic!(
"Invalid dtype (expected DType::QFloat, got {:?})",
data.dtype
),
}
} }
fn quantize<const D: usize>( fn quantize<const D: usize>(
_tensor: FloatTensor<Self, D>, tensor: FloatTensor<Self, D>,
_scheme: &QuantizationScheme, scheme: &QuantizationScheme,
_qparams: QuantizationParametersPrimitive<Self>, qparams: QuantizationParametersPrimitive<Self>,
) -> QuantizedTensor<Self, D> { ) -> QuantizedTensor<Self, D> {
unimplemented!() kernel::quantization::quantize(tensor, scheme, qparams.into())
} }
fn dequantize<const D: usize>(_tensor: QuantizedTensor<Self, D>) -> FloatTensor<Self, D> { fn dequantize<const D: usize>(tensor: QuantizedTensor<Self, D>) -> FloatTensor<Self, D> {
unimplemented!() kernel::quantization::dequantize(tensor)
} }
fn q_shape<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Shape<D> { fn q_shape<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Shape<D> {
@ -42,10 +104,12 @@ where
} }
fn q_to_device<const D: usize>( fn q_to_device<const D: usize>(
_tensor: QuantizedTensor<Self, D>, tensor: QuantizedTensor<Self, D>,
_device: &Device<Self>, device: &Device<Self>,
) -> QuantizedTensor<Self, D> { ) -> QuantizedTensor<Self, D> {
unimplemented!() let mut tensor = tensor;
tensor.qtensor = super::to_device(tensor.qtensor, device);
tensor
} }
fn q_reshape<const D1: usize, const D2: usize>( fn q_reshape<const D1: usize, const D2: usize>(
@ -55,11 +119,43 @@ where
QJitTensor { QJitTensor {
qtensor: super::reshape(tensor.qtensor, shape), qtensor: super::reshape(tensor.qtensor, shape),
scheme: tensor.scheme, scheme: tensor.scheme,
qparams: tensor.qparams,
} }
} }
async fn q_into_data<const D: usize>(_tensor: QuantizedTensor<Self, D>) -> TensorData { async fn q_into_data<const D: usize>(tensor: QuantizedTensor<Self, D>) -> TensorData {
unimplemented!() let strategy = tensor.strategy();
let numel = tensor.qtensor.shape.num_elements();
let qtensor = kernel::into_contiguous(tensor.qtensor);
let bytes = qtensor.client.read_async(qtensor.handle.binding()).await;
// Convert packed bytes to quantized dtype (TensorData can be used with other backends,
// which don't have the prior knowledge of this packed representation)
match &tensor.scheme {
QuantizationScheme::PerTensorAffine(dtype)
| QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
QuantizationType::QInt8 => TensorData::quantized(
u32::from_bytes(&bytes)
.iter()
.enumerate()
.flat_map(|(i, packed)| {
// A single u32 could contain less than four 8-bit values...
let n = core::cmp::min(4, numel - i * 4);
// Extract each 8-bit segment from u32 and cast back to i8
// Same as doing this (when 4 values are fully packed):
// let a = ((packed >> 24) & 0xFF) as i8;
// let b = ((packed >> 16) & 0xFF) as i8;
// let c = ((packed >> 8) & 0xFF) as i8;
// let d = (packed & 0xFF) as i8;
(0..n).map(move |i| (packed >> ((3 - i) * 8) & 0xFF) as i8)
})
.collect(),
qtensor.shape,
strategy,
),
},
}
} }
fn q_swap_dims<const D: usize>( fn q_swap_dims<const D: usize>(

View File

@ -1,34 +1,111 @@
use burn_tensor::quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy}; use burn_tensor::{
quantization::{
AffineQuantization, QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme,
QuantizationStrategy, QuantizationType, SymmetricQuantization,
},
read_sync, TensorData,
};
use crate::JitRuntime; use crate::{ops::into_data, FloatElement, IntElement, JitBackend, JitRuntime};
use super::JitTensor; use super::JitTensor;
/// A quantized tensor primitive. /// A quantized tensor primitive.
#[derive(Debug)] #[derive(Debug)]
pub struct QJitTensor<R: JitRuntime, const D: usize> { pub struct QJitTensor<R: JitRuntime, F: FloatElement, I: IntElement, const D: usize> {
/// The quantized tensor. /// The quantized tensor.
// TODO: implement `JitElement` / `CubeElement` for quantized type /// Values are stored as multiple packed quantized values in u32.
pub qtensor: JitTensor<R, u32, D>, pub qtensor: JitTensor<R, u32, D>,
/// The quantization scheme. /// The quantization scheme.
pub scheme: QuantizationScheme, pub scheme: QuantizationScheme,
/// The quantization parameters.
pub qparams: JitQuantizationParameters<R, F, I>,
} }
impl<R: JitRuntime, const D: usize> QTensorPrimitive for QJitTensor<R, D> { impl<R: JitRuntime, F: FloatElement, I: IntElement, const D: usize> QTensorPrimitive
for QJitTensor<R, F, I, D>
{
fn scheme(&self) -> &QuantizationScheme { fn scheme(&self) -> &QuantizationScheme {
&self.scheme &self.scheme
} }
fn strategy(&self) -> QuantizationStrategy { fn strategy(&self) -> QuantizationStrategy {
todo!() match &self.scheme {
QuantizationScheme::PerTensorAffine(dtype) => match dtype {
QuantizationType::QInt8 => {
let scale = read_sync(into_data(self.qparams.scale.clone()))
.iter()
.next()
.unwrap();
let offset = read_sync(into_data(self.qparams.offset.clone().unwrap()))
.iter()
.next()
.unwrap();
QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(
scale, offset,
))
}
},
QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
QuantizationType::QInt8 => {
let scale = read_sync(into_data(self.qparams.scale.clone()))
.iter()
.next()
.unwrap();
QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(scale))
}
},
}
} }
} }
impl<R: JitRuntime, const D: usize> Clone for QJitTensor<R, D> { impl<R: JitRuntime, F: FloatElement, I: IntElement, const D: usize> Clone
for QJitTensor<R, F, I, D>
{
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
qtensor: self.qtensor.clone(), qtensor: self.qtensor.clone(),
scheme: self.scheme.clone(), scheme: self.scheme.clone(),
qparams: self.qparams.clone(),
}
}
}
/// The quantization parameters.
#[derive(Debug)]
pub struct JitQuantizationParameters<R: JitRuntime, F: FloatElement, I: IntElement> {
/// The scaling factor.
pub scale: JitTensor<R, F, 1>,
/// The zero-point offset.
pub offset: Option<JitTensor<R, I, 1>>,
}
impl<R: JitRuntime, F: FloatElement, I: IntElement> Clone for JitQuantizationParameters<R, F, I> {
fn clone(&self) -> Self {
Self {
scale: self.scale.clone(),
offset: self.offset.clone(),
}
}
}
impl<R: JitRuntime, F: FloatElement, I: IntElement>
From<QuantizationParametersPrimitive<JitBackend<R, F, I>>>
for JitQuantizationParameters<R, F, I>
{
fn from(value: QuantizationParametersPrimitive<JitBackend<R, F, I>>) -> Self {
JitQuantizationParameters {
scale: value.scale,
offset: value.offset,
}
}
}
impl<R: JitRuntime, F: FloatElement, I: IntElement> JitQuantizationParameters<R, F, I> {
pub fn new(scale: F, offset: Option<I>, device: &R::Device) -> Self {
Self {
scale: crate::ops::from_data(TensorData::new(vec![scale], [1]), device),
offset: offset.map(|o| crate::ops::from_data(TensorData::new(vec![o], [1]), device)),
} }
} }
} }

View File

@ -16,6 +16,7 @@ mod matmul;
mod max_pool2d; mod max_pool2d;
mod max_pool2d_backward; mod max_pool2d_backward;
mod normal; mod normal;
mod quantization;
mod reduce; mod reduce;
mod repeat_dim; mod repeat_dim;
mod scatter; mod scatter;
@ -73,6 +74,8 @@ macro_rules! testgen_all {
burn_jit::testgen_cat!(); burn_jit::testgen_cat!();
burn_jit::testgen_clamp!(); burn_jit::testgen_clamp!();
burn_jit::testgen_unary!(); burn_jit::testgen_unary!();
burn_jit::testgen_quantization!();
} }
} }
mod jit_fusion { mod jit_fusion {
@ -100,6 +103,14 @@ macro_rules! testgen_jit {
burn_tensor::testgen_all!(); burn_tensor::testgen_all!();
burn_autodiff::testgen_all!(); burn_autodiff::testgen_all!();
// Not all ops are implemented for quantization yet, notably missing:
// `q_swap_dims`, `q_permute`, `q_flip`, `q_gather`, `q_select`, `q_slice`, `q_expand`
// burn_tensor::testgen_quantization!();
// test quantization
burn_tensor::testgen_calibration!();
burn_tensor::testgen_scheme!();
burn_tensor::testgen_quantize!();
}; };
} }

View File

@ -0,0 +1,82 @@
#[burn_tensor_testgen::testgen(quantization)]
mod tests {
use super::*;
use burn_tensor::{
quantization::{QuantizationScheme, QuantizationType},
Tensor,
};
#[test]
fn should_quantize_dequantize_symmetric_single() {
let scheme = QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8);
let input = Tensor::<TestBackend, 1>::from_floats([-1.8], &Default::default());
let input_ref =
Tensor::<ReferenceBackend, 1>::from_data(input.to_data(), &Default::default());
let output = input.quantize_dynamic(&scheme);
let output_ref = input_ref.quantize_dynamic(&scheme);
output.to_data().assert_eq(&output_ref.to_data(), false);
let output = output.dequantize();
let output_ref = output_ref.dequantize();
output.to_data().assert_approx_eq(&output_ref.to_data(), 3);
}
#[test]
fn should_quantize_dequantize_affine_single() {
let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8);
let input = Tensor::<TestBackend, 1>::from_floats([-1.8], &Default::default());
let input_ref =
Tensor::<ReferenceBackend, 1>::from_data(input.to_data(), &Default::default());
let output = input.quantize_dynamic(&scheme);
let output_ref = input_ref.quantize_dynamic(&scheme);
output.to_data().assert_eq(&output_ref.to_data(), false);
let output = output.dequantize();
let output_ref = output_ref.dequantize();
output.to_data().assert_approx_eq(&output_ref.to_data(), 2);
}
#[test]
fn should_quantize_dequantize_symmetric_multiple() {
let scheme = QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8);
let input =
Tensor::<TestBackend, 1>::from_floats([-1.8, -1.0, 0.0, 0.5, 0.0], &Default::default());
let input_ref =
Tensor::<ReferenceBackend, 1>::from_data(input.to_data(), &Default::default());
let output = input.quantize_dynamic(&scheme);
let output_ref = input_ref.quantize_dynamic(&scheme);
output.to_data().assert_eq(&output_ref.to_data(), false);
let output = output.dequantize();
let output_ref = output_ref.dequantize();
output.to_data().assert_approx_eq(&output_ref.to_data(), 3);
}
#[test]
fn should_quantize_dequantize_affine_multiple() {
let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8);
let input =
Tensor::<TestBackend, 1>::from_floats([-1.8, -1.0, 0.0, 0.5, 0.0], &Default::default());
let input_ref =
Tensor::<ReferenceBackend, 1>::from_data(input.to_data(), &Default::default());
let output = input.quantize_dynamic(&scheme);
let output_ref = input_ref.quantize_dynamic(&scheme);
output.to_data().assert_eq(&output_ref.to_data(), false);
let output = output.dequantize();
let output_ref = output_ref.dequantize();
output.to_data().assert_approx_eq(&output_ref.to_data(), 3);
}
}

View File

@ -361,22 +361,21 @@ impl TensorData {
DType::Bool => self.assert_eq_elem::<bool>(other), DType::Bool => self.assert_eq_elem::<bool>(other),
DType::QFloat(q) => { DType::QFloat(q) => {
// Strict or not, it doesn't make sense to compare quantized data to not quantized data for equality // Strict or not, it doesn't make sense to compare quantized data to not quantized data for equality
if let DType::QFloat(q_other) = other.dtype { let q_other = if let DType::QFloat(q_other) = other.dtype {
assert_eq!( q_other
q, q_other,
"Quantization strategies differ ({:?} != {:?})",
q, q_other
)
} else { } else {
panic!("Quantized data differs from other not quantized data") panic!("Quantized data differs from other not quantized data")
} };
match q { match (q, q_other) {
QuantizationStrategy::PerTensorAffineInt8(_) => { (
self.assert_eq_elem::<i8>(other) QuantizationStrategy::PerTensorAffineInt8(_),
} QuantizationStrategy::PerTensorAffineInt8(_),
QuantizationStrategy::PerTensorSymmetricInt8(_) => { ) => self.assert_eq_elem::<i8>(other),
self.assert_eq_elem::<i8>(other) (
} QuantizationStrategy::PerTensorSymmetricInt8(_),
QuantizationStrategy::PerTensorSymmetricInt8(_),
) => self.assert_eq_elem::<i8>(other),
_ => panic!("Quantization strategies differ ({:?} != {:?})", q, q_other),
} }
} }
} }

View File

@ -194,7 +194,7 @@ mod tests {
output output
.dequantize() .dequantize()
.into_data() .into_data()
.assert_eq(&TensorData::from([0.0]), false); .assert_approx_eq(&TensorData::from([0.0]), 5);
} }
#[test] #[test]

View File

@ -55,15 +55,6 @@ mod tests {
#[test] #[test]
fn should_support_dequantize() { fn should_support_dequantize() {
let device = Default::default(); let device = Default::default();
let tensor = Tensor::<TestBackend, 1>::from_floats([-1.8, -1.0, 0.0, 0.5], &device);
let scheme = QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8);
let qparams = QuantizationParameters {
scale: Tensor::from_floats([0.014_173_228], &device),
offset: None,
};
let x_q = tensor.quantize(&scheme, qparams);
// Quantized [-1.8, -1.0, 0.0, 0.5] // Quantized [-1.8, -1.0, 0.0, 0.5]
let data = TensorData::quantized( let data = TensorData::quantized(
vec![-127i8, -71, 0, 35], vec![-127i8, -71, 0, 35],
@ -97,6 +88,6 @@ mod tests {
QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.05882353, 42)), QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.05882353, 42)),
); );
x_q.to_data().assert_eq(&expected, true); x_q.to_data().assert_eq(&expected, false);
} }
} }