Fix pretty printer corner case in mlir_runner_utils.cpp.

In the particular case where the size of a memref dimension is 1, double printing would happen because printLast was called unconditionally.
This CL fixes the print and updates an incorrect test that should have caught this in the first place.

PiperOrigin-RevId: 281345142
This commit is contained in:
Nicolas Vasilache 2019-11-19 11:51:53 -08:00 committed by A. Unique TensorFlower
parent c017704cd9
commit 3732ba4def
2 changed files with 17 additions and 27 deletions

View File

@ -86,14 +86,6 @@ template <typename T> struct MemRefDataPrinter<T, 0> {
int64_t *sizes = nullptr, int64_t *strides = nullptr);
};
template <typename T>
static void printNewLineIfVector(std::ostream &os, T &t) {}
template <typename T, int Dim, int... Dims>
static void printNewLineIfVector(std::ostream &os, Vector<T, Dim, Dims...> &t) {
os << "\n";
}
template <typename T, int N>
void MemRefDataPrinter<T, N>::printFirst(std::ostream &os, T *base,
int64_t rank, int64_t offset,
@ -101,10 +93,12 @@ void MemRefDataPrinter<T, N>::printFirst(std::ostream &os, T *base,
os << "[";
MemRefDataPrinter<T, N - 1>::print(os, base, rank, offset, sizes + 1,
strides + 1);
if (sizes[0] > 0) {
os << ", ";
printNewLineIfVector(os, *base);
// If single element, close square bracket and return early.
if (sizes[0] <= 1) {
os << "]";
return;
}
os << ", ";
if (N > 1)
os << "\n";
}
@ -116,13 +110,14 @@ void MemRefDataPrinter<T, N>::print(std::ostream &os, T *base, int64_t rank,
printFirst(os, base, rank, offset, sizes, strides);
for (unsigned i = 1; i + 1 < sizes[0]; ++i) {
printSpace(os, rank - N + 1);
MemRefDataPrinter<T, N - 1>::print(os, base, rank, offset + i * (*strides),
MemRefDataPrinter<T, N - 1>::print(os, base, rank, offset + i * strides[0],
sizes + 1, strides + 1);
os << ", ";
printNewLineIfVector(os, *base);
if (N > 1)
os << "\n";
}
if (sizes[0] <= 1)
return;
printLast(os, base, rank, offset, sizes, strides);
}

View File

@ -11,6 +11,7 @@ func @print_0d() {
dealloc %A : memref<f32>
return
}
// PRINT-0D: Memref base@ = {{.*}} rank = 0 offset = 0 data = [2]
func @print_1d() {
%f = constant 2.00000e+00 : f32
@ -21,6 +22,8 @@ func @print_1d() {
dealloc %A : memref<16xf32>
return
}
// PRINT-1D: Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [16] strides = [1] data =
// PRINT-1D-NEXT: [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
func @print_3d() {
%f = constant 2.00000e+00 : f32
@ -36,16 +39,6 @@ func @print_3d() {
dealloc %A : memref<3x4x5xf32>
return
}
func @print_memref_0d_f32(memref<f32>)
func @print_memref_1d_f32(memref<?xf32>)
func @print_memref_3d_f32(memref<?x?x?xf32>)
// PRINT-0D: Memref base@ = {{.*}} rank = 0 offset = 0 data = [2]
// PRINT-1D: Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [16] strides = [1] data =
// PRINT-1D-NEXT: [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
// PRINT-3D: Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [3, 4, 5] strides = [20, 5, 1] data =
// PRINT-3D-COUNT-4: {{.*[[:space:]].*}}2, 2, 2, 2, 2
// PRINT-3D-COUNT-4: {{.*[[:space:]].*}}2, 2, 2, 2, 2
@ -53,6 +46,11 @@ func @print_memref_3d_f32(memref<?x?x?xf32>)
// PRINT-3D-NEXT: 2, 2, 4, 2, 2
// PRINT-3D-NEXT: 2, 2, 2, 2, 2
func @print_memref_0d_f32(memref<f32>)
func @print_memref_1d_f32(memref<?xf32>)
func @print_memref_3d_f32(memref<?x?x?xf32>)
!vector_type_C = type vector<4x4xf32>
!matrix_type_CC = type memref<1x1x!vector_type_C>
func @vector_splat_2d() {
@ -70,9 +68,6 @@ func @vector_splat_2d() {
}
// PRINT-VECTOR-SPLAT-2D: Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [1, 1] strides = [1, 1] data =
// PRINT-VECTOR-SPLAT-2D-NEXT: [((10, 10, 10, 10), (10, 10, 10, 10), (10, 10, 10, 10), (10, 10, 10, 10)),
// PRINT-VECTOR-SPLAT-2D-NEXT: ((10, 10, 10, 10), (10, 10, 10, 10), (10, 10, 10, 10), (10, 10, 10, 10))],
// PRINT-VECTOR-SPLAT-2D: [((10, 10, 10, 10), (10, 10, 10, 10), (10, 10, 10, 10), (10, 10, 10, 10)),
// PRINT-VECTOR-SPLAT-2D-NEXT: ((10, 10, 10, 10), (10, 10, 10, 10), (10, 10, 10, 10), (10, 10, 10, 10))]
// PRINT-VECTOR-SPLAT-2D-NEXT: [((10, 10, 10, 10), (10, 10, 10, 10), (10, 10, 10, 10), (10, 10, 10, 10))]
func @print_memref_vector_4x4xf32(memref<?x?x!vector_type_C>)