forked from OSchip/llvm-project
[mlir][linalg] Move generalization pattern to Transforms (NFC).
Move the generalization pattern to the other Linalg transforms to make it available to the codegen strategy. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D110728
This commit is contained in:
parent
cf818b55e7
commit
e826db6240
|
@ -234,6 +234,10 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
|
||||||
void interchangeGenericOp(PatternRewriter &rewriter, GenericOp genericOp,
|
void interchangeGenericOp(PatternRewriter &rewriter, GenericOp genericOp,
|
||||||
ArrayRef<unsigned> interchangeVector);
|
ArrayRef<unsigned> interchangeVector);
|
||||||
|
|
||||||
|
/// Creates a GenericOp from the given named operation `namedOp`. Assumes
|
||||||
|
/// `namedOp` is not a GenericOp and has a region builder.
|
||||||
|
GenericOp generalizeNamedOp(PatternRewriter &rewriter, LinalgOp namedOp);
|
||||||
|
|
||||||
/// Callback function type used to perform the allocation for the promoted
|
/// Callback function type used to perform the allocation for the promoted
|
||||||
/// `subView`. In `boundingSubViewsize` a best attempt is made to find the
|
/// `subView`. In `boundingSubViewsize` a best attempt is made to find the
|
||||||
/// smallest constant value for the size of the buffer needed for each
|
/// smallest constant value for the size of the buffer needed for each
|
||||||
|
@ -380,6 +384,9 @@ LogicalResult
|
||||||
interchangeGenericOpPrecondition(GenericOp genericOp,
|
interchangeGenericOpPrecondition(GenericOp genericOp,
|
||||||
ArrayRef<unsigned> interchangeVector);
|
ArrayRef<unsigned> interchangeVector);
|
||||||
|
|
||||||
|
/// Generalize named operations to generic operations.
|
||||||
|
LogicalResult generalizeNamedOpPrecondition(Operation *op);
|
||||||
|
|
||||||
/// Promote std.subviews feeding linalg operations.
|
/// Promote std.subviews feeding linalg operations.
|
||||||
LogicalResult promoteSubviewsPrecondition(Operation *op,
|
LogicalResult promoteSubviewsPrecondition(Operation *op,
|
||||||
LinalgPromotionOptions options);
|
LinalgPromotionOptions options);
|
||||||
|
@ -701,6 +708,31 @@ private:
|
||||||
SmallVector<unsigned, 8> interchangeVector;
|
SmallVector<unsigned, 8> interchangeVector;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Linalg generalization pattern.
|
||||||
|
///
|
||||||
|
/// Apply the `generalization` transformation as a pattern.
|
||||||
|
/// `filter` controls LinalgTransformMarker matching and update when specified.
|
||||||
|
/// See `generalization` for more details.
|
||||||
|
struct LinalgGeneralizationPattern : public RewritePattern {
|
||||||
|
// Entry point to match any LinalgOp OpInterface.
|
||||||
|
LinalgGeneralizationPattern(
|
||||||
|
MLIRContext *context,
|
||||||
|
LinalgTransformationFilter filter = LinalgTransformationFilter(),
|
||||||
|
PatternBenefit benefit = 1);
|
||||||
|
// Entry point to match a specific Linalg op.
|
||||||
|
LinalgGeneralizationPattern(
|
||||||
|
StringRef opName, MLIRContext *context,
|
||||||
|
LinalgTransformationFilter filter = LinalgTransformationFilter(),
|
||||||
|
PatternBenefit benefit = 1);
|
||||||
|
LogicalResult matchAndRewrite(Operation *op,
|
||||||
|
PatternRewriter &rewriter) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// LinalgTransformMarker handles special attribute manipulations.
|
||||||
|
LinalgTransformationFilter filter;
|
||||||
|
};
|
||||||
|
|
||||||
///
|
///
|
||||||
/// Linalg promotion patterns.
|
/// Linalg promotion patterns.
|
||||||
///
|
///
|
||||||
|
|
|
@ -29,10 +29,19 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::linalg;
|
using namespace mlir::linalg;
|
||||||
|
|
||||||
// Creates a linalg.generic op from the given `namedOp`. Returns a null op if
|
LogicalResult mlir::linalg::generalizeNamedOpPrecondition(Operation *op) {
|
||||||
// the given `namedOp` does not have a region builder.
|
LinalgOp namedOp = dyn_cast<LinalgOp>(op);
|
||||||
static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp,
|
// Check if the operation is a LinalgOp but not a GenericOp.
|
||||||
PatternRewriter &rewriter) {
|
if (!namedOp || isa<GenericOp>(op))
|
||||||
|
return failure();
|
||||||
|
// Check if the operation has a region builder.
|
||||||
|
if (!namedOp.getRegionBuilder())
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
GenericOp mlir::linalg::generalizeNamedOp(PatternRewriter &rewriter,
|
||||||
|
LinalgOp namedOp) {
|
||||||
SmallVector<Value> inputOperands = namedOp.getInputOperands();
|
SmallVector<Value> inputOperands = namedOp.getInputOperands();
|
||||||
SmallVector<Value> outputOperands = namedOp.getOutputOperands();
|
SmallVector<Value> outputOperands = namedOp.getOutputOperands();
|
||||||
SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps();
|
SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps();
|
||||||
|
@ -54,10 +63,7 @@ static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp,
|
||||||
// Otherwise use the region builder to generate a new region.
|
// Otherwise use the region builder to generate a new region.
|
||||||
// TODO: Remove this path once all linag operations have a region attached.
|
// TODO: Remove this path once all linag operations have a region attached.
|
||||||
auto regionBuilder = namedOp.getRegionBuilder();
|
auto regionBuilder = namedOp.getRegionBuilder();
|
||||||
if (!regionBuilder) {
|
assert(regionBuilder && "expect the operation to have region builder");
|
||||||
LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return rewriter.create<GenericOp>(
|
return rewriter.create<GenericOp>(
|
||||||
namedOp.getLoc(), types, inputOperands, outputOperands, indexingMaps,
|
namedOp.getLoc(), types, inputOperands, outputOperands, indexingMaps,
|
||||||
iterators,
|
iterators,
|
||||||
|
@ -112,41 +118,6 @@ struct GeneralizeConvOp
|
||||||
GenericOp createGenericOp(ConvOp convOp, OpBuilder &builder) const;
|
GenericOp createGenericOp(ConvOp convOp, OpBuilder &builder) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Catch-all pattern for converting all named ops with a region builder into
|
|
||||||
/// linalg.generic.
|
|
||||||
struct LinalgNamedOpGeneralizationPattern : RewritePattern {
|
|
||||||
LinalgNamedOpGeneralizationPattern(MLIRContext *context,
|
|
||||||
LinalgTransformationFilter marker,
|
|
||||||
PatternBenefit benefit = 1)
|
|
||||||
: RewritePattern(MatchAnyOpTypeTag(), benefit, context),
|
|
||||||
marker(std::move(marker)) {}
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(Operation *rootOp,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
auto linalgOp = dyn_cast<LinalgOp>(rootOp);
|
|
||||||
if (!linalgOp)
|
|
||||||
return failure();
|
|
||||||
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// No nothing to do for linalg.generic.
|
|
||||||
if (isa<GenericOp>(rootOp))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
GenericOp genericOp = createGenericOpFromNamedOp(linalgOp, rewriter);
|
|
||||||
if (!genericOp)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
rewriter.replaceOp(rootOp, genericOp.getResults());
|
|
||||||
marker.replaceLinalgTransformationFilter(rewriter,
|
|
||||||
genericOp.getOperation());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
LinalgTransformationFilter marker;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct LinalgGeneralizationPass
|
struct LinalgGeneralizationPass
|
||||||
: public LinalgGeneralizationBase<LinalgGeneralizationPass> {
|
: public LinalgGeneralizationBase<LinalgGeneralizationPass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
|
@ -187,8 +158,7 @@ void mlir::linalg::populateLinalgConvGeneralizationPatterns(
|
||||||
|
|
||||||
void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
|
void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
|
||||||
RewritePatternSet &patterns, LinalgTransformationFilter marker) {
|
RewritePatternSet &patterns, LinalgTransformationFilter marker) {
|
||||||
patterns.add<LinalgNamedOpGeneralizationPattern>(patterns.getContext(),
|
patterns.add<LinalgGeneralizationPattern>(patterns.getContext(), marker);
|
||||||
marker);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {
|
std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {
|
||||||
|
|
|
@ -488,6 +488,30 @@ LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Linalg generalization pattern.
|
||||||
|
mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
|
||||||
|
MLIRContext *context, LinalgTransformationFilter filter,
|
||||||
|
PatternBenefit benefit)
|
||||||
|
: RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
|
||||||
|
|
||||||
|
mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
|
||||||
|
StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
|
||||||
|
PatternBenefit benefit)
|
||||||
|
: RewritePattern(opName, benefit, context, {}), filter(filter) {}
|
||||||
|
|
||||||
|
LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite(
|
||||||
|
Operation *op, PatternRewriter &rewriter) const {
|
||||||
|
if (failed(filter.checkAndNotify(rewriter, op)))
|
||||||
|
return failure();
|
||||||
|
if (failed(generalizeNamedOpPrecondition(op)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
GenericOp genericOp = generalizeNamedOp(rewriter, op);
|
||||||
|
rewriter.replaceOp(op, genericOp.getResults());
|
||||||
|
filter.replaceLinalgTransformationFilter(rewriter, genericOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
|
mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
|
||||||
MLIRContext *context, LinalgTransformationFilter filter,
|
MLIRContext *context, LinalgTransformationFilter filter,
|
||||||
LinalgPromotionOptions options, PatternBenefit benefit)
|
LinalgPromotionOptions options, PatternBenefit benefit)
|
||||||
|
|
Loading…
Reference in New Issue