[ARM] Push gather/scatter shl index updates out of loops

This teaches the MVE gather scatter lowering pass that SHL is
essentially the same as Mul, where we are able to optimize the
induction of a gather/scatter address by pushing them out of loops.
https://alive2.llvm.org/ce/z/wG4VyT

Differential Revision: https://reviews.llvm.org/D112920
This commit is contained in:
David Green 2021-11-03 11:00:05 +00:00
parent 52615df0f2
commit d36dd1f842
3 changed files with 39 additions and 36 deletions

View File

@ -149,10 +149,10 @@ private:
bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI); bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
// Pushes the given add out of the loop // Pushes the given add out of the loop
void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex); void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
// Pushes the given mul out of the loop // Pushes the given mul or shl out of the loop
void pushOutMul(PHINode *&Phi, Value *IncrementPerRound, void pushOutMulShl(unsigned Opc, PHINode *&Phi, Value *IncrementPerRound,
Value *OffsSecondOperand, unsigned LoopIncrement, Value *OffsSecondOperand, unsigned LoopIncrement,
IRBuilder<> &Builder); IRBuilder<> &Builder);
}; };
} // end anonymous namespace } // end anonymous namespace
@ -342,7 +342,8 @@ Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
const Instruction *I = cast<Instruction>(V); const Instruction *I = cast<Instruction>(V);
if (I->getOpcode() == Instruction::Add || if (I->getOpcode() == Instruction::Add ||
I->getOpcode() == Instruction::Mul) { I->getOpcode() == Instruction::Mul ||
I->getOpcode() == Instruction::Shl) {
Optional<int64_t> Op0 = getIfConst(I->getOperand(0)); Optional<int64_t> Op0 = getIfConst(I->getOperand(0));
Optional<int64_t> Op1 = getIfConst(I->getOperand(1)); Optional<int64_t> Op1 = getIfConst(I->getOperand(1));
if (!Op0 || !Op1) if (!Op0 || !Op1)
@ -351,6 +352,8 @@ Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
return Optional<int64_t>{Op0.getValue() + Op1.getValue()}; return Optional<int64_t>{Op0.getValue() + Op1.getValue()};
if (I->getOpcode() == Instruction::Mul) if (I->getOpcode() == Instruction::Mul)
return Optional<int64_t>{Op0.getValue() * Op1.getValue()}; return Optional<int64_t>{Op0.getValue() * Op1.getValue()};
if (I->getOpcode() == Instruction::Shl)
return Optional<int64_t>{Op0.getValue() << Op1.getValue()};
} }
return Optional<int64_t>{}; return Optional<int64_t>{};
} }
@ -888,11 +891,11 @@ void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
Phi->removeIncomingValue(StartIndex); Phi->removeIncomingValue(StartIndex);
} }
void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi, void MVEGatherScatterLowering::pushOutMulShl(unsigned Opcode, PHINode *&Phi,
Value *IncrementPerRound, Value *IncrementPerRound,
Value *OffsSecondOperand, Value *OffsSecondOperand,
unsigned LoopIncrement, unsigned LoopIncrement,
IRBuilder<> &Builder) { IRBuilder<> &Builder) {
LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n"); LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
// Create a new scalar add outside of the loop and transform it to a splat // Create a new scalar add outside of the loop and transform it to a splat
@ -901,12 +904,13 @@ void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi,
Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back()); Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back());
// Create a new index // Create a new index
Value *StartIndex = BinaryOperator::Create( Value *StartIndex =
Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1), BinaryOperator::Create((Instruction::BinaryOps)Opcode,
OffsSecondOperand, "PushedOutMul", InsertionPoint); Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
OffsSecondOperand, "PushedOutMul", InsertionPoint);
Instruction *Product = Instruction *Product =
BinaryOperator::Create(Instruction::Mul, IncrementPerRound, BinaryOperator::Create((Instruction::BinaryOps)Opcode, IncrementPerRound,
OffsSecondOperand, "Product", InsertionPoint); OffsSecondOperand, "Product", InsertionPoint);
// Increment NewIndex by Product instead of the multiplication // Increment NewIndex by Product instead of the multiplication
Instruction *NewIncrement = BinaryOperator::Create( Instruction *NewIncrement = BinaryOperator::Create(
@ -936,7 +940,8 @@ static bool hasAllGatScatUsers(Instruction *I) {
return Gatscat; return Gatscat;
} else { } else {
unsigned OpCode = cast<Instruction>(U)->getOpcode(); unsigned OpCode = cast<Instruction>(U)->getOpcode();
if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) && if ((OpCode == Instruction::Add || OpCode == Instruction::Mul ||
OpCode == Instruction::Shl) &&
hasAllGatScatUsers(cast<Instruction>(U))) { hasAllGatScatUsers(cast<Instruction>(U))) {
continue; continue;
} }
@ -956,7 +961,8 @@ bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
return false; return false;
Instruction *Offs = cast<Instruction>(Offsets); Instruction *Offs = cast<Instruction>(Offsets);
if (Offs->getOpcode() != Instruction::Add && if (Offs->getOpcode() != Instruction::Add &&
Offs->getOpcode() != Instruction::Mul) Offs->getOpcode() != Instruction::Mul &&
Offs->getOpcode() != Instruction::Shl)
return false; return false;
Loop *L = LI->getLoopFor(BB); Loop *L = LI->getLoopFor(BB);
if (L == nullptr) if (L == nullptr)
@ -1063,8 +1069,9 @@ bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1); pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
break; break;
case Instruction::Mul: case Instruction::Mul:
pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock, case Instruction::Shl:
Builder); pushOutMulShl(Offs->getOpcode(), NewPhi, IncrementPerRound,
OffsSecondOperand, IncrementingBlock, Builder);
break; break;
default: default:
return false; return false;

