Convert MemRefCastOp to the LLVM IR dialect

Add support for converting `memref_cast` operations into the LLVM IR dialect.
This goes beyond want is currently implemented in the MLIR standard ops to LLVM
IR translation, but follows the general principles of the memref descriptors.
A memref cast creates a new descriptor containing the same buffer pointer but a
potentially different number of dynamic sizes (as many as dynamic dimensions in
the target memref type).  The lowering copies the buffer pointer to the new
descriptor and inserts dynamic sizes to it.  If the size is static in the
source type, a constant value is inserted as the dynamic size, otherwise a
dynamic value is copied from the source descriptor, taking into account the
difference in dynamic size positions in the descriptor.

PiperOrigin-RevId: 233082035
This commit is contained in:
Alex Zinenko 2019-02-08 10:20:01 -08:00 committed by jpienaar
parent 366ebcf6aa
commit d7e6b33e93
2 changed files with 149 additions and 9 deletions

View File

@ -48,6 +48,11 @@ public:
// Dispatches to the private functions below based on the actual type.
static Type convert(Type t, llvm::Module &llvmModule);
// Convert the element type of the memref `t` to to an LLVM type, get a
// pointer LLVM type pointing to the converted `t`, wrap it into the MLIR LLVM
// dialect type and return.
static Type getMemRefElementPtrType(MemRefType t, llvm::Module &llvmModule);
// Convert a non-empty list of types to an LLVM IR dialect type wrapping an
// LLVM IR structure type, elements of which are formed by converting
// individual types in the given list. Register the type in the `llvmModule`.
@ -264,6 +269,16 @@ Type TypeConverter::convert(Type t, llvm::Module &module) {
return TypeConverter(module, t.getContext()).convertType(t);
}
Type TypeConverter::getMemRefElementPtrType(MemRefType t,
llvm::Module &module) {
auto elementType = t.getElementType();
auto converted = convert(elementType, module);
if (!converted)
return {};
llvm::Type *llvmType = converted.cast<LLVM::LLVMType>().getUnderlyingType();
return LLVM::LLVMType::get(t.getContext(), llvmType->getPointerTo());
}
Type TypeConverter::pack(ArrayRef<Type> types, llvm::Module &module,
MLIRContext &mlirContext) {
return TypeConverter(module, &mlirContext).getPackedResultType(types);
@ -619,6 +634,80 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
}
};
struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
using LLVMLegalizationPattern<MemRefCastOp>::LLVMLegalizationPattern;
PatternMatchResult match(Instruction *op) const override {
if (!LLVMLegalizationPattern<MemRefCastOp>::match(op))
return matchFailure();
auto memRefCastOp = op->cast<MemRefCastOp>();
MemRefType sourceType =
memRefCastOp->getOperand()->getType().cast<MemRefType>();
MemRefType targetType = memRefCastOp->getType();
return (isSupportedMemRefType(targetType) &&
isSupportedMemRefType(sourceType))
? matchSuccess()
: matchFailure();
}
SmallVector<Value *, 4> rewrite(Instruction *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
auto memRefCastOp = op->cast<MemRefCastOp>();
auto targetType = memRefCastOp->getType();
auto sourceType = memRefCastOp->getOperand()->getType().cast<MemRefType>();
// Create the new MemRef descriptor.
auto structType = TypeConverter::convert(targetType, getModule());
Value *newDescriptor = rewriter.create<LLVM::UndefOp>(
op->getLoc(), structType, ArrayRef<Value *>{});
// Copy the data buffer pointer.
auto elementTypePtr =
TypeConverter::getMemRefElementPtrType(targetType, getModule());
Value *oldDescriptor = operands[0];
Value *buffer = rewriter.create<LLVM::ExtractValueOp>(
op->getLoc(), elementTypePtr, ArrayRef<Value *>{oldDescriptor},
getPositionAttribute(rewriter, 0));
newDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, ArrayRef<Value *>{newDescriptor, buffer},
getPositionAttribute(rewriter, 0));
// Fill in the dynamic sizes of the new descriptor. If the size was
// dynamic, copy it from the old descriptor. If the size was static, insert
// the constant. Note that the positions of dynamic sizes in the
// descriptors start from 1 (the buffer pointer is at position zero).
int64_t sourceDynamicDimIdx = 1;
int64_t targetDynamicDimIdx = 1;
for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
// Ignore new static sizes (they will be known from the type). If the
// size was dynamic, update the index of dynamic types.
if (targetType.getShape()[i] != -1) {
if (sourceType.getShape()[i] == -1)
++sourceDynamicDimIdx;
continue;
}
auto sourceSize = sourceType.getShape()[i];
Value *size =
sourceSize == -1
? rewriter.create<LLVM::ExtractValueOp>(
op->getLoc(), getIndexType(),
ArrayRef<Value *>{oldDescriptor},
getPositionAttribute(rewriter, sourceDynamicDimIdx++))
: createIndexConstant(rewriter, op->getLoc(), sourceSize);
newDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, ArrayRef<Value *>{newDescriptor, size},
getPositionAttribute(rewriter, targetDynamicDimIdx++));
}
assert(sourceDynamicDimIdx - 1 == sourceType.getNumDynamicDims() &&
"source dynamic dimensions were not processed");
assert(targetDynamicDimIdx - 1 == targetType.getNumDynamicDims() &&
"target dynamic dimensions were not set up");
return {newDescriptor};
}
};
// Common base for load and store operations on MemRefs. Restricts the match
// to supported MemRef types. Provides functionality to emit code accessing a
// specific element of the underlying data buffer.
@ -669,12 +758,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
// indies.
Value *getElementPtr(Location loc, MemRefType type, Value *memRefDescriptor,
ArrayRef<Value *> indices, FuncBuilder &rewriter) const {
auto elementType =
TypeConverter::convert(type.getElementType(), this->getModule());
auto elementTypePtr = rewriter.getType<LLVM::LLVMType>(
elementType.template cast<LLVM::LLVMType>()
.getUnderlyingType()
->getPointerTo());
auto elementTypePtr =
TypeConverter::getMemRefElementPtrType(type, this->getModule());
// Get the list of MemRef sizes. Static sizes are defined as constants.
// Dynamic sizes are extracted from the MemRef descriptor, where they start
@ -851,9 +936,10 @@ protected:
AddFOpLowering, AddIOpLowering, AllocOpLowering, BranchOpLowering,
Call0OpLowering, CallOpLowering, CmpIOpLowering, CondBranchOpLowering,
ConstLLVMOpLowering, DeallocOpLowering, DivISOpLowering,
DivIUOpLowering, LoadOpLowering, MulFOpLowering, MulIOpLowering,
RemISOpLowering, RemIUOpLowering, ReturnOpLowering, StoreOpLowering,
SubFOpLowering, SubIOpLowering>::build(&converterStorage, *llvmDialect);
DivIUOpLowering, LoadOpLowering, MemRefCastOpLowering, MulFOpLowering,
MulIOpLowering, RemISOpLowering, RemIUOpLowering, ReturnOpLowering,
StoreOpLowering, SubFOpLowering,
SubIOpLowering>::build(&converterStorage, *llvmDialect);
}
// Convert types using the stored LLVM IR module.

