Rename per-tensor strategies

This commit is contained in:
Guillaume Lagrange 2024-06-06 13:32:04 -04:00
parent 9f25cec4a6
commit e3cadc4ad3
2 changed files with 12 additions and 11 deletions

View File

@ -550,7 +550,7 @@ mod tests {
#[test]
fn should_support_quantization() {
let quant = QuantizationStrategy::Int8Affine(AffineQuantization::new(-1.8, 0.5));
let quant = QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::new(-1.8, 0.5));
let data1 = Data::<f32, 2>::from([[-1.8, -1.0, 0.0, 0.5]]);
let data2 = data1.clone().with_quantization(quant.clone());

View File

@ -147,26 +147,26 @@ impl<E: Float, Q: PrimInt> Quantization<E, Q> for SymmetricQuantization<E, Q> {
/// Quantization scheme/strategy.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum QuantizationStrategy {
/// `int8` affine/asymmetric quantization.
Int8Affine(AffineQuantization<f32, i8, i32>),
/// `int8` symmetric quantization.
Int8Symmetric(SymmetricQuantization<f32, i8>),
/// Per-tensor `int8` affine/asymmetric quantization.
PerTensorAffineInt8(AffineQuantization<f32, i8, i32>),
/// Per-tensor `int8` symmetric quantization.
PerTensorSymmetricInt8(SymmetricQuantization<f32, i8>),
}
impl QuantizationStrategy {
/// Convert the values to a lower precision data type.
pub fn quantize(&self, values: &[f32]) -> Vec<i8> {
match self {
Self::Int8Affine(m) => m.quantize(values),
Self::Int8Symmetric(m) => m.quantize(values),
Self::PerTensorAffineInt8(m) => m.quantize(values),
Self::PerTensorSymmetricInt8(m) => m.quantize(values),
}
}
/// Convert the values back to a higher precision data type.
pub fn dequantize(&self, values: &[i8]) -> Vec<f32> {
match self {
Self::Int8Affine(m) => m.dequantize(values),
Self::Int8Symmetric(m) => m.dequantize(values),
Self::PerTensorAffineInt8(m) => m.dequantize(values),
Self::PerTensorSymmetricInt8(m) => m.dequantize(values),
}
}
}
@ -182,7 +182,7 @@ mod tests {
let expected_q = vec![-128, -39, 72, 127];
let expected_d = vec![-1.8039216, -1.0011765, 0.0, 0.49607843];
let affine = QuantizationStrategy::Int8Affine(AffineQuantization::new(-1.8, 0.5));
let affine = QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::new(-1.8, 0.5));
let q = affine.quantize(&x);
assert_eq!(q, expected_q);
@ -198,7 +198,8 @@ mod tests {
let expected_q = vec![-127, -71, 0, 35];
let expected_d = vec![-1.8, -1.0062993, 0.0, 0.496063];
let symmetric = QuantizationStrategy::Int8Symmetric(SymmetricQuantization::new(-1.8, 0.5));
let symmetric =
QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::new(-1.8, 0.5));
let q = symmetric.quantize(&x);
assert_eq!(q, expected_q);