mirror of https://github.com/tracel-ai/burn.git
Add more quantization support for burn-jit (#2275)
* Add cubecl quantization kernels and QTensorOps for burn-jit * Fix typo * Fix output vec factor * Fix output dtype size_of * Remove unused code in dequantize test * Fix dequantize vectorization * Handle tensors when number of elems is not a multiple of 4 * Support quantize for tensors with less than 4 elems (no vectorization) * Fix equal 0 test * Add quantize/dequantize tests * Add q_to_device * Refactor kernels for latest cubecl * intermediate i32 cast * Fix size_of output type * Use strict=false to ignore floating point precision issues with qparams equality * Only check that lhs & rhs strategies match (but not strict on qparams values) * Use assert_approx_eq on dequant values * Reduce precision for flaky test * Remove todo comment * Add comment for cast to unsigned * More comment --------- Co-authored-by: louisfd <louisfd94@gmail.com>
This commit is contained in:
parent
834005eadb
commit
aa79e36a8d
|
@ -34,7 +34,8 @@ where
|
||||||
type FloatTensorPrimitive<const D: usize> = JitTensor<R, Self::FloatElem, D>;
|
type FloatTensorPrimitive<const D: usize> = JitTensor<R, Self::FloatElem, D>;
|
||||||
type IntTensorPrimitive<const D: usize> = JitTensor<R, Self::IntElem, D>;
|
type IntTensorPrimitive<const D: usize> = JitTensor<R, Self::IntElem, D>;
|
||||||
type BoolTensorPrimitive<const D: usize> = JitTensor<R, u32, D>;
|
type BoolTensorPrimitive<const D: usize> = JitTensor<R, u32, D>;
|
||||||
type QuantizedTensorPrimitive<const D: usize> = QJitTensor<R, D>;
|
type QuantizedTensorPrimitive<const D: usize> =
|
||||||
|
QJitTensor<R, Self::FloatElem, Self::IntElem, D>;
|
||||||
|
|
||||||
fn name() -> String {
|
fn name() -> String {
|
||||||
format!("jit<{}>", R::name())
|
format!("jit<{}>", R::name())
|
||||||
|
|
|
@ -25,6 +25,8 @@ pub mod matmul;
|
||||||
pub mod pool;
|
pub mod pool;
|
||||||
/// Pseudo-random number generator kernels
|
/// Pseudo-random number generator kernels
|
||||||
pub mod prng;
|
pub mod prng;
|
||||||
|
/// Quantization operations
|
||||||
|
pub mod quantization;
|
||||||
/// Reduction algorithms
|
/// Reduction algorithms
|
||||||
pub mod reduce;
|
pub mod reduce;
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,211 @@
|
||||||
|
use crate::tensor::{JitTensor, QJitTensor};
|
||||||
|
use crate::FloatElement;
|
||||||
|
use crate::{IntElement, JitElement, JitRuntime};
|
||||||
|
use burn_tensor::quantization::{QuantizationScheme, QuantizationType};
|
||||||
|
use cubecl::calculate_cube_count_elemwise;
|
||||||
|
use cubecl::prelude::*;
|
||||||
|
|
||||||
|
#[cube]
|
||||||
|
pub(crate) fn dequantize_affine_int8<F: Float>(value: i32, scale: F, offset: i32) -> F {
|
||||||
|
// x = scale * (x_q - offset)
|
||||||
|
scale * (F::cast_from(value) - F::cast_from(offset))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cube]
|
||||||
|
pub(crate) fn extract_i8(value: u32, offset: u32) -> i32 {
|
||||||
|
// Extract 8-bit segment
|
||||||
|
let value = (value >> offset) & 0xFF;
|
||||||
|
// Check if the value is negative by inspecting the MSB and subtract 256 if it is
|
||||||
|
// Subtract 0 or 256 to circumvent unsupported conditional assignment (let x = if {} else {};)
|
||||||
|
let sub = i32::cast_from(value & 0x80 != 0) * 256;
|
||||||
|
i32::cast_from(value) - sub
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cube(launch_unchecked)]
|
||||||
|
pub(crate) fn dequantize_per_tensor_affine_int8_kernel(
|
||||||
|
input: &Tensor<u32>,
|
||||||
|
scale: &Tensor<f32>,
|
||||||
|
offset: &Tensor<i32>,
|
||||||
|
output: &mut Tensor<f32>,
|
||||||
|
#[comptime] vectorized: bool,
|
||||||
|
) {
|
||||||
|
if ABSOLUTE_POS >= output.len() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let scale = scale[0];
|
||||||
|
let offset = offset[0];
|
||||||
|
|
||||||
|
let num_packed = 4;
|
||||||
|
let value = input[ABSOLUTE_POS];
|
||||||
|
let output_pos = ABSOLUTE_POS * num_packed;
|
||||||
|
|
||||||
|
if vectorized {
|
||||||
|
let vectorization_factor = vectorization_of(input);
|
||||||
|
#[unroll]
|
||||||
|
for i in 0..vectorization_factor {
|
||||||
|
// Extract each 8-bit segment
|
||||||
|
let v1 = extract_i8(value[i], 24);
|
||||||
|
let v2 = extract_i8(value[i], 16);
|
||||||
|
let v3 = extract_i8(value[i], 8);
|
||||||
|
let v4 = extract_i8(value[i], 0);
|
||||||
|
|
||||||
|
output[output_pos * vectorization_factor + i * num_packed] =
|
||||||
|
dequantize_affine_int8::<f32>(v1, scale, offset);
|
||||||
|
output[output_pos * vectorization_factor + i * num_packed + 1] =
|
||||||
|
dequantize_affine_int8::<f32>(v2, scale, offset);
|
||||||
|
output[output_pos * vectorization_factor + i * num_packed + 2] =
|
||||||
|
dequantize_affine_int8::<f32>(v3, scale, offset);
|
||||||
|
output[output_pos * vectorization_factor + i * num_packed + 3] =
|
||||||
|
dequantize_affine_int8::<f32>(v4, scale, offset);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Extract each 8-bit segment
|
||||||
|
let v1 = extract_i8(value, 24);
|
||||||
|
let v2 = extract_i8(value, 16);
|
||||||
|
let v3 = extract_i8(value, 8);
|
||||||
|
let v4 = extract_i8(value, 0);
|
||||||
|
|
||||||
|
output[output_pos] = dequantize_affine_int8::<f32>(v1, scale, offset);
|
||||||
|
output[output_pos + 1] = dequantize_affine_int8::<f32>(v2, scale, offset);
|
||||||
|
output[output_pos + 2] = dequantize_affine_int8::<f32>(v3, scale, offset);
|
||||||
|
output[output_pos + 3] = dequantize_affine_int8::<f32>(v4, scale, offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cube]
|
||||||
|
pub(crate) fn dequantize_symmetric_int8<F: Float>(value: i32, scale: F) -> F {
|
||||||
|
// x = scale * x_q
|
||||||
|
scale * F::cast_from(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Would have wrapped symmetric with the same affine kernel but cube doesn't support Option<Tensor> for offset.
|
||||||
|
#[cube(launch_unchecked)]
|
||||||
|
pub(crate) fn dequantize_per_tensor_symmetric_int8_kernel(
|
||||||
|
input: &Tensor<u32>,
|
||||||
|
scale: &Tensor<f32>,
|
||||||
|
output: &mut Tensor<f32>,
|
||||||
|
#[comptime] vectorized: bool,
|
||||||
|
) {
|
||||||
|
if ABSOLUTE_POS >= output.len() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let scale = scale[0];
|
||||||
|
|
||||||
|
let num_packed = 4;
|
||||||
|
let value = input[ABSOLUTE_POS];
|
||||||
|
let output_pos = ABSOLUTE_POS * num_packed;
|
||||||
|
|
||||||
|
if vectorized {
|
||||||
|
let vectorization_factor = vectorization_of(input);
|
||||||
|
#[unroll]
|
||||||
|
for i in 0..vectorization_factor {
|
||||||
|
for j in 0..num_packed {
|
||||||
|
let output_idx = output_pos * vectorization_factor + i * num_packed + j;
|
||||||
|
if output_idx >= output.len() {
|
||||||
|
return; // value not quantized (padding)
|
||||||
|
}
|
||||||
|
// Extract each 8-bit segment
|
||||||
|
let v = extract_i8(value[i], (3 - j) * 8);
|
||||||
|
output[output_idx] = dequantize_symmetric_int8::<f32>(v, scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Extract each 8-bit segment
|
||||||
|
for j in 0..num_packed {
|
||||||
|
let output_idx = output_pos + j;
|
||||||
|
if output_idx >= output.len() {
|
||||||
|
return; // value not quantized (padding)
|
||||||
|
}
|
||||||
|
// Extract each 8-bit segment
|
||||||
|
let v = extract_i8(value, (3 - j) * 8);
|
||||||
|
output[output_pos + j] = dequantize_symmetric_int8::<f32>(v, scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn dequantize_per_tensor<R, F, I, const D: usize>(
|
||||||
|
tensor: JitTensor<R, u32, D>,
|
||||||
|
scale: JitTensor<R, F, 1>,
|
||||||
|
offset: Option<JitTensor<R, I, 1>>,
|
||||||
|
) -> JitTensor<R, F, D>
|
||||||
|
where
|
||||||
|
R: JitRuntime,
|
||||||
|
F: JitElement,
|
||||||
|
I: IntElement,
|
||||||
|
{
|
||||||
|
// The actual number of elements is 1/4 (four int8 values packed in a single u32)
|
||||||
|
// so we choose a vectorization factor to match a valid input binding size.
|
||||||
|
let num_out_elems = tensor.shape.num_elements();
|
||||||
|
let num_elems = usize::div_ceil(num_out_elems, 4);
|
||||||
|
let vectorization_factor = [4u8, 2, 1]
|
||||||
|
.iter()
|
||||||
|
.filter_map(|&v| {
|
||||||
|
if num_elems >= v as usize {
|
||||||
|
Some(v)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.next()
|
||||||
|
.unwrap();
|
||||||
|
let cube_dim = CubeDim::default();
|
||||||
|
let cube_count =
|
||||||
|
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
|
||||||
|
|
||||||
|
let shape_output = tensor.shape.clone();
|
||||||
|
let client = tensor.client.clone();
|
||||||
|
let handle = client.empty(num_out_elems * core::mem::size_of::<F>());
|
||||||
|
let output =
|
||||||
|
JitTensor::new_contiguous(client.clone(), tensor.device.clone(), shape_output, handle);
|
||||||
|
|
||||||
|
let dummy_array = [1; D];
|
||||||
|
if let Some(offset) = offset {
|
||||||
|
unsafe {
|
||||||
|
dequantize_per_tensor_affine_int8_kernel::launch_unchecked::<R>(
|
||||||
|
&client,
|
||||||
|
cube_count,
|
||||||
|
cube_dim,
|
||||||
|
tensor.as_tensor_arg(vectorization_factor),
|
||||||
|
// Ignore shape and stride
|
||||||
|
TensorArg::from_raw_parts(&scale.handle, &dummy_array, &dummy_array, 1),
|
||||||
|
TensorArg::from_raw_parts(&offset.handle, &dummy_array, &dummy_array, 1),
|
||||||
|
output.as_tensor_arg(1),
|
||||||
|
vectorization_factor > 1,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
unsafe {
|
||||||
|
dequantize_per_tensor_symmetric_int8_kernel::launch_unchecked::<R>(
|
||||||
|
&client,
|
||||||
|
cube_count,
|
||||||
|
cube_dim,
|
||||||
|
tensor.as_tensor_arg(vectorization_factor),
|
||||||
|
// Ignore shape and stride
|
||||||
|
TensorArg::from_raw_parts(&scale.handle, &dummy_array, &dummy_array, 1),
|
||||||
|
output.as_tensor_arg(1),
|
||||||
|
vectorization_factor > 1,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert the tensor back to a higher precision data type.
|
||||||
|
pub fn dequantize<R, F, I, const D: usize>(tensor: QJitTensor<R, F, I, D>) -> JitTensor<R, F, D>
|
||||||
|
where
|
||||||
|
R: JitRuntime,
|
||||||
|
F: FloatElement,
|
||||||
|
I: IntElement,
|
||||||
|
{
|
||||||
|
match tensor.scheme {
|
||||||
|
QuantizationScheme::PerTensorAffine(dtype)
|
||||||
|
| QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
|
||||||
|
QuantizationType::QInt8 => {
|
||||||
|
dequantize_per_tensor(tensor.qtensor, tensor.qparams.scale, tensor.qparams.offset)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,5 @@
|
||||||
|
mod dequantize;
|
||||||
|
mod quantize;
|
||||||
|
|
||||||
|
pub use dequantize::*;
|
||||||
|
pub use quantize::*;
|
|
@ -0,0 +1,219 @@
|
||||||
|
use crate::tensor::{JitQuantizationParameters, JitTensor, QJitTensor};
|
||||||
|
use crate::FloatElement;
|
||||||
|
use crate::{IntElement, JitElement, JitRuntime};
|
||||||
|
use burn_tensor::quantization::{QuantizationScheme, QuantizationType};
|
||||||
|
use cubecl::calculate_cube_count_elemwise;
|
||||||
|
use cubecl::prelude::*;
|
||||||
|
|
||||||
|
#[cube]
|
||||||
|
pub(crate) fn quantize_affine_int8<F: Float>(
|
||||||
|
value: F,
|
||||||
|
scale: F,
|
||||||
|
offset: i32,
|
||||||
|
range_min: F,
|
||||||
|
range_max: F,
|
||||||
|
) -> u32 {
|
||||||
|
let offset = F::cast_from(offset);
|
||||||
|
|
||||||
|
// x_q = clamp(round(x / scale + offset), a, b)
|
||||||
|
// NOTE: we add 256 before casting to unsigned to correctly represent negative values
|
||||||
|
u32::cast_from(
|
||||||
|
i32::cast_from(F::clamp(
|
||||||
|
F::round((value / scale) + offset),
|
||||||
|
range_min,
|
||||||
|
range_max,
|
||||||
|
)) + 256,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cube(launch_unchecked)]
|
||||||
|
pub(crate) fn quantize_per_tensor_affine_int8_kernel(
|
||||||
|
input: &Tensor<f32>,
|
||||||
|
scale: &Tensor<f32>,
|
||||||
|
offset: &Tensor<i32>,
|
||||||
|
range_min: f32,
|
||||||
|
range_max: f32,
|
||||||
|
output: &mut Tensor<u32>,
|
||||||
|
#[comptime] vectorized: bool,
|
||||||
|
) {
|
||||||
|
if ABSOLUTE_POS >= output.len() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let scale = scale[0];
|
||||||
|
let offset = offset[0];
|
||||||
|
|
||||||
|
let num_packed = 4;
|
||||||
|
let mut v_packed = 0;
|
||||||
|
|
||||||
|
if vectorized {
|
||||||
|
// Assuming a vectorization factor of 4 (equal to the number of values packed)
|
||||||
|
let value = input[ABSOLUTE_POS];
|
||||||
|
let vectorization_factor = vectorization_of(input);
|
||||||
|
#[unroll]
|
||||||
|
for i in 0..vectorization_factor {
|
||||||
|
let v = quantize_affine_int8::<f32>(value[i], scale, offset, range_min, range_max);
|
||||||
|
// Shift and combine into u32
|
||||||
|
v_packed |= (v & 0xFF) << (8 * (num_packed - i - 1));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i in 0..num_packed {
|
||||||
|
let v = quantize_affine_int8::<f32>(
|
||||||
|
input[ABSOLUTE_POS + i],
|
||||||
|
scale,
|
||||||
|
offset,
|
||||||
|
range_min,
|
||||||
|
range_max,
|
||||||
|
);
|
||||||
|
// Shift and combine into u32
|
||||||
|
v_packed |= (v & 0xFF) << (8 * (num_packed - i - 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output[ABSOLUTE_POS] = v_packed;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cube]
|
||||||
|
pub(crate) fn quantize_symmetric_int8<F: Float>(
|
||||||
|
value: F,
|
||||||
|
scale: F,
|
||||||
|
range_min: F,
|
||||||
|
range_max: F,
|
||||||
|
) -> u32 {
|
||||||
|
// x_q = clamp(round(x / scale), a, b)
|
||||||
|
// NOTE: we add 256 before casting to unsigned to correctly represent negative values
|
||||||
|
u32::cast_from(i32::cast_from(F::clamp(F::round(value / scale), range_min, range_max)) + 256)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Would have wrapped symmetric with the same affine kernel but cube doesn't support Option<Tensor> for offset.
|
||||||
|
#[cube(launch_unchecked)]
|
||||||
|
pub(crate) fn quantize_per_tensor_symmetric_int8_kernel(
|
||||||
|
input: &Tensor<f32>,
|
||||||
|
scale: &Tensor<f32>,
|
||||||
|
range_min: f32,
|
||||||
|
range_max: f32,
|
||||||
|
output: &mut Tensor<u32>,
|
||||||
|
#[comptime] vectorized: bool,
|
||||||
|
) {
|
||||||
|
if ABSOLUTE_POS >= output.len() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let scale = scale[0];
|
||||||
|
|
||||||
|
let num_packed = 4;
|
||||||
|
let mut v_packed = 0;
|
||||||
|
|
||||||
|
if vectorized {
|
||||||
|
// Assuming a vectorization factor of 4 (equal to the number of values packed)
|
||||||
|
let value = input[ABSOLUTE_POS];
|
||||||
|
let vectorization_factor = vectorization_of(input);
|
||||||
|
#[unroll]
|
||||||
|
for i in 0..vectorization_factor {
|
||||||
|
let v = quantize_symmetric_int8::<f32>(value[i], scale, range_min, range_max);
|
||||||
|
// Shift and combine into u32
|
||||||
|
v_packed |= (v & 0xFF) << (8 * (num_packed - i - 1));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i in 0..num_packed {
|
||||||
|
let v = quantize_symmetric_int8::<f32>(
|
||||||
|
input[ABSOLUTE_POS + i],
|
||||||
|
scale,
|
||||||
|
range_min,
|
||||||
|
range_max,
|
||||||
|
);
|
||||||
|
// Shift and combine into u32
|
||||||
|
v_packed |= (v & 0xFF) << (8 * (num_packed - i - 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output[ABSOLUTE_POS] = v_packed;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn quantize_per_tensor<R, F, I, const D: usize>(
|
||||||
|
tensor: JitTensor<R, F, D>,
|
||||||
|
scale: JitTensor<R, F, 1>,
|
||||||
|
offset: Option<JitTensor<R, I, 1>>,
|
||||||
|
) -> JitTensor<R, u32, D>
|
||||||
|
where
|
||||||
|
R: JitRuntime,
|
||||||
|
F: JitElement,
|
||||||
|
I: IntElement,
|
||||||
|
{
|
||||||
|
let num_elems = tensor.shape.num_elements();
|
||||||
|
let shape_output = tensor.shape.clone();
|
||||||
|
let client = tensor.client.clone();
|
||||||
|
// Output tensor contains 4x less elements (four int8 values packed in a single u32)
|
||||||
|
let handle = client.empty(usize::div_ceil(num_elems, 4) * core::mem::size_of::<u32>());
|
||||||
|
let output =
|
||||||
|
JitTensor::new_contiguous(client.clone(), tensor.device.clone(), shape_output, handle);
|
||||||
|
|
||||||
|
// Force vectorization to process 4 quantized values packed for 1 output value
|
||||||
|
let vectorization_factor: u8 = if num_elems < 4 { 1 } else { 4 };
|
||||||
|
let cube_dim = CubeDim::default();
|
||||||
|
let cube_count =
|
||||||
|
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
|
||||||
|
|
||||||
|
let dummy_array = [1; D];
|
||||||
|
if let Some(offset) = offset {
|
||||||
|
unsafe {
|
||||||
|
quantize_per_tensor_affine_int8_kernel::launch_unchecked::<R>(
|
||||||
|
&client,
|
||||||
|
cube_count,
|
||||||
|
cube_dim,
|
||||||
|
tensor.as_tensor_arg(vectorization_factor),
|
||||||
|
// Ignore shape and stride
|
||||||
|
TensorArg::from_raw_parts(&scale.handle, &dummy_array, &dummy_array, 1),
|
||||||
|
TensorArg::from_raw_parts(&offset.handle, &dummy_array, &dummy_array, 1),
|
||||||
|
ScalarArg::new(i8::MIN as f32),
|
||||||
|
ScalarArg::new(i8::MAX as f32),
|
||||||
|
output.as_tensor_arg(1),
|
||||||
|
vectorization_factor > 1,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
unsafe {
|
||||||
|
quantize_per_tensor_symmetric_int8_kernel::launch_unchecked::<R>(
|
||||||
|
&client,
|
||||||
|
cube_count,
|
||||||
|
cube_dim,
|
||||||
|
tensor.as_tensor_arg(vectorization_factor),
|
||||||
|
// Ignore shape and stride
|
||||||
|
TensorArg::from_raw_parts(&scale.handle, &dummy_array, &dummy_array, 1),
|
||||||
|
ScalarArg::new(-i8::MAX as f32),
|
||||||
|
ScalarArg::new(i8::MAX as f32),
|
||||||
|
output.as_tensor_arg(1),
|
||||||
|
vectorization_factor > 1,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
|
||||||
|
pub fn quantize<R, F, I, const D: usize>(
|
||||||
|
tensor: JitTensor<R, F, D>,
|
||||||
|
scheme: &QuantizationScheme,
|
||||||
|
qparams: JitQuantizationParameters<R, F, I>,
|
||||||
|
) -> QJitTensor<R, F, I, D>
|
||||||
|
where
|
||||||
|
R: JitRuntime,
|
||||||
|
F: FloatElement,
|
||||||
|
I: IntElement,
|
||||||
|
{
|
||||||
|
let qtensor = match scheme {
|
||||||
|
QuantizationScheme::PerTensorAffine(dtype)
|
||||||
|
| QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
|
||||||
|
QuantizationType::QInt8 => {
|
||||||
|
quantize_per_tensor(tensor, qparams.scale.clone(), qparams.offset.clone())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
QJitTensor {
|
||||||
|
qtensor,
|
||||||
|
scheme: scheme.clone(),
|
||||||
|
qparams,
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,12 +1,47 @@
|
||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
|
|
||||||
|
use alloc::vec::Vec;
|
||||||
use burn_tensor::{
|
use burn_tensor::{
|
||||||
ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
|
ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
|
||||||
quantization::{QuantizationParametersPrimitive, QuantizationScheme},
|
quantization::{
|
||||||
Device, Shape, TensorData,
|
QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme,
|
||||||
|
QuantizationStrategy, QuantizationType,
|
||||||
|
},
|
||||||
|
DType, Device, ElementConversion, Shape, TensorData,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{tensor::QJitTensor, FloatElement, IntElement, JitBackend, JitRuntime};
|
use crate::{
|
||||||
|
kernel,
|
||||||
|
tensor::{JitQuantizationParameters, JitTensor, QJitTensor},
|
||||||
|
FloatElement, IntElement, JitBackend, JitRuntime,
|
||||||
|
};
|
||||||
|
use cubecl::CubeElement;
|
||||||
|
|
||||||
|
fn pack_i8s_to_u32s(data: &TensorData) -> Vec<u32> {
|
||||||
|
// Shift and combine groups of four 8-bit values into a u32.
|
||||||
|
// Same as doing this:
|
||||||
|
// let result = (a_u8 & 0xFF) << 24 | (b_u8 & 0xFF) << 16 | (c_u8 & 0xFF) << 8 | (d_u8 & 0xFF);
|
||||||
|
data.as_bytes()
|
||||||
|
.chunks(4)
|
||||||
|
.map(|x| {
|
||||||
|
x.iter().enumerate().fold(0u32, |acc, (i, x)| {
|
||||||
|
acc | (*x as i8 as u32 & 0xFF) << ((3 - i) * 8)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a quantized tensor with packed values (u32).
|
||||||
|
fn packed_tensor<R: JitRuntime, S: Into<Shape<D>>, const D: usize>(
|
||||||
|
data: Vec<u32>,
|
||||||
|
shape: S,
|
||||||
|
device: &R::Device,
|
||||||
|
) -> JitTensor<R, u32, D> {
|
||||||
|
let client = R::client(device);
|
||||||
|
let buffer = client.create(u32::as_bytes(&data));
|
||||||
|
|
||||||
|
JitTensor::new_contiguous(client, device.clone(), shape.into(), buffer)
|
||||||
|
}
|
||||||
|
|
||||||
impl<R, F, I> QTensorOps<Self> for JitBackend<R, F, I>
|
impl<R, F, I> QTensorOps<Self> for JitBackend<R, F, I>
|
||||||
where
|
where
|
||||||
|
@ -15,22 +50,49 @@ where
|
||||||
I: IntElement,
|
I: IntElement,
|
||||||
{
|
{
|
||||||
fn q_from_data<const D: usize>(
|
fn q_from_data<const D: usize>(
|
||||||
_data: TensorData,
|
data: TensorData,
|
||||||
_device: &Device<Self>,
|
device: &Device<Self>,
|
||||||
) -> QuantizedTensor<Self, D> {
|
) -> QuantizedTensor<Self, D> {
|
||||||
todo!()
|
match data.dtype {
|
||||||
|
DType::QFloat(strategy) => match strategy {
|
||||||
|
QuantizationStrategy::PerTensorAffineInt8(q) => {
|
||||||
|
// Convert quantized values to packed u32s
|
||||||
|
QJitTensor {
|
||||||
|
qtensor: packed_tensor(pack_i8s_to_u32s(&data), data.shape, device),
|
||||||
|
scheme: strategy.scheme(),
|
||||||
|
qparams: JitQuantizationParameters::new(
|
||||||
|
q.scale.elem(),
|
||||||
|
Some(q.offset.elem()),
|
||||||
|
device,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
QuantizationStrategy::PerTensorSymmetricInt8(q) => {
|
||||||
|
// Convert quantized values to packed u32s
|
||||||
|
QJitTensor {
|
||||||
|
qtensor: packed_tensor(pack_i8s_to_u32s(&data), data.shape, device),
|
||||||
|
scheme: strategy.scheme(),
|
||||||
|
qparams: JitQuantizationParameters::new(q.scale.elem(), None, device),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => panic!(
|
||||||
|
"Invalid dtype (expected DType::QFloat, got {:?})",
|
||||||
|
data.dtype
|
||||||
|
),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn quantize<const D: usize>(
|
fn quantize<const D: usize>(
|
||||||
_tensor: FloatTensor<Self, D>,
|
tensor: FloatTensor<Self, D>,
|
||||||
_scheme: &QuantizationScheme,
|
scheme: &QuantizationScheme,
|
||||||
_qparams: QuantizationParametersPrimitive<Self>,
|
qparams: QuantizationParametersPrimitive<Self>,
|
||||||
) -> QuantizedTensor<Self, D> {
|
) -> QuantizedTensor<Self, D> {
|
||||||
unimplemented!()
|
kernel::quantization::quantize(tensor, scheme, qparams.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn dequantize<const D: usize>(_tensor: QuantizedTensor<Self, D>) -> FloatTensor<Self, D> {
|
fn dequantize<const D: usize>(tensor: QuantizedTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||||
unimplemented!()
|
kernel::quantization::dequantize(tensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn q_shape<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Shape<D> {
|
fn q_shape<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Shape<D> {
|
||||||
|
@ -42,10 +104,12 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
fn q_to_device<const D: usize>(
|
fn q_to_device<const D: usize>(
|
||||||
_tensor: QuantizedTensor<Self, D>,
|
tensor: QuantizedTensor<Self, D>,
|
||||||
_device: &Device<Self>,
|
device: &Device<Self>,
|
||||||
) -> QuantizedTensor<Self, D> {
|
) -> QuantizedTensor<Self, D> {
|
||||||
unimplemented!()
|
let mut tensor = tensor;
|
||||||
|
tensor.qtensor = super::to_device(tensor.qtensor, device);
|
||||||
|
tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
fn q_reshape<const D1: usize, const D2: usize>(
|
fn q_reshape<const D1: usize, const D2: usize>(
|
||||||
|
@ -55,11 +119,43 @@ where
|
||||||
QJitTensor {
|
QJitTensor {
|
||||||
qtensor: super::reshape(tensor.qtensor, shape),
|
qtensor: super::reshape(tensor.qtensor, shape),
|
||||||
scheme: tensor.scheme,
|
scheme: tensor.scheme,
|
||||||
|
qparams: tensor.qparams,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn q_into_data<const D: usize>(_tensor: QuantizedTensor<Self, D>) -> TensorData {
|
async fn q_into_data<const D: usize>(tensor: QuantizedTensor<Self, D>) -> TensorData {
|
||||||
unimplemented!()
|
let strategy = tensor.strategy();
|
||||||
|
let numel = tensor.qtensor.shape.num_elements();
|
||||||
|
let qtensor = kernel::into_contiguous(tensor.qtensor);
|
||||||
|
|
||||||
|
let bytes = qtensor.client.read_async(qtensor.handle.binding()).await;
|
||||||
|
|
||||||
|
// Convert packed bytes to quantized dtype (TensorData can be used with other backends,
|
||||||
|
// which don't have the prior knowledge of this packed representation)
|
||||||
|
match &tensor.scheme {
|
||||||
|
QuantizationScheme::PerTensorAffine(dtype)
|
||||||
|
| QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
|
||||||
|
QuantizationType::QInt8 => TensorData::quantized(
|
||||||
|
u32::from_bytes(&bytes)
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.flat_map(|(i, packed)| {
|
||||||
|
// A single u32 could contain less than four 8-bit values...
|
||||||
|
let n = core::cmp::min(4, numel - i * 4);
|
||||||
|
// Extract each 8-bit segment from u32 and cast back to i8
|
||||||
|
// Same as doing this (when 4 values are fully packed):
|
||||||
|
// let a = ((packed >> 24) & 0xFF) as i8;
|
||||||
|
// let b = ((packed >> 16) & 0xFF) as i8;
|
||||||
|
// let c = ((packed >> 8) & 0xFF) as i8;
|
||||||
|
// let d = (packed & 0xFF) as i8;
|
||||||
|
(0..n).map(move |i| (packed >> ((3 - i) * 8) & 0xFF) as i8)
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
qtensor.shape,
|
||||||
|
strategy,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn q_swap_dims<const D: usize>(
|
fn q_swap_dims<const D: usize>(
|
||||||
|
|
|
@ -1,34 +1,111 @@
|
||||||
use burn_tensor::quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy};
|
use burn_tensor::{
|
||||||
|
quantization::{
|
||||||
|
AffineQuantization, QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme,
|
||||||
|
QuantizationStrategy, QuantizationType, SymmetricQuantization,
|
||||||
|
},
|
||||||
|
read_sync, TensorData,
|
||||||
|
};
|
||||||
|
|
||||||
use crate::JitRuntime;
|
use crate::{ops::into_data, FloatElement, IntElement, JitBackend, JitRuntime};
|
||||||
|
|
||||||
use super::JitTensor;
|
use super::JitTensor;
|
||||||
|
|
||||||
/// A quantized tensor primitive.
|
/// A quantized tensor primitive.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct QJitTensor<R: JitRuntime, const D: usize> {
|
pub struct QJitTensor<R: JitRuntime, F: FloatElement, I: IntElement, const D: usize> {
|
||||||
/// The quantized tensor.
|
/// The quantized tensor.
|
||||||
// TODO: implement `JitElement` / `CubeElement` for quantized type
|
/// Values are stored as multiple packed quantized values in u32.
|
||||||
pub qtensor: JitTensor<R, u32, D>,
|
pub qtensor: JitTensor<R, u32, D>,
|
||||||
/// The quantization scheme.
|
/// The quantization scheme.
|
||||||
pub scheme: QuantizationScheme,
|
pub scheme: QuantizationScheme,
|
||||||
|
/// The quantization parameters.
|
||||||
|
pub qparams: JitQuantizationParameters<R, F, I>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R: JitRuntime, const D: usize> QTensorPrimitive for QJitTensor<R, D> {
|
impl<R: JitRuntime, F: FloatElement, I: IntElement, const D: usize> QTensorPrimitive
|
||||||
|
for QJitTensor<R, F, I, D>
|
||||||
|
{
|
||||||
fn scheme(&self) -> &QuantizationScheme {
|
fn scheme(&self) -> &QuantizationScheme {
|
||||||
&self.scheme
|
&self.scheme
|
||||||
}
|
}
|
||||||
|
|
||||||
fn strategy(&self) -> QuantizationStrategy {
|
fn strategy(&self) -> QuantizationStrategy {
|
||||||
todo!()
|
match &self.scheme {
|
||||||
|
QuantizationScheme::PerTensorAffine(dtype) => match dtype {
|
||||||
|
QuantizationType::QInt8 => {
|
||||||
|
let scale = read_sync(into_data(self.qparams.scale.clone()))
|
||||||
|
.iter()
|
||||||
|
.next()
|
||||||
|
.unwrap();
|
||||||
|
let offset = read_sync(into_data(self.qparams.offset.clone().unwrap()))
|
||||||
|
.iter()
|
||||||
|
.next()
|
||||||
|
.unwrap();
|
||||||
|
QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(
|
||||||
|
scale, offset,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
|
||||||
|
QuantizationType::QInt8 => {
|
||||||
|
let scale = read_sync(into_data(self.qparams.scale.clone()))
|
||||||
|
.iter()
|
||||||
|
.next()
|
||||||
|
.unwrap();
|
||||||
|
QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(scale))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R: JitRuntime, const D: usize> Clone for QJitTensor<R, D> {
|
impl<R: JitRuntime, F: FloatElement, I: IntElement, const D: usize> Clone
|
||||||
|
for QJitTensor<R, F, I, D>
|
||||||
|
{
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
Self {
|
Self {
|
||||||
qtensor: self.qtensor.clone(),
|
qtensor: self.qtensor.clone(),
|
||||||
scheme: self.scheme.clone(),
|
scheme: self.scheme.clone(),
|
||||||
|
qparams: self.qparams.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The quantization parameters.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct JitQuantizationParameters<R: JitRuntime, F: FloatElement, I: IntElement> {
|
||||||
|
/// The scaling factor.
|
||||||
|
pub scale: JitTensor<R, F, 1>,
|
||||||
|
/// The zero-point offset.
|
||||||
|
pub offset: Option<JitTensor<R, I, 1>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: JitRuntime, F: FloatElement, I: IntElement> Clone for JitQuantizationParameters<R, F, I> {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
scale: self.scale.clone(),
|
||||||
|
offset: self.offset.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: JitRuntime, F: FloatElement, I: IntElement>
|
||||||
|
From<QuantizationParametersPrimitive<JitBackend<R, F, I>>>
|
||||||
|
for JitQuantizationParameters<R, F, I>
|
||||||
|
{
|
||||||
|
fn from(value: QuantizationParametersPrimitive<JitBackend<R, F, I>>) -> Self {
|
||||||
|
JitQuantizationParameters {
|
||||||
|
scale: value.scale,
|
||||||
|
offset: value.offset,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: JitRuntime, F: FloatElement, I: IntElement> JitQuantizationParameters<R, F, I> {
|
||||||
|
pub fn new(scale: F, offset: Option<I>, device: &R::Device) -> Self {
|
||||||
|
Self {
|
||||||
|
scale: crate::ops::from_data(TensorData::new(vec![scale], [1]), device),
|
||||||
|
offset: offset.map(|o| crate::ops::from_data(TensorData::new(vec![o], [1]), device)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@ mod matmul;
|
||||||
mod max_pool2d;
|
mod max_pool2d;
|
||||||
mod max_pool2d_backward;
|
mod max_pool2d_backward;
|
||||||
mod normal;
|
mod normal;
|
||||||
|
mod quantization;
|
||||||
mod reduce;
|
mod reduce;
|
||||||
mod repeat_dim;
|
mod repeat_dim;
|
||||||
mod scatter;
|
mod scatter;
|
||||||
|
@ -73,6 +74,8 @@ macro_rules! testgen_all {
|
||||||
burn_jit::testgen_cat!();
|
burn_jit::testgen_cat!();
|
||||||
burn_jit::testgen_clamp!();
|
burn_jit::testgen_clamp!();
|
||||||
burn_jit::testgen_unary!();
|
burn_jit::testgen_unary!();
|
||||||
|
|
||||||
|
burn_jit::testgen_quantization!();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
mod jit_fusion {
|
mod jit_fusion {
|
||||||
|
@ -100,6 +103,14 @@ macro_rules! testgen_jit {
|
||||||
|
|
||||||
burn_tensor::testgen_all!();
|
burn_tensor::testgen_all!();
|
||||||
burn_autodiff::testgen_all!();
|
burn_autodiff::testgen_all!();
|
||||||
|
|
||||||
|
// Not all ops are implemented for quantization yet, notably missing:
|
||||||
|
// `q_swap_dims`, `q_permute`, `q_flip`, `q_gather`, `q_select`, `q_slice`, `q_expand`
|
||||||
|
// burn_tensor::testgen_quantization!();
|
||||||
|
// test quantization
|
||||||
|
burn_tensor::testgen_calibration!();
|
||||||
|
burn_tensor::testgen_scheme!();
|
||||||
|
burn_tensor::testgen_quantize!();
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,82 @@
|
||||||
|
#[burn_tensor_testgen::testgen(quantization)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn_tensor::{
|
||||||
|
quantization::{QuantizationScheme, QuantizationType},
|
||||||
|
Tensor,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_quantize_dequantize_symmetric_single() {
|
||||||
|
let scheme = QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8);
|
||||||
|
let input = Tensor::<TestBackend, 1>::from_floats([-1.8], &Default::default());
|
||||||
|
let input_ref =
|
||||||
|
Tensor::<ReferenceBackend, 1>::from_data(input.to_data(), &Default::default());
|
||||||
|
|
||||||
|
let output = input.quantize_dynamic(&scheme);
|
||||||
|
let output_ref = input_ref.quantize_dynamic(&scheme);
|
||||||
|
|
||||||
|
output.to_data().assert_eq(&output_ref.to_data(), false);
|
||||||
|
|
||||||
|
let output = output.dequantize();
|
||||||
|
let output_ref = output_ref.dequantize();
|
||||||
|
|
||||||
|
output.to_data().assert_approx_eq(&output_ref.to_data(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_quantize_dequantize_affine_single() {
|
||||||
|
let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8);
|
||||||
|
let input = Tensor::<TestBackend, 1>::from_floats([-1.8], &Default::default());
|
||||||
|
let input_ref =
|
||||||
|
Tensor::<ReferenceBackend, 1>::from_data(input.to_data(), &Default::default());
|
||||||
|
|
||||||
|
let output = input.quantize_dynamic(&scheme);
|
||||||
|
let output_ref = input_ref.quantize_dynamic(&scheme);
|
||||||
|
|
||||||
|
output.to_data().assert_eq(&output_ref.to_data(), false);
|
||||||
|
|
||||||
|
let output = output.dequantize();
|
||||||
|
let output_ref = output_ref.dequantize();
|
||||||
|
|
||||||
|
output.to_data().assert_approx_eq(&output_ref.to_data(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_quantize_dequantize_symmetric_multiple() {
|
||||||
|
let scheme = QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8);
|
||||||
|
let input =
|
||||||
|
Tensor::<TestBackend, 1>::from_floats([-1.8, -1.0, 0.0, 0.5, 0.0], &Default::default());
|
||||||
|
let input_ref =
|
||||||
|
Tensor::<ReferenceBackend, 1>::from_data(input.to_data(), &Default::default());
|
||||||
|
|
||||||
|
let output = input.quantize_dynamic(&scheme);
|
||||||
|
let output_ref = input_ref.quantize_dynamic(&scheme);
|
||||||
|
|
||||||
|
output.to_data().assert_eq(&output_ref.to_data(), false);
|
||||||
|
|
||||||
|
let output = output.dequantize();
|
||||||
|
let output_ref = output_ref.dequantize();
|
||||||
|
|
||||||
|
output.to_data().assert_approx_eq(&output_ref.to_data(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_quantize_dequantize_affine_multiple() {
|
||||||
|
let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8);
|
||||||
|
let input =
|
||||||
|
Tensor::<TestBackend, 1>::from_floats([-1.8, -1.0, 0.0, 0.5, 0.0], &Default::default());
|
||||||
|
let input_ref =
|
||||||
|
Tensor::<ReferenceBackend, 1>::from_data(input.to_data(), &Default::default());
|
||||||
|
|
||||||
|
let output = input.quantize_dynamic(&scheme);
|
||||||
|
let output_ref = input_ref.quantize_dynamic(&scheme);
|
||||||
|
|
||||||
|
output.to_data().assert_eq(&output_ref.to_data(), false);
|
||||||
|
|
||||||
|
let output = output.dequantize();
|
||||||
|
let output_ref = output_ref.dequantize();
|
||||||
|
|
||||||
|
output.to_data().assert_approx_eq(&output_ref.to_data(), 3);
|
||||||
|
}
|
||||||
|
}
|
|
@ -361,22 +361,21 @@ impl TensorData {
|
||||||
DType::Bool => self.assert_eq_elem::<bool>(other),
|
DType::Bool => self.assert_eq_elem::<bool>(other),
|
||||||
DType::QFloat(q) => {
|
DType::QFloat(q) => {
|
||||||
// Strict or not, it doesn't make sense to compare quantized data to not quantized data for equality
|
// Strict or not, it doesn't make sense to compare quantized data to not quantized data for equality
|
||||||
if let DType::QFloat(q_other) = other.dtype {
|
let q_other = if let DType::QFloat(q_other) = other.dtype {
|
||||||
assert_eq!(
|
q_other
|
||||||
q, q_other,
|
|
||||||
"Quantization strategies differ ({:?} != {:?})",
|
|
||||||
q, q_other
|
|
||||||
)
|
|
||||||
} else {
|
} else {
|
||||||
panic!("Quantized data differs from other not quantized data")
|
panic!("Quantized data differs from other not quantized data")
|
||||||
}
|
};
|
||||||
match q {
|
match (q, q_other) {
|
||||||
QuantizationStrategy::PerTensorAffineInt8(_) => {
|
(
|
||||||
self.assert_eq_elem::<i8>(other)
|
QuantizationStrategy::PerTensorAffineInt8(_),
|
||||||
}
|
QuantizationStrategy::PerTensorAffineInt8(_),
|
||||||
QuantizationStrategy::PerTensorSymmetricInt8(_) => {
|
) => self.assert_eq_elem::<i8>(other),
|
||||||
self.assert_eq_elem::<i8>(other)
|
(
|
||||||
}
|
QuantizationStrategy::PerTensorSymmetricInt8(_),
|
||||||
|
QuantizationStrategy::PerTensorSymmetricInt8(_),
|
||||||
|
) => self.assert_eq_elem::<i8>(other),
|
||||||
|
_ => panic!("Quantization strategies differ ({:?} != {:?})", q, q_other),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -194,7 +194,7 @@ mod tests {
|
||||||
output
|
output
|
||||||
.dequantize()
|
.dequantize()
|
||||||
.into_data()
|
.into_data()
|
||||||
.assert_eq(&TensorData::from([0.0]), false);
|
.assert_approx_eq(&TensorData::from([0.0]), 5);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
@ -55,15 +55,6 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn should_support_dequantize() {
|
fn should_support_dequantize() {
|
||||||
let device = Default::default();
|
let device = Default::default();
|
||||||
let tensor = Tensor::<TestBackend, 1>::from_floats([-1.8, -1.0, 0.0, 0.5], &device);
|
|
||||||
let scheme = QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8);
|
|
||||||
let qparams = QuantizationParameters {
|
|
||||||
scale: Tensor::from_floats([0.014_173_228], &device),
|
|
||||||
offset: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let x_q = tensor.quantize(&scheme, qparams);
|
|
||||||
|
|
||||||
// Quantized [-1.8, -1.0, 0.0, 0.5]
|
// Quantized [-1.8, -1.0, 0.0, 0.5]
|
||||||
let data = TensorData::quantized(
|
let data = TensorData::quantized(
|
||||||
vec![-127i8, -71, 0, 35],
|
vec![-127i8, -71, 0, 35],
|
||||||
|
@ -97,6 +88,6 @@ mod tests {
|
||||||
QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.05882353, 42)),
|
QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.05882353, 42)),
|
||||||
);
|
);
|
||||||
|
|
||||||
x_q.to_data().assert_eq(&expected, true);
|
x_q.to_data().assert_eq(&expected, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue