diff --git a/mlir/g3doc/Quantization.md b/mlir/g3doc/Quantization.md new file mode 100644 index 000000000000..ad1775c434f6 --- /dev/null +++ b/mlir/g3doc/Quantization.md @@ -0,0 +1,330 @@ +# MLIR Quantization + +This document outlines the design of the MLIR quantization system. While the +term "quantization" is highly overloaded, in this case, it refers to a fairly +narrow scope of techniques in use to enable conversion of floating-point +computations to corresponding and plausible variants expressed in integer math +for inference, as has historically been supported by low-bit depth inference +engines such as TFLite, various accelerator hardware, and many DSPs. + +Much of this is inspired by the approach taken +[in this paper](https://arxiv.org/abs/1712.05877) with many extensions and +adaptations folded in. It specifically documents the positions that MLIR has +taken on the topic, and is not a general reference. + +[TOC] + +## Uniform quantization + +The primary quantization mechanism supported by MLIR is a scheme which can +express fixed point and affine transformations via uniformly spaced point on the +Real number line. + +Further, the scheme can be applied: + +* *per-layer* : Applying to every value within the target type. +* *per-axis* (also called *per-channel*) : Applying individually to each index + along a specific axis of a tensor type. + +### Fixed point values {#fixed-point} + +[Fixed point](https://en.wikipedia.org/wiki/Fixed-point_arithmetic) values are a +[Real](https://en.wikipedia.org/wiki/Real_number) number divided by a *scale*. +We will call the result of the divided Real the *scaled value*. + +$$ real\_value = scaled\_value * scale $$ + +The scale can be interpreted as the distance, in Real units, between neighboring +scaled values. For example, if the scale is $$ \pi $$, then fixed point values +with this scale can only represent multiples of $$ \pi $$, and nothing in +between. The maximum rounding error to convert an arbitrary Real to a fixed +point value with a given $$ scale $$ is $$ \frac{scale}{2} $$. Continuing the +previous example, when $$ scale = \pi $$, the maximum rounding error will be $$ +\frac{\pi}{2} $$. + +Multiplication can be performed on scaled values with different scales, using +the same algorithm as multiplication of Real values (note that product scaled +value has $$ scale_{product} = scale_{left \mbox{ } operand} * scale_{right +\mbox{ } operand} $$). Addition can be performed on scaled values, as long as +they have the same scale, using the same algorithm as addition of Real values. +This makes it convenient to represent scaled values on a computer as signed +integers, and perform arithmetic on those signed integers, because the results +will be correct scaled values. + +### Affine values {#affine} + +Mathematically speaking, affine values are the result of +[adding a Real-valued *zero point*, to a scaled value](https://en.wikipedia.org/wiki/Affine_transformation#Representation). +Or equivalently, subtracting a zero point from an affine value results in a +scaled value: + +$$ real\_value = scaled\_value * scale = (affine\_value - zero\_point) * scale $$ + +Essentially, affine values are a shifting of the scaled values by some constant +amount. Arithmetic (i.e., addition, subtraction, multiplication, division) +cannot, in general, be directly performed on affine values; you must first +[convert](#affine-to-fixed-point) them to the equivalent scaled values. + +As alluded to above, the motivation for using affine values is to more +efficiently represent the Real values that will actually be encountered during +computation. Frequently, the Real values that will be encountered are not +symmetric around the Real zero. We also make the assumption that the Real zero +is encountered during computation, and should thus be represented. + +In this case, it's inefficient to store scaled values represented by signed +integers, as some of the signed integers will never be used. The bit patterns +corresponding to those signed integers are going to waste. + +In order to exactly represent the Real zero with an integral-valued affine +value, the zero point must be an integer between the minimum and maximum affine +value (inclusive). For example, given an affine value represented by an 8 bit +unsigned integer, we have: $$ 0 \leq zero\_point \leq 255$$. This is important, +because in deep neural networks's convolution-like operations, we frequently +need to zero-pad inputs and outputs, so zero must be exactly representable, or +the result will be biased. + +### Relation + +Real values, fixed point values, and affine values relate through the following +equation, which demonstrates how to convert one type of number to another: + +$$ real\_value = scaled\_value * scale = (affine\_value - zero\_point) * scale $$ + +Note that computers generally store mathematical values using a finite number of +bits. Thus, while the above conversions are exact, to store the result in a +finite number of bits, we must, in general, round the result of the conversion +(this applies to both cases: storing using floating point and storing using +fixed point). Note that a full discussion of rounding behavior is outside the +scope of this document, and it is safe to assume unless otherwise stated that +rounding should be according to the IEEE754 default of RNE (where hardware +permits). + +### Converting between Real and fixed point or affine {#converting-between} + +To convert a Real value to a fixed point value, you must know the scale. To +convert a Real value to an affine value, you must know the scale and zero point. + +#### Real to affine + +To convert an input tensor of Real-valued elements (usually represented by a +floating point format, frequently +[Single precision](https://en.wikipedia.org/wiki/Single-precision_floating-point_format)) +to a tensor of affine elements represented by an integral type (e.g. 8-bit +unsigned integer), the following conversion can be performed (note that it is +not required that all representable values of the integral type are used): + +$$ +\begin{align*} +af&fine\_value_{uint8 \, or \, uint16} \\ + &= clampToTargetSize(roundToNearestInteger( \frac{real\_value_{Single}}{scale_{Single}})_{sint32} + zero\_point_{uint8 \, or \, uint16}) +\end{align*} +$$ + +In the above, we assume that $$real\_value$$ is a Single, $$scale$$ is a Single, +$$roundToNearestInteger$$ returns a signed 32 bit integer, and $$zero\_point$$ +is an unsigned 8 or 16 bit integer. Note that bit depth and number of fixed +point values is indicative of common types on typical hardware but is not +constrained to particular bit depths or a requirement that the entire range of +an N-bit integer is used. + +#### Affine to Real {#affine-to-real} + +To convert an output tensor of affine elements represented by uint8 +or uint16 to a tensor of Real-valued elements (usually represented with a +floating point format, frequently Single precision), the following conversion +can be performed: + +$$ +\begin{align*} +re&al\_value_{Single} \\ + &= roundToNearestFloat((affine\_value_{uint8 \, or \, uint16} - zero\_point_{uint8 \, or \, uint16})_{sint32})_{Single} * scale_{Single} +\end{align*} +$$ + +In the above, we assume that the result of subtraction is in 32-bit signed +integer format, and that $$roundToNearestFloat$$ returns a Single. + +#### Affine to fixed point {#affine-to-fixed-point} + +When the affine and fixed point scales are the same, subtract the zero point +from the affine value to get the equivalent fixed point value. + +$$ +scaled\_value = affine\_value_{non\mbox{-}negative} - zero\_point_{non\mbox{-}negative} +$$ + +#### Fixed point to affine {#fixed-point-to-affine} + +When the affine and fixed point scales are the same, add the zero point to the +fixed point value to get the equivalent affine value. + +$$ +affine\_value_{non\mbox{-}negative} = scaled\_value + zero\_point_{non\mbox{-}negative} +$$ + +## Usage within MLIR {#usage-within-mlir} + +There are several components to the quantization system within MLIR: + +* *Quantization* dialect containing: + + * A family of [QuantizedTypes](#quantized-type) which represent the + mapping between *expressed* values (typically of a floating point + computer type) and *storage* values (typically of an integral computer + type). + * [Type conversion ops](#quantized-type-conversion-ops) for converting + between types based on a QuantizedType and its *expressed* and *storage* + sub-types. + * [Instrumentation ops](#instrumentation-ops) for assigning + instrumentation points within the computation where runtime statistics + may help guide the quantization process. + +* *QuantizedMath* dialect containing: + + * [Real math ops](#real-math-ops) representing common combinations of + arithmetic operations that closely match corresponding fixed-point math + concepts (as opposed to being spread across multiple ops as is typical + in source dialects). + * [Fixed-point math ops](#fixed-point-math-ops) that for carrying out + computations on integers, as are typically needed by uniform + quantization schemes. + * Passes to lower from real math ops to fixed-point math ops. + +* [Solver tools](#solver-tools) which can generically operate on computations + expressed in the *QuantizedMath* dialect in order to convert from floating + point types to appropriate *QuantizedTypes*, allowing the computation to be + further lowered to integral math ops. + +Not every application of quantization will use all facilities. Specifically, the +TensorFlow to TensorFlow Lite conversion uses the QuantizedTypes but has its own +ops for type conversion and expression of the backing math. + +## Interactions with simulated quantization at training time {#training-time} + +TensorFlow has historically used the +[tf.quantization.fake_quant_\*](https://www.tensorflow.org/api_docs/python/tf/quantization/fake_quant_with_min_max_args) +family of operations to simulate the effect of quantization at training time. + +As originally implemented, TensorFlow Lite was the primary user of such +operations at inference time. When quantized inference was enabled, if every +eligible tensor passed through an appropriate fake_quant node (the rules of +which tensors can have fake_quant applied are somewhat involved), then +TensorFlow Lite would use the attributes of the fake_quant ops to make a +judgment about how to convert to use kernels from its quantized ops subset. + +In MLIR-based quantization, fake_quant_\* ops are handled by converting them to +a sequence of *qcast* (quantize) followed by *dcast* (dequantize) with an +appropriate *UniformQuantizedType* as the target of the qbarrier operation. + +This allows subsequent compiler passes to preserve the knowledge that +quantization was simulated in a certain way while giving the compiler +flexibility to move the barriers as it simplifies the computation and converts +it to a form based on integral arithmetic. + +This scheme also naturally allows computations that are *partially quantized* +where the parts which could not be reduced to integral ops are still carried out +in floating point with appropriate conversions at the boundaries. + +## Quantization Dialect + +### Quantized type {#quantized-type} + +TODO : Flesh this section out. + +* QuantizedType base class +* UniformQuantizedType + +### Quantized type conversion ops {#quantized-type-conversion-ops} + +* qcast : Convert from an expressed type to QuantizedType +* dcast : Convert from a QuantizedType to its expressed type +* scast : Convert between a QuantizedType and its storage type + +### Instrumentation and constraint ops {#instrumentation-ops} + +TODO : These ops are not defined yet + +* instrument_stats : Assigns a unique id and signals that statistics should be + collected by the runtime when execution passes through this op. +* constrain_uniform : Constrains that for uniform quantization, the solver + should choose a type with certain characteristics such as the number of + fixed-point values, underlying storage type, or whether to constrain to + power of two scales. + +## QuantizedMath Dialect + +### Real math ops {#real-math-ops} + +Note that these all support explicit clamps, which allows for simple fusions and +representation of some common sequences quantization-compatible math. Of +addition, some support explicit biases, which are often represented as separate +adds in source dialects. + +TODO: This op set is still evolving and needs to be completed. + +* RealBinaryOp + * RealAddEwOp + * RealSubEwOp + * RealMulEwOp + * RealDivEwOp +* RealUnaryOp + * IDENTITY + * TANH + * SIGMOID + * EXP + * LOG + * NEG + * RSQRT + * SIN + * SQUARE + * SQRT + * CMPZ + * CMPNZ + * CMPLZ + * CMPGZ + +### Fixed-point math ops {#fixed-point-math-ops} + +TODO: This op set only has enough ops to lower a simple power-of-two +RealAddEwOp. + +* RoundingDivideByPotFxpOp +* SaturatingAddFxpOp + +## Solver tools {#solver-tools} + +Solver tools exist to analyze an MLIR-computation, expressed in either a +supported source dialect or in the *real math ops* set and solve for appropriate +QuantizedTypes that allow the computation to be lowered to integral math. + +These tools are an active area of work and may be expanded in the future to +adjacent areas such as solving for transformations to other kinds of lower +precision types (i.e. bfloat16 or fp16). + +Solver tools are expected to operate in several modes, depending on the +computation and the manner in which it was trained: + +* *Transform* : With all available information in the MLIR computation, infer + boundaries where the computation can be carried out with integral math and + change types accordingly to appropriate QuantizedTypes: + + * For passthrough ops which do not perform active math, change them to + operate directly on the storage type, converting in and out at the edges + via scast ops. + * For ops that have the *Quantizable* trait, the type can be set directly. + This includes ops from the [real math ops set]{#real-math-ops}. + * For others, encase them in appropriate dcast/qcast ops, presuming that + some follow-on pass will know what to do with them. + +* *Instrument* : Most of the time, there are not sufficient implied + constraints within a computation to perform many transformations. For this + reason, the solver can insert instrumentation ops at points where additional + runtime statistics may yield solutions. It is expected that such + computations will be lowered as-is for execution, run over an appropriate + eval set, and statistics at each instrumentation point made available for a + future invocation of the solver. + +* *Simplify* : A variety of passes and simplifications are applied once + QuantizedTypes are added in order to arrive at a computation that is + expressed in as much integral math, with the fewest number of casts as + possible. diff --git a/mlir/include/mlir/Quantization/FakeQuantSupport.h b/mlir/include/mlir/Quantization/FakeQuantSupport.h new file mode 100644 index 000000000000..aa3b7b4dcdab --- /dev/null +++ b/mlir/include/mlir/Quantization/FakeQuantSupport.h @@ -0,0 +1,67 @@ +//===- FakeQuantSupport.h - Support utilities for FakeQuant ops -*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines support utilities for interoperating with FakeQuant* based +// QAT (Quantized Aware Training) computations, as implemented by TFLite. Note +// that FakeQuant* operators mix multiple concerns specific to how TFLite +// originally implemented quantization. As such, utilities here enforce +// opinions taken by that codebase (vs providing any amount of genericity). +// +// Specifically, it combines the following concerns, each of which would be +// independent variables in a more generic setup: +// - num_bits implies storage data type (quint8, int16) +// - num_bits < 8 is promoted to quint8 +// - "narrow_range" narrows the lower bound of the storage type's range by +// 1 +// - the specified min/max values are "nudged" so that the result has a zero +// that can be exactly expressed +// - min=max=0 implies scale=0 and zero_point=0 +// +// With the above assumptions applied, every conforming specified FakeQuant op +// can be represented by a UniformQuantizedType. This scheme is not expected to +// be generalized further in the future and should be considered to be a +// legacy set of rules. +// +// As canonically used in TensorFlow graphs, the presence of a FakeQuant node +// is a hint that the specific math represented here has been simulated at +// training time. As such, it is usually not advised to arbitrarily change +// quantization parameters derived from FakeQuant. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_QUANTIZATION_FAKEQUANTSUPPORT_H_ +#define MLIR_QUANTIZATION_FAKEQUANTSUPPORT_H_ + +#include "mlir/Quantization/QuantOps.h" + +namespace mlir { +namespace quant { + +/// Converts per-layer FakeQuant attributes to the corresponding type. +/// In the event that the parameters cannot be converted, returns a nullptr +/// convertible Type and issues an appropriate error. +/// Note that there are multiple variants of a per-layer FakeQuant op, so +/// this function takes the attributes discretely vs taking a reference to the +/// originating op. +UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits, + double rmin, double rmax, + bool narrowRange, Type expressedType); + +} // namespace quant +} // namespace mlir + +#endif // MLIR_QUANTIZATION_FAKEQUANTSUPPORT_H_ diff --git a/mlir/include/mlir/Quantization/Passes.h b/mlir/include/mlir/Quantization/Passes.h new file mode 100644 index 000000000000..090d21cb2925 --- /dev/null +++ b/mlir/include/mlir/Quantization/Passes.h @@ -0,0 +1,59 @@ +//===- Passes.h - Quantization Passes ------ --------------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines all of the passes owned by the quantization dialect. As +// things mature, it is expected that passes specific to certain frontend or +// backend dialects will move to those dialects directly. For now, they are +// incubated here. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_QUANTIZATION_PASSES_H +#define MLIR_QUANTIZATION_PASSES_H + +namespace mlir { +class FunctionPassBase; + +namespace quant { + +/// Creates a pass that lowers quantization related TensorFlow ops into +/// the quantization dialect so that express and implied constraints expressed +/// at the TensorFlow source level can be represented to the quantization +/// system. This will specially handle any TensorFlow op that is useful for +/// guiding quantization. +/// +/// Note that if your intent is to compile a TensorFlow graph for floating +/// point inference, you should probably not use this pass. +FunctionPassBase *createLowerTFPass(); + +/// Creates a pass that converts constants followed by a qbarrier to a +/// constant whose value is quantized. This is typically one of the last +/// passes done when lowering to express actual quantized arithmetic in a +/// low level representation. Because it modifies the constant, it is +/// destructive and cannot be undone. +FunctionPassBase *createConvertConstPass(); + +/// Creates a pass that lowers uniform-quantized real math ops to integer +/// arithmetic. This will leave unrecognized real math ops as-is and is +/// typically followed by a pass that lowers any unrecognized ops to a pure +/// floating point form. +FunctionPassBase *createLowerUniformRealMathPass(); + +} // namespace quant +} // namespace mlir + +#endif // MLIR_QUANTIZATION_PASSES_H diff --git a/mlir/include/mlir/Quantization/QuantOps.h b/mlir/include/mlir/Quantization/QuantOps.h new file mode 100644 index 000000000000..dceb2d07faeb --- /dev/null +++ b/mlir/include/mlir/Quantization/QuantOps.h @@ -0,0 +1,373 @@ +//===- Quantization/QuantOps.h - Quantization Ops and Types -----*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef MLIR_QUANTIZATION_QUANTOPS_H_ +#define MLIR_QUANTIZATION_QUANTOPS_H_ + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir { +namespace quant { + +class QuantizedIntegerType; + +namespace detail { + +struct QuantizedTypeStorage; +struct UniformQuantizedTypeStorage; +struct UniformQuantizedPerAxisTypeStorage; + +} // namespace detail + +namespace QuantizationTypes { +enum Kind { + UniformQuantized = Type::FIRST_QUANTIZATION_TYPE, + UniformQuantizedPerAxis, + LAST_USED_QUANTIZATION_TYPE = UniformQuantizedPerAxis, +}; +} // namespace QuantizationTypes + +/// Enumeration of bit-mapped flags related to quantized types. +namespace QuantizationFlags { +enum FlagValue { + // Indicates that the storage type should be interpreted as a signed + // integer. The default is to interpret it as an unsigned value. + Signed = 1, +}; +} // namespace QuantizationFlags + +/// Base class for all quantized types known to this dialect. +/// All quantized types have: +/// - storageType: The (narrower) numeric type that is being used to +/// approximate some expressed type. +/// - expressedType: The type that is being approximated. +/// +/// The base class provides generic support for manipulating the types based +/// on these fields. +class QuantizedType : public Type { +public: + using ImplType = detail::QuantizedTypeStorage; + using Type::Type; + + /// The maximum number of bits supported for storage types. + static constexpr unsigned MaxStorageBits = 32; + + static LogicalResult + verifyConstructionInvariants(llvm::Optional loc, + MLIRContext *context, unsigned flags, + Type storageType, Type expressedType, + int64_t storageTypeMin, int64_t storageTypeMax); + + /// Support method to enable LLVM-style type casting. + static bool kindof(unsigned kind) { + return kind == QuantizationTypes::UniformQuantized; + } + + /// Gets the minimum possible stored by a storageType. storageTypeMin must + /// be greater than or equal to this value. + static int64_t getDefaultMininumForInteger(bool isSigned, + unsigned integralWidth) { + if (isSigned) { + return llvm::minIntN(integralWidth); + } + return 0; + } + + /// Gets the maximum possible stored by a storageType. storageTypeMax must + /// be less than or equal to this value. + static int64_t getDefaultMaxinumForInteger(bool isSigned, + unsigned integralWidth) { + if (isSigned) { + return llvm::maxIntN(integralWidth); + } + return llvm::maxUIntN(integralWidth); + } + + /// Gets the original expressed type that this quantized type approximates. + /// Note that this presumes that the quantized type was always derived from + /// a floating point type, which in the broadest definition, is not true (i.e. + /// it could be some form of integral, fixed type or affine type in its own + /// right); however, at the high level, no examples of such usage are + /// presently known and the restriction serves some useful purposes (such as + /// always being able to reverse a transformation or measure error). In most + /// cases, this will be f32. + Type getExpressedType() const; + + /// Gets the flags associated with this type. Typically a more specific + /// accessor is appropriate. + unsigned getFlags() const; + + // Convenience helpers. + /// Whether the storage type should be interpreted as a signed quantity + /// (true) or an unsigned value (false). + bool isSigned() const { + return (getFlags() & QuantizationFlags::Signed) == + QuantizationFlags::Signed; + } + + /// Gets the underlying type used for to store values. Note that this may + /// be signed or unsigned. Use the isSigned() accessor to differentiate. + Type getStorageType() const; + + /// The minimum value that storageType can take. + int64_t getStorageTypeMin() const; + + /// The maximum value that storageType can take. + int64_t getStorageTypeMax() const; + + /// Gets the integral bit width that the underlying storage type can exactly + /// represent. For integral storage types, this will just be their width. + unsigned getStorageTypeIntegralWidth() const; + + /// Returns whether the candidateExpressedType is a match for this + /// QuantizedType. This will be true if the candidate type is either a + /// primitive type or a container type whose element type equals this + /// QuantizedType's expressed type. + /// Examples of compatible candidateExpressedType: + /// !quant<"uniform[i8:f32]{1.0}"> =~ f32 + /// !quant<"uniform[i8:f32]{1.0}"> =~ tensor<4xf32> + bool isCompatibleExpressedType(Type candidateExpressedType); + + /// Returns the element type as a QuantizedType or nullptr if it is not + /// a quantized type. If the type is primitive, returns that. If it is a + /// container (vector/tensor), return the element type. + /// Examples: + /// !quant<"uniform[i8:f32]{1.0}"> -> !quant<"uniform[i8:f32]{1.0}"> + /// tensor<4x!quant<"uniform[i8:f32]{1.0}"> -> quant<"uniform[i8:f32]{1.0}"> + static QuantizedType getQuantizedElementType(Type primitiveOrContainerType); + + /// Casts from a type based on the storageType to a corresponding type based + /// on this type (returns nullptr if the cast is not valid). + /// Examples: + /// i8 -> !quant<"uniform[i8:f32]{1.0}"> + /// tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> + /// vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">> + Type castFromStorageType(Type candidateType); + + /// Casts from a type based on a QuantizedType to a corresponding type based + /// on the storageType (returns nullptr if the cast is not valid). + /// This is the inverse of castFromStorageType(). + static Type castToStorageType(Type quantizedType); + + /// Casts from a type based on the expressedType to a corresponding type based + /// on this type (returns nullptr if the cast is not valid). + /// Examples: + /// f32 -> !quant<"uniform[i8:f32]{1.0}"> + /// tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> + /// vector<4xf32> -> vector<4x!quant<"uniform[i8:f32]{1.0}">> + Type castFromExpressedType(Type candidateType); + + /// Casts from a type based on QuantizedType to a corresponding type based + /// on the expressedType (returns nullptr if the cast is not valid). + /// This is the inverse of castFromExpressedType. + static Type castToExpressedType(Type quantizedType); + + /// Casts from a type based on the expressedType to the equivalent type + /// based on storageType by way of this QuantizedType. Equivalent to: + /// QuantizedType::castToStorageType(castFromExpressedType(candidateType)) + /// (but with validity checks). + /// Example (for this = !quant<"uniform[i8:f32]{1.0}">): + /// tensor<4xf32> -> tensor<4xi8> + Type castExpressedToStorageType(Type candidateType); +}; + +/// Represents a family of uniform, quantized types. +/// +/// Each instance of this type expresses a mapping between real values (most +/// often expressed in floating point f32) and quantized values (either fixed +/// point or affine). +/// +/// The relationship is: +/// real_value = scale * (quantized_value - zero_point) +/// +/// It is used as part of high level graph transformations that have the goal +/// of re-expressing parts of a computation in terms of this common form for +/// more efficient execution at runtime. In addition, it is designed to be +/// expressive enough to facilitate lowering to precise types and operations +/// in target hardware. +/// +/// As a high-level type, focused on intermediate passes, this type holds +/// opinions consistent with high-level usage. If lowering math kernels below +/// the high level arithmetic ops (i.e. to LLVM IR or hardware specific +/// instruction sets), it is expected that the information expressed here +/// will be used to drive low level codegen and target specific type selection, +/// but this type will likely be erased in the process. +/// +/// Syntax synopsis: +/// Per-layer, all parameters expressed: +/// !quant +/// Per-layer, optional parameters omitted: +/// !quant +/// +/// StorageType: 'i'|'u' NumBits +/// ExpressedType: 'f16', 'f32', 'bf16', 'f64' +/// Scale: A legal double value +/// ZeroPoint: An integer value +class UniformQuantizedType + : public Type::TypeBase { +public: + using Base::Base; + + /// Gets an instance of the type with all parameters specified but not + /// checked. + static UniformQuantizedType get(unsigned flags, Type storageType, + Type expressedType, double scale, + int64_t zeroPoint, int64_t storageTypeMin, + int64_t storageTypeMax); + + /// Gets an instance of the type with all specified parameters checked. + /// Returns a nullptr convertible type on failure. + static UniformQuantizedType + getChecked(unsigned flags, Type storageType, Type expressedType, double scale, + int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax, + Location location); + + /// Verifies construction invariants and issues errors/warnings. + static LogicalResult verifyConstructionInvariants( + llvm::Optional loc, MLIRContext *context, unsigned flags, + Type storageType, Type expressedType, double scale, int64_t zeroPoint, + int64_t storageTypeMin, int64_t storageTypeMax); + + /// Support method to enable LLVM-style type casting. + static bool kindof(unsigned kind) { + return kind == QuantizationTypes::UniformQuantized; + } + + /// Gets the scale term. The scale designates the difference between the real + /// values corresponding to consecutive quantized values differing by 1. + double getScale() const; + + /// Gets the storage value corresponding to the real value 0 in the affine + /// equation. + int64_t getZeroPoint() const; + + // Fixed point values are real numbers divided by a scale. + // Currently, only signed storage types are treated as fixed point. + // A fixed point value can be obtained from an affine value by subtracting + // the zeroPoint. + // In the future, this may be explicit versus implied by type and zeroPoint. + bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; } +}; + +/// Represents per-axis (also known as per-channel quantization). +/// +/// Syntax synopsis: +/// Per-axis, all parameters expressed: +/// !quant +/// Per-axis, optional parameters omitted: +/// !quant +/// +/// StorageType: 'i'|'u' NumBits +/// ExpressedType: 'f16', 'f32', 'bf16', 'f64' +/// QuantizedDim: An integer value +/// QuantParams: (Scale ':' ZeroPoint)+ +/// Scale: A legal double value +/// ZeroPoint: An integer value +class UniformQuantizedPerAxisType + : public Type::TypeBase { +public: + using Base::Base; + + /// Gets an instance of the type with all parameters specified but not + /// checked. + static UniformQuantizedPerAxisType + get(unsigned flags, Type storageType, Type expressedType, + ArrayRef scales, ArrayRef zeroPoints, + int32_t quantizedDimension, int64_t storageTypeMin, + int64_t storageTypeMax); + + /// Gets an instance of the type with all specified parameters checked. + /// Returns a nullptr convertible type on failure. + static UniformQuantizedPerAxisType + getChecked(unsigned flags, Type storageType, Type expressedType, + ArrayRef scales, ArrayRef zeroPoints, + int32_t quantizedDimension, int64_t storageTypeMin, + int64_t storageTypeMax, Location location); + + /// Verifies construction invariants and issues errors/warnings. + static LogicalResult verifyConstructionInvariants( + llvm::Optional loc, MLIRContext *context, unsigned flags, + Type storageType, Type expressedType, ArrayRef scales, + ArrayRef zeroPoints, int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax); + + /// Support method to enable LLVM-style type casting. + static bool kindof(unsigned kind) { + return kind == QuantizationTypes::UniformQuantizedPerAxis; + } + + /// Gets the quantization scales. The scales designate the difference between + /// the real values corresponding to consecutive quantized values differing + /// by 1. The ith scale corresponds to the ith slice in the + /// quantized_dimension. + ArrayRef getScales() const; + + /// Gets the storage values corresponding to the real value 0 in the affine + /// equation. The ith zero point corresponds to the ith slice in the + /// quantized_dimension. + ArrayRef getZeroPoints() const; + + /// Specifies the dimension of the Tensor's shape that the scales and + /// zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1] + /// with quantization params: + /// scales=[1.0, 2.0, 3.0], zeroPoints=[1, 2, 3], quantizedDimension=1 + /// will be quantized across the second dimension of t. + /// t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1 + /// t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2 + /// t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3 + int32_t getQuantizedDimension() const; + + /// Fixed point values are real numbers divided by a scale. + /// Currently, only signed storage types are treated as fixed point. + /// A fixed point value can be obtained from an affine value by subtracting + /// the zeroPoint. + /// In the future, this may be explicit versus implied by type and zeroPoint. + bool isFixedPoint() const { + if (!isSigned()) + return false; + return llvm::all_of(getZeroPoints(), + [](int64_t zeroPoint) { return zeroPoint != 0; }); + } +}; + +/// Defines the 'Quantization' dialect +class QuantizationDialect : public Dialect { +public: + QuantizationDialect(MLIRContext *context); + + /// Parse a type registered to this dialect. + Type parseType(StringRef spec, Location loc) const override; + + /// Print a type registered to this dialect. + void printType(Type type, raw_ostream &os) const override; +}; + +#define GET_OP_CLASSES +#include "mlir/Quantization/QuantOps.h.inc" + +} // namespace quant +} // namespace mlir + +#endif // MLIR_QUANTIZATION_QUANTOPS_H_ diff --git a/mlir/include/mlir/Quantization/QuantOps.td b/mlir/include/mlir/Quantization/QuantOps.td new file mode 100644 index 000000000000..8c247a3c5911 --- /dev/null +++ b/mlir/include/mlir/Quantization/QuantOps.td @@ -0,0 +1,285 @@ +//===- QuantOps.td - Quantization operation definition -----*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This is the operation definition file for Quantization. +// +//===----------------------------------------------------------------------===// + +#ifdef QUANTIZATION_OPS +#else + +#ifdef OP_BASE +#else +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +//===----------------------------------------------------------------------===// +// Quantization type definitions +//===----------------------------------------------------------------------===// + +class quant_TypedPrimitiveOrContainer : + Type.predicate, + TypedVector.predicate]>, + "primitive/tensor/vector of " # etype.description>; + +// An implementation of QuantizedType. +def quant_QuantizedType : + Type()">, "QuantizedType">; + +// A primitive type that can represent a real value. This is either a +// floating point value or a quantized type. +def quant_RealPrimitiveType : + Type, + "real valued primitive (float or quantized type)">; + +// A primitive type that can represent a storage value. This is either an +// integer or quantized type. +def quant_StoragePrimitiveType : + Type, + "quantized storage primitive (integer or quantized type)">; + +// A primitive or container of RealPrimitiveType. +def quant_RealValueType : + quant_TypedPrimitiveOrContainer; + +// A primitive or container of StoragePrimitiveType. +def quant_StorageValueType : + quant_TypedPrimitiveOrContainer; + +// Either a real valued or storage primitive or container type. +def quant_RealOrStorageValueType : + Type>; + +// An implementation of UniformQuantizedType. +def quant_UniformQuantizedType : + Type()">, "UniformQuantizedType">; + +// Predicate for detecting a container or primitive of UniformQuantizedType. +def quant_UniformQuantizedValueType : + quant_TypedPrimitiveOrContainer; + +//===----------------------------------------------------------------------===// +// Attributes +//===----------------------------------------------------------------------===// + +// Real value for an (inclusive) min/max clamp limit. +def quant_ClampValueAttr : OptionalAttr; + +// Element-wise activation function to apply. +// Note that RELU activations are not here: they are expressed as clamps. +def quant_EwUnaryFnAttr : + StringBasedAttr, "element-wise unary function"> { + let returnType = [{ StringRef }]; + let defaultValue = "IDENTITY"; +} + +class quant_ConstEwUnaryFn : ConstantAttr; +def quant_EwUnaryFn_Identity: quant_ConstEwUnaryFn<"IDENTITY">; +def quant_EwUnaryFn_Tanh : quant_ConstEwUnaryFn<"TANH">; +def quant_EwUnaryFn_Sigmoid : quant_ConstEwUnaryFn<"SIGMOID">; +def quant_EwUnaryFn_Exp : quant_ConstEwUnaryFn<"EXP">; +def quant_EwUnaryFn_Log : quant_ConstEwUnaryFn<"LOG">; +def quant_EwUnaryFn_Neg : quant_ConstEwUnaryFn<"NEG">; +def quant_EwUnaryFn_Rsqrt : quant_ConstEwUnaryFn<"RSQRT">; +def quant_EwUnaryFn_Sin : quant_ConstEwUnaryFn<"SIN">; +def quant_EwUnaryFn_Square : quant_ConstEwUnaryFn<"SQUARE">; +def quant_EwUnaryFn_Sqrt : quant_ConstEwUnaryFn<"SQRT">; +def quant_EwUnaryFn_CmpZ : quant_ConstEwUnaryFn<"CMPZ">; +def quant_EwUnaryFn_CmpNZ : quant_ConstEwUnaryFn<"CMPNZ">; +def quant_EwUnaryFn_CmpLZ : quant_ConstEwUnaryFn<"CMPLZ">; +def quant_EwUnaryFn_CmpGZ : quant_ConstEwUnaryFn<"CMPGZ">; + +//===----------------------------------------------------------------------===// +// Base classes +//===----------------------------------------------------------------------===// + +class quant_Op traits> : + Op; + +//===----------------------------------------------------------------------===// +// Quantization barriers +//===----------------------------------------------------------------------===// +class quant_BarrierOp traits> : + quant_Op, Arguments<(ins quant_RealValueType:$arg)>, + Results<(outs quant_RealValueType)>; + +// A QuantizeBarrier (qbarrier) represents a potential type shift from a +// quantizable type to a quantized type. +// +// At runtime, a qbarrier will apply the transformation expressed by its +// operand and result type. For flexibility during transformation, it is also +// possible to have a qbarrier that performs no transformation (both its +// operand and result type are quantizable). +// +// A qbarrier will typically originate from either: +// a) An expressed or implied constraint in the source dialect which signals +// that a certain level of quantization is possible or required. +// b) An inference made by a quantization algorithm indicating that a +// quantized representation may be acceptable. +// +// Especially early in transformation, it is common to have pairs of +// qbarrier/dbarrier at points where a transition to a quantized type is +// required. In addition, it is also common to have an identity qbarrier +// (where the operand and result type are not quantized) at all points where +// it is legal to use a quantized representation (but is not known to be +// acceptable). +def quant_QuantizeBarrierOp : quant_BarrierOp<"qbarrier", [NoSideEffect]>; + +// A DequantizeBarrier (dbarrier) represents the inverse of a qbarrier, +// converting back from a quantized to quantizable (expressed) type. +// +// Like qbarriers, a dbarrier is allowed to have both its operand and result +// as non quantized types. This facilitates transformations and marks edges +// where the computation must be carried out in the expressed type. +// +// Especially early in transformation, it is common to have dbarriers on +// all operands to ops that must operate with the expressed type (typically +// math ops prior to lowering to target-specific, quantized kernels). +def quant_DequantizeBarrierOp : quant_BarrierOp<"dbarrier", [NoSideEffect]>; + +// A StorageCast (scast) represents a cast from or to a type based on the +// storage type and a type based on a corresponding quantized type. +// +// This op exists to ensure type coherency for between parts of the computation +// which are operating directly on an underlying storage type and those which +// operate on quantized values. +// +// Examples from storage to quantized type: +// i8 -> !quant<"uniform[i8:f32]{1.0}"> +// tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> +// vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">> +def quant_StorageCastOp : + quant_Op<"scast", [NoSideEffect]>, + Arguments<(ins quant_RealOrStorageValueType:$arg)>, + Results<(outs quant_RealOrStorageValueType)>; + +//===----------------------------------------------------------------------===// +// Integral arithmetic ops used by kernels. +//===----------------------------------------------------------------------===// + +def quant_RoundingDivideByPotIOp : + quant_Op<"rounding_divide_by_poti", [NoSideEffect, SameValueType]>, + Arguments<(ins quant_StorageValueType:$x, I32Attr:$exponent)>, + Results<(outs quant_StorageValueType:$y)> { + let description = [{ + Computes integer division by a power-of-two, correctly rounded-to-nearest. + Also known as a rounding arithmetic right shift. See + gemmlowp::RoundingDivideByPOT for a reference implementation. + }]; + + let verifier = [{ + auto verifyExponent = exponent().getSExtValue(); + if (verifyExponent < 0 || verifyExponent > 31) { + return emitOpError("exponent must be in range [0..31]"); + } + return success(); + }]; +} + +def quant_SaturatingAddIOp : + quant_Op<"saturating_addi", [NoSideEffect, SameValueType]>, + Arguments<(ins quant_StorageValueType:$x, + quant_StorageValueType:$y, + I32Attr:$clamp_min, + I32Attr:$clamp_max)>, + Results<(outs quant_StorageValueType:$sum)> { + let description = [{ + Computes saturating addition of two operands, saturating to the given min + and max value. The implementation is responsible for choosing an + intermediate register size appropriate to carry out the operation without + overflow. See gemmlowp::SaturatingAdd for a reference implementation. + }]; +} + +//===----------------------------------------------------------------------===// +// Real math ops. +// +// Math ops on real numbers which may have a representation in quantized +// arithmetic. It is expected that eligible ops are lowered from a source +// dialect to this set of ops prior to the process of converting a compuation +// to a quantized form. It is a non-goal of these ops to preserve enough +// information to convert back to the higher level, source dialect. +// +// These ops support either real/floating point or QuantizedTypes as operands +// and results. Since not all transformations are supported (globally or +// sometimes for specific targets), a computation may end up with +// untransformable RealMathOps, in which case they need to be lowered as is +// (using floating point math). +// +// This op set takes advantage of the fact that it is typically trivial to +// combine a math function with a compatible bias addition and real-valued +// clamp (which can be done at a higher accumulation bit depth). +// +// In addition, all element-wise unary functions are collapsed into a single +// quant_RealUnaryEwOp and selected via an enum-like attribute. Especially at +// low bit depths, this makes matching simpler and allows the construction of +// generic LUT-based implementations. It also allows specific lowering rules +// to consolidate runs of chained unary ops and fuse them to preceding math +// ops, potentially allowing them to operate directly on higher precision +// intermediates without resorting to lots of custom kernels for common +// formulas that can suffer from insufficient precision at low bit depths. +// +// Comparison operators are modeled as element-wise unary functions (i.e. +// CMPZ, CMPNZ, CMPLZ, CMPGZ) intended to follow a sub and output a 1bit +// quantized value. It is expected that lowering rules can fuse them with +// the preceding sub. +//===----------------------------------------------------------------------===// + +class quant_RealMathOp traits = [], dag args> : + quant_Op, + Arguments; + +//===----------------------------------------------------------------------===// +// Element wise binary real math ops. +//===----------------------------------------------------------------------===// + +class quant_RealBinaryOp traits = []> : + quant_RealMathOp, + Results<(outs quant_RealValueType:$r)>; + +class quant_RealBinaryBiasOp traits = []> : + quant_RealMathOp, + Results<(outs quant_RealValueType:$r)>; + +def quant_RealAddEwOp : + quant_RealBinaryOp<"real_add_ew", [NoSideEffect]>; + +def quant_RealSubEwOp : + quant_RealBinaryOp<"real_sub_ew", [NoSideEffect]>; + +def quant_RealMulEwOp : + quant_RealBinaryOp<"real_mul_ew", [NoSideEffect]>; + +def quant_RealDivEwOp : + quant_RealBinaryOp<"real_div_ew", [NoSideEffect]>; + +//===----------------------------------------------------------------------===// +// Element wise unary real math op. +//===----------------------------------------------------------------------===// + +def quant_RealUnaryEwOp : + quant_RealMathOp<"real_unary_ew", [NoSideEffect], + (ins quant_RealValueType:$x, quant_EwUnaryFnAttr:$fn)>, + Results<(outs quant_RealValueType:$r)>; + +#endif // QUANTIZATION_OPS diff --git a/mlir/include/mlir/Quantization/QuantizeUtils.h b/mlir/include/mlir/Quantization/QuantizeUtils.h new file mode 100644 index 000000000000..0e4d04ab9f1b --- /dev/null +++ b/mlir/include/mlir/Quantization/QuantizeUtils.h @@ -0,0 +1,70 @@ +//===- QuantizeUtils.h - Support utilities for quantization -----*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef MLIR_QUANTIZATION_QUANTIZEUTILS_H_ +#define MLIR_QUANTIZATION_QUANTIZEUTILS_H_ + +namespace mlir { +class Attribute; +class Type; + +namespace quant { +class QuantizedType; +class UniformQuantizedType; +class UniformQuantizedValueConverter; + +/// Converts an attribute from a type based on +/// quantizedElementType.getExpressedType() to one based on +/// quantizedElementType.getStorageType(), where quantizedElementType is as from +/// QuantizedType::getQuantizedElementType(). +/// Returns nullptr if the conversion is not supported. On success, stores the +/// converted type in outConvertedType. +/// +/// Examples: +/// 1. realValue is a primitive value attribute: +/// (realValue: FloatAttr, quantizedElementType: UniformQuantizedType[i8:f32]) +/// -> (IntegerAttr, outConvertedType: i8) +/// 2. realValue is an elements attribute: +/// (realValue: DenseElementsAttr[tensor<2x2xf32>], +/// quantizedElementType: UniformQuantizedType[i8:f32]) +/// -> (DenseElementsAttr[tensor<2x2xi8>], outConvertedType: tensor<2x2xi8>) +Attribute quantizeAttr(Attribute realValue, QuantizedType quantizedElementType, + Type &outConvertedType); + +/// Converts an attribute from a type based on +/// quantizedElementType.getExpressedType() to one based on +/// quantizedElementType.getStorageType(), where quantizedElementType is as from +/// QuantizedType::getQuantizedElementType() and casted to an +/// UniformQuantizedType. Returns nullptr if the conversion is not supported. On +/// success, stores the converted type in outConvertedType. +/// +/// Examples: +/// 1. realValue is a primitive value attribute: +/// (realValue: FloatAttr, quantizedElementType: UniformQuantizedType[i8:f32]) +/// -> (IntegerAttr, outConvertedType: i8) +/// 2. realValue is an elements attribute: +/// (realValue: DenseElementsAttr[tensor<2x2xf32>], +/// quantizedElementType: UniformQuantizedType[i8:f32]) +/// -> (DenseElementsAttr[tensor<2x2xi8>], outConvertedType: tensor<2x2xi8>) +Attribute quantizeAttrUniform(Attribute realValue, + UniformQuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter, + Type &outConvertedType); +} // namespace quant +} // namespace mlir + +#endif // MLIR_QUANTIZATION_QUANTIZEUTILS_H_ diff --git a/mlir/include/mlir/Quantization/UniformSupport.h b/mlir/include/mlir/Quantization/UniformSupport.h new file mode 100644 index 000000000000..a2055ee287a5 --- /dev/null +++ b/mlir/include/mlir/Quantization/UniformSupport.h @@ -0,0 +1,119 @@ +//===- UniformSupport.h - Support utilities for uniform quant ---*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef MLIR_QUANTIZATION_UNIFORMSUPPORT_H_ +#define MLIR_QUANTIZATION_UNIFORMSUPPORT_H_ + +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/Quantization/QuantOps.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/APSInt.h" + +namespace mlir { +namespace quant { + +/// Performs type conversion from an arbitrary input type to a type +/// that is expressed by a UniformQuantizedType. +/// +/// This handles cases where the inputType is a supported primitive type +/// (i.e. f32, bf16, etc) or a vector/tensor type based on a supported +/// elemental type. +/// +/// Since conversion often involves introspecting some attributes of the +/// input type in order to determine how to represent it, this is a two step +/// process. +struct ExpressedToUniformQuantizedConverter { + /// Creates a converter for the given input type. + static const ExpressedToUniformQuantizedConverter + forInputType(Type inputType); + + /// Converts the inputType to be based on the given elemental type, + /// returning the new type (or nullptr and emit an error on failure). + Type convert(UniformQuantizedType elementalType) const; + + /// Whether the conversion is legal. + explicit operator bool() const { return (bool)expressedType; } + + /// The input type that is being converted from. + /// This may be an elemental or composite type. + const Type inputType; + + /// Supported, elemental expressed type (i.e. f32). + /// Will be nullptr if conversion is not supported. + const Type expressedType; +}; + +/// Reference implementation of converting between real numbers and values +/// represented by a UniformQuantizedType. +/// Note that this is not expected to be speedy and may be superceded eventually +/// by a more optimal implementation. +/// Also, the interface assumes that quantization is done per-layer and will +/// need to be wider for various per-channel schemes. As such, this is a +/// placeholder. +class UniformQuantizedValueConverter { +public: + UniformQuantizedValueConverter(UniformQuantizedType uniformType) + : scale(uniformType.getScale()), + zeroPoint(static_cast(uniformType.getZeroPoint())), + clampMin(static_cast(uniformType.getStorageTypeMin())), + clampMax(static_cast(uniformType.getStorageTypeMax())), + storageBitWidth(uniformType.getStorageTypeIntegralWidth()), + isSigned(uniformType.isSigned()) { + assert(uniformType.getExpressedType().isa()); + assert(uniformType.getStorageType().isa()); + } + + virtual APInt quantizeFloatToInt(APFloat expressedValue) const { + bool lossy; + expressedValue.convert(scale.getSemantics(), APFloat::rmNearestTiesToEven, + &lossy); + // fixedpoint = clamp(clampMin, clampMax, ( + // roundHalfToEven(expressed / scale) + zeroPoint)) + APFloat scaled = (expressedValue / scale); + scaled.roundToIntegral(APFloat::rmNearestTiesToEven); + scaled.add(zeroPoint, APFloat::rmNearestTiesToEven); + APFloat fixedpoint = llvm::minimum(scaled, clampMax); + fixedpoint = llvm::maximum(fixedpoint, clampMin); + + llvm::APSInt result(storageBitWidth, !isSigned); + fixedpoint.convertToInteger(result, APFloat::rmNearestTiesToEven, &lossy); + + return result; + } + + int64_t quantizeFloatToInt64(APFloat expressedValue) const { + APInt qValue = quantizeFloatToInt(expressedValue); + return isSigned ? qValue.getSExtValue() : qValue.getZExtValue(); + } + + virtual ~UniformQuantizedValueConverter() {} + +private: + const APFloat scale; + const APFloat zeroPoint; + const APFloat clampMin; + const APFloat clampMax; + const uint32_t storageBitWidth; + const bool isSigned; +}; + +} // namespace quant +} // namespace mlir + +#endif // MLIR_QUANTIZATION_UNIFORMSUPPORT_H_ diff --git a/mlir/lib/Quantization/IR/DialectRegistration.cpp b/mlir/lib/Quantization/IR/DialectRegistration.cpp new file mode 100644 index 000000000000..6beb193aecd7 --- /dev/null +++ b/mlir/lib/Quantization/IR/DialectRegistration.cpp @@ -0,0 +1,24 @@ +//===- DialectRegistration.cpp - Register Quantization dialect ------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Quantization/QuantOps.h" + +using namespace mlir; +using namespace mlir::quant; + +// Static initialization for Quantization dialect registration. +static mlir::DialectRegistration QuantizationOps; diff --git a/mlir/lib/Quantization/IR/FakeQuantSupport.cpp b/mlir/lib/Quantization/IR/FakeQuantSupport.cpp new file mode 100644 index 000000000000..34457d909734 --- /dev/null +++ b/mlir/lib/Quantization/IR/FakeQuantSupport.cpp @@ -0,0 +1,94 @@ +#include "mlir/Quantization/FakeQuantSupport.h" +#include "mlir/Quantization/QuantOps.h" + +using namespace mlir; +using namespace mlir::quant; + +UniformQuantizedType mlir::quant::fakeQuantAttrsToType(Location loc, + unsigned numBits, + double rmin, double rmax, + bool narrowRange, + Type expressedType) { + MLIRContext *ctx = expressedType.getContext(); + Type storageType; + unsigned flags; + int64_t qmin; + int64_t qmax; + + // Hard-coded type mapping from TFLite. + if (numBits <= 8) { + storageType = IntegerType::get(8, ctx); + flags = 0; + qmin = 0; + qmax = 255; + } else if (numBits <= 16) { + storageType = IntegerType::get(16, ctx); + flags = QuantizationFlags::Signed; + qmin = -32768; + qmax = 32767; + } else { + ctx->emitError(loc, + "unsupported FakeQuant number of bits: " + Twine(numBits)); + return nullptr; + } + + // Handle narrowRange. + if (narrowRange) { + qmin += 1; + } + + // Range must straddle zero. + if (rmin > 0.0 || rmax < 0.0) { + return (ctx->emitError(loc, "FakeQuant range must straddle zero: [" + + Twine(std::to_string(rmin)) + "," + + Twine(std::to_string(rmax)) + "]"), + nullptr); + } + + // Special case where min/max is a point. Must be 0. + if (rmin == rmax) { + return UniformQuantizedType::getChecked(flags, storageType, expressedType, + 0.0, 0, qmin, qmax, loc); + } + + // Determine the scale. + const double qminDouble = qmin; + const double qmaxDouble = qmax; + const double scale = (rmax - rmin) / (qmaxDouble - qminDouble); + + // Zero point computation. + // In float, solve the affine equation for any known pair + // (real value, corresponding quantized value), of which, two such pairs + // are known: (rmin, qmin), (rmax, qmax). + // The arithmetic error on the zero point computed from either pair will be + // roughly machine_epsilon * (sum of absolute values of terms). + // Use the variant that adds the smaller error. + const double zeroPointFromMin = qminDouble - rmin / scale; + const double zeroPointFromMinError = + std::abs(qminDouble) + std::abs(rmin / scale); + const double zeroPointFromMax = qmaxDouble - rmax / scale; + const double zeroPointFromMaxError = + std::abs(qmaxDouble) + std::abs(rmax / scale); + + const double zeroPointDouble = (zeroPointFromMinError < zeroPointFromMaxError) + ? zeroPointFromMin + : zeroPointFromMax; + + // Now nudge the zero point to be an integer. + int64_t nudgedZeroPoint = 0; + if (zeroPointDouble < qminDouble) { + nudgedZeroPoint = qmin; + } else if (zeroPointDouble > qmaxDouble) { + nudgedZeroPoint = qmax; + } else { + nudgedZeroPoint = round(zeroPointDouble); + } + + // By construction, the nudged zero point should always be in range. + assert(nudgedZeroPoint >= qmin); + assert(nudgedZeroPoint <= qmax); + + return UniformQuantizedType::getChecked(flags, storageType, expressedType, + scale, nudgedZeroPoint, qmin, qmax, + loc); +} diff --git a/mlir/lib/Quantization/IR/QuantOps.cpp b/mlir/lib/Quantization/IR/QuantOps.cpp new file mode 100644 index 000000000000..05a516202c1f --- /dev/null +++ b/mlir/lib/Quantization/IR/QuantOps.cpp @@ -0,0 +1,360 @@ +//===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Quantization/QuantOps.h" +#include "TypeDetail.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/StandardTypes.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/MathExtras.h" + +using namespace mlir; +using namespace mlir::quant; +using namespace mlir::quant::detail; + +unsigned QuantizedType::getFlags() const { + return static_cast(type)->flags; +} + +LogicalResult QuantizedType::verifyConstructionInvariants( + llvm::Optional loc, MLIRContext *context, unsigned flags, + Type storageType, Type expressedType, int64_t storageTypeMin, + int64_t storageTypeMax) { + // Verify that the expressed type is floating point. + // If this restriction is ever eliminated, the parser/printer must be + // extended. + if (!expressedType.isa()) { + if (loc) { + context->emitError(*loc, "expressed type must be floating point"); + } + return failure(); + } + + // Verify that the storage type is integral. + // This restriction may be lifted at some point in favor of using bf16 + // or f16 as exact representations on hardware where that is advantageous. + auto intStorageType = storageType.dyn_cast(); + if (!intStorageType) { + if (loc) { + context->emitError(*loc, "storage type must be integral"); + } + return failure(); + } + unsigned integralWidth = intStorageType.getWidth(); + + // Verify storage width. + if (integralWidth == 0 || integralWidth > MaxStorageBits) { + if (loc) { + context->emitError(*loc, + "illegal storage type size: " + Twine(integralWidth)); + } + return failure(); + } + + // Verify storageTypeMin and storageTypeMax. + bool isSigned = + (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed; + int64_t defaultIntegerMin = + getDefaultMininumForInteger(isSigned, integralWidth); + int64_t defaultIntegerMax = + getDefaultMaxinumForInteger(isSigned, integralWidth); + if (storageTypeMax - storageTypeMin <= 0 || + storageTypeMin < defaultIntegerMin || + storageTypeMax > defaultIntegerMax) { + if (loc) { + context->emitError(*loc, "illegal storage min and storage max: (" + + Twine(storageTypeMin) + ":" + + Twine(storageTypeMax) + ")"); + } + return failure(); + } + return success(); +} + +Type QuantizedType::getStorageType() const { + return static_cast(type)->storageType; +} + +int64_t QuantizedType::getStorageTypeMin() const { + return static_cast(type)->storageTypeMin; +} + +int64_t QuantizedType::getStorageTypeMax() const { + return static_cast(type)->storageTypeMax; +} + +unsigned QuantizedType::getStorageTypeIntegralWidth() const { + // NOTE: If ever supporting non-integral storage types, some other scheme + // for determining the width will be needed. + return static_cast(type)->storageType.getIntOrFloatBitWidth(); +} + +Type QuantizedType::getExpressedType() const { + return static_cast(type)->expressedType; +} + +bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) { + if (candidateExpressedType.isa()) { + return candidateExpressedType.cast().getElementType() == + getExpressedType(); + } + return candidateExpressedType == getExpressedType(); +} + +QuantizedType +QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) { + if (primitiveOrContainerType.isa()) { + Type elementType = + primitiveOrContainerType.cast().getElementType(); + return elementType.dyn_cast(); + } + return primitiveOrContainerType.dyn_cast(); +} + +Type QuantizedType::castFromStorageType(Type candidateType) { + if (candidateType == getStorageType()) { + // i.e. i32 -> quant<"uniform[i8:f32]{1.0}"> + return *this; + } else if (candidateType.isa()) { + // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> + return RankedTensorType::get( + candidateType.cast().getShape(), getStorageType()); + } else if (candidateType.isa()) { + // i.e. tensor -> tensor> + return UnrankedTensorType::get(getStorageType()); + } else if (candidateType.isa()) { + // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> + return VectorType::get(candidateType.cast().getShape(), + getStorageType()); + } + + return nullptr; +} + +Type QuantizedType::castToStorageType(Type quantizedType) { + if (quantizedType.isa()) { + // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8 + return quantizedType.cast().getStorageType(); + } else if (quantizedType.isa()) { + // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> + VectorOrTensorType vtType = quantizedType.cast(); + if (!vtType.getElementType().isa()) { + return nullptr; + } + Type storageType = + vtType.getElementType().cast().getStorageType(); + if (quantizedType.isa()) { + return RankedTensorType::get(vtType.getShape(), storageType); + } else if (quantizedType.isa()) { + return UnrankedTensorType::get(storageType); + } else if (quantizedType.isa()) { + return VectorType::get(vtType.getShape(), storageType); + } + } + + return nullptr; +} + +Type QuantizedType::castFromExpressedType(Type candidateType) { + if (candidateType == getExpressedType()) { + // i.e. f32 -> quant<"uniform[i8:f32]{1.0}"> + return *this; + } else if (candidateType.isa()) { + VectorOrTensorType candidateVtType = + candidateType.cast(); + if (candidateVtType.getElementType() != getExpressedType()) { + return nullptr; + } + + if (candidateType.isa()) { + // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> + return RankedTensorType::get(candidateVtType.getShape(), *this); + } else if (candidateType.isa()) { + // i.e. tensor -> tensor> + return UnrankedTensorType::get(*this); + } else if (candidateType.isa()) { + // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> + return VectorType::get(candidateVtType.getShape(), *this); + } + } + + return nullptr; +} + +Type QuantizedType::castToExpressedType(Type quantizedType) { + if (quantizedType.isa()) { + // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32 + return quantizedType.cast().getExpressedType(); + } else if (quantizedType.isa()) { + // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> + VectorOrTensorType vtType = quantizedType.cast(); + if (!vtType.getElementType().isa()) { + return nullptr; + } + Type expressedType = + vtType.getElementType().cast().getExpressedType(); + if (quantizedType.isa()) { + return RankedTensorType::get(vtType.getShape(), expressedType); + } else if (quantizedType.isa()) { + return UnrankedTensorType::get(expressedType); + } else if (quantizedType.isa()) { + return VectorType::get(vtType.getShape(), expressedType); + } + } + + return nullptr; +} + +Type QuantizedType::castExpressedToStorageType(Type candidateType) { + Type expressedQuantizedType = castFromExpressedType(candidateType); + if (!expressedQuantizedType) { + return nullptr; + } + return QuantizedType::castToStorageType(expressedQuantizedType); +} + +UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType, + Type expressedType, double scale, + int64_t zeroPoint, + int64_t storageTypeMin, + int64_t storageTypeMax) { + return Base::get(storageType.getContext(), + QuantizationTypes::UniformQuantized, flags, storageType, + expressedType, scale, zeroPoint, storageTypeMin, + storageTypeMax); +} + +UniformQuantizedType +UniformQuantizedType::getChecked(unsigned flags, Type storageType, + Type expressedType, double scale, + int64_t zeroPoint, int64_t storageTypeMin, + int64_t storageTypeMax, Location location) { + return Base::getChecked(location, storageType.getContext(), + QuantizationTypes::UniformQuantized, flags, + storageType, expressedType, scale, zeroPoint, + storageTypeMin, storageTypeMax); +} + +LogicalResult UniformQuantizedType::verifyConstructionInvariants( + llvm::Optional loc, MLIRContext *context, unsigned flags, + Type storageType, Type expressedType, double scale, int64_t zeroPoint, + int64_t storageTypeMin, int64_t storageTypeMax) { + if (failed(QuantizedType::verifyConstructionInvariants( + loc, context, flags, storageType, expressedType, storageTypeMin, + storageTypeMax))) { + return failure(); + } + + // Verify scale. + if (scale <= 0.0 || isinf(scale) || isnan(scale)) { + if (loc) { + context->emitError(*loc, + "illegal scale: " + Twine(std::to_string(scale))); + } + return failure(); + } + + return success(); +} + +double UniformQuantizedType::getScale() const { return getImpl()->scale; } + +int64_t UniformQuantizedType::getZeroPoint() const { + return getImpl()->zeroPoint; +} + +UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get( + unsigned flags, Type storageType, Type expressedType, + ArrayRef scales, ArrayRef zeroPoints, + int32_t quantizedDimension, int64_t storageTypeMin, + int64_t storageTypeMax) { + return Base::get(storageType.getContext(), + QuantizationTypes::UniformQuantizedPerAxis, flags, + storageType, expressedType, scales, zeroPoints, + quantizedDimension, storageTypeMin, storageTypeMax); +} + +UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked( + unsigned flags, Type storageType, Type expressedType, + ArrayRef scales, ArrayRef zeroPoints, + int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax, + Location location) { + return Base::getChecked(location, storageType.getContext(), + QuantizationTypes::UniformQuantizedPerAxis, flags, + storageType, expressedType, scales, zeroPoints, + quantizedDimension, storageTypeMin, storageTypeMax); +} + +LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants( + llvm::Optional loc, MLIRContext *context, unsigned flags, + Type storageType, Type expressedType, ArrayRef scales, + ArrayRef zeroPoints, int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax) { + if (failed(QuantizedType::verifyConstructionInvariants( + loc, context, flags, storageType, expressedType, storageTypeMin, + storageTypeMax))) { + return failure(); + } + + // Ensure that the number of scales and zeroPoints match. + if (scales.size() != zeroPoints.size()) { + if (loc) { + context->emitError(*loc, "illegal number of scales and zeroPoints: " + + Twine(scales.size()) + ", " + + Twine(zeroPoints.size())); + } + return failure(); + } + + // Verify scale. + for (double scale : scales) { + if (scale <= 0.0 || isinf(scale) || isnan(scale)) { + if (loc) { + context->emitError(*loc, + "illegal scale: " + Twine(std::to_string(scale))); + } + return failure(); + } + } + + return success(); +} + +ArrayRef UniformQuantizedPerAxisType::getScales() const { + return getImpl()->getScales(); +} + +ArrayRef UniformQuantizedPerAxisType::getZeroPoints() const { + return getImpl()->getZeroPoints(); +} + +int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const { + return getImpl()->quantizedDimension; +} + +#define GET_OP_CLASSES +#include "mlir/Quantization/QuantOps.cpp.inc" + +QuantizationDialect::QuantizationDialect(MLIRContext *context) + : Dialect(/*name=*/"quant", context) { + addTypes(); + addOperations< +#define GET_OP_LIST +#include "mlir/Quantization/QuantOps.cpp.inc" + >(); +} diff --git a/mlir/lib/Quantization/IR/TypeDetail.h b/mlir/lib/Quantization/IR/TypeDetail.h new file mode 100644 index 000000000000..d3db91e10ab5 --- /dev/null +++ b/mlir/lib/Quantization/IR/TypeDetail.h @@ -0,0 +1,219 @@ +//===- Quantization/IR/TypeDetail.h - Type detail ---------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef TYPE_DETAIL_H_ +#define TYPE_DETAIL_H_ + +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/bit.h" + +namespace mlir { +namespace quant { +namespace detail { + +struct QuantizedTypeStorage : public mlir::TypeStorage { + QuantizedTypeStorage(unsigned flags, Type storageType, Type expressedType, + int64_t storageTypeMin, int64_t storageTypeMax) + : flags(flags), storageType(storageType), expressedType(expressedType), + storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {} + + /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue. + unsigned flags; + + // Integral type for the storage point representation. + Type storageType; + + // Floating point type that the quantized type approximates. + Type expressedType; + + // The minimum value storageType can take. + int64_t storageTypeMin; + + // The maximum value storageType can take. + int64_t storageTypeMax; +}; + +struct UniformQuantizedTypeStorage : public QuantizedTypeStorage { + struct KeyTy { + KeyTy(unsigned flags, Type storageType, Type expressedType, double scale, + int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax) + : flags(flags), storageType(storageType), expressedType(expressedType), + scale(scale), zeroPoint(zeroPoint), storageTypeMin(storageTypeMin), + storageTypeMax(storageTypeMax) {} + /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue. + unsigned flags; + + // Integral type for the storage point representation. + Type storageType; + + // Floating point type that the quantized type approximates. + Type expressedType; + + double scale; + int64_t zeroPoint; + int64_t storageTypeMin; + int64_t storageTypeMax; + + // Check for equality of two structures that share KeyTy data members + // (by name). + template + static bool genericIsEqual(const T &lhs, const U &rhs) { + return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType && + lhs.expressedType == rhs.expressedType && lhs.scale == rhs.scale && + lhs.zeroPoint == rhs.zeroPoint && + lhs.storageTypeMin == rhs.storageTypeMin && + lhs.storageTypeMax == rhs.storageTypeMax; + } + + bool operator==(const KeyTy &other) const { + return genericIsEqual(*this, other); + } + + unsigned getHashValue() const { + int64_t scaleBits = llvm::bit_cast(scale); + return llvm::hash_combine(flags, storageType, expressedType, scaleBits, + zeroPoint, storageTypeMin, storageTypeMax); + } + }; + + UniformQuantizedTypeStorage(const KeyTy &key) + : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType, + key.storageTypeMin, key.storageTypeMax), + scale(key.scale), zeroPoint(key.zeroPoint) {} + + bool operator==(const KeyTy &key) const { + return KeyTy::genericIsEqual(*this, key); + } + + /// Construction. + static UniformQuantizedTypeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) + UniformQuantizedTypeStorage(key); + } + + static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } + + double scale; + int64_t zeroPoint; +}; + +struct UniformQuantizedPerAxisTypeStorage : public QuantizedTypeStorage { + struct KeyTy { + KeyTy(unsigned flags, Type storageType, Type expressedType, + ArrayRef scales, ArrayRef zeroPoints, + int32_t quantizedDimension, int64_t storageTypeMin, + int64_t storageTypeMax) + : flags(flags), storageType(storageType), expressedType(expressedType), + scales(scales), zeroPoints(zeroPoints), + quantizedDimension(quantizedDimension), + storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {} + /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue. + unsigned flags; + + // Integral type for the storage point representation. + Type storageType; + + // Floating point type that the quantized type approximates. + Type expressedType; + + ArrayRef scales; + ArrayRef zeroPoints; + int32_t quantizedDimension; + int64_t storageTypeMin; + int64_t storageTypeMax; + + ArrayRef getScales() const { return scales; } + + ArrayRef getZeroPoints() const { return zeroPoints; } + + // Check for equality of two structures that share KeyTy data members + // (by name). + template + static bool genericIsEqual(const T &lhs, const U &rhs) { + return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType && + lhs.expressedType == rhs.expressedType && + lhs.getScales() == rhs.getScales() && + lhs.getZeroPoints() == rhs.getZeroPoints() && + lhs.quantizedDimension == rhs.quantizedDimension && + lhs.storageTypeMin == rhs.storageTypeMin && + lhs.storageTypeMax == rhs.storageTypeMax; + } + + bool operator==(const KeyTy &other) const { + return genericIsEqual(*this, other); + } + + unsigned getHashValue() const { + int64_t *scalesCast = llvm::bit_cast(scales.data()); + ArrayRef scalesBits(scalesCast, scales.size()); + return llvm::hash_combine( + flags, storageType, expressedType, + llvm::hash_combine_range(scalesBits.begin(), scalesBits.end()), + llvm::hash_combine_range(zeroPoints.begin(), zeroPoints.end()), + storageTypeMin, storageTypeMax); + } + }; + + // We pass scales and zeroPoints in directly rather than relying on KeyTy + // because we have to create new reallocated versions in `constrcut` below. + UniformQuantizedPerAxisTypeStorage(const KeyTy &key, ArrayRef scales, + ArrayRef zeroPoints) + : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType, + key.storageTypeMin, key.storageTypeMax), + scaleElements(scales.data()), zeroPointElements(zeroPoints.data()), + quantParamsSize(scales.size()), + quantizedDimension(key.quantizedDimension) {} + + bool operator==(const KeyTy &key) const { + return KeyTy::genericIsEqual(*this, key); + } + + /// Construction. + static UniformQuantizedPerAxisTypeStorage * + construct(TypeStorageAllocator &allocator, const KeyTy &key) { + ArrayRef scales = allocator.copyInto(key.scales); + ArrayRef zeroPoints = allocator.copyInto(key.zeroPoints); + return new (allocator.allocate()) + UniformQuantizedPerAxisTypeStorage(key, scales, zeroPoints); + } + + static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } + + ArrayRef getScales() const { + return ArrayRef(scaleElements, quantParamsSize); + } + + ArrayRef getZeroPoints() const { + return ArrayRef(zeroPointElements, quantParamsSize); + } + + const double *scaleElements; + const int64_t *zeroPointElements; + unsigned quantParamsSize; + int32_t quantizedDimension; +}; + +} // namespace detail +} // namespace quant +} // namespace mlir + +#endif // TYPE_DETAIL_H_ diff --git a/mlir/lib/Quantization/IR/TypeParser.cpp b/mlir/lib/Quantization/IR/TypeParser.cpp new file mode 100644 index 000000000000..352e952a79fb --- /dev/null +++ b/mlir/lib/Quantization/IR/TypeParser.cpp @@ -0,0 +1,653 @@ +//===- Quantization/IR/TypeParser.h - Quantization Type Parser --*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/IR/Location.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/Quantization/QuantOps.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/Format.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace quant { + +/// Print a floating point value in a way that the parser will be able to +/// round-trip losslessly. +static void printStabilizedFloat(const APFloat &apValue, raw_ostream &os) { + // We would like to output the FP constant value in exponential notation, + // but we cannot do this if doing so will lose precision. Check here to + // make sure that we only output it in exponential format if we can parse + // the value back and get the same value. + bool isInf = apValue.isInfinity(); + bool isNaN = apValue.isNaN(); + if (!isInf && !isNaN) { + SmallString<128> strValue; + apValue.toString(strValue, 6, 0, false); + + // Check to make sure that the stringized number is not some string like + // "Inf" or NaN, that atof will accept, but the lexer will not. Check + // that the string matches the "[-+]?[0-9]" regex. + assert(((strValue[0] >= '0' && strValue[0] <= '9') || + ((strValue[0] == '-' || strValue[0] == '+') && + (strValue[1] >= '0' && strValue[1] <= '9'))) && + "[-+]?[0-9] regex does not match!"); + // Reparse stringized version! + if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) { + os << strValue; + return; + } + } + + SmallVector str; + apValue.toString(str); + os << str; +} + +namespace { + +enum class TokenKind { + error, + eof, + l_bracket, + r_bracket, + l_brace, + r_brace, + l_paren, + r_paren, + colon, + comma, + alpha_ident, + integer_literal, + float_literal, +}; + +struct Token { + TokenKind kind; + StringRef spelling; +}; + +class Lexer { +public: + Lexer(StringRef source) : curBuffer(source), curPtr(curBuffer.begin()) {} + + Token lexToken(); + +private: + Token formToken(TokenKind kind, const char *tokStart) { + return Token{kind, StringRef(tokStart, curPtr - tokStart)}; + } + + Token emitError(const char *loc, const Twine &message) { + return formToken(TokenKind::error, loc); + } + + bool isEnd() const { return curPtr == curBuffer.end(); } + + // Lexer implementation methods + Token lexalpha_ident(const char *tokStart); + Token lexNumber(const char *tokStart); + + StringRef curBuffer; + const char *curPtr; +}; + +} // namespace + +Token Lexer::lexToken() { + // Ignore whitespace. + while (!isEnd()) { + switch (*curPtr) { + case ' ': + case '\t': + case '\n': + case '\r': + ++curPtr; + continue; + default: + break; + } + break; + } + + if (isEnd()) { + return Token{TokenKind::eof, ""}; + } + + const char *tokStart = curPtr; + switch (*curPtr++) { + default: + if (isalpha(*tokStart)) { + return lexalpha_ident(tokStart); + } + if (isdigit(*tokStart)) { + return lexNumber(tokStart); + } + + return emitError(tokStart, "unexpected character"); + + case '[': + return formToken(TokenKind::l_bracket, tokStart); + case ']': + return formToken(TokenKind::r_bracket, tokStart); + case '{': + return formToken(TokenKind::l_brace, tokStart); + case '}': + return formToken(TokenKind::r_brace, tokStart); + case '(': + return formToken(TokenKind::l_paren, tokStart); + case ')': + return formToken(TokenKind::r_paren, tokStart); + case ':': + return formToken(TokenKind::colon, tokStart); + case ',': + return formToken(TokenKind::comma, tokStart); + case '-': + return lexNumber(tokStart); + case '+': + return lexNumber(tokStart); + } +} + +/// Lex a bare alpha identifier. Since this DSL often contains identifiers with +/// trailing numeric components, this only matches alphas. It is up to the +/// parser to handle identifiers that can be mixed alphanum. +/// +/// alpha-ident ::= (letter)(letter)* +Token Lexer::lexalpha_ident(const char *tokStart) { + while (!isEnd() && isalpha(*curPtr)) { + ++curPtr; + } + return formToken(TokenKind::alpha_ident, tokStart); +} + +/// Lex a number. +/// +/// integer-literal ::= [-+]?digit+ +/// float-literal ::= [-+]?[0-9]+[.][0-9]*([eE][-+]?[0-9]+)? +Token Lexer::lexNumber(const char *tokStart) { + // Leading '+', '-' or digit has already been consumed. + while (!isEnd() && isdigit(*curPtr)) { + ++curPtr; + } + // If not a decimal point, treat as integer. + if (isEnd() || *curPtr != '.') { + return formToken(TokenKind::integer_literal, tokStart); + } + ++curPtr; + + // Skip over [0-9]*([eE][-+]?[0-9]+)? + // Leading digits. + while (!isEnd() && isdigit(*curPtr)) { + ++curPtr; + } + + // [eE][-+]?[0-9]+ + if (!isEnd() && (*curPtr == 'e' || *curPtr == 'E')) { + auto remaining = curBuffer.end() - curPtr; + if (remaining > 2 && isdigit(curPtr[1])) { + // Lookahead 2 for digit. + curPtr += 2; + while (!isEnd() && isdigit(*curPtr)) { + ++curPtr; + } + } else if (remaining > 3 && (curPtr[1] == '-' || curPtr[1] == '+') && + isdigit(curPtr[2])) { + // Lookahead 3 for [+-] digit. + curPtr += 3; + while (!isEnd() && isdigit(*curPtr)) { + ++curPtr; + } + } + } + return formToken(TokenKind::float_literal, tokStart); +} // end namespace + +// --- TypeParser --- +namespace { + +class TypeParser { +public: + TypeParser(StringRef source, MLIRContext *context, Location location) + : context(context), location(location), lexer(source), + curToken(lexer.lexToken()) {} + + /// Attempts to parse the source as a type, returning the unknown + /// type on error. + Type parseType(); + +private: + /// Unconditionally consumes the current token. + void consumeToken() { + assert(curToken.kind != TokenKind::eof && + "should not advance past EOF or errors"); + curToken = lexer.lexToken(); + } + + /// Unconditionally consumes the current token, asserting that it is of the + /// specified kind. + void consumeToken(TokenKind kind) { + assert(curToken.kind == kind && "consumed an unexpected token"); + consumeToken(); + } + + /// Conditionally consumes a token if of the specified kind. + /// Returns true if consumed. + bool consumeIf(TokenKind kind) { + if (curToken.kind == kind) { + consumeToken(); + return true; + } + return false; + } + + /// Emits an error at the current location with a message. + void emitError(const Twine &message) { + // TODO: All errors show up at the beginning of the extended type location. + // Figure out how to make this location relative to where the error occurred + // in this instance. + context->emitError(location, message); + } + + // Parsers. + Type parseUniformType(); + IntegerType parseStorageType(bool &isSigned); + FloatType parseExpressedType(); + bool parseQuantParams(double &scale, int64_t &zeroPoint); + + MLIRContext *context; + Location location; + Lexer lexer; + + // The next token that has not yet been consumed. + Token curToken; +}; + +} // namespace + +Type TypeParser::parseType() { + // All types start with an identifier that we switch on. + if (curToken.kind == TokenKind::alpha_ident) { + StringRef typeNameSpelling = curToken.spelling; + consumeToken(); + + Type result; + if (typeNameSpelling == "uniform") { + result = parseUniformType(); + if (!result) { + return nullptr; + } + } else { + return (emitError("unknown quantized type " + typeNameSpelling), nullptr); + } + + // Make sure the entire input was consumed. + if (curToken.kind != TokenKind::eof) { + return (emitError("unrecognized token: " + curToken.spelling), nullptr); + } + + return result; + } else { + return (emitError("unrecognized token: " + curToken.spelling), nullptr); + } +} + +/// Parses a UniformQuantizedType. +/// +/// uniform_type ::= `uniform` type_spec quant_param_spec +/// +/// type_spec ::= `[` storage-spec `:` expressed-type (quant-dim)? `]` +/// quant-dim ::= `:` integer-literal +/// storage-spec ::= storage-type (`(` storage-range `)`)? +/// storage-range ::= integer-literal `:` integer-literal +/// storage-type ::= (`i` | `u`) integer-literal +/// expressed-type ::= (`f16` | `f32` | `f64` | `bf16`) +/// +/// quant_param_spec ::= `{` scale-zero (`,` scale-zero )* `}` +/// scale-zero ::= float-literal `:` integer-literal +Type TypeParser::parseUniformType() { + IntegerType storageType; + FloatType expressedType; + unsigned typeFlags = 0; + int64_t storageTypeMin; + int64_t storageTypeMax; + bool isPerAxis = false; + int32_t quantizedDimension; + SmallVector scales; + SmallVector zeroPoints; + + // Type specification. + if (!consumeIf(TokenKind::l_bracket)) { + return (emitError("unrecognized token: " + curToken.spelling), nullptr); + } + + // Storage type. + bool isSigned = false; + storageType = parseStorageType(isSigned); + if (!storageType) { + return nullptr; + } + if (isSigned) { + typeFlags |= QuantizationFlags::Signed; + } + + // Storage type range. + int64_t defaultIntegerMin = QuantizedType::getDefaultMininumForInteger( + isSigned, storageType.getWidth()); + int64_t defaultIntegerMax = QuantizedType::getDefaultMaxinumForInteger( + isSigned, storageType.getWidth()); + if (consumeIf(TokenKind::l_paren)) { + // Explicit storage min and storage max. + if (curToken.kind != TokenKind::integer_literal) { + return (emitError("expected storage type minimum"), nullptr); + } + if (curToken.spelling.getAsInteger(10, storageTypeMin) || + storageTypeMin < defaultIntegerMin) { + return (emitError("illegal storage type minimum: " + curToken.spelling), + nullptr); + } + consumeToken(TokenKind::integer_literal); + + if (!consumeIf(TokenKind::colon)) { + return (emitError("unrecognized token: " + curToken.spelling), nullptr); + } + + if (curToken.kind != TokenKind::integer_literal) { + return (emitError("expected storage type maximum"), nullptr); + } + if (curToken.spelling.getAsInteger(10, storageTypeMax) || + storageTypeMax > defaultIntegerMax) { + return (emitError("illegal storage type maximum: " + curToken.spelling), + nullptr); + } + consumeToken(TokenKind::integer_literal); + + if (!consumeIf(TokenKind::r_paren)) { + return (emitError("unrecognized token: " + curToken.spelling), nullptr); + } + } else { + storageTypeMin = defaultIntegerMin; + storageTypeMax = defaultIntegerMax; + } + + // Repr type. + if (!consumeIf(TokenKind::colon)) { + return (emitError("unrecognized token: " + curToken.spelling), nullptr); + } + expressedType = parseExpressedType(); + if (!expressedType) { + return nullptr; + } + + // Optionally parse quantized dimension for per-axis quantization. + if (consumeIf(TokenKind::colon)) { + if (curToken.kind != TokenKind::integer_literal) { + return (emitError("expected quantized dimension"), nullptr); + } + if (curToken.spelling.getAsInteger(10, quantizedDimension)) { + return (emitError("illegal quantized dimension: " + curToken.spelling), + nullptr); + } + consumeToken(TokenKind::integer_literal); + isPerAxis = true; + } + + if (!consumeIf(TokenKind::r_bracket)) { + return (emitError("unrecognized token: " + curToken.spelling), nullptr); + } + + // Parameter specification. + if (!consumeIf(TokenKind::l_brace)) { + return (emitError("unrecognized token: " + curToken.spelling), nullptr); + } + + // Parse scales/zeroPoints. + do { + scales.resize(scales.size() + 1); + zeroPoints.resize(zeroPoints.size() + 1); + if (parseQuantParams(scales.back(), zeroPoints.back())) { + return nullptr; + } + } while (consumeIf(TokenKind::comma)); + + if (!consumeIf(TokenKind::r_brace)) { + return (emitError("unrecognized token: " + curToken.spelling), nullptr); + } + + if (!isPerAxis && scales.size() > 1) { + return (emitError("multiple scales/zeroPoints provided, but " + "quantizedDimension wasn't specified"), + nullptr); + } + + if (isPerAxis) { + ArrayRef scalesRef(scales.begin(), scales.end()); + ArrayRef zeroPointsRef(zeroPoints.begin(), zeroPoints.end()); + return UniformQuantizedPerAxisType::getChecked( + typeFlags, storageType, expressedType, scalesRef, zeroPointsRef, + quantizedDimension, storageTypeMin, storageTypeMax, location); + } + + return UniformQuantizedType::getChecked( + typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(), + storageTypeMin, storageTypeMax, location); +} + +IntegerType TypeParser::parseStorageType(bool &isSigned) { + // Parse storage type (alpha_ident, integer_literal). + StringRef storageTypePrefix = curToken.spelling; + unsigned storageTypeWidth; + if (curToken.kind != TokenKind::alpha_ident) { + return (emitError("expected storage type prefix"), nullptr); + } + consumeToken(); + if (curToken.kind != TokenKind::integer_literal) { + return (emitError("expected storage type width"), nullptr); + } + if (curToken.spelling.getAsInteger(10, storageTypeWidth) || + storageTypeWidth == 0 || + storageTypeWidth > QuantizedType::MaxStorageBits) { + return (emitError("illegal storage type size: " + Twine(curToken.spelling)), + nullptr); + } + consumeToken(); + + if (storageTypePrefix == "i") { + isSigned = true; + return IntegerType::get(storageTypeWidth, context); + } else if (storageTypePrefix == "u") { + isSigned = false; + return IntegerType::get(storageTypeWidth, context); + } else { + return ( + emitError("illegal storage type prefix: " + Twine(storageTypePrefix)), + nullptr); + } +} + +FloatType TypeParser::parseExpressedType() { + // Expect an alpha_ident followed by integer literal that we concat back + // together. + StringRef prefix = curToken.spelling; + if (!consumeIf(TokenKind::alpha_ident)) { + return (emitError("expected expressed type"), nullptr); + } + StringRef suffix = curToken.spelling; + if (!consumeIf(TokenKind::integer_literal)) { + return (emitError("expected expressed type"), nullptr); + } + + SmallVector holder; + StringRef typeName = (Twine(prefix) + Twine(suffix)).toStringRef(holder); + if (typeName == "f32") + return FloatType::getF32(context); + if (typeName == "f16") + return FloatType::getF16(context); + if (typeName == "bf16") + return FloatType::getBF16(context); + if (typeName == "f64") + return FloatType::getF64(context); + + return (emitError("unrecognized expressed type: " + typeName), nullptr); +} + +bool TypeParser::parseQuantParams(double &scale, int64_t &zeroPoint) { + // scale[:zeroPoint]? + // scale. + StringRef scaleSpelling = curToken.spelling; + if (!consumeIf(TokenKind::float_literal) || + scaleSpelling.getAsDouble(scale)) { + return ( + emitError("expected valid uniform scale. got: " + Twine(scaleSpelling)), + true); + } + + // zero point. + zeroPoint = 0; + if (!consumeIf(TokenKind::colon)) { + // Default zero point. + return false; + } + StringRef zeroPointSpelling = curToken.spelling; + if (!consumeIf(TokenKind::integer_literal) || + zeroPointSpelling.getAsInteger(10, zeroPoint)) { + return (emitError("expected integer uniform zero point. got: " + + Twine(zeroPointSpelling)), + true); + } + + return false; +} + +/// Parse a type registered to this dialect. +Type QuantizationDialect::parseType(StringRef spec, Location loc) const { + TypeParser parser(spec, getContext(), loc); + Type parsedType = parser.parseType(); + if (parsedType == nullptr) { + // Error. + // TODO(laurenzo): Do something? + return parsedType; + } + + return parsedType; +} + +static void printStorageType(QuantizedType type, raw_ostream &out) { + // storage type + unsigned storageWidth = type.getStorageTypeIntegralWidth(); + bool isSigned = type.isSigned(); + if (isSigned) { + out << "i" << storageWidth; + } else { + out << "u" << storageWidth; + } + + // storageTypeMin and storageTypeMax if not default. + int64_t defaultIntegerMin = + QuantizedType::getDefaultMininumForInteger(isSigned, storageWidth); + int64_t defaultIntegerMax = + QuantizedType::getDefaultMaxinumForInteger(isSigned, storageWidth); + if (defaultIntegerMin != type.getStorageTypeMin() || + defaultIntegerMax != type.getStorageTypeMax()) { + out << "(" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax() + << ")"; + } +} + +static void printExpressedType(QuantizedType type, raw_ostream &out) { + // repr type + Type expressedType = type.getExpressedType(); + if (expressedType.isF32()) { + out << "f32"; + } else if (expressedType.isF64()) { + out << "f64"; + } else if (expressedType.isF16()) { + out << "f16"; + } else if (expressedType.isBF16()) { + out << "bf16"; + } else { + out << "unknown"; + } +} + +static void printQuantParams(double scale, int64_t zeroPoint, + raw_ostream &out) { + printStabilizedFloat(APFloat(scale), out); + if (zeroPoint != 0) { + out << ":" << zeroPoint; + } +} + +/// Helper that prints a UniformQuantizedType. +static void printUniformQuantizedType(UniformQuantizedType type, + raw_ostream &out) { + out << "uniform["; + printStorageType(type, out); + out << ":"; + printExpressedType(type, out); + out << "]"; + + // scheme specific parameters + out << "{"; + printQuantParams(type.getScale(), type.getZeroPoint(), out); + out << "}"; +} + +/// Helper that prints a UniformQuantizedPerAxisType. +static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, + raw_ostream &out) { + out << "uniform["; + printStorageType(type, out); + out << ":"; + printExpressedType(type, out); + out << ":"; + out << type.getQuantizedDimension(); + out << "]"; + + // scheme specific parameters + ArrayRef scales = type.getScales(); + ArrayRef zeroPoints = type.getZeroPoints(); + out << "{"; + for (unsigned i = 0; i < scales.size(); ++i) { + printQuantParams(scales[i], zeroPoints[i], out); + if (i != scales.size() - 1) { + out << ","; + } + } + out << "}"; +} + +/// Print a type registered to this dialect. +void QuantizationDialect::printType(Type type, raw_ostream &os) const { + switch (type.getKind()) { + default: + llvm_unreachable("Unhandled quantized type"); + case QuantizationTypes::UniformQuantized: + printUniformQuantizedType(type.cast(), os); + break; + case QuantizationTypes::UniformQuantizedPerAxis: + printUniformQuantizedPerAxisType(type.cast(), + os); + break; + } +} + +} // namespace quant +} // namespace mlir diff --git a/mlir/lib/Quantization/IR/UniformSupport.cpp b/mlir/lib/Quantization/IR/UniformSupport.cpp new file mode 100644 index 000000000000..d9549bb07476 --- /dev/null +++ b/mlir/lib/Quantization/IR/UniformSupport.cpp @@ -0,0 +1,73 @@ +//===- UniformSupport.cpp - Support utilities for uniform quant -----------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Quantization/UniformSupport.h" +#include "mlir/IR/StandardTypes.h" + +using namespace mlir; +using namespace mlir::quant; + +static bool isQuantizablePrimitiveType(Type inputType) { + return inputType.isa(); +} + +const ExpressedToUniformQuantizedConverter +ExpressedToUniformQuantizedConverter::forInputType(Type inputType) { + switch (inputType.getKind()) { + default: + if (isQuantizablePrimitiveType(inputType)) { + // Supported primitive type (which just is the expressed type). + return ExpressedToUniformQuantizedConverter{inputType, inputType}; + } + // Unsupported. + return ExpressedToUniformQuantizedConverter{inputType, nullptr}; + case StandardTypes::RankedTensor: + case StandardTypes::UnrankedTensor: + case StandardTypes::Vector: { + Type elementType = inputType.cast().getElementType(); + if (!isQuantizablePrimitiveType(elementType)) { + // Unsupported. + return ExpressedToUniformQuantizedConverter{inputType, nullptr}; + } + return ExpressedToUniformQuantizedConverter{ + inputType, inputType.cast().getElementType()}; + } + } +} + +Type ExpressedToUniformQuantizedConverter::convert( + UniformQuantizedType elementalType) const { + assert(expressedType && "convert() on unsupported conversion"); + + switch (inputType.getKind()) { + default: + if (isQuantizablePrimitiveType(elementalType)) { + // For primitives, just use the new elemental type. + return elementalType; + } + // Unsupported. + return nullptr; + case StandardTypes::RankedTensor: + return RankedTensorType::get(inputType.cast().getShape(), + elementalType); + case StandardTypes::UnrankedTensor: + return UnrankedTensorType::get(elementalType); + case StandardTypes::Vector: + return VectorType::get(inputType.cast().getShape(), + elementalType); + } +} diff --git a/mlir/lib/Quantization/Transforms/ConvertConst.cpp b/mlir/lib/Quantization/Transforms/ConvertConst.cpp new file mode 100644 index 000000000000..ec947f24905e --- /dev/null +++ b/mlir/lib/Quantization/Transforms/ConvertConst.cpp @@ -0,0 +1,133 @@ +//===- ConvertConst.cpp - Quantizes constant ops --------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Quantization/Passes.h" +#include "mlir/Quantization/QuantOps.h" +#include "mlir/Quantization/QuantizeUtils.h" +#include "mlir/Quantization/UniformSupport.h" +#include "mlir/StandardOps/Ops.h" + +using namespace mlir; +using namespace mlir::quant; + +namespace { + +class ConvertConstPass : public FunctionPass { +public: + void runOnFunction() override; +}; + +class QuantizedConstRewrite : public RewritePattern { +public: + struct State : PatternState { + QuantizedType quantizedElementType; + Attribute value; + }; + + QuantizedConstRewrite(MLIRContext *context) + : RewritePattern(QuantizeBarrierOp::getOperationName(), 1, context) {} + + PatternMatchResult match(Operation *op) const override; + void rewrite(Operation *op, std::unique_ptr baseState, + PatternRewriter &rewriter) const override; +}; + +} // end anonymous namespace + +/// Matches a [constant] -> [qbarrier] where the qbarrier results type is +/// quantized and the operand type is quantizable. +PatternMatchResult QuantizedConstRewrite::match(Operation *op) const { + State state; + + // Is the operand a constant? + auto qbarrier = op->cast(); + if (!matchPattern(qbarrier.arg(), m_Constant(&state.value))) { + return matchFailure(); + } + // Does the qbarrier convert to a quantized type. This will not be true + // if a quantized type has not yet been chosen or if the cast to an equivalent + // storage type is not supported. + Type qbarrierResultType = qbarrier.getResult()->getType(); + state.quantizedElementType = + QuantizedType::getQuantizedElementType(qbarrierResultType); + if (!state.quantizedElementType) { + return matchFailure(); + } + if (!QuantizedType::castToStorageType(qbarrierResultType)) { + return matchFailure(); + } + + // Is the operand type compatible with the expressed type of the quantized + // type? This will not be true if the qbarrier is superfluous (converts + // from and to a quantized type). + if (!state.quantizedElementType.isCompatibleExpressedType( + qbarrier.arg()->getType())) { + return matchFailure(); + } + + // Is the constant value a type expressed in a way that we support? + if (!state.value.isa() && !state.value.isa() && + !state.value.isa() && + !state.value.isa()) { + return matchFailure(); + } + + return matchSuccess(llvm::make_unique(std::move(state))); +} + +void QuantizedConstRewrite::rewrite(Operation *op, + std::unique_ptr baseState, + PatternRewriter &rewriter) const { + auto state = static_cast(baseState.get()); + + Type newConstValueType; + Attribute newConstValue = quantizeAttr( + state->value, state->quantizedElementType, newConstValueType); + if (!newConstValue) { + return; + } + + auto *origConstOp = op->getOperand(0); + // When creating the new const op, use a fused location that combines the + // original const and the qbarrier that led to the quantization. + auto fusedLoc = + FusedLoc::get({origConstOp->getDefiningOp()->getLoc(), op->getLoc()}, + rewriter.getContext()); + auto newConstOp = + rewriter.create(fusedLoc, newConstValueType, newConstValue); + rewriter.replaceOpWithNewOp( + op, {origConstOp}, *op->result_type_begin(), newConstOp); +} + +void ConvertConstPass::runOnFunction() { + OwningRewritePatternList patterns; + auto &func = getFunction(); + auto *context = &getContext(); + patterns.push_back(llvm::make_unique(context)); + applyPatternsGreedily(func, std::move(patterns)); +} + +FunctionPassBase *createConvertConstPass() { return new ConvertConstPass(); } + +static PassRegistration + pass("quant-convert-const", + "Converts constants followed by qbarrier to actual quantized values"); diff --git a/mlir/lib/Quantization/Transforms/LowerTF.cpp b/mlir/lib/Quantization/Transforms/LowerTF.cpp new file mode 100644 index 000000000000..24a35c9f3179 --- /dev/null +++ b/mlir/lib/Quantization/Transforms/LowerTF.cpp @@ -0,0 +1,112 @@ +//===- LowerTF.cpp - Passes for lowering from TensorFlow ------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Quantization/FakeQuantSupport.h" +#include "mlir/Quantization/Passes.h" +#include "mlir/Quantization/QuantOps.h" +#include "mlir/Quantization/UniformSupport.h" +#include "mlir/TensorFlow/TFOps.h" + +using namespace mlir; +using namespace mlir::quant; + +namespace { + +class LowerTFPass : public FunctionPass { +public: + void runOnFunction() override; +}; + +} // end anonymous namespace + +/// Rewrites TensorFlow FakeQuantWithMinMaxArgs into a qbarrier/dbarrier pair. +class FakeQuantWithMinMaxArgsRewrite : public RewritePattern { +public: + bool *hadFailure; + + FakeQuantWithMinMaxArgsRewrite(MLIRContext *context, bool *hadFailure) + : RewritePattern(TF::FakeQuantWithMinMaxArgsOp::getOperationName(), 1, + context), + hadFailure(hadFailure) {} + + PatternMatchResult match(Operation *op) const override { + return matchSuccess(); + } + + void rewrite(Operation *op, PatternRewriter &rewriter) const override { + // TODO: If this pattern comes up more frequently, consider adding core + // support for failable rewrites. + if (failableRewrite(op, rewriter)) { + *hadFailure = true; + } + } + + bool failableRewrite(Operation *op, PatternRewriter &rewriter) const { + auto fqOp = op->template cast(); + + auto converter = + ExpressedToUniformQuantizedConverter::forInputType(fqOp.getType()); + if (!converter) { + return (op->emitError("unsupported quantized type conversion"), true); + } + + UniformQuantizedType uniformElementType = fakeQuantAttrsToType( + fqOp.getLoc(), fqOp.num_bits().getSExtValue(), + fqOp.min().convertToDouble(), fqOp.max().convertToDouble(), + fqOp.narrow_range(), converter.expressedType); + + if (!uniformElementType) { + // Note that the fakeQuantAttrsToType will have emitted the error. + return true; + } + + Type quantizedType = converter.convert(uniformElementType); + assert(quantizedType && + "Converter accepted a type that it did not convert"); + + // TODO: Map to a qbarrier with an attribute like [Forced] to signal that + // this is a forced/hard-coded constraint. + auto qbarrier = rewriter.create( + op->getLoc(), quantizedType, fqOp.inputs()); + rewriter.replaceOpWithNewOp(op, converter.inputType, + qbarrier.getResult()); + + return false; + } +}; + +void LowerTFPass::runOnFunction() { + bool hadFailure = false; + OwningRewritePatternList patterns; + auto &func = getFunction(); + auto *context = &getContext(); + patterns.push_back( + llvm::make_unique(context, &hadFailure)); + applyPatternsGreedily(func, std::move(patterns)); + if (hadFailure) + signalPassFailure(); +} + +FunctionPassBase *createLowerTFPass() { return new LowerTFPass(); } + +static PassRegistration + pass("quant-lower-tf", + "Lowers TensorFlow constraint ops to the quantization dialect"); diff --git a/mlir/lib/Quantization/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Quantization/Transforms/LowerUniformRealMath.cpp new file mode 100644 index 000000000000..9ce926408bb9 --- /dev/null +++ b/mlir/lib/Quantization/Transforms/LowerUniformRealMath.cpp @@ -0,0 +1,259 @@ +//===- LowerUniformRealMath.cpp ------------------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Quantization/Passes.h" +#include "mlir/Quantization/QuantOps.h" +#include "mlir/Quantization/UniformSupport.h" + +#include + +using namespace mlir; +using namespace mlir::quant; + +namespace { + +struct LowerUniformRealMathPass + : public FunctionPass { + void runOnFunction() override; +}; + +UniformQuantizedType getUniformElementType(Type t) { + return QuantizedType::getQuantizedElementType(t) + .dyn_cast_or_null(); +} + +/// Computes the log2(x), rounded to an integral value. Returns whether 'x' can +/// be considered an exact integral value. +template bool integralLog2(F x, int &log2Result) { + const F xLog2 = std::log(x) * (1.0 / std::log(2.0)); + const F xLog2Rounded = std::round(xLog2); + const F xLog2Frac = xLog2 - xLog2Rounded; + log2Result = static_cast(xLog2Rounded); + // Allow small comparison slop below the level that would make a difference + // for 2^16 levels. + return std::abs(xLog2Frac) < 1e-6; +} + +/// Helper class for operating on binary operations where all operands +/// and the result are a UniformQuantizedType. +struct RealBinaryOpInfo { + RealBinaryOpInfo(Operation *op, Value *lhs, Value *rhs, + Optional clampMin, Optional clampMax) + : op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax), + lhsType(getUniformElementType(lhs->getType())), + rhsType(getUniformElementType(rhs->getType())), + resultType(getUniformElementType(*op->result_type_begin())), + lhsStorageType(QuantizedType::castToStorageType(lhs->getType())), + rhsStorageType(QuantizedType::castToStorageType(rhs->getType())), + resultStorageType( + QuantizedType::castToStorageType(*op->result_type_begin())) {} + + /// Returns whether this info is valid (all types defined, etc). + bool isValid() const { + return lhsType && rhsType && resultType && lhsStorageType && + rhsStorageType && resultStorageType; + } + + /// Returns whether the storage type of all operands is identical. + bool isSameStorageType() const { + return lhsType.getStorageType() == rhsType.getStorageType() && + lhsType.getStorageType() == resultType.getStorageType(); + } + + /// Returns whether all operands and result are considered fixedpoint power + /// of two, setting the lhs, rhs, and result log2 scale references. + bool isFixedPointPOT(int &lhsLog2Scale, int &rhsLog2Scale, + int &resultLog2Scale) const { + if (!lhsType.isFixedPoint() || !rhsType.isFixedPoint() || + !resultType.isFixedPoint()) { + return false; + } + + if (!integralLog2(lhsType.getScale(), lhsLog2Scale) || + !integralLog2(rhsType.getScale(), rhsLog2Scale) || + !integralLog2(resultType.getScale(), resultLog2Scale)) { + return false; + } + + return true; + } + + /// Gets the result integer clamp range given the result quantized type + // and any explicit clamp provided as attributes. + std::pair getClampMinMax() const { + int64_t typeMin = resultType.getStorageTypeMin(); + int64_t typeMax = resultType.getStorageTypeMax(); + + if (clampMin || clampMax) { + UniformQuantizedValueConverter conv(resultType); + if (clampMin) { + typeMin = std::max(typeMin, conv.quantizeFloatToInt64(*clampMin)); + } + if (clampMax) { + typeMax = std::min(typeMax, conv.quantizeFloatToInt64(*clampMax)); + } + } + + // The quantized, integral ops expect clamps as 32bit ints. + return { + IntegerAttr::get(IntegerType::get(32, resultType.getContext()), + typeMin), + IntegerAttr::get(IntegerType::get(32, resultType.getContext()), + typeMax), + }; + } + + Operation *op; + Value *lhs; + Value *rhs; + Optional clampMin; + Optional clampMax; + + // Element UniformQuantizedType for operands/result. + UniformQuantizedType lhsType; + UniformQuantizedType rhsType; + UniformQuantizedType resultType; + + // Full storage-based types. + Type lhsStorageType; + Type rhsStorageType; + Type resultStorageType; +}; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Elementwise add +//===----------------------------------------------------------------------===// +/// Attempts to rewrite a fixed point power-of-two addition of two integers. +/// This supports a limited number of cases, but when supported, represents +/// the simplest computation. +static LogicalResult tryRewriteFixedPOTAddEw(const RealBinaryOpInfo &constInfo, + PatternRewriter &rewriter) { + if (!constInfo.isSameStorageType()) { + return failure(); + } + + int lhsLog2Scale; + int rhsLog2Scale; + int resultLog2Scale; + if (!constInfo.isFixedPointPOT(lhsLog2Scale, rhsLog2Scale, resultLog2Scale)) { + return failure(); + } + + // Adjust shifts to be relative to the output. + // Left shift of one input scale is supported. The other must match the result + // scale. + int lhsScaleShift = lhsLog2Scale - resultLog2Scale; + int rhsScaleShift = rhsLog2Scale - resultLog2Scale; + if (lhsScaleShift != 0 && rhsScaleShift != 0) { + return failure(); + } + if (lhsScaleShift > 0 || rhsScaleShift > 0) { + return failure(); + } + + // State accessed by the closure. + Operation *mathOp = constInfo.op; + const auto clampMinMax = constInfo.getClampMinMax(); + Value *lhs = constInfo.lhs; + Value *rhs = constInfo.rhs; + Type lhsStorageType = constInfo.lhsStorageType; + Type rhsStorageType = constInfo.rhsStorageType; + + // If the lhs operand is the one requiring a shift, swap it so that the shift + // happens the rhs operand. + if (lhsScaleShift != 0) { + std::swap(lhs, rhs); + std::swap(lhsStorageType, rhsStorageType); + std::swap(lhsScaleShift, rhsScaleShift); + } + int rhsRightShift = -rhsScaleShift; + + // Cast operands to storage type. + Value *lhsStorageValue = + rewriter.create(mathOp->getLoc(), lhsStorageType, lhs) + .getResult(); + Value *rhsStorageValue = + rewriter.create(mathOp->getLoc(), rhsStorageType, rhs) + .getResult(); + + // Rescale the rhs operand if needed. + if (rhsRightShift != 0) { + rhsStorageValue = + rewriter + .create( + mathOp->getLoc(), rhsStorageValue, + IntegerAttr::get(IntegerType::get(32, rewriter.getContext()), + rhsRightShift)) + .getResult(); + } + + // Add. + Value *sumValue = rewriter.create( + mathOp->getLoc(), lhsStorageValue, rhsStorageValue, clampMinMax.first, + clampMinMax.second); + + // Cast back for new result. + rewriter.replaceOpWithNewOp( + mathOp, *mathOp->result_type_begin(), sumValue); + return success(); +} + +namespace { + +struct UniformRealAddEwPattern : public RewritePattern { + UniformRealAddEwPattern(MLIRContext *context) + : RewritePattern(RealAddEwOp::getOperationName(), 1, context) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + auto addOp = op->cast(); + const RealBinaryOpInfo info(op, addOp.x(), addOp.y(), addOp.clamp_min(), + addOp.clamp_max()); + if (!info.isValid()) { + return matchFailure(); + } + + // Try all of the permutations we support. + if (succeeded(tryRewriteFixedPOTAddEw(info, rewriter))) { + return matchSuccess(); + } + + return matchFailure(); + } +}; + +} // end anonymous namespace + +void LowerUniformRealMathPass::runOnFunction() { + auto &fn = getFunction(); + OwningRewritePatternList patterns; + auto *context = &getContext(); + patterns.push_back(llvm::make_unique(context)); + applyPatternsGreedily(fn, std::move(patterns)); +} + +FunctionPassBase *createLowerUniformRealMathPass() { + return new LowerUniformRealMathPass(); +} + +static PassRegistration + pass("quant-lower-uniform-real-math", + "Lowers uniform-quantized real math ops to integer arithmetic."); diff --git a/mlir/lib/Quantization/Utils/QuantizeUtils.cpp b/mlir/lib/Quantization/Utils/QuantizeUtils.cpp new file mode 100644 index 000000000000..159d6eb77306 --- /dev/null +++ b/mlir/lib/Quantization/Utils/QuantizeUtils.cpp @@ -0,0 +1,186 @@ +//===- QuantizeUtils.cpp - Support utilities for quantization -------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Quantization/QuantizeUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Quantization/UniformSupport.h" + +namespace mlir { +namespace quant { +/// Converts a possible primitive, real expressed value attribute to a +/// corresponding storage attribute (typically FloatAttr -> IntegerAttr). +/// quantizedElementType is the QuantizedType that describes the expressed +/// origValue. +/// Returns a converter Attribute or nullptr if conversion is not possible. +static Attribute convertPrimitiveValueAttr( + Attribute origRealValue, QuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter, Type &outConvertedType) { + if (origRealValue.isa()) { + FloatAttr floatAttr = origRealValue.cast(); + outConvertedType = quantizedElementType.getStorageType(); + return IntegerAttr::get(quantizedElementType.getStorageType(), + converter.quantizeFloatToInt(floatAttr.getValue())); + } + + return nullptr; +} + +/// Converts a real expressed DenseFPElementsAttr to a corresponding +/// DenseElementsAttr (typically DenseIntElementsAttr) containing quantized +/// storage values assuming the given quantizedElementType and converter. +static DenseElementsAttr +convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr, + QuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter) { + // Read real expressed values. + SmallVector realValues; + realValues.reserve(realFPElementsAttr.getType().getNumElements()); + realFPElementsAttr.getValues(realValues); + + // Convert to corresponding quantized value attributes. + SmallVector quantValues(realValues.size()); + for (size_t i = 0, e = realValues.size(); i < e; ++i) { + quantValues[i] = converter.quantizeFloatToInt(realValues[i]); + } + + // Cast from an expressed-type-based type to storage-type-based type, + // preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>). + VectorOrTensorType newDenseType = + quantizedElementType + .castExpressedToStorageType(realFPElementsAttr.getType()) + .dyn_cast_or_null(); + if (!newDenseType) { + return nullptr; + } + return DenseIntElementsAttr::get(newDenseType, quantValues); +} + +/// Converts a real expressed SplatElementsAttr to a corresponding +/// SplatElementsAttr containing quantized storage values assuming the given +/// quantizedElementType and converter. +static SplatElementsAttr +convertSplatElementsAttr(SplatElementsAttr realSplatAttr, + QuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter) { + // Since the splat just references a single primitive value, use the + // function for converting primitives. + // NOTE: When implementing per-channel, we will need to promote the + // splat to a dense and handle channels individually. + Type unusedPrimitiveType; + auto elementAttr = + convertPrimitiveValueAttr(realSplatAttr.getValue(), quantizedElementType, + converter, unusedPrimitiveType); + if (!elementAttr) { + return nullptr; + } + + // Cast from an expressed-type-based type to storage-type-based type, + // preserving the splat shape (i.e. tensor<4xf32> -> tensor<4xi8>). + VectorOrTensorType newSplatType = + quantizedElementType.castExpressedToStorageType(realSplatAttr.getType()) + .dyn_cast_or_null(); + if (!newSplatType) { + return nullptr; + } + return SplatElementsAttr::get(newSplatType, elementAttr); +} + +/// Converts a real expressed SplatElementsAttr to a corresponding +/// SplatElementsAttr containing quantized storage values assuming the given +/// quantizedElementType and converter. +static SparseElementsAttr +convertSparseElementsAttr(SparseElementsAttr realSparseAttr, + QuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter) { + DenseElementsAttr realDenseAttr = realSparseAttr.getValues(); + if (!realDenseAttr.isa()) { + return nullptr; + } + DenseElementsAttr quantDenseAttr = + convertDenseFPElementsAttr(realDenseAttr.cast(), + quantizedElementType, converter); + if (!quantDenseAttr) { + return nullptr; + } + + // Cast from an expressed-type-based type to storage-type-based type, + // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>). + VectorOrTensorType newSparseType = + quantizedElementType.castExpressedToStorageType(realSparseAttr.getType()) + .dyn_cast_or_null(); + if (!newSparseType) { + return nullptr; + } + return SparseElementsAttr::get(newSparseType, realSparseAttr.getIndices(), + quantDenseAttr); +} + +/// Converts a real expressed Attribute to a corresponding Attribute containing +/// quantized storage values assuming the given uniform quantizedElementType and +/// converter. +Attribute quantizeAttrUniform(Attribute realValue, + UniformQuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter, + Type &outConvertedType) { + // Fork to handle different variants of constants supported. + if (realValue.isa()) { + // Splatted tensor or vector constant. + auto converted = convertSplatElementsAttr( + realValue.cast(), quantizedElementType, converter); + outConvertedType = converted.getType(); + return converted; + } else if (realValue.isa()) { + // Dense tensor or vector constant. + auto converted = convertDenseFPElementsAttr( + realValue.cast(), quantizedElementType, converter); + outConvertedType = converted.getType(); + return converted; + } else if (realValue.isa()) { + // Sparse tensor or vector constant. + auto converted = convertSparseElementsAttr( + realValue.cast(), quantizedElementType, converter); + outConvertedType = converted.getType(); + return converted; + } else { + // Nothing else matched: try to convert a primitive. + return convertPrimitiveValueAttr(realValue, quantizedElementType, converter, + outConvertedType); + } +} + +/// Convert an attribute from a type based on +/// quantizedElementType.getExpressedType() to one based on +/// quantizedElementType.getStorageType(). +/// Returns nullptr if the conversion is not supported. +/// On success, stores the converted type in outConvertedType. +Attribute quantizeAttr(Attribute realValue, QuantizedType quantizedElementType, + Type &outConvertedType) { + // Hard-coded to just support UniformQuantizedType. This will need to + // be generalized when there is more than one. + auto uniformQuantizedType = + quantizedElementType.dyn_cast(); + if (!uniformQuantizedType) { + return nullptr; + } + UniformQuantizedValueConverter converter(uniformQuantizedType); + return quantizeAttrUniform(realValue, uniformQuantizedType, converter, + outConvertedType); +} + +} // namespace quant +} // namespace mlir diff --git a/mlir/test/Quantization/convert-const.mlir b/mlir/test/Quantization/convert-const.mlir new file mode 100644 index 000000000000..d0ac5d70beeb --- /dev/null +++ b/mlir/test/Quantization/convert-const.mlir @@ -0,0 +1,140 @@ +// RUN: mlir-opt %s -split-input-file -quant-convert-const | FileCheck %s --dump-input=fail + +// Magic numbers: +// 7.8125e-03 = 1/128 = 2/256 : real range = [-1.0, 0.9921875] (for 8bit, zeroPoint=128) +// 1.250000e-01 = 1/8 = 2/16 : real range = [-1.0, 0.875] (for 4bit, zeroPoint=8) + +// ----- +// Verifies u8 affine quantization on a splat tensor. +// Note that MLIR prints int attributes as signed, so the constant, when +// quantized, is the signed printed version of an unsigned quantity +// (-64 signed == 192 unsigned). +// CHECK-LABEL: constant_splat_tensor_u8_affine +func @constant_splat_tensor_u8_affine() -> tensor<4xf32> { + // CHECK: %cst = constant splat, -64> : tensor<4xi8> + // CHECK-NEXT: %0 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">> + %cst = constant splat, 0.5> : tensor<4xf32> + %1 = "quant.qbarrier"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">> + %2 = "quant.dbarrier"(%1) : (tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">>) -> (tensor<4xf32>) + return %2 : tensor<4xf32> +} + +// ----- +// Verifies i8 affine quantization on a splat tensor. +// CHECK-LABEL: constant_splat_tensor_i8_affine +func @constant_splat_tensor_i8_affine() -> tensor<4xf32> { + // CHECK: %cst = constant splat, 63> : tensor<4xi8> + // CHECK-NEXT: %0 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03:-1}">> + %cst = constant splat, 0.5> : tensor<4xf32> + %1 = "quant.qbarrier"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03:-1}">> + %2 = "quant.dbarrier"(%1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03:-1}">>) -> (tensor<4xf32>) + return %2 : tensor<4xf32> +} + +// ----- +// Verifies i8 fixedpoint quantization on a splat tensor. +// CHECK-LABEL: const_splat_tensor_i8_fixedpoint +func @const_splat_tensor_i8_fixedpoint() -> tensor<4xf32> { + // CHECK: %cst = constant splat, 64> : tensor<4xi8> + // CHECK-NEXT: %0 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">> + %cst = constant splat, 0.5> : tensor<4xf32> + %1 = "quant.qbarrier"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">> + %2 = "quant.dbarrier"(%1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<4xf32>) + return %2 : tensor<4xf32> +} + +// ----- +// Verifies i8 fixedpoint quantization on a splat tensor resulting in a negative storage value. +// CHECK-LABEL: const_splat_tensor_i8_fixedpoint_neg +func @const_splat_tensor_i8_fixedpoint_neg() -> tensor<4xf32> { + // CHECK: %cst = constant splat, -64> : tensor<4xi8> + %cst = constant splat, -0.5> : tensor<4xf32> + %1 = "quant.qbarrier"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">> + %2 = "quant.dbarrier"(%1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<4xf32>) + return %2 : tensor<4xf32> +} + +// ----- +// Verifies i8 fixedpoint quantization on a dense tensor, sweeping values. +// CHECK-LABEL: const_dense_tensor_i8_fixedpoint +func @const_dense_tensor_i8_fixedpoint() -> tensor<7xf32> { + // CHECK: %cst = constant dense, [-128, -128, -64, 0, 64, 127, 127]> : tensor<7xi8> + %cst = constant dense, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32> + %1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i8:f32]{7.812500e-03}">> + %2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<7xf32>) + return %2 : tensor<7xf32> +} + +// ----- +// Verifies i8 fixedpoint quantization on a sparse tensor, sweeping values. +// CHECK-LABEL: const_sparse_tensor_i8_fixedpoint +func @const_sparse_tensor_i8_fixedpoint() -> tensor<7x2xf32> { + // NOTE: Ugly regex match pattern for opening "[[" of indices tensor. + // CHECK: %cst = constant sparse, {{\[}}[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]], [-128, -128, -64, 0, 64, 127, 127]> : tensor<7x2xi8> + %cst = constant sparse, + [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]], + [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7x2xf32> + %1 = "quant.qbarrier"(%cst) : (tensor<7x2xf32>) -> tensor<7x2x!quant<"uniform[i8:f32]{7.812500e-03}">> + %2 = "quant.dbarrier"(%1) : (tensor<7x2x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<7x2xf32>) + return %2 : tensor<7x2xf32> +} + +// ----- +// Verifies i8 fixedpoint quantization on a primitive const. +// CHECK-LABEL: const_primitive_float_i8_fixedpoint +func @const_primitive_float_i8_fixedpoint() -> f32 { + // CHECK: %c64_i8 = constant 64 : i8 + // CHECK-NEXT: %0 = "quant.scast"(%c64_i8) : (i8) -> !quant<"uniform[i8:f32]{7.812500e-03}"> + %cst = constant 0.5 : f32 + %1 = "quant.qbarrier"(%cst) : (f32) -> !quant<"uniform[i8:f32]{7.812500e-03}"> + %2 = "quant.dbarrier"(%1) : (!quant<"uniform[i8:f32]{7.812500e-03}">) -> (f32) + return %2 : f32 +} + +// ----- +// Verifies u4 affine quantization on a dense tensor, sweeping values. +// CHECK-LABEL: const_dense_tensor_u4_affine +func @const_dense_tensor_u4_affine() -> tensor<7xf32> { + // NOTE: Unsigned quantities printed by MLIR as signed. + // CHECK: %cst = constant dense, [0, 0, 4, -8, -4, -1, -1]> : tensor<7xi4> + %cst = constant dense, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32> + %1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[u4:f32]{1.250000e-01:8}">> + %2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[u4:f32]{1.250000e-01:8}">>) -> (tensor<7xf32>) + return %2 : tensor<7xf32> +} + +// ----- +// Verifies i4 affine quantization on a dense tensor, sweeping values. +// CHECK-LABEL: const_dense_tensor_i4_affine +func @const_dense_tensor_i4_affine() -> tensor<7xf32> { + // NOTE: Unsigned quantities printed by MLIR as signed. + // CHECK: %cst = constant dense, [-8, -8, -5, -1, 3, 7, 7]> : tensor<7xi4> + %cst = constant dense, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32> + %1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i4:f32]{1.250000e-01:-1}">> + %2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[i4:f32]{1.250000e-01:-1}">>) -> (tensor<7xf32>) + return %2 : tensor<7xf32> +} + +// ----- +// Verifies i4 fixed point quantization on a dense tensor, sweeping values. +// CHECK-LABEL: const_dense_tensor_i4_fixedpoint +func @const_dense_tensor_i4_fixedpoint() -> tensor<7xf32> { + // CHECK: %cst = constant dense, [-8, -8, -4, 0, 4, 7, 7]> : tensor<7xi4> + %cst = constant dense, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32> + %1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i4:f32]{1.250000e-01}">> + %2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[i4:f32]{1.250000e-01}">>) -> (tensor<7xf32>) + return %2 : tensor<7xf32> +} + +// ----- +// Verifies i8 fixedpoint quantization on a dense tensor, sweeping values, and +// custom storage range. (the -128 should be clamped to -100, and the 127 should +// be clamped to 100). +// CHECK-LABEL: const_custom_storage_range_i8_fixedpoint +func @const_custom_storage_range_i8_fixedpoint() -> tensor<7xf32> { + // CHECK: %cst = constant dense, [-100, -100, -64, 0, 64, 100, 100]> : tensor<7xi8> + %cst = constant dense, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32> + %1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i8(-100:100):f32]{7.812500e-03}">> + %2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[i8(-100:100):f32]{7.812500e-03}">>) -> (tensor<7xf32>) + return %2 : tensor<7xf32> +} diff --git a/mlir/test/Quantization/lower-uniform-real-math-addew.mlir b/mlir/test/Quantization/lower-uniform-real-math-addew.mlir new file mode 100644 index 000000000000..96c0886ba807 --- /dev/null +++ b/mlir/test/Quantization/lower-uniform-real-math-addew.mlir @@ -0,0 +1,106 @@ +// RUN: mlir-opt %s -split-input-file -quant-lower-uniform-real-math | FileCheck %s --dump-input=fail + +// ----- +// Verify lowering when operands and result have the same fixedpoint pot scale. +// CHECK-LABEL: real_addew_fixedpoint_same_scale +// CHECK: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8> +// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8> +// CHECK-NEXT: %2 = "quant.saturating_addi"(%0, %1) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> +// CHECK-NEXT: %3 = "quant.scast"(%2) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> +// CHECK-NEXT: return %3 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> +!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> +!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> +!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> +func @real_addew_fixedpoint_same_scale(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { + %0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) + return %0 : !type_result +} + +// ----- +// Verify lowering when the rhs is a shifted pot scale compared to lhs and result. +// CHECK-LABEL: real_addew_fixedpoint_rhs_shift +// CHECK: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8> +// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> tensor<4xi8> +// CHECK-NEXT: %2 = "quant.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8> +// CHECK-NEXT: %3 = "quant.saturating_addi"(%0, %2) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> +// CHECK-NEXT: %4 = "quant.scast"(%3) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> +// CHECK-NEXT: return %4 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> +!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> +!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{7.8125e-03}">> +!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> +func @real_addew_fixedpoint_rhs_shift(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { + %0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) + return %0 : !type_result +} + +// ----- +// Verify lowering when the lhs is a shifted pot scale compared to lhs and result. +// CHECK-LABEL: real_addew_fixedpoint_lhs_shift +// CHECK: %0 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8> +// CHECK-NEXT: %1 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> tensor<4xi8> +// CHECK-NEXT: %2 = "quant.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8> +// CHECK-NEXT: %3 = "quant.saturating_addi"(%0, %2) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> +// CHECK-NEXT: %4 = "quant.scast"(%3) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> +// CHECK-NEXT: return %4 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> +!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{7.8125e-03}">> +!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> +!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> +func @real_addew_fixedpoint_lhs_shift(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { + %0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) + return %0 : !type_result +} + +// ----- +// The RHS quant parameters proscribe a range of [-8..8) so an explicit clamp +// of [-4..4] should result in an integral clamp range of [-64..64]. +// CHECK-LABEL: real_addew_fixedpoint_clamp +// CHECK: %0 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8> +// CHECK-NEXT: %1 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> tensor<4xi8> +// CHECK-NEXT: %2 = "quant.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8> +// CHECK-NEXT: %3 = "quant.saturating_addi"(%0, %2) {clamp_max: 64 : i32, clamp_min: -64 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> +// CHECK-NEXT: %4 = "quant.scast"(%3) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> +// CHECK-NEXT: return %4 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> +!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{7.8125e-03}">> +!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> +!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> +func @real_addew_fixedpoint_clamp(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { + %0 = "quant.real_add_ew"(%arg0, %arg1) { clamp_min:-4.0, clamp_max:4.0 } + : (!type_lhs, !type_rhs) -> (!type_result) + return %0 : !type_result +} + +// ----- +// CHECK-LABEL: real_addew_unquantized_lhs +// Verifies that leaves as-is for unquantized lhs. +!type_lhs = type tensor<4xf32> +!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> +!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> +func @real_addew_unquantized_lhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { + // CHECK: %0 = "quant.real_add_ew"(%arg0, %arg1) + %0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) + return %0 : !type_result +} + +// ----- +// CHECK-LABEL: real_addew_unquantized_rhs +// Verifies that leaves as-is for unquantized rhs. +!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> +!type_rhs = type tensor<4xf32> +!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> +func @real_addew_unquantized_rhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { + // CHECK: %0 = "quant.real_add_ew"(%arg0, %arg1) + %0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) + return %0 : !type_result +} + +// ----- +// CHECK-LABEL: real_addew_unquantized_result +// Verifies that leaves as-is for unquantized result. +!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> +!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> +!type_result = type tensor<4xf32> +func @real_addew_unquantized_result(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { + // CHECK: %0 = "quant.real_add_ew"(%arg0, %arg1) + %0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) + return %0 : !type_result +} diff --git a/mlir/test/Quantization/parse-uniform-invalid.mlir b/mlir/test/Quantization/parse-uniform-invalid.mlir new file mode 100644 index 000000000000..91685662eddf --- /dev/null +++ b/mlir/test/Quantization/parse-uniform-invalid.mlir @@ -0,0 +1,137 @@ +// RUN: mlir-opt %s -split-input-file -verify + +// ----- +// Unknown type. +// expected-error@+1 {{unknown quantized type foobar}} +!qalias = type !quant<"foobar"> + +// ----- +// Unrecognized token: illegal token +// expected-error@+1 {{unrecognized token: %}} +!qalias = type !quant<"%%"> + +// ----- +// Unrecognized token: trailing +// expected-error@+1 {{unrecognized token: 23}} +!qalias = type !quant<"uniform[i8(-4:3):f32]{0.99872:127} 23"> + +// ----- +// Unrecognized token: type open +// expected-error@+1 {{unrecognized token: (}} +!qalias = type !quant<"uniform(i8(-4:3):f32){0.99872:127}"> + +// ----- +// Unrecognized token: missing storage type maximum +// expected-error@+1 {{expected storage type maximum}} +!qalias = type !quant<"uniform[i8(16:f32]{0.99872:127}"> + +// ----- +// Unrecognized token: missing closing paren +// expected-error@+1 {{unrecognized token: :}} +!qalias = type !quant<"uniform[i8(-4:3:f32]{0.99872:127}"> + +// ----- +// Unrecognized token: missing type colon +// expected-error@+1 {{unrecognized token: f}} +!qalias = type !quant<"uniform[i8(-4:3)f32]{0.99872:127}"> + +// ----- +// Unrecognized token: missing closing bracket +// expected-error@+1 {{unrecognized token: {}} +!qalias = type !quant<"uniform[i8(-4:3):f32{0.99872:127}"> + +// ----- +// Unrecognized token: wrong opening brace +// expected-error@+1 {{unrecognized token: (}} +!qalias = type !quant<"uniform[i8(-4:3):f32](0.99872:127}"> + +// ----- +// Unrecognized storage type: illegal prefix +// expected-error@+1 {{illegal storage type prefix: int}} +!qalias = type !quant<"uniform[int8(-4:3):f32]{0.99872:127}"> + +// ----- +// Unrecognized storage type: no width +// expected-error@+1 {{expected storage type width}} +!qalias = type !quant<"uniform[i(-4:3):f32]{0.99872:127}"> + +// ----- +// Unrecognized storage type: storage size > 32 +// expected-error@+1 {{illegal storage type size: 33}} +!qalias = type !quant<"uniform[i33:f32]{0.99872:127}"> + +// ----- +// Unrecognized storage type: storage size < 0 +// expected-error@+1 {{illegal storage type size: -1}} +!qalias = type !quant<"uniform[i-1(-4:3):f32]{0.99872:127}"> + +// ----- +// Unrecognized storage type: storage size == 0 +// expected-error@+1 {{illegal storage type size: 0}} +!qalias = type !quant<"uniform[i0(-4:3):f32]{0.99872:127}"> + +// ----- +// Illegal storage min/max: max - min < 0 +// expected-error@+1 {{illegal storage min and storage max: (2:1)}} +!qalias = type !quant<"uniform[i8(2:1):f32]{0.99872:127}"> + +// ----- +// Illegal storage min/max: max - min == 0 +// expected-error@+1 {{illegal storage min and storage max: (1:1)}} +!qalias = type !quant<"uniform[i8(1:1):f32]{0.99872:127}"> + +// ----- +// Illegal storage min/max: max > defaultMax +// expected-error@+1 {{illegal storage type maximum: 9}} +!qalias = type !quant<"uniform[i4(-1:9):f32]{0.99872:127}"> + +// ----- +// Illegal storage min/max: min < defaultMin +// expected-error@+1 {{illegal storage type minimum: -9}} +!qalias = type !quant<"uniform[i4(-9:1):f32]{0.99872:127}"> + +// ----- +// Illegal uniform params: invalid scale +// expected-error@+1 {{expected valid uniform scale. got: abc}} +!qalias = type !quant<"uniform[i8(-4:3):f32]{abc:127}"> + +// ----- +// Illegal uniform params: invalid zero point separator +// expected-error@+1 {{unrecognized token: abc}} +!qalias = type !quant<"uniform[i8(-4:3):f32]{0.1abc}"> + +// ----- +// Illegal uniform params: missing zero point +// expected-error@+1 {{expected integer uniform zero point. got: }}} +!qalias = type !quant<"uniform[i8(-4:3):f32]{0.1:}"> + +// ----- +// Illegal uniform params: invalid zero point +// expected-error@+1 {{expected integer uniform zero point. got: abc}} +!qalias = type !quant<"uniform[i8(-4:3):f32]{0.1:abc}"> + +// ----- +// Illegal uniform params: missing closing brace +// expected-error@+1 {{unrecognized token: )}} +!qalias = type !quant<"uniform[i8(-4:3):f32]{0.1:0)"> + +// ----- +// Illegal expressed type: f33 +// expected-error@+1 {{unrecognized expressed type: f33}} +!qalias = type !quant<"uniform[i8(-4:3):f33]{0.99872:127}"> + +// ----- +// Illegal scale: negative +// expected-error@+1 {{illegal scale: -1.000000}} +!qalias = type !quant<"uniform[i8(-4:3):f32]{-1.0:127}"> + +// ----- +// Illegal uniform params: missing quantized dimension +// expected-error@+1 {{expected quantized dimension}} +!qalias = type !quant<"uniform[i8(-4:3):f32:]{2.000000e+02:-19.987200e-01:1}"> + +// ----- +// Illegal uniform params: unspecified quantized dimension, when multiple scales +// provided. +// expected-error@+1 {{multiple scales/zeroPoints provided, but quantizedDimension wasn't specified}} +!qalias = type !quant<"uniform[i8(-4:3):f32]{2.000000e+02,-19.987200e-01:1}"> diff --git a/mlir/test/Quantization/parse-uniform.mlir b/mlir/test/Quantization/parse-uniform.mlir new file mode 100644 index 000000000000..f29a93d9437a --- /dev/null +++ b/mlir/test/Quantization/parse-uniform.mlir @@ -0,0 +1,147 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// ----- +// All per-layer params specified: +// [signed] storageType, storageTypeMin, storageTypeMax, expressedType, scale, zeroPoint +// CHECK: !quant<"uniform[i8(-8:7):f32]{9.987200e-01:127}"> +!qalias = type !quant<"uniform[i8(-8:7):f32]{0.99872:127}"> +func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Trailing whitespace. +// CHECK: !quant<"uniform[i8(-8:7):f32]{9.987200e-01:127}"> +!qalias = type !quant<"uniform[i8(-8:7):f32]{0.99872:127} "> +func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Required per-layer params specified: +// [unsigned] storageType, expressedType, scale +// CHECK: !quant<"uniform[u8:f32]{9.987200e-01}"> +!qalias = type !quant<"uniform[u8:f32]{0.99872}"> +func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Exponential scale (-) +// CHECK: !quant<"uniform[u8:f32]{2.000000e-02}"> +!qalias = type !quant<"uniform[u8:f32]{2.0e-2}"> +func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Exponential scale (+) +// CHECK: !quant<"uniform[u8:f32]{2.000000e+02}"> +!qalias = type !quant<"uniform[u8:f32]{2.0e+2}"> +func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Storage type: i16 +// CHECK: !quant<"uniform[i16:f32]{2.000000e+02}"> +!qalias = type !quant<"uniform[i16:f32]{2.0e+2}"> +func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Storage type: u16 +// CHECK: !quant<"uniform[u16:f32]{2.000000e+02}"> +!qalias = type !quant<"uniform[u16:f32]{2.0e+2}"> +func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Storage type: i32 +// CHECK: !quant<"uniform[i32:f32]{2.000000e+02}"> +!qalias = type !quant<"uniform[i32:f32]{2.0e+2}"> +func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Storage type: u32 +// CHECK: !quant<"uniform[u32:f32]{2.000000e+02}"> +!qalias = type !quant<"uniform[u32:f32]{2.0e+2}"> +func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Expressed type: f32 +// CHECK: !quant<"uniform[u8:f32]{2.000000e+02}"> +!qalias = type !quant<"uniform[u8:f32]{2.0e+2}"> +func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Expressed type: f16 +// CHECK: !quant<"uniform[u8:f16]{2.000000e+02}"> +!qalias = type !quant<"uniform[u8:f16]{2.0e+2}"> +func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Expressed type: f64 +// CHECK: !quant<"uniform[u8:f64]{2.000000e+02}"> +!qalias = type !quant<"uniform[u8:f64]{2.0e+2}"> +func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Expressed type: bf16 +// CHECK: !quant<"uniform[u8:bf16]{2.000000e+02}"> +!qalias = type !quant<"uniform[u8:bf16]{2.0e+2}"> +func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Per-axis scales and zero points (affine) +// CHECK: !quant<"uniform[u8:f32:1]{2.000000e+02:-120,9.987200e-01:127}"> +!qalias = type !quant<"uniform[u8:f32:1]{2.0e+2:-120,0.99872:127}"> +func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Per-axis scales and no zero points (fixedpoint) +// CHECK: !quant<"uniform[i8:f32:1]{2.000000e+02,9.987200e-01}"> +!qalias = type !quant<"uniform[i8:f32:1]{2.0e+2,0.99872}"> +func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Per-axis scales and zero points (mixed affine and fixedpoint) +// CHECK: !quant<"uniform[i8:f32:1]{2.000000e+02,9.987200e-01:120}"> +!qalias = type !quant<"uniform[i8:f32:1]{2.0e+2,0.99872:120}"> +func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} diff --git a/mlir/test/Quantization/tf-lower-fakequant-invalid.mlir b/mlir/test/Quantization/tf-lower-fakequant-invalid.mlir new file mode 100644 index 000000000000..193522a56c49 --- /dev/null +++ b/mlir/test/Quantization/tf-lower-fakequant-invalid.mlir @@ -0,0 +1,36 @@ +// RUN: mlir-opt %s -split-input-file -verify -quant-lower-tf + +// ----- +// TODO(laurenzo): move this test to the TensorFlow/tf-ops-invalid.mlir +// Verify that a mismatched range errors. +func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { +^bb0(%arg0: tensor<8x4x3xf32>): + // expected-error@+1 {{op range failed to straddle zero: [1.100000,1.500000]}} + %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) { + min: 1.1, max: 1.5, num_bits: 8 + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %0 : tensor<8x4x3xf32> +} + +// ----- +// Verify that a valid range errors. +func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { +^bb0(%arg0: tensor<8x4x3xf32>): + // expected-error@+1 {{op range is invalid: [1.100000,1.000000}} + %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) { + min: 1.1, max: 1.0, num_bits: 8 + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %0 : tensor<8x4x3xf32> +} + +// ----- +// TODO(laurenzo): move this test to the TensorFlow/tf-ops-invalid.mlir +// Unsupported quantizable type (i1 is currently not a supported element type). +func @fakeQuantArgs(tensor<8x4x3xi1>) -> tensor<8x4x3xi1> { +^bb0(%arg0: tensor<8x4x3xi1>): + // expected-error@+1 {{op operand #0 must be tensor of 32-bit float values}} + %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) { + min: 1.1, max: 1.0, num_bits: 8 + } : (tensor<8x4x3xi1>) -> tensor<8x4x3xi1> + return %0 : tensor<8x4x3xi1> +} diff --git a/mlir/test/Quantization/tf-lower-fakequant.mlir b/mlir/test/Quantization/tf-lower-fakequant.mlir new file mode 100644 index 000000000000..a6c572e77901 --- /dev/null +++ b/mlir/test/Quantization/tf-lower-fakequant.mlir @@ -0,0 +1,77 @@ +// RUN: mlir-opt %s -split-input-file -quant-lower-tf | FileCheck %s --dump-input=fail + +// ----- +// Verifies a quint8 asymmetric 0..1 range. +// CHECK-LABEL: fakeQuantArgs_Quint8_0_1 +func @fakeQuantArgs_Quint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { +^bb0(%arg0: tensor<8x4x3xf32>): + // CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<8x4x3xf32>) + // CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[u8:f32]{0.0039215686274509803}">> + // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8:f32]{0.0039215686274509803}">>) + // CHECK-SAME: -> tensor<8x4x3xf32> + %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) { + min: 0.0, max: 1.0, num_bits: 8 + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %0 : tensor<8x4x3xf32> +} + +// ----- +// Verifies a quint8 asymmetric 0..1 range (with narrow_range = true). +// CHECK_LABEL: fakeQuantArgs_Quint8_NarrowRange +func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { +^bb0(%arg0: tensor<8x4x3xf32>): + // CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<8x4x3xf32>) + // CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[u8(1:255):f32]{0.003937007874015748:1}">> + // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8(1:255):f32]{0.003937007874015748:1}">>) + // CHECK-SAME: -> tensor<8x4x3xf32> + %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) { + min: 0.0, max: 1.0, num_bits: 8, narrow_range: true + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %0 : tensor<8x4x3xf32> +} + +// ----- +// Verifies a quint8 symmetric range of -1..127/128. +// CHECK_LABEL: fakeQuantArgs_Quint8_SymmetricRange +func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { +^bb0(%arg0: tensor<8x4x3xf32>): + // CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<8x4x3xf32>) + // CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[u8:f32]{7.812500e-03:128}">> + // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8:f32]{7.812500e-03:128}">>) + // CHECK-SAME: -> tensor<8x4x3xf32> + %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) { + min: -1.0, max: 0.9921875, num_bits: 8, narrow_range: false + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %0 : tensor<8x4x3xf32> +} + +// ----- +// Verifies a commonly used -1..1 symmetric 16bit range with a zero point of +// 0 and range -1.0 .. 32767/32768. +// CHECK-LABEL: fakeQuantArgs_Qint16_Symmetric +func @fakeQuantArgs_Qint16_Symmetric(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { +^bb0(%arg0: tensor<8x4x3xf32>): + // CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<8x4x3xf32>) + // CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[i16:f32]{3.05175781185626E-5}">> + // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[i16:f32]{3.05175781185626E-5}">>) + // CHECK-SAME: -> tensor<8x4x3xf32> + %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) { + min: -1.0, max: 0.999969482, num_bits: 16 + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %0 : tensor<8x4x3xf32> +} + +// ----- +// Verify that lowering to barriers of unranked tensors functions. +// CHECK-LABEL: fakeQuantArgs_UnrankedTensor +func @fakeQuantArgs_UnrankedTensor(tensor) -> tensor { +^bb0(%arg0: tensor): + // CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor) + // CHECK-SAME: -> tensor> + // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor>) + // CHECK-SAME: -> tensor + %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) { + min: 0.0, max: 1.0, num_bits: 8 + } : (tensor) -> tensor + return %0 : tensor +} diff --git a/mlir/unittests/Quantization/QuantizationUtilsTest.cpp b/mlir/unittests/Quantization/QuantizationUtilsTest.cpp new file mode 100644 index 000000000000..9d30d286936f --- /dev/null +++ b/mlir/unittests/Quantization/QuantizationUtilsTest.cpp @@ -0,0 +1,173 @@ +//===- QuantizationUtilsTest.cpp - unit tests for quantization utils ------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Quantization/QuantizeUtils.h" +#include "mlir/Quantization/UniformSupport.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace mlir::quant; + +namespace { + +// Test UniformQuantizedValueConverter converts all APFloat to a magic number 5. +class TestUniformQuantizedValueConverter + : public UniformQuantizedValueConverter { +public: + TestUniformQuantizedValueConverter(UniformQuantizedType type) + : UniformQuantizedValueConverter(type), qtype(type) {} + APInt quantizeFloatToInt(APFloat expressedValue) const { + return APInt(qtype.getStorageType().cast().getWidth(), 5L); + } + +private: + UniformQuantizedType qtype; +}; + +Attribute getTestFloatAttr(double value, MLIRContext *ctx) { + return FloatAttr::get(FloatType::getF32(ctx), value); +} + +template +ConcreteAttrClass getTestElementsAttr(MLIRContext *ctx, ArrayRef shape, + Arg... value) { + auto eleType = FloatType::getF32(ctx); + VectorOrTensorType tensorType; + if (shape.size() == 1 && shape[0] == -1) { + tensorType = UnrankedTensorType::get(eleType); + } else { + tensorType = RankedTensorType::get(shape, eleType); + } + return ConcreteAttrClass::get(tensorType, value...); +} + +ElementsAttr getTestSparseElementsAttr(MLIRContext *ctx, + ArrayRef shape) { + auto eleType = FloatType::getF32(ctx); + VectorOrTensorType tensorType; + if (shape.size() == 1 && shape[0] == -1) { + tensorType = UnrankedTensorType::get(eleType); + } else { + tensorType = RankedTensorType::get(shape, eleType); + } + auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(64, ctx)); + auto indices = + DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)}); + auto valuesType = RankedTensorType::get({1}, eleType); + auto values = DenseFPElementsAttr::get(valuesType, {APFloat(0.0f)}); + return SparseElementsAttr::get(tensorType, indices, values); +} + +UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) { + return UniformQuantizedType::get(/*flags=*/false, storageType, + FloatType::getF32(ctx), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/0, + /*storageTypeMax=*/255); +} + +TEST(QuantizationUtilsTest, convertFloatAttrUniform) { + MLIRContext ctx; + IntegerType convertedType = IntegerType::get(8, &ctx); + auto quantizedType = getTestQuantizedType(convertedType, &ctx); + TestUniformQuantizedValueConverter converter(quantizedType); + + auto realValue = getTestFloatAttr(1.0, &ctx); + Type typeResult; + auto valueResult = + quantizeAttrUniform(realValue, quantizedType, converter, typeResult); + + EXPECT_EQ(valueResult.cast().getInt(), 5); + EXPECT_EQ( + valueResult.cast().getType().cast().getWidth(), + convertedType.getWidth()); +} + +TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) { + MLIRContext ctx; + IntegerType convertedType = IntegerType::get(8, &ctx); + auto quantizedType = getTestQuantizedType(convertedType, &ctx); + TestUniformQuantizedValueConverter converter(quantizedType); + auto realValue = getTestElementsAttr>( + &ctx, {1, 2}, {getTestFloatAttr(1.0, &ctx), getTestFloatAttr(2.0, &ctx)}); + + Type returnedType; + auto returnedValue = + quantizeAttrUniform(realValue, quantizedType, converter, returnedType); + + // Check Elements attribute shape and kind are not changed. + auto tensorType = returnedType.cast(); + auto expectedTensorType = realValue.getType().cast(); + EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape()); + EXPECT_EQ(tensorType.getElementType(), convertedType); + EXPECT_EQ(returnedValue.getKind(), Attribute::Kind::DenseIntElements); + + // Check Elements attribute element value is expected. + auto firstValue = returnedValue.cast().getValue({0, 0}); + EXPECT_EQ(firstValue.cast().getInt(), 5); +} + +TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) { + MLIRContext ctx; + IntegerType convertedType = IntegerType::get(8, &ctx); + auto quantizedType = getTestQuantizedType(convertedType, &ctx); + TestUniformQuantizedValueConverter converter(quantizedType); + auto realValue = getTestElementsAttr( + &ctx, {1, 2}, getTestFloatAttr(1.0, &ctx)); + + Type returnedType; + auto returnedValue = + quantizeAttrUniform(realValue, quantizedType, converter, returnedType); + + // Check Elements attribute shape and kind are not changed. + auto tensorType = returnedType.cast(); + auto expectedTensorType = realValue.getType().cast(); + EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape()); + EXPECT_EQ(tensorType.getElementType(), convertedType); + EXPECT_EQ(returnedValue.getKind(), Attribute::Kind::SplatElements); + + // Check Elements attribute element value is expected. + auto firstValue = returnedValue.cast().getValue({0, 0}); + EXPECT_EQ(firstValue.cast().getInt(), 5); +} + +TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) { + MLIRContext ctx; + IntegerType convertedType = IntegerType::get(8, &ctx); + auto quantizedType = getTestQuantizedType(convertedType, &ctx); + TestUniformQuantizedValueConverter converter(quantizedType); + auto realValue = getTestSparseElementsAttr(&ctx, {1, 2}); + + Type returnedType; + auto returnedValue = + quantizeAttrUniform(realValue, quantizedType, converter, returnedType); + + // Check Elements attribute shape and kind are not changed. + auto tensorType = returnedType.cast(); + auto expectedTensorType = realValue.getType().cast(); + EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape()); + EXPECT_EQ(tensorType.getElementType(), convertedType); + EXPECT_EQ(returnedValue.getKind(), Attribute::Kind::SparseElements); + + // Check Elements attribute element value is expected. + auto firstValue = returnedValue.cast().getValue({0, 0}); + EXPECT_EQ(firstValue.cast().getInt(), 5); +} + +} // end namespace