diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 0e060ce6d73b..4ca9319cfa50 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -374,6 +374,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, if (Subtarget.hasStdExtF()) { setOperationAction(ISD::FLT_ROUNDS_, XLenVT, Custom); + setOperationAction(ISD::SET_ROUNDING, MVT::Other, Custom); } setOperationAction(ISD::GlobalAddress, XLenVT, Custom); @@ -2167,6 +2168,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, return lowerMSCATTER(Op, DAG); case ISD::FLT_ROUNDS_: return lowerGET_ROUNDING(Op, DAG); + case ISD::SET_ROUNDING: + return lowerSET_ROUNDING(Op, DAG); } } @@ -4144,6 +4147,36 @@ SDValue RISCVTargetLowering::lowerGET_ROUNDING(SDValue Op, return DAG.getMergeValues({Masked, Chain}, DL); } +SDValue RISCVTargetLowering::lowerSET_ROUNDING(SDValue Op, + SelectionDAG &DAG) const { + const MVT XLenVT = Subtarget.getXLenVT(); + SDLoc DL(Op); + SDValue Chain = Op->getOperand(0); + SDValue RMValue = Op->getOperand(1); + SDValue SysRegNo = DAG.getConstant( + RISCVSysReg::lookupSysRegByName("FRM")->Encoding, DL, XLenVT); + + // Encoding used for rounding mode in RISCV differs from that used in + // FLT_ROUNDS. To convert it the C rounding mode is used as an index in + // a table, which consists of a sequence of 4-bit fields, each representing + // corresponding RISCV mode. + static const unsigned Table = + (RISCVFPRndMode::RNE << 4 * int(RoundingMode::NearestTiesToEven)) | + (RISCVFPRndMode::RTZ << 4 * int(RoundingMode::TowardZero)) | + (RISCVFPRndMode::RDN << 4 * int(RoundingMode::TowardNegative)) | + (RISCVFPRndMode::RUP << 4 * int(RoundingMode::TowardPositive)) | + (RISCVFPRndMode::RMM << 4 * int(RoundingMode::NearestTiesToAway)); + + SDValue Shift = DAG.getNode(ISD::SHL, DL, XLenVT, RMValue, + DAG.getConstant(2, DL, XLenVT)); + SDValue Shifted = DAG.getNode(ISD::SRL, DL, XLenVT, + DAG.getConstant(Table, DL, XLenVT), Shift); + RMValue = DAG.getNode(ISD::AND, DL, XLenVT, Shifted, + DAG.getConstant(0x7, DL, XLenVT)); + return DAG.getNode(RISCVISD::WRITE_CSR, DL, MVT::Other, Chain, SysRegNo, + RMValue); +} + // Returns the opcode of the target-specific SDNode that implements the 32-bit // form of the given Opcode. static RISCVISD::NodeType getRISCVWOpcode(unsigned Opcode) { diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h index 7a09b725243f..ddccdb2838df 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -534,6 +534,7 @@ private: SDValue lowerFixedLengthVectorExtendToRVV(SDValue Op, SelectionDAG &DAG, unsigned ExtendOpc) const; SDValue lowerGET_ROUNDING(SDValue Op, SelectionDAG &DAG) const; + SDValue lowerSET_ROUNDING(SDValue Op, SelectionDAG &DAG) const; bool isEligibleForTailCallOptimization( CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF, diff --git a/llvm/test/CodeGen/RISCV/fpenv.ll b/llvm/test/CodeGen/RISCV/fpenv.ll index ed62d75b1bf6..28fac83b97d9 100644 --- a/llvm/test/CodeGen/RISCV/fpenv.ll +++ b/llvm/test/CodeGen/RISCV/fpenv.ll @@ -26,4 +26,100 @@ define i32 @func_01() { ret i32 %rm } +define void @func_02(i32 %rm) { +; RV32IF-LABEL: func_02: +; RV32IF: # %bb.0: +; RV32IF-NEXT: slli a0, a0, 2 +; RV32IF-NEXT: lui a1, 66 +; RV32IF-NEXT: addi a1, a1, 769 +; RV32IF-NEXT: srl a0, a1, a0 +; RV32IF-NEXT: andi a0, a0, 7 +; RV32IF-NEXT: fsrm a0 +; RV32IF-NEXT: ret +; +; RV64IF-LABEL: func_02: +; RV64IF: # %bb.0: +; RV64IF-NEXT: slli a0, a0, 32 +; RV64IF-NEXT: srli a0, a0, 30 +; RV64IF-NEXT: lui a1, 66 +; RV64IF-NEXT: addiw a1, a1, 769 +; RV64IF-NEXT: srl a0, a1, a0 +; RV64IF-NEXT: andi a0, a0, 7 +; RV64IF-NEXT: fsrm a0 +; RV64IF-NEXT: ret + call void @llvm.set.rounding(i32 %rm) + ret void +} + +define void @func_03() { +; RV32IF-LABEL: func_03: +; RV32IF: # %bb.0: +; RV32IF-NEXT: fsrmi 1 +; RV32IF-NEXT: ret +; +; RV64IF-LABEL: func_03: +; RV64IF: # %bb.0: +; RV64IF-NEXT: fsrmi 1 +; RV64IF-NEXT: ret + call void @llvm.set.rounding(i32 0) + ret void +} + +define void @func_04() { +; RV32IF-LABEL: func_04: +; RV32IF: # %bb.0: +; RV32IF-NEXT: fsrmi 0 +; RV32IF-NEXT: ret +; +; RV64IF-LABEL: func_04: +; RV64IF: # %bb.0: +; RV64IF-NEXT: fsrmi 0 +; RV64IF-NEXT: ret + call void @llvm.set.rounding(i32 1) + ret void +} + +define void @func_05() { +; RV32IF-LABEL: func_05: +; RV32IF: # %bb.0: +; RV32IF-NEXT: fsrmi 3 +; RV32IF-NEXT: ret +; +; RV64IF-LABEL: func_05: +; RV64IF: # %bb.0: +; RV64IF-NEXT: fsrmi 3 +; RV64IF-NEXT: ret + call void @llvm.set.rounding(i32 2) + ret void +} + +define void @func_06() { +; RV32IF-LABEL: func_06: +; RV32IF: # %bb.0: +; RV32IF-NEXT: fsrmi 2 +; RV32IF-NEXT: ret +; +; RV64IF-LABEL: func_06: +; RV64IF: # %bb.0: +; RV64IF-NEXT: fsrmi 2 +; RV64IF-NEXT: ret + call void @llvm.set.rounding(i32 3) + ret void +} + +define void @func_07() { +; RV32IF-LABEL: func_07: +; RV32IF: # %bb.0: +; RV32IF-NEXT: fsrmi 4 +; RV32IF-NEXT: ret +; +; RV64IF-LABEL: func_07: +; RV64IF: # %bb.0: +; RV64IF-NEXT: fsrmi 4 +; RV64IF-NEXT: ret + call void @llvm.set.rounding(i32 4) + ret void +} + +declare void @llvm.set.rounding(i32) declare i32 @llvm.flt.rounds()