forked from OSchip/llvm-project
[mlir] MemRefToLLVM: convert memref.view operations for empty memrefs
Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D126094
This commit is contained in:
parent
8801a5d185
commit
705f048cbb
|
@ -1848,6 +1848,12 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
|
|||
return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
|
||||
assert(offset == 0 && "expected offset to be 0");
|
||||
|
||||
// Target memref must be contiguous in memory (innermost stride is 1), or
|
||||
// empty (special case when at least one of the memref dimensions is 0).
|
||||
if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
|
||||
return viewOp.emitWarning("cannot cast to non-contiguous shape"),
|
||||
failure();
|
||||
|
||||
// Create the descriptor.
|
||||
MemRefDescriptor sourceMemRef(adaptor.source());
|
||||
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
|
||||
|
@ -1884,9 +1890,6 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
|
|||
return rewriter.replaceOp(viewOp, {targetMemRef}), success();
|
||||
|
||||
// Fields 4 and 5: Update sizes and strides.
|
||||
if (strides.back() != 1)
|
||||
return viewOp.emitWarning("cannot cast to non-contiguous shape"),
|
||||
failure();
|
||||
Value stride = nullptr, nextSize = nullptr;
|
||||
for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
|
||||
// Update size.
|
||||
|
|
|
@ -89,6 +89,29 @@ func.func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABL: func @view_empty_memref(
|
||||
// CHECK: %[[ARG0:.*]]: index,
|
||||
// CHECK: %[[ARG1:.*]]: memref<0xi8>)
|
||||
func.func @view_empty_memref(%offset: index, %mem: memref<0xi8>) {
|
||||
|
||||
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: llvm.mlir.constant(0 : index) : i64
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: llvm.mlir.constant(4 : index) : i64
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: llvm.mlir.constant(0 : index) : i64
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: llvm.mlir.constant(0 : index) : i64
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: llvm.mlir.constant(0 : index) : i64
|
||||
// CHECK: = llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
%0 = memref.view %mem[%offset][] : memref<0xi8> to memref<0x4xf32>
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @subview(
|
||||
// CHECK: %[[MEM:.*]]: memref<{{.*}}>,
|
||||
// CHECK: %[[ARG0f:[a-zA-Z0-9]*]]: index,
|
||||
|
|
Loading…
Reference in New Issue