[MLIR] Fix AsmPrinter for short-hand bound notation

This CL retricts shorthand notation printing to only the bounds that can
be roundtripped unambiguously; i.e.:
1. ()[]->(%some_cst) ()[]
2. ()[s0]->(s0) ()[%some_symbol]

Upon inspection it turns out that the constant case was lossy so this CL also
updates it.

Note however that fixing this issue exhibits a potential issues in unroll.mlir.
L488 exhibits a map ()[s0] -> (1)()[%arg0] which could be simplified down to
()[]->(1)()[].
This does not seem like a bug but maybe an undesired complexity in the maps
generated by unrolling.
bondhugula@, care to take a look?

PiperOrigin-RevId: 214531410
This commit is contained in:
Nicolas Vasilache 2018-09-25 17:15:54 -07:00 committed by jpienaar
parent 0f7fddfd65
commit 54e5b4b4c0
3 changed files with 46 additions and 23 deletions

View File

@ -1422,31 +1422,36 @@ void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) {
AffineMap *map = bound.getMap();
// Check if this bound should be printed using short-hand notation.
// The decision to restrict printing short-hand notation to trivial cases
// comes from the will to roundtrip MLIR binary -> text -> binary in a
// lossless way.
// Therefore, short-hand parsing and printing is only supported for
// zero-operand constant maps and single symbol operand identity maps.
if (map->getNumResults() == 1) {
AffineExpr *expr = map->getResult(0);
// Print constant bound.
if (auto *constExpr = dyn_cast<AffineConstantExpr>(expr)) {
os << constExpr->getValue();
return;
if (map->getNumDims() == 0 && map->getNumSymbols() == 0) {
if (auto *constExpr = dyn_cast<AffineConstantExpr>(expr)) {
os << constExpr->getValue();
return;
}
}
// Print bound that consists of a single SSA id, we need an indirection
// to achieve this.
if (auto *dimExpr = dyn_cast<AffineDimExpr>(expr)) {
printOperand(bound.getOperand(dimExpr->getPosition()));
return;
} else if (auto *symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
printOperand(
bound.getOperand(map->getNumDims() + symExpr->getPosition()));
return;
// Print bound that consists of a single SSA symbol if the map is over a
// single symbol.
if (map->getNumDims() == 0 && map->getNumSymbols() == 1) {
if (auto *symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
printOperand(bound.getOperand(0));
return;
}
}
} else {
// Map has multiple results. Print 'min' or 'max' prefix.
os << prefix << ' ';
}
// Print the map and the operands.
// Print the map and its operands.
printAffineMapReference(map);
printDimAndSymbolList(bound.getStmtOperands(), map->getNumDims());
}

View File

@ -191,7 +191,7 @@ mlfunc @complex_loops() {
mlfunc @triang_loop(%arg0 : affineint, %arg1 : memref<?x?xi32>) {
%c = constant 0 : i32 // CHECK: %c0_i32 = constant 0 : i32
for %i0 = 1 to %arg0 { // CHECK: for %i0 = 1 to %arg0 {
for %i1 = %i0 to %arg0 { // CHECK: for %i1 = %i0 to %arg0 {
for %i1 = %i0 to %arg0 { // CHECK: for %i1 = #map1(%i0) to %arg0 {
store %c, %arg1[%i0, %i1] : memref<?x?xi32> // CHECK: store %c0_i32, %arg1[%i0, %i1]
} // CHECK: }
} // CHECK: }
@ -214,7 +214,7 @@ mlfunc @loop_bounds(%N : affineint) {
%s = "foo"(%N) : (affineint) -> affineint
// CHECK: for %i0 = %0 to %arg0
for %i = %s to %N {
// CHECK: for %i1 = %i0 to 0 step -1
// CHECK: for %i1 = #map1(%i0) to 0 step -1
for %j = %i to 0 step -1 {
// CHECK: %1 = affine_apply #map{{.*}}(%i0, %i1)[%0]
%w = affine_apply(d0, d1)[s0] -> (d0+d1, s0+1) (%i, %j) [%s]
@ -456,12 +456,30 @@ mlfunc @mlfuncattrempty() -> ()
}
// CHECK-label mlfunc @mlfuncsimplemap
#mapsimple0 = ()[s0, s1, s2] -> (s0)
#mapsimple1 = (d0)[s0, s1, s2] -> (s1)
mlfunc @mlfuncsimplemap(%arg0 : affineint, %arg1 : affineint, %arg2 : affineint) -> () {
for %i0 = 0 to #mapsimple0()[%arg0, %arg1, %arg2] { // CHECK: for %i0 = 0 to %arg0 {
for %i1 = 0 to #mapsimple1(%i0)[%arg0, %arg1, %arg2] { // CHECK: for %i1 = 0 to %arg1 {
%c42_i32 = constant 42 : i32
#map_simple0 = ()[] -> (10)
#map_simple1 = ()[s0] -> (s0)
#map_non_simple0 = (d0)[] -> (d0)
#map_non_simple1 = (d0)[s0] -> (d0 + s0)
#map_non_simple2 = ()[s0, s1] -> (s0 + s1)
#map_non_simple3 = ()[s0] -> (s0 + 3)
mlfunc @mlfuncsimplemap(%arg0 : affineint, %arg1 : affineint) -> () {
for %i0 = 0 to #map_simple0()[] {
// CHECK: for %i0 = 0 to 10 {
for %i1 = 0 to #map_simple1()[%arg1] {
// CHECK: for %i1 = 0 to %arg1 {
for %i2 = 0 to #map_non_simple0(%i0)[] {
// CHECK: for %i2 = 0 to #map{{[a-z_0-9]*}}(%i0) {
for %i3 = 0 to #map_non_simple1(%i0)[%arg1] {
// CHECK: for %i3 = 0 to #map{{[a-z_0-9]*}}(%i0)[%arg1] {
for %i4 = 0 to #map_non_simple2()[%arg1, %arg0] {
// CHECK: for %i4 = 0 to #map{{[a-z_0-9]*}}()[%arg1, %arg0] {
for %i5 = 0 to #map_non_simple3()[%arg0] {
// CHECK: for %i5 = 0 to #map{{[a-z_0-9]*}}()[%arg0] {
%c42_i32 = constant 42 : i32
}
}
}
}
}
}
return

View File

@ -467,7 +467,7 @@ mlfunc @loop_nest_operand2() {
mlfunc @loop_nest_operand3() {
// UNROLL-BY-4: for %i0 = 1 to 100 step 2 {
for %i = 1 to 100 step 2 {
// UNROLL-BY-4: for %i1 = %i0 to #map{{[0-9]+}}(%i0) step 4 {
// UNROLL-BY-4: for %i1 = (d0) -> (d0)(%i0) to #map{{[0-9]+}}(%i0) step 4 {
// UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32
@ -485,7 +485,7 @@ mlfunc @loop_nest_operand3() {
mlfunc @loop_nest_operand4(%N : affineint) {
// UNROLL-BY-4: for %i0 = 1 to 100 {
for %i = 1 to 100 {
// UNROLL-BY-4: for %i1 = 1 to #map{{[0-9]+}}()[%arg0] step 4 {
// UNROLL-BY-4: for %i1 = ()[s0] -> (1)()[%arg0] to #map{{[0-9]+}}()[%arg0] step 4 {
// UNROLL-BY-4: %0 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32