Remove FxpMathOps dialect and Quantizer tool.

Summary:
* Removal of FxpMathOps was discussed on the mailing list.
* Will send a courtesy note about also removing the Quantizer (which had some dependencies on FxpMathOps).
* These were only ever used for experimental purposes and we know how to get them back from history as needed.
* There is a new proposal for more generalized quantization tooling, so moving these older experiments out of the way helps clean things up.

Subscribers: mgorny, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, grosul1, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D77479
This commit is contained in:
Stella Laurenzo 2020-04-04 19:22:05 -07:00
parent da4ffc64e4
commit f5deb0878d
53 changed files with 0 additions and 4675 deletions

View File

@ -12,18 +12,10 @@ This document describes the available MLIR passes and their contracts.
[include "ConversionPasses.md"]
## Quantizer Passes
[include "QuantizerPasses.md"]
## `affine` Dialect Passes
[include "AffinePasses.md"]
## `fxpmath` Dialect Passes
[include "FxpMathPasses.md"]
## `gpu` Dialect Passes
[include "GPUPasses.md"]

View File

@ -188,23 +188,6 @@ MLIR:
* Passes and tools exist to convert directly from the *TensorFlow* dialect
to the TFLite quantized operation set.
* [*FxpMath* dialect](#fxpmath-dialect) containing (experimental) generalized
representations of fixed-point math operations and conversions:
* [Real math ops](#real-math-ops) representing common combinations of
arithmetic operations that closely match corresponding fixed-point math
concepts (as opposed to being spread across multiple ops as is typical
in source dialects).
* [Fixed-point math ops](#fixed-point-math-ops) that for carrying out
computations on integers, as are typically needed by uniform
quantization schemes.
* Passes to lower from real math operations to fixed-point math operations.
* [Solver tools](#solver-tools) which can (experimentally and generically
operate on computations expressed in the *FxpMath* dialect in order to
convert from floating point types to appropriate *QuantizedTypes*, allowing
the computation to be further lowered to integral math operations.
Not every application of quantization will use all of these facilities. Specifically, the
TensorFlow to TensorFlow Lite conversion uses the QuantizedTypes but has its own
operations for type conversion and expression of the supporting math.
@ -279,81 +262,3 @@ TODO : Flesh this out
1. Run quantization pass that take (tfl.DQ (for both input and weights) -> op
-> tfl.Q) and replaces with (op). Also replace (constant_float -> tfl.Q)
with (constant_quant).
## FxpMath dialect
### Real math operations
Note that these all support explicit clamps, which allows for simple fusions and
representation of some common sequences quantization-compatible math. Of
addition, some support explicit biases, which are often represented as separate
adds in source dialects.
TODO: This operation set is still evolving and needs to be completed.
* RealBinaryOp
* RealAddEwOp
* RealSubEwOp
* RealMulEwOp
* RealDivEwOp
* RealUnaryOp
* IDENTITY
* TANH
* SIGMOID
* EXP
* LOG
* NEG
* RSQRT
* SIN
* SQUARE
* SQRT
* CMPZ
* CMPNZ
* CMPLZ
* CMPGZ
### Fixed-point math operationss
TODO: This operation set only has enough operations to lower a simple power-of-two
RealAddEwOp.
* RoundingDivideByPotFxpOp
* SaturatingAddFxpOp
## Solver tools
Solver tools exist to analyze an MLIR-computation, expressed in either a
supported source dialect or in the *real math ops* set and solve for appropriate
QuantizedTypes that allow the computation to be lowered to integral math.
These tools are an active area of work and may be expanded in the future to
adjacent areas such as solving for transformations to other kinds of lower
precision types (i.e. bfloat16 or fp16).
Solver tools are expected to operate in several modes, depending on the
computation and the training characteristics of the model:
* *Transform* : With all available information in the MLIR computation, infer
boundaries where the computation can be carried out with integral math and
change types accordingly to appropriate QuantizedTypes:
* For passthrough ops which do not perform active math, change them to
operate directly on the storage type, converting in and out at the edges
via scast operations.
* For operations that have the *Quantizable* trait, the type can be set directly.
This includes operations from the [real math ops set]{#real-math-ops}.
* For others, encase them in appropriate dcast/qcast operations, presuming that
some follow-on pass will know what to do with them.
* *Instrument* : Most of the time, there are not sufficient implied
constraints within a computation to perform many transformations. For this
reason, the solver can insert instrumentation operations at points where additional
runtime statistics may yield solutions. It is expected that such
computations will be lowered as-is for execution, run over an appropriate
evaluation set, and statistics at each instrumentation point made available for a
future invocation of the solver.
* *Simplify* : A variety of passes and simplifications are applied once
QuantizedTypes are added in order to arrive at a computation that is
expressed in as much integral math, with the fewest number of casts as
possible.

View File

@ -3,7 +3,6 @@ get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
set(LIBS
${dialect_libs}
${conversion_libs}
MLIRQuantizerTransforms
MLIROptLib
MLIRStandalone
)

View File

@ -2,5 +2,4 @@ add_subdirectory(Conversion)
add_subdirectory(Dialect)
add_subdirectory(IR)
add_subdirectory(Interfaces)
add_subdirectory(Quantizer)
add_subdirectory(Transforms)

View File

@ -1,6 +1,5 @@
add_subdirectory(Affine)
add_subdirectory(AVX512)
add_subdirectory(FxpMathOps)
add_subdirectory(GPU)
add_subdirectory(Linalg)
add_subdirectory(LLVMIR)

View File

@ -1,8 +0,0 @@
add_mlir_dialect(FxpMathOps fxpmath)
add_mlir_doc(FxpMathOps -gen-dialect-doc FxpMathDialect Dialects/)
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(MLIRFxpMathPassIncGen)
add_mlir_doc(Passes -gen-pass-doc FxpMathPasses ./)

View File

@ -1,28 +0,0 @@
//===- FxpMathOps.h - Fixed point ops ---------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_FXPMATHOPS_FXPMATHOPS_H_
#define MLIR_DIALECT_FXPMATHOPS_FXPMATHOPS_H_
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Interfaces/SideEffects.h"
namespace mlir {
namespace fxpmath {
#include "mlir/Dialect/FxpMathOps/FxpMathOpsDialect.h.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/FxpMathOps/FxpMathOps.h.inc"
} // namespace fxpmath
} // namespace mlir
#endif // MLIR_DIALECT_FXPMATHOPS_FXPMATHOPS_H_

View File

@ -1,278 +0,0 @@
//===- FxpMathOps.td - Fixed point ops --------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the operation definition file for fixed point ops (and real
// equivalents).
//
//===----------------------------------------------------------------------===//
#ifndef DIALECT_FXPMATHOPS_FXPMATH_OPS_
#define DIALECT_FXPMATHOPS_FXPMATH_OPS_
include "mlir/IR/OpBase.td"
include "mlir/Dialect/Quant/QuantOpsBase.td"
include "mlir/Interfaces/SideEffects.td"
def FxpMathOps_Dialect : Dialect {
let name = "fxpmath";
}
//===----------------------------------------------------------------------===//
// Attributes
//===----------------------------------------------------------------------===//
// Real value for an (inclusive) min/max clamp limit.
def fxpmath_ClampValueAttr : OptionalAttr<F64Attr>;
// Element-wise activation function to apply.
// Note that RELU activations are not here: they are expressed as clamps.
def fxpmath_EwUnaryFnAttr :
StringBasedAttr<CPred<"true">, "element-wise unary function"> {
let returnType = [{ StringRef }];
let defaultValue = "IDENTITY";
}
class fxpmath_ConstEwUnaryFn<string val> : ConstantAttr<fxpmath_EwUnaryFnAttr, val>;
def fxpmath_EwUnaryFn_Abs : fxpmath_ConstEwUnaryFn<"ABS">;
def fxpmath_EwUnaryFn_Exp : fxpmath_ConstEwUnaryFn<"EXP">;
def fxpmath_EwUnaryFn_Identity: fxpmath_ConstEwUnaryFn<"IDENTITY">;
def fxpmath_EwUnaryFn_Log : fxpmath_ConstEwUnaryFn<"LOG">;
def fxpmath_EwUnaryFn_Neg : fxpmath_ConstEwUnaryFn<"NEG">;
def fxpmath_EwUnaryFn_Rsqrt : fxpmath_ConstEwUnaryFn<"RSQRT">;
def fxpmath_EwUnaryFn_Sigmoid : fxpmath_ConstEwUnaryFn<"SIGMOID">;
def fxpmath_EwUnaryFn_Sign : fxpmath_ConstEwUnaryFn<"SIGN">;
def fxpmath_EwUnaryFn_Sin : fxpmath_ConstEwUnaryFn<"SIN">;
def fxpmath_EwUnaryFn_Sqrt : fxpmath_ConstEwUnaryFn<"SQRT">;
def fxpmath_EwUnaryFn_Square : fxpmath_ConstEwUnaryFn<"SQUARE">;
def fxpmath_EwUnaryFn_Tanh : fxpmath_ConstEwUnaryFn<"TANH">;
//===----------------------------------------------------------------------===//
// Comparison functions (compares relative to zero on a subtraction result).
//===----------------------------------------------------------------------===//
def fxpmath_CompareZ : StrEnumAttrCase<"CMPZ">;
def fxpmath_CompareNZ : StrEnumAttrCase<"CMPNZ">;
def fxpmath_CompareLZ : StrEnumAttrCase<"CMPLZ">;
def fxpmath_CompareLZE : StrEnumAttrCase<"CMPLZE">;
def fxpmath_CompareGZ : StrEnumAttrCase<"CMPGZ">;
def fxpmath_CompareGZE : StrEnumAttrCase<"CMPGZE">;
def fxpmath_CompareFnAttr : StrEnumAttr<"ComparisonFn",
"Type of subtraction-result comparison to perform.",
[
fxpmath_CompareZ,
fxpmath_CompareNZ,
fxpmath_CompareLZ,
fxpmath_CompareLZE,
fxpmath_CompareGZ,
fxpmath_CompareGZE
]>;
//===----------------------------------------------------------------------===//
// Base classes
//===----------------------------------------------------------------------===//
class fxpmath_Op<string mnemonic, list<OpTrait> traits> :
Op<FxpMathOps_Dialect, mnemonic, traits>;
//===----------------------------------------------------------------------===//
// Fixed-point (fxp) arithmetic ops used by kernels.
// Some of these are temporary pending inclusion into a more core dialect.
//===----------------------------------------------------------------------===//
def fxpmath_ClampISOp : fxpmath_Op<"clampis", [NoSideEffect, SameOperandsAndResultType]> {
let summary =
"Clamps a signed-integer like argument to a min/max range.";
let description = [{
Element-wise equivalent to:
r = std::min(clamp_max, std::max(e, clamp_min))
}];
let arguments = (ins SignlessIntegerLike:$operand,
APIntAttr:$clamp_min,
APIntAttr:$clamp_max);
let results = (outs SignlessIntegerLike);
}
def fxpmath_ConvertISOp :
fxpmath_Op<"convertis",
[NoSideEffect, SameOperandsAndResultShape]> {
let summary =
"Does an element-wise conversion from a signed integer to signed integer";
let description = [{
Similar to an element-wise static_cast in C++, from a one signed integer
element type to another.
}];
let arguments = (ins SignlessIntegerLike:$operand);
let results = (outs SignlessIntegerLike);
}
def fxpmath_ConvertISToFOp :
fxpmath_Op<"convertistof",
[NoSideEffect, SameOperandsAndResultShape]> {
let summary =
"Does an element-wise conversion from a signed integer to a float";
let description = [{
Similar to an element-wise static_cast in C++, from a signed integer
element type to a floating point element type, rounding to the nearest
floating point value.
}];
let arguments = (ins SignlessIntegerLike:$operand);
let results = (outs FloatLike);
}
def fxpmath_VecScalarSaturatingRoundingDoublingHighMulISOp :
fxpmath_Op<"vs_saturating_rounding_doubling_high_mulis",
[NoSideEffect, SameOperandsAndResultType]> {
let summary = "Implements equivalent functionality to ARMv7 NEON VQRDMULH";
let description = [{
Equivalent to the ARMv7 NEON VQRDMULH instruction.
See gemmlowp::SaturatingRoundingDoublingHighMul for a reference
implementation.
}];
let arguments = (ins SignlessIntegerLike:$a, APIntAttr:$b);
let results = (outs SignlessIntegerLike);
}
def fxpmath_RoundingDivideByPotISOp :
fxpmath_Op<"rounding_divide_by_potis", [NoSideEffect, SameOperandsAndResultType]> {
let summary = [{
Computes a rounding arithmetic right shift.
}];
let description = [{
Computes integer division by a power-of-two, correctly rounded-to-nearest.
Also known as a rounding arithmetic right shift. See
gemmlowp::RoundingDivideByPOT for a reference implementation.
}];
let arguments = (ins SignlessIntegerLike:$operand, APIntAttr:$exponent);
let results = (outs SignlessIntegerLike:$res);
let verifier = [{
auto verifyExponent = exponent().getSExtValue();
if (verifyExponent < 0 || verifyExponent > 31) {
return emitOpError("exponent must be in range [0..31]");
}
return success();
}];
}
//===----------------------------------------------------------------------===//
// Real math ops.
//
// Math ops on real numbers which may have a representation in quantized
// arithmetic. It is expected that eligible ops are lowered from a source
// dialect to this set of ops prior to the process of converting a computation
// to a quantized form. It is a non-goal of these ops to preserve enough
// information to convert back to the higher level, source dialect.
//
// These ops support either real/floating point or QuantizedTypes as operands
// and results. Since not all transformations are supported (globally or
// sometimes for specific targets), a computation may end up with
// untransformable RealMathOps, in which case they need to be lowered as is
// (using floating point math).
//
// This op set takes advantage of the fact that it is typically trivial to
// combine a math function with a compatible bias addition and real-valued
// clamp (which can be done at a higher accumulation bit depth).
//
// In addition, all element-wise unary functions are collapsed into a single
// fxpmath_RealUnaryEwOp and selected via an enum-like attribute. Especially at
// low bit depths, this makes matching simpler and allows the construction of
// generic LUT-based implementations. It also allows specific lowering rules
// to consolidate runs of chained unary ops and fuse them to preceding math
// ops, potentially allowing them to operate directly on higher precision
// intermediates without resorting to lots of custom kernels for common
// formulas that can suffer from insufficient precision at low bit depths.
//
// Comparison operators are modeled as element-wise unary functions (i.e.
// CMPZ, CMPNZ, CMPLZ, CMPGZ) intended to follow a sub and output a 1bit
// quantized value. It is expected that lowering rules can fuse them with
// the preceding sub.
//===----------------------------------------------------------------------===//
class fxpmath_RealMathOp<string mnemonic, list<OpTrait> traits = [], dag args> :
fxpmath_Op<mnemonic, traits>,
Arguments<!con(args, (ins
fxpmath_ClampValueAttr:$clamp_min, fxpmath_ClampValueAttr:$clamp_max))>;
//===----------------------------------------------------------------------===//
// Element wise binary real math ops.
//===----------------------------------------------------------------------===//
class fxpmath_RealBinaryOp<string mnemonic, list<OpTrait> traits = []> :
fxpmath_RealMathOp<mnemonic, traits,
(ins quant_RealValueType:$lhs,
quant_RealValueType:$rhs)>,
Results<(outs quant_RealValueType:$res)>;
class fxpmath_RealBinaryBiasOp<string mnemonic, list<OpTrait> traits = []> :
fxpmath_RealMathOp<mnemonic, traits,
(ins quant_RealValueType:$lhs, quant_RealValueType:$rhs,
quant_RealValueType:$bias)>,
Results<(outs quant_RealValueType:$res)>;
def fxpmath_RealAddEwOp :
fxpmath_RealBinaryOp<"real_add_ew", [NoSideEffect]>;
def fxpmath_RealSubEwOp :
fxpmath_RealBinaryOp<"real_sub_ew", [NoSideEffect]>;
def fxpmath_RealMulEwOp :
fxpmath_RealBinaryOp<"real_mul_ew", [NoSideEffect]>;
def fxpmath_RealDivEwOp :
fxpmath_RealBinaryOp<"real_div_ew", [NoSideEffect]>;
//===----------------------------------------------------------------------===//
// Element wise unary real math op.
//===----------------------------------------------------------------------===//
def fxpmath_RealUnaryEwOp :
fxpmath_RealMathOp<"real_unary_ew", [NoSideEffect],
(ins quant_RealValueType:$operand, fxpmath_EwUnaryFnAttr:$fn)>,
Results<(outs quant_RealValueType:$res)>;
def fxpmath_RealCompareZeroEwOp : fxpmath_Op<"compare", [NoSideEffect]>,
Arguments<(ins quant_RealValueType:$operand, fxpmath_CompareFnAttr:$fn)>,
Results<(outs I1Tensor:$res)> {
let description = [{
Compares a real value to zero, returning an I1 (boolean) tensor with the
result of applying the comparison function.
}];
}
//===----------------------------------------------------------------------===//
// Dot op with fused bias addition.
//===----------------------------------------------------------------------===//
def fxpmath_RealMatMulOp :
fxpmath_RealBinaryOp<"real_matmul", [NoSideEffect]> {
let summary = "Matmul";
let description = [{
A matrix multiply of [m, k] and [k, n] -> [m, n] where the bias vector is
of shape [n]. Also accepts rank 3 or more input tensors, in which case
the leading dimensions are batch dims.
Many real systems have specific library calls optimized for this precise
operation, which is why it is handled explicitly versus purely as a
generalized tensor contraction.
}];
}
def fxpmath_RealMatMulBiasOp :
fxpmath_RealBinaryBiasOp<"real_matmul_bias", [NoSideEffect]> {
let summary = "Matmul with bias";
let description = [{
A specialization of a RealMatMulOp that also accepts an [n] dimension
bias vector.
In addition, there is often special support for a fused bias and clamp,
which is why they are included.
}];
}
#endif // DIALECT_FXPMATHOPS_FXPMATH_OPS_

View File

@ -1,37 +0,0 @@
//===- Passes.h - Fixed point math passes -----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines all of the passes owned by the FxpMathOps dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_FXPMATHOPS_PASSES_H
#define MLIR_DIALECT_FXPMATHOPS_PASSES_H
#include <memory>
namespace mlir {
class FuncOp;
template <typename T> class OpPassBase;
namespace fxpmath {
/// Creates a pass that lowers uniform-quantized real math ops to integer
/// arithmetic. This will leave unrecognized real math ops as-is and is
/// typically followed by a pass that lowers any unrecognized ops to a pure
/// floating point form.
std::unique_ptr<OpPassBase<FuncOp>> createLowerUniformRealMathPass();
/// Creates a pass that lowers uniform-quantized qcast/dcast ops to equivalent
/// operations that perform quantize/dequantize.
std::unique_ptr<OpPassBase<FuncOp>> createLowerUniformCastsPass();
} // namespace fxpmath
} // namespace mlir
#endif // MLIR_DIALECT_FXPMATHOPS_PASSES_H

View File

@ -1,24 +0,0 @@
//===-- Passes.td - FxpMath pass definition file -----------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_FXPMATH_PASSES
#define MLIR_DIALECT_FXPMATH_PASSES
include "mlir/Pass/PassBase.td"
def FxpMathLowerUniformCasts : Pass<"fxpmath-lower-uniform-casts"> {
let summary = "Lowers uniform-quantized casts";
let constructor = "mlir::fxpmath::createLowerUniformCastsPass()";
}
def FxpMathLowerUniformRealMath : Pass<"fxpmath-lower-uniform-real-math"> {
let summary = "Lowers uniform-quantized real math ops to integer arithmetic";
let constructor = "mlir::fxpmath::createLowerUniformRealMathPass()";
}
#endif // MLIR_DIALECT_FXPMATH_PASSES

View File

@ -16,7 +16,6 @@
#include "mlir/Dialect/AVX512/AVX512Dialect.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@ -41,7 +40,6 @@ inline void registerAllDialects() {
static bool init_once = []() {
registerDialect<AffineDialect>();
registerDialect<avx512::AVX512Dialect>();
registerDialect<fxpmath::FxpMathOpsDialect>();
registerDialect<gpu::GPUDialect>();
registerDialect<LLVM::LLVMAVX512Dialect>();
registerDialect<LLVM::LLVMDialect>();

View File

@ -28,14 +28,12 @@
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/FxpMathOps/Passes.h"
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/LoopOps/Passes.h"
#include "mlir/Dialect/Quant/Passes.h"
#include "mlir/Dialect/SPIRV/Passes.h"
#include "mlir/Quantizer/Transforms/Passes.h"
#include "mlir/Transforms/LocationSnapshot.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/ViewOpGraph.h"
@ -65,10 +63,6 @@ inline void registerAllPasses() {
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Affine/Passes.h.inc"
// FxpMath
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/FxpMathOps/Passes.h.inc"
// GPU
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/GPU/Passes.h.inc"
@ -88,8 +82,6 @@ inline void registerAllPasses() {
// Quant
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Quant/Passes.h.inc"
#define GEN_PASS_REGISTRATION
#include "mlir/Quantizer/Transforms/Passes.h.inc"
// SPIR-V
#define GEN_PASS_REGISTRATION

View File

@ -1 +0,0 @@
add_subdirectory(Transforms)

View File

@ -1,41 +0,0 @@
//===- FxpMathConfig.h - Reference fixed point config -----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines a TargetConfiguration for reference fixed-point math
// quantization scheme based on the FxpMathOps (plus a small category of
// extension ops that can be added from other dialects).
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_QUANTIZER_CONFIGURATIONS_FXPMATHCONFIG_H
#define MLIR_QUANTIZER_CONFIGURATIONS_FXPMATHCONFIG_H
#include "mlir/Quantizer/Support/Configuration.h"
#include "mlir/Quantizer/Support/Metadata.h"
namespace mlir {
namespace quantizer {
/// Target configuration for a reference affine/fixed-point quantization
/// scheme defined in terms of the FxpMathOps dialect. This can be extended
/// with select ops from other dialects by way of the following public
/// methods:
/// - addValueIdentityOp
class FxpMathTargetConfig : public TargetConfiguration {
public:
/// Creates an FxpMathTargetConfig instance which can be further customized.
static std::unique_ptr<FxpMathTargetConfig> create(SolverContext &context);
protected:
FxpMathTargetConfig(SolverContext &context) : TargetConfiguration(context) {}
};
} // namespace quantizer
} // namespace mlir
#endif // MLIR_QUANTIZER_CONFIGURATIONS_FXPMATHCONFIG_H

View File

@ -1,146 +0,0 @@
//===- Configuration.h - Configuration object base classes ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// The quantizer is relatively agnostic to source and target dialects, with
// the specific represented by configuration policy objects derived from
// classes in this file.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_QUANTIZER_SUPPORT_CONFIGURATION_H
#define MLIR_QUANTIZER_SUPPORT_CONFIGURATION_H
#include <functional>
#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/IR/Identifier.h"
#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
#include "mlir/Quantizer/Support/Metadata.h"
#include "mlir/Quantizer/Support/Rules.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/StringSet.h"
namespace mlir {
class Operation;
namespace quantizer {
class CAGSlice;
/// Defines quantization configuration for the target.
/// The settings here depend on a variety of details about the deployment
/// environment, although, where we have control over such things, we do
/// try to standardize as possible.
///
/// Non-const methods are used to setup the configuration. It is expected that
/// const instances/references are used post-build.
class TargetConfiguration {
public:
static constexpr size_t MaxSchemeIndex = 31;
using OpHandlerFn = std::function<void(Operation *op, CAGSlice &cag)>;
TargetConfiguration(SolverContext &context);
virtual ~TargetConfiguration() = default;
/// Adds a candidate type, returning its ordinal.
unsigned addCandidateType(quant::AnyQuantizedType quantizedType,
CandidateQuantizedType::Scheme scheme) {
unsigned ordinal = candidateTypes.size();
assert(allCandidateTypesMask.size() == ordinal);
CandidateQuantizedType ct{ordinal, quantizedType, scheme};
candidateTypes.push_back(ct);
allCandidateTypesMask.push_back(true);
return ordinal;
}
/// Gets a prototype scheme by index.
const CandidateQuantizedType &getCandidateType(unsigned index) const {
assert(index < candidateTypes.size());
return candidateTypes[index];
}
ArrayRef<CandidateQuantizedType> getCandidateTypes() const {
return candidateTypes;
}
/// Gets a mask of all enabled candidate types by ordinal.
llvm::SmallBitVector getAllCandidateTypesMask() const {
return allCandidateTypesMask;
}
/// Gets a mask with every candidate type except those in the given mask.
llvm::SmallBitVector
getCandidateTypeDisabledExceptMask(ArrayRef<unsigned> exceptOrdinals) const {
llvm::SmallBitVector disabled(allCandidateTypesMask);
for (unsigned ordinal : exceptOrdinals) {
disabled.reset(ordinal);
}
return disabled;
}
/// Adds an op handler.
template <typename OpTy>
void addOpHandler(OpHandlerFn fn) {
addOpHandlerByName(OpTy::getOperationName(), fn);
}
/// Adds an operation which requires statistics at its result nodes for
/// best quantization performance. Note that the opName StringRef is
/// expected to come from getOperationName() and be static.
template <typename OpTy>
void addRequireStatsOp() {
addRequireStatsOpByName(OpTy::getOperationName());
}
/// Returns whether opName is a RequireStatsOp.
bool isRequireStatsOp(Operation *op) const;
/// Adds an op which does not mutate its values but may mutate its shape
/// or combine its operands in an arbitrary way.
/// Such ops are expected to have the same types for operands and results
/// and must be capable of operating on storage types.
template <typename OpTy>
void addValueIdentityOp() {
addValueIdentityOpByName(OpTy::getOperationName());
}
/// Handles the operation if a handler is defined for it.
void handleOp(Operation *op, CAGSlice &cag) const;
/// Finalizes the CAG after all anchors have been added.
virtual void finalizeAnchors(CAGSlice &cag) const {}
/// Whether an operand or result type is subject to analysis by this config.
virtual bool isHandledType(Type t) const = 0;
protected:
virtual void addValueIdentityOpByName(StringRef opName) = 0;
void addOpHandlerByName(StringRef name, OpHandlerFn fn);
private:
void addRequireStatsOpByName(StringRef opName);
/// Vector of all candidate type constraints, indexed by ordinal.
std::vector<CandidateQuantizedType> candidateTypes;
// A SmallBoolVector with bits set for all known candidate types.
llvm::SmallBitVector allCandidateTypesMask;
/// Map of all op handlers.
llvm::StringMap<OpHandlerFn> opHandlers;
/// Names of operations which should have their results annotated with
/// statistics.
llvm::StringSet<> requireStatsOpNames;
};
} // namespace quantizer
} // namespace mlir
#endif // MLIR_QUANTIZER_SUPPORT_CONFIGURATION_H

View File

@ -1,360 +0,0 @@
//===- ConstraintAnalysisGraph.h - Graphs type for constraints --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file provides graph-based data structures for representing anchors
// and constraints between them.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPH_H
#define MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPH_H
#include <utility>
#include <vector>
#include "mlir/IR/Function.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/Quantizer/Support/Metadata.h"
#include "llvm/ADT/DenseMap.h"
namespace mlir {
namespace quantizer {
class CAGNode;
class CAGSlice;
class TargetConfiguration;
/// A node in the Constraint Analysis Graph.
/// Nodes are either anchors (representing results and operands) or constraints.
/// Anchor nodes are connected to other anchor nodes via constraints.
/// Nodes exist within graph slices, which are typically analyses attached to
/// the function or module. Slices can contain other slices, which mirrors
/// the nesting of analyses.
///
/// Nodes have directed relationships which propagate successor-ward when dirty.
/// Relationships can be bi-directional, in which case, the constraint's
/// propagation mechanism must ensure convergence.
class CAGNode {
public:
enum class Kind {
/// Anchors.
Anchor,
OperandAnchor,
ResultAnchor,
LastAnchor = ResultAnchor,
/// Constraints.
Constraint,
SolveUniformConstraint,
UniformPropagateExplicitScale,
LastConstraint = UniformPropagateExplicitScale,
};
// Vector and iterator over nodes.
using node_vector = SmallVector<CAGNode *, 1>;
using iterator = node_vector::iterator;
using const_iterator = node_vector::const_iterator;
virtual ~CAGNode() = default;
Kind getKind() const { return kind; }
/// Unique id of the node within the slice.
int getNodeId() const { return nodeId; }
/// Whether the node is dirty, requiring one or more calls to propagate().
bool isDirty() const { return dirty; }
void markDirty() { dirty = true; }
void clearDirty() { dirty = false; }
/// Iterator over this node's children (outgoing) nodes.
const_iterator begin() const { return outgoing.begin(); }
const_iterator end() const { return outgoing.end(); }
iterator begin() { return outgoing.begin(); }
iterator end() { return outgoing.end(); }
/// Iterator over this parents (incoming) nodes.
const_iterator incoming_begin() const { return incoming.begin(); }
const_iterator incoming_end() const { return incoming.end(); }
iterator incoming_begin() { return incoming.begin(); }
iterator incoming_end() { return incoming.end(); }
virtual void propagate(SolverContext &solverContext,
const TargetConfiguration &config) {}
/// Prints the node label, suitable for one-line display.
virtual void printLabel(raw_ostream &os) const;
template <typename T> void findChildrenOfKind(SmallVectorImpl<T *> &found) {
for (CAGNode *child : *this) {
T *ofKind = dyn_cast<T>(child);
if (ofKind) {
found.push_back(ofKind);
}
}
}
/// Replaces this node by rerouting any parent nodes to have otherNode
/// as a child.
void replaceIncoming(CAGNode *otherNode);
/// Adds an outgoing connection to this node (and corresponding back
/// incoming connection).
void addOutgoing(CAGNode *toNode);
/// Whether this node is an orphan (has no incoming or outgoing connections).
bool isOrphan() const { return incoming.empty() && outgoing.empty(); }
protected:
CAGNode(Kind kind) : kind(kind) {}
private:
Kind kind;
int nodeId = -1;
node_vector outgoing;
node_vector incoming;
bool dirty = false;
friend class CAGSlice;
};
/// Anchor nodes represent points in the source IR where we may choose to
/// introduce a type transition. These include operands, results, arguments
/// returns, etc.
class CAGAnchorNode : public CAGNode {
public:
enum class TypeTransformRule {
/// The owning op directly supports all transformed types. In practice,
/// this means that the op supports QuantizedType for this anchor.
Direct,
/// The type of this anchor should be set to the QuantizedType storage
/// type. This will only be valid if constraints are such that all
/// inputs/outputs converge to the same storage type (i.e. coupled).
DirectStorage,
/// The anchor must only be typed based on the expressed type. This is
/// used for ops that do not natively support quantization, and suitable
/// casts will be inserted.
ExpressedOnly,
};
/// Metadata for solving uniform quantization params.
CAGUniformMetadata &getUniformMetadata() { return uniformMetadata; }
const CAGUniformMetadata &getUniformMetadata() const {
return uniformMetadata;
}
virtual Operation *getOp() const = 0;
virtual Value getValue() const = 0;
static bool classof(const CAGNode *n) {
return n->getKind() >= Kind::Anchor && n->getKind() <= Kind::LastAnchor;
}
void propagate(SolverContext &solverContext,
const TargetConfiguration &config) override;
void printLabel(raw_ostream &os) const override;
/// Given the anchor metadata and resolved solutions, chooses the most
/// salient and returns an appropriate type to represent it.
Type getTransformedType();
TypeTransformRule getTypeTransformRule() const { return typeTransformRule; }
void setTypeTransformRule(TypeTransformRule r) { typeTransformRule = r; }
/// Gets the Type that was defined for this anchor at the time of
/// construction.
Type getOriginalType() const { return originalType; }
protected:
CAGAnchorNode(Kind kind, Type originalType)
: CAGNode(kind), originalType(originalType) {}
private:
CAGUniformMetadata uniformMetadata;
Type originalType;
TypeTransformRule typeTransformRule = TypeTransformRule::Direct;
};
/// An anchor tied to a specific operand.
/// Since operand anchors can be rewritten so that the operand refers to
/// a new result, they are maintained by reference (to the op and index).
class CAGOperandAnchor : public CAGAnchorNode {
public:
CAGOperandAnchor(Operation *op, unsigned operandIdx);
Operation *getOp() const final { return op; }
unsigned getOperandIdx() const { return operandIdx; }
static bool classof(const CAGNode *n) {
return n->getKind() == Kind::Anchor || n->getKind() == Kind::OperandAnchor;
}
Value getValue() const final { return op->getOperand(operandIdx); }
void printLabel(raw_ostream &os) const override;
private:
Operation *op;
unsigned operandIdx;
};
/// An anchor tied to a specific result.
/// Since a result is already anchored to its defining op, result anchors refer
/// directly to the underlying Value.
class CAGResultAnchor : public CAGAnchorNode {
public:
CAGResultAnchor(Operation *op, unsigned resultIdx);
static bool classof(const CAGNode *n) {
return n->getKind() == Kind::Anchor || n->getKind() == Kind::ResultAnchor;
}
Operation *getOp() const final { return resultValue.getDefiningOp(); }
Value getValue() const final { return resultValue; }
void printLabel(raw_ostream &os) const override;
private:
Value resultValue;
};
/// Base class for constraint nodes.
class CAGConstraintNode : public CAGNode {
public:
CAGConstraintNode(Kind kind) : CAGNode(kind) {}
static bool classof(const CAGNode *n) {
return n->getKind() >= Kind::Constraint &&
n->getKind() <= Kind::LastConstraint;
}
};
/// A slice of a CAG (which may be the whole graph).
class CAGSlice {
public:
CAGSlice(SolverContext &context);
~CAGSlice();
using node_vector = std::vector<CAGNode *>;
using iterator = node_vector::iterator;
using const_iterator = node_vector::const_iterator;
iterator begin() { return allNodes.begin(); }
iterator end() { return allNodes.end(); }
const_iterator begin() const { return allNodes.begin(); }
const_iterator end() const { return allNodes.end(); }
/// Gets an operand anchor node.
CAGOperandAnchor *getOperandAnchor(Operation *op, unsigned operandIdx);
/// Gets a result anchor node.
CAGResultAnchor *getResultAnchor(Operation *op, unsigned resultIdx);
/// Adds a relation constraint with incoming 'from' anchors and outgoing 'to'
/// anchors.
template <typename T, typename... Args>
T *addUniqueConstraint(ArrayRef<CAGAnchorNode *> anchors, Args... args) {
static_assert(std::is_convertible<T *, CAGConstraintNode *>(),
"T must be a CAGConstraingNode");
T *constraintNode = addNode(std::make_unique<T>(args...));
for (auto *anchor : anchors)
anchor->addOutgoing(constraintNode);
return constraintNode;
}
/// Adds a unidirectional constraint from a node to an array of target nodes.
template <typename T, typename... Args>
T *addUnidirectionalConstraint(CAGAnchorNode *fromAnchor,
ArrayRef<CAGAnchorNode *> toAnchors,
Args... args) {
static_assert(std::is_convertible<T *, CAGConstraintNode *>(),
"T must be a CAGConstraingNode");
T *constraintNode = addNode(std::make_unique<T>(args...));
fromAnchor->addOutgoing(constraintNode);
for (auto *toAnchor : toAnchors) {
constraintNode->addOutgoing(toAnchor);
}
return constraintNode;
}
template <typename T>
T *addClusteredConstraint(ArrayRef<CAGAnchorNode *> anchors) {
static_assert(std::is_convertible<T *, CAGConstraintNode *>(),
"T must be a CAGConstraingNode");
SmallVector<T *, 8> cluster;
for (auto *anchor : anchors) {
anchor->findChildrenOfKind<T>(cluster);
}
T *constraintNode;
if (cluster.empty()) {
// Create new.
constraintNode = addNode(std::make_unique<T>());
} else {
// Merge existing.
constraintNode = cluster[0];
for (size_t i = 1, e = cluster.size(); i < e; ++i) {
cluster[i]->replaceIncoming(constraintNode);
}
}
for (auto *anchor : anchors) {
anchor->addOutgoing(constraintNode);
}
return constraintNode;
}
/// Enumerates all implied connections in the slice.
/// An implied connection is any two nodes that physically refer to the
/// same value in the IR, such as result->operand.
/// Typically this will be modeled with some kind of strong or weak
/// identity constraint such that types propagate.
/// This is usually called when the slice has been fully constructed in
/// order to add final constraints.
/// It is legal for the callback to modify the graph by adding constraints.
void enumerateImpliedConnections(
std::function<void(CAGAnchorNode *from, CAGAnchorNode *to)> callback);
/// Performs one round of propagation, returning the number of nodes
/// propagates. If returns > 0, then additional propagate() rounds are
/// required.
unsigned propagate(const TargetConfiguration &config);
private:
/// Adds a node to the graph.
/// The node should be a subclass of TransformNode.
/// Returns the raw pointer to the node.
template <typename T>
T *addNode(std::unique_ptr<T> node) {
node->nodeId = allNodes.size();
T *unownedNode = node.release();
allNodes.push_back(unownedNode);
return unownedNode;
}
SolverContext &context;
std::vector<CAGNode *> allNodes;
DenseMap<std::pair<Operation *, unsigned>, CAGOperandAnchor *> operandAnchors;
DenseMap<std::pair<Operation *, unsigned>, CAGResultAnchor *> resultAnchors;
};
inline raw_ostream &operator<<(raw_ostream &os, const CAGNode &node) {
node.printLabel(os);
return os;
}
} // namespace quantizer
} // namespace mlir
#endif // MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPH_H

View File

@ -1,49 +0,0 @@
//===- ConstraintAnalysisGraphTraits.h - Traits for CAGs --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Provides graph traits for constraint analysis graphs.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPHTRAITS_H
#define MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPHTRAITS_H
#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
#include "llvm/ADT/GraphTraits.h"
namespace llvm {
template <>
struct GraphTraits<const mlir::quantizer::CAGNode *> {
using NodeRef = const mlir::quantizer::CAGNode *;
static NodeRef getEntryNode(NodeRef node) { return node; }
// Successors.
using ChildIteratorType = mlir::quantizer::CAGNode::const_iterator;
static ChildIteratorType child_begin(NodeRef node) { return node->begin(); }
static ChildIteratorType child_end(NodeRef node) { return node->end(); }
};
template <>
struct GraphTraits<const mlir::quantizer::CAGSlice *>
: public llvm::GraphTraits<const mlir::quantizer::CAGNode *> {
using nodes_iterator = mlir::quantizer::CAGSlice::const_iterator;
static mlir::quantizer::CAGSlice::const_iterator
nodes_begin(const mlir::quantizer::CAGSlice *G) {
return G->begin();
}
static mlir::quantizer::CAGSlice::const_iterator
nodes_end(const mlir::quantizer::CAGSlice *G) {
return G->end();
}
};
} // end namespace llvm
#endif // MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPHTRAITS_H

View File

@ -1,101 +0,0 @@
//===- Metadata.h - Top level types and metadata ----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains top level types needed to construct constraint graphs,
// including context/allocator support and concrete metadata structs for
// different quantization schemes (which must be attached to anchor nodes).
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_QUANTIZER_SUPPORT_METADATA_H
#define MLIR_QUANTIZER_SUPPORT_METADATA_H
#include <limits>
#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Quantizer/Support/Rules.h"
#include "llvm/ADT/SmallBitVector.h"
namespace mlir {
namespace quantizer {
class SolverContext {
public:
SolverContext(MLIRContext &mlirContext) : mlirContext(mlirContext) {}
MLIRContext &getMlirContext() { return mlirContext; }
llvm::BumpPtrAllocator &getAllocator() { return allocator; }
// Optional path to write a debug DOT file for the CAG.
StringRef getDebugCAGDotPath() const { return debugCAGDotPath; }
void setDebugCAGDotPath(StringRef p) { debugCAGDotPath = std::string(p); }
private:
MLIRContext &mlirContext;
llvm::BumpPtrAllocator allocator;
std::string debugCAGDotPath;
};
/// Candidate for a quantized type conversion.
struct CandidateQuantizedType {
// Note that scheme encodes more than just the target type: it also encodes
// additional constraints.
enum class Scheme {
// Uses aggregate range information for all nodes in the cluster to
// solve for uniform scale and zero point.
UniformPerLayer,
// Uses aggregate per-axis range information for all nodes in the cluster
// to solve for per-axis uniform scale and zero point.
UniformPerAxisFixedPoint,
// Uses the |explicitScaleZeroPoint| to set the scale (and zero point = 0)
// for the uniform type. This typically overrides all other constraints
// and is used for wide accumulator types (i.e. i32 bias vectors).
UniformExplicitFixedPointScale,
};
unsigned ordinal;
quant::AnyQuantizedType quantizedType;
Scheme scheme;
};
struct CAGUniformMetadata {
/// Default salience for facts that are derived from data either statically
/// discovered in the computation or observed from an outside source.
static constexpr int SalienceDefault = 0;
/// Highest salience level for facts derived from overrides provided
/// explicitly.
static constexpr int SalienceForced = 100;
/// Salience for facts derived from constraints in how the math is
/// expressed which must be satisfied.
static constexpr int SalienceRequired = 200;
/// The range that the scheme must represent in order to accommodate the
/// underlying data.
ExpandingMinMaxFact requiredRange;
/// Bool vector of scheme ordinals that are disabled.
llvm::SmallBitVector disabledCandidateTypes;
/// If set, then a solution has converged for the given per-layer scheme.
quant::QuantizedType selectedType;
/// Optional scale and zero point to be used by types which solve via the
/// UniformExplicitFixedPointScale scheme.
DiscreteScaleZeroPointFact explicitScaleZeroPoint;
/// Prints a summary of the metadata suitable for display in a graph label.
void printSummary(raw_ostream &os) const;
};
} // end namespace quantizer
} // end namespace mlir
#endif // MLIR_QUANTIZER_SUPPORT_METADATA_H

View File

@ -1,200 +0,0 @@
//===- Rules.h - Helpers for declaring facts and rules ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines helper classes and functions for managing state (facts),
// merging and tracking modification for various data types important for
// quantization.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_QUANTIZER_SUPPORT_RULES_H
#define MLIR_QUANTIZER_SUPPORT_RULES_H
#include "llvm/ADT/Optional.h"
#include <algorithm>
#include <limits>
#include <utility>
namespace mlir {
namespace quantizer {
/// Typed indicator of whether a mutator produces a modification.
struct ModificationResult {
enum ModificationEnum { Retained, Modified } value;
ModificationResult(ModificationEnum v) : value(v) {}
ModificationResult operator|(ModificationResult other) {
if (value == Modified || other.value == Modified) {
return ModificationResult(Modified);
} else {
return ModificationResult(Retained);
}
}
ModificationResult operator|=(ModificationResult other) {
value =
(value == Modified || other.value == Modified) ? Modified : Retained;
return *this;
}
};
inline ModificationResult modify(bool isModified = true) {
return ModificationResult{isModified ? ModificationResult::Modified
: ModificationResult::Retained};
}
inline bool modified(ModificationResult m) {
return m.value == ModificationResult::Modified;
}
/// A fact that can converge through forward propagation alone without the
/// need to track ownership or individual assertions. In practice, this works
/// for static assertions that are either minimized or maximized and do not
/// vary dynamically.
///
/// It is expected that ValueTy is appropriate to pass by value and has an
/// operator==. The BinaryReducer type should have two static methods:
/// using ValueTy : Type of the value.
/// ValueTy initialValue() : Returns the initial value of the fact.
/// ValueTy reduce(ValueTy lhs, ValueTy rhs) : Reduces two values.
template <typename BinaryReducer>
class BasePropagatedFact {
public:
using ValueTy = typename BinaryReducer::ValueTy;
using ThisTy = BasePropagatedFact<BinaryReducer>;
BasePropagatedFact()
: value(BinaryReducer::initialValue()),
salience(std::numeric_limits<int>::min()) {}
int getSalience() const { return salience; }
bool hasValue() const { return salience != std::numeric_limits<int>::min(); }
ValueTy getValue() const { return value; }
ModificationResult assertValue(int assertSalience, ValueTy assertValue) {
if (assertSalience > salience) {
// New salience band.
value = assertValue;
salience = assertSalience;
return modify(true);
} else if (assertSalience < salience) {
// Lower salience - ignore.
return modify(false);
}
// Merge within same salience band.
ValueTy updatedValue = BinaryReducer::reduce(value, assertValue);
auto mod = modify(value != updatedValue);
value = updatedValue;
return mod;
}
ModificationResult mergeFrom(const ThisTy &other) {
if (other.hasValue()) {
return assertValue(other.getSalience(), other.getValue());
}
return modify(false);
}
private:
ValueTy value;
int salience;
};
/// A binary reducer that expands a min/max range represented by a pair
/// of doubles such that it represents the largest of all inputs.
/// The initial value is (Inf, -Inf).
struct ExpandingMinMaxReducer {
using ValueTy = std::pair<double, double>;
static ValueTy initialValue() {
return std::make_pair(std::numeric_limits<double>::infinity(),
-std::numeric_limits<double>::infinity());
}
static ValueTy reduce(ValueTy lhs, ValueTy rhs) {
return std::make_pair(std::min(lhs.first, rhs.first),
std::max(lhs.second, rhs.second));
}
};
using ExpandingMinMaxFact = BasePropagatedFact<ExpandingMinMaxReducer>;
/// A binary reducer that minimizing a numeric type.
template <typename T>
struct MinimizingNumericReducer {
using ValueTy = T;
static ValueTy initialValue() {
if (std::numeric_limits<T>::has_infinity()) {
return std::numeric_limits<T>::infinity();
} else {
return std::numeric_limits<T>::max();
}
}
static ValueTy reduce(ValueTy lhs, ValueTy rhs) { return std::min(lhs, rhs); }
};
using MinimizingDoubleFact =
BasePropagatedFact<MinimizingNumericReducer<double>>;
using MinimizingIntFact = BasePropagatedFact<MinimizingNumericReducer<int>>;
/// A binary reducer that maximizes a numeric type.
template <typename T>
struct MaximizingNumericReducer {
using ValueTy = T;
static ValueTy initialValue() {
if (std::numeric_limits<T>::has_infinity()) {
return -std::numeric_limits<T>::infinity();
} else {
return std::numeric_limits<T>::min();
}
}
static ValueTy reduce(ValueTy lhs, ValueTy rhs) { return std::max(lhs, rhs); }
};
using MaximizingDoubleFact =
BasePropagatedFact<MaximizingNumericReducer<double>>;
using MaximizingIntFact = BasePropagatedFact<MaximizingNumericReducer<int>>;
/// A fact and reducer for tracking agreement of discrete values. The value
/// type consists of a |T| value and a flag indicating whether there is a
/// conflict (in which case, the preserved value is arbitrary).
template <typename T>
struct DiscreteReducer {
struct ValueTy {
ValueTy() : conflict(false) {}
ValueTy(T value) : value(value), conflict(false) {}
ValueTy(T value, bool conflict) : value(value), conflict(conflict) {}
llvm::Optional<T> value;
bool conflict;
bool operator==(const ValueTy &other) const {
if (conflict != other.conflict)
return false;
if (value && other.value) {
return *value == *other.value;
} else {
return !value && !other.value;
}
}
bool operator!=(const ValueTy &other) const { return !(*this == other); }
};
static ValueTy initialValue() { return ValueTy(); }
static ValueTy reduce(ValueTy lhs, ValueTy rhs) {
if (!lhs.value && !rhs.value)
return lhs;
else if (!lhs.value)
return rhs;
else if (!rhs.value)
return lhs;
else
return ValueTy(*lhs.value, *lhs.value != *rhs.value);
}
};
template <typename T>
using DiscreteFact = BasePropagatedFact<DiscreteReducer<T>>;
/// Discrete scale/zeroPoint fact.
using DiscreteScaleZeroPointFact = DiscreteFact<std::pair<double, int64_t>>;
} // end namespace quantizer
} // end namespace mlir
#endif // MLIR_QUANTIZER_SUPPORT_RULES_H

View File

@ -1,102 +0,0 @@
//===- Statistics.h - Collects statistics over tensors ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines adapters for extracting various (per layer and per axis)
// statistics over tensors.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_QUANTIZER_SUPPORT_STATISTICS_H
#define MLIR_QUANTIZER_SUPPORT_STATISTICS_H
#include "mlir/IR/Attributes.h"
namespace mlir {
namespace quantizer {
/// Statistics about a tensor axis (or the whole tensor).
struct TensorAxisStatistics {
int64_t sampleSize = 0;
double minValue = 0;
double maxValue = 0;
double mean = 0;
double variance = 0;
int64_t sampleSizePerAxis = 0;
SmallVector<double, 4> minValuePerAxis;
SmallVector<double, 4> maxValuePerAxis;
SmallVector<double, 4> meanPerAxis;
SmallVector<double, 4> variancePerAxis;
TensorAxisStatistics() {}
TensorAxisStatistics(int64_t sampleSize, double minValue, double maxValue,
double mean, double variance)
: sampleSize(sampleSize), minValue(minValue), maxValue(maxValue),
mean(mean), variance(variance) {}
TensorAxisStatistics(int64_t sampleSize, ArrayRef<double> minValues,
ArrayRef<double> maxValues, ArrayRef<double> means,
ArrayRef<double> variances)
: sampleSizePerAxis(sampleSize),
minValuePerAxis(minValues.begin(), minValues.end()),
maxValuePerAxis(maxValues.begin(), maxValues.end()),
meanPerAxis(means.begin(), means.end()),
variancePerAxis(variances.begin(), variances.end()) {}
void clear() { *this = TensorAxisStatistics(); }
};
/// Base class for querying statistics about a tensor.
class AbstractTensorStatistics {
public:
virtual ~AbstractTensorStatistics() = default;
/// Gets statistics across the whole tensor.
/// Returns true if statistics are valid and were populated.
virtual bool get(TensorAxisStatistics &stats) const { return false; }
/// Whether this instance supports querying per axis statistics. If true,
/// then getForAxis(...) can be used.
virtual bool supportsPerAxis() const { return false; }
/// Count of axes supported in a per-axis query.
virtual unsigned getAxisCount() const { return 0; }
/// Gets statistics for a specific axis (0..getAxisCount() - 1).
/// Returns true if statistics are valid and were populated.
virtual bool getForAxis(unsigned axis, TensorAxisStatistics &stats) const {
return false;
}
};
/// Wraps an MLIR Attribute and returns statistics about it.
/// It is expected that the attribute be one of:
/// FloatAttr (scalar)
/// DenseFPElementsAttr
/// OpaqueElementsAttr (with Float based type)
/// SparseElementAttr (with Float based type)
class AttributeTensorStatistics : public AbstractTensorStatistics {
public:
AttributeTensorStatistics(Attribute attr) : attr(attr) {}
bool get(TensorAxisStatistics &stats) const override;
bool supportsPerAxis() const override;
unsigned getAxisCount() const override;
bool getForAxis(unsigned axis, TensorAxisStatistics &stats) const override;
private:
Attribute attr;
};
raw_ostream &operator<<(raw_ostream &os, const TensorAxisStatistics &stats);
} // end namespace quantizer
} // end namespace mlir
#endif // MLIR_QUANTIZER_SUPPORT_STATISTICS_H

View File

@ -1,31 +0,0 @@
//===- TypeUtils.h - Helper function for manipulating types -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines various helper functions for manipulating types. The
// process of quantizing typically involves a number of type manipulations
// that are not very common elsewhere, and it is best to name them and define
// them here versus inline in the rest of the tool.
//
//===----------------------------------------------------------------------===//
#ifndef THIRD_PARTY_MLIR_EDGE_FXPSOLVER_SUPPORT_TYPEUTILS_H_
#define THIRD_PARTY_MLIR_EDGE_FXPSOLVER_SUPPORT_TYPEUTILS_H_
#include "mlir/IR/Types.h"
namespace mlir {
namespace quantizer {
/// Given an arbitrary container or primitive type, returns the element type,
/// where the element type is just the type for non-containers.
Type getElementOrPrimitiveType(Type t);
} // namespace quantizer
} // namespace mlir
#endif // THIRD_PARTY_MLIR_EDGE_FXPSOLVER_SUPPORT_TYPEUTILS_H_

View File

@ -1,60 +0,0 @@
//===- UniformConstraints.h - Constraints for uniform quant -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines a builder that lets you attach constraints necessary to
// perform a variety of uniform quantization conversions to CAG anchors.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_QUANTIZER_SUPPORT_UNIFORMCONSTRAINTS_H
#define MLIR_QUANTIZER_SUPPORT_UNIFORMCONSTRAINTS_H
#include "mlir/Quantizer/Support/Statistics.h"
namespace mlir {
namespace quantizer {
class CAGAnchorNode;
class CAGSlice;
/// Factory methods for adding CAG constraints of various kinds suitable
/// for solving for uniform quantization.
class UniformConstraintsBuilder {
public:
UniformConstraintsBuilder(CAGSlice &slice) : slice(slice) {}
/// Adds a coupling constraint between two nodes, effectively treating
/// them as a hard identity relationship.
void coupleAnchors(CAGAnchorNode *a, CAGAnchorNode *b);
/// Applies statistics constraints to the given anchor, such that the solver
/// ensures that the statistics are representable by chosen types.
void applyStats(CAGAnchorNode *a, TensorAxisStatistics stats);
/// Applies a constraint to a node which allows solutions that do not extend
/// beyond given min/max bounds (this is a hint that the tensor will not
/// take values outside of these bounds). If either minValue or maxValue is
/// NAN, then that side is considered open.
void clamp(CAGAnchorNode *a, APFloat minValue, APFloat maxValue);
/// Propagates an explicit scale from an anchor that may have a uniform
/// |selectedType| to the |explicitScaleZeroPoint| field of the to node.
/// This is typically used with a to node that has a candidate quantized
/// type of |UniformExplicitFixedPointScale|, indicating that it can be
/// an arbitrary (signed) type that is expected to share the same scale
/// as the originating node.
void propagateExplicitScale(CAGAnchorNode *from, CAGAnchorNode *to);
private:
CAGSlice &slice;
};
} // namespace quantizer
} // namespace mlir
#endif // MLIR_QUANTIZER_SUPPORT_UNIFORMCONSTRAINTS_H

View File

@ -1,86 +0,0 @@
//===- UniformSolvers.h - Uniform type solver algorithms --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines algorithms for solving uniform type parameters for various
// conditions (i.e. fixed-point, affine, scale matching, etc).
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_QUANTIZER_SUPPORT_UNIFORMSOLVERS_H
#define MLIR_QUANTIZER_SUPPORT_UNIFORMSOLVERS_H
#include <cstdint>
#include <limits>
namespace llvm {
class raw_ostream;
} // end namespace llvm
namespace mlir {
namespace quantizer {
struct UniformStorageParams {
static UniformStorageParams getQuint8() { return {255, 0}; }
static UniformStorageParams getQuint8SymmetricRight() { return {254, 1}; }
static UniformStorageParams getQuint16() { return {32767, 0}; }
uint64_t numLevels;
int64_t minValue;
};
/// Solves for the uniform quantization scheme parameters delta and z given
/// bounding min/max.
class UniformParamsFromMinMaxSolver {
public:
UniformParamsFromMinMaxSolver(const UniformStorageParams &storageParams,
double boundingMin, double boundingMax)
: storageParams(storageParams), boundingMin(boundingMin),
boundingMax(boundingMax) {}
/// Performs the computation, returning whether satisfied.
bool compute();
// Params.
double getBoundingMin() const { return boundingMin; }
double getBoundingMax() const { return boundingMax; }
bool isSatisfied() const { return satisfied; }
double getAdjMin() const { return adjMin; }
double getAdjMax() const { return adjMax; }
double getScale() const { return delta; }
int64_t getZp() const { return zp; }
int getStepCount() const { return stepCount; }
// Quantize and dequantize.
int64_t quantize(double x) const;
double dequantize(int64_t xq) const;
private:
const UniformStorageParams storageParams;
const double boundingMin;
const double boundingMax;
// Results
int stepCount = 0;
double adjMin = std::numeric_limits<double>::quiet_NaN();
double adjMax = std::numeric_limits<double>::quiet_NaN();
double delta = std::numeric_limits<double>::quiet_NaN();
int64_t zp = 0;
bool satisfied = false;
};
llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const UniformStorageParams &p);
llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const UniformParamsFromMinMaxSolver &s);
} // end namespace quantizer
} // end namespace mlir
#endif // MLIR_QUANTIZER_SUPPORT_UNIFORMSOLVERS_H

View File

@ -1,6 +0,0 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(MLIRQuantizerPassIncGen)
add_mlir_doc(Passes -gen-pass-doc QuantizerPasses ./)

View File

@ -1,43 +0,0 @@
//===- Passes.h - Quantizer passes -----------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines entry points to create passes to perform various kinds
// of quantization related transforms.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_QUANTIZER_TRANSFORMS_PASSES_H
#define MLIR_QUANTIZER_TRANSFORMS_PASSES_H
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace quantizer {
class SolverContext;
class TargetConfiguration;
/// Creates a pass that infers quantized types based on metadata discovered
/// in the computation.
std::unique_ptr<OpPassBase<ModuleOp>>
createInferQuantizedTypesPass(SolverContext &solverContext,
const TargetConfiguration &config);
std::unique_ptr<OpPassBase<ModuleOp>> createInferQuantizedTypesPass();
/// Creates a pass which removes any instrumentation and hint ops which have
/// no effect on final runtime.
std::unique_ptr<OpPassBase<FuncOp>> createRemoveInstrumentationPass();
/// Adds default (dummy) statistics to ops that can benefit from runtime stats.
/// Meant for testing.
std::unique_ptr<OpPassBase<FuncOp>> createAddDefaultStatsPass();
} // namespace quantizer
} // namespace mlir
#endif // MLIR_QUANTIZER_TRANSFORMS_PASSES_H

View File

@ -1,31 +0,0 @@
//===-- Passes.td - Quantizer pass definition file ---------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_QUANTIZER_TRANSFORMS_PASSES
#define MLIR_QUANTIZER_TRANSFORMS_PASSES
include "mlir/Pass/PassBase.td"
def QuantizerAddDefaultStats : Pass<"quantizer-add-default-stats-test"> {
let summary = "Add default (dummy) statistics to all ops that can benefit "
"from runtime statistics";
let constructor = "mlir::quantizer::createAddDefaultStatsPass()";
}
def QuantizerInferQuantizedTypes : Pass<"quantizer-infer-quantized-types"> {
let summary = "Infer quantized types for a module";
let constructor = "mlir::quantizer::createInferQuantizedTypesPass()";
}
def QuantizerRemoveInstrumentation : Pass<"quantizer-remove-instrumentation"> {
let summary = "Remove instrumentation and hints which have no effect on "
"final execution";
let constructor = "mlir::quantizer::createRemoveInstrumentationPass()";
}
#endif // MLIR_QUANTIZER_TRANSFORMS_PASSES

View File

@ -7,7 +7,6 @@ add_subdirectory(IR)
add_subdirectory(Interfaces)
add_subdirectory(Parser)
add_subdirectory(Pass)
add_subdirectory(Quantizer)
add_subdirectory(Support)
add_subdirectory(TableGen)
add_subdirectory(Target)

View File

@ -1,6 +1,5 @@
add_subdirectory(Affine)
add_subdirectory(AVX512)
add_subdirectory(FxpMathOps)
add_subdirectory(GPU)
add_subdirectory(Linalg)
add_subdirectory(LLVMIR)

View File

@ -1,21 +0,0 @@
add_mlir_dialect_library(MLIRFxpMathOps
IR/FxpMathOps.cpp
Transforms/LowerUniformRealMath.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/FxpMathOps
DEPENDS
MLIRFxpMathOpsIncGen
MLIRFxpMathPassIncGen
)
target_link_libraries(MLIRFxpMathOps
PUBLIC
MLIRQuant
MLIRIR
MLIRPass
MLIRSideEffects
MLIRSupport
MLIRStandardOps
)

View File

@ -1,29 +0,0 @@
//===- FxpMathOps.cpp - Op implementation for FxpMathOps ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StandardTypes.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/MathExtras.h"
using namespace mlir;
using namespace mlir::fxpmath;
#define GET_OP_CLASSES
#include "mlir/Dialect/FxpMathOps/FxpMathOps.cpp.inc"
FxpMathOpsDialect::FxpMathOpsDialect(MLIRContext *context)
: Dialect(/*name=*/"fxpmath", context) {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/FxpMathOps/FxpMathOps.cpp.inc"
>();
}

View File

@ -1,394 +0,0 @@
//===- LowerUniformRealMath.cpp ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "UniformKernelUtils.h"
#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
#include "mlir/Dialect/FxpMathOps/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
using namespace mlir::fxpmath;
using namespace mlir::fxpmath::detail;
using namespace mlir::quant;
namespace {
struct LowerUniformRealMathPass
: public FunctionPass<LowerUniformRealMathPass> {
/// Include the generated pass utilities.
#define GEN_PASS_FxpMathLowerUniformRealMath
#include "mlir/Dialect/FxpMathOps/Passes.h.inc"
void runOnFunction() override;
};
struct LowerUniformCastsPass : public FunctionPass<LowerUniformCastsPass> {
/// Include the generated pass utilities.
#define GEN_PASS_FxpMathLowerUniformCasts
#include "mlir/Dialect/FxpMathOps/Passes.h.inc"
void runOnFunction() override;
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// Dequantize
//===----------------------------------------------------------------------===//
static Value emitUniformPerLayerDequantize(Location loc, Value input,
UniformQuantizedType elementType,
PatternRewriter &rewriter) {
// Pre-conditions.
if (!elementType.isSigned()) {
// TODO: Support unsigned storage type.
emitWarning(loc, "unimplemented: dequantize signed uniform");
return nullptr;
}
Type storageType = elementType.castToStorageType(input.getType());
Type realType = elementType.castToExpressedType(input.getType());
Type intermediateType =
castElementType(storageType, IntegerType::get(32, rewriter.getContext()));
assert(storageType && "cannot cast to storage type");
assert(realType && "cannot cast to expressed type");
// Cast to storage type.
input = rewriter.create<StorageCastOp>(loc, storageType, input);
// Promote to intermediate type.
input = rewriter.create<ConvertISOp>(loc, intermediateType, input);
// Apply zero-point offset.
if (elementType.getZeroPoint() != 0) {
Value negZeroPointConst = rewriter.create<ConstantOp>(
loc, broadcastScalarConstIntValue(intermediateType,
-elementType.getZeroPoint()));
input = rewriter.create<AddIOp>(loc, input, negZeroPointConst);
}
// Convert to float.
input = rewriter.create<ConvertISToFOp>(loc, realType, input);
// Mul by scale.
Value scaleConst = rewriter.create<ConstantOp>(
loc, broadcastScalarConstFloatValue(realType,
APFloat(elementType.getScale())));
return rewriter.create<MulFOp>(loc, input, scaleConst);
}
static Value
emitUniformPerAxisDequantize(Location loc, Value input,
UniformQuantizedPerAxisType elementType,
PatternRewriter &rewriter) {
// TODO: Support per-axis dequantize.
rewriter.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Warning)
<< "unimplemented: per-axis uniform dequantization";
return nullptr;
}
static Value emitDequantize(Location loc, Value input,
PatternRewriter &rewriter) {
Type inputType = input.getType();
QuantizedType qElementType =
QuantizedType::getQuantizedElementType(inputType);
if (auto uperLayerElementType =
qElementType.dyn_cast_or_null<UniformQuantizedType>()) {
return emitUniformPerLayerDequantize(loc, input, uperLayerElementType,
rewriter);
} else if (auto uperAxisElementType =
qElementType.dyn_cast_or_null<UniformQuantizedPerAxisType>()) {
return emitUniformPerAxisDequantize(loc, input, uperAxisElementType,
rewriter);
} else {
return nullptr;
}
}
namespace {
struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> {
using OpRewritePattern<DequantizeCastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DequantizeCastOp op,
PatternRewriter &rewriter) const override {
Type inputType = op.arg().getType();
Type outputType = op.getResult().getType();
QuantizedType inputElementType =
QuantizedType::getQuantizedElementType(inputType);
Type expressedOutputType = inputElementType.castToExpressedType(inputType);
if (expressedOutputType != outputType) {
// Not a valid uniform cast.
return failure();
}
Value dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter);
if (!dequantizedValue) {
return failure();
}
rewriter.replaceOp(op, dequantizedValue);
return success();
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// Elementwise add
//===----------------------------------------------------------------------===//
static LogicalResult
tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info,
PatternRewriter &rewriter) {
if (!info.resultType.isSigned() || info.lhsType != info.resultType ||
info.rhsType != info.resultType) {
return failure();
}
// Choose a byte aligned intermediate width big enough to perform the
// calculation without overflow.
// TODO: This should probably be made just big enough to avoid overflow and
// leave the downstream tooling to decide how to align that to machine
// word sizes.
unsigned intermediateWidth =
info.resultType.getStorageTypeIntegralWidth() <= 8 ? 16 : 32;
IntegerType intermediateElementType =
IntegerType::get(intermediateWidth, rewriter.getContext());
Type intermediateType =
castElementType(info.resultStorageType, intermediateElementType);
// Cast operands to storage type.
Value lhsValue = rewriter
.create<StorageCastOp>(info.op->getLoc(),
info.lhsStorageType, info.lhs)
.getResult();
Value rhsValue = rewriter
.create<StorageCastOp>(info.op->getLoc(),
info.rhsStorageType, info.rhs)
.getResult();
// Cast to the intermediate sized type.
lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
lhsValue);
rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
rhsValue);
// Add.
Value resultValue =
rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, rhsValue);
// Zero point offset adjustment.
// result = (lhs - zp) + (rhs - zp) + zp
// zpOffset = -zp
int zpOffset = -1 * info.resultType.getZeroPoint();
if (zpOffset != 0) {
Value zpOffsetConst = rewriter.create<ConstantOp>(
info.op->getLoc(),
broadcastScalarConstIntValue(intermediateType, zpOffset));
resultValue =
rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
}
// Clamp.
auto clampMinMax = info.getClampMinMax(intermediateElementType);
resultValue = rewriter.create<ClampISOp>(
info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
// Convert back to original type.
resultValue = rewriter.create<ConvertISOp>(
info.op->getLoc(), info.resultStorageType, resultValue);
// Cast back for new result.
rewriter.replaceOpWithNewOp<StorageCastOp>(
info.op, info.getQuantizedResultType(), resultValue);
return success();
}
//===----------------------------------------------------------------------===//
// Elementwise mul
//===----------------------------------------------------------------------===//
static LogicalResult
tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info,
PatternRewriter &rewriter) {
if (!info.resultType.isSigned()) {
return failure();
}
double outputMultiplierReal = info.lhsType.getScale() *
info.rhsType.getScale() /
info.resultType.getScale();
if (outputMultiplierReal > 1.0) {
info.op->emitWarning(
"unimplemented: cannot multiply with multiplier > 1.0");
return failure();
}
// TODO: Choose an appropriate intermediate width for muls > 8 bits to
// avoid overflow.
unsigned intermediateWidth = 32;
IntegerType intermediateElementType =
IntegerType::get(intermediateWidth, rewriter.getContext());
Type intermediateType =
castElementType(info.resultStorageType, intermediateElementType);
// Cast operands to storage type.
Value lhsValue = rewriter
.create<StorageCastOp>(info.op->getLoc(),
info.lhsStorageType, info.lhs)
.getResult();
Value rhsValue = rewriter
.create<StorageCastOp>(info.op->getLoc(),
info.rhsStorageType, info.rhs)
.getResult();
// Cast to the intermediate sized type.
lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
lhsValue);
rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
rhsValue);
// Apply argument zeroPoints.
if (info.lhsType.getZeroPoint() != 0) {
Value zpOffsetConst = rewriter.create<ConstantOp>(
info.op->getLoc(), broadcastScalarConstIntValue(
intermediateType, -info.lhsType.getZeroPoint()));
lhsValue =
rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, zpOffsetConst);
}
if (info.rhsType.getZeroPoint() != 0) {
Value zpOffsetConst = rewriter.create<ConstantOp>(
info.op->getLoc(), broadcastScalarConstIntValue(
intermediateType, -info.rhsType.getZeroPoint()));
rhsValue =
rewriter.create<AddIOp>(info.op->getLoc(), rhsValue, zpOffsetConst);
}
// Mul.
Value resultValue =
rewriter.create<MulIOp>(info.op->getLoc(), lhsValue, rhsValue);
// Scale output.
QuantizedMultiplierSmallerThanOneExp outputMultiplier(outputMultiplierReal);
resultValue = rewriter.create<VecScalarSaturatingRoundingDoublingHighMulISOp>(
info.op->getLoc(), resultValue,
IntegerAttr::get(intermediateElementType, outputMultiplier.multiplier));
resultValue = rewriter.create<RoundingDivideByPotISOp>(
info.op->getLoc(), resultValue,
IntegerAttr::get(intermediateElementType, -outputMultiplier.exponent));
// Zero point offset adjustment.
if (info.resultType.getZeroPoint() != 0) {
Value zpOffsetConst = rewriter.create<ConstantOp>(
info.op->getLoc(),
broadcastScalarConstIntValue(intermediateType,
info.resultType.getZeroPoint()));
resultValue =
rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
}
// Clamp.
auto clampMinMax = info.getClampMinMax(intermediateElementType);
resultValue = rewriter.create<ClampISOp>(
info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
// Convert back to original type.
resultValue = rewriter.create<ConvertISOp>(
info.op->getLoc(), info.resultStorageType, resultValue);
// Cast back for new result.
rewriter.replaceOpWithNewOp<StorageCastOp>(
info.op, info.getQuantizedResultType(), resultValue);
return success();
}
namespace {
struct UniformRealAddEwPattern : public OpRewritePattern<RealAddEwOp> {
using OpRewritePattern<RealAddEwOp>::OpRewritePattern;
LogicalResult matchAndRewrite(RealAddEwOp op,
PatternRewriter &rewriter) const override {
const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
op.clamp_max());
if (!info.isValid()) {
return failure();
}
// Try all of the permutations we support.
if (succeeded(tryRewriteAffineAddEwIsomorphicSigned(info, rewriter))) {
return success();
}
return failure();
}
};
struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> {
using OpRewritePattern<RealMulEwOp>::OpRewritePattern;
LogicalResult matchAndRewrite(RealMulEwOp op,
PatternRewriter &rewriter) const override {
const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
op.clamp_max());
if (!info.isValid()) {
return failure();
}
// Try all of the permutations we support.
if (succeeded(tryRewriteAffineMulEwSigned(info, rewriter))) {
return success();
}
return failure();
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// LowerUniformRealMath pass
//===----------------------------------------------------------------------===//
void LowerUniformRealMathPass::runOnFunction() {
auto fn = getFunction();
OwningRewritePatternList patterns;
auto *context = &getContext();
patterns.insert<UniformRealAddEwPattern, UniformRealMulEwPattern>(context);
applyPatternsGreedily(fn, patterns);
}
std::unique_ptr<OpPassBase<FuncOp>>
mlir::fxpmath::createLowerUniformRealMathPass() {
return std::make_unique<LowerUniformRealMathPass>();
}
//===----------------------------------------------------------------------===//
// LowerUniformCasts pass
//===----------------------------------------------------------------------===//
void LowerUniformCastsPass::runOnFunction() {
auto fn = getFunction();
OwningRewritePatternList patterns;
auto *context = &getContext();
patterns.insert<UniformDequantizePattern>(context);
applyPatternsGreedily(fn, patterns);
}
std::unique_ptr<OpPassBase<FuncOp>>
mlir::fxpmath::createLowerUniformCastsPass() {
return std::make_unique<LowerUniformCastsPass>();
}

View File

@ -1,227 +0,0 @@
//===- UniformKernelUtils.h - Utilities for lowering uniform math - C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
#define MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/Dialect/Quant/UniformSupport.h"
#include "mlir/IR/Operation.h"
#include <cmath>
namespace mlir {
namespace fxpmath {
namespace detail {
inline quant::UniformQuantizedType getUniformElementType(Type t) {
return quant::QuantizedType::getQuantizedElementType(t)
.dyn_cast_or_null<quant::UniformQuantizedType>();
}
inline bool hasStorageBitWidth(quant::QuantizedType t,
ArrayRef<unsigned> checkWidths) {
unsigned w = t.getStorageType().getIntOrFloatBitWidth();
for (unsigned checkWidth : checkWidths) {
if (w == checkWidth)
return true;
}
return false;
}
/// Computes the log2(x), rounded to an integral value. Returns whether 'x' can
/// be considered an exact integral value.
template <typename F> bool integralLog2(F x, int &log2Result) {
const F xLog2 = std::log(x) * (1.0 / std::log(2.0));
const F xLog2Rounded = std::round(xLog2);
const F xLog2Frac = xLog2 - xLog2Rounded;
log2Result = static_cast<int>(xLog2Rounded);
// Allow small comparison slop below the level that would make a difference
// for 2^16 levels.
return std::abs(xLog2Frac) < 1e-6;
}
/// Helper class for operating on binary operations where all operands
/// and the result are a UniformQuantizedType.
struct UniformBinaryOpInfo {
UniformBinaryOpInfo(Operation *op, Value lhs, Value rhs,
Optional<APFloat> clampMin, Optional<APFloat> clampMax)
: op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax),
lhsType(getUniformElementType(lhs.getType())),
rhsType(getUniformElementType(rhs.getType())),
resultType(getUniformElementType(*op->result_type_begin())),
lhsStorageType(quant::QuantizedType::castToStorageType(lhs.getType())),
rhsStorageType(quant::QuantizedType::castToStorageType(rhs.getType())),
resultStorageType(
quant::QuantizedType::castToStorageType(*op->result_type_begin())) {
}
/// Returns whether this info is valid (all types defined, etc).
bool isValid() const {
return lhsType && rhsType && resultType && lhsStorageType &&
rhsStorageType && resultStorageType;
}
/// Gets the final quantized result type of the result.
Type getQuantizedResultType() const { return *op->result_type_begin(); }
/// Returns whether the storage type of all operands is identical.
bool isSameStorageType() const {
return lhsType.getStorageType() == rhsType.getStorageType() &&
lhsType.getStorageType() == resultType.getStorageType();
}
/// Returns whether all operands and result are considered fixedpoint power
/// of two, setting the lhs, rhs, and result log2 scale references.
bool isFixedPointPOT(int &lhsLog2Scale, int &rhsLog2Scale,
int &resultLog2Scale) const {
if (!lhsType.isFixedPoint() || !rhsType.isFixedPoint() ||
!resultType.isFixedPoint()) {
return false;
}
if (!integralLog2(lhsType.getScale(), lhsLog2Scale) ||
!integralLog2(rhsType.getScale(), rhsLog2Scale) ||
!integralLog2(resultType.getScale(), resultLog2Scale)) {
return false;
}
return true;
}
/// Gets the result integer clamp range given the result quantized type
// and any explicit clamp provided as attributes.
std::pair<IntegerAttr, IntegerAttr> getClampMinMax(IntegerType ty) const {
int64_t typeMin = resultType.getStorageTypeMin();
int64_t typeMax = resultType.getStorageTypeMax();
if (clampMin || clampMax) {
quant::UniformQuantizedValueConverter conv(resultType);
if (clampMin) {
typeMin = std::max(typeMin, conv.quantizeFloatToInt64(*clampMin));
}
if (clampMax) {
typeMax = std::min(typeMax, conv.quantizeFloatToInt64(*clampMax));
}
}
// The quantized, integral ops expect clamps as 32bit ints.
return {
IntegerAttr::get(ty, typeMin),
IntegerAttr::get(ty, typeMax),
};
}
Operation *op;
Value lhs;
Value rhs;
Optional<APFloat> clampMin;
Optional<APFloat> clampMax;
// Element UniformQuantizedType for operands/result.
quant::UniformQuantizedType lhsType;
quant::UniformQuantizedType rhsType;
quant::UniformQuantizedType resultType;
// Full storage-based types.
Type lhsStorageType;
Type rhsStorageType;
Type resultStorageType;
};
/// Derives a quantized multiplier and shift from a real valued multiplier
/// less than 1.
struct QuantizedMultiplierSmallerThanOneExp {
QuantizedMultiplierSmallerThanOneExp(double realMultiplier) {
assert(realMultiplier < 1.0);
assert(realMultiplier > 0.0);
const double q = std::frexp(realMultiplier, &exponent);
auto qFixed = static_cast<int64_t>(std::round(q * (1ll << 31)));
assert(qFixed <= (1ll << 31));
if (qFixed == (1ll << 31)) {
qFixed /= 2;
++exponent;
}
assert(qFixed <= std::numeric_limits<int32_t>::max());
multiplier = static_cast<int32_t>(qFixed);
}
int32_t multiplier;
int exponent;
};
/// Casts an integer or floating point based shaped type to a new element type.
inline Type castElementType(Type t, Type newElementType) {
if (auto st = t.dyn_cast<ShapedType>()) {
switch (st.getKind()) {
case StandardTypes::Kind::Vector:
return VectorType::get(st.getShape(), newElementType);
case StandardTypes::Kind::RankedTensor:
return RankedTensorType::get(st.getShape(), newElementType);
case StandardTypes::Kind::UnrankedTensor:
return UnrankedTensorType::get(newElementType);
case StandardTypes::Kind::MemRef:
return MemRefType::Builder(st.cast<MemRefType>())
.setElementType(newElementType);
}
}
assert(t.isSignlessIntOrFloat());
return newElementType;
}
/// Creates an IntegerAttr with a type that matches the shape of 't' (which can
/// be a scalar primitive or a shaped type).
inline Attribute broadcastScalarConstIntValue(Type t, int64_t value) {
if (auto st = t.dyn_cast<ShapedType>()) {
assert(st.getElementType().isSignlessInteger());
return DenseElementsAttr::get(st,
IntegerAttr::get(st.getElementType(), value));
}
auto integerType = t.cast<IntegerType>();
assert(t.isSignlessInteger() && "integer broadcast must be of integer type");
return IntegerAttr::get(integerType, value);
}
/// Given an APFloat, converts it to the float semantics that matches the
/// given FloatType, silently ignoring inexact conversions.
inline APFloat convertFloatToType(FloatType ft, APFloat value) {
bool losesInfo;
auto status = value.convert(ft.getFloatSemantics(),
APFloat::rmNearestTiesToEven, &losesInfo);
(void)status; // unused in opt mode
assert((status & (APFloat::opDivByZero | APFloat::opInvalidOp)) == 0 &&
"could not convert to float const");
return value;
}
/// Creates a FloatAttr with a type that matches the shape of 't' (which can be
/// a scalar primitive or a shaped type).
inline Attribute broadcastScalarConstFloatValue(Type t, APFloat value) {
if (auto st = t.dyn_cast<ShapedType>()) {
FloatType floatElementType = st.getElementType().dyn_cast<FloatType>();
assert(floatElementType &&
"float broadcast element type must be float like");
APFloat apValue = convertFloatToType(floatElementType, value);
return DenseElementsAttr::get(st,
FloatAttr::get(st.getElementType(), apValue));
} else {
auto floatType = t.dyn_cast<FloatType>();
assert(floatType && "float broadcast must be of float type");
APFloat apValue = convertFloatToType(floatType, value);
return FloatAttr::get(floatType, apValue);
}
}
} // namespace detail
} // namespace fxpmath
} // namespace mlir
#endif // MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_

View File

@ -1,63 +0,0 @@
# Support.
add_mlir_library(MLIRQuantizerSupport
Support/Configuration.cpp
Support/ConstraintAnalysisGraph.cpp
Support/Metadata.cpp
Support/Statistics.cpp
Support/TypeUtils.cpp
Support/UniformConstraints.cpp
Support/UniformSolvers.cpp
ADDITIONAL_HEADER_DIRS
)
target_link_libraries(MLIRQuantizerSupport
PUBLIC
MLIRIR
MLIRQuant
MLIRSupport
MLIRStandardOps
LLVMSupport
)
# Configurations.
add_mlir_library(MLIRQuantizerFxpMathConfig
Configurations/FxpMathConfig.cpp
ADDITIONAL_HEADER_DIRS
DEPENDS
MLIRFxpMathOpsIncGen
)
target_link_libraries(MLIRQuantizerFxpMathConfig
PUBLIC
MLIRIR
MLIRFxpMathOps
MLIRQuant
MLIRQuantizerSupport
MLIRStandardOps
MLIRSupport
LLVMSupport
)
# Transforms.
add_mlir_library(MLIRQuantizerTransforms
Transforms/AddDefaultStatsTestPass.cpp
Transforms/InferQuantizedTypesPass.cpp
Transforms/RemoveInstrumentationPass.cpp
ADDITIONAL_HEADER_DIRS
DEPENDS
MLIRQuantizerPassIncGen
)
target_link_libraries(MLIRQuantizerTransforms
PUBLIC
MLIRIR
MLIRQuantizerFxpMathConfig
MLIRQuantizerSupport
MLIRQuant
MLIRPass
LLVMSupport
)

View File

@ -1,278 +0,0 @@
//===- FxpMathConfig.cpp - Reference fixed point config -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines a TargetConfiguration for reference fixed-point math
// quantization scheme based on the FxpMathOps (plus a small category of
// extension ops that can be added from other dialects).
//
//===----------------------------------------------------------------------===//
#include "mlir/Quantizer/Configurations/FxpMathConfig.h"
#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
#include "mlir/Quantizer/Support/Metadata.h"
#include "mlir/Quantizer/Support/Statistics.h"
#include "mlir/Quantizer/Support/UniformConstraints.h"
using namespace mlir;
using namespace mlir::quantizer;
using namespace mlir::fxpmath;
using namespace mlir::quant;
using namespace std::placeholders;
namespace {
struct FxpMathTargetConfigImpl : public FxpMathTargetConfig {
FxpMathTargetConfigImpl(SolverContext &context)
: FxpMathTargetConfig(context) {
Builder b(&context.getMlirContext());
IntegerType i8Type = b.getIntegerType(8);
IntegerType i16Type = b.getIntegerType(16);
IntegerType i32Type = b.getIntegerType(32);
q8 = addCandidateType(
AnyQuantizedType::get(QuantizationFlags::Signed, i8Type, nullptr,
std::numeric_limits<int8_t>::min(),
std::numeric_limits<int8_t>::max()),
CandidateQuantizedType::Scheme::UniformPerLayer);
q16 = addCandidateType(
AnyQuantizedType::get(QuantizationFlags::Signed, i16Type, nullptr,
std::numeric_limits<int16_t>::min(),
std::numeric_limits<int16_t>::max()),
CandidateQuantizedType::Scheme::UniformPerLayer);
q32ExplicitFixedPoint = addCandidateType(
AnyQuantizedType::get(QuantizationFlags::Signed, i32Type, nullptr,
std::numeric_limits<int32_t>::min(),
std::numeric_limits<int32_t>::max()),
CandidateQuantizedType::Scheme::UniformExplicitFixedPointScale);
// Op handlers.
addOpHandler<ConstantOp>(
std::bind(&FxpMathTargetConfigImpl::handleConstant, this, _1, _2));
addOpHandler<mlir::ReturnOp>(
std::bind(&FxpMathTargetConfigImpl::handleTerminal, this, _1, _2));
addOpHandler<quant::StatisticsOp>(
std::bind(&FxpMathTargetConfigImpl::handleStats, this, _1, _2));
// FxpMathOps.
addOpHandler<RealAddEwOp>(
std::bind(&FxpMathTargetConfigImpl::handleAdd, this, _1, _2));
addOpHandler<RealMulEwOp>(
std::bind(&FxpMathTargetConfigImpl::handleMul, this, _1, _2));
addOpHandler<RealMatMulOp>(
std::bind(&FxpMathTargetConfigImpl::handleMatMul, this, _1, _2));
addOpHandler<RealMatMulBiasOp>(
std::bind(&FxpMathTargetConfigImpl::handleMatMulBias, this, _1, _2));
// Require stats ops.
addRequireStatsOp<RealAddEwOp>();
addRequireStatsOp<RealSubEwOp>();
addRequireStatsOp<RealDivEwOp>();
addRequireStatsOp<RealMulEwOp>();
addRequireStatsOp<RealMatMulOp>();
addRequireStatsOp<RealMatMulBiasOp>();
}
bool isHandledType(Type t) const final {
if (t.isa<FloatType>())
return true;
return (t.isa<VectorType>() || t.isa<TensorType>()) &&
t.cast<ShapedType>().getElementType().isa<FloatType>();
}
void finalizeAnchors(CAGSlice &cag) const override {
cag.enumerateImpliedConnections(
[&](CAGAnchorNode *from, CAGAnchorNode *to) {
UniformConstraintsBuilder(cag).coupleAnchors(from, to);
});
}
void addValueIdentityOpByName(StringRef opName) override {
addOpHandlerByName(
opName,
std::bind(&FxpMathTargetConfigImpl::handleValueIdentity, this, _1, _2));
}
void handleValueIdentity(Operation *op, CAGSlice &cag) const {
assert(op->getNumResults() == 1);
if (!isHandledType(op->getResult(0).getType()))
return;
auto resultNode = cag.getResultAnchor(op, 0);
resultNode->setTypeTransformRule(
CAGAnchorNode::TypeTransformRule::DirectStorage);
for (unsigned opIdx = 0, e = op->getNumOperands(); opIdx < e; ++opIdx) {
if (!isHandledType(op->getOperand(opIdx).getType()))
continue;
auto operandNode = cag.getOperandAnchor(op, opIdx);
operandNode->setTypeTransformRule(
CAGAnchorNode::TypeTransformRule::DirectStorage);
UniformConstraintsBuilder(cag).coupleAnchors(operandNode, resultNode);
}
}
void handleConstant(Operation *op, CAGSlice &cag) const {
if (!isHandledType(op->getResult(0).getType()))
return;
auto resultNode = cag.getResultAnchor(op, 0);
resultNode->setTypeTransformRule(
CAGAnchorNode::TypeTransformRule::ExpressedOnly);
Attribute valueAttr;
if (!matchPattern(op, m_Constant(&valueAttr))) {
return;
}
AttributeTensorStatistics stats(valueAttr);
TensorAxisStatistics layerStats;
if (!stats.get(layerStats)) {
op->emitOpError("could not compute statistics");
return;
}
UniformConstraintsBuilder(cag).applyStats(resultNode, layerStats);
}
void handleTerminal(Operation *op, CAGSlice &cag) const {
if (!isHandledType(op->getOperand(0).getType()))
return;
auto operandNode = cag.getOperandAnchor(op, 0);
operandNode->setTypeTransformRule(
CAGAnchorNode::TypeTransformRule::ExpressedOnly);
}
void handleStats(Operation *op, CAGSlice &cag) const {
if (!isHandledType(op->getResult(0).getType()))
return;
auto argNode = cag.getOperandAnchor(op, 0);
auto resultNode = cag.getResultAnchor(op, 0);
UniformConstraintsBuilder(cag).coupleAnchors(argNode, resultNode);
TensorAxisStatistics layerStats;
auto statsOp = cast<quant::StatisticsOp>(op);
auto layerStatsAttr = statsOp.layerStats();
layerStats.minValue =
layerStatsAttr.getValue<FloatAttr>(0).getValueAsDouble();
layerStats.maxValue =
layerStatsAttr.getValue<FloatAttr>(1).getValueAsDouble();
UniformConstraintsBuilder(cag).applyStats(resultNode, layerStats);
}
void handleAdd(Operation *op, CAGSlice &cag) const {
if (!isHandledType(op->getResult(0).getType()))
return;
auto lhs = cag.getOperandAnchor(op, 0);
auto rhs = cag.getOperandAnchor(op, 1);
auto resultNode = cag.getResultAnchor(op, 0);
// Add supports 8/16 bit math.
llvm::SmallBitVector disableMask =
getCandidateTypeDisabledExceptMask({q8, q16});
lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
// NOTE: We couple the add such that the scale/zeroPoint match between
// both args and the result. This is overly constrained in that it is
// possible to write efficient add kernels with a bit more freedom (i.e.
// zeroPoints can vary, scales can differ by a power of two, etc).
// However, fully coupled yields the simples solutions on the fast path.
// Further efficiency can be had by constraining the zeroPoint to 0, but
// there isn't a constraint for this yet (and there are tradeoffs).
UniformConstraintsBuilder(cag).coupleAnchors(lhs, resultNode);
UniformConstraintsBuilder(cag).coupleAnchors(rhs, resultNode);
addRealMathOptionalConstraints(op, resultNode, cag);
}
void handleMul(Operation *op, CAGSlice &cag) const {
if (!isHandledType(op->getResult(0).getType()))
return;
auto lhs = cag.getOperandAnchor(op, 0);
auto rhs = cag.getOperandAnchor(op, 1);
auto resultNode = cag.getResultAnchor(op, 0);
// Mul supports 8/16 bit math.
llvm::SmallBitVector disableMask =
getCandidateTypeDisabledExceptMask({q8, q16});
lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
addRealMathOptionalConstraints(op, resultNode, cag);
}
void handleMatMul(Operation *op, CAGSlice &cag) const {
if (!isHandledType(op->getResult(0).getType()))
return;
auto lhs = cag.getOperandAnchor(op, 0);
auto rhs = cag.getOperandAnchor(op, 1);
auto resultNode = cag.getResultAnchor(op, 0);
// Mul supports 8/16 bit math.
llvm::SmallBitVector disableMask =
getCandidateTypeDisabledExceptMask({q8, q16});
lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
addRealMathOptionalConstraints(op, resultNode, cag);
}
void handleMatMulBias(Operation *op, CAGSlice &cag) const {
if (!isHandledType(op->getResult(0).getType()))
return;
auto lhs = cag.getOperandAnchor(op, 0);
auto rhs = cag.getOperandAnchor(op, 1);
auto bias = cag.getOperandAnchor(op, 2);
bias->getUniformMetadata().disabledCandidateTypes =
getCandidateTypeDisabledExceptMask({q32ExplicitFixedPoint});
auto resultNode = cag.getResultAnchor(op, 0);
UniformConstraintsBuilder(cag).propagateExplicitScale(resultNode, bias);
// Mul supports 8/16 bit math.
llvm::SmallBitVector disableMask =
getCandidateTypeDisabledExceptMask({q8, q16});
lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
addRealMathOptionalConstraints(op, resultNode, cag);
}
void addRealMathOptionalConstraints(Operation *op, CAGAnchorNode *anchor,
CAGSlice &cag) const {
// TODO: It would be nice if these all extended some base trait instead
// of requiring name lookup.
auto clampMinAttr = op->getAttrOfType<FloatAttr>("clamp_min");
auto clampMaxAttr = op->getAttrOfType<FloatAttr>("clamp_max");
if (clampMinAttr || clampMaxAttr) {
auto nan = APFloat::getQNaN(APFloat::IEEEdouble());
auto clampMin = clampMinAttr ? clampMinAttr.getValue() : nan;
auto clampMax = clampMaxAttr ? clampMaxAttr.getValue() : nan;
UniformConstraintsBuilder(cag).clamp(anchor, clampMin, clampMax);
}
}
unsigned q8;
unsigned q16;
unsigned q32ExplicitFixedPoint;
};
} // anonymous namespace
std::unique_ptr<FxpMathTargetConfig>
FxpMathTargetConfig::create(SolverContext &context) {
return std::make_unique<FxpMathTargetConfigImpl>(context);
}

View File

@ -1,39 +0,0 @@
//===- Configuration.cpp - Configuration object base classes --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Quantizer/Support/Configuration.h"
#include <limits>
#include "mlir/IR/Builders.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/MLIRContext.h"
using namespace mlir;
using namespace mlir::quantizer;
TargetConfiguration::TargetConfiguration(SolverContext &context) {}
void TargetConfiguration::addOpHandlerByName(StringRef name, OpHandlerFn fn) {
opHandlers[name] = fn;
}
void TargetConfiguration::addRequireStatsOpByName(StringRef opName) {
requireStatsOpNames.insert(opName);
}
bool TargetConfiguration::isRequireStatsOp(Operation *op) const {
return requireStatsOpNames.find(op->getName().getStringRef()) !=
requireStatsOpNames.end();
}
void TargetConfiguration::handleOp(Operation *op, CAGSlice &cag) const {
auto found_it = opHandlers.find(op->getName().getStringRef());
if (found_it != opHandlers.end())
found_it->second(op, cag);
}

View File

@ -1,172 +0,0 @@
//===- ConstraintAnalysisGraph.cpp - Graphs type for constraints ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Quantizer/Support/Configuration.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::quantizer;
void CAGNode::replaceIncoming(CAGNode *otherNode) {
if (this == otherNode)
return;
for (CAGNode *parentNode : incoming) {
for (CAGNode *&it : parentNode->outgoing) {
if (it == this) {
it = otherNode;
otherNode->incoming.push_back(parentNode);
}
}
}
incoming.clear();
}
void CAGNode::addOutgoing(CAGNode *toNode) {
if (!llvm::is_contained(outgoing, toNode)) {
outgoing.push_back(toNode);
toNode->incoming.push_back(this);
}
}
CAGOperandAnchor::CAGOperandAnchor(Operation *op, unsigned operandIdx)
: CAGAnchorNode(Kind::OperandAnchor, op->getOperand(operandIdx).getType()),
op(op), operandIdx(operandIdx) {}
CAGResultAnchor::CAGResultAnchor(Operation *op, unsigned resultIdx)
: CAGAnchorNode(Kind::ResultAnchor, op->getResult(resultIdx).getType()),
resultValue(op->getResult(resultIdx)) {}
CAGSlice::CAGSlice(SolverContext &context) : context(context) {}
CAGSlice::~CAGSlice() { llvm::DeleteContainerPointers(allNodes); }
CAGOperandAnchor *CAGSlice::getOperandAnchor(Operation *op,
unsigned operandIdx) {
assert(operandIdx < op->getNumOperands() && "illegal operand index");
// Dedup.
auto key = std::make_pair(op, operandIdx);
auto foundIt = operandAnchors.find(key);
if (foundIt != operandAnchors.end()) {
return foundIt->second;
}
// Create.
auto anchor = std::make_unique<CAGOperandAnchor>(op, operandIdx);
auto *unowned = anchor.release();
unowned->nodeId = allNodes.size();
allNodes.push_back(unowned);
operandAnchors.insert(std::make_pair(key, unowned));
return unowned;
}
CAGResultAnchor *CAGSlice::getResultAnchor(Operation *op, unsigned resultIdx) {
assert(resultIdx < op->getNumResults() && "illegal result index");
// Dedup.
auto key = std::make_pair(op, resultIdx);
auto foundIt = resultAnchors.find(key);
if (foundIt != resultAnchors.end()) {
return foundIt->second;
}
// Create.
auto anchor = std::make_unique<CAGResultAnchor>(op, resultIdx);
auto *unowned = anchor.release();
unowned->nodeId = allNodes.size();
allNodes.push_back(unowned);
resultAnchors.insert(std::make_pair(key, unowned));
return unowned;
}
void CAGSlice::enumerateImpliedConnections(
std::function<void(CAGAnchorNode *from, CAGAnchorNode *to)> callback) {
// Discover peer identity pairs (i.e. implied edges from Result->Operand and
// Arg->Call). Use an intermediate vector so that the callback can modify.
std::vector<std::pair<CAGAnchorNode *, CAGAnchorNode *>> impliedPairs;
for (auto &resultAnchorPair : resultAnchors) {
CAGResultAnchor *resultAnchor = resultAnchorPair.second;
Value resultValue = resultAnchor->getValue();
for (auto &use : resultValue.getUses()) {
Operation *operandOp = use.getOwner();
unsigned operandIdx = use.getOperandNumber();
auto foundIt = operandAnchors.find(std::make_pair(operandOp, operandIdx));
if (foundIt != operandAnchors.end()) {
impliedPairs.push_back(std::make_pair(resultAnchor, foundIt->second));
}
}
}
// Callback for each pair.
for (auto &impliedPair : impliedPairs) {
callback(impliedPair.first, impliedPair.second);
}
}
unsigned CAGSlice::propagate(const TargetConfiguration &config) {
std::vector<CAGNode *> dirtyNodes;
dirtyNodes.reserve(allNodes.size());
// Note that because iteration happens in nodeId order, there is no need
// to sort in order to make deterministic. If the selection method changes,
// a sort should be explicitly done.
for (CAGNode *child : *this) {
if (child->isDirty()) {
dirtyNodes.push_back(child);
}
}
if (dirtyNodes.empty()) {
return 0;
}
for (auto dirtyNode : dirtyNodes) {
dirtyNode->clearDirty();
dirtyNode->propagate(context, config);
}
return dirtyNodes.size();
}
void CAGAnchorNode::propagate(SolverContext &solverContext,
const TargetConfiguration &config) {
for (CAGNode *child : *this) {
child->markDirty();
}
}
Type CAGAnchorNode::getTransformedType() {
if (!getUniformMetadata().selectedType) {
return nullptr;
}
return getUniformMetadata().selectedType.castFromExpressedType(
getOriginalType());
}
void CAGNode::printLabel(raw_ostream &os) const {
os << "Node<" << static_cast<const void *>(this) << ">";
}
void CAGAnchorNode::printLabel(raw_ostream &os) const {
getUniformMetadata().printSummary(os);
}
void CAGOperandAnchor::printLabel(raw_ostream &os) const {
os << "Operand<";
op->getName().print(os);
os << "," << operandIdx;
os << ">";
CAGAnchorNode::printLabel(os);
}
void CAGResultAnchor::printLabel(raw_ostream &os) const {
os << "Result<";
getOp()->getName().print(os);
os << ">";
CAGAnchorNode::printLabel(os);
}

View File

@ -1,33 +0,0 @@
//===- Metadata.cpp - Top level types and metadata ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Quantizer/Support/Metadata.h"
#include "mlir/IR/MLIRContext.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::quantizer;
void CAGUniformMetadata::printSummary(raw_ostream &os) const {
if (requiredRange.hasValue()) {
os << "\n[" << requiredRange.getValue().first << ","
<< requiredRange.getValue().second << "]";
}
if (disabledCandidateTypes.any()) {
os << "\n![";
mlir::interleaveComma(disabledCandidateTypes.set_bits(), os);
os << "]";
}
if (selectedType) {
os << "\n" << selectedType;
}
}

View File

@ -1,201 +0,0 @@
//===- Statistics.cpp - Collects statistics over tensors ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Quantizer/Support/Statistics.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/StandardTypes.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::quantizer;
//===----------------------------------------------------------------------===//
// AttributeTensorStatistics implementation
//===----------------------------------------------------------------------===//
static void collectElementsStatisticsDim(ElementsAttr attr,
unsigned numElements,
ArrayRef<int64_t> shape,
SmallVectorImpl<uint64_t> &indices,
uint64_t dim,
TensorAxisStatistics &statistics) {
// Recursive terminating condition.
if (dim >= shape.size())
return;
if (dim < (shape.size() - 1)) {
// Recurse past dim.
for (uint64_t i = 0, s = shape[dim]; i < s; ++i) {
indices[dim] = i;
collectElementsStatisticsDim(attr, numElements, shape, indices, dim + 1,
statistics);
}
return;
}
// Collection dim.
for (uint64_t i = 0, s = shape[dim]; i < s; ++i) {
indices[dim] = i;
double value = attr.getValue<FloatAttr>(indices).getValueAsDouble();
statistics.minValue = std::min(statistics.minValue, value);
statistics.maxValue = std::max(statistics.maxValue, value);
statistics.mean += value / numElements;
// TODO: Calculate a running variance.
}
}
static void collectElementsStatisticsDimForAxis(
unsigned axis, ElementsAttr attr, unsigned numElements,
ArrayRef<int64_t> shape, SmallVectorImpl<uint64_t> &indices, uint64_t dim,
TensorAxisStatistics &statistics) {
// Recursive terminating condition.
if (dim >= shape.size())
return;
// Axis is passed separately
if (dim == axis) {
collectElementsStatisticsDimForAxis(axis, attr, numElements, shape, indices,
dim + 1, statistics);
return;
}
// Go to last not axis dim
if (dim < (shape.size() - 2) ||
(dim == (shape.size() - 2) && axis != (shape.size() - 1))) {
// Recurse past dim.
for (uint64_t i = 0, s = shape[dim]; i < s; ++i) {
indices[dim] = i;
collectElementsStatisticsDimForAxis(axis, attr, numElements, shape,
indices, dim + 1, statistics);
}
return;
}
// Pass axis
uint64_t axisSize = shape[axis];
for (uint64_t axisIdx = 0; axisIdx < axisSize; ++axisIdx) {
indices[axis] = axisIdx;
// Collection dim.
for (uint64_t i = 0, s = shape[dim]; i < s; ++i) {
indices[dim] = i;
double value = attr.getValue<FloatAttr>(indices).getValueAsDouble();
statistics.minValuePerAxis[axisIdx] =
std::min(statistics.minValuePerAxis[axisIdx], value);
statistics.maxValuePerAxis[axisIdx] =
std::max(statistics.maxValuePerAxis[axisIdx], value);
statistics.meanPerAxis[axisIdx] += value / numElements;
// TODO: Calculate a running variance.
}
}
}
static bool getElementsStatistics(ElementsAttr attr,
TensorAxisStatistics &statistics) {
ShapedType sType = attr.getType();
if (!sType.hasStaticShape())
return false;
Type elementTy = sType.getElementType();
if (!elementTy.isa<FloatType>())
return false;
SmallVector<uint64_t, 4> indices;
indices.resize(sType.getRank());
ArrayRef<int64_t> shape = sType.getShape();
statistics.minValue = std::numeric_limits<double>::infinity();
statistics.maxValue = -std::numeric_limits<double>::infinity();
statistics.mean = 0;
statistics.variance = 0;
auto numElements = sType.getNumElements();
collectElementsStatisticsDim(attr, numElements, shape, indices, 0,
statistics);
statistics.sampleSize = numElements;
return true;
}
static bool getElementsStatisticsForAxis(unsigned axis, ElementsAttr attr,
TensorAxisStatistics &statistics) {
ShapedType sType = attr.getType();
if (!sType.hasStaticShape() || axis >= sType.getRank())
return false;
Type elementTy = sType.getElementType();
if (!elementTy.isa<FloatType>())
return false;
SmallVector<uint64_t, 4> indices;
indices.resize(sType.getRank());
ArrayRef<int64_t> shape = sType.getShape();
uint64_t axisSize = shape[axis];
statistics.minValuePerAxis.assign(axisSize,
std::numeric_limits<double>::infinity());
statistics.maxValuePerAxis.assign(axisSize,
-std::numeric_limits<double>::infinity());
statistics.meanPerAxis.assign(axisSize, 0);
statistics.variancePerAxis.assign(axisSize, 0);
uint64_t numElements = sType.getNumElements() / shape[axis];
collectElementsStatisticsDimForAxis(axis, attr, numElements, shape, indices,
0, statistics);
statistics.sampleSizePerAxis = numElements;
return true;
}
bool AttributeTensorStatistics::get(TensorAxisStatistics &stats) const {
if (FloatAttr floatAttr = attr.dyn_cast<FloatAttr>()) {
double value = floatAttr.getValueAsDouble();
stats = TensorAxisStatistics(1, value, value, value, 0);
return true;
} else if (auto eltAttr = attr.dyn_cast<ElementsAttr>()) {
return getElementsStatistics(eltAttr, stats);
}
return false;
}
bool AttributeTensorStatistics::supportsPerAxis() const {
if (auto eltAttr = attr.dyn_cast<ElementsAttr>())
return eltAttr.getType().getRank() > 1;
return false;
}
unsigned AttributeTensorStatistics::getAxisCount() const {
if (!supportsPerAxis())
return 0;
return attr.cast<ElementsAttr>().getType().getRank();
}
bool AttributeTensorStatistics::getForAxis(unsigned axis,
TensorAxisStatistics &stats) const {
if (!supportsPerAxis())
return false;
auto eltAttr = attr.cast<ElementsAttr>();
return getElementsStatisticsForAxis(axis, eltAttr, stats);
}
raw_ostream &mlir::quantizer::operator<<(raw_ostream &os,
const TensorAxisStatistics &stats) {
os << "STATS[sampleSizeLayer=" << stats.sampleSize
<< ", minValueLayer=" << stats.minValue
<< ", maxValueLayer=" << stats.maxValue << ", meanLayer=" << stats.mean
<< ", varianceLayer=" << stats.variance
<< ", sampleSizePerAxis=" << stats.sampleSizePerAxis << ", statsPerAxis={";
for (unsigned i = 0, n = stats.minValuePerAxis.size(); i < n; ++i) {
os << "minValue=" << stats.minValuePerAxis[i]
<< ", maxValue=" << stats.maxValuePerAxis[i]
<< ", mean=" << stats.meanPerAxis[i]
<< ", variance=" << stats.variancePerAxis[i];
if (i != n - 1)
os << "; ";
}
os << "}]";
return os;
}

View File

@ -1,22 +0,0 @@
//===- TypeUtils.cpp - Helper function for manipulating types -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Quantizer/Support/TypeUtils.h"
#include "mlir/IR/StandardTypes.h"
using namespace mlir;
using namespace mlir::quantizer;
Type mlir::quantizer::getElementOrPrimitiveType(Type t) {
if (auto sType = t.dyn_cast<ShapedType>()) {
return sType.getElementType();
} else {
return t;
}
}

View File

@ -1,256 +0,0 @@
//===- UniformConstraints.cpp - Constraints for uniform quant -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Quantizer/Support/UniformConstraints.h"
#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Quantizer/Support/Configuration.h"
#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
#include "mlir/Quantizer/Support/Metadata.h"
#include "mlir/Quantizer/Support/Rules.h"
#include "mlir/Quantizer/Support/TypeUtils.h"
#include "mlir/Quantizer/Support/UniformSolvers.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::quantizer;
using namespace mlir::quant;
namespace {
struct ClusteredFacts {
ExpandingMinMaxFact requiredRange;
DiscreteScaleZeroPointFact explicitScaleZeroPoint;
};
} // end anonymous namespace
static QuantizedType solveUniformType(SolverContext &solverContext,
const ClusteredFacts &clusteredFacts,
const CandidateQuantizedType &ct,
Type originalElementType, Location loc) {
switch (ct.scheme) {
default:
emitError(loc, "unsupported scheme for uniform type conversion");
return nullptr;
case CandidateQuantizedType::Scheme::UniformPerLayer: {
if (!clusteredFacts.requiredRange.hasValue()) {
// TODO: Issue some kind of diagnostic. This is not an error.
return nullptr;
}
uint64_t numLevels = ct.quantizedType.getStorageTypeMax() -
ct.quantizedType.getStorageTypeMin();
UniformStorageParams params{numLevels,
ct.quantizedType.getStorageTypeMin()};
UniformParamsFromMinMaxSolver solver(
params, clusteredFacts.requiredRange.getValue().first,
clusteredFacts.requiredRange.getValue().second);
if (!solver.compute()) {
emitWarning(loc) << "unable to solve uniform type with "
<< "UniformParamsFromMinMaxSolver";
return nullptr;
}
return UniformQuantizedType::getChecked(
ct.quantizedType.getFlags(), ct.quantizedType.getStorageType(),
originalElementType, solver.getScale(), solver.getZp(),
ct.quantizedType.getStorageTypeMin(),
ct.quantizedType.getStorageTypeMax(), loc);
}
case CandidateQuantizedType::Scheme::UniformExplicitFixedPointScale: {
if (!clusteredFacts.explicitScaleZeroPoint.hasValue()) {
emitRemark(loc)
<< "unable to solve uniform type with UniformExplicitFixedPointScale "
<< "(no explicitScaleZeroPoint)";
return nullptr;
}
const auto &scaleZp = clusteredFacts.explicitScaleZeroPoint.getValue();
assert(scaleZp.value && "optional value not set on fact");
if (scaleZp.conflict) {
emitWarning(loc)
<< "conflicting explicit scale/zeroPoint on node cluster: "
<< "an arbitrary scale/zeroPoint will be used";
}
return UniformQuantizedType::getChecked(
ct.quantizedType.getFlags(), ct.quantizedType.getStorageType(),
originalElementType,
scaleZp.value->first, // scale
0, // zeroPoint (fixed point solutions only for this scheme)
ct.quantizedType.getStorageTypeMin(),
ct.quantizedType.getStorageTypeMax(), loc);
return nullptr;
}
}
}
namespace {
class PropagateExplicitScale : public CAGConstraintNode {
public:
PropagateExplicitScale()
: CAGConstraintNode(Kind::UniformPropagateExplicitScale) {}
static bool classof(const CAGNode *n) {
return n->getKind() == Kind::Constraint ||
n->getKind() == Kind::UniformPropagateExplicitScale;
}
private:
void printLabel(raw_ostream &os) const override {
os << "PropagateExplicitScale";
}
void propagate(SolverContext &solverContext,
const TargetConfiguration &config) override {
DiscreteScaleZeroPointFact scaleZp;
// Get scale/zp from all parents.
for (auto it = incoming_begin(), e = incoming_end(); it != e; ++it) {
auto parentAnchor = cast<CAGAnchorNode>(*it);
auto selectedType = parentAnchor->getUniformMetadata().selectedType;
if (auto uqType = selectedType.dyn_cast_or_null<UniformQuantizedType>()) {
scaleZp.assertValue(
CAGUniformMetadata::SalienceRequired,
std::make_pair(uqType.getScale(), static_cast<int64_t>(0)));
}
}
// Propagate to children.
if (scaleZp.hasValue()) {
for (auto it = begin(), e = end(); it != e; ++it) {
auto childAnchor = cast<CAGAnchorNode>(*it);
if (modified(childAnchor->getUniformMetadata()
.explicitScaleZeroPoint.mergeFrom(scaleZp))) {
childAnchor->markDirty();
}
}
}
}
};
/// A constraint node which will solve uniform quantization for all parents
/// of the constraint, assuming that they are coupled.
class SolveUniformConstraintNode : public CAGConstraintNode {
public:
SolveUniformConstraintNode()
: CAGConstraintNode(Kind::SolveUniformConstraint) {
markDirty();
}
static bool classof(const CAGNode *n) {
return n->getKind() == Kind::Constraint ||
n->getKind() == Kind::SolveUniformConstraint;
}
private:
void printLabel(raw_ostream &os) const override { os << "SolveUniform"; }
void propagate(SolverContext &solverContext,
const TargetConfiguration &config) override {
// First determine the required min/max range and type constraints.
Location fusedLoc = UnknownLoc::get(&solverContext.getMlirContext());
llvm::SmallBitVector enabledCandidateTypesMask(
config.getAllCandidateTypesMask());
ClusteredFacts clusteredFacts;
Type originalElementType;
for (auto it = incoming_begin(), e = incoming_end(); it != e; ++it) {
auto parentAnchor = cast<CAGAnchorNode>(*it);
auto metadata = parentAnchor->getUniformMetadata();
// TODO: Possibly use a location that fuses all involved parents.
fusedLoc = parentAnchor->getOp()->getLoc();
// Shared element type.
auto parentOriginalElementType =
getElementOrPrimitiveType(parentAnchor->getOriginalType());
if (!originalElementType) {
originalElementType = parentOriginalElementType;
} else {
if (originalElementType != parentOriginalElementType) {
parentAnchor->getOp()->emitError()
<< "cannot compute uniform type: parent element types mismatch";
return;
}
}
// Range.
clusteredFacts.requiredRange.mergeFrom(metadata.requiredRange);
// Explicit scale and zero point.
clusteredFacts.explicitScaleZeroPoint.mergeFrom(
metadata.explicitScaleZeroPoint);
// Shared candidate types.
enabledCandidateTypesMask.reset(metadata.disabledCandidateTypes);
}
// Find the first enabled candidate type.
const CandidateQuantizedType *bestCandidateType = nullptr;
for (auto &ct : config.getCandidateTypes()) {
if (enabledCandidateTypesMask.test(ct.ordinal)) {
bestCandidateType = &ct;
break;
}
}
if (!bestCandidateType || !originalElementType) {
emitRemark(fusedLoc)
<< "not solving uniform type (no viable candidate type)";
return;
}
// Solve for the type.
QuantizedType selectedType =
solveUniformType(solverContext, clusteredFacts, *bestCandidateType,
originalElementType, fusedLoc);
// Apply it to all parents.
for (auto it = incoming_begin(), e = incoming_end(); it != e; ++it) {
auto parentAnchor = cast<CAGAnchorNode>(*it);
auto &metadata = parentAnchor->getUniformMetadata();
if (metadata.selectedType != selectedType) {
metadata.selectedType = selectedType;
// And mark all children of the parent dirty (except us).
for (auto child : *parentAnchor) {
if (child != this) {
child->markDirty();
}
}
}
}
}
};
} // end anonymous namespace
void UniformConstraintsBuilder::coupleAnchors(CAGAnchorNode *a,
CAGAnchorNode *b) {
slice.addClusteredConstraint<SolveUniformConstraintNode>({a, b});
}
void UniformConstraintsBuilder::applyStats(CAGAnchorNode *a,
TensorAxisStatistics stats) {
a->getUniformMetadata().requiredRange.assertValue(
CAGUniformMetadata::SalienceDefault, {stats.minValue, stats.maxValue});
}
void UniformConstraintsBuilder::clamp(CAGAnchorNode *a, APFloat minValue,
APFloat maxValue) {
a->getUniformMetadata().requiredRange.assertValue(
CAGUniformMetadata::SalienceDefault,
{minValue.convertToDouble(), maxValue.convertToDouble()});
}
void UniformConstraintsBuilder::propagateExplicitScale(CAGAnchorNode *from,
CAGAnchorNode *to) {
slice.addUnidirectionalConstraint<PropagateExplicitScale>(from, {to});
}

View File

@ -1,143 +0,0 @@
//===- UniformSolvers.cpp - Uniform type solver algorithms ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Quantizer/Support/UniformSolvers.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/raw_ostream.h"
#include <cmath>
using namespace mlir;
using namespace mlir::quantizer;
bool UniformParamsFromMinMaxSolver::compute() {
// Compute adjMin, adjMax, clamping to ensure that they straddle zero.
if (boundingMin > 0 && boundingMax >= boundingMin) {
// Lop-sided to the positive.
adjMin = 0;
adjMax = boundingMax;
} else if (boundingMax < 0 && boundingMin <= boundingMax) {
// Lop-sided to the negative.
adjMin = boundingMin;
adjMax = 0;
} else if (boundingMin <= 0 && boundingMax >= 0) {
adjMin = boundingMin;
adjMax = boundingMax;
} else {
// Illegal bounds.
return satisfied = false;
}
const double origMinAdj = adjMin;
const double origMaxAdj = adjMax;
const double numLevelsDouble = storageParams.numLevels;
struct fns {
static std::pair<double, double>
computeMinMax(double boundingMin, double numLevels, double delta) {
double adjMin = delta * std::floor(boundingMin / delta);
return std::make_pair(adjMin, adjMin + numLevels * delta);
}
static double overshoot(double boundingMin, double boundingMax,
double numLevels, double delta) {
auto adjMinMax = computeMinMax(boundingMin, numLevels, delta);
double maxOvershoot = adjMinMax.second - boundingMax;
double minOvershoot = boundingMin - adjMinMax.first;
// If undershooting on the min or max end, return that because it is
// to be unconditionally avoided. Otherwise return the end with the
// greatest magnitude of overshoot.
if (maxOvershoot < 0)
return maxOvershoot;
if (minOvershoot < 0)
return minOvershoot;
return std::max(maxOvershoot, minOvershoot);
}
};
// Bisect to find a suitable delta, starting with bounds of deltaInit
// and deltaMax.
double deltaInit = (adjMax - adjMin) / numLevelsDouble;
double deltaMax =
((numLevelsDouble * deltaInit) + 2 * deltaInit) / numLevelsDouble;
double deltaMid;
double prevDeltaMid = 0.0;
for (stepCount = 0; stepCount < 60; ++stepCount) {
deltaMid = (deltaInit + deltaMax) / 2.0;
auto fInit =
fns::overshoot(origMinAdj, origMaxAdj, numLevelsDouble, deltaInit);
auto fMid =
fns::overshoot(origMinAdj, origMaxAdj, numLevelsDouble, deltaMid);
if (fMid == 0 || (fMid > 0 && std::fabs(deltaMid - prevDeltaMid) < 1e-15)) {
// Solution found (or step size is infinitesimal and an overshoot).
// Empirically, this seems to terminate around 30-50 steps or so.
// This will find a zero point for exactly representable ranges and
// will terminate on a small step size for inexact, biasing towards
// overshooting.
delta = deltaMid;
break;
}
bool signMid = fMid > 0;
bool signInit = fInit > 0;
if (signMid == signInit) {
deltaInit = deltaMid;
} else {
deltaMax = deltaMid;
}
prevDeltaMid = deltaMid;
}
delta = deltaMid;
// Recalculate adjMin/adjMax based on new delta.
auto adjMinMax = fns::computeMinMax(origMinAdj, numLevelsDouble, delta);
adjMin = adjMinMax.first;
adjMax = adjMinMax.second;
satisfied = false;
zp = 0;
if (!std::isnan(delta) && !std::isnan(adjMin) && !std::isnan(adjMax)) {
satisfied = true;
// Finally, scale and zeroPoint. Since it casts to integer, only valid
// if the inputs are valid.
zp = std::round(storageParams.minValue - adjMin / delta);
}
return satisfied;
}
int64_t UniformParamsFromMinMaxSolver::quantize(double x) const {
int64_t xq = std::round(x / delta + zp);
return std::max<int64_t>(0, std::min<int64_t>(storageParams.numLevels, xq));
}
double UniformParamsFromMinMaxSolver::dequantize(int64_t xq) const {
return (xq - zp) * delta;
}
raw_ostream &mlir::quantizer::operator<<(raw_ostream &os,
const UniformStorageParams &p) {
os << "UniformStorageParams{" << p.numLevels << ", " << p.minValue << "}";
return os;
}
raw_ostream &
mlir::quantizer::operator<<(raw_ostream &os,
const UniformParamsFromMinMaxSolver &s) {
os << "UniformParamsFromMinMaxSolver(" << s.getStepCount() << "){";
os << "(" << s.getBoundingMin() << ":" << s.getBoundingMax() << ") -> ";
if (!s.isSatisfied()) {
os << "unsat}";
return os;
}
os << "(" << s.getAdjMin() << ":" << s.getAdjMax() << ")";
os << ", scale = " << s.getScale();
os << ", zp = " << s.getZp();
os << "}";
return os;
}

View File

@ -1,118 +0,0 @@
//===- AddDefaultStatsTestPass.cpp - Testing pass to add default stats ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines a testing pass to add default statistics nodes to every
// quantization eligible op. Useful for unit testing.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/Quantizer/Configurations/FxpMathConfig.h"
#include "mlir/Quantizer/Support/Configuration.h"
#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
#include "mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h"
#include "mlir/Quantizer/Transforms/Passes.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/Support/GraphWriter.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::quantizer;
using namespace mlir::quant;
namespace {
class AddDefaultStatsPass : public FunctionPass<AddDefaultStatsPass> {
public:
/// Include the generated pass utilities.
#define GEN_PASS_QuantizerAddDefaultStats
#include "mlir/Quantizer/Transforms/Passes.h.inc"
AddDefaultStatsPass() = default;
AddDefaultStatsPass(SolverContext &solverContext,
const TargetConfiguration &config)
: explicitSolverContext(&solverContext), explicitConfig(&config) {}
void runOnFunction() override;
void runWithConfig(SolverContext &solverContext,
const TargetConfiguration &config);
private:
SolverContext *explicitSolverContext = nullptr;
const TargetConfiguration *explicitConfig = nullptr;
};
} // end anonymous namespace
void AddDefaultStatsPass::runOnFunction() {
if (explicitSolverContext && explicitConfig) {
// If explicitly constructed with a config and context.
runWithConfig(*explicitSolverContext, *explicitConfig);
return;
}
// For global pass registration, use defaults.
SolverContext solverContext(*getFunction().getContext());
auto config = FxpMathTargetConfig::create(solverContext);
runWithConfig(solverContext, *config);
}
void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext,
const TargetConfiguration &config) {
auto func = getFunction();
// Insert stats for each argument.
for (auto arg : func.getArguments()) {
if (!config.isHandledType(arg.getType()))
continue;
OpBuilder b(func.getBody());
APFloat minValue(-1.0f);
APFloat maxValue(1.0f);
ElementsAttr layerStats = DenseFPElementsAttr::get(
RankedTensorType::get({2}, b.getF32Type()), {minValue, maxValue});
auto statsOp = b.create<StatisticsOp>(func.getLoc(), arg, layerStats,
nullptr, nullptr);
arg.replaceAllUsesWith(statsOp);
// StatsOp contained a use to 'arg' so make sure to reset it after replacing
// all of the uses of 'arg'.
statsOp.getOperation()->replaceUsesOfWith(statsOp, arg);
}
// Walk the ops and insert stats.
func.walk([&](Operation *op) {
if (!config.isRequireStatsOp(op)) {
return;
}
assert(op->getNumResults() == 1);
auto originalResult = op->getResult(0);
if (!config.isHandledType(originalResult.getType()))
return;
OpBuilder b(op->getBlock(), ++op->getIterator());
APFloat minValue(-1.0f);
APFloat maxValue(1.0f);
ElementsAttr layerStats = DenseFPElementsAttr::get(
RankedTensorType::get({2}, b.getF32Type()), {minValue, maxValue});
auto statsOp = b.create<StatisticsOp>(op->getLoc(), op->getResult(0),
layerStats, nullptr, nullptr);
originalResult.replaceAllUsesWith(statsOp);
// StatsOp contained a use to 'op' so make sure to reset it after replacing
// all of the uses of 'op'.
statsOp.getOperation()->replaceUsesOfWith(statsOp, originalResult);
});
}
std::unique_ptr<OpPassBase<FuncOp>>
mlir::quantizer::createAddDefaultStatsPass() {
return std::make_unique<AddDefaultStatsPass>();
}

View File

@ -1,292 +0,0 @@
//===- InferQuantizedTypesPass.cpp - Infers quantized types ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the primary pass for instantiating a CAG, running it to
// convergence on a module to determine eligible quantized type transforms, and
// applying those transforms to the IR.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/Quantizer/Configurations/FxpMathConfig.h"
#include "mlir/Quantizer/Support/Configuration.h"
#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
#include "mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h"
#include "mlir/Quantizer/Transforms/Passes.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/Support/DOTGraphTraits.h"
#include "llvm/Support/GraphWriter.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::quantizer;
using namespace mlir::quant;
namespace llvm {
template <>
struct DOTGraphTraits<const CAGSlice *>
: public DOTGraphTraits<const CAGNode *> {
DOTGraphTraits(bool isSimple = false)
: DOTGraphTraits<const CAGNode *>(isSimple) {}
std::string getNodeLabel(const CAGNode *node, const CAGSlice *graph) {
std::string s;
llvm::raw_string_ostream out(s);
node->printLabel(out);
return out.str();
}
static std::string getGraphProperties(const CAGSlice *) {
return "rankdir=LR;";
}
static bool isNodeHidden(const CAGNode *node) {
// Filter constraint nodes with no incoming or outgoing connections.
// These orphans are often created as part of graph merging operations.
return llvm::isa<CAGConstraintNode>(node) && node->isOrphan();
}
std::string getNodeAttributes(const CAGNode *node, const CAGSlice *graph) {
switch (node->getKind()) {
default:
return std::string();
case CAGNode::Kind::OperandAnchor:
return "shape=record,color=yellow,style=filled";
case CAGNode::Kind::ResultAnchor:
return "shape=record,color=lightblue,style=filled";
case CAGNode::Kind::Constraint:
return "shape=record,style=dotted";
}
}
};
} // end namespace llvm
namespace {
class InferQuantizedTypesPass : public ModulePass<InferQuantizedTypesPass> {
public:
/// Include the generated pass utilities.
#define GEN_PASS_QuantizerInferQuantizedTypes
#include "mlir/Quantizer/Transforms/Passes.h.inc"
InferQuantizedTypesPass() = default;
InferQuantizedTypesPass(SolverContext &solverContext,
const TargetConfiguration &config)
: explicitSolverContext(&solverContext), explicitConfig(&config) {}
void runOnModule() override;
void runWithConfig(SolverContext &solverContext,
const TargetConfiguration &config);
void transformOperandType(CAGOperandAnchor *anchor, Type newType);
void transformResultType(CAGResultAnchor *anchor, Type newType);
private:
SolverContext *explicitSolverContext = nullptr;
const TargetConfiguration *explicitConfig = nullptr;
};
} // end anonymous namespace
/// Maximum number of propagation rounds to run to converge the CAG before
/// signalling an error.
static const int kMaximumPropagationRounds = 1000;
static LogicalResult validateTypeConversion(Type newType, Type origType,
Operation *op) {
if (!newType) {
return op->emitOpError() << "unsupported type conversion from " << newType;
}
return success();
}
void InferQuantizedTypesPass::runOnModule() {
if (explicitSolverContext && explicitConfig) {
// If explicitly constructed with a config and context.
runWithConfig(*explicitSolverContext, *explicitConfig);
return;
}
// For global pass registration, use defaults.
SolverContext solverContext(*getModule().getContext());
auto config = FxpMathTargetConfig::create(solverContext);
runWithConfig(solverContext, *config);
}
void InferQuantizedTypesPass::runWithConfig(SolverContext &solverContext,
const TargetConfiguration &config) {
CAGSlice cag(solverContext);
for (auto f : getModule().getOps<FuncOp>()) {
f.walk([&cag, &config](Operation *op) { config.handleOp(op, cag); });
}
config.finalizeAnchors(cag);
// Propagate.
int propRound;
for (propRound = kMaximumPropagationRounds; propRound > 0; --propRound) {
auto propCount = cag.propagate(config);
if (propCount == 0)
break;
}
if (propRound == 0) {
emitError(UnknownLoc::get(&getContext()),
"exceeded maximum number of solver iterations (infinite loop?)");
return;
}
// TODO: Only dump the GraphViz if a flag is set and move to a utility.
// GraphViz.
if (!solverContext.getDebugCAGDotPath().empty()) {
auto actFileName = llvm::WriteGraph(
const_cast<const CAGSlice *>(&cag), "CAG",
/*ShortNames=*/false,
/*Title=*/"CAG",
/*Filename=*/std::string(solverContext.getDebugCAGDotPath()));
llvm::errs() << "Wrote graphviz file: " << actFileName << "\n";
}
// Start transforming the types in order of anchor type (results, then
// operands).
// Apply result types.
for (auto *node : cag) {
auto anchorNode = dyn_cast<CAGResultAnchor>(node);
if (!anchorNode)
continue;
if (Type newType = anchorNode->getTransformedType())
transformResultType(anchorNode, newType);
}
// Apply operand types.
for (auto *node : cag) {
auto anchorNode = dyn_cast<CAGOperandAnchor>(node);
if (!anchorNode)
continue;
if (Type newType = anchorNode->getTransformedType())
transformOperandType(anchorNode, newType);
}
}
void InferQuantizedTypesPass::transformOperandType(CAGOperandAnchor *anchor,
Type newType) {
Value inputValue = anchor->getValue();
Operation *op = anchor->getOp();
OpBuilder b(op->getBlock(), Block::iterator(op));
SmallVector<Value, 1> removeValuesIfDead;
// Because we've already run the result transforms at this phase, it is
// very likely that inputValue points to a dcast op whose input matches
// our type. We detect that situation and route around just to save some
// bulk in the IR.
Value newTypedInputValue = inputValue;
auto inputDcastOp =
dyn_cast_or_null<DequantizeCastOp>(inputValue.getDefiningOp());
if (inputDcastOp && inputDcastOp.arg().getType() == newType) {
// Can just use the dcast's input value.
newTypedInputValue = inputDcastOp.arg();
removeValuesIfDead.push_back(inputDcastOp);
} else {
// Need to synthesize a qcast.
newTypedInputValue =
b.create<QuantizeCastOp>(op->getLoc(), newType, inputValue);
}
switch (anchor->getTypeTransformRule()) {
case CAGAnchorNode::TypeTransformRule::Direct:
anchor->getOp()->setOperand(anchor->getOperandIdx(), newTypedInputValue);
break;
case CAGAnchorNode::TypeTransformRule::DirectStorage: {
Type storageType = QuantizedType::castToStorageType(newType);
if (failed(validateTypeConversion(storageType, newType, op)))
return;
anchor->getOp()->setOperand(
anchor->getOperandIdx(),
b.create<StorageCastOp>(op->getLoc(), storageType, newTypedInputValue));
break;
}
case CAGAnchorNode::TypeTransformRule::ExpressedOnly:
// Leave the anchor as-is and just cast in/out after it.
anchor->getOp()->setOperand(
anchor->getOperandIdx(),
b.create<DequantizeCastOp>(op->getLoc(), anchor->getOriginalType(),
newTypedInputValue));
break;
}
for (Value removeValueIfDead : removeValuesIfDead) {
if (removeValueIfDead.use_empty()) {
removeValueIfDead.getDefiningOp()->erase();
}
}
}
void InferQuantizedTypesPass::transformResultType(CAGResultAnchor *anchor,
Type newType) {
Value origResultValue = anchor->getValue();
Operation *op = origResultValue.getDefiningOp();
OpBuilder b(op->getBlock(), ++Block::iterator(op));
Value replacedResultValue = nullptr;
Value newResultValue = nullptr;
switch (anchor->getTypeTransformRule()) {
case CAGAnchorNode::TypeTransformRule::Direct:
origResultValue.setType(newType);
replacedResultValue = newResultValue = b.create<DequantizeCastOp>(
op->getLoc(), anchor->getOriginalType(), origResultValue);
break;
case CAGAnchorNode::TypeTransformRule::DirectStorage: {
Type storageType = QuantizedType::castToStorageType(newType);
if (failed(validateTypeConversion(storageType, newType, op)))
return;
origResultValue.setType(storageType);
replacedResultValue =
b.create<StorageCastOp>(op->getLoc(), newType, origResultValue);
newResultValue = b.create<DequantizeCastOp>(
op->getLoc(), anchor->getOriginalType(), replacedResultValue);
break;
}
case CAGAnchorNode::TypeTransformRule::ExpressedOnly:
// Leave the anchor as-is and just cast in/out after it.
replacedResultValue =
b.create<QuantizeCastOp>(op->getLoc(), newType, origResultValue);
newResultValue = b.create<DequantizeCastOp>(
op->getLoc(), anchor->getOriginalType(), replacedResultValue);
break;
}
if (replacedResultValue) {
// Transform:
// origResultValue --> replaceResultValue -> newResultValue
// \-> [original uses]
// To:
// origResultValue -> replaceResultValue ->
// newResultValue -> [original uses]
// Note that replaceResultValue may equal newResultValue or there may
// be operands between the two.
origResultValue.replaceAllUsesWith(newResultValue);
replacedResultValue.getDefiningOp()->replaceUsesOfWith(newResultValue,
origResultValue);
}
}
std::unique_ptr<OpPassBase<ModuleOp>>
mlir::quantizer::createInferQuantizedTypesPass(
SolverContext &solverContext, const TargetConfiguration &config) {
return std::make_unique<InferQuantizedTypesPass>(solverContext, config);
}
std::unique_ptr<OpPassBase<ModuleOp>>
mlir::quantizer::createInferQuantizedTypesPass() {
return std::make_unique<InferQuantizedTypesPass>();
}

View File

@ -1,66 +0,0 @@
//===- RemoveInstrumentationPass.cpp - Removes instrumentation ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines a pass to remove any instrumentation ops. It is often one
// of the final steps when performing quantization and is run after any
// decisions requiring instrumentation have been made.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Quantizer/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::quantizer;
using namespace mlir::quant;
namespace {
class RemoveInstrumentationPass
: public FunctionPass<RemoveInstrumentationPass> {
/// Include the generated pass utilities.
#define GEN_PASS_QuantizerRemoveInstrumentation
#include "mlir/Quantizer/Transforms/Passes.h.inc"
void runOnFunction() override;
};
template <typename OpTy>
class RemoveIdentityOpRewrite : public RewritePattern {
public:
RemoveIdentityOpRewrite(MLIRContext *context)
: RewritePattern(OpTy::getOperationName(), 1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
assert(op->getNumOperands() == 1);
assert(op->getNumResults() == 1);
rewriter.replaceOp(op, op->getOperand(0));
return success();
}
};
} // end anonymous namespace
void RemoveInstrumentationPass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
auto *context = &getContext();
patterns.insert<RemoveIdentityOpRewrite<StatisticsOp>,
RemoveIdentityOpRewrite<StatisticsRefOp>,
RemoveIdentityOpRewrite<CoupledRefOp>>(context);
applyPatternsGreedily(func, patterns);
}
std::unique_ptr<OpPassBase<FuncOp>>
mlir::quantizer::createRemoveInstrumentationPass() {
return std::make_unique<RemoveInstrumentationPass>();
}

