diff --git a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp index cda6567de89c..2b1e8f384751 100644 --- a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp +++ b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -15,6 +15,10 @@ // 1. Trivial instructions between the call and return do not prevent the // transformation from taking place, though currently the analysis cannot // support moving any really useful instructions (only dead ones). +// 2. This pass transforms functions that are prevented from being tail +// recursive by an associative expression to use an accumulator variable, +// thus compiling the typical naive factorial or 'fib' implementation into +// efficient code. // // There are several improvements that could be made: // @@ -37,10 +41,6 @@ // requires some substantial analysis (such as with DSA) to prove safe to // move ahead of the call, but doing so could allow many more TREs to be // performed, for example in TreeAdd/TreeAlloc from the treeadd benchmark. -// 5. This pass could transform functions that are prevented from being tail -// recursive by a commutative expression to use an accumulator helper -// function, thus compiling the typical naive factorial or 'fib' -// implementation into efficient code. // //===----------------------------------------------------------------------===// @@ -49,11 +49,13 @@ #include "llvm/Function.h" #include "llvm/Instructions.h" #include "llvm/Pass.h" +#include "llvm/Support/CFG.h" #include "Support/Statistic.h" using namespace llvm; namespace { Statistic<> NumEliminated("tailcallelim", "Number of tail calls removed"); + Statistic<> NumAccumAdded("tailcallelim","Number of accumulators introduced"); struct TailCallElim : public FunctionPass { virtual bool runOnFunction(Function &F); @@ -62,6 +64,7 @@ namespace { bool ProcessReturningBlock(ReturnInst *RI, BasicBlock *&OldEntry, std::vector &ArgumentPHIs); bool CanMoveAboveCall(Instruction *I, CallInst *CI); + Value *CanTransformAccumulatorRecursion(Instruction *I, CallInst *CI); }; RegisterOpt X("tailcallelim", "Tail Call Elimination"); } @@ -90,10 +93,10 @@ bool TailCallElim::runOnFunction(Function &F) { } -// CanMoveAboveCall - Return true if it is safe to move the specified -// instruction from after the call to before the call, assuming that all -// instructions between the call and this instruction are movable. -// +/// CanMoveAboveCall - Return true if it is safe to move the specified +/// instruction from after the call to before the call, assuming that all +/// instructions between the call and this instruction are movable. +/// bool TailCallElim::CanMoveAboveCall(Instruction *I, CallInst *CI) { // FIXME: We can move load/store/call/free instructions above the call if the // call does not mod/ref the memory location being processed. @@ -112,6 +115,49 @@ bool TailCallElim::CanMoveAboveCall(Instruction *I, CallInst *CI) { } +/// CanTransformAccumulatorRecursion - If the specified instruction can be +/// transformed using accumulator recursion elimination, return the constant +/// which is the start of the accumulator value. Otherwise return null. +/// +Value *TailCallElim::CanTransformAccumulatorRecursion(Instruction *I, + CallInst *CI) { + if (!I->isAssociative()) return 0; + assert(I->getNumOperands() == 2 && + "Associative operations should have 2 args!"); + + // Exactly one operand should be the result of the call instruction... + if (I->getOperand(0) == CI && I->getOperand(1) == CI || + I->getOperand(0) != CI && I->getOperand(1) != CI) + return 0; + + // The only user of this instruction we allow is a single return instruction. + if (!I->hasOneUse() || !isa(I->use_back())) + return 0; + + // Ok, now we have to check all of the other return instructions in this + // function. If they return non-constants or differing values, then we cannot + // transform the function safely. + Value *ReturnedValue = 0; + Function *F = CI->getParent()->getParent(); + + for (Function::iterator BBI = F->begin(), E = F->end(); BBI != E; ++BBI) + if (ReturnInst *RI = dyn_cast(BBI->getTerminator())) { + Value *RetOp = RI->getOperand(0); + if (isa(RetOp)) { + if (ReturnedValue && RetOp != ReturnedValue) + return 0; // Cannot transform if differing constants are returned. + ReturnedValue = RetOp; + + } else if (RetOp != I) { // Ignore the one returning I. + return 0; // Not returning a constant, cannot transform. + } + } + + // Ok, if we passed this battery of tests, we can perform accumulator + // recursion elimination. + return ReturnedValue; +} + bool TailCallElim::ProcessReturningBlock(ReturnInst *Ret, BasicBlock *&OldEntry, std::vector &ArgumentPHIs) { BasicBlock *BB = Ret->getParent(); @@ -134,17 +180,38 @@ bool TailCallElim::ProcessReturningBlock(ReturnInst *Ret, BasicBlock *&OldEntry, --BBI; } + // If we are introducing accumulator recursion to eliminate associative + // operations after the call instruction, this variable contains the initial + // value for the accumulator. If this value is set, we actually perform + // accumulator recursion elimination instead of simple tail recursion + // elimination. + Value *AccumulatorRecursionEliminationInitVal = 0; + Instruction *AccumulatorRecursionInstr = 0; + // Ok, we found a potential tail call. We can currently only transform the // tail call if all of the instructions between the call and the return are // movable to above the call itself, leaving the call next to the return. // Check that this is the case now. for (BBI = CI, ++BBI; &*BBI != Ret; ++BBI) - if (!CanMoveAboveCall(BBI, CI)) - return false; // Cannot move this instruction out of the way. + if (!CanMoveAboveCall(BBI, CI)) { + // If we can't move the instruction above the call, it might be because it + // is an associative operation that could be tranformed using accumulator + // recursion elimination. Check to see if this is the case, and if so, + // remember the initial accumulator value for later. + if ((AccumulatorRecursionEliminationInitVal = + CanTransformAccumulatorRecursion(BBI, CI))) { + // Yes, this is accumulator recursion. Remember which instruction + // accumulates. + AccumulatorRecursionInstr = BBI; + } else { + return false; // Otherwise, we cannot eliminate the tail recursion! + } + } // We can only transform call/return pairs that either ignore the return value // of the call and return void, or return the value returned by the tail call. - if (Ret->getNumOperands() != 0 && Ret->getReturnValue() != CI) + if (Ret->getNumOperands() != 0 && Ret->getReturnValue() != CI && + AccumulatorRecursionEliminationInitVal == 0) return false; // OK! We can transform this tail call. If this is the first one found, @@ -174,11 +241,54 @@ bool TailCallElim::ProcessReturningBlock(ReturnInst *Ret, BasicBlock *&OldEntry, for (unsigned i = 0, e = CI->getNumOperands()-1; i != e; ++i) ArgumentPHIs[i]->addIncoming(CI->getOperand(i+1), BB); + // If we are introducing an accumulator variable to eliminate the recursion, + // do so now. Note that we _know_ that no subsequent tail recursion + // eliminations will happen on this function because of the way the + // accumulator recursion predicate is set up. + // + if (AccumulatorRecursionEliminationInitVal) { + Instruction *AccRecInstr = AccumulatorRecursionInstr; + // Start by inserting a new PHI node for the accumulator. + PHINode *AccPN = new PHINode(AccRecInstr->getType(), "accumulator.tr", + OldEntry->begin()); + + // Loop over all of the predecessors of the tail recursion block. For the + // real entry into the function we seed the PHI with the initial value, + // computed earlier. For any other existing branches to this block (due to + // other tail recursions eliminated) the accumulator is not modified. + // Because we haven't added the branch in the current block to OldEntry yet, + // it will not show up as a predecessor. + for (pred_iterator PI = pred_begin(OldEntry), PE = pred_end(OldEntry); + PI != PE; ++PI) { + if (*PI == &F->getEntryBlock()) + AccPN->addIncoming(AccumulatorRecursionEliminationInitVal, *PI); + else + AccPN->addIncoming(AccPN, *PI); + } + + // Add an incoming argument for the current block, which is computed by our + // associative accumulator instruction. + AccPN->addIncoming(AccRecInstr, BB); + + // Next, rewrite the accumulator recursion instruction so that it does not + // use the result of the call anymore, instead, use the PHI node we just + // inserted. + AccRecInstr->setOperand(AccRecInstr->getOperand(0) != CI, AccPN); + + // Finally, rewrite any return instructions in the program to return the PHI + // node instead of the "initval" that they do currently. This loop will + // actually rewrite the return value we are destroying, but that's ok. + for (Function::iterator BBI = F->begin(), E = F->end(); BBI != E; ++BBI) + if (ReturnInst *RI = dyn_cast(BBI->getTerminator())) + RI->setOperand(0, AccPN); + ++NumAccumAdded; + } + // Now that all of the PHI nodes are in place, remove the call and // ret instructions, replacing them with an unconditional branch. new BranchInst(OldEntry, Ret); BB->getInstList().erase(Ret); // Remove return. BB->getInstList().erase(CI); // Remove call. - NumEliminated++; + ++NumEliminated; return true; }