[mlir] Handle linalg.index correctly in TilingInterface

The existing implementation of the TilingInterface for Linalg ops was not
modifying the `linalg.index` ops contained within other Linalg ops (they need
to be summed up with the values of respective tile loop induction variables),
which led to the interface-based tiling being incorrect for any Linalg op with
index semantics.

In the process, fix the function performing the index offsetting to use the
pattern rewriter API instead of RAUW as it is being called from patterns and
may mess up the internal state of the rewriter. Also rename the function to
clearly catch all uses.

Depends On D129365

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D129366
This commit is contained in:
Alex Zinenko 2022-07-08 13:49:47 +02:00
parent e15b855e09
commit 81b62f7feb
9 changed files with 73 additions and 26 deletions

View File

@ -243,10 +243,11 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc,
ArrayRef<Value> sizeBounds,
bool omitPartialTileCheck);
/// Add the tile loop induction variables `ivs` to the IndexOp results found in
/// the body of the `tiledOp` to account for the tile offset.
void addTileLoopIvsToIndexOpResults(OpBuilder &b, LinalgOp tiledOp,
ArrayRef<Value> ivs);
/// Add the specified offsets to any `linalg.index` ops contained in the given
/// `linalgOp`. The offsets are provided in the same order as iteration space
/// dimensions. Null offests are assumed to be zero.
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef<Value> offests);
void offsetIndices(RewriterBase &b, LinalgOp linalgOp, ArrayRef<Value> offests);
using FusableOpDependencesTy = llvm::MapVector<
Operation *,

View File

@ -170,7 +170,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
SmallVector<Value> allIvs;
llvm::transform(loopRanges, std::back_inserter(allIvs),
[](Range range) { return range.offset; });
addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs);
offsetIndices(b, clonedOp, allIvs);
return clonedOp;
}

View File

@ -186,7 +186,7 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands);
// Shift all IndexOp results by the tile offset.
addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs);
offsetIndices(b, clonedOp, allIvs);
return clonedOp;
}

View File

@ -139,8 +139,7 @@ std::pair<LinalgOp, LinalgOp> linalg::splitOp(RewriterBase &rewriter,
SmallVector<Value> ivAdditions;
ivAdditions.resize(splitIterationSpace.size());
ivAdditions[dimension] = splitPointValue;
linalg::addTileLoopIvsToIndexOpResults(builder, cast<LinalgOp>(second),
ivAdditions);
linalg::offsetIndices(rewriter, cast<LinalgOp>(second), ivAdditions);
// Replace the original op with the results of the two newly created ops.
rewriter.replaceOp(op, secondResults);

View File

@ -80,7 +80,7 @@ void mlir::linalg::transformIndexOps(
continue;
en.value() = ivs[rangeIndex->second];
}
addTileLoopIvsToIndexOpResults(b, op, allIvs);
offsetIndices(b, op, allIvs);
}
/// Asserts that the given index-typed value is strictly positive. If the value

View File

@ -71,9 +71,10 @@ struct LinalgOpTilingInterface
Location loc = op->getLoc();
LinalgOp linalgOp = cast<LinalgOp>(op);
SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
SmallVector<Value> offsetValues =
getValueOrCreateConstantIndexOp(b, loc, offsets);
SmallVector<Value, 4> tiledOperands = makeTiledShapes(
b, loc, linalgOp, valuesToTile,
getValueOrCreateConstantIndexOp(b, loc, offsets),
b, loc, linalgOp, valuesToTile, offsetValues,
getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true);
SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range(
@ -83,6 +84,7 @@ struct LinalgOpTilingInterface
Operation *tiledOp =
linalgOp.clone(b, loc, resultTensorTypes, tiledOperands);
offsetIndices(b, cast<LinalgOp>(tiledOp), offsetValues);
return {tiledOp};
}

View File

@ -1048,21 +1048,29 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &b, Location loc,
return tiledShapes;
}
void addTileLoopIvsToIndexOpResults(OpBuilder &b, LinalgOp tiledOp,
ArrayRef<Value> ivs) {
if (tiledOp.hasIndexSemantics()) {
for (IndexOp indexOp : tiledOp.getBlock()->getOps<IndexOp>()) {
if (ivs[indexOp.dim()] == nullptr)
continue;
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointAfter(indexOp);
AffineExpr index, offset;
bindDims(b.getContext(), index, offset);
AffineApplyOp applyOp = makeComposedAffineApply(
b, indexOp.getLoc(), index + offset,
ValueRange{indexOp.getResult(), ivs[indexOp.dim()]});
indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp);
}
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef<Value> offsets) {
IRRewriter rewriter(b);
offsetIndices(rewriter, linalgOp, offsets);
}
void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
ArrayRef<Value> offsets) {
if (!linalgOp.hasIndexSemantics())
return;
for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) {
if (indexOp.dim() >= offsets.size() || offsets[indexOp.dim()] == nullptr)
continue;
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointAfter(indexOp);
AffineExpr index, offset;
bindDims(b.getContext(), index, offset);
AffineApplyOp applyOp = makeComposedAffineApply(
b, indexOp.getLoc(), index + offset,
ValueRange{indexOp.getResult(), offsets[indexOp.dim()]});
b.replaceOpWithIf(indexOp, applyOp.getResult(), [&](OpOperand &use) {
return use.getOwner() != applyOp;
});
}
}

View File

@ -192,3 +192,37 @@ func.func @conv2D(%arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,
// CHECK-SAME: outs(%[[INIT_TILE]] :
// CHECK: tensor.insert_slice %[[CONV_TILE]] into %[[INIT2]]
// CHECK-SAME: [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]]
// -----
// CHECK: #[[$MAP_ADD:.+]] = affine_map<(d0, d1) -> (d0 + d1)>
// CHECK-LABEL: @indexed_semantics
func.func @indexed_semantics(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
// Check that we correctly amend "linalg.index" results.
// CHECK: scf.for %[[I0:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
// CHECK: scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
%0 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
{__internal_linalg_transform__ = "indexed_semantics"}
ins(%arg0: tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) {
^bb0(%arg2: f32, %arg3: f32):
// CHECK: %[[INDEX0:.+]] = linalg.index 0
// CHECK: %[[INDEX0_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX0]], %[[I0]])
%1 = linalg.index 0 : index
// CHECK: %[[INDEX1:.+]] = linalg.index 1
// CHECK: %[[INDEX1_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX1]], %[[I1]])
%2 = linalg.index 1 : index
// CHECK: arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]]
%3 = arith.addi %1, %2 : index
%4 = arith.index_cast %3 : index to i64
%5 = arith.uitofp %4 : i64 to f32
%6 = arith.addf %5, %arg2 : f32
linalg.yield %6 : f32
} -> (tensor<?x?xf32>)
return %0 : tensor<?x?xf32>
}

View File

@ -171,6 +171,9 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
// 4. Tiling 2D conv op.
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
context, {0, 0, 0, 0, 10, 20, 30}, "simple_conv", patterns);
// 5. Tiling a simple op with `linalg.index` inside.
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
context, {10, 20}, "indexed_semantics", patterns);
return;
}
if (testTileConsumerAndFuseProducer) {