forked from OSchip/llvm-project
Split the Quantization dialect.
- Retains Quantization types and predicates. - Retains utilities and example (testable) passes/ops. - Retains unit tests for example passes/ops. - Moves fixed point ops (and corresponding real ops) to FxpMathOps. - Moves real -> fixed point pass to FxpMathOps. - Sever the dependency on the TF dialect from Quantization. These dialects should now be open-sourcable. -- PiperOrigin-RevId: 241825598
This commit is contained in:
parent
1b56ce3087
commit
288bf2b5b9
|
@ -0,0 +1,39 @@
|
|||
//===- FxpMathOps/FxpMathOps.h - Fixed point ops ----------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef MLIR_FXPMATH_QUANTOPS_H_
|
||||
#define MLIR_FXPMATH_QUANTOPS_H_
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace fxpmath {
|
||||
|
||||
/// Defines the 'FxpMathOps' dialect.
|
||||
class FxpMathOpsDialect : public Dialect {
|
||||
public:
|
||||
FxpMathOpsDialect(MLIRContext *context);
|
||||
};
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/FxpMathOps/FxpMathOps.h.inc"
|
||||
|
||||
} // namespace fxpmath
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_FXPMATH_QUANTOPS_H_
|
|
@ -0,0 +1,183 @@
|
|||
//===- FxpMathOps.td - Fixed point ops --------------------*- tablegen -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This is the operation definition file for fixed point ops (and real
|
||||
// equivalents).
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifdef FXPMATH_OPS
|
||||
#else
|
||||
|
||||
#ifdef OP_BASE
|
||||
#else
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Quantization/QuantPredicates.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Attributes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Real value for an (inclusive) min/max clamp limit.
|
||||
def fxpmath_ClampValueAttr : OptionalAttr<F64Attr>;
|
||||
|
||||
// Element-wise activation function to apply.
|
||||
// Note that RELU activations are not here: they are expressed as clamps.
|
||||
def fxpmath_EwUnaryFnAttr :
|
||||
StringBasedAttr<CPred<"true">, "element-wise unary function"> {
|
||||
let returnType = [{ StringRef }];
|
||||
let defaultValue = "IDENTITY";
|
||||
}
|
||||
|
||||
class fxpmath_ConstEwUnaryFn<string val> : ConstantAttr<fxpmath_EwUnaryFnAttr, val>;
|
||||
def fxpmath_EwUnaryFn_Identity: fxpmath_ConstEwUnaryFn<"IDENTITY">;
|
||||
def fxpmath_EwUnaryFn_Tanh : fxpmath_ConstEwUnaryFn<"TANH">;
|
||||
def fxpmath_EwUnaryFn_Sigmoid : fxpmath_ConstEwUnaryFn<"SIGMOID">;
|
||||
def fxpmath_EwUnaryFn_Exp : fxpmath_ConstEwUnaryFn<"EXP">;
|
||||
def fxpmath_EwUnaryFn_Log : fxpmath_ConstEwUnaryFn<"LOG">;
|
||||
def fxpmath_EwUnaryFn_Neg : fxpmath_ConstEwUnaryFn<"NEG">;
|
||||
def fxpmath_EwUnaryFn_Rsqrt : fxpmath_ConstEwUnaryFn<"RSQRT">;
|
||||
def fxpmath_EwUnaryFn_Sin : fxpmath_ConstEwUnaryFn<"SIN">;
|
||||
def fxpmath_EwUnaryFn_Square : fxpmath_ConstEwUnaryFn<"SQUARE">;
|
||||
def fxpmath_EwUnaryFn_Sqrt : fxpmath_ConstEwUnaryFn<"SQRT">;
|
||||
def fxpmath_EwUnaryFn_CmpZ : fxpmath_ConstEwUnaryFn<"CMPZ">;
|
||||
def fxpmath_EwUnaryFn_CmpNZ : fxpmath_ConstEwUnaryFn<"CMPNZ">;
|
||||
def fxpmath_EwUnaryFn_CmpLZ : fxpmath_ConstEwUnaryFn<"CMPLZ">;
|
||||
def fxpmath_EwUnaryFn_CmpGZ : fxpmath_ConstEwUnaryFn<"CMPGZ">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Base classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class fxpmath_Op<string mnemonic, list<OpTrait> traits> :
|
||||
Op<!strconcat("fxpmath.", mnemonic), traits>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Fixed-point (fxp) arithmetic ops used by kernels.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def fxpmath_RoundingDivideByPotFxpOp :
|
||||
fxpmath_Op<"rounding_divide_by_poti", [NoSideEffect, SameValueType]>,
|
||||
Arguments<(ins quant_StorageValueType:$x, I32Attr:$exponent)>,
|
||||
Results<(outs quant_StorageValueType:$y)> {
|
||||
let description = [{
|
||||
Computes integer division by a power-of-two, correctly rounded-to-nearest.
|
||||
Also known as a rounding arithmetic right shift. See
|
||||
gemmlowp::RoundingDivideByPOT for a reference implementation.
|
||||
}];
|
||||
|
||||
let verifier = [{
|
||||
auto verifyExponent = exponent().getSExtValue();
|
||||
if (verifyExponent < 0 || verifyExponent > 31) {
|
||||
return emitOpError("exponent must be in range [0..31]");
|
||||
}
|
||||
return success();
|
||||
}];
|
||||
}
|
||||
|
||||
def fxpmath_SaturatingAddFxpOp :
|
||||
fxpmath_Op<"saturating_addi", [NoSideEffect, SameValueType]>,
|
||||
Arguments<(ins quant_StorageValueType:$x,
|
||||
quant_StorageValueType:$y,
|
||||
I32Attr:$clamp_min,
|
||||
I32Attr:$clamp_max)>,
|
||||
Results<(outs quant_StorageValueType:$sum)> {
|
||||
let description = [{
|
||||
Computes saturating addition of two operands, saturating to the given min
|
||||
and max value. The implementation is responsible for choosing an
|
||||
intermediate register size appropriate to carry out the operation without
|
||||
overflow. See gemmlowp::SaturatingAdd for a reference implementation.
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Real math ops.
|
||||
//
|
||||
// Math ops on real numbers which may have a representation in quantized
|
||||
// arithmetic. It is expected that eligible ops are lowered from a source
|
||||
// dialect to this set of ops prior to the process of converting a compuation
|
||||
// to a quantized form. It is a non-goal of these ops to preserve enough
|
||||
// information to convert back to the higher level, source dialect.
|
||||
//
|
||||
// These ops support either real/floating point or QuantizedTypes as operands
|
||||
// and results. Since not all transformations are supported (globally or
|
||||
// sometimes for specific targets), a computation may end up with
|
||||
// untransformable RealMathOps, in which case they need to be lowered as is
|
||||
// (using floating point math).
|
||||
//
|
||||
// This op set takes advantage of the fact that it is typically trivial to
|
||||
// combine a math function with a compatible bias addition and real-valued
|
||||
// clamp (which can be done at a higher accumulation bit depth).
|
||||
//
|
||||
// In addition, all element-wise unary functions are collapsed into a single
|
||||
// fxpmath_RealUnaryEwOp and selected via an enum-like attribute. Especially at
|
||||
// low bit depths, this makes matching simpler and allows the construction of
|
||||
// generic LUT-based implementations. It also allows specific lowering rules
|
||||
// to consolidate runs of chained unary ops and fuse them to preceding math
|
||||
// ops, potentially allowing them to operate directly on higher precision
|
||||
// intermediates without resorting to lots of custom kernels for common
|
||||
// formulas that can suffer from insufficient precision at low bit depths.
|
||||
//
|
||||
// Comparison operators are modeled as element-wise unary functions (i.e.
|
||||
// CMPZ, CMPNZ, CMPLZ, CMPGZ) intended to follow a sub and output a 1bit
|
||||
// quantized value. It is expected that lowering rules can fuse them with
|
||||
// the preceding sub.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class fxpmath_RealMathOp<string mnemonic, list<OpTrait> traits = [], dag args> :
|
||||
fxpmath_Op<mnemonic, traits>,
|
||||
Arguments<!con(args, (ins
|
||||
fxpmath_ClampValueAttr:$clamp_min, fxpmath_ClampValueAttr:$clamp_max))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Element wise binary real math ops.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class fxpmath_RealBinaryOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
fxpmath_RealMathOp<mnemonic, traits,
|
||||
(ins quant_RealValueType:$x, quant_RealValueType:$y)>,
|
||||
Results<(outs quant_RealValueType:$r)>;
|
||||
|
||||
class fxpmath_RealBinaryBiasOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
fxpmath_RealMathOp<mnemonic, traits,
|
||||
(ins quant_RealValueType:$x, quant_RealValueType:$y,
|
||||
quant_RealValueType:$bias)>,
|
||||
Results<(outs quant_RealValueType:$r)>;
|
||||
|
||||
def fxpmath_RealAddEwOp :
|
||||
fxpmath_RealBinaryOp<"real_add_ew", [NoSideEffect]>;
|
||||
|
||||
def fxpmath_RealSubEwOp :
|
||||
fxpmath_RealBinaryOp<"real_sub_ew", [NoSideEffect]>;
|
||||
|
||||
def fxpmath_RealMulEwOp :
|
||||
fxpmath_RealBinaryOp<"real_mul_ew", [NoSideEffect]>;
|
||||
|
||||
def fxpmath_RealDivEwOp :
|
||||
fxpmath_RealBinaryOp<"real_div_ew", [NoSideEffect]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Element wise unary real math op.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def fxpmath_RealUnaryEwOp :
|
||||
fxpmath_RealMathOp<"real_unary_ew", [NoSideEffect],
|
||||
(ins quant_RealValueType:$x, fxpmath_EwUnaryFnAttr:$fn)>,
|
||||
Results<(outs quant_RealValueType:$r)>;
|
||||
|
||||
#endif // FXPMATH_OPS
|
|
@ -0,0 +1,39 @@
|
|||
//===- Passes.h - Fixed point math passes -----------------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file defines all of the passes owned by the FxpMathOps dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_FXPMATH_PASSES_H
|
||||
#define MLIR_FXPMATH_PASSES_H
|
||||
|
||||
namespace mlir {
|
||||
class FunctionPassBase;
|
||||
|
||||
namespace fxpmath {
|
||||
|
||||
/// Creates a pass that lowers uniform-quantized real math ops to integer
|
||||
/// arithmetic. This will leave unrecognized real math ops as-is and is
|
||||
/// typically followed by a pass that lowers any unrecognized ops to a pure
|
||||
/// floating point form.
|
||||
FunctionPassBase *createLowerUniformRealMathPass();
|
||||
|
||||
} // namespace fxpmath
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_FXPMATH_PASSES_H
|
|
@ -30,15 +30,9 @@ class FunctionPassBase;
|
|||
|
||||
namespace quant {
|
||||
|
||||
/// Creates a pass that lowers quantization related TensorFlow ops into
|
||||
/// the quantization dialect so that express and implied constraints expressed
|
||||
/// at the TensorFlow source level can be represented to the quantization
|
||||
/// system. This will specially handle any TensorFlow op that is useful for
|
||||
/// guiding quantization.
|
||||
///
|
||||
/// Note that if your intent is to compile a TensorFlow graph for floating
|
||||
/// point inference, you should probably not use this pass.
|
||||
FunctionPassBase *createLowerTFPass();
|
||||
/// Creates a pass that converts quantization simulation operations (i.e.
|
||||
/// FakeQuant and those like it) to casts into/out of supported QuantizedTypes.
|
||||
FunctionPassBase *createConvertSimulatedQuantPass();
|
||||
|
||||
/// Creates a pass that converts constants followed by a qbarrier to a
|
||||
/// constant whose value is quantized. This is typically one of the last
|
||||
|
@ -47,12 +41,6 @@ FunctionPassBase *createLowerTFPass();
|
|||
/// destructive and cannot be undone.
|
||||
FunctionPassBase *createConvertConstPass();
|
||||
|
||||
/// Creates a pass that lowers uniform-quantized real math ops to integer
|
||||
/// arithmetic. This will leave unrecognized real math ops as-is and is
|
||||
/// typically followed by a pass that lowers any unrecognized ops to a pure
|
||||
/// floating point form.
|
||||
FunctionPassBase *createLowerUniformRealMathPass();
|
||||
|
||||
} // namespace quant
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -25,86 +25,9 @@
|
|||
#ifdef OP_BASE
|
||||
#else
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Quantization/QuantPredicates.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Quantization type definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class quant_TypedPrimitiveOrContainer<Type etype> :
|
||||
Type<AnyOf<[etype.predicate,
|
||||
TypedTensor<etype>.predicate,
|
||||
TypedVector<etype>.predicate]>,
|
||||
"primitive/tensor/vector of " # etype.description>;
|
||||
|
||||
// An implementation of QuantizedType.
|
||||
def quant_QuantizedType :
|
||||
Type<CPred<"{0}.isa<QuantizedType>()">, "QuantizedType">;
|
||||
|
||||
// A primitive type that can represent a real value. This is either a
|
||||
// floating point value or a quantized type.
|
||||
def quant_RealPrimitiveType :
|
||||
Type<AnyOf<[Float.predicate, quant_QuantizedType.predicate]>,
|
||||
"real valued primitive (float or quantized type)">;
|
||||
|
||||
// A primitive type that can represent a storage value. This is either an
|
||||
// integer or quantized type.
|
||||
def quant_StoragePrimitiveType :
|
||||
Type<AnyOf<[Integer.predicate, quant_QuantizedType.predicate]>,
|
||||
"quantized storage primitive (integer or quantized type)">;
|
||||
|
||||
// A primitive or container of RealPrimitiveType.
|
||||
def quant_RealValueType :
|
||||
quant_TypedPrimitiveOrContainer<quant_RealPrimitiveType>;
|
||||
|
||||
// A primitive or container of StoragePrimitiveType.
|
||||
def quant_StorageValueType :
|
||||
quant_TypedPrimitiveOrContainer<quant_StoragePrimitiveType>;
|
||||
|
||||
// Either a real valued or storage primitive or container type.
|
||||
def quant_RealOrStorageValueType :
|
||||
Type<AnyOf<[quant_RealValueType.predicate,
|
||||
quant_StorageValueType.predicate]>>;
|
||||
|
||||
// An implementation of UniformQuantizedType.
|
||||
def quant_UniformQuantizedType :
|
||||
Type<CPred<"{0}.isa<UniformQuantizedType>()">, "UniformQuantizedType">;
|
||||
|
||||
// Predicate for detecting a container or primitive of UniformQuantizedType.
|
||||
def quant_UniformQuantizedValueType :
|
||||
quant_TypedPrimitiveOrContainer<quant_UniformQuantizedType>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Attributes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Real value for an (inclusive) min/max clamp limit.
|
||||
def quant_ClampValueAttr : OptionalAttr<F64Attr>;
|
||||
|
||||
// Element-wise activation function to apply.
|
||||
// Note that RELU activations are not here: they are expressed as clamps.
|
||||
def quant_EwUnaryFnAttr :
|
||||
StringBasedAttr<CPred<"true">, "element-wise unary function"> {
|
||||
let returnType = [{ StringRef }];
|
||||
let defaultValue = "IDENTITY";
|
||||
}
|
||||
|
||||
class quant_ConstEwUnaryFn<string val> : ConstantAttr<quant_EwUnaryFnAttr, val>;
|
||||
def quant_EwUnaryFn_Identity: quant_ConstEwUnaryFn<"IDENTITY">;
|
||||
def quant_EwUnaryFn_Tanh : quant_ConstEwUnaryFn<"TANH">;
|
||||
def quant_EwUnaryFn_Sigmoid : quant_ConstEwUnaryFn<"SIGMOID">;
|
||||
def quant_EwUnaryFn_Exp : quant_ConstEwUnaryFn<"EXP">;
|
||||
def quant_EwUnaryFn_Log : quant_ConstEwUnaryFn<"LOG">;
|
||||
def quant_EwUnaryFn_Neg : quant_ConstEwUnaryFn<"NEG">;
|
||||
def quant_EwUnaryFn_Rsqrt : quant_ConstEwUnaryFn<"RSQRT">;
|
||||
def quant_EwUnaryFn_Sin : quant_ConstEwUnaryFn<"SIN">;
|
||||
def quant_EwUnaryFn_Square : quant_ConstEwUnaryFn<"SQUARE">;
|
||||
def quant_EwUnaryFn_Sqrt : quant_ConstEwUnaryFn<"SQRT">;
|
||||
def quant_EwUnaryFn_CmpZ : quant_ConstEwUnaryFn<"CMPZ">;
|
||||
def quant_EwUnaryFn_CmpNZ : quant_ConstEwUnaryFn<"CMPNZ">;
|
||||
def quant_EwUnaryFn_CmpLZ : quant_ConstEwUnaryFn<"CMPLZ">;
|
||||
def quant_EwUnaryFn_CmpGZ : quant_ConstEwUnaryFn<"CMPGZ">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Base classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -170,116 +93,34 @@ def quant_StorageCastOp :
|
|||
Results<(outs quant_RealOrStorageValueType)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Integral arithmetic ops used by kernels.
|
||||
// Training integration ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def quant_RoundingDivideByPotIOp :
|
||||
quant_Op<"rounding_divide_by_poti", [NoSideEffect, SameValueType]>,
|
||||
Arguments<(ins quant_StorageValueType:$x, I32Attr:$exponent)>,
|
||||
Results<(outs quant_StorageValueType:$y)> {
|
||||
let description = [{
|
||||
Computes integer division by a power-of-two, correctly rounded-to-nearest.
|
||||
Also known as a rounding arithmetic right shift. See
|
||||
gemmlowp::RoundingDivideByPOT for a reference implementation.
|
||||
}];
|
||||
def quant_ConstFakeQuant : quant_Op<"const_fake_quant",
|
||||
[NoSideEffect]> {
|
||||
let summary =
|
||||
"Simulates the effect of uniform quantization with const range.";
|
||||
|
||||
let verifier = [{
|
||||
auto verifyExponent = exponent().getSExtValue();
|
||||
if (verifyExponent < 0 || verifyExponent > 31) {
|
||||
return emitOpError("exponent must be in range [0..31]");
|
||||
}
|
||||
return success();
|
||||
}];
|
||||
let description = [{
|
||||
Given a const min, max, num_bits and narrow_range attribute, applies the same
|
||||
uniform quantization simulation as is done by the TensorFlow
|
||||
fake_quant_with_min_max_args op. See the fakeQuantAttrsToType() utility
|
||||
method and the quant-convert-simulated-quantization pass for futher details.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
F32Tensor:$inputs,
|
||||
F32Attr:$min,
|
||||
F32Attr:$max,
|
||||
// The bitwidth of the quantization; between 2 and 16, inclusive.
|
||||
I64Attr:$num_bits,
|
||||
// Quantization range starts from 0 or 1; starts from 1 if true.
|
||||
DefaultValuedAttr<BoolAttr, "false">:$narrow_range
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
F32Tensor:$outputs
|
||||
);
|
||||
}
|
||||
|
||||
def quant_SaturatingAddIOp :
|
||||
quant_Op<"saturating_addi", [NoSideEffect, SameValueType]>,
|
||||
Arguments<(ins quant_StorageValueType:$x,
|
||||
quant_StorageValueType:$y,
|
||||
I32Attr:$clamp_min,
|
||||
I32Attr:$clamp_max)>,
|
||||
Results<(outs quant_StorageValueType:$sum)> {
|
||||
let description = [{
|
||||
Computes saturating addition of two operands, saturating to the given min
|
||||
and max value. The implementation is responsible for choosing an
|
||||
intermediate register size appropriate to carry out the operation without
|
||||
overflow. See gemmlowp::SaturatingAdd for a reference implementation.
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Real math ops.
|
||||
//
|
||||
// Math ops on real numbers which may have a representation in quantized
|
||||
// arithmetic. It is expected that eligible ops are lowered from a source
|
||||
// dialect to this set of ops prior to the process of converting a compuation
|
||||
// to a quantized form. It is a non-goal of these ops to preserve enough
|
||||
// information to convert back to the higher level, source dialect.
|
||||
//
|
||||
// These ops support either real/floating point or QuantizedTypes as operands
|
||||
// and results. Since not all transformations are supported (globally or
|
||||
// sometimes for specific targets), a computation may end up with
|
||||
// untransformable RealMathOps, in which case they need to be lowered as is
|
||||
// (using floating point math).
|
||||
//
|
||||
// This op set takes advantage of the fact that it is typically trivial to
|
||||
// combine a math function with a compatible bias addition and real-valued
|
||||
// clamp (which can be done at a higher accumulation bit depth).
|
||||
//
|
||||
// In addition, all element-wise unary functions are collapsed into a single
|
||||
// quant_RealUnaryEwOp and selected via an enum-like attribute. Especially at
|
||||
// low bit depths, this makes matching simpler and allows the construction of
|
||||
// generic LUT-based implementations. It also allows specific lowering rules
|
||||
// to consolidate runs of chained unary ops and fuse them to preceding math
|
||||
// ops, potentially allowing them to operate directly on higher precision
|
||||
// intermediates without resorting to lots of custom kernels for common
|
||||
// formulas that can suffer from insufficient precision at low bit depths.
|
||||
//
|
||||
// Comparison operators are modeled as element-wise unary functions (i.e.
|
||||
// CMPZ, CMPNZ, CMPLZ, CMPGZ) intended to follow a sub and output a 1bit
|
||||
// quantized value. It is expected that lowering rules can fuse them with
|
||||
// the preceding sub.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class quant_RealMathOp<string mnemonic, list<OpTrait> traits = [], dag args> :
|
||||
quant_Op<mnemonic, traits>,
|
||||
Arguments<!con(args, (ins
|
||||
quant_ClampValueAttr:$clamp_min, quant_ClampValueAttr:$clamp_max))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Element wise binary real math ops.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class quant_RealBinaryOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
quant_RealMathOp<mnemonic, traits,
|
||||
(ins quant_RealValueType:$x, quant_RealValueType:$y)>,
|
||||
Results<(outs quant_RealValueType:$r)>;
|
||||
|
||||
class quant_RealBinaryBiasOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
quant_RealMathOp<mnemonic, traits,
|
||||
(ins quant_RealValueType:$x, quant_RealValueType:$y,
|
||||
quant_RealValueType:$bias)>,
|
||||
Results<(outs quant_RealValueType:$r)>;
|
||||
|
||||
def quant_RealAddEwOp :
|
||||
quant_RealBinaryOp<"real_add_ew", [NoSideEffect]>;
|
||||
|
||||
def quant_RealSubEwOp :
|
||||
quant_RealBinaryOp<"real_sub_ew", [NoSideEffect]>;
|
||||
|
||||
def quant_RealMulEwOp :
|
||||
quant_RealBinaryOp<"real_mul_ew", [NoSideEffect]>;
|
||||
|
||||
def quant_RealDivEwOp :
|
||||
quant_RealBinaryOp<"real_div_ew", [NoSideEffect]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Element wise unary real math op.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def quant_RealUnaryEwOp :
|
||||
quant_RealMathOp<"real_unary_ew", [NoSideEffect],
|
||||
(ins quant_RealValueType:$x, quant_EwUnaryFnAttr:$fn)>,
|
||||
Results<(outs quant_RealValueType:$r)>;
|
||||
|
||||
#endif // QUANTIZATION_OPS
|
||||
#endif // QUANT_OPS
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
//===- QuantPredicates.td - Predicates for dialect types ---*- tablegen -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// Predicates for types in the Quantization dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifdef QUANTIZATION_PREDICATES_
|
||||
#else
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Quantization type definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class quant_TypedPrimitiveOrContainer<Type etype> :
|
||||
Type<AnyOf<[etype.predicate,
|
||||
TypedTensor<etype>.predicate,
|
||||
TypedVector<etype>.predicate]>,
|
||||
"primitive/tensor/vector of " # etype.description>;
|
||||
|
||||
// An implementation of QuantizedType.
|
||||
def quant_QuantizedType :
|
||||
Type<CPred<"{0}.isa<mlir::quant::QuantizedType>()">, "QuantizedType">;
|
||||
|
||||
// A primitive type that can represent a real value. This is either a
|
||||
// floating point value or a quantized type.
|
||||
def quant_RealPrimitiveType :
|
||||
Type<AnyOf<[Float.predicate, quant_QuantizedType.predicate]>,
|
||||
"real valued primitive (float or quantized type)">;
|
||||
|
||||
// A primitive type that can represent a storage value. This is either an
|
||||
// integer or quantized type.
|
||||
def quant_StoragePrimitiveType :
|
||||
Type<AnyOf<[Integer.predicate, quant_QuantizedType.predicate]>,
|
||||
"quantized storage primitive (integer or quantized type)">;
|
||||
|
||||
// A primitive or container of RealPrimitiveType.
|
||||
def quant_RealValueType :
|
||||
quant_TypedPrimitiveOrContainer<quant_RealPrimitiveType>;
|
||||
|
||||
// A primitive or container of StoragePrimitiveType.
|
||||
def quant_StorageValueType :
|
||||
quant_TypedPrimitiveOrContainer<quant_StoragePrimitiveType>;
|
||||
|
||||
// Either a real valued or storage primitive or container type.
|
||||
def quant_RealOrStorageValueType :
|
||||
Type<AnyOf<[quant_RealValueType.predicate,
|
||||
quant_StorageValueType.predicate]>>;
|
||||
|
||||
// An implementation of UniformQuantizedType.
|
||||
def quant_UniformQuantizedType :
|
||||
Type<CPred<"{0}.isa<UniformQuantizedType>()">, "UniformQuantizedType">;
|
||||
|
||||
// Predicate for detecting a container or primitive of UniformQuantizedType.
|
||||
def quant_UniformQuantizedValueType :
|
||||
quant_TypedPrimitiveOrContainer<quant_UniformQuantizedType>;
|
||||
|
||||
#endif // QUANTIZATION_PREDICATES_
|
|
@ -0,0 +1,24 @@
|
|||
//===- DialectRegistration.cpp - Register FxpMathOps dialect --------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/FxpMathOps/FxpMathOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::fxpmath;
|
||||
|
||||
// Static initialization for the fxpmath ops dialect registration.
|
||||
static mlir::DialectRegistration<FxpMathOpsDialect> FxpMathOps;
|
|
@ -0,0 +1,38 @@
|
|||
//===- FxpMathOps.cpp - Op implementation for FxpMathOps ------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/FxpMathOps/FxpMathOps.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Quantization/QuantOps.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::fxpmath;
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/FxpMathOps/FxpMathOps.cpp.inc"
|
||||
|
||||
FxpMathOpsDialect::FxpMathOpsDialect(MLIRContext *context)
|
||||
: Dialect(/*name=*/"fxpmath", context) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/FxpMathOps/FxpMathOps.cpp.inc"
|
||||
>();
|
||||
}
|
|
@ -15,15 +15,16 @@
|
|||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/FxpMathOps/FxpMathOps.h"
|
||||
#include "mlir/FxpMathOps/Passes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Quantization/Passes.h"
|
||||
#include "mlir/Quantization/QuantOps.h"
|
||||
#include "mlir/Quantization/UniformSupport.h"
|
||||
|
||||
#include <functional>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::fxpmath;
|
||||
using namespace mlir::quant;
|
||||
|
||||
namespace {
|
||||
|
@ -198,7 +199,7 @@ static LogicalResult tryRewriteFixedPOTAddEw(const RealBinaryOpInfo &constInfo,
|
|||
if (rhsRightShift != 0) {
|
||||
rhsStorageValue =
|
||||
rewriter
|
||||
.create<RoundingDivideByPotIOp>(
|
||||
.create<RoundingDivideByPotFxpOp>(
|
||||
mathOp->getLoc(), rhsStorageValue,
|
||||
IntegerAttr::get(IntegerType::get(32, rewriter.getContext()),
|
||||
rhsRightShift))
|
||||
|
@ -206,7 +207,7 @@ static LogicalResult tryRewriteFixedPOTAddEw(const RealBinaryOpInfo &constInfo,
|
|||
}
|
||||
|
||||
// Add.
|
||||
Value *sumValue = rewriter.create<SaturatingAddIOp>(
|
||||
Value *sumValue = rewriter.create<SaturatingAddFxpOp>(
|
||||
mathOp->getLoc(), lhsStorageValue, rhsStorageValue, clampMinMax.first,
|
||||
clampMinMax.second);
|
||||
|
||||
|
@ -255,5 +256,5 @@ FunctionPassBase *createLowerUniformRealMathPass() {
|
|||
}
|
||||
|
||||
static PassRegistration<LowerUniformRealMathPass>
|
||||
pass("quant-lower-uniform-real-math",
|
||||
pass("fxpmath-lower-uniform-real-math",
|
||||
"Lowers uniform-quantized real math ops to integer arithmetic.");
|
|
@ -1,4 +1,4 @@
|
|||
//===- LowerTF.cpp - Passes for lowering from TensorFlow ------------------===//
|
||||
//===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
|
@ -23,44 +23,42 @@
|
|||
#include "mlir/Quantization/Passes.h"
|
||||
#include "mlir/Quantization/QuantOps.h"
|
||||
#include "mlir/Quantization/UniformSupport.h"
|
||||
#include "mlir/TensorFlow/TFOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::quant;
|
||||
|
||||
namespace {
|
||||
|
||||
class LowerTFPass : public FunctionPass<LowerTFPass> {
|
||||
class ConvertSimulatedQuantPass
|
||||
: public FunctionPass<ConvertSimulatedQuantPass> {
|
||||
public:
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Rewrites TensorFlow FakeQuantWithMinMaxArgs into a qbarrier/dbarrier pair.
|
||||
class FakeQuantWithMinMaxArgsRewrite : public RewritePattern {
|
||||
/// Rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
|
||||
class ConstFakeQuantRewrite : public RewritePattern {
|
||||
public:
|
||||
bool *hadFailure;
|
||||
|
||||
FakeQuantWithMinMaxArgsRewrite(MLIRContext *context, bool *hadFailure)
|
||||
: RewritePattern(TF::FakeQuantWithMinMaxArgsOp::getOperationName(), 1,
|
||||
context),
|
||||
ConstFakeQuantRewrite(MLIRContext *context, bool *hadFailure)
|
||||
: RewritePattern(ConstFakeQuant::getOperationName(), 1, context),
|
||||
hadFailure(hadFailure) {}
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// TODO: If this pattern comes up more frequently, consider adding core
|
||||
// support for failable rewrites.
|
||||
if (failableRewrite(op, rewriter)) {
|
||||
*hadFailure = true;
|
||||
}
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
bool failableRewrite(Operation *op, PatternRewriter &rewriter) const {
|
||||
auto fqOp = op->template cast<TF::FakeQuantWithMinMaxArgsOp>();
|
||||
auto fqOp = op->cast<ConstFakeQuant>();
|
||||
|
||||
auto converter =
|
||||
ExpressedToUniformQuantizedConverter::forInputType(fqOp.getType());
|
||||
|
@ -93,20 +91,23 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
void LowerTFPass::runOnFunction() {
|
||||
void ConvertSimulatedQuantPass::runOnFunction() {
|
||||
bool hadFailure = false;
|
||||
OwningRewritePatternList patterns;
|
||||
auto &func = getFunction();
|
||||
auto *context = &getContext();
|
||||
patterns.push_back(
|
||||
llvm::make_unique<FakeQuantWithMinMaxArgsRewrite>(context, &hadFailure));
|
||||
llvm::make_unique<ConstFakeQuantRewrite>(context, &hadFailure));
|
||||
applyPatternsGreedily(func, std::move(patterns));
|
||||
if (hadFailure)
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
FunctionPassBase *createLowerTFPass() { return new LowerTFPass(); }
|
||||
FunctionPassBase *createConvertSimulatedQuantPass() {
|
||||
return new ConvertSimulatedQuantPass();
|
||||
}
|
||||
|
||||
static PassRegistration<LowerTFPass>
|
||||
pass("quant-lower-tf",
|
||||
"Lowers TensorFlow constraint ops to the quantization dialect");
|
||||
static PassRegistration<ConvertSimulatedQuantPass>
|
||||
pass("quant-convert-simulated-quantization",
|
||||
"Converts training-time simulated quantization ops to corresponding "
|
||||
"quantize/dequantize casts.");
|
|
@ -1,18 +1,18 @@
|
|||
// RUN: mlir-opt %s -split-input-file -quant-lower-uniform-real-math | FileCheck %s --dump-input=fail
|
||||
// RUN: mlir-opt %s -split-input-file -fxpmath-lower-uniform-real-math | FileCheck %s --dump-input=fail
|
||||
|
||||
// -----
|
||||
// Verify lowering when operands and result have the same fixedpoint pot scale.
|
||||
// CHECK-LABEL: real_addew_fixedpoint_same_scale
|
||||
// CHECK: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %2 = "quant.saturating_addi"(%0, %1) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %2 = "fxpmath.saturating_addi"(%0, %1) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %3 = "quant.scast"(%2) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
|
||||
// CHECK-NEXT: return %3 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
|
||||
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
func @real_addew_fixedpoint_same_scale(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
%0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
|
@ -21,15 +21,15 @@ func @real_addew_fixedpoint_same_scale(%arg0 : !type_lhs, %arg1: !type_rhs) -> !
|
|||
// CHECK-LABEL: real_addew_fixedpoint_rhs_shift
|
||||
// CHECK: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %2 = "quant.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %3 = "quant.saturating_addi"(%0, %2) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %2 = "fxpmath.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %3 = "fxpmath.saturating_addi"(%0, %2) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %4 = "quant.scast"(%3) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
|
||||
// CHECK-NEXT: return %4 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
|
||||
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{7.8125e-03}">>
|
||||
!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
func @real_addew_fixedpoint_rhs_shift(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
%0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
|
@ -38,15 +38,15 @@ func @real_addew_fixedpoint_rhs_shift(%arg0 : !type_lhs, %arg1: !type_rhs) -> !t
|
|||
// CHECK-LABEL: real_addew_fixedpoint_lhs_shift
|
||||
// CHECK: %0 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %1 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %2 = "quant.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %3 = "quant.saturating_addi"(%0, %2) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %2 = "fxpmath.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %3 = "fxpmath.saturating_addi"(%0, %2) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %4 = "quant.scast"(%3) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
|
||||
// CHECK-NEXT: return %4 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
|
||||
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{7.8125e-03}">>
|
||||
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
func @real_addew_fixedpoint_lhs_shift(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
%0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
|
@ -56,15 +56,15 @@ func @real_addew_fixedpoint_lhs_shift(%arg0 : !type_lhs, %arg1: !type_rhs) -> !t
|
|||
// CHECK-LABEL: real_addew_fixedpoint_clamp
|
||||
// CHECK: %0 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %1 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %2 = "quant.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %3 = "quant.saturating_addi"(%0, %2) {clamp_max: 64 : i32, clamp_min: -64 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %2 = "fxpmath.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %3 = "fxpmath.saturating_addi"(%0, %2) {clamp_max: 64 : i32, clamp_min: -64 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %4 = "quant.scast"(%3) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
|
||||
// CHECK-NEXT: return %4 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
|
||||
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{7.8125e-03}">>
|
||||
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
func @real_addew_fixedpoint_clamp(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
%0 = "quant.real_add_ew"(%arg0, %arg1) { clamp_min:-4.0, clamp_max:4.0 }
|
||||
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) { clamp_min:-4.0, clamp_max:4.0 }
|
||||
: (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
@ -76,8 +76,8 @@ func @real_addew_fixedpoint_clamp(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_
|
|||
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
func @real_addew_unquantized_lhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// CHECK: %0 = "quant.real_add_ew"(%arg0, %arg1)
|
||||
%0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
// CHECK: %0 = "fxpmath.real_add_ew"(%arg0, %arg1)
|
||||
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
|
@ -88,8 +88,8 @@ func @real_addew_unquantized_lhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_r
|
|||
!type_rhs = type tensor<4xf32>
|
||||
!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
func @real_addew_unquantized_rhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// CHECK: %0 = "quant.real_add_ew"(%arg0, %arg1)
|
||||
%0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
// CHECK: %0 = "fxpmath.real_add_ew"(%arg0, %arg1)
|
||||
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
|
@ -100,7 +100,7 @@ func @real_addew_unquantized_rhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_r
|
|||
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
!type_result = type tensor<4xf32>
|
||||
func @real_addew_unquantized_result(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// CHECK: %0 = "quant.real_add_ew"(%arg0, %arg1)
|
||||
%0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
// CHECK: %0 = "fxpmath.real_add_ew"(%arg0, %arg1)
|
||||
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
|
@ -1,12 +1,11 @@
|
|||
// RUN: mlir-opt %s -split-input-file -verify -quant-lower-tf
|
||||
// RUN: mlir-opt %s -split-input-file -verify -quant-convert-simulated-quantization
|
||||
|
||||
// -----
|
||||
// TODO(laurenzo): move this test to the TensorFlow/tf-ops-invalid.mlir
|
||||
// Verify that a mismatched range errors.
|
||||
func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
||||
^bb0(%arg0: tensor<8x4x3xf32>):
|
||||
// expected-error@+1 {{op range failed to straddle zero: [1.100000,1.500000]}}
|
||||
%0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {
|
||||
// expected-error@+1 {{FakeQuant range must straddle zero: [1.100000,1.500000]}}
|
||||
%0 = "quant.const_fake_quant"(%arg0) {
|
||||
min: 1.1, max: 1.5, num_bits: 8
|
||||
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
|
||||
return %0 : tensor<8x4x3xf32>
|
||||
|
@ -16,20 +15,19 @@ func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
|||
// Verify that a valid range errors.
|
||||
func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
||||
^bb0(%arg0: tensor<8x4x3xf32>):
|
||||
// expected-error@+1 {{op range is invalid: [1.100000,1.000000}}
|
||||
%0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {
|
||||
// expected-error@+1 {{FakeQuant range must straddle zero: [1.100000,1.000000}}
|
||||
%0 = "quant.const_fake_quant"(%arg0) {
|
||||
min: 1.1, max: 1.0, num_bits: 8
|
||||
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
|
||||
return %0 : tensor<8x4x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// TODO(laurenzo): move this test to the TensorFlow/tf-ops-invalid.mlir
|
||||
// Unsupported quantizable type (i1 is currently not a supported element type).
|
||||
func @fakeQuantArgs(tensor<8x4x3xi1>) -> tensor<8x4x3xi1> {
|
||||
^bb0(%arg0: tensor<8x4x3xi1>):
|
||||
// expected-error@+1 {{op operand #0 must be tensor of 32-bit float values}}
|
||||
%0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {
|
||||
%0 = "quant.const_fake_quant"(%arg0) {
|
||||
min: 1.1, max: 1.0, num_bits: 8
|
||||
} : (tensor<8x4x3xi1>) -> tensor<8x4x3xi1>
|
||||
return %0 : tensor<8x4x3xi1>
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt %s -split-input-file -quant-lower-tf | FileCheck %s --dump-input=fail
|
||||
// RUN: mlir-opt %s -split-input-file -quant-convert-simulated-quantization | FileCheck %s --dump-input=fail
|
||||
|
||||
// -----
|
||||
// Verifies a quint8 asymmetric 0..1 range.
|
||||
|
@ -9,7 +9,7 @@ func @fakeQuantArgs_Quint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
|||
// CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[u8:f32]{0.0039215686274509803}">>
|
||||
// CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8:f32]{0.0039215686274509803}">>)
|
||||
// CHECK-SAME: -> tensor<8x4x3xf32>
|
||||
%0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {
|
||||
%0 = "quant.const_fake_quant"(%arg0) {
|
||||
min: 0.0, max: 1.0, num_bits: 8
|
||||
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
|
||||
return %0 : tensor<8x4x3xf32>
|
||||
|
@ -24,7 +24,7 @@ func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
|||
// CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[u8(1:255):f32]{0.003937007874015748:1}">>
|
||||
// CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8(1:255):f32]{0.003937007874015748:1}">>)
|
||||
// CHECK-SAME: -> tensor<8x4x3xf32>
|
||||
%0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {
|
||||
%0 = "quant.const_fake_quant"(%arg0) {
|
||||
min: 0.0, max: 1.0, num_bits: 8, narrow_range: true
|
||||
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
|
||||
return %0 : tensor<8x4x3xf32>
|
||||
|
@ -39,7 +39,7 @@ func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32
|
|||
// CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[u8:f32]{7.812500e-03:128}">>
|
||||
// CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8:f32]{7.812500e-03:128}">>)
|
||||
// CHECK-SAME: -> tensor<8x4x3xf32>
|
||||
%0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {
|
||||
%0 = "quant.const_fake_quant"(%arg0) {
|
||||
min: -1.0, max: 0.9921875, num_bits: 8, narrow_range: false
|
||||
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
|
||||
return %0 : tensor<8x4x3xf32>
|
||||
|
@ -55,7 +55,7 @@ func @fakeQuantArgs_Qint16_Symmetric(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
|||
// CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[i16:f32]{3.05175781185626E-5}">>
|
||||
// CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[i16:f32]{3.05175781185626E-5}">>)
|
||||
// CHECK-SAME: -> tensor<8x4x3xf32>
|
||||
%0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {
|
||||
%0 = "quant.const_fake_quant"(%arg0) {
|
||||
min: -1.0, max: 0.999969482, num_bits: 16
|
||||
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
|
||||
return %0 : tensor<8x4x3xf32>
|
||||
|
@ -70,7 +70,7 @@ func @fakeQuantArgs_UnrankedTensor(tensor<f32>) -> tensor<f32> {
|
|||
// CHECK-SAME: -> tensor<!quant<"uniform[u8:f32]{0.0039215686274509803}">>
|
||||
// CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<!quant<"uniform[u8:f32]{0.0039215686274509803}">>)
|
||||
// CHECK-SAME: -> tensor<f32>
|
||||
%0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {
|
||||
%0 = "quant.const_fake_quant"(%arg0) {
|
||||
min: 0.0, max: 1.0, num_bits: 8
|
||||
} : (tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
Loading…
Reference in New Issue