Support signed and unsigned quantization types

This patch added a new argument to the fakeQuantAttrsToType utility method, so
it can be used to convert min/max to quantized type with different signed
storage types.

PiperOrigin-RevId: 258382538
This commit is contained in:
Feng Liu 2019-07-16 09:32:18 -07:00 committed by Mehdi Amini
parent 0ede23010f
commit a6d2223584
2 changed files with 17 additions and 11 deletions

View File

@ -23,8 +23,8 @@
//
// Specifically, it combines the following concerns, each of which would be
// independent variables in a more generic setup:
// - num_bits implies storage data type (quint8, int16)
// - num_bits < 8 is promoted to quint8
// - numBits and isSigned imply storage data type (uint8, int8, int16)
// - numBits < 8 is promoted to uint8 or int8
// - "narrow_range" narrows the lower bound of the storage type's range by
// 1
// - the specified min/max values are "nudged" so that the result has a zero
@ -59,7 +59,8 @@ namespace quant {
/// originating op.
UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits,
double rmin, double rmax,
bool narrowRange, Type expressedType);
bool narrowRange, Type expressedType,
bool isSigned = false);
} // namespace quant
} // namespace mlir

View File

@ -21,11 +21,10 @@
using namespace mlir;
using namespace mlir::quant;
UniformQuantizedType mlir::quant::fakeQuantAttrsToType(Location loc,
unsigned numBits,
double rmin, double rmax,
bool narrowRange,
Type expressedType) {
UniformQuantizedType
mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin,
double rmax, bool narrowRange,
Type expressedType, bool isSigned) {
MLIRContext *ctx = expressedType.getContext();
Type storageType;
unsigned flags;
@ -35,9 +34,15 @@ UniformQuantizedType mlir::quant::fakeQuantAttrsToType(Location loc,
// Hard-coded type mapping from TFLite.
if (numBits <= 8) {
storageType = IntegerType::get(8, ctx);
flags = 0;
qmin = 0;
qmax = 255;
if (isSigned) {
flags = QuantizationFlags::Signed;
qmin = -128;
qmax = 127;
} else {
flags = 0;
qmin = 0;
qmax = 255;
}
} else if (numBits <= 16) {
storageType = IntegerType::get(16, ctx);
flags = QuantizationFlags::Signed;