Convert per channel fake quant attributes to type

For per channel fake quant attributes, the returned type should be
UniformQuantizedPerAxisType. Currently, this method isn't under test because we
haven't added the quant_ConstFakeQuantPerAxis op and the convert method.

PiperOrigin-RevId: 268084017
This commit is contained in:
Feng Liu 2019-09-09 14:57:29 -07:00 committed by A. Unique TensorFlower
parent 893c86fff7
commit 27d776fa6d
2 changed files with 110 additions and 36 deletions

View File

@ -62,6 +62,14 @@ UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits,
bool narrowRange, Type expressedType,
bool isSigned = false);
/// Converts per-channel FakeQuant attributes to the corresponding type.
/// In the event that the parameters cannot be converted, returns a nullptr
/// convertible Type and issues an appropriate error.
UniformQuantizedPerAxisType
fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension,
ArrayRef<double> rmins, ArrayRef<double> rmax,
bool narrowRange, Type expressedType,
bool isSigned = false);
} // namespace quant
} // namespace mlir

View File

@ -18,71 +18,48 @@
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h"
#include "mlir/Dialect/QuantOps/QuantTypes.h"
using namespace mlir;
using namespace mlir::quant;
UniformQuantizedType
mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin,
double rmax, bool narrowRange,
Type expressedType, bool isSigned) {
MLIRContext *ctx = expressedType.getContext();
Type storageType;
unsigned flags;
int64_t qmin;
int64_t qmax;
namespace mlir {
namespace quant {
namespace {
bool getDefaultStorageParams(unsigned numBits, bool narrowRange, bool isSigned,
MLIRContext *ctx, Type &storageType, int64_t &qmin,
int64_t &qmax) {
// Hard-coded type mapping from TFLite.
if (numBits <= 8) {
storageType = IntegerType::get(8, ctx);
if (isSigned) {
flags = QuantizationFlags::Signed;
qmin = -128;
qmax = 127;
} else {
flags = 0;
qmin = 0;
qmax = 255;
}
} else if (numBits <= 16) {
storageType = IntegerType::get(16, ctx);
if (isSigned) {
flags = QuantizationFlags::Signed;
qmin = -32768;
qmax = 32767;
} else {
flags = 0;
qmin = 0;
qmax = 65535;
}
} else {
emitError(loc, "unsupported FakeQuant number of bits: ") << numBits;
return nullptr;
return true;
}
// Handle narrowRange.
if (narrowRange) {
qmin += 1;
}
return false;
}
// Range must straddle zero.
if (rmin > 0.0 || rmax < 0.0) {
return (emitError(loc, "FakeQuant range must straddle zero: [")
<< rmin << "," << rmax << "]",
nullptr);
}
// Special case where min/max is close enough. The tensor contents are all
// 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero
// points and dequantized to 0.0.
if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
return UniformQuantizedType::getChecked(flags, storageType, expressedType,
1.0, qmin, qmin, qmax, loc);
}
void getScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, double rmax,
double &scale, int64_t &nudgedZeroPoint) {
// Determine the scale.
const double qminDouble = qmin;
const double qmaxDouble = qmax;
const double scale = (rmax - rmin) / (qmaxDouble - qminDouble);
scale = (rmax - rmin) / (qmaxDouble - qminDouble);
// Zero point computation.
// In float, solve the affine equation for any known pair
@ -103,7 +80,7 @@ mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin,
: zeroPointFromMax;
// Now nudge the zero point to be an integer.
int64_t nudgedZeroPoint = 0;
nudgedZeroPoint = 0;
if (zeroPointDouble < qminDouble) {
nudgedZeroPoint = qmin;
} else if (zeroPointDouble > qmaxDouble) {
@ -115,8 +92,97 @@ mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin,
// By construction, the nudged zero point should always be in range.
assert(nudgedZeroPoint >= qmin);
assert(nudgedZeroPoint <= qmax);
}
} // end namespace
UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits,
double rmin, double rmax,
bool narrowRange, Type expressedType,
bool isSigned) {
// Range must straddle zero.
// TODO(b/140641593): remove this constraint.
if (rmin > 0.0 || rmax < 0.0) {
return (emitError(loc, "FakeQuant range must straddle zero: [")
<< rmin << "," << rmax << "]",
nullptr);
}
MLIRContext *ctx = expressedType.getContext();
unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
Type storageType;
int64_t qmin;
int64_t qmax;
if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
qmin, qmax)) {
return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits,
nullptr);
}
// Special case where min/max is close enough. The tensor contents are all
// 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero
// points and dequantized to 0.0.
if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
return UniformQuantizedType::getChecked(flags, storageType, expressedType,
1.0, qmin, qmin, qmax, loc);
}
double scale;
int64_t nudgedZeroPoint;
getScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
return UniformQuantizedType::getChecked(flags, storageType, expressedType,
scale, nudgedZeroPoint, qmin, qmax,
loc);
}
// TODO(fengliuai): test this method once the quantizeAttr method is fixed.
UniformQuantizedPerAxisType
fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension,
ArrayRef<double> rmins, ArrayRef<double> rmaxs,
bool narrowRange, Type expressedType, bool isSigned) {
size_t axis_size = rmins.size();
if (axis_size != rmaxs.size()) {
return (emitError(loc, "mismatched per-axis min and max size: ")
<< axis_size << " vs. " << rmaxs.size(),
nullptr);
}
MLIRContext *ctx = expressedType.getContext();
Type storageType;
int64_t qmin;
int64_t qmax;
if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
qmin, qmax)) {
return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits,
nullptr);
}
SmallVector<double, 4> scales;
SmallVector<int64_t, 4> zeroPoints;
scales.reserve(axis_size);
zeroPoints.reserve(axis_size);
for (size_t axis = 0; axis != axis_size; ++axis) {
double rmin = rmins[axis];
double rmax = rmaxs[axis];
if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
scales.push_back(1.0);
zeroPoints.push_back(qmin);
continue;
}
double scale;
int64_t nudgedZeroPoint;
getScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
scales.push_back(scale);
zeroPoints.push_back(nudgedZeroPoint);
}
unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
return UniformQuantizedPerAxisType::getChecked(
flags, storageType, expressedType, scales, zeroPoints, qmin, qmax,
quantizedDimension, loc);
}
} // namespace quant
} // namespace mlir