From 465746f262af5aa17af4fe09f31b93a1d7ad75f5 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Tue, 12 Feb 2019 13:47:25 -0800 Subject: [PATCH] LLVM IR Dialect: port DimOp lowering from the translator DimOp is converted to a constant LLVM IR dialect operation for static dimensions and to an access to the dynamic size info stored in the memref descriptor for the dynamic dimensions. This behavior is consistent with the existing mlir-translator. This completes the porting of MLIR -> LLVM lowering to the dialect conversion infrastructure. PiperOrigin-RevId: 233665634 --- .../Transforms/ConvertToLLVMDialect.cpp | 46 ++++++++++++++++++- mlir/test/LLVMIR/convert-memref-ops.mlir | 15 ++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index 6c3954f8492d..22b7bc50cfc7 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -712,6 +712,50 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern { } }; +// A `dim` is converted to a constant for static sizes and to an access to the +// size stored in the memref descriptor for dynamic sizes. +struct DimOpLowering : public LLVMLegalizationPattern { + using LLVMLegalizationPattern::LLVMLegalizationPattern; + + PatternMatchResult match(Instruction *op) const override { + if (!LLVMLegalizationPattern::match(op)) + return this->matchFailure(); + auto dimOp = op->cast(); + MemRefType type = dimOp->getOperand()->getType().cast(); + return isSupportedMemRefType(type) ? matchSuccess() : matchFailure(); + } + + SmallVector rewrite(Instruction *op, ArrayRef operands, + FuncBuilder &rewriter) const override { + assert(operands.size() == 1 && "expected exactly one operand"); + auto dimOp = op->cast(); + MemRefType type = dimOp->getOperand()->getType().cast(); + + SmallVector results; + auto shape = type.getShape(); + uint64_t index = dimOp->getIndex(); + // Extract dynamic size from the memref descriptor and define static size + // as a constant. + if (shape[index] == -1) { + // Find the position of the dynamic dimension in the list of dynamic sizes + // by counting the number of preceding dynamic dimensions. Start from 1 + // because the buffer pointer is at position zero. + int64_t position = 1; + for (uint64_t i = 0; i < index; ++i) { + if (shape[i] == -1) + ++position; + } + results.push_back(rewriter.create( + op->getLoc(), getIndexType(), operands, + getPositionAttribute(rewriter, position))); + } else { + results.push_back( + createIndexConstant(rewriter, op->getLoc(), shape[index])); + } + return results; + } +}; + // Common base for load and store operations on MemRefs. Restricts the match // to supported MemRef types. Provides functionality to emit code accessing a // specific element of the underlying data buffer. @@ -939,7 +983,7 @@ protected: return ConversionListBuilder< AddFOpLowering, AddIOpLowering, AllocOpLowering, BranchOpLowering, Call0OpLowering, CallOpLowering, CmpIOpLowering, CondBranchOpLowering, - ConstLLVMOpLowering, DeallocOpLowering, DivISOpLowering, + ConstLLVMOpLowering, DeallocOpLowering, DimOpLowering, DivISOpLowering, DivIUOpLowering, LoadOpLowering, MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, RemISOpLowering, RemIUOpLowering, ReturnOpLowering, SelectOpLowering, StoreOpLowering, SubFOpLowering, diff --git a/mlir/test/LLVMIR/convert-memref-ops.mlir b/mlir/test/LLVMIR/convert-memref-ops.mlir index bec1bb9ad112..7bc4b13db3cc 100644 --- a/mlir/test/LLVMIR/convert-memref-ops.mlir +++ b/mlir/test/LLVMIR/convert-memref-ops.mlir @@ -145,3 +145,18 @@ func @memref_cast(%static : memref<10x42xf32>, %dynamic : memref, %6 = memref_cast %mixed : memref<42x?xf32> to memref return } + +// CHECK-LABEL: func @memref_dim(%arg0: !llvm<"{ float*, i64, i64, i64 }">) +func @memref_dim(%mixed : memref<42x?x?x13x?xf32>) { +// CHECK-NEXT: %0 = "llvm.constant"() {value: 42 : index} : () -> !llvm<"i64"> + %0 = dim %mixed, 0 : memref<42x?x?x13x?xf32> +// CHECK-NEXT: %1 = "llvm.extractvalue"(%arg0) {position: [1]} : (!llvm<"{ float*, i64, i64, i64 }">) -> !llvm<"i64"> + %1 = dim %mixed, 1 : memref<42x?x?x13x?xf32> +// CHECK-NEXT: %2 = "llvm.extractvalue"(%arg0) {position: [2]} : (!llvm<"{ float*, i64, i64, i64 }">) -> !llvm<"i64"> + %2 = dim %mixed, 2 : memref<42x?x?x13x?xf32> +// CHECK-NEXT: %3 = "llvm.constant"() {value: 13 : index} : () -> !llvm<"i64"> + %3 = dim %mixed, 3 : memref<42x?x?x13x?xf32> +// CHECK-NEXT: %4 = "llvm.extractvalue"(%arg0) {position: [3]} : (!llvm<"{ float*, i64, i64, i64 }">) -> !llvm<"i64"> + %4 = dim %mixed, 4 : memref<42x?x?x13x?xf32> + return +}