From 288bf2b5b95c2084d721461b6207787cd58946df Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Wed, 3 Apr 2019 16:07:37 -0700 Subject: [PATCH] Split the Quantization dialect. - Retains Quantization types and predicates. - Retains utilities and example (testable) passes/ops. - Retains unit tests for example passes/ops. - Moves fixed point ops (and corresponding real ops) to FxpMathOps. - Moves real -> fixed point pass to FxpMathOps. - Sever the dependency on the TF dialect from Quantization. These dialects should now be open-sourcable. -- PiperOrigin-RevId: 241825598 --- mlir/include/mlir/FxpMathOps/FxpMathOps.h | 39 ++++ mlir/include/mlir/FxpMathOps/FxpMathOps.td | 183 +++++++++++++++ mlir/include/mlir/FxpMathOps/Passes.h | 39 ++++ mlir/include/mlir/Quantization/Passes.h | 18 +- mlir/include/mlir/Quantization/QuantOps.td | 213 +++--------------- .../mlir/Quantization/QuantPredicates.td | 72 ++++++ .../lib/FxpMathOps/IR/DialectRegistration.cpp | 24 ++ mlir/lib/FxpMathOps/IR/FxpMathOps.cpp | 38 ++++ .../Transforms/LowerUniformRealMath.cpp | 11 +- .../{LowerTF.cpp => ConvertSimQuant.cpp} | 41 ++-- .../lower-uniform-real-math-addew.mlir | 36 +-- ...id.mlir => convert-fakequant-invalid.mlir} | 14 +- ...-fakequant.mlir => convert-fakequant.mlir} | 12 +- 13 files changed, 482 insertions(+), 258 deletions(-) create mode 100644 mlir/include/mlir/FxpMathOps/FxpMathOps.h create mode 100644 mlir/include/mlir/FxpMathOps/FxpMathOps.td create mode 100644 mlir/include/mlir/FxpMathOps/Passes.h create mode 100644 mlir/include/mlir/Quantization/QuantPredicates.td create mode 100644 mlir/lib/FxpMathOps/IR/DialectRegistration.cpp create mode 100644 mlir/lib/FxpMathOps/IR/FxpMathOps.cpp rename mlir/lib/{Quantization => FxpMathOps}/Transforms/LowerUniformRealMath.cpp (97%) rename mlir/lib/Quantization/Transforms/{LowerTF.cpp => ConvertSimQuant.cpp} (74%) rename mlir/test/{Quantization => FxpMathOps}/lower-uniform-real-math-addew.mlir (72%) rename mlir/test/Quantization/{tf-lower-fakequant-invalid.mlir => convert-fakequant-invalid.mlir} (64%) rename mlir/test/Quantization/{tf-lower-fakequant.mlir => convert-fakequant.mlir} (91%) diff --git a/mlir/include/mlir/FxpMathOps/FxpMathOps.h b/mlir/include/mlir/FxpMathOps/FxpMathOps.h new file mode 100644 index 000000000000..c8854c26ce5a --- /dev/null +++ b/mlir/include/mlir/FxpMathOps/FxpMathOps.h @@ -0,0 +1,39 @@ +//===- FxpMathOps/FxpMathOps.h - Fixed point ops ----------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef MLIR_FXPMATH_QUANTOPS_H_ +#define MLIR_FXPMATH_QUANTOPS_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace fxpmath { + +/// Defines the 'FxpMathOps' dialect. +class FxpMathOpsDialect : public Dialect { +public: + FxpMathOpsDialect(MLIRContext *context); +}; + +#define GET_OP_CLASSES +#include "mlir/FxpMathOps/FxpMathOps.h.inc" + +} // namespace fxpmath +} // namespace mlir + +#endif // MLIR_FXPMATH_QUANTOPS_H_ diff --git a/mlir/include/mlir/FxpMathOps/FxpMathOps.td b/mlir/include/mlir/FxpMathOps/FxpMathOps.td new file mode 100644 index 000000000000..24d5e6f57c97 --- /dev/null +++ b/mlir/include/mlir/FxpMathOps/FxpMathOps.td @@ -0,0 +1,183 @@ +//===- FxpMathOps.td - Fixed point ops --------------------*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This is the operation definition file for fixed point ops (and real +// equivalents). +// +//===----------------------------------------------------------------------===// + +#ifdef FXPMATH_OPS +#else + +#ifdef OP_BASE +#else +include "mlir/IR/OpBase.td" +include "mlir/Quantization/QuantPredicates.td" +#endif // OP_BASE + +//===----------------------------------------------------------------------===// +// Attributes +//===----------------------------------------------------------------------===// + +// Real value for an (inclusive) min/max clamp limit. +def fxpmath_ClampValueAttr : OptionalAttr; + +// Element-wise activation function to apply. +// Note that RELU activations are not here: they are expressed as clamps. +def fxpmath_EwUnaryFnAttr : + StringBasedAttr, "element-wise unary function"> { + let returnType = [{ StringRef }]; + let defaultValue = "IDENTITY"; +} + +class fxpmath_ConstEwUnaryFn : ConstantAttr; +def fxpmath_EwUnaryFn_Identity: fxpmath_ConstEwUnaryFn<"IDENTITY">; +def fxpmath_EwUnaryFn_Tanh : fxpmath_ConstEwUnaryFn<"TANH">; +def fxpmath_EwUnaryFn_Sigmoid : fxpmath_ConstEwUnaryFn<"SIGMOID">; +def fxpmath_EwUnaryFn_Exp : fxpmath_ConstEwUnaryFn<"EXP">; +def fxpmath_EwUnaryFn_Log : fxpmath_ConstEwUnaryFn<"LOG">; +def fxpmath_EwUnaryFn_Neg : fxpmath_ConstEwUnaryFn<"NEG">; +def fxpmath_EwUnaryFn_Rsqrt : fxpmath_ConstEwUnaryFn<"RSQRT">; +def fxpmath_EwUnaryFn_Sin : fxpmath_ConstEwUnaryFn<"SIN">; +def fxpmath_EwUnaryFn_Square : fxpmath_ConstEwUnaryFn<"SQUARE">; +def fxpmath_EwUnaryFn_Sqrt : fxpmath_ConstEwUnaryFn<"SQRT">; +def fxpmath_EwUnaryFn_CmpZ : fxpmath_ConstEwUnaryFn<"CMPZ">; +def fxpmath_EwUnaryFn_CmpNZ : fxpmath_ConstEwUnaryFn<"CMPNZ">; +def fxpmath_EwUnaryFn_CmpLZ : fxpmath_ConstEwUnaryFn<"CMPLZ">; +def fxpmath_EwUnaryFn_CmpGZ : fxpmath_ConstEwUnaryFn<"CMPGZ">; + +//===----------------------------------------------------------------------===// +// Base classes +//===----------------------------------------------------------------------===// + +class fxpmath_Op traits> : + Op; + +//===----------------------------------------------------------------------===// +// Fixed-point (fxp) arithmetic ops used by kernels. +//===----------------------------------------------------------------------===// + +def fxpmath_RoundingDivideByPotFxpOp : + fxpmath_Op<"rounding_divide_by_poti", [NoSideEffect, SameValueType]>, + Arguments<(ins quant_StorageValueType:$x, I32Attr:$exponent)>, + Results<(outs quant_StorageValueType:$y)> { + let description = [{ + Computes integer division by a power-of-two, correctly rounded-to-nearest. + Also known as a rounding arithmetic right shift. See + gemmlowp::RoundingDivideByPOT for a reference implementation. + }]; + + let verifier = [{ + auto verifyExponent = exponent().getSExtValue(); + if (verifyExponent < 0 || verifyExponent > 31) { + return emitOpError("exponent must be in range [0..31]"); + } + return success(); + }]; +} + +def fxpmath_SaturatingAddFxpOp : + fxpmath_Op<"saturating_addi", [NoSideEffect, SameValueType]>, + Arguments<(ins quant_StorageValueType:$x, + quant_StorageValueType:$y, + I32Attr:$clamp_min, + I32Attr:$clamp_max)>, + Results<(outs quant_StorageValueType:$sum)> { + let description = [{ + Computes saturating addition of two operands, saturating to the given min + and max value. The implementation is responsible for choosing an + intermediate register size appropriate to carry out the operation without + overflow. See gemmlowp::SaturatingAdd for a reference implementation. + }]; +} + +//===----------------------------------------------------------------------===// +// Real math ops. +// +// Math ops on real numbers which may have a representation in quantized +// arithmetic. It is expected that eligible ops are lowered from a source +// dialect to this set of ops prior to the process of converting a compuation +// to a quantized form. It is a non-goal of these ops to preserve enough +// information to convert back to the higher level, source dialect. +// +// These ops support either real/floating point or QuantizedTypes as operands +// and results. Since not all transformations are supported (globally or +// sometimes for specific targets), a computation may end up with +// untransformable RealMathOps, in which case they need to be lowered as is +// (using floating point math). +// +// This op set takes advantage of the fact that it is typically trivial to +// combine a math function with a compatible bias addition and real-valued +// clamp (which can be done at a higher accumulation bit depth). +// +// In addition, all element-wise unary functions are collapsed into a single +// fxpmath_RealUnaryEwOp and selected via an enum-like attribute. Especially at +// low bit depths, this makes matching simpler and allows the construction of +// generic LUT-based implementations. It also allows specific lowering rules +// to consolidate runs of chained unary ops and fuse them to preceding math +// ops, potentially allowing them to operate directly on higher precision +// intermediates without resorting to lots of custom kernels for common +// formulas that can suffer from insufficient precision at low bit depths. +// +// Comparison operators are modeled as element-wise unary functions (i.e. +// CMPZ, CMPNZ, CMPLZ, CMPGZ) intended to follow a sub and output a 1bit +// quantized value. It is expected that lowering rules can fuse them with +// the preceding sub. +//===----------------------------------------------------------------------===// + +class fxpmath_RealMathOp traits = [], dag args> : + fxpmath_Op, + Arguments; + +//===----------------------------------------------------------------------===// +// Element wise binary real math ops. +//===----------------------------------------------------------------------===// + +class fxpmath_RealBinaryOp traits = []> : + fxpmath_RealMathOp, + Results<(outs quant_RealValueType:$r)>; + +class fxpmath_RealBinaryBiasOp traits = []> : + fxpmath_RealMathOp, + Results<(outs quant_RealValueType:$r)>; + +def fxpmath_RealAddEwOp : + fxpmath_RealBinaryOp<"real_add_ew", [NoSideEffect]>; + +def fxpmath_RealSubEwOp : + fxpmath_RealBinaryOp<"real_sub_ew", [NoSideEffect]>; + +def fxpmath_RealMulEwOp : + fxpmath_RealBinaryOp<"real_mul_ew", [NoSideEffect]>; + +def fxpmath_RealDivEwOp : + fxpmath_RealBinaryOp<"real_div_ew", [NoSideEffect]>; + +//===----------------------------------------------------------------------===// +// Element wise unary real math op. +//===----------------------------------------------------------------------===// + +def fxpmath_RealUnaryEwOp : + fxpmath_RealMathOp<"real_unary_ew", [NoSideEffect], + (ins quant_RealValueType:$x, fxpmath_EwUnaryFnAttr:$fn)>, + Results<(outs quant_RealValueType:$r)>; + +#endif // FXPMATH_OPS diff --git a/mlir/include/mlir/FxpMathOps/Passes.h b/mlir/include/mlir/FxpMathOps/Passes.h new file mode 100644 index 000000000000..e4df24ffebf9 --- /dev/null +++ b/mlir/include/mlir/FxpMathOps/Passes.h @@ -0,0 +1,39 @@ +//===- Passes.h - Fixed point math passes -----------------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines all of the passes owned by the FxpMathOps dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_FXPMATH_PASSES_H +#define MLIR_FXPMATH_PASSES_H + +namespace mlir { +class FunctionPassBase; + +namespace fxpmath { + +/// Creates a pass that lowers uniform-quantized real math ops to integer +/// arithmetic. This will leave unrecognized real math ops as-is and is +/// typically followed by a pass that lowers any unrecognized ops to a pure +/// floating point form. +FunctionPassBase *createLowerUniformRealMathPass(); + +} // namespace fxpmath +} // namespace mlir + +#endif // MLIR_FXPMATH_PASSES_H diff --git a/mlir/include/mlir/Quantization/Passes.h b/mlir/include/mlir/Quantization/Passes.h index 090d21cb2925..03bd986940d8 100644 --- a/mlir/include/mlir/Quantization/Passes.h +++ b/mlir/include/mlir/Quantization/Passes.h @@ -30,15 +30,9 @@ class FunctionPassBase; namespace quant { -/// Creates a pass that lowers quantization related TensorFlow ops into -/// the quantization dialect so that express and implied constraints expressed -/// at the TensorFlow source level can be represented to the quantization -/// system. This will specially handle any TensorFlow op that is useful for -/// guiding quantization. -/// -/// Note that if your intent is to compile a TensorFlow graph for floating -/// point inference, you should probably not use this pass. -FunctionPassBase *createLowerTFPass(); +/// Creates a pass that converts quantization simulation operations (i.e. +/// FakeQuant and those like it) to casts into/out of supported QuantizedTypes. +FunctionPassBase *createConvertSimulatedQuantPass(); /// Creates a pass that converts constants followed by a qbarrier to a /// constant whose value is quantized. This is typically one of the last @@ -47,12 +41,6 @@ FunctionPassBase *createLowerTFPass(); /// destructive and cannot be undone. FunctionPassBase *createConvertConstPass(); -/// Creates a pass that lowers uniform-quantized real math ops to integer -/// arithmetic. This will leave unrecognized real math ops as-is and is -/// typically followed by a pass that lowers any unrecognized ops to a pure -/// floating point form. -FunctionPassBase *createLowerUniformRealMathPass(); - } // namespace quant } // namespace mlir diff --git a/mlir/include/mlir/Quantization/QuantOps.td b/mlir/include/mlir/Quantization/QuantOps.td index 8c247a3c5911..09a1a65933f3 100644 --- a/mlir/include/mlir/Quantization/QuantOps.td +++ b/mlir/include/mlir/Quantization/QuantOps.td @@ -25,86 +25,9 @@ #ifdef OP_BASE #else include "mlir/IR/OpBase.td" +include "mlir/Quantization/QuantPredicates.td" #endif // OP_BASE -//===----------------------------------------------------------------------===// -// Quantization type definitions -//===----------------------------------------------------------------------===// - -class quant_TypedPrimitiveOrContainer : - Type.predicate, - TypedVector.predicate]>, - "primitive/tensor/vector of " # etype.description>; - -// An implementation of QuantizedType. -def quant_QuantizedType : - Type()">, "QuantizedType">; - -// A primitive type that can represent a real value. This is either a -// floating point value or a quantized type. -def quant_RealPrimitiveType : - Type, - "real valued primitive (float or quantized type)">; - -// A primitive type that can represent a storage value. This is either an -// integer or quantized type. -def quant_StoragePrimitiveType : - Type, - "quantized storage primitive (integer or quantized type)">; - -// A primitive or container of RealPrimitiveType. -def quant_RealValueType : - quant_TypedPrimitiveOrContainer; - -// A primitive or container of StoragePrimitiveType. -def quant_StorageValueType : - quant_TypedPrimitiveOrContainer; - -// Either a real valued or storage primitive or container type. -def quant_RealOrStorageValueType : - Type>; - -// An implementation of UniformQuantizedType. -def quant_UniformQuantizedType : - Type()">, "UniformQuantizedType">; - -// Predicate for detecting a container or primitive of UniformQuantizedType. -def quant_UniformQuantizedValueType : - quant_TypedPrimitiveOrContainer; - -//===----------------------------------------------------------------------===// -// Attributes -//===----------------------------------------------------------------------===// - -// Real value for an (inclusive) min/max clamp limit. -def quant_ClampValueAttr : OptionalAttr; - -// Element-wise activation function to apply. -// Note that RELU activations are not here: they are expressed as clamps. -def quant_EwUnaryFnAttr : - StringBasedAttr, "element-wise unary function"> { - let returnType = [{ StringRef }]; - let defaultValue = "IDENTITY"; -} - -class quant_ConstEwUnaryFn : ConstantAttr; -def quant_EwUnaryFn_Identity: quant_ConstEwUnaryFn<"IDENTITY">; -def quant_EwUnaryFn_Tanh : quant_ConstEwUnaryFn<"TANH">; -def quant_EwUnaryFn_Sigmoid : quant_ConstEwUnaryFn<"SIGMOID">; -def quant_EwUnaryFn_Exp : quant_ConstEwUnaryFn<"EXP">; -def quant_EwUnaryFn_Log : quant_ConstEwUnaryFn<"LOG">; -def quant_EwUnaryFn_Neg : quant_ConstEwUnaryFn<"NEG">; -def quant_EwUnaryFn_Rsqrt : quant_ConstEwUnaryFn<"RSQRT">; -def quant_EwUnaryFn_Sin : quant_ConstEwUnaryFn<"SIN">; -def quant_EwUnaryFn_Square : quant_ConstEwUnaryFn<"SQUARE">; -def quant_EwUnaryFn_Sqrt : quant_ConstEwUnaryFn<"SQRT">; -def quant_EwUnaryFn_CmpZ : quant_ConstEwUnaryFn<"CMPZ">; -def quant_EwUnaryFn_CmpNZ : quant_ConstEwUnaryFn<"CMPNZ">; -def quant_EwUnaryFn_CmpLZ : quant_ConstEwUnaryFn<"CMPLZ">; -def quant_EwUnaryFn_CmpGZ : quant_ConstEwUnaryFn<"CMPGZ">; - //===----------------------------------------------------------------------===// // Base classes //===----------------------------------------------------------------------===// @@ -170,116 +93,34 @@ def quant_StorageCastOp : Results<(outs quant_RealOrStorageValueType)>; //===----------------------------------------------------------------------===// -// Integral arithmetic ops used by kernels. +// Training integration ops //===----------------------------------------------------------------------===// -def quant_RoundingDivideByPotIOp : - quant_Op<"rounding_divide_by_poti", [NoSideEffect, SameValueType]>, - Arguments<(ins quant_StorageValueType:$x, I32Attr:$exponent)>, - Results<(outs quant_StorageValueType:$y)> { - let description = [{ - Computes integer division by a power-of-two, correctly rounded-to-nearest. - Also known as a rounding arithmetic right shift. See - gemmlowp::RoundingDivideByPOT for a reference implementation. - }]; +def quant_ConstFakeQuant : quant_Op<"const_fake_quant", + [NoSideEffect]> { + let summary = + "Simulates the effect of uniform quantization with const range."; - let verifier = [{ - auto verifyExponent = exponent().getSExtValue(); - if (verifyExponent < 0 || verifyExponent > 31) { - return emitOpError("exponent must be in range [0..31]"); - } - return success(); - }]; + let description = [{ +Given a const min, max, num_bits and narrow_range attribute, applies the same +uniform quantization simulation as is done by the TensorFlow +fake_quant_with_min_max_args op. See the fakeQuantAttrsToType() utility +method and the quant-convert-simulated-quantization pass for futher details. +}]; + + let arguments = (ins + F32Tensor:$inputs, + F32Attr:$min, + F32Attr:$max, + // The bitwidth of the quantization; between 2 and 16, inclusive. + I64Attr:$num_bits, + // Quantization range starts from 0 or 1; starts from 1 if true. + DefaultValuedAttr:$narrow_range + ); + + let results = (outs + F32Tensor:$outputs + ); } -def quant_SaturatingAddIOp : - quant_Op<"saturating_addi", [NoSideEffect, SameValueType]>, - Arguments<(ins quant_StorageValueType:$x, - quant_StorageValueType:$y, - I32Attr:$clamp_min, - I32Attr:$clamp_max)>, - Results<(outs quant_StorageValueType:$sum)> { - let description = [{ - Computes saturating addition of two operands, saturating to the given min - and max value. The implementation is responsible for choosing an - intermediate register size appropriate to carry out the operation without - overflow. See gemmlowp::SaturatingAdd for a reference implementation. - }]; -} - -//===----------------------------------------------------------------------===// -// Real math ops. -// -// Math ops on real numbers which may have a representation in quantized -// arithmetic. It is expected that eligible ops are lowered from a source -// dialect to this set of ops prior to the process of converting a compuation -// to a quantized form. It is a non-goal of these ops to preserve enough -// information to convert back to the higher level, source dialect. -// -// These ops support either real/floating point or QuantizedTypes as operands -// and results. Since not all transformations are supported (globally or -// sometimes for specific targets), a computation may end up with -// untransformable RealMathOps, in which case they need to be lowered as is -// (using floating point math). -// -// This op set takes advantage of the fact that it is typically trivial to -// combine a math function with a compatible bias addition and real-valued -// clamp (which can be done at a higher accumulation bit depth). -// -// In addition, all element-wise unary functions are collapsed into a single -// quant_RealUnaryEwOp and selected via an enum-like attribute. Especially at -// low bit depths, this makes matching simpler and allows the construction of -// generic LUT-based implementations. It also allows specific lowering rules -// to consolidate runs of chained unary ops and fuse them to preceding math -// ops, potentially allowing them to operate directly on higher precision -// intermediates without resorting to lots of custom kernels for common -// formulas that can suffer from insufficient precision at low bit depths. -// -// Comparison operators are modeled as element-wise unary functions (i.e. -// CMPZ, CMPNZ, CMPLZ, CMPGZ) intended to follow a sub and output a 1bit -// quantized value. It is expected that lowering rules can fuse them with -// the preceding sub. -//===----------------------------------------------------------------------===// - -class quant_RealMathOp traits = [], dag args> : - quant_Op, - Arguments; - -//===----------------------------------------------------------------------===// -// Element wise binary real math ops. -//===----------------------------------------------------------------------===// - -class quant_RealBinaryOp traits = []> : - quant_RealMathOp, - Results<(outs quant_RealValueType:$r)>; - -class quant_RealBinaryBiasOp traits = []> : - quant_RealMathOp, - Results<(outs quant_RealValueType:$r)>; - -def quant_RealAddEwOp : - quant_RealBinaryOp<"real_add_ew", [NoSideEffect]>; - -def quant_RealSubEwOp : - quant_RealBinaryOp<"real_sub_ew", [NoSideEffect]>; - -def quant_RealMulEwOp : - quant_RealBinaryOp<"real_mul_ew", [NoSideEffect]>; - -def quant_RealDivEwOp : - quant_RealBinaryOp<"real_div_ew", [NoSideEffect]>; - -//===----------------------------------------------------------------------===// -// Element wise unary real math op. -//===----------------------------------------------------------------------===// - -def quant_RealUnaryEwOp : - quant_RealMathOp<"real_unary_ew", [NoSideEffect], - (ins quant_RealValueType:$x, quant_EwUnaryFnAttr:$fn)>, - Results<(outs quant_RealValueType:$r)>; - -#endif // QUANTIZATION_OPS +#endif // QUANT_OPS diff --git a/mlir/include/mlir/Quantization/QuantPredicates.td b/mlir/include/mlir/Quantization/QuantPredicates.td new file mode 100644 index 000000000000..62a1e50568bf --- /dev/null +++ b/mlir/include/mlir/Quantization/QuantPredicates.td @@ -0,0 +1,72 @@ +//===- QuantPredicates.td - Predicates for dialect types ---*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Predicates for types in the Quantization dialect. +// +//===----------------------------------------------------------------------===// + +#ifdef QUANTIZATION_PREDICATES_ +#else + +//===----------------------------------------------------------------------===// +// Quantization type definitions +//===----------------------------------------------------------------------===// + +class quant_TypedPrimitiveOrContainer : + Type.predicate, + TypedVector.predicate]>, + "primitive/tensor/vector of " # etype.description>; + +// An implementation of QuantizedType. +def quant_QuantizedType : + Type()">, "QuantizedType">; + +// A primitive type that can represent a real value. This is either a +// floating point value or a quantized type. +def quant_RealPrimitiveType : + Type, + "real valued primitive (float or quantized type)">; + +// A primitive type that can represent a storage value. This is either an +// integer or quantized type. +def quant_StoragePrimitiveType : + Type, + "quantized storage primitive (integer or quantized type)">; + +// A primitive or container of RealPrimitiveType. +def quant_RealValueType : + quant_TypedPrimitiveOrContainer; + +// A primitive or container of StoragePrimitiveType. +def quant_StorageValueType : + quant_TypedPrimitiveOrContainer; + +// Either a real valued or storage primitive or container type. +def quant_RealOrStorageValueType : + Type>; + +// An implementation of UniformQuantizedType. +def quant_UniformQuantizedType : + Type()">, "UniformQuantizedType">; + +// Predicate for detecting a container or primitive of UniformQuantizedType. +def quant_UniformQuantizedValueType : + quant_TypedPrimitiveOrContainer; + +#endif // QUANTIZATION_PREDICATES_ \ No newline at end of file diff --git a/mlir/lib/FxpMathOps/IR/DialectRegistration.cpp b/mlir/lib/FxpMathOps/IR/DialectRegistration.cpp new file mode 100644 index 000000000000..24e666860b48 --- /dev/null +++ b/mlir/lib/FxpMathOps/IR/DialectRegistration.cpp @@ -0,0 +1,24 @@ +//===- DialectRegistration.cpp - Register FxpMathOps dialect --------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/FxpMathOps/FxpMathOps.h" + +using namespace mlir; +using namespace mlir::fxpmath; + +// Static initialization for the fxpmath ops dialect registration. +static mlir::DialectRegistration FxpMathOps; diff --git a/mlir/lib/FxpMathOps/IR/FxpMathOps.cpp b/mlir/lib/FxpMathOps/IR/FxpMathOps.cpp new file mode 100644 index 000000000000..f276685b7868 --- /dev/null +++ b/mlir/lib/FxpMathOps/IR/FxpMathOps.cpp @@ -0,0 +1,38 @@ +//===- FxpMathOps.cpp - Op implementation for FxpMathOps ------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/FxpMathOps/FxpMathOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Quantization/QuantOps.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/MathExtras.h" + +using namespace mlir; +using namespace mlir::fxpmath; + +#define GET_OP_CLASSES +#include "mlir/FxpMathOps/FxpMathOps.cpp.inc" + +FxpMathOpsDialect::FxpMathOpsDialect(MLIRContext *context) + : Dialect(/*name=*/"fxpmath", context) { + addOperations< +#define GET_OP_LIST +#include "mlir/FxpMathOps/FxpMathOps.cpp.inc" + >(); +} diff --git a/mlir/lib/Quantization/Transforms/LowerUniformRealMath.cpp b/mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp similarity index 97% rename from mlir/lib/Quantization/Transforms/LowerUniformRealMath.cpp rename to mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp index 9ce926408bb9..a5ad64233c38 100644 --- a/mlir/lib/Quantization/Transforms/LowerUniformRealMath.cpp +++ b/mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -15,15 +15,16 @@ // limitations under the License. // ============================================================================= +#include "mlir/FxpMathOps/FxpMathOps.h" +#include "mlir/FxpMathOps/Passes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" -#include "mlir/Quantization/Passes.h" -#include "mlir/Quantization/QuantOps.h" #include "mlir/Quantization/UniformSupport.h" #include using namespace mlir; +using namespace mlir::fxpmath; using namespace mlir::quant; namespace { @@ -198,7 +199,7 @@ static LogicalResult tryRewriteFixedPOTAddEw(const RealBinaryOpInfo &constInfo, if (rhsRightShift != 0) { rhsStorageValue = rewriter - .create( + .create( mathOp->getLoc(), rhsStorageValue, IntegerAttr::get(IntegerType::get(32, rewriter.getContext()), rhsRightShift)) @@ -206,7 +207,7 @@ static LogicalResult tryRewriteFixedPOTAddEw(const RealBinaryOpInfo &constInfo, } // Add. - Value *sumValue = rewriter.create( + Value *sumValue = rewriter.create( mathOp->getLoc(), lhsStorageValue, rhsStorageValue, clampMinMax.first, clampMinMax.second); @@ -255,5 +256,5 @@ FunctionPassBase *createLowerUniformRealMathPass() { } static PassRegistration - pass("quant-lower-uniform-real-math", + pass("fxpmath-lower-uniform-real-math", "Lowers uniform-quantized real math ops to integer arithmetic."); diff --git a/mlir/lib/Quantization/Transforms/LowerTF.cpp b/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp similarity index 74% rename from mlir/lib/Quantization/Transforms/LowerTF.cpp rename to mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp index 24a35c9f3179..a1c9568e4224 100644 --- a/mlir/lib/Quantization/Transforms/LowerTF.cpp +++ b/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp @@ -1,4 +1,4 @@ -//===- LowerTF.cpp - Passes for lowering from TensorFlow ------------------===// +//===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===// // // Copyright 2019 The MLIR Authors. // @@ -23,44 +23,42 @@ #include "mlir/Quantization/Passes.h" #include "mlir/Quantization/QuantOps.h" #include "mlir/Quantization/UniformSupport.h" -#include "mlir/TensorFlow/TFOps.h" using namespace mlir; using namespace mlir::quant; namespace { -class LowerTFPass : public FunctionPass { +class ConvertSimulatedQuantPass + : public FunctionPass { public: void runOnFunction() override; }; } // end anonymous namespace -/// Rewrites TensorFlow FakeQuantWithMinMaxArgs into a qbarrier/dbarrier pair. -class FakeQuantWithMinMaxArgsRewrite : public RewritePattern { +/// Rewrites ConstFakeQuant into a qbarrier/dbarrier pair. +class ConstFakeQuantRewrite : public RewritePattern { public: bool *hadFailure; - FakeQuantWithMinMaxArgsRewrite(MLIRContext *context, bool *hadFailure) - : RewritePattern(TF::FakeQuantWithMinMaxArgsOp::getOperationName(), 1, - context), + ConstFakeQuantRewrite(MLIRContext *context, bool *hadFailure) + : RewritePattern(ConstFakeQuant::getOperationName(), 1, context), hadFailure(hadFailure) {} - PatternMatchResult match(Operation *op) const override { - return matchSuccess(); - } - - void rewrite(Operation *op, PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { // TODO: If this pattern comes up more frequently, consider adding core // support for failable rewrites. if (failableRewrite(op, rewriter)) { *hadFailure = true; } + + return matchSuccess(); } bool failableRewrite(Operation *op, PatternRewriter &rewriter) const { - auto fqOp = op->template cast(); + auto fqOp = op->cast(); auto converter = ExpressedToUniformQuantizedConverter::forInputType(fqOp.getType()); @@ -93,20 +91,23 @@ public: } }; -void LowerTFPass::runOnFunction() { +void ConvertSimulatedQuantPass::runOnFunction() { bool hadFailure = false; OwningRewritePatternList patterns; auto &func = getFunction(); auto *context = &getContext(); patterns.push_back( - llvm::make_unique(context, &hadFailure)); + llvm::make_unique(context, &hadFailure)); applyPatternsGreedily(func, std::move(patterns)); if (hadFailure) signalPassFailure(); } -FunctionPassBase *createLowerTFPass() { return new LowerTFPass(); } +FunctionPassBase *createConvertSimulatedQuantPass() { + return new ConvertSimulatedQuantPass(); +} -static PassRegistration - pass("quant-lower-tf", - "Lowers TensorFlow constraint ops to the quantization dialect"); +static PassRegistration + pass("quant-convert-simulated-quantization", + "Converts training-time simulated quantization ops to corresponding " + "quantize/dequantize casts."); diff --git a/mlir/test/Quantization/lower-uniform-real-math-addew.mlir b/mlir/test/FxpMathOps/lower-uniform-real-math-addew.mlir similarity index 72% rename from mlir/test/Quantization/lower-uniform-real-math-addew.mlir rename to mlir/test/FxpMathOps/lower-uniform-real-math-addew.mlir index 96c0886ba807..29783f8500da 100644 --- a/mlir/test/Quantization/lower-uniform-real-math-addew.mlir +++ b/mlir/test/FxpMathOps/lower-uniform-real-math-addew.mlir @@ -1,18 +1,18 @@ -// RUN: mlir-opt %s -split-input-file -quant-lower-uniform-real-math | FileCheck %s --dump-input=fail +// RUN: mlir-opt %s -split-input-file -fxpmath-lower-uniform-real-math | FileCheck %s --dump-input=fail // ----- // Verify lowering when operands and result have the same fixedpoint pot scale. // CHECK-LABEL: real_addew_fixedpoint_same_scale // CHECK: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8> // CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8> -// CHECK-NEXT: %2 = "quant.saturating_addi"(%0, %1) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> +// CHECK-NEXT: %2 = "fxpmath.saturating_addi"(%0, %1) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> // CHECK-NEXT: %3 = "quant.scast"(%2) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> // CHECK-NEXT: return %3 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> !type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> !type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> !type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> func @real_addew_fixedpoint_same_scale(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { - %0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) + %0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) return %0 : !type_result } @@ -21,15 +21,15 @@ func @real_addew_fixedpoint_same_scale(%arg0 : !type_lhs, %arg1: !type_rhs) -> ! // CHECK-LABEL: real_addew_fixedpoint_rhs_shift // CHECK: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8> // CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> tensor<4xi8> -// CHECK-NEXT: %2 = "quant.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8> -// CHECK-NEXT: %3 = "quant.saturating_addi"(%0, %2) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> +// CHECK-NEXT: %2 = "fxpmath.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8> +// CHECK-NEXT: %3 = "fxpmath.saturating_addi"(%0, %2) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> // CHECK-NEXT: %4 = "quant.scast"(%3) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> // CHECK-NEXT: return %4 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> !type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> !type_rhs = type tensor<4x!quant<"uniform[i8:f32]{7.8125e-03}">> !type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> func @real_addew_fixedpoint_rhs_shift(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { - %0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) + %0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) return %0 : !type_result } @@ -38,15 +38,15 @@ func @real_addew_fixedpoint_rhs_shift(%arg0 : !type_lhs, %arg1: !type_rhs) -> !t // CHECK-LABEL: real_addew_fixedpoint_lhs_shift // CHECK: %0 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8> // CHECK-NEXT: %1 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> tensor<4xi8> -// CHECK-NEXT: %2 = "quant.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8> -// CHECK-NEXT: %3 = "quant.saturating_addi"(%0, %2) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> +// CHECK-NEXT: %2 = "fxpmath.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8> +// CHECK-NEXT: %3 = "fxpmath.saturating_addi"(%0, %2) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> // CHECK-NEXT: %4 = "quant.scast"(%3) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> // CHECK-NEXT: return %4 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> !type_lhs = type tensor<4x!quant<"uniform[i8:f32]{7.8125e-03}">> !type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> !type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> func @real_addew_fixedpoint_lhs_shift(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { - %0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) + %0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) return %0 : !type_result } @@ -56,15 +56,15 @@ func @real_addew_fixedpoint_lhs_shift(%arg0 : !type_lhs, %arg1: !type_rhs) -> !t // CHECK-LABEL: real_addew_fixedpoint_clamp // CHECK: %0 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8> // CHECK-NEXT: %1 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> tensor<4xi8> -// CHECK-NEXT: %2 = "quant.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8> -// CHECK-NEXT: %3 = "quant.saturating_addi"(%0, %2) {clamp_max: 64 : i32, clamp_min: -64 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> +// CHECK-NEXT: %2 = "fxpmath.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8> +// CHECK-NEXT: %3 = "fxpmath.saturating_addi"(%0, %2) {clamp_max: 64 : i32, clamp_min: -64 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> // CHECK-NEXT: %4 = "quant.scast"(%3) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> // CHECK-NEXT: return %4 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> !type_lhs = type tensor<4x!quant<"uniform[i8:f32]{7.8125e-03}">> !type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> !type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> func @real_addew_fixedpoint_clamp(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { - %0 = "quant.real_add_ew"(%arg0, %arg1) { clamp_min:-4.0, clamp_max:4.0 } + %0 = "fxpmath.real_add_ew"(%arg0, %arg1) { clamp_min:-4.0, clamp_max:4.0 } : (!type_lhs, !type_rhs) -> (!type_result) return %0 : !type_result } @@ -76,8 +76,8 @@ func @real_addew_fixedpoint_clamp(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_ !type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> !type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> func @real_addew_unquantized_lhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { - // CHECK: %0 = "quant.real_add_ew"(%arg0, %arg1) - %0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) + // CHECK: %0 = "fxpmath.real_add_ew"(%arg0, %arg1) + %0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) return %0 : !type_result } @@ -88,8 +88,8 @@ func @real_addew_unquantized_lhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_r !type_rhs = type tensor<4xf32> !type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> func @real_addew_unquantized_rhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { - // CHECK: %0 = "quant.real_add_ew"(%arg0, %arg1) - %0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) + // CHECK: %0 = "fxpmath.real_add_ew"(%arg0, %arg1) + %0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) return %0 : !type_result } @@ -100,7 +100,7 @@ func @real_addew_unquantized_rhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_r !type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> !type_result = type tensor<4xf32> func @real_addew_unquantized_result(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { - // CHECK: %0 = "quant.real_add_ew"(%arg0, %arg1) - %0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) + // CHECK: %0 = "fxpmath.real_add_ew"(%arg0, %arg1) + %0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) return %0 : !type_result } diff --git a/mlir/test/Quantization/tf-lower-fakequant-invalid.mlir b/mlir/test/Quantization/convert-fakequant-invalid.mlir similarity index 64% rename from mlir/test/Quantization/tf-lower-fakequant-invalid.mlir rename to mlir/test/Quantization/convert-fakequant-invalid.mlir index 193522a56c49..bdaab47ab630 100644 --- a/mlir/test/Quantization/tf-lower-fakequant-invalid.mlir +++ b/mlir/test/Quantization/convert-fakequant-invalid.mlir @@ -1,12 +1,11 @@ -// RUN: mlir-opt %s -split-input-file -verify -quant-lower-tf +// RUN: mlir-opt %s -split-input-file -verify -quant-convert-simulated-quantization // ----- -// TODO(laurenzo): move this test to the TensorFlow/tf-ops-invalid.mlir // Verify that a mismatched range errors. func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>): - // expected-error@+1 {{op range failed to straddle zero: [1.100000,1.500000]}} - %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) { + // expected-error@+1 {{FakeQuant range must straddle zero: [1.100000,1.500000]}} + %0 = "quant.const_fake_quant"(%arg0) { min: 1.1, max: 1.5, num_bits: 8 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> @@ -16,20 +15,19 @@ func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { // Verify that a valid range errors. func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>): - // expected-error@+1 {{op range is invalid: [1.100000,1.000000}} - %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) { + // expected-error@+1 {{FakeQuant range must straddle zero: [1.100000,1.000000}} + %0 = "quant.const_fake_quant"(%arg0) { min: 1.1, max: 1.0, num_bits: 8 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } // ----- -// TODO(laurenzo): move this test to the TensorFlow/tf-ops-invalid.mlir // Unsupported quantizable type (i1 is currently not a supported element type). func @fakeQuantArgs(tensor<8x4x3xi1>) -> tensor<8x4x3xi1> { ^bb0(%arg0: tensor<8x4x3xi1>): // expected-error@+1 {{op operand #0 must be tensor of 32-bit float values}} - %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) { + %0 = "quant.const_fake_quant"(%arg0) { min: 1.1, max: 1.0, num_bits: 8 } : (tensor<8x4x3xi1>) -> tensor<8x4x3xi1> return %0 : tensor<8x4x3xi1> diff --git a/mlir/test/Quantization/tf-lower-fakequant.mlir b/mlir/test/Quantization/convert-fakequant.mlir similarity index 91% rename from mlir/test/Quantization/tf-lower-fakequant.mlir rename to mlir/test/Quantization/convert-fakequant.mlir index a6c572e77901..fcfa18e832fa 100644 --- a/mlir/test/Quantization/tf-lower-fakequant.mlir +++ b/mlir/test/Quantization/convert-fakequant.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -quant-lower-tf | FileCheck %s --dump-input=fail +// RUN: mlir-opt %s -split-input-file -quant-convert-simulated-quantization | FileCheck %s --dump-input=fail // ----- // Verifies a quint8 asymmetric 0..1 range. @@ -9,7 +9,7 @@ func @fakeQuantArgs_Quint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { // CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[u8:f32]{0.0039215686274509803}">> // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8:f32]{0.0039215686274509803}">>) // CHECK-SAME: -> tensor<8x4x3xf32> - %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) { + %0 = "quant.const_fake_quant"(%arg0) { min: 0.0, max: 1.0, num_bits: 8 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> @@ -24,7 +24,7 @@ func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { // CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[u8(1:255):f32]{0.003937007874015748:1}">> // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8(1:255):f32]{0.003937007874015748:1}">>) // CHECK-SAME: -> tensor<8x4x3xf32> - %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) { + %0 = "quant.const_fake_quant"(%arg0) { min: 0.0, max: 1.0, num_bits: 8, narrow_range: true } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> @@ -39,7 +39,7 @@ func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32 // CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[u8:f32]{7.812500e-03:128}">> // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8:f32]{7.812500e-03:128}">>) // CHECK-SAME: -> tensor<8x4x3xf32> - %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) { + %0 = "quant.const_fake_quant"(%arg0) { min: -1.0, max: 0.9921875, num_bits: 8, narrow_range: false } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> @@ -55,7 +55,7 @@ func @fakeQuantArgs_Qint16_Symmetric(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { // CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[i16:f32]{3.05175781185626E-5}">> // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[i16:f32]{3.05175781185626E-5}">>) // CHECK-SAME: -> tensor<8x4x3xf32> - %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) { + %0 = "quant.const_fake_quant"(%arg0) { min: -1.0, max: 0.999969482, num_bits: 16 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> @@ -70,7 +70,7 @@ func @fakeQuantArgs_UnrankedTensor(tensor) -> tensor { // CHECK-SAME: -> tensor> // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor>) // CHECK-SAME: -> tensor - %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) { + %0 = "quant.const_fake_quant"(%arg0) { min: 0.0, max: 1.0, num_bits: 8 } : (tensor) -> tensor return %0 : tensor