View File

@ -91,3 +91,57 @@ func @store(%static : memref<10x42xf32>, %dynamic : memref<?x?xf32>,
store %val, %mixed[%i, %j] : memref<42x?xf32>
return
}
// CHECK-LABEL: func @memref_cast
func @memref_cast(%static : memref<10x42xf32>, %dynamic : memref<?x?xf32>,
%mixed : memref<42x?xf32>) {
// CHECK-NEXT: %0 = "llvm.undef"() : () -> !llvm<"{ float*, i64, i64 }">
// CHECK-NEXT: %1 = "llvm.extractvalue"(%arg0) {position: [0]} : (!llvm<"{ float* }">) -> !llvm<"float*">
// CHECK-NEXT: %2 = "llvm.insertvalue"(%0, %1) {position: [0]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"float*">) -> !llvm<"{ float*, i64, i64 }">
// CHECK-NEXT: %3 = "llvm.constant"() {value: 10 : index} : () -> !llvm<"i64">
// CHECK-NEXT: %4 = "llvm.insertvalue"(%2, %3) {position: [1]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }">
// CHECK-NEXT: %5 = "llvm.constant"() {value: 42 : index} : () -> !llvm<"i64">
// CHECK-NEXT: %6 = "llvm.insertvalue"(%4, %5) {position: [2]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }">
%0 = memref_cast %static : memref<10x42xf32> to memref<?x?xf32>
// CHECK-NEXT: %7 = "llvm.undef"() : () -> !llvm<"{ float*, i64 }">
// CHECK-NEXT: %8 = "llvm.extractvalue"(%arg0) {position: [0]} : (!llvm<"{ float* }">) -> !llvm<"float*">
// CHECK-NEXT: %9 = "llvm.insertvalue"(%7, %8) {position: [0]} : (!llvm<"{ float*, i64 }">, !llvm<"float*">) -> !llvm<"{ float*, i64 }">
// CHECK-NEXT: %10 = "llvm.constant"() {value: 10 : index} : () -> !llvm<"i64">
// CHECK-NEXT: %11 = "llvm.insertvalue"(%9, %10) {position: [1]} : (!llvm<"{ float*, i64 }">, !llvm<"i64">) -> !llvm<"{ float*, i64 }">
%1 = memref_cast %static : memref<10x42xf32> to memref<?x42xf32>
// CHECK-NEXT: %12 = "llvm.undef"() : () -> !llvm<"{ float* }">
// CHECK-NEXT: %13 = "llvm.extractvalue"(%arg1) {position: [0]} : (!llvm<"{ float*, i64, i64 }">) -> !llvm<"float*">
// CHECK-NEXT: %14 = "llvm.insertvalue"(%12, %13) {position: [0]} : (!llvm<"{ float* }">, !llvm<"float*">) -> !llvm<"{ float* }">
%2 = memref_cast %dynamic : memref<?x?xf32> to memref<10x12xf32>
// CHECK-NEXT: %15 = "llvm.undef"() : () -> !llvm<"{ float*, i64 }">
// CHECK-NEXT: %16 = "llvm.extractvalue"(%arg1) {position: [0]} : (!llvm<"{ float*, i64, i64 }">) -> !llvm<"float*">
// CHECK-NEXT: %17 = "llvm.insertvalue"(%15, %16) {position: [0]} : (!llvm<"{ float*, i64 }">, !llvm<"float*">) -> !llvm<"{ float*, i64 }">
// CHECK-NEXT: %18 = "llvm.extractvalue"(%arg1) {position: [1]} : (!llvm<"{ float*, i64, i64 }">) -> !llvm<"i64">
// CHECK-NEXT: %19 = "llvm.insertvalue"(%17, %18) {position: [1]} : (!llvm<"{ float*, i64 }">, !llvm<"i64">) -> !llvm<"{ float*, i64 }">
%3 = memref_cast %dynamic : memref<?x?xf32> to memref<?x12xf32>
// CHECK-NEXT: %20 = "llvm.undef"() : () -> !llvm<"{ float*, i64, i64 }">
// CHECK-NEXT: %21 = "llvm.extractvalue"(%arg2) {position: [0]} : (!llvm<"{ float*, i64 }">) -> !llvm<"float*">
// CHECK-NEXT: %22 = "llvm.insertvalue"(%20, %21) {position: [0]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"float*">) -> !llvm<"{ float*, i64, i64 }">
// CHECK-NEXT: %23 = "llvm.constant"() {value: 42 : index} : () -> !llvm<"i64">
// CHECK-NEXT: %24 = "llvm.insertvalue"(%22, %23) {position: [1]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }">
// CHECK-NEXT: %25 = "llvm.extractvalue"(%arg2) {position: [1]} : (!llvm<"{ float*, i64 }">) -> !llvm<"i64">
// CHECK-NEXT: %26 = "llvm.insertvalue"(%24, %25) {position: [2]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }">
%4 = memref_cast %mixed : memref<42x?xf32> to memref<?x?xf32>
// CHECK-NEXT: %27 = "llvm.undef"() : () -> !llvm<"{ float* }">
// CHECK-NEXT: %28 = "llvm.extractvalue"(%arg2) {position: [0]} : (!llvm<"{ float*, i64 }">) -> !llvm<"float*">
// CHECK-NEXT: %29 = "llvm.insertvalue"(%27, %28) {position: [0]} : (!llvm<"{ float* }">, !llvm<"float*">) -> !llvm<"{ float* }">
%5 = memref_cast %mixed : memref<42x?xf32> to memref<42x1xf32>
// CHECK-NEXT: %30 = "llvm.undef"() : () -> !llvm<"{ float*, i64 }">
// CHECK-NEXT: %31 = "llvm.extractvalue"(%arg2) {position: [0]} : (!llvm<"{ float*, i64 }">) -> !llvm<"float*">
// CHECK-NEXT: %32 = "llvm.insertvalue"(%30, %31) {position: [0]} : (!llvm<"{ float*, i64 }">, !llvm<"float*">) -> !llvm<"{ float*, i64 }">
// CHECK-NEXT: %33 = "llvm.constant"() {value: 42 : index} : () -> !llvm<"i64">
// CHECK-NEXT: %34 = "llvm.insertvalue"(%32, %33) {position: [1]} : (!llvm<"{ float*, i64 }">, !llvm<"i64">) -> !llvm<"{ float*, i64 }">
%6 = memref_cast %mixed : memref<42x?xf32> to memref<?x1xf32>
return
}