forked from OSchip/llvm-project
[mlir][tosa] Fix tosa.cast semantics to perform rounding/clipping
Rounding to integers requires rounding (for floating points) and clipping to the min/max values of the destination range. Added this behavior and updated tests appropriately. Reviewed By: sjarus, silvas Differential Revision: https://reviews.llvm.org/D102375
This commit is contained in:
parent
f6907152db
commit
3f8aafd790
|
@ -491,9 +491,34 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
|
|||
args.front(), zero);
|
||||
}
|
||||
|
||||
if (mlir::FPToSIOp::areCastCompatible(srcTy, dstTy))
|
||||
return rewriter.create<mlir::FPToSIOp>(loc, resultTypes, args,
|
||||
mlir::None);
|
||||
if (mlir::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
|
||||
auto zero =
|
||||
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||
auto half =
|
||||
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.5f));
|
||||
|
||||
auto intMin = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getF32FloatAttr(
|
||||
APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
|
||||
.getSExtValue()));
|
||||
|
||||
auto intMax = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getF32FloatAttr(
|
||||
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
|
||||
.getSExtValue()));
|
||||
|
||||
auto added = rewriter.create<AddFOp>(loc, args[0], half);
|
||||
auto subbed = rewriter.create<SubFOp>(loc, args[0], half);
|
||||
auto negative =
|
||||
rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT, args[0], zero);
|
||||
auto rounded =
|
||||
rewriter.create<mlir::SelectOp>(loc, negative, subbed, added);
|
||||
|
||||
auto clamped = clampHelper<mlir::CmpFOp>(loc, rounded, intMin, intMax,
|
||||
CmpFPredicate::OLT, rewriter);
|
||||
|
||||
return rewriter.create<mlir::FPToSIOp>(loc, dstTy, clamped);
|
||||
}
|
||||
|
||||
// Casting to boolean, integers need to only be checked as not-equal to
|
||||
// zero.
|
||||
|
@ -508,9 +533,23 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
|
|||
return rewriter.create<mlir::SignExtendIOp>(loc, resultTypes, args,
|
||||
mlir::None);
|
||||
|
||||
if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !bitExtend)
|
||||
return rewriter.create<mlir::TruncateIOp>(loc, resultTypes, args,
|
||||
mlir::None);
|
||||
if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !bitExtend) {
|
||||
auto intMin = rewriter.create<ConstantIntOp>(
|
||||
loc,
|
||||
APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
|
||||
.getSExtValue(),
|
||||
srcTy.getIntOrFloatBitWidth());
|
||||
|
||||
auto intMax = rewriter.create<ConstantIntOp>(
|
||||
loc,
|
||||
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
|
||||
.getSExtValue(),
|
||||
srcTy.getIntOrFloatBitWidth());
|
||||
|
||||
auto clamped = clampHelper<mlir::CmpIOp>(loc, args[0], intMin, intMax,
|
||||
CmpIPredicate::slt, rewriter);
|
||||
return rewriter.create<mlir::TruncateIOp>(loc, dstTy, clamped);
|
||||
}
|
||||
}
|
||||
|
||||
(void)rewriter.notifyMatchFailure(
|
||||
|
|
|
@ -213,6 +213,18 @@ func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
|
|||
%20 = "tosa.sigmoid"(%0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: constant 0.000000e+00
|
||||
// CHECK: constant 5.000000e-01
|
||||
// CHECK: constant -2.14748365E+9
|
||||
// CHECK: constant 2.14748365E+9
|
||||
// CHECK: addf
|
||||
// CHECK: subf
|
||||
// CHECK: cmpf olt
|
||||
// CHECK: select
|
||||
// CHECK: cmpf olt
|
||||
// CHECK: select
|
||||
// CHECK: cmpf olt
|
||||
// CHECK: select
|
||||
// CHECK: fptosi
|
||||
%21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32>
|
||||
|
||||
|
@ -358,6 +370,12 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
|
|||
%18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: constant -32768
|
||||
// CHECK: constant 32767
|
||||
// CHECK: cmpi slt
|
||||
// CHECK: select
|
||||
// CHECK: cmpi slt
|
||||
// CHECK: select
|
||||
// CHECK: trunci
|
||||
%19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
|
||||
|
||||
|
|
Loading…
Reference in New Issue