[Polly] Remove uses of type-less CreateLoad() APIs (NFC)

These are incompatible with opaque pointers and are going away.
Explicitly specify the loaded type instead.
This commit is contained in:
Nikita Popov 2021-03-11 17:01:48 +01:00
parent 0890b39ee9
commit ff9b37e95f
5 changed files with 46 additions and 45 deletions

View File

@ -687,8 +687,6 @@ private:
Value *getVectorValue(ScopStmt &Stmt, Value *Old, ValueMapT &VectorMap, Value *getVectorValue(ScopStmt &Stmt, Value *Old, ValueMapT &VectorMap,
VectorValueMapT &ScalarMaps, Loop *L); VectorValueMapT &ScalarMaps, Loop *L);
Type *getVectorPtrTy(const Value *V, int Width);
/// Load a vector from a set of adjacent scalars /// Load a vector from a set of adjacent scalars
/// ///
/// In case a set of scalars is known to be next to each other in memory, /// In case a set of scalars is known to be next to each other in memory,

View File

@ -171,8 +171,10 @@ public:
/// @param Expr The ast expression of type isl_ast_op_access /// @param Expr The ast expression of type isl_ast_op_access
/// for which we generate LLVM-IR. /// for which we generate LLVM-IR.
/// ///
/// @return The llvm::Value* containing the result of the computation. /// @return A pair of the llvm::Value* containing the result of the
llvm::Value *createAccessAddress(__isl_take isl_ast_expr *Expr); /// computation and the llvm::Type* it points to.
std::pair<llvm::Value *, llvm::Type *>
createAccessAddress(__isl_take isl_ast_expr *Expr);
/// Check if an @p Expr contains integer constants larger than 64 bit. /// Check if an @p Expr contains integer constants larger than 64 bit.
/// ///

View File

@ -316,8 +316,9 @@ Value *BlockGenerator::generateArrayLoad(ScopStmt &Stmt, LoadInst *Load,
Value *NewPointer = Value *NewPointer =
generateLocationAccessed(Stmt, Load, BBMap, LTS, NewAccesses); generateLocationAccessed(Stmt, Load, BBMap, LTS, NewAccesses);
Value *ScalarLoad = Builder.CreateAlignedLoad(NewPointer, Load->getAlign(), Value *ScalarLoad =
Load->getName() + "_p_scalar_"); Builder.CreateAlignedLoad(Load->getType(), NewPointer, Load->getAlign(),
Load->getName() + "_p_scalar_");
if (PollyDebugPrinting) if (PollyDebugPrinting)
RuntimeDebugBuilder::createCPUPrinter(Builder, "Load from ", NewPointer, RuntimeDebugBuilder::createCPUPrinter(Builder, "Load from ", NewPointer,
@ -575,8 +576,8 @@ void BlockGenerator::generateScalarLoads(
DT.dominates(cast<Instruction>(Address)->getParent(), DT.dominates(cast<Instruction>(Address)->getParent(),
Builder.GetInsertBlock())) && Builder.GetInsertBlock())) &&
"Domination violation"); "Domination violation");
BBMap[MA->getAccessValue()] = BBMap[MA->getAccessValue()] = Builder.CreateLoad(
Builder.CreateLoad(Address, Address->getName() + ".reload"); MA->getElementType(), Address, Address->getName() + ".reload");
} }
} }
@ -875,11 +876,12 @@ void BlockGenerator::createScalarFinalization(Scop &S) {
Instruction *EscapeInst = EscapeMapping.first; Instruction *EscapeInst = EscapeMapping.first;
const auto &EscapeMappingValue = EscapeMapping.second; const auto &EscapeMappingValue = EscapeMapping.second;
const EscapeUserVectorTy &EscapeUsers = EscapeMappingValue.second; const EscapeUserVectorTy &EscapeUsers = EscapeMappingValue.second;
Value *ScalarAddr = EscapeMappingValue.first; auto *ScalarAddr = cast<AllocaInst>(&*EscapeMappingValue.first);
// Reload the demoted instruction in the optimized version of the SCoP. // Reload the demoted instruction in the optimized version of the SCoP.
Value *EscapeInstReload = Value *EscapeInstReload =
Builder.CreateLoad(ScalarAddr, EscapeInst->getName() + ".final_reload"); Builder.CreateLoad(ScalarAddr->getAllocatedType(), ScalarAddr,
EscapeInst->getName() + ".final_reload");
EscapeInstReload = EscapeInstReload =
Builder.CreateBitOrPointerCast(EscapeInstReload, EscapeInst->getType()); Builder.CreateBitOrPointerCast(EscapeInstReload, EscapeInst->getType());
@ -959,7 +961,8 @@ void BlockGenerator::createExitPHINodeMerges(Scop &S) {
std::string Name = PHI->getName().str(); std::string Name = PHI->getName().str();
Value *ScalarAddr = getOrCreateAlloca(SAI); Value *ScalarAddr = getOrCreateAlloca(SAI);
Value *Reload = Builder.CreateLoad(ScalarAddr, Name + ".ph.final_reload"); Value *Reload = Builder.CreateLoad(SAI->getElementType(), ScalarAddr,
Name + ".ph.final_reload");
Reload = Builder.CreateBitOrPointerCast(Reload, PHI->getType()); Reload = Builder.CreateBitOrPointerCast(Reload, PHI->getType());
Value *OriginalValue = PHI->getIncomingValueForBlock(MergeBB); Value *OriginalValue = PHI->getIncomingValueForBlock(MergeBB);
assert((!isa<Instruction>(OriginalValue) || assert((!isa<Instruction>(OriginalValue) ||
@ -1037,30 +1040,21 @@ Value *VectorBlockGenerator::getVectorValue(ScopStmt &Stmt, Value *Old,
return Vector; return Vector;
} }
Type *VectorBlockGenerator::getVectorPtrTy(const Value *Val, int Width) {
auto *PointerTy = cast<PointerType>(Val->getType());
unsigned AddrSpace = PointerTy->getAddressSpace();
Type *ScalarType = PointerTy->getElementType();
auto *FVTy = FixedVectorType::get(ScalarType, Width);
return PointerType::get(FVTy, AddrSpace);
}
Value *VectorBlockGenerator::generateStrideOneLoad( Value *VectorBlockGenerator::generateStrideOneLoad(
ScopStmt &Stmt, LoadInst *Load, VectorValueMapT &ScalarMaps, ScopStmt &Stmt, LoadInst *Load, VectorValueMapT &ScalarMaps,
__isl_keep isl_id_to_ast_expr *NewAccesses, bool NegativeStride = false) { __isl_keep isl_id_to_ast_expr *NewAccesses, bool NegativeStride = false) {
unsigned VectorWidth = getVectorWidth(); unsigned VectorWidth = getVectorWidth();
auto *Pointer = Load->getPointerOperand(); Type *VectorType = FixedVectorType::get(Load->getType(), VectorWidth);
Type *VectorPtrType = getVectorPtrTy(Pointer, VectorWidth); Type *VectorPtrType =
PointerType::get(VectorType, Load->getPointerAddressSpace());
unsigned Offset = NegativeStride ? VectorWidth - 1 : 0; unsigned Offset = NegativeStride ? VectorWidth - 1 : 0;
Value *NewPointer = generateLocationAccessed(Stmt, Load, ScalarMaps[Offset], Value *NewPointer = generateLocationAccessed(Stmt, Load, ScalarMaps[Offset],
VLTS[Offset], NewAccesses); VLTS[Offset], NewAccesses);
Value *VectorPtr = Value *VectorPtr =
Builder.CreateBitCast(NewPointer, VectorPtrType, "vector_ptr"); Builder.CreateBitCast(NewPointer, VectorPtrType, "vector_ptr");
LoadInst *VecLoad = LoadInst *VecLoad = Builder.CreateLoad(VectorType, VectorPtr,
Builder.CreateLoad(VectorPtr, Load->getName() + "_p_vec_full"); Load->getName() + "_p_vec_full");
if (!Aligned) if (!Aligned)
VecLoad->setAlignment(Align(8)); VecLoad->setAlignment(Align(8));
@ -1080,14 +1074,15 @@ Value *VectorBlockGenerator::generateStrideOneLoad(
Value *VectorBlockGenerator::generateStrideZeroLoad( Value *VectorBlockGenerator::generateStrideZeroLoad(
ScopStmt &Stmt, LoadInst *Load, ValueMapT &BBMap, ScopStmt &Stmt, LoadInst *Load, ValueMapT &BBMap,
__isl_keep isl_id_to_ast_expr *NewAccesses) { __isl_keep isl_id_to_ast_expr *NewAccesses) {
auto *Pointer = Load->getPointerOperand(); Type *VectorType = FixedVectorType::get(Load->getType(), 1);
Type *VectorPtrType = getVectorPtrTy(Pointer, 1); Type *VectorPtrType =
PointerType::get(VectorType, Load->getPointerAddressSpace());
Value *NewPointer = Value *NewPointer =
generateLocationAccessed(Stmt, Load, BBMap, VLTS[0], NewAccesses); generateLocationAccessed(Stmt, Load, BBMap, VLTS[0], NewAccesses);
Value *VectorPtr = Builder.CreateBitCast(NewPointer, VectorPtrType, Value *VectorPtr = Builder.CreateBitCast(NewPointer, VectorPtrType,
Load->getName() + "_p_vec_p"); Load->getName() + "_p_vec_p");
LoadInst *ScalarLoad = LoadInst *ScalarLoad = Builder.CreateLoad(VectorType, VectorPtr,
Builder.CreateLoad(VectorPtr, Load->getName() + "_p_splat_one"); Load->getName() + "_p_splat_one");
if (!Aligned) if (!Aligned)
ScalarLoad->setAlignment(Align(8)); ScalarLoad->setAlignment(Align(8));
@ -1104,9 +1099,8 @@ Value *VectorBlockGenerator::generateUnknownStrideLoad(
ScopStmt &Stmt, LoadInst *Load, VectorValueMapT &ScalarMaps, ScopStmt &Stmt, LoadInst *Load, VectorValueMapT &ScalarMaps,
__isl_keep isl_id_to_ast_expr *NewAccesses) { __isl_keep isl_id_to_ast_expr *NewAccesses) {
int VectorWidth = getVectorWidth(); int VectorWidth = getVectorWidth();
auto *Pointer = Load->getPointerOperand(); Type *ElemTy = Load->getType();
auto *FVTy = FixedVectorType::get( auto *FVTy = FixedVectorType::get(ElemTy, VectorWidth);
dyn_cast<PointerType>(Pointer->getType())->getElementType(), VectorWidth);
Value *Vector = UndefValue::get(FVTy); Value *Vector = UndefValue::get(FVTy);
@ -1114,7 +1108,7 @@ Value *VectorBlockGenerator::generateUnknownStrideLoad(
Value *NewPointer = generateLocationAccessed(Stmt, Load, ScalarMaps[i], Value *NewPointer = generateLocationAccessed(Stmt, Load, ScalarMaps[i],
VLTS[i], NewAccesses); VLTS[i], NewAccesses);
Value *ScalarLoad = Value *ScalarLoad =
Builder.CreateLoad(NewPointer, Load->getName() + "_p_scalar_"); Builder.CreateLoad(ElemTy, NewPointer, Load->getName() + "_p_scalar_");
Vector = Builder.CreateInsertElement( Vector = Builder.CreateInsertElement(
Vector, ScalarLoad, Builder.getInt32(i), Load->getName() + "_p_vec_"); Vector, ScalarLoad, Builder.getInt32(i), Load->getName() + "_p_vec_");
} }
@ -1192,7 +1186,6 @@ void VectorBlockGenerator::copyStore(
VectorValueMapT &ScalarMaps, __isl_keep isl_id_to_ast_expr *NewAccesses) { VectorValueMapT &ScalarMaps, __isl_keep isl_id_to_ast_expr *NewAccesses) {
const MemoryAccess &Access = Stmt.getArrayAccessFor(Store); const MemoryAccess &Access = Stmt.getArrayAccessFor(Store);
auto *Pointer = Store->getPointerOperand();
Value *Vector = getVectorValue(Stmt, Store->getValueOperand(), VectorMap, Value *Vector = getVectorValue(Stmt, Store->getValueOperand(), VectorMap,
ScalarMaps, getLoopForStmt(Stmt)); ScalarMaps, getLoopForStmt(Stmt));
@ -1201,7 +1194,10 @@ void VectorBlockGenerator::copyStore(
extractScalarValues(Store, VectorMap, ScalarMaps); extractScalarValues(Store, VectorMap, ScalarMaps);
if (Access.isStrideOne(isl::manage_copy(Schedule))) { if (Access.isStrideOne(isl::manage_copy(Schedule))) {
Type *VectorPtrType = getVectorPtrTy(Pointer, getVectorWidth()); Type *VectorType = FixedVectorType::get(Store->getValueOperand()->getType(),
getVectorWidth());
Type *VectorPtrType =
PointerType::get(VectorType, Store->getPointerAddressSpace());
Value *NewPointer = generateLocationAccessed(Stmt, Store, ScalarMaps[0], Value *NewPointer = generateLocationAccessed(Stmt, Store, ScalarMaps[0],
VLTS[0], NewAccesses); VLTS[0], NewAccesses);
@ -1339,10 +1335,13 @@ void VectorBlockGenerator::generateScalarVectorLoads(
continue; continue;
auto *Address = getOrCreateAlloca(*MA); auto *Address = getOrCreateAlloca(*MA);
Type *VectorPtrType = getVectorPtrTy(Address, 1); Type *VectorType = FixedVectorType::get(MA->getElementType(), 1);
Type *VectorPtrType = PointerType::get(
VectorType, Address->getType()->getPointerAddressSpace());
Value *VectorPtr = Builder.CreateBitCast(Address, VectorPtrType, Value *VectorPtr = Builder.CreateBitCast(Address, VectorPtrType,
Address->getName() + "_p_vec_p"); Address->getName() + "_p_vec_p");
auto *Val = Builder.CreateLoad(VectorPtr, Address->getName() + ".reload"); auto *Val = Builder.CreateLoad(VectorType, VectorPtr,
Address->getName() + ".reload");
Constant *SplatVector = Constant::getNullValue( Constant *SplatVector = Constant::getNullValue(
FixedVectorType::get(Builder.getInt32Ty(), getVectorWidth())); FixedVectorType::get(Builder.getInt32Ty(), getVectorWidth()));

View File

@ -231,7 +231,8 @@ Value *IslExprBuilder::createOpNAry(__isl_take isl_ast_expr *Expr) {
return V; return V;
} }
Value *IslExprBuilder::createAccessAddress(isl_ast_expr *Expr) { std::pair<Value *, Type *>
IslExprBuilder::createAccessAddress(isl_ast_expr *Expr) {
assert(isl_ast_expr_get_type(Expr) == isl_ast_expr_op && assert(isl_ast_expr_get_type(Expr) == isl_ast_expr_op &&
"isl ast expression not of type isl_ast_op"); "isl ast expression not of type isl_ast_op");
assert(isl_ast_expr_get_op_type(Expr) == isl_ast_op_access && assert(isl_ast_expr_get_op_type(Expr) == isl_ast_op_access &&
@ -281,7 +282,7 @@ Value *IslExprBuilder::createAccessAddress(isl_ast_expr *Expr) {
isl_ast_expr_free(Expr); isl_ast_expr_free(Expr);
if (PollyDebugPrinting) if (PollyDebugPrinting)
RuntimeDebugBuilder::createCPUPrinter(Builder, "\n"); RuntimeDebugBuilder::createCPUPrinter(Builder, "\n");
return Base; return {Base, SAI->getElementType()};
} }
IndexOp = nullptr; IndexOp = nullptr;
@ -338,13 +339,14 @@ Value *IslExprBuilder::createAccessAddress(isl_ast_expr *Expr) {
if (PollyDebugPrinting) if (PollyDebugPrinting)
RuntimeDebugBuilder::createCPUPrinter(Builder, "\n"); RuntimeDebugBuilder::createCPUPrinter(Builder, "\n");
isl_ast_expr_free(Expr); isl_ast_expr_free(Expr);
return Access; return {Access, SAI->getElementType()};
} }
Value *IslExprBuilder::createOpAccess(isl_ast_expr *Expr) { Value *IslExprBuilder::createOpAccess(isl_ast_expr *Expr) {
Value *Addr = createAccessAddress(Expr); auto Info = createAccessAddress(Expr);
assert(Addr && "Could not create op access address"); assert(Info.first && "Could not create op access address");
return Builder.CreateLoad(Addr, Addr->getName() + ".load"); return Builder.CreateLoad(Info.second, Info.first,
Info.first->getName() + ".load");
} }
Value *IslExprBuilder::createOpBin(__isl_take isl_ast_expr *Expr) { Value *IslExprBuilder::createOpBin(__isl_take isl_ast_expr *Expr) {
@ -704,7 +706,7 @@ Value *IslExprBuilder::createOpAddressOf(__isl_take isl_ast_expr *Expr) {
assert(isl_ast_expr_get_op_type(Op) == isl_ast_op_access && assert(isl_ast_expr_get_op_type(Op) == isl_ast_op_access &&
"Expected address of operator to be an access expression."); "Expected address of operator to be an access expression.");
Value *V = createAccessAddress(Op); Value *V = createAccessAddress(Op).first;
isl_ast_expr_free(Expr); isl_ast_expr_free(Expr);

View File

@ -952,7 +952,7 @@ void IslNodeBuilder::generateCopyStmt(
auto *LoadValue = ExprBuilder.create(AccessExpr); auto *LoadValue = ExprBuilder.create(AccessExpr);
AccessExpr = AccessExpr =
isl_id_to_ast_expr_get(NewAccesses, (*WriteAccess)->getId().release()); isl_id_to_ast_expr_get(NewAccesses, (*WriteAccess)->getId().release());
auto *StoreAddr = ExprBuilder.createAccessAddress(AccessExpr); auto *StoreAddr = ExprBuilder.createAccessAddress(AccessExpr).first;
Builder.CreateStore(LoadValue, StoreAddr); Builder.CreateStore(LoadValue, StoreAddr);
} }