LLVM IR Dialect: port DimOp lowering from the translator

DimOp is converted to a constant LLVM IR dialect operation for static
dimensions and to an access to the dynamic size info stored in the memref
descriptor for the dynamic dimensions.  This behavior is consistent with the
existing mlir-translator.

This completes the porting of MLIR -> LLVM lowering to the dialect conversion
infrastructure.

PiperOrigin-RevId: 233665634
This commit is contained in:
Alex Zinenko 2019-02-12 13:47:25 -08:00 committed by jpienaar
parent 2f11f86846
commit 465746f262
2 changed files with 60 additions and 1 deletions

View File

@ -712,6 +712,50 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
}
};
// A `dim` is converted to a constant for static sizes and to an access to the
// size stored in the memref descriptor for dynamic sizes.
struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
using LLVMLegalizationPattern<DimOp>::LLVMLegalizationPattern;
PatternMatchResult match(Instruction *op) const override {
if (!LLVMLegalizationPattern<DimOp>::match(op))
return this->matchFailure();
auto dimOp = op->cast<DimOp>();
MemRefType type = dimOp->getOperand()->getType().cast<MemRefType>();
return isSupportedMemRefType(type) ? matchSuccess() : matchFailure();
}
SmallVector<Value *, 4> rewrite(Instruction *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
assert(operands.size() == 1 && "expected exactly one operand");
auto dimOp = op->cast<DimOp>();
MemRefType type = dimOp->getOperand()->getType().cast<MemRefType>();
SmallVector<Value *, 4> results;
auto shape = type.getShape();
uint64_t index = dimOp->getIndex();
// Extract dynamic size from the memref descriptor and define static size
// as a constant.
if (shape[index] == -1) {
// Find the position of the dynamic dimension in the list of dynamic sizes
// by counting the number of preceding dynamic dimensions. Start from 1
// because the buffer pointer is at position zero.
int64_t position = 1;
for (uint64_t i = 0; i < index; ++i) {
if (shape[i] == -1)
++position;
}
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
op->getLoc(), getIndexType(), operands,
getPositionAttribute(rewriter, position)));
} else {
results.push_back(
createIndexConstant(rewriter, op->getLoc(), shape[index]));
}
return results;
}
};
// Common base for load and store operations on MemRefs. Restricts the match
// to supported MemRef types. Provides functionality to emit code accessing a
// specific element of the underlying data buffer.
@ -939,7 +983,7 @@ protected:
return ConversionListBuilder<
AddFOpLowering, AddIOpLowering, AllocOpLowering, BranchOpLowering,
Call0OpLowering, CallOpLowering, CmpIOpLowering, CondBranchOpLowering,
ConstLLVMOpLowering, DeallocOpLowering, DivISOpLowering,
ConstLLVMOpLowering, DeallocOpLowering, DimOpLowering, DivISOpLowering,
DivIUOpLowering, LoadOpLowering, MemRefCastOpLowering, MulFOpLowering,
MulIOpLowering, RemISOpLowering, RemIUOpLowering, ReturnOpLowering,
SelectOpLowering, StoreOpLowering, SubFOpLowering,

View File

@ -145,3 +145,18 @@ func @memref_cast(%static : memref<10x42xf32>, %dynamic : memref<?x?xf32>,
%6 = memref_cast %mixed : memref<42x?xf32> to memref<?x1xf32>
return
}
// CHECK-LABEL: func @memref_dim(%arg0: !llvm<"{ float*, i64, i64, i64 }">)
func @memref_dim(%mixed : memref<42x?x?x13x?xf32>) {
// CHECK-NEXT: %0 = "llvm.constant"() {value: 42 : index} : () -> !llvm<"i64">
%0 = dim %mixed, 0 : memref<42x?x?x13x?xf32>
// CHECK-NEXT: %1 = "llvm.extractvalue"(%arg0) {position: [1]} : (!llvm<"{ float*, i64, i64, i64 }">) -> !llvm<"i64">
%1 = dim %mixed, 1 : memref<42x?x?x13x?xf32>
// CHECK-NEXT: %2 = "llvm.extractvalue"(%arg0) {position: [2]} : (!llvm<"{ float*, i64, i64, i64 }">) -> !llvm<"i64">
%2 = dim %mixed, 2 : memref<42x?x?x13x?xf32>
// CHECK-NEXT: %3 = "llvm.constant"() {value: 13 : index} : () -> !llvm<"i64">
%3 = dim %mixed, 3 : memref<42x?x?x13x?xf32>
// CHECK-NEXT: %4 = "llvm.extractvalue"(%arg0) {position: [3]} : (!llvm<"{ float*, i64, i64, i64 }">) -> !llvm<"i64">
%4 = dim %mixed, 4 : memref<42x?x?x13x?xf32>
return
}