View File

@ -1,64 +0,0 @@
// RUN: mlir-opt %s -split-input-file -fxpmath-lower-uniform-casts | FileCheck %s --dump-input=always
// -----
// CHECK-LABEL: dequantize_per_layer_fixedpoint
!type_input = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
!type_result = type tensor<4xf32>
func @dequantize_per_layer_fixedpoint(%arg0 : !type_input) -> !type_result {
// CHECK: %cst = constant dense<6.250000e-02> : tensor<4xf32>
// CHECK-NEXT: %0 = "quant.scast"(%arg0) : (tensor<4x!quant.uniform<i8:f32, 6.250000e-02>>) -> tensor<4xi8>
// CHECK-NEXT: %1 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi32>
// CHECK-NEXT: %2 = "fxpmath.convertistof"(%1) : (tensor<4xi32>) -> tensor<4xf32>
// CHECK-NEXT: %3 = mulf %2, %cst : tensor<4xf32>
// CHECK-NEXT: return %3 : tensor<4xf32>
%0 = "quant.dcast"(%arg0) : (!type_input) -> (!type_result)
return %0 : !type_result
}
// -----
// CHECK-LABEL: dequantize_per_layer_affine
!type_input = type tensor<4x!quant.uniform<i8:f32, 6.25e-2:-36>>
!type_result = type tensor<4xf32>
func @dequantize_per_layer_affine(%arg0 : !type_input) -> !type_result {
// CHECK: %cst = constant dense<36> : tensor<4xi32>
// CHECK-NEXT: %cst_0 = constant dense<6.250000e-02> : tensor<4xf32>
// CHECK-NEXT: %0 = "quant.scast"(%arg0) : (tensor<4x!quant.uniform<i8:f32, 6.250000e-02:-36>>) -> tensor<4xi8>
// CHECK-NEXT: %1 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi32>
// CHECK-NEXT: %2 = addi %1, %cst : tensor<4xi32>
// CHECK-NEXT: %3 = "fxpmath.convertistof"(%2) : (tensor<4xi32>) -> tensor<4xf32>
// CHECK-NEXT: %4 = mulf %3, %cst_0 : tensor<4xf32>
// CHECK-NEXT: return %4 : tensor<4xf32>
%0 = "quant.dcast"(%arg0) : (!type_input) -> (!type_result)
return %0 : !type_result
}
// -----
// CHECK-LABEL: dequantize_per_axis_fixedpoint
!type_input = type tensor<4x!quant.uniform<i8:f32:0, {6.25e-2,3.26e-2,4.25e-2,1.23e-2}>>
!type_result = type tensor<4xf32>
func @dequantize_per_axis_fixedpoint(%arg0 : !type_input) -> !type_result {
// expected-warning@+1 {{unimplemented: per-axis uniform dequantization}}
%0 = "quant.dcast"(%arg0) : (!type_input) -> (!type_result)
return %0 : !type_result
}
// -----
// CHECK-LABEL: dequantize_per_axis_affine
!type_input = type tensor<4x!quant.uniform<i8:f32:0, {6.25e-2,3.26e-2,4.25e-2,1.23e-2}>>
!type_result = type tensor<4xf32>
func @dequantize_per_axis_affine(%arg0 : !type_input) -> !type_result {
// expected-warning@+1 {{unimplemented: per-axis uniform dequantization}}
%0 = "quant.dcast"(%arg0) : (!type_input) -> (!type_result)
return %0 : !type_result
}
// -----
// Noop dequantize should be skipped (will be canonicalized away later).
// CHECK-LABEL: dequantize_noop
!type_input = type tensor<4x!quant.uniform<i8:f32, 6.25e-2:-36>>
!type_result = type tensor<4x!quant.uniform<i8:f32, 6.25e-2:-36>>
func @dequantize_noop(%arg0 : !type_input) -> !type_result {
// CHECK: %0 = "quant.dcast"(%arg0)
%0 = "quant.dcast"(%arg0) : (!type_input) -> (!type_result)
return %0 : !type_result
}

View File

@ -1,102 +0,0 @@
// RUN: mlir-opt %s -split-input-file -fxpmath-lower-uniform-real-math -pass-pipeline='func(canonicalize)' | FileCheck %s --dump-input=always
// -----
// Verify lowering when operands and result have the same fixedpoint scale.
// CHECK-LABEL: real_addew_fixedpoint_isomorphic
!type_lhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
!type_rhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
!type_result = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
func @real_addew_fixedpoint_isomorphic(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
// CHECK-NEXT: %0 = "quant.scast"(%arg0) : (tensor<4x!quant.uniform<i8:f32, 6.250000e-02>>) -> tensor<4xi8>
// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant.uniform<i8:f32, 6.250000e-02>>) -> tensor<4xi8>
// CHECK-NEXT: %2 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi16>
// CHECK-NEXT: %3 = "fxpmath.convertis"(%1) : (tensor<4xi8>) -> tensor<4xi16>
// CHECK-NEXT: %4 = addi %2, %3 : tensor<4xi16>
// CHECK-NEXT: %5 = "fxpmath.clampis"(%4) {clamp_max = 127 : i16, clamp_min = -128 : i16} : (tensor<4xi16>) -> tensor<4xi16>
// CHECK-NEXT: %6 = "fxpmath.convertis"(%5) : (tensor<4xi16>) -> tensor<4xi8>
// CHECK-NEXT: %7 = "quant.scast"(%6) : (tensor<4xi8>) -> tensor<4x!quant.uniform<i8:f32, 6.250000e-02>>
// CHECK-NEXT: return %7 : tensor<4x!quant.uniform<i8:f32, 6.250000e-02>>
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
return %0 : !type_result
}
// -----
// Verify lowering when operands and result have the same fixedpoint scale
// and non-zero zero points.
// CHECK-LABEL: real_addew_affine_isomorphic
!type_lhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2:-5>>
!type_rhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2:-5>>
!type_result = type tensor<4x!quant.uniform<i8:f32, 6.25e-2:-5>>
func @real_addew_affine_isomorphic(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
// CHECK-NEXT: %cst = constant dense<5> : tensor<4xi16>
// CHECK-NEXT: %0 = "quant.scast"(%arg0) : (tensor<4x!quant.uniform<i8:f32, 6.250000e-02:-5>>) -> tensor<4xi8>
// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant.uniform<i8:f32, 6.250000e-02:-5>>) -> tensor<4xi8>
// CHECK-NEXT: %2 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi16>
// CHECK-NEXT: %3 = "fxpmath.convertis"(%1) : (tensor<4xi8>) -> tensor<4xi16>
// CHECK-NEXT: %4 = addi %2, %3 : tensor<4xi16>
// CHECK-NEXT: %5 = addi %4, %cst : tensor<4xi16>
// CHECK-NEXT: %6 = "fxpmath.clampis"(%5) {clamp_max = 127 : i16, clamp_min = -128 : i16} : (tensor<4xi16>) -> tensor<4xi16>
// CHECK-NEXT: %7 = "fxpmath.convertis"(%6) : (tensor<4xi16>) -> tensor<4xi8>
// CHECK-NEXT: %8 = "quant.scast"(%7) : (tensor<4xi8>) -> tensor<4x!quant.uniform<i8:f32, 6.250000e-02:-5>>
// CHECK-NEXT: return %8 : tensor<4x!quant.uniform<i8:f32, 6.250000e-02:-5>>
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
return %0 : !type_result
}
// -----
// The RHS quant parameters proscribe a range of [-8..8) so an explicit clamp
// of [-4..4] should result in an integral clamp range of [-64..64].
// CHECK-LABEL: real_addew_fixedpoint_clamp
!type_lhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
!type_rhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
!type_result = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
func @real_addew_fixedpoint_clamp(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
// CHECK-NEXT: %0 = "quant.scast"(%arg0) : (tensor<4x!quant.uniform<i8:f32, 6.250000e-02>>) -> tensor<4xi8>
// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant.uniform<i8:f32, 6.250000e-02>>) -> tensor<4xi8>
// CHECK-NEXT: %2 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi16>
// CHECK-NEXT: %3 = "fxpmath.convertis"(%1) : (tensor<4xi8>) -> tensor<4xi16>
// CHECK-NEXT: %4 = addi %2, %3 : tensor<4xi16>
// CHECK-NEXT: %5 = "fxpmath.clampis"(%4) {clamp_max = 64 : i16, clamp_min = -64 : i16} : (tensor<4xi16>) -> tensor<4xi16>
// CHECK-NEXT: %6 = "fxpmath.convertis"(%5) : (tensor<4xi16>) -> tensor<4xi8>
// CHECK-NEXT: %7 = "quant.scast"(%6) : (tensor<4xi8>) -> tensor<4x!quant.uniform<i8:f32, 6.250000e-02>>
// CHECK-NEXT: return %7 : tensor<4x!quant.uniform<i8:f32, 6.250000e-02>>
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) { clamp_min = -4.0, clamp_max = 4.0 }
: (!type_lhs, !type_rhs) -> (!type_result)
return %0 : !type_result
}
// -----
// CHECK-LABEL: real_addew_unquantized_lhs
// Verifies that leaves as-is for unquantized lhs.
!type_lhs = type tensor<4xf32>
!type_rhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
!type_result = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
func @real_addew_unquantized_lhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
// CHECK: %0 = "fxpmath.real_add_ew"(%arg0, %arg1)
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
return %0 : !type_result
}
// -----
// CHECK-LABEL: real_addew_unquantized_rhs
// Verifies that leaves as-is for unquantized rhs.
!type_lhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
!type_rhs = type tensor<4xf32>
!type_result = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
func @real_addew_unquantized_rhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
// CHECK: %0 = "fxpmath.real_add_ew"(%arg0, %arg1)
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
return %0 : !type_result
}
// -----
// CHECK-LABEL: real_addew_unquantized_result
// Verifies that leaves as-is for unquantized result.
!type_lhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
!type_rhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
!type_result = type tensor<4xf32>
func @real_addew_unquantized_result(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
// CHECK: %0 = "fxpmath.real_add_ew"(%arg0, %arg1)
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
return %0 : !type_result
}

