diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 501c34f5c46b..491c59d838c4 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -342,21 +342,17 @@ Optional promoteSubViews(OpBuilder &b, LinalgOp op, LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op, SmallVectorImpl &newResults); -/// Emits a loop nest of `LoopTy` with the proper body for `linalgOp`. -template -Optional linalgLowerOpToLoops(PatternRewriter &rewriter, - LinalgOp linalgOp); - /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`. -LogicalResult linalgOpToLoops(PatternRewriter &rewriter, LinalgOp linalgOp); - -/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`. -LogicalResult linalgOpToParallelLoops(PatternRewriter &rewriter, +Optional linalgOpToLoops(PatternRewriter &rewriter, LinalgOp linalgOp); +/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`. +Optional linalgOpToParallelLoops(PatternRewriter &rewriter, + LinalgOp linalgOp); + /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`. -LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, - LinalgOp linalgOp); +Optional linalgOpToAffineLoops(PatternRewriter &rewriter, + LinalgOp linalgOp); //===----------------------------------------------------------------------===// // Preconditions that ensure the corresponding transformation succeeds and can @@ -814,15 +810,15 @@ struct LinalgLoweringPattern : public RewritePattern { // TODO: Move lowering to library calls here. return failure(); case LinalgLoweringType::Loops: - if (failed(linalgOpToLoops(rewriter, op))) + if (!linalgOpToLoops(rewriter, op)) return failure(); break; case LinalgLoweringType::AffineLoops: - if (failed(linalgOpToAffineLoops(rewriter, op))) + if (!linalgOpToAffineLoops(rewriter, op)) return failure(); break; case LinalgLoweringType::ParallelLoops: - if (failed(linalgOpToParallelLoops(rewriter, op))) + if (!linalgOpToParallelLoops(rewriter, op)) return failure(); break; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index b1bf213e9cbb..317a9864516a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -378,51 +378,6 @@ static void emitScalarImplementation(ArrayRef allIvs, PoolingSumOp op) { getPoolingInput(op, indices.inputs); } -template -static Optional linalgOpToLoopsImpl(LinalgOp linalgOp, - OpBuilder &builder) { - using IndexedValueTy = typename GenerateLoopNest::IndexedValueTy; - ScopedContext scope(builder, linalgOp.getLoc()); - - // The flattened loopToOperandRangesMaps is expected to be an invertible - // permutation map (which is asserted in the inverse calculation). - assert(linalgOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - - auto loopRanges = linalgOp.createLoopRanges(builder, linalgOp.getLoc()); - auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue()); - - SmallVector allIvs; - GenerateLoopNest::doit( - loopRanges, linalgOp, iteratorTypes, - [&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector { - assert(iterArgs.empty() && "unexpected iterArgs"); - allIvs.append(ivs.begin(), ivs.end()); - llvm::TypeSwitch(linalgOp) - .Case( - [&](auto op) { - emitScalarImplementation(allIvs, op); - }) - .Default([&](Operation *op) { assert(false && "unexpected op"); }); - return scf::ValueVector{}; - }); - // Number of loop ops might be different from the number of ivs since some - // loops like affine.parallel and scf.parallel have multiple ivs. - SetVector loopSet; - for (Value iv : allIvs) { - if (!iv) - return {}; - // The induction variable is a block argument of the entry block of the - // loop operation. - BlockArgument ivVal = iv.dyn_cast(); - if (!ivVal) - return {}; - loopSet.insert(ivVal.getOwner()->getParentOp()); - } - LinalgLoops loops(loopSet.begin(), loopSet.end()); - return loops; -} - /// Replace the index operations in the body of the loop nest by the matching /// induction variables. static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp, @@ -455,6 +410,57 @@ static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp, } } +template +static Optional linalgOpToLoopsImpl(PatternRewriter &rewriter, + LinalgOp linalgOp) { + using IndexedValueTy = typename GenerateLoopNest::IndexedValueTy; + ScopedContext scope(rewriter, linalgOp.getLoc()); + + // Canonicalize indexed_generic operations before lowering them to loops. + if (isa(linalgOp)) + return llvm::None; + + // The flattened loopToOperandRangesMaps is expected to be an invertible + // permutation map (which is asserted in the inverse calculation). + assert(linalgOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + + auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc()); + auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue()); + + SmallVector allIvs; + GenerateLoopNest::doit( + loopRanges, linalgOp, iteratorTypes, + [&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector { + assert(iterArgs.empty() && "unexpected iterArgs"); + allIvs.append(ivs.begin(), ivs.end()); + llvm::TypeSwitch(linalgOp) + .Case( + [&](auto op) { + emitScalarImplementation(allIvs, op); + }) + .Default([&](Operation *op) { assert(false && "unexpected op"); }); + return scf::ValueVector{}; + }); + // Number of loop ops might be different from the number of ivs since some + // loops like affine.parallel and scf.parallel have multiple ivs. + SetVector loopSet; + for (Value iv : allIvs) { + if (!iv) + return {}; + // The induction variable is a block argument of the entry block of the + // loop operation. + BlockArgument ivVal = iv.dyn_cast(); + if (!ivVal) + return {}; + loopSet.insert(ivVal.getOwner()->getParentOp()); + } + LinalgLoops loops(loopSet.begin(), loopSet.end()); + // Replace all index operations in the loop body. + replaceIndexOpsByInductionVariables(linalgOp, rewriter, loops); + return loops; +} + namespace { template class LinalgRewritePattern : public RewritePattern { @@ -467,7 +473,7 @@ public: auto linalgOp = dyn_cast(op); if (!isa(op)) return failure(); - if (!linalgLowerOpToLoops(rewriter, linalgOp)) + if (!linalgOpToLoopsImpl(rewriter, linalgOp)) return failure(); rewriter.eraseOp(op); return success(); @@ -614,52 +620,22 @@ mlir::createConvertLinalgToAffineLoopsPass() { return std::make_unique(); } -/// Emits a loop nest with the proper body for `linalgOp`. -template -Optional -mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter, - LinalgOp linalgOp) { - // Convert indexed_generic ops to generic ops before lowering them to loops. - if (isa(linalgOp)) - return llvm::None; - - Optional loopOps = - linalgOpToLoopsImpl(linalgOp.getOperation(), rewriter); - if (loopOps.hasValue()) - replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue()); - return loopOps; -} - -template Optional -mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter, - LinalgOp linalgOp); -template Optional -mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter, - LinalgOp linalgOp); -template Optional -mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter, - LinalgOp linalgOp); - /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`. -LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter, - LinalgOp linalgOp) { - Optional loops = - linalgLowerOpToLoops(rewriter, linalgOp); - return loops ? success() : failure(); +Optional +mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter, + LinalgOp linalgOp) { + return linalgOpToLoopsImpl(rewriter, linalgOp); } /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`. -LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter, - LinalgOp linalgOp) { - Optional loops = - linalgLowerOpToLoops(rewriter, linalgOp); - return loops ? success() : failure(); +Optional mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter, + LinalgOp linalgOp) { + return linalgOpToLoopsImpl(rewriter, linalgOp); } /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`. -LogicalResult mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter, - LinalgOp linalgOp) { - Optional loops = - linalgLowerOpToLoops(rewriter, linalgOp); - return loops ? success() : failure(); +Optional +mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter, + LinalgOp linalgOp) { + return linalgOpToLoopsImpl(rewriter, linalgOp); }