[mlir][linalg] Move generalization pattern to Transforms (NFC).

Move the generalization pattern to the other Linalg transforms to make it available to the codegen strategy.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D110728
This commit is contained in:
Tobias Gysi 2021-10-05 12:24:19 +00:00
parent cf818b55e7
commit e826db6240
3 changed files with 71 additions and 45 deletions

View File

@ -234,6 +234,10 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
void interchangeGenericOp(PatternRewriter &rewriter, GenericOp genericOp,
ArrayRef<unsigned> interchangeVector);
/// Creates a GenericOp from the given named operation `namedOp`. Assumes
/// `namedOp` is not a GenericOp and has a region builder.
GenericOp generalizeNamedOp(PatternRewriter &rewriter, LinalgOp namedOp);
/// Callback function type used to perform the allocation for the promoted
/// `subView`. In `boundingSubViewsize` a best attempt is made to find the
/// smallest constant value for the size of the buffer needed for each
@ -380,6 +384,9 @@ LogicalResult
interchangeGenericOpPrecondition(GenericOp genericOp,
ArrayRef<unsigned> interchangeVector);
/// Generalize named operations to generic operations.
LogicalResult generalizeNamedOpPrecondition(Operation *op);
/// Promote std.subviews feeding linalg operations.
LogicalResult promoteSubviewsPrecondition(Operation *op,
LinalgPromotionOptions options);
@ -701,6 +708,31 @@ private:
SmallVector<unsigned, 8> interchangeVector;
};
///
/// Linalg generalization pattern.
///
/// Apply the `generalization` transformation as a pattern.
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `generalization` for more details.
struct LinalgGeneralizationPattern : public RewritePattern {
// Entry point to match any LinalgOp OpInterface.
LinalgGeneralizationPattern(
MLIRContext *context,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
// Entry point to match a specific Linalg op.
LinalgGeneralizationPattern(
StringRef opName, MLIRContext *context,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgTransformationFilter filter;
};
///
/// Linalg promotion patterns.
///

View File

@ -29,10 +29,19 @@
using namespace mlir;
using namespace mlir::linalg;
// Creates a linalg.generic op from the given `namedOp`. Returns a null op if
// the given `namedOp` does not have a region builder.
static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp,
PatternRewriter &rewriter) {
LogicalResult mlir::linalg::generalizeNamedOpPrecondition(Operation *op) {
LinalgOp namedOp = dyn_cast<LinalgOp>(op);
// Check if the operation is a LinalgOp but not a GenericOp.
if (!namedOp || isa<GenericOp>(op))
return failure();
// Check if the operation has a region builder.
if (!namedOp.getRegionBuilder())
return failure();
return success();
}
GenericOp mlir::linalg::generalizeNamedOp(PatternRewriter &rewriter,
LinalgOp namedOp) {
SmallVector<Value> inputOperands = namedOp.getInputOperands();
SmallVector<Value> outputOperands = namedOp.getOutputOperands();
SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps();
@ -54,10 +63,7 @@ static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp,
// Otherwise use the region builder to generate a new region.
// TODO: Remove this path once all linag operations have a region attached.
auto regionBuilder = namedOp.getRegionBuilder();
if (!regionBuilder) {
LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n");
return nullptr;
}
assert(regionBuilder && "expect the operation to have region builder");
return rewriter.create<GenericOp>(
namedOp.getLoc(), types, inputOperands, outputOperands, indexingMaps,
iterators,
@ -112,41 +118,6 @@ struct GeneralizeConvOp
GenericOp createGenericOp(ConvOp convOp, OpBuilder &builder) const;
};
/// Catch-all pattern for converting all named ops with a region builder into
/// linalg.generic.
struct LinalgNamedOpGeneralizationPattern : RewritePattern {
LinalgNamedOpGeneralizationPattern(MLIRContext *context,
LinalgTransformationFilter marker,
PatternBenefit benefit = 1)
: RewritePattern(MatchAnyOpTypeTag(), benefit, context),
marker(std::move(marker)) {}
LogicalResult matchAndRewrite(Operation *rootOp,
PatternRewriter &rewriter) const override {
auto linalgOp = dyn_cast<LinalgOp>(rootOp);
if (!linalgOp)
return failure();
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
return failure();
// No nothing to do for linalg.generic.
if (isa<GenericOp>(rootOp))
return failure();
GenericOp genericOp = createGenericOpFromNamedOp(linalgOp, rewriter);
if (!genericOp)
return failure();
rewriter.replaceOp(rootOp, genericOp.getResults());
marker.replaceLinalgTransformationFilter(rewriter,
genericOp.getOperation());
return success();
}
private:
LinalgTransformationFilter marker;
};
struct LinalgGeneralizationPass
: public LinalgGeneralizationBase<LinalgGeneralizationPass> {
void runOnFunction() override;
@ -187,8 +158,7 @@ void mlir::linalg::populateLinalgConvGeneralizationPatterns(
void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
RewritePatternSet &patterns, LinalgTransformationFilter marker) {
patterns.add<LinalgNamedOpGeneralizationPattern>(patterns.getContext(),
marker);
patterns.add<LinalgGeneralizationPattern>(patterns.getContext(), marker);
}
std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {

View File

@ -488,6 +488,30 @@ LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
return success();
}
/// Linalg generalization pattern.
mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
MLIRContext *context, LinalgTransformationFilter filter,
PatternBenefit benefit)
: RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
PatternBenefit benefit)
: RewritePattern(opName, benefit, context, {}), filter(filter) {}
LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
if (failed(filter.checkAndNotify(rewriter, op)))
return failure();
if (failed(generalizeNamedOpPrecondition(op)))
return failure();
GenericOp genericOp = generalizeNamedOp(rewriter, op);
rewriter.replaceOp(op, genericOp.getResults());
filter.replaceLinalgTransformationFilter(rewriter, genericOp);
return success();
}
mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
MLIRContext *context, LinalgTransformationFilter filter,
LinalgPromotionOptions options, PatternBenefit benefit)