forked from OSchip/llvm-project
Concentrate memref descriptor manipulation logic in one place
Memref descriptor is becoming increasingly complex. Memrefs are manipulated by multiple standard instructions, each of which has a non-trivial lowering to the LLVM dialect. This leads to verbose code that manipulates the descriptors exposing the internals of insert/extractelement opreations. Implement a wrapper class that contains a memref descriptor and provides semantically named methods that build the primitive IR operations instead. PiperOrigin-RevId: 280371225
This commit is contained in:
parent
d1c99e10d0
commit
ee5c2256ef
|
@ -235,6 +235,125 @@ LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
|
||||||
: ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {}
|
: ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
/// Helper class to produce LLVM dialect operations extracting or inserting
|
||||||
|
/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
|
||||||
|
/// The Value may be null, in which case none of the operations are valid.
|
||||||
|
class MemRefDescriptor {
|
||||||
|
public:
|
||||||
|
/// Construct a helper for the given descriptor value.
|
||||||
|
explicit MemRefDescriptor(Value *descriptor) : value(descriptor) {
|
||||||
|
if (value) {
|
||||||
|
structType = value->getType().cast<LLVM::LLVMType>();
|
||||||
|
indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType(
|
||||||
|
LLVMTypeConverter::kOffsetPosInMemRefDescriptor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builds IR creating an `undef` value of the descriptor type.
|
||||||
|
static MemRefDescriptor undef(OpBuilder &builder, Location loc,
|
||||||
|
Type descriptorType) {
|
||||||
|
Value *descriptor = builder.create<LLVM::UndefOp>(
|
||||||
|
loc, descriptorType.cast<LLVM::LLVMType>());
|
||||||
|
return MemRefDescriptor(descriptor);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builds IR extracting the allocated pointer from the descriptor.
|
||||||
|
Value *allocatedPtr(OpBuilder &builder, Location loc) {
|
||||||
|
return extractPtr(builder, loc,
|
||||||
|
LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builds IR inserting the allocated pointer into the descriptor.
|
||||||
|
void setAllocatedPtr(OpBuilder &builder, Location loc, Value *ptr) {
|
||||||
|
setPtr(builder, loc, LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor,
|
||||||
|
ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builds IR extracting the aligned pointer from the descriptor.
|
||||||
|
Value *alignedPtr(OpBuilder &builder, Location loc) {
|
||||||
|
return extractPtr(builder, loc,
|
||||||
|
LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builds IR inserting the aligned pointer into the descriptor.
|
||||||
|
void setAlignedPtr(OpBuilder &builder, Location loc, Value *ptr) {
|
||||||
|
setPtr(builder, loc, LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor,
|
||||||
|
ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builds IR extracting the offset from the descriptor.
|
||||||
|
Value *offset(OpBuilder &builder, Location loc) {
|
||||||
|
return builder.create<LLVM::ExtractValueOp>(
|
||||||
|
loc, indexType, value,
|
||||||
|
builder.getI64ArrayAttr(
|
||||||
|
LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builds IR inserting the offset into the descriptor.
|
||||||
|
void setOffset(OpBuilder &builder, Location loc, Value *offset) {
|
||||||
|
value = builder.create<LLVM::InsertValueOp>(
|
||||||
|
loc, structType, value, offset,
|
||||||
|
builder.getI64ArrayAttr(
|
||||||
|
LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builds IR extracting the pos-th size from the descriptor.
|
||||||
|
Value *size(OpBuilder &builder, Location loc, unsigned pos) {
|
||||||
|
return builder.create<LLVM::ExtractValueOp>(
|
||||||
|
loc, indexType, value,
|
||||||
|
builder.getI64ArrayAttr(
|
||||||
|
{LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builds IR inserting the pos-th size into the descriptor
|
||||||
|
void setSize(OpBuilder &builder, Location loc, unsigned pos, Value *size) {
|
||||||
|
value = builder.create<LLVM::InsertValueOp>(
|
||||||
|
loc, structType, value, size,
|
||||||
|
builder.getI64ArrayAttr(
|
||||||
|
{LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builds IR extracting the pos-th size from the descriptor.
|
||||||
|
Value *stride(OpBuilder &builder, Location loc, unsigned pos) {
|
||||||
|
return builder.create<LLVM::ExtractValueOp>(
|
||||||
|
loc, indexType, value,
|
||||||
|
builder.getI64ArrayAttr(
|
||||||
|
{LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builds IR inserting the pos-th stride into the descriptor
|
||||||
|
void setStride(OpBuilder &builder, Location loc, unsigned pos,
|
||||||
|
Value *stride) {
|
||||||
|
value = builder.create<LLVM::InsertValueOp>(
|
||||||
|
loc, structType, value, stride,
|
||||||
|
builder.getI64ArrayAttr(
|
||||||
|
{LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
|
||||||
|
}
|
||||||
|
|
||||||
|
/*implicit*/ operator Value *() { return value; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
Value *extractPtr(OpBuilder &builder, Location loc, unsigned pos) {
|
||||||
|
Type type = structType.getStructElementType(pos);
|
||||||
|
return builder.create<LLVM::ExtractValueOp>(loc, type, value,
|
||||||
|
builder.getI64ArrayAttr(pos));
|
||||||
|
}
|
||||||
|
|
||||||
|
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value *ptr) {
|
||||||
|
value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
|
||||||
|
builder.getI64ArrayAttr(pos));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cached descriptor type.
|
||||||
|
LLVM::LLVMType structType;
|
||||||
|
|
||||||
|
// Cached index type.
|
||||||
|
LLVM::LLVMType indexType;
|
||||||
|
|
||||||
|
// Actual descriptor.
|
||||||
|
Value *value;
|
||||||
|
};
|
||||||
|
|
||||||
// Base class for Standard to LLVM IR op conversions. Matches the Op type
|
// Base class for Standard to LLVM IR op conversions. Matches the Op type
|
||||||
// provided as template argument. Carries a reference to the LLVM dialect in
|
// provided as template argument. Carries a reference to the LLVM dialect in
|
||||||
// case it is necessary for rewriters.
|
// case it is necessary for rewriters.
|
||||||
|
@ -278,29 +397,6 @@ public:
|
||||||
return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr);
|
return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract allocated data pointer value from a value representing a memref.
|
|
||||||
static Value *
|
|
||||||
extractAllocatedMemRefElementPtr(ConversionPatternRewriter &builder,
|
|
||||||
Location loc, Value *memref,
|
|
||||||
Type elementTypePtr) {
|
|
||||||
return builder.create<LLVM::ExtractValueOp>(
|
|
||||||
loc, elementTypePtr, memref,
|
|
||||||
builder.getI64ArrayAttr(
|
|
||||||
LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract properly aligned data pointer value from a value representing a
|
|
||||||
// memref.
|
|
||||||
static Value *
|
|
||||||
extractAlignedMemRefElementPtr(ConversionPatternRewriter &builder,
|
|
||||||
Location loc, Value *memref,
|
|
||||||
Type elementTypePtr) {
|
|
||||||
return builder.create<LLVM::ExtractValueOp>(
|
|
||||||
loc, elementTypePtr, memref,
|
|
||||||
builder.getI64ArrayAttr(
|
|
||||||
LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
LLVM::LLVMDialect &dialect;
|
LLVM::LLVMDialect &dialect;
|
||||||
};
|
};
|
||||||
|
@ -786,14 +882,10 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
||||||
|
|
||||||
// Create the MemRef descriptor.
|
// Create the MemRef descriptor.
|
||||||
auto structType = lowering.convertType(type);
|
auto structType = lowering.convertType(type);
|
||||||
Value *memRefDescriptor =
|
auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
|
||||||
rewriter.create<LLVM::UndefOp>(loc, structType, ArrayRef<Value *>{});
|
|
||||||
|
|
||||||
// Field 1: Allocated pointer, used for malloc/free.
|
// Field 1: Allocated pointer, used for malloc/free.
|
||||||
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
|
memRefDescriptor.setAllocatedPtr(rewriter, loc, bitcastAllocated);
|
||||||
loc, structType, memRefDescriptor, bitcastAllocated,
|
|
||||||
rewriter.getI64ArrayAttr(
|
|
||||||
LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
|
|
||||||
// Field 2: Actual aligned pointer to payload.
|
// Field 2: Actual aligned pointer to payload.
|
||||||
Value *bitcastAligned = bitcastAllocated;
|
Value *bitcastAligned = bitcastAllocated;
|
||||||
if (align) {
|
if (align) {
|
||||||
|
@ -808,20 +900,15 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
||||||
bitcastAligned = rewriter.create<LLVM::BitcastOp>(
|
bitcastAligned = rewriter.create<LLVM::BitcastOp>(
|
||||||
loc, elementPtrType, ArrayRef<Value *>(aligned));
|
loc, elementPtrType, ArrayRef<Value *>(aligned));
|
||||||
}
|
}
|
||||||
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
|
memRefDescriptor.setAlignedPtr(rewriter, loc, bitcastAligned);
|
||||||
loc, structType, memRefDescriptor, bitcastAligned,
|
|
||||||
rewriter.getI64ArrayAttr(
|
|
||||||
LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
|
|
||||||
// Field 3: Offset in aligned pointer.
|
// Field 3: Offset in aligned pointer.
|
||||||
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
|
memRefDescriptor.setOffset(rewriter, loc,
|
||||||
loc, structType, memRefDescriptor,
|
createIndexConstant(rewriter, loc, offset));
|
||||||
createIndexConstant(rewriter, loc, offset),
|
|
||||||
rewriter.getI64ArrayAttr(
|
|
||||||
LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
|
|
||||||
|
|
||||||
if (type.getRank() == 0)
|
if (type.getRank() == 0)
|
||||||
// No size/stride descriptor in memref, return the descriptor value.
|
// No size/stride descriptor in memref, return the descriptor value.
|
||||||
return rewriter.replaceOp(op, memRefDescriptor);
|
return rewriter.replaceOp(op, {memRefDescriptor});
|
||||||
|
|
||||||
// Fields 4 and 5: Sizes and strides of the strided MemRef.
|
// Fields 4 and 5: Sizes and strides of the strided MemRef.
|
||||||
// Store all sizes in the descriptor. Only dynamic sizes are passed in as
|
// Store all sizes in the descriptor. Only dynamic sizes are passed in as
|
||||||
|
@ -846,18 +933,12 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
||||||
// Fill size and stride descriptors in memref.
|
// Fill size and stride descriptors in memref.
|
||||||
for (auto indexedSize : llvm::enumerate(sizes)) {
|
for (auto indexedSize : llvm::enumerate(sizes)) {
|
||||||
int64_t index = indexedSize.index();
|
int64_t index = indexedSize.index();
|
||||||
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
|
memRefDescriptor.setSize(rewriter, loc, index, indexedSize.value());
|
||||||
loc, structType, memRefDescriptor, indexedSize.value(),
|
memRefDescriptor.setStride(rewriter, loc, index, strideValues[index]);
|
||||||
rewriter.getI64ArrayAttr(
|
|
||||||
{LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
|
|
||||||
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
|
|
||||||
loc, structType, memRefDescriptor, strideValues[index],
|
|
||||||
rewriter.getI64ArrayAttr(
|
|
||||||
{LLVMTypeConverter::kStridePosInMemRefDescriptor, index}));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the final value of the descriptor.
|
// Return the final value of the descriptor.
|
||||||
rewriter.replaceOp(op, memRefDescriptor);
|
rewriter.replaceOp(op, {memRefDescriptor});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -947,13 +1028,10 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
|
||||||
/*isVarArg=*/false));
|
/*isVarArg=*/false));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
|
MemRefDescriptor memref(transformed.memref());
|
||||||
Type elementPtrType = type.getStructElementType(
|
|
||||||
LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
|
|
||||||
Value *bufferPtr = extractAllocatedMemRefElementPtr(
|
|
||||||
rewriter, op->getLoc(), transformed.memref(), elementPtrType);
|
|
||||||
Value *casted = rewriter.create<LLVM::BitcastOp>(
|
Value *casted = rewriter.create<LLVM::BitcastOp>(
|
||||||
op->getLoc(), getVoidPtrType(), bufferPtr);
|
op->getLoc(), getVoidPtrType(),
|
||||||
|
memref.allocatedPtr(rewriter, op->getLoc()));
|
||||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
|
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
|
||||||
op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
|
op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
|
@ -1003,10 +1081,8 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
|
||||||
int64_t index = dimOp.getIndex();
|
int64_t index = dimOp.getIndex();
|
||||||
// Extract dynamic size from the memref descriptor.
|
// Extract dynamic size from the memref descriptor.
|
||||||
if (ShapedType::isDynamic(shape[index]))
|
if (ShapedType::isDynamic(shape[index]))
|
||||||
rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
|
rewriter.replaceOp(op, {MemRefDescriptor(transformed.memrefOrTensor())
|
||||||
op, getIndexType(), transformed.memrefOrTensor(),
|
.size(rewriter, op->getLoc(), index)});
|
||||||
rewriter.getI64ArrayAttr(
|
|
||||||
{LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
|
|
||||||
else
|
else
|
||||||
// Use constant for static size.
|
// Use constant for static size.
|
||||||
rewriter.replaceOp(
|
rewriter.replaceOp(
|
||||||
|
@ -1058,34 +1134,21 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
|
||||||
// This is a strided getElementPtr variant that linearizes subscripts as:
|
// This is a strided getElementPtr variant that linearizes subscripts as:
|
||||||
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
|
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
|
||||||
Value *getStridedElementPtr(Location loc, Type elementTypePtr,
|
Value *getStridedElementPtr(Location loc, Type elementTypePtr,
|
||||||
Value *memRefDescriptor,
|
Value *descriptor, ArrayRef<Value *> indices,
|
||||||
ArrayRef<Value *> indices,
|
|
||||||
ArrayRef<int64_t> strides, int64_t offset,
|
ArrayRef<int64_t> strides, int64_t offset,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
auto indexTy = this->getIndexType();
|
MemRefDescriptor memRefDescriptor(descriptor);
|
||||||
Value *base = this->extractAlignedMemRefElementPtr(
|
|
||||||
rewriter, loc, memRefDescriptor, elementTypePtr);
|
Value *base = memRefDescriptor.alignedPtr(rewriter, loc);
|
||||||
Value *offsetValue =
|
Value *offsetValue = offset == MemRefType::getDynamicStrideOrOffset()
|
||||||
offset == MemRefType::getDynamicStrideOrOffset()
|
? memRefDescriptor.offset(rewriter, loc)
|
||||||
? rewriter.create<LLVM::ExtractValueOp>(
|
|
||||||
loc, indexTy, memRefDescriptor,
|
|
||||||
rewriter.getI64ArrayAttr(
|
|
||||||
LLVMTypeConverter::kOffsetPosInMemRefDescriptor))
|
|
||||||
: this->createIndexConstant(rewriter, loc, offset);
|
: this->createIndexConstant(rewriter, loc, offset);
|
||||||
|
|
||||||
for (int i = 0, e = indices.size(); i < e; ++i) {
|
for (int i = 0, e = indices.size(); i < e; ++i) {
|
||||||
Value *stride;
|
Value *stride =
|
||||||
if (strides[i] != MemRefType::getDynamicStrideOrOffset()) {
|
strides[i] == MemRefType::getDynamicStrideOrOffset()
|
||||||
// Use static stride.
|
? memRefDescriptor.stride(rewriter, loc, i)
|
||||||
auto attr =
|
: this->createIndexConstant(rewriter, loc, strides[i]);
|
||||||
rewriter.getIntegerAttr(rewriter.getIndexType(), strides[i]);
|
|
||||||
stride = rewriter.create<LLVM::ConstantOp>(loc, indexTy, attr);
|
|
||||||
} else {
|
|
||||||
// Use dynamic stride.
|
|
||||||
stride = rewriter.create<LLVM::ExtractValueOp>(
|
|
||||||
loc, indexTy, memRefDescriptor,
|
|
||||||
rewriter.getI64ArrayAttr(
|
|
||||||
{LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
|
|
||||||
}
|
|
||||||
Value *additionalOffset =
|
Value *additionalOffset =
|
||||||
rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
|
rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
|
||||||
offsetValue =
|
offsetValue =
|
||||||
|
@ -1452,74 +1515,45 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
|
||||||
return matchFailure();
|
return matchFailure();
|
||||||
|
|
||||||
// Create the descriptor.
|
// Create the descriptor.
|
||||||
Value *desc = rewriter.create<LLVM::UndefOp>(loc, targetDescTy);
|
MemRefDescriptor sourceMemRef(adaptor.source());
|
||||||
|
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
|
||||||
|
|
||||||
// Copy the buffer pointer from the old descriptor to the new one.
|
// Copy the buffer pointer from the old descriptor to the new one.
|
||||||
Value *sourceDescriptor = adaptor.source();
|
Value *extracted = sourceMemRef.allocatedPtr(rewriter, loc);
|
||||||
Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
|
|
||||||
loc, sourceElementTy.getPointerTo(), sourceDescriptor,
|
|
||||||
rewriter.getI64ArrayAttr(
|
|
||||||
LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
|
|
||||||
Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>(
|
Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>(
|
||||||
loc, targetElementTy.getPointerTo(), extracted);
|
loc, targetElementTy.getPointerTo(), extracted);
|
||||||
desc = rewriter.create<LLVM::InsertValueOp>(
|
targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
|
||||||
loc, desc, bitcastPtr,
|
|
||||||
rewriter.getI64ArrayAttr(
|
extracted = sourceMemRef.alignedPtr(rewriter, loc);
|
||||||
LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
|
|
||||||
extracted = rewriter.create<LLVM::ExtractValueOp>(
|
|
||||||
loc, sourceElementTy.getPointerTo(), sourceDescriptor,
|
|
||||||
rewriter.getI64ArrayAttr(
|
|
||||||
LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
|
|
||||||
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
|
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
|
||||||
loc, targetElementTy.getPointerTo(), extracted);
|
loc, targetElementTy.getPointerTo(), extracted);
|
||||||
desc = rewriter.create<LLVM::InsertValueOp>(
|
targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
|
||||||
loc, desc, bitcastPtr,
|
|
||||||
rewriter.getI64ArrayAttr(
|
|
||||||
LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
|
|
||||||
|
|
||||||
// Extract strides needed to compute offset.
|
// Extract strides needed to compute offset.
|
||||||
SmallVector<Value *, 4> strideValues;
|
SmallVector<Value *, 4> strideValues;
|
||||||
strideValues.reserve(viewMemRefType.getRank());
|
strideValues.reserve(viewMemRefType.getRank());
|
||||||
for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) {
|
for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i)
|
||||||
strideValues.push_back(rewriter.create<LLVM::ExtractValueOp>(
|
strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
|
||||||
loc, getIndexType(), sourceDescriptor,
|
|
||||||
rewriter.getI64ArrayAttr(
|
|
||||||
{LLVMTypeConverter::kStridePosInMemRefDescriptor, i})));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Offset.
|
// Offset.
|
||||||
Value *baseOffset = rewriter.create<LLVM::ExtractValueOp>(
|
Value *baseOffset = sourceMemRef.offset(rewriter, loc);
|
||||||
loc, getIndexType(), sourceDescriptor,
|
|
||||||
rewriter.getI64ArrayAttr(
|
|
||||||
LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
|
|
||||||
for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) {
|
for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) {
|
||||||
Value *min = adaptor.offsets()[i];
|
Value *min = adaptor.offsets()[i];
|
||||||
baseOffset = rewriter.create<LLVM::AddOp>(
|
baseOffset = rewriter.create<LLVM::AddOp>(
|
||||||
loc, baseOffset,
|
loc, baseOffset,
|
||||||
rewriter.create<LLVM::MulOp>(loc, min, strideValues[i]));
|
rewriter.create<LLVM::MulOp>(loc, min, strideValues[i]));
|
||||||
}
|
}
|
||||||
desc = rewriter.create<LLVM::InsertValueOp>(
|
targetMemRef.setOffset(rewriter, loc, baseOffset);
|
||||||
loc, desc, baseOffset,
|
|
||||||
rewriter.getI64ArrayAttr(
|
|
||||||
LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
|
|
||||||
|
|
||||||
// Update sizes and strides.
|
// Update sizes and strides.
|
||||||
for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
|
for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
|
||||||
// Update size.
|
targetMemRef.setSize(rewriter, loc, i, adaptor.sizes()[i]);
|
||||||
desc = rewriter.create<LLVM::InsertValueOp>(
|
targetMemRef.setStride(rewriter, loc, i,
|
||||||
loc, desc, adaptor.sizes()[i],
|
rewriter.create<LLVM::MulOp>(
|
||||||
rewriter.getI64ArrayAttr(
|
loc, adaptor.strides()[i], strideValues[i]));
|
||||||
{LLVMTypeConverter::kSizePosInMemRefDescriptor, i}));
|
|
||||||
// Update stride.
|
|
||||||
desc = rewriter.create<LLVM::InsertValueOp>(
|
|
||||||
loc, desc,
|
|
||||||
rewriter.create<LLVM::MulOp>(loc, adaptor.strides()[i],
|
|
||||||
strideValues[i]),
|
|
||||||
rewriter.getI64ArrayAttr(
|
|
||||||
{LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(op, desc);
|
rewriter.replaceOp(op, {targetMemRef});
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1571,10 +1605,6 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
auto viewOp = cast<ViewOp>(op);
|
auto viewOp = cast<ViewOp>(op);
|
||||||
ViewOpOperandAdaptor adaptor(operands);
|
ViewOpOperandAdaptor adaptor(operands);
|
||||||
auto sourceMemRefType = viewOp.source()->getType().cast<MemRefType>();
|
|
||||||
auto sourceElementTy =
|
|
||||||
lowering.convertType(sourceMemRefType.getElementType())
|
|
||||||
.dyn_cast<LLVM::LLVMType>();
|
|
||||||
|
|
||||||
auto viewMemRefType = viewOp.getType();
|
auto viewMemRefType = viewOp.getType();
|
||||||
auto targetElementTy = lowering.convertType(viewMemRefType.getElementType())
|
auto targetElementTy = lowering.convertType(viewMemRefType.getElementType())
|
||||||
|
@ -1593,32 +1623,20 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
|
||||||
matchFailure();
|
matchFailure();
|
||||||
|
|
||||||
// Create the descriptor.
|
// Create the descriptor.
|
||||||
Value *desc = rewriter.create<LLVM::UndefOp>(loc, targetDescTy);
|
MemRefDescriptor sourceMemRef(adaptor.source());
|
||||||
|
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
|
||||||
|
|
||||||
// Field 1: Copy the allocated pointer, used for malloc/free.
|
// Field 1: Copy the allocated pointer, used for malloc/free.
|
||||||
Value *sourceDescriptor = adaptor.source();
|
Value *extracted = sourceMemRef.allocatedPtr(rewriter, loc);
|
||||||
Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
|
|
||||||
loc, sourceElementTy.getPointerTo(), sourceDescriptor,
|
|
||||||
rewriter.getI64ArrayAttr(
|
|
||||||
LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
|
|
||||||
Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>(
|
Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>(
|
||||||
loc, targetElementTy.getPointerTo(), extracted);
|
loc, targetElementTy.getPointerTo(), extracted);
|
||||||
desc = rewriter.create<LLVM::InsertValueOp>(
|
targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
|
||||||
loc, desc, bitcastPtr,
|
|
||||||
rewriter.getI64ArrayAttr(
|
|
||||||
LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
|
|
||||||
|
|
||||||
// Field 2: Copy the actual aligned pointer to payload.
|
// Field 2: Copy the actual aligned pointer to payload.
|
||||||
extracted = rewriter.create<LLVM::ExtractValueOp>(
|
extracted = sourceMemRef.alignedPtr(rewriter, loc);
|
||||||
loc, sourceElementTy.getPointerTo(), sourceDescriptor,
|
|
||||||
rewriter.getI64ArrayAttr(
|
|
||||||
LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
|
|
||||||
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
|
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
|
||||||
loc, targetElementTy.getPointerTo(), extracted);
|
loc, targetElementTy.getPointerTo(), extracted);
|
||||||
desc = rewriter.create<LLVM::InsertValueOp>(
|
targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
|
||||||
loc, desc, bitcastPtr,
|
|
||||||
rewriter.getI64ArrayAttr(
|
|
||||||
LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
|
|
||||||
|
|
||||||
// Field 3: Copy the offset in aligned pointer.
|
// Field 3: Copy the offset in aligned pointer.
|
||||||
unsigned numDynamicSizes = llvm::size(viewOp.getDynamicSizes());
|
unsigned numDynamicSizes = llvm::size(viewOp.getDynamicSizes());
|
||||||
|
@ -1630,14 +1648,11 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
|
||||||
? createIndexConstant(rewriter, loc, offset)
|
? createIndexConstant(rewriter, loc, offset)
|
||||||
// TODO(ntv): better adaptor.
|
// TODO(ntv): better adaptor.
|
||||||
: sizeAndOffsetOperands.back();
|
: sizeAndOffsetOperands.back();
|
||||||
desc = rewriter.create<LLVM::InsertValueOp>(
|
targetMemRef.setOffset(rewriter, loc, baseOffset);
|
||||||
loc, desc, baseOffset,
|
|
||||||
rewriter.getI64ArrayAttr(
|
|
||||||
LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
|
|
||||||
|
|
||||||
// Early exit for 0-D corner case.
|
// Early exit for 0-D corner case.
|
||||||
if (viewMemRefType.getRank() == 0)
|
if (viewMemRefType.getRank() == 0)
|
||||||
return rewriter.replaceOp(op, desc), matchSuccess();
|
return rewriter.replaceOp(op, {targetMemRef}), matchSuccess();
|
||||||
|
|
||||||
// Fields 4 and 5: Update sizes and strides.
|
// Fields 4 and 5: Update sizes and strides.
|
||||||
if (strides.back() != 1)
|
if (strides.back() != 1)
|
||||||
|
@ -1648,20 +1663,14 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
|
||||||
// Update size.
|
// Update size.
|
||||||
Value *size = getSize(rewriter, loc, viewMemRefType.getShape(),
|
Value *size = getSize(rewriter, loc, viewMemRefType.getShape(),
|
||||||
sizeAndOffsetOperands, i);
|
sizeAndOffsetOperands, i);
|
||||||
desc = rewriter.create<LLVM::InsertValueOp>(
|
targetMemRef.setSize(rewriter, loc, i, size);
|
||||||
loc, desc, size,
|
|
||||||
rewriter.getI64ArrayAttr(
|
|
||||||
{LLVMTypeConverter::kSizePosInMemRefDescriptor, i}));
|
|
||||||
// Update stride.
|
// Update stride.
|
||||||
stride = getStride(rewriter, loc, strides, nextSize, stride, i);
|
stride = getStride(rewriter, loc, strides, nextSize, stride, i);
|
||||||
desc = rewriter.create<LLVM::InsertValueOp>(
|
targetMemRef.setStride(rewriter, loc, i, stride);
|
||||||
loc, desc, stride,
|
|
||||||
rewriter.getI64ArrayAttr(
|
|
||||||
{LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
|
|
||||||
nextSize = size;
|
nextSize = size;
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(op, desc);
|
rewriter.replaceOp(op, {targetMemRef});
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue