Concentrate memref descriptor manipulation logic in one place

Memref descriptor is becoming increasingly complex. Memrefs are manipulated by
multiple standard instructions, each of which has a non-trivial lowering to the
LLVM dialect. This leads to verbose code that manipulates the descriptors
exposing the internals of insert/extractelement opreations. Implement a wrapper
class that contains a memref descriptor and provides semantically named methods
that build the primitive IR operations instead.

PiperOrigin-RevId: 280371225
This commit is contained in:
Alex Zinenko 2019-11-14 00:48:41 -08:00 committed by A. Unique TensorFlower
parent d1c99e10d0
commit ee5c2256ef
1 changed files with 174 additions and 165 deletions

View File

@ -235,6 +235,125 @@ LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
: ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {} : ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {}
namespace { namespace {
/// Helper class to produce LLVM dialect operations extracting or inserting
/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
/// The Value may be null, in which case none of the operations are valid.
class MemRefDescriptor {
public:
/// Construct a helper for the given descriptor value.
explicit MemRefDescriptor(Value *descriptor) : value(descriptor) {
if (value) {
structType = value->getType().cast<LLVM::LLVMType>();
indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType(
LLVMTypeConverter::kOffsetPosInMemRefDescriptor);
}
}
/// Builds IR creating an `undef` value of the descriptor type.
static MemRefDescriptor undef(OpBuilder &builder, Location loc,
Type descriptorType) {
Value *descriptor = builder.create<LLVM::UndefOp>(
loc, descriptorType.cast<LLVM::LLVMType>());
return MemRefDescriptor(descriptor);
}
/// Builds IR extracting the allocated pointer from the descriptor.
Value *allocatedPtr(OpBuilder &builder, Location loc) {
return extractPtr(builder, loc,
LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
}
/// Builds IR inserting the allocated pointer into the descriptor.
void setAllocatedPtr(OpBuilder &builder, Location loc, Value *ptr) {
setPtr(builder, loc, LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor,
ptr);
}
/// Builds IR extracting the aligned pointer from the descriptor.
Value *alignedPtr(OpBuilder &builder, Location loc) {
return extractPtr(builder, loc,
LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
}
/// Builds IR inserting the aligned pointer into the descriptor.
void setAlignedPtr(OpBuilder &builder, Location loc, Value *ptr) {
setPtr(builder, loc, LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor,
ptr);
}
/// Builds IR extracting the offset from the descriptor.
Value *offset(OpBuilder &builder, Location loc) {
return builder.create<LLVM::ExtractValueOp>(
loc, indexType, value,
builder.getI64ArrayAttr(
LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
}
/// Builds IR inserting the offset into the descriptor.
void setOffset(OpBuilder &builder, Location loc, Value *offset) {
value = builder.create<LLVM::InsertValueOp>(
loc, structType, value, offset,
builder.getI64ArrayAttr(
LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
}
/// Builds IR extracting the pos-th size from the descriptor.
Value *size(OpBuilder &builder, Location loc, unsigned pos) {
return builder.create<LLVM::ExtractValueOp>(
loc, indexType, value,
builder.getI64ArrayAttr(
{LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
}
/// Builds IR inserting the pos-th size into the descriptor
void setSize(OpBuilder &builder, Location loc, unsigned pos, Value *size) {
value = builder.create<LLVM::InsertValueOp>(
loc, structType, value, size,
builder.getI64ArrayAttr(
{LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
}
/// Builds IR extracting the pos-th size from the descriptor.
Value *stride(OpBuilder &builder, Location loc, unsigned pos) {
return builder.create<LLVM::ExtractValueOp>(
loc, indexType, value,
builder.getI64ArrayAttr(
{LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
}
/// Builds IR inserting the pos-th stride into the descriptor
void setStride(OpBuilder &builder, Location loc, unsigned pos,
Value *stride) {
value = builder.create<LLVM::InsertValueOp>(
loc, structType, value, stride,
builder.getI64ArrayAttr(
{LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
}
/*implicit*/ operator Value *() { return value; }
private:
Value *extractPtr(OpBuilder &builder, Location loc, unsigned pos) {
Type type = structType.getStructElementType(pos);
return builder.create<LLVM::ExtractValueOp>(loc, type, value,
builder.getI64ArrayAttr(pos));
}
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value *ptr) {
value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
builder.getI64ArrayAttr(pos));
}
// Cached descriptor type.
LLVM::LLVMType structType;
// Cached index type.
LLVM::LLVMType indexType;
// Actual descriptor.
Value *value;
};
// Base class for Standard to LLVM IR op conversions. Matches the Op type // Base class for Standard to LLVM IR op conversions. Matches the Op type
// provided as template argument. Carries a reference to the LLVM dialect in // provided as template argument. Carries a reference to the LLVM dialect in
// case it is necessary for rewriters. // case it is necessary for rewriters.
@ -278,29 +397,6 @@ public:
return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr); return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr);
} }
// Extract allocated data pointer value from a value representing a memref.
static Value *
extractAllocatedMemRefElementPtr(ConversionPatternRewriter &builder,
Location loc, Value *memref,
Type elementTypePtr) {
return builder.create<LLVM::ExtractValueOp>(
loc, elementTypePtr, memref,
builder.getI64ArrayAttr(
LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
}
// Extract properly aligned data pointer value from a value representing a
// memref.
static Value *
extractAlignedMemRefElementPtr(ConversionPatternRewriter &builder,
Location loc, Value *memref,
Type elementTypePtr) {
return builder.create<LLVM::ExtractValueOp>(
loc, elementTypePtr, memref,
builder.getI64ArrayAttr(
LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
}
protected: protected:
LLVM::LLVMDialect &dialect; LLVM::LLVMDialect &dialect;
}; };
@ -786,14 +882,10 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
// Create the MemRef descriptor. // Create the MemRef descriptor.
auto structType = lowering.convertType(type); auto structType = lowering.convertType(type);
Value *memRefDescriptor = auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
rewriter.create<LLVM::UndefOp>(loc, structType, ArrayRef<Value *>{});
// Field 1: Allocated pointer, used for malloc/free. // Field 1: Allocated pointer, used for malloc/free.
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>( memRefDescriptor.setAllocatedPtr(rewriter, loc, bitcastAllocated);
loc, structType, memRefDescriptor, bitcastAllocated,
rewriter.getI64ArrayAttr(
LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
// Field 2: Actual aligned pointer to payload. // Field 2: Actual aligned pointer to payload.
Value *bitcastAligned = bitcastAllocated; Value *bitcastAligned = bitcastAllocated;
if (align) { if (align) {
@ -808,20 +900,15 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
bitcastAligned = rewriter.create<LLVM::BitcastOp>( bitcastAligned = rewriter.create<LLVM::BitcastOp>(
loc, elementPtrType, ArrayRef<Value *>(aligned)); loc, elementPtrType, ArrayRef<Value *>(aligned));
} }
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>( memRefDescriptor.setAlignedPtr(rewriter, loc, bitcastAligned);
loc, structType, memRefDescriptor, bitcastAligned,
rewriter.getI64ArrayAttr(
LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
// Field 3: Offset in aligned pointer. // Field 3: Offset in aligned pointer.
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>( memRefDescriptor.setOffset(rewriter, loc,
loc, structType, memRefDescriptor, createIndexConstant(rewriter, loc, offset));
createIndexConstant(rewriter, loc, offset),
rewriter.getI64ArrayAttr(
LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
if (type.getRank() == 0) if (type.getRank() == 0)
// No size/stride descriptor in memref, return the descriptor value. // No size/stride descriptor in memref, return the descriptor value.
return rewriter.replaceOp(op, memRefDescriptor); return rewriter.replaceOp(op, {memRefDescriptor});
// Fields 4 and 5: Sizes and strides of the strided MemRef. // Fields 4 and 5: Sizes and strides of the strided MemRef.
// Store all sizes in the descriptor. Only dynamic sizes are passed in as // Store all sizes in the descriptor. Only dynamic sizes are passed in as
@ -846,18 +933,12 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
// Fill size and stride descriptors in memref. // Fill size and stride descriptors in memref.
for (auto indexedSize : llvm::enumerate(sizes)) { for (auto indexedSize : llvm::enumerate(sizes)) {
int64_t index = indexedSize.index(); int64_t index = indexedSize.index();
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>( memRefDescriptor.setSize(rewriter, loc, index, indexedSize.value());
loc, structType, memRefDescriptor, indexedSize.value(), memRefDescriptor.setStride(rewriter, loc, index, strideValues[index]);
rewriter.getI64ArrayAttr(
{LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
loc, structType, memRefDescriptor, strideValues[index],
rewriter.getI64ArrayAttr(
{LLVMTypeConverter::kStridePosInMemRefDescriptor, index}));
} }
// Return the final value of the descriptor. // Return the final value of the descriptor.
rewriter.replaceOp(op, memRefDescriptor); rewriter.replaceOp(op, {memRefDescriptor});
} }
}; };
@ -947,13 +1028,10 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
/*isVarArg=*/false)); /*isVarArg=*/false));
} }
auto type = transformed.memref()->getType().cast<LLVM::LLVMType>(); MemRefDescriptor memref(transformed.memref());
Type elementPtrType = type.getStructElementType(
LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
Value *bufferPtr = extractAllocatedMemRefElementPtr(
rewriter, op->getLoc(), transformed.memref(), elementPtrType);
Value *casted = rewriter.create<LLVM::BitcastOp>( Value *casted = rewriter.create<LLVM::BitcastOp>(
op->getLoc(), getVoidPtrType(), bufferPtr); op->getLoc(), getVoidPtrType(),
memref.allocatedPtr(rewriter, op->getLoc()));
rewriter.replaceOpWithNewOp<LLVM::CallOp>( rewriter.replaceOpWithNewOp<LLVM::CallOp>(
op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted); op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
return matchSuccess(); return matchSuccess();
@ -1003,10 +1081,8 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
int64_t index = dimOp.getIndex(); int64_t index = dimOp.getIndex();
// Extract dynamic size from the memref descriptor. // Extract dynamic size from the memref descriptor.
if (ShapedType::isDynamic(shape[index])) if (ShapedType::isDynamic(shape[index]))
rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>( rewriter.replaceOp(op, {MemRefDescriptor(transformed.memrefOrTensor())
op, getIndexType(), transformed.memrefOrTensor(), .size(rewriter, op->getLoc(), index)});
rewriter.getI64ArrayAttr(
{LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
else else
// Use constant for static size. // Use constant for static size.
rewriter.replaceOp( rewriter.replaceOp(
@ -1058,34 +1134,21 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
// This is a strided getElementPtr variant that linearizes subscripts as: // This is a strided getElementPtr variant that linearizes subscripts as:
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
Value *getStridedElementPtr(Location loc, Type elementTypePtr, Value *getStridedElementPtr(Location loc, Type elementTypePtr,
Value *memRefDescriptor, Value *descriptor, ArrayRef<Value *> indices,
ArrayRef<Value *> indices,
ArrayRef<int64_t> strides, int64_t offset, ArrayRef<int64_t> strides, int64_t offset,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
auto indexTy = this->getIndexType(); MemRefDescriptor memRefDescriptor(descriptor);
Value *base = this->extractAlignedMemRefElementPtr(
rewriter, loc, memRefDescriptor, elementTypePtr); Value *base = memRefDescriptor.alignedPtr(rewriter, loc);
Value *offsetValue = Value *offsetValue = offset == MemRefType::getDynamicStrideOrOffset()
offset == MemRefType::getDynamicStrideOrOffset() ? memRefDescriptor.offset(rewriter, loc)
? rewriter.create<LLVM::ExtractValueOp>( : this->createIndexConstant(rewriter, loc, offset);
loc, indexTy, memRefDescriptor,
rewriter.getI64ArrayAttr(
LLVMTypeConverter::kOffsetPosInMemRefDescriptor))
: this->createIndexConstant(rewriter, loc, offset);
for (int i = 0, e = indices.size(); i < e; ++i) { for (int i = 0, e = indices.size(); i < e; ++i) {
Value *stride; Value *stride =
if (strides[i] != MemRefType::getDynamicStrideOrOffset()) { strides[i] == MemRefType::getDynamicStrideOrOffset()
// Use static stride. ? memRefDescriptor.stride(rewriter, loc, i)
auto attr = : this->createIndexConstant(rewriter, loc, strides[i]);
rewriter.getIntegerAttr(rewriter.getIndexType(), strides[i]);
stride = rewriter.create<LLVM::ConstantOp>(loc, indexTy, attr);
} else {
// Use dynamic stride.
stride = rewriter.create<LLVM::ExtractValueOp>(
loc, indexTy, memRefDescriptor,
rewriter.getI64ArrayAttr(
{LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
}
Value *additionalOffset = Value *additionalOffset =
rewriter.create<LLVM::MulOp>(loc, indices[i], stride); rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
offsetValue = offsetValue =
@ -1452,74 +1515,45 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
return matchFailure(); return matchFailure();
// Create the descriptor. // Create the descriptor.
Value *desc = rewriter.create<LLVM::UndefOp>(loc, targetDescTy); MemRefDescriptor sourceMemRef(adaptor.source());
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
// Copy the buffer pointer from the old descriptor to the new one. // Copy the buffer pointer from the old descriptor to the new one.
Value *sourceDescriptor = adaptor.source(); Value *extracted = sourceMemRef.allocatedPtr(rewriter, loc);
Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, sourceElementTy.getPointerTo(), sourceDescriptor,
rewriter.getI64ArrayAttr(
LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>( Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>(
loc, targetElementTy.getPointerTo(), extracted); loc, targetElementTy.getPointerTo(), extracted);
desc = rewriter.create<LLVM::InsertValueOp>( targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
loc, desc, bitcastPtr,
rewriter.getI64ArrayAttr( extracted = sourceMemRef.alignedPtr(rewriter, loc);
LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, sourceElementTy.getPointerTo(), sourceDescriptor,
rewriter.getI64ArrayAttr(
LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
bitcastPtr = rewriter.create<LLVM::BitcastOp>( bitcastPtr = rewriter.create<LLVM::BitcastOp>(
loc, targetElementTy.getPointerTo(), extracted); loc, targetElementTy.getPointerTo(), extracted);
desc = rewriter.create<LLVM::InsertValueOp>( targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
loc, desc, bitcastPtr,
rewriter.getI64ArrayAttr(
LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
// Extract strides needed to compute offset. // Extract strides needed to compute offset.
SmallVector<Value *, 4> strideValues; SmallVector<Value *, 4> strideValues;
strideValues.reserve(viewMemRefType.getRank()); strideValues.reserve(viewMemRefType.getRank());
for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) { for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i)
strideValues.push_back(rewriter.create<LLVM::ExtractValueOp>( strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
loc, getIndexType(), sourceDescriptor,
rewriter.getI64ArrayAttr(
{LLVMTypeConverter::kStridePosInMemRefDescriptor, i})));
}
// Offset. // Offset.
Value *baseOffset = rewriter.create<LLVM::ExtractValueOp>( Value *baseOffset = sourceMemRef.offset(rewriter, loc);
loc, getIndexType(), sourceDescriptor,
rewriter.getI64ArrayAttr(
LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) { for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) {
Value *min = adaptor.offsets()[i]; Value *min = adaptor.offsets()[i];
baseOffset = rewriter.create<LLVM::AddOp>( baseOffset = rewriter.create<LLVM::AddOp>(
loc, baseOffset, loc, baseOffset,
rewriter.create<LLVM::MulOp>(loc, min, strideValues[i])); rewriter.create<LLVM::MulOp>(loc, min, strideValues[i]));
} }
desc = rewriter.create<LLVM::InsertValueOp>( targetMemRef.setOffset(rewriter, loc, baseOffset);
loc, desc, baseOffset,
rewriter.getI64ArrayAttr(
LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
// Update sizes and strides. // Update sizes and strides.
for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
// Update size. targetMemRef.setSize(rewriter, loc, i, adaptor.sizes()[i]);
desc = rewriter.create<LLVM::InsertValueOp>( targetMemRef.setStride(rewriter, loc, i,
loc, desc, adaptor.sizes()[i], rewriter.create<LLVM::MulOp>(
rewriter.getI64ArrayAttr( loc, adaptor.strides()[i], strideValues[i]));
{LLVMTypeConverter::kSizePosInMemRefDescriptor, i}));
// Update stride.
desc = rewriter.create<LLVM::InsertValueOp>(
loc, desc,
rewriter.create<LLVM::MulOp>(loc, adaptor.strides()[i],
strideValues[i]),
rewriter.getI64ArrayAttr(
{LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
} }
rewriter.replaceOp(op, desc); rewriter.replaceOp(op, {targetMemRef});
return matchSuccess(); return matchSuccess();
} }
}; };
@ -1571,10 +1605,6 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
auto loc = op->getLoc(); auto loc = op->getLoc();
auto viewOp = cast<ViewOp>(op); auto viewOp = cast<ViewOp>(op);
ViewOpOperandAdaptor adaptor(operands); ViewOpOperandAdaptor adaptor(operands);
auto sourceMemRefType = viewOp.source()->getType().cast<MemRefType>();
auto sourceElementTy =
lowering.convertType(sourceMemRefType.getElementType())
.dyn_cast<LLVM::LLVMType>();
auto viewMemRefType = viewOp.getType(); auto viewMemRefType = viewOp.getType();
auto targetElementTy = lowering.convertType(viewMemRefType.getElementType()) auto targetElementTy = lowering.convertType(viewMemRefType.getElementType())
@ -1593,32 +1623,20 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
matchFailure(); matchFailure();
// Create the descriptor. // Create the descriptor.
Value *desc = rewriter.create<LLVM::UndefOp>(loc, targetDescTy); MemRefDescriptor sourceMemRef(adaptor.source());
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
// Field 1: Copy the allocated pointer, used for malloc/free. // Field 1: Copy the allocated pointer, used for malloc/free.
Value *sourceDescriptor = adaptor.source(); Value *extracted = sourceMemRef.allocatedPtr(rewriter, loc);
Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, sourceElementTy.getPointerTo(), sourceDescriptor,
rewriter.getI64ArrayAttr(
LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>( Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>(
loc, targetElementTy.getPointerTo(), extracted); loc, targetElementTy.getPointerTo(), extracted);
desc = rewriter.create<LLVM::InsertValueOp>( targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
loc, desc, bitcastPtr,
rewriter.getI64ArrayAttr(
LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
// Field 2: Copy the actual aligned pointer to payload. // Field 2: Copy the actual aligned pointer to payload.
extracted = rewriter.create<LLVM::ExtractValueOp>( extracted = sourceMemRef.alignedPtr(rewriter, loc);
loc, sourceElementTy.getPointerTo(), sourceDescriptor,
rewriter.getI64ArrayAttr(
LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
bitcastPtr = rewriter.create<LLVM::BitcastOp>( bitcastPtr = rewriter.create<LLVM::BitcastOp>(
loc, targetElementTy.getPointerTo(), extracted); loc, targetElementTy.getPointerTo(), extracted);
desc = rewriter.create<LLVM::InsertValueOp>( targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
loc, desc, bitcastPtr,
rewriter.getI64ArrayAttr(
LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
// Field 3: Copy the offset in aligned pointer. // Field 3: Copy the offset in aligned pointer.
unsigned numDynamicSizes = llvm::size(viewOp.getDynamicSizes()); unsigned numDynamicSizes = llvm::size(viewOp.getDynamicSizes());
@ -1630,14 +1648,11 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
? createIndexConstant(rewriter, loc, offset) ? createIndexConstant(rewriter, loc, offset)
// TODO(ntv): better adaptor. // TODO(ntv): better adaptor.
: sizeAndOffsetOperands.back(); : sizeAndOffsetOperands.back();
desc = rewriter.create<LLVM::InsertValueOp>( targetMemRef.setOffset(rewriter, loc, baseOffset);
loc, desc, baseOffset,
rewriter.getI64ArrayAttr(
LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
// Early exit for 0-D corner case. // Early exit for 0-D corner case.
if (viewMemRefType.getRank() == 0) if (viewMemRefType.getRank() == 0)
return rewriter.replaceOp(op, desc), matchSuccess(); return rewriter.replaceOp(op, {targetMemRef}), matchSuccess();
// Fields 4 and 5: Update sizes and strides. // Fields 4 and 5: Update sizes and strides.
if (strides.back() != 1) if (strides.back() != 1)
@ -1648,20 +1663,14 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
// Update size. // Update size.
Value *size = getSize(rewriter, loc, viewMemRefType.getShape(), Value *size = getSize(rewriter, loc, viewMemRefType.getShape(),
sizeAndOffsetOperands, i); sizeAndOffsetOperands, i);
desc = rewriter.create<LLVM::InsertValueOp>( targetMemRef.setSize(rewriter, loc, i, size);
loc, desc, size,
rewriter.getI64ArrayAttr(
{LLVMTypeConverter::kSizePosInMemRefDescriptor, i}));
// Update stride. // Update stride.
stride = getStride(rewriter, loc, strides, nextSize, stride, i); stride = getStride(rewriter, loc, strides, nextSize, stride, i);
desc = rewriter.create<LLVM::InsertValueOp>( targetMemRef.setStride(rewriter, loc, i, stride);
loc, desc, stride,
rewriter.getI64ArrayAttr(
{LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
nextSize = size; nextSize = size;
} }
rewriter.replaceOp(op, desc); rewriter.replaceOp(op, {targetMemRef});
return matchSuccess(); return matchSuccess();
} }
}; };