forked from OSchip/llvm-project
[MLIR] Fix AsmPrinter for short-hand bound notation
This CL retricts shorthand notation printing to only the bounds that can be roundtripped unambiguously; i.e.: 1. ()[]->(%some_cst) ()[] 2. ()[s0]->(s0) ()[%some_symbol] Upon inspection it turns out that the constant case was lossy so this CL also updates it. Note however that fixing this issue exhibits a potential issues in unroll.mlir. L488 exhibits a map ()[s0] -> (1)()[%arg0] which could be simplified down to ()[]->(1)()[]. This does not seem like a bug but maybe an undesired complexity in the maps generated by unrolling. bondhugula@, care to take a look? PiperOrigin-RevId: 214531410
This commit is contained in:
parent
0f7fddfd65
commit
54e5b4b4c0
|
@ -1422,31 +1422,36 @@ void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) {
|
|||
AffineMap *map = bound.getMap();
|
||||
|
||||
// Check if this bound should be printed using short-hand notation.
|
||||
// The decision to restrict printing short-hand notation to trivial cases
|
||||
// comes from the will to roundtrip MLIR binary -> text -> binary in a
|
||||
// lossless way.
|
||||
// Therefore, short-hand parsing and printing is only supported for
|
||||
// zero-operand constant maps and single symbol operand identity maps.
|
||||
if (map->getNumResults() == 1) {
|
||||
AffineExpr *expr = map->getResult(0);
|
||||
|
||||
// Print constant bound.
|
||||
if (auto *constExpr = dyn_cast<AffineConstantExpr>(expr)) {
|
||||
os << constExpr->getValue();
|
||||
return;
|
||||
if (map->getNumDims() == 0 && map->getNumSymbols() == 0) {
|
||||
if (auto *constExpr = dyn_cast<AffineConstantExpr>(expr)) {
|
||||
os << constExpr->getValue();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Print bound that consists of a single SSA id, we need an indirection
|
||||
// to achieve this.
|
||||
if (auto *dimExpr = dyn_cast<AffineDimExpr>(expr)) {
|
||||
printOperand(bound.getOperand(dimExpr->getPosition()));
|
||||
return;
|
||||
} else if (auto *symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
|
||||
printOperand(
|
||||
bound.getOperand(map->getNumDims() + symExpr->getPosition()));
|
||||
return;
|
||||
// Print bound that consists of a single SSA symbol if the map is over a
|
||||
// single symbol.
|
||||
if (map->getNumDims() == 0 && map->getNumSymbols() == 1) {
|
||||
if (auto *symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
|
||||
printOperand(bound.getOperand(0));
|
||||
return;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Map has multiple results. Print 'min' or 'max' prefix.
|
||||
os << prefix << ' ';
|
||||
}
|
||||
|
||||
// Print the map and the operands.
|
||||
// Print the map and its operands.
|
||||
printAffineMapReference(map);
|
||||
printDimAndSymbolList(bound.getStmtOperands(), map->getNumDims());
|
||||
}
|
||||
|
|
|
@ -191,7 +191,7 @@ mlfunc @complex_loops() {
|
|||
mlfunc @triang_loop(%arg0 : affineint, %arg1 : memref<?x?xi32>) {
|
||||
%c = constant 0 : i32 // CHECK: %c0_i32 = constant 0 : i32
|
||||
for %i0 = 1 to %arg0 { // CHECK: for %i0 = 1 to %arg0 {
|
||||
for %i1 = %i0 to %arg0 { // CHECK: for %i1 = %i0 to %arg0 {
|
||||
for %i1 = %i0 to %arg0 { // CHECK: for %i1 = #map1(%i0) to %arg0 {
|
||||
store %c, %arg1[%i0, %i1] : memref<?x?xi32> // CHECK: store %c0_i32, %arg1[%i0, %i1]
|
||||
} // CHECK: }
|
||||
} // CHECK: }
|
||||
|
@ -214,7 +214,7 @@ mlfunc @loop_bounds(%N : affineint) {
|
|||
%s = "foo"(%N) : (affineint) -> affineint
|
||||
// CHECK: for %i0 = %0 to %arg0
|
||||
for %i = %s to %N {
|
||||
// CHECK: for %i1 = %i0 to 0 step -1
|
||||
// CHECK: for %i1 = #map1(%i0) to 0 step -1
|
||||
for %j = %i to 0 step -1 {
|
||||
// CHECK: %1 = affine_apply #map{{.*}}(%i0, %i1)[%0]
|
||||
%w = affine_apply(d0, d1)[s0] -> (d0+d1, s0+1) (%i, %j) [%s]
|
||||
|
@ -456,12 +456,30 @@ mlfunc @mlfuncattrempty() -> ()
|
|||
}
|
||||
|
||||
// CHECK-label mlfunc @mlfuncsimplemap
|
||||
#mapsimple0 = ()[s0, s1, s2] -> (s0)
|
||||
#mapsimple1 = (d0)[s0, s1, s2] -> (s1)
|
||||
mlfunc @mlfuncsimplemap(%arg0 : affineint, %arg1 : affineint, %arg2 : affineint) -> () {
|
||||
for %i0 = 0 to #mapsimple0()[%arg0, %arg1, %arg2] { // CHECK: for %i0 = 0 to %arg0 {
|
||||
for %i1 = 0 to #mapsimple1(%i0)[%arg0, %arg1, %arg2] { // CHECK: for %i1 = 0 to %arg1 {
|
||||
%c42_i32 = constant 42 : i32
|
||||
#map_simple0 = ()[] -> (10)
|
||||
#map_simple1 = ()[s0] -> (s0)
|
||||
#map_non_simple0 = (d0)[] -> (d0)
|
||||
#map_non_simple1 = (d0)[s0] -> (d0 + s0)
|
||||
#map_non_simple2 = ()[s0, s1] -> (s0 + s1)
|
||||
#map_non_simple3 = ()[s0] -> (s0 + 3)
|
||||
mlfunc @mlfuncsimplemap(%arg0 : affineint, %arg1 : affineint) -> () {
|
||||
for %i0 = 0 to #map_simple0()[] {
|
||||
// CHECK: for %i0 = 0 to 10 {
|
||||
for %i1 = 0 to #map_simple1()[%arg1] {
|
||||
// CHECK: for %i1 = 0 to %arg1 {
|
||||
for %i2 = 0 to #map_non_simple0(%i0)[] {
|
||||
// CHECK: for %i2 = 0 to #map{{[a-z_0-9]*}}(%i0) {
|
||||
for %i3 = 0 to #map_non_simple1(%i0)[%arg1] {
|
||||
// CHECK: for %i3 = 0 to #map{{[a-z_0-9]*}}(%i0)[%arg1] {
|
||||
for %i4 = 0 to #map_non_simple2()[%arg1, %arg0] {
|
||||
// CHECK: for %i4 = 0 to #map{{[a-z_0-9]*}}()[%arg1, %arg0] {
|
||||
for %i5 = 0 to #map_non_simple3()[%arg0] {
|
||||
// CHECK: for %i5 = 0 to #map{{[a-z_0-9]*}}()[%arg0] {
|
||||
%c42_i32 = constant 42 : i32
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
|
|
|
@ -467,7 +467,7 @@ mlfunc @loop_nest_operand2() {
|
|||
mlfunc @loop_nest_operand3() {
|
||||
// UNROLL-BY-4: for %i0 = 1 to 100 step 2 {
|
||||
for %i = 1 to 100 step 2 {
|
||||
// UNROLL-BY-4: for %i1 = %i0 to #map{{[0-9]+}}(%i0) step 4 {
|
||||
// UNROLL-BY-4: for %i1 = (d0) -> (d0)(%i0) to #map{{[0-9]+}}(%i0) step 4 {
|
||||
// UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32
|
||||
// UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32
|
||||
// UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32
|
||||
|
@ -485,7 +485,7 @@ mlfunc @loop_nest_operand3() {
|
|||
mlfunc @loop_nest_operand4(%N : affineint) {
|
||||
// UNROLL-BY-4: for %i0 = 1 to 100 {
|
||||
for %i = 1 to 100 {
|
||||
// UNROLL-BY-4: for %i1 = 1 to #map{{[0-9]+}}()[%arg0] step 4 {
|
||||
// UNROLL-BY-4: for %i1 = ()[s0] -> (1)()[%arg0] to #map{{[0-9]+}}()[%arg0] step 4 {
|
||||
// UNROLL-BY-4: %0 = "foo"() : () -> i32
|
||||
// UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32
|
||||
// UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32
|
||||
|
|
Loading…
Reference in New Issue