Rename MemRefDescriptor::getElementType() to MemRefDescriptor::getElementPtrType().

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D87284
This commit is contained in:
Christian Sigg 2020-09-09 07:41:56 +02:00
parent 8427885e27
commit 3a577f5446
3 changed files with 12 additions and 9 deletions

View File

@ -34,6 +34,7 @@ class UnrankedMemRefType;
namespace LLVM {
class LLVMDialect;
class LLVMType;
class LLVMPointerType;
} // namespace LLVM
/// Callback to convert function argument types. It converts a MemRef function
@ -281,8 +282,8 @@ public:
void setConstantStride(OpBuilder &builder, Location loc, unsigned pos,
uint64_t stride);
/// Returns the (LLVM) type this descriptor points to.
LLVM::LLVMType getElementType();
/// Returns the (LLVM) pointer type this descriptor contains.
LLVM::LLVMPointerType getElementPtrType();
/// Builds IR populating a MemRef descriptor structure from a list of
/// individual values composing that descriptor, in the following order:

View File

@ -642,9 +642,11 @@ void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc,
createIndexAttrConstant(builder, loc, indexType, stride));
}
LLVM::LLVMType MemRefDescriptor::getElementType() {
return value.getType().cast<LLVM::LLVMType>().getStructElementType(
kAlignedPtrPosInMemRefDescriptor);
LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() {
return value.getType()
.cast<LLVM::LLVMType>()
.getStructElementType(kAlignedPtrPosInMemRefDescriptor)
.cast<LLVM::LLVMPointerType>();
}
/// Creates a MemRef descriptor structure from a list of individual values
@ -894,7 +896,7 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
Value ConvertToLLVMPattern::getDataPtr(
Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
ConversionPatternRewriter &rewriter) const {
LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementPtrType();
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(type, strides, offset);

View File

@ -198,7 +198,7 @@ static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
Value base;
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
return failure();
auto pType = MemRefDescriptor(memref).getElementType();
auto pType = MemRefDescriptor(memref).getElementPtrType();
ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
return success();
}
@ -225,7 +225,7 @@ static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
Value base;
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
return failure();
auto pType = MemRefDescriptor(memref).getElementType();
auto pType = MemRefDescriptor(memref).getElementPtrType();
auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0));
ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
return success();
@ -1151,7 +1151,7 @@ public:
// Create descriptor.
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
Type llvmTargetElementTy = desc.getElementType();
Type llvmTargetElementTy = desc.getElementPtrType();
// Set allocated ptr.
Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
allocated =