forked from OSchip/llvm-project
Port load/store op translation to LLVM IR dialect lowering
Implement the lowering of memref load and store standard operations into the LLVM IR dialect. This largely follows the existing mechanism in MLIR-to-LLVM-IR translation for the sake of compatibility. A memref value is transformed into a memref descriptor value which holds the pointer to the underlying data buffer and the dynamic memref sizes. The data buffer is contiguous. Accesses to multidimensional memrefs are linearized in row-major form. In linear address computation, statically known sizes are used as constants while dynamic sizes are extracted from the memref descriptor. PiperOrigin-RevId: 233043846
This commit is contained in:
parent
c419accea3
commit
4c35bbbb51
|
@ -619,6 +619,131 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
|
|||
}
|
||||
};
|
||||
|
||||
// Common base for load and store operations on MemRefs. Restricts the match
|
||||
// to supported MemRef types. Provides functionality to emit code accessing a
|
||||
// specific element of the underlying data buffer.
|
||||
template <typename Derived>
|
||||
struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
|
||||
using LLVMLegalizationPattern<Derived>::LLVMLegalizationPattern;
|
||||
using Base = LoadStoreOpLowering<Derived>;
|
||||
|
||||
PatternMatchResult match(Instruction *op) const override {
|
||||
if (!LLVMLegalizationPattern<Derived>::match(op))
|
||||
return this->matchFailure();
|
||||
auto loadOp = op->cast<Derived>();
|
||||
MemRefType type = loadOp->getMemRefType();
|
||||
return isSupportedMemRefType(type) ? this->matchSuccess()
|
||||
: this->matchFailure();
|
||||
}
|
||||
|
||||
// Given subscript indices and array sizes in row-major order,
|
||||
// i_n, i_{n-1}, ..., i_1
|
||||
// s_n, s_{n-1}, ..., s_1
|
||||
// obtain a value that corresponds to the linearized subscript
|
||||
// \sum_k i_k * \prod_{j=1}^{k-1} s_j
|
||||
// by accumulating the running linearized value.
|
||||
// Note that `indices` and `allocSizes` are passed in the same order as they
|
||||
// appear in load/store operations and memref type declarations.
|
||||
Value *linearizeSubscripts(FuncBuilder &builder, Location loc,
|
||||
ArrayRef<Value *> indices,
|
||||
ArrayRef<Value *> allocSizes) const {
|
||||
assert(indices.size() == allocSizes.size() &&
|
||||
"mismatching number of indices and allocation sizes");
|
||||
assert(!indices.empty() && "cannot linearize a 0-dimensional access");
|
||||
|
||||
Value *linearized = indices.front();
|
||||
for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) {
|
||||
linearized = builder.create<LLVM::MulOp>(
|
||||
loc, this->getIndexType(),
|
||||
ArrayRef<Value *>{linearized, allocSizes[i]});
|
||||
linearized = builder.create<LLVM::AddOp>(
|
||||
loc, this->getIndexType(), ArrayRef<Value *>{linearized, indices[i]});
|
||||
}
|
||||
return linearized;
|
||||
}
|
||||
|
||||
// Given the MemRef type, a descriptor and a list of indices, extract the data
|
||||
// buffer pointer from the descriptor, convert multi-dimensional subscripts
|
||||
// into a linearized index (using dynamic size data from the descriptor if
|
||||
// necessary) and get the pointer to the buffer element identified by the
|
||||
// indies.
|
||||
Value *getElementPtr(Location loc, MemRefType type, Value *memRefDescriptor,
|
||||
ArrayRef<Value *> indices, FuncBuilder &rewriter) const {
|
||||
auto elementType =
|
||||
TypeConverter::convert(type.getElementType(), this->getModule());
|
||||
auto elementTypePtr = rewriter.getType<LLVM::LLVMType>(
|
||||
elementType.template cast<LLVM::LLVMType>()
|
||||
.getUnderlyingType()
|
||||
->getPointerTo());
|
||||
|
||||
// Get the list of MemRef sizes. Static sizes are defined as constants.
|
||||
// Dynamic sizes are extracted from the MemRef descriptor, where they start
|
||||
// from the position 1 (the buffer is at position 0).
|
||||
SmallVector<Value *, 4> sizes;
|
||||
unsigned dynamicSizeIdx = 1;
|
||||
for (int64_t s : type.getShape()) {
|
||||
if (s == -1) {
|
||||
Value *size = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, this->getIndexType(), ArrayRef<Value *>{memRefDescriptor},
|
||||
llvm::makeArrayRef(
|
||||
this->getPositionAttribute(rewriter, dynamicSizeIdx++)));
|
||||
sizes.push_back(size);
|
||||
} else {
|
||||
sizes.push_back(this->createIndexConstant(rewriter, loc, s));
|
||||
}
|
||||
}
|
||||
|
||||
// The second and subsequent operands are access subscripts. Obtain the
|
||||
// linearized address in the buffer.
|
||||
Value *subscript = linearizeSubscripts(rewriter, loc, indices, sizes);
|
||||
|
||||
Value *dataPtr = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, elementTypePtr, ArrayRef<Value *>{memRefDescriptor},
|
||||
llvm::makeArrayRef(this->getPositionAttribute(rewriter, 0)));
|
||||
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr,
|
||||
ArrayRef<Value *>{dataPtr, subscript},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
}
|
||||
};
|
||||
|
||||
// Load operation is lowered to obtaining a pointer to the indexed element
|
||||
// and loading it.
|
||||
struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
|
||||
using Base::Base;
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Instruction *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
auto loadOp = op->cast<LoadOp>();
|
||||
auto type = loadOp->getMemRefType();
|
||||
auto elementType =
|
||||
TypeConverter::convert(type.getElementType(), getModule());
|
||||
Value *dataPtr = getElementPtr(op->getLoc(), type, operands.front(),
|
||||
operands.drop_front(), rewriter);
|
||||
|
||||
SmallVector<Value *, 4> results;
|
||||
results.push_back(rewriter.create<LLVM::LoadOp>(
|
||||
op->getLoc(), elementType, ArrayRef<Value *>{dataPtr}));
|
||||
return results;
|
||||
}
|
||||
};
|
||||
|
||||
// Store opreation is lowered to obtaining a pointer to the indexed element,
|
||||
// and storing the given value to it.
|
||||
struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
|
||||
using Base::Base;
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Instruction *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
auto storeOp = op->cast<StoreOp>();
|
||||
auto type = storeOp->getMemRefType();
|
||||
Value *dataPtr = getElementPtr(op->getLoc(), type, operands[1],
|
||||
operands.drop_front(2), rewriter);
|
||||
|
||||
rewriter.create<LLVM::StoreOp>(op->getLoc(), operands[0], dataPtr);
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
// Base class for LLVM IR lowering terminator operations with successors.
|
||||
template <typename SourceOp, typename TargetOp>
|
||||
struct OneToOneLLVMTerminatorLowering
|
||||
|
@ -723,12 +848,12 @@ protected:
|
|||
|
||||
// FIXME: this should be tablegen'ed
|
||||
return ConversionListBuilder<
|
||||
AllocOpLowering, DeallocOpLowering, AddIOpLowering, SubIOpLowering,
|
||||
MulIOpLowering, DivISOpLowering, DivIUOpLowering, RemISOpLowering,
|
||||
RemIUOpLowering, AddFOpLowering, SubFOpLowering, MulFOpLowering,
|
||||
CmpIOpLowering, CallOpLowering, Call0OpLowering, BranchOpLowering,
|
||||
CondBranchOpLowering, ReturnOpLowering,
|
||||
ConstLLVMOpLowering>::build(&converterStorage, *llvmDialect);
|
||||
AddFOpLowering, AddIOpLowering, AllocOpLowering, BranchOpLowering,
|
||||
Call0OpLowering, CallOpLowering, CmpIOpLowering, CondBranchOpLowering,
|
||||
ConstLLVMOpLowering, DeallocOpLowering, DivISOpLowering,
|
||||
DivIUOpLowering, LoadOpLowering, MulFOpLowering, MulIOpLowering,
|
||||
RemISOpLowering, RemIUOpLowering, ReturnOpLowering, StoreOpLowering,
|
||||
SubFOpLowering, SubIOpLowering>::build(&converterStorage, *llvmDialect);
|
||||
}
|
||||
|
||||
// Convert types using the stored LLVM IR module.
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
// RUN: mlir-opt -convert-to-llvmir %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @alloc(%arg0: !llvm<"i64">, %arg1: !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }"> {
|
||||
func @alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> {
|
||||
// CHECK-NEXT: %0 = "llvm.constant"() {value: 42 : index} : () -> !llvm<"i64">
|
||||
// CHECK-NEXT: %1 = "llvm.mul"(%arg0, %0) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %2 = "llvm.mul"(%1, %arg1) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %3 = "llvm.undef"() : () -> !llvm<"{ float*, i64, i64 }">
|
||||
// CHECK-NEXT: %4 = "llvm.constant"() {value: 4 : index} : () -> !llvm<"i64">
|
||||
// CHECK-NEXT: %5 = "llvm.mul"(%2, %4) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %6 = "llvm.call"(%5) {callee: @malloc : (!llvm<"i64">) -> !llvm<"i8*">} : (!llvm<"i64">) -> !llvm<"i8*">
|
||||
// CHECK-NEXT: %7 = "llvm.bitcast"(%6) : (!llvm<"i8*">) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %8 = "llvm.insertvalue"(%3, %7) {position: [0]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"float*">) -> !llvm<"{ float*, i64, i64 }">
|
||||
// CHECK-NEXT: %9 = "llvm.insertvalue"(%8, %arg0) {position: [1]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }">
|
||||
// CHECK-NEXT: %10 = "llvm.insertvalue"(%9, %arg1) {position: [2]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }">
|
||||
%0 = alloc(%arg0, %arg1) : memref<?x42x?xf32>
|
||||
// CHECK-NEXT: "llvm.return"(%10) : (!llvm<"{ float*, i64, i64 }">) -> ()
|
||||
return %0 : memref<?x42x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @dealloc(%arg0: !llvm<"{ float*, i64, i64 }">) {
|
||||
func @dealloc(%arg0: memref<?x42x?xf32>) {
|
||||
// CHECK-NEXT: %0 = "llvm.extractvalue"(%arg0) {position: [0]} : (!llvm<"{ float*, i64, i64 }">) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %1 = "llvm.bitcast"(%0) : (!llvm<"float*">) -> !llvm<"i8*">
|
||||
// CHECK-NEXT: "llvm.call0"(%1) {callee: @free : (!llvm<"i8*">) -> ()} : (!llvm<"i8*">) -> ()
|
||||
dealloc %arg0 : memref<?x42x?xf32>
|
||||
// CHECK-NEXT: "llvm.return"() : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @load
|
||||
func @load(%static : memref<10x42xf32>, %dynamic : memref<?x?xf32>,
|
||||
%mixed : memref<42x?xf32>, %i : index, %j : index) {
|
||||
// CHECK-NEXT: %0 = "llvm.constant"() {value: 10 : index} : () -> !llvm<"i64">
|
||||
// CHECK-NEXT: %1 = "llvm.constant"() {value: 42 : index} : () -> !llvm<"i64">
|
||||
// CHECK-NEXT: %2 = "llvm.mul"(%arg3, %1) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %3 = "llvm.add"(%2, %arg4) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %4 = "llvm.extractvalue"(%arg0) {position: [0]} : (!llvm<"{ float* }">) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %5 = "llvm.getelementptr"(%4, %3) : (!llvm<"float*">, !llvm<"i64">) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %6 = "llvm.load"(%5) : (!llvm<"float*">) -> !llvm<"float">
|
||||
%0 = load %static[%i, %j] : memref<10x42xf32>
|
||||
|
||||
// CHECK-NEXT: %7 = "llvm.extractvalue"(%arg1) {position: [1]} : (!llvm<"{ float*, i64, i64 }">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %8 = "llvm.extractvalue"(%arg1) {position: [2]} : (!llvm<"{ float*, i64, i64 }">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %9 = "llvm.mul"(%arg3, %8) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %10 = "llvm.add"(%9, %arg4) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %11 = "llvm.extractvalue"(%arg1) {position: [0]} : (!llvm<"{ float*, i64, i64 }">) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %12 = "llvm.getelementptr"(%11, %10) : (!llvm<"float*">, !llvm<"i64">) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %13 = "llvm.load"(%12) : (!llvm<"float*">) -> !llvm<"float">
|
||||
%1 = load %dynamic[%i, %j] : memref<?x?xf32>
|
||||
|
||||
// CHECK-NEXT: %14 = "llvm.constant"() {value: 42 : index} : () -> !llvm<"i64">
|
||||
// CHECK-NEXT: %15 = "llvm.extractvalue"(%arg2) {position: [1]} : (!llvm<"{ float*, i64 }">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %16 = "llvm.mul"(%arg3, %15) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %17 = "llvm.add"(%16, %arg4) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %18 = "llvm.extractvalue"(%arg2) {position: [0]} : (!llvm<"{ float*, i64 }">) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %19 = "llvm.getelementptr"(%18, %17) : (!llvm<"float*">, !llvm<"i64">) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %20 = "llvm.load"(%19) : (!llvm<"float*">) -> !llvm<"float">
|
||||
%2 = load %mixed[%i, %j] : memref<42x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @store
|
||||
func @store(%static : memref<10x42xf32>, %dynamic : memref<?x?xf32>,
|
||||
%mixed : memref<42x?xf32>, %i : index, %j : index, %val : f32) {
|
||||
// CHECK-NEXT: %0 = "llvm.constant"() {value: 10 : index} : () -> !llvm<"i64">
|
||||
// CHECK-NEXT: %1 = "llvm.constant"() {value: 42 : index} : () -> !llvm<"i64">
|
||||
// CHECK-NEXT: %2 = "llvm.mul"(%arg3, %1) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %3 = "llvm.add"(%2, %arg4) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %4 = "llvm.extractvalue"(%arg0) {position: [0]} : (!llvm<"{ float* }">) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %5 = "llvm.getelementptr"(%4, %3) : (!llvm<"float*">, !llvm<"i64">) -> !llvm<"float*">
|
||||
// CHECK-NEXT: "llvm.store"(%arg5, %5) : (!llvm<"float">, !llvm<"float*">) -> ()
|
||||
store %val, %static[%i, %j] : memref<10x42xf32>
|
||||
|
||||
// CHECK-NEXT: %6 = "llvm.extractvalue"(%arg1) {position: [1]} : (!llvm<"{ float*, i64, i64 }">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %7 = "llvm.extractvalue"(%arg1) {position: [2]} : (!llvm<"{ float*, i64, i64 }">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %8 = "llvm.mul"(%arg3, %7) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %9 = "llvm.add"(%8, %arg4) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %10 = "llvm.extractvalue"(%arg1) {position: [0]} : (!llvm<"{ float*, i64, i64 }">) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %11 = "llvm.getelementptr"(%10, %9) : (!llvm<"float*">, !llvm<"i64">) -> !llvm<"float*">
|
||||
// CHECK-NEXT: "llvm.store"(%arg5, %11) : (!llvm<"float">, !llvm<"float*">) -> ()
|
||||
store %val, %dynamic[%i, %j] : memref<?x?xf32>
|
||||
|
||||
// CHECK-NEXT: %12 = "llvm.constant"() {value: 42 : index} : () -> !llvm<"i64">
|
||||
// CHECK-NEXT: %13 = "llvm.extractvalue"(%arg2) {position: [1]} : (!llvm<"{ float*, i64 }">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %14 = "llvm.mul"(%arg3, %13) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %15 = "llvm.add"(%14, %arg4) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %16 = "llvm.extractvalue"(%arg2) {position: [0]} : (!llvm<"{ float*, i64 }">) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %17 = "llvm.getelementptr"(%16, %15) : (!llvm<"float*">, !llvm<"i64">) -> !llvm<"float*">
|
||||
// CHECK-NEXT: "llvm.store"(%arg5, %17) : (!llvm<"float">, !llvm<"float*">) -> ()
|
||||
store %val, %mixed[%i, %j] : memref<42x?xf32>
|
||||
return
|
||||
}
|
|
@ -416,32 +416,3 @@ func @dfs_block_order() -> (i32) {
|
|||
// CHECK-NEXT: "llvm.br"()[^bb1] : () -> ()
|
||||
br ^bb1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @alloc(%arg0: !llvm<"i64">, %arg1: !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }"> {
|
||||
func @alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> {
|
||||
// CHECK-NEXT: %0 = "llvm.constant"() {value: 42 : index} : () -> !llvm<"i64">
|
||||
// CHECK-NEXT: %1 = "llvm.mul"(%arg0, %0) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %2 = "llvm.mul"(%1, %arg1) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %3 = "llvm.undef"() : () -> !llvm<"{ float*, i64, i64 }">
|
||||
// CHECK-NEXT: %4 = "llvm.constant"() {value: 4 : index} : () -> !llvm<"i64">
|
||||
// CHECK-NEXT: %5 = "llvm.mul"(%2, %4) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
|
||||
// CHECK-NEXT: %6 = "llvm.call"(%5) {callee: @malloc : (!llvm<"i64">) -> !llvm<"i8*">} : (!llvm<"i64">) -> !llvm<"i8*">
|
||||
// CHECK-NEXT: %7 = "llvm.bitcast"(%6) : (!llvm<"i8*">) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %8 = "llvm.insertvalue"(%3, %7) {position: [0]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"float*">) -> !llvm<"{ float*, i64, i64 }">
|
||||
// CHECK-NEXT: %9 = "llvm.insertvalue"(%8, %arg0) {position: [1]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }">
|
||||
// CHECK-NEXT: %10 = "llvm.insertvalue"(%9, %arg1) {position: [2]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }">
|
||||
%0 = alloc(%arg0, %arg1) : memref<?x42x?xf32>
|
||||
// CHECK-NEXT: "llvm.return"(%10) : (!llvm<"{ float*, i64, i64 }">) -> ()
|
||||
return %0 : memref<?x42x?xf32>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @dealloc(%arg0: !llvm<"{ float*, i64, i64 }">) {
|
||||
func @dealloc(%arg0: memref<?x42x?xf32>) {
|
||||
// CHECK-NEXT: %0 = "llvm.extractvalue"(%arg0) {position: [0]} : (!llvm<"{ float*, i64, i64 }">) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %1 = "llvm.bitcast"(%0) : (!llvm<"float*">) -> !llvm<"i8*">
|
||||
// CHECK-NEXT: "llvm.call0"(%1) {callee: @free : (!llvm<"i8*">) -> ()} : (!llvm<"i8*">) -> ()
|
||||
dealloc %arg0 : memref<?x42x?xf32>
|
||||
// CHECK-NEXT: "llvm.return"() : () -> ()
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue