diff --git a/polly/include/polly/CodeGen/BlockGenerators.h b/polly/include/polly/CodeGen/BlockGenerators.h index f02524a34254..9c9fa60fba64 100644 --- a/polly/include/polly/CodeGen/BlockGenerators.h +++ b/polly/include/polly/CodeGen/BlockGenerators.h @@ -687,8 +687,6 @@ private: Value *getVectorValue(ScopStmt &Stmt, Value *Old, ValueMapT &VectorMap, VectorValueMapT &ScalarMaps, Loop *L); - Type *getVectorPtrTy(const Value *V, int Width); - /// Load a vector from a set of adjacent scalars /// /// In case a set of scalars is known to be next to each other in memory, diff --git a/polly/include/polly/CodeGen/IslExprBuilder.h b/polly/include/polly/CodeGen/IslExprBuilder.h index 998f8f6f7286..057724640525 100644 --- a/polly/include/polly/CodeGen/IslExprBuilder.h +++ b/polly/include/polly/CodeGen/IslExprBuilder.h @@ -171,8 +171,10 @@ public: /// @param Expr The ast expression of type isl_ast_op_access /// for which we generate LLVM-IR. /// - /// @return The llvm::Value* containing the result of the computation. - llvm::Value *createAccessAddress(__isl_take isl_ast_expr *Expr); + /// @return A pair of the llvm::Value* containing the result of the + /// computation and the llvm::Type* it points to. + std::pair + createAccessAddress(__isl_take isl_ast_expr *Expr); /// Check if an @p Expr contains integer constants larger than 64 bit. /// diff --git a/polly/lib/CodeGen/BlockGenerators.cpp b/polly/lib/CodeGen/BlockGenerators.cpp index bb7c99822592..14e85e0feb23 100644 --- a/polly/lib/CodeGen/BlockGenerators.cpp +++ b/polly/lib/CodeGen/BlockGenerators.cpp @@ -316,8 +316,9 @@ Value *BlockGenerator::generateArrayLoad(ScopStmt &Stmt, LoadInst *Load, Value *NewPointer = generateLocationAccessed(Stmt, Load, BBMap, LTS, NewAccesses); - Value *ScalarLoad = Builder.CreateAlignedLoad(NewPointer, Load->getAlign(), - Load->getName() + "_p_scalar_"); + Value *ScalarLoad = + Builder.CreateAlignedLoad(Load->getType(), NewPointer, Load->getAlign(), + Load->getName() + "_p_scalar_"); if (PollyDebugPrinting) RuntimeDebugBuilder::createCPUPrinter(Builder, "Load from ", NewPointer, @@ -575,8 +576,8 @@ void BlockGenerator::generateScalarLoads( DT.dominates(cast(Address)->getParent(), Builder.GetInsertBlock())) && "Domination violation"); - BBMap[MA->getAccessValue()] = - Builder.CreateLoad(Address, Address->getName() + ".reload"); + BBMap[MA->getAccessValue()] = Builder.CreateLoad( + MA->getElementType(), Address, Address->getName() + ".reload"); } } @@ -875,11 +876,12 @@ void BlockGenerator::createScalarFinalization(Scop &S) { Instruction *EscapeInst = EscapeMapping.first; const auto &EscapeMappingValue = EscapeMapping.second; const EscapeUserVectorTy &EscapeUsers = EscapeMappingValue.second; - Value *ScalarAddr = EscapeMappingValue.first; + auto *ScalarAddr = cast(&*EscapeMappingValue.first); // Reload the demoted instruction in the optimized version of the SCoP. Value *EscapeInstReload = - Builder.CreateLoad(ScalarAddr, EscapeInst->getName() + ".final_reload"); + Builder.CreateLoad(ScalarAddr->getAllocatedType(), ScalarAddr, + EscapeInst->getName() + ".final_reload"); EscapeInstReload = Builder.CreateBitOrPointerCast(EscapeInstReload, EscapeInst->getType()); @@ -959,7 +961,8 @@ void BlockGenerator::createExitPHINodeMerges(Scop &S) { std::string Name = PHI->getName().str(); Value *ScalarAddr = getOrCreateAlloca(SAI); - Value *Reload = Builder.CreateLoad(ScalarAddr, Name + ".ph.final_reload"); + Value *Reload = Builder.CreateLoad(SAI->getElementType(), ScalarAddr, + Name + ".ph.final_reload"); Reload = Builder.CreateBitOrPointerCast(Reload, PHI->getType()); Value *OriginalValue = PHI->getIncomingValueForBlock(MergeBB); assert((!isa(OriginalValue) || @@ -1037,30 +1040,21 @@ Value *VectorBlockGenerator::getVectorValue(ScopStmt &Stmt, Value *Old, return Vector; } -Type *VectorBlockGenerator::getVectorPtrTy(const Value *Val, int Width) { - auto *PointerTy = cast(Val->getType()); - unsigned AddrSpace = PointerTy->getAddressSpace(); - - Type *ScalarType = PointerTy->getElementType(); - auto *FVTy = FixedVectorType::get(ScalarType, Width); - - return PointerType::get(FVTy, AddrSpace); -} - Value *VectorBlockGenerator::generateStrideOneLoad( ScopStmt &Stmt, LoadInst *Load, VectorValueMapT &ScalarMaps, __isl_keep isl_id_to_ast_expr *NewAccesses, bool NegativeStride = false) { unsigned VectorWidth = getVectorWidth(); - auto *Pointer = Load->getPointerOperand(); - Type *VectorPtrType = getVectorPtrTy(Pointer, VectorWidth); + Type *VectorType = FixedVectorType::get(Load->getType(), VectorWidth); + Type *VectorPtrType = + PointerType::get(VectorType, Load->getPointerAddressSpace()); unsigned Offset = NegativeStride ? VectorWidth - 1 : 0; Value *NewPointer = generateLocationAccessed(Stmt, Load, ScalarMaps[Offset], VLTS[Offset], NewAccesses); Value *VectorPtr = Builder.CreateBitCast(NewPointer, VectorPtrType, "vector_ptr"); - LoadInst *VecLoad = - Builder.CreateLoad(VectorPtr, Load->getName() + "_p_vec_full"); + LoadInst *VecLoad = Builder.CreateLoad(VectorType, VectorPtr, + Load->getName() + "_p_vec_full"); if (!Aligned) VecLoad->setAlignment(Align(8)); @@ -1080,14 +1074,15 @@ Value *VectorBlockGenerator::generateStrideOneLoad( Value *VectorBlockGenerator::generateStrideZeroLoad( ScopStmt &Stmt, LoadInst *Load, ValueMapT &BBMap, __isl_keep isl_id_to_ast_expr *NewAccesses) { - auto *Pointer = Load->getPointerOperand(); - Type *VectorPtrType = getVectorPtrTy(Pointer, 1); + Type *VectorType = FixedVectorType::get(Load->getType(), 1); + Type *VectorPtrType = + PointerType::get(VectorType, Load->getPointerAddressSpace()); Value *NewPointer = generateLocationAccessed(Stmt, Load, BBMap, VLTS[0], NewAccesses); Value *VectorPtr = Builder.CreateBitCast(NewPointer, VectorPtrType, Load->getName() + "_p_vec_p"); - LoadInst *ScalarLoad = - Builder.CreateLoad(VectorPtr, Load->getName() + "_p_splat_one"); + LoadInst *ScalarLoad = Builder.CreateLoad(VectorType, VectorPtr, + Load->getName() + "_p_splat_one"); if (!Aligned) ScalarLoad->setAlignment(Align(8)); @@ -1104,9 +1099,8 @@ Value *VectorBlockGenerator::generateUnknownStrideLoad( ScopStmt &Stmt, LoadInst *Load, VectorValueMapT &ScalarMaps, __isl_keep isl_id_to_ast_expr *NewAccesses) { int VectorWidth = getVectorWidth(); - auto *Pointer = Load->getPointerOperand(); - auto *FVTy = FixedVectorType::get( - dyn_cast(Pointer->getType())->getElementType(), VectorWidth); + Type *ElemTy = Load->getType(); + auto *FVTy = FixedVectorType::get(ElemTy, VectorWidth); Value *Vector = UndefValue::get(FVTy); @@ -1114,7 +1108,7 @@ Value *VectorBlockGenerator::generateUnknownStrideLoad( Value *NewPointer = generateLocationAccessed(Stmt, Load, ScalarMaps[i], VLTS[i], NewAccesses); Value *ScalarLoad = - Builder.CreateLoad(NewPointer, Load->getName() + "_p_scalar_"); + Builder.CreateLoad(ElemTy, NewPointer, Load->getName() + "_p_scalar_"); Vector = Builder.CreateInsertElement( Vector, ScalarLoad, Builder.getInt32(i), Load->getName() + "_p_vec_"); } @@ -1192,7 +1186,6 @@ void VectorBlockGenerator::copyStore( VectorValueMapT &ScalarMaps, __isl_keep isl_id_to_ast_expr *NewAccesses) { const MemoryAccess &Access = Stmt.getArrayAccessFor(Store); - auto *Pointer = Store->getPointerOperand(); Value *Vector = getVectorValue(Stmt, Store->getValueOperand(), VectorMap, ScalarMaps, getLoopForStmt(Stmt)); @@ -1201,7 +1194,10 @@ void VectorBlockGenerator::copyStore( extractScalarValues(Store, VectorMap, ScalarMaps); if (Access.isStrideOne(isl::manage_copy(Schedule))) { - Type *VectorPtrType = getVectorPtrTy(Pointer, getVectorWidth()); + Type *VectorType = FixedVectorType::get(Store->getValueOperand()->getType(), + getVectorWidth()); + Type *VectorPtrType = + PointerType::get(VectorType, Store->getPointerAddressSpace()); Value *NewPointer = generateLocationAccessed(Stmt, Store, ScalarMaps[0], VLTS[0], NewAccesses); @@ -1339,10 +1335,13 @@ void VectorBlockGenerator::generateScalarVectorLoads( continue; auto *Address = getOrCreateAlloca(*MA); - Type *VectorPtrType = getVectorPtrTy(Address, 1); + Type *VectorType = FixedVectorType::get(MA->getElementType(), 1); + Type *VectorPtrType = PointerType::get( + VectorType, Address->getType()->getPointerAddressSpace()); Value *VectorPtr = Builder.CreateBitCast(Address, VectorPtrType, Address->getName() + "_p_vec_p"); - auto *Val = Builder.CreateLoad(VectorPtr, Address->getName() + ".reload"); + auto *Val = Builder.CreateLoad(VectorType, VectorPtr, + Address->getName() + ".reload"); Constant *SplatVector = Constant::getNullValue( FixedVectorType::get(Builder.getInt32Ty(), getVectorWidth())); diff --git a/polly/lib/CodeGen/IslExprBuilder.cpp b/polly/lib/CodeGen/IslExprBuilder.cpp index 86896235eb83..bd1a5642017e 100644 --- a/polly/lib/CodeGen/IslExprBuilder.cpp +++ b/polly/lib/CodeGen/IslExprBuilder.cpp @@ -231,7 +231,8 @@ Value *IslExprBuilder::createOpNAry(__isl_take isl_ast_expr *Expr) { return V; } -Value *IslExprBuilder::createAccessAddress(isl_ast_expr *Expr) { +std::pair +IslExprBuilder::createAccessAddress(isl_ast_expr *Expr) { assert(isl_ast_expr_get_type(Expr) == isl_ast_expr_op && "isl ast expression not of type isl_ast_op"); assert(isl_ast_expr_get_op_type(Expr) == isl_ast_op_access && @@ -281,7 +282,7 @@ Value *IslExprBuilder::createAccessAddress(isl_ast_expr *Expr) { isl_ast_expr_free(Expr); if (PollyDebugPrinting) RuntimeDebugBuilder::createCPUPrinter(Builder, "\n"); - return Base; + return {Base, SAI->getElementType()}; } IndexOp = nullptr; @@ -338,13 +339,14 @@ Value *IslExprBuilder::createAccessAddress(isl_ast_expr *Expr) { if (PollyDebugPrinting) RuntimeDebugBuilder::createCPUPrinter(Builder, "\n"); isl_ast_expr_free(Expr); - return Access; + return {Access, SAI->getElementType()}; } Value *IslExprBuilder::createOpAccess(isl_ast_expr *Expr) { - Value *Addr = createAccessAddress(Expr); - assert(Addr && "Could not create op access address"); - return Builder.CreateLoad(Addr, Addr->getName() + ".load"); + auto Info = createAccessAddress(Expr); + assert(Info.first && "Could not create op access address"); + return Builder.CreateLoad(Info.second, Info.first, + Info.first->getName() + ".load"); } Value *IslExprBuilder::createOpBin(__isl_take isl_ast_expr *Expr) { @@ -704,7 +706,7 @@ Value *IslExprBuilder::createOpAddressOf(__isl_take isl_ast_expr *Expr) { assert(isl_ast_expr_get_op_type(Op) == isl_ast_op_access && "Expected address of operator to be an access expression."); - Value *V = createAccessAddress(Op); + Value *V = createAccessAddress(Op).first; isl_ast_expr_free(Expr); diff --git a/polly/lib/CodeGen/IslNodeBuilder.cpp b/polly/lib/CodeGen/IslNodeBuilder.cpp index cb176cffba5f..a575e5037acb 100644 --- a/polly/lib/CodeGen/IslNodeBuilder.cpp +++ b/polly/lib/CodeGen/IslNodeBuilder.cpp @@ -952,7 +952,7 @@ void IslNodeBuilder::generateCopyStmt( auto *LoadValue = ExprBuilder.create(AccessExpr); AccessExpr = isl_id_to_ast_expr_get(NewAccesses, (*WriteAccess)->getId().release()); - auto *StoreAddr = ExprBuilder.createAccessAddress(AccessExpr); + auto *StoreAddr = ExprBuilder.createAccessAddress(AccessExpr).first; Builder.CreateStore(LoadValue, StoreAddr); }