[MLIR] TilingInterface: Avoid map when tile divides iteration domain

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D131080
This commit is contained in:
lorenzo chelini 2022-07-19 16:13:22 +02:00
parent ea50901aa9
commit 954de25a92
2 changed files with 28 additions and 10 deletions

View File

@ -90,6 +90,20 @@ static bool isPermutation(ArrayRef<unsigned> interchange) {
// TileUsingSCFForOp pattern implementation.
//===----------------------------------------------------------------------===//
// Check if `stride` evenly divides the trip count `size - offset`.
static bool tileDividesIterationDomain(Range loopRange) {
Optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
if (!offsetAsInt)
return false;
Optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
if (!sizeAsInt)
return false;
Optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
if (!strideAsInt)
return false;
return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
}
/// Generate an empty loop nest that represents the tiled loop nest shell.
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
@ -134,9 +148,15 @@ generateTileLoopNest(OpBuilder &builder, Location loc,
loc, offset, size, tileSizeVals[loopRange.index()], ValueRange{},
[&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
ValueRange /*iterArgs*/) {
Value boundedTileSize = builder.create<AffineMinOp>(
bodyLoc, minMap,
ValueRange{iv, tileSizeVals[loopRange.index()], size});
bool canAvoidMap = tileDividesIterationDomain(
Range{loopRange.value().offset, loopRange.value().size,
tileSizeVals[loopRange.index()]});
Value boundedTileSize =
(canAvoidMap)
? tileSizeVals[loopRange.index()]
: builder.create<AffineMinOp>(
bodyLoc, minMap,
ValueRange{iv, tileSizeVals[loopRange.index()], size});
sizes[loopRange.index()] = boundedTileSize;
builder.create<scf::YieldOp>(loc);
});

View File

@ -101,7 +101,6 @@ func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200x
return %0#0, %0#1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
// CHECK: func.func @multi_result(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
@ -116,20 +115,19 @@ func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200x
// CHECK: %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[C128]]]
// CHECK: %[[INNER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C20]]
// CHECK-SAME: iter_args(%[[ARG3:[a-zA-Z0-9]+]] = %[[ARG1]], %[[ARG4:[a-zA-Z0-9]+]] = %[[ARG2]])
// CHECK: %[[TS_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[C300]]]
// CHECK-DAG: %[[ARG_TILE:.+]] = tensor.extract_slice %[[ARG0]]
// CHECK-SAME: [%[[IV0]], 0, %[[IV1]]] [%[[TS_Y]], 200, %[[TS_X]]] [1, 1, 1]
// CHECK-SAME: [%[[IV0]], 0, %[[IV1]]] [%[[TS_Y]], 200, 20] [1, 1, 1]
// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ARG3]]
// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], %[[TS_X]], 200] [1, 1, 1]
// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], 20, 200] [1, 1, 1]
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ARG4]]
// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [%[[TS_X]], %[[TS_Y]], 200] [1, 1, 1]
// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [20, %[[TS_Y]], 200] [1, 1, 1]
// CHECK: %[[RESULT_TILE:.+]]:2 = linalg.generic
// CHECK-SAME: ins(%[[ARG_TILE]] :
// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
// CHECK: %[[UPDATE0:.+]] = tensor.insert_slice %[[RESULT_TILE]]#0 into %[[ARG3]]
// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], %[[TS_X]], 200] [1, 1, 1]
// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], 20, 200] [1, 1, 1]
// CHECK: %[[UPDATE1:.+]] = tensor.insert_slice %[[RESULT_TILE]]#1 into %[[ARG4]]
// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [%[[TS_X]], %[[TS_Y]], 200] [1, 1, 1]
// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [20, %[[TS_Y]], 200] [1, 1, 1]
// CHECK: scf.yield %[[UPDATE0]], %[[UPDATE1]]
// CHECK: scf.yield %[[INNER]]#0, %[[INNER]]#1
// CHECK: return %[[OUTER]]#0, %[[OUTER]]#1