forked from OSchip/llvm-project
[mlir][spirv] Add lowering for std cmp ops.
Differential Revision: https://reviews.llvm.org/D72296
This commit is contained in:
parent
9883b14cd1
commit
dd495e8a87
|
@ -39,6 +39,16 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Convert floating-point comparison operations to SPIR-V dialect.
|
||||
class CmpFOpConversion final : public SPIRVOpLowering<CmpFOp> {
|
||||
public:
|
||||
using SPIRVOpLowering<CmpFOp>::SPIRVOpLowering;
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Convert compare operation to SPIR-V dialect.
|
||||
class CmpIOpConversion final : public SPIRVOpLowering<CmpIOp> {
|
||||
public:
|
||||
|
@ -195,6 +205,46 @@ PatternMatchResult ConstantIndexOpConversion::matchAndRewrite(
|
|||
return matchSuccess();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CmpFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PatternMatchResult
|
||||
CmpFOpConversion::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
CmpFOpOperandAdaptor cmpFOpOperands(operands);
|
||||
|
||||
switch (cmpFOp.getPredicate()) {
|
||||
#define DISPATCH(cmpPredicate, spirvOp) \
|
||||
case cmpPredicate: \
|
||||
rewriter.replaceOpWithNewOp<spirvOp>( \
|
||||
cmpFOp, cmpFOp.getResult()->getType(), cmpFOpOperands.lhs(), \
|
||||
cmpFOpOperands.rhs()); \
|
||||
return matchSuccess();
|
||||
|
||||
// Ordered.
|
||||
DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp);
|
||||
DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
|
||||
DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
|
||||
DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp);
|
||||
DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
|
||||
DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
|
||||
// Unordered.
|
||||
DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp);
|
||||
DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
|
||||
DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
|
||||
DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp);
|
||||
DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
|
||||
DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
|
||||
|
||||
#undef DISPATCH
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CmpIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -218,11 +268,12 @@ CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
|
|||
DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp);
|
||||
DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp);
|
||||
DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
|
||||
DISPATCH(CmpIPredicate::ult, spirv::ULessThanOp);
|
||||
DISPATCH(CmpIPredicate::ule, spirv::ULessThanEqualOp);
|
||||
DISPATCH(CmpIPredicate::ugt, spirv::UGreaterThanOp);
|
||||
DISPATCH(CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
|
||||
|
||||
#undef DISPATCH
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return matchFailure();
|
||||
}
|
||||
|
@ -302,7 +353,7 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
|
|||
OwningRewritePatternList &patterns) {
|
||||
// Add patterns that lower operations into SPIR-V dialect.
|
||||
populateWithGenerated(context, &patterns);
|
||||
patterns.insert<ConstantIndexOpConversion, CmpIOpConversion,
|
||||
patterns.insert<ConstantIndexOpConversion, CmpFOpConversion, CmpIOpConversion,
|
||||
IntegerOpConversion<AddIOp, spirv::IAddOp>,
|
||||
IntegerOpConversion<MulIOp, spirv::IMulOp>,
|
||||
IntegerOpConversion<SignedDivIOp, spirv::SDivOp>,
|
||||
|
|
|
@ -142,6 +142,39 @@ func @shift_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
|
|||
return
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// std.cmpf
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: @cmpf
|
||||
func @cmpf(%arg0 : f32, %arg1 : f32) {
|
||||
// CHECK: spv.FOrdEqual
|
||||
%1 = cmpf "oeq", %arg0, %arg1 : f32
|
||||
// CHECK: spv.FOrdGreaterThan
|
||||
%2 = cmpf "ogt", %arg0, %arg1 : f32
|
||||
// CHECK: spv.FOrdGreaterThanEqual
|
||||
%3 = cmpf "oge", %arg0, %arg1 : f32
|
||||
// CHECK: spv.FOrdLessThan
|
||||
%4 = cmpf "olt", %arg0, %arg1 : f32
|
||||
// CHECK: spv.FOrdLessThanEqual
|
||||
%5 = cmpf "ole", %arg0, %arg1 : f32
|
||||
// CHECK: spv.FOrdNotEqual
|
||||
%6 = cmpf "one", %arg0, %arg1 : f32
|
||||
// CHECK: spv.FUnordEqual
|
||||
%7 = cmpf "ueq", %arg0, %arg1 : f32
|
||||
// CHECK: spv.FUnordGreaterThan
|
||||
%8 = cmpf "ugt", %arg0, %arg1 : f32
|
||||
// CHECK: spv.FUnordGreaterThanEqual
|
||||
%9 = cmpf "uge", %arg0, %arg1 : f32
|
||||
// CHECK: spv.FUnordLessThan
|
||||
%10 = cmpf "ult", %arg0, %arg1 : f32
|
||||
// CHECK: FUnordLessThanEqual
|
||||
%11 = cmpf "ule", %arg0, %arg1 : f32
|
||||
// CHECK: spv.FUnordNotEqual
|
||||
%12 = cmpf "une", %arg0, %arg1 : f32
|
||||
return
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// std.cmpi
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -160,6 +193,14 @@ func @cmpi(%arg0 : i32, %arg1 : i32) {
|
|||
%4 = cmpi "sgt", %arg0, %arg1 : i32
|
||||
// CHECK: spv.SGreaterThanEqual
|
||||
%5 = cmpi "sge", %arg0, %arg1 : i32
|
||||
// CHECK: spv.ULessThan
|
||||
%6 = cmpi "ult", %arg0, %arg1 : i32
|
||||
// CHECK: spv.ULessThanEqual
|
||||
%7 = cmpi "ule", %arg0, %arg1 : i32
|
||||
// CHECK: spv.UGreaterThan
|
||||
%8 = cmpi "ugt", %arg0, %arg1 : i32
|
||||
// CHECK: spv.UGreaterThanEqual
|
||||
%9 = cmpi "uge", %arg0, %arg1 : i32
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -51,6 +51,33 @@ spv.module "Logical" "GLSL450" {
|
|||
%0 = spv.ULessThanEqual %arg0, %arg1 : vector<4xi32>
|
||||
spv.Return
|
||||
}
|
||||
func @cmpf(%arg0 : f32, %arg1 : f32) {
|
||||
// CHECK: spv.FOrdEqual
|
||||
%1 = spv.FOrdEqual %arg0, %arg1 : f32
|
||||
// CHECK: spv.FOrdGreaterThan
|
||||
%2 = spv.FOrdGreaterThan %arg0, %arg1 : f32
|
||||
// CHECK: spv.FOrdGreaterThanEqual
|
||||
%3 = spv.FOrdGreaterThanEqual %arg0, %arg1 : f32
|
||||
// CHECK: spv.FOrdLessThan
|
||||
%4 = spv.FOrdLessThan %arg0, %arg1 : f32
|
||||
// CHECK: spv.FOrdLessThanEqual
|
||||
%5 = spv.FOrdLessThanEqual %arg0, %arg1 : f32
|
||||
// CHECK: spv.FOrdNotEqual
|
||||
%6 = spv.FOrdNotEqual %arg0, %arg1 : f32
|
||||
// CHECK: spv.FUnordEqual
|
||||
%7 = spv.FUnordEqual %arg0, %arg1 : f32
|
||||
// CHECK: spv.FUnordGreaterThan
|
||||
%8 = spv.FUnordGreaterThan %arg0, %arg1 : f32
|
||||
// CHECK: spv.FUnordGreaterThanEqual
|
||||
%9 = spv.FUnordGreaterThanEqual %arg0, %arg1 : f32
|
||||
// CHECK: spv.FUnordLessThan
|
||||
%10 = spv.FUnordLessThan %arg0, %arg1 : f32
|
||||
// CHECK: spv.FUnordLessThanEqual
|
||||
%11 = spv.FUnordLessThanEqual %arg0, %arg1 : f32
|
||||
// CHECK: spv.FUnordNotEqual
|
||||
%12 = spv.FUnordNotEqual %arg0, %arg1 : f32
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
Loading…
Reference in New Issue