forked from OSchip/llvm-project
[mlir][memref] ReinterpretCast: allow static sizes/strides/offset where affine map expects dynamic
* There is no reason to forbid that case * Also, user will get very unfriendly error like `expected result type with offset = -9223372036854775808 instead of 1` Differential Revision: https://reviews.llvm.org/D114678
This commit is contained in:
parent
6f1a501fdd
commit
28ab10f404
|
@ -1158,7 +1158,7 @@ static LogicalResult verify(ReinterpretCastOp op) {
|
|||
extractFromI64ArrayAttr(op.static_sizes())))) {
|
||||
int64_t resultSize = std::get<0>(en.value());
|
||||
int64_t expectedSize = std::get<1>(en.value());
|
||||
if (resultSize != expectedSize)
|
||||
if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
|
||||
return op.emitError("expected result type with size = ")
|
||||
<< expectedSize << " instead of " << resultSize
|
||||
<< " in dim = " << en.index();
|
||||
|
@ -1175,7 +1175,8 @@ static LogicalResult verify(ReinterpretCastOp op) {
|
|||
// Match offset in result memref type and in static_offsets attribute.
|
||||
int64_t expectedOffset =
|
||||
extractFromI64ArrayAttr(op.static_offsets()).front();
|
||||
if (resultOffset != expectedOffset)
|
||||
if (!ShapedType::isDynamicStrideOrOffset(resultOffset) &&
|
||||
resultOffset != expectedOffset)
|
||||
return op.emitError("expected result type with offset = ")
|
||||
<< resultOffset << " instead of " << expectedOffset;
|
||||
|
||||
|
@ -1184,7 +1185,8 @@ static LogicalResult verify(ReinterpretCastOp op) {
|
|||
resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
|
||||
int64_t resultStride = std::get<0>(en.value());
|
||||
int64_t expectedStride = std::get<1>(en.value());
|
||||
if (resultStride != expectedStride)
|
||||
if (!ShapedType::isDynamicStrideOrOffset(resultStride) &&
|
||||
resultStride != expectedStride)
|
||||
return op.emitError("expected result type with stride = ")
|
||||
<< expectedStride << " instead of " << resultStride
|
||||
<< " in dim = " << en.index();
|
||||
|
|
|
@ -208,18 +208,6 @@ func @memref_reinterpret_cast_offset_mismatch(%in: memref<?xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
func @memref_reinterpret_cast_offset_mismatch(%in: memref<?xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c10 = arith.constant 10 : index
|
||||
// expected-error @+1 {{expected result type with size = 10 instead of -1 in dim = 0}}
|
||||
%out = memref.reinterpret_cast %in to
|
||||
offset: [%c0], sizes: [10, %c10], strides: [%c10, 1]
|
||||
: memref<?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @memref_reshape_element_type_mismatch(
|
||||
%buf: memref<*xf32>, %shape: memref<1xi32>) {
|
||||
// expected-error @+1 {{element types of source and destination memref types should be the same}}
|
||||
|
|
|
@ -18,6 +18,15 @@ func @memref_reinterpret_cast(%in: memref<?xf32>)
|
|||
return %out : memref<10x?xf32, offset: ?, strides: [?, 1]>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @memref_reinterpret_cast_static_to_dynamic_sizes
|
||||
func @memref_reinterpret_cast_static_to_dynamic_sizes(%in: memref<?xf32>)
|
||||
-> memref<10x?xf32, offset: ?, strides: [?, 1]> {
|
||||
%out = memref.reinterpret_cast %in to
|
||||
offset: [1], sizes: [10, 10], strides: [1, 1]
|
||||
: memref<?xf32> to memref<10x?xf32, offset: ?, strides: [?, 1]>
|
||||
return %out : memref<10x?xf32, offset: ?, strides: [?, 1]>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @memref_reshape(
|
||||
func @memref_reshape(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
|
||||
%shape2: memref<2xi32>, %shape3: memref<?xi32>) -> memref<*xf32> {
|
||||
|
|
Loading…
Reference in New Issue