Use MemRefDescriptor in Linalg-to-LLVM conversion

Following up on the consolidation of MemRef descriptor conversion, update
Linalg-to-LLVM conversion to use the helper class that abstracts away the
implementation details of the MemRef descriptor. This required MemRefDescriptor
to become publicly visible. Since this conversion is heavily EDSC-based,
introduce locally an additional wrapper that uses builder and location pointed
to by the EDSC context while emitting descriptor manipulation operations.

PiperOrigin-RevId: 280429228
This commit is contained in:
Alex Zinenko 2019-11-14 08:03:39 -08:00 committed by A. Unique TensorFlower
parent a007d4395a
commit 7c28de4aef
3 changed files with 224 additions and 213 deletions

View File

@ -133,6 +133,61 @@ private:
LLVM::LLVMType unwrap(Type type);
};
/// Helper class to produce LLVM dialect operations extracting or inserting
/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
/// The Value may be null, in which case none of the operations are valid.
class MemRefDescriptor {
public:
/// Construct a helper for the given descriptor value.
explicit MemRefDescriptor(Value *descriptor);
/// Builds IR creating an `undef` value of the descriptor type.
static MemRefDescriptor undef(OpBuilder &builder, Location loc,
Type descriptorType);
/// Builds IR extracting the allocated pointer from the descriptor.
Value *allocatedPtr(OpBuilder &builder, Location loc);
/// Builds IR inserting the allocated pointer into the descriptor.
void setAllocatedPtr(OpBuilder &builder, Location loc, Value *ptr);
/// Builds IR extracting the aligned pointer from the descriptor.
Value *alignedPtr(OpBuilder &builder, Location loc);
/// Builds IR inserting the aligned pointer into the descriptor.
void setAlignedPtr(OpBuilder &builder, Location loc, Value *ptr);
/// Builds IR extracting the offset from the descriptor.
Value *offset(OpBuilder &builder, Location loc);
/// Builds IR inserting the offset into the descriptor.
void setOffset(OpBuilder &builder, Location loc, Value *offset);
/// Builds IR extracting the pos-th size from the descriptor.
Value *size(OpBuilder &builder, Location loc, unsigned pos);
/// Builds IR inserting the pos-th size into the descriptor
void setSize(OpBuilder &builder, Location loc, unsigned pos, Value *size);
/// Builds IR extracting the pos-th size from the descriptor.
Value *stride(OpBuilder &builder, Location loc, unsigned pos);
/// Builds IR inserting the pos-th stride into the descriptor
void setStride(OpBuilder &builder, Location loc, unsigned pos, Value *stride);
/*implicit*/ operator Value *() { return value; }
private:
Value *extractPtr(OpBuilder &builder, Location loc, unsigned pos);
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value *ptr);
// Cached descriptor type.
Type structType;
// Cached index type.
Type indexType;
// Actual descriptor.
Value *value;
};
/// Base class for operation conversions targeting the LLVM IR dialect. Provides
/// conversion patterns with an access to the containing LLVMLowering for the
/// purpose of type conversions.

View File

