forked from OSchip/llvm-project
[mlir][StandardToSPIRV] Handle conversion of cmpi operation with i1
type operands. The instructions used to convert std.cmpi cannot have i1 types according to SPIR-V specification. A different set of operations are specified in the SPIR-V spec for comparing boolean types. Enhance the StandardToSPIRV lowering to target these instructions when operands to std.cmpi operation are of i1 type. Differential Revision: https://reviews.llvm.org/D79049
This commit is contained in:
parent
dcdb1b94e1
commit
1c12a95d9c
|
@ -184,6 +184,16 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Converts integer compare operation on i1 type opearnds to SPIR-V ops.
|
||||
class BoolCmpIOpPattern final : public SPIRVOpLowering<CmpIOp> {
|
||||
public:
|
||||
using SPIRVOpLowering<CmpIOp>::SPIRVOpLowering;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Converts integer compare operation to SPIR-V ops.
|
||||
class CmpIOpPattern final : public SPIRVOpLowering<CmpIOp> {
|
||||
public:
|
||||
|
@ -453,11 +463,43 @@ CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
|
|||
// CmpIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult
|
||||
BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
CmpIOpOperandAdaptor cmpIOpOperands(operands);
|
||||
|
||||
Type operandType = cmpIOp.lhs().getType();
|
||||
if (!operandType.isa<IntegerType>() ||
|
||||
operandType.cast<IntegerType>().getWidth() != 1)
|
||||
return failure();
|
||||
|
||||
switch (cmpIOp.getPredicate()) {
|
||||
#define DISPATCH(cmpPredicate, spirvOp) \
|
||||
case cmpPredicate: \
|
||||
rewriter.replaceOpWithNewOp<spirvOp>(cmpIOp, cmpIOp.getResult().getType(), \
|
||||
cmpIOpOperands.lhs(), \
|
||||
cmpIOpOperands.rhs()); \
|
||||
return success();
|
||||
|
||||
DISPATCH(CmpIPredicate::eq, spirv::LogicalEqualOp);
|
||||
DISPATCH(CmpIPredicate::ne, spirv::LogicalNotEqualOp);
|
||||
|
||||
#undef DISPATCH
|
||||
default:;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
CmpIOpOperandAdaptor cmpIOpOperands(operands);
|
||||
|
||||
Type operandType = cmpIOp.lhs().getType();
|
||||
if (operandType.isa<IntegerType>() &&
|
||||
operandType.cast<IntegerType>().getWidth() == 1)
|
||||
return failure();
|
||||
|
||||
switch (cmpIOp.getPredicate()) {
|
||||
#define DISPATCH(cmpPredicate, spirvOp) \
|
||||
case cmpPredicate: \
|
||||
|
@ -599,9 +641,10 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
|
|||
UnaryAndBinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
|
||||
BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
|
||||
BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
|
||||
ConstantCompositeOpPattern, ConstantScalarOpPattern, CmpFOpPattern,
|
||||
CmpIOpPattern, LoadOpPattern, ReturnOpPattern, SelectOpPattern,
|
||||
StoreOpPattern, TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
|
||||
BoolCmpIOpPattern, ConstantCompositeOpPattern, ConstantScalarOpPattern,
|
||||
CmpFOpPattern, CmpIOpPattern, LoadOpPattern, ReturnOpPattern,
|
||||
SelectOpPattern, StoreOpPattern,
|
||||
TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
|
||||
TypeCastingOpPattern<FPExtOp, spirv::FConvertOp>,
|
||||
TypeCastingOpPattern<FPTruncOp, spirv::FConvertOp>, XOrOpPattern>(
|
||||
context, typeConverter);
|
||||
|
|
|
@ -285,6 +285,15 @@ func @cmpi(%arg0 : i32, %arg1 : i32) {
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @boolcmpi
|
||||
func @boolcmpi(%arg0 : i1, %arg1 : i1) {
|
||||
// CHECK: spv.LogicalEqual
|
||||
%0 = cmpi "eq", %arg0, %arg1 : i1
|
||||
// CHECK: spv.LogicalNotEqual
|
||||
%1 = cmpi "ne", %arg0, %arg1 : i1
|
||||
return
|
||||
}
|
||||
|
||||
} // end module
|
||||
|
||||
// -----
|
||||
|
|
Loading…
Reference in New Issue