[mlir][Linalg] Further improve codegen strategy and add a linalg.matmul_i8_i8_i32

This revision adds a layer of SFINAE to the composable codegen strategy so it does
not have to require statically defined ops but instead can also be used with OpInterfaces, Operation* and an op name string.

A linalg.matmul_i8_i8_i32 is added to the .tc spec to demonstrate how all this works end to end.

Differential Revision: https://reviews.llvm.org/D95600
This commit is contained in:
Nicolas Vasilache 2021-01-28 12:55:40 +00:00
parent 42635856ed
commit 299cc5da6d
16 changed files with 666 additions and 294 deletions

View File

@ -8,6 +8,13 @@ def matmul_column_major(A: f32(K, M), B: f32(N, K)) -> (C: f32(N, M)) {
C(n, m) = std_addf<k>(std_mulf(A(k, m), B(n, k)));
}
ods_def<MatmulI8I8I32Op>:
def matmul_i8_i8_i32(A: i8(M, K), B: i8(K, N)) -> (C: i32(M, N)) {
// TODO: ideally something closer to
// C(m, n) += cast<i32>(A(m, k)) * cast<i32>(B(k, n))
C(m, n) = std_addi<k>(std_sexti32(std_muli(A(m, k), B(k, n))));
}
ods_def<MatvecOp>:
def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) {
x(m) = std_addf<n>(std_mulf(A(m, n), y(n)));

View File

@ -21,27 +21,63 @@ namespace linalg {
/// Abstract Transformation class applied in a sequence that also handles state
/// through markers.
struct Transformation {
explicit Transformation(linalg::LinalgTransformationFilter::FilterFunction f)
: filter(f) {}
virtual ~Transformation() = default;
virtual OwningRewritePatternList
buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) = 0;
linalg::LinalgMarker marker;
buildRewritePatterns(MLIRContext *context,
linalg::LinalgTransformationFilter m) = 0;
linalg::LinalgTransformationFilter::FilterFunction filter = nullptr;
};
/// SFINAE: Enqueue helper for ConcreteOpType that have a `getOperationName`.
template <template <typename> class PatternType, typename ConcreteOpType,
typename OptionsType,
typename std::enable_if<std::is_member_function_pointer<
decltype(&ConcreteOpType::getOperationName)>::value>>
void sfinae_enqueue(OwningRewritePatternList &patterList, OptionsType options,
MLIRContext *context, StringRef opName,
linalg::LinalgTransformationFilter m) {
assert(opName.empty() ||
opName == ConcreteOpType::getOperationName() &&
"explicit name must match ConcreteOpType::getOperationName");
patterList.insert<PatternType<ConcreteOpType>>(context, options, m);
}
/// SFINAE: Enqueue helper for OpType that do not have a `getOperationName`
/// (e.g. LinalgOp, other interfaces, Operation*).
template <template <typename> class PatternType, typename OpType,
typename OptionsType>
void sfinae_enqueue(OwningRewritePatternList &patterList, OptionsType options,
MLIRContext *context, StringRef opName,
linalg::LinalgTransformationFilter m) {
assert(!opName.empty() && "opName must not be empty");
patterList.insert<PatternType<OpType>>(opName, context, options, m);
}
/// Promotion transformation enqueues a particular stage-1 pattern for
/// `Tile<LinalgOpType>`with the appropriate `options`.
template <typename LinalgOpType>
struct Tile : public Transformation {
explicit Tile(linalg::LinalgTilingOptions options) : options(options) {}
explicit Tile(linalg::LinalgTilingOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
: Transformation(f), opName(""), options(options) {}
Tile(StringRef name, linalg::LinalgTilingOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
: Transformation(f), opName(name), options(options) {}
OwningRewritePatternList
buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) override {
buildRewritePatterns(MLIRContext *context,
linalg::LinalgTransformationFilter m) override {
OwningRewritePatternList tilingPatterns;
tilingPatterns.insert<linalg::LinalgTilingPattern<LinalgOpType>>(
context, options, m);
sfinae_enqueue<linalg::LinalgTilingPattern, LinalgOpType>(
tilingPatterns, options, context, opName, m);
return tilingPatterns;
}
private:
std::string opName;
linalg::LinalgTilingOptions options;
};
@ -49,17 +85,26 @@ private:
/// `Promote<LinalgOpType>`with the appropriate `options`.
template <typename LinalgOpType>
struct Promote : public Transformation {
explicit Promote(linalg::LinalgPromotionOptions options) : options(options) {}
explicit Promote(
linalg::LinalgPromotionOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
: Transformation(f), opName(""), options(options) {}
Promote(StringRef name, linalg::LinalgPromotionOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
: Transformation(f), opName(name), options(options) {}
OwningRewritePatternList
buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) override {
buildRewritePatterns(MLIRContext *context,
linalg::LinalgTransformationFilter m) override {
OwningRewritePatternList promotionPatterns;
promotionPatterns.insert<linalg::LinalgPromotionPattern<LinalgOpType>>(
context, options, m);
sfinae_enqueue<linalg::LinalgPromotionPattern, LinalgOpType>(
promotionPatterns, options, context, opName, m);
return promotionPatterns;
}
private:
std::string opName;
linalg::LinalgPromotionOptions options;
};
@ -68,25 +113,36 @@ private:
/// transfer rewrite forwarding patterns.
template <typename LinalgOpType>
struct Vectorize : public Transformation {
explicit Vectorize(
linalg::LinalgVectorizationOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
: Transformation(f), opName(""), options(options) {}
Vectorize(StringRef name, linalg::LinalgVectorizationOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
: Transformation(f), opName(name), options(options) {}
OwningRewritePatternList
buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) override {
buildRewritePatterns(MLIRContext *context,
linalg::LinalgTransformationFilter m) override {
OwningRewritePatternList vectorizationPatterns;
// FillOp may interfere with forwarding patterns atm, so we bump up the
// priority of LinalgCopyVTRForwardingPattern /
// LinalgCopyVTWForwardingPattern.
vectorizationPatterns
.insert<linalg::LinalgVectorizationPattern<LinalgOpType>>(context, m);
sfinae_enqueue<linalg::LinalgVectorizationPattern, LinalgOpType>(
vectorizationPatterns, options, context, opName, m);
vectorizationPatterns.insert<linalg::LinalgCopyVTRForwardingPattern,
linalg::LinalgCopyVTWForwardingPattern>(
context, /*benefit=*/2);
return vectorizationPatterns;
}
private:
std::string opName;
linalg::LinalgVectorizationOptions options;
};
/// Codegen strategy controls how a Linalg op is progressively lowered.
/// The application uses a 3-level staged patterns strategy which allows
/// ordering transformations by using the Linalg `applyStagedPatterns` function,
/// where:
/// ordering transformations by using the Linalg `applyStagedPatterns`
/// function, where:
/// 1. The first stage consists of the successive `tile`, `promote` and
/// `vectorize` patterns, applied sequentially.
/// 2. The second stage consists of common local canonicalization patterns
@ -97,41 +153,112 @@ struct CodegenStrategy {
/// Append a pattern to add a level of tiling for `LinalgOpType` with tiling
/// `options`.
template <typename LinalgOpType>
CodegenStrategy &tile(linalg::LinalgTilingOptions options) {
transformationSequence.emplace_back(new Tile<LinalgOpType>(options));
CodegenStrategy &
tile(linalg::LinalgTilingOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
transformationSequence.emplace_back(
std::make_unique<Tile<LinalgOpType>>(options, f));
return *this;
}
/// Conditionally append a pattern to add a level of tiling for `LinalgOpType`
/// with tiling `options`.
/// Append a pattern to add a level of tiling for `LinalgOpType` with tiling
/// `options`.
template <typename LinalgOpType>
CodegenStrategy &tileIf(bool b, linalg::LinalgTilingOptions options) {
CodegenStrategy &
tile(StringRef opName, linalg::LinalgTilingOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
transformationSequence.emplace_back(
std::make_unique<Tile<LinalgOpType>>(opName, options, f));
return *this;
}
/// Conditionally append a pattern to add a level of tiling for
/// `LinalgOpType` with tiling `options`.
template <typename LinalgOpType>
CodegenStrategy &
tileIf(bool b, linalg::LinalgTilingOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
return b ? tile<LinalgOpType>(options) : *this;
}
/// Conditionally append a pattern to add a level of tiling for
/// `LinalgOpType` with tiling `options`.
template <typename LinalgOpType>
CodegenStrategy &
tileIf(bool b, StringRef opName, linalg::LinalgTilingOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
return b ? tile<LinalgOpType>(opName, options) : *this;
}
/// Append a pattern to add a level of promotion for `LinalgOpType` with
/// promotion `options`.
template <typename LinalgOpType>
CodegenStrategy &promote(linalg::LinalgPromotionOptions options) {
transformationSequence.emplace_back(new Promote<LinalgOpType>(options));
CodegenStrategy &
promote(linalg::LinalgPromotionOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
transformationSequence.emplace_back(
std::make_unique<Promote<LinalgOpType>>(options, f));
return *this;
}
/// Append a pattern to add a level of promotion for `LinalgOpType` with
/// promotion `options`.
template <typename LinalgOpType>
CodegenStrategy &
promote(StringRef opName, linalg::LinalgPromotionOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
transformationSequence.emplace_back(
std::make_unique<Promote<LinalgOpType>>(opName, options, f));
return *this;
}
/// Conditionally append a pattern to add a level of promotion for
/// `LinalgOpType` with promotion `options`.
template <typename LinalgOpType>
CodegenStrategy &promoteIf(bool b, linalg::LinalgPromotionOptions options) {
return b ? promote<LinalgOpType>(options) : *this;
CodegenStrategy &
promoteIf(bool b, StringRef opName, linalg::LinalgPromotionOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
return b ? promote<LinalgOpType>(opName, options, f) : *this;
return *this;
}
/// Conditionally append a pattern to add a level of promotion for
/// `LinalgOpType` with promotion `options`.
template <typename LinalgOpType>
CodegenStrategy &
promoteIf(bool b, linalg::LinalgPromotionOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
return b ? promote<LinalgOpType>(options, f) : *this;
return *this;
}
/// Append a pattern to rewrite `LinalgOpType` as a vector operation.
template <typename LinalgOpType>
CodegenStrategy &vectorize() {
transformationSequence.emplace_back(new Vectorize<LinalgOpType>());
CodegenStrategy &
vectorize(linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
transformationSequence.emplace_back(
std::make_unique<Vectorize<LinalgOpType>>(
linalg::LinalgVectorizationOptions(), f));
return *this;
}
/// Append a pattern to rewrite `LinalgOpType` as a vector operation.
template <typename LinalgOpType>
CodegenStrategy &
vectorize(StringRef opName,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
transformationSequence.emplace_back(
std::make_unique<Vectorize<LinalgOpType>>(
opName, linalg::LinalgVectorizationOptions(), f));
return *this;
}
/// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
/// operation.
template <typename LinalgOpType>
CodegenStrategy &vectorizeIf(bool b) {
return b ? vectorize<LinalgOpType>() : *this;
CodegenStrategy &
vectorizeIf(bool b,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
return b ? vectorize<LinalgOpType>(f) : *this;
return *this;
}
/// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
/// operation.
template <typename LinalgOpType>
CodegenStrategy &
vectorizeIf(bool b, StringRef opName,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
return b ? vectorize<LinalgOpType>(opName, f) : *this;
return *this;
}
/// Configure the post staged-patterns late vector transformations.
@ -140,15 +267,22 @@ struct CodegenStrategy {
vectorTransformsOptions = options;
return *this;
}
/// Configure the post staged-patterns late vector.transfer to scf conversion.
/// Configure the post staged-patterns late vector.transfer to scf
/// conversion.
CodegenStrategy &
setVectorTransferToSCFOptions(VectorTransferToSCFOptions options) {
vectorToSCFOptions = options;
return *this;
}
/// Configure the post staged-patterns late vector.transfer to scf
/// conversion.
CodegenStrategy &setHoistInvariantCode(bool enableLICM) {
this->enableLICM = enableLICM;
return *this;
}
/// Apply the transformation patterns in sequence with cleanup transformations
/// interleaved.
/// Apply the transformation patterns in sequence with cleanup
/// transformations interleaved.
void transform(FuncOp func) const;
private:
@ -157,6 +291,7 @@ private:
vector::VectorTransformsOptions vectorTransformsOptions;
VectorTransferToSCFOptions vectorToSCFOptions;
SmallVector<std::unique_ptr<Transformation>, 4> transformationSequence;
bool enableLICM = true;
};
} // namespace linalg

View File

@ -316,16 +316,32 @@ struct LinalgTransforms {
static const StringLiteral kLinalgTransformMarker;
};
/// Helper class to control common attribute matching and setting behavior.
struct LinalgMarker {
explicit LinalgMarker(ArrayRef<Identifier> matchDisjunction = {},
Optional<Identifier> replacement = None);
LinalgMarker(LinalgMarker &&) = default;
LinalgMarker(const LinalgMarker &) = default;
/// Helper class to control application of linalg transformation patterns.
/// Control comes in 2 forms:
/// 1. attribute matching and setting behavior using the attribute named
/// `kLinalgTransformMarker`. This can be used to build a state machine
/// using attributes and incrementally applying patterns to advance states.
/// 2. filter function, which is a simple lambda on the Operation* that
/// returns a LogicalResult.
struct LinalgTransformationFilter {
using FilterFunction = std::function<LogicalResult(Operation *)>;
explicit LinalgTransformationFilter(
ArrayRef<Identifier> matchDisjunction = {},
Optional<Identifier> replacement = None);
explicit LinalgTransformationFilter(
FilterFunction f, ArrayRef<Identifier> matchDisjunction = {},
Optional<Identifier> replacement = None);
LinalgTransformationFilter(LinalgTransformationFilter &&) = default;
LinalgTransformationFilter(const LinalgTransformationFilter &) = default;
LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
void replaceLinalgMarker(PatternRewriter &rewriter, Operation *op) const;
void replaceLinalgTransformationFilter(PatternRewriter &rewriter,
Operation *op) const;
private:
FilterFunction filter;
SmallVector<Identifier, 4> matchDisjunction;
Optional<Identifier> replacement;
};
@ -425,31 +441,44 @@ void populateLinalgTilingCanonicalizationPatterns(
/// and some operand shape cannot be bounded statically.
struct LinalgBaseTilingPattern : public RewritePattern {
// Entry point to match any LinalgOp OpInterface.
LinalgBaseTilingPattern(LinalgTilingOptions options,
LinalgMarker marker = LinalgMarker(),
PatternBenefit benefit = 1);
LinalgBaseTilingPattern(
LinalgTilingOptions options,
LinalgTransformationFilter marker = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
// Entry point to match a specific Linalg op.
LinalgBaseTilingPattern(StringRef opName, MLIRContext *context,
LinalgTilingOptions options,
LinalgMarker marker = LinalgMarker(),
PatternBenefit benefit = 1);
LinalgBaseTilingPattern(
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
LinalgTransformationFilter marker = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
LogicalResult matchAndRewriteBase(Operation *op, PatternRewriter &rewriter,
TiledLinalgOp &result) const;
private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgMarker marker;
LinalgTransformationFilter marker;
/// Options to control tiling;
LinalgTilingOptions options;
};
template <typename OpTy>
struct LinalgTilingPattern : public LinalgBaseTilingPattern {
LinalgTilingPattern(MLIRContext *context, LinalgTilingOptions options,
LinalgMarker marker = LinalgMarker(),
PatternBenefit benefit = 1)
: LinalgBaseTilingPattern(OpTy::getOperationName(), context, options,
marker, benefit) {}
/// SFINAE: This constructor can only trigger for concrete ops that have a
/// static `getOperationName` method.
template <typename ConcreateOpTy = OpTy>
LinalgTilingPattern(
MLIRContext *context, LinalgTilingOptions options,
LinalgTransformationFilter marker = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: LinalgBaseTilingPattern(ConcreateOpTy::getOperationName(), context,
options, marker, benefit) {}
/// This constructor is available to anyone.
LinalgTilingPattern(
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
LinalgTransformationFilter marker = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: LinalgBaseTilingPattern(opName, context, options, marker, benefit) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
TiledLinalgOp tiledLinalgOp;
@ -474,14 +503,15 @@ struct LinalgFusionOptions {
};
struct LinalgBaseTileAndFusePattern : public RewritePattern {
LinalgBaseTileAndFusePattern(StringRef opName, MLIRContext *context,
const LinalgDependenceGraph &dependenceGraph,
LinalgTilingOptions tilingOptions,
LinalgFusionOptions fusionOptions,
LinalgMarker marker = LinalgMarker(),
LinalgMarker fusedOpMarker = LinalgMarker(),
LinalgMarker originalOpMarker = LinalgMarker(),
PatternBenefit benefit = 1);
LinalgBaseTileAndFusePattern(
StringRef opName, MLIRContext *context,
const LinalgDependenceGraph &dependenceGraph,
LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
LinalgTransformationFilter marker = LinalgTransformationFilter(),
LinalgTransformationFilter fusedOpMarker = LinalgTransformationFilter(),
LinalgTransformationFilter originalOpMarker =
LinalgTransformationFilter(),
PatternBenefit benefit = 1);
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
@ -493,27 +523,27 @@ private:
/// Options to control fusion.
LinalgFusionOptions fusionOptions;
/// Marker to control application of the pattern.
LinalgMarker marker;
LinalgTransformationFilter marker;
/// Marker set on the fused op after tile and fuse.
LinalgMarker fusedOpMarker;
LinalgTransformationFilter fusedOpMarker;
/// The dependenceGraph is not modifiable, i.e. if the Linalg operations used
/// to build the dependence graph changes then the dependenceGraph needs to be
/// recomputed right now. To not invalidate the dependenceGraph as
/// transformation happens, the original producer can be tagged with a marker
/// that can be later used to delete the original operations.
LinalgMarker originalOpMarker;
LinalgTransformationFilter originalOpMarker;
};
template <typename OpTy>
struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern {
LinalgTileAndFusePattern(MLIRContext *context,
const LinalgDependenceGraph &dependenceGraph,
LinalgTilingOptions tilingOptions,
LinalgFusionOptions fusionOptions,
LinalgMarker marker = LinalgMarker(),
LinalgMarker fusedOpMarker = LinalgMarker(),
LinalgMarker originalOpMarker = LinalgMarker(),
PatternBenefit benefit = 1)
LinalgTileAndFusePattern(
MLIRContext *context, const LinalgDependenceGraph &dependenceGraph,
LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
LinalgTransformationFilter marker = LinalgTransformationFilter(),
LinalgTransformationFilter fusedOpMarker = LinalgTransformationFilter(),
LinalgTransformationFilter originalOpMarker =
LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: LinalgBaseTileAndFusePattern(
OpTy::getOperationName(), context, dependenceGraph, tilingOptions,
fusionOptions, marker, fusedOpMarker, originalOpMarker, benefit) {}
@ -526,26 +556,27 @@ struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern {
/// `marker` controls LinalgTransformMarker matching and update when specified.
/// See `interchange` for more details.
struct LinalgBaseInterchangePattern : public RewritePattern {
LinalgBaseInterchangePattern(StringRef opName, MLIRContext *context,
ArrayRef<unsigned> interchangeVector,
LinalgMarker marker = LinalgMarker(),
PatternBenefit benefit = 1);
LinalgBaseInterchangePattern(
StringRef opName, MLIRContext *context,
ArrayRef<unsigned> interchangeVector,
LinalgTransformationFilter marker = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgMarker marker;
LinalgTransformationFilter marker;
/// The interchange vector to reorder the iterators and indexing_maps dims.
SmallVector<unsigned, 8> interchangeVector;
};
template <typename OpTy>
struct LinalgInterchangePattern : public LinalgBaseInterchangePattern {
LinalgInterchangePattern(MLIRContext *context,
ArrayRef<unsigned> interchangeVector,
LinalgMarker marker = LinalgMarker(),
PatternBenefit benefit = 1)
LinalgInterchangePattern(
MLIRContext *context, ArrayRef<unsigned> interchangeVector,
LinalgTransformationFilter marker = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: LinalgBaseInterchangePattern(OpTy::getOperationName(), context,
interchangeVector, marker, benefit) {}
};
@ -557,27 +588,38 @@ struct LinalgInterchangePattern : public LinalgBaseInterchangePattern {
/// `marker` controls LinalgTransformMarker matching and update when specified.
/// See `promoteSubViews` for more details.
struct LinalgBasePromotionPattern : public RewritePattern {
LinalgBasePromotionPattern(StringRef opName, MLIRContext *context,
LinalgPromotionOptions options,
LinalgMarker marker = LinalgMarker(),
PatternBenefit benefit = 1);
LinalgBasePromotionPattern(
StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
LinalgTransformationFilter marker = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgMarker marker;
LinalgTransformationFilter marker;
/// Promotion options.
LinalgPromotionOptions options;
};
template <typename OpTy>
struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
LinalgPromotionPattern(MLIRContext *context, LinalgPromotionOptions options,
LinalgMarker marker = LinalgMarker(),
PatternBenefit benefit = 1)
/// SFINAE: This constructor can only trigger for concrete ops that have a
/// static `getOperationName` method.
template <typename ConcreateOpTy = OpTy>
LinalgPromotionPattern(
MLIRContext *context, LinalgPromotionOptions options,
LinalgTransformationFilter marker = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: LinalgBasePromotionPattern(OpTy::getOperationName(), context, options,
marker, benefit) {}
/// This constructor is available to anyone.
LinalgPromotionPattern(
StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
LinalgTransformationFilter marker = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: LinalgBasePromotionPattern(opName, context, options, marker, benefit) {}
};
///
@ -586,25 +628,42 @@ struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
/// Apply the `vectorizeLinalgOp` transformation as a pattern.
/// `marker` controls LinalgTransformMarker matching and update when specified.
/// See `vectorizeLinalgOp` for more details.
/// Empty for now, used for SFINAE purposes only.
struct LinalgVectorizationOptions {};
struct LinalgBaseVectorizationPattern : public RewritePattern {
LinalgBaseVectorizationPattern(StringRef opName, MLIRContext *context,
LinalgMarker marker = LinalgMarker(),
PatternBenefit benefit = 1);
LinalgBaseVectorizationPattern(
StringRef opName, MLIRContext *context,
LinalgTransformationFilter marker = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgMarker marker;
LinalgTransformationFilter marker;
};
template <typename OpTy>
struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern {
LinalgVectorizationPattern(MLIRContext *context,
LinalgMarker marker = LinalgMarker(),
PatternBenefit benefit = 1)
/// SFINAE: This constructor can only trigger for concrete ops that have a
/// static `getOperationName` method.
template <typename ConcreateOpTy = OpTy>
LinalgVectorizationPattern(
MLIRContext *context,
LinalgVectorizationOptions options = LinalgVectorizationOptions(),
LinalgTransformationFilter marker = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: LinalgBaseVectorizationPattern(OpTy::getOperationName(), context,
marker, benefit) {}
/// This constructor is available to anyone.
LinalgVectorizationPattern(
StringRef opName, MLIRContext *context,
LinalgVectorizationOptions options = LinalgVectorizationOptions(),
LinalgTransformationFilter marker = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: LinalgBaseVectorizationPattern(opName, context, marker, benefit) {}
};
///
@ -622,10 +681,10 @@ enum class LinalgLoweringType {
template <typename OpTy>
struct LinalgLoweringPattern : public RewritePattern {
LinalgLoweringPattern(MLIRContext *context, LinalgLoweringType loweringType,
LinalgMarker marker = LinalgMarker(),
ArrayRef<unsigned> interchangeVector = {},
PatternBenefit benefit = 1)
LinalgLoweringPattern(
MLIRContext *context, LinalgLoweringType loweringType,
LinalgTransformationFilter marker = LinalgTransformationFilter(),
ArrayRef<unsigned> interchangeVector = {}, PatternBenefit benefit = 1)
: RewritePattern(OpTy::getOperationName(), {}, benefit, context),
marker(marker), loweringType(loweringType),
interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
@ -663,7 +722,7 @@ struct LinalgLoweringPattern : public RewritePattern {
private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgMarker marker;
LinalgTransformationFilter marker;
/// Controls whether the pattern lowers to library calls, scf.for, affine.for
/// or scf.parallel.
LinalgLoweringType loweringType;
@ -677,13 +736,13 @@ private:
/// linalg.generic ops.
void populateLinalgNamedOpsGeneralizationPatterns(
MLIRContext *context, OwningRewritePatternList &patterns,
LinalgMarker marker = LinalgMarker());
LinalgTransformationFilter marker = LinalgTransformationFilter());
/// Populates `patterns` with patterns to convert linalg.conv ops to
/// linalg.generic ops.
void populateLinalgConvGeneralizationPatterns(
MLIRContext *context, OwningRewritePatternList &patterns,
LinalgMarker marker = LinalgMarker());
LinalgTransformationFilter marker = LinalgTransformationFilter());
//===----------------------------------------------------------------------===//
// Op-specific patterns.

View File

@ -38,6 +38,7 @@ using std_ret = OperationBuilder<ReturnOp>;
using std_rsqrt = ValueBuilder<RsqrtOp>;
using std_select = ValueBuilder<SelectOp>;
using std_load = ValueBuilder<LoadOp>;
using std_sign_extendi = ValueBuilder<SignExtendIOp>;
using std_splat = ValueBuilder<SplatOp>;
using std_store = OperationBuilder<StoreOp>;
using std_subf = ValueBuilder<SubFOp>;
@ -48,9 +49,19 @@ using std_tensor_load = ValueBuilder<TensorLoadOp>;
using std_tensor_store = OperationBuilder<TensorStoreOp>;
using std_view = ValueBuilder<ViewOp>;
using std_zero_extendi = ValueBuilder<ZeroExtendIOp>;
using std_sign_extendi = ValueBuilder<SignExtendIOp>;
using tensor_extract = ValueBuilder<tensor::ExtractOp>;
template <int N>
struct SExtiValueBuilder : public ValueBuilder<SignExtendIOp> {
using ValueBuilder<SignExtendIOp>::ValueBuilder;
template <typename... Args>
SExtiValueBuilder(Args... args)
: ValueBuilder<SignExtendIOp>(ScopedContext::getBuilderRef().getI32Type(),
args...) {}
};
using std_sexti32 = SExtiValueBuilder<32>;
/// Branches into `block` with `operands`.
BranchOp std_br(Block *block, ValueRange operands);

View File

@ -14,9 +14,12 @@
// RUN: tee -a /dev/stderr | FileCheck %s
!row_major_A = type memref<${M}x${K}xf32>
!row_major_B = type memref<${K}x${N}xf32>
!row_major_C = type memref<${M}x${N}xf32>
!elem_type_a = type f32
!elem_type_b = type f32
!elem_type_c = type f32
!row_major_A = type memref<${M}x${K}x!elem_type_a>
!row_major_B = type memref<${K}x${N}x!elem_type_b>
!row_major_C = type memref<${M}x${N}x!elem_type_c>
func @matmul(%a: !row_major_A, %b: !row_major_B, %c: !row_major_C)
// TODO: activate manually for now.
@ -48,16 +51,16 @@ func @print_perf(%iters: index, %total_time: f64) {
}
func @main() {
%f0 = constant 0.0 : f32
%f1 = constant 1.0 : f32
%v0 = constant 0.0 : !elem_type_a
%v1 = constant 1.0 : !elem_type_a
%A = alloc() : !row_major_A
%B = alloc() : !row_major_B
%C = alloc() : !row_major_C
linalg.fill(%A, %f1) : !row_major_A, f32
linalg.fill(%B, %f1) : !row_major_B, f32
linalg.fill(%C, %f0) : !row_major_C, f32
linalg.fill(%A, %v1) : !row_major_A, !elem_type_a
linalg.fill(%B, %v1) : !row_major_B, !elem_type_b
linalg.fill(%C, %v0) : !row_major_C, !elem_type_c
%c0 = constant 0: index
%c1 = constant 1: index
@ -66,7 +69,8 @@ func @main() {
/// Run and dump performance for matmul.
/// Preheating run:
scf.for %arg0 = %c0 to %iters step %c1 {
linalg.fill(%C, %f0) : !row_major_C, f32
%z = constant 0.0 : !elem_type_c
linalg.fill(%C, %z) : !row_major_C, !elem_type_c
call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> ()
}
%t_start_matmul = call @rtclock() : () -> f64
@ -75,7 +79,8 @@ func @main() {
// This is accounts for about 10-15% perf hit on small sizes.
// Once linalg on tensors is ready, fusing fill at teh register level will
// be easy.
linalg.fill(%C, %f0) : !row_major_C, f32
%z = constant 0.0 : !elem_type_c
linalg.fill(%C, %z) : !row_major_C, !elem_type_c
call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> ()
}
%t_end_matmul = call @rtclock() : () -> f64

View File

@ -15,12 +15,15 @@
// Use tee to both print to stderr and FileCheck
// RUN: tee -a /dev/stderr | FileCheck %s
!row_major_A = type memref<${M}x${K}xf32>
!row_major_B = type memref<${K}x${N}xf32>
!row_major_C = type memref<${M}x${N}xf32>
!column_major_A = type memref<${K}x${M}xf32>
!column_major_B = type memref<${N}x${K}xf32>
!column_major_C = type memref<${N}x${M}xf32>
!elem_type_a = type f32
!elem_type_b = type f32
!elem_type_c = type f32
!row_major_A = type memref<${M}x${K}x!elem_type_a>
!row_major_B = type memref<${K}x${N}x!elem_type_b>
!row_major_C = type memref<${M}x${N}x!elem_type_c>
!column_major_A = type memref<${K}x${M}x!elem_type_a>
!column_major_B = type memref<${N}x${K}x!elem_type_b>
!column_major_C = type memref<${N}x${M}x!elem_type_c>
func @matmul_column_major(%a: !column_major_A, %b: !column_major_B, %c: !column_major_C)
// TODO: activate manually for now.
@ -52,16 +55,16 @@ func @print_perf(%iters: index, %total_time: f64) {
}
func @main() {
%f0 = constant 0.0 : f32
%f1 = constant 1.0 : f32
%f0 = constant 0.0 : !elem_type_c
%f1 = constant 1.0 : !elem_type_a
%cA = alloc() : !column_major_A
%cB = alloc() : !column_major_B
%cC = alloc() : !column_major_C
linalg.fill(%cA, %f1) : !column_major_A, f32
linalg.fill(%cB, %f1) : !column_major_B, f32
linalg.fill(%cC, %f0) : !column_major_C, f32
linalg.fill(%cA, %f1) : !column_major_A, !elem_type_a
linalg.fill(%cB, %f1) : !column_major_B, !elem_type_b
linalg.fill(%cC, %f0) : !column_major_C, !elem_type_c
%c0 = constant 0: index
%c1 = constant 1: index
@ -74,7 +77,7 @@ func @main() {
// This is accounts for about 10-15% perf hit on small sizes.
// Once linalg on tensors is ready, fusing fill at teh register level will
// be easy.
linalg.fill(%cC, %f0) : !column_major_C, f32
linalg.fill(%cC, %f0) : !column_major_C, !elem_type_c
call @matmul_column_major(%cA, %cB, %cC) : (!column_major_A, !column_major_B, !column_major_C) -> ()
}
%t_end_matmul_column_major = call @rtclock() : () -> f64
@ -83,7 +86,7 @@ func @main() {
%res = load %cC[%c0, %c0]: !column_major_C
// CHECK: 64
vector.print %res: f32
vector.print %res: !elem_type_c
dealloc %cA : !column_major_A
dealloc %cB : !column_major_B

View File

@ -16,12 +16,15 @@
// Use tee to both print to stderr and FileCheck
// RUN: tee -a /dev/stderr | FileCheck %s
!row_major_A = type memref<${M}x${K}xf32>
!row_major_B = type memref<${K}x${N}xf32>
!row_major_C = type memref<${M}x${N}xf32>
!column_major_A = type memref<${K}x${M}xf32>
!column_major_B = type memref<${N}x${K}xf32>
!column_major_C = type memref<${N}x${M}xf32>
!elem_type_a = type f32
!elem_type_b = type f32
!elem_type_c = type f32
!row_major_A = type memref<${M}x${K}x!elem_type_a>
!row_major_B = type memref<${K}x${N}x!elem_type_b>
!row_major_C = type memref<${M}x${N}x!elem_type_c>
!column_major_A = type memref<${K}x${M}x!elem_type_a>
!column_major_B = type memref<${N}x${K}x!elem_type_b>
!column_major_C = type memref<${N}x${M}x!elem_type_c>
func @matmul_column_major_as_row_major(
%ca: !column_major_A, %cb: !column_major_B, %cc: !column_major_C,
@ -58,16 +61,16 @@ func @print_perf(%iters: index, %total_time: f64) {
}
func @main() {
%f0 = constant 0.0 : f32
%f1 = constant 1.0 : f32
%f0 = constant 0.0 : !elem_type_c
%f1 = constant 1.0 : !elem_type_a
%cA = alloc() : !column_major_A
%cB = alloc() : !column_major_B
%cC = alloc() : !column_major_C
linalg.fill(%cA, %f1) : !column_major_A, f32
linalg.fill(%cB, %f1) : !column_major_B, f32
linalg.fill(%cC, %f0) : !column_major_C, f32
linalg.fill(%cA, %f1) : !column_major_A, !elem_type_a
linalg.fill(%cB, %f1) : !column_major_B, !elem_type_b
linalg.fill(%cC, %f0) : !column_major_C, !elem_type_c
%c0 = constant 0: index
%c1 = constant 1: index
@ -83,7 +86,7 @@ func @main() {
// This is accounts for about 10-15% perf hit on small sizes.
// Once linalg on tensors is ready, fusing fill at teh register level will
// be easy.
linalg.fill(%C, %f0) : !row_major_C, f32
linalg.fill(%C, %f0) : !row_major_C, !elem_type_c
call @matmul_column_major_as_row_major(%cA, %cB, %cC, %A, %B, %C) :
(!column_major_A, !column_major_B, !column_major_C,
!row_major_A, !row_major_B, !row_major_C) -> ()
@ -94,10 +97,10 @@ func @main() {
%res = load %cC[%c0, %c0]: !column_major_C
// CHECK: 64
vector.print %res: f32
vector.print %res: !elem_type_c
%res2 = load %C[%c0, %c0]: !row_major_C
// CHECK: 64
vector.print %res2: f32
vector.print %res2: !elem_type_c
dealloc %A : !row_major_A
dealloc %B : !row_major_B

View File

@ -0,0 +1,103 @@
// RUN: export M=24 && export K=64 && export N=192 && export ITERS=10 && \
// RUN: cat %s | sed 's@${M}@'"$M"'@g'| sed 's@${K}@'"$K"'@g' | sed 's@${N}@'"$N"'@g'| sed 's@${ITERS}@'"$ITERS"'@g'| \
// TODO: extend vectorization with interfaces so that it works with sexti
// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.matmul_i8_i8_i32 register-tile-sizes=12,32,16" | \
// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.fill register-tile-sizes=4,32 vectorize" | \
// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.copy register-tile-sizes=4,32 vectorize" | \
// RUN: mlir-opt -canonicalize -convert-vector-to-scf -lower-affine -convert-linalg-to-loops | \
// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm | \
// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \
// Activate to dump assembly
// R_UN: -dump-object-file -object-filename=/tmp/a.o \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext | \
// Use tee to both print to stderr and FileCheck
// RUN: tee -a /dev/stderr | FileCheck %s
!elem_type_a = type i8
!elem_type_b = type i8
!elem_type_c = type i32
!row_major_A = type memref<${M}x${K}x!elem_type_a>
!row_major_B = type memref<${K}x${N}x!elem_type_b>
!row_major_C = type memref<${M}x${N}x!elem_type_c>
func @matmul(%a: !row_major_A, %b: !row_major_B, %c: !row_major_C)
// TODO: activate manually for now.
// attributes { passthrough = [["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]}
{
linalg.matmul_i8_i8_i32 ins(%a, %b : !row_major_A, !row_major_B)
outs(%c: !row_major_C)
return
}
func @print_perf(%iters: index, %total_time: f64) {
%c2 = constant 2 : index
%cM = constant ${M} : index
%cN = constant ${N} : index
%cK = constant ${K} : index
%mn = muli %cM, %cN : index
%mnk = muli %mn, %cK : index
// 2*M*N*K.
%flops_per_iter = muli %c2, %mnk : index
%flops = muli %iters, %flops_per_iter : index
%flops_i64 = index_cast %flops : index to i64
%flops_f = sitofp %flops_i64 : i64 to f64
%flops_per_s = divf %flops_f, %total_time : f64
vector.print %flops_per_s : f64
return
}
func @main() {
%v0 = constant 0 : !elem_type_c
%v1 = constant 1 : !elem_type_a
%A = alloc() : !row_major_A
%B = alloc() : !row_major_B
%C = alloc() : !row_major_C
linalg.fill(%A, %v1) : !row_major_A, !elem_type_a
linalg.fill(%B, %v1) : !row_major_B, !elem_type_b
linalg.fill(%C, %v0) : !row_major_C, !elem_type_c
%c0 = constant 0: index
%c1 = constant 1: index
%iters = constant ${ITERS}: index
/// Run and dump performance for matmul.
/// Preheating run:
scf.for %arg0 = %c0 to %iters step %c1 {
linalg.fill(%C, %v0) : !row_major_C, !elem_type_c
call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> ()
}
%t_start_matmul = call @rtclock() : () -> f64
scf.for %arg0 = %c0 to %iters step %c1 {
// linalg.matmul writes %C in place, need to reset it to zero every time.
// This is accounts for about 10-15% perf hit on small sizes.
// Once linalg on tensors is ready, fusing fill at teh register level will
// be easy.
linalg.fill(%C, %v0) : !row_major_C, !elem_type_c
call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> ()
}
%t_end_matmul = call @rtclock() : () -> f64
%tmatmul = subf %t_end_matmul, %t_start_matmul: f64
call @print_perf(%iters, %tmatmul) : (index, f64) -> ()
%res = load %C[%c0, %c0]: !row_major_C
// CHECK: 64
vector.print %res: !elem_type_c
dealloc %A : !row_major_A
dealloc %B : !row_major_B
dealloc %C : !row_major_C
return
}
func private @rtclock() -> f64
// TODO: init with random, run and check output.
// func private @fill_random_f32(memref<*xf32>)

View File

@ -37,8 +37,10 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
for (const std::unique_ptr<Transformation> &t : transformationSequence) {
auto nextState = Identifier::get(std::to_string(++stepCount), context);
auto marker = (currentState == zeroState)
? linalg::LinalgMarker({}, nextState)
: linalg::LinalgMarker(currentState, nextState);
? linalg::LinalgTransformationFilter(
t->filter, ArrayRef<Identifier>{}, nextState)
: linalg::LinalgTransformationFilter(
t->filter, currentState, nextState);
stage1Patterns.emplace_back(t->buildRewritePatterns(context, marker));
currentState = nextState;
}
@ -47,15 +49,17 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
linalg::getLinalgTilingCanonicalizationPatterns(context);
stage2Patterns.insert<AffineMinSCFCanonicalizationPattern>(context);
auto stage3Transforms = [](Operation *op) {
auto stage3Transforms = [&](Operation *op) {
// Some of these may be too aggressive as a stage 3 that is applied on each
// stage 1 application and may have to be split out to post staged patterns
// application (in which case they could just be passes, TBD).
op->walk([&](LoopLikeOpInterface loopLike) {
LLVM_DEBUG(loopLike.print(llvm::dbgs() << "\nOriginal loop:\n"));
if (failed(moveLoopInvariantCode(loopLike)))
llvm_unreachable("unexpected LICM failure");
});
if (enableLICM) {
op->walk([&](LoopLikeOpInterface loopLike) {
LLVM_DEBUG(loopLike.print(llvm::dbgs() << "\nOriginal loop:\n"));
if (failed(moveLoopInvariantCode(loopLike)))
llvm_unreachable("unexpected LICM failure");
});
}
promoteSingleIterationLoops(cast<FuncOp>(op));
hoistViewAllocOps(cast<FuncOp>(op));
hoistRedundantVectorTransfers(cast<FuncOp>(op));

View File

@ -63,7 +63,8 @@ namespace {
// into auto-generated ones.
template <typename ConcretePattern, typename RootOp>
struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> {
LinalgGeneralizationPattern(MLIRContext *context, linalg::LinalgMarker marker,
LinalgGeneralizationPattern(MLIRContext *context,
linalg::LinalgTransformationFilter marker,
PatternBenefit benefit = 1)
: OpRewritePattern<RootOp>(context, benefit), marker(std::move(marker)) {}
@ -81,12 +82,13 @@ struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> {
return failure();
rewriter.replaceOp(rootOp, genericOp.getResults());
marker.replaceLinalgMarker(rewriter, genericOp.getOperation());
marker.replaceLinalgTransformationFilter(rewriter,
genericOp.getOperation());
return success();
}
private:
linalg::LinalgMarker marker;
linalg::LinalgTransformationFilter marker;
};
struct GeneralizeConvOp
@ -100,7 +102,7 @@ struct GeneralizeConvOp
/// linalg.generic.
struct LinalgNamedOpGeneralizationPattern : RewritePattern {
LinalgNamedOpGeneralizationPattern(MLIRContext *context,
linalg::LinalgMarker marker,
linalg::LinalgTransformationFilter marker,
PatternBenefit benefit = 1)
: RewritePattern(benefit, MatchAnyOpTypeTag()),
marker(std::move(marker)) {}
@ -123,12 +125,13 @@ struct LinalgNamedOpGeneralizationPattern : RewritePattern {
return failure();
rewriter.replaceOp(rootOp, genericOp.getResults());
marker.replaceLinalgMarker(rewriter, genericOp.getOperation());
marker.replaceLinalgTransformationFilter(rewriter,
genericOp.getOperation());
return success();
}
private:
linalg::LinalgMarker marker;
linalg::LinalgTransformationFilter marker;
};
struct LinalgGeneralizationPass
@ -165,13 +168,13 @@ linalg::GenericOp GeneralizeConvOp::createGenericOp(linalg::ConvOp convOp,
void mlir::linalg::populateLinalgConvGeneralizationPatterns(
MLIRContext *context, OwningRewritePatternList &patterns,
linalg::LinalgMarker marker) {
linalg::LinalgTransformationFilter marker) {
patterns.insert<GeneralizeConvOp>(context, marker);
}
void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
MLIRContext *context, OwningRewritePatternList &patterns,
linalg::LinalgMarker marker) {
linalg::LinalgTransformationFilter marker) {
patterns.insert<LinalgNamedOpGeneralizationPattern>(context, marker);
}

View File

@ -536,7 +536,9 @@ public:
static void insert(OwningRewritePatternList &patterns,
const LinalgTilingOptions &options, MLIRContext *ctx) {
patterns.insert<LinalgTilingPattern<OpTy>>(
ctx, options, LinalgMarker({}, Identifier::get("tiled", ctx)));
ctx, options,
LinalgTransformationFilter(ArrayRef<Identifier>{},
Identifier::get("tiled", ctx)));
RewritePatternList<OpTypes...>::insert(patterns, options, ctx);
}
};

View File

@ -46,14 +46,23 @@ using namespace mlir::linalg;
const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
"__internal_linalg_transform__";
mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<Identifier> matchDisjunction,
Optional<Identifier> replacement)
: matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement)
: LinalgTransformationFilter([](Operation *) { return success(); },
matchDisjunction, replacement) {}
mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
FilterFunction f, ArrayRef<Identifier> matchDisjunction,
Optional<Identifier> replacement)
: filter(f),
matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
replacement(replacement) {}
LogicalResult
mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
Operation *op) const {
LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
PatternRewriter &rewriter, Operation *op) const {
if (filter && failed(filter(op)))
return failure();
auto attr = op->template getAttrOfType<StringAttr>(
LinalgTransforms::kLinalgTransformMarker);
@ -81,8 +90,9 @@ mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
});
}
void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter,
Operation *op) const {
void mlir::linalg::LinalgTransformationFilter::
replaceLinalgTransformationFilter(PatternRewriter &rewriter,
Operation *op) const {
if (replacement.hasValue())
op->setAttr(LinalgTransforms::kLinalgTransformMarker,
rewriter.getStringAttr(replacement.getValue()));
@ -219,12 +229,13 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
/// Linalg base tiling pattern.
mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
LinalgMarker marker, PatternBenefit benefit)
LinalgTransformationFilter marker, PatternBenefit benefit)
: RewritePattern(opName, {}, benefit, context), marker(marker),
options(options) {}
mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
LinalgTilingOptions options, LinalgMarker marker, PatternBenefit benefit)
LinalgTilingOptions options, LinalgTransformationFilter marker,
PatternBenefit benefit)
: RewritePattern(benefit, MatchAnyOpTypeTag()), marker(marker),
options(options) {}
@ -250,9 +261,9 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
// Return relevant information to derived pattern.
result = *res;
// Replace marker on both tiledOp and tiledAndPaddedOp, if necessary.
marker.replaceLinalgMarker(rewriter, tiledOp);
marker.replaceLinalgTransformationFilter(rewriter, tiledOp);
if (tiledOp != res->op)
marker.replaceLinalgMarker(rewriter, res->op);
marker.replaceLinalgTransformationFilter(rewriter, res->op);
});
// Consider padding on the fly only if the op has tensor semantics.
@ -276,8 +287,8 @@ mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
StringRef opName, MLIRContext *context,
const LinalgDependenceGraph &dependenceGraph,
LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
LinalgMarker marker, LinalgMarker fusedOpMarker,
LinalgMarker originalOpMarker, PatternBenefit benefit)
LinalgTransformationFilter marker, LinalgTransformationFilter fusedOpMarker,
LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
: RewritePattern(opName, {}, benefit, context),
dependenceGraph(dependenceGraph), tilingOptions(tilingOptions),
fusionOptions(fusionOptions), marker(marker),
@ -352,23 +363,26 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
tiledAndFusedOps->op = unfusedTiledOp->op;
}
marker.replaceLinalgMarker(rewriter, tiledAndFusedOps->op.getOperation());
marker.replaceLinalgTransformationFilter(rewriter,
tiledAndFusedOps->op.getOperation());
for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
fusedOpMarker.replaceLinalgMarker(rewriter, fusedOp.getOperation());
fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
fusedOp.getOperation());
}
for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
originalOpMarker.replaceLinalgMarker(rewriter,
origProducerOp.getOperation());
originalOpMarker.replaceLinalgTransformationFilter(
rewriter, origProducerOp.getOperation());
}
rewriter.updateRootInPlace(
op, [&]() { originalOpMarker.replaceLinalgMarker(rewriter, op); });
rewriter.updateRootInPlace(op, [&]() {
originalOpMarker.replaceLinalgTransformationFilter(rewriter, op);
});
return success();
}
/// Linalg base interchange pattern.
mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
StringRef opName, MLIRContext *context,
ArrayRef<unsigned> interchangeVector, LinalgMarker marker,
ArrayRef<unsigned> interchangeVector, LinalgTransformationFilter marker,
PatternBenefit benefit)
: RewritePattern(opName, {}, benefit, context), marker(marker),
interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
@ -388,14 +402,14 @@ LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
rewriter.updateRootInPlace(op, [&]() {
interchange(linalgOp, interchangeVector);
// New marker if specified.
marker.replaceLinalgMarker(rewriter, op);
marker.replaceLinalgTransformationFilter(rewriter, op);
});
return success();
}
mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
LinalgMarker marker, PatternBenefit benefit)
LinalgTransformationFilter marker, PatternBenefit benefit)
: RewritePattern(opName, {}, benefit, context), marker(marker),
options(options) {}
@ -417,12 +431,12 @@ LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
return op->emitError("subview promotion failed");
}
rewriter.finalizeRootUpdate(op);
marker.replaceLinalgMarker(rewriter, op);
marker.replaceLinalgTransformationFilter(rewriter, op);
return success();
}
mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
StringRef opName, MLIRContext *context, LinalgMarker marker,
StringRef opName, MLIRContext *context, LinalgTransformationFilter marker,
PatternBenefit benefit)
: RewritePattern(opName, {}, benefit, context), marker(marker) {}

View File

@ -607,12 +607,13 @@ populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns,
constexpr static StringRef kPromotedMarker = "PROMOTED";
tilingPatterns.insert<LinalgTilingPattern<ConvOp>>(
context, LinalgTilingOptions().setTileSizes(tileSizes),
LinalgMarker({}, Identifier::get(kTiledMarker, context)));
LinalgTransformationFilter(ArrayRef<Identifier>{},
Identifier::get(kTiledMarker, context)));
promotionPatterns.insert<LinalgPromotionPattern<ConvOp>>(
context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
LinalgMarker(Identifier::get(kTiledMarker, context),
Identifier::get(kPromotedMarker, context)));
LinalgTransformationFilter(Identifier::get(kTiledMarker, context),
Identifier::get(kPromotedMarker, context)));
SmallVector<bool, 4> mask(N);
int offset = tileSizes.size() - N;

View File

@ -107,8 +107,12 @@ struct TestLinalgCodegenStrategy
};
} // end anonymous namespace
template <typename LinalgNamedOp>
void TestLinalgCodegenStrategy::applyStrategyToNamedLinalgOp() {
/// Apply transformations specified as patterns.
void TestLinalgCodegenStrategy::runOnFunction() {
linalg::LinalgTransformationFilter::FilterFunction filterOpName =
[&](Operation *op) -> LogicalResult {
return success(op->getName().getStringRef() == anchorOpName);
};
LinalgTilingOptions tilingOptions;
if (!tileSizes.empty())
tilingOptions = tilingOptions.setTileSizes(tileSizes);
@ -134,19 +138,20 @@ void TestLinalgCodegenStrategy::applyStrategyToNamedLinalgOp() {
.Default(vector::VectorTransferSplit::None);
CodegenStrategy strategy;
strategy.template tileIf<LinalgNamedOp>(!tileSizes.empty(), tilingOptions)
.template promoteIf<LinalgNamedOp>(
promote, LinalgPromotionOptions()
.setAlignment(16)
.setUseFullTileBuffersByDefault(promoteFullTile))
.template tileIf<LinalgNamedOp>(!registerTileSizes.empty(),
registerTilingOptions)
.template promoteIf<LinalgNamedOp>(
registerPromote,
strategy.tileIf<LinalgOp>(!tileSizes.empty(), anchorOpName, tilingOptions)
.promoteIf<LinalgOp>(promote, anchorOpName,
LinalgPromotionOptions()
.setAlignment(16)
.setUseFullTileBuffersByDefault(promoteFullTile),
filterOpName)
.tileIf<LinalgOp>(!registerTileSizes.empty(), anchorOpName,
registerTilingOptions)
.promoteIf<LinalgOp>(
registerPromote, anchorOpName,
LinalgPromotionOptions()
.setAlignment(16)
.setUseFullTileBuffersByDefault(registerPromoteFullTile))
.template vectorizeIf<LinalgNamedOp>(vectorize)
.vectorizeIf<LinalgOp>(vectorize, anchorOpName)
.setVectorTransformsOptions(
vector::VectorTransformsOptions()
.setVectorTransformsOptions(vectorContractLowering)
@ -156,20 +161,6 @@ void TestLinalgCodegenStrategy::applyStrategyToNamedLinalgOp() {
strategy.transform(getFunction());
}
/// Apply transformations specified as patterns.
void TestLinalgCodegenStrategy::runOnFunction() {
if (anchorOpName == MatmulOp::getOperationName())
applyStrategyToNamedLinalgOp<MatmulOp>();
else if (anchorOpName == MatmulColumnMajorOp::getOperationName())
applyStrategyToNamedLinalgOp<MatmulColumnMajorOp>();
else if (anchorOpName == CopyOp::getOperationName())
applyStrategyToNamedLinalgOp<CopyOp>();
else if (anchorOpName == FillOp::getOperationName())
applyStrategyToNamedLinalgOp<FillOp>();
else
llvm_unreachable("Unsupported anchor op");
}
namespace mlir {
namespace test {
void registerTestLinalgCodegenStrategy() {

View File

@ -45,12 +45,15 @@ static void fillFusionPatterns(MLIRContext *context,
.setTileSizes({32, 64, 16})
.setLoopType(LinalgTilingLoopType::ParallelLoops),
LinalgFusionOptions().setIndicesToFuse({2}),
LinalgMarker(Identifier::get("basic_fusion", context),
Identifier::get("after_basic_fusion", context)),
LinalgMarker(ArrayRef<Identifier>(),
Identifier::get("after_basic_fusion_producer", context)),
LinalgMarker(ArrayRef<Identifier>(),
Identifier::get("after_basic_fusion_original", context)));
LinalgTransformationFilter(
Identifier::get("basic_fusion", context),
Identifier::get("after_basic_fusion", context)),
LinalgTransformationFilter(
ArrayRef<Identifier>(),
Identifier::get("after_basic_fusion_producer", context)),
LinalgTransformationFilter(
ArrayRef<Identifier>(),
Identifier::get("after_basic_fusion_original", context)));
patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
context, dependenceGraph,
@ -58,12 +61,14 @@ static void fillFusionPatterns(MLIRContext *context,
.setTileSizes({32, 64, 16})
.setLoopType(LinalgTilingLoopType::ParallelLoops),
LinalgFusionOptions().setIndicesToFuse({0}),
LinalgMarker(Identifier::get("lhs_fusion", context),
Identifier::get("after_lhs_fusion", context)),
LinalgMarker(ArrayRef<Identifier>(),
Identifier::get("after_lhs_fusion_producer", context)),
LinalgMarker(ArrayRef<Identifier>(),
Identifier::get("after_lhs_fusion_original", context)));
LinalgTransformationFilter(Identifier::get("lhs_fusion", context),
Identifier::get("after_lhs_fusion", context)),
LinalgTransformationFilter(
ArrayRef<Identifier>(),
Identifier::get("after_lhs_fusion_producer", context)),
LinalgTransformationFilter(
ArrayRef<Identifier>(),
Identifier::get("after_lhs_fusion_original", context)));
patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
context, dependenceGraph,
@ -71,12 +76,14 @@ static void fillFusionPatterns(MLIRContext *context,
.setTileSizes({32, 64, 16})
.setLoopType(LinalgTilingLoopType::ParallelLoops),
LinalgFusionOptions().setIndicesToFuse({1}),
LinalgMarker(Identifier::get("rhs_fusion", context),
Identifier::get("after_rhs_fusion", context)),
LinalgMarker(ArrayRef<Identifier>(),
Identifier::get("after_rhs_fusion_producer", context)),
LinalgMarker(ArrayRef<Identifier>(),
Identifier::get("after_rhs_fusion_original", context)));
LinalgTransformationFilter(Identifier::get("rhs_fusion", context),
Identifier::get("after_rhs_fusion", context)),
LinalgTransformationFilter(
ArrayRef<Identifier>(),
Identifier::get("after_rhs_fusion_producer", context)),
LinalgTransformationFilter(
ArrayRef<Identifier>(),
Identifier::get("after_rhs_fusion_original", context)));
patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
context, dependenceGraph,
@ -84,12 +91,13 @@ static void fillFusionPatterns(MLIRContext *context,
.setTileSizes({32, 64, 16})
.setLoopType(LinalgTilingLoopType::ParallelLoops),
LinalgFusionOptions().setIndicesToFuse({0, 2}),
LinalgMarker(Identifier::get("two_operand_fusion", context),
Identifier::get("after_two_operand_fusion", context)),
LinalgMarker(
LinalgTransformationFilter(
Identifier::get("two_operand_fusion", context),
Identifier::get("after_two_operand_fusion", context)),
LinalgTransformationFilter(
ArrayRef<Identifier>(),
Identifier::get("after_two_operand_fusion_producer", context)),
LinalgMarker(
LinalgTransformationFilter(
ArrayRef<Identifier>(),
Identifier::get("after_two_operand_fusion_original", context)));
@ -98,11 +106,13 @@ static void fillFusionPatterns(MLIRContext *context,
LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(
LinalgTilingLoopType::ParallelLoops),
LinalgFusionOptions().setIndicesToFuse({0, 1}),
LinalgMarker(Identifier::get("transpose_fusion", context),
Identifier::get("after_transpose_fusion", context)),
LinalgMarker(ArrayRef<Identifier>(),
Identifier::get("after_transpose_fusion_producer", context)),
LinalgMarker(
LinalgTransformationFilter(
Identifier::get("transpose_fusion", context),
Identifier::get("after_transpose_fusion", context)),
LinalgTransformationFilter(
ArrayRef<Identifier>(),
Identifier::get("after_transpose_fusion_producer", context)),
LinalgTransformationFilter(
ArrayRef<Identifier>(),
Identifier::get("after_transpose_fusion_original", context)));
}

View File

@ -98,29 +98,35 @@ static void applyPatterns(FuncOp funcOp) {
//===--------------------------------------------------------------------===//
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
LinalgMarker(Identifier::get("MEM", ctx), Identifier::get("L3", ctx)));
LinalgTransformationFilter(Identifier::get("MEM", ctx),
Identifier::get("L3", ctx)));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
LinalgMarker(Identifier::get("L3", ctx), Identifier::get("L2", ctx)));
LinalgTransformationFilter(Identifier::get("L3", ctx),
Identifier::get("L2", ctx)));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
LinalgMarker(Identifier::get("L2", ctx), Identifier::get("L1", ctx)));
LinalgTransformationFilter(Identifier::get("L2", ctx),
Identifier::get("L1", ctx)));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
LinalgMarker(Identifier::get("L1", ctx), Identifier::get("REG", ctx)));
LinalgTransformationFilter(Identifier::get("L1", ctx),
Identifier::get("REG", ctx)));
patterns.insert<LinalgTilingPattern<MatvecOp>>(
ctx,
LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
LinalgTilingLoopType::ParallelLoops),
LinalgMarker({}, Identifier::get("L1", ctx)));
LinalgTransformationFilter(ArrayRef<Identifier>{},
Identifier::get("L1", ctx)));
patterns.insert<LinalgTilingPattern<DotOp>>(
ctx, LinalgTilingOptions().setTileSizes(8000),
LinalgMarker(ArrayRef<Identifier>{Identifier::get("MEM", ctx),
Identifier::get("L3", ctx),
Identifier::get("L2", ctx)},
Identifier::get("REG", ctx)));
LinalgTransformationFilter(
ArrayRef<Identifier>{Identifier::get("MEM", ctx),
Identifier::get("L3", ctx),
Identifier::get("L2", ctx)},
Identifier::get("REG", ctx)));
//===--------------------------------------------------------------------===//
// Linalg tiling and permutation patterns.
@ -130,24 +136,24 @@ static void applyPatterns(FuncOp funcOp) {
LinalgTilingOptions()
.setTileSizes({2000, 3000, 4000})
.setInterchange({1, 2, 0}),
LinalgMarker(Identifier::get("__with_perm__", ctx),
Identifier::get("L2__with_perm__", ctx)));
LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
Identifier::get("L2__with_perm__", ctx)));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx,
LinalgTilingOptions()
.setTileSizes({200, 300, 400})
.setInterchange({1, 0, 2}),
LinalgMarker(Identifier::get("L2__with_perm__", ctx),
Identifier::get("L1__with_perm__", ctx)));
LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx),
Identifier::get("L1__with_perm__", ctx)));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
LinalgMarker(Identifier::get("L1__with_perm__", ctx),
Identifier::get("REG__with_perm__", ctx)));
LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx),
Identifier::get("REG__with_perm__", ctx)));
patterns.insert<LinalgTilingPattern<MatvecOp>>(
ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
LinalgMarker(Identifier::get("__with_perm__", ctx),
Identifier::get("L1__with_perm__", ctx)));
LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
Identifier::get("L1__with_perm__", ctx)));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx,
@ -155,8 +161,9 @@ static void applyPatterns(FuncOp funcOp) {
.setTileSizes({16, 8, 4})
.setInterchange({1, 2, 0})
.setLoopType(LinalgTilingLoopType::ParallelLoops),
LinalgMarker(Identifier::get("par__with_perm__", ctx),
Identifier::get("after_par__with_perm__", ctx)));
LinalgTransformationFilter(
Identifier::get("par__with_perm__", ctx),
Identifier::get("after_par__with_perm__", ctx)));
//===--------------------------------------------------------------------===//
// Linalg to loops patterns.
@ -164,7 +171,7 @@ static void applyPatterns(FuncOp funcOp) {
patterns.insert<LinalgLoweringPattern<DotOp>>(
ctx,
/*loweringType=*/LinalgLoweringType::Loops,
LinalgMarker(Identifier::get("REG", ctx)));
LinalgTransformationFilter(Identifier::get("REG", ctx)));
//===--------------------------------------------------------------------===//
// Linalg distribution patterns.
@ -178,7 +185,8 @@ static void applyPatterns(FuncOp funcOp) {
LinalgVectorizationPattern<FillOp>,
LinalgVectorizationPattern<CopyOp>,
LinalgVectorizationPattern<GenericOp>>(
ctx, LinalgMarker(Identifier::get("VECTORIZE", ctx)));
ctx, LinalgVectorizationOptions(),
LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx)));
//===--------------------------------------------------------------------===//
// Linalg generic permutation patterns.
@ -186,34 +194,38 @@ static void applyPatterns(FuncOp funcOp) {
patterns.insert<LinalgInterchangePattern<GenericOp>>(
ctx,
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
LinalgMarker({}, Identifier::get("PERMUTED", ctx)));
LinalgTransformationFilter(ArrayRef<Identifier>{},
Identifier::get("PERMUTED", ctx)));
patterns.insert<LinalgInterchangePattern<IndexedGenericOp>>(
ctx,
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
LinalgMarker({}, Identifier::get("PERMUTED", ctx)));
LinalgTransformationFilter(ArrayRef<Identifier>{},
Identifier::get("PERMUTED", ctx)));
//===--------------------------------------------------------------------===//
// Linalg subview operands promotion.
//===--------------------------------------------------------------------===//
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
LinalgMarker(Identifier::get("_promote_views_", ctx),
Identifier::get("_views_promoted_", ctx)));
LinalgTransformationFilter(Identifier::get("_promote_views_", ctx),
Identifier::get("_views_promoted_", ctx)));
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
ctx,
LinalgPromotionOptions()
.setOperandsToPromote({0})
.setUseFullTileBuffersByDefault(true),
LinalgMarker(Identifier::get("_promote_first_view_", ctx),
Identifier::get("_first_view_promoted_", ctx)));
LinalgTransformationFilter(
Identifier::get("_promote_first_view_", ctx),
Identifier::get("_first_view_promoted_", ctx)));
patterns.insert<LinalgPromotionPattern<FillOp>>(
ctx,
LinalgPromotionOptions()
.setOperandsToPromote({0})
.setUseFullTileBuffers({true})
.setAlignment(32),
LinalgMarker(Identifier::get("_promote_views_aligned_", ctx),
Identifier::get("_views_aligned_promoted_", ctx)));
LinalgTransformationFilter(
Identifier::get("_promote_views_aligned_", ctx),
Identifier::get("_views_aligned_promoted_", ctx)));
applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
@ -230,18 +242,19 @@ static void fillL1TilingAndMatmulToVectorPatterns(
patternsVector.emplace_back(std::make_unique<LinalgTilingPattern<MatmulOp>>(
ctx,
LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}),
LinalgMarker(Identifier::get(startMarker, ctx),
Identifier::get("L1", ctx))));
LinalgTransformationFilter(Identifier::get(startMarker, ctx),
Identifier::get("L1", ctx))));
patternsVector.emplace_back(
std::make_unique<LinalgPromotionPattern<MatmulOp>>(
ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
LinalgMarker(Identifier::get("L1", ctx),
Identifier::get("VEC", ctx))));
LinalgTransformationFilter(Identifier::get("L1", ctx),
Identifier::get("VEC", ctx))));
patternsVector.emplace_back(
std::make_unique<LinalgVectorizationPattern<MatmulOp>>(
ctx, LinalgMarker(Identifier::get("VEC", ctx))));
ctx, LinalgVectorizationOptions(),
LinalgTransformationFilter(Identifier::get("VEC", ctx))));
patternsVector.back()
.insert<LinalgVectorizationPattern<FillOp>,
LinalgVectorizationPattern<CopyOp>>(ctx);
@ -289,8 +302,8 @@ static void fillPromotionCallBackPatterns(MLIRContext *ctx,
OwningRewritePatternList &patterns) {
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
LinalgMarker(Identifier::get("START", ctx),
Identifier::get("PROMOTE", ctx)));
LinalgTransformationFilter(Identifier::get("START", ctx),
Identifier::get("PROMOTE", ctx)));
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
ctx,
LinalgPromotionOptions()
@ -306,7 +319,7 @@ static void fillPromotionCallBackPatterns(MLIRContext *ctx,
copyCallBackFn(b, src, dst, true);
return success();
}),
LinalgMarker(Identifier::get("PROMOTE", ctx)));
LinalgTransformationFilter(Identifier::get("PROMOTE", ctx)));
}
template <typename IdOp, typename NProcsOp>
@ -335,8 +348,9 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::ParallelLoops)
.setDistributionOptions(cyclicNprocsEqNiters),
LinalgMarker(Identifier::get("distribute1", context),
Identifier::get("after_distribute1", context)));
LinalgTransformationFilter(
Identifier::get("distribute1", context),
Identifier::get("after_distribute1", context)));
}
{
@ -351,8 +365,9 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::ParallelLoops)
.setDistributionOptions(cyclicNprocsGeNiters),
LinalgMarker(Identifier::get("distribute2", context),
Identifier::get("after_distribute2", context)));
LinalgTransformationFilter(
Identifier::get("distribute2", context),
Identifier::get("after_distribute2", context)));
}
{
@ -367,8 +382,9 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::ParallelLoops)
.setDistributionOptions(cyclicNprocsDefault),
LinalgMarker(Identifier::get("distribute3", context),
Identifier::get("after_distribute3", context)));
LinalgTransformationFilter(
Identifier::get("distribute3", context),
Identifier::get("after_distribute3", context)));
}
{
@ -383,8 +399,9 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::ParallelLoops)
.setDistributionOptions(cyclicNprocsMixed1),
LinalgMarker(Identifier::get("distribute4", context),
Identifier::get("after_distribute4", context)));
LinalgTransformationFilter(
Identifier::get("distribute4", context),
Identifier::get("after_distribute4", context)));
}
{
@ -399,8 +416,9 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::ParallelLoops)
.setDistributionOptions(cyclicNprocsMixed2),
LinalgMarker(Identifier::get("distribute5", context),
Identifier::get("after_distribute5", context)));
LinalgTransformationFilter(
Identifier::get("distribute5", context),
Identifier::get("after_distribute5", context)));
}
{
@ -416,8 +434,9 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::ParallelLoops)
.setDistributionOptions(cyclicNprocsMixed3),
LinalgMarker(Identifier::get("distribute6", context),
Identifier::get("after_distribute6", context)));
LinalgTransformationFilter(
Identifier::get("distribute6", context),
Identifier::get("after_distribute6", context)));
}
{
@ -432,8 +451,9 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::Loops)
.setDistributionOptions(cyclicNprocsEqNiters),
LinalgMarker(Identifier::get("tensors_distribute1", context),
Identifier::get("tensors_after_distribute1", context)));
LinalgTransformationFilter(
Identifier::get("tensors_distribute1", context),
Identifier::get("tensors_after_distribute1", context)));
}
}
@ -452,8 +472,8 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
LinalgTilingOptions()
.setTileSizes({768, 264, 768})
.setInterchange({1, 2, 0}),
LinalgMarker(Identifier::get("START", ctx),
Identifier::get("L2", ctx))));
LinalgTransformationFilter(Identifier::get("START", ctx),
Identifier::get("L2", ctx))));
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
stage1Patterns);
}
@ -511,7 +531,8 @@ static void applyTileAndPadPattern(FuncOp funcOp) {
.setPaddingValueComputationFunction(getNeutralOfLinalgOp);
tilingPattern.insert<linalg::LinalgTilingPattern<linalg::MatmulOp>>(
context, linalgTilingOptions,
linalg::LinalgMarker(Identifier::get("tile-and-pad", context)));
linalg::LinalgTransformationFilter(
Identifier::get("tile-and-pad", context)));
applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
}