@ -234,126 +234,117 @@ LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
PatternBenefit benefit)
: ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {}
/*============================================================================*/
/* MemRefDescriptor implementation */
/*============================================================================*/
/// Construct a helper for the given descriptor value.
MemRefDescriptor::MemRefDescriptor(Value *descriptor) : value(descriptor) {
if (value) {
structType = value->getType().cast<LLVM::LLVMType>();
indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType(
LLVMTypeConverter::kOffsetPosInMemRefDescriptor);
}
}
/// Builds IR creating an `undef` value of the descriptor type.
MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
Type descriptorType) {
Value *descriptor =
builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
return MemRefDescriptor(descriptor);
}
/// Builds IR extracting the allocated pointer from the descriptor.
Value *MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) {
return extractPtr(builder, loc,
LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
}
/// Builds IR inserting the allocated pointer into the descriptor.
void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
Value *ptr) {
setPtr(builder, loc, LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor,
ptr);
}
/// Builds IR extracting the aligned pointer from the descriptor.
Value *MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) {
return extractPtr(builder, loc,
LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
}
/// Builds IR inserting the aligned pointer into the descriptor.
void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
Value *ptr) {
setPtr(builder, loc, LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor,
ptr);
}
/// Builds IR extracting the offset from the descriptor.
Value *MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
return builder.create<LLVM::ExtractValueOp>(
loc, indexType, value,
builder.getI64ArrayAttr(LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
}
/// Builds IR inserting the offset into the descriptor.
void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
Value *offset) {
value = builder.create<LLVM::InsertValueOp>(
loc, structType, value, offset,
builder.getI64ArrayAttr(LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
}
/// Builds IR extracting the pos-th size from the descriptor.
Value *MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
return builder.create<LLVM::ExtractValueOp>(
loc, indexType, value,
builder.getI64ArrayAttr(
{LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
}
/// Builds IR inserting the pos-th size into the descriptor
void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
Value *size) {
value = builder.create<LLVM::InsertValueOp>(
loc, structType, value, size,
builder.getI64ArrayAttr(
{LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
}
/// Builds IR extracting the pos-th size from the descriptor.
Value *MemRefDescriptor::stride(OpBuilder &builder, Location loc,
unsigned pos) {
return builder.create<LLVM::ExtractValueOp>(
loc, indexType, value,
builder.getI64ArrayAttr(
{LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
}
/// Builds IR inserting the pos-th stride into the descriptor
void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
Value *stride) {
value = builder.create<LLVM::InsertValueOp>(
loc, structType, value, stride,
builder.getI64ArrayAttr(
{LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
}
Value *MemRefDescriptor::extractPtr(OpBuilder &builder, Location loc,
unsigned pos) {
Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos);
return builder.create<LLVM::ExtractValueOp>(loc, type, value,
builder.getI64ArrayAttr(pos));
}
void MemRefDescriptor::setPtr(OpBuilder &builder, Location loc, unsigned pos,
Value *ptr) {
value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
builder.getI64ArrayAttr(pos));
}
namespace {
/// Helper class to produce LLVM dialect operations extracting or inserting
/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
/// The Value may be null, in which case none of the operations are valid.
class MemRefDescriptor {
public:
/// Construct a helper for the given descriptor value.
explicit MemRefDescriptor(Value *descriptor) : value(descriptor) {
if (value) {
structType = value->getType().cast<LLVM::LLVMType>();
indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType(
LLVMTypeConverter::kOffsetPosInMemRefDescriptor);
}
}
/// Builds IR creating an `undef` value of the descriptor type.
static MemRefDescriptor undef(OpBuilder &builder, Location loc,
Type descriptorType) {
Value *descriptor = builder.create<LLVM::UndefOp>(
loc, descriptorType.cast<LLVM::LLVMType>());
return MemRefDescriptor(descriptor);
}
/// Builds IR extracting the allocated pointer from the descriptor.
Value *allocatedPtr(OpBuilder &builder, Location loc) {
return extractPtr(builder, loc,
LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
}
/// Builds IR inserting the allocated pointer into the descriptor.
void setAllocatedPtr(OpBuilder &builder, Location loc, Value *ptr) {
setPtr(builder, loc, LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor,
ptr);
}
/// Builds IR extracting the aligned pointer from the descriptor.
Value *alignedPtr(OpBuilder &builder, Location loc) {
return extractPtr(builder, loc,
LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
}
/// Builds IR inserting the aligned pointer into the descriptor.
void setAlignedPtr(OpBuilder &builder, Location loc, Value *ptr) {
setPtr(builder, loc, LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor,
ptr);
}
/// Builds IR extracting the offset from the descriptor.
Value *offset(OpBuilder &builder, Location loc) {
return builder.create<LLVM::ExtractValueOp>(
loc, indexType, value,
builder.getI64ArrayAttr(
LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
}
/// Builds IR inserting the offset into the descriptor.
void setOffset(OpBuilder &builder, Location loc, Value *offset) {
value = builder.create<LLVM::InsertValueOp>(
loc, structType, value, offset,
builder.getI64ArrayAttr(
LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
}
/// Builds IR extracting the pos-th size from the descriptor.
Value *size(OpBuilder &builder, Location loc, unsigned pos) {
return builder.create<LLVM::ExtractValueOp>(
loc, indexType, value,
builder.getI64ArrayAttr(
{LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
}
/// Builds IR inserting the pos-th size into the descriptor
void setSize(OpBuilder &builder, Location loc, unsigned pos, Value *size) {
value = builder.create<LLVM::InsertValueOp>(
loc, structType, value, size,
builder.getI64ArrayAttr(
{LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
}
/// Builds IR extracting the pos-th size from the descriptor.
Value *stride(OpBuilder &builder, Location loc, unsigned pos) {
return builder.create<LLVM::ExtractValueOp>(
loc, indexType, value,
builder.getI64ArrayAttr(
{LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
}
/// Builds IR inserting the pos-th stride into the descriptor
void setStride(OpBuilder &builder, Location loc, unsigned pos,
Value *stride) {
value = builder.create<LLVM::InsertValueOp>(
loc, structType, value, stride,
builder.getI64ArrayAttr(
{LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
}
/*implicit*/ operator Value *() { return value; }
private:
Value *extractPtr(OpBuilder &builder, Location loc, unsigned pos) {
Type type = structType.getStructElementType(pos);
return builder.create<LLVM::ExtractValueOp>(loc, type, value,
builder.getI64ArrayAttr(pos));
}
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value *ptr) {
value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
builder.getI64ArrayAttr(pos));
}
// Cached descriptor type.
LLVM::LLVMType structType;
// Cached index type.
LLVM::LLVMType indexType;
// Actual descriptor.
Value *value;
};
// Base class for Standard to LLVM IR op conversions. Matches the Op type
// provided as template argument. Carries a reference to the LLVM dialect in
// case it is necessary for rewriters.

View File

@ -128,33 +128,33 @@ static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) {
}
namespace {
/// Factor out the common information for all view conversions:
/// 1. common types in (standard and LLVM dialects)
/// 2. `pos` method
/// 3. view descriptor construction `desc`.
/// EDSC-compatible wrapper for MemRefDescriptor.
class BaseViewConversionHelper {
public:
BaseViewConversionHelper(Location loc, MemRefType memRefType,
ConversionPatternRewriter &rewriter,
LLVMTypeConverter &lowering)
: zeroDMemRef(memRefType.getRank() == 0),
elementTy(getPtrToElementType(memRefType, lowering)),
int64Ty(
lowering.convertType(rewriter.getIntegerType(64)).cast<LLVMType>()),
desc(nullptr), rewriter(rewriter) {
assert(isStrided(memRefType) && "expected strided memref type");
viewDescriptorTy = lowering.convertType(memRefType).cast<LLVMType>();
desc = rewriter.create<LLVM::UndefOp>(loc, viewDescriptorTy);
}
BaseViewConversionHelper(Type type)
: d(MemRefDescriptor::undef(rewriter(), loc(), type)) {}
ArrayAttr pos(ArrayRef<int64_t> values) const {
return rewriter.getI64ArrayAttr(values);
};
BaseViewConversionHelper(Value *v) : d(v) {}
bool zeroDMemRef;
LLVMType elementTy, int64Ty, viewDescriptorTy;
Value *desc;
ConversionPatternRewriter &rewriter;
/// Wrappers around MemRefDescriptor that use EDSC builder and location.
Value *allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); }
void setAllocatedPtr(Value *v) { d.setAllocatedPtr(rewriter(), loc(), v); }
Value *alignedPtr() { return d.alignedPtr(rewriter(), loc()); }
void setAlignedPtr(Value *v) { d.setAlignedPtr(rewriter(), loc(), v); }
Value *offset() { return d.offset(rewriter(), loc()); }
void setOffset(Value *v) { d.setOffset(rewriter(), loc(), v); }
Value *size(unsigned i) { return d.size(rewriter(), loc(), i); }
void setSize(unsigned i, Value *v) { d.setSize(rewriter(), loc(), i, v); }
Value *stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
void setStride(unsigned i, Value *v) { d.setStride(rewriter(), loc(), i, v); }
operator Value *() { return d; }
private:
OpBuilder &rewriter() { return ScopedContext::getBuilder(); }
Location loc() { return ScopedContext::getLocation(); }
MemRefDescriptor d;
};
} // namespace
@ -200,53 +200,46 @@ public:
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
edsc::ScopedContext context(rewriter, op->getLoc());
SliceOpOperandAdaptor adaptor(operands);
Value *baseDesc = adaptor.view();
BaseViewConversionHelper baseDesc(adaptor.view());
auto sliceOp = cast<SliceOp>(op);
auto memRefType = sliceOp.getBaseViewType();
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64))
.cast<LLVM::LLVMType>();
BaseViewConversionHelper helper(op->getLoc(), sliceOp.getViewType(),
rewriter, lowering);
LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty;
Value *desc = helper.desc;
edsc::ScopedContext context(rewriter, op->getLoc());
BaseViewConversionHelper desc(lowering.convertType(sliceOp.getViewType()));
// TODO(ntv): extract sizes and emit asserts.
SmallVector<Value *, 4> strides(memRefType.getRank());
for (int i = 0, e = memRefType.getRank(); i < e; ++i)
strides[i] = extractvalue(
int64Ty, baseDesc,
helper.pos({LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
strides[i] = baseDesc.stride(i);
auto pos = [&rewriter](ArrayRef<int64_t> values) {
return rewriter.getI64ArrayAttr(values);
};
// Compute base offset.
Value *baseOffset = extractvalue(
int64Ty, baseDesc,
helper.pos(LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
Value *baseOffset = baseDesc.offset();
for (int i = 0, e = memRefType.getRank(); i < e; ++i) {
Value *indexing = adaptor.indexings()[i];
Value *min = indexing;
if (sliceOp.indexing(i)->getType().isa<RangeType>())
min = extractvalue(int64Ty, indexing, helper.pos(0));
min = extractvalue(int64Ty, indexing, pos(0));
baseOffset = add(baseOffset, mul(min, strides[i]));
}
// Insert the base and aligned pointers.
auto ptrPos =
helper.pos(LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
ptrPos = helper.pos(LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
desc.setAllocatedPtr(baseDesc.allocatedPtr());
desc.setAlignedPtr(baseDesc.alignedPtr());
// Insert base offset.
desc = insertvalue(
desc, baseOffset,
helper.pos(LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
desc.setOffset(baseOffset);
// Corner case, no sizes or strides: early return the descriptor.
if (helper.zeroDMemRef)
return rewriter.replaceOp(op, desc), matchSuccess();
if (sliceOp.getViewType().getRank() == 0)
return rewriter.replaceOp(op, {desc}), matchSuccess();
Value *zero =
constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
@ -258,12 +251,11 @@ public:
if (indexing->getType().isa<RangeType>()) {
int rank = en.index();
Value *rangeDescriptor = adaptor.indexings()[rank];
Value *min = extractvalue(int64Ty, rangeDescriptor, helper.pos(0));
Value *max = extractvalue(int64Ty, rangeDescriptor, helper.pos(1));
Value *step = extractvalue(int64Ty, rangeDescriptor, helper.pos(2));
Value *baseSize = extractvalue(
int64Ty, baseDesc,
helper.pos({LLVMTypeConverter::kSizePosInMemRefDescriptor, rank}));
Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
Value *baseSize = baseDesc.size(rank);
// Bound upper by base view upper bound.
max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
baseSize);
@ -272,19 +264,13 @@ public:
size =
llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
Value *stride = mul(strides[rank], step);
desc = insertvalue(
desc, size,
helper.pos(
{LLVMTypeConverter::kSizePosInMemRefDescriptor, numNewDims}));
desc = insertvalue(
desc, stride,
helper.pos(
{LLVMTypeConverter::kStridePosInMemRefDescriptor, numNewDims}));
desc.setSize(numNewDims, size);
desc.setStride(numNewDims, stride);
++numNewDims;
}
}
rewriter.replaceOp(op, desc);
rewriter.replaceOp(op, {desc});
return matchSuccess();
}
};
@ -306,56 +292,35 @@ public:
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
// Initialize the common boilerplate and alloca at the top of the FuncOp.
edsc::ScopedContext context(rewriter, op->getLoc());
TransposeOpOperandAdaptor adaptor(operands);
Value *baseDesc = adaptor.view();
BaseViewConversionHelper baseDesc(adaptor.view());
auto transposeOp = cast<TransposeOp>(op);
// No permutation, early exit.
if (transposeOp.permutation().isIdentity())
return rewriter.replaceOp(op, baseDesc), matchSuccess();
return rewriter.replaceOp(op, {baseDesc}), matchSuccess();
BaseViewConversionHelper helper(op->getLoc(), transposeOp.getViewType(),
rewriter, lowering);
LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty;
Value *desc = helper.desc;
BaseViewConversionHelper desc(
lowering.convertType(transposeOp.getViewType()));
edsc::ScopedContext context(rewriter, op->getLoc());
// Copy the base and aligned pointers from the old descriptor to the new
// one.
ArrayAttr ptrPos =
helper.pos(LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
ptrPos = helper.pos(LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
desc.setAllocatedPtr(baseDesc.allocatedPtr());
desc.setAlignedPtr(baseDesc.alignedPtr());
// Copy the offset pointer from the old descriptor to the new one.
ArrayAttr offPos =
helper.pos(LLVMTypeConverter::kOffsetPosInMemRefDescriptor);
desc = insertvalue(desc, extractvalue(int64Ty, baseDesc, offPos), offPos);
desc.setOffset(baseDesc.offset());
// Iterate over the dimensions and apply size/stride permutation.
for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) {
int sourcePos = en.index();
int targetPos = en.value().cast<AffineDimExpr>().getPosition();
Value *size = extractvalue(
int64Ty, baseDesc,
helper.pos(
{LLVMTypeConverter::kSizePosInMemRefDescriptor, sourcePos}));
desc =
insertvalue(desc, size,
helper.pos({LLVMTypeConverter::kSizePosInMemRefDescriptor,
targetPos}));
Value *stride = extractvalue(
int64Ty, baseDesc,
helper.pos(
{LLVMTypeConverter::kStridePosInMemRefDescriptor, sourcePos}));
desc = insertvalue(
desc, stride,
helper.pos(
{LLVMTypeConverter::kStridePosInMemRefDescriptor, targetPos}));
desc.setSize(targetPos, baseDesc.size(sourcePos));
desc.setStride(targetPos, baseDesc.stride(sourcePos));
}
rewriter.replaceOp(op, desc);
rewriter.replaceOp(op, {desc});
return matchSuccess();
}
};