forked from OSchip/llvm-project
Remove FxpMathOps dialect and Quantizer tool.
Summary: * Removal of FxpMathOps was discussed on the mailing list. * Will send a courtesy note about also removing the Quantizer (which had some dependencies on FxpMathOps). * These were only ever used for experimental purposes and we know how to get them back from history as needed. * There is a new proposal for more generalized quantization tooling, so moving these older experiments out of the way helps clean things up. Subscribers: mgorny, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, grosul1, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D77479
This commit is contained in:
parent
da4ffc64e4
commit
f5deb0878d
|
@ -12,18 +12,10 @@ This document describes the available MLIR passes and their contracts.
|
|||
|
||||
[include "ConversionPasses.md"]
|
||||
|
||||
## Quantizer Passes
|
||||
|
||||
[include "QuantizerPasses.md"]
|
||||
|
||||
## `affine` Dialect Passes
|
||||
|
||||
[include "AffinePasses.md"]
|
||||
|
||||
## `fxpmath` Dialect Passes
|
||||
|
||||
[include "FxpMathPasses.md"]
|
||||
|
||||
## `gpu` Dialect Passes
|
||||
|
||||
[include "GPUPasses.md"]
|
||||
|
|
|
@ -188,23 +188,6 @@ MLIR:
|
|||
* Passes and tools exist to convert directly from the *TensorFlow* dialect
|
||||
to the TFLite quantized operation set.
|
||||
|
||||
* [*FxpMath* dialect](#fxpmath-dialect) containing (experimental) generalized
|
||||
representations of fixed-point math operations and conversions:
|
||||
|
||||
* [Real math ops](#real-math-ops) representing common combinations of
|
||||
arithmetic operations that closely match corresponding fixed-point math
|
||||
concepts (as opposed to being spread across multiple ops as is typical
|
||||
in source dialects).
|
||||
* [Fixed-point math ops](#fixed-point-math-ops) that for carrying out
|
||||
computations on integers, as are typically needed by uniform
|
||||
quantization schemes.
|
||||
* Passes to lower from real math operations to fixed-point math operations.
|
||||
|
||||
* [Solver tools](#solver-tools) which can (experimentally and generically
|
||||
operate on computations expressed in the *FxpMath* dialect in order to
|
||||
convert from floating point types to appropriate *QuantizedTypes*, allowing
|
||||
the computation to be further lowered to integral math operations.
|
||||
|
||||
Not every application of quantization will use all of these facilities. Specifically, the
|
||||
TensorFlow to TensorFlow Lite conversion uses the QuantizedTypes but has its own
|
||||
operations for type conversion and expression of the supporting math.
|
||||
|
@ -279,81 +262,3 @@ TODO : Flesh this out
|
|||
1. Run quantization pass that take (tfl.DQ (for both input and weights) -> op
|
||||
-> tfl.Q) and replaces with (op). Also replace (constant_float -> tfl.Q)
|
||||
with (constant_quant).
|
||||
|
||||
## FxpMath dialect
|
||||
|
||||
### Real math operations
|
||||
|
||||
Note that these all support explicit clamps, which allows for simple fusions and
|
||||
representation of some common sequences quantization-compatible math. Of
|
||||
addition, some support explicit biases, which are often represented as separate
|
||||
adds in source dialects.
|
||||
|
||||
TODO: This operation set is still evolving and needs to be completed.
|
||||
|
||||
* RealBinaryOp
|
||||
* RealAddEwOp
|
||||
* RealSubEwOp
|
||||
* RealMulEwOp
|
||||
* RealDivEwOp
|
||||
* RealUnaryOp
|
||||
* IDENTITY
|
||||
* TANH
|
||||
* SIGMOID
|
||||
* EXP
|
||||
* LOG
|
||||
* NEG
|
||||
* RSQRT
|
||||
* SIN
|
||||
* SQUARE
|
||||
* SQRT
|
||||
* CMPZ
|
||||
* CMPNZ
|
||||
* CMPLZ
|
||||
* CMPGZ
|
||||
|
||||
### Fixed-point math operationss
|
||||
|
||||
TODO: This operation set only has enough operations to lower a simple power-of-two
|
||||
RealAddEwOp.
|
||||
|
||||
* RoundingDivideByPotFxpOp
|
||||
* SaturatingAddFxpOp
|
||||
|
||||
## Solver tools
|
||||
|
||||
Solver tools exist to analyze an MLIR-computation, expressed in either a
|
||||
supported source dialect or in the *real math ops* set and solve for appropriate
|
||||
QuantizedTypes that allow the computation to be lowered to integral math.
|
||||
|
||||
These tools are an active area of work and may be expanded in the future to
|
||||
adjacent areas such as solving for transformations to other kinds of lower
|
||||
precision types (i.e. bfloat16 or fp16).
|
||||
|
||||
Solver tools are expected to operate in several modes, depending on the
|
||||
computation and the training characteristics of the model:
|
||||
|
||||
* *Transform* : With all available information in the MLIR computation, infer
|
||||
boundaries where the computation can be carried out with integral math and
|
||||
change types accordingly to appropriate QuantizedTypes:
|
||||
|
||||
* For passthrough ops which do not perform active math, change them to
|
||||
operate directly on the storage type, converting in and out at the edges
|
||||
via scast operations.
|
||||
* For operations that have the *Quantizable* trait, the type can be set directly.
|
||||
This includes operations from the [real math ops set]{#real-math-ops}.
|
||||
* For others, encase them in appropriate dcast/qcast operations, presuming that
|
||||
some follow-on pass will know what to do with them.
|
||||
|
||||
* *Instrument* : Most of the time, there are not sufficient implied
|
||||
constraints within a computation to perform many transformations. For this
|
||||
reason, the solver can insert instrumentation operations at points where additional
|
||||
runtime statistics may yield solutions. It is expected that such
|
||||
computations will be lowered as-is for execution, run over an appropriate
|
||||
evaluation set, and statistics at each instrumentation point made available for a
|
||||
future invocation of the solver.
|
||||
|
||||
* *Simplify* : A variety of passes and simplifications are applied once
|
||||
QuantizedTypes are added in order to arrive at a computation that is
|
||||
expressed in as much integral math, with the fewest number of casts as
|
||||
possible.
|
||||
|
|
|
@ -3,7 +3,6 @@ get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
|
|||
set(LIBS
|
||||
${dialect_libs}
|
||||
${conversion_libs}
|
||||
MLIRQuantizerTransforms
|
||||
MLIROptLib
|
||||
MLIRStandalone
|
||||
)
|
||||
|
|
|
@ -2,5 +2,4 @@ add_subdirectory(Conversion)
|
|||
add_subdirectory(Dialect)
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Interfaces)
|
||||
add_subdirectory(Quantizer)
|
||||
add_subdirectory(Transforms)
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
add_subdirectory(Affine)
|
||||
add_subdirectory(AVX512)
|
||||
add_subdirectory(FxpMathOps)
|
||||
add_subdirectory(GPU)
|
||||
add_subdirectory(Linalg)
|
||||
add_subdirectory(LLVMIR)
|
||||
|
|
|
@ -1,8 +0,0 @@
|
|||
add_mlir_dialect(FxpMathOps fxpmath)
|
||||
add_mlir_doc(FxpMathOps -gen-dialect-doc FxpMathDialect Dialects/)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
add_public_tablegen_target(MLIRFxpMathPassIncGen)
|
||||
|
||||
add_mlir_doc(Passes -gen-pass-doc FxpMathPasses ./)
|
|
@ -1,28 +0,0 @@
|
|||
//===- FxpMathOps.h - Fixed point ops ---------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_FXPMATHOPS_FXPMATHOPS_H_
|
||||
#define MLIR_DIALECT_FXPMATHOPS_FXPMATHOPS_H_
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Interfaces/SideEffects.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace fxpmath {
|
||||
|
||||
#include "mlir/Dialect/FxpMathOps/FxpMathOpsDialect.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/FxpMathOps/FxpMathOps.h.inc"
|
||||
|
||||
} // namespace fxpmath
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_FXPMATHOPS_FXPMATHOPS_H_
|
|
@ -1,278 +0,0 @@
|
|||
//===- FxpMathOps.td - Fixed point ops --------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This is the operation definition file for fixed point ops (and real
|
||||
// equivalents).
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef DIALECT_FXPMATHOPS_FXPMATH_OPS_
|
||||
#define DIALECT_FXPMATHOPS_FXPMATH_OPS_
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/Quant/QuantOpsBase.td"
|
||||
include "mlir/Interfaces/SideEffects.td"
|
||||
|
||||
def FxpMathOps_Dialect : Dialect {
|
||||
let name = "fxpmath";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Attributes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Real value for an (inclusive) min/max clamp limit.
|
||||
def fxpmath_ClampValueAttr : OptionalAttr<F64Attr>;
|
||||
|
||||
// Element-wise activation function to apply.
|
||||
// Note that RELU activations are not here: they are expressed as clamps.
|
||||
def fxpmath_EwUnaryFnAttr :
|
||||
StringBasedAttr<CPred<"true">, "element-wise unary function"> {
|
||||
let returnType = [{ StringRef }];
|
||||
let defaultValue = "IDENTITY";
|
||||
}
|
||||
|
||||
class fxpmath_ConstEwUnaryFn<string val> : ConstantAttr<fxpmath_EwUnaryFnAttr, val>;
|
||||
def fxpmath_EwUnaryFn_Abs : fxpmath_ConstEwUnaryFn<"ABS">;
|
||||
def fxpmath_EwUnaryFn_Exp : fxpmath_ConstEwUnaryFn<"EXP">;
|
||||
def fxpmath_EwUnaryFn_Identity: fxpmath_ConstEwUnaryFn<"IDENTITY">;
|
||||
def fxpmath_EwUnaryFn_Log : fxpmath_ConstEwUnaryFn<"LOG">;
|
||||
def fxpmath_EwUnaryFn_Neg : fxpmath_ConstEwUnaryFn<"NEG">;
|
||||
def fxpmath_EwUnaryFn_Rsqrt : fxpmath_ConstEwUnaryFn<"RSQRT">;
|
||||
def fxpmath_EwUnaryFn_Sigmoid : fxpmath_ConstEwUnaryFn<"SIGMOID">;
|
||||
def fxpmath_EwUnaryFn_Sign : fxpmath_ConstEwUnaryFn<"SIGN">;
|
||||
def fxpmath_EwUnaryFn_Sin : fxpmath_ConstEwUnaryFn<"SIN">;
|
||||
def fxpmath_EwUnaryFn_Sqrt : fxpmath_ConstEwUnaryFn<"SQRT">;
|
||||
def fxpmath_EwUnaryFn_Square : fxpmath_ConstEwUnaryFn<"SQUARE">;
|
||||
def fxpmath_EwUnaryFn_Tanh : fxpmath_ConstEwUnaryFn<"TANH">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Comparison functions (compares relative to zero on a subtraction result).
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def fxpmath_CompareZ : StrEnumAttrCase<"CMPZ">;
|
||||
def fxpmath_CompareNZ : StrEnumAttrCase<"CMPNZ">;
|
||||
def fxpmath_CompareLZ : StrEnumAttrCase<"CMPLZ">;
|
||||
def fxpmath_CompareLZE : StrEnumAttrCase<"CMPLZE">;
|
||||
def fxpmath_CompareGZ : StrEnumAttrCase<"CMPGZ">;
|
||||
def fxpmath_CompareGZE : StrEnumAttrCase<"CMPGZE">;
|
||||
|
||||
def fxpmath_CompareFnAttr : StrEnumAttr<"ComparisonFn",
|
||||
"Type of subtraction-result comparison to perform.",
|
||||
[
|
||||
fxpmath_CompareZ,
|
||||
fxpmath_CompareNZ,
|
||||
fxpmath_CompareLZ,
|
||||
fxpmath_CompareLZE,
|
||||
fxpmath_CompareGZ,
|
||||
fxpmath_CompareGZE
|
||||
]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Base classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class fxpmath_Op<string mnemonic, list<OpTrait> traits> :
|
||||
Op<FxpMathOps_Dialect, mnemonic, traits>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Fixed-point (fxp) arithmetic ops used by kernels.
|
||||
// Some of these are temporary pending inclusion into a more core dialect.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def fxpmath_ClampISOp : fxpmath_Op<"clampis", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary =
|
||||
"Clamps a signed-integer like argument to a min/max range.";
|
||||
let description = [{
|
||||
Element-wise equivalent to:
|
||||
r = std::min(clamp_max, std::max(e, clamp_min))
|
||||
}];
|
||||
let arguments = (ins SignlessIntegerLike:$operand,
|
||||
APIntAttr:$clamp_min,
|
||||
APIntAttr:$clamp_max);
|
||||
let results = (outs SignlessIntegerLike);
|
||||
}
|
||||
|
||||
def fxpmath_ConvertISOp :
|
||||
fxpmath_Op<"convertis",
|
||||
[NoSideEffect, SameOperandsAndResultShape]> {
|
||||
let summary =
|
||||
"Does an element-wise conversion from a signed integer to signed integer";
|
||||
let description = [{
|
||||
Similar to an element-wise static_cast in C++, from a one signed integer
|
||||
element type to another.
|
||||
}];
|
||||
let arguments = (ins SignlessIntegerLike:$operand);
|
||||
let results = (outs SignlessIntegerLike);
|
||||
}
|
||||
|
||||
def fxpmath_ConvertISToFOp :
|
||||
fxpmath_Op<"convertistof",
|
||||
[NoSideEffect, SameOperandsAndResultShape]> {
|
||||
let summary =
|
||||
"Does an element-wise conversion from a signed integer to a float";
|
||||
let description = [{
|
||||
Similar to an element-wise static_cast in C++, from a signed integer
|
||||
element type to a floating point element type, rounding to the nearest
|
||||
floating point value.
|
||||
}];
|
||||
let arguments = (ins SignlessIntegerLike:$operand);
|
||||
let results = (outs FloatLike);
|
||||
}
|
||||
|
||||
|
||||
def fxpmath_VecScalarSaturatingRoundingDoublingHighMulISOp :
|
||||
fxpmath_Op<"vs_saturating_rounding_doubling_high_mulis",
|
||||
[NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Implements equivalent functionality to ARMv7 NEON VQRDMULH";
|
||||
let description = [{
|
||||
Equivalent to the ARMv7 NEON VQRDMULH instruction.
|
||||
See gemmlowp::SaturatingRoundingDoublingHighMul for a reference
|
||||
implementation.
|
||||
}];
|
||||
let arguments = (ins SignlessIntegerLike:$a, APIntAttr:$b);
|
||||
let results = (outs SignlessIntegerLike);
|
||||
}
|
||||
|
||||
def fxpmath_RoundingDivideByPotISOp :
|
||||
fxpmath_Op<"rounding_divide_by_potis", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = [{
|
||||
Computes a rounding arithmetic right shift.
|
||||
}];
|
||||
let description = [{
|
||||
Computes integer division by a power-of-two, correctly rounded-to-nearest.
|
||||
Also known as a rounding arithmetic right shift. See
|
||||
gemmlowp::RoundingDivideByPOT for a reference implementation.
|
||||
}];
|
||||
let arguments = (ins SignlessIntegerLike:$operand, APIntAttr:$exponent);
|
||||
let results = (outs SignlessIntegerLike:$res);
|
||||
let verifier = [{
|
||||
auto verifyExponent = exponent().getSExtValue();
|
||||
if (verifyExponent < 0 || verifyExponent > 31) {
|
||||
return emitOpError("exponent must be in range [0..31]");
|
||||
}
|
||||
return success();
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Real math ops.
|
||||
//
|
||||
// Math ops on real numbers which may have a representation in quantized
|
||||
// arithmetic. It is expected that eligible ops are lowered from a source
|
||||
// dialect to this set of ops prior to the process of converting a computation
|
||||
// to a quantized form. It is a non-goal of these ops to preserve enough
|
||||
// information to convert back to the higher level, source dialect.
|
||||
//
|
||||
// These ops support either real/floating point or QuantizedTypes as operands
|
||||
// and results. Since not all transformations are supported (globally or
|
||||
// sometimes for specific targets), a computation may end up with
|
||||
// untransformable RealMathOps, in which case they need to be lowered as is
|
||||
// (using floating point math).
|
||||
//
|
||||
// This op set takes advantage of the fact that it is typically trivial to
|
||||
// combine a math function with a compatible bias addition and real-valued
|
||||
// clamp (which can be done at a higher accumulation bit depth).
|
||||
//
|
||||
// In addition, all element-wise unary functions are collapsed into a single
|
||||
// fxpmath_RealUnaryEwOp and selected via an enum-like attribute. Especially at
|
||||
// low bit depths, this makes matching simpler and allows the construction of
|
||||
// generic LUT-based implementations. It also allows specific lowering rules
|
||||
// to consolidate runs of chained unary ops and fuse them to preceding math
|
||||
// ops, potentially allowing them to operate directly on higher precision
|
||||
// intermediates without resorting to lots of custom kernels for common
|
||||
// formulas that can suffer from insufficient precision at low bit depths.
|
||||
//
|
||||
// Comparison operators are modeled as element-wise unary functions (i.e.
|
||||
// CMPZ, CMPNZ, CMPLZ, CMPGZ) intended to follow a sub and output a 1bit
|
||||
// quantized value. It is expected that lowering rules can fuse them with
|
||||
// the preceding sub.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class fxpmath_RealMathOp<string mnemonic, list<OpTrait> traits = [], dag args> :
|
||||
fxpmath_Op<mnemonic, traits>,
|
||||
Arguments<!con(args, (ins
|
||||
fxpmath_ClampValueAttr:$clamp_min, fxpmath_ClampValueAttr:$clamp_max))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Element wise binary real math ops.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class fxpmath_RealBinaryOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
fxpmath_RealMathOp<mnemonic, traits,
|
||||
(ins quant_RealValueType:$lhs,
|
||||
quant_RealValueType:$rhs)>,
|
||||
Results<(outs quant_RealValueType:$res)>;
|
||||
|
||||
class fxpmath_RealBinaryBiasOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
fxpmath_RealMathOp<mnemonic, traits,
|
||||
(ins quant_RealValueType:$lhs, quant_RealValueType:$rhs,
|
||||
quant_RealValueType:$bias)>,
|
||||
Results<(outs quant_RealValueType:$res)>;
|
||||
|
||||
def fxpmath_RealAddEwOp :
|
||||
fxpmath_RealBinaryOp<"real_add_ew", [NoSideEffect]>;
|
||||
|
||||
def fxpmath_RealSubEwOp :
|
||||
fxpmath_RealBinaryOp<"real_sub_ew", [NoSideEffect]>;
|
||||
|
||||
def fxpmath_RealMulEwOp :
|
||||
fxpmath_RealBinaryOp<"real_mul_ew", [NoSideEffect]>;
|
||||
|
||||
def fxpmath_RealDivEwOp :
|
||||
fxpmath_RealBinaryOp<"real_div_ew", [NoSideEffect]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Element wise unary real math op.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def fxpmath_RealUnaryEwOp :
|
||||
fxpmath_RealMathOp<"real_unary_ew", [NoSideEffect],
|
||||
(ins quant_RealValueType:$operand, fxpmath_EwUnaryFnAttr:$fn)>,
|
||||
Results<(outs quant_RealValueType:$res)>;
|
||||
|
||||
def fxpmath_RealCompareZeroEwOp : fxpmath_Op<"compare", [NoSideEffect]>,
|
||||
Arguments<(ins quant_RealValueType:$operand, fxpmath_CompareFnAttr:$fn)>,
|
||||
Results<(outs I1Tensor:$res)> {
|
||||
let description = [{
|
||||
Compares a real value to zero, returning an I1 (boolean) tensor with the
|
||||
result of applying the comparison function.
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dot op with fused bias addition.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def fxpmath_RealMatMulOp :
|
||||
fxpmath_RealBinaryOp<"real_matmul", [NoSideEffect]> {
|
||||
let summary = "Matmul";
|
||||
let description = [{
|
||||
A matrix multiply of [m, k] and [k, n] -> [m, n] where the bias vector is
|
||||
of shape [n]. Also accepts rank 3 or more input tensors, in which case
|
||||
the leading dimensions are batch dims.
|
||||
|
||||
Many real systems have specific library calls optimized for this precise
|
||||
operation, which is why it is handled explicitly versus purely as a
|
||||
generalized tensor contraction.
|
||||
}];
|
||||
}
|
||||
|
||||
def fxpmath_RealMatMulBiasOp :
|
||||
fxpmath_RealBinaryBiasOp<"real_matmul_bias", [NoSideEffect]> {
|
||||
let summary = "Matmul with bias";
|
||||
let description = [{
|
||||
A specialization of a RealMatMulOp that also accepts an [n] dimension
|
||||
bias vector.
|
||||
|
||||
In addition, there is often special support for a fused bias and clamp,
|
||||
which is why they are included.
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // DIALECT_FXPMATHOPS_FXPMATH_OPS_
|
|
@ -1,37 +0,0 @@
|
|||
//===- Passes.h - Fixed point math passes -----------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines all of the passes owned by the FxpMathOps dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_FXPMATHOPS_PASSES_H
|
||||
#define MLIR_DIALECT_FXPMATHOPS_PASSES_H
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
class FuncOp;
|
||||
template <typename T> class OpPassBase;
|
||||
|
||||
namespace fxpmath {
|
||||
|
||||
/// Creates a pass that lowers uniform-quantized real math ops to integer
|
||||
/// arithmetic. This will leave unrecognized real math ops as-is and is
|
||||
/// typically followed by a pass that lowers any unrecognized ops to a pure
|
||||
/// floating point form.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> createLowerUniformRealMathPass();
|
||||
|
||||
/// Creates a pass that lowers uniform-quantized qcast/dcast ops to equivalent
|
||||
/// operations that perform quantize/dequantize.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> createLowerUniformCastsPass();
|
||||
|
||||
} // namespace fxpmath
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_FXPMATHOPS_PASSES_H
|
|
@ -1,24 +0,0 @@
|
|||
//===-- Passes.td - FxpMath pass definition file -----------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_FXPMATH_PASSES
|
||||
#define MLIR_DIALECT_FXPMATH_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def FxpMathLowerUniformCasts : Pass<"fxpmath-lower-uniform-casts"> {
|
||||
let summary = "Lowers uniform-quantized casts";
|
||||
let constructor = "mlir::fxpmath::createLowerUniformCastsPass()";
|
||||
}
|
||||
|
||||
def FxpMathLowerUniformRealMath : Pass<"fxpmath-lower-uniform-real-math"> {
|
||||
let summary = "Lowers uniform-quantized real math ops to integer arithmetic";
|
||||
let constructor = "mlir::fxpmath::createLowerUniformRealMathPass()";
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_FXPMATH_PASSES
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
#include "mlir/Dialect/AVX512/AVX512Dialect.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
|
@ -41,7 +40,6 @@ inline void registerAllDialects() {
|
|||
static bool init_once = []() {
|
||||
registerDialect<AffineDialect>();
|
||||
registerDialect<avx512::AVX512Dialect>();
|
||||
registerDialect<fxpmath::FxpMathOpsDialect>();
|
||||
registerDialect<gpu::GPUDialect>();
|
||||
registerDialect<LLVM::LLVMAVX512Dialect>();
|
||||
registerDialect<LLVM::LLVMDialect>();
|
||||
|
|
|
@ -28,14 +28,12 @@
|
|||
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
|
||||
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
||||
#include "mlir/Dialect/Affine/Passes.h"
|
||||
#include "mlir/Dialect/FxpMathOps/Passes.h"
|
||||
#include "mlir/Dialect/GPU/Passes.h"
|
||||
#include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/LoopOps/Passes.h"
|
||||
#include "mlir/Dialect/Quant/Passes.h"
|
||||
#include "mlir/Dialect/SPIRV/Passes.h"
|
||||
#include "mlir/Quantizer/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/LocationSnapshot.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/ViewOpGraph.h"
|
||||
|
@ -65,10 +63,6 @@ inline void registerAllPasses() {
|
|||
#define GEN_PASS_REGISTRATION
|
||||
#include "mlir/Dialect/Affine/Passes.h.inc"
|
||||
|
||||
// FxpMath
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "mlir/Dialect/FxpMathOps/Passes.h.inc"
|
||||
|
||||
// GPU
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "mlir/Dialect/GPU/Passes.h.inc"
|
||||
|
@ -88,8 +82,6 @@ inline void registerAllPasses() {
|
|||
// Quant
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "mlir/Dialect/Quant/Passes.h.inc"
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "mlir/Quantizer/Transforms/Passes.h.inc"
|
||||
|
||||
// SPIR-V
|
||||
#define GEN_PASS_REGISTRATION
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
add_subdirectory(Transforms)
|
|
@ -1,41 +0,0 @@
|
|||
//===- FxpMathConfig.h - Reference fixed point config -----------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines a TargetConfiguration for reference fixed-point math
|
||||
// quantization scheme based on the FxpMathOps (plus a small category of
|
||||
// extension ops that can be added from other dialects).
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_QUANTIZER_CONFIGURATIONS_FXPMATHCONFIG_H
|
||||
#define MLIR_QUANTIZER_CONFIGURATIONS_FXPMATHCONFIG_H
|
||||
|
||||
#include "mlir/Quantizer/Support/Configuration.h"
|
||||
#include "mlir/Quantizer/Support/Metadata.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace quantizer {
|
||||
|
||||
/// Target configuration for a reference affine/fixed-point quantization
|
||||
/// scheme defined in terms of the FxpMathOps dialect. This can be extended
|
||||
/// with select ops from other dialects by way of the following public
|
||||
/// methods:
|
||||
/// - addValueIdentityOp
|
||||
class FxpMathTargetConfig : public TargetConfiguration {
|
||||
public:
|
||||
/// Creates an FxpMathTargetConfig instance which can be further customized.
|
||||
static std::unique_ptr<FxpMathTargetConfig> create(SolverContext &context);
|
||||
|
||||
protected:
|
||||
FxpMathTargetConfig(SolverContext &context) : TargetConfiguration(context) {}
|
||||
};
|
||||
|
||||
} // namespace quantizer
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_QUANTIZER_CONFIGURATIONS_FXPMATHCONFIG_H
|
|
@ -1,146 +0,0 @@
|
|||
//===- Configuration.h - Configuration object base classes ------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// The quantizer is relatively agnostic to source and target dialects, with
|
||||
// the specific represented by configuration policy objects derived from
|
||||
// classes in this file.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_QUANTIZER_SUPPORT_CONFIGURATION_H
|
||||
#define MLIR_QUANTIZER_SUPPORT_CONFIGURATION_H
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h"
|
||||
#include "mlir/IR/Identifier.h"
|
||||
#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
|
||||
#include "mlir/Quantizer/Support/Metadata.h"
|
||||
#include "mlir/Quantizer/Support/Rules.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
|
||||
namespace mlir {
|
||||
class Operation;
|
||||
|
||||
namespace quantizer {
|
||||
|
||||
class CAGSlice;
|
||||
|
||||
/// Defines quantization configuration for the target.
|
||||
/// The settings here depend on a variety of details about the deployment
|
||||
/// environment, although, where we have control over such things, we do
|
||||
/// try to standardize as possible.
|
||||
///
|
||||
/// Non-const methods are used to setup the configuration. It is expected that
|
||||
/// const instances/references are used post-build.
|
||||
class TargetConfiguration {
|
||||
public:
|
||||
static constexpr size_t MaxSchemeIndex = 31;
|
||||
using OpHandlerFn = std::function<void(Operation *op, CAGSlice &cag)>;
|
||||
|
||||
TargetConfiguration(SolverContext &context);
|
||||
virtual ~TargetConfiguration() = default;
|
||||
|
||||
/// Adds a candidate type, returning its ordinal.
|
||||
unsigned addCandidateType(quant::AnyQuantizedType quantizedType,
|
||||
CandidateQuantizedType::Scheme scheme) {
|
||||
unsigned ordinal = candidateTypes.size();
|
||||
assert(allCandidateTypesMask.size() == ordinal);
|
||||
CandidateQuantizedType ct{ordinal, quantizedType, scheme};
|
||||
candidateTypes.push_back(ct);
|
||||
allCandidateTypesMask.push_back(true);
|
||||
return ordinal;
|
||||
}
|
||||
|
||||
/// Gets a prototype scheme by index.
|
||||
const CandidateQuantizedType &getCandidateType(unsigned index) const {
|
||||
assert(index < candidateTypes.size());
|
||||
return candidateTypes[index];
|
||||
}
|
||||
|
||||
ArrayRef<CandidateQuantizedType> getCandidateTypes() const {
|
||||
return candidateTypes;
|
||||
}
|
||||
|
||||
/// Gets a mask of all enabled candidate types by ordinal.
|
||||
llvm::SmallBitVector getAllCandidateTypesMask() const {
|
||||
return allCandidateTypesMask;
|
||||
}
|
||||
|
||||
/// Gets a mask with every candidate type except those in the given mask.
|
||||
llvm::SmallBitVector
|
||||
getCandidateTypeDisabledExceptMask(ArrayRef<unsigned> exceptOrdinals) const {
|
||||
llvm::SmallBitVector disabled(allCandidateTypesMask);
|
||||
for (unsigned ordinal : exceptOrdinals) {
|
||||
disabled.reset(ordinal);
|
||||
}
|
||||
return disabled;
|
||||
}
|
||||
|
||||
/// Adds an op handler.
|
||||
template <typename OpTy>
|
||||
void addOpHandler(OpHandlerFn fn) {
|
||||
addOpHandlerByName(OpTy::getOperationName(), fn);
|
||||
}
|
||||
|
||||
/// Adds an operation which requires statistics at its result nodes for
|
||||
/// best quantization performance. Note that the opName StringRef is
|
||||
/// expected to come from getOperationName() and be static.
|
||||
template <typename OpTy>
|
||||
void addRequireStatsOp() {
|
||||
addRequireStatsOpByName(OpTy::getOperationName());
|
||||
}
|
||||
|
||||
/// Returns whether opName is a RequireStatsOp.
|
||||
bool isRequireStatsOp(Operation *op) const;
|
||||
|
||||
/// Adds an op which does not mutate its values but may mutate its shape
|
||||
/// or combine its operands in an arbitrary way.
|
||||
/// Such ops are expected to have the same types for operands and results
|
||||
/// and must be capable of operating on storage types.
|
||||
template <typename OpTy>
|
||||
void addValueIdentityOp() {
|
||||
addValueIdentityOpByName(OpTy::getOperationName());
|
||||
}
|
||||
|
||||
/// Handles the operation if a handler is defined for it.
|
||||
void handleOp(Operation *op, CAGSlice &cag) const;
|
||||
|
||||
/// Finalizes the CAG after all anchors have been added.
|
||||
virtual void finalizeAnchors(CAGSlice &cag) const {}
|
||||
|
||||
/// Whether an operand or result type is subject to analysis by this config.
|
||||
virtual bool isHandledType(Type t) const = 0;
|
||||
|
||||
protected:
|
||||
virtual void addValueIdentityOpByName(StringRef opName) = 0;
|
||||
void addOpHandlerByName(StringRef name, OpHandlerFn fn);
|
||||
|
||||
private:
|
||||
void addRequireStatsOpByName(StringRef opName);
|
||||
|
||||
/// Vector of all candidate type constraints, indexed by ordinal.
|
||||
std::vector<CandidateQuantizedType> candidateTypes;
|
||||
|
||||
// A SmallBoolVector with bits set for all known candidate types.
|
||||
llvm::SmallBitVector allCandidateTypesMask;
|
||||
|
||||
/// Map of all op handlers.
|
||||
llvm::StringMap<OpHandlerFn> opHandlers;
|
||||
|
||||
/// Names of operations which should have their results annotated with
|
||||
/// statistics.
|
||||
llvm::StringSet<> requireStatsOpNames;
|
||||
};
|
||||
|
||||
} // namespace quantizer
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_QUANTIZER_SUPPORT_CONFIGURATION_H
|
|
@ -1,360 +0,0 @@
|
|||
//===- ConstraintAnalysisGraph.h - Graphs type for constraints --*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file provides graph-based data structures for representing anchors
|
||||
// and constraints between them.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPH_H
|
||||
#define MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPH_H
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Quantizer/Support/Metadata.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace quantizer {
|
||||
|
||||
class CAGNode;
|
||||
class CAGSlice;
|
||||
class TargetConfiguration;
|
||||
|
||||
/// A node in the Constraint Analysis Graph.
|
||||
/// Nodes are either anchors (representing results and operands) or constraints.
|
||||
/// Anchor nodes are connected to other anchor nodes via constraints.
|
||||
/// Nodes exist within graph slices, which are typically analyses attached to
|
||||
/// the function or module. Slices can contain other slices, which mirrors
|
||||
/// the nesting of analyses.
|
||||
///
|
||||
/// Nodes have directed relationships which propagate successor-ward when dirty.
|
||||
/// Relationships can be bi-directional, in which case, the constraint's
|
||||
/// propagation mechanism must ensure convergence.
|
||||
class CAGNode {
|
||||
public:
|
||||
enum class Kind {
|
||||
/// Anchors.
|
||||
Anchor,
|
||||
OperandAnchor,
|
||||
ResultAnchor,
|
||||
LastAnchor = ResultAnchor,
|
||||
|
||||
/// Constraints.
|
||||
Constraint,
|
||||
SolveUniformConstraint,
|
||||
UniformPropagateExplicitScale,
|
||||
LastConstraint = UniformPropagateExplicitScale,
|
||||
};
|
||||
|
||||
// Vector and iterator over nodes.
|
||||
using node_vector = SmallVector<CAGNode *, 1>;
|
||||
using iterator = node_vector::iterator;
|
||||
using const_iterator = node_vector::const_iterator;
|
||||
|
||||
virtual ~CAGNode() = default;
|
||||
|
||||
Kind getKind() const { return kind; }
|
||||
|
||||
/// Unique id of the node within the slice.
|
||||
int getNodeId() const { return nodeId; }
|
||||
|
||||
/// Whether the node is dirty, requiring one or more calls to propagate().
|
||||
bool isDirty() const { return dirty; }
|
||||
void markDirty() { dirty = true; }
|
||||
void clearDirty() { dirty = false; }
|
||||
|
||||
/// Iterator over this node's children (outgoing) nodes.
|
||||
const_iterator begin() const { return outgoing.begin(); }
|
||||
const_iterator end() const { return outgoing.end(); }
|
||||
iterator begin() { return outgoing.begin(); }
|
||||
iterator end() { return outgoing.end(); }
|
||||
|
||||
/// Iterator over this parents (incoming) nodes.
|
||||
const_iterator incoming_begin() const { return incoming.begin(); }
|
||||
const_iterator incoming_end() const { return incoming.end(); }
|
||||
iterator incoming_begin() { return incoming.begin(); }
|
||||
iterator incoming_end() { return incoming.end(); }
|
||||
|
||||
virtual void propagate(SolverContext &solverContext,
|
||||
const TargetConfiguration &config) {}
|
||||
|
||||
/// Prints the node label, suitable for one-line display.
|
||||
virtual void printLabel(raw_ostream &os) const;
|
||||
|
||||
template <typename T> void findChildrenOfKind(SmallVectorImpl<T *> &found) {
|
||||
for (CAGNode *child : *this) {
|
||||
T *ofKind = dyn_cast<T>(child);
|
||||
if (ofKind) {
|
||||
found.push_back(ofKind);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Replaces this node by rerouting any parent nodes to have otherNode
|
||||
/// as a child.
|
||||
void replaceIncoming(CAGNode *otherNode);
|
||||
|
||||
/// Adds an outgoing connection to this node (and corresponding back
|
||||
/// incoming connection).
|
||||
void addOutgoing(CAGNode *toNode);
|
||||
|
||||
/// Whether this node is an orphan (has no incoming or outgoing connections).
|
||||
bool isOrphan() const { return incoming.empty() && outgoing.empty(); }
|
||||
|
||||
protected:
|
||||
CAGNode(Kind kind) : kind(kind) {}
|
||||
|
||||
private:
|
||||
Kind kind;
|
||||
int nodeId = -1;
|
||||
node_vector outgoing;
|
||||
node_vector incoming;
|
||||
bool dirty = false;
|
||||
|
||||
friend class CAGSlice;
|
||||
};
|
||||
|
||||
/// Anchor nodes represent points in the source IR where we may choose to
|
||||
/// introduce a type transition. These include operands, results, arguments
|
||||
/// returns, etc.
|
||||
class CAGAnchorNode : public CAGNode {
|
||||
public:
|
||||
enum class TypeTransformRule {
|
||||
/// The owning op directly supports all transformed types. In practice,
|
||||
/// this means that the op supports QuantizedType for this anchor.
|
||||
Direct,
|
||||
|
||||
/// The type of this anchor should be set to the QuantizedType storage
|
||||
/// type. This will only be valid if constraints are such that all
|
||||
/// inputs/outputs converge to the same storage type (i.e. coupled).
|
||||
DirectStorage,
|
||||
|
||||
/// The anchor must only be typed based on the expressed type. This is
|
||||
/// used for ops that do not natively support quantization, and suitable
|
||||
/// casts will be inserted.
|
||||
ExpressedOnly,
|
||||
};
|
||||
|
||||
/// Metadata for solving uniform quantization params.
|
||||
CAGUniformMetadata &getUniformMetadata() { return uniformMetadata; }
|
||||
const CAGUniformMetadata &getUniformMetadata() const {
|
||||
return uniformMetadata;
|
||||
}
|
||||
|
||||
virtual Operation *getOp() const = 0;
|
||||
virtual Value getValue() const = 0;
|
||||
|
||||
static bool classof(const CAGNode *n) {
|
||||
return n->getKind() >= Kind::Anchor && n->getKind() <= Kind::LastAnchor;
|
||||
}
|
||||
|
||||
void propagate(SolverContext &solverContext,
|
||||
const TargetConfiguration &config) override;
|
||||
|
||||
void printLabel(raw_ostream &os) const override;
|
||||
|
||||
/// Given the anchor metadata and resolved solutions, chooses the most
|
||||
/// salient and returns an appropriate type to represent it.
|
||||
Type getTransformedType();
|
||||
|
||||
TypeTransformRule getTypeTransformRule() const { return typeTransformRule; }
|
||||
|
||||
void setTypeTransformRule(TypeTransformRule r) { typeTransformRule = r; }
|
||||
|
||||
/// Gets the Type that was defined for this anchor at the time of
|
||||
/// construction.
|
||||
Type getOriginalType() const { return originalType; }
|
||||
|
||||
protected:
|
||||
CAGAnchorNode(Kind kind, Type originalType)
|
||||
: CAGNode(kind), originalType(originalType) {}
|
||||
|
||||
private:
|
||||
CAGUniformMetadata uniformMetadata;
|
||||
Type originalType;
|
||||
TypeTransformRule typeTransformRule = TypeTransformRule::Direct;
|
||||
};
|
||||
|
||||
/// An anchor tied to a specific operand.
|
||||
/// Since operand anchors can be rewritten so that the operand refers to
|
||||
/// a new result, they are maintained by reference (to the op and index).
|
||||
class CAGOperandAnchor : public CAGAnchorNode {
|
||||
public:
|
||||
CAGOperandAnchor(Operation *op, unsigned operandIdx);
|
||||
|
||||
Operation *getOp() const final { return op; }
|
||||
unsigned getOperandIdx() const { return operandIdx; }
|
||||
|
||||
static bool classof(const CAGNode *n) {
|
||||
return n->getKind() == Kind::Anchor || n->getKind() == Kind::OperandAnchor;
|
||||
}
|
||||
|
||||
Value getValue() const final { return op->getOperand(operandIdx); }
|
||||
|
||||
void printLabel(raw_ostream &os) const override;
|
||||
|
||||
private:
|
||||
Operation *op;
|
||||
unsigned operandIdx;
|
||||
};
|
||||
|
||||
/// An anchor tied to a specific result.
|
||||
/// Since a result is already anchored to its defining op, result anchors refer
|
||||
/// directly to the underlying Value.
|
||||
class CAGResultAnchor : public CAGAnchorNode {
|
||||
public:
|
||||
CAGResultAnchor(Operation *op, unsigned resultIdx);
|
||||
|
||||
static bool classof(const CAGNode *n) {
|
||||
return n->getKind() == Kind::Anchor || n->getKind() == Kind::ResultAnchor;
|
||||
}
|
||||
|
||||
Operation *getOp() const final { return resultValue.getDefiningOp(); }
|
||||
Value getValue() const final { return resultValue; }
|
||||
|
||||
void printLabel(raw_ostream &os) const override;
|
||||
|
||||
private:
|
||||
Value resultValue;
|
||||
};
|
||||
|
||||
/// Base class for constraint nodes.
|
||||
class CAGConstraintNode : public CAGNode {
|
||||
public:
|
||||
CAGConstraintNode(Kind kind) : CAGNode(kind) {}
|
||||
|
||||
static bool classof(const CAGNode *n) {
|
||||
return n->getKind() >= Kind::Constraint &&
|
||||
n->getKind() <= Kind::LastConstraint;
|
||||
}
|
||||
};
|
||||
|
||||
/// A slice of a CAG (which may be the whole graph).
|
||||
class CAGSlice {
|
||||
public:
|
||||
CAGSlice(SolverContext &context);
|
||||
~CAGSlice();
|
||||
|
||||
using node_vector = std::vector<CAGNode *>;
|
||||
using iterator = node_vector::iterator;
|
||||
using const_iterator = node_vector::const_iterator;
|
||||
|
||||
iterator begin() { return allNodes.begin(); }
|
||||
iterator end() { return allNodes.end(); }
|
||||
const_iterator begin() const { return allNodes.begin(); }
|
||||
const_iterator end() const { return allNodes.end(); }
|
||||
|
||||
/// Gets an operand anchor node.
|
||||
CAGOperandAnchor *getOperandAnchor(Operation *op, unsigned operandIdx);
|
||||
|
||||
/// Gets a result anchor node.
|
||||
CAGResultAnchor *getResultAnchor(Operation *op, unsigned resultIdx);
|
||||
|
||||
/// Adds a relation constraint with incoming 'from' anchors and outgoing 'to'
|
||||
/// anchors.
|
||||
template <typename T, typename... Args>
|
||||
T *addUniqueConstraint(ArrayRef<CAGAnchorNode *> anchors, Args... args) {
|
||||
static_assert(std::is_convertible<T *, CAGConstraintNode *>(),
|
||||
"T must be a CAGConstraingNode");
|
||||
T *constraintNode = addNode(std::make_unique<T>(args...));
|
||||
for (auto *anchor : anchors)
|
||||
anchor->addOutgoing(constraintNode);
|
||||
return constraintNode;
|
||||
}
|
||||
|
||||
/// Adds a unidirectional constraint from a node to an array of target nodes.
|
||||
template <typename T, typename... Args>
|
||||
T *addUnidirectionalConstraint(CAGAnchorNode *fromAnchor,
|
||||
ArrayRef<CAGAnchorNode *> toAnchors,
|
||||
Args... args) {
|
||||
static_assert(std::is_convertible<T *, CAGConstraintNode *>(),
|
||||
"T must be a CAGConstraingNode");
|
||||
T *constraintNode = addNode(std::make_unique<T>(args...));
|
||||
fromAnchor->addOutgoing(constraintNode);
|
||||
for (auto *toAnchor : toAnchors) {
|
||||
constraintNode->addOutgoing(toAnchor);
|
||||
}
|
||||
return constraintNode;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T *addClusteredConstraint(ArrayRef<CAGAnchorNode *> anchors) {
|
||||
static_assert(std::is_convertible<T *, CAGConstraintNode *>(),
|
||||
"T must be a CAGConstraingNode");
|
||||
SmallVector<T *, 8> cluster;
|
||||
for (auto *anchor : anchors) {
|
||||
anchor->findChildrenOfKind<T>(cluster);
|
||||
}
|
||||
|
||||
T *constraintNode;
|
||||
if (cluster.empty()) {
|
||||
// Create new.
|
||||
constraintNode = addNode(std::make_unique<T>());
|
||||
} else {
|
||||
// Merge existing.
|
||||
constraintNode = cluster[0];
|
||||
for (size_t i = 1, e = cluster.size(); i < e; ++i) {
|
||||
cluster[i]->replaceIncoming(constraintNode);
|
||||
}
|
||||
}
|
||||
for (auto *anchor : anchors) {
|
||||
anchor->addOutgoing(constraintNode);
|
||||
}
|
||||
return constraintNode;
|
||||
}
|
||||
|
||||
/// Enumerates all implied connections in the slice.
|
||||
/// An implied connection is any two nodes that physically refer to the
|
||||
/// same value in the IR, such as result->operand.
|
||||
/// Typically this will be modeled with some kind of strong or weak
|
||||
/// identity constraint such that types propagate.
|
||||
/// This is usually called when the slice has been fully constructed in
|
||||
/// order to add final constraints.
|
||||
/// It is legal for the callback to modify the graph by adding constraints.
|
||||
void enumerateImpliedConnections(
|
||||
std::function<void(CAGAnchorNode *from, CAGAnchorNode *to)> callback);
|
||||
|
||||
/// Performs one round of propagation, returning the number of nodes
|
||||
/// propagates. If returns > 0, then additional propagate() rounds are
|
||||
/// required.
|
||||
unsigned propagate(const TargetConfiguration &config);
|
||||
|
||||
private:
|
||||
/// Adds a node to the graph.
|
||||
/// The node should be a subclass of TransformNode.
|
||||
/// Returns the raw pointer to the node.
|
||||
template <typename T>
|
||||
T *addNode(std::unique_ptr<T> node) {
|
||||
node->nodeId = allNodes.size();
|
||||
T *unownedNode = node.release();
|
||||
allNodes.push_back(unownedNode);
|
||||
return unownedNode;
|
||||
}
|
||||
|
||||
SolverContext &context;
|
||||
std::vector<CAGNode *> allNodes;
|
||||
DenseMap<std::pair<Operation *, unsigned>, CAGOperandAnchor *> operandAnchors;
|
||||
DenseMap<std::pair<Operation *, unsigned>, CAGResultAnchor *> resultAnchors;
|
||||
};
|
||||
|
||||
inline raw_ostream &operator<<(raw_ostream &os, const CAGNode &node) {
|
||||
node.printLabel(os);
|
||||
return os;
|
||||
}
|
||||
|
||||
} // namespace quantizer
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPH_H
|
|
@ -1,49 +0,0 @@
|
|||
//===- ConstraintAnalysisGraphTraits.h - Traits for CAGs --------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Provides graph traits for constraint analysis graphs.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPHTRAITS_H
|
||||
#define MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPHTRAITS_H
|
||||
|
||||
#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
|
||||
#include "llvm/ADT/GraphTraits.h"
|
||||
|
||||
namespace llvm {
|
||||
|
||||
template <>
|
||||
struct GraphTraits<const mlir::quantizer::CAGNode *> {
|
||||
using NodeRef = const mlir::quantizer::CAGNode *;
|
||||
|
||||
static NodeRef getEntryNode(NodeRef node) { return node; }
|
||||
|
||||
// Successors.
|
||||
using ChildIteratorType = mlir::quantizer::CAGNode::const_iterator;
|
||||
static ChildIteratorType child_begin(NodeRef node) { return node->begin(); }
|
||||
static ChildIteratorType child_end(NodeRef node) { return node->end(); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GraphTraits<const mlir::quantizer::CAGSlice *>
|
||||
: public llvm::GraphTraits<const mlir::quantizer::CAGNode *> {
|
||||
using nodes_iterator = mlir::quantizer::CAGSlice::const_iterator;
|
||||
static mlir::quantizer::CAGSlice::const_iterator
|
||||
nodes_begin(const mlir::quantizer::CAGSlice *G) {
|
||||
return G->begin();
|
||||
}
|
||||
static mlir::quantizer::CAGSlice::const_iterator
|
||||
nodes_end(const mlir::quantizer::CAGSlice *G) {
|
||||
return G->end();
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace llvm
|
||||
|
||||
#endif // MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPHTRAITS_H
|
|
@ -1,101 +0,0 @@
|
|||
//===- Metadata.h - Top level types and metadata ----------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file contains top level types needed to construct constraint graphs,
|
||||
// including context/allocator support and concrete metadata structs for
|
||||
// different quantization schemes (which must be attached to anchor nodes).
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_QUANTIZER_SUPPORT_METADATA_H
|
||||
#define MLIR_QUANTIZER_SUPPORT_METADATA_H
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/Quantizer/Support/Rules.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace quantizer {
|
||||
|
||||
class SolverContext {
|
||||
public:
|
||||
SolverContext(MLIRContext &mlirContext) : mlirContext(mlirContext) {}
|
||||
|
||||
MLIRContext &getMlirContext() { return mlirContext; }
|
||||
|
||||
llvm::BumpPtrAllocator &getAllocator() { return allocator; }
|
||||
|
||||
// Optional path to write a debug DOT file for the CAG.
|
||||
StringRef getDebugCAGDotPath() const { return debugCAGDotPath; }
|
||||
void setDebugCAGDotPath(StringRef p) { debugCAGDotPath = std::string(p); }
|
||||
|
||||
private:
|
||||
MLIRContext &mlirContext;
|
||||
llvm::BumpPtrAllocator allocator;
|
||||
std::string debugCAGDotPath;
|
||||
};
|
||||
|
||||
/// Candidate for a quantized type conversion.
|
||||
struct CandidateQuantizedType {
|
||||
// Note that scheme encodes more than just the target type: it also encodes
|
||||
// additional constraints.
|
||||
enum class Scheme {
|
||||
// Uses aggregate range information for all nodes in the cluster to
|
||||
// solve for uniform scale and zero point.
|
||||
UniformPerLayer,
|
||||
// Uses aggregate per-axis range information for all nodes in the cluster
|
||||
// to solve for per-axis uniform scale and zero point.
|
||||
UniformPerAxisFixedPoint,
|
||||
// Uses the |explicitScaleZeroPoint| to set the scale (and zero point = 0)
|
||||
// for the uniform type. This typically overrides all other constraints
|
||||
// and is used for wide accumulator types (i.e. i32 bias vectors).
|
||||
UniformExplicitFixedPointScale,
|
||||
};
|
||||
unsigned ordinal;
|
||||
quant::AnyQuantizedType quantizedType;
|
||||
Scheme scheme;
|
||||
};
|
||||
|
||||
struct CAGUniformMetadata {
|
||||
/// Default salience for facts that are derived from data either statically
|
||||
/// discovered in the computation or observed from an outside source.
|
||||
static constexpr int SalienceDefault = 0;
|
||||
|
||||
/// Highest salience level for facts derived from overrides provided
|
||||
/// explicitly.
|
||||
static constexpr int SalienceForced = 100;
|
||||
|
||||
/// Salience for facts derived from constraints in how the math is
|
||||
/// expressed which must be satisfied.
|
||||
static constexpr int SalienceRequired = 200;
|
||||
|
||||
/// The range that the scheme must represent in order to accommodate the
|
||||
/// underlying data.
|
||||
ExpandingMinMaxFact requiredRange;
|
||||
|
||||
/// Bool vector of scheme ordinals that are disabled.
|
||||
llvm::SmallBitVector disabledCandidateTypes;
|
||||
|
||||
/// If set, then a solution has converged for the given per-layer scheme.
|
||||
quant::QuantizedType selectedType;
|
||||
|
||||
/// Optional scale and zero point to be used by types which solve via the
|
||||
/// UniformExplicitFixedPointScale scheme.
|
||||
DiscreteScaleZeroPointFact explicitScaleZeroPoint;
|
||||
|
||||
/// Prints a summary of the metadata suitable for display in a graph label.
|
||||
void printSummary(raw_ostream &os) const;
|
||||
};
|
||||
|
||||
} // end namespace quantizer
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_QUANTIZER_SUPPORT_METADATA_H
|
|
@ -1,200 +0,0 @@
|
|||
//===- Rules.h - Helpers for declaring facts and rules ----------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines helper classes and functions for managing state (facts),
|
||||
// merging and tracking modification for various data types important for
|
||||
// quantization.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_QUANTIZER_SUPPORT_RULES_H
|
||||
#define MLIR_QUANTIZER_SUPPORT_RULES_H
|
||||
|
||||
#include "llvm/ADT/Optional.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
|
||||
namespace mlir {
|
||||
namespace quantizer {
|
||||
|
||||
/// Typed indicator of whether a mutator produces a modification.
|
||||
struct ModificationResult {
|
||||
enum ModificationEnum { Retained, Modified } value;
|
||||
ModificationResult(ModificationEnum v) : value(v) {}
|
||||
|
||||
ModificationResult operator|(ModificationResult other) {
|
||||
if (value == Modified || other.value == Modified) {
|
||||
return ModificationResult(Modified);
|
||||
} else {
|
||||
return ModificationResult(Retained);
|
||||
}
|
||||
}
|
||||
|
||||
ModificationResult operator|=(ModificationResult other) {
|
||||
value =
|
||||
(value == Modified || other.value == Modified) ? Modified : Retained;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
inline ModificationResult modify(bool isModified = true) {
|
||||
return ModificationResult{isModified ? ModificationResult::Modified
|
||||
: ModificationResult::Retained};
|
||||
}
|
||||
|
||||
inline bool modified(ModificationResult m) {
|
||||
return m.value == ModificationResult::Modified;
|
||||
}
|
||||
|
||||
/// A fact that can converge through forward propagation alone without the
|
||||
/// need to track ownership or individual assertions. In practice, this works
|
||||
/// for static assertions that are either minimized or maximized and do not
|
||||
/// vary dynamically.
|
||||
///
|
||||
/// It is expected that ValueTy is appropriate to pass by value and has an
|
||||
/// operator==. The BinaryReducer type should have two static methods:
|
||||
/// using ValueTy : Type of the value.
|
||||
/// ValueTy initialValue() : Returns the initial value of the fact.
|
||||
/// ValueTy reduce(ValueTy lhs, ValueTy rhs) : Reduces two values.
|
||||
template <typename BinaryReducer>
|
||||
class BasePropagatedFact {
|
||||
public:
|
||||
using ValueTy = typename BinaryReducer::ValueTy;
|
||||
using ThisTy = BasePropagatedFact<BinaryReducer>;
|
||||
BasePropagatedFact()
|
||||
: value(BinaryReducer::initialValue()),
|
||||
salience(std::numeric_limits<int>::min()) {}
|
||||
|
||||
int getSalience() const { return salience; }
|
||||
bool hasValue() const { return salience != std::numeric_limits<int>::min(); }
|
||||
ValueTy getValue() const { return value; }
|
||||
ModificationResult assertValue(int assertSalience, ValueTy assertValue) {
|
||||
if (assertSalience > salience) {
|
||||
// New salience band.
|
||||
value = assertValue;
|
||||
salience = assertSalience;
|
||||
return modify(true);
|
||||
} else if (assertSalience < salience) {
|
||||
// Lower salience - ignore.
|
||||
return modify(false);
|
||||
}
|
||||
// Merge within same salience band.
|
||||
ValueTy updatedValue = BinaryReducer::reduce(value, assertValue);
|
||||
auto mod = modify(value != updatedValue);
|
||||
value = updatedValue;
|
||||
return mod;
|
||||
}
|
||||
ModificationResult mergeFrom(const ThisTy &other) {
|
||||
if (other.hasValue()) {
|
||||
return assertValue(other.getSalience(), other.getValue());
|
||||
}
|
||||
return modify(false);
|
||||
}
|
||||
|
||||
private:
|
||||
ValueTy value;
|
||||
int salience;
|
||||
};
|
||||
|
||||
/// A binary reducer that expands a min/max range represented by a pair
|
||||
/// of doubles such that it represents the largest of all inputs.
|
||||
/// The initial value is (Inf, -Inf).
|
||||
struct ExpandingMinMaxReducer {
|
||||
using ValueTy = std::pair<double, double>;
|
||||
static ValueTy initialValue() {
|
||||
return std::make_pair(std::numeric_limits<double>::infinity(),
|
||||
-std::numeric_limits<double>::infinity());
|
||||
}
|
||||
static ValueTy reduce(ValueTy lhs, ValueTy rhs) {
|
||||
return std::make_pair(std::min(lhs.first, rhs.first),
|
||||
std::max(lhs.second, rhs.second));
|
||||
}
|
||||
};
|
||||
using ExpandingMinMaxFact = BasePropagatedFact<ExpandingMinMaxReducer>;
|
||||
|
||||
/// A binary reducer that minimizing a numeric type.
|
||||
template <typename T>
|
||||
struct MinimizingNumericReducer {
|
||||
using ValueTy = T;
|
||||
static ValueTy initialValue() {
|
||||
if (std::numeric_limits<T>::has_infinity()) {
|
||||
return std::numeric_limits<T>::infinity();
|
||||
} else {
|
||||
return std::numeric_limits<T>::max();
|
||||
}
|
||||
}
|
||||
static ValueTy reduce(ValueTy lhs, ValueTy rhs) { return std::min(lhs, rhs); }
|
||||
};
|
||||
using MinimizingDoubleFact =
|
||||
BasePropagatedFact<MinimizingNumericReducer<double>>;
|
||||
using MinimizingIntFact = BasePropagatedFact<MinimizingNumericReducer<int>>;
|
||||
|
||||
/// A binary reducer that maximizes a numeric type.
|
||||
template <typename T>
|
||||
struct MaximizingNumericReducer {
|
||||
using ValueTy = T;
|
||||
static ValueTy initialValue() {
|
||||
if (std::numeric_limits<T>::has_infinity()) {
|
||||
return -std::numeric_limits<T>::infinity();
|
||||
} else {
|
||||
return std::numeric_limits<T>::min();
|
||||
}
|
||||
}
|
||||
static ValueTy reduce(ValueTy lhs, ValueTy rhs) { return std::max(lhs, rhs); }
|
||||
};
|
||||
using MaximizingDoubleFact =
|
||||
BasePropagatedFact<MaximizingNumericReducer<double>>;
|
||||
using MaximizingIntFact = BasePropagatedFact<MaximizingNumericReducer<int>>;
|
||||
|
||||
/// A fact and reducer for tracking agreement of discrete values. The value
|
||||
/// type consists of a |T| value and a flag indicating whether there is a
|
||||
/// conflict (in which case, the preserved value is arbitrary).
|
||||
template <typename T>
|
||||
struct DiscreteReducer {
|
||||
struct ValueTy {
|
||||
ValueTy() : conflict(false) {}
|
||||
ValueTy(T value) : value(value), conflict(false) {}
|
||||
ValueTy(T value, bool conflict) : value(value), conflict(conflict) {}
|
||||
llvm::Optional<T> value;
|
||||
bool conflict;
|
||||
bool operator==(const ValueTy &other) const {
|
||||
if (conflict != other.conflict)
|
||||
return false;
|
||||
if (value && other.value) {
|
||||
return *value == *other.value;
|
||||
} else {
|
||||
return !value && !other.value;
|
||||
}
|
||||
}
|
||||
bool operator!=(const ValueTy &other) const { return !(*this == other); }
|
||||
};
|
||||
static ValueTy initialValue() { return ValueTy(); }
|
||||
static ValueTy reduce(ValueTy lhs, ValueTy rhs) {
|
||||
if (!lhs.value && !rhs.value)
|
||||
return lhs;
|
||||
else if (!lhs.value)
|
||||
return rhs;
|
||||
else if (!rhs.value)
|
||||
return lhs;
|
||||
else
|
||||
return ValueTy(*lhs.value, *lhs.value != *rhs.value);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using DiscreteFact = BasePropagatedFact<DiscreteReducer<T>>;
|
||||
|
||||
/// Discrete scale/zeroPoint fact.
|
||||
using DiscreteScaleZeroPointFact = DiscreteFact<std::pair<double, int64_t>>;
|
||||
|
||||
} // end namespace quantizer
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_QUANTIZER_SUPPORT_RULES_H
|
|
@ -1,102 +0,0 @@
|
|||
//===- Statistics.h - Collects statistics over tensors ----------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines adapters for extracting various (per layer and per axis)
|
||||
// statistics over tensors.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_QUANTIZER_SUPPORT_STATISTICS_H
|
||||
#define MLIR_QUANTIZER_SUPPORT_STATISTICS_H
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace quantizer {
|
||||
|
||||
/// Statistics about a tensor axis (or the whole tensor).
|
||||
struct TensorAxisStatistics {
|
||||
int64_t sampleSize = 0;
|
||||
double minValue = 0;
|
||||
double maxValue = 0;
|
||||
double mean = 0;
|
||||
double variance = 0;
|
||||
|
||||
int64_t sampleSizePerAxis = 0;
|
||||
SmallVector<double, 4> minValuePerAxis;
|
||||
SmallVector<double, 4> maxValuePerAxis;
|
||||
SmallVector<double, 4> meanPerAxis;
|
||||
SmallVector<double, 4> variancePerAxis;
|
||||
|
||||
TensorAxisStatistics() {}
|
||||
TensorAxisStatistics(int64_t sampleSize, double minValue, double maxValue,
|
||||
double mean, double variance)
|
||||
: sampleSize(sampleSize), minValue(minValue), maxValue(maxValue),
|
||||
mean(mean), variance(variance) {}
|
||||
TensorAxisStatistics(int64_t sampleSize, ArrayRef<double> minValues,
|
||||
ArrayRef<double> maxValues, ArrayRef<double> means,
|
||||
ArrayRef<double> variances)
|
||||
: sampleSizePerAxis(sampleSize),
|
||||
minValuePerAxis(minValues.begin(), minValues.end()),
|
||||
maxValuePerAxis(maxValues.begin(), maxValues.end()),
|
||||
meanPerAxis(means.begin(), means.end()),
|
||||
variancePerAxis(variances.begin(), variances.end()) {}
|
||||
void clear() { *this = TensorAxisStatistics(); }
|
||||
};
|
||||
|
||||
/// Base class for querying statistics about a tensor.
|
||||
class AbstractTensorStatistics {
|
||||
public:
|
||||
virtual ~AbstractTensorStatistics() = default;
|
||||
|
||||
/// Gets statistics across the whole tensor.
|
||||
/// Returns true if statistics are valid and were populated.
|
||||
virtual bool get(TensorAxisStatistics &stats) const { return false; }
|
||||
|
||||
/// Whether this instance supports querying per axis statistics. If true,
|
||||
/// then getForAxis(...) can be used.
|
||||
virtual bool supportsPerAxis() const { return false; }
|
||||
|
||||
/// Count of axes supported in a per-axis query.
|
||||
virtual unsigned getAxisCount() const { return 0; }
|
||||
|
||||
/// Gets statistics for a specific axis (0..getAxisCount() - 1).
|
||||
/// Returns true if statistics are valid and were populated.
|
||||
virtual bool getForAxis(unsigned axis, TensorAxisStatistics &stats) const {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
/// Wraps an MLIR Attribute and returns statistics about it.
|
||||
/// It is expected that the attribute be one of:
|
||||
/// FloatAttr (scalar)
|
||||
/// DenseFPElementsAttr
|
||||
/// OpaqueElementsAttr (with Float based type)
|
||||
/// SparseElementAttr (with Float based type)
|
||||
class AttributeTensorStatistics : public AbstractTensorStatistics {
|
||||
public:
|
||||
AttributeTensorStatistics(Attribute attr) : attr(attr) {}
|
||||
|
||||
bool get(TensorAxisStatistics &stats) const override;
|
||||
|
||||
bool supportsPerAxis() const override;
|
||||
|
||||
unsigned getAxisCount() const override;
|
||||
|
||||
bool getForAxis(unsigned axis, TensorAxisStatistics &stats) const override;
|
||||
|
||||
private:
|
||||
Attribute attr;
|
||||
};
|
||||
|
||||
raw_ostream &operator<<(raw_ostream &os, const TensorAxisStatistics &stats);
|
||||
|
||||
} // end namespace quantizer
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_QUANTIZER_SUPPORT_STATISTICS_H
|
|
@ -1,31 +0,0 @@
|
|||
//===- TypeUtils.h - Helper function for manipulating types -----*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines various helper functions for manipulating types. The
|
||||
// process of quantizing typically involves a number of type manipulations
|
||||
// that are not very common elsewhere, and it is best to name them and define
|
||||
// them here versus inline in the rest of the tool.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef THIRD_PARTY_MLIR_EDGE_FXPSOLVER_SUPPORT_TYPEUTILS_H_
|
||||
#define THIRD_PARTY_MLIR_EDGE_FXPSOLVER_SUPPORT_TYPEUTILS_H_
|
||||
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace quantizer {
|
||||
|
||||
/// Given an arbitrary container or primitive type, returns the element type,
|
||||
/// where the element type is just the type for non-containers.
|
||||
Type getElementOrPrimitiveType(Type t);
|
||||
|
||||
} // namespace quantizer
|
||||
} // namespace mlir
|
||||
|
||||
#endif // THIRD_PARTY_MLIR_EDGE_FXPSOLVER_SUPPORT_TYPEUTILS_H_
|
|
@ -1,60 +0,0 @@
|
|||
//===- UniformConstraints.h - Constraints for uniform quant -----*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines a builder that lets you attach constraints necessary to
|
||||
// perform a variety of uniform quantization conversions to CAG anchors.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_QUANTIZER_SUPPORT_UNIFORMCONSTRAINTS_H
|
||||
#define MLIR_QUANTIZER_SUPPORT_UNIFORMCONSTRAINTS_H
|
||||
|
||||
#include "mlir/Quantizer/Support/Statistics.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace quantizer {
|
||||
|
||||
class CAGAnchorNode;
|
||||
class CAGSlice;
|
||||
|
||||
/// Factory methods for adding CAG constraints of various kinds suitable
|
||||
/// for solving for uniform quantization.
|
||||
class UniformConstraintsBuilder {
|
||||
public:
|
||||
UniformConstraintsBuilder(CAGSlice &slice) : slice(slice) {}
|
||||
|
||||
/// Adds a coupling constraint between two nodes, effectively treating
|
||||
/// them as a hard identity relationship.
|
||||
void coupleAnchors(CAGAnchorNode *a, CAGAnchorNode *b);
|
||||
|
||||
/// Applies statistics constraints to the given anchor, such that the solver
|
||||
/// ensures that the statistics are representable by chosen types.
|
||||
void applyStats(CAGAnchorNode *a, TensorAxisStatistics stats);
|
||||
|
||||
/// Applies a constraint to a node which allows solutions that do not extend
|
||||
/// beyond given min/max bounds (this is a hint that the tensor will not
|
||||
/// take values outside of these bounds). If either minValue or maxValue is
|
||||
/// NAN, then that side is considered open.
|
||||
void clamp(CAGAnchorNode *a, APFloat minValue, APFloat maxValue);
|
||||
|
||||
/// Propagates an explicit scale from an anchor that may have a uniform
|
||||
/// |selectedType| to the |explicitScaleZeroPoint| field of the to node.
|
||||
/// This is typically used with a to node that has a candidate quantized
|
||||
/// type of |UniformExplicitFixedPointScale|, indicating that it can be
|
||||
/// an arbitrary (signed) type that is expected to share the same scale
|
||||
/// as the originating node.
|
||||
void propagateExplicitScale(CAGAnchorNode *from, CAGAnchorNode *to);
|
||||
|
||||
private:
|
||||
CAGSlice &slice;
|
||||
};
|
||||
|
||||
} // namespace quantizer
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_QUANTIZER_SUPPORT_UNIFORMCONSTRAINTS_H
|
|
@ -1,86 +0,0 @@
|
|||
//===- UniformSolvers.h - Uniform type solver algorithms --------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines algorithms for solving uniform type parameters for various
|
||||
// conditions (i.e. fixed-point, affine, scale matching, etc).
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_QUANTIZER_SUPPORT_UNIFORMSOLVERS_H
|
||||
#define MLIR_QUANTIZER_SUPPORT_UNIFORMSOLVERS_H
|
||||
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
|
||||
namespace llvm {
|
||||
class raw_ostream;
|
||||
} // end namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
namespace quantizer {
|
||||
|
||||
struct UniformStorageParams {
|
||||
static UniformStorageParams getQuint8() { return {255, 0}; }
|
||||
static UniformStorageParams getQuint8SymmetricRight() { return {254, 1}; }
|
||||
static UniformStorageParams getQuint16() { return {32767, 0}; }
|
||||
|
||||
uint64_t numLevels;
|
||||
int64_t minValue;
|
||||
};
|
||||
|
||||
/// Solves for the uniform quantization scheme parameters delta and z given
|
||||
/// bounding min/max.
|
||||
class UniformParamsFromMinMaxSolver {
|
||||
public:
|
||||
UniformParamsFromMinMaxSolver(const UniformStorageParams &storageParams,
|
||||
double boundingMin, double boundingMax)
|
||||
: storageParams(storageParams), boundingMin(boundingMin),
|
||||
boundingMax(boundingMax) {}
|
||||
|
||||
/// Performs the computation, returning whether satisfied.
|
||||
bool compute();
|
||||
|
||||
// Params.
|
||||
double getBoundingMin() const { return boundingMin; }
|
||||
double getBoundingMax() const { return boundingMax; }
|
||||
bool isSatisfied() const { return satisfied; }
|
||||
double getAdjMin() const { return adjMin; }
|
||||
double getAdjMax() const { return adjMax; }
|
||||
double getScale() const { return delta; }
|
||||
int64_t getZp() const { return zp; }
|
||||
int getStepCount() const { return stepCount; }
|
||||
|
||||
// Quantize and dequantize.
|
||||
int64_t quantize(double x) const;
|
||||
double dequantize(int64_t xq) const;
|
||||
|
||||
private:
|
||||
const UniformStorageParams storageParams;
|
||||
const double boundingMin;
|
||||
const double boundingMax;
|
||||
|
||||
// Results
|
||||
int stepCount = 0;
|
||||
double adjMin = std::numeric_limits<double>::quiet_NaN();
|
||||
double adjMax = std::numeric_limits<double>::quiet_NaN();
|
||||
double delta = std::numeric_limits<double>::quiet_NaN();
|
||||
int64_t zp = 0;
|
||||
|
||||
bool satisfied = false;
|
||||
};
|
||||
|
||||
llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
|
||||
const UniformStorageParams &p);
|
||||
|
||||
llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
|
||||
const UniformParamsFromMinMaxSolver &s);
|
||||
|
||||
} // end namespace quantizer
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_QUANTIZER_SUPPORT_UNIFORMSOLVERS_H
|
|
@ -1,6 +0,0 @@
|
|||
|
||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
add_public_tablegen_target(MLIRQuantizerPassIncGen)
|
||||
|
||||
add_mlir_doc(Passes -gen-pass-doc QuantizerPasses ./)
|
|
@ -1,43 +0,0 @@
|
|||
//===- Passes.h - Quantizer passes -----------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines entry points to create passes to perform various kinds
|
||||
// of quantization related transforms.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_QUANTIZER_TRANSFORMS_PASSES_H
|
||||
#define MLIR_QUANTIZER_TRANSFORMS_PASSES_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace quantizer {
|
||||
|
||||
class SolverContext;
|
||||
class TargetConfiguration;
|
||||
|
||||
/// Creates a pass that infers quantized types based on metadata discovered
|
||||
/// in the computation.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
||||
createInferQuantizedTypesPass(SolverContext &solverContext,
|
||||
const TargetConfiguration &config);
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> createInferQuantizedTypesPass();
|
||||
|
||||
/// Creates a pass which removes any instrumentation and hint ops which have
|
||||
/// no effect on final runtime.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> createRemoveInstrumentationPass();
|
||||
|
||||
/// Adds default (dummy) statistics to ops that can benefit from runtime stats.
|
||||
/// Meant for testing.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> createAddDefaultStatsPass();
|
||||
|
||||
} // namespace quantizer
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_QUANTIZER_TRANSFORMS_PASSES_H
|
|
@ -1,31 +0,0 @@
|
|||
//===-- Passes.td - Quantizer pass definition file ---------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_QUANTIZER_TRANSFORMS_PASSES
|
||||
#define MLIR_QUANTIZER_TRANSFORMS_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def QuantizerAddDefaultStats : Pass<"quantizer-add-default-stats-test"> {
|
||||
let summary = "Add default (dummy) statistics to all ops that can benefit "
|
||||
"from runtime statistics";
|
||||
let constructor = "mlir::quantizer::createAddDefaultStatsPass()";
|
||||
}
|
||||
|
||||
def QuantizerInferQuantizedTypes : Pass<"quantizer-infer-quantized-types"> {
|
||||
let summary = "Infer quantized types for a module";
|
||||
let constructor = "mlir::quantizer::createInferQuantizedTypesPass()";
|
||||
}
|
||||
|
||||
def QuantizerRemoveInstrumentation : Pass<"quantizer-remove-instrumentation"> {
|
||||
let summary = "Remove instrumentation and hints which have no effect on "
|
||||
"final execution";
|
||||
let constructor = "mlir::quantizer::createRemoveInstrumentationPass()";
|
||||
}
|
||||
|
||||
#endif // MLIR_QUANTIZER_TRANSFORMS_PASSES
|
|
@ -7,7 +7,6 @@ add_subdirectory(IR)
|
|||
add_subdirectory(Interfaces)
|
||||
add_subdirectory(Parser)
|
||||
add_subdirectory(Pass)
|
||||
add_subdirectory(Quantizer)
|
||||
add_subdirectory(Support)
|
||||
add_subdirectory(TableGen)
|
||||
add_subdirectory(Target)
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
add_subdirectory(Affine)
|
||||
add_subdirectory(AVX512)
|
||||
add_subdirectory(FxpMathOps)
|
||||
add_subdirectory(GPU)
|
||||
add_subdirectory(Linalg)
|
||||
add_subdirectory(LLVMIR)
|
||||
|
|
|
@ -1,21 +0,0 @@
|
|||
add_mlir_dialect_library(MLIRFxpMathOps
|
||||
IR/FxpMathOps.cpp
|
||||
Transforms/LowerUniformRealMath.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/FxpMathOps
|
||||
|
||||
DEPENDS
|
||||
MLIRFxpMathOpsIncGen
|
||||
MLIRFxpMathPassIncGen
|
||||
)
|
||||
|
||||
target_link_libraries(MLIRFxpMathOps
|
||||
PUBLIC
|
||||
MLIRQuant
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRSideEffects
|
||||
MLIRSupport
|
||||
MLIRStandardOps
|
||||
)
|
|
@ -1,29 +0,0 @@
|
|||
//===- FxpMathOps.cpp - Op implementation for FxpMathOps ------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::fxpmath;
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/FxpMathOps/FxpMathOps.cpp.inc"
|
||||
|
||||
FxpMathOpsDialect::FxpMathOpsDialect(MLIRContext *context)
|
||||
: Dialect(/*name=*/"fxpmath", context) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/FxpMathOps/FxpMathOps.cpp.inc"
|
||||
>();
|
||||
}
|
|
@ -1,394 +0,0 @@
|
|||
//===- LowerUniformRealMath.cpp ------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "UniformKernelUtils.h"
|
||||
|
||||
#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
|
||||
#include "mlir/Dialect/FxpMathOps/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::fxpmath;
|
||||
using namespace mlir::fxpmath::detail;
|
||||
using namespace mlir::quant;
|
||||
|
||||
namespace {
|
||||
struct LowerUniformRealMathPass
|
||||
: public FunctionPass<LowerUniformRealMathPass> {
|
||||
/// Include the generated pass utilities.
|
||||
#define GEN_PASS_FxpMathLowerUniformRealMath
|
||||
#include "mlir/Dialect/FxpMathOps/Passes.h.inc"
|
||||
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
struct LowerUniformCastsPass : public FunctionPass<LowerUniformCastsPass> {
|
||||
/// Include the generated pass utilities.
|
||||
#define GEN_PASS_FxpMathLowerUniformCasts
|
||||
#include "mlir/Dialect/FxpMathOps/Passes.h.inc"
|
||||
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dequantize
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static Value emitUniformPerLayerDequantize(Location loc, Value input,
|
||||
UniformQuantizedType elementType,
|
||||
PatternRewriter &rewriter) {
|
||||
// Pre-conditions.
|
||||
if (!elementType.isSigned()) {
|
||||
// TODO: Support unsigned storage type.
|
||||
emitWarning(loc, "unimplemented: dequantize signed uniform");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Type storageType = elementType.castToStorageType(input.getType());
|
||||
Type realType = elementType.castToExpressedType(input.getType());
|
||||
Type intermediateType =
|
||||
castElementType(storageType, IntegerType::get(32, rewriter.getContext()));
|
||||
assert(storageType && "cannot cast to storage type");
|
||||
assert(realType && "cannot cast to expressed type");
|
||||
|
||||
// Cast to storage type.
|
||||
input = rewriter.create<StorageCastOp>(loc, storageType, input);
|
||||
|
||||
// Promote to intermediate type.
|
||||
input = rewriter.create<ConvertISOp>(loc, intermediateType, input);
|
||||
|
||||
// Apply zero-point offset.
|
||||
if (elementType.getZeroPoint() != 0) {
|
||||
Value negZeroPointConst = rewriter.create<ConstantOp>(
|
||||
loc, broadcastScalarConstIntValue(intermediateType,
|
||||
-elementType.getZeroPoint()));
|
||||
input = rewriter.create<AddIOp>(loc, input, negZeroPointConst);
|
||||
}
|
||||
|
||||
// Convert to float.
|
||||
input = rewriter.create<ConvertISToFOp>(loc, realType, input);
|
||||
|
||||
// Mul by scale.
|
||||
Value scaleConst = rewriter.create<ConstantOp>(
|
||||
loc, broadcastScalarConstFloatValue(realType,
|
||||
APFloat(elementType.getScale())));
|
||||
return rewriter.create<MulFOp>(loc, input, scaleConst);
|
||||
}
|
||||
|
||||
static Value
|
||||
emitUniformPerAxisDequantize(Location loc, Value input,
|
||||
UniformQuantizedPerAxisType elementType,
|
||||
PatternRewriter &rewriter) {
|
||||
// TODO: Support per-axis dequantize.
|
||||
rewriter.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Warning)
|
||||
<< "unimplemented: per-axis uniform dequantization";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static Value emitDequantize(Location loc, Value input,
|
||||
PatternRewriter &rewriter) {
|
||||
Type inputType = input.getType();
|
||||
QuantizedType qElementType =
|
||||
QuantizedType::getQuantizedElementType(inputType);
|
||||
if (auto uperLayerElementType =
|
||||
qElementType.dyn_cast_or_null<UniformQuantizedType>()) {
|
||||
return emitUniformPerLayerDequantize(loc, input, uperLayerElementType,
|
||||
rewriter);
|
||||
} else if (auto uperAxisElementType =
|
||||
qElementType.dyn_cast_or_null<UniformQuantizedPerAxisType>()) {
|
||||
return emitUniformPerAxisDequantize(loc, input, uperAxisElementType,
|
||||
rewriter);
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> {
|
||||
using OpRewritePattern<DequantizeCastOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(DequantizeCastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Type inputType = op.arg().getType();
|
||||
Type outputType = op.getResult().getType();
|
||||
|
||||
QuantizedType inputElementType =
|
||||
QuantizedType::getQuantizedElementType(inputType);
|
||||
Type expressedOutputType = inputElementType.castToExpressedType(inputType);
|
||||
if (expressedOutputType != outputType) {
|
||||
// Not a valid uniform cast.
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter);
|
||||
if (!dequantizedValue) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, dequantizedValue);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Elementwise add
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult
|
||||
tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info,
|
||||
PatternRewriter &rewriter) {
|
||||
if (!info.resultType.isSigned() || info.lhsType != info.resultType ||
|
||||
info.rhsType != info.resultType) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Choose a byte aligned intermediate width big enough to perform the
|
||||
// calculation without overflow.
|
||||
// TODO: This should probably be made just big enough to avoid overflow and
|
||||
// leave the downstream tooling to decide how to align that to machine
|
||||
// word sizes.
|
||||
unsigned intermediateWidth =
|
||||
info.resultType.getStorageTypeIntegralWidth() <= 8 ? 16 : 32;
|
||||
IntegerType intermediateElementType =
|
||||
IntegerType::get(intermediateWidth, rewriter.getContext());
|
||||
Type intermediateType =
|
||||
castElementType(info.resultStorageType, intermediateElementType);
|
||||
|
||||
// Cast operands to storage type.
|
||||
Value lhsValue = rewriter
|
||||
.create<StorageCastOp>(info.op->getLoc(),
|
||||
info.lhsStorageType, info.lhs)
|
||||
.getResult();
|
||||
Value rhsValue = rewriter
|
||||
.create<StorageCastOp>(info.op->getLoc(),
|
||||
info.rhsStorageType, info.rhs)
|
||||
.getResult();
|
||||
|
||||
// Cast to the intermediate sized type.
|
||||
lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
|
||||
lhsValue);
|
||||
rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
|
||||
rhsValue);
|
||||
|
||||
// Add.
|
||||
Value resultValue =
|
||||
rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, rhsValue);
|
||||
|
||||
// Zero point offset adjustment.
|
||||
// result = (lhs - zp) + (rhs - zp) + zp
|
||||
// zpOffset = -zp
|
||||
int zpOffset = -1 * info.resultType.getZeroPoint();
|
||||
if (zpOffset != 0) {
|
||||
Value zpOffsetConst = rewriter.create<ConstantOp>(
|
||||
info.op->getLoc(),
|
||||
broadcastScalarConstIntValue(intermediateType, zpOffset));
|
||||
resultValue =
|
||||
rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
|
||||
}
|
||||
|
||||
// Clamp.
|
||||
auto clampMinMax = info.getClampMinMax(intermediateElementType);
|
||||
resultValue = rewriter.create<ClampISOp>(
|
||||
info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
|
||||
|
||||
// Convert back to original type.
|
||||
resultValue = rewriter.create<ConvertISOp>(
|
||||
info.op->getLoc(), info.resultStorageType, resultValue);
|
||||
|
||||
// Cast back for new result.
|
||||
rewriter.replaceOpWithNewOp<StorageCastOp>(
|
||||
info.op, info.getQuantizedResultType(), resultValue);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Elementwise mul
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult
|
||||
tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info,
|
||||
PatternRewriter &rewriter) {
|
||||
if (!info.resultType.isSigned()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
double outputMultiplierReal = info.lhsType.getScale() *
|
||||
info.rhsType.getScale() /
|
||||
info.resultType.getScale();
|
||||
if (outputMultiplierReal > 1.0) {
|
||||
info.op->emitWarning(
|
||||
"unimplemented: cannot multiply with multiplier > 1.0");
|
||||
return failure();
|
||||
}
|
||||
|
||||
// TODO: Choose an appropriate intermediate width for muls > 8 bits to
|
||||
// avoid overflow.
|
||||
unsigned intermediateWidth = 32;
|
||||
IntegerType intermediateElementType =
|
||||
IntegerType::get(intermediateWidth, rewriter.getContext());
|
||||
Type intermediateType =
|
||||
castElementType(info.resultStorageType, intermediateElementType);
|
||||
|
||||
// Cast operands to storage type.
|
||||
Value lhsValue = rewriter
|
||||
.create<StorageCastOp>(info.op->getLoc(),
|
||||
info.lhsStorageType, info.lhs)
|
||||
.getResult();
|
||||
Value rhsValue = rewriter
|
||||
.create<StorageCastOp>(info.op->getLoc(),
|
||||
info.rhsStorageType, info.rhs)
|
||||
.getResult();
|
||||
|
||||
// Cast to the intermediate sized type.
|
||||
lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
|
||||
lhsValue);
|
||||
rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
|
||||
rhsValue);
|
||||
|
||||
// Apply argument zeroPoints.
|
||||
if (info.lhsType.getZeroPoint() != 0) {
|
||||
Value zpOffsetConst = rewriter.create<ConstantOp>(
|
||||
info.op->getLoc(), broadcastScalarConstIntValue(
|
||||
intermediateType, -info.lhsType.getZeroPoint()));
|
||||
lhsValue =
|
||||
rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, zpOffsetConst);
|
||||
}
|
||||
|
||||
if (info.rhsType.getZeroPoint() != 0) {
|
||||
Value zpOffsetConst = rewriter.create<ConstantOp>(
|
||||
info.op->getLoc(), broadcastScalarConstIntValue(
|
||||
intermediateType, -info.rhsType.getZeroPoint()));
|
||||
rhsValue =
|
||||
rewriter.create<AddIOp>(info.op->getLoc(), rhsValue, zpOffsetConst);
|
||||
}
|
||||
|
||||
// Mul.
|
||||
Value resultValue =
|
||||
rewriter.create<MulIOp>(info.op->getLoc(), lhsValue, rhsValue);
|
||||
|
||||
// Scale output.
|
||||
QuantizedMultiplierSmallerThanOneExp outputMultiplier(outputMultiplierReal);
|
||||
resultValue = rewriter.create<VecScalarSaturatingRoundingDoublingHighMulISOp>(
|
||||
info.op->getLoc(), resultValue,
|
||||
IntegerAttr::get(intermediateElementType, outputMultiplier.multiplier));
|
||||
resultValue = rewriter.create<RoundingDivideByPotISOp>(
|
||||
info.op->getLoc(), resultValue,
|
||||
IntegerAttr::get(intermediateElementType, -outputMultiplier.exponent));
|
||||
|
||||
// Zero point offset adjustment.
|
||||
if (info.resultType.getZeroPoint() != 0) {
|
||||
Value zpOffsetConst = rewriter.create<ConstantOp>(
|
||||
info.op->getLoc(),
|
||||
broadcastScalarConstIntValue(intermediateType,
|
||||
info.resultType.getZeroPoint()));
|
||||
resultValue =
|
||||
rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
|
||||
}
|
||||
|
||||
// Clamp.
|
||||
auto clampMinMax = info.getClampMinMax(intermediateElementType);
|
||||
resultValue = rewriter.create<ClampISOp>(
|
||||
info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
|
||||
|
||||
// Convert back to original type.
|
||||
resultValue = rewriter.create<ConvertISOp>(
|
||||
info.op->getLoc(), info.resultStorageType, resultValue);
|
||||
|
||||
// Cast back for new result.
|
||||
rewriter.replaceOpWithNewOp<StorageCastOp>(
|
||||
info.op, info.getQuantizedResultType(), resultValue);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct UniformRealAddEwPattern : public OpRewritePattern<RealAddEwOp> {
|
||||
using OpRewritePattern<RealAddEwOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(RealAddEwOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
|
||||
op.clamp_max());
|
||||
if (!info.isValid()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Try all of the permutations we support.
|
||||
if (succeeded(tryRewriteAffineAddEwIsomorphicSigned(info, rewriter))) {
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> {
|
||||
using OpRewritePattern<RealMulEwOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(RealMulEwOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
|
||||
op.clamp_max());
|
||||
if (!info.isValid()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Try all of the permutations we support.
|
||||
if (succeeded(tryRewriteAffineMulEwSigned(info, rewriter))) {
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LowerUniformRealMath pass
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void LowerUniformRealMathPass::runOnFunction() {
|
||||
auto fn = getFunction();
|
||||
OwningRewritePatternList patterns;
|
||||
auto *context = &getContext();
|
||||
patterns.insert<UniformRealAddEwPattern, UniformRealMulEwPattern>(context);
|
||||
applyPatternsGreedily(fn, patterns);
|
||||
}
|
||||
|
||||
std::unique_ptr<OpPassBase<FuncOp>>
|
||||
mlir::fxpmath::createLowerUniformRealMathPass() {
|
||||
return std::make_unique<LowerUniformRealMathPass>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LowerUniformCasts pass
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void LowerUniformCastsPass::runOnFunction() {
|
||||
auto fn = getFunction();
|
||||
OwningRewritePatternList patterns;
|
||||
auto *context = &getContext();
|
||||
patterns.insert<UniformDequantizePattern>(context);
|
||||
applyPatternsGreedily(fn, patterns);
|
||||
}
|
||||
|
||||
std::unique_ptr<OpPassBase<FuncOp>>
|
||||
mlir::fxpmath::createLowerUniformCastsPass() {
|
||||
return std::make_unique<LowerUniformCastsPass>();
|
||||
}
|
|
@ -1,227 +0,0 @@
|
|||
//===- UniformKernelUtils.h - Utilities for lowering uniform math - C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
|
||||
#define MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
|
||||
|
||||
#include "mlir/Dialect/Quant/QuantOps.h"
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h"
|
||||
#include "mlir/Dialect/Quant/UniformSupport.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
namespace mlir {
|
||||
namespace fxpmath {
|
||||
namespace detail {
|
||||
|
||||
inline quant::UniformQuantizedType getUniformElementType(Type t) {
|
||||
return quant::QuantizedType::getQuantizedElementType(t)
|
||||
.dyn_cast_or_null<quant::UniformQuantizedType>();
|
||||
}
|
||||
|
||||
inline bool hasStorageBitWidth(quant::QuantizedType t,
|
||||
ArrayRef<unsigned> checkWidths) {
|
||||
unsigned w = t.getStorageType().getIntOrFloatBitWidth();
|
||||
for (unsigned checkWidth : checkWidths) {
|
||||
if (w == checkWidth)
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Computes the log2(x), rounded to an integral value. Returns whether 'x' can
|
||||
/// be considered an exact integral value.
|
||||
template <typename F> bool integralLog2(F x, int &log2Result) {
|
||||
const F xLog2 = std::log(x) * (1.0 / std::log(2.0));
|
||||
const F xLog2Rounded = std::round(xLog2);
|
||||
const F xLog2Frac = xLog2 - xLog2Rounded;
|
||||
log2Result = static_cast<int>(xLog2Rounded);
|
||||
// Allow small comparison slop below the level that would make a difference
|
||||
// for 2^16 levels.
|
||||
return std::abs(xLog2Frac) < 1e-6;
|
||||
}
|
||||
|
||||
/// Helper class for operating on binary operations where all operands
|
||||
/// and the result are a UniformQuantizedType.
|
||||
struct UniformBinaryOpInfo {
|
||||
UniformBinaryOpInfo(Operation *op, Value lhs, Value rhs,
|
||||
Optional<APFloat> clampMin, Optional<APFloat> clampMax)
|
||||
: op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax),
|
||||
lhsType(getUniformElementType(lhs.getType())),
|
||||
rhsType(getUniformElementType(rhs.getType())),
|
||||
resultType(getUniformElementType(*op->result_type_begin())),
|
||||
lhsStorageType(quant::QuantizedType::castToStorageType(lhs.getType())),
|
||||
rhsStorageType(quant::QuantizedType::castToStorageType(rhs.getType())),
|
||||
resultStorageType(
|
||||
quant::QuantizedType::castToStorageType(*op->result_type_begin())) {
|
||||
}
|
||||
|
||||
/// Returns whether this info is valid (all types defined, etc).
|
||||
bool isValid() const {
|
||||
return lhsType && rhsType && resultType && lhsStorageType &&
|
||||
rhsStorageType && resultStorageType;
|
||||
}
|
||||
|
||||
/// Gets the final quantized result type of the result.
|
||||
Type getQuantizedResultType() const { return *op->result_type_begin(); }
|
||||
|
||||
/// Returns whether the storage type of all operands is identical.
|
||||
bool isSameStorageType() const {
|
||||
return lhsType.getStorageType() == rhsType.getStorageType() &&
|
||||
lhsType.getStorageType() == resultType.getStorageType();
|
||||
}
|
||||
|
||||
/// Returns whether all operands and result are considered fixedpoint power
|
||||
/// of two, setting the lhs, rhs, and result log2 scale references.
|
||||
bool isFixedPointPOT(int &lhsLog2Scale, int &rhsLog2Scale,
|
||||
int &resultLog2Scale) const {
|
||||
if (!lhsType.isFixedPoint() || !rhsType.isFixedPoint() ||
|
||||
!resultType.isFixedPoint()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!integralLog2(lhsType.getScale(), lhsLog2Scale) ||
|
||||
!integralLog2(rhsType.getScale(), rhsLog2Scale) ||
|
||||
!integralLog2(resultType.getScale(), resultLog2Scale)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Gets the result integer clamp range given the result quantized type
|
||||
// and any explicit clamp provided as attributes.
|
||||
std::pair<IntegerAttr, IntegerAttr> getClampMinMax(IntegerType ty) const {
|
||||
int64_t typeMin = resultType.getStorageTypeMin();
|
||||
int64_t typeMax = resultType.getStorageTypeMax();
|
||||
|
||||
if (clampMin || clampMax) {
|
||||
quant::UniformQuantizedValueConverter conv(resultType);
|
||||
if (clampMin) {
|
||||
typeMin = std::max(typeMin, conv.quantizeFloatToInt64(*clampMin));
|
||||
}
|
||||
if (clampMax) {
|
||||
typeMax = std::min(typeMax, conv.quantizeFloatToInt64(*clampMax));
|
||||
}
|
||||
}
|
||||
|
||||
// The quantized, integral ops expect clamps as 32bit ints.
|
||||
return {
|
||||
IntegerAttr::get(ty, typeMin),
|
||||
IntegerAttr::get(ty, typeMax),
|
||||
};
|
||||
}
|
||||
|
||||
Operation *op;
|
||||
Value lhs;
|
||||
Value rhs;
|
||||
Optional<APFloat> clampMin;
|
||||
Optional<APFloat> clampMax;
|
||||
|
||||
// Element UniformQuantizedType for operands/result.
|
||||
quant::UniformQuantizedType lhsType;
|
||||
quant::UniformQuantizedType rhsType;
|
||||
quant::UniformQuantizedType resultType;
|
||||
|
||||
// Full storage-based types.
|
||||
Type lhsStorageType;
|
||||
Type rhsStorageType;
|
||||
Type resultStorageType;
|
||||
};
|
||||
|
||||
/// Derives a quantized multiplier and shift from a real valued multiplier
|
||||
/// less than 1.
|
||||
struct QuantizedMultiplierSmallerThanOneExp {
|
||||
QuantizedMultiplierSmallerThanOneExp(double realMultiplier) {
|
||||
assert(realMultiplier < 1.0);
|
||||
assert(realMultiplier > 0.0);
|
||||
|
||||
const double q = std::frexp(realMultiplier, &exponent);
|
||||
auto qFixed = static_cast<int64_t>(std::round(q * (1ll << 31)));
|
||||
assert(qFixed <= (1ll << 31));
|
||||
if (qFixed == (1ll << 31)) {
|
||||
qFixed /= 2;
|
||||
++exponent;
|
||||
}
|
||||
assert(qFixed <= std::numeric_limits<int32_t>::max());
|
||||
multiplier = static_cast<int32_t>(qFixed);
|
||||
}
|
||||
|
||||
int32_t multiplier;
|
||||
int exponent;
|
||||
};
|
||||
|
||||
/// Casts an integer or floating point based shaped type to a new element type.
|
||||
inline Type castElementType(Type t, Type newElementType) {
|
||||
if (auto st = t.dyn_cast<ShapedType>()) {
|
||||
switch (st.getKind()) {
|
||||
case StandardTypes::Kind::Vector:
|
||||
return VectorType::get(st.getShape(), newElementType);
|
||||
case StandardTypes::Kind::RankedTensor:
|
||||
return RankedTensorType::get(st.getShape(), newElementType);
|
||||
case StandardTypes::Kind::UnrankedTensor:
|
||||
return UnrankedTensorType::get(newElementType);
|
||||
case StandardTypes::Kind::MemRef:
|
||||
return MemRefType::Builder(st.cast<MemRefType>())
|
||||
.setElementType(newElementType);
|
||||
}
|
||||
}
|
||||
assert(t.isSignlessIntOrFloat());
|
||||
return newElementType;
|
||||
}
|
||||
|
||||
/// Creates an IntegerAttr with a type that matches the shape of 't' (which can
|
||||
/// be a scalar primitive or a shaped type).
|
||||
inline Attribute broadcastScalarConstIntValue(Type t, int64_t value) {
|
||||
if (auto st = t.dyn_cast<ShapedType>()) {
|
||||
assert(st.getElementType().isSignlessInteger());
|
||||
return DenseElementsAttr::get(st,
|
||||
IntegerAttr::get(st.getElementType(), value));
|
||||
}
|
||||
|
||||
auto integerType = t.cast<IntegerType>();
|
||||
assert(t.isSignlessInteger() && "integer broadcast must be of integer type");
|
||||
return IntegerAttr::get(integerType, value);
|
||||
}
|
||||
|
||||
/// Given an APFloat, converts it to the float semantics that matches the
|
||||
/// given FloatType, silently ignoring inexact conversions.
|
||||
inline APFloat convertFloatToType(FloatType ft, APFloat value) {
|
||||
bool losesInfo;
|
||||
auto status = value.convert(ft.getFloatSemantics(),
|
||||
APFloat::rmNearestTiesToEven, &losesInfo);
|
||||
(void)status; // unused in opt mode
|
||||
assert((status & (APFloat::opDivByZero | APFloat::opInvalidOp)) == 0 &&
|
||||
"could not convert to float const");
|
||||
return value;
|
||||
}
|
||||
|
||||
/// Creates a FloatAttr with a type that matches the shape of 't' (which can be
|
||||
/// a scalar primitive or a shaped type).
|
||||
inline Attribute broadcastScalarConstFloatValue(Type t, APFloat value) {
|
||||
if (auto st = t.dyn_cast<ShapedType>()) {
|
||||
FloatType floatElementType = st.getElementType().dyn_cast<FloatType>();
|
||||
assert(floatElementType &&
|
||||
"float broadcast element type must be float like");
|
||||
APFloat apValue = convertFloatToType(floatElementType, value);
|
||||
return DenseElementsAttr::get(st,
|
||||
FloatAttr::get(st.getElementType(), apValue));
|
||||
} else {
|
||||
auto floatType = t.dyn_cast<FloatType>();
|
||||
assert(floatType && "float broadcast must be of float type");
|
||||
APFloat apValue = convertFloatToType(floatType, value);
|
||||
return FloatAttr::get(floatType, apValue);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace fxpmath
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
|
|
@ -1,63 +0,0 @@
|
|||
# Support.
|
||||
add_mlir_library(MLIRQuantizerSupport
|
||||
Support/Configuration.cpp
|
||||
Support/ConstraintAnalysisGraph.cpp
|
||||
Support/Metadata.cpp
|
||||
Support/Statistics.cpp
|
||||
Support/TypeUtils.cpp
|
||||
Support/UniformConstraints.cpp
|
||||
Support/UniformSolvers.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
)
|
||||
|
||||
target_link_libraries(MLIRQuantizerSupport
|
||||
PUBLIC
|
||||
MLIRIR
|
||||
MLIRQuant
|
||||
MLIRSupport
|
||||
MLIRStandardOps
|
||||
LLVMSupport
|
||||
)
|
||||
|
||||
# Configurations.
|
||||
add_mlir_library(MLIRQuantizerFxpMathConfig
|
||||
Configurations/FxpMathConfig.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
|
||||
DEPENDS
|
||||
MLIRFxpMathOpsIncGen
|
||||
)
|
||||
|
||||
target_link_libraries(MLIRQuantizerFxpMathConfig
|
||||
PUBLIC
|
||||
MLIRIR
|
||||
MLIRFxpMathOps
|
||||
MLIRQuant
|
||||
MLIRQuantizerSupport
|
||||
MLIRStandardOps
|
||||
MLIRSupport
|
||||
LLVMSupport
|
||||
)
|
||||
|
||||
# Transforms.
|
||||
add_mlir_library(MLIRQuantizerTransforms
|
||||
Transforms/AddDefaultStatsTestPass.cpp
|
||||
Transforms/InferQuantizedTypesPass.cpp
|
||||
Transforms/RemoveInstrumentationPass.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
|
||||
DEPENDS
|
||||
MLIRQuantizerPassIncGen
|
||||
)
|
||||
target_link_libraries(MLIRQuantizerTransforms
|
||||
PUBLIC
|
||||
MLIRIR
|
||||
MLIRQuantizerFxpMathConfig
|
||||
MLIRQuantizerSupport
|
||||
MLIRQuant
|
||||
MLIRPass
|
||||
LLVMSupport
|
||||
)
|
|
@ -1,278 +0,0 @@
|
|||
//===- FxpMathConfig.cpp - Reference fixed point config -------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines a TargetConfiguration for reference fixed-point math
|
||||
// quantization scheme based on the FxpMathOps (plus a small category of
|
||||
// extension ops that can be added from other dialects).
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Quantizer/Configurations/FxpMathConfig.h"
|
||||
|
||||
#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h"
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
|
||||
#include "mlir/Quantizer/Support/Metadata.h"
|
||||
#include "mlir/Quantizer/Support/Statistics.h"
|
||||
#include "mlir/Quantizer/Support/UniformConstraints.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::quantizer;
|
||||
using namespace mlir::fxpmath;
|
||||
using namespace mlir::quant;
|
||||
using namespace std::placeholders;
|
||||
|
||||
namespace {
|
||||
|
||||
struct FxpMathTargetConfigImpl : public FxpMathTargetConfig {
|
||||
FxpMathTargetConfigImpl(SolverContext &context)
|
||||
: FxpMathTargetConfig(context) {
|
||||
Builder b(&context.getMlirContext());
|
||||
IntegerType i8Type = b.getIntegerType(8);
|
||||
IntegerType i16Type = b.getIntegerType(16);
|
||||
IntegerType i32Type = b.getIntegerType(32);
|
||||
|
||||
q8 = addCandidateType(
|
||||
AnyQuantizedType::get(QuantizationFlags::Signed, i8Type, nullptr,
|
||||
std::numeric_limits<int8_t>::min(),
|
||||
std::numeric_limits<int8_t>::max()),
|
||||
CandidateQuantizedType::Scheme::UniformPerLayer);
|
||||
q16 = addCandidateType(
|
||||
AnyQuantizedType::get(QuantizationFlags::Signed, i16Type, nullptr,
|
||||
std::numeric_limits<int16_t>::min(),
|
||||
std::numeric_limits<int16_t>::max()),
|
||||
CandidateQuantizedType::Scheme::UniformPerLayer);
|
||||
q32ExplicitFixedPoint = addCandidateType(
|
||||
AnyQuantizedType::get(QuantizationFlags::Signed, i32Type, nullptr,
|
||||
std::numeric_limits<int32_t>::min(),
|
||||
std::numeric_limits<int32_t>::max()),
|
||||
CandidateQuantizedType::Scheme::UniformExplicitFixedPointScale);
|
||||
|
||||
// Op handlers.
|
||||
addOpHandler<ConstantOp>(
|
||||
std::bind(&FxpMathTargetConfigImpl::handleConstant, this, _1, _2));
|
||||
addOpHandler<mlir::ReturnOp>(
|
||||
std::bind(&FxpMathTargetConfigImpl::handleTerminal, this, _1, _2));
|
||||
addOpHandler<quant::StatisticsOp>(
|
||||
std::bind(&FxpMathTargetConfigImpl::handleStats, this, _1, _2));
|
||||
|
||||
// FxpMathOps.
|
||||
addOpHandler<RealAddEwOp>(
|
||||
std::bind(&FxpMathTargetConfigImpl::handleAdd, this, _1, _2));
|
||||
addOpHandler<RealMulEwOp>(
|
||||
std::bind(&FxpMathTargetConfigImpl::handleMul, this, _1, _2));
|
||||
addOpHandler<RealMatMulOp>(
|
||||
std::bind(&FxpMathTargetConfigImpl::handleMatMul, this, _1, _2));
|
||||
addOpHandler<RealMatMulBiasOp>(
|
||||
std::bind(&FxpMathTargetConfigImpl::handleMatMulBias, this, _1, _2));
|
||||
|
||||
// Require stats ops.
|
||||
addRequireStatsOp<RealAddEwOp>();
|
||||
addRequireStatsOp<RealSubEwOp>();
|
||||
addRequireStatsOp<RealDivEwOp>();
|
||||
addRequireStatsOp<RealMulEwOp>();
|
||||
addRequireStatsOp<RealMatMulOp>();
|
||||
addRequireStatsOp<RealMatMulBiasOp>();
|
||||
}
|
||||
|
||||
bool isHandledType(Type t) const final {
|
||||
if (t.isa<FloatType>())
|
||||
return true;
|
||||
return (t.isa<VectorType>() || t.isa<TensorType>()) &&
|
||||
t.cast<ShapedType>().getElementType().isa<FloatType>();
|
||||
}
|
||||
|
||||
void finalizeAnchors(CAGSlice &cag) const override {
|
||||
cag.enumerateImpliedConnections(
|
||||
[&](CAGAnchorNode *from, CAGAnchorNode *to) {
|
||||
UniformConstraintsBuilder(cag).coupleAnchors(from, to);
|
||||
});
|
||||
}
|
||||
|
||||
void addValueIdentityOpByName(StringRef opName) override {
|
||||
addOpHandlerByName(
|
||||
opName,
|
||||
std::bind(&FxpMathTargetConfigImpl::handleValueIdentity, this, _1, _2));
|
||||
}
|
||||
|
||||
void handleValueIdentity(Operation *op, CAGSlice &cag) const {
|
||||
assert(op->getNumResults() == 1);
|
||||
if (!isHandledType(op->getResult(0).getType()))
|
||||
return;
|
||||
|
||||
auto resultNode = cag.getResultAnchor(op, 0);
|
||||
resultNode->setTypeTransformRule(
|
||||
CAGAnchorNode::TypeTransformRule::DirectStorage);
|
||||
|
||||
for (unsigned opIdx = 0, e = op->getNumOperands(); opIdx < e; ++opIdx) {
|
||||
if (!isHandledType(op->getOperand(opIdx).getType()))
|
||||
continue;
|
||||
auto operandNode = cag.getOperandAnchor(op, opIdx);
|
||||
operandNode->setTypeTransformRule(
|
||||
CAGAnchorNode::TypeTransformRule::DirectStorage);
|
||||
UniformConstraintsBuilder(cag).coupleAnchors(operandNode, resultNode);
|
||||
}
|
||||
}
|
||||
|
||||
void handleConstant(Operation *op, CAGSlice &cag) const {
|
||||
if (!isHandledType(op->getResult(0).getType()))
|
||||
return;
|
||||
|
||||
auto resultNode = cag.getResultAnchor(op, 0);
|
||||
resultNode->setTypeTransformRule(
|
||||
CAGAnchorNode::TypeTransformRule::ExpressedOnly);
|
||||
Attribute valueAttr;
|
||||
if (!matchPattern(op, m_Constant(&valueAttr))) {
|
||||
return;
|
||||
}
|
||||
|
||||
AttributeTensorStatistics stats(valueAttr);
|
||||
TensorAxisStatistics layerStats;
|
||||
if (!stats.get(layerStats)) {
|
||||
op->emitOpError("could not compute statistics");
|
||||
return;
|
||||
}
|
||||
|
||||
UniformConstraintsBuilder(cag).applyStats(resultNode, layerStats);
|
||||
}
|
||||
|
||||
void handleTerminal(Operation *op, CAGSlice &cag) const {
|
||||
if (!isHandledType(op->getOperand(0).getType()))
|
||||
return;
|
||||
auto operandNode = cag.getOperandAnchor(op, 0);
|
||||
operandNode->setTypeTransformRule(
|
||||
CAGAnchorNode::TypeTransformRule::ExpressedOnly);
|
||||
}
|
||||
|
||||
void handleStats(Operation *op, CAGSlice &cag) const {
|
||||
if (!isHandledType(op->getResult(0).getType()))
|
||||
return;
|
||||
|
||||
auto argNode = cag.getOperandAnchor(op, 0);
|
||||
auto resultNode = cag.getResultAnchor(op, 0);
|
||||
UniformConstraintsBuilder(cag).coupleAnchors(argNode, resultNode);
|
||||
|
||||
TensorAxisStatistics layerStats;
|
||||
auto statsOp = cast<quant::StatisticsOp>(op);
|
||||
auto layerStatsAttr = statsOp.layerStats();
|
||||
layerStats.minValue =
|
||||
layerStatsAttr.getValue<FloatAttr>(0).getValueAsDouble();
|
||||
layerStats.maxValue =
|
||||
layerStatsAttr.getValue<FloatAttr>(1).getValueAsDouble();
|
||||
UniformConstraintsBuilder(cag).applyStats(resultNode, layerStats);
|
||||
}
|
||||
|
||||
void handleAdd(Operation *op, CAGSlice &cag) const {
|
||||
if (!isHandledType(op->getResult(0).getType()))
|
||||
return;
|
||||
|
||||
auto lhs = cag.getOperandAnchor(op, 0);
|
||||
auto rhs = cag.getOperandAnchor(op, 1);
|
||||
auto resultNode = cag.getResultAnchor(op, 0);
|
||||
// Add supports 8/16 bit math.
|
||||
llvm::SmallBitVector disableMask =
|
||||
getCandidateTypeDisabledExceptMask({q8, q16});
|
||||
lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
|
||||
rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
|
||||
resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
|
||||
// NOTE: We couple the add such that the scale/zeroPoint match between
|
||||
// both args and the result. This is overly constrained in that it is
|
||||
// possible to write efficient add kernels with a bit more freedom (i.e.
|
||||
// zeroPoints can vary, scales can differ by a power of two, etc).
|
||||
// However, fully coupled yields the simples solutions on the fast path.
|
||||
// Further efficiency can be had by constraining the zeroPoint to 0, but
|
||||
// there isn't a constraint for this yet (and there are tradeoffs).
|
||||
UniformConstraintsBuilder(cag).coupleAnchors(lhs, resultNode);
|
||||
UniformConstraintsBuilder(cag).coupleAnchors(rhs, resultNode);
|
||||
addRealMathOptionalConstraints(op, resultNode, cag);
|
||||
}
|
||||
|
||||
void handleMul(Operation *op, CAGSlice &cag) const {
|
||||
if (!isHandledType(op->getResult(0).getType()))
|
||||
return;
|
||||
|
||||
auto lhs = cag.getOperandAnchor(op, 0);
|
||||
auto rhs = cag.getOperandAnchor(op, 1);
|
||||
auto resultNode = cag.getResultAnchor(op, 0);
|
||||
// Mul supports 8/16 bit math.
|
||||
llvm::SmallBitVector disableMask =
|
||||
getCandidateTypeDisabledExceptMask({q8, q16});
|
||||
lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
|
||||
rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
|
||||
resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
|
||||
addRealMathOptionalConstraints(op, resultNode, cag);
|
||||
}
|
||||
|
||||
void handleMatMul(Operation *op, CAGSlice &cag) const {
|
||||
if (!isHandledType(op->getResult(0).getType()))
|
||||
return;
|
||||
|
||||
auto lhs = cag.getOperandAnchor(op, 0);
|
||||
auto rhs = cag.getOperandAnchor(op, 1);
|
||||
auto resultNode = cag.getResultAnchor(op, 0);
|
||||
// Mul supports 8/16 bit math.
|
||||
llvm::SmallBitVector disableMask =
|
||||
getCandidateTypeDisabledExceptMask({q8, q16});
|
||||
lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
|
||||
rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
|
||||
resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
|
||||
addRealMathOptionalConstraints(op, resultNode, cag);
|
||||
}
|
||||
|
||||
void handleMatMulBias(Operation *op, CAGSlice &cag) const {
|
||||
if (!isHandledType(op->getResult(0).getType()))
|
||||
return;
|
||||
|
||||
auto lhs = cag.getOperandAnchor(op, 0);
|
||||
auto rhs = cag.getOperandAnchor(op, 1);
|
||||
auto bias = cag.getOperandAnchor(op, 2);
|
||||
bias->getUniformMetadata().disabledCandidateTypes =
|
||||
getCandidateTypeDisabledExceptMask({q32ExplicitFixedPoint});
|
||||
|
||||
auto resultNode = cag.getResultAnchor(op, 0);
|
||||
UniformConstraintsBuilder(cag).propagateExplicitScale(resultNode, bias);
|
||||
|
||||
// Mul supports 8/16 bit math.
|
||||
llvm::SmallBitVector disableMask =
|
||||
getCandidateTypeDisabledExceptMask({q8, q16});
|
||||
lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
|
||||
rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
|
||||
resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
|
||||
addRealMathOptionalConstraints(op, resultNode, cag);
|
||||
}
|
||||
|
||||
void addRealMathOptionalConstraints(Operation *op, CAGAnchorNode *anchor,
|
||||
CAGSlice &cag) const {
|
||||
// TODO: It would be nice if these all extended some base trait instead
|
||||
// of requiring name lookup.
|
||||
auto clampMinAttr = op->getAttrOfType<FloatAttr>("clamp_min");
|
||||
auto clampMaxAttr = op->getAttrOfType<FloatAttr>("clamp_max");
|
||||
|
||||
if (clampMinAttr || clampMaxAttr) {
|
||||
auto nan = APFloat::getQNaN(APFloat::IEEEdouble());
|
||||
auto clampMin = clampMinAttr ? clampMinAttr.getValue() : nan;
|
||||
auto clampMax = clampMaxAttr ? clampMaxAttr.getValue() : nan;
|
||||
UniformConstraintsBuilder(cag).clamp(anchor, clampMin, clampMax);
|
||||
}
|
||||
}
|
||||
|
||||
unsigned q8;
|
||||
unsigned q16;
|
||||
unsigned q32ExplicitFixedPoint;
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::unique_ptr<FxpMathTargetConfig>
|
||||
FxpMathTargetConfig::create(SolverContext &context) {
|
||||
return std::make_unique<FxpMathTargetConfigImpl>(context);
|
||||
}
|
|
@ -1,39 +0,0 @@
|
|||
//===- Configuration.cpp - Configuration object base classes --------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Quantizer/Support/Configuration.h"
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Identifier.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::quantizer;
|
||||
|
||||
TargetConfiguration::TargetConfiguration(SolverContext &context) {}
|
||||
|
||||
void TargetConfiguration::addOpHandlerByName(StringRef name, OpHandlerFn fn) {
|
||||
opHandlers[name] = fn;
|
||||
}
|
||||
|
||||
void TargetConfiguration::addRequireStatsOpByName(StringRef opName) {
|
||||
requireStatsOpNames.insert(opName);
|
||||
}
|
||||
|
||||
bool TargetConfiguration::isRequireStatsOp(Operation *op) const {
|
||||
return requireStatsOpNames.find(op->getName().getStringRef()) !=
|
||||
requireStatsOpNames.end();
|
||||
}
|
||||
|
||||
void TargetConfiguration::handleOp(Operation *op, CAGSlice &cag) const {
|
||||
auto found_it = opHandlers.find(op->getName().getStringRef());
|
||||
if (found_it != opHandlers.end())
|
||||
found_it->second(op, cag);
|
||||
}
|
|
@ -1,172 +0,0 @@
|
|||
//===- ConstraintAnalysisGraph.cpp - Graphs type for constraints ----------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
|
||||
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/Quantizer/Support/Configuration.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::quantizer;
|
||||
|
||||
void CAGNode::replaceIncoming(CAGNode *otherNode) {
|
||||
if (this == otherNode)
|
||||
return;
|
||||
for (CAGNode *parentNode : incoming) {
|
||||
for (CAGNode *&it : parentNode->outgoing) {
|
||||
if (it == this) {
|
||||
it = otherNode;
|
||||
otherNode->incoming.push_back(parentNode);
|
||||
}
|
||||
}
|
||||
}
|
||||
incoming.clear();
|
||||
}
|
||||
|
||||
void CAGNode::addOutgoing(CAGNode *toNode) {
|
||||
if (!llvm::is_contained(outgoing, toNode)) {
|
||||
outgoing.push_back(toNode);
|
||||
toNode->incoming.push_back(this);
|
||||
}
|
||||
}
|
||||
|
||||
CAGOperandAnchor::CAGOperandAnchor(Operation *op, unsigned operandIdx)
|
||||
: CAGAnchorNode(Kind::OperandAnchor, op->getOperand(operandIdx).getType()),
|
||||
op(op), operandIdx(operandIdx) {}
|
||||
|
||||
CAGResultAnchor::CAGResultAnchor(Operation *op, unsigned resultIdx)
|
||||
: CAGAnchorNode(Kind::ResultAnchor, op->getResult(resultIdx).getType()),
|
||||
resultValue(op->getResult(resultIdx)) {}
|
||||
|
||||
CAGSlice::CAGSlice(SolverContext &context) : context(context) {}
|
||||
CAGSlice::~CAGSlice() { llvm::DeleteContainerPointers(allNodes); }
|
||||
|
||||
CAGOperandAnchor *CAGSlice::getOperandAnchor(Operation *op,
|
||||
unsigned operandIdx) {
|
||||
assert(operandIdx < op->getNumOperands() && "illegal operand index");
|
||||
|
||||
// Dedup.
|
||||
auto key = std::make_pair(op, operandIdx);
|
||||
auto foundIt = operandAnchors.find(key);
|
||||
if (foundIt != operandAnchors.end()) {
|
||||
return foundIt->second;
|
||||
}
|
||||
|
||||
// Create.
|
||||
auto anchor = std::make_unique<CAGOperandAnchor>(op, operandIdx);
|
||||
auto *unowned = anchor.release();
|
||||
unowned->nodeId = allNodes.size();
|
||||
allNodes.push_back(unowned);
|
||||
operandAnchors.insert(std::make_pair(key, unowned));
|
||||
return unowned;
|
||||
}
|
||||
|
||||
CAGResultAnchor *CAGSlice::getResultAnchor(Operation *op, unsigned resultIdx) {
|
||||
assert(resultIdx < op->getNumResults() && "illegal result index");
|
||||
|
||||
// Dedup.
|
||||
auto key = std::make_pair(op, resultIdx);
|
||||
auto foundIt = resultAnchors.find(key);
|
||||
if (foundIt != resultAnchors.end()) {
|
||||
return foundIt->second;
|
||||
}
|
||||
|
||||
// Create.
|
||||
auto anchor = std::make_unique<CAGResultAnchor>(op, resultIdx);
|
||||
auto *unowned = anchor.release();
|
||||
unowned->nodeId = allNodes.size();
|
||||
allNodes.push_back(unowned);
|
||||
resultAnchors.insert(std::make_pair(key, unowned));
|
||||
return unowned;
|
||||
}
|
||||
|
||||
void CAGSlice::enumerateImpliedConnections(
|
||||
std::function<void(CAGAnchorNode *from, CAGAnchorNode *to)> callback) {
|
||||
// Discover peer identity pairs (i.e. implied edges from Result->Operand and
|
||||
// Arg->Call). Use an intermediate vector so that the callback can modify.
|
||||
std::vector<std::pair<CAGAnchorNode *, CAGAnchorNode *>> impliedPairs;
|
||||
for (auto &resultAnchorPair : resultAnchors) {
|
||||
CAGResultAnchor *resultAnchor = resultAnchorPair.second;
|
||||
Value resultValue = resultAnchor->getValue();
|
||||
for (auto &use : resultValue.getUses()) {
|
||||
Operation *operandOp = use.getOwner();
|
||||
unsigned operandIdx = use.getOperandNumber();
|
||||
auto foundIt = operandAnchors.find(std::make_pair(operandOp, operandIdx));
|
||||
if (foundIt != operandAnchors.end()) {
|
||||
impliedPairs.push_back(std::make_pair(resultAnchor, foundIt->second));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Callback for each pair.
|
||||
for (auto &impliedPair : impliedPairs) {
|
||||
callback(impliedPair.first, impliedPair.second);
|
||||
}
|
||||
}
|
||||
|
||||
unsigned CAGSlice::propagate(const TargetConfiguration &config) {
|
||||
std::vector<CAGNode *> dirtyNodes;
|
||||
dirtyNodes.reserve(allNodes.size());
|
||||
// Note that because iteration happens in nodeId order, there is no need
|
||||
// to sort in order to make deterministic. If the selection method changes,
|
||||
// a sort should be explicitly done.
|
||||
for (CAGNode *child : *this) {
|
||||
if (child->isDirty()) {
|
||||
dirtyNodes.push_back(child);
|
||||
}
|
||||
}
|
||||
|
||||
if (dirtyNodes.empty()) {
|
||||
return 0;
|
||||
}
|
||||
for (auto dirtyNode : dirtyNodes) {
|
||||
dirtyNode->clearDirty();
|
||||
dirtyNode->propagate(context, config);
|
||||
}
|
||||
|
||||
return dirtyNodes.size();
|
||||
}
|
||||
|
||||
void CAGAnchorNode::propagate(SolverContext &solverContext,
|
||||
const TargetConfiguration &config) {
|
||||
for (CAGNode *child : *this) {
|
||||
child->markDirty();
|
||||
}
|
||||
}
|
||||
|
||||
Type CAGAnchorNode::getTransformedType() {
|
||||
if (!getUniformMetadata().selectedType) {
|
||||
return nullptr;
|
||||
}
|
||||
return getUniformMetadata().selectedType.castFromExpressedType(
|
||||
getOriginalType());
|
||||
}
|
||||
|
||||
void CAGNode::printLabel(raw_ostream &os) const {
|
||||
os << "Node<" << static_cast<const void *>(this) << ">";
|
||||
}
|
||||
|
||||
void CAGAnchorNode::printLabel(raw_ostream &os) const {
|
||||
getUniformMetadata().printSummary(os);
|
||||
}
|
||||
|
||||
void CAGOperandAnchor::printLabel(raw_ostream &os) const {
|
||||
os << "Operand<";
|
||||
op->getName().print(os);
|
||||
os << "," << operandIdx;
|
||||
os << ">";
|
||||
CAGAnchorNode::printLabel(os);
|
||||
}
|
||||
|
||||
void CAGResultAnchor::printLabel(raw_ostream &os) const {
|
||||
os << "Result<";
|
||||
getOp()->getName().print(os);
|
||||
os << ">";
|
||||
CAGAnchorNode::printLabel(os);
|
||||
}
|
|
@ -1,33 +0,0 @@
|
|||
//===- Metadata.cpp - Top level types and metadata ------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Quantizer/Support/Metadata.h"
|
||||
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "llvm/Support/Allocator.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::quantizer;
|
||||
|
||||
void CAGUniformMetadata::printSummary(raw_ostream &os) const {
|
||||
if (requiredRange.hasValue()) {
|
||||
os << "\n[" << requiredRange.getValue().first << ","
|
||||
<< requiredRange.getValue().second << "]";
|
||||
}
|
||||
|
||||
if (disabledCandidateTypes.any()) {
|
||||
os << "\n![";
|
||||
mlir::interleaveComma(disabledCandidateTypes.set_bits(), os);
|
||||
os << "]";
|
||||
}
|
||||
|
||||
if (selectedType) {
|
||||
os << "\n" << selectedType;
|
||||
}
|
||||
}
|
|
@ -1,201 +0,0 @@
|
|||
//===- Statistics.cpp - Collects statistics over tensors ------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Quantizer/Support/Statistics.h"
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::quantizer;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AttributeTensorStatistics implementation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void collectElementsStatisticsDim(ElementsAttr attr,
|
||||
unsigned numElements,
|
||||
ArrayRef<int64_t> shape,
|
||||
SmallVectorImpl<uint64_t> &indices,
|
||||
uint64_t dim,
|
||||
TensorAxisStatistics &statistics) {
|
||||
// Recursive terminating condition.
|
||||
if (dim >= shape.size())
|
||||
return;
|
||||
|
||||
if (dim < (shape.size() - 1)) {
|
||||
// Recurse past dim.
|
||||
for (uint64_t i = 0, s = shape[dim]; i < s; ++i) {
|
||||
indices[dim] = i;
|
||||
collectElementsStatisticsDim(attr, numElements, shape, indices, dim + 1,
|
||||
statistics);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Collection dim.
|
||||
for (uint64_t i = 0, s = shape[dim]; i < s; ++i) {
|
||||
indices[dim] = i;
|
||||
double value = attr.getValue<FloatAttr>(indices).getValueAsDouble();
|
||||
statistics.minValue = std::min(statistics.minValue, value);
|
||||
statistics.maxValue = std::max(statistics.maxValue, value);
|
||||
statistics.mean += value / numElements;
|
||||
// TODO: Calculate a running variance.
|
||||
}
|
||||
}
|
||||
|
||||
static void collectElementsStatisticsDimForAxis(
|
||||
unsigned axis, ElementsAttr attr, unsigned numElements,
|
||||
ArrayRef<int64_t> shape, SmallVectorImpl<uint64_t> &indices, uint64_t dim,
|
||||
TensorAxisStatistics &statistics) {
|
||||
// Recursive terminating condition.
|
||||
if (dim >= shape.size())
|
||||
return;
|
||||
|
||||
// Axis is passed separately
|
||||
if (dim == axis) {
|
||||
collectElementsStatisticsDimForAxis(axis, attr, numElements, shape, indices,
|
||||
dim + 1, statistics);
|
||||
return;
|
||||
}
|
||||
|
||||
// Go to last not axis dim
|
||||
if (dim < (shape.size() - 2) ||
|
||||
(dim == (shape.size() - 2) && axis != (shape.size() - 1))) {
|
||||
// Recurse past dim.
|
||||
for (uint64_t i = 0, s = shape[dim]; i < s; ++i) {
|
||||
indices[dim] = i;
|
||||
collectElementsStatisticsDimForAxis(axis, attr, numElements, shape,
|
||||
indices, dim + 1, statistics);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Pass axis
|
||||
uint64_t axisSize = shape[axis];
|
||||
for (uint64_t axisIdx = 0; axisIdx < axisSize; ++axisIdx) {
|
||||
indices[axis] = axisIdx;
|
||||
// Collection dim.
|
||||
for (uint64_t i = 0, s = shape[dim]; i < s; ++i) {
|
||||
indices[dim] = i;
|
||||
double value = attr.getValue<FloatAttr>(indices).getValueAsDouble();
|
||||
statistics.minValuePerAxis[axisIdx] =
|
||||
std::min(statistics.minValuePerAxis[axisIdx], value);
|
||||
statistics.maxValuePerAxis[axisIdx] =
|
||||
std::max(statistics.maxValuePerAxis[axisIdx], value);
|
||||
statistics.meanPerAxis[axisIdx] += value / numElements;
|
||||
// TODO: Calculate a running variance.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static bool getElementsStatistics(ElementsAttr attr,
|
||||
TensorAxisStatistics &statistics) {
|
||||
ShapedType sType = attr.getType();
|
||||
if (!sType.hasStaticShape())
|
||||
return false;
|
||||
Type elementTy = sType.getElementType();
|
||||
if (!elementTy.isa<FloatType>())
|
||||
return false;
|
||||
|
||||
SmallVector<uint64_t, 4> indices;
|
||||
indices.resize(sType.getRank());
|
||||
ArrayRef<int64_t> shape = sType.getShape();
|
||||
|
||||
statistics.minValue = std::numeric_limits<double>::infinity();
|
||||
statistics.maxValue = -std::numeric_limits<double>::infinity();
|
||||
statistics.mean = 0;
|
||||
statistics.variance = 0;
|
||||
|
||||
auto numElements = sType.getNumElements();
|
||||
collectElementsStatisticsDim(attr, numElements, shape, indices, 0,
|
||||
statistics);
|
||||
statistics.sampleSize = numElements;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool getElementsStatisticsForAxis(unsigned axis, ElementsAttr attr,
|
||||
TensorAxisStatistics &statistics) {
|
||||
ShapedType sType = attr.getType();
|
||||
if (!sType.hasStaticShape() || axis >= sType.getRank())
|
||||
return false;
|
||||
Type elementTy = sType.getElementType();
|
||||
if (!elementTy.isa<FloatType>())
|
||||
return false;
|
||||
|
||||
SmallVector<uint64_t, 4> indices;
|
||||
indices.resize(sType.getRank());
|
||||
ArrayRef<int64_t> shape = sType.getShape();
|
||||
|
||||
uint64_t axisSize = shape[axis];
|
||||
statistics.minValuePerAxis.assign(axisSize,
|
||||
std::numeric_limits<double>::infinity());
|
||||
statistics.maxValuePerAxis.assign(axisSize,
|
||||
-std::numeric_limits<double>::infinity());
|
||||
statistics.meanPerAxis.assign(axisSize, 0);
|
||||
statistics.variancePerAxis.assign(axisSize, 0);
|
||||
|
||||
uint64_t numElements = sType.getNumElements() / shape[axis];
|
||||
collectElementsStatisticsDimForAxis(axis, attr, numElements, shape, indices,
|
||||
0, statistics);
|
||||
statistics.sampleSizePerAxis = numElements;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AttributeTensorStatistics::get(TensorAxisStatistics &stats) const {
|
||||
if (FloatAttr floatAttr = attr.dyn_cast<FloatAttr>()) {
|
||||
double value = floatAttr.getValueAsDouble();
|
||||
stats = TensorAxisStatistics(1, value, value, value, 0);
|
||||
return true;
|
||||
} else if (auto eltAttr = attr.dyn_cast<ElementsAttr>()) {
|
||||
return getElementsStatistics(eltAttr, stats);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool AttributeTensorStatistics::supportsPerAxis() const {
|
||||
if (auto eltAttr = attr.dyn_cast<ElementsAttr>())
|
||||
return eltAttr.getType().getRank() > 1;
|
||||
return false;
|
||||
}
|
||||
|
||||
unsigned AttributeTensorStatistics::getAxisCount() const {
|
||||
if (!supportsPerAxis())
|
||||
return 0;
|
||||
return attr.cast<ElementsAttr>().getType().getRank();
|
||||
}
|
||||
|
||||
bool AttributeTensorStatistics::getForAxis(unsigned axis,
|
||||
TensorAxisStatistics &stats) const {
|
||||
if (!supportsPerAxis())
|
||||
return false;
|
||||
auto eltAttr = attr.cast<ElementsAttr>();
|
||||
return getElementsStatisticsForAxis(axis, eltAttr, stats);
|
||||
}
|
||||
|
||||
raw_ostream &mlir::quantizer::operator<<(raw_ostream &os,
|
||||
const TensorAxisStatistics &stats) {
|
||||
os << "STATS[sampleSizeLayer=" << stats.sampleSize
|
||||
<< ", minValueLayer=" << stats.minValue
|
||||
<< ", maxValueLayer=" << stats.maxValue << ", meanLayer=" << stats.mean
|
||||
<< ", varianceLayer=" << stats.variance
|
||||
<< ", sampleSizePerAxis=" << stats.sampleSizePerAxis << ", statsPerAxis={";
|
||||
for (unsigned i = 0, n = stats.minValuePerAxis.size(); i < n; ++i) {
|
||||
os << "minValue=" << stats.minValuePerAxis[i]
|
||||
<< ", maxValue=" << stats.maxValuePerAxis[i]
|
||||
<< ", mean=" << stats.meanPerAxis[i]
|
||||
<< ", variance=" << stats.variancePerAxis[i];
|
||||
if (i != n - 1)
|
||||
os << "; ";
|
||||
}
|
||||
os << "}]";
|
||||
return os;
|
||||
}
|
|
@ -1,22 +0,0 @@
|
|||
//===- TypeUtils.cpp - Helper function for manipulating types -------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Quantizer/Support/TypeUtils.h"
|
||||
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::quantizer;
|
||||
|
||||
Type mlir::quantizer::getElementOrPrimitiveType(Type t) {
|
||||
if (auto sType = t.dyn_cast<ShapedType>()) {
|
||||
return sType.getElementType();
|
||||
} else {
|
||||
return t;
|
||||
}
|
||||
}
|
|
@ -1,256 +0,0 @@
|
|||
//===- UniformConstraints.cpp - Constraints for uniform quant -------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Quantizer/Support/UniformConstraints.h"
|
||||
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/Quantizer/Support/Configuration.h"
|
||||
#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
|
||||
#include "mlir/Quantizer/Support/Metadata.h"
|
||||
#include "mlir/Quantizer/Support/Rules.h"
|
||||
#include "mlir/Quantizer/Support/TypeUtils.h"
|
||||
#include "mlir/Quantizer/Support/UniformSolvers.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::quantizer;
|
||||
using namespace mlir::quant;
|
||||
|
||||
namespace {
|
||||
|
||||
struct ClusteredFacts {
|
||||
ExpandingMinMaxFact requiredRange;
|
||||
DiscreteScaleZeroPointFact explicitScaleZeroPoint;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
static QuantizedType solveUniformType(SolverContext &solverContext,
|
||||
const ClusteredFacts &clusteredFacts,
|
||||
const CandidateQuantizedType &ct,
|
||||
Type originalElementType, Location loc) {
|
||||
switch (ct.scheme) {
|
||||
default:
|
||||
emitError(loc, "unsupported scheme for uniform type conversion");
|
||||
return nullptr;
|
||||
|
||||
case CandidateQuantizedType::Scheme::UniformPerLayer: {
|
||||
if (!clusteredFacts.requiredRange.hasValue()) {
|
||||
// TODO: Issue some kind of diagnostic. This is not an error.
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
uint64_t numLevels = ct.quantizedType.getStorageTypeMax() -
|
||||
ct.quantizedType.getStorageTypeMin();
|
||||
UniformStorageParams params{numLevels,
|
||||
ct.quantizedType.getStorageTypeMin()};
|
||||
UniformParamsFromMinMaxSolver solver(
|
||||
params, clusteredFacts.requiredRange.getValue().first,
|
||||
clusteredFacts.requiredRange.getValue().second);
|
||||
if (!solver.compute()) {
|
||||
emitWarning(loc) << "unable to solve uniform type with "
|
||||
<< "UniformParamsFromMinMaxSolver";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return UniformQuantizedType::getChecked(
|
||||
ct.quantizedType.getFlags(), ct.quantizedType.getStorageType(),
|
||||
originalElementType, solver.getScale(), solver.getZp(),
|
||||
ct.quantizedType.getStorageTypeMin(),
|
||||
ct.quantizedType.getStorageTypeMax(), loc);
|
||||
}
|
||||
case CandidateQuantizedType::Scheme::UniformExplicitFixedPointScale: {
|
||||
if (!clusteredFacts.explicitScaleZeroPoint.hasValue()) {
|
||||
emitRemark(loc)
|
||||
<< "unable to solve uniform type with UniformExplicitFixedPointScale "
|
||||
<< "(no explicitScaleZeroPoint)";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto &scaleZp = clusteredFacts.explicitScaleZeroPoint.getValue();
|
||||
assert(scaleZp.value && "optional value not set on fact");
|
||||
|
||||
if (scaleZp.conflict) {
|
||||
emitWarning(loc)
|
||||
<< "conflicting explicit scale/zeroPoint on node cluster: "
|
||||
<< "an arbitrary scale/zeroPoint will be used";
|
||||
}
|
||||
|
||||
return UniformQuantizedType::getChecked(
|
||||
ct.quantizedType.getFlags(), ct.quantizedType.getStorageType(),
|
||||
originalElementType,
|
||||
scaleZp.value->first, // scale
|
||||
0, // zeroPoint (fixed point solutions only for this scheme)
|
||||
ct.quantizedType.getStorageTypeMin(),
|
||||
ct.quantizedType.getStorageTypeMax(), loc);
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
class PropagateExplicitScale : public CAGConstraintNode {
|
||||
public:
|
||||
PropagateExplicitScale()
|
||||
: CAGConstraintNode(Kind::UniformPropagateExplicitScale) {}
|
||||
static bool classof(const CAGNode *n) {
|
||||
return n->getKind() == Kind::Constraint ||
|
||||
n->getKind() == Kind::UniformPropagateExplicitScale;
|
||||
}
|
||||
|
||||
private:
|
||||
void printLabel(raw_ostream &os) const override {
|
||||
os << "PropagateExplicitScale";
|
||||
}
|
||||
void propagate(SolverContext &solverContext,
|
||||
const TargetConfiguration &config) override {
|
||||
DiscreteScaleZeroPointFact scaleZp;
|
||||
|
||||
// Get scale/zp from all parents.
|
||||
for (auto it = incoming_begin(), e = incoming_end(); it != e; ++it) {
|
||||
auto parentAnchor = cast<CAGAnchorNode>(*it);
|
||||
auto selectedType = parentAnchor->getUniformMetadata().selectedType;
|
||||
if (auto uqType = selectedType.dyn_cast_or_null<UniformQuantizedType>()) {
|
||||
scaleZp.assertValue(
|
||||
CAGUniformMetadata::SalienceRequired,
|
||||
std::make_pair(uqType.getScale(), static_cast<int64_t>(0)));
|
||||
}
|
||||
}
|
||||
|
||||
// Propagate to children.
|
||||
if (scaleZp.hasValue()) {
|
||||
for (auto it = begin(), e = end(); it != e; ++it) {
|
||||
auto childAnchor = cast<CAGAnchorNode>(*it);
|
||||
if (modified(childAnchor->getUniformMetadata()
|
||||
.explicitScaleZeroPoint.mergeFrom(scaleZp))) {
|
||||
childAnchor->markDirty();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// A constraint node which will solve uniform quantization for all parents
|
||||
/// of the constraint, assuming that they are coupled.
|
||||
class SolveUniformConstraintNode : public CAGConstraintNode {
|
||||
public:
|
||||
SolveUniformConstraintNode()
|
||||
: CAGConstraintNode(Kind::SolveUniformConstraint) {
|
||||
markDirty();
|
||||
}
|
||||
static bool classof(const CAGNode *n) {
|
||||
return n->getKind() == Kind::Constraint ||
|
||||
n->getKind() == Kind::SolveUniformConstraint;
|
||||
}
|
||||
|
||||
private:
|
||||
void printLabel(raw_ostream &os) const override { os << "SolveUniform"; }
|
||||
|
||||
void propagate(SolverContext &solverContext,
|
||||
const TargetConfiguration &config) override {
|
||||
// First determine the required min/max range and type constraints.
|
||||
Location fusedLoc = UnknownLoc::get(&solverContext.getMlirContext());
|
||||
llvm::SmallBitVector enabledCandidateTypesMask(
|
||||
config.getAllCandidateTypesMask());
|
||||
ClusteredFacts clusteredFacts;
|
||||
Type originalElementType;
|
||||
for (auto it = incoming_begin(), e = incoming_end(); it != e; ++it) {
|
||||
auto parentAnchor = cast<CAGAnchorNode>(*it);
|
||||
auto metadata = parentAnchor->getUniformMetadata();
|
||||
// TODO: Possibly use a location that fuses all involved parents.
|
||||
fusedLoc = parentAnchor->getOp()->getLoc();
|
||||
|
||||
// Shared element type.
|
||||
auto parentOriginalElementType =
|
||||
getElementOrPrimitiveType(parentAnchor->getOriginalType());
|
||||
if (!originalElementType) {
|
||||
originalElementType = parentOriginalElementType;
|
||||
} else {
|
||||
if (originalElementType != parentOriginalElementType) {
|
||||
parentAnchor->getOp()->emitError()
|
||||
<< "cannot compute uniform type: parent element types mismatch";
|
||||
return;
|
||||
}
|
||||
}
|
||||
// Range.
|
||||
clusteredFacts.requiredRange.mergeFrom(metadata.requiredRange);
|
||||
|
||||
// Explicit scale and zero point.
|
||||
clusteredFacts.explicitScaleZeroPoint.mergeFrom(
|
||||
metadata.explicitScaleZeroPoint);
|
||||
|
||||
// Shared candidate types.
|
||||
enabledCandidateTypesMask.reset(metadata.disabledCandidateTypes);
|
||||
}
|
||||
|
||||
// Find the first enabled candidate type.
|
||||
const CandidateQuantizedType *bestCandidateType = nullptr;
|
||||
for (auto &ct : config.getCandidateTypes()) {
|
||||
if (enabledCandidateTypesMask.test(ct.ordinal)) {
|
||||
bestCandidateType = &ct;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!bestCandidateType || !originalElementType) {
|
||||
emitRemark(fusedLoc)
|
||||
<< "not solving uniform type (no viable candidate type)";
|
||||
return;
|
||||
}
|
||||
|
||||
// Solve for the type.
|
||||
QuantizedType selectedType =
|
||||
solveUniformType(solverContext, clusteredFacts, *bestCandidateType,
|
||||
originalElementType, fusedLoc);
|
||||
|
||||
// Apply it to all parents.
|
||||
for (auto it = incoming_begin(), e = incoming_end(); it != e; ++it) {
|
||||
auto parentAnchor = cast<CAGAnchorNode>(*it);
|
||||
auto &metadata = parentAnchor->getUniformMetadata();
|
||||
if (metadata.selectedType != selectedType) {
|
||||
metadata.selectedType = selectedType;
|
||||
// And mark all children of the parent dirty (except us).
|
||||
for (auto child : *parentAnchor) {
|
||||
if (child != this) {
|
||||
child->markDirty();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
void UniformConstraintsBuilder::coupleAnchors(CAGAnchorNode *a,
|
||||
CAGAnchorNode *b) {
|
||||
slice.addClusteredConstraint<SolveUniformConstraintNode>({a, b});
|
||||
}
|
||||
|
||||
void UniformConstraintsBuilder::applyStats(CAGAnchorNode *a,
|
||||
TensorAxisStatistics stats) {
|
||||
a->getUniformMetadata().requiredRange.assertValue(
|
||||
CAGUniformMetadata::SalienceDefault, {stats.minValue, stats.maxValue});
|
||||
}
|
||||
|
||||
void UniformConstraintsBuilder::clamp(CAGAnchorNode *a, APFloat minValue,
|
||||
APFloat maxValue) {
|
||||
a->getUniformMetadata().requiredRange.assertValue(
|
||||
CAGUniformMetadata::SalienceDefault,
|
||||
{minValue.convertToDouble(), maxValue.convertToDouble()});
|
||||
}
|
||||
|
||||
void UniformConstraintsBuilder::propagateExplicitScale(CAGAnchorNode *from,
|
||||
CAGAnchorNode *to) {
|
||||
slice.addUnidirectionalConstraint<PropagateExplicitScale>(from, {to});
|
||||
}
|
|
@ -1,143 +0,0 @@
|
|||
//===- UniformSolvers.cpp - Uniform type solver algorithms ----------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Quantizer/Support/UniformSolvers.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <cmath>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::quantizer;
|
||||
|
||||
bool UniformParamsFromMinMaxSolver::compute() {
|
||||
// Compute adjMin, adjMax, clamping to ensure that they straddle zero.
|
||||
if (boundingMin > 0 && boundingMax >= boundingMin) {
|
||||
// Lop-sided to the positive.
|
||||
adjMin = 0;
|
||||
adjMax = boundingMax;
|
||||
} else if (boundingMax < 0 && boundingMin <= boundingMax) {
|
||||
// Lop-sided to the negative.
|
||||
adjMin = boundingMin;
|
||||
adjMax = 0;
|
||||
} else if (boundingMin <= 0 && boundingMax >= 0) {
|
||||
adjMin = boundingMin;
|
||||
adjMax = boundingMax;
|
||||
} else {
|
||||
// Illegal bounds.
|
||||
return satisfied = false;
|
||||
}
|
||||
|
||||
const double origMinAdj = adjMin;
|
||||
const double origMaxAdj = adjMax;
|
||||
const double numLevelsDouble = storageParams.numLevels;
|
||||
|
||||
struct fns {
|
||||
static std::pair<double, double>
|
||||
computeMinMax(double boundingMin, double numLevels, double delta) {
|
||||
double adjMin = delta * std::floor(boundingMin / delta);
|
||||
return std::make_pair(adjMin, adjMin + numLevels * delta);
|
||||
}
|
||||
static double overshoot(double boundingMin, double boundingMax,
|
||||
double numLevels, double delta) {
|
||||
auto adjMinMax = computeMinMax(boundingMin, numLevels, delta);
|
||||
double maxOvershoot = adjMinMax.second - boundingMax;
|
||||
double minOvershoot = boundingMin - adjMinMax.first;
|
||||
// If undershooting on the min or max end, return that because it is
|
||||
// to be unconditionally avoided. Otherwise return the end with the
|
||||
// greatest magnitude of overshoot.
|
||||
if (maxOvershoot < 0)
|
||||
return maxOvershoot;
|
||||
if (minOvershoot < 0)
|
||||
return minOvershoot;
|
||||
return std::max(maxOvershoot, minOvershoot);
|
||||
}
|
||||
};
|
||||
|
||||
// Bisect to find a suitable delta, starting with bounds of deltaInit
|
||||
// and deltaMax.
|
||||
double deltaInit = (adjMax - adjMin) / numLevelsDouble;
|
||||
double deltaMax =
|
||||
((numLevelsDouble * deltaInit) + 2 * deltaInit) / numLevelsDouble;
|
||||
double deltaMid;
|
||||
double prevDeltaMid = 0.0;
|
||||
for (stepCount = 0; stepCount < 60; ++stepCount) {
|
||||
deltaMid = (deltaInit + deltaMax) / 2.0;
|
||||
auto fInit =
|
||||
fns::overshoot(origMinAdj, origMaxAdj, numLevelsDouble, deltaInit);
|
||||
auto fMid =
|
||||
fns::overshoot(origMinAdj, origMaxAdj, numLevelsDouble, deltaMid);
|
||||
if (fMid == 0 || (fMid > 0 && std::fabs(deltaMid - prevDeltaMid) < 1e-15)) {
|
||||
// Solution found (or step size is infinitesimal and an overshoot).
|
||||
// Empirically, this seems to terminate around 30-50 steps or so.
|
||||
// This will find a zero point for exactly representable ranges and
|
||||
// will terminate on a small step size for inexact, biasing towards
|
||||
// overshooting.
|
||||
delta = deltaMid;
|
||||
break;
|
||||
}
|
||||
bool signMid = fMid > 0;
|
||||
bool signInit = fInit > 0;
|
||||
if (signMid == signInit) {
|
||||
deltaInit = deltaMid;
|
||||
} else {
|
||||
deltaMax = deltaMid;
|
||||
}
|
||||
prevDeltaMid = deltaMid;
|
||||
}
|
||||
delta = deltaMid;
|
||||
|
||||
// Recalculate adjMin/adjMax based on new delta.
|
||||
auto adjMinMax = fns::computeMinMax(origMinAdj, numLevelsDouble, delta);
|
||||
adjMin = adjMinMax.first;
|
||||
adjMax = adjMinMax.second;
|
||||
|
||||
satisfied = false;
|
||||
zp = 0;
|
||||
|
||||
if (!std::isnan(delta) && !std::isnan(adjMin) && !std::isnan(adjMax)) {
|
||||
satisfied = true;
|
||||
// Finally, scale and zeroPoint. Since it casts to integer, only valid
|
||||
// if the inputs are valid.
|
||||
zp = std::round(storageParams.minValue - adjMin / delta);
|
||||
}
|
||||
|
||||
return satisfied;
|
||||
}
|
||||
|
||||
int64_t UniformParamsFromMinMaxSolver::quantize(double x) const {
|
||||
int64_t xq = std::round(x / delta + zp);
|
||||
return std::max<int64_t>(0, std::min<int64_t>(storageParams.numLevels, xq));
|
||||
}
|
||||
|
||||
double UniformParamsFromMinMaxSolver::dequantize(int64_t xq) const {
|
||||
return (xq - zp) * delta;
|
||||
}
|
||||
|
||||
raw_ostream &mlir::quantizer::operator<<(raw_ostream &os,
|
||||
const UniformStorageParams &p) {
|
||||
os << "UniformStorageParams{" << p.numLevels << ", " << p.minValue << "}";
|
||||
return os;
|
||||
}
|
||||
|
||||
raw_ostream &
|
||||
mlir::quantizer::operator<<(raw_ostream &os,
|
||||
const UniformParamsFromMinMaxSolver &s) {
|
||||
os << "UniformParamsFromMinMaxSolver(" << s.getStepCount() << "){";
|
||||
os << "(" << s.getBoundingMin() << ":" << s.getBoundingMax() << ") -> ";
|
||||
if (!s.isSatisfied()) {
|
||||
os << "unsat}";
|
||||
return os;
|
||||
}
|
||||
|
||||
os << "(" << s.getAdjMin() << ":" << s.getAdjMax() << ")";
|
||||
os << ", scale = " << s.getScale();
|
||||
os << ", zp = " << s.getZp();
|
||||
os << "}";
|
||||
|
||||
return os;
|
||||
}
|
|
@ -1,118 +0,0 @@
|
|||
//===- AddDefaultStatsTestPass.cpp - Testing pass to add default stats ----===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines a testing pass to add default statistics nodes to every
|
||||
// quantization eligible op. Useful for unit testing.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Quant/QuantOps.h"
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/Quantizer/Configurations/FxpMathConfig.h"
|
||||
#include "mlir/Quantizer/Support/Configuration.h"
|
||||
#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
|
||||
#include "mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h"
|
||||
#include "mlir/Quantizer/Transforms/Passes.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "llvm/Support/GraphWriter.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::quantizer;
|
||||
using namespace mlir::quant;
|
||||
|
||||
namespace {
|
||||
class AddDefaultStatsPass : public FunctionPass<AddDefaultStatsPass> {
|
||||
public:
|
||||
/// Include the generated pass utilities.
|
||||
#define GEN_PASS_QuantizerAddDefaultStats
|
||||
#include "mlir/Quantizer/Transforms/Passes.h.inc"
|
||||
|
||||
AddDefaultStatsPass() = default;
|
||||
AddDefaultStatsPass(SolverContext &solverContext,
|
||||
const TargetConfiguration &config)
|
||||
: explicitSolverContext(&solverContext), explicitConfig(&config) {}
|
||||
|
||||
void runOnFunction() override;
|
||||
void runWithConfig(SolverContext &solverContext,
|
||||
const TargetConfiguration &config);
|
||||
|
||||
private:
|
||||
SolverContext *explicitSolverContext = nullptr;
|
||||
const TargetConfiguration *explicitConfig = nullptr;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
void AddDefaultStatsPass::runOnFunction() {
|
||||
if (explicitSolverContext && explicitConfig) {
|
||||
// If explicitly constructed with a config and context.
|
||||
runWithConfig(*explicitSolverContext, *explicitConfig);
|
||||
return;
|
||||
}
|
||||
// For global pass registration, use defaults.
|
||||
SolverContext solverContext(*getFunction().getContext());
|
||||
auto config = FxpMathTargetConfig::create(solverContext);
|
||||
runWithConfig(solverContext, *config);
|
||||
}
|
||||
|
||||
void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext,
|
||||
const TargetConfiguration &config) {
|
||||
auto func = getFunction();
|
||||
|
||||
// Insert stats for each argument.
|
||||
for (auto arg : func.getArguments()) {
|
||||
if (!config.isHandledType(arg.getType()))
|
||||
continue;
|
||||
OpBuilder b(func.getBody());
|
||||
APFloat minValue(-1.0f);
|
||||
APFloat maxValue(1.0f);
|
||||
ElementsAttr layerStats = DenseFPElementsAttr::get(
|
||||
RankedTensorType::get({2}, b.getF32Type()), {minValue, maxValue});
|
||||
auto statsOp = b.create<StatisticsOp>(func.getLoc(), arg, layerStats,
|
||||
nullptr, nullptr);
|
||||
arg.replaceAllUsesWith(statsOp);
|
||||
|
||||
// StatsOp contained a use to 'arg' so make sure to reset it after replacing
|
||||
// all of the uses of 'arg'.
|
||||
statsOp.getOperation()->replaceUsesOfWith(statsOp, arg);
|
||||
}
|
||||
|
||||
// Walk the ops and insert stats.
|
||||
func.walk([&](Operation *op) {
|
||||
if (!config.isRequireStatsOp(op)) {
|
||||
return;
|
||||
}
|
||||
assert(op->getNumResults() == 1);
|
||||
|
||||
auto originalResult = op->getResult(0);
|
||||
if (!config.isHandledType(originalResult.getType()))
|
||||
return;
|
||||
|
||||
OpBuilder b(op->getBlock(), ++op->getIterator());
|
||||
|
||||
APFloat minValue(-1.0f);
|
||||
APFloat maxValue(1.0f);
|
||||
ElementsAttr layerStats = DenseFPElementsAttr::get(
|
||||
RankedTensorType::get({2}, b.getF32Type()), {minValue, maxValue});
|
||||
auto statsOp = b.create<StatisticsOp>(op->getLoc(), op->getResult(0),
|
||||
layerStats, nullptr, nullptr);
|
||||
originalResult.replaceAllUsesWith(statsOp);
|
||||
|
||||
// StatsOp contained a use to 'op' so make sure to reset it after replacing
|
||||
// all of the uses of 'op'.
|
||||
statsOp.getOperation()->replaceUsesOfWith(statsOp, originalResult);
|
||||
});
|
||||
}
|
||||
|
||||
std::unique_ptr<OpPassBase<FuncOp>>
|
||||
mlir::quantizer::createAddDefaultStatsPass() {
|
||||
return std::make_unique<AddDefaultStatsPass>();
|
||||
}
|
|
@ -1,292 +0,0 @@
|
|||
//===- InferQuantizedTypesPass.cpp - Infers quantized types ---------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines the primary pass for instantiating a CAG, running it to
|
||||
// convergence on a module to determine eligible quantized type transforms, and
|
||||
// applying those transforms to the IR.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Quant/QuantOps.h"
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/Quantizer/Configurations/FxpMathConfig.h"
|
||||
#include "mlir/Quantizer/Support/Configuration.h"
|
||||
#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
|
||||
#include "mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h"
|
||||
#include "mlir/Quantizer/Transforms/Passes.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "llvm/Support/DOTGraphTraits.h"
|
||||
#include "llvm/Support/GraphWriter.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::quantizer;
|
||||
using namespace mlir::quant;
|
||||
|
||||
namespace llvm {
|
||||
|
||||
template <>
|
||||
struct DOTGraphTraits<const CAGSlice *>
|
||||
: public DOTGraphTraits<const CAGNode *> {
|
||||
DOTGraphTraits(bool isSimple = false)
|
||||
: DOTGraphTraits<const CAGNode *>(isSimple) {}
|
||||
|
||||
std::string getNodeLabel(const CAGNode *node, const CAGSlice *graph) {
|
||||
std::string s;
|
||||
llvm::raw_string_ostream out(s);
|
||||
node->printLabel(out);
|
||||
return out.str();
|
||||
}
|
||||
|
||||
static std::string getGraphProperties(const CAGSlice *) {
|
||||
return "rankdir=LR;";
|
||||
}
|
||||
|
||||
static bool isNodeHidden(const CAGNode *node) {
|
||||
// Filter constraint nodes with no incoming or outgoing connections.
|
||||
// These orphans are often created as part of graph merging operations.
|
||||
return llvm::isa<CAGConstraintNode>(node) && node->isOrphan();
|
||||
}
|
||||
|
||||
std::string getNodeAttributes(const CAGNode *node, const CAGSlice *graph) {
|
||||
switch (node->getKind()) {
|
||||
default:
|
||||
return std::string();
|
||||
case CAGNode::Kind::OperandAnchor:
|
||||
return "shape=record,color=yellow,style=filled";
|
||||
case CAGNode::Kind::ResultAnchor:
|
||||
return "shape=record,color=lightblue,style=filled";
|
||||
case CAGNode::Kind::Constraint:
|
||||
return "shape=record,style=dotted";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace llvm
|
||||
|
||||
namespace {
|
||||
class InferQuantizedTypesPass : public ModulePass<InferQuantizedTypesPass> {
|
||||
public:
|
||||
/// Include the generated pass utilities.
|
||||
#define GEN_PASS_QuantizerInferQuantizedTypes
|
||||
#include "mlir/Quantizer/Transforms/Passes.h.inc"
|
||||
|
||||
InferQuantizedTypesPass() = default;
|
||||
InferQuantizedTypesPass(SolverContext &solverContext,
|
||||
const TargetConfiguration &config)
|
||||
: explicitSolverContext(&solverContext), explicitConfig(&config) {}
|
||||
|
||||
void runOnModule() override;
|
||||
void runWithConfig(SolverContext &solverContext,
|
||||
const TargetConfiguration &config);
|
||||
|
||||
void transformOperandType(CAGOperandAnchor *anchor, Type newType);
|
||||
void transformResultType(CAGResultAnchor *anchor, Type newType);
|
||||
|
||||
private:
|
||||
SolverContext *explicitSolverContext = nullptr;
|
||||
const TargetConfiguration *explicitConfig = nullptr;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Maximum number of propagation rounds to run to converge the CAG before
|
||||
/// signalling an error.
|
||||
static const int kMaximumPropagationRounds = 1000;
|
||||
|
||||
static LogicalResult validateTypeConversion(Type newType, Type origType,
|
||||
Operation *op) {
|
||||
if (!newType) {
|
||||
return op->emitOpError() << "unsupported type conversion from " << newType;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void InferQuantizedTypesPass::runOnModule() {
|
||||
if (explicitSolverContext && explicitConfig) {
|
||||
// If explicitly constructed with a config and context.
|
||||
runWithConfig(*explicitSolverContext, *explicitConfig);
|
||||
return;
|
||||
}
|
||||
|
||||
// For global pass registration, use defaults.
|
||||
SolverContext solverContext(*getModule().getContext());
|
||||
auto config = FxpMathTargetConfig::create(solverContext);
|
||||
runWithConfig(solverContext, *config);
|
||||
}
|
||||
|
||||
void InferQuantizedTypesPass::runWithConfig(SolverContext &solverContext,
|
||||
const TargetConfiguration &config) {
|
||||
CAGSlice cag(solverContext);
|
||||
for (auto f : getModule().getOps<FuncOp>()) {
|
||||
f.walk([&cag, &config](Operation *op) { config.handleOp(op, cag); });
|
||||
}
|
||||
config.finalizeAnchors(cag);
|
||||
|
||||
// Propagate.
|
||||
int propRound;
|
||||
for (propRound = kMaximumPropagationRounds; propRound > 0; --propRound) {
|
||||
auto propCount = cag.propagate(config);
|
||||
if (propCount == 0)
|
||||
break;
|
||||
}
|
||||
if (propRound == 0) {
|
||||
emitError(UnknownLoc::get(&getContext()),
|
||||
"exceeded maximum number of solver iterations (infinite loop?)");
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: Only dump the GraphViz if a flag is set and move to a utility.
|
||||
// GraphViz.
|
||||
if (!solverContext.getDebugCAGDotPath().empty()) {
|
||||
auto actFileName = llvm::WriteGraph(
|
||||
const_cast<const CAGSlice *>(&cag), "CAG",
|
||||
/*ShortNames=*/false,
|
||||
/*Title=*/"CAG",
|
||||
/*Filename=*/std::string(solverContext.getDebugCAGDotPath()));
|
||||
llvm::errs() << "Wrote graphviz file: " << actFileName << "\n";
|
||||
}
|
||||
|
||||
// Start transforming the types in order of anchor type (results, then
|
||||
// operands).
|
||||
// Apply result types.
|
||||
for (auto *node : cag) {
|
||||
auto anchorNode = dyn_cast<CAGResultAnchor>(node);
|
||||
if (!anchorNode)
|
||||
continue;
|
||||
if (Type newType = anchorNode->getTransformedType())
|
||||
transformResultType(anchorNode, newType);
|
||||
}
|
||||
|
||||
// Apply operand types.
|
||||
for (auto *node : cag) {
|
||||
auto anchorNode = dyn_cast<CAGOperandAnchor>(node);
|
||||
if (!anchorNode)
|
||||
continue;
|
||||
if (Type newType = anchorNode->getTransformedType())
|
||||
transformOperandType(anchorNode, newType);
|
||||
}
|
||||
}
|
||||
|
||||
void InferQuantizedTypesPass::transformOperandType(CAGOperandAnchor *anchor,
|
||||
Type newType) {
|
||||
Value inputValue = anchor->getValue();
|
||||
Operation *op = anchor->getOp();
|
||||
OpBuilder b(op->getBlock(), Block::iterator(op));
|
||||
|
||||
SmallVector<Value, 1> removeValuesIfDead;
|
||||
|
||||
// Because we've already run the result transforms at this phase, it is
|
||||
// very likely that inputValue points to a dcast op whose input matches
|
||||
// our type. We detect that situation and route around just to save some
|
||||
// bulk in the IR.
|
||||
Value newTypedInputValue = inputValue;
|
||||
auto inputDcastOp =
|
||||
dyn_cast_or_null<DequantizeCastOp>(inputValue.getDefiningOp());
|
||||
if (inputDcastOp && inputDcastOp.arg().getType() == newType) {
|
||||
// Can just use the dcast's input value.
|
||||
newTypedInputValue = inputDcastOp.arg();
|
||||
removeValuesIfDead.push_back(inputDcastOp);
|
||||
} else {
|
||||
// Need to synthesize a qcast.
|
||||
newTypedInputValue =
|
||||
b.create<QuantizeCastOp>(op->getLoc(), newType, inputValue);
|
||||
}
|
||||
|
||||
switch (anchor->getTypeTransformRule()) {
|
||||
case CAGAnchorNode::TypeTransformRule::Direct:
|
||||
anchor->getOp()->setOperand(anchor->getOperandIdx(), newTypedInputValue);
|
||||
break;
|
||||
|
||||
case CAGAnchorNode::TypeTransformRule::DirectStorage: {
|
||||
Type storageType = QuantizedType::castToStorageType(newType);
|
||||
if (failed(validateTypeConversion(storageType, newType, op)))
|
||||
return;
|
||||
anchor->getOp()->setOperand(
|
||||
anchor->getOperandIdx(),
|
||||
b.create<StorageCastOp>(op->getLoc(), storageType, newTypedInputValue));
|
||||
break;
|
||||
}
|
||||
|
||||
case CAGAnchorNode::TypeTransformRule::ExpressedOnly:
|
||||
// Leave the anchor as-is and just cast in/out after it.
|
||||
anchor->getOp()->setOperand(
|
||||
anchor->getOperandIdx(),
|
||||
b.create<DequantizeCastOp>(op->getLoc(), anchor->getOriginalType(),
|
||||
newTypedInputValue));
|
||||
break;
|
||||
}
|
||||
|
||||
for (Value removeValueIfDead : removeValuesIfDead) {
|
||||
if (removeValueIfDead.use_empty()) {
|
||||
removeValueIfDead.getDefiningOp()->erase();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void InferQuantizedTypesPass::transformResultType(CAGResultAnchor *anchor,
|
||||
Type newType) {
|
||||
Value origResultValue = anchor->getValue();
|
||||
Operation *op = origResultValue.getDefiningOp();
|
||||
OpBuilder b(op->getBlock(), ++Block::iterator(op));
|
||||
|
||||
Value replacedResultValue = nullptr;
|
||||
Value newResultValue = nullptr;
|
||||
switch (anchor->getTypeTransformRule()) {
|
||||
case CAGAnchorNode::TypeTransformRule::Direct:
|
||||
origResultValue.setType(newType);
|
||||
replacedResultValue = newResultValue = b.create<DequantizeCastOp>(
|
||||
op->getLoc(), anchor->getOriginalType(), origResultValue);
|
||||
break;
|
||||
|
||||
case CAGAnchorNode::TypeTransformRule::DirectStorage: {
|
||||
Type storageType = QuantizedType::castToStorageType(newType);
|
||||
if (failed(validateTypeConversion(storageType, newType, op)))
|
||||
return;
|
||||
origResultValue.setType(storageType);
|
||||
replacedResultValue =
|
||||
b.create<StorageCastOp>(op->getLoc(), newType, origResultValue);
|
||||
newResultValue = b.create<DequantizeCastOp>(
|
||||
op->getLoc(), anchor->getOriginalType(), replacedResultValue);
|
||||
break;
|
||||
}
|
||||
|
||||
case CAGAnchorNode::TypeTransformRule::ExpressedOnly:
|
||||
// Leave the anchor as-is and just cast in/out after it.
|
||||
replacedResultValue =
|
||||
b.create<QuantizeCastOp>(op->getLoc(), newType, origResultValue);
|
||||
newResultValue = b.create<DequantizeCastOp>(
|
||||
op->getLoc(), anchor->getOriginalType(), replacedResultValue);
|
||||
break;
|
||||
}
|
||||
|
||||
if (replacedResultValue) {
|
||||
// Transform:
|
||||
// origResultValue --> replaceResultValue -> newResultValue
|
||||
// \-> [original uses]
|
||||
// To:
|
||||
// origResultValue -> replaceResultValue ->
|
||||
// newResultValue -> [original uses]
|
||||
// Note that replaceResultValue may equal newResultValue or there may
|
||||
// be operands between the two.
|
||||
origResultValue.replaceAllUsesWith(newResultValue);
|
||||
replacedResultValue.getDefiningOp()->replaceUsesOfWith(newResultValue,
|
||||
origResultValue);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
||||
mlir::quantizer::createInferQuantizedTypesPass(
|
||||
SolverContext &solverContext, const TargetConfiguration &config) {
|
||||
return std::make_unique<InferQuantizedTypesPass>(solverContext, config);
|
||||
}
|
||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
||||
mlir::quantizer::createInferQuantizedTypesPass() {
|
||||
return std::make_unique<InferQuantizedTypesPass>();
|
||||
}
|
|
@ -1,66 +0,0 @@
|
|||
//===- RemoveInstrumentationPass.cpp - Removes instrumentation ------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines a pass to remove any instrumentation ops. It is often one
|
||||
// of the final steps when performing quantization and is run after any
|
||||
// decisions requiring instrumentation have been made.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Quant/QuantOps.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Quantizer/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::quantizer;
|
||||
using namespace mlir::quant;
|
||||
|
||||
namespace {
|
||||
class RemoveInstrumentationPass
|
||||
: public FunctionPass<RemoveInstrumentationPass> {
|
||||
/// Include the generated pass utilities.
|
||||
#define GEN_PASS_QuantizerRemoveInstrumentation
|
||||
#include "mlir/Quantizer/Transforms/Passes.h.inc"
|
||||
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
class RemoveIdentityOpRewrite : public RewritePattern {
|
||||
public:
|
||||
RemoveIdentityOpRewrite(MLIRContext *context)
|
||||
: RewritePattern(OpTy::getOperationName(), 1, context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
assert(op->getNumOperands() == 1);
|
||||
assert(op->getNumResults() == 1);
|
||||
|
||||
rewriter.replaceOp(op, op->getOperand(0));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
void RemoveInstrumentationPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
auto *context = &getContext();
|
||||
patterns.insert<RemoveIdentityOpRewrite<StatisticsOp>,
|
||||
RemoveIdentityOpRewrite<StatisticsRefOp>,
|
||||
RemoveIdentityOpRewrite<CoupledRefOp>>(context);
|
||||
applyPatternsGreedily(func, patterns);
|
||||
}
|
||||
|
||||
std::unique_ptr<OpPassBase<FuncOp>>
|
||||
mlir::quantizer::createRemoveInstrumentationPass() {
|
||||
return std::make_unique<RemoveInstrumentationPass>();
|
||||
}
|
|
@ -1,64 +0,0 @@
|
|||
// RUN: mlir-opt %s -split-input-file -fxpmath-lower-uniform-casts | FileCheck %s --dump-input=always
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: dequantize_per_layer_fixedpoint
|
||||
!type_input = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
|
||||
!type_result = type tensor<4xf32>
|
||||
func @dequantize_per_layer_fixedpoint(%arg0 : !type_input) -> !type_result {
|
||||
// CHECK: %cst = constant dense<6.250000e-02> : tensor<4xf32>
|
||||
// CHECK-NEXT: %0 = "quant.scast"(%arg0) : (tensor<4x!quant.uniform<i8:f32, 6.250000e-02>>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %1 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi32>
|
||||
// CHECK-NEXT: %2 = "fxpmath.convertistof"(%1) : (tensor<4xi32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: %3 = mulf %2, %cst : tensor<4xf32>
|
||||
// CHECK-NEXT: return %3 : tensor<4xf32>
|
||||
%0 = "quant.dcast"(%arg0) : (!type_input) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: dequantize_per_layer_affine
|
||||
!type_input = type tensor<4x!quant.uniform<i8:f32, 6.25e-2:-36>>
|
||||
!type_result = type tensor<4xf32>
|
||||
func @dequantize_per_layer_affine(%arg0 : !type_input) -> !type_result {
|
||||
// CHECK: %cst = constant dense<36> : tensor<4xi32>
|
||||
// CHECK-NEXT: %cst_0 = constant dense<6.250000e-02> : tensor<4xf32>
|
||||
// CHECK-NEXT: %0 = "quant.scast"(%arg0) : (tensor<4x!quant.uniform<i8:f32, 6.250000e-02:-36>>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %1 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi32>
|
||||
// CHECK-NEXT: %2 = addi %1, %cst : tensor<4xi32>
|
||||
// CHECK-NEXT: %3 = "fxpmath.convertistof"(%2) : (tensor<4xi32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: %4 = mulf %3, %cst_0 : tensor<4xf32>
|
||||
// CHECK-NEXT: return %4 : tensor<4xf32>
|
||||
%0 = "quant.dcast"(%arg0) : (!type_input) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: dequantize_per_axis_fixedpoint
|
||||
!type_input = type tensor<4x!quant.uniform<i8:f32:0, {6.25e-2,3.26e-2,4.25e-2,1.23e-2}>>
|
||||
!type_result = type tensor<4xf32>
|
||||
func @dequantize_per_axis_fixedpoint(%arg0 : !type_input) -> !type_result {
|
||||
// expected-warning@+1 {{unimplemented: per-axis uniform dequantization}}
|
||||
%0 = "quant.dcast"(%arg0) : (!type_input) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: dequantize_per_axis_affine
|
||||
!type_input = type tensor<4x!quant.uniform<i8:f32:0, {6.25e-2,3.26e-2,4.25e-2,1.23e-2}>>
|
||||
!type_result = type tensor<4xf32>
|
||||
func @dequantize_per_axis_affine(%arg0 : !type_input) -> !type_result {
|
||||
// expected-warning@+1 {{unimplemented: per-axis uniform dequantization}}
|
||||
%0 = "quant.dcast"(%arg0) : (!type_input) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
// -----
|
||||
// Noop dequantize should be skipped (will be canonicalized away later).
|
||||
// CHECK-LABEL: dequantize_noop
|
||||
!type_input = type tensor<4x!quant.uniform<i8:f32, 6.25e-2:-36>>
|
||||
!type_result = type tensor<4x!quant.uniform<i8:f32, 6.25e-2:-36>>
|
||||
func @dequantize_noop(%arg0 : !type_input) -> !type_result {
|
||||
// CHECK: %0 = "quant.dcast"(%arg0)
|
||||
%0 = "quant.dcast"(%arg0) : (!type_input) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
|
@ -1,102 +0,0 @@
|
|||
// RUN: mlir-opt %s -split-input-file -fxpmath-lower-uniform-real-math -pass-pipeline='func(canonicalize)' | FileCheck %s --dump-input=always
|
||||
|
||||
// -----
|
||||
// Verify lowering when operands and result have the same fixedpoint scale.
|
||||
// CHECK-LABEL: real_addew_fixedpoint_isomorphic
|
||||
!type_lhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
|
||||
!type_rhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
|
||||
!type_result = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
|
||||
func @real_addew_fixedpoint_isomorphic(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// CHECK-NEXT: %0 = "quant.scast"(%arg0) : (tensor<4x!quant.uniform<i8:f32, 6.250000e-02>>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant.uniform<i8:f32, 6.250000e-02>>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %2 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi16>
|
||||
// CHECK-NEXT: %3 = "fxpmath.convertis"(%1) : (tensor<4xi8>) -> tensor<4xi16>
|
||||
// CHECK-NEXT: %4 = addi %2, %3 : tensor<4xi16>
|
||||
// CHECK-NEXT: %5 = "fxpmath.clampis"(%4) {clamp_max = 127 : i16, clamp_min = -128 : i16} : (tensor<4xi16>) -> tensor<4xi16>
|
||||
// CHECK-NEXT: %6 = "fxpmath.convertis"(%5) : (tensor<4xi16>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %7 = "quant.scast"(%6) : (tensor<4xi8>) -> tensor<4x!quant.uniform<i8:f32, 6.250000e-02>>
|
||||
// CHECK-NEXT: return %7 : tensor<4x!quant.uniform<i8:f32, 6.250000e-02>>
|
||||
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
// -----
|
||||
// Verify lowering when operands and result have the same fixedpoint scale
|
||||
// and non-zero zero points.
|
||||
// CHECK-LABEL: real_addew_affine_isomorphic
|
||||
!type_lhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2:-5>>
|
||||
!type_rhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2:-5>>
|
||||
!type_result = type tensor<4x!quant.uniform<i8:f32, 6.25e-2:-5>>
|
||||
func @real_addew_affine_isomorphic(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// CHECK-NEXT: %cst = constant dense<5> : tensor<4xi16>
|
||||
// CHECK-NEXT: %0 = "quant.scast"(%arg0) : (tensor<4x!quant.uniform<i8:f32, 6.250000e-02:-5>>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant.uniform<i8:f32, 6.250000e-02:-5>>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %2 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi16>
|
||||
// CHECK-NEXT: %3 = "fxpmath.convertis"(%1) : (tensor<4xi8>) -> tensor<4xi16>
|
||||
// CHECK-NEXT: %4 = addi %2, %3 : tensor<4xi16>
|
||||
// CHECK-NEXT: %5 = addi %4, %cst : tensor<4xi16>
|
||||
// CHECK-NEXT: %6 = "fxpmath.clampis"(%5) {clamp_max = 127 : i16, clamp_min = -128 : i16} : (tensor<4xi16>) -> tensor<4xi16>
|
||||
// CHECK-NEXT: %7 = "fxpmath.convertis"(%6) : (tensor<4xi16>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %8 = "quant.scast"(%7) : (tensor<4xi8>) -> tensor<4x!quant.uniform<i8:f32, 6.250000e-02:-5>>
|
||||
// CHECK-NEXT: return %8 : tensor<4x!quant.uniform<i8:f32, 6.250000e-02:-5>>
|
||||
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
// -----
|
||||
// The RHS quant parameters proscribe a range of [-8..8) so an explicit clamp
|
||||
// of [-4..4] should result in an integral clamp range of [-64..64].
|
||||
// CHECK-LABEL: real_addew_fixedpoint_clamp
|
||||
!type_lhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
|
||||
!type_rhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
|
||||
!type_result = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
|
||||
func @real_addew_fixedpoint_clamp(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// CHECK-NEXT: %0 = "quant.scast"(%arg0) : (tensor<4x!quant.uniform<i8:f32, 6.250000e-02>>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant.uniform<i8:f32, 6.250000e-02>>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %2 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi16>
|
||||
// CHECK-NEXT: %3 = "fxpmath.convertis"(%1) : (tensor<4xi8>) -> tensor<4xi16>
|
||||
// CHECK-NEXT: %4 = addi %2, %3 : tensor<4xi16>
|
||||
// CHECK-NEXT: %5 = "fxpmath.clampis"(%4) {clamp_max = 64 : i16, clamp_min = -64 : i16} : (tensor<4xi16>) -> tensor<4xi16>
|
||||
// CHECK-NEXT: %6 = "fxpmath.convertis"(%5) : (tensor<4xi16>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %7 = "quant.scast"(%6) : (tensor<4xi8>) -> tensor<4x!quant.uniform<i8:f32, 6.250000e-02>>
|
||||
// CHECK-NEXT: return %7 : tensor<4x!quant.uniform<i8:f32, 6.250000e-02>>
|
||||
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) { clamp_min = -4.0, clamp_max = 4.0 }
|
||||
: (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: real_addew_unquantized_lhs
|
||||
// Verifies that leaves as-is for unquantized lhs.
|
||||
!type_lhs = type tensor<4xf32>
|
||||
!type_rhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
|
||||
!type_result = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
|
||||
func @real_addew_unquantized_lhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// CHECK: %0 = "fxpmath.real_add_ew"(%arg0, %arg1)
|
||||
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: real_addew_unquantized_rhs
|
||||
// Verifies that leaves as-is for unquantized rhs.
|
||||
!type_lhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
|
||||
!type_rhs = type tensor<4xf32>
|
||||
!type_result = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
|
||||
func @real_addew_unquantized_rhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// CHECK: %0 = "fxpmath.real_add_ew"(%arg0, %arg1)
|
||||
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: real_addew_unquantized_result
|
||||
// Verifies that leaves as-is for unquantized result.
|
||||
!type_lhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
|
||||
!type_rhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
|
||||
!type_result = type tensor<4xf32>
|
||||
func @real_addew_unquantized_result(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// CHECK: %0 = "fxpmath.real_add_ew"(%arg0, %arg1)
|
||||
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
|
@ -1,94 +0,0 @@
|
|||
// RUN: mlir-opt %s -split-input-file -fxpmath-lower-uniform-real-math -pass-pipeline='func(canonicalize)' -verify-diagnostics | FileCheck %s --dump-input=always
|
||||
|
||||
// -----
|
||||
// Verify lowering when operands and result have the same fixedpoint scale.
|
||||
// CHECK-LABEL: real_mulew_fixedpoint
|
||||
!type_lhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
|
||||
!type_rhs = type tensor<4x!quant.uniform<i8:f32, 3.875e-2>>
|
||||
!type_result = type tensor<4x!quant.uniform<i8:f32, 1.065e-1>>
|
||||
func @real_mulew_fixedpoint(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// CHECK: %0 = "quant.scast"(%arg0) : (tensor<4x!quant.uniform<i8:f32, 6.250000e-02>>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant.uniform<i8:f32, 3.875000e-02>>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %2 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi32>
|
||||
// CHECK-NEXT: %3 = "fxpmath.convertis"(%1) : (tensor<4xi8>) -> tensor<4xi32>
|
||||
// CHECK-NEXT: %4 = muli %2, %3 : tensor<4xi32>
|
||||
// CHECK-NEXT: %5 = "fxpmath.vs_saturating_rounding_doubling_high_mulis"(%4) {b = 1562722842 : i32} : (tensor<4xi32>) -> tensor<4xi32>
|
||||
// CHECK-NEXT: %6 = "fxpmath.rounding_divide_by_potis"(%5) {exponent = 5 : i32} : (tensor<4xi32>) -> tensor<4xi32>
|
||||
// CHECK-NEXT: %7 = "fxpmath.clampis"(%6) {clamp_max = 127 : i32, clamp_min = -128 : i32} : (tensor<4xi32>) -> tensor<4xi32>
|
||||
// CHECK-NEXT: %8 = "fxpmath.convertis"(%7) : (tensor<4xi32>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %9 = "quant.scast"(%8) : (tensor<4xi8>) -> tensor<4x!quant.uniform<i8:f32, 1.065000e-01>>
|
||||
// CHECK-NEXT: return %9 : tensor<4x!quant.uniform<i8:f32, 1.065000e-01>>
|
||||
%0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
// -----
|
||||
// Verify lowering when operands and result have the same fixedpoint scale
|
||||
// and non-zero zero points.
|
||||
// CHECK-LABEL: real_mulew_affine_clamp
|
||||
!type_lhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2:-3>>
|
||||
!type_rhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2:-5>>
|
||||
!type_result = type tensor<4x!quant.uniform<i8:f32, 6.25e-2:-9>>
|
||||
func @real_mulew_affine_clamp(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// Just verify that the affine adds/constants and clamps are present.
|
||||
// CHECK: %cst = constant dense<3> : tensor<4xi32>
|
||||
// CHECK: %cst_0 = constant dense<5> : tensor<4xi32>
|
||||
// CHECK: %cst_1 = constant dense<-9> : tensor<4xi32>
|
||||
// CHECK: addi %2, %cst : tensor<4xi32>
|
||||
// CHECK: addi %3, %cst_0 : tensor<4xi32>
|
||||
// CHECK: muli %4, %5 : tensor<4xi32>
|
||||
// CHECK: addi %8, %cst_1 : tensor<4xi32>
|
||||
// CHECK: {clamp_max = 55 : i32, clamp_min = -73 : i32}
|
||||
%0 = "fxpmath.real_mul_ew"(%arg0, %arg1) { clamp_min = -4.0, clamp_max = 4.0 } : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: real_mulew_unquantized_lhs
|
||||
// Verifies that leaves as-is for unquantized lhs.
|
||||
!type_lhs = type tensor<4xf32>
|
||||
!type_rhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
|
||||
!type_result = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
|
||||
func @real_mulew_unquantized_lhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// CHECK: %0 = "fxpmath.real_mul_ew"(%arg0, %arg1)
|
||||
%0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: real_mulew_unquantized_rhs
|
||||
// Verifies that leaves as-is for unquantized rhs.
|
||||
!type_lhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
|
||||
!type_rhs = type tensor<4xf32>
|
||||
!type_result = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
|
||||
func @real_mulew_unquantized_rhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// CHECK: %0 = "fxpmath.real_mul_ew"(%arg0, %arg1)
|
||||
%0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: real_mulew_unquantized_result
|
||||
// Verifies that leaves as-is for unquantized result.
|
||||
!type_lhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
|
||||
!type_rhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
|
||||
!type_result = type tensor<4xf32>
|
||||
func @real_mulew_unquantized_result(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// CHECK: %0 = "fxpmath.real_mul_ew"(%arg0, %arg1)
|
||||
%0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
// -----
|
||||
// Verify lowering when operands and result have the same fixedpoint scale.
|
||||
// Note that the multiplier = lhs_scale * rhs_scale / result_scale
|
||||
// = 22.740610328638496
|
||||
// CHECK-LABEL: real_mulew_multiplier_gt_1
|
||||
!type_lhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
|
||||
!type_rhs = type tensor<4x!quant.uniform<i8:f32, 3.875e-2>>
|
||||
!type_result = type tensor<4x!quant.uniform<i8:f32, 1.065e-4>>
|
||||
func @real_mulew_multiplier_gt_1(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// expected-warning@+1 {{unimplemented: cannot multiply with multiplier > 1.0}}
|
||||
%0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
|
@ -1,51 +0,0 @@
|
|||
// RUN: mlir-opt %s -quantizer-infer-quantized-types -quant-convert-const -quantizer-remove-instrumentation -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s
|
||||
|
||||
// ----
|
||||
// A matmul without fused clamp or bias.
|
||||
// CHECK-LABEL: @matmul
|
||||
// CHECK: %cst = constant dense{{.*}}tensor<3x5xi8>
|
||||
// CHECK-NEXT: %0 = "quant.qcast"(%arg0) : (tensor<300x3xf32>) -> tensor<300x3x!quant.uniform<i8:f32, 0.037564418067230126:35>>
|
||||
// CHECK-NEXT: %1 = "quant.scast"(%cst) : (tensor<3x5xi8>) -> tensor<3x5x!quant.uniform<i8:f32, 0.0062823070315864236:-1>>
|
||||
// CHECK-NEXT: %2 = "fxpmath.real_matmul"(%0, %1) : (tensor<300x3x!quant.uniform<i8:f32, 0.037564418067230126:35>>, tensor<3x5x!quant.uniform<i8:f32, 0.0062823070315864236:-1>>) -> tensor<300x5x!quant.uniform<i8:f32, 0.0629921259842528:-1>>
|
||||
// CHECK-NEXT: %3 = "quant.dcast"(%2) : (tensor<300x5x!quant.uniform<i8:f32, 0.0629921259842528:-1>>) -> tensor<300x5xf32>
|
||||
func @matmul(%arg0: tensor<300x3xf32>) -> tensor<300x5xf32> {
|
||||
%0 = "quant.stats"(%arg0) {layerStats = dense<[-6.123e+00, 3.45e+00]> : tensor<2xf32>} : (tensor<300x3xf32>) -> tensor<300x3xf32>
|
||||
%cst = constant {name = "constant.35"} dense<[[-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32>
|
||||
%1 = "fxpmath.real_matmul"(%0, %cst) : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32>
|
||||
%2 = "quant.stats"(%1) {layerStats = dense<[-8.000000e+00, 8.000000e+00]> : tensor<2xf32>} : (tensor<300x5xf32>) -> tensor<300x5xf32>
|
||||
return %2 : tensor<300x5xf32>
|
||||
}
|
||||
|
||||
// ----
|
||||
// A matmul with fused clamp which serves as statistics for the result.
|
||||
// CHECK-LABEL: @matmul_clamp
|
||||
// CHECK: %cst = constant dense{{.*}}tensor<3x5xi8>
|
||||
// CHECK-NEXT: %0 = "quant.qcast"(%arg0) : (tensor<300x3xf32>) -> tensor<300x3x!quant.uniform<i8:f32, 0.037564418067230126:35>>
|
||||
// CHECK-NEXT: %1 = "quant.scast"(%cst) : (tensor<3x5xi8>) -> tensor<3x5x!quant.uniform<i8:f32, 0.0062823070315864236:-1>>
|
||||
// CHECK-NEXT: %2 = "fxpmath.real_matmul"(%0, %1) {clamp_max = 6.100000e+00 : f64, clamp_min = -1.225000e+01 : f64} : (tensor<300x3x!quant.uniform<i8:f32, 0.037564418067230126:35>>, tensor<3x5x!quant.uniform<i8:f32, 0.0062823070315864236:-1>>) -> tensor<300x5x!quant.uniform<i8:f32, 0.072058823529412216:42>>
|
||||
// CHECK-NEXT: %3 = "quant.dcast"(%2) : (tensor<300x5x!quant.uniform<i8:f32, 0.072058823529412216:42>>) -> tensor<300x5xf32>
|
||||
func @matmul_clamp(%arg0: tensor<300x3xf32>) -> tensor<300x5xf32> {
|
||||
%0 = "quant.stats"(%arg0) {layerStats = dense<[-6.123e+00, 3.45e+00]> : tensor<2xf32>} : (tensor<300x3xf32>) -> tensor<300x3xf32>
|
||||
%cst = constant {name = "constant.35"} dense<[[-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32>
|
||||
%1 = "fxpmath.real_matmul"(%0, %cst) {clamp_max = 6.10, clamp_min = -12.25} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32>
|
||||
return %1 : tensor<300x5xf32>
|
||||
}
|
||||
|
||||
// ----
|
||||
// A matmul with bias and clamp.
|
||||
// CHECK-LABEL: @matmul_add_clamp
|
||||
// CHECK: %cst = constant dense{{.*}}tensor<3x5xi8>
|
||||
// CHECK-NEXT: %cst_0 = constant dense<[14, 28, 42, 56, 69]> : tensor<5xi32>
|
||||
// CHECK-NEXT: %0 = "quant.qcast"(%arg0) : (tensor<300x3xf32>) -> tensor<300x3x!quant.uniform<i8:f32, 0.037564418067230126:35>>
|
||||
// CHECK-NEXT: %1 = "quant.scast"(%cst) : (tensor<3x5xi8>) -> tensor<3x5x!quant.uniform<i8:f32, 0.0062823070315864236:-1>>
|
||||
// CHECK-NEXT: %2 = "quant.scast"(%cst_0) : (tensor<5xi32>) -> tensor<5x!quant.uniform<i32:f32, 0.072058823529412216>>
|
||||
// CHECK-NEXT: %3 = "fxpmath.real_matmul_bias"(%0, %1, %2) {clamp_max = 6.100000e+00 : f64, clamp_min = -1.225000e+01 : f64} : (tensor<300x3x!quant.uniform<i8:f32, 0.037564418067230126:35>>, tensor<3x5x!quant.uniform<i8:f32, 0.0062823070315864236:-1>>, tensor<5x!quant.uniform<i32:f32, 0.072058823529412216>>) -> tensor<300x5x!quant.uniform<i8:f32, 0.072058823529412216:42>>
|
||||
// CHECK-NEXT: %4 = "quant.dcast"(%3) : (tensor<300x5x!quant.uniform<i8:f32, 0.072058823529412216:42>>) -> tensor<300x5xf32>
|
||||
func @matmul_add_clamp(%arg0: tensor<300x3xf32>) -> tensor<300x5xf32> {
|
||||
%0 = "quant.stats"(%arg0) {layerStats = dense<[-6.123e+00, 3.45e+00]> : tensor<2xf32>} : (tensor<300x3xf32>) -> tensor<300x3xf32>
|
||||
%cst = constant {name = "constant.35"} dense<[[-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32>
|
||||
%cst_0 = constant {name = "constant.37"} dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<5xf32>
|
||||
%1 = "fxpmath.real_matmul_bias"(%0, %cst, %cst_0) {clamp_max = 6.10, clamp_min = -12.25} : (tensor<300x3xf32>, tensor<3x5xf32>, tensor<5xf32>) -> tensor<300x5xf32>
|
||||
return %1 : tensor<300x5xf32>
|
||||
}
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
// RUN: mlir-opt %s -quantizer-remove-instrumentation -split-input-file | FileCheck %s
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: remove_ops
|
||||
func @remove_ops(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
||||
%0 = "quant.stats"(%arg0) {
|
||||
layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>
|
||||
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
|
||||
%1 = "quant.coupled_ref"(%0) { coupledKey = "foobar" } :
|
||||
(tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
|
||||
%2 = "quant.stats_ref"(%1) { statsKey = "foobar" } :
|
||||
(tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
|
||||
// CHECK: return %arg0 : tensor<8x4x3xf32>
|
||||
return %2 : tensor<8x4x3xf32>
|
||||
}
|
|
@ -1,7 +1,6 @@
|
|||
// RUN: mlir-opt --show-dialects | FileCheck %s
|
||||
// CHECK: Registered Dialects:
|
||||
// CHECK: affine
|
||||
// CHECK: fxpmath
|
||||
// CHECK: gpu
|
||||
// CHECK: linalg
|
||||
// CHECK: llvm
|
||||
|
|
|
@ -16,9 +16,6 @@ set(LIBS
|
|||
MLIROptLib
|
||||
MLIRParser
|
||||
MLIRPass
|
||||
MLIRQuantizerFxpMathConfig
|
||||
MLIRQuantizerSupport
|
||||
MLIRQuantizerTransforms
|
||||
MLIRSPIRV
|
||||
MLIRSPIRVTestPasses
|
||||
MLIRSPIRVTransforms
|
||||
|
|
|
@ -1,99 +0,0 @@
|
|||
//===- RulesTest.cpp - Rules unit tests -----------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Quantizer/Support/Rules.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::quantizer;
|
||||
|
||||
namespace {
|
||||
|
||||
using TestDiscreteFact = DiscreteFact<int>;
|
||||
|
||||
TEST(ExpandingMinMaxReducer, Basic) {
|
||||
ExpandingMinMaxFact f;
|
||||
EXPECT_FALSE(f.hasValue());
|
||||
|
||||
// First assertion always modifies.
|
||||
EXPECT_TRUE(modified(f.assertValue(0, {-1.0, 1.0})));
|
||||
EXPECT_TRUE(f.hasValue());
|
||||
EXPECT_EQ(std::make_pair(-1.0, 1.0), f.getValue());
|
||||
EXPECT_EQ(0, f.getSalience());
|
||||
|
||||
// Assertion in the same band expands.
|
||||
EXPECT_TRUE(modified(f.assertValue(0, {-0.5, 2.0})));
|
||||
EXPECT_TRUE(f.hasValue());
|
||||
EXPECT_EQ(std::make_pair(-1.0, 2.0), f.getValue());
|
||||
EXPECT_EQ(0, f.getSalience());
|
||||
|
||||
EXPECT_TRUE(modified(f.assertValue(0, {-2.0, 0.5})));
|
||||
EXPECT_TRUE(f.hasValue());
|
||||
EXPECT_EQ(std::make_pair(-2.0, 2.0), f.getValue());
|
||||
EXPECT_EQ(0, f.getSalience());
|
||||
|
||||
// Same band smaller bound does not modify.
|
||||
EXPECT_FALSE(modified(f.assertValue(0, {-0.5, 0.5})));
|
||||
EXPECT_TRUE(f.hasValue());
|
||||
EXPECT_EQ(std::make_pair(-2.0, 2.0), f.getValue());
|
||||
EXPECT_EQ(0, f.getSalience());
|
||||
|
||||
// Higher salience overrides.
|
||||
EXPECT_TRUE(modified(f.assertValue(10, {-0.2, 0.2})));
|
||||
EXPECT_TRUE(f.hasValue());
|
||||
EXPECT_EQ(std::make_pair(-0.2, 0.2), f.getValue());
|
||||
EXPECT_EQ(10, f.getSalience());
|
||||
|
||||
// Lower salience no-ops.
|
||||
EXPECT_FALSE(modified(f.assertValue(5, {-2.0, 2.0})));
|
||||
EXPECT_TRUE(f.hasValue());
|
||||
EXPECT_EQ(std::make_pair(-0.2, 0.2), f.getValue());
|
||||
EXPECT_EQ(10, f.getSalience());
|
||||
|
||||
// Merge from a fact without a value no-ops.
|
||||
ExpandingMinMaxFact f1;
|
||||
EXPECT_FALSE(modified(f.mergeFrom(f1)));
|
||||
EXPECT_TRUE(f.hasValue());
|
||||
EXPECT_EQ(std::make_pair(-0.2, 0.2), f.getValue());
|
||||
EXPECT_EQ(10, f.getSalience());
|
||||
|
||||
// Merge from a fact with a value merges.
|
||||
EXPECT_TRUE(modified(f1.mergeFrom(f)));
|
||||
EXPECT_TRUE(f1.hasValue());
|
||||
EXPECT_EQ(std::make_pair(-0.2, 0.2), f1.getValue());
|
||||
EXPECT_EQ(10, f1.getSalience());
|
||||
}
|
||||
|
||||
TEST(TestDiscreteFact, Basic) {
|
||||
TestDiscreteFact f;
|
||||
EXPECT_FALSE(f.hasValue());
|
||||
|
||||
// Initial value.
|
||||
EXPECT_TRUE(modified(f.assertValue(0, {2})));
|
||||
EXPECT_FALSE(modified(f.assertValue(0, {2})));
|
||||
EXPECT_EQ(2, f.getValue().value);
|
||||
EXPECT_FALSE(f.getValue().conflict);
|
||||
|
||||
// Conflicting update.
|
||||
EXPECT_TRUE(modified(f.assertValue(0, {4})));
|
||||
EXPECT_EQ(2, f.getValue().value); // Arbitrary but known to be first wins.
|
||||
EXPECT_TRUE(f.getValue().conflict);
|
||||
|
||||
// Further update still conflicts.
|
||||
EXPECT_FALSE(modified(f.assertValue(0, {4})));
|
||||
EXPECT_EQ(2, f.getValue().value); // Arbitrary but known to be first wins.
|
||||
EXPECT_TRUE(f.getValue().conflict);
|
||||
|
||||
// Different salience update does not conflict.
|
||||
EXPECT_TRUE(modified(f.assertValue(1, {6})));
|
||||
EXPECT_EQ(6, f.getValue().value);
|
||||
EXPECT_FALSE(f.getValue().conflict);
|
||||
}
|
||||
|
||||
} // end anonymous namespace
|
|
@ -1,142 +0,0 @@
|
|||
//===- UniformSolversTest.cpp - Tests for uniform solvers -----------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Quantizer/Support/UniformSolvers.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::quantizer;
|
||||
|
||||
namespace {
|
||||
|
||||
const double kEpsilon = 1e-12;
|
||||
|
||||
TEST(UniformMathTest, testAsym) {
|
||||
UniformParamsFromMinMaxSolver s(UniformStorageParams::getQuint8(), -8, 8.123);
|
||||
ASSERT_TRUE(s.compute());
|
||||
|
||||
llvm::errs() << "testBasic Results: " << s << "\n";
|
||||
|
||||
EXPECT_EQ(0.0, s.dequantize(s.getZp())); // Exact.
|
||||
EXPECT_EQ(s.getZp(), s.quantize(0.0));
|
||||
EXPECT_GE(s.getAdjMax() + kEpsilon, s.getBoundingMax());
|
||||
EXPECT_LE(s.getAdjMin() - kEpsilon, s.getBoundingMin());
|
||||
}
|
||||
|
||||
TEST(UniformMathTest, testPOT) {
|
||||
UniformParamsFromMinMaxSolver s(UniformStorageParams::getQuint8(), -8,
|
||||
7.9375);
|
||||
ASSERT_TRUE(s.compute());
|
||||
|
||||
llvm::errs() << "testPOT Results: " << s << "\n";
|
||||
|
||||
// POT ranges should be exact.
|
||||
EXPECT_EQ(128, s.getZp());
|
||||
EXPECT_NEAR(6.25e-2, s.getScale(), kEpsilon);
|
||||
EXPECT_EQ(0.0, s.dequantize(s.getZp())); // Exact.
|
||||
EXPECT_EQ(s.getZp(), s.quantize(0.0));
|
||||
EXPECT_GE(s.getAdjMax() + kEpsilon, s.getBoundingMax());
|
||||
EXPECT_LE(s.getAdjMin() - kEpsilon, s.getBoundingMin());
|
||||
}
|
||||
|
||||
TEST(UniformMathTest, testLopsidedPositive) {
|
||||
UniformParamsFromMinMaxSolver s(UniformStorageParams::getQuint8(), 1.0, 8.0);
|
||||
ASSERT_TRUE(s.compute());
|
||||
|
||||
llvm::errs() << "testLopsidedPositive Results: " << s << "\n";
|
||||
|
||||
EXPECT_EQ(0, s.getZp());
|
||||
EXPECT_EQ(0.0, s.dequantize(s.getZp())); // Exact.
|
||||
EXPECT_EQ(0, s.quantize(0.0));
|
||||
EXPECT_GE(s.getAdjMax() + kEpsilon, s.getBoundingMax());
|
||||
EXPECT_LE(s.getAdjMin() - kEpsilon, s.getBoundingMin());
|
||||
}
|
||||
|
||||
TEST(UniformMathTest, testLopsidedNegative) {
|
||||
UniformParamsFromMinMaxSolver s(UniformStorageParams::getQuint8(), -72.0,
|
||||
-4.0);
|
||||
ASSERT_TRUE(s.compute());
|
||||
|
||||
llvm::errs() << "testLopsidedNegative Results: " << s << "\n";
|
||||
|
||||
EXPECT_EQ(255, s.getZp());
|
||||
EXPECT_EQ(0.0, s.dequantize(s.getZp())); // Exact.
|
||||
EXPECT_EQ(255, s.quantize(0.0));
|
||||
EXPECT_GE(s.getAdjMax() + kEpsilon, s.getBoundingMax());
|
||||
EXPECT_LE(s.getAdjMin() - kEpsilon, s.getBoundingMin());
|
||||
}
|
||||
|
||||
TEST(UniformMathTest, testLargeRange) {
|
||||
UniformParamsFromMinMaxSolver s(UniformStorageParams::getQuint8(), -123.23389,
|
||||
231.1289);
|
||||
ASSERT_TRUE(s.compute());
|
||||
|
||||
llvm::errs() << "testLargeRange Results: " << s << "\n";
|
||||
|
||||
// EXPECT_EQ(255, s.getZp());
|
||||
EXPECT_EQ(0.0, s.dequantize(s.getZp())); // Exact.
|
||||
EXPECT_EQ(s.getZp(), s.quantize(0.0));
|
||||
EXPECT_GE(s.getAdjMax() + kEpsilon, s.getBoundingMax());
|
||||
EXPECT_LE(s.getAdjMin() - kEpsilon, s.getBoundingMin());
|
||||
}
|
||||
|
||||
TEST(UniformMathTest, test16BitLargeRange) {
|
||||
UniformParamsFromMinMaxSolver s(UniformStorageParams::getQuint16(),
|
||||
-123.23389, 231.1289);
|
||||
ASSERT_TRUE(s.compute());
|
||||
|
||||
llvm::errs() << "test16BitLargeRange Results: " << s << "\n";
|
||||
|
||||
EXPECT_EQ(0.0, s.dequantize(s.getZp())); // Exact.
|
||||
EXPECT_EQ(s.getZp(), s.quantize(0.0));
|
||||
EXPECT_GE(s.getAdjMax() + kEpsilon, s.getBoundingMax());
|
||||
EXPECT_LE(s.getAdjMin() - kEpsilon, s.getBoundingMin());
|
||||
}
|
||||
|
||||
TEST(UniformMathTest, testQuint8SymmetricRight) {
|
||||
UniformParamsFromMinMaxSolver s(
|
||||
UniformStorageParams::getQuint8SymmetricRight(), -123.23389, 231.1289);
|
||||
ASSERT_TRUE(s.compute());
|
||||
|
||||
llvm::errs() << "testQuint8SymmetricRight Results: " << s << "\n";
|
||||
|
||||
EXPECT_EQ(0.0, s.dequantize(s.getZp())); // Exact.
|
||||
EXPECT_EQ(s.getZp(), s.quantize(0.0));
|
||||
EXPECT_GE(s.getAdjMax() + kEpsilon, s.getBoundingMax());
|
||||
EXPECT_LE(s.getAdjMin() - kEpsilon, s.getBoundingMin());
|
||||
}
|
||||
|
||||
TEST(UniformMathTest, testQuint4) {
|
||||
UniformParamsFromMinMaxSolver s({15, 0}, -1.0, 1.0);
|
||||
ASSERT_TRUE(s.compute());
|
||||
|
||||
llvm::errs() << "testQuint4 Results: " << s << "\n";
|
||||
|
||||
EXPECT_EQ(0.0, s.dequantize(s.getZp())); // Exact.
|
||||
EXPECT_EQ(s.getZp(), s.quantize(0.0));
|
||||
EXPECT_GE(s.getAdjMax() + kEpsilon, s.getBoundingMax());
|
||||
EXPECT_LE(s.getAdjMin() - kEpsilon, s.getBoundingMin());
|
||||
}
|
||||
|
||||
TEST(UniformMathTest, testNan) {
|
||||
UniformParamsFromMinMaxSolver s({0, 0}, -1.0, 1.0);
|
||||
ASSERT_FALSE(s.compute());
|
||||
}
|
||||
|
||||
TEST(UniformMathTest, testBadBounds) {
|
||||
UniformParamsFromMinMaxSolver s(UniformStorageParams::getQuint16(), 123.23389,
|
||||
-231.1289);
|
||||
ASSERT_FALSE(s.compute());
|
||||
}
|
||||
|
||||
TEST(UniformMathTest, testZeroBounds) {
|
||||
UniformParamsFromMinMaxSolver s(UniformStorageParams::getQuint16(), 0, 0);
|
||||
ASSERT_FALSE(s.compute());
|
||||
}
|
||||
|
||||
} // end namespace
|
Loading…
Reference in New Issue