forked from OSchip/llvm-project
[mlir][linalg] Move generalization pattern to Transforms (NFC).
Move the generalization pattern to the other Linalg transforms to make it available to the codegen strategy. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D110728
This commit is contained in:
parent
cf818b55e7
commit
e826db6240
|
@ -234,6 +234,10 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
|
|||
void interchangeGenericOp(PatternRewriter &rewriter, GenericOp genericOp,
|
||||
ArrayRef<unsigned> interchangeVector);
|
||||
|
||||
/// Creates a GenericOp from the given named operation `namedOp`. Assumes
|
||||
/// `namedOp` is not a GenericOp and has a region builder.
|
||||
GenericOp generalizeNamedOp(PatternRewriter &rewriter, LinalgOp namedOp);
|
||||
|
||||
/// Callback function type used to perform the allocation for the promoted
|
||||
/// `subView`. In `boundingSubViewsize` a best attempt is made to find the
|
||||
/// smallest constant value for the size of the buffer needed for each
|
||||
|
@ -380,6 +384,9 @@ LogicalResult
|
|||
interchangeGenericOpPrecondition(GenericOp genericOp,
|
||||
ArrayRef<unsigned> interchangeVector);
|
||||
|
||||
/// Generalize named operations to generic operations.
|
||||
LogicalResult generalizeNamedOpPrecondition(Operation *op);
|
||||
|
||||
/// Promote std.subviews feeding linalg operations.
|
||||
LogicalResult promoteSubviewsPrecondition(Operation *op,
|
||||
LinalgPromotionOptions options);
|
||||
|
@ -701,6 +708,31 @@ private:
|
|||
SmallVector<unsigned, 8> interchangeVector;
|
||||
};
|
||||
|
||||
///
|
||||
/// Linalg generalization pattern.
|
||||
///
|
||||
/// Apply the `generalization` transformation as a pattern.
|
||||
/// `filter` controls LinalgTransformMarker matching and update when specified.
|
||||
/// See `generalization` for more details.
|
||||
struct LinalgGeneralizationPattern : public RewritePattern {
|
||||
// Entry point to match any LinalgOp OpInterface.
|
||||
LinalgGeneralizationPattern(
|
||||
MLIRContext *context,
|
||||
LinalgTransformationFilter filter = LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1);
|
||||
// Entry point to match a specific Linalg op.
|
||||
LinalgGeneralizationPattern(
|
||||
StringRef opName, MLIRContext *context,
|
||||
LinalgTransformationFilter filter = LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1);
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
/// LinalgTransformMarker handles special attribute manipulations.
|
||||
LinalgTransformationFilter filter;
|
||||
};
|
||||
|
||||
///
|
||||
/// Linalg promotion patterns.
|
||||
///
|
||||
|
|
|
@ -29,10 +29,19 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
// Creates a linalg.generic op from the given `namedOp`. Returns a null op if
|
||||
// the given `namedOp` does not have a region builder.
|
||||
static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp,
|
||||
PatternRewriter &rewriter) {
|
||||
LogicalResult mlir::linalg::generalizeNamedOpPrecondition(Operation *op) {
|
||||
LinalgOp namedOp = dyn_cast<LinalgOp>(op);
|
||||
// Check if the operation is a LinalgOp but not a GenericOp.
|
||||
if (!namedOp || isa<GenericOp>(op))
|
||||
return failure();
|
||||
// Check if the operation has a region builder.
|
||||
if (!namedOp.getRegionBuilder())
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
GenericOp mlir::linalg::generalizeNamedOp(PatternRewriter &rewriter,
|
||||
LinalgOp namedOp) {
|
||||
SmallVector<Value> inputOperands = namedOp.getInputOperands();
|
||||
SmallVector<Value> outputOperands = namedOp.getOutputOperands();
|
||||
SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps();
|
||||
|
@ -54,10 +63,7 @@ static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp,
|
|||
// Otherwise use the region builder to generate a new region.
|
||||
// TODO: Remove this path once all linag operations have a region attached.
|
||||
auto regionBuilder = namedOp.getRegionBuilder();
|
||||
if (!regionBuilder) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n");
|
||||
return nullptr;
|
||||
}
|
||||
assert(regionBuilder && "expect the operation to have region builder");
|
||||
return rewriter.create<GenericOp>(
|
||||
namedOp.getLoc(), types, inputOperands, outputOperands, indexingMaps,
|
||||
iterators,
|
||||
|
@ -112,41 +118,6 @@ struct GeneralizeConvOp
|
|||
GenericOp createGenericOp(ConvOp convOp, OpBuilder &builder) const;
|
||||
};
|
||||
|
||||
/// Catch-all pattern for converting all named ops with a region builder into
|
||||
/// linalg.generic.
|
||||
struct LinalgNamedOpGeneralizationPattern : RewritePattern {
|
||||
LinalgNamedOpGeneralizationPattern(MLIRContext *context,
|
||||
LinalgTransformationFilter marker,
|
||||
PatternBenefit benefit = 1)
|
||||
: RewritePattern(MatchAnyOpTypeTag(), benefit, context),
|
||||
marker(std::move(marker)) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *rootOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto linalgOp = dyn_cast<LinalgOp>(rootOp);
|
||||
if (!linalgOp)
|
||||
return failure();
|
||||
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
|
||||
return failure();
|
||||
|
||||
// No nothing to do for linalg.generic.
|
||||
if (isa<GenericOp>(rootOp))
|
||||
return failure();
|
||||
|
||||
GenericOp genericOp = createGenericOpFromNamedOp(linalgOp, rewriter);
|
||||
if (!genericOp)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(rootOp, genericOp.getResults());
|
||||
marker.replaceLinalgTransformationFilter(rewriter,
|
||||
genericOp.getOperation());
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
LinalgTransformationFilter marker;
|
||||
};
|
||||
|
||||
struct LinalgGeneralizationPass
|
||||
: public LinalgGeneralizationBase<LinalgGeneralizationPass> {
|
||||
void runOnFunction() override;
|
||||
|
@ -187,8 +158,7 @@ void mlir::linalg::populateLinalgConvGeneralizationPatterns(
|
|||
|
||||
void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
|
||||
RewritePatternSet &patterns, LinalgTransformationFilter marker) {
|
||||
patterns.add<LinalgNamedOpGeneralizationPattern>(patterns.getContext(),
|
||||
marker);
|
||||
patterns.add<LinalgGeneralizationPattern>(patterns.getContext(), marker);
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {
|
||||
|
|
|
@ -488,6 +488,30 @@ LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
/// Linalg generalization pattern.
|
||||
mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
|
||||
MLIRContext *context, LinalgTransformationFilter filter,
|
||||
PatternBenefit benefit)
|
||||
: RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
|
||||
|
||||
mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
|
||||
StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
|
||||
PatternBenefit benefit)
|
||||
: RewritePattern(opName, benefit, context, {}), filter(filter) {}
|
||||
|
||||
LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite(
|
||||
Operation *op, PatternRewriter &rewriter) const {
|
||||
if (failed(filter.checkAndNotify(rewriter, op)))
|
||||
return failure();
|
||||
if (failed(generalizeNamedOpPrecondition(op)))
|
||||
return failure();
|
||||
|
||||
GenericOp genericOp = generalizeNamedOp(rewriter, op);
|
||||
rewriter.replaceOp(op, genericOp.getResults());
|
||||
filter.replaceLinalgTransformationFilter(rewriter, genericOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
|
||||
MLIRContext *context, LinalgTransformationFilter filter,
|
||||
LinalgPromotionOptions options, PatternBenefit benefit)
|
||||
|
|
Loading…
Reference in New Issue