forked from OSchip/llvm-project
Remove NaN constant from arith.minf, arith.maxf expansion
If any of the operands is NaN, return the operand instead of a new constant. When the rhs operand is a constant, the second arith.cmpf+select ops will be folded away. https://reviews.llvm.org/D117010 marks the two ops commutative, which will place the constant on the rhs. Reviewed By: herhut Differential Revision: https://reviews.llvm.org/D117011
This commit is contained in:
parent
bf9c8636f2
commit
be1aeb818c
|
@ -156,19 +156,16 @@ public:
|
|||
Value rhs = op.getRhs();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
// If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
|
||||
static_assert(pred == arith::CmpFPredicate::UGT ||
|
||||
pred == arith::CmpFPredicate::ULT);
|
||||
Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
|
||||
Value select = rewriter.create<SelectOp>(loc, cmp, lhs, rhs);
|
||||
|
||||
auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
|
||||
// Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'.
|
||||
Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
|
||||
lhs, rhs);
|
||||
|
||||
Value nan = rewriter.create<arith::ConstantFloatOp>(
|
||||
loc, APFloat::getQNaN(floatType.getFloatSemantics()), floatType);
|
||||
if (VectorType vectorType = lhs.getType().dyn_cast<VectorType>())
|
||||
nan = rewriter.create<SplatOp>(loc, vectorType, nan);
|
||||
|
||||
rewriter.replaceOpWithNewOp<SelectOp>(op, isNaN, nan, select);
|
||||
rhs, rhs);
|
||||
rewriter.replaceOpWithNewOp<SelectOp>(op, isNaN, rhs, select);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -226,8 +223,8 @@ void mlir::arith::populateArithmeticExpandOpsPatterns(
|
|||
CeilDivSIOpConverter,
|
||||
CeilDivUIOpConverter,
|
||||
FloorDivSIOpConverter,
|
||||
MaxMinFOpConverter<MaxFOp, arith::CmpFPredicate::OGT>,
|
||||
MaxMinFOpConverter<MinFOp, arith::CmpFPredicate::OLT>,
|
||||
MaxMinFOpConverter<MaxFOp, arith::CmpFPredicate::UGT>,
|
||||
MaxMinFOpConverter<MinFOp, arith::CmpFPredicate::ULT>,
|
||||
MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
|
||||
MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
|
||||
MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
|
||||
|
|
|
@ -154,11 +154,10 @@ func @maxf(%a: f32, %b: f32) -> f32 {
|
|||
return %result : f32
|
||||
}
|
||||
// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
|
||||
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ogt, %[[LHS]], %[[RHS]] : f32
|
||||
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS]], %[[RHS]] : f32
|
||||
// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32
|
||||
// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[RHS]] : f32
|
||||
// CHECK-NEXT: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32
|
||||
// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : f32
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32
|
||||
// CHECK-NEXT: return %[[RESULT]] : f32
|
||||
|
||||
// -----
|
||||
|
@ -169,12 +168,10 @@ func @maxf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> {
|
|||
return %result : vector<4xf16>
|
||||
}
|
||||
// CHECK-SAME: %[[LHS:.*]]: vector<4xf16>, %[[RHS:.*]]: vector<4xf16>)
|
||||
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ogt, %[[LHS]], %[[RHS]] : vector<4xf16>
|
||||
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS]], %[[RHS]] : vector<4xf16>
|
||||
// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]]
|
||||
// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[RHS]] : vector<4xf16>
|
||||
// CHECK-NEXT: %[[NAN:.*]] = arith.constant 0x7E00 : f16
|
||||
// CHECK-NEXT: %[[SPLAT_NAN:.*]] = splat %[[NAN]] : vector<4xf16>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[SPLAT_NAN]], %[[SELECT]]
|
||||
// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : vector<4xf16>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[RHS]], %[[SELECT]]
|
||||
// CHECK-NEXT: return %[[RESULT]] : vector<4xf16>
|
||||
|
||||
// -----
|
||||
|
@ -185,11 +182,10 @@ func @minf(%a: f32, %b: f32) -> f32 {
|
|||
return %result : f32
|
||||
}
|
||||
// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
|
||||
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf olt, %[[LHS]], %[[RHS]] : f32
|
||||
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ult, %[[LHS]], %[[RHS]] : f32
|
||||
// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32
|
||||
// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[RHS]] : f32
|
||||
// CHECK-NEXT: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32
|
||||
// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : f32
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32
|
||||
// CHECK-NEXT: return %[[RESULT]] : f32
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue