forked from OSchip/llvm-project
[mlir][linalg] Remove template parameter from loop lowering.
Replace the templated linalgLowerOpToLoops method by three specialized methods linalgOpToLoops, LinalgOpToParallelLoops, and linalgOpToAffineLoops. Differential Revision: https://reviews.llvm.org/D102324
This commit is contained in:
parent
900c898994
commit
7c16f93c44
|
@ -342,21 +342,17 @@ Optional<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
|
|||
LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op,
|
||||
SmallVectorImpl<Value> &newResults);
|
||||
|
||||
/// Emits a loop nest of `LoopTy` with the proper body for `linalgOp`.
|
||||
template <typename LoopTy>
|
||||
Optional<LinalgLoops> linalgLowerOpToLoops(PatternRewriter &rewriter,
|
||||
LinalgOp linalgOp);
|
||||
|
||||
/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
|
||||
LogicalResult linalgOpToLoops(PatternRewriter &rewriter, LinalgOp linalgOp);
|
||||
|
||||
/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
|
||||
LogicalResult linalgOpToParallelLoops(PatternRewriter &rewriter,
|
||||
Optional<LinalgLoops> linalgOpToLoops(PatternRewriter &rewriter,
|
||||
LinalgOp linalgOp);
|
||||
|
||||
/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
|
||||
Optional<LinalgLoops> linalgOpToParallelLoops(PatternRewriter &rewriter,
|
||||
LinalgOp linalgOp);
|
||||
|
||||
/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
|
||||
LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter,
|
||||
LinalgOp linalgOp);
|
||||
Optional<LinalgLoops> linalgOpToAffineLoops(PatternRewriter &rewriter,
|
||||
LinalgOp linalgOp);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Preconditions that ensure the corresponding transformation succeeds and can
|
||||
|
@ -814,15 +810,15 @@ struct LinalgLoweringPattern : public RewritePattern {
|
|||
// TODO: Move lowering to library calls here.
|
||||
return failure();
|
||||
case LinalgLoweringType::Loops:
|
||||
if (failed(linalgOpToLoops(rewriter, op)))
|
||||
if (!linalgOpToLoops(rewriter, op))
|
||||
return failure();
|
||||
break;
|
||||
case LinalgLoweringType::AffineLoops:
|
||||
if (failed(linalgOpToAffineLoops(rewriter, op)))
|
||||
if (!linalgOpToAffineLoops(rewriter, op))
|
||||
return failure();
|
||||
break;
|
||||
case LinalgLoweringType::ParallelLoops:
|
||||
if (failed(linalgOpToParallelLoops(rewriter, op)))
|
||||
if (!linalgOpToParallelLoops(rewriter, op))
|
||||
return failure();
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -378,51 +378,6 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) {
|
|||
getPoolingInput<IndexedValueType>(op, indices.inputs);
|
||||
}
|
||||
|
||||
template <typename LoopTy>
|
||||
static Optional<LinalgLoops> linalgOpToLoopsImpl(LinalgOp linalgOp,
|
||||
OpBuilder &builder) {
|
||||
using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
|
||||
ScopedContext scope(builder, linalgOp.getLoc());
|
||||
|
||||
// The flattened loopToOperandRangesMaps is expected to be an invertible
|
||||
// permutation map (which is asserted in the inverse calculation).
|
||||
assert(linalgOp.hasBufferSemantics() &&
|
||||
"expected linalg op with buffer semantics");
|
||||
|
||||
auto loopRanges = linalgOp.createLoopRanges(builder, linalgOp.getLoc());
|
||||
auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue());
|
||||
|
||||
SmallVector<Value, 4> allIvs;
|
||||
GenerateLoopNest<LoopTy>::doit(
|
||||
loopRanges, linalgOp, iteratorTypes,
|
||||
[&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector {
|
||||
assert(iterArgs.empty() && "unexpected iterArgs");
|
||||
allIvs.append(ivs.begin(), ivs.end());
|
||||
llvm::TypeSwitch<Operation *>(linalgOp)
|
||||
.Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp, LinalgOp>(
|
||||
[&](auto op) {
|
||||
emitScalarImplementation<IndexedValueTy>(allIvs, op);
|
||||
})
|
||||
.Default([&](Operation *op) { assert(false && "unexpected op"); });
|
||||
return scf::ValueVector{};
|
||||
});
|
||||
// Number of loop ops might be different from the number of ivs since some
|
||||
// loops like affine.parallel and scf.parallel have multiple ivs.
|
||||
SetVector<Operation *> loopSet;
|
||||
for (Value iv : allIvs) {
|
||||
if (!iv)
|
||||
return {};
|
||||
// The induction variable is a block argument of the entry block of the
|
||||
// loop operation.
|
||||
BlockArgument ivVal = iv.dyn_cast<BlockArgument>();
|
||||
if (!ivVal)
|
||||
return {};
|
||||
loopSet.insert(ivVal.getOwner()->getParentOp());
|
||||
}
|
||||
LinalgLoops loops(loopSet.begin(), loopSet.end());
|
||||
return loops;
|
||||
}
|
||||
|
||||
/// Replace the index operations in the body of the loop nest by the matching
|
||||
/// induction variables.
|
||||
static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp,
|
||||
|
@ -455,6 +410,57 @@ static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp,
|
|||
}
|
||||
}
|
||||
|
||||
template <typename LoopTy>
|
||||
static Optional<LinalgLoops> linalgOpToLoopsImpl(PatternRewriter &rewriter,
|
||||
LinalgOp linalgOp) {
|
||||
using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
|
||||
ScopedContext scope(rewriter, linalgOp.getLoc());
|
||||
|
||||
// Canonicalize indexed_generic operations before lowering them to loops.
|
||||
if (isa<IndexedGenericOp>(linalgOp))
|
||||
return llvm::None;
|
||||
|
||||
// The flattened loopToOperandRangesMaps is expected to be an invertible
|
||||
// permutation map (which is asserted in the inverse calculation).
|
||||
assert(linalgOp.hasBufferSemantics() &&
|
||||
"expected linalg op with buffer semantics");
|
||||
|
||||
auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
|
||||
auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue());
|
||||
|
||||
SmallVector<Value, 4> allIvs;
|
||||
GenerateLoopNest<LoopTy>::doit(
|
||||
loopRanges, linalgOp, iteratorTypes,
|
||||
[&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector {
|
||||
assert(iterArgs.empty() && "unexpected iterArgs");
|
||||
allIvs.append(ivs.begin(), ivs.end());
|
||||
llvm::TypeSwitch<Operation *>(linalgOp)
|
||||
.Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp, LinalgOp>(
|
||||
[&](auto op) {
|
||||
emitScalarImplementation<IndexedValueTy>(allIvs, op);
|
||||
})
|
||||
.Default([&](Operation *op) { assert(false && "unexpected op"); });
|
||||
return scf::ValueVector{};
|
||||
});
|
||||
// Number of loop ops might be different from the number of ivs since some
|
||||
// loops like affine.parallel and scf.parallel have multiple ivs.
|
||||
SetVector<Operation *> loopSet;
|
||||
for (Value iv : allIvs) {
|
||||
if (!iv)
|
||||
return {};
|
||||
// The induction variable is a block argument of the entry block of the
|
||||
// loop operation.
|
||||
BlockArgument ivVal = iv.dyn_cast<BlockArgument>();
|
||||
if (!ivVal)
|
||||
return {};
|
||||
loopSet.insert(ivVal.getOwner()->getParentOp());
|
||||
}
|
||||
LinalgLoops loops(loopSet.begin(), loopSet.end());
|
||||
// Replace all index operations in the loop body.
|
||||
replaceIndexOpsByInductionVariables(linalgOp, rewriter, loops);
|
||||
return loops;
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename LoopType>
|
||||
class LinalgRewritePattern : public RewritePattern {
|
||||
|
@ -467,7 +473,7 @@ public:
|
|||
auto linalgOp = dyn_cast<LinalgOp>(op);
|
||||
if (!isa<LinalgOp>(op))
|
||||
return failure();
|
||||
if (!linalgLowerOpToLoops<LoopType>(rewriter, linalgOp))
|
||||
if (!linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp))
|
||||
return failure();
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
|
@ -614,52 +620,22 @@ mlir::createConvertLinalgToAffineLoopsPass() {
|
|||
return std::make_unique<LowerToAffineLoops>();
|
||||
}
|
||||
|
||||
/// Emits a loop nest with the proper body for `linalgOp`.
|
||||
template <typename LoopTy>
|
||||
Optional<LinalgLoops>
|
||||
mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter,
|
||||
LinalgOp linalgOp) {
|
||||
// Convert indexed_generic ops to generic ops before lowering them to loops.
|
||||
if (isa<IndexedGenericOp>(linalgOp))
|
||||
return llvm::None;
|
||||
|
||||
Optional<LinalgLoops> loopOps =
|
||||
linalgOpToLoopsImpl<LoopTy>(linalgOp.getOperation(), rewriter);
|
||||
if (loopOps.hasValue())
|
||||
replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue());
|
||||
return loopOps;
|
||||
}
|
||||
|
||||
template Optional<LinalgLoops>
|
||||
mlir::linalg::linalgLowerOpToLoops<AffineForOp>(PatternRewriter &rewriter,
|
||||
LinalgOp linalgOp);
|
||||
template Optional<LinalgLoops>
|
||||
mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(PatternRewriter &rewriter,
|
||||
LinalgOp linalgOp);
|
||||
template Optional<LinalgLoops>
|
||||
mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(PatternRewriter &rewriter,
|
||||
LinalgOp linalgOp);
|
||||
|
||||
/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
|
||||
LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter,
|
||||
LinalgOp linalgOp) {
|
||||
Optional<LinalgLoops> loops =
|
||||
linalgLowerOpToLoops<AffineForOp>(rewriter, linalgOp);
|
||||
return loops ? success() : failure();
|
||||
Optional<LinalgLoops>
|
||||
mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter,
|
||||
LinalgOp linalgOp) {
|
||||
return linalgOpToLoopsImpl<AffineForOp>(rewriter, linalgOp);
|
||||
}
|
||||
|
||||
/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
|
||||
LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter,
|
||||
LinalgOp linalgOp) {
|
||||
Optional<LinalgLoops> loops =
|
||||
linalgLowerOpToLoops<scf::ForOp>(rewriter, linalgOp);
|
||||
return loops ? success() : failure();
|
||||
Optional<LinalgLoops> mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter,
|
||||
LinalgOp linalgOp) {
|
||||
return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp);
|
||||
}
|
||||
|
||||
/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
|
||||
LogicalResult mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter,
|
||||
LinalgOp linalgOp) {
|
||||
Optional<LinalgLoops> loops =
|
||||
linalgLowerOpToLoops<scf::ParallelOp>(rewriter, linalgOp);
|
||||
return loops ? success() : failure();
|
||||
Optional<LinalgLoops>
|
||||
mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter,
|
||||
LinalgOp linalgOp) {
|
||||
return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue