Port load/store op translation to LLVM IR dialect lowering

Implement the lowering of memref load and store standard operations into the
LLVM IR dialect.  This largely follows the existing mechanism in
MLIR-to-LLVM-IR translation for the sake of compatibility.  A memref value is
transformed into a memref descriptor value which holds the pointer to the
underlying data buffer and the dynamic memref sizes.  The data buffer is
contiguous.  Accesses to multidimensional memrefs are linearized in row-major
form.  In linear address computation, statically known sizes are used as
constants while dynamic sizes are extracted from the memref descriptor.

PiperOrigin-RevId: 233043846
This commit is contained in:
Alex Zinenko 2019-02-08 05:26:20 -08:00 committed by jpienaar
parent c419accea3
commit 4c35bbbb51
3 changed files with 224 additions and 35 deletions

View File

@ -619,6 +619,131 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
}
};
// Common base for load and store operations on MemRefs. Restricts the match
// to supported MemRef types. Provides functionality to emit code accessing a
// specific element of the underlying data buffer.
template <typename Derived>
struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
using LLVMLegalizationPattern<Derived>::LLVMLegalizationPattern;
using Base = LoadStoreOpLowering<Derived>;
PatternMatchResult match(Instruction *op) const override {
if (!LLVMLegalizationPattern<Derived>::match(op))
return this->matchFailure();
auto loadOp = op->cast<Derived>();
MemRefType type = loadOp->getMemRefType();
return isSupportedMemRefType(type) ? this->matchSuccess()
: this->matchFailure();
}
// Given subscript indices and array sizes in row-major order,
// i_n, i_{n-1}, ..., i_1
// s_n, s_{n-1}, ..., s_1
// obtain a value that corresponds to the linearized subscript
// \sum_k i_k * \prod_{j=1}^{k-1} s_j
// by accumulating the running linearized value.
// Note that `indices` and `allocSizes` are passed in the same order as they
// appear in load/store operations and memref type declarations.
Value *linearizeSubscripts(FuncBuilder &builder, Location loc,
ArrayRef<Value *> indices,
ArrayRef<Value *> allocSizes) const {
assert(indices.size() == allocSizes.size() &&
"mismatching number of indices and allocation sizes");
assert(!indices.empty() && "cannot linearize a 0-dimensional access");
Value *linearized = indices.front();
for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) {
linearized = builder.create<LLVM::MulOp>(
loc, this->getIndexType(),
ArrayRef<Value *>{linearized, allocSizes[i]});
linearized = builder.create<LLVM::AddOp>(
loc, this->getIndexType(), ArrayRef<Value *>{linearized, indices[i]});
}
return linearized;
}
// Given the MemRef type, a descriptor and a list of indices, extract the data
// buffer pointer from the descriptor, convert multi-dimensional subscripts
// into a linearized index (using dynamic size data from the descriptor if
// necessary) and get the pointer to the buffer element identified by the
// indies.
Value *getElementPtr(Location loc, MemRefType type, Value *memRefDescriptor,
ArrayRef<Value *> indices, FuncBuilder &rewriter) const {
auto elementType =
TypeConverter::convert(type.getElementType(), this->getModule());
auto elementTypePtr = rewriter.getType<LLVM::LLVMType>(
elementType.template cast<LLVM::LLVMType>()
.getUnderlyingType()
->getPointerTo());
// Get the list of MemRef sizes. Static sizes are defined as constants.
// Dynamic sizes are extracted from the MemRef descriptor, where they start
// from the position 1 (the buffer is at position 0).
SmallVector<Value *, 4> sizes;
unsigned dynamicSizeIdx = 1;
for (int64_t s : type.getShape()) {
if (s == -1) {
Value *size = rewriter.create<LLVM::ExtractValueOp>(
loc, this->getIndexType(), ArrayRef<Value *>{memRefDescriptor},
llvm::makeArrayRef(
this->getPositionAttribute(rewriter, dynamicSizeIdx++)));
sizes.push_back(size);
} else {
sizes.push_back(this->createIndexConstant(rewriter, loc, s));
}
}
// The second and subsequent operands are access subscripts. Obtain the
// linearized address in the buffer.
Value *subscript = linearizeSubscripts(rewriter, loc, indices, sizes);
Value *dataPtr = rewriter.create<LLVM::ExtractValueOp>(
loc, elementTypePtr, ArrayRef<Value *>{memRefDescriptor},
llvm::makeArrayRef(this->getPositionAttribute(rewriter, 0)));
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr,
ArrayRef<Value *>{dataPtr, subscript},
ArrayRef<NamedAttribute>{});
}
};
// Load operation is lowered to obtaining a pointer to the indexed element
// and loading it.
struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
using Base::Base;
SmallVector<Value *, 4> rewrite(Instruction *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
auto loadOp = op->cast<LoadOp>();
auto type = loadOp->getMemRefType();
auto elementType =
TypeConverter::convert(type.getElementType(), getModule());
Value *dataPtr = getElementPtr(op->getLoc(), type, operands.front(),
operands.drop_front(), rewriter);
SmallVector<Value *, 4> results;
results.push_back(rewriter.create<LLVM::LoadOp>(
op->getLoc(), elementType, ArrayRef<Value *>{dataPtr}));
return results;
}
};
// Store opreation is lowered to obtaining a pointer to the indexed element,
// and storing the given value to it.
struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
using Base::Base;
SmallVector<Value *, 4> rewrite(Instruction *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
auto storeOp = op->cast<StoreOp>();
auto type = storeOp->getMemRefType();
Value *dataPtr = getElementPtr(op->getLoc(), type, operands[1],
operands.drop_front(2), rewriter);
rewriter.create<LLVM::StoreOp>(op->getLoc(), operands[0], dataPtr);
return {};
}
};
// Base class for LLVM IR lowering terminator operations with successors.
template <typename SourceOp, typename TargetOp>
struct OneToOneLLVMTerminatorLowering
@ -723,12 +848,12 @@ protected:
// FIXME: this should be tablegen'ed
return ConversionListBuilder<
AllocOpLowering, DeallocOpLowering, AddIOpLowering, SubIOpLowering,
MulIOpLowering, DivISOpLowering, DivIUOpLowering, RemISOpLowering,
RemIUOpLowering, AddFOpLowering, SubFOpLowering, MulFOpLowering,
CmpIOpLowering, CallOpLowering, Call0OpLowering, BranchOpLowering,
CondBranchOpLowering, ReturnOpLowering,
ConstLLVMOpLowering>::build(&converterStorage, *llvmDialect);
AddFOpLowering, AddIOpLowering, AllocOpLowering, BranchOpLowering,
Call0OpLowering, CallOpLowering, CmpIOpLowering, CondBranchOpLowering,
ConstLLVMOpLowering, DeallocOpLowering, DivISOpLowering,
DivIUOpLowering, LoadOpLowering, MulFOpLowering, MulIOpLowering,
RemISOpLowering, RemIUOpLowering, ReturnOpLowering, StoreOpLowering,
SubFOpLowering, SubIOpLowering>::build(&converterStorage, *llvmDialect);
}
// Convert types using the stored LLVM IR module.

View File

@ -0,0 +1,93 @@
// RUN: mlir-opt -convert-to-llvmir %s | FileCheck %s
// CHECK-LABEL: func @alloc(%arg0: !llvm<"i64">, %arg1: !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }"> {
func @alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> {
// CHECK-NEXT: %0 = "llvm.constant"() {value: 42 : index} : () -> !llvm<"i64">
// CHECK-NEXT: %1 = "llvm.mul"(%arg0, %0) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
// CHECK-NEXT: %2 = "llvm.mul"(%1, %arg1) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
// CHECK-NEXT: %3 = "llvm.undef"() : () -> !llvm<"{ float*, i64, i64 }">
// CHECK-NEXT: %4 = "llvm.constant"() {value: 4 : index} : () -> !llvm<"i64">
// CHECK-NEXT: %5 = "llvm.mul"(%2, %4) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
// CHECK-NEXT: %6 = "llvm.call"(%5) {callee: @malloc : (!llvm<"i64">) -> !llvm<"i8*">} : (!llvm<"i64">) -> !llvm<"i8*">
// CHECK-NEXT: %7 = "llvm.bitcast"(%6) : (!llvm<"i8*">) -> !llvm<"float*">
// CHECK-NEXT: %8 = "llvm.insertvalue"(%3, %7) {position: [0]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"float*">) -> !llvm<"{ float*, i64, i64 }">
// CHECK-NEXT: %9 = "llvm.insertvalue"(%8, %arg0) {position: [1]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }">
// CHECK-NEXT: %10 = "llvm.insertvalue"(%9, %arg1) {position: [2]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }">
%0 = alloc(%arg0, %arg1) : memref<?x42x?xf32>
// CHECK-NEXT: "llvm.return"(%10) : (!llvm<"{ float*, i64, i64 }">) -> ()
return %0 : memref<?x42x?xf32>
}
// CHECK-LABEL: func @dealloc(%arg0: !llvm<"{ float*, i64, i64 }">) {
func @dealloc(%arg0: memref<?x42x?xf32>) {
// CHECK-NEXT: %0 = "llvm.extractvalue"(%arg0) {position: [0]} : (!llvm<"{ float*, i64, i64 }">) -> !llvm<"float*">
// CHECK-NEXT: %1 = "llvm.bitcast"(%0) : (!llvm<"float*">) -> !llvm<"i8*">
// CHECK-NEXT: "llvm.call0"(%1) {callee: @free : (!llvm<"i8*">) -> ()} : (!llvm<"i8*">) -> ()
dealloc %arg0 : memref<?x42x?xf32>
// CHECK-NEXT: "llvm.return"() : () -> ()
return
}
// CHECK-LABEL: func @load
func @load(%static : memref<10x42xf32>, %dynamic : memref<?x?xf32>,
%mixed : memref<42x?xf32>, %i : index, %j : index) {
// CHECK-NEXT: %0 = "llvm.constant"() {value: 10 : index} : () -> !llvm<"i64">
// CHECK-NEXT: %1 = "llvm.constant"() {value: 42 : index} : () -> !llvm<"i64">
// CHECK-NEXT: %2 = "llvm.mul"(%arg3, %1) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
// CHECK-NEXT: %3 = "llvm.add"(%2, %arg4) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
// CHECK-NEXT: %4 = "llvm.extractvalue"(%arg0) {position: [0]} : (!llvm<"{ float* }">) -> !llvm<"float*">
// CHECK-NEXT: %5 = "llvm.getelementptr"(%4, %3) : (!llvm<"float*">, !llvm<"i64">) -> !llvm<"float*">
// CHECK-NEXT: %6 = "llvm.load"(%5) : (!llvm<"float*">) -> !llvm<"float">
%0 = load %static[%i, %j] : memref<10x42xf32>
// CHECK-NEXT: %7 = "llvm.extractvalue"(%arg1) {position: [1]} : (!llvm<"{ float*, i64, i64 }">) -> !llvm<"i64">
// CHECK-NEXT: %8 = "llvm.extractvalue"(%arg1) {position: [2]} : (!llvm<"{ float*, i64, i64 }">) -> !llvm<"i64">
// CHECK-NEXT: %9 = "llvm.mul"(%arg3, %8) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
// CHECK-NEXT: %10 = "llvm.add"(%9, %arg4) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
// CHECK-NEXT: %11 = "llvm.extractvalue"(%arg1) {position: [0]} : (!llvm<"{ float*, i64, i64 }">) -> !llvm<"float*">
// CHECK-NEXT: %12 = "llvm.getelementptr"(%11, %10) : (!llvm<"float*">, !llvm<"i64">) -> !llvm<"float*">
// CHECK-NEXT: %13 = "llvm.load"(%12) : (!llvm<"float*">) -> !llvm<"float">
%1 = load %dynamic[%i, %j] : memref<?x?xf32>
// CHECK-NEXT: %14 = "llvm.constant"() {value: 42 : index} : () -> !llvm<"i64">
// CHECK-NEXT: %15 = "llvm.extractvalue"(%arg2) {position: [1]} : (!llvm<"{ float*, i64 }">) -> !llvm<"i64">
// CHECK-NEXT: %16 = "llvm.mul"(%arg3, %15) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
// CHECK-NEXT: %17 = "llvm.add"(%16, %arg4) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
// CHECK-NEXT: %18 = "llvm.extractvalue"(%arg2) {position: [0]} : (!llvm<"{ float*, i64 }">) -> !llvm<"float*">
// CHECK-NEXT: %19 = "llvm.getelementptr"(%18, %17) : (!llvm<"float*">, !llvm<"i64">) -> !llvm<"float*">
// CHECK-NEXT: %20 = "llvm.load"(%19) : (!llvm<"float*">) -> !llvm<"float">
%2 = load %mixed[%i, %j] : memref<42x?xf32>
return
}
// CHECK-LABEL: func @store
func @store(%static : memref<10x42xf32>, %dynamic : memref<?x?xf32>,
%mixed : memref<42x?xf32>, %i : index, %j : index, %val : f32) {
// CHECK-NEXT: %0 = "llvm.constant"() {value: 10 : index} : () -> !llvm<"i64">
// CHECK-NEXT: %1 = "llvm.constant"() {value: 42 : index} : () -> !llvm<"i64">
// CHECK-NEXT: %2 = "llvm.mul"(%arg3, %1) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
// CHECK-NEXT: %3 = "llvm.add"(%2, %arg4) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
// CHECK-NEXT: %4 = "llvm.extractvalue"(%arg0) {position: [0]} : (!llvm<"{ float* }">) -> !llvm<"float*">
// CHECK-NEXT: %5 = "llvm.getelementptr"(%4, %3) : (!llvm<"float*">, !llvm<"i64">) -> !llvm<"float*">
// CHECK-NEXT: "llvm.store"(%arg5, %5) : (!llvm<"float">, !llvm<"float*">) -> ()
store %val, %static[%i, %j] : memref<10x42xf32>
// CHECK-NEXT: %6 = "llvm.extractvalue"(%arg1) {position: [1]} : (!llvm<"{ float*, i64, i64 }">) -> !llvm<"i64">
// CHECK-NEXT: %7 = "llvm.extractvalue"(%arg1) {position: [2]} : (!llvm<"{ float*, i64, i64 }">) -> !llvm<"i64">
// CHECK-NEXT: %8 = "llvm.mul"(%arg3, %7) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
// CHECK-NEXT: %9 = "llvm.add"(%8, %arg4) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
// CHECK-NEXT: %10 = "llvm.extractvalue"(%arg1) {position: [0]} : (!llvm<"{ float*, i64, i64 }">) -> !llvm<"float*">
// CHECK-NEXT: %11 = "llvm.getelementptr"(%10, %9) : (!llvm<"float*">, !llvm<"i64">) -> !llvm<"float*">
// CHECK-NEXT: "llvm.store"(%arg5, %11) : (!llvm<"float">, !llvm<"float*">) -> ()
store %val, %dynamic[%i, %j] : memref<?x?xf32>
// CHECK-NEXT: %12 = "llvm.constant"() {value: 42 : index} : () -> !llvm<"i64">
// CHECK-NEXT: %13 = "llvm.extractvalue"(%arg2) {position: [1]} : (!llvm<"{ float*, i64 }">) -> !llvm<"i64">
// CHECK-NEXT: %14 = "llvm.mul"(%arg3, %13) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
// CHECK-NEXT: %15 = "llvm.add"(%14, %arg4) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
// CHECK-NEXT: %16 = "llvm.extractvalue"(%arg2) {position: [0]} : (!llvm<"{ float*, i64 }">) -> !llvm<"float*">
// CHECK-NEXT: %17 = "llvm.getelementptr"(%16, %15) : (!llvm<"float*">, !llvm<"i64">) -> !llvm<"float*">
// CHECK-NEXT: "llvm.store"(%arg5, %17) : (!llvm<"float">, !llvm<"float*">) -> ()
store %val, %mixed[%i, %j] : memref<42x?xf32>
return
}

View File

@ -416,32 +416,3 @@ func @dfs_block_order() -> (i32) {
// CHECK-NEXT: "llvm.br"()[^bb1] : () -> ()
br ^bb1
}
// CHECK-LABEL: func @alloc(%arg0: !llvm<"i64">, %arg1: !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }"> {
func @alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> {
// CHECK-NEXT: %0 = "llvm.constant"() {value: 42 : index} : () -> !llvm<"i64">
// CHECK-NEXT: %1 = "llvm.mul"(%arg0, %0) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
// CHECK-NEXT: %2 = "llvm.mul"(%1, %arg1) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
// CHECK-NEXT: %3 = "llvm.undef"() : () -> !llvm<"{ float*, i64, i64 }">
// CHECK-NEXT: %4 = "llvm.constant"() {value: 4 : index} : () -> !llvm<"i64">
// CHECK-NEXT: %5 = "llvm.mul"(%2, %4) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
// CHECK-NEXT: %6 = "llvm.call"(%5) {callee: @malloc : (!llvm<"i64">) -> !llvm<"i8*">} : (!llvm<"i64">) -> !llvm<"i8*">
// CHECK-NEXT: %7 = "llvm.bitcast"(%6) : (!llvm<"i8*">) -> !llvm<"float*">
// CHECK-NEXT: %8 = "llvm.insertvalue"(%3, %7) {position: [0]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"float*">) -> !llvm<"{ float*, i64, i64 }">
// CHECK-NEXT: %9 = "llvm.insertvalue"(%8, %arg0) {position: [1]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }">
// CHECK-NEXT: %10 = "llvm.insertvalue"(%9, %arg1) {position: [2]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }">
%0 = alloc(%arg0, %arg1) : memref<?x42x?xf32>
// CHECK-NEXT: "llvm.return"(%10) : (!llvm<"{ float*, i64, i64 }">) -> ()
return %0 : memref<?x42x?xf32>
}
// CHECK-LABEL: func @dealloc(%arg0: !llvm<"{ float*, i64, i64 }">) {
func @dealloc(%arg0: memref<?x42x?xf32>) {
// CHECK-NEXT: %0 = "llvm.extractvalue"(%arg0) {position: [0]} : (!llvm<"{ float*, i64, i64 }">) -> !llvm<"float*">
// CHECK-NEXT: %1 = "llvm.bitcast"(%0) : (!llvm<"float*">) -> !llvm<"i8*">
// CHECK-NEXT: "llvm.call0"(%1) {callee: @free : (!llvm<"i8*">) -> ()} : (!llvm<"i8*">) -> ()
dealloc %arg0 : memref<?x42x?xf32>
// CHECK-NEXT: "llvm.return"() : () -> ()
return
}