forked from OSchip/llvm-project
[mlir] Handle linalg.index correctly in TilingInterface
The existing implementation of the TilingInterface for Linalg ops was not modifying the `linalg.index` ops contained within other Linalg ops (they need to be summed up with the values of respective tile loop induction variables), which led to the interface-based tiling being incorrect for any Linalg op with index semantics. In the process, fix the function performing the index offsetting to use the pattern rewriter API instead of RAUW as it is being called from patterns and may mess up the internal state of the rewriter. Also rename the function to clearly catch all uses. Depends On D129365 Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D129366
This commit is contained in:
parent
e15b855e09
commit
81b62f7feb
|
@ -243,10 +243,11 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc,
|
|||
ArrayRef<Value> sizeBounds,
|
||||
bool omitPartialTileCheck);
|
||||
|
||||
/// Add the tile loop induction variables `ivs` to the IndexOp results found in
|
||||
/// the body of the `tiledOp` to account for the tile offset.
|
||||
void addTileLoopIvsToIndexOpResults(OpBuilder &b, LinalgOp tiledOp,
|
||||
ArrayRef<Value> ivs);
|
||||
/// Add the specified offsets to any `linalg.index` ops contained in the given
|
||||
/// `linalgOp`. The offsets are provided in the same order as iteration space
|
||||
/// dimensions. Null offests are assumed to be zero.
|
||||
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef<Value> offests);
|
||||
void offsetIndices(RewriterBase &b, LinalgOp linalgOp, ArrayRef<Value> offests);
|
||||
|
||||
using FusableOpDependencesTy = llvm::MapVector<
|
||||
Operation *,
|
||||
|
|
|
@ -170,7 +170,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
|
|||
SmallVector<Value> allIvs;
|
||||
llvm::transform(loopRanges, std::back_inserter(allIvs),
|
||||
[](Range range) { return range.offset; });
|
||||
addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs);
|
||||
offsetIndices(b, clonedOp, allIvs);
|
||||
|
||||
return clonedOp;
|
||||
}
|
||||
|
|
|
@ -186,7 +186,7 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
|
|||
LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands);
|
||||
|
||||
// Shift all IndexOp results by the tile offset.
|
||||
addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs);
|
||||
offsetIndices(b, clonedOp, allIvs);
|
||||
|
||||
return clonedOp;
|
||||
}
|
||||
|
|
|
@ -139,8 +139,7 @@ std::pair<LinalgOp, LinalgOp> linalg::splitOp(RewriterBase &rewriter,
|
|||
SmallVector<Value> ivAdditions;
|
||||
ivAdditions.resize(splitIterationSpace.size());
|
||||
ivAdditions[dimension] = splitPointValue;
|
||||
linalg::addTileLoopIvsToIndexOpResults(builder, cast<LinalgOp>(second),
|
||||
ivAdditions);
|
||||
linalg::offsetIndices(rewriter, cast<LinalgOp>(second), ivAdditions);
|
||||
|
||||
// Replace the original op with the results of the two newly created ops.
|
||||
rewriter.replaceOp(op, secondResults);
|
||||
|
|
|
@ -80,7 +80,7 @@ void mlir::linalg::transformIndexOps(
|
|||
continue;
|
||||
en.value() = ivs[rangeIndex->second];
|
||||
}
|
||||
addTileLoopIvsToIndexOpResults(b, op, allIvs);
|
||||
offsetIndices(b, op, allIvs);
|
||||
}
|
||||
|
||||
/// Asserts that the given index-typed value is strictly positive. If the value
|
||||
|
|
|
@ -71,9 +71,10 @@ struct LinalgOpTilingInterface
|
|||
Location loc = op->getLoc();
|
||||
LinalgOp linalgOp = cast<LinalgOp>(op);
|
||||
SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
|
||||
SmallVector<Value> offsetValues =
|
||||
getValueOrCreateConstantIndexOp(b, loc, offsets);
|
||||
SmallVector<Value, 4> tiledOperands = makeTiledShapes(
|
||||
b, loc, linalgOp, valuesToTile,
|
||||
getValueOrCreateConstantIndexOp(b, loc, offsets),
|
||||
b, loc, linalgOp, valuesToTile, offsetValues,
|
||||
getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true);
|
||||
|
||||
SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range(
|
||||
|
@ -83,6 +84,7 @@ struct LinalgOpTilingInterface
|
|||
|
||||
Operation *tiledOp =
|
||||
linalgOp.clone(b, loc, resultTensorTypes, tiledOperands);
|
||||
offsetIndices(b, cast<LinalgOp>(tiledOp), offsetValues);
|
||||
|
||||
return {tiledOp};
|
||||
}
|
||||
|
|
|
@ -1048,21 +1048,29 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &b, Location loc,
|
|||
return tiledShapes;
|
||||
}
|
||||
|
||||
void addTileLoopIvsToIndexOpResults(OpBuilder &b, LinalgOp tiledOp,
|
||||
ArrayRef<Value> ivs) {
|
||||
if (tiledOp.hasIndexSemantics()) {
|
||||
for (IndexOp indexOp : tiledOp.getBlock()->getOps<IndexOp>()) {
|
||||
if (ivs[indexOp.dim()] == nullptr)
|
||||
continue;
|
||||
OpBuilder::InsertionGuard guard(b);
|
||||
b.setInsertionPointAfter(indexOp);
|
||||
AffineExpr index, offset;
|
||||
bindDims(b.getContext(), index, offset);
|
||||
AffineApplyOp applyOp = makeComposedAffineApply(
|
||||
b, indexOp.getLoc(), index + offset,
|
||||
ValueRange{indexOp.getResult(), ivs[indexOp.dim()]});
|
||||
indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp);
|
||||
}
|
||||
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef<Value> offsets) {
|
||||
IRRewriter rewriter(b);
|
||||
offsetIndices(rewriter, linalgOp, offsets);
|
||||
}
|
||||
|
||||
void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
|
||||
ArrayRef<Value> offsets) {
|
||||
if (!linalgOp.hasIndexSemantics())
|
||||
return;
|
||||
|
||||
for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) {
|
||||
if (indexOp.dim() >= offsets.size() || offsets[indexOp.dim()] == nullptr)
|
||||
continue;
|
||||
OpBuilder::InsertionGuard guard(b);
|
||||
b.setInsertionPointAfter(indexOp);
|
||||
AffineExpr index, offset;
|
||||
bindDims(b.getContext(), index, offset);
|
||||
AffineApplyOp applyOp = makeComposedAffineApply(
|
||||
b, indexOp.getLoc(), index + offset,
|
||||
ValueRange{indexOp.getResult(), offsets[indexOp.dim()]});
|
||||
b.replaceOpWithIf(indexOp, applyOp.getResult(), [&](OpOperand &use) {
|
||||
return use.getOwner() != applyOp;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -192,3 +192,37 @@ func.func @conv2D(%arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,
|
|||
// CHECK-SAME: outs(%[[INIT_TILE]] :
|
||||
// CHECK: tensor.insert_slice %[[CONV_TILE]] into %[[INIT2]]
|
||||
// CHECK-SAME: [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[$MAP_ADD:.+]] = affine_map<(d0, d1) -> (d0 + d1)>
|
||||
|
||||
// CHECK-LABEL: @indexed_semantics
|
||||
func.func @indexed_semantics(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// Check that we correctly amend "linalg.index" results.
|
||||
|
||||
// CHECK: scf.for %[[I0:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
|
||||
// CHECK: scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
|
||||
%0 = linalg.generic {
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
{__internal_linalg_transform__ = "indexed_semantics"}
|
||||
ins(%arg0: tensor<?x?xf32>)
|
||||
outs(%arg1: tensor<?x?xf32>) {
|
||||
^bb0(%arg2: f32, %arg3: f32):
|
||||
// CHECK: %[[INDEX0:.+]] = linalg.index 0
|
||||
// CHECK: %[[INDEX0_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX0]], %[[I0]])
|
||||
%1 = linalg.index 0 : index
|
||||
// CHECK: %[[INDEX1:.+]] = linalg.index 1
|
||||
// CHECK: %[[INDEX1_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX1]], %[[I1]])
|
||||
%2 = linalg.index 1 : index
|
||||
// CHECK: arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]]
|
||||
%3 = arith.addi %1, %2 : index
|
||||
%4 = arith.index_cast %3 : index to i64
|
||||
%5 = arith.uitofp %4 : i64 to f32
|
||||
%6 = arith.addf %5, %arg2 : f32
|
||||
linalg.yield %6 : f32
|
||||
} -> (tensor<?x?xf32>)
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
|
|
@ -171,6 +171,9 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
|
|||
// 4. Tiling 2D conv op.
|
||||
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
|
||||
context, {0, 0, 0, 0, 10, 20, 30}, "simple_conv", patterns);
|
||||
// 5. Tiling a simple op with `linalg.index` inside.
|
||||
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
|
||||
context, {10, 20}, "indexed_semantics", patterns);
|
||||
return;
|
||||
}
|
||||
if (testTileConsumerAndFuseProducer) {
|
||||
|
|
Loading…
Reference in New Issue