forked from OSchip/llvm-project
[mlir][tosa] Rework tosa.apply_scale lowering for 32-bit
Added handling rounding behavior in 32-bits for when possible. This avoids kernel compilation generating scalarized code on platforms where 64-bit vectors are not available. As the 48-bit lowering requires 64-bit anyway, we added a full 64-bit solution simplifying the old path. Reviewed By: dcaballe, mravishankar Differential Revision: https://reviews.llvm.org/D125583
This commit is contained in:
parent
d4545e6fa0
commit
9294a1e9a8
|
@ -756,7 +756,10 @@ def TosaToArith : Pass<"tosa-to-arith"> {
|
|||
let options = [
|
||||
Option<"includeApplyRescale", "include-apply-rescale",
|
||||
"bool", /*default=*/"false",
|
||||
"Whether to include the lowering for tosa.apply_rescale to arith">
|
||||
"Whether to include the lowering for tosa.apply_rescale to arith">,
|
||||
Option<"use32Bit", "use-32-bit",
|
||||
"bool", /*default=*/"false",
|
||||
"Whether to prioritze lowering to 32-bit operations">
|
||||
];
|
||||
|
||||
let constructor = "tosa::createTosaToArith()";
|
||||
|
|
|
@ -22,7 +22,8 @@ std::unique_ptr<Pass> createTosaToArith();
|
|||
|
||||
void populateTosaToArithConversionPatterns(RewritePatternSet *patterns);
|
||||
|
||||
void populateTosaRescaleToArithConversionPatterns(RewritePatternSet *patterns);
|
||||
void populateTosaRescaleToArithConversionPatterns(RewritePatternSet *patterns,
|
||||
bool include32Bit = false);
|
||||
|
||||
} // namespace tosa
|
||||
} // namespace mlir
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -49,103 +50,194 @@ Attribute getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
|
|||
return rewriter.getIntegerAttr(type, value);
|
||||
}
|
||||
|
||||
Value getConstantValue(Location loc, Type type, int64_t value,
|
||||
PatternRewriter &rewriter) {
|
||||
return rewriter.create<arith::ConstantOp>(
|
||||
loc, getConstantAttr(type, value, rewriter));
|
||||
}
|
||||
|
||||
// This converts the TOSA ApplyScale operator to a set of arithmetic ops,
|
||||
// using 64-bit operations to perform the necessary multiply, bias, and shift.
|
||||
// Multiple types are used to use minimal bit width operations.
|
||||
class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
|
||||
class ApplyScaleGenericOpConverter
|
||||
: public OpRewritePattern<tosa::ApplyScaleOp> {
|
||||
public:
|
||||
using OpRewritePattern<tosa::ApplyScaleOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
Location loc = op.getLoc();
|
||||
Value value32 = op.value();
|
||||
Value value = op.value();
|
||||
Value multiplier32 = op.multiplier();
|
||||
Value shift8 = op.shift();
|
||||
|
||||
bool doubleRound = op.double_round();
|
||||
Type inType = op.value().getType();
|
||||
Type resultTy = op.getType();
|
||||
|
||||
Type i8Ty = matchContainerType(rewriter.getIntegerType(8), resultTy);
|
||||
Type valueTy = value.getType();
|
||||
Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
|
||||
Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy);
|
||||
|
||||
Value one8 = rewriter.create<arith::ConstantOp>(
|
||||
loc, getConstantAttr(i8Ty, 1, rewriter));
|
||||
Value one64 = rewriter.create<arith::ConstantOp>(
|
||||
loc, getConstantAttr(i64Ty, 1, rewriter));
|
||||
Value zero = getConstantValue(loc, valueTy, 0, rewriter);
|
||||
Value one64 = getConstantValue(loc, i64Ty, 1, rewriter);
|
||||
Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter);
|
||||
|
||||
Value shiftSubOne8 = rewriter.create<arith::SubIOp>(loc, shift8, one8);
|
||||
Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.shift());
|
||||
|
||||
// The rounding value semantics below equate to the following code:
|
||||
// int64_t round = 1 << (shift - 1);
|
||||
// if (double_round) {
|
||||
// if (shift > 31 && value >= 0) round += 1<<30;
|
||||
// if (shift > 31 && value < 0) round -= 1<<30;
|
||||
// }
|
||||
//
|
||||
// Note that minimal bitwidth operators are used throughout the block.
|
||||
// Compute the multiplication in 64-bits then select the high / low parts.
|
||||
Value value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value);
|
||||
Value multiplier64 =
|
||||
rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
|
||||
Value multiply64 =
|
||||
rewriter.create<arith::MulIOp>(loc, value64, multiplier64);
|
||||
|
||||
Value round64 = rewriter.create<arith::ShLIOp>(
|
||||
loc, one64, rewriter.create<arith::ExtSIOp>(loc, i64Ty, shiftSubOne8));
|
||||
// Apply normal rounding.
|
||||
Value shift64 = rewriter.create<arith::ExtUIOp>(loc, i64Ty, shift32);
|
||||
Value round = rewriter.create<arith::ShLIOp>(loc, one64, shift64);
|
||||
round = rewriter.create<arith::ShRUIOp>(loc, round, one64);
|
||||
multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round);
|
||||
|
||||
// Double rounding is performing a round operation before the shift
|
||||
if (doubleRound) {
|
||||
Value one32 = rewriter.create<arith::ConstantOp>(
|
||||
loc, getConstantAttr(i32Ty, 1, rewriter));
|
||||
Value shift32 = rewriter.create<arith::ExtSIOp>(loc, i32Ty, shift8);
|
||||
Value thirty32 = rewriter.create<arith::ConstantOp>(
|
||||
loc, getConstantAttr(i32Ty, 30, rewriter));
|
||||
|
||||
Value shiftThirty32 =
|
||||
rewriter.create<arith::ShLIOp>(loc, one32, thirty32);
|
||||
Value shiftThirty64 =
|
||||
rewriter.create<arith::ExtSIOp>(loc, i64Ty, shiftThirty32);
|
||||
|
||||
// Round value needs to with be added or subtracted depending on the sign
|
||||
// of the input value.
|
||||
Value roundAdd64 =
|
||||
rewriter.create<arith::AddIOp>(loc, round64, shiftThirty64);
|
||||
Value roundSub64 =
|
||||
rewriter.create<arith::SubIOp>(loc, round64, shiftThirty64);
|
||||
|
||||
Value zero32 =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(inType));
|
||||
Value valueGreaterThanZero = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sge, value32, zero32);
|
||||
|
||||
Value doubleRound64 = rewriter.create<arith::SelectOp>(
|
||||
loc, valueGreaterThanZero, roundAdd64, roundSub64);
|
||||
|
||||
// We only perform double rounding if the shift value is greater than 32.
|
||||
Value thirtyTwo32 = rewriter.create<arith::ConstantOp>(
|
||||
loc, getConstantAttr(i32Ty, 32, rewriter));
|
||||
Value shiftGreaterThanThirtyTwo = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
|
||||
round64 = rewriter.create<arith::SelectOp>(loc, shiftGreaterThanThirtyTwo,
|
||||
doubleRound64, round64);
|
||||
// Apply double rounding if necessary.
|
||||
if (op.double_round()) {
|
||||
int64_t roundInt = 1 << 30;
|
||||
Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
|
||||
Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
|
||||
Value positive = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sge, value, zero);
|
||||
Value dir =
|
||||
rewriter.create<arith::SelectOp>(loc, positive, roundUp, roundDown);
|
||||
Value val = rewriter.create<arith::AddIOp>(loc, dir, multiply64);
|
||||
Value valid = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32);
|
||||
multiply64 =
|
||||
rewriter.create<arith::SelectOp>(loc, valid, val, multiply64);
|
||||
}
|
||||
|
||||
// The computation below equates to the following pseudocode:
|
||||
// int64_t result = (int64_t)value * multiplier + round;
|
||||
// result = result >> shift;
|
||||
//
|
||||
// Note that multiply and shift need to be perform in i64 to preserve bits.
|
||||
Value result64 = rewriter.create<arith::ShRSIOp>(loc, multiply64, shift64);
|
||||
Value result32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, result64);
|
||||
|
||||
rewriter.replaceOp(op, result32);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
|
||||
public:
|
||||
using OpRewritePattern<tosa::ApplyScaleOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
Location loc = op.getLoc();
|
||||
|
||||
Type resultTy = op.getType();
|
||||
Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
|
||||
Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy);
|
||||
|
||||
Value value = op.value();
|
||||
if (getElementTypeOrSelf(value.getType()).getIntOrFloatBitWidth() > 32) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value value32 = op.value();
|
||||
Value multiplier32 = op.multiplier();
|
||||
Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.shift());
|
||||
|
||||
// Constants used during the scaling operation.
|
||||
Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter);
|
||||
Value one32 = getConstantValue(loc, i32Ty, 1, rewriter);
|
||||
Value two32 = getConstantValue(loc, i32Ty, 2, rewriter);
|
||||
Value thirty32 = getConstantValue(loc, i32Ty, 30, rewriter);
|
||||
Value thirtyTwo32 = getConstantValue(loc, i32Ty, 32, rewriter);
|
||||
Value thirtyTwo64 = getConstantValue(loc, i64Ty, 32, rewriter);
|
||||
|
||||
// Compute the multiplication in 64-bits then select the high / low parts.
|
||||
Value value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value32);
|
||||
Value multiplier64 =
|
||||
rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
|
||||
Value shift64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, shift8);
|
||||
Value multiply64 =
|
||||
rewriter.create<arith::MulIOp>(loc, value64, multiplier64);
|
||||
|
||||
// Multiply as a pair of i64 values to guarantee the end value fits.
|
||||
Value result64 = rewriter.create<arith::MulIOp>(loc, value64, multiplier64);
|
||||
result64 = rewriter.create<arith::AddIOp>(loc, result64, round64);
|
||||
result64 = rewriter.create<arith::ShRSIOp>(loc, result64, shift64);
|
||||
// Grab out the high/low of the computation
|
||||
Value high64 =
|
||||
rewriter.create<arith::ShRUIOp>(loc, multiply64, thirtyTwo64);
|
||||
Value high32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, high64);
|
||||
Value low32 = rewriter.create<arith::MulIOp>(loc, value32, multiplier32);
|
||||
|
||||
Value result32 = rewriter.create<arith::TruncIOp>(loc, resultTy, result64);
|
||||
// Determine the direction and amount to shift the high bits.
|
||||
Value shiftOver32 = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
|
||||
Value roundHighBits = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32);
|
||||
|
||||
rewriter.replaceOp(op, result32);
|
||||
Value shiftHighL =
|
||||
rewriter.create<arith::SubIOp>(loc, thirtyTwo32, shift32);
|
||||
Value shiftHighR =
|
||||
rewriter.create<arith::SubIOp>(loc, shift32, thirtyTwo32);
|
||||
|
||||
shiftHighL =
|
||||
rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, shiftHighL);
|
||||
shiftHighR =
|
||||
rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);
|
||||
|
||||
// Conditionally perform our double round.
|
||||
if (op.double_round()) {
|
||||
Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
|
||||
Value valuePositive = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sge, value32, zero32);
|
||||
|
||||
Value roundDir =
|
||||
rewriter.create<arith::SelectOp>(loc, valuePositive, one32, negOne32);
|
||||
roundDir =
|
||||
rewriter.create<arith::SelectOp>(loc, shiftOver32, roundDir, zero32);
|
||||
|
||||
Value shiftLow = rewriter.create<arith::ShRUIOp>(loc, low32, thirty32);
|
||||
Value rounded = rewriter.create<arith::AddIOp>(loc, shiftLow, roundDir);
|
||||
Value carry = rewriter.create<arith::ShRSIOp>(loc, rounded, two32);
|
||||
|
||||
Value shiftRound =
|
||||
rewriter.create<arith::ShLIOp>(loc, roundDir, thirty32);
|
||||
|
||||
low32 = rewriter.create<arith::AddIOp>(loc, low32, shiftRound);
|
||||
high32 = rewriter.create<arith::AddIOp>(loc, high32, carry);
|
||||
}
|
||||
|
||||
// Conditionally apply rounding in the low bits.
|
||||
{
|
||||
Value shiftSubOne = rewriter.create<arith::SubIOp>(loc, shift32, one32);
|
||||
Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
|
||||
roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, zero32,
|
||||
roundBit);
|
||||
|
||||
Value newLow32 = rewriter.create<arith::AddIOp>(loc, low32, roundBit);
|
||||
Value wasRounded = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::ugt, low32, newLow32);
|
||||
low32 = newLow32;
|
||||
|
||||
Value rounded32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, wasRounded);
|
||||
high32 = rewriter.create<arith::AddIOp>(loc, high32, rounded32);
|
||||
}
|
||||
|
||||
// Conditionally apply rounding in the high bits.
|
||||
{
|
||||
Value shiftSubOne =
|
||||
rewriter.create<arith::SubIOp>(loc, shiftHighR, one32);
|
||||
Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
|
||||
roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, roundBit,
|
||||
zero32);
|
||||
high32 = rewriter.create<arith::AddIOp>(loc, high32, roundBit);
|
||||
}
|
||||
|
||||
// Combine the correct high/low bits into the final rescale result.
|
||||
high32 = rewriter.create<arith::ShLIOp>(loc, high32, shiftHighL);
|
||||
high32 = rewriter.create<arith::ShRSIOp>(loc, high32, shiftHighR);
|
||||
low32 = rewriter.create<arith::ShRUIOp>(loc, low32, shift32);
|
||||
low32 = rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, low32);
|
||||
|
||||
// Apply the rounding behavior and shift to the final alignment.
|
||||
Value result = rewriter.create<arith::AddIOp>(loc, low32, high32);
|
||||
|
||||
// Truncate if necessary.
|
||||
if (!getElementTypeOrSelf(resultTy).isInteger(32)) {
|
||||
result = rewriter.create<arith::TruncIOp>(loc, resultTy, result);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -158,6 +250,9 @@ void mlir::tosa::populateTosaToArithConversionPatterns(
|
|||
}
|
||||
|
||||
void mlir::tosa::populateTosaRescaleToArithConversionPatterns(
|
||||
RewritePatternSet *patterns) {
|
||||
patterns->add<ApplyScaleOpConverter>(patterns->getContext());
|
||||
RewritePatternSet *patterns, bool include32Bit) {
|
||||
patterns->add<ApplyScaleGenericOpConverter>(patterns->getContext(), 100);
|
||||
if (include32Bit) {
|
||||
patterns->add<ApplyScale32BitOpConverter>(patterns->getContext(), 200);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,7 +36,8 @@ public:
|
|||
mlir::tosa::populateTosaToArithConversionPatterns(&patterns);
|
||||
|
||||
if (this->includeApplyRescale) {
|
||||
mlir::tosa::populateTosaRescaleToArithConversionPatterns(&patterns);
|
||||
mlir::tosa::populateTosaRescaleToArithConversionPatterns(&patterns,
|
||||
this->use32Bit);
|
||||
target.addIllegalOp<tosa::ApplyScaleOp>();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,119 +1,126 @@
|
|||
// RUN: mlir-opt --split-input-file --tosa-to-arith="include-apply-rescale=true" %s -verify-diagnostics -o -| FileCheck %s
|
||||
// RUN: mlir-opt --split-input-file --tosa-to-arith="include-apply-rescale=true use-32-bit=true" %s -verify-diagnostics -o -| FileCheck %s
|
||||
// RUN: mlir-opt --split-input-file --tosa-to-arith="include-apply-rescale=false" %s -verify-diagnostics -o -| FileCheck --check-prefix="SCALE" %s
|
||||
|
||||
// CHECK-LABEL: func @const_test
|
||||
func.func @const_test() -> (tensor<i32>) {
|
||||
// CHECK: [[C3:%.+]] = arith.constant dense<3> : tensor<i32>
|
||||
%0 = "tosa.const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
%result = "tosa.const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
|
||||
// CHECK: return [[C3]]
|
||||
return %0 : tensor<i32>
|
||||
return %result : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @apply_scale_test_i32
|
||||
// SCALE: "tosa.apply_scale"
|
||||
func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
|
||||
// CHECK-DAG: [[C1_8:%.+]] = arith.constant 1 : i8
|
||||
// CHECK-DAG: [[C1_32:%.+]] = arith.constant 1 : i32
|
||||
// CHECK-DAG: [[C1_64:%.+]] = arith.constant 1 : i64
|
||||
// CHECK-DAG: [[SHIFT_MINUS_ONE_8:%.+]] = arith.subi %arg2, [[C1_8]]
|
||||
// CHECK-DAG: %[[S32:.+]] = arith.extui %arg2 : i8 to i32
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i32
|
||||
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : i32
|
||||
// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : i32
|
||||
// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : i32
|
||||
// CHECK-DAG: %[[C32L:.+]] = arith.constant 32 : i64
|
||||
|
||||
// CHECK-DAG: [[SHIFT_32:%.+]] = arith.extsi %arg2 : i8 to i32
|
||||
// CHECK-DAG: [[SHIFT_MINUS_ONE_64:%.+]] = arith.extsi [[SHIFT_MINUS_ONE_8]] : i8 to i64
|
||||
// CHECK-DAG: [[SHIFTED_64:%.+]] = arith.shli [[C1_64]], [[SHIFT_MINUS_ONE_64]]
|
||||
// Compute the high-low values of the matmul in 64-bits.
|
||||
// CHECK-DAG: %[[V64:.+]] = arith.extsi %arg0 : i32 to i64
|
||||
// CHECK-DAG: %[[M64:.+]] = arith.extsi %arg1 : i32 to i64
|
||||
// CHECK-DAG: %[[MUL64:.+]] = arith.muli %[[V64]], %[[M64]]
|
||||
// CHECK-DAG: %[[HI64:.+]] = arith.shrui %[[MUL64]], %[[C32L]]
|
||||
// CHECK-DAG: %[[HI:.+]] = arith.trunci %[[HI64]] : i64 to i32
|
||||
// CHECK-DAG: %[[LOW:.+]] = arith.muli %arg0, %arg1
|
||||
|
||||
// CHECK-DAG: [[C0_32:%.+]] = arith.constant 0 : i32
|
||||
// CHECK-DAG: [[C30_32:%.+]] = arith.constant 30 : i32
|
||||
// CHECK-DAG: [[SECOND_BIAS:%.+]] = arith.shli [[C1_32]], [[C30_32]]
|
||||
// CHECK-DAG: [[SECOND_BIAS_64:%.+]] = arith.extsi [[SECOND_BIAS]] : i32 to i64
|
||||
// CHECK-DAG: [[POSITIVE_ROUND:%.+]] = arith.addi [[SHIFTED_64]], [[SECOND_BIAS_64]]
|
||||
// CHECK-DAG: [[NEGATIVE_ROUND:%.+]] = arith.subi [[SHIFTED_64]], [[SECOND_BIAS_64]]
|
||||
// CHECK-DAG: [[VALUE_NEGATIVE:%.+]] = arith.cmpi sge, %arg0, [[C0_32]] : i32
|
||||
// CHECK-DAG: [[DOUBLE_ROUNDED:%.+]] = arith.select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : i64
|
||||
// CHECK-DAG: [[C32_32:%.+]] = arith.constant 32 : i32
|
||||
// CHECK-DAG: [[IS_32BIT_SHIFT:%.+]] = arith.cmpi sge, [[SHIFT_32]], [[C32_32]]
|
||||
// CHECK-DAG: [[ROUND:%.+]] = arith.select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]]
|
||||
// Determine whether the high bits need to shift left or right and by how much.
|
||||
// CHECK-DAG: %[[OVER31:.+]] = arith.cmpi sge, %[[S32]], %[[C32]]
|
||||
// CHECK-DAG: %[[OVER32:.+]] = arith.cmpi sgt, %[[S32]], %[[C32]]
|
||||
// CHECK-DAG: %[[HISHLN:.+]] = arith.subi %[[C32]], %[[S32]]
|
||||
// CHECK-DAG: %[[HISHRN:.+]] = arith.subi %[[S32]], %[[C32]]
|
||||
// CHECK-DAG: %[[HISHL:.+]] = arith.select %[[OVER31]], %[[C0]], %[[HISHLN]]
|
||||
// CHECK-DAG: %[[HISHR:.+]] = arith.select %[[OVER31]], %[[HISHRN]], %[[C0]]
|
||||
|
||||
// CHECK-DAG: [[VAL_64:%.+]] = arith.extsi %arg0 : i32 to i64
|
||||
// CHECK-DAG: [[MULTIPLY_64:%.+]] = arith.extsi %arg1 : i32 to i64
|
||||
// CHECK-DAG: [[SHIFT_64:%.+]] = arith.extsi %arg2 : i8 to i64
|
||||
// CHECK-DAG: [[SCALED:%.+]] = arith.muli [[VAL_64]], [[MULTIPLY_64]]
|
||||
// CHECK-DAG: [[BIASED:%.+]] = arith.addi [[SCALED]], [[ROUND]]
|
||||
// CHECK-DAG: [[DOWNSHIFTED:%.+]] = arith.shrsi [[BIASED]], [[SHIFT_64]]
|
||||
// CHECK: [[TRUNCATED:%.+]] = arith.trunci [[DOWNSHIFTED]]
|
||||
// Apply double rounding.
|
||||
// CHECK-DAG: %[[CN1:.+]] = arith.constant -1
|
||||
// CHECK-DAG: %[[POS:.+]] = arith.cmpi sge, %arg0, %[[C0]]
|
||||
// CHECK-DAG: %[[DIR:.+]] = arith.select %[[POS]], %[[C1]], %[[CN1]]
|
||||
// CHECK-DAG: %[[DRND:.+]] = arith.select %[[OVER31]], %[[DIR]], %[[C0]]
|
||||
// CHECK-DAG: %[[DSHFTR:.+]] = arith.shrui %[[LOW]], %[[C30]]
|
||||
// CHECK-DAG: %[[DRNDED:.+]] = arith.addi %[[DSHFTR]], %[[DRND]]
|
||||
// CHECK-DAG: %[[DCARRY:.+]] = arith.shrsi %[[DRNDED]], %[[C2:.+]]
|
||||
// CHECK-DAG: %[[DBIT:.+]] = arith.shli %[[DRND]], %[[C30]]
|
||||
// CHECK-DAG: %[[DLOW:.+]] = arith.addi %[[LOW]], %[[DBIT]]
|
||||
// CHECK-DAG: %[[DHI:.+]] = arith.addi %[[HI]], %[[DCARRY]]
|
||||
|
||||
// SCALE: "tosa.apply_scale"
|
||||
%0 = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (i32, i32, i8) -> i32
|
||||
return %0 : i32
|
||||
// Apply low-bit rounding.
|
||||
// CHECK-DAG: %[[SHFTM1:.+]] = arith.subi %[[S32]], %[[C1]]
|
||||
// CHECK-DAG: %[[LBIT:.+]] = arith.shli %[[C1]], %[[SHFTM1]]
|
||||
// CHECK-DAG: %[[HALF:.+]] = arith.select %[[OVER32]], %[[C0]], %[[LBIT]]
|
||||
// CHECK-DAG: %[[LADD:.+]] = arith.addi %[[DLOW]], %[[HALF]]
|
||||
// CHECK-DAG: %[[LLO:.+]] = arith.cmpi ugt, %[[DLOW]], %[[LADD]]
|
||||
// CHECK-DAG: %[[LCARRY:.+]] = arith.extui %[[LLO]] : i1 to i32
|
||||
// CHECK-DAG: %[[LRNDED:.+]] = arith.addi %[[DHI]], %[[LCARRY]]
|
||||
|
||||
// Apply high-bit rounding.
|
||||
// CHECK-DAG: %[[HISHRM1:.+]] = arith.subi %[[HISHR]], %[[C1]]
|
||||
// CHECK-DAG: %[[LHISHFT:.+]] = arith.shli %[[C1]], %[[HISHRM1]]
|
||||
// CHECK-DAG: %[[LHI:.+]] = arith.select %[[OVER32]], %[[LHISHFT]], %[[C0]]
|
||||
// CHECK-DAG: %[[FHI:.+]] = arith.addi %[[LRNDED]], %[[LHI]]
|
||||
|
||||
// Combine hi-low into the final result.
|
||||
// CHECK-DAG: %[[HIL:.+]] = arith.shli %[[FHI]], %[[HISHL]]
|
||||
// CHECK-DAG: %[[HIALIGN:.+]] = arith.shrsi %[[HIL:.+]], %[[HISHR]]
|
||||
// CHECK-DAG: %[[LOR:.+]] = arith.shrui %[[LADD]], %[[S32]]
|
||||
// CHECK-DAG: %[[LOWALIGN:.+]] = arith.select %[[OVER31]], %[[C0]], %[[LOR]]
|
||||
// CHECK-DAG: %[[RESULT:.+]] = arith.addi %[[LOWALIGN]], %[[HIALIGN]]
|
||||
// CHECK: return %[[RESULT]]
|
||||
%res = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (i32, i32, i8) -> i32
|
||||
return %res : i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @apply_scale_test_vector
|
||||
// SCALE: "tosa.apply_scale"
|
||||
func.func @apply_scale_test_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>, %arg2 : vector<4xi8>) -> (vector<4xi32>) {
|
||||
// CHECK-DAG: [[C1_8:%.+]] = arith.constant dense<1> : vector<4xi8>
|
||||
// CHECK-DAG: [[C1_32:%.+]] = arith.constant dense<1> : vector<4xi32>
|
||||
// CHECK-DAG: [[C1_64:%.+]] = arith.constant dense<1> : vector<4xi64>
|
||||
// CHECK-DAG: [[SHIFT_MINUS_ONE_8:%.+]] = arith.subi %arg2, [[C1_8]]
|
||||
|
||||
// CHECK-DAG: [[SHIFT_32:%.+]] = arith.extsi %arg2 : vector<4xi8> to vector<4xi32>
|
||||
// CHECK-DAG: [[SHIFT_MINUS_ONE_64:%.+]] = arith.extsi [[SHIFT_MINUS_ONE_8]] : vector<4xi8> to vector<4xi64>
|
||||
// CHECK-DAG: [[SHIFTED_64:%.+]] = arith.shli [[C1_64]], [[SHIFT_MINUS_ONE_64]]
|
||||
|
||||
// CHECK-DAG: [[C0_32:%.+]] = arith.constant dense<0> : vector<4xi32>
|
||||
// CHECK-DAG: [[C30_32:%.+]] = arith.constant dense<30> : vector<4xi32>
|
||||
// CHECK-DAG: [[SECOND_BIAS:%.+]] = arith.shli [[C1_32]], [[C30_32]]
|
||||
// CHECK-DAG: [[SECOND_BIAS_64:%.+]] = arith.extsi [[SECOND_BIAS]] : vector<4xi32> to vector<4xi64>
|
||||
// CHECK-DAG: [[POSITIVE_ROUND:%.+]] = arith.addi [[SHIFTED_64]], [[SECOND_BIAS_64]]
|
||||
// CHECK-DAG: [[NEGATIVE_ROUND:%.+]] = arith.subi [[SHIFTED_64]], [[SECOND_BIAS_64]]
|
||||
// CHECK-DAG: [[VALUE_NEGATIVE:%.+]] = arith.cmpi sge, %arg0, [[C0_32]] : vector<4xi32>
|
||||
// CHECK-DAG: [[DOUBLE_ROUNDED:%.+]] = arith.select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : vector<4xi1>, vector<4xi64>
|
||||
// CHECK-DAG: [[C32_32:%.+]] = arith.constant dense<32> : vector<4xi32>
|
||||
// CHECK-DAG: [[IS_32BIT_SHIFT:%.+]] = arith.cmpi sge, [[SHIFT_32]], [[C32_32]]
|
||||
// CHECK-DAG: [[ROUND:%.+]] = arith.select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]]
|
||||
|
||||
// CHECK-DAG: [[VAL_64:%.+]] = arith.extsi %arg0 : vector<4xi32> to vector<4xi64>
|
||||
// CHECK-DAG: [[MULTIPLY_64:%.+]] = arith.extsi %arg1 : vector<4xi32> to vector<4xi64>
|
||||
// CHECK-DAG: [[SHIFT_64:%.+]] = arith.extsi %arg2 : vector<4xi8> to vector<4xi64>
|
||||
// CHECK-DAG: [[SCALED:%.+]] = arith.muli [[VAL_64]], [[MULTIPLY_64]]
|
||||
// CHECK-DAG: [[BIASED:%.+]] = arith.addi [[SCALED]], [[ROUND]]
|
||||
// CHECK-DAG: [[DOWNSHIFTED:%.+]] = arith.shrsi [[BIASED]], [[SHIFT_64]]
|
||||
// CHECK: [[TRUNCATED:%.+]] = arith.trunci [[DOWNSHIFTED]]
|
||||
|
||||
%0 = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32>
|
||||
return %0 : vector<4xi32>
|
||||
// CHECK-NOT: "tosa.apply_scale"
|
||||
%res = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32>
|
||||
return %res : vector<4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @apply_scale_test_i48
|
||||
// SCALE: "tosa.apply_scale"
|
||||
func.func @apply_scale_test_i48(%arg0 : i48, %arg1 : i32, %arg2 : i8) -> (i32) {
|
||||
// CHECK-DAG: [[C1_8:%.+]] = arith.constant 1 : i8
|
||||
// CHECK-DAG: [[C1_32:%.+]] = arith.constant 1 : i32
|
||||
// CHECK-DAG: [[C1_64:%.+]] = arith.constant 1 : i64
|
||||
// CHECK-DAG: [[C30_32:%.+]] = arith.constant 30 : i32
|
||||
// CHECK-DAG: [[C0_32:%.+]] = arith.constant 0 : i48
|
||||
// CHECK-DAG: [[C32_32:%.+]] = arith.constant 32 : i32
|
||||
// CHECK-DAG: [[SHIFT_MINUS_ONE_8:%.+]] = arith.subi %arg2, [[C1_8]]
|
||||
// CHECK-DAG: [[SHIFT_32:%.+]] = arith.extsi %arg2 : i8 to i32
|
||||
// CHECK-DAG: [[SHIFT_MINUS_ONE_64:%.+]] = arith.extsi [[SHIFT_MINUS_ONE_8]] : i8 to i64
|
||||
// CHECK-DAG: [[SHIFTED_64:%.+]] = arith.shli [[C1_64]], [[SHIFT_MINUS_ONE_64]]
|
||||
// CHECK-DAG: [[SECOND_BIAS:%.+]] = arith.shli [[C1_32]], [[C30_32]]
|
||||
// CHECK-DAG: [[SECOND_BIAS_64:%.+]] = arith.extsi [[SECOND_BIAS]] : i32 to i64
|
||||
// CHECK-DAG: [[POSITIVE_ROUND:%.+]] = arith.addi [[SHIFTED_64]], [[SECOND_BIAS_64]]
|
||||
// CHECK-DAG: [[NEGATIVE_ROUND:%.+]] = arith.subi [[SHIFTED_64]], [[SECOND_BIAS_64]]
|
||||
// CHECK-DAG: [[VALUE_NEGATIVE:%.+]] = arith.cmpi sge, %arg0, [[C0_32]] : i48
|
||||
// CHECK-DAG: [[DOUBLE_ROUNDED:%.+]] = arith.select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : i64
|
||||
// CHECK-DAG: [[IS_32BIT_SHIFT:%.+]] = arith.cmpi sge, [[SHIFT_32]], [[C32_32]]
|
||||
// CHECK-DAG: [[ROUND:%.+]] = arith.select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]]
|
||||
// CHECK-DAG: [[VAL_64:%.+]] = arith.extsi %arg0 : i48 to i64
|
||||
// CHECK-DAG: [[MULTIPLY_64:%.+]] = arith.extsi %arg1 : i32 to i64
|
||||
// CHECK-DAG: [[SHIFT_64:%.+]] = arith.extsi %arg2 : i8 to i64
|
||||
// CHECK-DAG: [[SCALED:%.+]] = arith.muli [[VAL_64]], [[MULTIPLY_64]]
|
||||
// CHECK-DAG: [[BIASED:%.+]] = arith.addi [[SCALED]], [[ROUND]]
|
||||
// CHECK-DAG: [[DOWNSHIFTED:%.+]] = arith.shrsi [[BIASED]], [[SHIFT_64]]
|
||||
// CHECK: [[TRUNCATED:%.+]] = arith.trunci [[DOWNSHIFTED]]
|
||||
%0 = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (i48, i32, i8) -> i32
|
||||
return %0 : i32
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i48
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i64
|
||||
// CHECK-DAG: %[[C31:.+]] = arith.constant 31 : i32
|
||||
|
||||
// Multiply in 64 bits.
|
||||
// CHECK-DAG: %[[V64:.+]] = arith.extsi %arg0 : i48 to i64
|
||||
// CHECK-DAG: %[[M64:.+]] = arith.extsi %arg1 : i32 to i64
|
||||
// CHECK-DAG: %[[MUL:.+]] = arith.muli %[[V64]], %[[M64]]
|
||||
|
||||
// Round normally.
|
||||
// CHECK-DAG: %[[S32:.+]] = arith.extui %arg2 : i8 to i32
|
||||
// CHECK-DAG: %[[S64:.+]] = arith.extui %[[S32]] : i32 to i64
|
||||
// CHECK-DAG: %[[ONEL:.+]] = arith.shli %[[C1]], %[[S64]] : i64
|
||||
// CHECK-DAG: %[[ONER:.+]] = arith.shrui %[[ONEL]], %[[C1]]
|
||||
// CHECK-DAG: %[[ROUND:.+]] = arith.addi %[[MUL]], %[[ONER]]
|
||||
|
||||
// Apply double rounding.
|
||||
// CHECK-DAG: %[[DUP:.+]] = arith.constant 1073741824 : i64
|
||||
// CHECK-DAG: %[[DDOWN:.+]] = arith.constant -1073741824 : i64
|
||||
// CHECK-DAG: %[[POS:.+]] = arith.cmpi sge, %arg0, %[[C0]]
|
||||
// CHECK-DAG: %[[DBIT:.+]] = arith.select %[[POS]], %[[DUP]], %[[DDOWN]]
|
||||
// CHECK-DAG: %[[DRND:.+]] = arith.addi %[[DBIT]], %[[ROUND]]
|
||||
// CHECK-DAG: %[[USED:.+]] = arith.cmpi sgt, %[[S32]], %[[C31]] : i32
|
||||
// CHECK-DAG: %[[RES64:.+]] = arith.select %[[USED]], %[[DRND]], %[[ROUND]] : i64
|
||||
|
||||
// Shift and truncate final answer.
|
||||
// CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]]
|
||||
// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32
|
||||
// CHECK: return %[[TRUNC]]
|
||||
%res = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (i48, i32, i8) -> i32
|
||||
return %res : i32
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue