forked from OSchip/llvm-project
[fir] Add fir.cmpc conversion
This patch adds the codegen for fir.cmpc. The real and imaginary parts are extracted and compared separately. For the .EQ. predicate the results are AND'd, for the .NE. predicate the results are OR'd, and for other predicates we keep only the result on the real parts. This patch is part of the upstreaming effort from fir-dev. Differential Revision: https://reviews.llvm.org/D113976 Co-authored-by: Eric Schweitz <eschweitz@nvidia.com> Co-authored-by: Jean Perier <jperier@nvidia.com>
This commit is contained in:
parent
738e7f1231
commit
f1dfc0275c
|
@ -487,6 +487,52 @@ static mlir::Type getComplexEleTy(mlir::Type complex) {
|
|||
return complex.cast<fir::ComplexType>().getElementType();
|
||||
}
|
||||
|
||||
/// Compare complex values
|
||||
///
|
||||
/// Per 10.1, the only comparisons available are .EQ. (oeq) and .NE. (une).
|
||||
///
|
||||
/// For completeness, all other comparison are done on the real component only.
|
||||
struct CmpcOpConversion : public FIROpConversion<fir::CmpcOp> {
|
||||
using FIROpConversion::FIROpConversion;
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(fir::CmpcOp cmp, OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::ValueRange operands = adaptor.getOperands();
|
||||
mlir::MLIRContext *ctxt = cmp.getContext();
|
||||
mlir::Type eleTy = convertType(getComplexEleTy(cmp.lhs().getType()));
|
||||
mlir::Type resTy = convertType(cmp.getType());
|
||||
mlir::Location loc = cmp.getLoc();
|
||||
auto pos0 = mlir::ArrayAttr::get(ctxt, rewriter.getI32IntegerAttr(0));
|
||||
SmallVector<mlir::Value, 2> rp{rewriter.create<mlir::LLVM::ExtractValueOp>(
|
||||
loc, eleTy, operands[0], pos0),
|
||||
rewriter.create<mlir::LLVM::ExtractValueOp>(
|
||||
loc, eleTy, operands[1], pos0)};
|
||||
auto rcp =
|
||||
rewriter.create<mlir::LLVM::FCmpOp>(loc, resTy, rp, cmp->getAttrs());
|
||||
auto pos1 = mlir::ArrayAttr::get(ctxt, rewriter.getI32IntegerAttr(1));
|
||||
SmallVector<mlir::Value, 2> ip{rewriter.create<mlir::LLVM::ExtractValueOp>(
|
||||
loc, eleTy, operands[0], pos1),
|
||||
rewriter.create<mlir::LLVM::ExtractValueOp>(
|
||||
loc, eleTy, operands[1], pos1)};
|
||||
auto icp =
|
||||
rewriter.create<mlir::LLVM::FCmpOp>(loc, resTy, ip, cmp->getAttrs());
|
||||
SmallVector<mlir::Value, 2> cp{rcp, icp};
|
||||
switch (cmp.getPredicate()) {
|
||||
case mlir::arith::CmpFPredicate::OEQ: // .EQ.
|
||||
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(cmp, resTy, cp);
|
||||
break;
|
||||
case mlir::arith::CmpFPredicate::UNE: // .NE.
|
||||
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(cmp, resTy, cp);
|
||||
break;
|
||||
default:
|
||||
rewriter.replaceOp(cmp, rcp.getResult());
|
||||
break;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// convert value of from-type to value of to-type
|
||||
struct ConvertOpConversion : public FIROpConversion<fir::ConvertOp> {
|
||||
using FIROpConversion::FIROpConversion;
|
||||
|
@ -1514,15 +1560,17 @@ public:
|
|||
AllocaOpConversion, BoxAddrOpConversion, BoxDimsOpConversion,
|
||||
BoxEleSizeOpConversion, BoxIsAllocOpConversion, BoxIsArrayOpConversion,
|
||||
BoxIsPtrOpConversion, BoxRankOpConversion, CallOpConversion,
|
||||
ConvertOpConversion, DispatchOpConversion, DispatchTableOpConversion,
|
||||
DTEntryOpConversion, DivcOpConversion, EmboxCharOpConversion,
|
||||
ExtractValueOpConversion, HasValueOpConversion, GlobalLenOpConversion,
|
||||
GlobalOpConversion, InsertOnRangeOpConversion, InsertValueOpConversion,
|
||||
IsPresentOpConversion, LoadOpConversion, NegcOpConversion,
|
||||
MulcOpConversion, SelectCaseOpConversion, SelectOpConversion,
|
||||
SelectRankOpConversion, SelectTypeOpConversion, StoreOpConversion,
|
||||
SubcOpConversion, UnboxCharOpConversion, UndefOpConversion,
|
||||
UnreachableOpConversion, ZeroOpConversion>(typeConverter);
|
||||
CmpcOpConversion, ConvertOpConversion, DispatchOpConversion,
|
||||
DispatchTableOpConversion, DTEntryOpConversion, DivcOpConversion,
|
||||
EmboxCharOpConversion, ExtractValueOpConversion, HasValueOpConversion,
|
||||
GlobalLenOpConversion, GlobalOpConversion, InsertOnRangeOpConversion,
|
||||
InsertValueOpConversion, IsPresentOpConversion, LoadOpConversion,
|
||||
NegcOpConversion, MulcOpConversion, SelectCaseOpConversion,
|
||||
SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion,
|
||||
StoreOpConversion, SubcOpConversion, UnboxCharOpConversion,
|
||||
UndefOpConversion, UnreachableOpConversion, ZeroOpConversion>(
|
||||
typeConverter);
|
||||
|
||||
mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
|
||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||
pattern);
|
||||
|
|
|
@ -521,6 +521,57 @@ func @fir_complex_neg(%a: !fir.complex<16>) -> !fir.complex<16> {
|
|||
|
||||
// -----
|
||||
|
||||
// Test FIR complex compare conversion
|
||||
|
||||
func @compare_complex_eq(%a : !fir.complex<8>, %b : !fir.complex<8>) -> i1 {
|
||||
%r = fir.cmpc "oeq", %a, %b : !fir.complex<8>
|
||||
return %r : i1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: llvm.func @compare_complex_eq
|
||||
// CHECK-SAME: [[A:%.*]]: !llvm.struct<(f64, f64)>,
|
||||
// CHECK-SAME: [[B:%.*]]: !llvm.struct<(f64, f64)>
|
||||
// CHECK-DAG: [[RA:%.*]] = llvm.extractvalue [[A]][0 : i32] : !llvm.struct<(f64, f64)>
|
||||
// CHECK-DAG: [[IA:%.*]] = llvm.extractvalue [[A]][1 : i32] : !llvm.struct<(f64, f64)>
|
||||
// CHECK-DAG: [[RB:%.*]] = llvm.extractvalue [[B]][0 : i32] : !llvm.struct<(f64, f64)>
|
||||
// CHECK-DAG: [[IB:%.*]] = llvm.extractvalue [[B]][1 : i32] : !llvm.struct<(f64, f64)>
|
||||
// CHECK-DAG: [[RESR:%.*]] = llvm.fcmp "oeq" [[RA]], [[RB]] : f64
|
||||
// CHECK-DAG: [[RESI:%.*]] = llvm.fcmp "oeq" [[IA]], [[IB]] : f64
|
||||
// CHECK: [[RES:%.*]] = llvm.and [[RESR]], [[RESI]] : i1
|
||||
// CHECK: return [[RES]] : i1
|
||||
|
||||
func @compare_complex_ne(%a : !fir.complex<8>, %b : !fir.complex<8>) -> i1 {
|
||||
%r = fir.cmpc "une", %a, %b : !fir.complex<8>
|
||||
return %r : i1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: llvm.func @compare_complex_ne
|
||||
// CHECK-SAME: [[A:%.*]]: !llvm.struct<(f64, f64)>,
|
||||
// CHECK-SAME: [[B:%.*]]: !llvm.struct<(f64, f64)>
|
||||
// CHECK-DAG: [[RA:%.*]] = llvm.extractvalue [[A]][0 : i32] : !llvm.struct<(f64, f64)>
|
||||
// CHECK-DAG: [[IA:%.*]] = llvm.extractvalue [[A]][1 : i32] : !llvm.struct<(f64, f64)>
|
||||
// CHECK-DAG: [[RB:%.*]] = llvm.extractvalue [[B]][0 : i32] : !llvm.struct<(f64, f64)>
|
||||
// CHECK-DAG: [[IB:%.*]] = llvm.extractvalue [[B]][1 : i32] : !llvm.struct<(f64, f64)>
|
||||
// CHECK-DAG: [[RESR:%.*]] = llvm.fcmp "une" [[RA]], [[RB]] : f64
|
||||
// CHECK-DAG: [[RESI:%.*]] = llvm.fcmp "une" [[IA]], [[IB]] : f64
|
||||
// CHECK: [[RES:%.*]] = llvm.or [[RESR]], [[RESI]] : i1
|
||||
// CHECK: return [[RES]] : i1
|
||||
|
||||
func @compare_complex_other(%a : !fir.complex<8>, %b : !fir.complex<8>) -> i1 {
|
||||
%r = fir.cmpc "ogt", %a, %b : !fir.complex<8>
|
||||
return %r : i1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: llvm.func @compare_complex_other
|
||||
// CHECK-SAME: [[A:%.*]]: !llvm.struct<(f64, f64)>,
|
||||
// CHECK-SAME: [[B:%.*]]: !llvm.struct<(f64, f64)>
|
||||
// CHECK-DAG: [[RA:%.*]] = llvm.extractvalue [[A]][0 : i32] : !llvm.struct<(f64, f64)>
|
||||
// CHECK-DAG: [[RB:%.*]] = llvm.extractvalue [[B]][0 : i32] : !llvm.struct<(f64, f64)>
|
||||
// CHECK: [[RESR:%.*]] = llvm.fcmp "ogt" [[RA]], [[RB]] : f64
|
||||
// CHECK: return [[RESR]] : i1
|
||||
|
||||
// -----
|
||||
|
||||
// Test `fir.convert` operation conversion from Float type.
|
||||
|
||||
func @convert_from_float(%arg0 : f32) {
|
||||
|
|
Loading…
Reference in New Issue