forked from OSchip/llvm-project
Use MemRefDescriptor in Linalg-to-LLVM conversion
Following up on the consolidation of MemRef descriptor conversion, update Linalg-to-LLVM conversion to use the helper class that abstracts away the implementation details of the MemRef descriptor. This required MemRefDescriptor to become publicly visible. Since this conversion is heavily EDSC-based, introduce locally an additional wrapper that uses builder and location pointed to by the EDSC context while emitting descriptor manipulation operations. PiperOrigin-RevId: 280429228
This commit is contained in:
parent
a007d4395a
commit
7c28de4aef
|
@ -133,6 +133,61 @@ private:
|
|||
LLVM::LLVMType unwrap(Type type);
|
||||
};
|
||||
|
||||
/// Helper class to produce LLVM dialect operations extracting or inserting
|
||||
/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
|
||||
/// The Value may be null, in which case none of the operations are valid.
|
||||
class MemRefDescriptor {
|
||||
public:
|
||||
/// Construct a helper for the given descriptor value.
|
||||
explicit MemRefDescriptor(Value *descriptor);
|
||||
/// Builds IR creating an `undef` value of the descriptor type.
|
||||
static MemRefDescriptor undef(OpBuilder &builder, Location loc,
|
||||
Type descriptorType);
|
||||
/// Builds IR extracting the allocated pointer from the descriptor.
|
||||
Value *allocatedPtr(OpBuilder &builder, Location loc);
|
||||
/// Builds IR inserting the allocated pointer into the descriptor.
|
||||
void setAllocatedPtr(OpBuilder &builder, Location loc, Value *ptr);
|
||||
|
||||
/// Builds IR extracting the aligned pointer from the descriptor.
|
||||
Value *alignedPtr(OpBuilder &builder, Location loc);
|
||||
|
||||
/// Builds IR inserting the aligned pointer into the descriptor.
|
||||
void setAlignedPtr(OpBuilder &builder, Location loc, Value *ptr);
|
||||
|
||||
/// Builds IR extracting the offset from the descriptor.
|
||||
Value *offset(OpBuilder &builder, Location loc);
|
||||
|
||||
/// Builds IR inserting the offset into the descriptor.
|
||||
void setOffset(OpBuilder &builder, Location loc, Value *offset);
|
||||
|
||||
/// Builds IR extracting the pos-th size from the descriptor.
|
||||
Value *size(OpBuilder &builder, Location loc, unsigned pos);
|
||||
|
||||
/// Builds IR inserting the pos-th size into the descriptor
|
||||
void setSize(OpBuilder &builder, Location loc, unsigned pos, Value *size);
|
||||
|
||||
/// Builds IR extracting the pos-th size from the descriptor.
|
||||
Value *stride(OpBuilder &builder, Location loc, unsigned pos);
|
||||
|
||||
/// Builds IR inserting the pos-th stride into the descriptor
|
||||
void setStride(OpBuilder &builder, Location loc, unsigned pos, Value *stride);
|
||||
|
||||
/*implicit*/ operator Value *() { return value; }
|
||||
|
||||
private:
|
||||
Value *extractPtr(OpBuilder &builder, Location loc, unsigned pos);
|
||||
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value *ptr);
|
||||
|
||||
// Cached descriptor type.
|
||||
Type structType;
|
||||
|
||||
// Cached index type.
|
||||
Type indexType;
|
||||
|
||||
// Actual descriptor.
|
||||
Value *value;
|
||||
};
|
||||
|
||||
/// Base class for operation conversions targeting the LLVM IR dialect. Provides
|
||||
/// conversion patterns with an access to the containing LLVMLowering for the
|
||||
/// purpose of type conversions.
|
||||
|
|
|
@ -234,126 +234,117 @@ LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
|
|||
PatternBenefit benefit)
|
||||
: ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {}
|
||||
|
||||
/*============================================================================*/
|
||||
/* MemRefDescriptor implementation */
|
||||
/*============================================================================*/
|
||||
|
||||
/// Construct a helper for the given descriptor value.
|
||||
MemRefDescriptor::MemRefDescriptor(Value *descriptor) : value(descriptor) {
|
||||
if (value) {
|
||||
structType = value->getType().cast<LLVM::LLVMType>();
|
||||
indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType(
|
||||
LLVMTypeConverter::kOffsetPosInMemRefDescriptor);
|
||||
}
|
||||
}
|
||||
|
||||
/// Builds IR creating an `undef` value of the descriptor type.
|
||||
MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
|
||||
Type descriptorType) {
|
||||
Value *descriptor =
|
||||
builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
|
||||
return MemRefDescriptor(descriptor);
|
||||
}
|
||||
|
||||
/// Builds IR extracting the allocated pointer from the descriptor.
|
||||
Value *MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) {
|
||||
return extractPtr(builder, loc,
|
||||
LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
|
||||
}
|
||||
|
||||
/// Builds IR inserting the allocated pointer into the descriptor.
|
||||
void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
|
||||
Value *ptr) {
|
||||
setPtr(builder, loc, LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor,
|
||||
ptr);
|
||||
}
|
||||
|
||||
/// Builds IR extracting the aligned pointer from the descriptor.
|
||||
Value *MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) {
|
||||
return extractPtr(builder, loc,
|
||||
LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
|
||||
}
|
||||
|
||||
/// Builds IR inserting the aligned pointer into the descriptor.
|
||||
void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
|
||||
Value *ptr) {
|
||||
setPtr(builder, loc, LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor,
|
||||
ptr);
|
||||
}
|
||||
|
||||
/// Builds IR extracting the offset from the descriptor.
|
||||
Value *MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
|
||||
return builder.create<LLVM::ExtractValueOp>(
|
||||
loc, indexType, value,
|
||||
builder.getI64ArrayAttr(LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
|
||||
}
|
||||
|
||||
/// Builds IR inserting the offset into the descriptor.
|
||||
void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
|
||||
Value *offset) {
|
||||
value = builder.create<LLVM::InsertValueOp>(
|
||||
loc, structType, value, offset,
|
||||
builder.getI64ArrayAttr(LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
|
||||
}
|
||||
|
||||
/// Builds IR extracting the pos-th size from the descriptor.
|
||||
Value *MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
|
||||
return builder.create<LLVM::ExtractValueOp>(
|
||||
loc, indexType, value,
|
||||
builder.getI64ArrayAttr(
|
||||
{LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
|
||||
}
|
||||
|
||||
/// Builds IR inserting the pos-th size into the descriptor
|
||||
void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
|
||||
Value *size) {
|
||||
value = builder.create<LLVM::InsertValueOp>(
|
||||
loc, structType, value, size,
|
||||
builder.getI64ArrayAttr(
|
||||
{LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
|
||||
}
|
||||
|
||||
/// Builds IR extracting the pos-th size from the descriptor.
|
||||
Value *MemRefDescriptor::stride(OpBuilder &builder, Location loc,
|
||||
unsigned pos) {
|
||||
return builder.create<LLVM::ExtractValueOp>(
|
||||
loc, indexType, value,
|
||||
builder.getI64ArrayAttr(
|
||||
{LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
|
||||
}
|
||||
|
||||
/// Builds IR inserting the pos-th stride into the descriptor
|
||||
void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
|
||||
Value *stride) {
|
||||
value = builder.create<LLVM::InsertValueOp>(
|
||||
loc, structType, value, stride,
|
||||
builder.getI64ArrayAttr(
|
||||
{LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
|
||||
}
|
||||
|
||||
Value *MemRefDescriptor::extractPtr(OpBuilder &builder, Location loc,
|
||||
unsigned pos) {
|
||||
Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos);
|
||||
return builder.create<LLVM::ExtractValueOp>(loc, type, value,
|
||||
builder.getI64ArrayAttr(pos));
|
||||
}
|
||||
|
||||
void MemRefDescriptor::setPtr(OpBuilder &builder, Location loc, unsigned pos,
|
||||
Value *ptr) {
|
||||
value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
|
||||
builder.getI64ArrayAttr(pos));
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Helper class to produce LLVM dialect operations extracting or inserting
|
||||
/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
|
||||
/// The Value may be null, in which case none of the operations are valid.
|
||||
class MemRefDescriptor {
|
||||
public:
|
||||
/// Construct a helper for the given descriptor value.
|
||||
explicit MemRefDescriptor(Value *descriptor) : value(descriptor) {
|
||||
if (value) {
|
||||
structType = value->getType().cast<LLVM::LLVMType>();
|
||||
indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType(
|
||||
LLVMTypeConverter::kOffsetPosInMemRefDescriptor);
|
||||
}
|
||||
}
|
||||
|
||||
/// Builds IR creating an `undef` value of the descriptor type.
|
||||
static MemRefDescriptor undef(OpBuilder &builder, Location loc,
|
||||
Type descriptorType) {
|
||||
Value *descriptor = builder.create<LLVM::UndefOp>(
|
||||
loc, descriptorType.cast<LLVM::LLVMType>());
|
||||
return MemRefDescriptor(descriptor);
|
||||
}
|
||||
|
||||
/// Builds IR extracting the allocated pointer from the descriptor.
|
||||
Value *allocatedPtr(OpBuilder &builder, Location loc) {
|
||||
return extractPtr(builder, loc,
|
||||
LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
|
||||
}
|
||||
|
||||
/// Builds IR inserting the allocated pointer into the descriptor.
|
||||
void setAllocatedPtr(OpBuilder &builder, Location loc, Value *ptr) {
|
||||
setPtr(builder, loc, LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor,
|
||||
ptr);
|
||||
}
|
||||
|
||||
/// Builds IR extracting the aligned pointer from the descriptor.
|
||||
Value *alignedPtr(OpBuilder &builder, Location loc) {
|
||||
return extractPtr(builder, loc,
|
||||
LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
|
||||
}
|
||||
|
||||
/// Builds IR inserting the aligned pointer into the descriptor.
|
||||
void setAlignedPtr(OpBuilder &builder, Location loc, Value *ptr) {
|
||||
setPtr(builder, loc, LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor,
|
||||
ptr);
|
||||
}
|
||||
|
||||
/// Builds IR extracting the offset from the descriptor.
|
||||
Value *offset(OpBuilder &builder, Location loc) {
|
||||
return builder.create<LLVM::ExtractValueOp>(
|
||||
loc, indexType, value,
|
||||
builder.getI64ArrayAttr(
|
||||
LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
|
||||
}
|
||||
|
||||
/// Builds IR inserting the offset into the descriptor.
|
||||
void setOffset(OpBuilder &builder, Location loc, Value *offset) {
|
||||
value = builder.create<LLVM::InsertValueOp>(
|
||||
loc, structType, value, offset,
|
||||
builder.getI64ArrayAttr(
|
||||
LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
|
||||
}
|
||||
|
||||
/// Builds IR extracting the pos-th size from the descriptor.
|
||||
Value *size(OpBuilder &builder, Location loc, unsigned pos) {
|
||||
return builder.create<LLVM::ExtractValueOp>(
|
||||
loc, indexType, value,
|
||||
builder.getI64ArrayAttr(
|
||||
{LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
|
||||
}
|
||||
|
||||
/// Builds IR inserting the pos-th size into the descriptor
|
||||
void setSize(OpBuilder &builder, Location loc, unsigned pos, Value *size) {
|
||||
value = builder.create<LLVM::InsertValueOp>(
|
||||
loc, structType, value, size,
|
||||
builder.getI64ArrayAttr(
|
||||
{LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
|
||||
}
|
||||
|
||||
/// Builds IR extracting the pos-th size from the descriptor.
|
||||
Value *stride(OpBuilder &builder, Location loc, unsigned pos) {
|
||||
return builder.create<LLVM::ExtractValueOp>(
|
||||
loc, indexType, value,
|
||||
builder.getI64ArrayAttr(
|
||||
{LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
|
||||
}
|
||||
|
||||
/// Builds IR inserting the pos-th stride into the descriptor
|
||||
void setStride(OpBuilder &builder, Location loc, unsigned pos,
|
||||
Value *stride) {
|
||||
value = builder.create<LLVM::InsertValueOp>(
|
||||
loc, structType, value, stride,
|
||||
builder.getI64ArrayAttr(
|
||||
{LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
|
||||
}
|
||||
|
||||
/*implicit*/ operator Value *() { return value; }
|
||||
|
||||
private:
|
||||
Value *extractPtr(OpBuilder &builder, Location loc, unsigned pos) {
|
||||
Type type = structType.getStructElementType(pos);
|
||||
return builder.create<LLVM::ExtractValueOp>(loc, type, value,
|
||||
builder.getI64ArrayAttr(pos));
|
||||
}
|
||||
|
||||
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value *ptr) {
|
||||
value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
|
||||
builder.getI64ArrayAttr(pos));
|
||||
}
|
||||
|
||||
// Cached descriptor type.
|
||||
LLVM::LLVMType structType;
|
||||
|
||||
// Cached index type.
|
||||
LLVM::LLVMType indexType;
|
||||
|
||||
// Actual descriptor.
|
||||
Value *value;
|
||||
};
|
||||
|
||||
// Base class for Standard to LLVM IR op conversions. Matches the Op type
|
||||
// provided as template argument. Carries a reference to the LLVM dialect in
|
||||
// case it is necessary for rewriters.
|
||||
|
|
|
@ -128,33 +128,33 @@ static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) {
|
|||
}
|
||||
|
||||
namespace {
|
||||
/// Factor out the common information for all view conversions:
|
||||
/// 1. common types in (standard and LLVM dialects)
|
||||
/// 2. `pos` method
|
||||
/// 3. view descriptor construction `desc`.
|
||||
/// EDSC-compatible wrapper for MemRefDescriptor.
|
||||
class BaseViewConversionHelper {
|
||||
public:
|
||||
BaseViewConversionHelper(Location loc, MemRefType memRefType,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
LLVMTypeConverter &lowering)
|
||||
: zeroDMemRef(memRefType.getRank() == 0),
|
||||
elementTy(getPtrToElementType(memRefType, lowering)),
|
||||
int64Ty(
|
||||
lowering.convertType(rewriter.getIntegerType(64)).cast<LLVMType>()),
|
||||
desc(nullptr), rewriter(rewriter) {
|
||||
assert(isStrided(memRefType) && "expected strided memref type");
|
||||
viewDescriptorTy = lowering.convertType(memRefType).cast<LLVMType>();
|
||||
desc = rewriter.create<LLVM::UndefOp>(loc, viewDescriptorTy);
|
||||
}
|
||||
BaseViewConversionHelper(Type type)
|
||||
: d(MemRefDescriptor::undef(rewriter(), loc(), type)) {}
|
||||
|
||||
ArrayAttr pos(ArrayRef<int64_t> values) const {
|
||||
return rewriter.getI64ArrayAttr(values);
|
||||
};
|
||||
BaseViewConversionHelper(Value *v) : d(v) {}
|
||||
|
||||
bool zeroDMemRef;
|
||||
LLVMType elementTy, int64Ty, viewDescriptorTy;
|
||||
Value *desc;
|
||||
ConversionPatternRewriter &rewriter;
|
||||
/// Wrappers around MemRefDescriptor that use EDSC builder and location.
|
||||
Value *allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); }
|
||||
void setAllocatedPtr(Value *v) { d.setAllocatedPtr(rewriter(), loc(), v); }
|
||||
Value *alignedPtr() { return d.alignedPtr(rewriter(), loc()); }
|
||||
void setAlignedPtr(Value *v) { d.setAlignedPtr(rewriter(), loc(), v); }
|
||||
Value *offset() { return d.offset(rewriter(), loc()); }
|
||||
void setOffset(Value *v) { d.setOffset(rewriter(), loc(), v); }
|
||||
Value *size(unsigned i) { return d.size(rewriter(), loc(), i); }
|
||||
void setSize(unsigned i, Value *v) { d.setSize(rewriter(), loc(), i, v); }
|
||||
Value *stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
|
||||
void setStride(unsigned i, Value *v) { d.setStride(rewriter(), loc(), i, v); }
|
||||
|
||||
operator Value *() { return d; }
|
||||
|
||||
private:
|
||||
OpBuilder &rewriter() { return ScopedContext::getBuilder(); }
|
||||
Location loc() { return ScopedContext::getLocation(); }
|
||||
|
||||
MemRefDescriptor d;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -200,53 +200,46 @@ public:
|
|||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
SliceOpOperandAdaptor adaptor(operands);
|
||||
Value *baseDesc = adaptor.view();
|
||||
BaseViewConversionHelper baseDesc(adaptor.view());
|
||||
|
||||
auto sliceOp = cast<SliceOp>(op);
|
||||
auto memRefType = sliceOp.getBaseViewType();
|
||||
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64))
|
||||
.cast<LLVM::LLVMType>();
|
||||
|
||||
BaseViewConversionHelper helper(op->getLoc(), sliceOp.getViewType(),
|
||||
rewriter, lowering);
|
||||
LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty;
|
||||
Value *desc = helper.desc;
|
||||
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
BaseViewConversionHelper desc(lowering.convertType(sliceOp.getViewType()));
|
||||
|
||||
// TODO(ntv): extract sizes and emit asserts.
|
||||
SmallVector<Value *, 4> strides(memRefType.getRank());
|
||||
for (int i = 0, e = memRefType.getRank(); i < e; ++i)
|
||||
strides[i] = extractvalue(
|
||||
int64Ty, baseDesc,
|
||||
helper.pos({LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
|
||||
strides[i] = baseDesc.stride(i);
|
||||
|
||||
auto pos = [&rewriter](ArrayRef<int64_t> values) {
|
||||
return rewriter.getI64ArrayAttr(values);
|
||||
};
|
||||
|
||||
// Compute base offset.
|
||||
Value *baseOffset = extractvalue(
|
||||
int64Ty, baseDesc,
|
||||
helper.pos(LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
|
||||
Value *baseOffset = baseDesc.offset();
|
||||
for (int i = 0, e = memRefType.getRank(); i < e; ++i) {
|
||||
Value *indexing = adaptor.indexings()[i];
|
||||
Value *min = indexing;
|
||||
if (sliceOp.indexing(i)->getType().isa<RangeType>())
|
||||
min = extractvalue(int64Ty, indexing, helper.pos(0));
|
||||
min = extractvalue(int64Ty, indexing, pos(0));
|
||||
baseOffset = add(baseOffset, mul(min, strides[i]));
|
||||
}
|
||||
|
||||
// Insert the base and aligned pointers.
|
||||
auto ptrPos =
|
||||
helper.pos(LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
|
||||
desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
|
||||
ptrPos = helper.pos(LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
|
||||
desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
|
||||
desc.setAllocatedPtr(baseDesc.allocatedPtr());
|
||||
desc.setAlignedPtr(baseDesc.alignedPtr());
|
||||
|
||||
// Insert base offset.
|
||||
desc = insertvalue(
|
||||
desc, baseOffset,
|
||||
helper.pos(LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
|
||||
desc.setOffset(baseOffset);
|
||||
|
||||
// Corner case, no sizes or strides: early return the descriptor.
|
||||
if (helper.zeroDMemRef)
|
||||
return rewriter.replaceOp(op, desc), matchSuccess();
|
||||
if (sliceOp.getViewType().getRank() == 0)
|
||||
return rewriter.replaceOp(op, {desc}), matchSuccess();
|
||||
|
||||
Value *zero =
|
||||
constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
|
||||
|
@ -258,12 +251,11 @@ public:
|
|||
if (indexing->getType().isa<RangeType>()) {
|
||||
int rank = en.index();
|
||||
Value *rangeDescriptor = adaptor.indexings()[rank];
|
||||
Value *min = extractvalue(int64Ty, rangeDescriptor, helper.pos(0));
|
||||
Value *max = extractvalue(int64Ty, rangeDescriptor, helper.pos(1));
|
||||
Value *step = extractvalue(int64Ty, rangeDescriptor, helper.pos(2));
|
||||
Value *baseSize = extractvalue(
|
||||
int64Ty, baseDesc,
|
||||
helper.pos({LLVMTypeConverter::kSizePosInMemRefDescriptor, rank}));
|
||||
Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
|
||||
Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
|
||||
Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
|
||||
Value *baseSize = baseDesc.size(rank);
|
||||
|
||||
// Bound upper by base view upper bound.
|
||||
max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
|
||||
baseSize);
|
||||
|
@ -272,19 +264,13 @@ public:
|
|||
size =
|
||||
llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
|
||||
Value *stride = mul(strides[rank], step);
|
||||
desc = insertvalue(
|
||||
desc, size,
|
||||
helper.pos(
|
||||
{LLVMTypeConverter::kSizePosInMemRefDescriptor, numNewDims}));
|
||||
desc = insertvalue(
|
||||
desc, stride,
|
||||
helper.pos(
|
||||
{LLVMTypeConverter::kStridePosInMemRefDescriptor, numNewDims}));
|
||||
desc.setSize(numNewDims, size);
|
||||
desc.setStride(numNewDims, stride);
|
||||
++numNewDims;
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, desc);
|
||||
rewriter.replaceOp(op, {desc});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
@ -306,56 +292,35 @@ public:
|
|||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Initialize the common boilerplate and alloca at the top of the FuncOp.
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
TransposeOpOperandAdaptor adaptor(operands);
|
||||
Value *baseDesc = adaptor.view();
|
||||
BaseViewConversionHelper baseDesc(adaptor.view());
|
||||
|
||||
auto transposeOp = cast<TransposeOp>(op);
|
||||
// No permutation, early exit.
|
||||
if (transposeOp.permutation().isIdentity())
|
||||
return rewriter.replaceOp(op, baseDesc), matchSuccess();
|
||||
return rewriter.replaceOp(op, {baseDesc}), matchSuccess();
|
||||
|
||||
BaseViewConversionHelper helper(op->getLoc(), transposeOp.getViewType(),
|
||||
rewriter, lowering);
|
||||
LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty;
|
||||
Value *desc = helper.desc;
|
||||
BaseViewConversionHelper desc(
|
||||
lowering.convertType(transposeOp.getViewType()));
|
||||
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
// Copy the base and aligned pointers from the old descriptor to the new
|
||||
// one.
|
||||
ArrayAttr ptrPos =
|
||||
helper.pos(LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
|
||||
desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
|
||||
ptrPos = helper.pos(LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
|
||||
desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
|
||||
desc.setAllocatedPtr(baseDesc.allocatedPtr());
|
||||
desc.setAlignedPtr(baseDesc.alignedPtr());
|
||||
|
||||
// Copy the offset pointer from the old descriptor to the new one.
|
||||
ArrayAttr offPos =
|
||||
helper.pos(LLVMTypeConverter::kOffsetPosInMemRefDescriptor);
|
||||
desc = insertvalue(desc, extractvalue(int64Ty, baseDesc, offPos), offPos);
|
||||
desc.setOffset(baseDesc.offset());
|
||||
|
||||
// Iterate over the dimensions and apply size/stride permutation.
|
||||
for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) {
|
||||
int sourcePos = en.index();
|
||||
int targetPos = en.value().cast<AffineDimExpr>().getPosition();
|
||||
Value *size = extractvalue(
|
||||
int64Ty, baseDesc,
|
||||
helper.pos(
|
||||
{LLVMTypeConverter::kSizePosInMemRefDescriptor, sourcePos}));
|
||||
desc =
|
||||
insertvalue(desc, size,
|
||||
helper.pos({LLVMTypeConverter::kSizePosInMemRefDescriptor,
|
||||
targetPos}));
|
||||
Value *stride = extractvalue(
|
||||
int64Ty, baseDesc,
|
||||
helper.pos(
|
||||
{LLVMTypeConverter::kStridePosInMemRefDescriptor, sourcePos}));
|
||||
desc = insertvalue(
|
||||
desc, stride,
|
||||
helper.pos(
|
||||
{LLVMTypeConverter::kStridePosInMemRefDescriptor, targetPos}));
|
||||
desc.setSize(targetPos, baseDesc.size(sourcePos));
|
||||
desc.setStride(targetPos, baseDesc.stride(sourcePos));
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, desc);
|
||||
rewriter.replaceOp(op, {desc});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue