Add support to AffineApplyOp::fold for folding dim and symbol expression results.

PiperOrigin-RevId: 251512700
This commit is contained in:
River Riddle 2019-06-04 14:12:40 -07:00 committed by Mehdi Amini
parent 23cf3b39e0
commit f59f64e838
5 changed files with 42 additions and 47 deletions

View File

@ -185,16 +185,16 @@ TEST_FUNC(matmul_as_matvec_as_affine) {
// CHECK-NOT: {{.*}} = linalg.
// CHECK: affine.for %i1 = 0 to (d0) -> (d0)(%[[M]]) {
// CHECK: affine.for %i2 = 0 to (d0) -> (d0)(%[[K]]) {
// CHECK: %4 = cmpi "eq", %i2, %c0 : index
// CHECK: %6 = load %arg2[%5, %3] : memref<?x?xf32>
// CHECK: %7 = select %4, %cst, %6 : f32
// CHECK: %3 = cmpi "eq", %i2, %c0 : index
// CHECK: %4 = load %arg2[%i1, %i0] : memref<?x?xf32>
// CHECK: %5 = select %3, %cst, %4 : f32
// CHECK-NOT: {{.*}} = linalg.
// CHECK: %9 = load %arg1[%8, %3] : memref<?x?xf32>
// CHECK: %10 = load %arg0[%5, %8] : memref<?x?xf32>
// CHECK: %11 = mulf %10, %9 : f32
// CHECK: %12 = addf %7, %11 : f32
// CHECK: %6 = load %arg1[%i2, %i0] : memref<?x?xf32>
// CHECK: %7 = load %arg0[%i1, %i2] : memref<?x?xf32>
// CHECK: %8 = mulf %7, %6 : f32
// CHECK: %9 = addf %5, %8 : f32
// CHECK-NOT: {{.*}} = linalg.
// CHECK: store %12, %arg2[%5, %3] : memref<?x?xf32>
// CHECK: store %9, %arg2[%i1, %i0] : memref<?x?xf32>
// clang-format on
}

View File

