[mlir][linalg] Remove template parameter from loop lowering.

Replace the templated linalgLowerOpToLoops method by three specialized methods linalgOpToLoops, LinalgOpToParallelLoops, and linalgOpToAffineLoops.

Differential Revision: https://reviews.llvm.org/D102324
This commit is contained in:
Tobias Gysi 2021-05-17 08:50:15 +00:00
parent 900c898994
commit 7c16f93c44
2 changed files with 73 additions and 101 deletions

View File

@ -342,21 +342,17 @@ Optional<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op,
SmallVectorImpl<Value> &newResults);
/// Emits a loop nest of `LoopTy` with the proper body for `linalgOp`.
template <typename LoopTy>
Optional<LinalgLoops> linalgLowerOpToLoops(PatternRewriter &rewriter,
LinalgOp linalgOp);
/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
LogicalResult linalgOpToLoops(PatternRewriter &rewriter, LinalgOp linalgOp);
/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
LogicalResult linalgOpToParallelLoops(PatternRewriter &rewriter,
Optional<LinalgLoops> linalgOpToLoops(PatternRewriter &rewriter,
LinalgOp linalgOp);
/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
Optional<LinalgLoops> linalgOpToParallelLoops(PatternRewriter &rewriter,
LinalgOp linalgOp);
/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter,
LinalgOp linalgOp);
Optional<LinalgLoops> linalgOpToAffineLoops(PatternRewriter &rewriter,
LinalgOp linalgOp);
//===----------------------------------------------------------------------===//
// Preconditions that ensure the corresponding transformation succeeds and can
@ -814,15 +810,15 @@ struct LinalgLoweringPattern : public RewritePattern {
// TODO: Move lowering to library calls here.
return failure();
case LinalgLoweringType::Loops:
if (failed(linalgOpToLoops(rewriter, op)))
if (!linalgOpToLoops(rewriter, op))
return failure();
break;
case LinalgLoweringType::AffineLoops:
if (failed(linalgOpToAffineLoops(rewriter, op)))
if (!linalgOpToAffineLoops(rewriter, op))
return failure();
break;
case LinalgLoweringType::ParallelLoops:
if (failed(linalgOpToParallelLoops(rewriter, op)))
if (!linalgOpToParallelLoops(rewriter, op))
return failure();
break;
}

View File

@ -378,51 +378,6 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) {
getPoolingInput<IndexedValueType>(op, indices.inputs);
}
template <typename LoopTy>
static Optional<LinalgLoops> linalgOpToLoopsImpl(LinalgOp linalgOp,
OpBuilder &builder) {
using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
ScopedContext scope(builder, linalgOp.getLoc());
// The flattened loopToOperandRangesMaps is expected to be an invertible
// permutation map (which is asserted in the inverse calculation).
assert(linalgOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
auto loopRanges = linalgOp.createLoopRanges(builder, linalgOp.getLoc());
auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue());
SmallVector<Value, 4> allIvs;
GenerateLoopNest<LoopTy>::doit(
loopRanges, linalgOp, iteratorTypes,
[&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector {
assert(iterArgs.empty() && "unexpected iterArgs");
allIvs.append(ivs.begin(), ivs.end());
llvm::TypeSwitch<Operation *>(linalgOp)
.Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp, LinalgOp>(
[&](auto op) {
emitScalarImplementation<IndexedValueTy>(allIvs, op);
})
.Default([&](Operation *op) { assert(false && "unexpected op"); });
return scf::ValueVector{};
});
// Number of loop ops might be different from the number of ivs since some
// loops like affine.parallel and scf.parallel have multiple ivs.
SetVector<Operation *> loopSet;
for (Value iv : allIvs) {
if (!iv)
return {};
// The induction variable is a block argument of the entry block of the
// loop operation.
BlockArgument ivVal = iv.dyn_cast<BlockArgument>();
if (!ivVal)
return {};
loopSet.insert(ivVal.getOwner()->getParentOp());
}
LinalgLoops loops(loopSet.begin(), loopSet.end());
return loops;
}
/// Replace the index operations in the body of the loop nest by the matching
/// induction variables.
static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp,
@ -455,6 +410,57 @@ static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp,
}
}
template <typename LoopTy>
static Optional<LinalgLoops> linalgOpToLoopsImpl(PatternRewriter &rewriter,
LinalgOp linalgOp) {
using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
ScopedContext scope(rewriter, linalgOp.getLoc());
// Canonicalize indexed_generic operations before lowering them to loops.
if (isa<IndexedGenericOp>(linalgOp))
return llvm::None;
// The flattened loopToOperandRangesMaps is expected to be an invertible
// permutation map (which is asserted in the inverse calculation).
assert(linalgOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue());
SmallVector<Value, 4> allIvs;
GenerateLoopNest<LoopTy>::doit(
loopRanges, linalgOp, iteratorTypes,
[&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector {
assert(iterArgs.empty() && "unexpected iterArgs");
allIvs.append(ivs.begin(), ivs.end());
llvm::TypeSwitch<Operation *>(linalgOp)
.Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp, LinalgOp>(
[&](auto op) {
emitScalarImplementation<IndexedValueTy>(allIvs, op);
})
.Default([&](Operation *op) { assert(false && "unexpected op"); });
return scf::ValueVector{};
});
// Number of loop ops might be different from the number of ivs since some
// loops like affine.parallel and scf.parallel have multiple ivs.
SetVector<Operation *> loopSet;
for (Value iv : allIvs) {
if (!iv)
return {};
// The induction variable is a block argument of the entry block of the
// loop operation.
BlockArgument ivVal = iv.dyn_cast<BlockArgument>();
if (!ivVal)
return {};
loopSet.insert(ivVal.getOwner()->getParentOp());
}
LinalgLoops loops(loopSet.begin(), loopSet.end());
// Replace all index operations in the loop body.
replaceIndexOpsByInductionVariables(linalgOp, rewriter, loops);
return loops;
}
namespace {
template <typename LoopType>
class LinalgRewritePattern : public RewritePattern {
@ -467,7 +473,7 @@ public:
auto linalgOp = dyn_cast<LinalgOp>(op);
if (!isa<LinalgOp>(op))
return failure();
if (!linalgLowerOpToLoops<LoopType>(rewriter, linalgOp))
if (!linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp))
return failure();
rewriter.eraseOp(op);
return success();
@ -614,52 +620,22 @@ mlir::createConvertLinalgToAffineLoopsPass() {
return std::make_unique<LowerToAffineLoops>();
}
/// Emits a loop nest with the proper body for `linalgOp`.
template <typename LoopTy>
Optional<LinalgLoops>
mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter,
LinalgOp linalgOp) {
// Convert indexed_generic ops to generic ops before lowering them to loops.
if (isa<IndexedGenericOp>(linalgOp))
return llvm::None;
Optional<LinalgLoops> loopOps =
linalgOpToLoopsImpl<LoopTy>(linalgOp.getOperation(), rewriter);
if (loopOps.hasValue())
replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue());
return loopOps;
}
template Optional<LinalgLoops>
mlir::linalg::linalgLowerOpToLoops<AffineForOp>(PatternRewriter &rewriter,
LinalgOp linalgOp);
template Optional<LinalgLoops>
mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(PatternRewriter &rewriter,
LinalgOp linalgOp);
template Optional<LinalgLoops>
mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(PatternRewriter &rewriter,
LinalgOp linalgOp);
/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter,
LinalgOp linalgOp) {
Optional<LinalgLoops> loops =
linalgLowerOpToLoops<AffineForOp>(rewriter, linalgOp);
return loops ? success() : failure();
Optional<LinalgLoops>
mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter,
LinalgOp linalgOp) {
return linalgOpToLoopsImpl<AffineForOp>(rewriter, linalgOp);
}
/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter,
LinalgOp linalgOp) {
Optional<LinalgLoops> loops =
linalgLowerOpToLoops<scf::ForOp>(rewriter, linalgOp);
return loops ? success() : failure();
Optional<LinalgLoops> mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter,
LinalgOp linalgOp) {
return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp);
}
/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
LogicalResult mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter,
LinalgOp linalgOp) {
Optional<LinalgLoops> loops =
linalgLowerOpToLoops<scf::ParallelOp>(rewriter, linalgOp);
return loops ? success() : failure();
Optional<LinalgLoops>
mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter,
LinalgOp linalgOp) {
return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp);
}