forked from OSchip/llvm-project
LLVM IR Dialect: port DimOp lowering from the translator
DimOp is converted to a constant LLVM IR dialect operation for static dimensions and to an access to the dynamic size info stored in the memref descriptor for the dynamic dimensions. This behavior is consistent with the existing mlir-translator. This completes the porting of MLIR -> LLVM lowering to the dialect conversion infrastructure. PiperOrigin-RevId: 233665634
This commit is contained in:
parent
2f11f86846
commit
465746f262
|
@ -712,6 +712,50 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
|
|||
}
|
||||
};
|
||||
|
||||
// A `dim` is converted to a constant for static sizes and to an access to the
|
||||
// size stored in the memref descriptor for dynamic sizes.
|
||||
struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
|
||||
using LLVMLegalizationPattern<DimOp>::LLVMLegalizationPattern;
|
||||
|
||||
PatternMatchResult match(Instruction *op) const override {
|
||||
if (!LLVMLegalizationPattern<DimOp>::match(op))
|
||||
return this->matchFailure();
|
||||
auto dimOp = op->cast<DimOp>();
|
||||
MemRefType type = dimOp->getOperand()->getType().cast<MemRefType>();
|
||||
return isSupportedMemRefType(type) ? matchSuccess() : matchFailure();
|
||||
}
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Instruction *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
assert(operands.size() == 1 && "expected exactly one operand");
|
||||
auto dimOp = op->cast<DimOp>();
|
||||
MemRefType type = dimOp->getOperand()->getType().cast<MemRefType>();
|
||||
|
||||
SmallVector<Value *, 4> results;
|
||||
auto shape = type.getShape();
|
||||
uint64_t index = dimOp->getIndex();
|
||||
// Extract dynamic size from the memref descriptor and define static size
|
||||
// as a constant.
|
||||
if (shape[index] == -1) {
|
||||
// Find the position of the dynamic dimension in the list of dynamic sizes
|
||||
// by counting the number of preceding dynamic dimensions. Start from 1
|
||||
// because the buffer pointer is at position zero.
|
||||
int64_t position = 1;
|
||||
for (uint64_t i = 0; i < index; ++i) {
|
||||
if (shape[i] == -1)
|
||||
++position;
|
||||
}
|
||||
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
|
||||
op->getLoc(), getIndexType(), operands,
|
||||
getPositionAttribute(rewriter, position)));
|
||||
} else {
|
||||
results.push_back(
|
||||
createIndexConstant(rewriter, op->getLoc(), shape[index]));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
};
|
||||
|
||||
// Common base for load and store operations on MemRefs. Restricts the match
|
||||
// to supported MemRef types. Provides functionality to emit code accessing a
|
||||
// specific element of the underlying data buffer.
|
||||
|
@ -939,7 +983,7 @@ protected:
|
|||
return ConversionListBuilder<
|
||||
AddFOpLowering, AddIOpLowering, AllocOpLowering, BranchOpLowering,
|
||||
Call0OpLowering, CallOpLowering, CmpIOpLowering, CondBranchOpLowering,
|
||||
ConstLLVMOpLowering, DeallocOpLowering, DivISOpLowering,
|
||||
ConstLLVMOpLowering, DeallocOpLowering, DimOpLowering, DivISOpLowering,
|
||||
DivIUOpLowering, LoadOpLowering, MemRefCastOpLowering, MulFOpLowering,
|
||||
MulIOpLowering, RemISOpLowering, RemIUOpLowering, ReturnOpLowering,
|
||||
SelectOpLowering, StoreOpLowering, SubFOpLowering,
|
||||
|
|
|
@ -145,3 +145,18 @@ func @memref_cast(%static : memref<10x42xf32>, %dynamic : memref<?x?xf32>,
|
|||
%6 = memref_cast %mixed : memref<42x?xf32> to memref<?x1xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @memref_dim(%arg0: !llvm<"{ float*, i64, i64, i64 }">)
|
||||
func @memref_dim(%mixed : memref<42x?x?x13x?xf32>) {
|
||||
// CHECK-NEXT: %0 = "llvm.constant"() {value: 42 : index} : () -> !llvm<"i64">
|
||||
%0 = dim %mixed, 0 : memref<42x?x?x13x?xf32>
|
||||
// CHECK-NEXT: %1 = "llvm.extractvalue"(%arg0) {position: [1]} : (!llvm<"{ float*, i64, i64, i64 }">) -> !llvm<"i64">
|
||||
%1 = dim %mixed, 1 : memref<42x?x?x13x?xf32>
|
||||
// CHECK-NEXT: %2 = "llvm.extractvalue"(%arg0) {position: [2]} : (!llvm<"{ float*, i64, i64, i64 }">) -> !llvm<"i64">
|
||||
%2 = dim %mixed, 2 : memref<42x?x?x13x?xf32>
|
||||
// CHECK-NEXT: %3 = "llvm.constant"() {value: 13 : index} : () -> !llvm<"i64">
|
||||
%3 = dim %mixed, 3 : memref<42x?x?x13x?xf32>
|
||||
// CHECK-NEXT: %4 = "llvm.extractvalue"(%arg0) {position: [3]} : (!llvm<"{ float*, i64, i64, i64 }">) -> !llvm<"i64">
|
||||
%4 = dim %mixed, 4 : memref<42x?x?x13x?xf32>
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue