[fir] Add fir.cmpc conversion

This patch adds the codegen for fir.cmpc. The real and imaginary parts
are extracted and compared separately. For the .EQ. predicate the
results are AND'd, for the .NE. predicate the results are OR'd, and for
other predicates we keep only the result on the real parts.

This patch is part of the upstreaming effort from fir-dev.

Differential Revision: https://reviews.llvm.org/D113976

Co-authored-by: Eric Schweitz <eschweitz@nvidia.com>
Co-authored-by: Jean Perier <jperier@nvidia.com>
This commit is contained in:
Diana Picus 2021-11-16 10:17:29 +00:00
parent 738e7f1231
commit f1dfc0275c
2 changed files with 108 additions and 9 deletions

View File

@ -487,6 +487,52 @@ static mlir::Type getComplexEleTy(mlir::Type complex) {
return complex.cast<fir::ComplexType>().getElementType();
}
/// Compare complex values
///
/// Per 10.1, the only comparisons available are .EQ. (oeq) and .NE. (une).
///
/// For completeness, all other comparison are done on the real component only.
struct CmpcOpConversion : public FIROpConversion<fir::CmpcOp> {
using FIROpConversion::FIROpConversion;
mlir::LogicalResult
matchAndRewrite(fir::CmpcOp cmp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::ValueRange operands = adaptor.getOperands();
mlir::MLIRContext *ctxt = cmp.getContext();
mlir::Type eleTy = convertType(getComplexEleTy(cmp.lhs().getType()));
mlir::Type resTy = convertType(cmp.getType());
mlir::Location loc = cmp.getLoc();
auto pos0 = mlir::ArrayAttr::get(ctxt, rewriter.getI32IntegerAttr(0));
SmallVector<mlir::Value, 2> rp{rewriter.create<mlir::LLVM::ExtractValueOp>(
loc, eleTy, operands[0], pos0),
rewriter.create<mlir::LLVM::ExtractValueOp>(
loc, eleTy, operands[1], pos0)};
auto rcp =
rewriter.create<mlir::LLVM::FCmpOp>(loc, resTy, rp, cmp->getAttrs());
auto pos1 = mlir::ArrayAttr::get(ctxt, rewriter.getI32IntegerAttr(1));
SmallVector<mlir::Value, 2> ip{rewriter.create<mlir::LLVM::ExtractValueOp>(
loc, eleTy, operands[0], pos1),
rewriter.create<mlir::LLVM::ExtractValueOp>(
loc, eleTy, operands[1], pos1)};
auto icp =
rewriter.create<mlir::LLVM::FCmpOp>(loc, resTy, ip, cmp->getAttrs());
SmallVector<mlir::Value, 2> cp{rcp, icp};
switch (cmp.getPredicate()) {
case mlir::arith::CmpFPredicate::OEQ: // .EQ.
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(cmp, resTy, cp);
break;
case mlir::arith::CmpFPredicate::UNE: // .NE.
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(cmp, resTy, cp);
break;
default:
rewriter.replaceOp(cmp, rcp.getResult());
break;
}
return success();
}
};
/// convert value of from-type to value of to-type
struct ConvertOpConversion : public FIROpConversion<fir::ConvertOp> {
using FIROpConversion::FIROpConversion;
@ -1514,15 +1560,17 @@ public:
AllocaOpConversion, BoxAddrOpConversion, BoxDimsOpConversion,
BoxEleSizeOpConversion, BoxIsAllocOpConversion, BoxIsArrayOpConversion,
BoxIsPtrOpConversion, BoxRankOpConversion, CallOpConversion,
ConvertOpConversion, DispatchOpConversion, DispatchTableOpConversion,
DTEntryOpConversion, DivcOpConversion, EmboxCharOpConversion,
ExtractValueOpConversion, HasValueOpConversion, GlobalLenOpConversion,
GlobalOpConversion, InsertOnRangeOpConversion, InsertValueOpConversion,
IsPresentOpConversion, LoadOpConversion, NegcOpConversion,
MulcOpConversion, SelectCaseOpConversion, SelectOpConversion,
SelectRankOpConversion, SelectTypeOpConversion, StoreOpConversion,
SubcOpConversion, UnboxCharOpConversion, UndefOpConversion,
UnreachableOpConversion, ZeroOpConversion>(typeConverter);
CmpcOpConversion, ConvertOpConversion, DispatchOpConversion,
DispatchTableOpConversion, DTEntryOpConversion, DivcOpConversion,
EmboxCharOpConversion, ExtractValueOpConversion, HasValueOpConversion,
GlobalLenOpConversion, GlobalOpConversion, InsertOnRangeOpConversion,
InsertValueOpConversion, IsPresentOpConversion, LoadOpConversion,
NegcOpConversion, MulcOpConversion, SelectCaseOpConversion,
SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion,
StoreOpConversion, SubcOpConversion, UnboxCharOpConversion,
UndefOpConversion, UnreachableOpConversion, ZeroOpConversion>(
typeConverter);
mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
pattern);

View File

@ -521,6 +521,57 @@ func @fir_complex_neg(%a: !fir.complex<16>) -> !fir.complex<16> {
// -----
// Test FIR complex compare conversion
func @compare_complex_eq(%a : !fir.complex<8>, %b : !fir.complex<8>) -> i1 {
%r = fir.cmpc "oeq", %a, %b : !fir.complex<8>
return %r : i1
}
// CHECK-LABEL: llvm.func @compare_complex_eq
// CHECK-SAME: [[A:%.*]]: !llvm.struct<(f64, f64)>,
// CHECK-SAME: [[B:%.*]]: !llvm.struct<(f64, f64)>
// CHECK-DAG: [[RA:%.*]] = llvm.extractvalue [[A]][0 : i32] : !llvm.struct<(f64, f64)>
// CHECK-DAG: [[IA:%.*]] = llvm.extractvalue [[A]][1 : i32] : !llvm.struct<(f64, f64)>
// CHECK-DAG: [[RB:%.*]] = llvm.extractvalue [[B]][0 : i32] : !llvm.struct<(f64, f64)>
// CHECK-DAG: [[IB:%.*]] = llvm.extractvalue [[B]][1 : i32] : !llvm.struct<(f64, f64)>
// CHECK-DAG: [[RESR:%.*]] = llvm.fcmp "oeq" [[RA]], [[RB]] : f64
// CHECK-DAG: [[RESI:%.*]] = llvm.fcmp "oeq" [[IA]], [[IB]] : f64
// CHECK: [[RES:%.*]] = llvm.and [[RESR]], [[RESI]] : i1
// CHECK: return [[RES]] : i1
func @compare_complex_ne(%a : !fir.complex<8>, %b : !fir.complex<8>) -> i1 {
%r = fir.cmpc "une", %a, %b : !fir.complex<8>
return %r : i1
}
// CHECK-LABEL: llvm.func @compare_complex_ne
// CHECK-SAME: [[A:%.*]]: !llvm.struct<(f64, f64)>,
// CHECK-SAME: [[B:%.*]]: !llvm.struct<(f64, f64)>
// CHECK-DAG: [[RA:%.*]] = llvm.extractvalue [[A]][0 : i32] : !llvm.struct<(f64, f64)>
// CHECK-DAG: [[IA:%.*]] = llvm.extractvalue [[A]][1 : i32] : !llvm.struct<(f64, f64)>
// CHECK-DAG: [[RB:%.*]] = llvm.extractvalue [[B]][0 : i32] : !llvm.struct<(f64, f64)>
// CHECK-DAG: [[IB:%.*]] = llvm.extractvalue [[B]][1 : i32] : !llvm.struct<(f64, f64)>
// CHECK-DAG: [[RESR:%.*]] = llvm.fcmp "une" [[RA]], [[RB]] : f64
// CHECK-DAG: [[RESI:%.*]] = llvm.fcmp "une" [[IA]], [[IB]] : f64
// CHECK: [[RES:%.*]] = llvm.or [[RESR]], [[RESI]] : i1
// CHECK: return [[RES]] : i1
func @compare_complex_other(%a : !fir.complex<8>, %b : !fir.complex<8>) -> i1 {
%r = fir.cmpc "ogt", %a, %b : !fir.complex<8>
return %r : i1
}
// CHECK-LABEL: llvm.func @compare_complex_other
// CHECK-SAME: [[A:%.*]]: !llvm.struct<(f64, f64)>,
// CHECK-SAME: [[B:%.*]]: !llvm.struct<(f64, f64)>
// CHECK-DAG: [[RA:%.*]] = llvm.extractvalue [[A]][0 : i32] : !llvm.struct<(f64, f64)>
// CHECK-DAG: [[RB:%.*]] = llvm.extractvalue [[B]][0 : i32] : !llvm.struct<(f64, f64)>
// CHECK: [[RESR:%.*]] = llvm.fcmp "ogt" [[RA]], [[RB]] : f64
// CHECK: return [[RESR]] : i1
// -----
// Test `fir.convert` operation conversion from Float type.
func @convert_from_float(%arg0 : f32) {