forked from OSchip/llvm-project
Add support to AffineApplyOp::fold for folding dim and symbol expression results.
PiperOrigin-RevId: 251512700
This commit is contained in:
parent
23cf3b39e0
commit
f59f64e838
|
@ -185,16 +185,16 @@ TEST_FUNC(matmul_as_matvec_as_affine) {
|
|||
// CHECK-NOT: {{.*}} = linalg.
|
||||
// CHECK: affine.for %i1 = 0 to (d0) -> (d0)(%[[M]]) {
|
||||
// CHECK: affine.for %i2 = 0 to (d0) -> (d0)(%[[K]]) {
|
||||
// CHECK: %4 = cmpi "eq", %i2, %c0 : index
|
||||
// CHECK: %6 = load %arg2[%5, %3] : memref<?x?xf32>
|
||||
// CHECK: %7 = select %4, %cst, %6 : f32
|
||||
// CHECK: %3 = cmpi "eq", %i2, %c0 : index
|
||||
// CHECK: %4 = load %arg2[%i1, %i0] : memref<?x?xf32>
|
||||
// CHECK: %5 = select %3, %cst, %4 : f32
|
||||
// CHECK-NOT: {{.*}} = linalg.
|
||||
// CHECK: %9 = load %arg1[%8, %3] : memref<?x?xf32>
|
||||
// CHECK: %10 = load %arg0[%5, %8] : memref<?x?xf32>
|
||||
// CHECK: %11 = mulf %10, %9 : f32
|
||||
// CHECK: %12 = addf %7, %11 : f32
|
||||
// CHECK: %6 = load %arg1[%i2, %i0] : memref<?x?xf32>
|
||||
// CHECK: %7 = load %arg0[%i1, %i2] : memref<?x?xf32>
|
||||
// CHECK: %8 = mulf %7, %6 : f32
|
||||
// CHECK: %9 = addf %5, %8 : f32
|
||||
// CHECK-NOT: {{.*}} = linalg.
|
||||
// CHECK: store %12, %arg2[%5, %3] : memref<?x?xf32>
|
||||
// CHECK: store %9, %arg2[%i1, %i0] : memref<?x?xf32>
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
|
|
|
@ -83,16 +83,13 @@ TEST_FUNC(matmul_tiled_loops) {
|
|||
// CHECK: affine.for %i3 = max (d0)[s0] -> (s0, d0)(%i0)[%{{.*}}] to min (d0)[s0] -> (s0, d0 + 8)(%i0)[%[[M]]] {
|
||||
// CHECK: affine.for %i4 = max (d0)[s0] -> (s0, d0)(%i1)[%{{.*}}] to min (d0)[s0] -> (s0, d0 + 9)(%i1)[%[[N]]] {
|
||||
// CHECK-NEXT: %{{.*}} = cmpi "eq", %i2, %{{.*}} : index
|
||||
// CHECK-NEXT: %[[I3:.*]] = affine.apply (d0) -> (d0)(%i3)
|
||||
// CHECK-NEXT: %[[I4:.*]] = affine.apply (d0) -> (d0)(%i4)
|
||||
// CHECK-NEXT: %{{.*}} = load %arg2[%[[I3]], %[[I4]]] : memref<?x?xf32>
|
||||
// CHECK-NEXT: %{{.*}} = load %arg2[%i3, %i4] : memref<?x?xf32>
|
||||
// CHECK-NEXT: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : f32
|
||||
// CHECK-NEXT: %[[I2:.*]] = affine.apply (d0) -> (d0)(%i2)
|
||||
// CHECK-NEXT: %{{.*}} = load %arg1[%[[I2]], %[[I4]]] : memref<?x?xf32>
|
||||
// CHECK-NEXT: %{{.*}} = load %arg0[%[[I3]], %[[I2]]] : memref<?x?xf32>
|
||||
// CHECK-NEXT: %{{.*}} = mulf %10, %9 : f32
|
||||
// CHECK-NEXT: %{{.*}} = addf %7, %11 : f32
|
||||
// CHECK-NEXT: store %{{.*}}, %arg2[%[[I3]], %[[I4]]] : memref<?x?xf32>
|
||||
// CHECK-NEXT: %{{.*}} = load %arg1[%i2, %i4] : memref<?x?xf32>
|
||||
// CHECK-NEXT: %{{.*}} = load %arg0[%i3, %i2] : memref<?x?xf32>
|
||||
// CHECK-NEXT: %{{.*}} = mulf %7, %6 : f32
|
||||
// CHECK-NEXT: %{{.*}} = addf %5, %8 : f32
|
||||
// CHECK-NEXT: store %{{.*}}, %arg2[%i3, %i4] : memref<?x?xf32>
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
|
@ -112,16 +109,14 @@ TEST_FUNC(matmul_tiled_views) {
|
|||
// CHECK: %[[K:.*]] = dim %arg0, 1 : memref<?x?xf32>
|
||||
// CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[M]]) step 8 {
|
||||
// CHECK-NEXT: affine.for %i1 = 0 to (d0) -> (d0)(%[[N]]) step 9 {
|
||||
// CHECK-NEXT: %[[i0min:.*]] = affine.apply (d0) -> (d0)(%i0)
|
||||
// CHECK-NEXT: %[[i0max:.*]] = affine.apply (d0) -> (d0 + 8)(%i0)
|
||||
// CHECK-NEXT: %[[ri0:.*]] = linalg.range %[[i0min]]:%[[i0max]]:{{.*}} : !linalg.range
|
||||
// CHECK-NEXT: %[[ri0:.*]] = linalg.range %i0:%[[i0max]]:{{.*}} : !linalg.range
|
||||
// CHECK: %[[rK:.*]] = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range
|
||||
// CHECK: %[[vA:.*]] = linalg.view %arg0[%[[ri0]], %[[rK]]] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
// CHECK: %[[i1min:.*]] = affine.apply (d0) -> (d0)(%i1)
|
||||
// CHECK-NEXT: %[[i1max:.*]] = affine.apply (d0) -> (d0 + 9)(%i1)
|
||||
// CHECK-NEXT: %[[ri1:.*]] = linalg.range %[[i1min]]:%[[i1max]]:%{{.*}} : !linalg.range
|
||||
// CHECK-NEXT: %[[vB:.*]] = linalg.view %arg1[%10, %13] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
// CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%5, %13] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
// CHECK: %[[i1max:.*]] = affine.apply (d0) -> (d0 + 9)(%i1)
|
||||
// CHECK-NEXT: %[[ri1:.*]] = linalg.range %i1:%[[i1max]]:%{{.*}} : !linalg.range
|
||||
// CHECK-NEXT: %[[vB:.*]] = linalg.view %arg1[%7, %9] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
// CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%4, %9] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
// CHECK-NEXT: linalg.matmul(%[[vA]], %[[vB]], %[[vC]]) : !linalg.view<?x?xf32>
|
||||
// clang-format on
|
||||
cleanupAndPrintFunction(f);
|
||||
|
@ -148,16 +143,14 @@ TEST_FUNC(matmul_tiled_views_as_loops) {
|
|||
// CHECK: %[[K:.*]] = dim %arg0, 1 : memref<?x?xf32>
|
||||
// CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[M]]) step 8 {
|
||||
// CHECK-NEXT: affine.for %i1 = 0 to (d0) -> (d0)(%[[N]]) step 9 {
|
||||
// CHECK-NEXT: %[[i0min:.*]] = affine.apply (d0) -> (d0)(%i0)
|
||||
// CHECK-NEXT: %[[i0max:.*]] = affine.apply (d0) -> (d0 + 8)(%i0)
|
||||
// CHECK-NEXT: %[[ri0:.*]] = linalg.range %[[i0min]]:%[[i0max]]:{{.*}} : !linalg.range
|
||||
// CHECK-NEXT: %[[ri0:.*]] = linalg.range %i0:%[[i0max]]:{{.*}} : !linalg.range
|
||||
// CHECK: %[[rK:.*]] = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range
|
||||
// CHECK: %[[vA:.*]] = linalg.view %arg0[%[[ri0]], %[[rK]]] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
// CHECK: %[[i1min:.*]] = affine.apply (d0) -> (d0)(%i1)
|
||||
// CHECK-NEXT: %[[i1max:.*]] = affine.apply (d0) -> (d0 + 9)(%i1)
|
||||
// CHECK-NEXT: %[[ri1:.*]] = linalg.range %[[i1min]]:%[[i1max]]:%{{.*}} : !linalg.range
|
||||
// CHECK-NEXT: %[[vB:.*]] = linalg.view %arg1[%10, %13] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
// CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%5, %13] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
// CHECK: %[[i1max:.*]] = affine.apply (d0) -> (d0 + 9)(%i1)
|
||||
// CHECK-NEXT: %[[ri1:.*]] = linalg.range %i1:%[[i1max]]:%{{.*}} : !linalg.range
|
||||
// CHECK-NEXT: %[[vB:.*]] = linalg.view %arg1[%7, %9] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
// CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%4, %9] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
// CHECK-NEXT: affine.for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0)(%[[i0max]]) {
|
||||
// CHECK-NEXT: affine.for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0)(%[[i1max]]) {
|
||||
// CHECK-NEXT: affine.for %i4 = 0 to (d0) -> (d0)(%[[K]]) {
|
||||
|
|
|
@ -241,7 +241,10 @@ public:
|
|||
if (!result)
|
||||
return failure();
|
||||
|
||||
results.push_back(result);
|
||||
// Check if the operation was folded in place. In this case, the operation
|
||||
// returns itself.
|
||||
if (result.template dyn_cast<Value *>() != op->getResult(0))
|
||||
results.push_back(result);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -203,6 +203,15 @@ bool AffineApplyOp::isValidSymbol() {
|
|||
|
||||
OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto map = getAffineMap();
|
||||
|
||||
// Fold dims and symbols to existing values.
|
||||
auto expr = map.getResult(0);
|
||||
if (auto dim = expr.dyn_cast<AffineDimExpr>())
|
||||
return getOperand(dim.getPosition());
|
||||
if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
|
||||
return getOperand(map.getNumDims() + sym.getPosition());
|
||||
|
||||
// Otherwise, default to folding the map.
|
||||
SmallVector<Attribute, 1> result;
|
||||
if (failed(map.constantFold(operands, result)))
|
||||
return {};
|
||||
|
|
|
@ -22,9 +22,6 @@
|
|||
// CHECK-DAG: [[MAP13A:#map[0-9]+]] = (d0) -> ((d0 + 6) ceildiv 8)
|
||||
// CHECK-DAG: [[MAP13B:#map[0-9]+]] = (d0) -> ((d0 * 4 - 4) floordiv 3)
|
||||
|
||||
// Affine maps for test case: arg_used_as_dim_and_symbol
|
||||
// CHECK-DAG: [[MAP14:#map[0-9]+]] = (d0) -> (d0)
|
||||
|
||||
// Affine maps for test case: partial_fold_map
|
||||
// CHECK-DAG: [[MAP15:#map[0-9]+]] = ()[s0, s1] -> (s0 - s1)
|
||||
|
||||
|
@ -55,8 +52,7 @@ func @compose_affine_maps_1dto2d_no_symbols() {
|
|||
%x1_1 = affine.apply (d0, d1) -> (d1) (%x0, %x0)
|
||||
|
||||
// CHECK: [[I0A:%[0-9]+]] = affine.apply [[MAP0]](%i0)
|
||||
// CHECK-NEXT: [[I0B:%[0-9]+]] = affine.apply [[MAP0]](%i0)
|
||||
// CHECK-NEXT: load %0{{\[}}[[I0A]], [[I0B]]{{\]}}
|
||||
// CHECK-NEXT: load %0{{\[}}[[I0A]], [[I0A]]{{\]}}
|
||||
%v0 = load %0[%x1_0, %x1_1] : memref<4x4xf32>
|
||||
|
||||
// Test load[%y, %y]
|
||||
|
@ -65,25 +61,20 @@ func @compose_affine_maps_1dto2d_no_symbols() {
|
|||
%y1_1 = affine.apply (d0, d1) -> (d1) (%y0, %y0)
|
||||
|
||||
// CHECK-NEXT: [[I1A:%[0-9]+]] = affine.apply [[MAP1]](%i0)
|
||||
// CHECK-NEXT: [[I1B:%[0-9]+]] = affine.apply [[MAP1]](%i0)
|
||||
// CHECK-NEXT: load %0{{\[}}[[I1A]], [[I1B]]{{\]}}
|
||||
// CHECK-NEXT: load %0{{\[}}[[I1A]], [[I1A]]{{\]}}
|
||||
%v1 = load %0[%y1_0, %y1_1] : memref<4x4xf32>
|
||||
|
||||
// Test load[%x, %y]
|
||||
%xy_0 = affine.apply (d0, d1) -> (d0) (%x0, %y0)
|
||||
%xy_1 = affine.apply (d0, d1) -> (d1) (%x0, %y0)
|
||||
|
||||
// CHECK-NEXT: [[I2A:%[0-9]+]] = affine.apply [[MAP0]](%i0)
|
||||
// CHECK-NEXT: [[I2B:%[0-9]+]] = affine.apply [[MAP1]](%i0)
|
||||
// CHECK-NEXT: load %0{{\[}}[[I2A]], [[I2B]]{{\]}}
|
||||
// CHECK-NEXT: load %0{{\[}}[[I0A]], [[I1A]]{{\]}}
|
||||
%v2 = load %0[%xy_0, %xy_1] : memref<4x4xf32>
|
||||
|
||||
// Test load[%y, %x]
|
||||
%yx_0 = affine.apply (d0, d1) -> (d0) (%y0, %x0)
|
||||
%yx_1 = affine.apply (d0, d1) -> (d1) (%y0, %x0)
|
||||
// CHECK-NEXT: [[I3A:%[0-9]+]] = affine.apply [[MAP1]](%i0)
|
||||
// CHECK-NEXT: [[I3B:%[0-9]+]] = affine.apply [[MAP0]](%i0)
|
||||
// CHECK-NEXT: load %0{{\[}}[[I3A]], [[I3B]]{{\]}}
|
||||
// CHECK-NEXT: load %0{{\[}}[[I1A]], [[I0A]]{{\]}}
|
||||
%v3 = load %0[%yx_0, %yx_1] : memref<4x4xf32>
|
||||
}
|
||||
return
|
||||
|
@ -238,8 +229,7 @@ func @arg_used_as_dim_and_symbol(%arg0: memref<100x100xf32>, %arg1: index) {
|
|||
(%i0, %i1)[%arg1, %c9]
|
||||
%4 = affine.apply (d0, d1, d3) -> (d3 - (d0 + d1))
|
||||
(%arg1, %c9, %3)
|
||||
// CHECK: [[I0:%[0-9]+]] = affine.apply [[MAP14]](%i1)
|
||||
// CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I0]], %arg1{{\]}}
|
||||
// CHECK: load %{{[0-9]+}}{{\[}}%i1, %arg1{{\]}}
|
||||
%5 = load %1[%4, %arg1] : memref<100x100xf32, 1>
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue