mirror of https://github.com/tracel-ai/burn.git
Rename per-tensor strategies
This commit is contained in:
parent
9f25cec4a6
commit
e3cadc4ad3
|
@ -550,7 +550,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn should_support_quantization() {
|
||||
let quant = QuantizationStrategy::Int8Affine(AffineQuantization::new(-1.8, 0.5));
|
||||
let quant = QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::new(-1.8, 0.5));
|
||||
let data1 = Data::<f32, 2>::from([[-1.8, -1.0, 0.0, 0.5]]);
|
||||
let data2 = data1.clone().with_quantization(quant.clone());
|
||||
|
||||
|
|
|
@ -147,26 +147,26 @@ impl<E: Float, Q: PrimInt> Quantization<E, Q> for SymmetricQuantization<E, Q> {
|
|||
/// Quantization scheme/strategy.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum QuantizationStrategy {
|
||||
/// `int8` affine/asymmetric quantization.
|
||||
Int8Affine(AffineQuantization<f32, i8, i32>),
|
||||
/// `int8` symmetric quantization.
|
||||
Int8Symmetric(SymmetricQuantization<f32, i8>),
|
||||
/// Per-tensor `int8` affine/asymmetric quantization.
|
||||
PerTensorAffineInt8(AffineQuantization<f32, i8, i32>),
|
||||
/// Per-tensor `int8` symmetric quantization.
|
||||
PerTensorSymmetricInt8(SymmetricQuantization<f32, i8>),
|
||||
}
|
||||
|
||||
impl QuantizationStrategy {
|
||||
/// Convert the values to a lower precision data type.
|
||||
pub fn quantize(&self, values: &[f32]) -> Vec<i8> {
|
||||
match self {
|
||||
Self::Int8Affine(m) => m.quantize(values),
|
||||
Self::Int8Symmetric(m) => m.quantize(values),
|
||||
Self::PerTensorAffineInt8(m) => m.quantize(values),
|
||||
Self::PerTensorSymmetricInt8(m) => m.quantize(values),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert the values back to a higher precision data type.
|
||||
pub fn dequantize(&self, values: &[i8]) -> Vec<f32> {
|
||||
match self {
|
||||
Self::Int8Affine(m) => m.dequantize(values),
|
||||
Self::Int8Symmetric(m) => m.dequantize(values),
|
||||
Self::PerTensorAffineInt8(m) => m.dequantize(values),
|
||||
Self::PerTensorSymmetricInt8(m) => m.dequantize(values),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -182,7 +182,7 @@ mod tests {
|
|||
let expected_q = vec![-128, -39, 72, 127];
|
||||
let expected_d = vec![-1.8039216, -1.0011765, 0.0, 0.49607843];
|
||||
|
||||
let affine = QuantizationStrategy::Int8Affine(AffineQuantization::new(-1.8, 0.5));
|
||||
let affine = QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::new(-1.8, 0.5));
|
||||
|
||||
let q = affine.quantize(&x);
|
||||
assert_eq!(q, expected_q);
|
||||
|
@ -198,7 +198,8 @@ mod tests {
|
|||
let expected_q = vec![-127, -71, 0, 35];
|
||||
let expected_d = vec![-1.8, -1.0062993, 0.0, 0.496063];
|
||||
|
||||
let symmetric = QuantizationStrategy::Int8Symmetric(SymmetricQuantization::new(-1.8, 0.5));
|
||||
let symmetric =
|
||||
QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::new(-1.8, 0.5));
|
||||
|
||||
let q = symmetric.quantize(&x);
|
||||
assert_eq!(q, expected_q);
|
||||
|
|
Loading…
Reference in New Issue