View File

@ -1,94 +0,0 @@
// RUN: mlir-opt %s -split-input-file -fxpmath-lower-uniform-real-math -pass-pipeline='func(canonicalize)' -verify-diagnostics | FileCheck %s --dump-input=always
// -----
// Verify lowering when operands and result have the same fixedpoint scale.
// CHECK-LABEL: real_mulew_fixedpoint
!type_lhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
!type_rhs = type tensor<4x!quant.uniform<i8:f32, 3.875e-2>>
!type_result = type tensor<4x!quant.uniform<i8:f32, 1.065e-1>>
func @real_mulew_fixedpoint(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
// CHECK: %0 = "quant.scast"(%arg0) : (tensor<4x!quant.uniform<i8:f32, 6.250000e-02>>) -> tensor<4xi8>
// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant.uniform<i8:f32, 3.875000e-02>>) -> tensor<4xi8>
// CHECK-NEXT: %2 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi32>
// CHECK-NEXT: %3 = "fxpmath.convertis"(%1) : (tensor<4xi8>) -> tensor<4xi32>
// CHECK-NEXT: %4 = muli %2, %3 : tensor<4xi32>
// CHECK-NEXT: %5 = "fxpmath.vs_saturating_rounding_doubling_high_mulis"(%4) {b = 1562722842 : i32} : (tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: %6 = "fxpmath.rounding_divide_by_potis"(%5) {exponent = 5 : i32} : (tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: %7 = "fxpmath.clampis"(%6) {clamp_max = 127 : i32, clamp_min = -128 : i32} : (tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: %8 = "fxpmath.convertis"(%7) : (tensor<4xi32>) -> tensor<4xi8>
// CHECK-NEXT: %9 = "quant.scast"(%8) : (tensor<4xi8>) -> tensor<4x!quant.uniform<i8:f32, 1.065000e-01>>
// CHECK-NEXT: return %9 : tensor<4x!quant.uniform<i8:f32, 1.065000e-01>>
%0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
return %0 : !type_result
}
// -----
// Verify lowering when operands and result have the same fixedpoint scale
// and non-zero zero points.
// CHECK-LABEL: real_mulew_affine_clamp
!type_lhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2:-3>>
!type_rhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2:-5>>
!type_result = type tensor<4x!quant.uniform<i8:f32, 6.25e-2:-9>>
func @real_mulew_affine_clamp(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
// Just verify that the affine adds/constants and clamps are present.
// CHECK: %cst = constant dense<3> : tensor<4xi32>
// CHECK: %cst_0 = constant dense<5> : tensor<4xi32>
// CHECK: %cst_1 = constant dense<-9> : tensor<4xi32>
// CHECK: addi %2, %cst : tensor<4xi32>
// CHECK: addi %3, %cst_0 : tensor<4xi32>
// CHECK: muli %4, %5 : tensor<4xi32>
// CHECK: addi %8, %cst_1 : tensor<4xi32>
// CHECK: {clamp_max = 55 : i32, clamp_min = -73 : i32}
%0 = "fxpmath.real_mul_ew"(%arg0, %arg1) { clamp_min = -4.0, clamp_max = 4.0 } : (!type_lhs, !type_rhs) -> (!type_result)
return %0 : !type_result
}
// -----
// CHECK-LABEL: real_mulew_unquantized_lhs
// Verifies that leaves as-is for unquantized lhs.
!type_lhs = type tensor<4xf32>
!type_rhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
!type_result = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
func @real_mulew_unquantized_lhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
// CHECK: %0 = "fxpmath.real_mul_ew"(%arg0, %arg1)
%0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
return %0 : !type_result
}
// -----
// CHECK-LABEL: real_mulew_unquantized_rhs
// Verifies that leaves as-is for unquantized rhs.
!type_lhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
!type_rhs = type tensor<4xf32>
!type_result = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
func @real_mulew_unquantized_rhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
// CHECK: %0 = "fxpmath.real_mul_ew"(%arg0, %arg1)
%0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
return %0 : !type_result
}
// -----
// CHECK-LABEL: real_mulew_unquantized_result
// Verifies that leaves as-is for unquantized result.
!type_lhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
!type_rhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
!type_result = type tensor<4xf32>
func @real_mulew_unquantized_result(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
// CHECK: %0 = "fxpmath.real_mul_ew"(%arg0, %arg1)
%0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
return %0 : !type_result
}
// -----
// Verify lowering when operands and result have the same fixedpoint scale.
// Note that the multiplier = lhs_scale * rhs_scale / result_scale
// = 22.740610328638496
// CHECK-LABEL: real_mulew_multiplier_gt_1
!type_lhs = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
!type_rhs = type tensor<4x!quant.uniform<i8:f32, 3.875e-2>>
!type_result = type tensor<4x!quant.uniform<i8:f32, 1.065e-4>>
func @real_mulew_multiplier_gt_1(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
// expected-warning@+1 {{unimplemented: cannot multiply with multiplier > 1.0}}
%0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
return %0 : !type_result
}

View File

@ -1,51 +0,0 @@
// RUN: mlir-opt %s -quantizer-infer-quantized-types -quant-convert-const -quantizer-remove-instrumentation -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s
// ----
// A matmul without fused clamp or bias.
// CHECK-LABEL: @matmul
// CHECK: %cst = constant dense{{.*}}tensor<3x5xi8>
// CHECK-NEXT: %0 = "quant.qcast"(%arg0) : (tensor<300x3xf32>) -> tensor<300x3x!quant.uniform<i8:f32, 0.037564418067230126:35>>
// CHECK-NEXT: %1 = "quant.scast"(%cst) : (tensor<3x5xi8>) -> tensor<3x5x!quant.uniform<i8:f32, 0.0062823070315864236:-1>>
// CHECK-NEXT: %2 = "fxpmath.real_matmul"(%0, %1) : (tensor<300x3x!quant.uniform<i8:f32, 0.037564418067230126:35>>, tensor<3x5x!quant.uniform<i8:f32, 0.0062823070315864236:-1>>) -> tensor<300x5x!quant.uniform<i8:f32, 0.0629921259842528:-1>>
// CHECK-NEXT: %3 = "quant.dcast"(%2) : (tensor<300x5x!quant.uniform<i8:f32, 0.0629921259842528:-1>>) -> tensor<300x5xf32>
func @matmul(%arg0: tensor<300x3xf32>) -> tensor<300x5xf32> {
%0 = "quant.stats"(%arg0) {layerStats = dense<[-6.123e+00, 3.45e+00]> : tensor<2xf32>} : (tensor<300x3xf32>) -> tensor<300x3xf32>
%cst = constant {name = "constant.35"} dense<[[-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32>
%1 = "fxpmath.real_matmul"(%0, %cst) : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32>
%2 = "quant.stats"(%1) {layerStats = dense<[-8.000000e+00, 8.000000e+00]> : tensor<2xf32>} : (tensor<300x5xf32>) -> tensor<300x5xf32>
return %2 : tensor<300x5xf32>
}
// ----
// A matmul with fused clamp which serves as statistics for the result.
// CHECK-LABEL: @matmul_clamp
// CHECK: %cst = constant dense{{.*}}tensor<3x5xi8>
// CHECK-NEXT: %0 = "quant.qcast"(%arg0) : (tensor<300x3xf32>) -> tensor<300x3x!quant.uniform<i8:f32, 0.037564418067230126:35>>
// CHECK-NEXT: %1 = "quant.scast"(%cst) : (tensor<3x5xi8>) -> tensor<3x5x!quant.uniform<i8:f32, 0.0062823070315864236:-1>>
// CHECK-NEXT: %2 = "fxpmath.real_matmul"(%0, %1) {clamp_max = 6.100000e+00 : f64, clamp_min = -1.225000e+01 : f64} : (tensor<300x3x!quant.uniform<i8:f32, 0.037564418067230126:35>>, tensor<3x5x!quant.uniform<i8:f32, 0.0062823070315864236:-1>>) -> tensor<300x5x!quant.uniform<i8:f32, 0.072058823529412216:42>>
// CHECK-NEXT: %3 = "quant.dcast"(%2) : (tensor<300x5x!quant.uniform<i8:f32, 0.072058823529412216:42>>) -> tensor<300x5xf32>
func @matmul_clamp(%arg0: tensor<300x3xf32>) -> tensor<300x5xf32> {
%0 = "quant.stats"(%arg0) {layerStats = dense<[-6.123e+00, 3.45e+00]> : tensor<2xf32>} : (tensor<300x3xf32>) -> tensor<300x3xf32>
%cst = constant {name = "constant.35"} dense<[[-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32>
%1 = "fxpmath.real_matmul"(%0, %cst) {clamp_max = 6.10, clamp_min = -12.25} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32>
return %1 : tensor<300x5xf32>
}
// ----
// A matmul with bias and clamp.
// CHECK-LABEL: @matmul_add_clamp
// CHECK: %cst = constant dense{{.*}}tensor<3x5xi8>
// CHECK-NEXT: %cst_0 = constant dense<[14, 28, 42, 56, 69]> : tensor<5xi32>
// CHECK-NEXT: %0 = "quant.qcast"(%arg0) : (tensor<300x3xf32>) -> tensor<300x3x!quant.uniform<i8:f32, 0.037564418067230126:35>>
// CHECK-NEXT: %1 = "quant.scast"(%cst) : (tensor<3x5xi8>) -> tensor<3x5x!quant.uniform<i8:f32, 0.0062823070315864236:-1>>
// CHECK-NEXT: %2 = "quant.scast"(%cst_0) : (tensor<5xi32>) -> tensor<5x!quant.uniform<i32:f32, 0.072058823529412216>>
// CHECK-NEXT: %3 = "fxpmath.real_matmul_bias"(%0, %1, %2) {clamp_max = 6.100000e+00 : f64, clamp_min = -1.225000e+01 : f64} : (tensor<300x3x!quant.uniform<i8:f32, 0.037564418067230126:35>>, tensor<3x5x!quant.uniform<i8:f32, 0.0062823070315864236:-1>>, tensor<5x!quant.uniform<i32:f32, 0.072058823529412216>>) -> tensor<300x5x!quant.uniform<i8:f32, 0.072058823529412216:42>>
// CHECK-NEXT: %4 = "quant.dcast"(%3) : (tensor<300x5x!quant.uniform<i8:f32, 0.072058823529412216:42>>) -> tensor<300x5xf32>
func @matmul_add_clamp(%arg0: tensor<300x3xf32>) -> tensor<300x5xf32> {
%0 = "quant.stats"(%arg0) {layerStats = dense<[-6.123e+00, 3.45e+00]> : tensor<2xf32>} : (tensor<300x3xf32>) -> tensor<300x3xf32>
%cst = constant {name = "constant.35"} dense<[[-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32>
%cst_0 = constant {name = "constant.37"} dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<5xf32>
%1 = "fxpmath.real_matmul_bias"(%0, %cst, %cst_0) {clamp_max = 6.10, clamp_min = -12.25} : (tensor<300x3xf32>, tensor<3x5xf32>, tensor<5xf32>) -> tensor<300x5xf32>
return %1 : tensor<300x5xf32>
}

View File

@ -1,15 +0,0 @@
// RUN: mlir-opt %s -quantizer-remove-instrumentation -split-input-file | FileCheck %s
// -----
// CHECK-LABEL: remove_ops
func @remove_ops(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
%0 = "quant.stats"(%arg0) {
layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
%1 = "quant.coupled_ref"(%0) { coupledKey = "foobar" } :
(tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
%2 = "quant.stats_ref"(%1) { statsKey = "foobar" } :
(tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
// CHECK: return %arg0 : tensor<8x4x3xf32>
return %2 : tensor<8x4x3xf32>
}

View File

@ -1,7 +1,6 @@
// RUN: mlir-opt --show-dialects | FileCheck %s
// CHECK: Registered Dialects:
// CHECK: affine
// CHECK: fxpmath
// CHECK: gpu
// CHECK: linalg
// CHECK: llvm

View File

@ -16,9 +16,6 @@ set(LIBS
MLIROptLib
MLIRParser
MLIRPass
MLIRQuantizerFxpMathConfig
MLIRQuantizerSupport
MLIRQuantizerTransforms
MLIRSPIRV
MLIRSPIRVTestPasses
MLIRSPIRVTransforms

View File

@ -1,99 +0,0 @@
//===- RulesTest.cpp - Rules unit tests -----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Quantizer/Support/Rules.h"
#include "llvm/Support/raw_ostream.h"
#include "gtest/gtest.h"
using namespace mlir;
using namespace mlir::quantizer;
namespace {
using TestDiscreteFact = DiscreteFact<int>;
TEST(ExpandingMinMaxReducer, Basic) {
ExpandingMinMaxFact f;
EXPECT_FALSE(f.hasValue());
// First assertion always modifies.
EXPECT_TRUE(modified(f.assertValue(0, {-1.0, 1.0})));
EXPECT_TRUE(f.hasValue());
EXPECT_EQ(std::make_pair(-1.0, 1.0), f.getValue());
EXPECT_EQ(0, f.getSalience());
// Assertion in the same band expands.
EXPECT_TRUE(modified(f.assertValue(0, {-0.5, 2.0})));
EXPECT_TRUE(f.hasValue());
EXPECT_EQ(std::make_pair(-1.0, 2.0), f.getValue());
EXPECT_EQ(0, f.getSalience());
EXPECT_TRUE(modified(f.assertValue(0, {-2.0, 0.5})));
EXPECT_TRUE(f.hasValue());
EXPECT_EQ(std::make_pair(-2.0, 2.0), f.getValue());
EXPECT_EQ(0, f.getSalience());
// Same band smaller bound does not modify.
EXPECT_FALSE(modified(f.assertValue(0, {-0.5, 0.5})));
EXPECT_TRUE(f.hasValue());
EXPECT_EQ(std::make_pair(-2.0, 2.0), f.getValue());
EXPECT_EQ(0, f.getSalience());
// Higher salience overrides.
EXPECT_TRUE(modified(f.assertValue(10, {-0.2, 0.2})));
EXPECT_TRUE(f.hasValue());
EXPECT_EQ(std::make_pair(-0.2, 0.2), f.getValue());
EXPECT_EQ(10, f.getSalience());
// Lower salience no-ops.
EXPECT_FALSE(modified(f.assertValue(5, {-2.0, 2.0})));
EXPECT_TRUE(f.hasValue());
EXPECT_EQ(std::make_pair(-0.2, 0.2), f.getValue());
EXPECT_EQ(10, f.getSalience());
// Merge from a fact without a value no-ops.
ExpandingMinMaxFact f1;
EXPECT_FALSE(modified(f.mergeFrom(f1)));
EXPECT_TRUE(f.hasValue());
EXPECT_EQ(std::make_pair(-0.2, 0.2), f.getValue());
EXPECT_EQ(10, f.getSalience());
// Merge from a fact with a value merges.
EXPECT_TRUE(modified(f1.mergeFrom(f)));
EXPECT_TRUE(f1.hasValue());
EXPECT_EQ(std::make_pair(-0.2, 0.2), f1.getValue());
EXPECT_EQ(10, f1.getSalience());
}
TEST(TestDiscreteFact, Basic) {
TestDiscreteFact f;
EXPECT_FALSE(f.hasValue());
// Initial value.
EXPECT_TRUE(modified(f.assertValue(0, {2})));
EXPECT_FALSE(modified(f.assertValue(0, {2})));
EXPECT_EQ(2, f.getValue().value);
EXPECT_FALSE(f.getValue().conflict);
// Conflicting update.
EXPECT_TRUE(modified(f.assertValue(0, {4})));
EXPECT_EQ(2, f.getValue().value); // Arbitrary but known to be first wins.
EXPECT_TRUE(f.getValue().conflict);
// Further update still conflicts.
EXPECT_FALSE(modified(f.assertValue(0, {4})));
EXPECT_EQ(2, f.getValue().value); // Arbitrary but known to be first wins.
EXPECT_TRUE(f.getValue().conflict);
// Different salience update does not conflict.
EXPECT_TRUE(modified(f.assertValue(1, {6})));
EXPECT_EQ(6, f.getValue().value);
EXPECT_FALSE(f.getValue().conflict);
}
} // end anonymous namespace

View File

@ -1,142 +0,0 @@
//===- UniformSolversTest.cpp - Tests for uniform solvers -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Quantizer/Support/UniformSolvers.h"
#include "gtest/gtest.h"
using namespace mlir;
using namespace mlir::quantizer;
namespace {
const double kEpsilon = 1e-12;
TEST(UniformMathTest, testAsym) {
UniformParamsFromMinMaxSolver s(UniformStorageParams::getQuint8(), -8, 8.123);
ASSERT_TRUE(s.compute());
llvm::errs() << "testBasic Results: " << s << "\n";
EXPECT_EQ(0.0, s.dequantize(s.getZp())); // Exact.
EXPECT_EQ(s.getZp(), s.quantize(0.0));
EXPECT_GE(s.getAdjMax() + kEpsilon, s.getBoundingMax());
EXPECT_LE(s.getAdjMin() - kEpsilon, s.getBoundingMin());
}
TEST(UniformMathTest, testPOT) {
UniformParamsFromMinMaxSolver s(UniformStorageParams::getQuint8(), -8,
7.9375);
ASSERT_TRUE(s.compute());
llvm::errs() << "testPOT Results: " << s << "\n";
// POT ranges should be exact.
EXPECT_EQ(128, s.getZp());
EXPECT_NEAR(6.25e-2, s.getScale(), kEpsilon);
EXPECT_EQ(0.0, s.dequantize(s.getZp())); // Exact.
EXPECT_EQ(s.getZp(), s.quantize(0.0));
EXPECT_GE(s.getAdjMax() + kEpsilon, s.getBoundingMax());
EXPECT_LE(s.getAdjMin() - kEpsilon, s.getBoundingMin());
}
TEST(UniformMathTest, testLopsidedPositive) {
UniformParamsFromMinMaxSolver s(UniformStorageParams::getQuint8(), 1.0, 8.0);
ASSERT_TRUE(s.compute());
llvm::errs() << "testLopsidedPositive Results: " << s << "\n";
EXPECT_EQ(0, s.getZp());
EXPECT_EQ(0.0, s.dequantize(s.getZp())); // Exact.
EXPECT_EQ(0, s.quantize(0.0));
EXPECT_GE(s.getAdjMax() + kEpsilon, s.getBoundingMax());
EXPECT_LE(s.getAdjMin() - kEpsilon, s.getBoundingMin());
}
TEST(UniformMathTest, testLopsidedNegative) {
UniformParamsFromMinMaxSolver s(UniformStorageParams::getQuint8(), -72.0,
-4.0);
ASSERT_TRUE(s.compute());
llvm::errs() << "testLopsidedNegative Results: " << s << "\n";
EXPECT_EQ(255, s.getZp());
EXPECT_EQ(0.0, s.dequantize(s.getZp())); // Exact.
EXPECT_EQ(255, s.quantize(0.0));
EXPECT_GE(s.getAdjMax() + kEpsilon, s.getBoundingMax());
EXPECT_LE(s.getAdjMin() - kEpsilon, s.getBoundingMin());
}
TEST(UniformMathTest, testLargeRange) {
UniformParamsFromMinMaxSolver s(UniformStorageParams::getQuint8(), -123.23389,
231.1289);
ASSERT_TRUE(s.compute());
llvm::errs() << "testLargeRange Results: " << s << "\n";
// EXPECT_EQ(255, s.getZp());
EXPECT_EQ(0.0, s.dequantize(s.getZp())); // Exact.
EXPECT_EQ(s.getZp(), s.quantize(0.0));
EXPECT_GE(s.getAdjMax() + kEpsilon, s.getBoundingMax());
EXPECT_LE(s.getAdjMin() - kEpsilon, s.getBoundingMin());
}
TEST(UniformMathTest, test16BitLargeRange) {
UniformParamsFromMinMaxSolver s(UniformStorageParams::getQuint16(),
-123.23389, 231.1289);
ASSERT_TRUE(s.compute());
llvm::errs() << "test16BitLargeRange Results: " << s << "\n";
EXPECT_EQ(0.0, s.dequantize(s.getZp())); // Exact.
EXPECT_EQ(s.getZp(), s.quantize(0.0));
EXPECT_GE(s.getAdjMax() + kEpsilon, s.getBoundingMax());
EXPECT_LE(s.getAdjMin() - kEpsilon, s.getBoundingMin());
}
TEST(UniformMathTest, testQuint8SymmetricRight) {
UniformParamsFromMinMaxSolver s(
UniformStorageParams::getQuint8SymmetricRight(), -123.23389, 231.1289);
ASSERT_TRUE(s.compute());
llvm::errs() << "testQuint8SymmetricRight Results: " << s << "\n";
EXPECT_EQ(0.0, s.dequantize(s.getZp())); // Exact.
EXPECT_EQ(s.getZp(), s.quantize(0.0));
EXPECT_GE(s.getAdjMax() + kEpsilon, s.getBoundingMax());
EXPECT_LE(s.getAdjMin() - kEpsilon, s.getBoundingMin());
}
TEST(UniformMathTest, testQuint4) {
UniformParamsFromMinMaxSolver s({15, 0}, -1.0, 1.0);
ASSERT_TRUE(s.compute());
llvm::errs() << "testQuint4 Results: " << s << "\n";
EXPECT_EQ(0.0, s.dequantize(s.getZp())); // Exact.
EXPECT_EQ(s.getZp(), s.quantize(0.0));
EXPECT_GE(s.getAdjMax() + kEpsilon, s.getBoundingMax());
EXPECT_LE(s.getAdjMin() - kEpsilon, s.getBoundingMin());
}
TEST(UniformMathTest, testNan) {
UniformParamsFromMinMaxSolver s({0, 0}, -1.0, 1.0);
ASSERT_FALSE(s.compute());
}
TEST(UniformMathTest, testBadBounds) {
UniformParamsFromMinMaxSolver s(UniformStorageParams::getQuint16(), 123.23389,
-231.1289);
ASSERT_FALSE(s.compute());
}
TEST(UniformMathTest, testZeroBounds) {
UniformParamsFromMinMaxSolver s(UniformStorageParams::getQuint16(), 0, 0);
ASSERT_FALSE(s.compute());
}
} // end namespace