[mlir] Lower RankOp to LLVM for unranked memrefs.

Differential Revision: https://reviews.llvm.org/D85273
This commit is contained in:
Alexander Belyaev 2020-08-05 12:12:45 +02:00
parent f97019ad6e
commit a3d427d30c
2 changed files with 46 additions and 0 deletions

View File

@ -2402,6 +2402,28 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
}
};
struct RankOpLowering : public ConvertOpToLLVMPattern<RankOp> {
using ConvertOpToLLVMPattern<RankOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Type operandType = cast<RankOp>(op).memrefOrTensor().getType();
if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
UnrankedMemRefDescriptor desc(RankOp::Adaptor(operands).memrefOrTensor());
rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
return success();
}
if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
rewriter.replaceOp(
op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
return success();
}
return failure();
}
};
// Common base for load and store operations on MemRefs. Restricts the match
// to supported MemRef types. Provides functionality to emit code accessing a
// specific element of the underlying data buffer.
@ -3272,6 +3294,7 @@ void mlir::populateStdToLLVMMemoryConversionPatterns(
DimOpLowering,
LoadOpLowering,
MemRefCastOpLowering,
RankOpLowering,
StoreOpLowering,
SubViewOpLowering,
ViewOpLowering,

View File

@ -1291,3 +1291,26 @@ func @bfloat(%arg0: bf16) -> bf16 {
func @memref_index(%arg0: memref<32xindex>) -> memref<32xindex> {
return %arg0 : memref<32xindex>
}
// -----
// CHECK-LABEL: func @rank_of_unranked
// CHECK32-LABEL: func @rank_of_unranked
func @rank_of_unranked(%unranked: memref<*xi32>) {
%rank = rank %unranked : memref<*xi32>
return
}
// CHECK-NEXT: llvm.mlir.undef
// CHECK-NEXT: llvm.insertvalue
// CHECK-NEXT: llvm.insertvalue
// CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i8* }">
// CHECK32: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i8* }">
// CHECK-LABEL: func @rank_of_ranked
// CHECK32-LABEL: func @rank_of_ranked
func @rank_of_ranked(%ranked: memref<?xi32>) {
%rank = rank %ranked : memref<?xi32>
return
}
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK32: llvm.mlir.constant(1 : index) : !llvm.i32