forked from OSchip/llvm-project
Remove the constraint that min / max should stride zero
Since we apply nudging for the zero point to make sure the nudged zerop points can be in the range of [qmin, qmax], the constraint that rmin / rmax should stride zero isn't necessary. This also matches the documentation of tensorflow's FakeQuantWithMinMaxArgs op, where min and max don't need to stride zero: https://www.tensorflow.org/api_docs/python/tf/quantization/fake_quant_with_min_max_args PiperOrigin-RevId: 268296285
This commit is contained in:
parent
c68d5467d6
commit
cf0a782339
|
@ -54,8 +54,17 @@ bool getDefaultStorageParams(unsigned numBits, bool narrowRange, bool isSigned,
|
|||
return false;
|
||||
}
|
||||
|
||||
void getScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, double rmax,
|
||||
double &scale, int64_t &nudgedZeroPoint) {
|
||||
// This is a specific implementation of nudging:
|
||||
// If 0.0 < rmin < rmax or rmin < rmax < 0.0, the range will be shifted
|
||||
// to include 0.0, but the range width size (rmax-rmin) isn't changed. The zero
|
||||
// point is derived from the shifted range, and the scale isn't changed. As
|
||||
// a consequence some values, which are supposeed in the original [rmin, rmax]
|
||||
// range will be outside the shifted range and be clamped during quantization.
|
||||
// TODO(fengliuai): we should nudge the scale as well, but that requires the
|
||||
// fake quant op used in the training to use the nudged scale as well.
|
||||
void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin,
|
||||
double rmax, double &scale,
|
||||
int64_t &nudgedZeroPoint) {
|
||||
// Determine the scale.
|
||||
const double qminDouble = qmin;
|
||||
const double qmaxDouble = qmax;
|
||||
|
@ -100,14 +109,6 @@ UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits,
|
|||
double rmin, double rmax,
|
||||
bool narrowRange, Type expressedType,
|
||||
bool isSigned) {
|
||||
// Range must straddle zero.
|
||||
// TODO(b/140641593): remove this constraint.
|
||||
if (rmin > 0.0 || rmax < 0.0) {
|
||||
return (emitError(loc, "FakeQuant range must straddle zero: [")
|
||||
<< rmin << "," << rmax << "]",
|
||||
nullptr);
|
||||
}
|
||||
|
||||
MLIRContext *ctx = expressedType.getContext();
|
||||
unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
|
||||
Type storageType;
|
||||
|
@ -129,7 +130,7 @@ UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits,
|
|||
|
||||
double scale;
|
||||
int64_t nudgedZeroPoint;
|
||||
getScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
|
||||
getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
|
||||
|
||||
return UniformQuantizedType::getChecked(flags, storageType, expressedType,
|
||||
scale, nudgedZeroPoint, qmin, qmax,
|
||||
|
@ -172,7 +173,7 @@ fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension,
|
|||
|
||||
double scale;
|
||||
int64_t nudgedZeroPoint;
|
||||
getScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
|
||||
getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
|
||||
scales.push_back(scale);
|
||||
zeroPoints.push_back(nudgedZeroPoint);
|
||||
}
|
||||
|
|
|
@ -1,27 +1,5 @@
|
|||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -quant-convert-simulated-quantization
|
||||
|
||||
// -----
|
||||
// Verify that a mismatched range errors.
|
||||
func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
||||
^bb0(%arg0: tensor<8x4x3xf32>):
|
||||
// expected-error@+1 {{FakeQuant range must straddle zero: [1.100000e+00,1.500000e+00]}}
|
||||
%0 = "quant.const_fake_quant"(%arg0) {
|
||||
min = 1.1 : f32, max = 1.5 : f32, num_bits = 8
|
||||
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
|
||||
return %0 : tensor<8x4x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// Verify that a valid range errors.
|
||||
func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
||||
^bb0(%arg0: tensor<8x4x3xf32>):
|
||||
// expected-error@+1 {{FakeQuant range must straddle zero: [1.100000e+00,1.000000e+00}}
|
||||
%0 = "quant.const_fake_quant"(%arg0) {
|
||||
min = 1.1 : f32, max = 1.0 : f32, num_bits = 8
|
||||
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
|
||||
return %0 : tensor<8x4x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// Unsupported quantizable type (i1 is currently not a supported element type).
|
||||
func @fakeQuantArgs(tensor<8x4x3xi1>) -> tensor<8x4x3xi1> {
|
||||
|
|
|
@ -47,7 +47,7 @@ func @fakeQuantArgs_Quint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
|||
|
||||
// -----
|
||||
// Verifies a quint8 asymmetric 0..1 range (with narrow_range = true).
|
||||
// CHECK_LABEL: fakeQuantArgs_Quint8_NarrowRange
|
||||
// CHECK-LABEL: fakeQuantArgs_Quint8_NarrowRange
|
||||
func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
||||
^bb0(%arg0: tensor<8x4x3xf32>):
|
||||
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
|
||||
|
@ -62,7 +62,7 @@ func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
|||
|
||||
// -----
|
||||
// Verifies a quint8 symmetric range of -1..127/128.
|
||||
// CHECK_LABEL: fakeQuantArgs_Quint8_SymmetricRange
|
||||
// CHECK-LABEL: fakeQuantArgs_Quint8_SymmetricRange
|
||||
func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
||||
^bb0(%arg0: tensor<8x4x3xf32>):
|
||||
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
|
||||
|
@ -122,7 +122,7 @@ func @fakeQuantArgs_Qint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
|||
|
||||
// -----
|
||||
// Verifies a qint8 asymmetric 0..1 range (with narrow_range = true).
|
||||
// CHECK_LABEL: fakeQuantArgs_Qint8_NarrowRange
|
||||
// CHECK-LABEL: fakeQuantArgs_Qint8_NarrowRange
|
||||
func @fakeQuantArgs_Qint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
||||
^bb0(%arg0: tensor<8x4x3xf32>):
|
||||
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
|
||||
|
@ -137,7 +137,7 @@ func @fakeQuantArgs_Qint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
|||
|
||||
// -----
|
||||
// Verifies a qint8 symmetric range of -1..127/128.
|
||||
// CHECK_LABEL: fakeQuantArgs_Qint8_SymmetricRange
|
||||
// CHECK-LABEL: fakeQuantArgs_Qint8_SymmetricRange
|
||||
func @fakeQuantArgs_Qint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
||||
^bb0(%arg0: tensor<8x4x3xf32>):
|
||||
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
|
||||
|
@ -181,9 +181,41 @@ func @fakeQuantArgs_UnrankedTensor(tensor<f32>) -> tensor<f32> {
|
|||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: fakeQuantArgs_all_positive
|
||||
func @fakeQuantArgs_all_positive(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
||||
^bb0(%arg0: tensor<8x4x3xf32>):
|
||||
|
||||
// CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
|
||||
// CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:-128>>
|
||||
// CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:-128>>)
|
||||
// CHECK-SAME: -> tensor<8x4x3xf32>
|
||||
|
||||
%0 = "quant.const_fake_quant"(%arg0) {
|
||||
min = 0.5 : f32, max = 1.5 : f32, num_bits = 8, narrow_range = false, is_signed = true
|
||||
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
|
||||
return %0 : tensor<8x4x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: fakeQuantArgs_all_negative
|
||||
func @fakeQuantArgs_all_negative(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
||||
^bb0(%arg0: tensor<8x4x3xf32>):
|
||||
|
||||
// CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
|
||||
// CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:127>>
|
||||
// CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:127>>)
|
||||
// CHECK-SAME: -> tensor<8x4x3xf32>
|
||||
|
||||
%0 = "quant.const_fake_quant"(%arg0) {
|
||||
min = -1.5 : f32, max = -0.5 : f32, num_bits = 8, narrow_range = false, is_signed = true
|
||||
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
|
||||
return %0 : tensor<8x4x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// Verifies a qint8 per axis
|
||||
// CHECK_LABEL: fakeQuantPerAxis
|
||||
// CHECK-LABEL: fakeQuantPerAxis
|
||||
func @fakeQuantPerAxis(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
||||
^bb0(%arg0: tensor<8x4x3xf32>):
|
||||
|
||||
|
|
Loading…
Reference in New Issue