[mlir] MemRefToLLVM: convert memref.view operations for empty memrefs

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D126094
This commit is contained in:
Eugene Zhulenev 2022-05-20 14:59:23 -07:00
parent 8801a5d185
commit 705f048cbb
2 changed files with 29 additions and 3 deletions

View File

@ -1848,6 +1848,12 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
assert(offset == 0 && "expected offset to be 0");
// Target memref must be contiguous in memory (innermost stride is 1), or
// empty (special case when at least one of the memref dimensions is 0).
if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
return viewOp.emitWarning("cannot cast to non-contiguous shape"),
failure();
// Create the descriptor.
MemRefDescriptor sourceMemRef(adaptor.source());
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
@ -1884,9 +1890,6 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
return rewriter.replaceOp(viewOp, {targetMemRef}), success();
// Fields 4 and 5: Update sizes and strides.
if (strides.back() != 1)
return viewOp.emitWarning("cannot cast to non-contiguous shape"),
failure();
Value stride = nullptr, nextSize = nullptr;
for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
// Update size.

View File

@ -89,6 +89,29 @@ func.func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
// -----
// CHECK-LABL: func @view_empty_memref(
// CHECK: %[[ARG0:.*]]: index,
// CHECK: %[[ARG1:.*]]: memref<0xi8>)
func.func @view_empty_memref(%offset: index, %mem: memref<0xi8>) {
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.mlir.constant(0 : index) : i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.mlir.constant(4 : index) : i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.mlir.constant(0 : index) : i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.mlir.constant(0 : index) : i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.mlir.constant(0 : index) : i64
// CHECK: = llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
%0 = memref.view %mem[%offset][] : memref<0xi8> to memref<0x4xf32>
return
}
// -----
// CHECK-LABEL: func @subview(
// CHECK: %[[MEM:.*]]: memref<{{.*}}>,
// CHECK: %[[ARG0f:[a-zA-Z0-9]*]]: index,