Bring naming of some quant ops in alignment with docs and introduce a few necessary additional ops (stats_ref, stats, coupled_ref).

--

PiperOrigin-RevId: 243919195
This commit is contained in:
Stella Laurenzo 2019-04-16 18:36:24 -07:00 committed by Mehdi Amini
parent ee7bdddfb3
commit a2e08eb384
8 changed files with 294 additions and 73 deletions

View File

@ -226,14 +226,15 @@ TODO : Flesh this section out.
### Instrumentation and constraint ops
TODO : These ops are not defined yet
* instrument_stats : Assigns a unique id and signals that statistics should be
collected by the runtime when execution passes through this op.
* constrain_uniform : Constrains that for uniform quantization, the solver
should choose a type with certain characteristics such as the number of
fixed-point values, underlying storage type, or whether to constrain to
power of two scales.
* const_fake_quant : Emulates the logic of the historic TensorFlow
fake_quant_with_min_max_args op.
* stats_ref : Declares that statistics should be gathered at this point with a
unique key and made available to future passes of the solver.
* stats : Declares inline statistics (per layer and per axis) for the point in
the computation. stats_ref ops are generally converted to stats ops once
trial runs have been performed.
* coupled_ref : Declares points in the computation to be coupled from a type
inference perspective based on a unique key.
## Integration with simulated quantization at training time

View File

@ -36,45 +36,47 @@ class quant_Op<string mnemonic, list<OpTrait> traits> :
Op<!strconcat("quant.", mnemonic), traits>;
//===----------------------------------------------------------------------===//
// Quantization barriers
// Quantization casts
//===----------------------------------------------------------------------===//
class quant_BarrierOp<string mnemonic, list<OpTrait> traits> :
quant_Op<mnemonic, traits>, Arguments<(ins quant_RealValueType:$arg)>,
Results<(outs quant_RealValueType)>;
// A QuantizeBarrier (qbarrier) represents a potential type shift from a
// quantizable type to a quantized type.
// A QuantizeCast (qcast) represents a potential type shift from a quantizable
// type to a quantized type.
//
// At runtime, a qbarrier will apply the transformation expressed by its
// At runtime, a qcast will apply the transformation expressed by its
// operand and result type. For flexibility during transformation, it is also
// possible to have a qbarrier that performs no transformation (both its
// possible to have a qcast that performs no transformation (both its
// operand and result type are quantizable).
//
// A qbarrier will typically originate from either:
// A qcast will typically originate from either:
// a) An expressed or implied constraint in the source dialect which signals
// that a certain level of quantization is possible or required.
// b) An inference made by a quantization algorithm indicating that a
// quantized representation may be acceptable.
//
// Especially early in transformation, it is common to have pairs of
// qbarrier/dbarrier at points where a transition to a quantized type is
// required. In addition, it is also common to have an identity qbarrier
// qcast/dcast at points where a transition to a quantized type is
// required. In addition, it is also common to have an identity qcast
// (where the operand and result type are not quantized) at all points where
// it is legal to use a quantized representation (but is not known to be
// acceptable).
def quant_QuantizeBarrierOp : quant_BarrierOp<"qbarrier", [NoSideEffect]>;
def quant_QuantizeCastOp : quant_Op<"qcast", [NoSideEffect]> {
let arguments = (ins quant_RealValueType:$arg);
let results = (outs quant_RealValueType);
}
// A DequantizeBarrier (dbarrier) represents the inverse of a qbarrier,
// A DequantizeCast op (dcast) represents the inverse of a qcast,
// converting back from a quantized to quantizable (expressed) type.
//
// Like qbarriers, a dbarrier is allowed to have both its operand and result
// Like qcasts, a dcast is allowed to have both its operand and result
// as non quantized types. This facilitates transformations and marks edges
// where the computation must be carried out in the expressed type.
//
// Especially early in transformation, it is common to have dbarriers on
// Especially early in transformation, it is common to have dcasts on
// all operands to ops that must operate with the expressed type (typically
// math ops prior to lowering to target-specific, quantized kernels).
def quant_DequantizeBarrierOp : quant_BarrierOp<"dbarrier", [NoSideEffect]>;
def quant_DequantizeCastOp : quant_Op<"dcast", [NoSideEffect]> {
let arguments = (ins quant_RealValueType:$arg);
let results = (outs quant_RealValueType);
}
// A StorageCast (scast) represents a cast from or to a type based on the
// storage type and a type based on a corresponding quantized type.
@ -87,13 +89,13 @@ def quant_DequantizeBarrierOp : quant_BarrierOp<"dbarrier", [NoSideEffect]>;
// i8 -> !quant<"uniform[i8:f32]{1.0}">
// tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
// vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
def quant_StorageCastOp :
quant_Op<"scast", [NoSideEffect]>,
Arguments<(ins quant_RealOrStorageValueType:$arg)>,
Results<(outs quant_RealOrStorageValueType)>;
def quant_StorageCastOp : quant_Op<"scast", [NoSideEffect]> {
let arguments = (ins quant_RealOrStorageValueType:$arg);
let results = (outs quant_RealOrStorageValueType);
}
//===----------------------------------------------------------------------===//
// Training integration ops
// Training integration and instrumentation ops
//===----------------------------------------------------------------------===//
def quant_ConstFakeQuant : quant_Op<"const_fake_quant",
@ -102,11 +104,11 @@ def quant_ConstFakeQuant : quant_Op<"const_fake_quant",
"Simulates the effect of uniform quantization with const range.";
let description = [{
Given a const min, max, num_bits and narrow_range attribute, applies the same
uniform quantization simulation as is done by the TensorFlow
fake_quant_with_min_max_args op. See the fakeQuantAttrsToType() utility
method and the quant-convert-simulated-quantization pass for futher details.
}];
Given a const min, max, num_bits and narrow_range attribute, applies the
same uniform quantization simulation as is done by the TensorFlow
fake_quant_with_min_max_args op. See the fakeQuantAttrsToType() utility
method and the quant-convert-simulated-quantization pass for futher details.
}];
let arguments = (ins
F32Tensor:$inputs,
@ -123,4 +125,96 @@ method and the quant-convert-simulated-quantization pass for futher details.
);
}
def quant_StatisticsRefOp : quant_Op<"stats_ref", []> {
let summary =
"Indicates that statistics are resolved by reference.";
let description = [{
This op acts as an identity that, when encountered at runtime, should result
in statistics being collected about about the value of its operand/result.
Such statistics will be stored with the provided key, allowing this node
to later be converted to a 'stats' op if statistics with that key have been
encountered.
}];
let arguments = (ins
quant_RealValueType:$arg,
StrAttr:$statsKey
);
let results = (outs quant_RealValueType);
}
def quant_StatisticsOp : quant_Op<"stats", []> {
let summary =
"Identity op which associates statistics with the value.";
let description = [{
Associates statistics about the runtime ranges of values observed for
evaluations of this node.
Statistics about the entire type are reported in the 'layerStats' attribute
and those for each axis, in the (optional) `axisStats` attribute. The
interpretation of each is determined by the last dimension of its shape.
Currently, only dim=2 is supported, which is interpreted as [min, max].
`layerStats` must be a rank 1 tensor: [2]
`axisStats` must be a rank 2 tensor: [N, 2], where N=the rank of `arg`.
}];
let arguments = (ins
quant_RealValueType:$arg,
ElementsAttr:$layerStats,
OptionalAttr<ElementsAttr>:$axisStats);
let results = (outs quant_RealValueType);
let verifier = [{
auto tensorArg = arg()->getType().dyn_cast<TensorType>();
auto argRank = tensorArg ? tensorArg.getRank() : 0;
// Verify layerStats attribute.
{
auto layerStatsType = layerStats().getType();
if (!layerStatsType.getElementType().isa<FloatType>()) {
return emitOpError(
"layerStats must have a floating point element type");
}
if (layerStatsType.getRank() != 1 || layerStatsType.getDimSize(0) != 2) {
return emitOpError("layerStats must have shape [2]");
}
}
// Verify axisStats (optional) attribute.
if (axisStats()) {
auto axisStatsType = axisStats()->getType();
if (!axisStatsType.getElementType().isa<FloatType>()) {
return emitOpError("axisStats must have a floating point element type");
}
if (axisStatsType.getRank() != 2 ||
axisStatsType.getDimSize(1) != 2 ||
axisStatsType.getDimSize(0) != argRank) {
return emitOpError("axisStats must have shape [N,2] "
"where N = the argument rank");
}
}
return success();
}];
}
def quant_CoupledRefOp : quant_Op<"coupled_ref", []> {
let summary =
"Indicates that one point of the computation is coupled to another.";
let description = [{
Ordinarily, relationships between ops for the purposes of determining
compatible quantized types is explicit based on the use-def chain. However,
in some situations, a use may be separated from its def by arbitrary
external connections. In such a case, during analysis, all coupled_ref
nodes in a module which share a coupledKey will be considered to be
directly connected as via an identity op for the purpose of type inference.
}];
let arguments = (ins
quant_RealValueType:$arg,
StrAttr:$coupledKey);
let results = (outs quant_RealValueType);
}
#endif // QUANT_OPS