@ -83,16 +83,13 @@ TEST_FUNC(matmul_tiled_loops) {
// CHECK: affine.for %i3 = max (d0)[s0] -> (s0, d0)(%i0)[%{{.*}}] to min (d0)[s0] -> (s0, d0 + 8)(%i0)[%[[M]]] {
// CHECK: affine.for %i4 = max (d0)[s0] -> (s0, d0)(%i1)[%{{.*}}] to min (d0)[s0] -> (s0, d0 + 9)(%i1)[%[[N]]] {
// CHECK-NEXT: %{{.*}} = cmpi "eq", %i2, %{{.*}} : index
// CHECK-NEXT: %[[I3:.*]] = affine.apply (d0) -> (d0)(%i3)
// CHECK-NEXT: %[[I4:.*]] = affine.apply (d0) -> (d0)(%i4)
// CHECK-NEXT: %{{.*}} = load %arg2[%[[I3]], %[[I4]]] : memref<?x?xf32>
// CHECK-NEXT: %{{.*}} = load %arg2[%i3, %i4] : memref<?x?xf32>
// CHECK-NEXT: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : f32
// CHECK-NEXT: %[[I2:.*]] = affine.apply (d0) -> (d0)(%i2)
// CHECK-NEXT: %{{.*}} = load %arg1[%[[I2]], %[[I4]]] : memref<?x?xf32>
// CHECK-NEXT: %{{.*}} = load %arg0[%[[I3]], %[[I2]]] : memref<?x?xf32>
// CHECK-NEXT: %{{.*}} = mulf %10, %9 : f32
// CHECK-NEXT: %{{.*}} = addf %7, %11 : f32
// CHECK-NEXT: store %{{.*}}, %arg2[%[[I3]], %[[I4]]] : memref<?x?xf32>
// CHECK-NEXT: %{{.*}} = load %arg1[%i2, %i4] : memref<?x?xf32>
// CHECK-NEXT: %{{.*}} = load %arg0[%i3, %i2] : memref<?x?xf32>
// CHECK-NEXT: %{{.*}} = mulf %7, %6 : f32
// CHECK-NEXT: %{{.*}} = addf %5, %8 : f32
// CHECK-NEXT: store %{{.*}}, %arg2[%i3, %i4] : memref<?x?xf32>
// clang-format on
}
@ -112,16 +109,14 @@ TEST_FUNC(matmul_tiled_views) {
// CHECK: %[[K:.*]] = dim %arg0, 1 : memref<?x?xf32>
// CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[M]]) step 8 {
// CHECK-NEXT: affine.for %i1 = 0 to (d0) -> (d0)(%[[N]]) step 9 {
// CHECK-NEXT: %[[i0min:.*]] = affine.apply (d0) -> (d0)(%i0)
// CHECK-NEXT: %[[i0max:.*]] = affine.apply (d0) -> (d0 + 8)(%i0)
// CHECK-NEXT: %[[ri0:.*]] = linalg.range %[[i0min]]:%[[i0max]]:{{.*}} : !linalg.range
// CHECK-NEXT: %[[ri0:.*]] = linalg.range %i0:%[[i0max]]:{{.*}} : !linalg.range
// CHECK: %[[rK:.*]] = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range
// CHECK: %[[vA:.*]] = linalg.view %arg0[%[[ri0]], %[[rK]]] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
// CHECK: %[[i1min:.*]] = affine.apply (d0) -> (d0)(%i1)
// CHECK-NEXT: %[[i1max:.*]] = affine.apply (d0) -> (d0 + 9)(%i1)
// CHECK-NEXT: %[[ri1:.*]] = linalg.range %[[i1min]]:%[[i1max]]:%{{.*}} : !linalg.range
// CHECK-NEXT: %[[vB:.*]] = linalg.view %arg1[%10, %13] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
// CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%5, %13] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
// CHECK: %[[i1max:.*]] = affine.apply (d0) -> (d0 + 9)(%i1)
// CHECK-NEXT: %[[ri1:.*]] = linalg.range %i1:%[[i1max]]:%{{.*}} : !linalg.range
// CHECK-NEXT: %[[vB:.*]] = linalg.view %arg1[%7, %9] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
// CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%4, %9] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
// CHECK-NEXT: linalg.matmul(%[[vA]], %[[vB]], %[[vC]]) : !linalg.view<?x?xf32>
// clang-format on
cleanupAndPrintFunction(f);
@ -148,16 +143,14 @@ TEST_FUNC(matmul_tiled_views_as_loops) {
// CHECK: %[[K:.*]] = dim %arg0, 1 : memref<?x?xf32>
// CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[M]]) step 8 {
// CHECK-NEXT: affine.for %i1 = 0 to (d0) -> (d0)(%[[N]]) step 9 {
// CHECK-NEXT: %[[i0min:.*]] = affine.apply (d0) -> (d0)(%i0)
// CHECK-NEXT: %[[i0max:.*]] = affine.apply (d0) -> (d0 + 8)(%i0)
// CHECK-NEXT: %[[ri0:.*]] = linalg.range %[[i0min]]:%[[i0max]]:{{.*}} : !linalg.range
// CHECK-NEXT: %[[ri0:.*]] = linalg.range %i0:%[[i0max]]:{{.*}} : !linalg.range
// CHECK: %[[rK:.*]] = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range
// CHECK: %[[vA:.*]] = linalg.view %arg0[%[[ri0]], %[[rK]]] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
// CHECK: %[[i1min:.*]] = affine.apply (d0) -> (d0)(%i1)
// CHECK-NEXT: %[[i1max:.*]] = affine.apply (d0) -> (d0 + 9)(%i1)
// CHECK-NEXT: %[[ri1:.*]] = linalg.range %[[i1min]]:%[[i1max]]:%{{.*}} : !linalg.range
// CHECK-NEXT: %[[vB:.*]] = linalg.view %arg1[%10, %13] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
// CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%5, %13] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
// CHECK: %[[i1max:.*]] = affine.apply (d0) -> (d0 + 9)(%i1)
// CHECK-NEXT: %[[ri1:.*]] = linalg.range %i1:%[[i1max]]:%{{.*}} : !linalg.range
// CHECK-NEXT: %[[vB:.*]] = linalg.view %arg1[%7, %9] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
// CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%4, %9] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
// CHECK-NEXT: affine.for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0)(%[[i0max]]) {
// CHECK-NEXT: affine.for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0)(%[[i1max]]) {
// CHECK-NEXT: affine.for %i4 = 0 to (d0) -> (d0)(%[[K]]) {

View File

@ -241,7 +241,10 @@ public:
if (!result)
return failure();
results.push_back(result);
// Check if the operation was folded in place. In this case, the operation
// returns itself.
if (result.template dyn_cast<Value *>() != op->getResult(0))
results.push_back(result);
return success();
}

View File

@ -203,6 +203,15 @@ bool AffineApplyOp::isValidSymbol() {
OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
auto map = getAffineMap();
// Fold dims and symbols to existing values.
auto expr = map.getResult(0);
if (auto dim = expr.dyn_cast<AffineDimExpr>())
return getOperand(dim.getPosition());
if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
return getOperand(map.getNumDims() + sym.getPosition());
// Otherwise, default to folding the map.
SmallVector<Attribute, 1> result;
if (failed(map.constantFold(operands, result)))
return {};

View File

@ -22,9 +22,6 @@
// CHECK-DAG: [[MAP13A:#map[0-9]+]] = (d0) -> ((d0 + 6) ceildiv 8)
// CHECK-DAG: [[MAP13B:#map[0-9]+]] = (d0) -> ((d0 * 4 - 4) floordiv 3)
// Affine maps for test case: arg_used_as_dim_and_symbol
// CHECK-DAG: [[MAP14:#map[0-9]+]] = (d0) -> (d0)
// Affine maps for test case: partial_fold_map
// CHECK-DAG: [[MAP15:#map[0-9]+]] = ()[s0, s1] -> (s0 - s1)
@ -55,8 +52,7 @@ func @compose_affine_maps_1dto2d_no_symbols() {
%x1_1 = affine.apply (d0, d1) -> (d1) (%x0, %x0)
// CHECK: [[I0A:%[0-9]+]] = affine.apply [[MAP0]](%i0)
// CHECK-NEXT: [[I0B:%[0-9]+]] = affine.apply [[MAP0]](%i0)
// CHECK-NEXT: load %0{{\[}}[[I0A]], [[I0B]]{{\]}}
// CHECK-NEXT: load %0{{\[}}[[I0A]], [[I0A]]{{\]}}
%v0 = load %0[%x1_0, %x1_1] : memref<4x4xf32>
// Test load[%y, %y]
@ -65,25 +61,20 @@ func @compose_affine_maps_1dto2d_no_symbols() {
%y1_1 = affine.apply (d0, d1) -> (d1) (%y0, %y0)
// CHECK-NEXT: [[I1A:%[0-9]+]] = affine.apply [[MAP1]](%i0)
// CHECK-NEXT: [[I1B:%[0-9]+]] = affine.apply [[MAP1]](%i0)
// CHECK-NEXT: load %0{{\[}}[[I1A]], [[I1B]]{{\]}}
// CHECK-NEXT: load %0{{\[}}[[I1A]], [[I1A]]{{\]}}
%v1 = load %0[%y1_0, %y1_1] : memref<4x4xf32>
// Test load[%x, %y]
%xy_0 = affine.apply (d0, d1) -> (d0) (%x0, %y0)
%xy_1 = affine.apply (d0, d1) -> (d1) (%x0, %y0)
// CHECK-NEXT: [[I2A:%[0-9]+]] = affine.apply [[MAP0]](%i0)
// CHECK-NEXT: [[I2B:%[0-9]+]] = affine.apply [[MAP1]](%i0)
// CHECK-NEXT: load %0{{\[}}[[I2A]], [[I2B]]{{\]}}
// CHECK-NEXT: load %0{{\[}}[[I0A]], [[I1A]]{{\]}}
%v2 = load %0[%xy_0, %xy_1] : memref<4x4xf32>
// Test load[%y, %x]
%yx_0 = affine.apply (d0, d1) -> (d0) (%y0, %x0)
%yx_1 = affine.apply (d0, d1) -> (d1) (%y0, %x0)
// CHECK-NEXT: [[I3A:%[0-9]+]] = affine.apply [[MAP1]](%i0)
// CHECK-NEXT: [[I3B:%[0-9]+]] = affine.apply [[MAP0]](%i0)
// CHECK-NEXT: load %0{{\[}}[[I3A]], [[I3B]]{{\]}}
// CHECK-NEXT: load %0{{\[}}[[I1A]], [[I0A]]{{\]}}
%v3 = load %0[%yx_0, %yx_1] : memref<4x4xf32>
}
return
@ -238,8 +229,7 @@ func @arg_used_as_dim_and_symbol(%arg0: memref<100x100xf32>, %arg1: index) {
(%i0, %i1)[%arg1, %c9]
%4 = affine.apply (d0, d1, d3) -> (d3 - (d0 + d1))
(%arg1, %c9, %3)
// CHECK: [[I0:%[0-9]+]] = affine.apply [[MAP14]](%i1)
// CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I0]], %arg1{{\]}}
// CHECK: load %{{[0-9]+}}{{\[}}%i1, %arg1{{\]}}
%5 = load %1[%4, %arg1] : memref<100x100xf32, 1>
}
}