Remove the constraint that min / max should stride zero

Since we apply nudging for the zero point to make sure the nudged zerop points
can be in the range of [qmin, qmax], the constraint that rmin / rmax should
stride zero isn't necessary.

This also matches the documentation of tensorflow's FakeQuantWithMinMaxArgs op,
where min and max don't need to stride zero:
https://www.tensorflow.org/api_docs/python/tf/quantization/fake_quant_with_min_max_args

PiperOrigin-RevId: 268296285
This commit is contained in:
Feng Liu 2019-09-10 13:26:14 -07:00 committed by A. Unique TensorFlower
parent c68d5467d6
commit cf0a782339
3 changed files with 50 additions and 39 deletions

View File

@ -54,8 +54,17 @@ bool getDefaultStorageParams(unsigned numBits, bool narrowRange, bool isSigned,
return false;
}
void getScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, double rmax,
double &scale, int64_t &nudgedZeroPoint) {
// This is a specific implementation of nudging:
// If 0.0 < rmin < rmax or rmin < rmax < 0.0, the range will be shifted
// to include 0.0, but the range width size (rmax-rmin) isn't changed. The zero
// point is derived from the shifted range, and the scale isn't changed. As
// a consequence some values, which are supposeed in the original [rmin, rmax]
// range will be outside the shifted range and be clamped during quantization.
// TODO(fengliuai): we should nudge the scale as well, but that requires the
// fake quant op used in the training to use the nudged scale as well.
void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin,
double rmax, double &scale,
int64_t &nudgedZeroPoint) {
// Determine the scale.
const double qminDouble = qmin;
const double qmaxDouble = qmax;
@ -100,14 +109,6 @@ UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits,
double rmin, double rmax,
bool narrowRange, Type expressedType,
bool isSigned) {
// Range must straddle zero.
// TODO(b/140641593): remove this constraint.
if (rmin > 0.0 || rmax < 0.0) {
return (emitError(loc, "FakeQuant range must straddle zero: [")
<< rmin << "," << rmax << "]",
nullptr);
}
MLIRContext *ctx = expressedType.getContext();
unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
Type storageType;
@ -129,7 +130,7 @@ UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits,
double scale;
int64_t nudgedZeroPoint;
getScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
return UniformQuantizedType::getChecked(flags, storageType, expressedType,
scale, nudgedZeroPoint, qmin, qmax,
@ -172,7 +173,7 @@ fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension,
double scale;
int64_t nudgedZeroPoint;
getScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
scales.push_back(scale);
zeroPoints.push_back(nudgedZeroPoint);
}

View File

@ -1,27 +1,5 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -quant-convert-simulated-quantization
// -----
// Verify that a mismatched range errors.
func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// expected-error@+1 {{FakeQuant range must straddle zero: [1.100000e+00,1.500000e+00]}}
%0 = "quant.const_fake_quant"(%arg0) {
min = 1.1 : f32, max = 1.5 : f32, num_bits = 8
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}
// -----
// Verify that a valid range errors.
func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// expected-error@+1 {{FakeQuant range must straddle zero: [1.100000e+00,1.000000e+00}}
%0 = "quant.const_fake_quant"(%arg0) {
min = 1.1 : f32, max = 1.0 : f32, num_bits = 8
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}
// -----
// Unsupported quantizable type (i1 is currently not a supported element type).
func @fakeQuantArgs(tensor<8x4x3xi1>) -> tensor<8x4x3xi1> {

View File

@ -47,7 +47,7 @@ func @fakeQuantArgs_Quint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
// -----
// Verifies a quint8 asymmetric 0..1 range (with narrow_range = true).
// CHECK_LABEL: fakeQuantArgs_Quint8_NarrowRange
// CHECK-LABEL: fakeQuantArgs_Quint8_NarrowRange
func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
@ -62,7 +62,7 @@ func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
// -----
// Verifies a quint8 symmetric range of -1..127/128.
// CHECK_LABEL: fakeQuantArgs_Quint8_SymmetricRange
// CHECK-LABEL: fakeQuantArgs_Quint8_SymmetricRange
func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
@ -122,7 +122,7 @@ func @fakeQuantArgs_Qint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
// -----
// Verifies a qint8 asymmetric 0..1 range (with narrow_range = true).
// CHECK_LABEL: fakeQuantArgs_Qint8_NarrowRange
// CHECK-LABEL: fakeQuantArgs_Qint8_NarrowRange
func @fakeQuantArgs_Qint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
@ -137,7 +137,7 @@ func @fakeQuantArgs_Qint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
// -----
// Verifies a qint8 symmetric range of -1..127/128.
// CHECK_LABEL: fakeQuantArgs_Qint8_SymmetricRange
// CHECK-LABEL: fakeQuantArgs_Qint8_SymmetricRange
func @fakeQuantArgs_Qint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
@ -181,9 +181,41 @@ func @fakeQuantArgs_UnrankedTensor(tensor<f32>) -> tensor<f32> {
return %0 : tensor<f32>
}
// -----
// CHECK-LABEL: fakeQuantArgs_all_positive
func @fakeQuantArgs_all_positive(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
// CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:-128>>
// CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:-128>>)
// CHECK-SAME: -> tensor<8x4x3xf32>
%0 = "quant.const_fake_quant"(%arg0) {
min = 0.5 : f32, max = 1.5 : f32, num_bits = 8, narrow_range = false, is_signed = true
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}
// -----
// CHECK-LABEL: fakeQuantArgs_all_negative
func @fakeQuantArgs_all_negative(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
// CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:127>>
// CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:127>>)
// CHECK-SAME: -> tensor<8x4x3xf32>
%0 = "quant.const_fake_quant"(%arg0) {
min = -1.5 : f32, max = -0.5 : f32, num_bits = 8, narrow_range = false, is_signed = true
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}
// -----
// Verifies a qint8 per axis
// CHECK_LABEL: fakeQuantPerAxis
// CHECK-LABEL: fakeQuantPerAxis
func @fakeQuantPerAxis(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):