View File

@ -44,7 +44,7 @@ public:
};
QuantizedConstRewrite(MLIRContext *context)
: RewritePattern(QuantizeBarrierOp::getOperationName(), 1, context) {}
: RewritePattern(QuantizeCastOp::getOperationName(), 1, context) {}
PatternMatchResult match(Operation *op) const override;
void rewrite(Operation *op, std::unique_ptr<PatternState> baseState,
@ -59,7 +59,7 @@ PatternMatchResult QuantizedConstRewrite::match(Operation *op) const {
State state;
// Is the operand a constant?
auto qbarrier = op->cast<QuantizeBarrierOp>();
auto qbarrier = op->cast<QuantizeCastOp>();
if (!matchPattern(qbarrier.arg(), m_Constant(&state.value))) {
return matchFailure();
}

View File

@ -82,10 +82,10 @@ public:
// TODO: Map to a qbarrier with an attribute like [Forced] to signal that
// this is a forced/hard-coded constraint.
auto qbarrier = rewriter.create<QuantizeBarrierOp>(
op->getLoc(), quantizedType, fqOp.inputs());
rewriter.replaceOpWithNewOp<DequantizeBarrierOp>(op, converter.inputType,
qbarrier.getResult());
auto qbarrier = rewriter.create<QuantizeCastOp>(op->getLoc(), quantizedType,
fqOp.inputs());
rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType,
qbarrier.getResult());
return false;
}

View File

@ -14,8 +14,8 @@ func @constant_splat_tensor_u8_affine() -> tensor<4xf32> {
// CHECK: %cst = constant splat<tensor<4xi8>, -64> : tensor<4xi8>
// CHECK-NEXT: %0 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">>
%cst = constant splat<tensor<4xf32>, 0.5> : tensor<4xf32>
%1 = "quant.qbarrier"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">>
%2 = "quant.dbarrier"(%1) : (tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">>) -> (tensor<4xf32>)
%1 = "quant.qcast"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">>
%2 = "quant.dcast"(%1) : (tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">>) -> (tensor<4xf32>)
return %2 : tensor<4xf32>
}
@ -26,8 +26,8 @@ func @constant_splat_tensor_i8_affine() -> tensor<4xf32> {
// CHECK: %cst = constant splat<tensor<4xi8>, 63> : tensor<4xi8>
// CHECK-NEXT: %0 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03:-1}">>
%cst = constant splat<tensor<4xf32>, 0.5> : tensor<4xf32>
%1 = "quant.qbarrier"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03:-1}">>
%2 = "quant.dbarrier"(%1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03:-1}">>) -> (tensor<4xf32>)
%1 = "quant.qcast"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03:-1}">>
%2 = "quant.dcast"(%1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03:-1}">>) -> (tensor<4xf32>)
return %2 : tensor<4xf32>
}
@ -38,8 +38,8 @@ func @const_splat_tensor_i8_fixedpoint() -> tensor<4xf32> {
// CHECK: %cst = constant splat<tensor<4xi8>, 64> : tensor<4xi8>
// CHECK-NEXT: %0 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>
%cst = constant splat<tensor<4xf32>, 0.5> : tensor<4xf32>
%1 = "quant.qbarrier"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>
%2 = "quant.dbarrier"(%1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<4xf32>)
%1 = "quant.qcast"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>
%2 = "quant.dcast"(%1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<4xf32>)
return %2 : tensor<4xf32>
}
@ -49,8 +49,8 @@ func @const_splat_tensor_i8_fixedpoint() -> tensor<4xf32> {
func @const_splat_tensor_i8_fixedpoint_neg() -> tensor<4xf32> {
// CHECK: %cst = constant splat<tensor<4xi8>, -64> : tensor<4xi8>
%cst = constant splat<tensor<4xf32>, -0.5> : tensor<4xf32>
%1 = "quant.qbarrier"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>
%2 = "quant.dbarrier"(%1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<4xf32>)
%1 = "quant.qcast"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>
%2 = "quant.dcast"(%1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<4xf32>)
return %2 : tensor<4xf32>
}
@ -60,8 +60,8 @@ func @const_splat_tensor_i8_fixedpoint_neg() -> tensor<4xf32> {
func @const_dense_tensor_i8_fixedpoint() -> tensor<7xf32> {
// CHECK: %cst = constant dense<tensor<7xi8>, [-128, -128, -64, 0, 64, 127, 127]> : tensor<7xi8>
%cst = constant dense<tensor<7xf32>, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32>
%1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i8:f32]{7.812500e-03}">>
%2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<7xf32>)
%1 = "quant.qcast"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i8:f32]{7.812500e-03}">>
%2 = "quant.dcast"(%1) : (tensor<7x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<7xf32>)
return %2 : tensor<7xf32>
}
@ -74,8 +74,8 @@ func @const_sparse_tensor_i8_fixedpoint() -> tensor<7x2xf32> {
%cst = constant sparse<tensor<7x2xf32>,
[[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]],
[-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7x2xf32>
%1 = "quant.qbarrier"(%cst) : (tensor<7x2xf32>) -> tensor<7x2x!quant<"uniform[i8:f32]{7.812500e-03}">>
%2 = "quant.dbarrier"(%1) : (tensor<7x2x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<7x2xf32>)
%1 = "quant.qcast"(%cst) : (tensor<7x2xf32>) -> tensor<7x2x!quant<"uniform[i8:f32]{7.812500e-03}">>
%2 = "quant.dcast"(%1) : (tensor<7x2x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<7x2xf32>)
return %2 : tensor<7x2xf32>
}
@ -86,8 +86,8 @@ func @const_primitive_float_i8_fixedpoint() -> f32 {
// CHECK: %c64_i8 = constant 64 : i8
// CHECK-NEXT: %0 = "quant.scast"(%c64_i8) : (i8) -> !quant<"uniform[i8:f32]{7.812500e-03}">
%cst = constant 0.5 : f32
%1 = "quant.qbarrier"(%cst) : (f32) -> !quant<"uniform[i8:f32]{7.812500e-03}">
%2 = "quant.dbarrier"(%1) : (!quant<"uniform[i8:f32]{7.812500e-03}">) -> (f32)
%1 = "quant.qcast"(%cst) : (f32) -> !quant<"uniform[i8:f32]{7.812500e-03}">
%2 = "quant.dcast"(%1) : (!quant<"uniform[i8:f32]{7.812500e-03}">) -> (f32)
return %2 : f32
}
@ -98,8 +98,8 @@ func @const_dense_tensor_u4_affine() -> tensor<7xf32> {
// NOTE: Unsigned quantities printed by MLIR as signed.
// CHECK: %cst = constant dense<tensor<7xi4>, [0, 0, 4, -8, -4, -1, -1]> : tensor<7xi4>
%cst = constant dense<tensor<7xf32>, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32>
%1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[u4:f32]{1.250000e-01:8}">>
%2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[u4:f32]{1.250000e-01:8}">>) -> (tensor<7xf32>)
%1 = "quant.qcast"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[u4:f32]{1.250000e-01:8}">>
%2 = "quant.dcast"(%1) : (tensor<7x!quant<"uniform[u4:f32]{1.250000e-01:8}">>) -> (tensor<7xf32>)
return %2 : tensor<7xf32>
}
@ -110,8 +110,8 @@ func @const_dense_tensor_i4_affine() -> tensor<7xf32> {
// NOTE: Unsigned quantities printed by MLIR as signed.
// CHECK: %cst = constant dense<tensor<7xi4>, [-8, -8, -5, -1, 3, 7, 7]> : tensor<7xi4>
%cst = constant dense<tensor<7xf32>, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32>
%1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i4:f32]{1.250000e-01:-1}">>
%2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[i4:f32]{1.250000e-01:-1}">>) -> (tensor<7xf32>)
%1 = "quant.qcast"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i4:f32]{1.250000e-01:-1}">>
%2 = "quant.dcast"(%1) : (tensor<7x!quant<"uniform[i4:f32]{1.250000e-01:-1}">>) -> (tensor<7xf32>)
return %2 : tensor<7xf32>
}
@ -121,8 +121,8 @@ func @const_dense_tensor_i4_affine() -> tensor<7xf32> {
func @const_dense_tensor_i4_fixedpoint() -> tensor<7xf32> {
// CHECK: %cst = constant dense<tensor<7xi4>, [-8, -8, -4, 0, 4, 7, 7]> : tensor<7xi4>
%cst = constant dense<tensor<7xf32>, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32>
%1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i4:f32]{1.250000e-01}">>
%2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[i4:f32]{1.250000e-01}">>) -> (tensor<7xf32>)
%1 = "quant.qcast"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i4:f32]{1.250000e-01}">>
%2 = "quant.dcast"(%1) : (tensor<7x!quant<"uniform[i4:f32]{1.250000e-01}">>) -> (tensor<7xf32>)
return %2 : tensor<7xf32>
}
@ -134,7 +134,7 @@ func @const_dense_tensor_i4_fixedpoint() -> tensor<7xf32> {
func @const_custom_storage_range_i8_fixedpoint() -> tensor<7xf32> {
// CHECK: %cst = constant dense<tensor<7xi8>, [-100, -100, -64, 0, 64, 100, 100]> : tensor<7xi8>
%cst = constant dense<tensor<7xf32>, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32>
%1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i8(-100:100):f32]{7.812500e-03}">>
%2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[i8(-100:100):f32]{7.812500e-03}">>) -> (tensor<7xf32>)
%1 = "quant.qcast"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i8(-100:100):f32]{7.812500e-03}">>
%2 = "quant.dcast"(%1) : (tensor<7x!quant<"uniform[i8(-100:100):f32]{7.812500e-03}">>) -> (tensor<7xf32>)
return %2 : tensor<7xf32>
}

View File

@ -5,9 +5,9 @@
// CHECK-LABEL: fakeQuantArgs_Quint8_0_1
func @fakeQuantArgs_Quint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<8x4x3xf32>)
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
// CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[u8:f32]{0.0039215686274509803}">>
// CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8:f32]{0.0039215686274509803}">>)
// CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant<"uniform[u8:f32]{0.0039215686274509803}">>)
// CHECK-SAME: -> tensor<8x4x3xf32>
%0 = "quant.const_fake_quant"(%arg0) {
min: 0.0 : f32, max: 1.0 : f32, num_bits: 8
@ -20,9 +20,9 @@ func @fakeQuantArgs_Quint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
// CHECK_LABEL: fakeQuantArgs_Quint8_NarrowRange
func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<8x4x3xf32>)
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
// CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[u8(1:255):f32]{0.003937007874015748:1}">>
// CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8(1:255):f32]{0.003937007874015748:1}">>)
// CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant<"uniform[u8(1:255):f32]{0.003937007874015748:1}">>)
// CHECK-SAME: -> tensor<8x4x3xf32>
%0 = "quant.const_fake_quant"(%arg0) {
min: 0.0 : f32, max: 1.0 : f32, num_bits: 8, narrow_range: true
@ -35,9 +35,9 @@ func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
// CHECK_LABEL: fakeQuantArgs_Quint8_SymmetricRange
func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<8x4x3xf32>)
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
// CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[u8:f32]{7.812500e-03:128}">>
// CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8:f32]{7.812500e-03:128}">>)
// CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant<"uniform[u8:f32]{7.812500e-03:128}">>)
// CHECK-SAME: -> tensor<8x4x3xf32>
%0 = "quant.const_fake_quant"(%arg0) {
min: -1.0 : f32, max: 0.9921875 : f32, num_bits: 8, narrow_range: false
@ -51,9 +51,9 @@ func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32
// CHECK-LABEL: fakeQuantArgs_Qint16_Symmetric
func @fakeQuantArgs_Qint16_Symmetric(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<8x4x3xf32>)
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
// CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[i16:f32]{3.0517578125E-5}">>
// CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[i16:f32]{3.0517578125E-5}">>)
// CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant<"uniform[i16:f32]{3.0517578125E-5}">>)
// CHECK-SAME: -> tensor<8x4x3xf32>
%0 = "quant.const_fake_quant"(%arg0) {
min: -1.0 : f32, max: 0.999969482 : f32, num_bits: 16
@ -66,9 +66,9 @@ func @fakeQuantArgs_Qint16_Symmetric(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
// CHECK-LABEL: fakeQuantArgs_UnrankedTensor
func @fakeQuantArgs_UnrankedTensor(tensor<f32>) -> tensor<f32> {
^bb0(%arg0: tensor<f32>):
// CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<f32>)
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<f32>)
// CHECK-SAME: -> tensor<!quant<"uniform[u8:f32]{0.0039215686274509803}">>
// CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<!quant<"uniform[u8:f32]{0.0039215686274509803}">>)
// CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<!quant<"uniform[u8:f32]{0.0039215686274509803}">>)
// CHECK-SAME: -> tensor<f32>
%0 = "quant.const_fake_quant"(%arg0) {
min: 0.0 : f32, max: 1.0 : f32, num_bits: 8

View File

@ -0,0 +1,77 @@
// RUN: mlir-opt %s -split-input-file -verify
// -----
func @invalidStatisticsMismatchedLayerType(%arg0: tensor<8x4x3xf32>) ->
tensor<8x4x3xf32> {
// expected-error@+1 {{layerStats must have a floating point element type}}
%0 = "quant.stats"(%arg0) {
layerStats: dense<tensor<2xi8>, [-1, 1]>
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}
// -----
func @invalidStatisticsMismatchedLayerRank(%arg0: tensor<8x4x3xf32>) ->
tensor<8x4x3xf32> {
// expected-error@+1 {{layerStats must have shape [2]}}
%0 = "quant.stats"(%arg0) {
layerStats: dense<tensor<1x2xf32>, [[-1.0, 1.0]]>
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}
// -----
func @invalidStatisticsMismatchedLayerShape(%arg0: tensor<8x4x3xf32>) ->
tensor<8x4x3xf32> {
// expected-error@+1 {{layerStats must have shape [2]}}
%0 = "quant.stats"(%arg0) {
layerStats: dense<tensor<3xf32>, [-1.0, 1.0, 2.0]>
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}
// -----
// CHECK-LABEL: validStatistics
func @invalidStatisticsMismatchedAxisType(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
// expected-error@+1 {{axisStats must have a floating point element type}}
%0 = "quant.stats"(%0) {
layerStats: dense<tensor<2xf32>, [-1.0, 1.0]>,
axisStats: dense<tensor<3x2xi8>, [
[-1, 1],
[-8, 8],
[-1, 0]
]>
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}
// -----
func @invalidStatisticsMismatchedAxisRank(%arg0: tensor<8x4x3xf32>) ->
tensor<8x4x3xf32> {
// expected-error@+1 {{axisStats must have shape [N,2] where N = the argument rank}}
%0 = "quant.stats"(%arg0) {
layerStats: dense<tensor<2xf32>, [-1.0, 1.0]>,
axisStats: dense<tensor<4x2xf32>, [
[-1.0, 1.0],
[-8.0, 8.0],
[-0.5, 0.5],
[-2.0, 3.5]
]>
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}
// -----
func @invalidStatisticsMismatchedAxisShape(%arg0: tensor<8x4x3xf32>) ->
tensor<8x4x3xf32> {
// expected-error@+1 {{axisStats must have shape [N,2] where N = the argument rank}}
%0 = "quant.stats"(%arg0) {
layerStats: dense<tensor<2xf32>, [-1.0, 1.0]>,
axisStats: dense<tensor<3x3xf32>, [
[-1.0, 1.0, 1.0],
[-8.0, 8.0, 1.0],
[-0.5, 0.5, 1.0]
]>
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}

View File

@ -0,0 +1,49 @@
// RUN: mlir-opt %s -split-input-file | FileCheck %s
// -----
// CHECK-LABEL: validConstFakeQuant
func @validConstFakeQuant(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
%0 = "quant.const_fake_quant"(%arg0) {
min: 0.0 : f32, max: 1.0 : f32, num_bits: 8, narrow_range: true
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
%1 = "quant.const_fake_quant"(%0) {
min: 0.0 : f32, max: 1.0 : f32, num_bits: 8, narrow_range: false
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
%2 = "quant.const_fake_quant"(%1) {
min: 0.0 : f32, max: 1.0 : f32, num_bits: 8
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %2 : tensor<8x4x3xf32>
}
// -----
// CHECK-LABEL: validStatisticsRef
func @validStatisticsRef(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
%0 = "quant.stats_ref"(%arg0) { statsKey: "foobar" } :
(tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}
// -----
// CHECK-LABEL: validStatistics
func @validStatistics(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
%0 = "quant.stats"(%arg0) {
layerStats: dense<tensor<2xf32>, [-1.0, 1.0]>
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
%1 = "quant.stats"(%0) {
layerStats: dense<tensor<2xf32>, [-1.0, 1.0]>,
axisStats: dense<tensor<3x2xf32>, [
[-1.0, 1.0],
[-8.0, 8.0],
[-0.5, 0.5]
]>
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %1 : tensor<8x4x3xf32>
}
// -----
// CHECK-LABEL: validCoupledRef
func @validCoupledRef(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
%0 = "quant.coupled_ref"(%arg0) { coupledKey: "foobar" } :
(tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}