diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 8099fe9fed51..6c044bc26c93 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -437,6 +437,7 @@ struct LinalgTransformationFilter { LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const; void replaceLinalgTransformationFilter(PatternRewriter &rewriter, Operation *op) const; + bool hasReplacementFilter(Operation *op) const; LinalgTransformationFilter &addFilter(FilterFunction f) { if (f) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 57cc48fd314c..657f2b760558 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -107,6 +107,15 @@ void mlir::linalg::LinalgTransformationFilter:: rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker)); } +bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter( + Operation *op) const { + if (!replacement) + return false; + auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker) + .dyn_cast(); + return attr && attr == replacement.getValue(); +} + LinalgTilingOptions & mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef ts) { assert(!tileSizeComputationFunction && "tile sizes already set");