From 0e777e4ad7d554436a1c181674bdbaeab9053c31 Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Thu, 22 Apr 2021 08:22:37 +0000 Subject: [PATCH] [mlir][linalg] remove interchange option on linalg to loop lowering. The interchange option attached to the linalg to loop lowering affects only the loops and does not update the memory accesses generated in to body of the operation. Instead of performing the interchange during the loop lowering use the interchange pattern. Differential Revision: https://reviews.llvm.org/D100758 --- mlir/include/mlir/Dialect/Linalg/Passes.td | 18 ---- .../Dialect/Linalg/Transforms/Transforms.h | 39 +++---- mlir/lib/Dialect/Linalg/Transforms/Loops.cpp | 102 +++++++----------- mlir/test/Dialect/Linalg/loop-order.mlir | 72 ------------- 4 files changed, 48 insertions(+), 183 deletions(-) delete mode 100644 mlir/test/Dialect/Linalg/loop-order.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td index 344ffe977caf..8d411d5964c5 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -62,12 +62,6 @@ def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> { let summary = "Lower the operations from the linalg dialect into affine " "loops"; let constructor = "mlir::createConvertLinalgToAffineLoopsPass()"; - let options = [ - ListOption<"interchangeVector", "interchange-vector", "unsigned", - "Permute the loops in the nest following the given " - "interchange vector", - "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated"> - ]; let dependentDialects = [ "AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"]; } @@ -75,12 +69,6 @@ def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> { def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> { let summary = "Lower the operations from the linalg dialect into loops"; let constructor = "mlir::createConvertLinalgToLoopsPass()"; - let options = [ - ListOption<"interchangeVector", "interchange-vector", "unsigned", - "Permute the loops in the nest following the given " - "interchange vector", - "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated"> - ]; let dependentDialects = [ "linalg::LinalgDialect", "scf::SCFDialect", @@ -103,12 +91,6 @@ def LinalgLowerToParallelLoops let summary = "Lower the operations from the linalg dialect into parallel " "loops"; let constructor = "mlir::createConvertLinalgToParallelLoopsPass()"; - let options = [ - ListOption<"interchangeVector", "interchange-vector", "unsigned", - "Permute the loops in the nest following the given " - "interchange vector", - "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated"> - ]; let dependentDialects = [ "AffineDialect", "linalg::LinalgDialect", diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 251a2f8e6d03..2338198b5f2e 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -338,28 +338,16 @@ LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op, /// Emits a loop nest of `LoopTy` with the proper body for `op`. template -Optional -linalgLowerOpToLoops(OpBuilder &builder, Operation *op, - ArrayRef interchangeVector = {}); +Optional linalgLowerOpToLoops(OpBuilder &builder, Operation *op); -/// Emits a loop nest of `scf.for` with the proper body for `op`. The generated -/// loop nest will follow the `interchangeVector`-permutated iterator order. If -/// `interchangeVector` is empty, then no permutation happens. -LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op, - ArrayRef interchangeVector = {}); +/// Emits a loop nest of `scf.for` with the proper body for `op`. +LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op); -/// Emits a loop nest of `scf.parallel` with the proper body for `op`. The -/// generated loop nest will follow the `interchangeVector`-permutated -// iterator order. If `interchangeVector` is empty, then no permutation happens. -LogicalResult -linalgOpToParallelLoops(OpBuilder &builder, Operation *op, - ArrayRef interchangeVector = {}); +/// Emits a loop nest of `scf.parallel` with the proper body for `op`. +LogicalResult linalgOpToParallelLoops(OpBuilder &builder, Operation *op); -/// Emits a loop nest of `affine.for` with the proper body for `op`. The -/// generated loop nest will follow the `interchangeVector`-permutated -// iterator order. If `interchangeVector` is empty, then no permutation happens. -LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op, - ArrayRef interchangeVector = {}); +/// Emits a loop nest of `affine.for` with the proper body for `op`. +LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op); //===----------------------------------------------------------------------===// // Preconditions that ensure the corresponding transformation succeeds and can @@ -808,10 +796,9 @@ struct LinalgLoweringPattern : public RewritePattern { LinalgLoweringPattern( MLIRContext *context, LinalgLoweringType loweringType, LinalgTransformationFilter filter = LinalgTransformationFilter(), - ArrayRef interchangeVector = {}, PatternBenefit benefit = 1) + PatternBenefit benefit = 1) : RewritePattern(OpTy::getOperationName(), benefit, context), - filter(filter), loweringType(loweringType), - interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} + filter(filter), loweringType(loweringType) {} // TODO: Move implementation to .cpp once named ops are auto-generated. LogicalResult matchAndRewrite(Operation *op, @@ -827,15 +814,15 @@ struct LinalgLoweringPattern : public RewritePattern { // TODO: Move lowering to library calls here. return failure(); case LinalgLoweringType::Loops: - if (failed(linalgOpToLoops(rewriter, op, interchangeVector))) + if (failed(linalgOpToLoops(rewriter, op))) return failure(); break; case LinalgLoweringType::AffineLoops: - if (failed(linalgOpToAffineLoops(rewriter, op, interchangeVector))) + if (failed(linalgOpToAffineLoops(rewriter, op))) return failure(); break; case LinalgLoweringType::ParallelLoops: - if (failed(linalgOpToParallelLoops(rewriter, op, interchangeVector))) + if (failed(linalgOpToParallelLoops(rewriter, op))) return failure(); break; } @@ -850,8 +837,6 @@ private: /// Controls whether the pattern lowers to library calls, scf.for, affine.for /// or scf.parallel. LinalgLoweringType loweringType; - /// Permutated loop order in the generated loop nest. - SmallVector interchangeVector; }; /// Linalg generalization patterns diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index c85f4a9abd38..f19493c3cca9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -457,9 +457,8 @@ static void emitScalarImplementation(ArrayRef allIvs, } template -static Optional -linalgOpToLoopsImpl(Operation *op, OpBuilder &builder, - ArrayRef interchangeVector) { +static Optional linalgOpToLoopsImpl(Operation *op, + OpBuilder &builder) { using IndexedValueTy = typename GenerateLoopNest::IndexedValueTy; ScopedContext scope(builder, op->getLoc()); @@ -472,13 +471,6 @@ linalgOpToLoopsImpl(Operation *op, OpBuilder &builder, auto loopRanges = linalgOp.createLoopRanges(builder, op->getLoc()); auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue()); - if (!interchangeVector.empty()) { - assert(interchangeVector.size() == loopRanges.size()); - assert(interchangeVector.size() == iteratorTypes.size()); - applyPermutationToVector(loopRanges, interchangeVector); - applyPermutationToVector(iteratorTypes, interchangeVector); - } - SmallVector allIvs; GenerateLoopNest::doit( loopRanges, /*iterInitArgs=*/{}, iteratorTypes, @@ -511,11 +503,10 @@ linalgOpToLoopsImpl(Operation *op, OpBuilder &builder, } /// Replace the index operations in the body of the loop nest by the matching -/// induction variables. If available use the interchange vector to map the -/// interchanged induction variables to the dimension of the index operation. -static void replaceIndexOpsByInductionVariables( - LinalgOp linalgOp, PatternRewriter &rewriter, ArrayRef loopOps, - ArrayRef interchangeVector) { +/// induction variables. +static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp, + PatternRewriter &rewriter, + ArrayRef loopOps) { // Extract the induction variables of the loop nest from outer to inner. SmallVector allIvs; for (Operation *loopOp : loopOps) { @@ -538,16 +529,8 @@ static void replaceIndexOpsByInductionVariables( if (!loopOps.empty()) { LoopLikeOpInterface loopOp = loopOps.back(); for (IndexOp indexOp : - llvm::make_early_inc_range(loopOp.getLoopBody().getOps())) { - // Search the indexing dimension in the interchange vector if available. - assert(interchangeVector.empty() || - interchangeVector.size() == linalgOp.getNumLoops()); - const auto *it = llvm::find(interchangeVector, indexOp.dim()); - uint64_t dim = it != interchangeVector.end() - ? std::distance(interchangeVector.begin(), it) - : indexOp.dim(); - rewriter.replaceOp(indexOp, allIvs[dim]); - } + llvm::make_early_inc_range(loopOp.getLoopBody().getOps())) + rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]); } } @@ -555,39 +538,31 @@ namespace { template class LinalgRewritePattern : public RewritePattern { public: - LinalgRewritePattern(MLIRContext *context, - ArrayRef interchangeVector) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), - interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} + LinalgRewritePattern(MLIRContext *context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { auto linalgOp = dyn_cast(op); if (!isa(op)) return failure(); - Optional loopOps = - linalgOpToLoopsImpl(op, rewriter, interchangeVector); + Optional loopOps = linalgOpToLoopsImpl(op, rewriter); if (!loopOps.hasValue()) return failure(); - replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue(), - interchangeVector); + replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue()); rewriter.eraseOp(op); return success(); } - -private: - SmallVector interchangeVector; }; struct FoldAffineOp; } // namespace template -static void lowerLinalgToLoopsImpl(FuncOp funcOp, - ArrayRef interchangeVector) { +static void lowerLinalgToLoopsImpl(FuncOp funcOp) { MLIRContext *context = funcOp.getContext(); RewritePatternSet patterns(context); - patterns.add>(context, interchangeVector); + patterns.add>(context); memref::DimOp::getCanonicalizationPatterns(patterns, context); AffineApplyOp::getCanonicalizationPatterns(patterns, context); patterns.add(context); @@ -639,7 +614,7 @@ struct LowerToAffineLoops registry.insert(); } void runOnFunction() override { - lowerLinalgToLoopsImpl(getFunction(), interchangeVector); + lowerLinalgToLoopsImpl(getFunction()); } }; @@ -648,14 +623,14 @@ struct LowerToLoops : public LinalgLowerToLoopsBase { registry.insert(); } void runOnFunction() override { - lowerLinalgToLoopsImpl(getFunction(), interchangeVector); + lowerLinalgToLoopsImpl(getFunction()); } }; struct LowerToParallelLoops : public LinalgLowerToParallelLoopsBase { void runOnFunction() override { - lowerLinalgToLoopsImpl(getFunction(), interchangeVector); + lowerLinalgToLoopsImpl(getFunction()); } }; } // namespace @@ -676,43 +651,38 @@ mlir::createConvertLinalgToAffineLoopsPass() { /// Emits a loop nest with the proper body for `op`. template -Optional -mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, Operation *op, - ArrayRef interchangeVector) { - return linalgOpToLoopsImpl(op, builder, interchangeVector); +Optional mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, + Operation *op) { + return linalgOpToLoopsImpl(op, builder); } -template Optional mlir::linalg::linalgLowerOpToLoops( - OpBuilder &builder, Operation *op, ArrayRef interchangeVector); -template Optional mlir::linalg::linalgLowerOpToLoops( - OpBuilder &builder, Operation *op, ArrayRef interchangeVector); template Optional -mlir::linalg::linalgLowerOpToLoops( - OpBuilder &builder, Operation *op, ArrayRef interchangeVector); +mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, + Operation *op); +template Optional +mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, + Operation *op); +template Optional +mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, + Operation *op); /// Emits a loop nest of `affine.for` with the proper body for `op`. -LogicalResult -mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, Operation *op, - ArrayRef interchangeVector) { - Optional loops = - linalgLowerOpToLoops(builder, op, interchangeVector); +LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, + Operation *op) { + Optional loops = linalgLowerOpToLoops(builder, op); return loops ? success() : failure(); } /// Emits a loop nest of `scf.for` with the proper body for `op`. -LogicalResult -mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op, - ArrayRef interchangeVector) { - Optional loops = - linalgLowerOpToLoops(builder, op, interchangeVector); +LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) { + Optional loops = linalgLowerOpToLoops(builder, op); return loops ? success() : failure(); } /// Emits a loop nest of `scf.parallel` with the proper body for `op`. -LogicalResult -mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, Operation *op, - ArrayRef interchangeVector) { +LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, + Operation *op) { Optional loops = - linalgLowerOpToLoops(builder, op, interchangeVector); + linalgLowerOpToLoops(builder, op); return loops ? success() : failure(); } diff --git a/mlir/test/Dialect/Linalg/loop-order.mlir b/mlir/test/Dialect/Linalg/loop-order.mlir deleted file mode 100644 index c572967e6d10..000000000000 --- a/mlir/test/Dialect/Linalg/loop-order.mlir +++ /dev/null @@ -1,72 +0,0 @@ -// RUN: mlir-opt %s -convert-linalg-to-loops="interchange-vector=4,0,3,1,2" -split-input-file | FileCheck --check-prefix=LOOP %s -// RUN: mlir-opt %s -convert-linalg-to-parallel-loops="interchange-vector=4,0,3,1,2" -split-input-file | FileCheck --check-prefix=PARALLEL %s -// RUN: mlir-opt %s -convert-linalg-to-affine-loops="interchange-vector=4,0,3,1,2" -split-input-file | FileCheck --check-prefix=AFFINE %s - -func @copy(%input: memref<1x2x3x4x5xf32>, %output: memref<1x2x3x4x5xf32>) { - linalg.copy(%input, %output): memref<1x2x3x4x5xf32>, memref<1x2x3x4x5xf32> - return -} - -// LOOP: scf.for %{{.*}} = %c0 to %c5 step %c1 -// LOOP: scf.for %{{.*}} = %c0 to %c1 step %c1 -// LOOP: scf.for %{{.*}} = %c0 to %c4 step %c1 -// LOOP: scf.for %{{.*}} = %c0 to %c2 step %c1 -// LOOP: scf.for %{{.*}} = %c0 to %c3 step %c1 - -// PARALLEL: scf.parallel -// PARALLEL-SAME: to (%c5, %c1, %c4, %c2, %c3) - -// AFFINE: affine.for %{{.*}} = 0 to 5 -// AFFINE: affine.for %{{.*}} = 0 to 1 -// AFFINE: affine.for %{{.*}} = 0 to 4 -// AFFINE: affine.for %{{.*}} = 0 to 2 -// AFFINE: affine.for %{{.*}} = 0 to 3 - -// ----- - -#map = affine_map<(i, j, k, l, m) -> (i, j, k, l, m)> -func @generic(%output: memref<1x2x3x4x5xindex>) { - linalg.generic {indexing_maps = [#map], - iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} - outs(%output : memref<1x2x3x4x5xindex>) { - ^bb0(%arg0 : index): - %i = linalg.index 0 : index - %j = linalg.index 1 : index - %k = linalg.index 2 : index - %l = linalg.index 3 : index - %m = linalg.index 4 : index - %0 = addi %i, %j : index - %1 = addi %0, %k : index - %2 = addi %1, %l : index - %3 = addi %2, %m : index - linalg.yield %3: index - } - return -} - -// LOOP: scf.for %[[m:.*]] = %c0 to %c5 step %c1 -// LOOP: scf.for %[[i:.*]] = %c0 to %c1 step %c1 -// LOOP: scf.for %[[l:.*]] = %c0 to %c4 step %c1 -// LOOP: scf.for %[[j:.*]] = %c0 to %c2 step %c1 -// LOOP: scf.for %[[k:.*]] = %c0 to %c3 step %c1 -// LOOP: %{{.*}} = addi %[[i]], %[[j]] : index -// LOOP: %{{.*}} = addi %{{.*}}, %[[k]] : index -// LOOP: %{{.*}} = addi %{{.*}}, %[[l]] : index -// LOOP: %{{.*}} = addi %{{.*}}, %[[m]] : index - -// PARALLEL: scf.parallel (%[[m:.*]], %[[i:.*]], %[[l:.*]], %[[j:.*]], %[[k:.*]]) = -// PARALLEL-SAME: to (%c5, %c1, %c4, %c2, %c3) -// PARALLEL: %{{.*}} = addi %[[i]], %[[j]] : index -// PARALLEL: %{{.*}} = addi %{{.*}}, %[[k]] : index -// PARALLEL: %{{.*}} = addi %{{.*}}, %[[l]] : index -// PARALLEL: %{{.*}} = addi %{{.*}}, %[[m]] : index - -// AFFINE: affine.for %[[m:.*]] = 0 to 5 -// AFFINE: affine.for %[[i:.*]] = 0 to 1 -// AFFINE: affine.for %[[l:.*]] = 0 to 4 -// AFFINE: affine.for %[[j:.*]] = 0 to 2 -// AFFINE: affine.for %[[k:.*]] = 0 to 3 -// AFFINE: %{{.*}} = addi %[[i]], %[[j]] : index -// AFFINE: %{{.*}} = addi %{{.*}}, %[[k]] : index -// AFFINE: %{{.*}} = addi %{{.*}}, %[[l]] : index -// AFFINE: %{{.*}} = addi %{{.*}}, %[[m]] : index