[mlir][spirv] Add lowering for std cmp ops.

Differential Revision: https://reviews.llvm.org/D72296
This commit is contained in:
Denis Khalikov 2020-01-07 21:47:49 -05:00 committed by Lei Zhang
parent 9883b14cd1
commit dd495e8a87
3 changed files with 123 additions and 4 deletions

View File

@ -39,6 +39,16 @@ public:
ConversionPatternRewriter &rewriter) const override;
};
/// Convert floating-point comparison operations to SPIR-V dialect.
class CmpFOpConversion final : public SPIRVOpLowering<CmpFOp> {
public:
using SPIRVOpLowering<CmpFOp>::SPIRVOpLowering;
PatternMatchResult
matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
/// Convert compare operation to SPIR-V dialect.
class CmpIOpConversion final : public SPIRVOpLowering<CmpIOp> {
public:
@ -195,6 +205,46 @@ PatternMatchResult ConstantIndexOpConversion::matchAndRewrite(
return matchSuccess();
}
//===----------------------------------------------------------------------===//
// CmpFOp
//===----------------------------------------------------------------------===//
PatternMatchResult
CmpFOpConversion::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
CmpFOpOperandAdaptor cmpFOpOperands(operands);
switch (cmpFOp.getPredicate()) {
#define DISPATCH(cmpPredicate, spirvOp) \
case cmpPredicate: \
rewriter.replaceOpWithNewOp<spirvOp>( \
cmpFOp, cmpFOp.getResult()->getType(), cmpFOpOperands.lhs(), \
cmpFOpOperands.rhs()); \
return matchSuccess();
// Ordered.
DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp);
DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp);
DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
// Unordered.
DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp);
DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp);
DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
#undef DISPATCH
default:
break;
}
return matchFailure();
}
//===----------------------------------------------------------------------===//
// CmpIOp
//===----------------------------------------------------------------------===//
@ -218,11 +268,12 @@ CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp);
DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp);
DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
DISPATCH(CmpIPredicate::ult, spirv::ULessThanOp);
DISPATCH(CmpIPredicate::ule, spirv::ULessThanEqualOp);
DISPATCH(CmpIPredicate::ugt, spirv::UGreaterThanOp);
DISPATCH(CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
#undef DISPATCH
default:
break;
}
return matchFailure();
}
@ -302,7 +353,7 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
// Add patterns that lower operations into SPIR-V dialect.
populateWithGenerated(context, &patterns);
patterns.insert<ConstantIndexOpConversion, CmpIOpConversion,
patterns.insert<ConstantIndexOpConversion, CmpFOpConversion, CmpIOpConversion,
IntegerOpConversion<AddIOp, spirv::IAddOp>,
IntegerOpConversion<MulIOp, spirv::IMulOp>,
IntegerOpConversion<SignedDivIOp, spirv::SDivOp>,

View File

@ -142,6 +142,39 @@ func @shift_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
return
}
//===----------------------------------------------------------------------===//
// std.cmpf
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @cmpf
func @cmpf(%arg0 : f32, %arg1 : f32) {
// CHECK: spv.FOrdEqual
%1 = cmpf "oeq", %arg0, %arg1 : f32
// CHECK: spv.FOrdGreaterThan
%2 = cmpf "ogt", %arg0, %arg1 : f32
// CHECK: spv.FOrdGreaterThanEqual
%3 = cmpf "oge", %arg0, %arg1 : f32
// CHECK: spv.FOrdLessThan
%4 = cmpf "olt", %arg0, %arg1 : f32
// CHECK: spv.FOrdLessThanEqual
%5 = cmpf "ole", %arg0, %arg1 : f32
// CHECK: spv.FOrdNotEqual
%6 = cmpf "one", %arg0, %arg1 : f32
// CHECK: spv.FUnordEqual
%7 = cmpf "ueq", %arg0, %arg1 : f32
// CHECK: spv.FUnordGreaterThan
%8 = cmpf "ugt", %arg0, %arg1 : f32
// CHECK: spv.FUnordGreaterThanEqual
%9 = cmpf "uge", %arg0, %arg1 : f32
// CHECK: spv.FUnordLessThan
%10 = cmpf "ult", %arg0, %arg1 : f32
// CHECK: FUnordLessThanEqual
%11 = cmpf "ule", %arg0, %arg1 : f32
// CHECK: spv.FUnordNotEqual
%12 = cmpf "une", %arg0, %arg1 : f32
return
}
//===----------------------------------------------------------------------===//
// std.cmpi
//===----------------------------------------------------------------------===//
@ -160,6 +193,14 @@ func @cmpi(%arg0 : i32, %arg1 : i32) {
%4 = cmpi "sgt", %arg0, %arg1 : i32
// CHECK: spv.SGreaterThanEqual
%5 = cmpi "sge", %arg0, %arg1 : i32
// CHECK: spv.ULessThan
%6 = cmpi "ult", %arg0, %arg1 : i32
// CHECK: spv.ULessThanEqual
%7 = cmpi "ule", %arg0, %arg1 : i32
// CHECK: spv.UGreaterThan
%8 = cmpi "ugt", %arg0, %arg1 : i32
// CHECK: spv.UGreaterThanEqual
%9 = cmpi "uge", %arg0, %arg1 : i32
return
}

View File

@ -51,6 +51,33 @@ spv.module "Logical" "GLSL450" {
%0 = spv.ULessThanEqual %arg0, %arg1 : vector<4xi32>
spv.Return
}
func @cmpf(%arg0 : f32, %arg1 : f32) {
// CHECK: spv.FOrdEqual
%1 = spv.FOrdEqual %arg0, %arg1 : f32
// CHECK: spv.FOrdGreaterThan
%2 = spv.FOrdGreaterThan %arg0, %arg1 : f32
// CHECK: spv.FOrdGreaterThanEqual
%3 = spv.FOrdGreaterThanEqual %arg0, %arg1 : f32
// CHECK: spv.FOrdLessThan
%4 = spv.FOrdLessThan %arg0, %arg1 : f32
// CHECK: spv.FOrdLessThanEqual
%5 = spv.FOrdLessThanEqual %arg0, %arg1 : f32
// CHECK: spv.FOrdNotEqual
%6 = spv.FOrdNotEqual %arg0, %arg1 : f32
// CHECK: spv.FUnordEqual
%7 = spv.FUnordEqual %arg0, %arg1 : f32
// CHECK: spv.FUnordGreaterThan
%8 = spv.FUnordGreaterThan %arg0, %arg1 : f32
// CHECK: spv.FUnordGreaterThanEqual
%9 = spv.FUnordGreaterThanEqual %arg0, %arg1 : f32
// CHECK: spv.FUnordLessThan
%10 = spv.FUnordLessThan %arg0, %arg1 : f32
// CHECK: spv.FUnordLessThanEqual
%11 = spv.FUnordLessThanEqual %arg0, %arg1 : f32
// CHECK: spv.FUnordNotEqual
%12 = spv.FUnordNotEqual %arg0, %arg1 : f32
spv.Return
}
}
// -----