forked from OSchip/llvm-project
Implement lowering of element-wise fixed point add and mul to the standard dialect.
This also does the following: - Removes the poc POT add implementation in favor of a version that does not rescale. - Adds a handful of FxpMathOps which are needed (these are for comment and we may want to move them to the StandardOps dialect). - Adds a canonicalizer to the StorageCastOp, which removes some cruft once conversions have been done. - Adds a couple of predicates to OpBase. -- PiperOrigin-RevId: 244287706
This commit is contained in:
parent
7977e62b96
commit
e8d551e2bd
|
@ -90,18 +90,60 @@ class fxpmath_Op<string mnemonic, list<OpTrait> traits> :
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Fixed-point (fxp) arithmetic ops used by kernels.
|
||||
// Some of these are temporary pending inclusion into a more core dialect.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def fxpmath_RoundingDivideByPotFxpOp :
|
||||
fxpmath_Op<"rounding_divide_by_poti", [NoSideEffect, SameValueType]>,
|
||||
Arguments<(ins quant_StorageValueType:$x, I32Attr:$exponent)>,
|
||||
Results<(outs quant_StorageValueType:$y)> {
|
||||
def fxpmath_ClampISOp : fxpmath_Op<"clampis", [NoSideEffect, SameValueType]> {
|
||||
let summary =
|
||||
"Clamps a signed-integer like argument to a min/max range.";
|
||||
let description = [{
|
||||
Element-wise equivalent to:
|
||||
r = std::min(clamp_max, std::max(e, clamp_min))
|
||||
}];
|
||||
let arguments = (ins IntegerLike:$arg,
|
||||
APIntAttr:$clamp_min,
|
||||
APIntAttr:$clamp_max);
|
||||
let results = (outs IntegerLike);
|
||||
}
|
||||
|
||||
def fxpmath_ConvertISOp :
|
||||
fxpmath_Op<"convertis",
|
||||
[NoSideEffect, SameValueShape]> {
|
||||
let summary =
|
||||
"Does an element-wise conversion from a signed integer to signed integer";
|
||||
let description = [{
|
||||
Similar to an element-wise static_cast in C++, from a one signed integer
|
||||
element type to another.
|
||||
}];
|
||||
let arguments = (ins IntegerLike:$arg);
|
||||
let results = (outs IntegerLike);
|
||||
}
|
||||
|
||||
def fxpmath_VecScalarSaturatingRoundingDoublingHighMulISOp :
|
||||
fxpmath_Op<"vs_saturating_rounding_doubling_high_mulis",
|
||||
[NoSideEffect, SameValueType]> {
|
||||
let summary = "Implements equivalent functionality to ARMv7 NEON VQRDMULH";
|
||||
let description = [{
|
||||
Equivalent to the ARMv7 NEON VQRDMULH instruction.
|
||||
See gemmlowp::SaturatingRoundingDoublingHighMul for a reference
|
||||
implementation.
|
||||
}];
|
||||
let arguments = (ins IntegerLike:$a, APIntAttr:$b);
|
||||
let results = (outs IntegerLike);
|
||||
}
|
||||
|
||||
def fxpmath_RoundingDivideByPotISOp :
|
||||
fxpmath_Op<"rounding_divide_by_potis", [NoSideEffect, SameValueType]> {
|
||||
let summary = [{
|
||||
Computes a rounding arithmetic right shift.
|
||||
}];
|
||||
let description = [{
|
||||
Computes integer division by a power-of-two, correctly rounded-to-nearest.
|
||||
Also known as a rounding arithmetic right shift. See
|
||||
gemmlowp::RoundingDivideByPOT for a reference implementation.
|
||||
}];
|
||||
|
||||
let arguments = (ins IntegerLike:$x, APIntAttr:$exponent);
|
||||
let results = (outs IntegerLike:$y);
|
||||
let verifier = [{
|
||||
auto verifyExponent = exponent().getSExtValue();
|
||||
if (verifyExponent < 0 || verifyExponent > 31) {
|
||||
|
@ -111,21 +153,6 @@ def fxpmath_RoundingDivideByPotFxpOp :
|
|||
}];
|
||||
}
|
||||
|
||||
def fxpmath_SaturatingAddFxpOp :
|
||||
fxpmath_Op<"saturating_addi", [NoSideEffect, SameValueType]>,
|
||||
Arguments<(ins quant_StorageValueType:$x,
|
||||
quant_StorageValueType:$y,
|
||||
I32Attr:$clamp_min,
|
||||
I32Attr:$clamp_max)>,
|
||||
Results<(outs quant_StorageValueType:$sum)> {
|
||||
let description = [{
|
||||
Computes saturating addition of two operands, saturating to the given min
|
||||
and max value. The implementation is responsible for choosing an
|
||||
intermediate register size appropriate to carry out the operation without
|
||||
overflow. See gemmlowp::SaturatingAdd for a reference implementation.
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Real math ops.
|
||||
//
|
||||
|
|
|
@ -336,11 +336,11 @@ class TypedStaticShapeTensor<Type t>
|
|||
: Type<AllOf<[ TypedTensor<t>.predicate, IsStaticShapeTensorTypePred ]>,
|
||||
"statically shaped tensor">;
|
||||
|
||||
def I1Tensor : TypedTensor<I1>;
|
||||
def I8Tensor : TypedTensor<I8>;
|
||||
def I16Tensor : TypedTensor<I16>;
|
||||
def I32Tensor : TypedTensor<I32>;
|
||||
def I64Tensor : TypedTensor<I64>;
|
||||
def I1Tensor : TypedTensor<I1>;
|
||||
def I8Tensor : TypedTensor<I8>;
|
||||
def I16Tensor : TypedTensor<I16>;
|
||||
def I32Tensor : TypedTensor<I32>;
|
||||
def I64Tensor : TypedTensor<I64>;
|
||||
|
||||
def BF16Tensor : TypedTensor<BF16>;
|
||||
def F16Tensor : TypedTensor<F16>;
|
||||
|
@ -503,6 +503,12 @@ class IntegerAttrBase<I attrValType, string descr> :
|
|||
let returnType = [{ APInt }];
|
||||
}
|
||||
|
||||
def APIntAttr : Attr<CPred<"$_self.isa<IntegerAttr>()">,
|
||||
"arbitrary integer attribute"> {
|
||||
let storageType = [{ IntegerAttr }];
|
||||
let returnType = [{ APInt }];
|
||||
}
|
||||
|
||||
def I32Attr : IntegerAttrBase<I32, "32-bit integer attribute">;
|
||||
def I64Attr : IntegerAttrBase<I64, "64-bit integer attribute">;
|
||||
|
||||
|
|
|
@ -92,6 +92,7 @@ def quant_DequantizeCastOp : quant_Op<"dcast", [NoSideEffect]> {
|
|||
def quant_StorageCastOp : quant_Op<"scast", [NoSideEffect]> {
|
||||
let arguments = (ins quant_RealOrStorageValueType:$arg);
|
||||
let results = (outs quant_RealOrStorageValueType);
|
||||
let hasCanonicalizer = 0b1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -131,6 +131,10 @@ def RemIUOp : IntArithmeticOp<"std.remiu"> {
|
|||
let hasConstantFolder = 0b1;
|
||||
}
|
||||
|
||||
def ShlISOp : IntArithmeticOp<"std.shlis"> {
|
||||
let summary = "signed integer shift left";
|
||||
}
|
||||
|
||||
def SubFOp : FloatArithmeticOp<"std.subf"> {
|
||||
let summary = "floating point subtraction operation";
|
||||
let hasConstantFolder = 0b1;
|
||||
|
|
|
@ -15,17 +15,17 @@
|
|||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "UniformKernelUtils.h"
|
||||
|
||||
#include "mlir/FxpMathOps/FxpMathOps.h"
|
||||
#include "mlir/FxpMathOps/Passes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Quantization/QuantOps.h"
|
||||
#include "mlir/Quantization/UniformSupport.h"
|
||||
|
||||
#include <functional>
|
||||
#include "mlir/StandardOps/Ops.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::fxpmath;
|
||||
using namespace mlir::fxpmath::detail;
|
||||
using namespace mlir::quant;
|
||||
|
||||
namespace {
|
||||
|
@ -35,186 +35,176 @@ struct LowerUniformRealMathPass
|
|||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
UniformQuantizedType getUniformElementType(Type t) {
|
||||
return QuantizedType::getQuantizedElementType(t)
|
||||
.dyn_cast_or_null<UniformQuantizedType>();
|
||||
}
|
||||
|
||||
/// Computes the log2(x), rounded to an integral value. Returns whether 'x' can
|
||||
/// be considered an exact integral value.
|
||||
template <typename F> bool integralLog2(F x, int &log2Result) {
|
||||
const F xLog2 = std::log(x) * (1.0 / std::log(2.0));
|
||||
const F xLog2Rounded = std::round(xLog2);
|
||||
const F xLog2Frac = xLog2 - xLog2Rounded;
|
||||
log2Result = static_cast<int>(xLog2Rounded);
|
||||
// Allow small comparison slop below the level that would make a difference
|
||||
// for 2^16 levels.
|
||||
return std::abs(xLog2Frac) < 1e-6;
|
||||
}
|
||||
|
||||
/// Helper class for operating on binary operations where all operands
|
||||
/// and the result are a UniformQuantizedType.
|
||||
struct RealBinaryOpInfo {
|
||||
RealBinaryOpInfo(Operation *op, Value *lhs, Value *rhs,
|
||||
Optional<APFloat> clampMin, Optional<APFloat> clampMax)
|
||||
: op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax),
|
||||
lhsType(getUniformElementType(lhs->getType())),
|
||||
rhsType(getUniformElementType(rhs->getType())),
|
||||
resultType(getUniformElementType(*op->result_type_begin())),
|
||||
lhsStorageType(QuantizedType::castToStorageType(lhs->getType())),
|
||||
rhsStorageType(QuantizedType::castToStorageType(rhs->getType())),
|
||||
resultStorageType(
|
||||
QuantizedType::castToStorageType(*op->result_type_begin())) {}
|
||||
|
||||
/// Returns whether this info is valid (all types defined, etc).
|
||||
bool isValid() const {
|
||||
return lhsType && rhsType && resultType && lhsStorageType &&
|
||||
rhsStorageType && resultStorageType;
|
||||
}
|
||||
|
||||
/// Returns whether the storage type of all operands is identical.
|
||||
bool isSameStorageType() const {
|
||||
return lhsType.getStorageType() == rhsType.getStorageType() &&
|
||||
lhsType.getStorageType() == resultType.getStorageType();
|
||||
}
|
||||
|
||||
/// Returns whether all operands and result are considered fixedpoint power
|
||||
/// of two, setting the lhs, rhs, and result log2 scale references.
|
||||
bool isFixedPointPOT(int &lhsLog2Scale, int &rhsLog2Scale,
|
||||
int &resultLog2Scale) const {
|
||||
if (!lhsType.isFixedPoint() || !rhsType.isFixedPoint() ||
|
||||
!resultType.isFixedPoint()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!integralLog2(lhsType.getScale(), lhsLog2Scale) ||
|
||||
!integralLog2(rhsType.getScale(), rhsLog2Scale) ||
|
||||
!integralLog2(resultType.getScale(), resultLog2Scale)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Gets the result integer clamp range given the result quantized type
|
||||
// and any explicit clamp provided as attributes.
|
||||
std::pair<IntegerAttr, IntegerAttr> getClampMinMax() const {
|
||||
int64_t typeMin = resultType.getStorageTypeMin();
|
||||
int64_t typeMax = resultType.getStorageTypeMax();
|
||||
|
||||
if (clampMin || clampMax) {
|
||||
UniformQuantizedValueConverter conv(resultType);
|
||||
if (clampMin) {
|
||||
typeMin = std::max(typeMin, conv.quantizeFloatToInt64(*clampMin));
|
||||
}
|
||||
if (clampMax) {
|
||||
typeMax = std::min(typeMax, conv.quantizeFloatToInt64(*clampMax));
|
||||
}
|
||||
}
|
||||
|
||||
// The quantized, integral ops expect clamps as 32bit ints.
|
||||
return {
|
||||
IntegerAttr::get(IntegerType::get(32, resultType.getContext()),
|
||||
typeMin),
|
||||
IntegerAttr::get(IntegerType::get(32, resultType.getContext()),
|
||||
typeMax),
|
||||
};
|
||||
}
|
||||
|
||||
Operation *op;
|
||||
Value *lhs;
|
||||
Value *rhs;
|
||||
Optional<APFloat> clampMin;
|
||||
Optional<APFloat> clampMax;
|
||||
|
||||
// Element UniformQuantizedType for operands/result.
|
||||
UniformQuantizedType lhsType;
|
||||
UniformQuantizedType rhsType;
|
||||
UniformQuantizedType resultType;
|
||||
|
||||
// Full storage-based types.
|
||||
Type lhsStorageType;
|
||||
Type rhsStorageType;
|
||||
Type resultStorageType;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Elementwise add
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Attempts to rewrite a fixed point power-of-two addition of two integers.
|
||||
/// This supports a limited number of cases, but when supported, represents
|
||||
/// the simplest computation.
|
||||
static LogicalResult tryRewriteFixedPOTAddEw(const RealBinaryOpInfo &constInfo,
|
||||
PatternRewriter &rewriter) {
|
||||
if (!constInfo.isSameStorageType()) {
|
||||
|
||||
static LogicalResult
|
||||
tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info,
|
||||
PatternRewriter &rewriter) {
|
||||
if (!info.resultType.isSigned() || info.lhsType != info.resultType ||
|
||||
info.rhsType != info.resultType) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
int lhsLog2Scale;
|
||||
int rhsLog2Scale;
|
||||
int resultLog2Scale;
|
||||
if (!constInfo.isFixedPointPOT(lhsLog2Scale, rhsLog2Scale, resultLog2Scale)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Adjust shifts to be relative to the output.
|
||||
// Left shift of one input scale is supported. The other must match the result
|
||||
// scale.
|
||||
int lhsScaleShift = lhsLog2Scale - resultLog2Scale;
|
||||
int rhsScaleShift = rhsLog2Scale - resultLog2Scale;
|
||||
if (lhsScaleShift != 0 && rhsScaleShift != 0) {
|
||||
return failure();
|
||||
}
|
||||
if (lhsScaleShift > 0 || rhsScaleShift > 0) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// State accessed by the closure.
|
||||
Operation *mathOp = constInfo.op;
|
||||
const auto clampMinMax = constInfo.getClampMinMax();
|
||||
Value *lhs = constInfo.lhs;
|
||||
Value *rhs = constInfo.rhs;
|
||||
Type lhsStorageType = constInfo.lhsStorageType;
|
||||
Type rhsStorageType = constInfo.rhsStorageType;
|
||||
|
||||
// If the lhs operand is the one requiring a shift, swap it so that the shift
|
||||
// happens the rhs operand.
|
||||
if (lhsScaleShift != 0) {
|
||||
std::swap(lhs, rhs);
|
||||
std::swap(lhsStorageType, rhsStorageType);
|
||||
std::swap(lhsScaleShift, rhsScaleShift);
|
||||
}
|
||||
int rhsRightShift = -rhsScaleShift;
|
||||
// Choose a byte aligned intermediate width big enough to perform the
|
||||
// calculation without overflow.
|
||||
// TODO: This should probably be made just big enough to avoid overflow and
|
||||
// leave the downstream tooling to decide how to align that to machine
|
||||
// word sizes.
|
||||
unsigned intermediateWidth =
|
||||
info.resultType.getStorageTypeIntegralWidth() <= 8 ? 16 : 32;
|
||||
IntegerType intermediateElementType =
|
||||
IntegerType::get(intermediateWidth, rewriter.getContext());
|
||||
Type intermediateType =
|
||||
castElementType(info.resultStorageType, intermediateElementType);
|
||||
|
||||
// Cast operands to storage type.
|
||||
Value *lhsStorageValue =
|
||||
rewriter.create<StorageCastOp>(mathOp->getLoc(), lhsStorageType, lhs)
|
||||
.getResult();
|
||||
Value *rhsStorageValue =
|
||||
rewriter.create<StorageCastOp>(mathOp->getLoc(), rhsStorageType, rhs)
|
||||
.getResult();
|
||||
Value *lhsValue = rewriter
|
||||
.create<StorageCastOp>(info.op->getLoc(),
|
||||
info.lhsStorageType, info.lhs)
|
||||
.getResult();
|
||||
Value *rhsValue = rewriter
|
||||
.create<StorageCastOp>(info.op->getLoc(),
|
||||
info.rhsStorageType, info.rhs)
|
||||
.getResult();
|
||||
|
||||
// Rescale the rhs operand if needed.
|
||||
if (rhsRightShift != 0) {
|
||||
rhsStorageValue =
|
||||
rewriter
|
||||
.create<RoundingDivideByPotFxpOp>(
|
||||
mathOp->getLoc(), rhsStorageValue,
|
||||
IntegerAttr::get(IntegerType::get(32, rewriter.getContext()),
|
||||
rhsRightShift))
|
||||
.getResult();
|
||||
}
|
||||
// Cast to the intermediate sized type.
|
||||
lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
|
||||
lhsValue);
|
||||
rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
|
||||
rhsValue);
|
||||
|
||||
// Add.
|
||||
Value *sumValue = rewriter.create<SaturatingAddFxpOp>(
|
||||
mathOp->getLoc(), lhsStorageValue, rhsStorageValue, clampMinMax.first,
|
||||
clampMinMax.second);
|
||||
Value *resultValue =
|
||||
rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, rhsValue);
|
||||
|
||||
// Zero point offset adjustment.
|
||||
// result = (lhs - zp) + (rhs - zp) + zp
|
||||
// zpOffset = -zp
|
||||
int zpOffset = -1 * info.resultType.getZeroPoint();
|
||||
if (zpOffset != 0) {
|
||||
Value *zpOffsetConst = rewriter.create<ConstantOp>(
|
||||
info.op->getLoc(),
|
||||
broadcastScalarConstIntValue(intermediateType, zpOffset));
|
||||
resultValue =
|
||||
rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
|
||||
}
|
||||
|
||||
// Clamp.
|
||||
auto clampMinMax = info.getClampMinMax(intermediateElementType);
|
||||
resultValue = rewriter.create<ClampISOp>(
|
||||
info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
|
||||
|
||||
// Convert back to original type.
|
||||
resultValue = rewriter.create<ConvertISOp>(
|
||||
info.op->getLoc(), info.resultStorageType, resultValue);
|
||||
|
||||
// Cast back for new result.
|
||||
rewriter.replaceOpWithNewOp<StorageCastOp>(
|
||||
mathOp, *mathOp->result_type_begin(), sumValue);
|
||||
info.op, info.getQuantizedResultType(), resultValue);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Elementwise mul
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult
|
||||
tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info,
|
||||
PatternRewriter &rewriter) {
|
||||
if (!info.resultType.isSigned()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
double outputMultiplierReal = info.lhsType.getScale() *
|
||||
info.rhsType.getScale() /
|
||||
info.resultType.getScale();
|
||||
if (outputMultiplierReal > 1.0) {
|
||||
info.op->emitWarning("unimplemented: cannot multiply with multipler > 1.0");
|
||||
return failure();
|
||||
}
|
||||
|
||||
// TODO: Choose an appropriate intermediate width for muls > 8 bits to
|
||||
// avoid overflow.
|
||||
unsigned intermediateWidth = 32;
|
||||
IntegerType intermediateElementType =
|
||||
IntegerType::get(intermediateWidth, rewriter.getContext());
|
||||
Type intermediateType =
|
||||
castElementType(info.resultStorageType, intermediateElementType);
|
||||
|
||||
// Cast operands to storage type.
|
||||
Value *lhsValue = rewriter
|
||||
.create<StorageCastOp>(info.op->getLoc(),
|
||||
info.lhsStorageType, info.lhs)
|
||||
.getResult();
|
||||
Value *rhsValue = rewriter
|
||||
.create<StorageCastOp>(info.op->getLoc(),
|
||||
info.rhsStorageType, info.rhs)
|
||||
.getResult();
|
||||
|
||||
// Cast to the intermediate sized type.
|
||||
lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
|
||||
lhsValue);
|
||||
rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
|
||||
rhsValue);
|
||||
|
||||
// Apply argument zeroPoints.
|
||||
if (info.lhsType.getZeroPoint() != 0) {
|
||||
Value *zpOffsetConst = rewriter.create<ConstantOp>(
|
||||
info.op->getLoc(), broadcastScalarConstIntValue(
|
||||
intermediateType, -info.lhsType.getZeroPoint()));
|
||||
lhsValue =
|
||||
rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, zpOffsetConst);
|
||||
}
|
||||
|
||||
if (info.rhsType.getZeroPoint() != 0) {
|
||||
Value *zpOffsetConst = rewriter.create<ConstantOp>(
|
||||
info.op->getLoc(), broadcastScalarConstIntValue(
|
||||
intermediateType, -info.rhsType.getZeroPoint()));
|
||||
rhsValue =
|
||||
rewriter.create<AddIOp>(info.op->getLoc(), rhsValue, zpOffsetConst);
|
||||
}
|
||||
|
||||
// Mul.
|
||||
Value *resultValue =
|
||||
rewriter.create<MulIOp>(info.op->getLoc(), lhsValue, rhsValue);
|
||||
|
||||
// Scale output.
|
||||
QuantizedMultiplierSmallerThanOneExp outputMultiplier(outputMultiplierReal);
|
||||
resultValue = rewriter.create<VecScalarSaturatingRoundingDoublingHighMulISOp>(
|
||||
info.op->getLoc(), resultValue,
|
||||
IntegerAttr::get(intermediateElementType, outputMultiplier.multiplier));
|
||||
resultValue = rewriter.create<RoundingDivideByPotISOp>(
|
||||
info.op->getLoc(), resultValue,
|
||||
IntegerAttr::get(intermediateElementType, -outputMultiplier.exponent));
|
||||
|
||||
// Zero point offset adjustment.
|
||||
if (info.resultType.getZeroPoint() != 0) {
|
||||
Value *zpOffsetConst = rewriter.create<ConstantOp>(
|
||||
info.op->getLoc(),
|
||||
broadcastScalarConstIntValue(intermediateType,
|
||||
info.resultType.getZeroPoint()));
|
||||
resultValue =
|
||||
rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
|
||||
}
|
||||
|
||||
// Clamp.
|
||||
auto clampMinMax = info.getClampMinMax(intermediateElementType);
|
||||
resultValue = rewriter.create<ClampISOp>(
|
||||
info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
|
||||
|
||||
// Convert back to original type.
|
||||
resultValue = rewriter.create<ConvertISOp>(
|
||||
info.op->getLoc(), info.resultStorageType, resultValue);
|
||||
|
||||
// Cast back for new result.
|
||||
rewriter.replaceOpWithNewOp<StorageCastOp>(
|
||||
info.op, info.getQuantizedResultType(), resultValue);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -227,14 +217,36 @@ struct UniformRealAddEwPattern : public RewritePattern {
|
|||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto addOp = op->cast<RealAddEwOp>();
|
||||
const RealBinaryOpInfo info(op, addOp.x(), addOp.y(), addOp.clamp_min(),
|
||||
addOp.clamp_max());
|
||||
const UniformBinaryOpInfo info(op, addOp.x(), addOp.y(), addOp.clamp_min(),
|
||||
addOp.clamp_max());
|
||||
if (!info.isValid()) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
// Try all of the permutations we support.
|
||||
if (succeeded(tryRewriteFixedPOTAddEw(info, rewriter))) {
|
||||
if (succeeded(tryRewriteAffineAddEwIsomorphicSigned(info, rewriter))) {
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
return matchFailure();
|
||||
}
|
||||
};
|
||||
|
||||
struct UniformRealMulEwPattern : public RewritePattern {
|
||||
UniformRealMulEwPattern(MLIRContext *context)
|
||||
: RewritePattern(RealMulEwOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto mulOp = op->cast<RealMulEwOp>();
|
||||
const UniformBinaryOpInfo info(op, mulOp.x(), mulOp.y(), mulOp.clamp_min(),
|
||||
mulOp.clamp_max());
|
||||
if (!info.isValid()) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
// Try all of the permutations we support.
|
||||
if (succeeded(tryRewriteAffineMulEwSigned(info, rewriter))) {
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
|
@ -249,6 +261,7 @@ void LowerUniformRealMathPass::runOnFunction() {
|
|||
OwningRewritePatternList patterns;
|
||||
auto *context = &getContext();
|
||||
patterns.push_back(llvm::make_unique<UniformRealAddEwPattern>(context));
|
||||
patterns.push_back(llvm::make_unique<UniformRealMulEwPattern>(context));
|
||||
applyPatternsGreedily(fn, std::move(patterns));
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,203 @@
|
|||
//===- UniformKernelUtils.h - Utilities for lowering uniform math - C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
|
||||
#define MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Quantization/QuantOps.h"
|
||||
#include "mlir/Quantization/QuantTypes.h"
|
||||
#include "mlir/Quantization/UniformSupport.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
namespace mlir {
|
||||
namespace fxpmath {
|
||||
namespace detail {
|
||||
|
||||
inline quant::UniformQuantizedType getUniformElementType(Type t) {
|
||||
return quant::QuantizedType::getQuantizedElementType(t)
|
||||
.dyn_cast_or_null<quant::UniformQuantizedType>();
|
||||
}
|
||||
|
||||
inline bool hasStorageBitWidth(quant::QuantizedType t,
|
||||
llvm::ArrayRef<unsigned> checkWidths) {
|
||||
unsigned w = t.getStorageType().getIntOrFloatBitWidth();
|
||||
for (unsigned checkWidth : checkWidths) {
|
||||
if (w == checkWidth)
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Computes the log2(x), rounded to an integral value. Returns whether 'x' can
|
||||
/// be considered an exact integral value.
|
||||
template <typename F> bool integralLog2(F x, int &log2Result) {
|
||||
const F xLog2 = std::log(x) * (1.0 / std::log(2.0));
|
||||
const F xLog2Rounded = std::round(xLog2);
|
||||
const F xLog2Frac = xLog2 - xLog2Rounded;
|
||||
log2Result = static_cast<int>(xLog2Rounded);
|
||||
// Allow small comparison slop below the level that would make a difference
|
||||
// for 2^16 levels.
|
||||
return std::abs(xLog2Frac) < 1e-6;
|
||||
}
|
||||
|
||||
/// Helper class for operating on binary operations where all operands
|
||||
/// and the result are a UniformQuantizedType.
|
||||
struct UniformBinaryOpInfo {
|
||||
UniformBinaryOpInfo(Operation *op, Value *lhs, Value *rhs,
|
||||
Optional<APFloat> clampMin, Optional<APFloat> clampMax)
|
||||
: op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax),
|
||||
lhsType(getUniformElementType(lhs->getType())),
|
||||
rhsType(getUniformElementType(rhs->getType())),
|
||||
resultType(getUniformElementType(*op->result_type_begin())),
|
||||
lhsStorageType(quant::QuantizedType::castToStorageType(lhs->getType())),
|
||||
rhsStorageType(quant::QuantizedType::castToStorageType(rhs->getType())),
|
||||
resultStorageType(
|
||||
quant::QuantizedType::castToStorageType(*op->result_type_begin())) {
|
||||
}
|
||||
|
||||
/// Returns whether this info is valid (all types defined, etc).
|
||||
bool isValid() const {
|
||||
return lhsType && rhsType && resultType && lhsStorageType &&
|
||||
rhsStorageType && resultStorageType;
|
||||
}
|
||||
|
||||
/// Gets the final quantized result type of the result.
|
||||
Type getQuantizedResultType() const { return *op->result_type_begin(); }
|
||||
|
||||
/// Returns whether the storage type of all operands is identical.
|
||||
bool isSameStorageType() const {
|
||||
return lhsType.getStorageType() == rhsType.getStorageType() &&
|
||||
lhsType.getStorageType() == resultType.getStorageType();
|
||||
}
|
||||
|
||||
/// Returns whether all operands and result are considered fixedpoint power
|
||||
/// of two, setting the lhs, rhs, and result log2 scale references.
|
||||
bool isFixedPointPOT(int &lhsLog2Scale, int &rhsLog2Scale,
|
||||
int &resultLog2Scale) const {
|
||||
if (!lhsType.isFixedPoint() || !rhsType.isFixedPoint() ||
|
||||
!resultType.isFixedPoint()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!integralLog2(lhsType.getScale(), lhsLog2Scale) ||
|
||||
!integralLog2(rhsType.getScale(), rhsLog2Scale) ||
|
||||
!integralLog2(resultType.getScale(), resultLog2Scale)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Gets the result integer clamp range given the result quantized type
|
||||
// and any explicit clamp provided as attributes.
|
||||
std::pair<IntegerAttr, IntegerAttr> getClampMinMax(IntegerType ty) const {
|
||||
int64_t typeMin = resultType.getStorageTypeMin();
|
||||
int64_t typeMax = resultType.getStorageTypeMax();
|
||||
|
||||
if (clampMin || clampMax) {
|
||||
quant::UniformQuantizedValueConverter conv(resultType);
|
||||
if (clampMin) {
|
||||
typeMin = std::max(typeMin, conv.quantizeFloatToInt64(*clampMin));
|
||||
}
|
||||
if (clampMax) {
|
||||
typeMax = std::min(typeMax, conv.quantizeFloatToInt64(*clampMax));
|
||||
}
|
||||
}
|
||||
|
||||
// The quantized, integral ops expect clamps as 32bit ints.
|
||||
return {
|
||||
IntegerAttr::get(ty, typeMin),
|
||||
IntegerAttr::get(ty, typeMax),
|
||||
};
|
||||
}
|
||||
|
||||
Operation *op;
|
||||
Value *lhs;
|
||||
Value *rhs;
|
||||
Optional<APFloat> clampMin;
|
||||
Optional<APFloat> clampMax;
|
||||
|
||||
// Element UniformQuantizedType for operands/result.
|
||||
quant::UniformQuantizedType lhsType;
|
||||
quant::UniformQuantizedType rhsType;
|
||||
quant::UniformQuantizedType resultType;
|
||||
|
||||
// Full storage-based types.
|
||||
Type lhsStorageType;
|
||||
Type rhsStorageType;
|
||||
Type resultStorageType;
|
||||
};
|
||||
|
||||
/// Derives a quantized multiplier and shift from a real valued multiplier
|
||||
/// less than 1.
|
||||
struct QuantizedMultiplierSmallerThanOneExp {
|
||||
QuantizedMultiplierSmallerThanOneExp(double realMultiplier) {
|
||||
assert(realMultiplier < 1.0);
|
||||
assert(realMultiplier > 0.0);
|
||||
|
||||
const double q = std::frexp(realMultiplier, &exponent);
|
||||
auto qFixed = static_cast<int64_t>(std::round(q * (1ll << 31)));
|
||||
assert(qFixed <= (1ll << 31));
|
||||
if (qFixed == (1ll << 31)) {
|
||||
qFixed /= 2;
|
||||
++exponent;
|
||||
}
|
||||
assert(qFixed <= std::numeric_limits<int32_t>::max());
|
||||
multiplier = static_cast<int32_t>(qFixed);
|
||||
}
|
||||
|
||||
int32_t multiplier;
|
||||
int exponent;
|
||||
};
|
||||
|
||||
/// Casts an integer or floating point based type to a new element type.
|
||||
inline Type castElementType(Type t, Type newElementType) {
|
||||
if (auto vt = t.dyn_cast<VectorOrTensorType>()) {
|
||||
switch (vt.getKind()) {
|
||||
case StandardTypes::Kind::Vector:
|
||||
return VectorType::get(vt.getShape(), newElementType);
|
||||
case StandardTypes::Kind::RankedTensor:
|
||||
return RankedTensorType::get(vt.getShape(), newElementType);
|
||||
case StandardTypes::Kind::UnrankedTensor:
|
||||
return UnrankedTensorType::get(newElementType);
|
||||
}
|
||||
}
|
||||
assert(t.isIntOrFloat());
|
||||
return newElementType;
|
||||
}
|
||||
|
||||
/// Creates an IntegerAttr with a type that matches the shape of 't' (which can
|
||||
/// be a primitive/vector/tensor).
|
||||
inline Attribute broadcastScalarConstIntValue(Type t, int64_t value) {
|
||||
if (auto vt = t.dyn_cast<VectorOrTensorType>()) {
|
||||
assert(vt.getElementType().isa<IntegerType>());
|
||||
return SplatElementsAttr::get(vt,
|
||||
IntegerAttr::get(vt.getElementType(), value));
|
||||
}
|
||||
|
||||
auto integerType = t.cast<IntegerType>();
|
||||
assert(t.isa<IntegerType>() && "integer broadcast must be of integer type");
|
||||
return IntegerAttr::get(integerType, value);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace fxpmath
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
|
|
@ -18,6 +18,8 @@
|
|||
#include "mlir/Quantization/QuantOps.h"
|
||||
#include "TypeDetail.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Quantization/QuantTypes.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
@ -31,6 +33,41 @@ using namespace mlir::quant::detail;
|
|||
#define GET_OP_CLASSES
|
||||
#include "mlir/Quantization/QuantOps.cpp.inc"
|
||||
|
||||
namespace {
|
||||
|
||||
/// Matches x -> [scast -> scast] -> y, replacing the second scast with the
|
||||
/// value of x if the casts invert each other.
|
||||
class RemoveRedundantStorageCastsRewrite : public RewritePattern {
|
||||
public:
|
||||
RemoveRedundantStorageCastsRewrite(MLIRContext *context)
|
||||
: RewritePattern(StorageCastOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
auto scastOp = op->cast<StorageCastOp>();
|
||||
if (matchPattern(scastOp.arg(), m_Op<StorageCastOp>())) {
|
||||
auto srcScastOp = scastOp.arg()->getDefiningOp()->cast<StorageCastOp>();
|
||||
if (srcScastOp.arg()->getType() == scastOp.getResult()->getType()) {
|
||||
return matchSuccess();
|
||||
}
|
||||
}
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
||||
auto scastOp = op->cast<StorageCastOp>();
|
||||
auto srcScastOp = scastOp.arg()->getDefiningOp()->cast<StorageCastOp>();
|
||||
rewriter.replaceOp(op, srcScastOp.arg());
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
void StorageCastOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context) {
|
||||
patterns.push_back(
|
||||
llvm::make_unique<RemoveRedundantStorageCastsRewrite>(context));
|
||||
}
|
||||
|
||||
QuantizationDialect::QuantizationDialect(MLIRContext *context)
|
||||
: Dialect(/*name=*/"quant", context) {
|
||||
addTypes<UniformQuantizedType, UniformQuantizedPerAxisType>();
|
||||
|
|
|
@ -1,51 +1,44 @@
|
|||
// RUN: mlir-opt %s -split-input-file -fxpmath-lower-uniform-real-math | FileCheck %s --dump-input=fail
|
||||
// RUN: mlir-opt %s -split-input-file -fxpmath-lower-uniform-real-math -canonicalize | FileCheck %s --dump-input=always
|
||||
|
||||
// -----
|
||||
// Verify lowering when operands and result have the same fixedpoint pot scale.
|
||||
// CHECK-LABEL: real_addew_fixedpoint_same_scale
|
||||
// CHECK: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %2 = "fxpmath.saturating_addi"(%0, %1) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %3 = "quant.scast"(%2) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
|
||||
// CHECK-NEXT: return %3 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
|
||||
// Verify lowering when operands and result have the same fixedpoint scale.
|
||||
// CHECK-LABEL: real_addew_fixedpoint_isomorphic
|
||||
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
func @real_addew_fixedpoint_same_scale(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
func @real_addew_fixedpoint_isomorphic(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// CHECK-NEXT: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %2 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi16>
|
||||
// CHECK-NEXT: %3 = "fxpmath.convertis"(%1) : (tensor<4xi8>) -> tensor<4xi16>
|
||||
// CHECK-NEXT: %4 = addi %2, %3 : tensor<4xi16>
|
||||
// CHECK-NEXT: %5 = "fxpmath.clampis"(%4) {clamp_max: 127 : i16, clamp_min: -128 : i16} : (tensor<4xi16>) -> tensor<4xi16>
|
||||
// CHECK-NEXT: %6 = "fxpmath.convertis"(%5) : (tensor<4xi16>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %7 = "quant.scast"(%6) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
|
||||
// CHECK-NEXT: return %7 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
|
||||
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
// -----
|
||||
// Verify lowering when the rhs is a shifted pot scale compared to lhs and result.
|
||||
// CHECK-LABEL: real_addew_fixedpoint_rhs_shift
|
||||
// CHECK: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %2 = "fxpmath.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %3 = "fxpmath.saturating_addi"(%0, %2) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %4 = "quant.scast"(%3) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
|
||||
// CHECK-NEXT: return %4 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
|
||||
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{7.8125e-03}">>
|
||||
!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
func @real_addew_fixedpoint_rhs_shift(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
// -----
|
||||
// Verify lowering when the lhs is a shifted pot scale compared to lhs and result.
|
||||
// CHECK-LABEL: real_addew_fixedpoint_lhs_shift
|
||||
// CHECK: %0 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %1 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %2 = "fxpmath.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %3 = "fxpmath.saturating_addi"(%0, %2) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %4 = "quant.scast"(%3) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
|
||||
// CHECK-NEXT: return %4 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
|
||||
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{7.8125e-03}">>
|
||||
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
func @real_addew_fixedpoint_lhs_shift(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// Verify lowering when operands and result have the same fixedpoint scale
|
||||
// and non-zero zero points.
|
||||
// CHECK-LABEL: real_addew_affine_isomorphic
|
||||
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2:-5}">>
|
||||
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2:-5}">>
|
||||
!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2:-5}">>
|
||||
func @real_addew_affine_isomorphic(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// CHECK-NEXT: %cst = constant splat<tensor<4xi16>, 5> : tensor<4xi16>
|
||||
// CHECK-NEXT: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02:-5}">>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02:-5}">>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %2 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi16>
|
||||
// CHECK-NEXT: %3 = "fxpmath.convertis"(%1) : (tensor<4xi8>) -> tensor<4xi16>
|
||||
// CHECK-NEXT: %4 = addi %2, %3 : tensor<4xi16>
|
||||
// CHECK-NEXT: %5 = addi %4, %cst : tensor<4xi16>
|
||||
// CHECK-NEXT: %6 = "fxpmath.clampis"(%5) {clamp_max: 127 : i16, clamp_min: -128 : i16} : (tensor<4xi16>) -> tensor<4xi16>
|
||||
// CHECK-NEXT: %7 = "fxpmath.convertis"(%6) : (tensor<4xi16>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %8 = "quant.scast"(%7) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02:-5}">>
|
||||
// CHECK-NEXT: return %8 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02:-5}">>
|
||||
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
@ -54,16 +47,19 @@ func @real_addew_fixedpoint_lhs_shift(%arg0 : !type_lhs, %arg1: !type_rhs) -> !t
|
|||
// The RHS quant parameters proscribe a range of [-8..8) so an explicit clamp
|
||||
// of [-4..4] should result in an integral clamp range of [-64..64].
|
||||
// CHECK-LABEL: real_addew_fixedpoint_clamp
|
||||
// CHECK: %0 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %1 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %2 = "fxpmath.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %3 = "fxpmath.saturating_addi"(%0, %2) {clamp_max: 64 : i32, clamp_min: -64 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %4 = "quant.scast"(%3) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
|
||||
// CHECK-NEXT: return %4 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
|
||||
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{7.8125e-03}">>
|
||||
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
func @real_addew_fixedpoint_clamp(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// CHECK-NEXT: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %2 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi16>
|
||||
// CHECK-NEXT: %3 = "fxpmath.convertis"(%1) : (tensor<4xi8>) -> tensor<4xi16>
|
||||
// CHECK-NEXT: %4 = addi %2, %3 : tensor<4xi16>
|
||||
// CHECK-NEXT: %5 = "fxpmath.clampis"(%4) {clamp_max: 64 : i16, clamp_min: -64 : i16} : (tensor<4xi16>) -> tensor<4xi16>
|
||||
// CHECK-NEXT: %6 = "fxpmath.convertis"(%5) : (tensor<4xi16>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %7 = "quant.scast"(%6) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
|
||||
// CHECK-NEXT: return %7 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
|
||||
%0 = "fxpmath.real_add_ew"(%arg0, %arg1) { clamp_min:-4.0, clamp_max:4.0 }
|
||||
: (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
// RUN: mlir-opt %s -split-input-file -fxpmath-lower-uniform-real-math -canonicalize -verify | FileCheck %s --dump-input=always
|
||||
|
||||
// -----
|
||||
// Verify lowering when operands and result have the same fixedpoint scale.
|
||||
// CHECK-LABEL: real_mulew_fixedpoint
|
||||
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{3.875e-2}">>
|
||||
!type_result = type tensor<4x!quant<"uniform[i8:f32]{1.065e-1}">>
|
||||
func @real_mulew_fixedpoint(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// CHECK: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{3.875000e-02}">>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %2 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi32>
|
||||
// CHECK-NEXT: %3 = "fxpmath.convertis"(%1) : (tensor<4xi8>) -> tensor<4xi32>
|
||||
// CHECK-NEXT: %4 = muli %2, %3 : tensor<4xi32>
|
||||
// CHECK-NEXT: %5 = "fxpmath.vs_saturating_rounding_doubling_high_mulis"(%4) {b: 1562722842 : i32} : (tensor<4xi32>) -> tensor<4xi32>
|
||||
// CHECK-NEXT: %6 = "fxpmath.rounding_divide_by_potis"(%5) {exponent: 5 : i32} : (tensor<4xi32>) -> tensor<4xi32>
|
||||
// CHECK-NEXT: %7 = "fxpmath.clampis"(%6) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi32>) -> tensor<4xi32>
|
||||
// CHECK-NEXT: %8 = "fxpmath.convertis"(%7) : (tensor<4xi32>) -> tensor<4xi8>
|
||||
// CHECK-NEXT: %9 = "quant.scast"(%8) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{1.065000e-01}">>
|
||||
// CHECK-NEXT: return %9 : tensor<4x!quant<"uniform[i8:f32]{1.065000e-01}">>
|
||||
%0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
// -----
|
||||
// Verify lowering when operands and result have the same fixedpoint scale
|
||||
// and non-zero zero points.
|
||||
// CHECK-LABEL: real_mulew_affine_clamp
|
||||
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2:-3}">>
|
||||
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2:-5}">>
|
||||
!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2:-9}">>
|
||||
func @real_mulew_affine_clamp(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// Just verify that the affine adds/constants and clamps are present.
|
||||
// CHECK: %cst = constant splat<tensor<4xi32>, 3> : tensor<4xi32>
|
||||
// CHECK: %cst_0 = constant splat<tensor<4xi32>, 5> : tensor<4xi32>
|
||||
// CHECK: %cst_1 = constant splat<tensor<4xi32>, -9> : tensor<4xi32>
|
||||
// CHECK: addi %2, %cst : tensor<4xi32>
|
||||
// CHECK: addi %3, %cst_0 : tensor<4xi32>
|
||||
// CHECK: muli %4, %5 : tensor<4xi32>
|
||||
// CHECK: addi %8, %cst_1 : tensor<4xi32>
|
||||
// CHECK: {clamp_max: 55 : i32, clamp_min: -73 : i32}
|
||||
%0 = "fxpmath.real_mul_ew"(%arg0, %arg1) { clamp_min:-4.0, clamp_max:4.0 } : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: real_mulew_unquantized_lhs
|
||||
// Verifies that leaves as-is for unquantized lhs.
|
||||
!type_lhs = type tensor<4xf32>
|
||||
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
func @real_mulew_unquantized_lhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// CHECK: %0 = "fxpmath.real_mul_ew"(%arg0, %arg1)
|
||||
%0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: real_mulew_unquantized_rhs
|
||||
// Verifies that leaves as-is for unquantized rhs.
|
||||
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
!type_rhs = type tensor<4xf32>
|
||||
!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
func @real_mulew_unquantized_rhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// CHECK: %0 = "fxpmath.real_mul_ew"(%arg0, %arg1)
|
||||
%0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: real_mulew_unquantized_result
|
||||
// Verifies that leaves as-is for unquantized result.
|
||||
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
!type_result = type tensor<4xf32>
|
||||
func @real_mulew_unquantized_result(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// CHECK: %0 = "fxpmath.real_mul_ew"(%arg0, %arg1)
|
||||
%0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
||||
|
||||
// -----
|
||||
// Verify lowering when operands and result have the same fixedpoint scale.
|
||||
// Note that the multiplier = lhs_scale * rhs_scale / result_scale
|
||||
// = 22.740610328638496
|
||||
// CHECK-LABEL: real_mulew_multiplier_gt_1
|
||||
!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
|
||||
!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{3.875e-2}">>
|
||||
!type_result = type tensor<4x!quant<"uniform[i8:f32]{1.065e-4}">>
|
||||
func @real_mulew_multiplier_gt_1(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
|
||||
// expected-warning@+1 {{unimplemented: cannot multiply with multipler > 1.0}}
|
||||
%0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
|
||||
return %0 : !type_result
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s --dump-input=fail
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: redundant_scast
|
||||
func @redundant_scast() -> tensor<4xi8> {
|
||||
// CHECK-NEXT: constant splat<tensor<4xi8>, 10>
|
||||
// CHECK-NEXT: return
|
||||
%cst = constant splat<tensor<4xi8>, 5> : tensor<4xi8>
|
||||
%1 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">>
|
||||
%2 = "quant.scast"(%1) : (tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">>) -> tensor<4xi8>
|
||||
%3 = addi %2, %2 : tensor<4xi8>
|
||||
return %3 : tensor<4xi8>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: non_redundant_scast
|
||||
func @non_redundant_scast() -> tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">> {
|
||||
// CHECK-NEXT: constant splat<tensor<4xi8>, 5>
|
||||
// CHECK-NEXT: scast
|
||||
// CHECK-NEXT: return
|
||||
%cst = constant splat<tensor<4xi8>, 5> : tensor<4xi8>
|
||||
%1 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">>
|
||||
return %1 : tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">>
|
||||
}
|
Loading…
Reference in New Issue