forked from OSchip/llvm-project
Fix pretty printer corner case in mlir_runner_utils.cpp.
In the particular case where the size of a memref dimension is 1, double printing would happen because printLast was called unconditionally. This CL fixes the print and updates an incorrect test that should have caught this in the first place. PiperOrigin-RevId: 281345142
This commit is contained in:
parent
c017704cd9
commit
3732ba4def
|
@ -86,14 +86,6 @@ template <typename T> struct MemRefDataPrinter<T, 0> {
|
|||
int64_t *sizes = nullptr, int64_t *strides = nullptr);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static void printNewLineIfVector(std::ostream &os, T &t) {}
|
||||
|
||||
template <typename T, int Dim, int... Dims>
|
||||
static void printNewLineIfVector(std::ostream &os, Vector<T, Dim, Dims...> &t) {
|
||||
os << "\n";
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
void MemRefDataPrinter<T, N>::printFirst(std::ostream &os, T *base,
|
||||
int64_t rank, int64_t offset,
|
||||
|
@ -101,10 +93,12 @@ void MemRefDataPrinter<T, N>::printFirst(std::ostream &os, T *base,
|
|||
os << "[";
|
||||
MemRefDataPrinter<T, N - 1>::print(os, base, rank, offset, sizes + 1,
|
||||
strides + 1);
|
||||
if (sizes[0] > 0) {
|
||||
os << ", ";
|
||||
printNewLineIfVector(os, *base);
|
||||
// If single element, close square bracket and return early.
|
||||
if (sizes[0] <= 1) {
|
||||
os << "]";
|
||||
return;
|
||||
}
|
||||
os << ", ";
|
||||
if (N > 1)
|
||||
os << "\n";
|
||||
}
|
||||
|
@ -116,13 +110,14 @@ void MemRefDataPrinter<T, N>::print(std::ostream &os, T *base, int64_t rank,
|
|||
printFirst(os, base, rank, offset, sizes, strides);
|
||||
for (unsigned i = 1; i + 1 < sizes[0]; ++i) {
|
||||
printSpace(os, rank - N + 1);
|
||||
MemRefDataPrinter<T, N - 1>::print(os, base, rank, offset + i * (*strides),
|
||||
MemRefDataPrinter<T, N - 1>::print(os, base, rank, offset + i * strides[0],
|
||||
sizes + 1, strides + 1);
|
||||
os << ", ";
|
||||
printNewLineIfVector(os, *base);
|
||||
if (N > 1)
|
||||
os << "\n";
|
||||
}
|
||||
if (sizes[0] <= 1)
|
||||
return;
|
||||
printLast(os, base, rank, offset, sizes, strides);
|
||||
}
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ func @print_0d() {
|
|||
dealloc %A : memref<f32>
|
||||
return
|
||||
}
|
||||
// PRINT-0D: Memref base@ = {{.*}} rank = 0 offset = 0 data = [2]
|
||||
|
||||
func @print_1d() {
|
||||
%f = constant 2.00000e+00 : f32
|
||||
|
@ -21,6 +22,8 @@ func @print_1d() {
|
|||
dealloc %A : memref<16xf32>
|
||||
return
|
||||
}
|
||||
// PRINT-1D: Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [16] strides = [1] data =
|
||||
// PRINT-1D-NEXT: [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
|
||||
|
||||
func @print_3d() {
|
||||
%f = constant 2.00000e+00 : f32
|
||||
|
@ -36,16 +39,6 @@ func @print_3d() {
|
|||
dealloc %A : memref<3x4x5xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func @print_memref_0d_f32(memref<f32>)
|
||||
func @print_memref_1d_f32(memref<?xf32>)
|
||||
func @print_memref_3d_f32(memref<?x?x?xf32>)
|
||||
|
||||
// PRINT-0D: Memref base@ = {{.*}} rank = 0 offset = 0 data = [2]
|
||||
|
||||
// PRINT-1D: Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [16] strides = [1] data =
|
||||
// PRINT-1D-NEXT: [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
|
||||
|
||||
// PRINT-3D: Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [3, 4, 5] strides = [20, 5, 1] data =
|
||||
// PRINT-3D-COUNT-4: {{.*[[:space:]].*}}2, 2, 2, 2, 2
|
||||
// PRINT-3D-COUNT-4: {{.*[[:space:]].*}}2, 2, 2, 2, 2
|
||||
|
@ -53,6 +46,11 @@ func @print_memref_3d_f32(memref<?x?x?xf32>)
|
|||
// PRINT-3D-NEXT: 2, 2, 4, 2, 2
|
||||
// PRINT-3D-NEXT: 2, 2, 2, 2, 2
|
||||
|
||||
func @print_memref_0d_f32(memref<f32>)
|
||||
func @print_memref_1d_f32(memref<?xf32>)
|
||||
func @print_memref_3d_f32(memref<?x?x?xf32>)
|
||||
|
||||
|
||||
!vector_type_C = type vector<4x4xf32>
|
||||
!matrix_type_CC = type memref<1x1x!vector_type_C>
|
||||
func @vector_splat_2d() {
|
||||
|
@ -70,9 +68,6 @@ func @vector_splat_2d() {
|
|||
}
|
||||
|
||||
// PRINT-VECTOR-SPLAT-2D: Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [1, 1] strides = [1, 1] data =
|
||||
// PRINT-VECTOR-SPLAT-2D-NEXT: [((10, 10, 10, 10), (10, 10, 10, 10), (10, 10, 10, 10), (10, 10, 10, 10)),
|
||||
// PRINT-VECTOR-SPLAT-2D-NEXT: ((10, 10, 10, 10), (10, 10, 10, 10), (10, 10, 10, 10), (10, 10, 10, 10))],
|
||||
// PRINT-VECTOR-SPLAT-2D: [((10, 10, 10, 10), (10, 10, 10, 10), (10, 10, 10, 10), (10, 10, 10, 10)),
|
||||
// PRINT-VECTOR-SPLAT-2D-NEXT: ((10, 10, 10, 10), (10, 10, 10, 10), (10, 10, 10, 10), (10, 10, 10, 10))]
|
||||
// PRINT-VECTOR-SPLAT-2D-NEXT: [((10, 10, 10, 10), (10, 10, 10, 10), (10, 10, 10, 10), (10, 10, 10, 10))]
|
||||
|
||||
func @print_memref_vector_4x4xf32(memref<?x?x!vector_type_C>)
|
||||
|
|
Loading…
Reference in New Issue