[mlir][linalg] Simplify slice dim computation for fusion on tensors (NFC).

Compute the tiled producer slice dimensions directly starting from the consumer not using the producer at all.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D110147
This commit is contained in:
Tobias Gysi 2021-09-21 15:09:32 +00:00
parent 9072f1b5f8
commit 8b5236def5
2 changed files with 54 additions and 52 deletions

View File

@ -30,61 +30,26 @@ using namespace linalg;
// StructuredOp specific helpers.
//===----------------------------------------------------------------------===//
/// Relate the producer to the consumer loop iterations that access the same
/// producer result element:
/// consumerToProducerLoops =
/// inverse(producerIndexingMap).compose(consumerIndexingMap).
/// Return `consumerToProducerLoops` or none if the inversion fails.
static Optional<AffineMap>
getConsumerToProducerLoopsMap(AffineMap producerIndexingMap,
AffineMap consumerIndexingMap) {
assert(consumerIndexingMap.getNumResults() ==
producerIndexingMap.getNumResults() &&
"expect the number of indexing map results to match");
// Ensure the producer indexing map is a projected permutation.
if (!producerIndexingMap.isProjectedPermutation())
return None;
AffineMap inverseIndexingMap =
inverseAndBroadcastProjectedPermuation(producerIndexingMap);
return inverseIndexingMap.compose(consumerIndexingMap);
}
/// Returns the producer result slice dimensions tiled by the tile loop nest or
/// an empty vector if `getConsumerToProducerLoopsMap` returns none.
// TODO: replace by Fourier-Motzkin and/or compute starting from consumer.
SmallVector<int64_t> getTiledSliceDims(OpResult producerResult,
OpOperand *consumerOperand,
/// Returns the tiled slice dimensions given the tiled consumer loop dimensions.
/// The slice defines a hyper rectangular iteration space and fusing the
/// producer is always possible. However, depending on the consumer indexing
/// map, not all slice elements may be consumed and the tiles may overlap. In
/// these cases, fusion introduces redundant computation.
SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand,
ArrayRef<int64_t> tiledLoopDims) {
// Get the consumer operand indexing map.
LinalgOp consumerOp = consumerOperand->getOwner();
LinalgOp producerOp = producerResult.getOwner();
OpOperand *opOperand =
producerOp.getOutputOperand(producerResult.getResultNumber());
AffineMap indexingMap = consumerOp.getTiedIndexingMap(consumerOperand);
// Compute the `consumerToProducerLoopsMap` and exit if the computation fails.
AffineMap producerIndexingMap = producerOp.getTiedIndexingMap(opOperand);
Optional<AffineMap> consumerToProducerLoopsMap =
getConsumerToProducerLoopsMap(
producerIndexingMap, consumerOp.getTiedIndexingMap(consumerOperand));
if (!consumerToProducerLoopsMap.hasValue())
return {};
// Compute the set of tiled producer loops.
DenseSet<int64_t> tiledProducerLoops;
for (auto en : enumerate(consumerToProducerLoopsMap->getResults())) {
for (int64_t dim : tiledLoopDims) {
if (en.value().isFunctionOfDim(dim))
tiledProducerLoops.insert(en.index());
// Search the slice dimensions tiled by a tile loop dimension.
DenseSet<int64_t> tiledSliceDims;
for (auto en : enumerate(indexingMap.getResults())) {
for (auto tiledLoopDim : tiledLoopDims) {
if (en.value().isFunctionOfDim(tiledLoopDim))
tiledSliceDims.insert(en.index());
}
}
// Compute the slice dimensions for the tiled producer loops.
SmallVector<int64_t> tiledSliceDims;
for (auto en : enumerate(producerIndexingMap.getResults())) {
auto dimExpr = en.value().dyn_cast<AffineDimExpr>();
if (dimExpr && tiledProducerLoops.count(dimExpr.getPosition()) != 0)
tiledSliceDims.push_back(en.index());
}
return tiledSliceDims;
return {tiledSliceDims.begin(), tiledSliceDims.end()};
}
/// Returns the producer fused in place of `sliceOp`. Tile the producer operands
@ -332,9 +297,10 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
if (!producerResult || !isa<LinalgOp>(producerResult.getOwner()))
return failure();
// Compute the slice dimensions tiled by `tileLoopNest`.
// Compute the tiled producer slice dimensions given the tiled root operation
// loop dimensions `loopDims`.
SmallVector<int64_t> tiledSliceDims =
getTiledSliceDims(producerResult, rootOpOperand, loopDims);
getTiledSliceDims(rootOpOperand, loopDims);
if (tiledSliceDims.empty())
return failure();

View File

@ -230,3 +230,39 @@ builtin.func @fuse_indexed(%arg0: tensor<24x12xi32>,
return %1 : tensor<24x25xi32>
}
// -----
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (8, -d0 - d1 + 18)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, -d1 - d2 + 18)>
#map0 = affine_map<(d0, d1) -> (d0, d0 + d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: fuse_non_rectangular
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<10x18xf32>
func @fuse_non_rectangular(%arg0: tensor<10x18xf32>,
%arg1: tensor<10x8xf32>) -> tensor<10x8xf32> {
%cst = constant 0.000000e+00 : f32
%0 = linalg.fill(%cst, %arg0) : f32, tensor<10x18xf32> -> tensor<10x18xf32>
// CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] =
// CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] =
// Compute producer on a hyper rectangular bounding box. Along the second dimenson,
// the offset is set to the sum of the induction variables and the upper bound
// to either eight (sum of the tile sizes) or eighteen (sum of the domain sizes)
// minus the induction variables.
// CHECK: %[[SUM:.*]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV0]]
// CHECK: %[[TS1:.*]] = affine.min #[[MAP1]](%[[IV1]], %[[IV0]]
// CHECK: %[[UB1:.*]] = affine.min #[[MAP2]](%[[TS1]], %[[IV1]], %[[IV0]]
// CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
// CHECK-SAME: %[[IV1]], %[[SUM]]
// CHECK-SAME: , %[[UB1]]
// CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
%1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<10x18xf32>) outs(%arg1 : tensor<10x8xf32>) {
^bb0(%arg2: f32, %arg3: f32): // no predecessors
%2 = addf %arg2, %arg3 : f32
linalg.yield %2 : f32
} -> tensor<10x8xf32>
return %1 : tensor<10x8xf32>
}