forked from OSchip/llvm-project
[mlir] Lower RankOp to LLVM for unranked memrefs.
Differential Revision: https://reviews.llvm.org/D85273
This commit is contained in:
parent
f97019ad6e
commit
a3d427d30c
|
@ -2402,6 +2402,28 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
|
|||
}
|
||||
};
|
||||
|
||||
struct RankOpLowering : public ConvertOpToLLVMPattern<RankOp> {
|
||||
using ConvertOpToLLVMPattern<RankOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
Type operandType = cast<RankOp>(op).memrefOrTensor().getType();
|
||||
if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
|
||||
UnrankedMemRefDescriptor desc(RankOp::Adaptor(operands).memrefOrTensor());
|
||||
rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
|
||||
return success();
|
||||
}
|
||||
if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
|
||||
rewriter.replaceOp(
|
||||
op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
// Common base for load and store operations on MemRefs. Restricts the match
|
||||
// to supported MemRef types. Provides functionality to emit code accessing a
|
||||
// specific element of the underlying data buffer.
|
||||
|
@ -3272,6 +3294,7 @@ void mlir::populateStdToLLVMMemoryConversionPatterns(
|
|||
DimOpLowering,
|
||||
LoadOpLowering,
|
||||
MemRefCastOpLowering,
|
||||
RankOpLowering,
|
||||
StoreOpLowering,
|
||||
SubViewOpLowering,
|
||||
ViewOpLowering,
|
||||
|
|
|
@ -1291,3 +1291,26 @@ func @bfloat(%arg0: bf16) -> bf16 {
|
|||
func @memref_index(%arg0: memref<32xindex>) -> memref<32xindex> {
|
||||
return %arg0 : memref<32xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @rank_of_unranked
|
||||
// CHECK32-LABEL: func @rank_of_unranked
|
||||
func @rank_of_unranked(%unranked: memref<*xi32>) {
|
||||
%rank = rank %unranked : memref<*xi32>
|
||||
return
|
||||
}
|
||||
// CHECK-NEXT: llvm.mlir.undef
|
||||
// CHECK-NEXT: llvm.insertvalue
|
||||
// CHECK-NEXT: llvm.insertvalue
|
||||
// CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i8* }">
|
||||
// CHECK32: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i8* }">
|
||||
|
||||
// CHECK-LABEL: func @rank_of_ranked
|
||||
// CHECK32-LABEL: func @rank_of_ranked
|
||||
func @rank_of_ranked(%ranked: memref<?xi32>) {
|
||||
%rank = rank %ranked : memref<?xi32>
|
||||
return
|
||||
}
|
||||
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK32: llvm.mlir.constant(1 : index) : !llvm.i32
|
||||
|
|
Loading…
Reference in New Issue