diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h index 0ae665107594..106b086184ac 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -511,6 +511,7 @@ private: void ScalarizeVectorResult(SDNode *N, unsigned OpNo); SDValue ScalarizeVecRes_MERGE_VALUES(SDNode *N, unsigned ResNo); SDValue ScalarizeVecRes_BinOp(SDNode *N); + SDValue ScalarizeVecRes_TernaryOp(SDNode *N); SDValue ScalarizeVecRes_UnaryOp(SDNode *N); SDValue ScalarizeVecRes_InregOp(SDNode *N); @@ -555,6 +556,7 @@ private: // Vector Result Splitting: <128 x ty> -> 2 x <64 x ty>. void SplitVectorResult(SDNode *N, unsigned OpNo); void SplitVecRes_BinOp(SDNode *N, SDValue &Lo, SDValue &Hi); + void SplitVecRes_TernaryOp(SDNode *N, SDValue &Lo, SDValue &Hi); void SplitVecRes_UnaryOp(SDNode *N, SDValue &Lo, SDValue &Hi); void SplitVecRes_InregOp(SDNode *N, SDValue &Lo, SDValue &Hi); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp index 5f23f01dafb4..d09411c42f3c 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -115,6 +115,9 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) { case ISD::SRL: R = ScalarizeVecRes_BinOp(N); break; + case ISD::FMA: + R = ScalarizeVecRes_TernaryOp(N); + break; } // If R is null, the sub-method took care of registering the result. @@ -129,6 +132,14 @@ SDValue DAGTypeLegalizer::ScalarizeVecRes_BinOp(SDNode *N) { LHS.getValueType(), LHS, RHS); } +SDValue DAGTypeLegalizer::ScalarizeVecRes_TernaryOp(SDNode *N) { + SDValue Op0 = GetScalarizedVector(N->getOperand(0)); + SDValue Op1 = GetScalarizedVector(N->getOperand(1)); + SDValue Op2 = GetScalarizedVector(N->getOperand(2)); + return DAG.getNode(N->getOpcode(), N->getDebugLoc(), + Op0.getValueType(), Op0, Op1, Op2); +} + SDValue DAGTypeLegalizer::ScalarizeVecRes_MERGE_VALUES(SDNode *N, unsigned ResNo) { SDValue Op = DisintegrateMERGE_VALUES(N, ResNo); @@ -529,6 +540,9 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) { case ISD::FREM: SplitVecRes_BinOp(N, Lo, Hi); break; + case ISD::FMA: + SplitVecRes_TernaryOp(N, Lo, Hi); + break; } // If Lo/Hi is null, the sub-method took care of registering results etc. @@ -548,6 +562,22 @@ void DAGTypeLegalizer::SplitVecRes_BinOp(SDNode *N, SDValue &Lo, Hi = DAG.getNode(N->getOpcode(), dl, LHSHi.getValueType(), LHSHi, RHSHi); } +void DAGTypeLegalizer::SplitVecRes_TernaryOp(SDNode *N, SDValue &Lo, + SDValue &Hi) { + SDValue Op0Lo, Op0Hi; + GetSplitVector(N->getOperand(0), Op0Lo, Op0Hi); + SDValue Op1Lo, Op1Hi; + GetSplitVector(N->getOperand(1), Op1Lo, Op1Hi); + SDValue Op2Lo, Op2Hi; + GetSplitVector(N->getOperand(2), Op2Lo, Op2Hi); + DebugLoc dl = N->getDebugLoc(); + + Lo = DAG.getNode(N->getOpcode(), dl, Op0Lo.getValueType(), + Op0Lo, Op1Lo, Op2Lo); + Hi = DAG.getNode(N->getOpcode(), dl, Op0Hi.getValueType(), + Op0Hi, Op1Hi, Op2Hi); +} + void DAGTypeLegalizer::SplitVecRes_BITCAST(SDNode *N, SDValue &Lo, SDValue &Hi) { // We know the result is a vector. The input may be either a vector or a diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 4426f300d56a..35366cebb507 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -715,6 +715,7 @@ X86TargetLowering::X86TargetLowering(X86TargetMachine &TM) setOperationAction(ISD::FSUB, (MVT::SimpleValueType)VT, Expand); setOperationAction(ISD::MUL , (MVT::SimpleValueType)VT, Expand); setOperationAction(ISD::FMUL, (MVT::SimpleValueType)VT, Expand); + setOperationAction(ISD::FMA, (MVT::SimpleValueType)VT, Custom); setOperationAction(ISD::SDIV, (MVT::SimpleValueType)VT, Expand); setOperationAction(ISD::UDIV, (MVT::SimpleValueType)VT, Expand); setOperationAction(ISD::FDIV, (MVT::SimpleValueType)VT, Expand); @@ -10893,6 +10894,7 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { case ISD::SUBE: return LowerADDC_ADDE_SUBC_SUBE(Op, DAG); case ISD::ADD: return LowerADD(Op, DAG); case ISD::SUB: return LowerSUB(Op, DAG); + case ISD::FMA: return SDValue(); } } diff --git a/llvm/test/CodeGen/ARM/fusedMAC.ll b/llvm/test/CodeGen/ARM/fusedMAC.ll index d35330c09f5f..303d165de0b6 100644 --- a/llvm/test/CodeGen/ARM/fusedMAC.ll +++ b/llvm/test/CodeGen/ARM/fusedMAC.ll @@ -206,7 +206,19 @@ define float @test_fma_canonicalize(float %a, float %b) nounwind { ret float %ret } +; Check that very wide vector fma's can be split into legal fma's. +define void @test_fma_v8f32(<8 x float> %a, <8 x float> %b, <8 x float> %c, <8 x float>* %p) nounwind readnone ssp { +; CHECK: test_fma_v8f32 +; CHECK: vfma.f32 +; CHECK: vfma.f32 +entry: + %call = tail call <8 x float> @llvm.fma.v8f32(<8 x float> %a, <8 x float> %b, <8 x float> %c) nounwind readnone + store <8 x float> %call, <8 x float>* %p, align 16 + ret void +} + declare float @llvm.fma.f32(float, float, float) nounwind readnone declare double @llvm.fma.f64(double, double, double) nounwind readnone declare <2 x float> @llvm.fma.v2f32(<2 x float>, <2 x float>, <2 x float>) nounwind readnone +declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>) nounwind readnone