Implement lowering of element-wise fixed point add and mul to the standard dialect.

This also does the following:
      - Removes the poc POT add implementation in favor of a version that does not rescale.
      - Adds a handful of FxpMathOps which are needed (these are for comment and we may want to move them to the StandardOps dialect).
      - Adds a canonicalizer to the StorageCastOp, which removes some cruft once conversions have been done.
      - Adds a couple of predicates to OpBase.

--

PiperOrigin-RevId: 244287706
This commit is contained in:
Stella Laurenzo 2019-04-18 17:06:05 -07:00 committed by Mehdi Amini
parent 7977e62b96
commit e8d551e2bd
10 changed files with 646 additions and 241 deletions

View File

@ -90,18 +90,60 @@ class fxpmath_Op<string mnemonic, list<OpTrait> traits> :
//===----------------------------------------------------------------------===//
// Fixed-point (fxp) arithmetic ops used by kernels.
// Some of these are temporary pending inclusion into a more core dialect.
//===----------------------------------------------------------------------===//
def fxpmath_RoundingDivideByPotFxpOp :
fxpmath_Op<"rounding_divide_by_poti", [NoSideEffect, SameValueType]>,
Arguments<(ins quant_StorageValueType:$x, I32Attr:$exponent)>,
Results<(outs quant_StorageValueType:$y)> {
def fxpmath_ClampISOp : fxpmath_Op<"clampis", [NoSideEffect, SameValueType]> {
let summary =
"Clamps a signed-integer like argument to a min/max range.";
let description = [{
Element-wise equivalent to:
r = std::min(clamp_max, std::max(e, clamp_min))
}];
let arguments = (ins IntegerLike:$arg,
APIntAttr:$clamp_min,
APIntAttr:$clamp_max);
let results = (outs IntegerLike);
}
def fxpmath_ConvertISOp :
fxpmath_Op<"convertis",
[NoSideEffect, SameValueShape]> {
let summary =
"Does an element-wise conversion from a signed integer to signed integer";
let description = [{
Similar to an element-wise static_cast in C++, from a one signed integer
element type to another.
}];
let arguments = (ins IntegerLike:$arg);
let results = (outs IntegerLike);
}
def fxpmath_VecScalarSaturatingRoundingDoublingHighMulISOp :
fxpmath_Op<"vs_saturating_rounding_doubling_high_mulis",
[NoSideEffect, SameValueType]> {
let summary = "Implements equivalent functionality to ARMv7 NEON VQRDMULH";
let description = [{
Equivalent to the ARMv7 NEON VQRDMULH instruction.
See gemmlowp::SaturatingRoundingDoublingHighMul for a reference
implementation.
}];
let arguments = (ins IntegerLike:$a, APIntAttr:$b);
let results = (outs IntegerLike);
}
def fxpmath_RoundingDivideByPotISOp :
fxpmath_Op<"rounding_divide_by_potis", [NoSideEffect, SameValueType]> {
let summary = [{
Computes a rounding arithmetic right shift.
}];
let description = [{
Computes integer division by a power-of-two, correctly rounded-to-nearest.
Also known as a rounding arithmetic right shift. See
gemmlowp::RoundingDivideByPOT for a reference implementation.
}];
let arguments = (ins IntegerLike:$x, APIntAttr:$exponent);
let results = (outs IntegerLike:$y);
let verifier = [{
auto verifyExponent = exponent().getSExtValue();
if (verifyExponent < 0 || verifyExponent > 31) {
@ -111,21 +153,6 @@ def fxpmath_RoundingDivideByPotFxpOp :
}];
}
def fxpmath_SaturatingAddFxpOp :
fxpmath_Op<"saturating_addi", [NoSideEffect, SameValueType]>,
Arguments<(ins quant_StorageValueType:$x,
quant_StorageValueType:$y,
I32Attr:$clamp_min,
I32Attr:$clamp_max)>,
Results<(outs quant_StorageValueType:$sum)> {
let description = [{
Computes saturating addition of two operands, saturating to the given min
and max value. The implementation is responsible for choosing an
intermediate register size appropriate to carry out the operation without
overflow. See gemmlowp::SaturatingAdd for a reference implementation.
}];
}
//===----------------------------------------------------------------------===//
// Real math ops.
//

View File

@ -336,11 +336,11 @@ class TypedStaticShapeTensor<Type t>
: Type<AllOf<[ TypedTensor<t>.predicate, IsStaticShapeTensorTypePred ]>,
"statically shaped tensor">;
def I1Tensor : TypedTensor<I1>;
def I8Tensor : TypedTensor<I8>;
def I16Tensor : TypedTensor<I16>;
def I32Tensor : TypedTensor<I32>;
def I64Tensor : TypedTensor<I64>;
def I1Tensor : TypedTensor<I1>;
def I8Tensor : TypedTensor<I8>;
def I16Tensor : TypedTensor<I16>;
def I32Tensor : TypedTensor<I32>;
def I64Tensor : TypedTensor<I64>;
def BF16Tensor : TypedTensor<BF16>;
def F16Tensor : TypedTensor<F16>;
@ -503,6 +503,12 @@ class IntegerAttrBase<I attrValType, string descr> :
let returnType = [{ APInt }];
}
def APIntAttr : Attr<CPred<"$_self.isa<IntegerAttr>()">,
"arbitrary integer attribute"> {
let storageType = [{ IntegerAttr }];
let returnType = [{ APInt }];
}
def I32Attr : IntegerAttrBase<I32, "32-bit integer attribute">;
def I64Attr : IntegerAttrBase<I64, "64-bit integer attribute">;

View File

@ -92,6 +92,7 @@ def quant_DequantizeCastOp : quant_Op<"dcast", [NoSideEffect]> {
def quant_StorageCastOp : quant_Op<"scast", [NoSideEffect]> {
let arguments = (ins quant_RealOrStorageValueType:$arg);
let results = (outs quant_RealOrStorageValueType);
let hasCanonicalizer = 0b1;
}
//===----------------------------------------------------------------------===//

View File

@ -131,6 +131,10 @@ def RemIUOp : IntArithmeticOp<"std.remiu"> {
let hasConstantFolder = 0b1;
}
def ShlISOp : IntArithmeticOp<"std.shlis"> {
let summary = "signed integer shift left";
}
def SubFOp : FloatArithmeticOp<"std.subf"> {
let summary = "floating point subtraction operation";
let hasConstantFolder = 0b1;

View File

@ -15,17 +15,17 @@
// limitations under the License.
// =============================================================================
#include "UniformKernelUtils.h"
#include "mlir/FxpMathOps/FxpMathOps.h"
#include "mlir/FxpMathOps/Passes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Quantization/QuantOps.h"
#include "mlir/Quantization/UniformSupport.h"
#include <functional>
#include "mlir/StandardOps/Ops.h"
using namespace mlir;
using namespace mlir::fxpmath;
using namespace mlir::fxpmath::detail;
using namespace mlir::quant;
namespace {
@ -35,186 +35,176 @@ struct LowerUniformRealMathPass
void runOnFunction() override;
};
UniformQuantizedType getUniformElementType(Type t) {
return QuantizedType::getQuantizedElementType(t)
.dyn_cast_or_null<UniformQuantizedType>();
}
/// Computes the log2(x), rounded to an integral value. Returns whether 'x' can
/// be considered an exact integral value.
template <typename F> bool integralLog2(F x, int &log2Result) {
const F xLog2 = std::log(x) * (1.0 / std::log(2.0));
const F xLog2Rounded = std::round(xLog2);
const F xLog2Frac = xLog2 - xLog2Rounded;
log2Result = static_cast<int>(xLog2Rounded);
// Allow small comparison slop below the level that would make a difference
// for 2^16 levels.
return std::abs(xLog2Frac) < 1e-6;
}
/// Helper class for operating on binary operations where all operands
/// and the result are a UniformQuantizedType.
struct RealBinaryOpInfo {
RealBinaryOpInfo(Operation *op, Value *lhs, Value *rhs,
Optional<APFloat> clampMin, Optional<APFloat> clampMax)
: op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax),
lhsType(getUniformElementType(lhs->getType())),
rhsType(getUniformElementType(rhs->getType())),
resultType(getUniformElementType(*op->result_type_begin())),
lhsStorageType(QuantizedType::castToStorageType(lhs->getType())),
rhsStorageType(QuantizedType::castToStorageType(rhs->getType())),
resultStorageType(
QuantizedType::castToStorageType(*op->result_type_begin())) {}
/// Returns whether this info is valid (all types defined, etc).
bool isValid() const {
return lhsType && rhsType && resultType && lhsStorageType &&
rhsStorageType && resultStorageType;
}
/// Returns whether the storage type of all operands is identical.
bool isSameStorageType() const {
return lhsType.getStorageType() == rhsType.getStorageType() &&
lhsType.getStorageType() == resultType.getStorageType();
}
/// Returns whether all operands and result are considered fixedpoint power
/// of two, setting the lhs, rhs, and result log2 scale references.
bool isFixedPointPOT(int &lhsLog2Scale, int &rhsLog2Scale,
int &resultLog2Scale) const {
if (!lhsType.isFixedPoint() || !rhsType.isFixedPoint() ||
!resultType.isFixedPoint()) {
return false;
}
if (!integralLog2(lhsType.getScale(), lhsLog2Scale) ||
!integralLog2(rhsType.getScale(), rhsLog2Scale) ||
!integralLog2(resultType.getScale(), resultLog2Scale)) {
return false;
}
return true;
}
/// Gets the result integer clamp range given the result quantized type
// and any explicit clamp provided as attributes.
std::pair<IntegerAttr, IntegerAttr> getClampMinMax() const {
int64_t typeMin = resultType.getStorageTypeMin();
int64_t typeMax = resultType.getStorageTypeMax();
if (clampMin || clampMax) {
UniformQuantizedValueConverter conv(resultType);
if (clampMin) {
typeMin = std::max(typeMin, conv.quantizeFloatToInt64(*clampMin));
}
if (clampMax) {
typeMax = std::min(typeMax, conv.quantizeFloatToInt64(*clampMax));
}
}
// The quantized, integral ops expect clamps as 32bit ints.
return {
IntegerAttr::get(IntegerType::get(32, resultType.getContext()),
typeMin),
IntegerAttr::get(IntegerType::get(32, resultType.getContext()),
typeMax),
};
}
Operation *op;
Value *lhs;
Value *rhs;
Optional<APFloat> clampMin;
Optional<APFloat> clampMax;
// Element UniformQuantizedType for operands/result.
UniformQuantizedType lhsType;
UniformQuantizedType rhsType;
UniformQuantizedType resultType;
// Full storage-based types.
Type lhsStorageType;
Type rhsStorageType;
Type resultStorageType;
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// Elementwise add
//===----------------------------------------------------------------------===//
/// Attempts to rewrite a fixed point power-of-two addition of two integers.
/// This supports a limited number of cases, but when supported, represents
/// the simplest computation.
static LogicalResult tryRewriteFixedPOTAddEw(const RealBinaryOpInfo &constInfo,
PatternRewriter &rewriter) {
if (!constInfo.isSameStorageType()) {
static LogicalResult
tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info,
PatternRewriter &rewriter) {
if (!info.resultType.isSigned() || info.lhsType != info.resultType ||
info.rhsType != info.resultType) {
return failure();
}
int lhsLog2Scale;
int rhsLog2Scale;
int resultLog2Scale;
if (!constInfo.isFixedPointPOT(lhsLog2Scale, rhsLog2Scale, resultLog2Scale)) {
return failure();
}
// Adjust shifts to be relative to the output.
// Left shift of one input scale is supported. The other must match the result
// scale.
int lhsScaleShift = lhsLog2Scale - resultLog2Scale;
int rhsScaleShift = rhsLog2Scale - resultLog2Scale;
if (lhsScaleShift != 0 && rhsScaleShift != 0) {
return failure();
}
if (lhsScaleShift > 0 || rhsScaleShift > 0) {
return failure();
}
// State accessed by the closure.
Operation *mathOp = constInfo.op;
const auto clampMinMax = constInfo.getClampMinMax();
Value *lhs = constInfo.lhs;
Value *rhs = constInfo.rhs;
Type lhsStorageType = constInfo.lhsStorageType;
Type rhsStorageType = constInfo.rhsStorageType;
// If the lhs operand is the one requiring a shift, swap it so that the shift
// happens the rhs operand.
if (lhsScaleShift != 0) {
std::swap(lhs, rhs);
std::swap(lhsStorageType, rhsStorageType);
std::swap(lhsScaleShift, rhsScaleShift);
}
int rhsRightShift = -rhsScaleShift;
// Choose a byte aligned intermediate width big enough to perform the
// calculation without overflow.
// TODO: This should probably be made just big enough to avoid overflow and
// leave the downstream tooling to decide how to align that to machine
// word sizes.
unsigned intermediateWidth =
info.resultType.getStorageTypeIntegralWidth() <= 8 ? 16 : 32;
IntegerType intermediateElementType =
IntegerType::get(intermediateWidth, rewriter.getContext());
Type intermediateType =
castElementType(info.resultStorageType, intermediateElementType);
// Cast operands to storage type.
Value *lhsStorageValue =
rewriter.create<StorageCastOp>(mathOp->getLoc(), lhsStorageType, lhs)
.getResult();
Value *rhsStorageValue =
rewriter.create<StorageCastOp>(mathOp->getLoc(), rhsStorageType, rhs)
.getResult();
Value *lhsValue = rewriter
.create<StorageCastOp>(info.op->getLoc(),
info.lhsStorageType, info.lhs)
.getResult();
Value *rhsValue = rewriter
.create<StorageCastOp>(info.op->getLoc(),
info.rhsStorageType, info.rhs)
.getResult();
// Rescale the rhs operand if needed.
if (rhsRightShift != 0) {
rhsStorageValue =
rewriter
.create<RoundingDivideByPotFxpOp>(
mathOp->getLoc(), rhsStorageValue,
IntegerAttr::get(IntegerType::get(32, rewriter.getContext()),
rhsRightShift))
.getResult();
}
// Cast to the intermediate sized type.
lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
lhsValue);
rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
rhsValue);
// Add.
Value *sumValue = rewriter.create<SaturatingAddFxpOp>(
mathOp->getLoc(), lhsStorageValue, rhsStorageValue, clampMinMax.first,
clampMinMax.second);
Value *resultValue =
rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, rhsValue);
// Zero point offset adjustment.
// result = (lhs - zp) + (rhs - zp) + zp
// zpOffset = -zp
int zpOffset = -1 * info.resultType.getZeroPoint();
if (zpOffset != 0) {
Value *zpOffsetConst = rewriter.create<ConstantOp>(
info.op->getLoc(),
broadcastScalarConstIntValue(intermediateType, zpOffset));
resultValue =
rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
}
// Clamp.
auto clampMinMax = info.getClampMinMax(intermediateElementType);
resultValue = rewriter.create<ClampISOp>(
info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
// Convert back to original type.
resultValue = rewriter.create<ConvertISOp>(
info.op->getLoc(), info.resultStorageType, resultValue);
// Cast back for new result.
rewriter.replaceOpWithNewOp<StorageCastOp>(
mathOp, *mathOp->result_type_begin(), sumValue);
info.op, info.getQuantizedResultType(), resultValue);
return success();
}
//===----------------------------------------------------------------------===//
// Elementwise mul
//===----------------------------------------------------------------------===//
static LogicalResult
tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info,
PatternRewriter &rewriter) {
if (!info.resultType.isSigned()) {
return failure();
}
double outputMultiplierReal = info.lhsType.getScale() *
info.rhsType.getScale() /
info.resultType.getScale();
if (outputMultiplierReal > 1.0) {
info.op->emitWarning("unimplemented: cannot multiply with multipler > 1.0");
return failure();
}
// TODO: Choose an appropriate intermediate width for muls > 8 bits to
// avoid overflow.
unsigned intermediateWidth = 32;
IntegerType intermediateElementType =
IntegerType::get(intermediateWidth, rewriter.getContext());
Type intermediateType =
castElementType(info.resultStorageType, intermediateElementType);
// Cast operands to storage type.
Value *lhsValue = rewriter
.create<StorageCastOp>(info.op->getLoc(),
info.lhsStorageType, info.lhs)
.getResult();
Value *rhsValue = rewriter
.create<StorageCastOp>(info.op->getLoc(),
info.rhsStorageType, info.rhs)
.getResult();
// Cast to the intermediate sized type.
lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
lhsValue);
rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
rhsValue);
// Apply argument zeroPoints.
if (info.lhsType.getZeroPoint() != 0) {
Value *zpOffsetConst = rewriter.create<ConstantOp>(
info.op->getLoc(), broadcastScalarConstIntValue(
intermediateType, -info.lhsType.getZeroPoint()));
lhsValue =
rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, zpOffsetConst);
}
if (info.rhsType.getZeroPoint() != 0) {
Value *zpOffsetConst = rewriter.create<ConstantOp>(
info.op->getLoc(), broadcastScalarConstIntValue(
intermediateType, -info.rhsType.getZeroPoint()));
rhsValue =
rewriter.create<AddIOp>(info.op->getLoc(), rhsValue, zpOffsetConst);
}
// Mul.
Value *resultValue =
rewriter.create<MulIOp>(info.op->getLoc(), lhsValue, rhsValue);
// Scale output.
QuantizedMultiplierSmallerThanOneExp outputMultiplier(outputMultiplierReal);
resultValue = rewriter.create<VecScalarSaturatingRoundingDoublingHighMulISOp>(
info.op->getLoc(), resultValue,
IntegerAttr::get(intermediateElementType, outputMultiplier.multiplier));
resultValue = rewriter.create<RoundingDivideByPotISOp>(
info.op->getLoc(), resultValue,
IntegerAttr::get(intermediateElementType, -outputMultiplier.exponent));
// Zero point offset adjustment.
if (info.resultType.getZeroPoint() != 0) {
Value *zpOffsetConst = rewriter.create<ConstantOp>(
info.op->getLoc(),
broadcastScalarConstIntValue(intermediateType,
info.resultType.getZeroPoint()));
resultValue =
rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
}
// Clamp.
auto clampMinMax = info.getClampMinMax(intermediateElementType);
resultValue = rewriter.create<ClampISOp>(
info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
// Convert back to original type.
resultValue = rewriter.create<ConvertISOp>(
info.op->getLoc(), info.resultStorageType, resultValue);
// Cast back for new result.
rewriter.replaceOpWithNewOp<StorageCastOp>(
info.op, info.getQuantizedResultType(), resultValue);
return success();
}
@ -227,14 +217,36 @@ struct UniformRealAddEwPattern : public RewritePattern {
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
auto addOp = op->cast<RealAddEwOp>();
const RealBinaryOpInfo info(op, addOp.x(), addOp.y(), addOp.clamp_min(),
addOp.clamp_max());
const UniformBinaryOpInfo info(op, addOp.x(), addOp.y(), addOp.clamp_min(),
addOp.clamp_max());
if (!info.isValid()) {
return matchFailure();
}
// Try all of the permutations we support.
if (succeeded(tryRewriteFixedPOTAddEw(info, rewriter))) {
if (succeeded(tryRewriteAffineAddEwIsomorphicSigned(info, rewriter))) {
return matchSuccess();
}
return matchFailure();
}
};
struct UniformRealMulEwPattern : public RewritePattern {
UniformRealMulEwPattern(MLIRContext *context)
: RewritePattern(RealMulEwOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
auto mulOp = op->cast<RealMulEwOp>();
const UniformBinaryOpInfo info(op, mulOp.x(), mulOp.y(), mulOp.clamp_min(),
mulOp.clamp_max());
if (!info.isValid()) {
return matchFailure();
}
// Try all of the permutations we support.
if (succeeded(tryRewriteAffineMulEwSigned(info, rewriter))) {
return matchSuccess();
}
@ -249,6 +261,7 @@ void LowerUniformRealMathPass::runOnFunction() {
OwningRewritePatternList patterns;
auto *context = &getContext();
patterns.push_back(llvm::make_unique<UniformRealAddEwPattern>(context));
patterns.push_back(llvm::make_unique<UniformRealMulEwPattern>(context));
applyPatternsGreedily(fn, std::move(patterns));
}

View File

@ -0,0 +1,203 @@
//===- UniformKernelUtils.h - Utilities for lowering uniform math - C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
#define MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
#include "mlir/IR/Operation.h"
#include "mlir/Quantization/QuantOps.h"
#include "mlir/Quantization/QuantTypes.h"
#include "mlir/Quantization/UniformSupport.h"
#include <cmath>
namespace mlir {
namespace fxpmath {
namespace detail {
inline quant::UniformQuantizedType getUniformElementType(Type t) {
return quant::QuantizedType::getQuantizedElementType(t)
.dyn_cast_or_null<quant::UniformQuantizedType>();
}
inline bool hasStorageBitWidth(quant::QuantizedType t,
llvm::ArrayRef<unsigned> checkWidths) {
unsigned w = t.getStorageType().getIntOrFloatBitWidth();
for (unsigned checkWidth : checkWidths) {
if (w == checkWidth)
return true;
}
return false;
}
/// Computes the log2(x), rounded to an integral value. Returns whether 'x' can
/// be considered an exact integral value.
template <typename F> bool integralLog2(F x, int &log2Result) {
const F xLog2 = std::log(x) * (1.0 / std::log(2.0));
const F xLog2Rounded = std::round(xLog2);
const F xLog2Frac = xLog2 - xLog2Rounded;
log2Result = static_cast<int>(xLog2Rounded);
// Allow small comparison slop below the level that would make a difference
// for 2^16 levels.
return std::abs(xLog2Frac) < 1e-6;
}
/// Helper class for operating on binary operations where all operands
/// and the result are a UniformQuantizedType.
struct UniformBinaryOpInfo {
UniformBinaryOpInfo(Operation *op, Value *lhs, Value *rhs,
Optional<APFloat> clampMin, Optional<APFloat> clampMax)
: op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax),
lhsType(getUniformElementType(lhs->getType())),
rhsType(getUniformElementType(rhs->getType())),
resultType(getUniformElementType(*op->result_type_begin())),
lhsStorageType(quant::QuantizedType::castToStorageType(lhs->getType())),
rhsStorageType(quant::QuantizedType::castToStorageType(rhs->getType())),
resultStorageType(
quant::QuantizedType::castToStorageType(*op->result_type_begin())) {
}
/// Returns whether this info is valid (all types defined, etc).
bool isValid() const {
return lhsType && rhsType && resultType && lhsStorageType &&
rhsStorageType && resultStorageType;
}
/// Gets the final quantized result type of the result.
Type getQuantizedResultType() const { return *op->result_type_begin(); }
/// Returns whether the storage type of all operands is identical.
bool isSameStorageType() const {
return lhsType.getStorageType() == rhsType.getStorageType() &&
lhsType.getStorageType() == resultType.getStorageType();
}
/// Returns whether all operands and result are considered fixedpoint power
/// of two, setting the lhs, rhs, and result log2 scale references.
bool isFixedPointPOT(int &lhsLog2Scale, int &rhsLog2Scale,
int &resultLog2Scale) const {
if (!lhsType.isFixedPoint() || !rhsType.isFixedPoint() ||
!resultType.isFixedPoint()) {
return false;
}
if (!integralLog2(lhsType.getScale(), lhsLog2Scale) ||
!integralLog2(rhsType.getScale(), rhsLog2Scale) ||
!integralLog2(resultType.getScale(), resultLog2Scale)) {
return false;
}
return true;
}
/// Gets the result integer clamp range given the result quantized type
// and any explicit clamp provided as attributes.
std::pair<IntegerAttr, IntegerAttr> getClampMinMax(IntegerType ty) const {
int64_t typeMin = resultType.getStorageTypeMin();
int64_t typeMax = resultType.getStorageTypeMax();
if (clampMin || clampMax) {
quant::UniformQuantizedValueConverter conv(resultType);
if (clampMin) {
typeMin = std::max(typeMin, conv.quantizeFloatToInt64(*clampMin));
}
if (clampMax) {
typeMax = std::min(typeMax, conv.quantizeFloatToInt64(*clampMax));
}
}
// The quantized, integral ops expect clamps as 32bit ints.
return {
IntegerAttr::get(ty, typeMin),
IntegerAttr::get(ty, typeMax),
};
}
Operation *op;
Value *lhs;
Value *rhs;
Optional<APFloat> clampMin;
Optional<APFloat> clampMax;
// Element UniformQuantizedType for operands/result.
quant::UniformQuantizedType lhsType;
quant::UniformQuantizedType rhsType;
quant::UniformQuantizedType resultType;
// Full storage-based types.
Type lhsStorageType;
Type rhsStorageType;
Type resultStorageType;
};
/// Derives a quantized multiplier and shift from a real valued multiplier
/// less than 1.
struct QuantizedMultiplierSmallerThanOneExp {
QuantizedMultiplierSmallerThanOneExp(double realMultiplier) {
assert(realMultiplier < 1.0);
assert(realMultiplier > 0.0);
const double q = std::frexp(realMultiplier, &exponent);
auto qFixed = static_cast<int64_t>(std::round(q * (1ll << 31)));
assert(qFixed <= (1ll << 31));
if (qFixed == (1ll << 31)) {
qFixed /= 2;
++exponent;
}
assert(qFixed <= std::numeric_limits<int32_t>::max());
multiplier = static_cast<int32_t>(qFixed);
}
int32_t multiplier;
int exponent;
};
/// Casts an integer or floating point based type to a new element type.
inline Type castElementType(Type t, Type newElementType) {
if (auto vt = t.dyn_cast<VectorOrTensorType>()) {
switch (vt.getKind()) {
case StandardTypes::Kind::Vector:
return VectorType::get(vt.getShape(), newElementType);
case StandardTypes::Kind::RankedTensor:
return RankedTensorType::get(vt.getShape(), newElementType);
case StandardTypes::Kind::UnrankedTensor:
return UnrankedTensorType::get(newElementType);
}
}
assert(t.isIntOrFloat());
return newElementType;
}
/// Creates an IntegerAttr with a type that matches the shape of 't' (which can
/// be a primitive/vector/tensor).
inline Attribute broadcastScalarConstIntValue(Type t, int64_t value) {
if (auto vt = t.dyn_cast<VectorOrTensorType>()) {
assert(vt.getElementType().isa<IntegerType>());
return SplatElementsAttr::get(vt,
IntegerAttr::get(vt.getElementType(), value));
}
auto integerType = t.cast<IntegerType>();
assert(t.isa<IntegerType>() && "integer broadcast must be of integer type");
return IntegerAttr::get(integerType, value);
}
} // namespace detail
} // namespace fxpmath
} // namespace mlir
#endif // MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_

View File

@ -18,6 +18,8 @@
#include "mlir/Quantization/QuantOps.h"
#include "TypeDetail.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Quantization/QuantTypes.h"
#include "llvm/ADT/StringRef.h"
@ -31,6 +33,41 @@ using namespace mlir::quant::detail;
#define GET_OP_CLASSES
#include "mlir/Quantization/QuantOps.cpp.inc"
namespace {
/// Matches x -> [scast -> scast] -> y, replacing the second scast with the
/// value of x if the casts invert each other.
class RemoveRedundantStorageCastsRewrite : public RewritePattern {
public:
RemoveRedundantStorageCastsRewrite(MLIRContext *context)
: RewritePattern(StorageCastOp::getOperationName(), 1, context) {}
PatternMatchResult match(Operation *op) const override {
auto scastOp = op->cast<StorageCastOp>();
if (matchPattern(scastOp.arg(), m_Op<StorageCastOp>())) {
auto srcScastOp = scastOp.arg()->getDefiningOp()->cast<StorageCastOp>();
if (srcScastOp.arg()->getType() == scastOp.getResult()->getType()) {
return matchSuccess();
}
}
return matchFailure();
}
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
auto scastOp = op->cast<StorageCastOp>();
auto srcScastOp = scastOp.arg()->getDefiningOp()->cast<StorageCastOp>();
rewriter.replaceOp(op, srcScastOp.arg());
}
};
} // end anonymous namespace
void StorageCastOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.push_back(
llvm::make_unique<RemoveRedundantStorageCastsRewrite>(context));
}
QuantizationDialect::QuantizationDialect(MLIRContext *context)
: Dialect(/*name=*/"quant", context) {
addTypes<UniformQuantizedType, UniformQuantizedPerAxisType>();

View File

@ -1,51 +1,44 @@
// RUN: mlir-opt %s -split-input-file -fxpmath-lower-uniform-real-math | FileCheck %s --dump-input=fail
// RUN: mlir-opt %s -split-input-file -fxpmath-lower-uniform-real-math -canonicalize | FileCheck %s --dump-input=always
// -----
// Verify lowering when operands and result have the same fixedpoint pot scale.
// CHECK-LABEL: real_addew_fixedpoint_same_scale
// CHECK: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
// CHECK-NEXT: %2 = "fxpmath.saturating_addi"(%0, %1) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
// CHECK-NEXT: %3 = "quant.scast"(%2) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
// CHECK-NEXT: return %3 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
// Verify lowering when operands and result have the same fixedpoint scale.
// CHECK-LABEL: real_addew_fixedpoint_isomorphic
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
func @real_addew_fixedpoint_same_scale(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
func @real_addew_fixedpoint_isomorphic(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
// CHECK-NEXT: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
// CHECK-NEXT: %2 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi16>
// CHECK-NEXT: %3 = "fxpmath.convertis"(%1) : (tensor<4xi8>) -> tensor<4xi16>
// CHECK-NEXT: %4 = addi %2, %3 : tensor<4xi16>
// CHECK-NEXT: %5 = "fxpmath.clampis"(%4) {clamp_max: 127 : i16, clamp_min: -128 : i16} : (tensor<4xi16>) -> tensor<4xi16>
// CHECK-NEXT: %6 = "fxpmath.convertis"(%5) : (tensor<4xi16>) -> tensor<4xi8>
// CHECK-NEXT: %7 = "quant.scast"(%6) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
// CHECK-NEXT: return %7 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
return %0 : !type_result
}
// -----
// Verify lowering when the rhs is a shifted pot scale compared to lhs and result.
// CHECK-LABEL: real_addew_fixedpoint_rhs_shift
// CHECK: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> tensor<4xi8>
// CHECK-NEXT: %2 = "fxpmath.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8>
// CHECK-NEXT: %3 = "fxpmath.saturating_addi"(%0, %2) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
// CHECK-NEXT: %4 = "quant.scast"(%3) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
// CHECK-NEXT: return %4 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{7.8125e-03}">>
!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
func @real_addew_fixedpoint_rhs_shift(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
return %0 : !type_result
}
// -----
// Verify lowering when the lhs is a shifted pot scale compared to lhs and result.
// CHECK-LABEL: real_addew_fixedpoint_lhs_shift
// CHECK: %0 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
// CHECK-NEXT: %1 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> tensor<4xi8>
// CHECK-NEXT: %2 = "fxpmath.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8>
// CHECK-NEXT: %3 = "fxpmath.saturating_addi"(%0, %2) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
// CHECK-NEXT: %4 = "quant.scast"(%3) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
// CHECK-NEXT: return %4 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{7.8125e-03}">>
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
func @real_addew_fixedpoint_lhs_shift(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
// Verify lowering when operands and result have the same fixedpoint scale
// and non-zero zero points.
// CHECK-LABEL: real_addew_affine_isomorphic
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2:-5}">>
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2:-5}">>
!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2:-5}">>
func @real_addew_affine_isomorphic(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
// CHECK-NEXT: %cst = constant splat<tensor<4xi16>, 5> : tensor<4xi16>
// CHECK-NEXT: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02:-5}">>) -> tensor<4xi8>
// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02:-5}">>) -> tensor<4xi8>
// CHECK-NEXT: %2 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi16>
// CHECK-NEXT: %3 = "fxpmath.convertis"(%1) : (tensor<4xi8>) -> tensor<4xi16>
// CHECK-NEXT: %4 = addi %2, %3 : tensor<4xi16>
// CHECK-NEXT: %5 = addi %4, %cst : tensor<4xi16>
// CHECK-NEXT: %6 = "fxpmath.clampis"(%5) {clamp_max: 127 : i16, clamp_min: -128 : i16} : (tensor<4xi16>) -> tensor<4xi16>
// CHECK-NEXT: %7 = "fxpmath.convertis"(%6) : (tensor<4xi16>) -> tensor<4xi8>
// CHECK-NEXT: %8 = "quant.scast"(%7) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02:-5}">>
// CHECK-NEXT: return %8 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02:-5}">>
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
return %0 : !type_result
}
@ -54,16 +47,19 @@ func @real_addew_fixedpoint_lhs_shift(%arg0 : !type_lhs, %arg1: !type_rhs) -> !t
// The RHS quant parameters proscribe a range of [-8..8) so an explicit clamp
// of [-4..4] should result in an integral clamp range of [-64..64].
// CHECK-LABEL: real_addew_fixedpoint_clamp
// CHECK: %0 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
// CHECK-NEXT: %1 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> tensor<4xi8>
// CHECK-NEXT: %2 = "fxpmath.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8>
// CHECK-NEXT: %3 = "fxpmath.saturating_addi"(%0, %2) {clamp_max: 64 : i32, clamp_min: -64 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
// CHECK-NEXT: %4 = "quant.scast"(%3) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
// CHECK-NEXT: return %4 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{7.8125e-03}">>
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
func @real_addew_fixedpoint_clamp(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
// CHECK-NEXT: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
// CHECK-NEXT: %2 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi16>
// CHECK-NEXT: %3 = "fxpmath.convertis"(%1) : (tensor<4xi8>) -> tensor<4xi16>
// CHECK-NEXT: %4 = addi %2, %3 : tensor<4xi16>
// CHECK-NEXT: %5 = "fxpmath.clampis"(%4) {clamp_max: 64 : i16, clamp_min: -64 : i16} : (tensor<4xi16>) -> tensor<4xi16>
// CHECK-NEXT: %6 = "fxpmath.convertis"(%5) : (tensor<4xi16>) -> tensor<4xi8>
// CHECK-NEXT: %7 = "quant.scast"(%6) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
// CHECK-NEXT: return %7 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) { clamp_min:-4.0, clamp_max:4.0 }
: (!type_lhs, !type_rhs) -> (!type_result)
return %0 : !type_result

View File

@ -0,0 +1,94 @@
// RUN: mlir-opt %s -split-input-file -fxpmath-lower-uniform-real-math -canonicalize -verify | FileCheck %s --dump-input=always
// -----
// Verify lowering when operands and result have the same fixedpoint scale.
// CHECK-LABEL: real_mulew_fixedpoint
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{3.875e-2}">>
!type_result = type tensor<4x!quant<"uniform[i8:f32]{1.065e-1}">>
func @real_mulew_fixedpoint(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
// CHECK: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{3.875000e-02}">>) -> tensor<4xi8>
// CHECK-NEXT: %2 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi32>
// CHECK-NEXT: %3 = "fxpmath.convertis"(%1) : (tensor<4xi8>) -> tensor<4xi32>
// CHECK-NEXT: %4 = muli %2, %3 : tensor<4xi32>
// CHECK-NEXT: %5 = "fxpmath.vs_saturating_rounding_doubling_high_mulis"(%4) {b: 1562722842 : i32} : (tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: %6 = "fxpmath.rounding_divide_by_potis"(%5) {exponent: 5 : i32} : (tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: %7 = "fxpmath.clampis"(%6) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: %8 = "fxpmath.convertis"(%7) : (tensor<4xi32>) -> tensor<4xi8>
// CHECK-NEXT: %9 = "quant.scast"(%8) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{1.065000e-01}">>
// CHECK-NEXT: return %9 : tensor<4x!quant<"uniform[i8:f32]{1.065000e-01}">>
%0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
return %0 : !type_result
}
// -----
// Verify lowering when operands and result have the same fixedpoint scale
// and non-zero zero points.
// CHECK-LABEL: real_mulew_affine_clamp
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2:-3}">>
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2:-5}">>
!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2:-9}">>
func @real_mulew_affine_clamp(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
// Just verify that the affine adds/constants and clamps are present.
// CHECK: %cst = constant splat<tensor<4xi32>, 3> : tensor<4xi32>
// CHECK: %cst_0 = constant splat<tensor<4xi32>, 5> : tensor<4xi32>
// CHECK: %cst_1 = constant splat<tensor<4xi32>, -9> : tensor<4xi32>
// CHECK: addi %2, %cst : tensor<4xi32>
// CHECK: addi %3, %cst_0 : tensor<4xi32>
// CHECK: muli %4, %5 : tensor<4xi32>
// CHECK: addi %8, %cst_1 : tensor<4xi32>
// CHECK: {clamp_max: 55 : i32, clamp_min: -73 : i32}
%0 = "fxpmath.real_mul_ew"(%arg0, %arg1) { clamp_min:-4.0, clamp_max:4.0 } : (!type_lhs, !type_rhs) -> (!type_result)
return %0 : !type_result
}
// -----
// CHECK-LABEL: real_mulew_unquantized_lhs
// Verifies that leaves as-is for unquantized lhs.
!type_lhs = type tensor<4xf32>
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
func @real_mulew_unquantized_lhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
// CHECK: %0 = "fxpmath.real_mul_ew"(%arg0, %arg1)
%0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
return %0 : !type_result
}
// -----
// CHECK-LABEL: real_mulew_unquantized_rhs
// Verifies that leaves as-is for unquantized rhs.
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
!type_rhs = type tensor<4xf32>
!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
func @real_mulew_unquantized_rhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
// CHECK: %0 = "fxpmath.real_mul_ew"(%arg0, %arg1)
%0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
return %0 : !type_result
}
// -----
// CHECK-LABEL: real_mulew_unquantized_result
// Verifies that leaves as-is for unquantized result.
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
!type_result = type tensor<4xf32>
func @real_mulew_unquantized_result(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
// CHECK: %0 = "fxpmath.real_mul_ew"(%arg0, %arg1)
%0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
return %0 : !type_result
}
// -----
// Verify lowering when operands and result have the same fixedpoint scale.
// Note that the multiplier = lhs_scale * rhs_scale / result_scale
// = 22.740610328638496
// CHECK-LABEL: real_mulew_multiplier_gt_1
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{3.875e-2}">>
!type_result = type tensor<4x!quant<"uniform[i8:f32]{1.065e-4}">>
func @real_mulew_multiplier_gt_1(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
// expected-warning@+1 {{unimplemented: cannot multiply with multipler > 1.0}}
%0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
return %0 : !type_result
}

View File

@ -0,0 +1,24 @@
// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s --dump-input=fail
// -----
// CHECK-LABEL: redundant_scast
func @redundant_scast() -> tensor<4xi8> {
// CHECK-NEXT: constant splat<tensor<4xi8>, 10>
// CHECK-NEXT: return
%cst = constant splat<tensor<4xi8>, 5> : tensor<4xi8>
%1 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">>
%2 = "quant.scast"(%1) : (tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">>) -> tensor<4xi8>
%3 = addi %2, %2 : tensor<4xi8>
return %3 : tensor<4xi8>
}
// -----
// CHECK-LABEL: non_redundant_scast
func @non_redundant_scast() -> tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">> {
// CHECK-NEXT: constant splat<tensor<4xi8>, 5>
// CHECK-NEXT: scast
// CHECK-NEXT: return
%cst = constant splat<tensor<4xi8>, 5> : tensor<4xi8>
%1 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">>
return %1 : tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">>
}