View File

@ -1410,24 +1410,22 @@ define void @shl(i32* nocapture %x, i32* noalias nocapture readonly %y, i32 %n)
; CHECK-NEXT: .LBB15_1: @ %vector.ph ; CHECK-NEXT: .LBB15_1: @ %vector.ph
; CHECK-NEXT: adr r3, .LCPI15_0 ; CHECK-NEXT: adr r3, .LCPI15_0
; CHECK-NEXT: vldrw.u32 q0, [r3] ; CHECK-NEXT: vldrw.u32 q0, [r3]
; CHECK-NEXT: vmov.i32 q1, #0x4 ; CHECK-NEXT: vadd.i32 q0, q0, r1
; CHECK-NEXT: dlstp.32 lr, r2 ; CHECK-NEXT: dlstp.32 lr, r2
; CHECK-NEXT: .LBB15_2: @ %vector.body ; CHECK-NEXT: .LBB15_2: @ %vector.body
; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 ; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1
; CHECK-NEXT: vshl.i32 q2, q0, #2 ; CHECK-NEXT: vldrw.u32 q1, [q0, #64]!
; CHECK-NEXT: vadd.i32 q0, q0, q1 ; CHECK-NEXT: vstrw.32 q1, [r0], #16
; CHECK-NEXT: vldrw.u32 q3, [r1, q2, uxtw #2]
; CHECK-NEXT: vstrw.32 q3, [r0], #16
; CHECK-NEXT: letp lr, .LBB15_2 ; CHECK-NEXT: letp lr, .LBB15_2
; CHECK-NEXT: @ %bb.3: @ %for.cond.cleanup ; CHECK-NEXT: @ %bb.3: @ %for.cond.cleanup
; CHECK-NEXT: pop {r7, pc} ; CHECK-NEXT: pop {r7, pc}
; CHECK-NEXT: .p2align 4 ; CHECK-NEXT: .p2align 4
; CHECK-NEXT: @ %bb.4: ; CHECK-NEXT: @ %bb.4:
; CHECK-NEXT: .LCPI15_0: ; CHECK-NEXT: .LCPI15_0:
; CHECK-NEXT: .long 0 @ 0x0 ; CHECK-NEXT: .long 4294967232 @ 0xffffffc0
; CHECK-NEXT: .long 1 @ 0x1 ; CHECK-NEXT: .long 4294967248 @ 0xffffffd0
; CHECK-NEXT: .long 2 @ 0x2 ; CHECK-NEXT: .long 4294967264 @ 0xffffffe0
; CHECK-NEXT: .long 3 @ 0x3 ; CHECK-NEXT: .long 4294967280 @ 0xfffffff0
entry: entry:
%cmp6 = icmp sgt i32 %n, 0 %cmp6 = icmp sgt i32 %n, 0
br i1 %cmp6, label %vector.ph, label %for.cond.cleanup br i1 %cmp6, label %vector.ph, label %for.cond.cleanup

View File

@ -236,24 +236,22 @@ define void @shl(i32* nocapture readonly %x, i32* noalias nocapture %y, i32 %n)
; CHECK-NEXT: .LBB4_1: @ %vector.ph ; CHECK-NEXT: .LBB4_1: @ %vector.ph
; CHECK-NEXT: adr r3, .LCPI4_0 ; CHECK-NEXT: adr r3, .LCPI4_0
; CHECK-NEXT: vldrw.u32 q0, [r3] ; CHECK-NEXT: vldrw.u32 q0, [r3]
; CHECK-NEXT: vmov.i32 q1, #0x4 ; CHECK-NEXT: vadd.i32 q0, q0, r1
; CHECK-NEXT: dlstp.32 lr, r2 ; CHECK-NEXT: dlstp.32 lr, r2
; CHECK-NEXT: .LBB4_2: @ %vector.body ; CHECK-NEXT: .LBB4_2: @ %vector.body
; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 ; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1
; CHECK-NEXT: vshl.i32 q3, q0, #2 ; CHECK-NEXT: vldrw.u32 q1, [r0], #16
; CHECK-NEXT: vadd.i32 q0, q0, q1 ; CHECK-NEXT: vstrw.32 q1, [q0, #64]!
; CHECK-NEXT: vldrw.u32 q2, [r0], #16
; CHECK-NEXT: vstrw.32 q2, [r1, q3, uxtw #2]
; CHECK-NEXT: letp lr, .LBB4_2 ; CHECK-NEXT: letp lr, .LBB4_2
; CHECK-NEXT: @ %bb.3: @ %for.cond.cleanup ; CHECK-NEXT: @ %bb.3: @ %for.cond.cleanup
; CHECK-NEXT: pop {r7, pc} ; CHECK-NEXT: pop {r7, pc}
; CHECK-NEXT: .p2align 4 ; CHECK-NEXT: .p2align 4
; CHECK-NEXT: @ %bb.4: ; CHECK-NEXT: @ %bb.4:
; CHECK-NEXT: .LCPI4_0: ; CHECK-NEXT: .LCPI4_0:
; CHECK-NEXT: .long 0 @ 0x0 ; CHECK-NEXT: .long 4294967232 @ 0xffffffc0
; CHECK-NEXT: .long 1 @ 0x1 ; CHECK-NEXT: .long 4294967248 @ 0xffffffd0
; CHECK-NEXT: .long 2 @ 0x2 ; CHECK-NEXT: .long 4294967264 @ 0xffffffe0
; CHECK-NEXT: .long 3 @ 0x3 ; CHECK-NEXT: .long 4294967280 @ 0xfffffff0
entry: entry:
%cmp6 = icmp sgt i32 %n, 0 %cmp6 = icmp sgt i32 %n, 0
br i1 %cmp6, label %vector.ph, label %for.cond.cleanup br i1 %cmp6, label %vector.ph, label %for.cond.cleanup