forked from OSchip/llvm-project
Convert MemRefCastOp to the LLVM IR dialect
Add support for converting `memref_cast` operations into the LLVM IR dialect. This goes beyond want is currently implemented in the MLIR standard ops to LLVM IR translation, but follows the general principles of the memref descriptors. A memref cast creates a new descriptor containing the same buffer pointer but a potentially different number of dynamic sizes (as many as dynamic dimensions in the target memref type). The lowering copies the buffer pointer to the new descriptor and inserts dynamic sizes to it. If the size is static in the source type, a constant value is inserted as the dynamic size, otherwise a dynamic value is copied from the source descriptor, taking into account the difference in dynamic size positions in the descriptor. PiperOrigin-RevId: 233082035
This commit is contained in:
parent
366ebcf6aa
commit
d7e6b33e93
|
@ -48,6 +48,11 @@ public:
|
|||
// Dispatches to the private functions below based on the actual type.
|
||||
static Type convert(Type t, llvm::Module &llvmModule);
|
||||
|
||||
// Convert the element type of the memref `t` to to an LLVM type, get a
|
||||
// pointer LLVM type pointing to the converted `t`, wrap it into the MLIR LLVM
|
||||
// dialect type and return.
|
||||
static Type getMemRefElementPtrType(MemRefType t, llvm::Module &llvmModule);
|
||||
|
||||
// Convert a non-empty list of types to an LLVM IR dialect type wrapping an
|
||||
// LLVM IR structure type, elements of which are formed by converting
|
||||
// individual types in the given list. Register the type in the `llvmModule`.
|
||||
|
@ -264,6 +269,16 @@ Type TypeConverter::convert(Type t, llvm::Module &module) {
|
|||
return TypeConverter(module, t.getContext()).convertType(t);
|
||||
}
|
||||
|
||||
Type TypeConverter::getMemRefElementPtrType(MemRefType t,
|
||||
llvm::Module &module) {
|
||||
auto elementType = t.getElementType();
|
||||
auto converted = convert(elementType, module);
|
||||
if (!converted)
|
||||
return {};
|
||||
llvm::Type *llvmType = converted.cast<LLVM::LLVMType>().getUnderlyingType();
|
||||
return LLVM::LLVMType::get(t.getContext(), llvmType->getPointerTo());
|
||||
}
|
||||
|
||||
Type TypeConverter::pack(ArrayRef<Type> types, llvm::Module &module,
|
||||
MLIRContext &mlirContext) {
|
||||
return TypeConverter(module, &mlirContext).getPackedResultType(types);
|
||||
|
@ -619,6 +634,80 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
|
|||
}
|
||||
};
|
||||
|
||||
struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
|
||||
using LLVMLegalizationPattern<MemRefCastOp>::LLVMLegalizationPattern;
|
||||
|
||||
PatternMatchResult match(Instruction *op) const override {
|
||||
if (!LLVMLegalizationPattern<MemRefCastOp>::match(op))
|
||||
return matchFailure();
|
||||
auto memRefCastOp = op->cast<MemRefCastOp>();
|
||||
MemRefType sourceType =
|
||||
memRefCastOp->getOperand()->getType().cast<MemRefType>();
|
||||
MemRefType targetType = memRefCastOp->getType();
|
||||
return (isSupportedMemRefType(targetType) &&
|
||||
isSupportedMemRefType(sourceType))
|
||||
? matchSuccess()
|
||||
: matchFailure();
|
||||
}
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Instruction *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
auto memRefCastOp = op->cast<MemRefCastOp>();
|
||||
auto targetType = memRefCastOp->getType();
|
||||
auto sourceType = memRefCastOp->getOperand()->getType().cast<MemRefType>();
|
||||
|
||||
// Create the new MemRef descriptor.
|
||||
auto structType = TypeConverter::convert(targetType, getModule());
|
||||
Value *newDescriptor = rewriter.create<LLVM::UndefOp>(
|
||||
op->getLoc(), structType, ArrayRef<Value *>{});
|
||||
|
||||
// Copy the data buffer pointer.
|
||||
auto elementTypePtr =
|
||||
TypeConverter::getMemRefElementPtrType(targetType, getModule());
|
||||
Value *oldDescriptor = operands[0];
|
||||
Value *buffer = rewriter.create<LLVM::ExtractValueOp>(
|
||||
op->getLoc(), elementTypePtr, ArrayRef<Value *>{oldDescriptor},
|
||||
getPositionAttribute(rewriter, 0));
|
||||
newDescriptor = rewriter.create<LLVM::InsertValueOp>(
|
||||
op->getLoc(), structType, ArrayRef<Value *>{newDescriptor, buffer},
|
||||
getPositionAttribute(rewriter, 0));
|
||||
|
||||
// Fill in the dynamic sizes of the new descriptor. If the size was
|
||||
// dynamic, copy it from the old descriptor. If the size was static, insert
|
||||
// the constant. Note that the positions of dynamic sizes in the
|
||||
// descriptors start from 1 (the buffer pointer is at position zero).
|
||||
int64_t sourceDynamicDimIdx = 1;
|
||||
int64_t targetDynamicDimIdx = 1;
|
||||
for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
|
||||
// Ignore new static sizes (they will be known from the type). If the
|
||||
// size was dynamic, update the index of dynamic types.
|
||||
if (targetType.getShape()[i] != -1) {
|
||||
if (sourceType.getShape()[i] == -1)
|
||||
++sourceDynamicDimIdx;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto sourceSize = sourceType.getShape()[i];
|
||||
Value *size =
|
||||
sourceSize == -1
|
||||
? rewriter.create<LLVM::ExtractValueOp>(
|
||||
op->getLoc(), getIndexType(),
|
||||
ArrayRef<Value *>{oldDescriptor},
|
||||
getPositionAttribute(rewriter, sourceDynamicDimIdx++))
|
||||
: createIndexConstant(rewriter, op->getLoc(), sourceSize);
|
||||
newDescriptor = rewriter.create<LLVM::InsertValueOp>(
|
||||
op->getLoc(), structType, ArrayRef<Value *>{newDescriptor, size},
|
||||
getPositionAttribute(rewriter, targetDynamicDimIdx++));
|
||||
}
|
||||
assert(sourceDynamicDimIdx - 1 == sourceType.getNumDynamicDims() &&
|
||||
"source dynamic dimensions were not processed");
|
||||
assert(targetDynamicDimIdx - 1 == targetType.getNumDynamicDims() &&
|
||||
"target dynamic dimensions were not set up");
|
||||
|
||||
return {newDescriptor};
|
||||
}
|
||||
};
|
||||
|
||||
// Common base for load and store operations on MemRefs. Restricts the match
|
||||
// to supported MemRef types. Provides functionality to emit code accessing a
|
||||
// specific element of the underlying data buffer.
|
||||
|
@ -669,12 +758,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
|
|||
// indies.
|
||||
Value *getElementPtr(Location loc, MemRefType type, Value *memRefDescriptor,
|
||||
ArrayRef<Value *> indices, FuncBuilder &rewriter) const {
|
||||
auto elementType =
|
||||
TypeConverter::convert(type.getElementType(), this->getModule());
|
||||
auto elementTypePtr = rewriter.getType<LLVM::LLVMType>(
|
||||
elementType.template cast<LLVM::LLVMType>()
|
||||
.getUnderlyingType()
|
||||
->getPointerTo());
|
||||
auto elementTypePtr =
|
||||
TypeConverter::getMemRefElementPtrType(type, this->getModule());
|
||||
|
||||
// Get the list of MemRef sizes. Static sizes are defined as constants.
|
||||
// Dynamic sizes are extracted from the MemRef descriptor, where they start
|
||||
|
@ -851,9 +936,10 @@ protected:
|
|||
AddFOpLowering, AddIOpLowering, AllocOpLowering, BranchOpLowering,
|
||||
Call0OpLowering, CallOpLowering, CmpIOpLowering, CondBranchOpLowering,
|
||||
ConstLLVMOpLowering, DeallocOpLowering, DivISOpLowering,
|
||||
DivIUOpLowering, LoadOpLowering, MulFOpLowering, MulIOpLowering,
|
||||
RemISOpLowering, RemIUOpLowering, ReturnOpLowering, StoreOpLowering,
|
||||
SubFOpLowering, SubIOpLowering>::build(&converterStorage, *llvmDialect);
|
||||
DivIUOpLowering, LoadOpLowering, MemRefCastOpLowering, MulFOpLowering,
|
||||
MulIOpLowering, RemISOpLowering, RemIUOpLowering, ReturnOpLowering,
|
||||
StoreOpLowering, SubFOpLowering,
|
||||
SubIOpLowering>::build(&converterStorage, *llvmDialect);
|
||||
}
|
||||
|
||||
// Convert types using the stored LLVM IR module.
|
||||
|
|
|
@ -91,3 +91,57 @@ func @store(%static : memref<10x42xf32>, %dynamic : memref<?x?xf32>,
|
|||
store %val, %mixed[%i, %j] : memref<42x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @memref_cast
|
||||
func @memref_cast(%static : memref<10x42xf32>, %dynamic : memref<?x?xf32>,
|
||||
%mixed : memref<42x?xf32>) {
|
||||
// CHECK-NEXT: %0 = "llvm.undef"() : () -> !llvm<"{ float*, i64, i64 }">
|
||||
// CHECK-NEXT: %1 = "llvm.extractvalue"(%arg0) {position: [0]} : (!llvm<"{ float* }">) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %2 = "llvm.insertvalue"(%0, %1) {position: [0]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"float*">) -> !llvm<"{ float*, i64, i64 }">
|
||||
// CHECK-NEXT: %3 = "llvm.constant"() {value: 10 : index} : () -> !llvm<"i64">
|
||||
// CHECK-NEXT: %4 = "llvm.insertvalue"(%2, %3) {position: [1]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }">
|
||||
// CHECK-NEXT: %5 = "llvm.constant"() {value: 42 : index} : () -> !llvm<"i64">
|
||||
// CHECK-NEXT: %6 = "llvm.insertvalue"(%4, %5) {position: [2]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }">
|
||||
%0 = memref_cast %static : memref<10x42xf32> to memref<?x?xf32>
|
||||
|
||||
// CHECK-NEXT: %7 = "llvm.undef"() : () -> !llvm<"{ float*, i64 }">
|
||||
// CHECK-NEXT: %8 = "llvm.extractvalue"(%arg0) {position: [0]} : (!llvm<"{ float* }">) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %9 = "llvm.insertvalue"(%7, %8) {position: [0]} : (!llvm<"{ float*, i64 }">, !llvm<"float*">) -> !llvm<"{ float*, i64 }">
|
||||
// CHECK-NEXT: %10 = "llvm.constant"() {value: 10 : index} : () -> !llvm<"i64">
|
||||
// CHECK-NEXT: %11 = "llvm.insertvalue"(%9, %10) {position: [1]} : (!llvm<"{ float*, i64 }">, !llvm<"i64">) -> !llvm<"{ float*, i64 }">
|
||||
%1 = memref_cast %static : memref<10x42xf32> to memref<?x42xf32>
|
||||
|
||||
// CHECK-NEXT: %12 = "llvm.undef"() : () -> !llvm<"{ float* }">
|
||||
// CHECK-NEXT: %13 = "llvm.extractvalue"(%arg1) {position: [0]} : (!llvm<"{ float*, i64, i64 }">) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %14 = "llvm.insertvalue"(%12, %13) {position: [0]} : (!llvm<"{ float* }">, !llvm<"float*">) -> !llvm<"{ float* }">
|
||||
%2 = memref_cast %dynamic : memref<?x?xf32> to memref<10x12xf32>
|
||||
|
||||
// CHECK-NEXT: %15 = "llvm.undef"() : () -> !llvm<"{ float*, i64 }">
|
||||
// CHECK-NEXT: %16 = "llvm.extractvalue"(%arg1) {position: [0]} : (!llvm<"{ float*, i64, i64 }">) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %17 = "llvm.insertvalue"(%15, %16) {position: [0]} : (!llvm<"{ float*, i64 }">, !llvm<"float*">) -> !llvm<"{ float*, i64 }">
|
||||
// CHECK-NEXT: %18 = "llvm.extractvalue"(%arg1) {position: [1]} : (!llvm<"{ float*, i64, i64 }">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %19 = "llvm.insertvalue"(%17, %18) {position: [1]} : (!llvm<"{ float*, i64 }">, !llvm<"i64">) -> !llvm<"{ float*, i64 }">
|
||||
%3 = memref_cast %dynamic : memref<?x?xf32> to memref<?x12xf32>
|
||||
|
||||
// CHECK-NEXT: %20 = "llvm.undef"() : () -> !llvm<"{ float*, i64, i64 }">
|
||||
// CHECK-NEXT: %21 = "llvm.extractvalue"(%arg2) {position: [0]} : (!llvm<"{ float*, i64 }">) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %22 = "llvm.insertvalue"(%20, %21) {position: [0]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"float*">) -> !llvm<"{ float*, i64, i64 }">
|
||||
// CHECK-NEXT: %23 = "llvm.constant"() {value: 42 : index} : () -> !llvm<"i64">
|
||||
// CHECK-NEXT: %24 = "llvm.insertvalue"(%22, %23) {position: [1]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }">
|
||||
// CHECK-NEXT: %25 = "llvm.extractvalue"(%arg2) {position: [1]} : (!llvm<"{ float*, i64 }">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %26 = "llvm.insertvalue"(%24, %25) {position: [2]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }">
|
||||
%4 = memref_cast %mixed : memref<42x?xf32> to memref<?x?xf32>
|
||||
|
||||
// CHECK-NEXT: %27 = "llvm.undef"() : () -> !llvm<"{ float* }">
|
||||
// CHECK-NEXT: %28 = "llvm.extractvalue"(%arg2) {position: [0]} : (!llvm<"{ float*, i64 }">) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %29 = "llvm.insertvalue"(%27, %28) {position: [0]} : (!llvm<"{ float* }">, !llvm<"float*">) -> !llvm<"{ float* }">
|
||||
%5 = memref_cast %mixed : memref<42x?xf32> to memref<42x1xf32>
|
||||
|
||||
// CHECK-NEXT: %30 = "llvm.undef"() : () -> !llvm<"{ float*, i64 }">
|
||||
// CHECK-NEXT: %31 = "llvm.extractvalue"(%arg2) {position: [0]} : (!llvm<"{ float*, i64 }">) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %32 = "llvm.insertvalue"(%30, %31) {position: [0]} : (!llvm<"{ float*, i64 }">, !llvm<"float*">) -> !llvm<"{ float*, i64 }">
|
||||
// CHECK-NEXT: %33 = "llvm.constant"() {value: 42 : index} : () -> !llvm<"i64">
|
||||
// CHECK-NEXT: %34 = "llvm.insertvalue"(%32, %33) {position: [1]} : (!llvm<"{ float*, i64 }">, !llvm<"i64">) -> !llvm<"{ float*, i64 }">
|
||||
%6 = memref_cast %mixed : memref<42x?xf32> to memref<?x1xf32>
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue