forked from OSchip/llvm-project
1206 lines
47 KiB
C++
1206 lines
47 KiB
C++
//===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
/// This pass custom lowers llvm.gather and llvm.scatter instructions to
|
|
/// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
|
|
/// produce a better final result as we go.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "ARM.h"
|
|
#include "ARMBaseInstrInfo.h"
|
|
#include "ARMSubtarget.h"
|
|
#include "llvm/Analysis/LoopInfo.h"
|
|
#include "llvm/Analysis/TargetTransformInfo.h"
|
|
#include "llvm/CodeGen/TargetLowering.h"
|
|
#include "llvm/CodeGen/TargetPassConfig.h"
|
|
#include "llvm/CodeGen/TargetSubtargetInfo.h"
|
|
#include "llvm/InitializePasses.h"
|
|
#include "llvm/IR/BasicBlock.h"
|
|
#include "llvm/IR/Constant.h"
|
|
#include "llvm/IR/Constants.h"
|
|
#include "llvm/IR/DerivedTypes.h"
|
|
#include "llvm/IR/Function.h"
|
|
#include "llvm/IR/InstrTypes.h"
|
|
#include "llvm/IR/Instruction.h"
|
|
#include "llvm/IR/Instructions.h"
|
|
#include "llvm/IR/IntrinsicInst.h"
|
|
#include "llvm/IR/Intrinsics.h"
|
|
#include "llvm/IR/IntrinsicsARM.h"
|
|
#include "llvm/IR/IRBuilder.h"
|
|
#include "llvm/IR/PatternMatch.h"
|
|
#include "llvm/IR/Type.h"
|
|
#include "llvm/IR/Value.h"
|
|
#include "llvm/Pass.h"
|
|
#include "llvm/Support/Casting.h"
|
|
#include "llvm/Transforms/Utils/Local.h"
|
|
#include <algorithm>
|
|
#include <cassert>
|
|
|
|
using namespace llvm;
|
|
|
|
#define DEBUG_TYPE "arm-mve-gather-scatter-lowering"
|
|
|
|
cl::opt<bool> EnableMaskedGatherScatters(
|
|
"enable-arm-maskedgatscat", cl::Hidden, cl::init(true),
|
|
cl::desc("Enable the generation of masked gathers and scatters"));
|
|
|
|
namespace {
|
|
|
|
class MVEGatherScatterLowering : public FunctionPass {
|
|
public:
|
|
static char ID; // Pass identification, replacement for typeid
|
|
|
|
explicit MVEGatherScatterLowering() : FunctionPass(ID) {
|
|
initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry());
|
|
}
|
|
|
|
bool runOnFunction(Function &F) override;
|
|
|
|
StringRef getPassName() const override {
|
|
return "MVE gather/scatter lowering";
|
|
}
|
|
|
|
void getAnalysisUsage(AnalysisUsage &AU) const override {
|
|
AU.setPreservesCFG();
|
|
AU.addRequired<TargetPassConfig>();
|
|
AU.addRequired<LoopInfoWrapperPass>();
|
|
FunctionPass::getAnalysisUsage(AU);
|
|
}
|
|
|
|
private:
|
|
LoopInfo *LI = nullptr;
|
|
|
|
// Check this is a valid gather with correct alignment
|
|
bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
|
|
Align Alignment);
|
|
// Check whether Ptr is hidden behind a bitcast and look through it
|
|
void lookThroughBitcast(Value *&Ptr);
|
|
// Check for a getelementptr and deduce base and offsets from it, on success
|
|
// returning the base directly and the offsets indirectly using the Offsets
|
|
// argument
|
|
Value *checkGEP(Value *&Offsets, FixedVectorType *Ty, GetElementPtrInst *GEP,
|
|
IRBuilder<> &Builder);
|
|
// Compute the scale of this gather/scatter instruction
|
|
int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize);
|
|
// If the value is a constant, or derived from constants via additions
|
|
// and multilications, return its numeric value
|
|
Optional<int64_t> getIfConst(const Value *V);
|
|
// If Inst is an add instruction, check whether one summand is a
|
|
// constant. If so, scale this constant and return it together with
|
|
// the other summand.
|
|
std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale);
|
|
|
|
Value *lowerGather(IntrinsicInst *I);
|
|
// Create a gather from a base + vector of offsets
|
|
Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
|
|
Instruction *&Root, IRBuilder<> &Builder);
|
|
// Create a gather from a vector of pointers
|
|
Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
|
|
IRBuilder<> &Builder, int64_t Increment = 0);
|
|
// Create an incrementing gather from a vector of pointers
|
|
Value *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr,
|
|
IRBuilder<> &Builder,
|
|
int64_t Increment = 0);
|
|
|
|
Value *lowerScatter(IntrinsicInst *I);
|
|
// Create a scatter to a base + vector of offsets
|
|
Value *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets,
|
|
IRBuilder<> &Builder);
|
|
// Create a scatter to a vector of pointers
|
|
Value *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr,
|
|
IRBuilder<> &Builder,
|
|
int64_t Increment = 0);
|
|
// Create an incrementing scatter from a vector of pointers
|
|
Value *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr,
|
|
IRBuilder<> &Builder,
|
|
int64_t Increment = 0);
|
|
|
|
// QI gathers and scatters can increment their offsets on their own if
|
|
// the increment is a constant value (digit)
|
|
Value *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *BasePtr,
|
|
Value *Ptr, GetElementPtrInst *GEP,
|
|
IRBuilder<> &Builder);
|
|
// QI gathers/scatters can increment their offsets on their own if the
|
|
// increment is a constant value (digit) - this creates a writeback QI
|
|
// gather/scatter
|
|
Value *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr,
|
|
Value *Ptr, unsigned TypeScale,
|
|
IRBuilder<> &Builder);
|
|
|
|
// Optimise the base and offsets of the given address
|
|
bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI);
|
|
// Try to fold consecutive geps together into one
|
|
Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, IRBuilder<> &Builder);
|
|
// Check whether these offsets could be moved out of the loop they're in
|
|
bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
|
|
// Pushes the given add out of the loop
|
|
void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
|
|
// Pushes the given mul out of the loop
|
|
void pushOutMul(PHINode *&Phi, Value *IncrementPerRound,
|
|
Value *OffsSecondOperand, unsigned LoopIncrement,
|
|
IRBuilder<> &Builder);
|
|
};
|
|
|
|
} // end anonymous namespace
|
|
|
|
char MVEGatherScatterLowering::ID = 0;
|
|
|
|
INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
|
|
"MVE gather/scattering lowering pass", false, false)
|
|
|
|
Pass *llvm::createMVEGatherScatterLoweringPass() {
|
|
return new MVEGatherScatterLowering();
|
|
}
|
|
|
|
bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
|
|
unsigned ElemSize,
|
|
Align Alignment) {
|
|
if (((NumElements == 4 &&
|
|
(ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
|
|
(NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
|
|
(NumElements == 16 && ElemSize == 8)) &&
|
|
Alignment >= ElemSize / 8)
|
|
return true;
|
|
LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "
|
|
<< "valid alignment or vector type \n");
|
|
return false;
|
|
}
|
|
|
|
static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) {
|
|
// Offsets that are not of type <N x i32> are sign extended by the
|
|
// getelementptr instruction, and MVE gathers/scatters treat the offset as
|
|
// unsigned. Thus, if the element size is smaller than 32, we can only allow
|
|
// positive offsets - i.e., the offsets are not allowed to be variables we
|
|
// can't look into.
|
|
// Additionally, <N x i32> offsets have to either originate from a zext of a
|
|
// vector with element types smaller or equal the type of the gather we're
|
|
// looking at, or consist of constants that we can check are small enough
|
|
// to fit into the gather type.
|
|
// Thus we check that 0 < value < 2^TargetElemSize.
|
|
unsigned TargetElemSize = 128 / TargetElemCount;
|
|
unsigned OffsetElemSize = cast<FixedVectorType>(Offsets->getType())
|
|
->getElementType()
|
|
->getScalarSizeInBits();
|
|
if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) {
|
|
Constant *ConstOff = dyn_cast<Constant>(Offsets);
|
|
if (!ConstOff)
|
|
return false;
|
|
int64_t TargetElemMaxSize = (1ULL << TargetElemSize);
|
|
auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) {
|
|
ConstantInt *OConst = dyn_cast<ConstantInt>(OffsetElem);
|
|
if (!OConst)
|
|
return false;
|
|
int SExtValue = OConst->getSExtValue();
|
|
if (SExtValue >= TargetElemMaxSize || SExtValue < 0)
|
|
return false;
|
|
return true;
|
|
};
|
|
if (isa<FixedVectorType>(ConstOff->getType())) {
|
|
for (unsigned i = 0; i < TargetElemCount; i++) {
|
|
if (!CheckValueSize(ConstOff->getAggregateElement(i)))
|
|
return false;
|
|
}
|
|
} else {
|
|
if (!CheckValueSize(ConstOff))
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, FixedVectorType *Ty,
|
|
GetElementPtrInst *GEP,
|
|
IRBuilder<> &Builder) {
|
|
if (!GEP) {
|
|
LLVM_DEBUG(
|
|
dbgs() << "masked gathers/scatters: no getelementpointer found\n");
|
|
return nullptr;
|
|
}
|
|
LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."
|
|
<< " Looking at intrinsic for base + vector of offsets\n");
|
|
Value *GEPPtr = GEP->getPointerOperand();
|
|
Offsets = GEP->getOperand(1);
|
|
if (GEPPtr->getType()->isVectorTy() ||
|
|
!isa<FixedVectorType>(Offsets->getType()))
|
|
return nullptr;
|
|
|
|
if (GEP->getNumOperands() != 2) {
|
|
LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"
|
|
<< " operands. Expanding.\n");
|
|
return nullptr;
|
|
}
|
|
Offsets = GEP->getOperand(1);
|
|
unsigned OffsetsElemCount =
|
|
cast<FixedVectorType>(Offsets->getType())->getNumElements();
|
|
// Paranoid check whether the number of parallel lanes is the same
|
|
assert(Ty->getNumElements() == OffsetsElemCount);
|
|
|
|
ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets);
|
|
if (ZextOffs)
|
|
Offsets = ZextOffs->getOperand(0);
|
|
FixedVectorType *OffsetType = cast<FixedVectorType>(Offsets->getType());
|
|
|
|
// If the offsets are already being zext-ed to <N x i32>, that relieves us of
|
|
// having to make sure that they won't overflow.
|
|
if (!ZextOffs || cast<FixedVectorType>(ZextOffs->getDestTy())
|
|
->getElementType()
|
|
->getScalarSizeInBits() != 32)
|
|
if (!checkOffsetSize(Offsets, OffsetsElemCount))
|
|
return nullptr;
|
|
|
|
// The offset sizes have been checked; if any truncating or zext-ing is
|
|
// required to fix them, do that now
|
|
if (Ty != Offsets->getType()) {
|
|
if ((Ty->getElementType()->getScalarSizeInBits() <
|
|
OffsetType->getElementType()->getScalarSizeInBits())) {
|
|
Offsets = Builder.CreateTrunc(Offsets, Ty);
|
|
} else {
|
|
Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty));
|
|
}
|
|
}
|
|
// If none of the checks failed, return the gep's base pointer
|
|
LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");
|
|
return GEPPtr;
|
|
}
|
|
|
|
void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
|
|
// Look through bitcast instruction if #elements is the same
|
|
if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
|
|
auto *BCTy = cast<FixedVectorType>(BitCast->getType());
|
|
auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType());
|
|
if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
|
|
LLVM_DEBUG(
|
|
dbgs() << "masked gathers/scatters: looking through bitcast\n");
|
|
Ptr = BitCast->getOperand(0);
|
|
}
|
|
}
|
|
}
|
|
|
|
int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize,
|
|
unsigned MemoryElemSize) {
|
|
// This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2,
|
|
// or a 8bit, 16bit or 32bit load/store scaled by 1
|
|
if (GEPElemSize == 32 && MemoryElemSize == 32)
|
|
return 2;
|
|
else if (GEPElemSize == 16 && MemoryElemSize == 16)
|
|
return 1;
|
|
else if (GEPElemSize == 8)
|
|
return 0;
|
|
LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "
|
|
<< "create intrinsic\n");
|
|
return -1;
|
|
}
|
|
|
|
Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
|
|
const Constant *C = dyn_cast<Constant>(V);
|
|
if (C != nullptr)
|
|
return Optional<int64_t>{C->getUniqueInteger().getSExtValue()};
|
|
if (!isa<Instruction>(V))
|
|
return Optional<int64_t>{};
|
|
|
|
const Instruction *I = cast<Instruction>(V);
|
|
if (I->getOpcode() == Instruction::Add ||
|
|
I->getOpcode() == Instruction::Mul) {
|
|
Optional<int64_t> Op0 = getIfConst(I->getOperand(0));
|
|
Optional<int64_t> Op1 = getIfConst(I->getOperand(1));
|
|
if (!Op0 || !Op1)
|
|
return Optional<int64_t>{};
|
|
if (I->getOpcode() == Instruction::Add)
|
|
return Optional<int64_t>{Op0.getValue() + Op1.getValue()};
|
|
if (I->getOpcode() == Instruction::Mul)
|
|
return Optional<int64_t>{Op0.getValue() * Op1.getValue()};
|
|
}
|
|
return Optional<int64_t>{};
|
|
}
|
|
|
|
std::pair<Value *, int64_t>
|
|
MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) {
|
|
std::pair<Value *, int64_t> ReturnFalse =
|
|
std::pair<Value *, int64_t>(nullptr, 0);
|
|
// At this point, the instruction we're looking at must be an add or we
|
|
// bail out
|
|
Instruction *Add = dyn_cast<Instruction>(Inst);
|
|
if (Add == nullptr || Add->getOpcode() != Instruction::Add)
|
|
return ReturnFalse;
|
|
|
|
Value *Summand;
|
|
Optional<int64_t> Const;
|
|
// Find out which operand the value that is increased is
|
|
if ((Const = getIfConst(Add->getOperand(0))))
|
|
Summand = Add->getOperand(1);
|
|
else if ((Const = getIfConst(Add->getOperand(1))))
|
|
Summand = Add->getOperand(0);
|
|
else
|
|
return ReturnFalse;
|
|
|
|
// Check that the constant is small enough for an incrementing gather
|
|
int64_t Immediate = Const.getValue() << TypeScale;
|
|
if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0)
|
|
return ReturnFalse;
|
|
|
|
return std::pair<Value *, int64_t>(Summand, Immediate);
|
|
}
|
|
|
|
Value *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
|
|
using namespace PatternMatch;
|
|
LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n");
|
|
|
|
// @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
|
|
// Attempt to turn the masked gather in I into a MVE intrinsic
|
|
// Potentially optimising the addressing modes as we do so.
|
|
auto *Ty = cast<FixedVectorType>(I->getType());
|
|
Value *Ptr = I->getArgOperand(0);
|
|
Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue();
|
|
Value *Mask = I->getArgOperand(2);
|
|
Value *PassThru = I->getArgOperand(3);
|
|
|
|
if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
|
|
Alignment))
|
|
return nullptr;
|
|
lookThroughBitcast(Ptr);
|
|
assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
|
|
|
|
IRBuilder<> Builder(I->getContext());
|
|
Builder.SetInsertPoint(I);
|
|
Builder.SetCurrentDebugLocation(I->getDebugLoc());
|
|
|
|
Instruction *Root = I;
|
|
Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);
|
|
if (!Load)
|
|
Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
|
|
if (!Load)
|
|
return nullptr;
|
|
|
|
if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
|
|
LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
|
|
<< "creating select\n");
|
|
Load = Builder.CreateSelect(Mask, Load, PassThru);
|
|
}
|
|
|
|
Root->replaceAllUsesWith(Load);
|
|
Root->eraseFromParent();
|
|
if (Root != I)
|
|
// If this was an extending gather, we need to get rid of the sext/zext
|
|
// sext/zext as well as of the gather itself
|
|
I->eraseFromParent();
|
|
|
|
LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n");
|
|
return Load;
|
|
}
|
|
|
|
Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(IntrinsicInst *I,
|
|
Value *Ptr,
|
|
IRBuilder<> &Builder,
|
|
int64_t Increment) {
|
|
using namespace PatternMatch;
|
|
auto *Ty = cast<FixedVectorType>(I->getType());
|
|
LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
|
|
if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
|
|
// Can't build an intrinsic for this
|
|
return nullptr;
|
|
Value *Mask = I->getArgOperand(2);
|
|
if (match(Mask, m_One()))
|
|
return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
|
|
{Ty, Ptr->getType()},
|
|
{Ptr, Builder.getInt32(Increment)});
|
|
else
|
|
return Builder.CreateIntrinsic(
|
|
Intrinsic::arm_mve_vldr_gather_base_predicated,
|
|
{Ty, Ptr->getType(), Mask->getType()},
|
|
{Ptr, Builder.getInt32(Increment), Mask});
|
|
}
|
|
|
|
Value *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(
|
|
IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
|
|
using namespace PatternMatch;
|
|
auto *Ty = cast<FixedVectorType>(I->getType());
|
|
LLVM_DEBUG(
|
|
dbgs()
|
|
<< "masked gathers: loading from vector of pointers with writeback\n");
|
|
if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
|
|
// Can't build an intrinsic for this
|
|
return nullptr;
|
|
Value *Mask = I->getArgOperand(2);
|
|
if (match(Mask, m_One()))
|
|
return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb,
|
|
{Ty, Ptr->getType()},
|
|
{Ptr, Builder.getInt32(Increment)});
|
|
else
|
|
return Builder.CreateIntrinsic(
|
|
Intrinsic::arm_mve_vldr_gather_base_wb_predicated,
|
|
{Ty, Ptr->getType(), Mask->getType()},
|
|
{Ptr, Builder.getInt32(Increment), Mask});
|
|
}
|
|
|
|
Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
|
|
IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) {
|
|
using namespace PatternMatch;
|
|
|
|
Type *OriginalTy = I->getType();
|
|
Type *ResultTy = OriginalTy;
|
|
|
|
unsigned Unsigned = 1;
|
|
// The size of the gather was already checked in isLegalTypeAndAlignment;
|
|
// if it was not a full vector width an appropriate extend should follow.
|
|
auto *Extend = Root;
|
|
if (OriginalTy->getPrimitiveSizeInBits() < 128) {
|
|
// Only transform gathers with exactly one use
|
|
if (!I->hasOneUse())
|
|
return nullptr;
|
|
|
|
// The correct root to replace is not the CallInst itself, but the
|
|
// instruction which extends it
|
|
Extend = cast<Instruction>(*I->users().begin());
|
|
if (isa<SExtInst>(Extend)) {
|
|
Unsigned = 0;
|
|
} else if (!isa<ZExtInst>(Extend)) {
|
|
LLVM_DEBUG(dbgs() << "masked gathers: extend needed but not provided. "
|
|
<< "Expanding\n");
|
|
return nullptr;
|
|
}
|
|
LLVM_DEBUG(dbgs() << "masked gathers: found an extending gather\n");
|
|
ResultTy = Extend->getType();
|
|
// The final size of the gather must be a full vector width
|
|
if (ResultTy->getPrimitiveSizeInBits() != 128) {
|
|
LLVM_DEBUG(dbgs() << "masked gathers: extending from the wrong type. "
|
|
<< "Expanding\n");
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
|
|
Value *Offsets;
|
|
Value *BasePtr =
|
|
checkGEP(Offsets, cast<FixedVectorType>(ResultTy), GEP, Builder);
|
|
if (!BasePtr)
|
|
return nullptr;
|
|
// Check whether the offset is a constant increment that could be merged into
|
|
// a QI gather
|
|
Value *Load = tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder);
|
|
if (Load)
|
|
return Load;
|
|
|
|
int Scale = computeScale(
|
|
BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(),
|
|
OriginalTy->getScalarSizeInBits());
|
|
if (Scale == -1)
|
|
return nullptr;
|
|
Root = Extend;
|
|
|
|
Value *Mask = I->getArgOperand(2);
|
|
if (!match(Mask, m_One()))
|
|
return Builder.CreateIntrinsic(
|
|
Intrinsic::arm_mve_vldr_gather_offset_predicated,
|
|
{ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()},
|
|
{BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
|
|
Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask});
|
|
else
|
|
return Builder.CreateIntrinsic(
|
|
Intrinsic::arm_mve_vldr_gather_offset,
|
|
{ResultTy, BasePtr->getType(), Offsets->getType()},
|
|
{BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
|
|
Builder.getInt32(Scale), Builder.getInt32(Unsigned)});
|
|
}
|
|
|
|
Value *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) {
|
|
using namespace PatternMatch;
|
|
LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n");
|
|
|
|
// @llvm.masked.scatter.*(data, ptrs, alignment, mask)
|
|
// Attempt to turn the masked scatter in I into a MVE intrinsic
|
|
// Potentially optimising the addressing modes as we do so.
|
|
Value *Input = I->getArgOperand(0);
|
|
Value *Ptr = I->getArgOperand(1);
|
|
Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue();
|
|
auto *Ty = cast<FixedVectorType>(Input->getType());
|
|
|
|
if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
|
|
Alignment))
|
|
return nullptr;
|
|
|
|
lookThroughBitcast(Ptr);
|
|
assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
|
|
|
|
IRBuilder<> Builder(I->getContext());
|
|
Builder.SetInsertPoint(I);
|
|
Builder.SetCurrentDebugLocation(I->getDebugLoc());
|
|
|
|
Value *Store = tryCreateMaskedScatterOffset(I, Ptr, Builder);
|
|
if (!Store)
|
|
Store = tryCreateMaskedScatterBase(I, Ptr, Builder);
|
|
if (!Store)
|
|
return nullptr;
|
|
|
|
LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n");
|
|
I->eraseFromParent();
|
|
return Store;
|
|
}
|
|
|
|
Value *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
|
|
IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
|
|
using namespace PatternMatch;
|
|
Value *Input = I->getArgOperand(0);
|
|
auto *Ty = cast<FixedVectorType>(Input->getType());
|
|
// Only QR variants allow truncating
|
|
if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) {
|
|
// Can't build an intrinsic for this
|
|
return nullptr;
|
|
}
|
|
Value *Mask = I->getArgOperand(3);
|
|
// int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask)
|
|
LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");
|
|
if (match(Mask, m_One()))
|
|
return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,
|
|
{Ptr->getType(), Input->getType()},
|
|
{Ptr, Builder.getInt32(Increment), Input});
|
|
else
|
|
return Builder.CreateIntrinsic(
|
|
Intrinsic::arm_mve_vstr_scatter_base_predicated,
|
|
{Ptr->getType(), Input->getType(), Mask->getType()},
|
|
{Ptr, Builder.getInt32(Increment), Input, Mask});
|
|
}
|
|
|
|
Value *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(
|
|
IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
|
|
using namespace PatternMatch;
|
|
Value *Input = I->getArgOperand(0);
|
|
auto *Ty = cast<FixedVectorType>(Input->getType());
|
|
LLVM_DEBUG(
|
|
dbgs()
|
|
<< "masked scatters: storing to a vector of pointers with writeback\n");
|
|
if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
|
|
// Can't build an intrinsic for this
|
|
return nullptr;
|
|
Value *Mask = I->getArgOperand(3);
|
|
if (match(Mask, m_One()))
|
|
return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb,
|
|
{Ptr->getType(), Input->getType()},
|
|
{Ptr, Builder.getInt32(Increment), Input});
|
|
else
|
|
return Builder.CreateIntrinsic(
|
|
Intrinsic::arm_mve_vstr_scatter_base_wb_predicated,
|
|
{Ptr->getType(), Input->getType(), Mask->getType()},
|
|
{Ptr, Builder.getInt32(Increment), Input, Mask});
|
|
}
|
|
|
|
Value *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
|
|
IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
|
|
using namespace PatternMatch;
|
|
Value *Input = I->getArgOperand(0);
|
|
Value *Mask = I->getArgOperand(3);
|
|
Type *InputTy = Input->getType();
|
|
Type *MemoryTy = InputTy;
|
|
LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"
|
|
<< " to base + vector of offsets\n");
|
|
// If the input has been truncated, try to integrate that trunc into the
|
|
// scatter instruction (we don't care about alignment here)
|
|
if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) {
|
|
Value *PreTrunc = Trunc->getOperand(0);
|
|
Type *PreTruncTy = PreTrunc->getType();
|
|
if (PreTruncTy->getPrimitiveSizeInBits() == 128) {
|
|
Input = PreTrunc;
|
|
InputTy = PreTruncTy;
|
|
}
|
|
}
|
|
if (InputTy->getPrimitiveSizeInBits() != 128) {
|
|
LLVM_DEBUG(
|
|
dbgs() << "masked scatters: cannot create scatters for non-standard"
|
|
<< " input types. Expanding.\n");
|
|
return nullptr;
|
|
}
|
|
|
|
GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
|
|
Value *Offsets;
|
|
Value *BasePtr =
|
|
checkGEP(Offsets, cast<FixedVectorType>(InputTy), GEP, Builder);
|
|
if (!BasePtr)
|
|
return nullptr;
|
|
// Check whether the offset is a constant increment that could be merged into
|
|
// a QI gather
|
|
Value *Store =
|
|
tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder);
|
|
if (Store)
|
|
return Store;
|
|
int Scale = computeScale(
|
|
BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(),
|
|
MemoryTy->getScalarSizeInBits());
|
|
if (Scale == -1)
|
|
return nullptr;
|
|
|
|
if (!match(Mask, m_One()))
|
|
return Builder.CreateIntrinsic(
|
|
Intrinsic::arm_mve_vstr_scatter_offset_predicated,
|
|
{BasePtr->getType(), Offsets->getType(), Input->getType(),
|
|
Mask->getType()},
|
|
{BasePtr, Offsets, Input,
|
|
Builder.getInt32(MemoryTy->getScalarSizeInBits()),
|
|
Builder.getInt32(Scale), Mask});
|
|
else
|
|
return Builder.CreateIntrinsic(
|
|
Intrinsic::arm_mve_vstr_scatter_offset,
|
|
{BasePtr->getType(), Offsets->getType(), Input->getType()},
|
|
{BasePtr, Offsets, Input,
|
|
Builder.getInt32(MemoryTy->getScalarSizeInBits()),
|
|
Builder.getInt32(Scale)});
|
|
}
|
|
|
|
Value *MVEGatherScatterLowering::tryCreateIncrementingGatScat(
|
|
IntrinsicInst *I, Value *BasePtr, Value *Offsets, GetElementPtrInst *GEP,
|
|
IRBuilder<> &Builder) {
|
|
FixedVectorType *Ty;
|
|
if (I->getIntrinsicID() == Intrinsic::masked_gather)
|
|
Ty = cast<FixedVectorType>(I->getType());
|
|
else
|
|
Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType());
|
|
// Incrementing gathers only exist for v4i32
|
|
if (Ty->getNumElements() != 4 ||
|
|
Ty->getScalarSizeInBits() != 32)
|
|
return nullptr;
|
|
Loop *L = LI->getLoopFor(I->getParent());
|
|
if (L == nullptr)
|
|
// Incrementing gathers are not beneficial outside of a loop
|
|
return nullptr;
|
|
LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
|
|
"wb gather/scatter\n");
|
|
|
|
// The gep was in charge of making sure the offsets are scaled correctly
|
|
// - calculate that factor so it can be applied by hand
|
|
DataLayout DT = I->getParent()->getParent()->getParent()->getDataLayout();
|
|
int TypeScale =
|
|
computeScale(DT.getTypeSizeInBits(GEP->getOperand(0)->getType()),
|
|
DT.getTypeSizeInBits(GEP->getType()) /
|
|
cast<FixedVectorType>(GEP->getType())->getNumElements());
|
|
if (TypeScale == -1)
|
|
return nullptr;
|
|
|
|
if (GEP->hasOneUse()) {
|
|
// Only in this case do we want to build a wb gather, because the wb will
|
|
// change the phi which does affect other users of the gep (which will still
|
|
// be using the phi in the old way)
|
|
Value *Load =
|
|
tryCreateIncrementingWBGatScat(I, BasePtr, Offsets, TypeScale, Builder);
|
|
if (Load != nullptr)
|
|
return Load;
|
|
}
|
|
LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
|
|
"non-wb gather/scatter\n");
|
|
|
|
std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
|
|
if (Add.first == nullptr)
|
|
return nullptr;
|
|
Value *OffsetsIncoming = Add.first;
|
|
int64_t Immediate = Add.second;
|
|
|
|
// Make sure the offsets are scaled correctly
|
|
Instruction *ScaledOffsets = BinaryOperator::Create(
|
|
Instruction::Shl, OffsetsIncoming,
|
|
Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)),
|
|
"ScaledIndex", I);
|
|
// Add the base to the offsets
|
|
OffsetsIncoming = BinaryOperator::Create(
|
|
Instruction::Add, ScaledOffsets,
|
|
Builder.CreateVectorSplat(
|
|
Ty->getNumElements(),
|
|
Builder.CreatePtrToInt(
|
|
BasePtr,
|
|
cast<VectorType>(ScaledOffsets->getType())->getElementType())),
|
|
"StartIndex", I);
|
|
|
|
if (I->getIntrinsicID() == Intrinsic::masked_gather)
|
|
return cast<IntrinsicInst>(
|
|
tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate));
|
|
else
|
|
return cast<IntrinsicInst>(
|
|
tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate));
|
|
}
|
|
|
|
Value *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(
|
|
IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale,
|
|
IRBuilder<> &Builder) {
|
|
// Check whether this gather's offset is incremented by a constant - if so,
|
|
// and the load is of the right type, we can merge this into a QI gather
|
|
Loop *L = LI->getLoopFor(I->getParent());
|
|
// Offsets that are worth merging into this instruction will be incremented
|
|
// by a constant, thus we're looking for an add of a phi and a constant
|
|
PHINode *Phi = dyn_cast<PHINode>(Offsets);
|
|
if (Phi == nullptr || Phi->getNumIncomingValues() != 2 ||
|
|
Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2)
|
|
// No phi means no IV to write back to; if there is a phi, we expect it
|
|
// to have exactly two incoming values; the only phis we are interested in
|
|
// will be loop IV's and have exactly two uses, one in their increment and
|
|
// one in the gather's gep
|
|
return nullptr;
|
|
|
|
unsigned IncrementIndex =
|
|
Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1;
|
|
// Look through the phi to the phi increment
|
|
Offsets = Phi->getIncomingValue(IncrementIndex);
|
|
|
|
std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
|
|
if (Add.first == nullptr)
|
|
return nullptr;
|
|
Value *OffsetsIncoming = Add.first;
|
|
int64_t Immediate = Add.second;
|
|
if (OffsetsIncoming != Phi)
|
|
// Then the increment we are looking at is not an increment of the
|
|
// induction variable, and we don't want to do a writeback
|
|
return nullptr;
|
|
|
|
Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back());
|
|
unsigned NumElems =
|
|
cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements();
|
|
|
|
// Make sure the offsets are scaled correctly
|
|
Instruction *ScaledOffsets = BinaryOperator::Create(
|
|
Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex),
|
|
Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)),
|
|
"ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
|
|
// Add the base to the offsets
|
|
OffsetsIncoming = BinaryOperator::Create(
|
|
Instruction::Add, ScaledOffsets,
|
|
Builder.CreateVectorSplat(
|
|
NumElems,
|
|
Builder.CreatePtrToInt(
|
|
BasePtr,
|
|
cast<VectorType>(ScaledOffsets->getType())->getElementType())),
|
|
"StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
|
|
// The gather is pre-incrementing
|
|
OffsetsIncoming = BinaryOperator::Create(
|
|
Instruction::Sub, OffsetsIncoming,
|
|
Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)),
|
|
"PreIncrementStartIndex",
|
|
&Phi->getIncomingBlock(1 - IncrementIndex)->back());
|
|
Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming);
|
|
|
|
Builder.SetInsertPoint(I);
|
|
|
|
Value *EndResult;
|
|
Value *NewInduction;
|
|
if (I->getIntrinsicID() == Intrinsic::masked_gather) {
|
|
// Build the incrementing gather
|
|
Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate);
|
|
// One value to be handed to whoever uses the gather, one is the loop
|
|
// increment
|
|
EndResult = Builder.CreateExtractValue(Load, 0, "Gather");
|
|
NewInduction = Builder.CreateExtractValue(Load, 1, "GatherIncrement");
|
|
} else {
|
|
// Build the incrementing scatter
|
|
NewInduction = tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate);
|
|
EndResult = NewInduction;
|
|
}
|
|
Instruction *AddInst = cast<Instruction>(Offsets);
|
|
AddInst->replaceAllUsesWith(NewInduction);
|
|
AddInst->eraseFromParent();
|
|
Phi->setIncomingValue(IncrementIndex, NewInduction);
|
|
|
|
return EndResult;
|
|
}
|
|
|
|
void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
|
|
Value *OffsSecondOperand,
|
|
unsigned StartIndex) {
|
|
LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");
|
|
Instruction *InsertionPoint =
|
|
&cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back());
|
|
// Initialize the phi with a vector that contains a sum of the constants
|
|
Instruction *NewIndex = BinaryOperator::Create(
|
|
Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand,
|
|
"PushedOutAdd", InsertionPoint);
|
|
unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
|
|
|
|
// Order such that start index comes first (this reduces mov's)
|
|
Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex));
|
|
Phi->addIncoming(Phi->getIncomingValue(IncrementIndex),
|
|
Phi->getIncomingBlock(IncrementIndex));
|
|
Phi->removeIncomingValue(IncrementIndex);
|
|
Phi->removeIncomingValue(StartIndex);
|
|
}
|
|
|
|
void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi,
|
|
Value *IncrementPerRound,
|
|
Value *OffsSecondOperand,
|
|
unsigned LoopIncrement,
|
|
IRBuilder<> &Builder) {
|
|
LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
|
|
|
|
// Create a new scalar add outside of the loop and transform it to a splat
|
|
// by which loop variable can be incremented
|
|
Instruction *InsertionPoint = &cast<Instruction>(
|
|
Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back());
|
|
|
|
// Create a new index
|
|
Value *StartIndex = BinaryOperator::Create(
|
|
Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
|
|
OffsSecondOperand, "PushedOutMul", InsertionPoint);
|
|
|
|
Instruction *Product =
|
|
BinaryOperator::Create(Instruction::Mul, IncrementPerRound,
|
|
OffsSecondOperand, "Product", InsertionPoint);
|
|
// Increment NewIndex by Product instead of the multiplication
|
|
Instruction *NewIncrement = BinaryOperator::Create(
|
|
Instruction::Add, Phi, Product, "IncrementPushedOutMul",
|
|
cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back())
|
|
.getPrevNode());
|
|
|
|
Phi->addIncoming(StartIndex,
|
|
Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1));
|
|
Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement));
|
|
Phi->removeIncomingValue((unsigned)0);
|
|
Phi->removeIncomingValue((unsigned)0);
|
|
return;
|
|
}
|
|
|
|
// Check whether all usages of this instruction are as offsets of
|
|
// gathers/scatters or simple arithmetics only used by gathers/scatters
|
|
static bool hasAllGatScatUsers(Instruction *I) {
|
|
if (I->hasNUses(0)) {
|
|
return false;
|
|
}
|
|
bool Gatscat = true;
|
|
for (User *U : I->users()) {
|
|
if (!isa<Instruction>(U))
|
|
return false;
|
|
if (isa<GetElementPtrInst>(U) ||
|
|
isGatherScatter(dyn_cast<IntrinsicInst>(U))) {
|
|
return Gatscat;
|
|
} else {
|
|
unsigned OpCode = cast<Instruction>(U)->getOpcode();
|
|
if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) &&
|
|
hasAllGatScatUsers(cast<Instruction>(U))) {
|
|
continue;
|
|
}
|
|
return false;
|
|
}
|
|
}
|
|
return Gatscat;
|
|
}
|
|
|
|
bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
|
|
LoopInfo *LI) {
|
|
LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n");
|
|
// Optimise the addresses of gathers/scatters by moving invariant
|
|
// calculations out of the loop
|
|
if (!isa<Instruction>(Offsets))
|
|
return false;
|
|
Instruction *Offs = cast<Instruction>(Offsets);
|
|
if (Offs->getOpcode() != Instruction::Add &&
|
|
Offs->getOpcode() != Instruction::Mul)
|
|
return false;
|
|
Loop *L = LI->getLoopFor(BB);
|
|
if (L == nullptr)
|
|
return false;
|
|
if (!Offs->hasOneUse()) {
|
|
if (!hasAllGatScatUsers(Offs))
|
|
return false;
|
|
}
|
|
|
|
// Find out which, if any, operand of the instruction
|
|
// is a phi node
|
|
PHINode *Phi;
|
|
int OffsSecondOp;
|
|
if (isa<PHINode>(Offs->getOperand(0))) {
|
|
Phi = cast<PHINode>(Offs->getOperand(0));
|
|
OffsSecondOp = 1;
|
|
} else if (isa<PHINode>(Offs->getOperand(1))) {
|
|
Phi = cast<PHINode>(Offs->getOperand(1));
|
|
OffsSecondOp = 0;
|
|
} else {
|
|
bool Changed = true;
|
|
if (isa<Instruction>(Offs->getOperand(0)) &&
|
|
L->contains(cast<Instruction>(Offs->getOperand(0))))
|
|
Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI);
|
|
if (isa<Instruction>(Offs->getOperand(1)) &&
|
|
L->contains(cast<Instruction>(Offs->getOperand(1))))
|
|
Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI);
|
|
if (!Changed) {
|
|
return false;
|
|
} else {
|
|
if (isa<PHINode>(Offs->getOperand(0))) {
|
|
Phi = cast<PHINode>(Offs->getOperand(0));
|
|
OffsSecondOp = 1;
|
|
} else if (isa<PHINode>(Offs->getOperand(1))) {
|
|
Phi = cast<PHINode>(Offs->getOperand(1));
|
|
OffsSecondOp = 0;
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
// A phi node we want to perform this function on should be from the
|
|
// loop header, and shouldn't have more than 2 incoming values
|
|
if (Phi->getParent() != L->getHeader() ||
|
|
Phi->getNumIncomingValues() != 2)
|
|
return false;
|
|
|
|
// The phi must be an induction variable
|
|
Instruction *Op;
|
|
int IncrementingBlock = -1;
|
|
|
|
for (int i = 0; i < 2; i++)
|
|
if ((Op = dyn_cast<Instruction>(Phi->getIncomingValue(i))) != nullptr)
|
|
if (Op->getOpcode() == Instruction::Add &&
|
|
(Op->getOperand(0) == Phi || Op->getOperand(1) == Phi))
|
|
IncrementingBlock = i;
|
|
if (IncrementingBlock == -1)
|
|
return false;
|
|
|
|
Instruction *IncInstruction =
|
|
cast<Instruction>(Phi->getIncomingValue(IncrementingBlock));
|
|
|
|
// If the phi is not used by anything else, we can just adapt it when
|
|
// replacing the instruction; if it is, we'll have to duplicate it
|
|
PHINode *NewPhi;
|
|
Value *IncrementPerRound = IncInstruction->getOperand(
|
|
(IncInstruction->getOperand(0) == Phi) ? 1 : 0);
|
|
|
|
// Get the value that is added to/multiplied with the phi
|
|
Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp);
|
|
|
|
if (IncrementPerRound->getType() != OffsSecondOperand->getType())
|
|
// Something has gone wrong, abort
|
|
return false;
|
|
|
|
// Only proceed if the increment per round is a constant or an instruction
|
|
// which does not originate from within the loop
|
|
if (!isa<Constant>(IncrementPerRound) &&
|
|
!(isa<Instruction>(IncrementPerRound) &&
|
|
!L->contains(cast<Instruction>(IncrementPerRound))))
|
|
return false;
|
|
|
|
if (Phi->getNumUses() == 2) {
|
|
// No other users -> reuse existing phi (One user is the instruction
|
|
// we're looking at, the other is the phi increment)
|
|
if (IncInstruction->getNumUses() != 1) {
|
|
// If the incrementing instruction does have more users than
|
|
// our phi, we need to copy it
|
|
IncInstruction = BinaryOperator::Create(
|
|
Instruction::BinaryOps(IncInstruction->getOpcode()), Phi,
|
|
IncrementPerRound, "LoopIncrement", IncInstruction);
|
|
Phi->setIncomingValue(IncrementingBlock, IncInstruction);
|
|
}
|
|
NewPhi = Phi;
|
|
} else {
|
|
// There are other users -> create a new phi
|
|
NewPhi = PHINode::Create(Phi->getType(), 0, "NewPhi", Phi);
|
|
std::vector<Value *> Increases;
|
|
// Copy the incoming values of the old phi
|
|
NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1),
|
|
Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1));
|
|
IncInstruction = BinaryOperator::Create(
|
|
Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi,
|
|
IncrementPerRound, "LoopIncrement", IncInstruction);
|
|
NewPhi->addIncoming(IncInstruction,
|
|
Phi->getIncomingBlock(IncrementingBlock));
|
|
IncrementingBlock = 1;
|
|
}
|
|
|
|
IRBuilder<> Builder(BB->getContext());
|
|
Builder.SetInsertPoint(Phi);
|
|
Builder.SetCurrentDebugLocation(Offs->getDebugLoc());
|
|
|
|
switch (Offs->getOpcode()) {
|
|
case Instruction::Add:
|
|
pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
|
|
break;
|
|
case Instruction::Mul:
|
|
pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock,
|
|
Builder);
|
|
break;
|
|
default:
|
|
return false;
|
|
}
|
|
LLVM_DEBUG(
|
|
dbgs() << "masked gathers/scatters: simplified loop variable add/mul\n");
|
|
|
|
// The instruction has now been "absorbed" into the phi value
|
|
Offs->replaceAllUsesWith(NewPhi);
|
|
if (Offs->hasNUses(0))
|
|
Offs->eraseFromParent();
|
|
// Clean up the old increment in case it's unused because we built a new
|
|
// one
|
|
if (IncInstruction->hasNUses(0))
|
|
IncInstruction->eraseFromParent();
|
|
|
|
return true;
|
|
}
|
|
|
|
static Value *CheckAndCreateOffsetAdd(Value *X, Value *Y, Value *GEP,
|
|
IRBuilder<> &Builder) {
|
|
// Splat the non-vector value to a vector of the given type - if the value is
|
|
// a constant (and its value isn't too big), we can even use this opportunity
|
|
// to scale it to the size of the vector elements
|
|
auto FixSummands = [&Builder](FixedVectorType *&VT, Value *&NonVectorVal) {
|
|
ConstantInt *Const;
|
|
if ((Const = dyn_cast<ConstantInt>(NonVectorVal)) &&
|
|
VT->getElementType() != NonVectorVal->getType()) {
|
|
unsigned TargetElemSize = VT->getElementType()->getPrimitiveSizeInBits();
|
|
uint64_t N = Const->getZExtValue();
|
|
if (N < (unsigned)(1 << (TargetElemSize - 1))) {
|
|
NonVectorVal = Builder.CreateVectorSplat(
|
|
VT->getNumElements(), Builder.getIntN(TargetElemSize, N));
|
|
return;
|
|
}
|
|
}
|
|
NonVectorVal =
|
|
Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal);
|
|
};
|
|
|
|
FixedVectorType *XElType = dyn_cast<FixedVectorType>(X->getType());
|
|
FixedVectorType *YElType = dyn_cast<FixedVectorType>(Y->getType());
|
|
// If one of X, Y is not a vector, we have to splat it in order
|
|
// to add the two of them.
|
|
if (XElType && !YElType) {
|
|
FixSummands(XElType, Y);
|
|
YElType = cast<FixedVectorType>(Y->getType());
|
|
} else if (YElType && !XElType) {
|
|
FixSummands(YElType, X);
|
|
XElType = cast<FixedVectorType>(X->getType());
|
|
}
|
|
assert(XElType && YElType && "Unknown vector types");
|
|
// Check that the summands are of compatible types
|
|
if (XElType != YElType) {
|
|
LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n");
|
|
return nullptr;
|
|
}
|
|
|
|
if (XElType->getElementType()->getScalarSizeInBits() != 32) {
|
|
// Check that by adding the vectors we do not accidentally
|
|
// create an overflow
|
|
Constant *ConstX = dyn_cast<Constant>(X);
|
|
Constant *ConstY = dyn_cast<Constant>(Y);
|
|
if (!ConstX || !ConstY)
|
|
return nullptr;
|
|
unsigned TargetElemSize = 128 / XElType->getNumElements();
|
|
for (unsigned i = 0; i < XElType->getNumElements(); i++) {
|
|
ConstantInt *ConstXEl =
|
|
dyn_cast<ConstantInt>(ConstX->getAggregateElement(i));
|
|
ConstantInt *ConstYEl =
|
|
dyn_cast<ConstantInt>(ConstY->getAggregateElement(i));
|
|
if (!ConstXEl || !ConstYEl ||
|
|
ConstXEl->getZExtValue() + ConstYEl->getZExtValue() >=
|
|
(unsigned)(1 << (TargetElemSize - 1)))
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
Value *Add = Builder.CreateAdd(X, Y);
|
|
|
|
FixedVectorType *GEPType = cast<FixedVectorType>(GEP->getType());
|
|
if (checkOffsetSize(Add, GEPType->getNumElements()))
|
|
return Add;
|
|
else
|
|
return nullptr;
|
|
}
|
|
|
|
Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP,
|
|
Value *&Offsets,
|
|
IRBuilder<> &Builder) {
|
|
Value *GEPPtr = GEP->getPointerOperand();
|
|
Offsets = GEP->getOperand(1);
|
|
// We only merge geps with constant offsets, because only for those
|
|
// we can make sure that we do not cause an overflow
|
|
if (!isa<Constant>(Offsets))
|
|
return nullptr;
|
|
GetElementPtrInst *BaseGEP;
|
|
if ((BaseGEP = dyn_cast<GetElementPtrInst>(GEPPtr))) {
|
|
// Merge the two geps into one
|
|
Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Builder);
|
|
if (!BaseBasePtr)
|
|
return nullptr;
|
|
Offsets =
|
|
CheckAndCreateOffsetAdd(Offsets, GEP->getOperand(1), GEP, Builder);
|
|
if (Offsets == nullptr)
|
|
return nullptr;
|
|
return BaseBasePtr;
|
|
}
|
|
return GEPPtr;
|
|
}
|
|
|
|
bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB,
|
|
LoopInfo *LI) {
|
|
GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Address);
|
|
if (!GEP)
|
|
return false;
|
|
bool Changed = false;
|
|
if (GEP->hasOneUse() &&
|
|
dyn_cast<GetElementPtrInst>(GEP->getPointerOperand())) {
|
|
IRBuilder<> Builder(GEP->getContext());
|
|
Builder.SetInsertPoint(GEP);
|
|
Builder.SetCurrentDebugLocation(GEP->getDebugLoc());
|
|
Value *Offsets;
|
|
Value *Base = foldGEP(GEP, Offsets, Builder);
|
|
// We only want to merge the geps if there is a real chance that they can be
|
|
// used by an MVE gather; thus the offset has to have the correct size
|
|
// (always i32 if it is not of vector type) and the base has to be a
|
|
// pointer.
|
|
if (Offsets && Base && Base != GEP) {
|
|
PointerType *BaseType = cast<PointerType>(Base->getType());
|
|
GetElementPtrInst *NewAddress = GetElementPtrInst::Create(
|
|
BaseType->getPointerElementType(), Base, Offsets, "gep.merged", GEP);
|
|
GEP->replaceAllUsesWith(NewAddress);
|
|
GEP = NewAddress;
|
|
Changed = true;
|
|
}
|
|
}
|
|
Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI);
|
|
return Changed;
|
|
}
|
|
|
|
bool MVEGatherScatterLowering::runOnFunction(Function &F) {
|
|
if (!EnableMaskedGatherScatters)
|
|
return false;
|
|
auto &TPC = getAnalysis<TargetPassConfig>();
|
|
auto &TM = TPC.getTM<TargetMachine>();
|
|
auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
|
|
if (!ST->hasMVEIntegerOps())
|
|
return false;
|
|
LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
|
|
SmallVector<IntrinsicInst *, 4> Gathers;
|
|
SmallVector<IntrinsicInst *, 4> Scatters;
|
|
|
|
bool Changed = false;
|
|
|
|
for (BasicBlock &BB : F) {
|
|
for (Instruction &I : BB) {
|
|
IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
|
|
if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
|
|
isa<FixedVectorType>(II->getType())) {
|
|
Gathers.push_back(II);
|
|
Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI);
|
|
} else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
|
|
isa<FixedVectorType>(II->getArgOperand(0)->getType())) {
|
|
Scatters.push_back(II);
|
|
Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI);
|
|
}
|
|
}
|
|
}
|
|
for (unsigned i = 0; i < Gathers.size(); i++) {
|
|
IntrinsicInst *I = Gathers[i];
|
|
Value *L = lowerGather(I);
|
|
if (L == nullptr)
|
|
continue;
|
|
|
|
// Get rid of any now dead instructions
|
|
SimplifyInstructionsInBlock(cast<Instruction>(L)->getParent());
|
|
Changed = true;
|
|
}
|
|
|
|
for (unsigned i = 0; i < Scatters.size(); i++) {
|
|
IntrinsicInst *I = Scatters[i];
|
|
Value *S = lowerScatter(I);
|
|
if (S == nullptr)
|
|
continue;
|
|
|
|
// Get rid of any now dead instructions
|
|
SimplifyInstructionsInBlock(cast<Instruction>(S)->getParent());
|
|
Changed = true;
|
|
}
|
|
return Changed;
|
|
}
|