[TRE] Allow accumulator elimination when base case returns non-constant

Remove the requirement, that when performing accumulator elimination,
all other cases must return the same dynamic constant. We can do this by
initializing the accumulator with the identity value of the accumulation
operation, and inserting an additional operation before any return.

Differential Revision: https://reviews.llvm.org/D80844
This commit is contained in:
Layton Kifer 2020-06-04 10:33:03 -07:00 committed by Eli Friedman
parent bd43f78c76
commit 7381fcdf62
2 changed files with 270 additions and 165 deletions

View File

@ -354,89 +354,23 @@ static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA) {
return !is_contained(I->operands(), CI);
}
/// Return true if the specified value is the same when the return would exit
/// as it was when the initial iteration of the recursive function was executed.
///
/// We currently handle static constants and arguments that are not modified as
/// part of the recursion.
static bool isDynamicConstant(Value *V, CallInst *CI, ReturnInst *RI) {
if (isa<Constant>(V)) return true; // Static constants are always dyn consts
static bool canTransformAccumulatorRecursion(Instruction *I, CallInst *CI) {
if (!I->isAssociative() || !I->isCommutative())
return false;
// Check to see if this is an immutable argument, if so, the value
// will be available to initialize the accumulator.
if (Argument *Arg = dyn_cast<Argument>(V)) {
// Figure out which argument number this is...
unsigned ArgNo = 0;
Function *F = CI->getParent()->getParent();
for (Function::arg_iterator AI = F->arg_begin(); &*AI != Arg; ++AI)
++ArgNo;
// If we are passing this argument into call as the corresponding
// argument operand, then the argument is dynamically constant.
// Otherwise, we cannot transform this function safely.
if (CI->getArgOperand(ArgNo) == Arg)
return true;
}
// Switch cases are always constant integers. If the value is being switched
// on and the return is only reachable from one of its cases, it's
// effectively constant.
if (BasicBlock *UniquePred = RI->getParent()->getUniquePredecessor())
if (SwitchInst *SI = dyn_cast<SwitchInst>(UniquePred->getTerminator()))
if (SI->getCondition() == V)
return SI->getDefaultDest() != RI->getParent();
// Not a constant or immutable argument, we can't safely transform.
return false;
}
/// Check to see if the function containing the specified tail call consistently
/// returns the same runtime-constant value at all exit points except for
/// IgnoreRI. If so, return the returned value.
static Value *getCommonReturnValue(ReturnInst *IgnoreRI, CallInst *CI) {
Function *F = CI->getParent()->getParent();
Value *ReturnedValue = nullptr;
for (BasicBlock &BBI : *F) {
ReturnInst *RI = dyn_cast<ReturnInst>(BBI.getTerminator());
if (RI == nullptr || RI == IgnoreRI) continue;
// We can only perform this transformation if the value returned is
// evaluatable at the start of the initial invocation of the function,
// instead of at the end of the evaluation.
//
Value *RetOp = RI->getOperand(0);
if (!isDynamicConstant(RetOp, CI, RI))
return nullptr;
if (ReturnedValue && RetOp != ReturnedValue)
return nullptr; // Cannot transform if differing values are returned.
ReturnedValue = RetOp;
}
return ReturnedValue;
}
/// If the specified instruction can be transformed using accumulator recursion
/// elimination, return the constant which is the start of the accumulator
/// value. Otherwise return null.
static Value *canTransformAccumulatorRecursion(Instruction *I, CallInst *CI) {
if (!I->isAssociative() || !I->isCommutative()) return nullptr;
assert(I->getNumOperands() == 2 &&
"Associative/commutative operations should have 2 args!");
// Exactly one operand should be the result of the call instruction.
if ((I->getOperand(0) == CI && I->getOperand(1) == CI) ||
(I->getOperand(0) != CI && I->getOperand(1) != CI))
return nullptr;
return false;
// The only user of this instruction we allow is a single return instruction.
if (!I->hasOneUse() || !isa<ReturnInst>(I->user_back()))
return nullptr;
return false;
// Ok, now we have to check all of the other return instructions in this
// function. If they return non-constants or differing values, then we cannot
// transform the function safely.
return getCommonReturnValue(cast<ReturnInst>(I->user_back()), CI);
return true;
}
static Instruction *firstNonDbg(BasicBlock::iterator I) {
@ -470,6 +404,16 @@ class TailRecursionEliminator {
// to either propagate RetPN or select a new return value.
SmallVector<SelectInst *, 8> RetSelects;
// The below are shared state needed when performing accumulator recursion.
// There values should be populated by insertAccumulator the first time we
// find an elimination that requires an accumulator.
// PHI node to store our current accumulated value.
PHINode *AccPN = nullptr;
// The instruction doing the accumulating.
Instruction *AccumulatorRecursionInstr = nullptr;
TailRecursionEliminator(Function &F, const TargetTransformInfo *TTI,
AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
DomTreeUpdater &DTU)
@ -480,7 +424,7 @@ class TailRecursionEliminator {
void createTailRecurseLoopHeader(CallInst *CI);
PHINode *insertAccumulator(Value *AccumulatorRecursionEliminationInitVal);
void insertAccumulator(Instruction *AccRecInstr);
bool eliminateCall(CallInst *CI);
@ -608,47 +552,44 @@ void TailRecursionEliminator::createTailRecurseLoopHeader(CallInst *CI) {
DTU.recalculate(*NewEntry->getParent());
}
PHINode *TailRecursionEliminator::insertAccumulator(
Value *AccumulatorRecursionEliminationInitVal) {
void TailRecursionEliminator::insertAccumulator(Instruction *AccRecInstr) {
assert(!AccPN && "Trying to insert multiple accumulators");
AccumulatorRecursionInstr = AccRecInstr;
// Start by inserting a new PHI node for the accumulator.
pred_iterator PB = pred_begin(HeaderBB), PE = pred_end(HeaderBB);
PHINode *AccPN = PHINode::Create(
AccumulatorRecursionEliminationInitVal->getType(),
std::distance(PB, PE) + 1, "accumulator.tr", &HeaderBB->front());
AccPN = PHINode::Create(F.getReturnType(), std::distance(PB, PE) + 1,
"accumulator.tr", &HeaderBB->front());
// Loop over all of the predecessors of the tail recursion block. For the
// real entry into the function we seed the PHI with the initial value,
// computed earlier. For any other existing branches to this block (due to
// other tail recursions eliminated) the accumulator is not modified.
// real entry into the function we seed the PHI with the identity constant for
// the accumulation operation. For any other existing branches to this block
// (due to other tail recursions eliminated) the accumulator is not modified.
// Because we haven't added the branch in the current block to HeaderBB yet,
// it will not show up as a predecessor.
for (pred_iterator PI = PB; PI != PE; ++PI) {
BasicBlock *P = *PI;
if (P == &F.getEntryBlock())
AccPN->addIncoming(AccumulatorRecursionEliminationInitVal, P);
else
if (P == &F.getEntryBlock()) {
Constant *Identity = ConstantExpr::getBinOpIdentity(
AccRecInstr->getOpcode(), AccRecInstr->getType());
AccPN->addIncoming(Identity, P);
} else {
AccPN->addIncoming(AccPN, P);
}
}
return AccPN;
++NumAccumAdded;
}
bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
ReturnInst *Ret = cast<ReturnInst>(CI->getParent()->getTerminator());
// If we are introducing accumulator recursion to eliminate operations after
// the call instruction that are both associative and commutative, the initial
// value for the accumulator is placed in this variable. If this value is set
// then we actually perform accumulator recursion elimination instead of
// simple tail recursion elimination. If the operation is an LLVM instruction
// (eg: "add") then it is recorded in AccumulatorRecursionInstr.
Value *AccumulatorRecursionEliminationInitVal = nullptr;
Instruction *AccumulatorRecursionInstr = nullptr;
// Ok, we found a potential tail call. We can currently only transform the
// tail call if all of the instructions between the call and the return are
// movable to above the call itself, leaving the call next to the return.
// Check that this is the case now.
Instruction *AccRecInstr = nullptr;
BasicBlock::iterator BBI(CI);
for (++BBI; &*BBI != Ret; ++BBI) {
if (canMoveAboveCall(&*BBI, CI, AA))
@ -657,15 +598,13 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
// If we can't move the instruction above the call, it might be because it
// is an associative and commutative operation that could be transformed
// using accumulator recursion elimination. Check to see if this is the
// case, and if so, remember the initial accumulator value for later.
if ((AccumulatorRecursionEliminationInitVal =
canTransformAccumulatorRecursion(&*BBI, CI))) {
// Yes, this is accumulator recursion. Remember which instruction
// accumulates.
AccumulatorRecursionInstr = &*BBI;
} else {
return false; // Otherwise, we cannot eliminate the tail recursion!
}
// case, and if so, remember which instruction accumulates for later.
if (AccPN || !canTransformAccumulatorRecursion(&*BBI, CI))
return false; // We cannot eliminate the tail recursion!
// Yes, this is accumulator recursion. Remember which instruction
// accumulates.
AccRecInstr = &*BBI;
}
BasicBlock *BB = Ret->getParent();
@ -690,37 +629,18 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i)
ArgumentPHIs[i]->addIncoming(CI->getArgOperand(i), BB);
// If we are introducing an accumulator variable to eliminate the recursion,
// do so now. Note that we _know_ that no subsequent tail recursion
// eliminations will happen on this function because of the way the
// accumulator recursion predicate is set up.
//
if (AccumulatorRecursionEliminationInitVal) {
PHINode *AccPN = insertAccumulator(AccumulatorRecursionEliminationInitVal);
if (AccRecInstr) {
insertAccumulator(AccRecInstr);
Instruction *AccRecInstr = AccumulatorRecursionInstr;
// Add an incoming argument for the current block, which is computed by
// our associative and commutative accumulator instruction.
AccPN->addIncoming(AccRecInstr, BB);
// Next, rewrite the accumulator recursion instruction so that it does not
// use the result of the call anymore, instead, use the PHI node we just
// Rewrite the accumulator recursion instruction so that it does not use
// the result of the call anymore, instead, use the PHI node we just
// inserted.
AccRecInstr->setOperand(AccRecInstr->getOperand(0) != CI, AccPN);
// Finally, rewrite any return instructions in the program to return the PHI
// node instead of the "initval" that they do currently. This loop will
// actually rewrite the return value we are destroying, but that's ok.
for (BasicBlock &BBI : F)
if (ReturnInst *RI = dyn_cast<ReturnInst>(BBI.getTerminator()))
RI->setOperand(0, AccPN);
++NumAccumAdded;
}
// Update our return value tracking
if (RetPN) {
if (Ret->getReturnValue() == CI || AccumulatorRecursionEliminationInitVal) {
if (Ret->getReturnValue() == CI || AccRecInstr) {
// Defer selecting a return value
RetPN->addIncoming(RetPN, BB);
RetKnownPN->addIncoming(RetKnownPN, BB);
@ -735,6 +655,9 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
RetPN->addIncoming(SI, BB);
RetKnownPN->addIncoming(ConstantInt::getTrue(RetKnownPN->getType()), BB);
}
if (AccPN)
AccPN->addIncoming(AccRecInstr ? AccRecInstr : AccPN, BB);
}
// Now that all of the PHI nodes are in place, remove the call and
@ -829,6 +752,24 @@ void TailRecursionEliminator::cleanupAndFinalize() {
RetKnownPN->dropAllReferences();
RetKnownPN->eraseFromParent();
if (AccPN) {
// We need to insert a copy of our accumulator instruction before any
// return in the function, and return its result instead.
Instruction *AccRecInstr = AccumulatorRecursionInstr;
for (BasicBlock &BB : F) {
ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator());
if (!RI)
continue;
Instruction *AccRecInstrNew = AccRecInstr->clone();
AccRecInstrNew->setName("accumulator.ret.tr");
AccRecInstrNew->setOperand(AccRecInstr->getOperand(0) == AccPN,
RI->getOperand(0));
AccRecInstrNew->insertBefore(RI);
RI->setOperand(0, AccRecInstrNew);
}
}
} else {
// We need to insert a select instruction before any return left in the
// function to select our stored return value if we have one.
@ -839,8 +780,23 @@ void TailRecursionEliminator::cleanupAndFinalize() {
SelectInst *SI = SelectInst::Create(
RetKnownPN, RetPN, RI->getOperand(0), "current.ret.tr", RI);
RetSelects.push_back(SI);
RI->setOperand(0, SI);
}
if (AccPN) {
// We need to insert a copy of our accumulator instruction before any
// of the selects we inserted, and select its result instead.
Instruction *AccRecInstr = AccumulatorRecursionInstr;
for (SelectInst *SI : RetSelects) {
Instruction *AccRecInstrNew = AccRecInstr->clone();
AccRecInstrNew->setName("accumulator.ret.tr");
AccRecInstrNew->setOperand(AccRecInstr->getOperand(0) == AccPN,
SI->getFalseValue());
AccRecInstrNew->insertBefore(SI);
SI->setFalseValue(AccRecInstrNew);
}
}
}
}
}

View File

@ -3,73 +3,222 @@
define i32 @test1_factorial(i32 %x) {
entry:
%tmp.1 = icmp sgt i32 %x, 0 ; <i1> [#uses=1]
%tmp.1 = icmp sgt i32 %x, 0
br i1 %tmp.1, label %then, label %else
then: ; preds = %entry
%tmp.6 = add i32 %x, -1 ; <i32> [#uses=1]
%tmp.4 = call i32 @test1_factorial( i32 %tmp.6 ) ; <i32> [#uses=1]
%tmp.7 = mul i32 %tmp.4, %x ; <i32> [#uses=1]
ret i32 %tmp.7
else: ; preds = %entry
then:
%tmp.6 = add i32 %x, -1
%recurse = call i32 @test1_factorial( i32 %tmp.6 )
%accumulate = mul i32 %recurse, %x
ret i32 %accumulate
else:
ret i32 1
}
; CHECK-LABEL: define i32 @test1_factorial(
; CHECK: phi i32
; CHECK-NOT: call i32
; CHECK: tailrecurse:
; CHECK: %accumulator.tr = phi i32 [ 1, %entry ], [ %accumulate, %then ]
; CHECK: then:
; CHECK-NOT: %recurse
; CHECK: %accumulate = mul i32 %accumulator.tr, %x.tr
; CHECK: else:
; CHECK: %accumulator.ret.tr = mul i32 %accumulator.tr, 1
; CHECK: ret i32 %accumulator.ret.tr
; This is a more aggressive form of accumulator recursion insertion, which
; requires noticing that X doesn't change as we perform the tailcall.
define i32 @test2_mul(i32 %x, i32 %y) {
entry:
%tmp.1 = icmp eq i32 %y, 0 ; <i1> [#uses=1]
%tmp.1 = icmp eq i32 %y, 0
br i1 %tmp.1, label %return, label %endif
endif: ; preds = %entry
%tmp.8 = add i32 %y, -1 ; <i32> [#uses=1]
%tmp.5 = call i32 @test2_mul( i32 %x, i32 %tmp.8 ) ; <i32> [#uses=1]
%tmp.9 = add i32 %tmp.5, %x ; <i32> [#uses=1]
ret i32 %tmp.9
return: ; preds = %entry
endif:
%tmp.8 = add i32 %y, -1
%recurse = call i32 @test2_mul( i32 %x, i32 %tmp.8 )
%accumulate = add i32 %recurse, %x
ret i32 %accumulate
return:
ret i32 %x
}
; CHECK-LABEL: define i32 @test2_mul(
; CHECK: phi i32
; CHECK-NOT: call i32
; CHECK: tailrecurse:
; CHECK: %accumulator.tr = phi i32 [ 0, %entry ], [ %accumulate, %endif ]
; CHECK: endif:
; CHECK-NOT: %recurse
; CHECK: %accumulate = add i32 %accumulator.tr, %x
; CHECK: return:
; CHECK: %accumulator.ret.tr = add i32 %accumulator.tr, %x
; CHECK: ret i32 %accumulator.ret.tr
define i64 @test3_fib(i64 %n) nounwind readnone {
; CHECK-LABEL: @test3_fib(
entry:
; CHECK: tailrecurse:
; CHECK: %accumulator.tr = phi i64 [ %n, %entry ], [ %3, %bb1 ]
; CHECK: %n.tr = phi i64 [ %n, %entry ], [ %2, %bb1 ]
switch i64 %n, label %bb1 [
; CHECK: switch i64 %n.tr, label %bb1 [
i64 0, label %bb2
i64 1, label %bb2
]
bb1:
; CHECK: bb1:
%0 = add i64 %n, -1
; CHECK: %0 = add i64 %n.tr, -1
%1 = tail call i64 @test3_fib(i64 %0) nounwind
; CHECK: %1 = tail call i64 @test3_fib(i64 %0)
%2 = add i64 %n, -2
; CHECK: %2 = add i64 %n.tr, -2
%3 = tail call i64 @test3_fib(i64 %2) nounwind
; CHECK-NOT: tail call i64 @test3_fib
%4 = add nsw i64 %3, %1
; CHECK: add nsw i64 %accumulator.tr, %1
ret i64 %4
; CHECK: br label %tailrecurse
%recurse1 = tail call i64 @test3_fib(i64 %0) nounwind
%1 = add i64 %n, -2
%recurse2 = tail call i64 @test3_fib(i64 %1) nounwind
%accumulate = add nsw i64 %recurse2, %recurse1
ret i64 %accumulate
bb2:
; CHECK: bb2:
ret i64 %n
; CHECK: ret i64 %accumulator.tr
}
; CHECK-LABEL: define i64 @test3_fib(
; CHECK: tailrecurse:
; CHECK: %accumulator.tr = phi i64 [ 0, %entry ], [ %accumulate, %bb1 ]
; CHECK: bb1:
; CHECK-NOT: %recurse2
; CHECK: %accumulate = add nsw i64 %accumulator.tr, %recurse1
; CHECK: bb2:
; CHECK: %accumulator.ret.tr = add nsw i64 %accumulator.tr, %n.tr
; CHECK: ret i64 %accumulator.ret.tr
define i32 @test4_base_case_call() local_unnamed_addr {
entry:
%base = call i32 @test4_helper()
switch i32 %base, label %sw.default [
i32 1, label %cleanup
i32 5, label %cleanup
i32 7, label %cleanup
]
sw.default:
%recurse = call i32 @test4_base_case_call()
%accumulate = add nsw i32 %recurse, 1
br label %cleanup
cleanup:
%retval.0 = phi i32 [ %accumulate, %sw.default ], [ %base, %entry ], [ %base, %entry ], [ %base, %entry ]
ret i32 %retval.0
}
declare i32 @test4_helper()
; CHECK-LABEL: define i32 @test4_base_case_call(
; CHECK: tailrecurse:
; CHECK: %accumulator.tr = phi i32 [ 0, %entry ], [ %accumulate, %sw.default ]
; CHECK: sw.default:
; CHECK-NOT: %recurse
; CHECK: %accumulate = add nsw i32 %accumulator.tr, 1
; CHECK: cleanup:
; CHECK: %accumulator.ret.tr = add nsw i32 %accumulator.tr, %base
; CHECK: ret i32 %accumulator.ret.tr
define i32 @test5_base_case_load(i32* nocapture %A, i32 %n) local_unnamed_addr {
entry:
%cmp = icmp eq i32 %n, 0
br i1 %cmp, label %if.then, label %if.end
if.then:
%base = load i32, i32* %A, align 4
ret i32 %base
if.end:
%idxprom = zext i32 %n to i64
%arrayidx1 = getelementptr inbounds i32, i32* %A, i64 %idxprom
%load = load i32, i32* %arrayidx1, align 4
%sub = add i32 %n, -1
%recurse = tail call i32 @test5_base_case_load(i32* %A, i32 %sub)
%accumulate = add i32 %recurse, %load
ret i32 %accumulate
}
; CHECK-LABEL: define i32 @test5_base_case_load(
; CHECK: tailrecurse:
; CHECK: %accumulator.tr = phi i32 [ 0, %entry ], [ %accumulate, %if.end ]
; CHECK: if.then:
; CHECK: %accumulator.ret.tr = add i32 %accumulator.tr, %base
; CHECK: ret i32 %accumulator.ret.tr
; CHECK: if.end:
; CHECK-NOT: %recurse
; CHECK: %accumulate = add i32 %accumulator.tr, %load
define i32 @test6_multiple_returns(i32 %x, i32 %y) local_unnamed_addr {
entry:
switch i32 %x, label %default [
i32 0, label %case0
i32 99, label %case99
]
case0:
%helper = call i32 @test6_helper()
ret i32 %helper
case99:
%sub1 = add i32 %x, -1
%recurse1 = call i32 @test6_multiple_returns(i32 %sub1, i32 %y)
ret i32 18
default:
%sub2 = add i32 %x, -1
%recurse2 = call i32 @test6_multiple_returns(i32 %sub2, i32 %y)
%accumulate = add i32 %recurse2, %y
ret i32 %accumulate
}
declare i32 @test6_helper()
; CHECK-LABEL: define i32 @test6_multiple_returns(
; CHECK: tailrecurse:
; CHECK: %accumulator.tr = phi i32 [ %accumulator.tr, %case99 ], [ 0, %entry ], [ %accumulate, %default ]
; CHECK: %ret.tr = phi i32 [ undef, %entry ], [ %current.ret.tr, %case99 ], [ %ret.tr, %default ]
; CHECK: %ret.known.tr = phi i1 [ false, %entry ], [ true, %case99 ], [ %ret.known.tr, %default ]
; CHECK: case0:
; CHECK: %accumulator.ret.tr2 = add i32 %accumulator.tr, %helper
; CHECK: %current.ret.tr1 = select i1 %ret.known.tr, i32 %ret.tr, i32 %accumulator.ret.tr2
; CHECK: case99:
; CHECK-NOT: %recurse
; CHECK: %accumulator.ret.tr = add i32 %accumulator.tr, 18
; CHECK: %current.ret.tr = select i1 %ret.known.tr, i32 %ret.tr, i32 %accumulator.ret.tr
; CHECK: default:
; CHECK-NOT: %recurse
; CHECK: %accumulate = add i32 %accumulator.tr, %y
; It is only safe to transform one accumulator per function, make sure we don't
; try to remove more.
define i32 @test7_multiple_accumulators(i32 %a) local_unnamed_addr {
entry:
%tobool = icmp eq i32 %a, 0
br i1 %tobool, label %return, label %if.end
if.end:
%and = and i32 %a, 1
%tobool1 = icmp eq i32 %and, 0
%sub = add nsw i32 %a, -1
br i1 %tobool1, label %if.end3, label %if.then2
if.then2:
%recurse1 = tail call i32 @test7_multiple_accumulators(i32 %sub)
%accumulate1 = add nsw i32 %recurse1, 1
br label %return
if.end3:
%recurse2 = tail call i32 @test7_multiple_accumulators(i32 %sub)
%accumulate2 = mul nsw i32 %recurse2, 2
br label %return
return:
%retval.0 = phi i32 [ %accumulate1, %if.then2 ], [ %accumulate2, %if.end3 ], [ 0, %entry ]
ret i32 %retval.0
}
; CHECK-LABEL: define i32 @test7_multiple_accumulators(
; CHECK: tailrecurse:
; CHECK: %accumulator.tr = phi i32 [ 0, %entry ], [ %accumulate1, %if.then2 ]
; CHECK: if.then2:
; CHECK-NOT: %recurse1
; CHECK: %accumulate1 = add nsw i32 %accumulator.tr, 1
; CHECK: if.end3:
; CHECK: %recurse2
; CHECK: %accumulator.ret.tr = add nsw i32 %accumulator.tr, %accumulate2
; CHECK: ret i32 %accumulator.ret.tr
; CHECK: return:
; CHECK: %accumulator.ret.tr1 = add nsw i32 %accumulator.tr, 0
; CHECK: ret i32 %accumulator.ret.tr1