[mlir][linalg] Move generalization pattern to Transforms (NFC).

Move the generalization pattern to the other Linalg transforms to make it available to the codegen strategy.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D110728
This commit is contained in:
Tobias Gysi 2021-10-05 12:24:19 +00:00
parent cf818b55e7
commit e826db6240
3 changed files with 71 additions and 45 deletions

View File

@ -234,6 +234,10 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
void interchangeGenericOp(PatternRewriter &rewriter, GenericOp genericOp, void interchangeGenericOp(PatternRewriter &rewriter, GenericOp genericOp,
ArrayRef<unsigned> interchangeVector); ArrayRef<unsigned> interchangeVector);
/// Creates a GenericOp from the given named operation `namedOp`. Assumes
/// `namedOp` is not a GenericOp and has a region builder.
GenericOp generalizeNamedOp(PatternRewriter &rewriter, LinalgOp namedOp);
/// Callback function type used to perform the allocation for the promoted /// Callback function type used to perform the allocation for the promoted
/// `subView`. In `boundingSubViewsize` a best attempt is made to find the /// `subView`. In `boundingSubViewsize` a best attempt is made to find the
/// smallest constant value for the size of the buffer needed for each /// smallest constant value for the size of the buffer needed for each
@ -380,6 +384,9 @@ LogicalResult
interchangeGenericOpPrecondition(GenericOp genericOp, interchangeGenericOpPrecondition(GenericOp genericOp,
ArrayRef<unsigned> interchangeVector); ArrayRef<unsigned> interchangeVector);
/// Generalize named operations to generic operations.
LogicalResult generalizeNamedOpPrecondition(Operation *op);
/// Promote std.subviews feeding linalg operations. /// Promote std.subviews feeding linalg operations.
LogicalResult promoteSubviewsPrecondition(Operation *op, LogicalResult promoteSubviewsPrecondition(Operation *op,
LinalgPromotionOptions options); LinalgPromotionOptions options);
@ -701,6 +708,31 @@ private:
SmallVector<unsigned, 8> interchangeVector; SmallVector<unsigned, 8> interchangeVector;
}; };
///
/// Linalg generalization pattern.
///
/// Apply the `generalization` transformation as a pattern.
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `generalization` for more details.
struct LinalgGeneralizationPattern : public RewritePattern {
// Entry point to match any LinalgOp OpInterface.
LinalgGeneralizationPattern(
MLIRContext *context,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
// Entry point to match a specific Linalg op.
LinalgGeneralizationPattern(
StringRef opName, MLIRContext *context,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgTransformationFilter filter;
};
/// ///
/// Linalg promotion patterns. /// Linalg promotion patterns.
/// ///

View File

@ -29,10 +29,19 @@
using namespace mlir; using namespace mlir;
using namespace mlir::linalg; using namespace mlir::linalg;
// Creates a linalg.generic op from the given `namedOp`. Returns a null op if LogicalResult mlir::linalg::generalizeNamedOpPrecondition(Operation *op) {
// the given `namedOp` does not have a region builder. LinalgOp namedOp = dyn_cast<LinalgOp>(op);
static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp, // Check if the operation is a LinalgOp but not a GenericOp.
PatternRewriter &rewriter) { if (!namedOp || isa<GenericOp>(op))
return failure();
// Check if the operation has a region builder.
if (!namedOp.getRegionBuilder())
return failure();
return success();
}
GenericOp mlir::linalg::generalizeNamedOp(PatternRewriter &rewriter,
LinalgOp namedOp) {
SmallVector<Value> inputOperands = namedOp.getInputOperands(); SmallVector<Value> inputOperands = namedOp.getInputOperands();
SmallVector<Value> outputOperands = namedOp.getOutputOperands(); SmallVector<Value> outputOperands = namedOp.getOutputOperands();
SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps(); SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps();
@ -54,10 +63,7 @@ static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp,
// Otherwise use the region builder to generate a new region. // Otherwise use the region builder to generate a new region.
// TODO: Remove this path once all linag operations have a region attached. // TODO: Remove this path once all linag operations have a region attached.
auto regionBuilder = namedOp.getRegionBuilder(); auto regionBuilder = namedOp.getRegionBuilder();
if (!regionBuilder) { assert(regionBuilder && "expect the operation to have region builder");
LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n");
return nullptr;
}
return rewriter.create<GenericOp>( return rewriter.create<GenericOp>(
namedOp.getLoc(), types, inputOperands, outputOperands, indexingMaps, namedOp.getLoc(), types, inputOperands, outputOperands, indexingMaps,
iterators, iterators,
@ -112,41 +118,6 @@ struct GeneralizeConvOp
GenericOp createGenericOp(ConvOp convOp, OpBuilder &builder) const; GenericOp createGenericOp(ConvOp convOp, OpBuilder &builder) const;
}; };
/// Catch-all pattern for converting all named ops with a region builder into
/// linalg.generic.
struct LinalgNamedOpGeneralizationPattern : RewritePattern {
LinalgNamedOpGeneralizationPattern(MLIRContext *context,
LinalgTransformationFilter marker,
PatternBenefit benefit = 1)
: RewritePattern(MatchAnyOpTypeTag(), benefit, context),
marker(std::move(marker)) {}
LogicalResult matchAndRewrite(Operation *rootOp,
PatternRewriter &rewriter) const override {
auto linalgOp = dyn_cast<LinalgOp>(rootOp);
if (!linalgOp)
return failure();
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
return failure();
// No nothing to do for linalg.generic.
if (isa<GenericOp>(rootOp))
return failure();
GenericOp genericOp = createGenericOpFromNamedOp(linalgOp, rewriter);
if (!genericOp)
return failure();
rewriter.replaceOp(rootOp, genericOp.getResults());
marker.replaceLinalgTransformationFilter(rewriter,
genericOp.getOperation());
return success();
}
private:
LinalgTransformationFilter marker;
};
struct LinalgGeneralizationPass struct LinalgGeneralizationPass
: public LinalgGeneralizationBase<LinalgGeneralizationPass> { : public LinalgGeneralizationBase<LinalgGeneralizationPass> {
void runOnFunction() override; void runOnFunction() override;
@ -187,8 +158,7 @@ void mlir::linalg::populateLinalgConvGeneralizationPatterns(
void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns( void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
RewritePatternSet &patterns, LinalgTransformationFilter marker) { RewritePatternSet &patterns, LinalgTransformationFilter marker) {
patterns.add<LinalgNamedOpGeneralizationPattern>(patterns.getContext(), patterns.add<LinalgGeneralizationPattern>(patterns.getContext(), marker);
marker);
} }
std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() { std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {

View File

@ -488,6 +488,30 @@ LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
return success(); return success();
} }
/// Linalg generalization pattern.
mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
MLIRContext *context, LinalgTransformationFilter filter,
PatternBenefit benefit)
: RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
PatternBenefit benefit)
: RewritePattern(opName, benefit, context, {}), filter(filter) {}
LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
if (failed(filter.checkAndNotify(rewriter, op)))
return failure();
if (failed(generalizeNamedOpPrecondition(op)))
return failure();
GenericOp genericOp = generalizeNamedOp(rewriter, op);
rewriter.replaceOp(op, genericOp.getResults());
filter.replaceLinalgTransformationFilter(rewriter, genericOp);
return success();
}
mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
MLIRContext *context, LinalgTransformationFilter filter, MLIRContext *context, LinalgTransformationFilter filter,
LinalgPromotionOptions options, PatternBenefit benefit) LinalgPromotionOptions options, PatternBenefit benefit)