diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs index 89e82a091..1480c07f3 100644 --- a/crates/burn-tensor/src/tensor/data.rs +++ b/crates/burn-tensor/src/tensor/data.rs @@ -550,7 +550,7 @@ mod tests { #[test] fn should_support_quantization() { - let quant = QuantizationStrategy::Int8Affine(AffineQuantization::new(-1.8, 0.5)); + let quant = QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::new(-1.8, 0.5)); let data1 = Data::::from([[-1.8, -1.0, 0.0, 0.5]]); let data2 = data1.clone().with_quantization(quant.clone()); diff --git a/crates/burn-tensor/src/tensor/quantization.rs b/crates/burn-tensor/src/tensor/quantization.rs index 0cbbe65a1..e08d78693 100644 --- a/crates/burn-tensor/src/tensor/quantization.rs +++ b/crates/burn-tensor/src/tensor/quantization.rs @@ -147,26 +147,26 @@ impl Quantization for SymmetricQuantization { /// Quantization scheme/strategy. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum QuantizationStrategy { - /// `int8` affine/asymmetric quantization. - Int8Affine(AffineQuantization), - /// `int8` symmetric quantization. - Int8Symmetric(SymmetricQuantization), + /// Per-tensor `int8` affine/asymmetric quantization. + PerTensorAffineInt8(AffineQuantization), + /// Per-tensor `int8` symmetric quantization. + PerTensorSymmetricInt8(SymmetricQuantization), } impl QuantizationStrategy { /// Convert the values to a lower precision data type. pub fn quantize(&self, values: &[f32]) -> Vec { match self { - Self::Int8Affine(m) => m.quantize(values), - Self::Int8Symmetric(m) => m.quantize(values), + Self::PerTensorAffineInt8(m) => m.quantize(values), + Self::PerTensorSymmetricInt8(m) => m.quantize(values), } } /// Convert the values back to a higher precision data type. pub fn dequantize(&self, values: &[i8]) -> Vec { match self { - Self::Int8Affine(m) => m.dequantize(values), - Self::Int8Symmetric(m) => m.dequantize(values), + Self::PerTensorAffineInt8(m) => m.dequantize(values), + Self::PerTensorSymmetricInt8(m) => m.dequantize(values), } } } @@ -182,7 +182,7 @@ mod tests { let expected_q = vec![-128, -39, 72, 127]; let expected_d = vec![-1.8039216, -1.0011765, 0.0, 0.49607843]; - let affine = QuantizationStrategy::Int8Affine(AffineQuantization::new(-1.8, 0.5)); + let affine = QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::new(-1.8, 0.5)); let q = affine.quantize(&x); assert_eq!(q, expected_q); @@ -198,7 +198,8 @@ mod tests { let expected_q = vec![-127, -71, 0, 35]; let expected_d = vec![-1.8, -1.0062993, 0.0, 0.496063]; - let symmetric = QuantizationStrategy::Int8Symmetric(SymmetricQuantization::new(-1.8, 0.5)); + let symmetric = + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::new(-1.8, 0.5)); let q = symmetric.quantize(&x); assert_eq!(q, expected_q);