[mlir][memref] Tighten verification of memref.reinterpret_cast

We allow the omission of a map in memref.reinterpret_cast under the assumption,
that the cast might cast to an identity layout. This change adds verification
that the static knowledge that is present in the reinterpret_cast supports
this assumption.

Differential Revision: https://reviews.llvm.org/D116601
This commit is contained in:
Stephan Herhut 2022-01-07 09:53:14 +01:00
parent e3c84fb948
commit 33cec20dbd
5 changed files with 81 additions and 30 deletions

View File

@ -1155,40 +1155,44 @@ static LogicalResult verify(ReinterpretCastOp op) {
extractFromI64ArrayAttr(op.static_sizes())))) {
int64_t resultSize = std::get<0>(en.value());
int64_t expectedSize = std::get<1>(en.value());
if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
if (!ShapedType::isDynamic(resultSize) &&
!ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
return op.emitError("expected result type with size = ")
<< expectedSize << " instead of " << resultSize
<< " in dim = " << en.index();
}
// Match offset and strides in static_offset and static_strides attributes if
// result memref type has an affine map specified.
if (!resultType.getLayout().isIdentity()) {
int64_t resultOffset;
SmallVector<int64_t, 4> resultStrides;
if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
return failure();
// Match offset and strides in static_offset and static_strides attributes. If
// result memref type has no affine map specified, this will assume an
// identity layout.
int64_t resultOffset;
SmallVector<int64_t, 4> resultStrides;
if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
return op.emitError(
"expected result type to have strided layout but found ")
<< resultType;
// Match offset in result memref type and in static_offsets attribute.
int64_t expectedOffset =
extractFromI64ArrayAttr(op.static_offsets()).front();
if (!ShapedType::isDynamicStrideOrOffset(resultOffset) &&
resultOffset != expectedOffset)
return op.emitError("expected result type with offset = ")
<< resultOffset << " instead of " << expectedOffset;
// Match offset in result memref type and in static_offsets attribute.
int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front();
if (!ShapedType::isDynamicStrideOrOffset(resultOffset) &&
!ShapedType::isDynamicStrideOrOffset(expectedOffset) &&
resultOffset != expectedOffset)
return op.emitError("expected result type with offset = ")
<< resultOffset << " instead of " << expectedOffset;
// Match strides in result memref type and in static_strides attribute.
for (auto &en : llvm::enumerate(llvm::zip(
resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
int64_t resultStride = std::get<0>(en.value());
int64_t expectedStride = std::get<1>(en.value());
if (!ShapedType::isDynamicStrideOrOffset(resultStride) &&
resultStride != expectedStride)
return op.emitError("expected result type with stride = ")
<< expectedStride << " instead of " << resultStride
<< " in dim = " << en.index();
}
// Match strides in result memref type and in static_strides attribute.
for (auto &en : llvm::enumerate(llvm::zip(
resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
int64_t resultStride = std::get<0>(en.value());
int64_t expectedStride = std::get<1>(en.value());
if (!ShapedType::isDynamicStrideOrOffset(resultStride) &&
!ShapedType::isDynamicStrideOrOffset(expectedStride) &&
resultStride != expectedStride)
return op.emitError("expected result type with stride = ")
<< expectedStride << " instead of " << resultStride
<< " in dim = " << en.index();
}
return success();
}

View File

@ -151,7 +151,7 @@ func @multiple_reducing_dims_all_dynamic(%arg0 : memref<?x?x?xf32, offset: ?, st
// CHECK: return %[[SIZE]] : index
func @dim_of_sized_view(%arg : memref<?xi8>, %size: index) -> index {
%c0 = arith.constant 0 : index
%0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size], strides: [0] : memref<?xi8> to memref<?xi8>
%0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size], strides: [1] : memref<?xi8> to memref<?xi8>
%1 = memref.dim %0, %c0 : memref<?xi8>
return %1 : index
}

View File

@ -208,6 +208,44 @@ func @memref_reinterpret_cast_offset_mismatch(%in: memref<?xf32>) {
// -----
func @memref_reinterpret_cast_no_map_but_offset(%in: memref<?xf32>) {
// expected-error @+1 {{expected result type with offset = 0 instead of 2}}
%out = memref.reinterpret_cast %in to offset: [2], sizes: [10], strides: [1]
: memref<?xf32> to memref<10xf32>
return
}
// -----
func @memref_reinterpret_cast_no_map_but_stride(%in: memref<?xf32>) {
// expected-error @+1 {{expected result type with stride = 10 instead of 1 in dim = 0}}
%out = memref.reinterpret_cast %in to offset: [0], sizes: [10], strides: [10]
: memref<?xf32> to memref<10xf32>
return
}
// -----
func @memref_reinterpret_cast_no_map_but_strides(%in: memref<?x?xf32>) {
// expected-error @+1 {{expected result type with stride = 42 instead of 10 in dim = 0}}
%out = memref.reinterpret_cast %in to
offset: [0], sizes: [9, 10], strides: [42, 1]
: memref<?x?xf32> to memref<9x10xf32>
return
}
// -----
func @memref_reinterpret_cast_non_strided_layout(%in: memref<?x?xf32>) {
// expected-error @+1 {{expected result type to have strided layout but found 'memref<9x10xf32, affine_map<(d0, d1) -> (d0)>>}}
%out = memref.reinterpret_cast %in to
offset: [0], sizes: [9, 10], strides: [42, 1]
: memref<?x?xf32> to memref<9x10xf32, affine_map<(d0, d1) -> (d0)>>
return
}
// -----
func @memref_reshape_element_type_mismatch(
%buf: memref<*xf32>, %shape: memref<1xi32>) {
// expected-error @+1 {{element types of source and destination memref types should be the same}}

View File

@ -27,6 +27,15 @@ func @memref_reinterpret_cast_static_to_dynamic_sizes(%in: memref<?xf32>)
return %out : memref<10x?xf32, offset: ?, strides: [?, 1]>
}
// CHECK-LABEL: func @memref_reinterpret_cast_dynamic_offset
func @memref_reinterpret_cast_dynamic_offset(%in: memref<?xf32>, %offset: index)
-> memref<10x?xf32, offset: ?, strides: [?, 1]> {
%out = memref.reinterpret_cast %in to
offset: [%offset], sizes: [10, 10], strides: [1, 1]
: memref<?xf32> to memref<10x?xf32, offset: ?, strides: [?, 1]>
return %out : memref<10x?xf32, offset: ?, strides: [?, 1]>
}
// CHECK-LABEL: func @memref_reshape(
func @memref_reshape(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
%shape2: memref<2xi32>, %shape3: memref<?xi32>) -> memref<*xf32> {

View File

@ -35,9 +35,9 @@ func @main() -> () {
// CHECK-NEXT: [3, 4, 5]
%copy_two = memref.alloc() : memref<3x2xf32>
%copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2,3], strides:[1, 2]
: memref<3x2xf32> to memref<2x3xf32>
memref.copy %input, %copy_two_casted : memref<2x3xf32> to memref<2x3xf32>
%copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2, 3], strides:[1, 2]
: memref<3x2xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]>
memref.copy %input, %copy_two_casted : memref<2x3xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]>
%unranked_copy_two = memref.cast %copy_two : memref<3x2xf32> to memref<*xf32>
call @print_memref_f32(%unranked_copy_two) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1]