forked from OSchip/llvm-project
[mlir][memref] Tighten verification of memref.reinterpret_cast
We allow the omission of a map in memref.reinterpret_cast under the assumption, that the cast might cast to an identity layout. This change adds verification that the static knowledge that is present in the reinterpret_cast supports this assumption. Differential Revision: https://reviews.llvm.org/D116601
This commit is contained in:
parent
e3c84fb948
commit
33cec20dbd
|
@ -1155,40 +1155,44 @@ static LogicalResult verify(ReinterpretCastOp op) {
|
|||
extractFromI64ArrayAttr(op.static_sizes())))) {
|
||||
int64_t resultSize = std::get<0>(en.value());
|
||||
int64_t expectedSize = std::get<1>(en.value());
|
||||
if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
|
||||
if (!ShapedType::isDynamic(resultSize) &&
|
||||
!ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
|
||||
return op.emitError("expected result type with size = ")
|
||||
<< expectedSize << " instead of " << resultSize
|
||||
<< " in dim = " << en.index();
|
||||
}
|
||||
|
||||
// Match offset and strides in static_offset and static_strides attributes if
|
||||
// result memref type has an affine map specified.
|
||||
if (!resultType.getLayout().isIdentity()) {
|
||||
int64_t resultOffset;
|
||||
SmallVector<int64_t, 4> resultStrides;
|
||||
if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
|
||||
return failure();
|
||||
// Match offset and strides in static_offset and static_strides attributes. If
|
||||
// result memref type has no affine map specified, this will assume an
|
||||
// identity layout.
|
||||
int64_t resultOffset;
|
||||
SmallVector<int64_t, 4> resultStrides;
|
||||
if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
|
||||
return op.emitError(
|
||||
"expected result type to have strided layout but found ")
|
||||
<< resultType;
|
||||
|
||||
// Match offset in result memref type and in static_offsets attribute.
|
||||
int64_t expectedOffset =
|
||||
extractFromI64ArrayAttr(op.static_offsets()).front();
|
||||
if (!ShapedType::isDynamicStrideOrOffset(resultOffset) &&
|
||||
resultOffset != expectedOffset)
|
||||
return op.emitError("expected result type with offset = ")
|
||||
<< resultOffset << " instead of " << expectedOffset;
|
||||
// Match offset in result memref type and in static_offsets attribute.
|
||||
int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front();
|
||||
if (!ShapedType::isDynamicStrideOrOffset(resultOffset) &&
|
||||
!ShapedType::isDynamicStrideOrOffset(expectedOffset) &&
|
||||
resultOffset != expectedOffset)
|
||||
return op.emitError("expected result type with offset = ")
|
||||
<< resultOffset << " instead of " << expectedOffset;
|
||||
|
||||
// Match strides in result memref type and in static_strides attribute.
|
||||
for (auto &en : llvm::enumerate(llvm::zip(
|
||||
resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
|
||||
int64_t resultStride = std::get<0>(en.value());
|
||||
int64_t expectedStride = std::get<1>(en.value());
|
||||
if (!ShapedType::isDynamicStrideOrOffset(resultStride) &&
|
||||
resultStride != expectedStride)
|
||||
return op.emitError("expected result type with stride = ")
|
||||
<< expectedStride << " instead of " << resultStride
|
||||
<< " in dim = " << en.index();
|
||||
}
|
||||
// Match strides in result memref type and in static_strides attribute.
|
||||
for (auto &en : llvm::enumerate(llvm::zip(
|
||||
resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
|
||||
int64_t resultStride = std::get<0>(en.value());
|
||||
int64_t expectedStride = std::get<1>(en.value());
|
||||
if (!ShapedType::isDynamicStrideOrOffset(resultStride) &&
|
||||
!ShapedType::isDynamicStrideOrOffset(expectedStride) &&
|
||||
resultStride != expectedStride)
|
||||
return op.emitError("expected result type with stride = ")
|
||||
<< expectedStride << " instead of " << resultStride
|
||||
<< " in dim = " << en.index();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -151,7 +151,7 @@ func @multiple_reducing_dims_all_dynamic(%arg0 : memref<?x?x?xf32, offset: ?, st
|
|||
// CHECK: return %[[SIZE]] : index
|
||||
func @dim_of_sized_view(%arg : memref<?xi8>, %size: index) -> index {
|
||||
%c0 = arith.constant 0 : index
|
||||
%0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size], strides: [0] : memref<?xi8> to memref<?xi8>
|
||||
%0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size], strides: [1] : memref<?xi8> to memref<?xi8>
|
||||
%1 = memref.dim %0, %c0 : memref<?xi8>
|
||||
return %1 : index
|
||||
}
|
||||
|
|
|
@ -208,6 +208,44 @@ func @memref_reinterpret_cast_offset_mismatch(%in: memref<?xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
func @memref_reinterpret_cast_no_map_but_offset(%in: memref<?xf32>) {
|
||||
// expected-error @+1 {{expected result type with offset = 0 instead of 2}}
|
||||
%out = memref.reinterpret_cast %in to offset: [2], sizes: [10], strides: [1]
|
||||
: memref<?xf32> to memref<10xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @memref_reinterpret_cast_no_map_but_stride(%in: memref<?xf32>) {
|
||||
// expected-error @+1 {{expected result type with stride = 10 instead of 1 in dim = 0}}
|
||||
%out = memref.reinterpret_cast %in to offset: [0], sizes: [10], strides: [10]
|
||||
: memref<?xf32> to memref<10xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @memref_reinterpret_cast_no_map_but_strides(%in: memref<?x?xf32>) {
|
||||
// expected-error @+1 {{expected result type with stride = 42 instead of 10 in dim = 0}}
|
||||
%out = memref.reinterpret_cast %in to
|
||||
offset: [0], sizes: [9, 10], strides: [42, 1]
|
||||
: memref<?x?xf32> to memref<9x10xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @memref_reinterpret_cast_non_strided_layout(%in: memref<?x?xf32>) {
|
||||
// expected-error @+1 {{expected result type to have strided layout but found 'memref<9x10xf32, affine_map<(d0, d1) -> (d0)>>}}
|
||||
%out = memref.reinterpret_cast %in to
|
||||
offset: [0], sizes: [9, 10], strides: [42, 1]
|
||||
: memref<?x?xf32> to memref<9x10xf32, affine_map<(d0, d1) -> (d0)>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @memref_reshape_element_type_mismatch(
|
||||
%buf: memref<*xf32>, %shape: memref<1xi32>) {
|
||||
// expected-error @+1 {{element types of source and destination memref types should be the same}}
|
||||
|
|
|
@ -27,6 +27,15 @@ func @memref_reinterpret_cast_static_to_dynamic_sizes(%in: memref<?xf32>)
|
|||
return %out : memref<10x?xf32, offset: ?, strides: [?, 1]>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @memref_reinterpret_cast_dynamic_offset
|
||||
func @memref_reinterpret_cast_dynamic_offset(%in: memref<?xf32>, %offset: index)
|
||||
-> memref<10x?xf32, offset: ?, strides: [?, 1]> {
|
||||
%out = memref.reinterpret_cast %in to
|
||||
offset: [%offset], sizes: [10, 10], strides: [1, 1]
|
||||
: memref<?xf32> to memref<10x?xf32, offset: ?, strides: [?, 1]>
|
||||
return %out : memref<10x?xf32, offset: ?, strides: [?, 1]>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @memref_reshape(
|
||||
func @memref_reshape(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
|
||||
%shape2: memref<2xi32>, %shape3: memref<?xi32>) -> memref<*xf32> {
|
||||
|
|
|
@ -35,9 +35,9 @@ func @main() -> () {
|
|||
// CHECK-NEXT: [3, 4, 5]
|
||||
|
||||
%copy_two = memref.alloc() : memref<3x2xf32>
|
||||
%copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2,3], strides:[1, 2]
|
||||
: memref<3x2xf32> to memref<2x3xf32>
|
||||
memref.copy %input, %copy_two_casted : memref<2x3xf32> to memref<2x3xf32>
|
||||
%copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2, 3], strides:[1, 2]
|
||||
: memref<3x2xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]>
|
||||
memref.copy %input, %copy_two_casted : memref<2x3xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]>
|
||||
%unranked_copy_two = memref.cast %copy_two : memref<3x2xf32> to memref<*xf32>
|
||||
call @print_memref_f32(%unranked_copy_two) : (memref<*xf32>) -> ()
|
||||
// CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1]
|
||||
|
|
Loading…
Reference in New Issue