mirror of https://github.com/tracel-ai/burn.git
Rename per-tensor strategies
This commit is contained in:
parent
9f25cec4a6
commit
e3cadc4ad3
|
@ -550,7 +550,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_support_quantization() {
|
fn should_support_quantization() {
|
||||||
let quant = QuantizationStrategy::Int8Affine(AffineQuantization::new(-1.8, 0.5));
|
let quant = QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::new(-1.8, 0.5));
|
||||||
let data1 = Data::<f32, 2>::from([[-1.8, -1.0, 0.0, 0.5]]);
|
let data1 = Data::<f32, 2>::from([[-1.8, -1.0, 0.0, 0.5]]);
|
||||||
let data2 = data1.clone().with_quantization(quant.clone());
|
let data2 = data1.clone().with_quantization(quant.clone());
|
||||||
|
|
||||||
|
|
|
@ -147,26 +147,26 @@ impl<E: Float, Q: PrimInt> Quantization<E, Q> for SymmetricQuantization<E, Q> {
|
||||||
/// Quantization scheme/strategy.
|
/// Quantization scheme/strategy.
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
pub enum QuantizationStrategy {
|
pub enum QuantizationStrategy {
|
||||||
/// `int8` affine/asymmetric quantization.
|
/// Per-tensor `int8` affine/asymmetric quantization.
|
||||||
Int8Affine(AffineQuantization<f32, i8, i32>),
|
PerTensorAffineInt8(AffineQuantization<f32, i8, i32>),
|
||||||
/// `int8` symmetric quantization.
|
/// Per-tensor `int8` symmetric quantization.
|
||||||
Int8Symmetric(SymmetricQuantization<f32, i8>),
|
PerTensorSymmetricInt8(SymmetricQuantization<f32, i8>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl QuantizationStrategy {
|
impl QuantizationStrategy {
|
||||||
/// Convert the values to a lower precision data type.
|
/// Convert the values to a lower precision data type.
|
||||||
pub fn quantize(&self, values: &[f32]) -> Vec<i8> {
|
pub fn quantize(&self, values: &[f32]) -> Vec<i8> {
|
||||||
match self {
|
match self {
|
||||||
Self::Int8Affine(m) => m.quantize(values),
|
Self::PerTensorAffineInt8(m) => m.quantize(values),
|
||||||
Self::Int8Symmetric(m) => m.quantize(values),
|
Self::PerTensorSymmetricInt8(m) => m.quantize(values),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert the values back to a higher precision data type.
|
/// Convert the values back to a higher precision data type.
|
||||||
pub fn dequantize(&self, values: &[i8]) -> Vec<f32> {
|
pub fn dequantize(&self, values: &[i8]) -> Vec<f32> {
|
||||||
match self {
|
match self {
|
||||||
Self::Int8Affine(m) => m.dequantize(values),
|
Self::PerTensorAffineInt8(m) => m.dequantize(values),
|
||||||
Self::Int8Symmetric(m) => m.dequantize(values),
|
Self::PerTensorSymmetricInt8(m) => m.dequantize(values),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -182,7 +182,7 @@ mod tests {
|
||||||
let expected_q = vec![-128, -39, 72, 127];
|
let expected_q = vec![-128, -39, 72, 127];
|
||||||
let expected_d = vec![-1.8039216, -1.0011765, 0.0, 0.49607843];
|
let expected_d = vec![-1.8039216, -1.0011765, 0.0, 0.49607843];
|
||||||
|
|
||||||
let affine = QuantizationStrategy::Int8Affine(AffineQuantization::new(-1.8, 0.5));
|
let affine = QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::new(-1.8, 0.5));
|
||||||
|
|
||||||
let q = affine.quantize(&x);
|
let q = affine.quantize(&x);
|
||||||
assert_eq!(q, expected_q);
|
assert_eq!(q, expected_q);
|
||||||
|
@ -198,7 +198,8 @@ mod tests {
|
||||||
let expected_q = vec![-127, -71, 0, 35];
|
let expected_q = vec![-127, -71, 0, 35];
|
||||||
let expected_d = vec![-1.8, -1.0062993, 0.0, 0.496063];
|
let expected_d = vec![-1.8, -1.0062993, 0.0, 0.496063];
|
||||||
|
|
||||||
let symmetric = QuantizationStrategy::Int8Symmetric(SymmetricQuantization::new(-1.8, 0.5));
|
let symmetric =
|
||||||
|
QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::new(-1.8, 0.5));
|
||||||
|
|
||||||
let q = symmetric.quantize(&x);
|
let q = symmetric.quantize(&x);
|
||||||
assert_eq!(q, expected_q);
|
assert_eq!(q, expected_q);
|
||||||
|
|
Loading…
Reference in New Issue