forked from OSchip/llvm-project
[mlir][Linalg] Further improve codegen strategy and add a linalg.matmul_i8_i8_i32
This revision adds a layer of SFINAE to the composable codegen strategy so it does not have to require statically defined ops but instead can also be used with OpInterfaces, Operation* and an op name string. A linalg.matmul_i8_i8_i32 is added to the .tc spec to demonstrate how all this works end to end. Differential Revision: https://reviews.llvm.org/D95600
This commit is contained in:
parent
42635856ed
commit
299cc5da6d
|
@ -8,6 +8,13 @@ def matmul_column_major(A: f32(K, M), B: f32(N, K)) -> (C: f32(N, M)) {
|
|||
C(n, m) = std_addf<k>(std_mulf(A(k, m), B(n, k)));
|
||||
}
|
||||
|
||||
ods_def<MatmulI8I8I32Op>:
|
||||
def matmul_i8_i8_i32(A: i8(M, K), B: i8(K, N)) -> (C: i32(M, N)) {
|
||||
// TODO: ideally something closer to
|
||||
// C(m, n) += cast<i32>(A(m, k)) * cast<i32>(B(k, n))
|
||||
C(m, n) = std_addi<k>(std_sexti32(std_muli(A(m, k), B(k, n))));
|
||||
}
|
||||
|
||||
ods_def<MatvecOp>:
|
||||
def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) {
|
||||
x(m) = std_addf<n>(std_mulf(A(m, n), y(n)));
|
||||
|
|
|
@ -21,27 +21,63 @@ namespace linalg {
|
|||
/// Abstract Transformation class applied in a sequence that also handles state
|
||||
/// through markers.
|
||||
struct Transformation {
|
||||
explicit Transformation(linalg::LinalgTransformationFilter::FilterFunction f)
|
||||
: filter(f) {}
|
||||
virtual ~Transformation() = default;
|
||||
virtual OwningRewritePatternList
|
||||
buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) = 0;
|
||||
linalg::LinalgMarker marker;
|
||||
buildRewritePatterns(MLIRContext *context,
|
||||
linalg::LinalgTransformationFilter m) = 0;
|
||||
linalg::LinalgTransformationFilter::FilterFunction filter = nullptr;
|
||||
};
|
||||
|
||||
/// SFINAE: Enqueue helper for ConcreteOpType that have a `getOperationName`.
|
||||
template <template <typename> class PatternType, typename ConcreteOpType,
|
||||
typename OptionsType,
|
||||
typename std::enable_if<std::is_member_function_pointer<
|
||||
decltype(&ConcreteOpType::getOperationName)>::value>>
|
||||
void sfinae_enqueue(OwningRewritePatternList &patterList, OptionsType options,
|
||||
MLIRContext *context, StringRef opName,
|
||||
linalg::LinalgTransformationFilter m) {
|
||||
assert(opName.empty() ||
|
||||
opName == ConcreteOpType::getOperationName() &&
|
||||
"explicit name must match ConcreteOpType::getOperationName");
|
||||
patterList.insert<PatternType<ConcreteOpType>>(context, options, m);
|
||||
}
|
||||
|
||||
/// SFINAE: Enqueue helper for OpType that do not have a `getOperationName`
|
||||
/// (e.g. LinalgOp, other interfaces, Operation*).
|
||||
template <template <typename> class PatternType, typename OpType,
|
||||
typename OptionsType>
|
||||
void sfinae_enqueue(OwningRewritePatternList &patterList, OptionsType options,
|
||||
MLIRContext *context, StringRef opName,
|
||||
linalg::LinalgTransformationFilter m) {
|
||||
assert(!opName.empty() && "opName must not be empty");
|
||||
patterList.insert<PatternType<OpType>>(opName, context, options, m);
|
||||
}
|
||||
|
||||
/// Promotion transformation enqueues a particular stage-1 pattern for
|
||||
/// `Tile<LinalgOpType>`with the appropriate `options`.
|
||||
template <typename LinalgOpType>
|
||||
struct Tile : public Transformation {
|
||||
explicit Tile(linalg::LinalgTilingOptions options) : options(options) {}
|
||||
explicit Tile(linalg::LinalgTilingOptions options,
|
||||
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
|
||||
: Transformation(f), opName(""), options(options) {}
|
||||
|
||||
Tile(StringRef name, linalg::LinalgTilingOptions options,
|
||||
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
|
||||
: Transformation(f), opName(name), options(options) {}
|
||||
|
||||
OwningRewritePatternList
|
||||
buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) override {
|
||||
buildRewritePatterns(MLIRContext *context,
|
||||
linalg::LinalgTransformationFilter m) override {
|
||||
OwningRewritePatternList tilingPatterns;
|
||||
tilingPatterns.insert<linalg::LinalgTilingPattern<LinalgOpType>>(
|
||||
context, options, m);
|
||||
sfinae_enqueue<linalg::LinalgTilingPattern, LinalgOpType>(
|
||||
tilingPatterns, options, context, opName, m);
|
||||
return tilingPatterns;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string opName;
|
||||
linalg::LinalgTilingOptions options;
|
||||
};
|
||||
|
||||
|
@ -49,17 +85,26 @@ private:
|
|||
/// `Promote<LinalgOpType>`with the appropriate `options`.
|
||||
template <typename LinalgOpType>
|
||||
struct Promote : public Transformation {
|
||||
explicit Promote(linalg::LinalgPromotionOptions options) : options(options) {}
|
||||
explicit Promote(
|
||||
linalg::LinalgPromotionOptions options,
|
||||
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
|
||||
: Transformation(f), opName(""), options(options) {}
|
||||
|
||||
Promote(StringRef name, linalg::LinalgPromotionOptions options,
|
||||
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
|
||||
: Transformation(f), opName(name), options(options) {}
|
||||
|
||||
OwningRewritePatternList
|
||||
buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) override {
|
||||
buildRewritePatterns(MLIRContext *context,
|
||||
linalg::LinalgTransformationFilter m) override {
|
||||
OwningRewritePatternList promotionPatterns;
|
||||
promotionPatterns.insert<linalg::LinalgPromotionPattern<LinalgOpType>>(
|
||||
context, options, m);
|
||||
sfinae_enqueue<linalg::LinalgPromotionPattern, LinalgOpType>(
|
||||
promotionPatterns, options, context, opName, m);
|
||||
return promotionPatterns;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string opName;
|
||||
linalg::LinalgPromotionOptions options;
|
||||
};
|
||||
|
||||
|
@ -68,25 +113,36 @@ private:
|
|||
/// transfer rewrite forwarding patterns.
|
||||
template <typename LinalgOpType>
|
||||
struct Vectorize : public Transformation {
|
||||
explicit Vectorize(
|
||||
linalg::LinalgVectorizationOptions options,
|
||||
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
|
||||
: Transformation(f), opName(""), options(options) {}
|
||||
|
||||
Vectorize(StringRef name, linalg::LinalgVectorizationOptions options,
|
||||
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
|
||||
: Transformation(f), opName(name), options(options) {}
|
||||
|
||||
OwningRewritePatternList
|
||||
buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) override {
|
||||
buildRewritePatterns(MLIRContext *context,
|
||||
linalg::LinalgTransformationFilter m) override {
|
||||
OwningRewritePatternList vectorizationPatterns;
|
||||
// FillOp may interfere with forwarding patterns atm, so we bump up the
|
||||
// priority of LinalgCopyVTRForwardingPattern /
|
||||
// LinalgCopyVTWForwardingPattern.
|
||||
vectorizationPatterns
|
||||
.insert<linalg::LinalgVectorizationPattern<LinalgOpType>>(context, m);
|
||||
sfinae_enqueue<linalg::LinalgVectorizationPattern, LinalgOpType>(
|
||||
vectorizationPatterns, options, context, opName, m);
|
||||
vectorizationPatterns.insert<linalg::LinalgCopyVTRForwardingPattern,
|
||||
linalg::LinalgCopyVTWForwardingPattern>(
|
||||
context, /*benefit=*/2);
|
||||
return vectorizationPatterns;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string opName;
|
||||
linalg::LinalgVectorizationOptions options;
|
||||
};
|
||||
|
||||
/// Codegen strategy controls how a Linalg op is progressively lowered.
|
||||
/// The application uses a 3-level staged patterns strategy which allows
|
||||
/// ordering transformations by using the Linalg `applyStagedPatterns` function,
|
||||
/// where:
|
||||
/// ordering transformations by using the Linalg `applyStagedPatterns`
|
||||
/// function, where:
|
||||
/// 1. The first stage consists of the successive `tile`, `promote` and
|
||||
/// `vectorize` patterns, applied sequentially.
|
||||
/// 2. The second stage consists of common local canonicalization patterns
|
||||
|
@ -97,41 +153,112 @@ struct CodegenStrategy {
|
|||
/// Append a pattern to add a level of tiling for `LinalgOpType` with tiling
|
||||
/// `options`.
|
||||
template <typename LinalgOpType>
|
||||
CodegenStrategy &tile(linalg::LinalgTilingOptions options) {
|
||||
transformationSequence.emplace_back(new Tile<LinalgOpType>(options));
|
||||
CodegenStrategy &
|
||||
tile(linalg::LinalgTilingOptions options,
|
||||
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
|
||||
transformationSequence.emplace_back(
|
||||
std::make_unique<Tile<LinalgOpType>>(options, f));
|
||||
return *this;
|
||||
}
|
||||
/// Conditionally append a pattern to add a level of tiling for `LinalgOpType`
|
||||
/// with tiling `options`.
|
||||
/// Append a pattern to add a level of tiling for `LinalgOpType` with tiling
|
||||
/// `options`.
|
||||
template <typename LinalgOpType>
|
||||
CodegenStrategy &tileIf(bool b, linalg::LinalgTilingOptions options) {
|
||||
CodegenStrategy &
|
||||
tile(StringRef opName, linalg::LinalgTilingOptions options,
|
||||
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
|
||||
transformationSequence.emplace_back(
|
||||
std::make_unique<Tile<LinalgOpType>>(opName, options, f));
|
||||
return *this;
|
||||
}
|
||||
/// Conditionally append a pattern to add a level of tiling for
|
||||
/// `LinalgOpType` with tiling `options`.
|
||||
template <typename LinalgOpType>
|
||||
CodegenStrategy &
|
||||
tileIf(bool b, linalg::LinalgTilingOptions options,
|
||||
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
|
||||
return b ? tile<LinalgOpType>(options) : *this;
|
||||
}
|
||||
/// Conditionally append a pattern to add a level of tiling for
|
||||
/// `LinalgOpType` with tiling `options`.
|
||||
template <typename LinalgOpType>
|
||||
CodegenStrategy &
|
||||
tileIf(bool b, StringRef opName, linalg::LinalgTilingOptions options,
|
||||
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
|
||||
return b ? tile<LinalgOpType>(opName, options) : *this;
|
||||
}
|
||||
/// Append a pattern to add a level of promotion for `LinalgOpType` with
|
||||
/// promotion `options`.
|
||||
template <typename LinalgOpType>
|
||||
CodegenStrategy &promote(linalg::LinalgPromotionOptions options) {
|
||||
transformationSequence.emplace_back(new Promote<LinalgOpType>(options));
|
||||
CodegenStrategy &
|
||||
promote(linalg::LinalgPromotionOptions options,
|
||||
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
|
||||
transformationSequence.emplace_back(
|
||||
std::make_unique<Promote<LinalgOpType>>(options, f));
|
||||
return *this;
|
||||
}
|
||||
/// Append a pattern to add a level of promotion for `LinalgOpType` with
|
||||
/// promotion `options`.
|
||||
template <typename LinalgOpType>
|
||||
CodegenStrategy &
|
||||
promote(StringRef opName, linalg::LinalgPromotionOptions options,
|
||||
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
|
||||
transformationSequence.emplace_back(
|
||||
std::make_unique<Promote<LinalgOpType>>(opName, options, f));
|
||||
return *this;
|
||||
}
|
||||
/// Conditionally append a pattern to add a level of promotion for
|
||||
/// `LinalgOpType` with promotion `options`.
|
||||
template <typename LinalgOpType>
|
||||
CodegenStrategy &promoteIf(bool b, linalg::LinalgPromotionOptions options) {
|
||||
return b ? promote<LinalgOpType>(options) : *this;
|
||||
CodegenStrategy &
|
||||
promoteIf(bool b, StringRef opName, linalg::LinalgPromotionOptions options,
|
||||
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
|
||||
return b ? promote<LinalgOpType>(opName, options, f) : *this;
|
||||
return *this;
|
||||
}
|
||||
/// Conditionally append a pattern to add a level of promotion for
|
||||
/// `LinalgOpType` with promotion `options`.
|
||||
template <typename LinalgOpType>
|
||||
CodegenStrategy &
|
||||
promoteIf(bool b, linalg::LinalgPromotionOptions options,
|
||||
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
|
||||
return b ? promote<LinalgOpType>(options, f) : *this;
|
||||
return *this;
|
||||
}
|
||||
/// Append a pattern to rewrite `LinalgOpType` as a vector operation.
|
||||
template <typename LinalgOpType>
|
||||
CodegenStrategy &vectorize() {
|
||||
transformationSequence.emplace_back(new Vectorize<LinalgOpType>());
|
||||
CodegenStrategy &
|
||||
vectorize(linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
|
||||
transformationSequence.emplace_back(
|
||||
std::make_unique<Vectorize<LinalgOpType>>(
|
||||
linalg::LinalgVectorizationOptions(), f));
|
||||
return *this;
|
||||
}
|
||||
/// Append a pattern to rewrite `LinalgOpType` as a vector operation.
|
||||
template <typename LinalgOpType>
|
||||
CodegenStrategy &
|
||||
vectorize(StringRef opName,
|
||||
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
|
||||
transformationSequence.emplace_back(
|
||||
std::make_unique<Vectorize<LinalgOpType>>(
|
||||
opName, linalg::LinalgVectorizationOptions(), f));
|
||||
return *this;
|
||||
}
|
||||
/// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
|
||||
/// operation.
|
||||
template <typename LinalgOpType>
|
||||
CodegenStrategy &vectorizeIf(bool b) {
|
||||
return b ? vectorize<LinalgOpType>() : *this;
|
||||
CodegenStrategy &
|
||||
vectorizeIf(bool b,
|
||||
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
|
||||
return b ? vectorize<LinalgOpType>(f) : *this;
|
||||
return *this;
|
||||
}
|
||||
/// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
|
||||
/// operation.
|
||||
template <typename LinalgOpType>
|
||||
CodegenStrategy &
|
||||
vectorizeIf(bool b, StringRef opName,
|
||||
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
|
||||
return b ? vectorize<LinalgOpType>(opName, f) : *this;
|
||||
return *this;
|
||||
}
|
||||
/// Configure the post staged-patterns late vector transformations.
|
||||
|
@ -140,15 +267,22 @@ struct CodegenStrategy {
|
|||
vectorTransformsOptions = options;
|
||||
return *this;
|
||||
}
|
||||
/// Configure the post staged-patterns late vector.transfer to scf conversion.
|
||||
/// Configure the post staged-patterns late vector.transfer to scf
|
||||
/// conversion.
|
||||
CodegenStrategy &
|
||||
setVectorTransferToSCFOptions(VectorTransferToSCFOptions options) {
|
||||
vectorToSCFOptions = options;
|
||||
return *this;
|
||||
}
|
||||
/// Configure the post staged-patterns late vector.transfer to scf
|
||||
/// conversion.
|
||||
CodegenStrategy &setHoistInvariantCode(bool enableLICM) {
|
||||
this->enableLICM = enableLICM;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Apply the transformation patterns in sequence with cleanup transformations
|
||||
/// interleaved.
|
||||
/// Apply the transformation patterns in sequence with cleanup
|
||||
/// transformations interleaved.
|
||||
void transform(FuncOp func) const;
|
||||
|
||||
private:
|
||||
|
@ -157,6 +291,7 @@ private:
|
|||
vector::VectorTransformsOptions vectorTransformsOptions;
|
||||
VectorTransferToSCFOptions vectorToSCFOptions;
|
||||
SmallVector<std::unique_ptr<Transformation>, 4> transformationSequence;
|
||||
bool enableLICM = true;
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
|
|
|
@ -316,16 +316,32 @@ struct LinalgTransforms {
|
|||
static const StringLiteral kLinalgTransformMarker;
|
||||
};
|
||||
|
||||
/// Helper class to control common attribute matching and setting behavior.
|
||||
struct LinalgMarker {
|
||||
explicit LinalgMarker(ArrayRef<Identifier> matchDisjunction = {},
|
||||
Optional<Identifier> replacement = None);
|
||||
LinalgMarker(LinalgMarker &&) = default;
|
||||
LinalgMarker(const LinalgMarker &) = default;
|
||||
/// Helper class to control application of linalg transformation patterns.
|
||||
/// Control comes in 2 forms:
|
||||
/// 1. attribute matching and setting behavior using the attribute named
|
||||
/// `kLinalgTransformMarker`. This can be used to build a state machine
|
||||
/// using attributes and incrementally applying patterns to advance states.
|
||||
/// 2. filter function, which is a simple lambda on the Operation* that
|
||||
/// returns a LogicalResult.
|
||||
struct LinalgTransformationFilter {
|
||||
using FilterFunction = std::function<LogicalResult(Operation *)>;
|
||||
|
||||
explicit LinalgTransformationFilter(
|
||||
ArrayRef<Identifier> matchDisjunction = {},
|
||||
Optional<Identifier> replacement = None);
|
||||
|
||||
explicit LinalgTransformationFilter(
|
||||
FilterFunction f, ArrayRef<Identifier> matchDisjunction = {},
|
||||
Optional<Identifier> replacement = None);
|
||||
|
||||
LinalgTransformationFilter(LinalgTransformationFilter &&) = default;
|
||||
LinalgTransformationFilter(const LinalgTransformationFilter &) = default;
|
||||
LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
|
||||
void replaceLinalgMarker(PatternRewriter &rewriter, Operation *op) const;
|
||||
void replaceLinalgTransformationFilter(PatternRewriter &rewriter,
|
||||
Operation *op) const;
|
||||
|
||||
private:
|
||||
FilterFunction filter;
|
||||
SmallVector<Identifier, 4> matchDisjunction;
|
||||
Optional<Identifier> replacement;
|
||||
};
|
||||
|
@ -425,31 +441,44 @@ void populateLinalgTilingCanonicalizationPatterns(
|
|||
/// and some operand shape cannot be bounded statically.
|
||||
struct LinalgBaseTilingPattern : public RewritePattern {
|
||||
// Entry point to match any LinalgOp OpInterface.
|
||||
LinalgBaseTilingPattern(LinalgTilingOptions options,
|
||||
LinalgMarker marker = LinalgMarker(),
|
||||
PatternBenefit benefit = 1);
|
||||
LinalgBaseTilingPattern(
|
||||
LinalgTilingOptions options,
|
||||
LinalgTransformationFilter marker = LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1);
|
||||
// Entry point to match a specific Linalg op.
|
||||
LinalgBaseTilingPattern(StringRef opName, MLIRContext *context,
|
||||
LinalgTilingOptions options,
|
||||
LinalgMarker marker = LinalgMarker(),
|
||||
PatternBenefit benefit = 1);
|
||||
LinalgBaseTilingPattern(
|
||||
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
|
||||
LinalgTransformationFilter marker = LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1);
|
||||
LogicalResult matchAndRewriteBase(Operation *op, PatternRewriter &rewriter,
|
||||
TiledLinalgOp &result) const;
|
||||
|
||||
private:
|
||||
/// LinalgTransformMarker handles special attribute manipulations.
|
||||
LinalgMarker marker;
|
||||
LinalgTransformationFilter marker;
|
||||
/// Options to control tiling;
|
||||
LinalgTilingOptions options;
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct LinalgTilingPattern : public LinalgBaseTilingPattern {
|
||||
LinalgTilingPattern(MLIRContext *context, LinalgTilingOptions options,
|
||||
LinalgMarker marker = LinalgMarker(),
|
||||
PatternBenefit benefit = 1)
|
||||
: LinalgBaseTilingPattern(OpTy::getOperationName(), context, options,
|
||||
marker, benefit) {}
|
||||
/// SFINAE: This constructor can only trigger for concrete ops that have a
|
||||
/// static `getOperationName` method.
|
||||
template <typename ConcreateOpTy = OpTy>
|
||||
LinalgTilingPattern(
|
||||
MLIRContext *context, LinalgTilingOptions options,
|
||||
LinalgTransformationFilter marker = LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1)
|
||||
: LinalgBaseTilingPattern(ConcreateOpTy::getOperationName(), context,
|
||||
options, marker, benefit) {}
|
||||
|
||||
/// This constructor is available to anyone.
|
||||
LinalgTilingPattern(
|
||||
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
|
||||
LinalgTransformationFilter marker = LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1)
|
||||
: LinalgBaseTilingPattern(opName, context, options, marker, benefit) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
TiledLinalgOp tiledLinalgOp;
|
||||
|
@ -474,14 +503,15 @@ struct LinalgFusionOptions {
|
|||
};
|
||||
|
||||
struct LinalgBaseTileAndFusePattern : public RewritePattern {
|
||||
LinalgBaseTileAndFusePattern(StringRef opName, MLIRContext *context,
|
||||
const LinalgDependenceGraph &dependenceGraph,
|
||||
LinalgTilingOptions tilingOptions,
|
||||
LinalgFusionOptions fusionOptions,
|
||||
LinalgMarker marker = LinalgMarker(),
|
||||
LinalgMarker fusedOpMarker = LinalgMarker(),
|
||||
LinalgMarker originalOpMarker = LinalgMarker(),
|
||||
PatternBenefit benefit = 1);
|
||||
LinalgBaseTileAndFusePattern(
|
||||
StringRef opName, MLIRContext *context,
|
||||
const LinalgDependenceGraph &dependenceGraph,
|
||||
LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
|
||||
LinalgTransformationFilter marker = LinalgTransformationFilter(),
|
||||
LinalgTransformationFilter fusedOpMarker = LinalgTransformationFilter(),
|
||||
LinalgTransformationFilter originalOpMarker =
|
||||
LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1);
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
|
@ -493,27 +523,27 @@ private:
|
|||
/// Options to control fusion.
|
||||
LinalgFusionOptions fusionOptions;
|
||||
/// Marker to control application of the pattern.
|
||||
LinalgMarker marker;
|
||||
LinalgTransformationFilter marker;
|
||||
/// Marker set on the fused op after tile and fuse.
|
||||
LinalgMarker fusedOpMarker;
|
||||
LinalgTransformationFilter fusedOpMarker;
|
||||
/// The dependenceGraph is not modifiable, i.e. if the Linalg operations used
|
||||
/// to build the dependence graph changes then the dependenceGraph needs to be
|
||||
/// recomputed right now. To not invalidate the dependenceGraph as
|
||||
/// transformation happens, the original producer can be tagged with a marker
|
||||
/// that can be later used to delete the original operations.
|
||||
LinalgMarker originalOpMarker;
|
||||
LinalgTransformationFilter originalOpMarker;
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern {
|
||||
LinalgTileAndFusePattern(MLIRContext *context,
|
||||
const LinalgDependenceGraph &dependenceGraph,
|
||||
LinalgTilingOptions tilingOptions,
|
||||
LinalgFusionOptions fusionOptions,
|
||||
LinalgMarker marker = LinalgMarker(),
|
||||
LinalgMarker fusedOpMarker = LinalgMarker(),
|
||||
LinalgMarker originalOpMarker = LinalgMarker(),
|
||||
PatternBenefit benefit = 1)
|
||||
LinalgTileAndFusePattern(
|
||||
MLIRContext *context, const LinalgDependenceGraph &dependenceGraph,
|
||||
LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
|
||||
LinalgTransformationFilter marker = LinalgTransformationFilter(),
|
||||
LinalgTransformationFilter fusedOpMarker = LinalgTransformationFilter(),
|
||||
LinalgTransformationFilter originalOpMarker =
|
||||
LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1)
|
||||
: LinalgBaseTileAndFusePattern(
|
||||
OpTy::getOperationName(), context, dependenceGraph, tilingOptions,
|
||||
fusionOptions, marker, fusedOpMarker, originalOpMarker, benefit) {}
|
||||
|
@ -526,26 +556,27 @@ struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern {
|
|||
/// `marker` controls LinalgTransformMarker matching and update when specified.
|
||||
/// See `interchange` for more details.
|
||||
struct LinalgBaseInterchangePattern : public RewritePattern {
|
||||
LinalgBaseInterchangePattern(StringRef opName, MLIRContext *context,
|
||||
ArrayRef<unsigned> interchangeVector,
|
||||
LinalgMarker marker = LinalgMarker(),
|
||||
PatternBenefit benefit = 1);
|
||||
LinalgBaseInterchangePattern(
|
||||
StringRef opName, MLIRContext *context,
|
||||
ArrayRef<unsigned> interchangeVector,
|
||||
LinalgTransformationFilter marker = LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1);
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
/// LinalgTransformMarker handles special attribute manipulations.
|
||||
LinalgMarker marker;
|
||||
LinalgTransformationFilter marker;
|
||||
/// The interchange vector to reorder the iterators and indexing_maps dims.
|
||||
SmallVector<unsigned, 8> interchangeVector;
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct LinalgInterchangePattern : public LinalgBaseInterchangePattern {
|
||||
LinalgInterchangePattern(MLIRContext *context,
|
||||
ArrayRef<unsigned> interchangeVector,
|
||||
LinalgMarker marker = LinalgMarker(),
|
||||
PatternBenefit benefit = 1)
|
||||
LinalgInterchangePattern(
|
||||
MLIRContext *context, ArrayRef<unsigned> interchangeVector,
|
||||
LinalgTransformationFilter marker = LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1)
|
||||
: LinalgBaseInterchangePattern(OpTy::getOperationName(), context,
|
||||
interchangeVector, marker, benefit) {}
|
||||
};
|
||||
|
@ -557,27 +588,38 @@ struct LinalgInterchangePattern : public LinalgBaseInterchangePattern {
|
|||
/// `marker` controls LinalgTransformMarker matching and update when specified.
|
||||
/// See `promoteSubViews` for more details.
|
||||
struct LinalgBasePromotionPattern : public RewritePattern {
|
||||
LinalgBasePromotionPattern(StringRef opName, MLIRContext *context,
|
||||
LinalgPromotionOptions options,
|
||||
LinalgMarker marker = LinalgMarker(),
|
||||
PatternBenefit benefit = 1);
|
||||
LinalgBasePromotionPattern(
|
||||
StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
|
||||
LinalgTransformationFilter marker = LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
/// LinalgTransformMarker handles special attribute manipulations.
|
||||
LinalgMarker marker;
|
||||
LinalgTransformationFilter marker;
|
||||
/// Promotion options.
|
||||
LinalgPromotionOptions options;
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
|
||||
LinalgPromotionPattern(MLIRContext *context, LinalgPromotionOptions options,
|
||||
LinalgMarker marker = LinalgMarker(),
|
||||
PatternBenefit benefit = 1)
|
||||
/// SFINAE: This constructor can only trigger for concrete ops that have a
|
||||
/// static `getOperationName` method.
|
||||
template <typename ConcreateOpTy = OpTy>
|
||||
LinalgPromotionPattern(
|
||||
MLIRContext *context, LinalgPromotionOptions options,
|
||||
LinalgTransformationFilter marker = LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1)
|
||||
: LinalgBasePromotionPattern(OpTy::getOperationName(), context, options,
|
||||
marker, benefit) {}
|
||||
/// This constructor is available to anyone.
|
||||
LinalgPromotionPattern(
|
||||
StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
|
||||
LinalgTransformationFilter marker = LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1)
|
||||
: LinalgBasePromotionPattern(opName, context, options, marker, benefit) {}
|
||||
};
|
||||
|
||||
///
|
||||
|
@ -586,25 +628,42 @@ struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
|
|||
/// Apply the `vectorizeLinalgOp` transformation as a pattern.
|
||||
/// `marker` controls LinalgTransformMarker matching and update when specified.
|
||||
/// See `vectorizeLinalgOp` for more details.
|
||||
|
||||
/// Empty for now, used for SFINAE purposes only.
|
||||
struct LinalgVectorizationOptions {};
|
||||
|
||||
struct LinalgBaseVectorizationPattern : public RewritePattern {
|
||||
LinalgBaseVectorizationPattern(StringRef opName, MLIRContext *context,
|
||||
LinalgMarker marker = LinalgMarker(),
|
||||
PatternBenefit benefit = 1);
|
||||
LinalgBaseVectorizationPattern(
|
||||
StringRef opName, MLIRContext *context,
|
||||
LinalgTransformationFilter marker = LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1);
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
/// LinalgTransformMarker handles special attribute manipulations.
|
||||
LinalgMarker marker;
|
||||
LinalgTransformationFilter marker;
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern {
|
||||
LinalgVectorizationPattern(MLIRContext *context,
|
||||
LinalgMarker marker = LinalgMarker(),
|
||||
PatternBenefit benefit = 1)
|
||||
/// SFINAE: This constructor can only trigger for concrete ops that have a
|
||||
/// static `getOperationName` method.
|
||||
template <typename ConcreateOpTy = OpTy>
|
||||
LinalgVectorizationPattern(
|
||||
MLIRContext *context,
|
||||
LinalgVectorizationOptions options = LinalgVectorizationOptions(),
|
||||
LinalgTransformationFilter marker = LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1)
|
||||
: LinalgBaseVectorizationPattern(OpTy::getOperationName(), context,
|
||||
marker, benefit) {}
|
||||
/// This constructor is available to anyone.
|
||||
LinalgVectorizationPattern(
|
||||
StringRef opName, MLIRContext *context,
|
||||
LinalgVectorizationOptions options = LinalgVectorizationOptions(),
|
||||
LinalgTransformationFilter marker = LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1)
|
||||
: LinalgBaseVectorizationPattern(opName, context, marker, benefit) {}
|
||||
};
|
||||
|
||||
///
|
||||
|
@ -622,10 +681,10 @@ enum class LinalgLoweringType {
|
|||
|
||||
template <typename OpTy>
|
||||
struct LinalgLoweringPattern : public RewritePattern {
|
||||
LinalgLoweringPattern(MLIRContext *context, LinalgLoweringType loweringType,
|
||||
LinalgMarker marker = LinalgMarker(),
|
||||
ArrayRef<unsigned> interchangeVector = {},
|
||||
PatternBenefit benefit = 1)
|
||||
LinalgLoweringPattern(
|
||||
MLIRContext *context, LinalgLoweringType loweringType,
|
||||
LinalgTransformationFilter marker = LinalgTransformationFilter(),
|
||||
ArrayRef<unsigned> interchangeVector = {}, PatternBenefit benefit = 1)
|
||||
: RewritePattern(OpTy::getOperationName(), {}, benefit, context),
|
||||
marker(marker), loweringType(loweringType),
|
||||
interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
|
||||
|
@ -663,7 +722,7 @@ struct LinalgLoweringPattern : public RewritePattern {
|
|||
|
||||
private:
|
||||
/// LinalgTransformMarker handles special attribute manipulations.
|
||||
LinalgMarker marker;
|
||||
LinalgTransformationFilter marker;
|
||||
/// Controls whether the pattern lowers to library calls, scf.for, affine.for
|
||||
/// or scf.parallel.
|
||||
LinalgLoweringType loweringType;
|
||||
|
@ -677,13 +736,13 @@ private:
|
|||
/// linalg.generic ops.
|
||||
void populateLinalgNamedOpsGeneralizationPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList &patterns,
|
||||
LinalgMarker marker = LinalgMarker());
|
||||
LinalgTransformationFilter marker = LinalgTransformationFilter());
|
||||
|
||||
/// Populates `patterns` with patterns to convert linalg.conv ops to
|
||||
/// linalg.generic ops.
|
||||
void populateLinalgConvGeneralizationPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList &patterns,
|
||||
LinalgMarker marker = LinalgMarker());
|
||||
LinalgTransformationFilter marker = LinalgTransformationFilter());
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Op-specific patterns.
|
||||
|
|
|
@ -38,6 +38,7 @@ using std_ret = OperationBuilder<ReturnOp>;
|
|||
using std_rsqrt = ValueBuilder<RsqrtOp>;
|
||||
using std_select = ValueBuilder<SelectOp>;
|
||||
using std_load = ValueBuilder<LoadOp>;
|
||||
using std_sign_extendi = ValueBuilder<SignExtendIOp>;
|
||||
using std_splat = ValueBuilder<SplatOp>;
|
||||
using std_store = OperationBuilder<StoreOp>;
|
||||
using std_subf = ValueBuilder<SubFOp>;
|
||||
|
@ -48,9 +49,19 @@ using std_tensor_load = ValueBuilder<TensorLoadOp>;
|
|||
using std_tensor_store = OperationBuilder<TensorStoreOp>;
|
||||
using std_view = ValueBuilder<ViewOp>;
|
||||
using std_zero_extendi = ValueBuilder<ZeroExtendIOp>;
|
||||
using std_sign_extendi = ValueBuilder<SignExtendIOp>;
|
||||
using tensor_extract = ValueBuilder<tensor::ExtractOp>;
|
||||
|
||||
template <int N>
|
||||
struct SExtiValueBuilder : public ValueBuilder<SignExtendIOp> {
|
||||
using ValueBuilder<SignExtendIOp>::ValueBuilder;
|
||||
template <typename... Args>
|
||||
SExtiValueBuilder(Args... args)
|
||||
: ValueBuilder<SignExtendIOp>(ScopedContext::getBuilderRef().getI32Type(),
|
||||
args...) {}
|
||||
};
|
||||
|
||||
using std_sexti32 = SExtiValueBuilder<32>;
|
||||
|
||||
/// Branches into `block` with `operands`.
|
||||
BranchOp std_br(Block *block, ValueRange operands);
|
||||
|
||||
|
|
|
@ -14,9 +14,12 @@
|
|||
// RUN: tee -a /dev/stderr | FileCheck %s
|
||||
|
||||
|
||||
!row_major_A = type memref<${M}x${K}xf32>
|
||||
!row_major_B = type memref<${K}x${N}xf32>
|
||||
!row_major_C = type memref<${M}x${N}xf32>
|
||||
!elem_type_a = type f32
|
||||
!elem_type_b = type f32
|
||||
!elem_type_c = type f32
|
||||
!row_major_A = type memref<${M}x${K}x!elem_type_a>
|
||||
!row_major_B = type memref<${K}x${N}x!elem_type_b>
|
||||
!row_major_C = type memref<${M}x${N}x!elem_type_c>
|
||||
|
||||
func @matmul(%a: !row_major_A, %b: !row_major_B, %c: !row_major_C)
|
||||
// TODO: activate manually for now.
|
||||
|
@ -48,16 +51,16 @@ func @print_perf(%iters: index, %total_time: f64) {
|
|||
}
|
||||
|
||||
func @main() {
|
||||
%f0 = constant 0.0 : f32
|
||||
%f1 = constant 1.0 : f32
|
||||
%v0 = constant 0.0 : !elem_type_a
|
||||
%v1 = constant 1.0 : !elem_type_a
|
||||
|
||||
%A = alloc() : !row_major_A
|
||||
%B = alloc() : !row_major_B
|
||||
%C = alloc() : !row_major_C
|
||||
|
||||
linalg.fill(%A, %f1) : !row_major_A, f32
|
||||
linalg.fill(%B, %f1) : !row_major_B, f32
|
||||
linalg.fill(%C, %f0) : !row_major_C, f32
|
||||
linalg.fill(%A, %v1) : !row_major_A, !elem_type_a
|
||||
linalg.fill(%B, %v1) : !row_major_B, !elem_type_b
|
||||
linalg.fill(%C, %v0) : !row_major_C, !elem_type_c
|
||||
|
||||
%c0 = constant 0: index
|
||||
%c1 = constant 1: index
|
||||
|
@ -66,7 +69,8 @@ func @main() {
|
|||
/// Run and dump performance for matmul.
|
||||
/// Preheating run:
|
||||
scf.for %arg0 = %c0 to %iters step %c1 {
|
||||
linalg.fill(%C, %f0) : !row_major_C, f32
|
||||
%z = constant 0.0 : !elem_type_c
|
||||
linalg.fill(%C, %z) : !row_major_C, !elem_type_c
|
||||
call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> ()
|
||||
}
|
||||
%t_start_matmul = call @rtclock() : () -> f64
|
||||
|
@ -75,7 +79,8 @@ func @main() {
|
|||
// This is accounts for about 10-15% perf hit on small sizes.
|
||||
// Once linalg on tensors is ready, fusing fill at teh register level will
|
||||
// be easy.
|
||||
linalg.fill(%C, %f0) : !row_major_C, f32
|
||||
%z = constant 0.0 : !elem_type_c
|
||||
linalg.fill(%C, %z) : !row_major_C, !elem_type_c
|
||||
call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> ()
|
||||
}
|
||||
%t_end_matmul = call @rtclock() : () -> f64
|
||||
|
|
|
@ -15,12 +15,15 @@
|
|||
// Use tee to both print to stderr and FileCheck
|
||||
// RUN: tee -a /dev/stderr | FileCheck %s
|
||||
|
||||
!row_major_A = type memref<${M}x${K}xf32>
|
||||
!row_major_B = type memref<${K}x${N}xf32>
|
||||
!row_major_C = type memref<${M}x${N}xf32>
|
||||
!column_major_A = type memref<${K}x${M}xf32>
|
||||
!column_major_B = type memref<${N}x${K}xf32>
|
||||
!column_major_C = type memref<${N}x${M}xf32>
|
||||
!elem_type_a = type f32
|
||||
!elem_type_b = type f32
|
||||
!elem_type_c = type f32
|
||||
!row_major_A = type memref<${M}x${K}x!elem_type_a>
|
||||
!row_major_B = type memref<${K}x${N}x!elem_type_b>
|
||||
!row_major_C = type memref<${M}x${N}x!elem_type_c>
|
||||
!column_major_A = type memref<${K}x${M}x!elem_type_a>
|
||||
!column_major_B = type memref<${N}x${K}x!elem_type_b>
|
||||
!column_major_C = type memref<${N}x${M}x!elem_type_c>
|
||||
|
||||
func @matmul_column_major(%a: !column_major_A, %b: !column_major_B, %c: !column_major_C)
|
||||
// TODO: activate manually for now.
|
||||
|
@ -52,16 +55,16 @@ func @print_perf(%iters: index, %total_time: f64) {
|
|||
}
|
||||
|
||||
func @main() {
|
||||
%f0 = constant 0.0 : f32
|
||||
%f1 = constant 1.0 : f32
|
||||
%f0 = constant 0.0 : !elem_type_c
|
||||
%f1 = constant 1.0 : !elem_type_a
|
||||
|
||||
%cA = alloc() : !column_major_A
|
||||
%cB = alloc() : !column_major_B
|
||||
%cC = alloc() : !column_major_C
|
||||
|
||||
linalg.fill(%cA, %f1) : !column_major_A, f32
|
||||
linalg.fill(%cB, %f1) : !column_major_B, f32
|
||||
linalg.fill(%cC, %f0) : !column_major_C, f32
|
||||
linalg.fill(%cA, %f1) : !column_major_A, !elem_type_a
|
||||
linalg.fill(%cB, %f1) : !column_major_B, !elem_type_b
|
||||
linalg.fill(%cC, %f0) : !column_major_C, !elem_type_c
|
||||
|
||||
%c0 = constant 0: index
|
||||
%c1 = constant 1: index
|
||||
|
@ -74,7 +77,7 @@ func @main() {
|
|||
// This is accounts for about 10-15% perf hit on small sizes.
|
||||
// Once linalg on tensors is ready, fusing fill at teh register level will
|
||||
// be easy.
|
||||
linalg.fill(%cC, %f0) : !column_major_C, f32
|
||||
linalg.fill(%cC, %f0) : !column_major_C, !elem_type_c
|
||||
call @matmul_column_major(%cA, %cB, %cC) : (!column_major_A, !column_major_B, !column_major_C) -> ()
|
||||
}
|
||||
%t_end_matmul_column_major = call @rtclock() : () -> f64
|
||||
|
@ -83,7 +86,7 @@ func @main() {
|
|||
|
||||
%res = load %cC[%c0, %c0]: !column_major_C
|
||||
// CHECK: 64
|
||||
vector.print %res: f32
|
||||
vector.print %res: !elem_type_c
|
||||
|
||||
dealloc %cA : !column_major_A
|
||||
dealloc %cB : !column_major_B
|
||||
|
|
|
@ -16,12 +16,15 @@
|
|||
// Use tee to both print to stderr and FileCheck
|
||||
// RUN: tee -a /dev/stderr | FileCheck %s
|
||||
|
||||
!row_major_A = type memref<${M}x${K}xf32>
|
||||
!row_major_B = type memref<${K}x${N}xf32>
|
||||
!row_major_C = type memref<${M}x${N}xf32>
|
||||
!column_major_A = type memref<${K}x${M}xf32>
|
||||
!column_major_B = type memref<${N}x${K}xf32>
|
||||
!column_major_C = type memref<${N}x${M}xf32>
|
||||
!elem_type_a = type f32
|
||||
!elem_type_b = type f32
|
||||
!elem_type_c = type f32
|
||||
!row_major_A = type memref<${M}x${K}x!elem_type_a>
|
||||
!row_major_B = type memref<${K}x${N}x!elem_type_b>
|
||||
!row_major_C = type memref<${M}x${N}x!elem_type_c>
|
||||
!column_major_A = type memref<${K}x${M}x!elem_type_a>
|
||||
!column_major_B = type memref<${N}x${K}x!elem_type_b>
|
||||
!column_major_C = type memref<${N}x${M}x!elem_type_c>
|
||||
|
||||
func @matmul_column_major_as_row_major(
|
||||
%ca: !column_major_A, %cb: !column_major_B, %cc: !column_major_C,
|
||||
|
@ -58,16 +61,16 @@ func @print_perf(%iters: index, %total_time: f64) {
|
|||
}
|
||||
|
||||
func @main() {
|
||||
%f0 = constant 0.0 : f32
|
||||
%f1 = constant 1.0 : f32
|
||||
%f0 = constant 0.0 : !elem_type_c
|
||||
%f1 = constant 1.0 : !elem_type_a
|
||||
|
||||
%cA = alloc() : !column_major_A
|
||||
%cB = alloc() : !column_major_B
|
||||
%cC = alloc() : !column_major_C
|
||||
|
||||
linalg.fill(%cA, %f1) : !column_major_A, f32
|
||||
linalg.fill(%cB, %f1) : !column_major_B, f32
|
||||
linalg.fill(%cC, %f0) : !column_major_C, f32
|
||||
linalg.fill(%cA, %f1) : !column_major_A, !elem_type_a
|
||||
linalg.fill(%cB, %f1) : !column_major_B, !elem_type_b
|
||||
linalg.fill(%cC, %f0) : !column_major_C, !elem_type_c
|
||||
|
||||
%c0 = constant 0: index
|
||||
%c1 = constant 1: index
|
||||
|
@ -83,7 +86,7 @@ func @main() {
|
|||
// This is accounts for about 10-15% perf hit on small sizes.
|
||||
// Once linalg on tensors is ready, fusing fill at teh register level will
|
||||
// be easy.
|
||||
linalg.fill(%C, %f0) : !row_major_C, f32
|
||||
linalg.fill(%C, %f0) : !row_major_C, !elem_type_c
|
||||
call @matmul_column_major_as_row_major(%cA, %cB, %cC, %A, %B, %C) :
|
||||
(!column_major_A, !column_major_B, !column_major_C,
|
||||
!row_major_A, !row_major_B, !row_major_C) -> ()
|
||||
|
@ -94,10 +97,10 @@ func @main() {
|
|||
|
||||
%res = load %cC[%c0, %c0]: !column_major_C
|
||||
// CHECK: 64
|
||||
vector.print %res: f32
|
||||
vector.print %res: !elem_type_c
|
||||
%res2 = load %C[%c0, %c0]: !row_major_C
|
||||
// CHECK: 64
|
||||
vector.print %res2: f32
|
||||
vector.print %res2: !elem_type_c
|
||||
|
||||
dealloc %A : !row_major_A
|
||||
dealloc %B : !row_major_B
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
// RUN: export M=24 && export K=64 && export N=192 && export ITERS=10 && \
|
||||
// RUN: cat %s | sed 's@${M}@'"$M"'@g'| sed 's@${K}@'"$K"'@g' | sed 's@${N}@'"$N"'@g'| sed 's@${ITERS}@'"$ITERS"'@g'| \
|
||||
// TODO: extend vectorization with interfaces so that it works with sexti
|
||||
// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.matmul_i8_i8_i32 register-tile-sizes=12,32,16" | \
|
||||
// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.fill register-tile-sizes=4,32 vectorize" | \
|
||||
// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.copy register-tile-sizes=4,32 vectorize" | \
|
||||
// RUN: mlir-opt -canonicalize -convert-vector-to-scf -lower-affine -convert-linalg-to-loops | \
|
||||
|
||||
// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm | \
|
||||
// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \
|
||||
// Activate to dump assembly
|
||||
// R_UN: -dump-object-file -object-filename=/tmp/a.o \
|
||||
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext | \
|
||||
// Use tee to both print to stderr and FileCheck
|
||||
// RUN: tee -a /dev/stderr | FileCheck %s
|
||||
|
||||
|
||||
!elem_type_a = type i8
|
||||
!elem_type_b = type i8
|
||||
!elem_type_c = type i32
|
||||
!row_major_A = type memref<${M}x${K}x!elem_type_a>
|
||||
!row_major_B = type memref<${K}x${N}x!elem_type_b>
|
||||
!row_major_C = type memref<${M}x${N}x!elem_type_c>
|
||||
|
||||
func @matmul(%a: !row_major_A, %b: !row_major_B, %c: !row_major_C)
|
||||
// TODO: activate manually for now.
|
||||
// attributes { passthrough = [["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]}
|
||||
{
|
||||
linalg.matmul_i8_i8_i32 ins(%a, %b : !row_major_A, !row_major_B)
|
||||
outs(%c: !row_major_C)
|
||||
return
|
||||
}
|
||||
|
||||
func @print_perf(%iters: index, %total_time: f64) {
|
||||
%c2 = constant 2 : index
|
||||
%cM = constant ${M} : index
|
||||
%cN = constant ${N} : index
|
||||
%cK = constant ${K} : index
|
||||
|
||||
%mn = muli %cM, %cN : index
|
||||
%mnk = muli %mn, %cK : index
|
||||
|
||||
// 2*M*N*K.
|
||||
%flops_per_iter = muli %c2, %mnk : index
|
||||
%flops = muli %iters, %flops_per_iter : index
|
||||
%flops_i64 = index_cast %flops : index to i64
|
||||
%flops_f = sitofp %flops_i64 : i64 to f64
|
||||
%flops_per_s = divf %flops_f, %total_time : f64
|
||||
vector.print %flops_per_s : f64
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func @main() {
|
||||
%v0 = constant 0 : !elem_type_c
|
||||
%v1 = constant 1 : !elem_type_a
|
||||
|
||||
%A = alloc() : !row_major_A
|
||||
%B = alloc() : !row_major_B
|
||||
%C = alloc() : !row_major_C
|
||||
|
||||
linalg.fill(%A, %v1) : !row_major_A, !elem_type_a
|
||||
linalg.fill(%B, %v1) : !row_major_B, !elem_type_b
|
||||
linalg.fill(%C, %v0) : !row_major_C, !elem_type_c
|
||||
|
||||
%c0 = constant 0: index
|
||||
%c1 = constant 1: index
|
||||
%iters = constant ${ITERS}: index
|
||||
|
||||
/// Run and dump performance for matmul.
|
||||
/// Preheating run:
|
||||
scf.for %arg0 = %c0 to %iters step %c1 {
|
||||
linalg.fill(%C, %v0) : !row_major_C, !elem_type_c
|
||||
call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> ()
|
||||
}
|
||||
%t_start_matmul = call @rtclock() : () -> f64
|
||||
scf.for %arg0 = %c0 to %iters step %c1 {
|
||||
// linalg.matmul writes %C in place, need to reset it to zero every time.
|
||||
// This is accounts for about 10-15% perf hit on small sizes.
|
||||
// Once linalg on tensors is ready, fusing fill at teh register level will
|
||||
// be easy.
|
||||
linalg.fill(%C, %v0) : !row_major_C, !elem_type_c
|
||||
call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> ()
|
||||
}
|
||||
%t_end_matmul = call @rtclock() : () -> f64
|
||||
%tmatmul = subf %t_end_matmul, %t_start_matmul: f64
|
||||
call @print_perf(%iters, %tmatmul) : (index, f64) -> ()
|
||||
|
||||
%res = load %C[%c0, %c0]: !row_major_C
|
||||
// CHECK: 64
|
||||
vector.print %res: !elem_type_c
|
||||
|
||||
dealloc %A : !row_major_A
|
||||
dealloc %B : !row_major_B
|
||||
dealloc %C : !row_major_C
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func private @rtclock() -> f64
|
||||
|
||||
// TODO: init with random, run and check output.
|
||||
// func private @fill_random_f32(memref<*xf32>)
|
|
@ -37,8 +37,10 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
|
|||
for (const std::unique_ptr<Transformation> &t : transformationSequence) {
|
||||
auto nextState = Identifier::get(std::to_string(++stepCount), context);
|
||||
auto marker = (currentState == zeroState)
|
||||
? linalg::LinalgMarker({}, nextState)
|
||||
: linalg::LinalgMarker(currentState, nextState);
|
||||
? linalg::LinalgTransformationFilter(
|
||||
t->filter, ArrayRef<Identifier>{}, nextState)
|
||||
: linalg::LinalgTransformationFilter(
|
||||
t->filter, currentState, nextState);
|
||||
stage1Patterns.emplace_back(t->buildRewritePatterns(context, marker));
|
||||
currentState = nextState;
|
||||
}
|
||||
|
@ -47,15 +49,17 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
|
|||
linalg::getLinalgTilingCanonicalizationPatterns(context);
|
||||
stage2Patterns.insert<AffineMinSCFCanonicalizationPattern>(context);
|
||||
|
||||
auto stage3Transforms = [](Operation *op) {
|
||||
auto stage3Transforms = [&](Operation *op) {
|
||||
// Some of these may be too aggressive as a stage 3 that is applied on each
|
||||
// stage 1 application and may have to be split out to post staged patterns
|
||||
// application (in which case they could just be passes, TBD).
|
||||
op->walk([&](LoopLikeOpInterface loopLike) {
|
||||
LLVM_DEBUG(loopLike.print(llvm::dbgs() << "\nOriginal loop:\n"));
|
||||
if (failed(moveLoopInvariantCode(loopLike)))
|
||||
llvm_unreachable("unexpected LICM failure");
|
||||
});
|
||||
if (enableLICM) {
|
||||
op->walk([&](LoopLikeOpInterface loopLike) {
|
||||
LLVM_DEBUG(loopLike.print(llvm::dbgs() << "\nOriginal loop:\n"));
|
||||
if (failed(moveLoopInvariantCode(loopLike)))
|
||||
llvm_unreachable("unexpected LICM failure");
|
||||
});
|
||||
}
|
||||
promoteSingleIterationLoops(cast<FuncOp>(op));
|
||||
hoistViewAllocOps(cast<FuncOp>(op));
|
||||
hoistRedundantVectorTransfers(cast<FuncOp>(op));
|
||||
|
|
|
@ -63,7 +63,8 @@ namespace {
|
|||
// into auto-generated ones.
|
||||
template <typename ConcretePattern, typename RootOp>
|
||||
struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> {
|
||||
LinalgGeneralizationPattern(MLIRContext *context, linalg::LinalgMarker marker,
|
||||
LinalgGeneralizationPattern(MLIRContext *context,
|
||||
linalg::LinalgTransformationFilter marker,
|
||||
PatternBenefit benefit = 1)
|
||||
: OpRewritePattern<RootOp>(context, benefit), marker(std::move(marker)) {}
|
||||
|
||||
|
@ -81,12 +82,13 @@ struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> {
|
|||
return failure();
|
||||
|
||||
rewriter.replaceOp(rootOp, genericOp.getResults());
|
||||
marker.replaceLinalgMarker(rewriter, genericOp.getOperation());
|
||||
marker.replaceLinalgTransformationFilter(rewriter,
|
||||
genericOp.getOperation());
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
linalg::LinalgMarker marker;
|
||||
linalg::LinalgTransformationFilter marker;
|
||||
};
|
||||
|
||||
struct GeneralizeConvOp
|
||||
|
@ -100,7 +102,7 @@ struct GeneralizeConvOp
|
|||
/// linalg.generic.
|
||||
struct LinalgNamedOpGeneralizationPattern : RewritePattern {
|
||||
LinalgNamedOpGeneralizationPattern(MLIRContext *context,
|
||||
linalg::LinalgMarker marker,
|
||||
linalg::LinalgTransformationFilter marker,
|
||||
PatternBenefit benefit = 1)
|
||||
: RewritePattern(benefit, MatchAnyOpTypeTag()),
|
||||
marker(std::move(marker)) {}
|
||||
|
@ -123,12 +125,13 @@ struct LinalgNamedOpGeneralizationPattern : RewritePattern {
|
|||
return failure();
|
||||
|
||||
rewriter.replaceOp(rootOp, genericOp.getResults());
|
||||
marker.replaceLinalgMarker(rewriter, genericOp.getOperation());
|
||||
marker.replaceLinalgTransformationFilter(rewriter,
|
||||
genericOp.getOperation());
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
linalg::LinalgMarker marker;
|
||||
linalg::LinalgTransformationFilter marker;
|
||||
};
|
||||
|
||||
struct LinalgGeneralizationPass
|
||||
|
@ -165,13 +168,13 @@ linalg::GenericOp GeneralizeConvOp::createGenericOp(linalg::ConvOp convOp,
|
|||
|
||||
void mlir::linalg::populateLinalgConvGeneralizationPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList &patterns,
|
||||
linalg::LinalgMarker marker) {
|
||||
linalg::LinalgTransformationFilter marker) {
|
||||
patterns.insert<GeneralizeConvOp>(context, marker);
|
||||
}
|
||||
|
||||
void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList &patterns,
|
||||
linalg::LinalgMarker marker) {
|
||||
linalg::LinalgTransformationFilter marker) {
|
||||
patterns.insert<LinalgNamedOpGeneralizationPattern>(context, marker);
|
||||
}
|
||||
|
||||
|
|
|
@ -536,7 +536,9 @@ public:
|
|||
static void insert(OwningRewritePatternList &patterns,
|
||||
const LinalgTilingOptions &options, MLIRContext *ctx) {
|
||||
patterns.insert<LinalgTilingPattern<OpTy>>(
|
||||
ctx, options, LinalgMarker({}, Identifier::get("tiled", ctx)));
|
||||
ctx, options,
|
||||
LinalgTransformationFilter(ArrayRef<Identifier>{},
|
||||
Identifier::get("tiled", ctx)));
|
||||
RewritePatternList<OpTypes...>::insert(patterns, options, ctx);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -46,14 +46,23 @@ using namespace mlir::linalg;
|
|||
const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
|
||||
"__internal_linalg_transform__";
|
||||
|
||||
mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<Identifier> matchDisjunction,
|
||||
Optional<Identifier> replacement)
|
||||
: matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
|
||||
mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
|
||||
ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement)
|
||||
: LinalgTransformationFilter([](Operation *) { return success(); },
|
||||
matchDisjunction, replacement) {}
|
||||
|
||||
mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
|
||||
FilterFunction f, ArrayRef<Identifier> matchDisjunction,
|
||||
Optional<Identifier> replacement)
|
||||
: filter(f),
|
||||
matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
|
||||
replacement(replacement) {}
|
||||
|
||||
LogicalResult
|
||||
mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
|
||||
Operation *op) const {
|
||||
LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
|
||||
PatternRewriter &rewriter, Operation *op) const {
|
||||
if (filter && failed(filter(op)))
|
||||
return failure();
|
||||
|
||||
auto attr = op->template getAttrOfType<StringAttr>(
|
||||
LinalgTransforms::kLinalgTransformMarker);
|
||||
|
||||
|
@ -81,8 +90,9 @@ mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
|
|||
});
|
||||
}
|
||||
|
||||
void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter,
|
||||
Operation *op) const {
|
||||
void mlir::linalg::LinalgTransformationFilter::
|
||||
replaceLinalgTransformationFilter(PatternRewriter &rewriter,
|
||||
Operation *op) const {
|
||||
if (replacement.hasValue())
|
||||
op->setAttr(LinalgTransforms::kLinalgTransformMarker,
|
||||
rewriter.getStringAttr(replacement.getValue()));
|
||||
|
@ -219,12 +229,13 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
|
|||
/// Linalg base tiling pattern.
|
||||
mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
|
||||
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
|
||||
LinalgMarker marker, PatternBenefit benefit)
|
||||
LinalgTransformationFilter marker, PatternBenefit benefit)
|
||||
: RewritePattern(opName, {}, benefit, context), marker(marker),
|
||||
options(options) {}
|
||||
|
||||
mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
|
||||
LinalgTilingOptions options, LinalgMarker marker, PatternBenefit benefit)
|
||||
LinalgTilingOptions options, LinalgTransformationFilter marker,
|
||||
PatternBenefit benefit)
|
||||
: RewritePattern(benefit, MatchAnyOpTypeTag()), marker(marker),
|
||||
options(options) {}
|
||||
|
||||
|
@ -250,9 +261,9 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
|
|||
// Return relevant information to derived pattern.
|
||||
result = *res;
|
||||
// Replace marker on both tiledOp and tiledAndPaddedOp, if necessary.
|
||||
marker.replaceLinalgMarker(rewriter, tiledOp);
|
||||
marker.replaceLinalgTransformationFilter(rewriter, tiledOp);
|
||||
if (tiledOp != res->op)
|
||||
marker.replaceLinalgMarker(rewriter, res->op);
|
||||
marker.replaceLinalgTransformationFilter(rewriter, res->op);
|
||||
});
|
||||
|
||||
// Consider padding on the fly only if the op has tensor semantics.
|
||||
|
@ -276,8 +287,8 @@ mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
|
|||
StringRef opName, MLIRContext *context,
|
||||
const LinalgDependenceGraph &dependenceGraph,
|
||||
LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
|
||||
LinalgMarker marker, LinalgMarker fusedOpMarker,
|
||||
LinalgMarker originalOpMarker, PatternBenefit benefit)
|
||||
LinalgTransformationFilter marker, LinalgTransformationFilter fusedOpMarker,
|
||||
LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
|
||||
: RewritePattern(opName, {}, benefit, context),
|
||||
dependenceGraph(dependenceGraph), tilingOptions(tilingOptions),
|
||||
fusionOptions(fusionOptions), marker(marker),
|
||||
|
@ -352,23 +363,26 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
|
|||
tiledAndFusedOps->op = unfusedTiledOp->op;
|
||||
}
|
||||
|
||||
marker.replaceLinalgMarker(rewriter, tiledAndFusedOps->op.getOperation());
|
||||
marker.replaceLinalgTransformationFilter(rewriter,
|
||||
tiledAndFusedOps->op.getOperation());
|
||||
for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
|
||||
fusedOpMarker.replaceLinalgMarker(rewriter, fusedOp.getOperation());
|
||||
fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
|
||||
fusedOp.getOperation());
|
||||
}
|
||||
for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
|
||||
originalOpMarker.replaceLinalgMarker(rewriter,
|
||||
origProducerOp.getOperation());
|
||||
originalOpMarker.replaceLinalgTransformationFilter(
|
||||
rewriter, origProducerOp.getOperation());
|
||||
}
|
||||
rewriter.updateRootInPlace(
|
||||
op, [&]() { originalOpMarker.replaceLinalgMarker(rewriter, op); });
|
||||
rewriter.updateRootInPlace(op, [&]() {
|
||||
originalOpMarker.replaceLinalgTransformationFilter(rewriter, op);
|
||||
});
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Linalg base interchange pattern.
|
||||
mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
|
||||
StringRef opName, MLIRContext *context,
|
||||
ArrayRef<unsigned> interchangeVector, LinalgMarker marker,
|
||||
ArrayRef<unsigned> interchangeVector, LinalgTransformationFilter marker,
|
||||
PatternBenefit benefit)
|
||||
: RewritePattern(opName, {}, benefit, context), marker(marker),
|
||||
interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
|
||||
|
@ -388,14 +402,14 @@ LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
|
|||
rewriter.updateRootInPlace(op, [&]() {
|
||||
interchange(linalgOp, interchangeVector);
|
||||
// New marker if specified.
|
||||
marker.replaceLinalgMarker(rewriter, op);
|
||||
marker.replaceLinalgTransformationFilter(rewriter, op);
|
||||
});
|
||||
return success();
|
||||
}
|
||||
|
||||
mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
|
||||
StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
|
||||
LinalgMarker marker, PatternBenefit benefit)
|
||||
LinalgTransformationFilter marker, PatternBenefit benefit)
|
||||
: RewritePattern(opName, {}, benefit, context), marker(marker),
|
||||
options(options) {}
|
||||
|
||||
|
@ -417,12 +431,12 @@ LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
|
|||
return op->emitError("subview promotion failed");
|
||||
}
|
||||
rewriter.finalizeRootUpdate(op);
|
||||
marker.replaceLinalgMarker(rewriter, op);
|
||||
marker.replaceLinalgTransformationFilter(rewriter, op);
|
||||
return success();
|
||||
}
|
||||
|
||||
mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
|
||||
StringRef opName, MLIRContext *context, LinalgMarker marker,
|
||||
StringRef opName, MLIRContext *context, LinalgTransformationFilter marker,
|
||||
PatternBenefit benefit)
|
||||
: RewritePattern(opName, {}, benefit, context), marker(marker) {}
|
||||
|
||||
|
|
|
@ -607,12 +607,13 @@ populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns,
|
|||
constexpr static StringRef kPromotedMarker = "PROMOTED";
|
||||
tilingPatterns.insert<LinalgTilingPattern<ConvOp>>(
|
||||
context, LinalgTilingOptions().setTileSizes(tileSizes),
|
||||
LinalgMarker({}, Identifier::get(kTiledMarker, context)));
|
||||
LinalgTransformationFilter(ArrayRef<Identifier>{},
|
||||
Identifier::get(kTiledMarker, context)));
|
||||
|
||||
promotionPatterns.insert<LinalgPromotionPattern<ConvOp>>(
|
||||
context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
|
||||
LinalgMarker(Identifier::get(kTiledMarker, context),
|
||||
Identifier::get(kPromotedMarker, context)));
|
||||
LinalgTransformationFilter(Identifier::get(kTiledMarker, context),
|
||||
Identifier::get(kPromotedMarker, context)));
|
||||
|
||||
SmallVector<bool, 4> mask(N);
|
||||
int offset = tileSizes.size() - N;
|
||||
|
|
|
@ -107,8 +107,12 @@ struct TestLinalgCodegenStrategy
|
|||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
template <typename LinalgNamedOp>
|
||||
void TestLinalgCodegenStrategy::applyStrategyToNamedLinalgOp() {
|
||||
/// Apply transformations specified as patterns.
|
||||
void TestLinalgCodegenStrategy::runOnFunction() {
|
||||
linalg::LinalgTransformationFilter::FilterFunction filterOpName =
|
||||
[&](Operation *op) -> LogicalResult {
|
||||
return success(op->getName().getStringRef() == anchorOpName);
|
||||
};
|
||||
LinalgTilingOptions tilingOptions;
|
||||
if (!tileSizes.empty())
|
||||
tilingOptions = tilingOptions.setTileSizes(tileSizes);
|
||||
|
@ -134,19 +138,20 @@ void TestLinalgCodegenStrategy::applyStrategyToNamedLinalgOp() {
|
|||
.Default(vector::VectorTransferSplit::None);
|
||||
|
||||
CodegenStrategy strategy;
|
||||
strategy.template tileIf<LinalgNamedOp>(!tileSizes.empty(), tilingOptions)
|
||||
.template promoteIf<LinalgNamedOp>(
|
||||
promote, LinalgPromotionOptions()
|
||||
.setAlignment(16)
|
||||
.setUseFullTileBuffersByDefault(promoteFullTile))
|
||||
.template tileIf<LinalgNamedOp>(!registerTileSizes.empty(),
|
||||
registerTilingOptions)
|
||||
.template promoteIf<LinalgNamedOp>(
|
||||
registerPromote,
|
||||
strategy.tileIf<LinalgOp>(!tileSizes.empty(), anchorOpName, tilingOptions)
|
||||
.promoteIf<LinalgOp>(promote, anchorOpName,
|
||||
LinalgPromotionOptions()
|
||||
.setAlignment(16)
|
||||
.setUseFullTileBuffersByDefault(promoteFullTile),
|
||||
filterOpName)
|
||||
.tileIf<LinalgOp>(!registerTileSizes.empty(), anchorOpName,
|
||||
registerTilingOptions)
|
||||
.promoteIf<LinalgOp>(
|
||||
registerPromote, anchorOpName,
|
||||
LinalgPromotionOptions()
|
||||
.setAlignment(16)
|
||||
.setUseFullTileBuffersByDefault(registerPromoteFullTile))
|
||||
.template vectorizeIf<LinalgNamedOp>(vectorize)
|
||||
.vectorizeIf<LinalgOp>(vectorize, anchorOpName)
|
||||
.setVectorTransformsOptions(
|
||||
vector::VectorTransformsOptions()
|
||||
.setVectorTransformsOptions(vectorContractLowering)
|
||||
|
@ -156,20 +161,6 @@ void TestLinalgCodegenStrategy::applyStrategyToNamedLinalgOp() {
|
|||
strategy.transform(getFunction());
|
||||
}
|
||||
|
||||
/// Apply transformations specified as patterns.
|
||||
void TestLinalgCodegenStrategy::runOnFunction() {
|
||||
if (anchorOpName == MatmulOp::getOperationName())
|
||||
applyStrategyToNamedLinalgOp<MatmulOp>();
|
||||
else if (anchorOpName == MatmulColumnMajorOp::getOperationName())
|
||||
applyStrategyToNamedLinalgOp<MatmulColumnMajorOp>();
|
||||
else if (anchorOpName == CopyOp::getOperationName())
|
||||
applyStrategyToNamedLinalgOp<CopyOp>();
|
||||
else if (anchorOpName == FillOp::getOperationName())
|
||||
applyStrategyToNamedLinalgOp<FillOp>();
|
||||
else
|
||||
llvm_unreachable("Unsupported anchor op");
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestLinalgCodegenStrategy() {
|
||||
|
|
|
@ -45,12 +45,15 @@ static void fillFusionPatterns(MLIRContext *context,
|
|||
.setTileSizes({32, 64, 16})
|
||||
.setLoopType(LinalgTilingLoopType::ParallelLoops),
|
||||
LinalgFusionOptions().setIndicesToFuse({2}),
|
||||
LinalgMarker(Identifier::get("basic_fusion", context),
|
||||
Identifier::get("after_basic_fusion", context)),
|
||||
LinalgMarker(ArrayRef<Identifier>(),
|
||||
Identifier::get("after_basic_fusion_producer", context)),
|
||||
LinalgMarker(ArrayRef<Identifier>(),
|
||||
Identifier::get("after_basic_fusion_original", context)));
|
||||
LinalgTransformationFilter(
|
||||
Identifier::get("basic_fusion", context),
|
||||
Identifier::get("after_basic_fusion", context)),
|
||||
LinalgTransformationFilter(
|
||||
ArrayRef<Identifier>(),
|
||||
Identifier::get("after_basic_fusion_producer", context)),
|
||||
LinalgTransformationFilter(
|
||||
ArrayRef<Identifier>(),
|
||||
Identifier::get("after_basic_fusion_original", context)));
|
||||
|
||||
patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
|
||||
context, dependenceGraph,
|
||||
|
@ -58,12 +61,14 @@ static void fillFusionPatterns(MLIRContext *context,
|
|||
.setTileSizes({32, 64, 16})
|
||||
.setLoopType(LinalgTilingLoopType::ParallelLoops),
|
||||
LinalgFusionOptions().setIndicesToFuse({0}),
|
||||
LinalgMarker(Identifier::get("lhs_fusion", context),
|
||||
Identifier::get("after_lhs_fusion", context)),
|
||||
LinalgMarker(ArrayRef<Identifier>(),
|
||||
Identifier::get("after_lhs_fusion_producer", context)),
|
||||
LinalgMarker(ArrayRef<Identifier>(),
|
||||
Identifier::get("after_lhs_fusion_original", context)));
|
||||
LinalgTransformationFilter(Identifier::get("lhs_fusion", context),
|
||||
Identifier::get("after_lhs_fusion", context)),
|
||||
LinalgTransformationFilter(
|
||||
ArrayRef<Identifier>(),
|
||||
Identifier::get("after_lhs_fusion_producer", context)),
|
||||
LinalgTransformationFilter(
|
||||
ArrayRef<Identifier>(),
|
||||
Identifier::get("after_lhs_fusion_original", context)));
|
||||
|
||||
patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
|
||||
context, dependenceGraph,
|
||||
|
@ -71,12 +76,14 @@ static void fillFusionPatterns(MLIRContext *context,
|
|||
.setTileSizes({32, 64, 16})
|
||||
.setLoopType(LinalgTilingLoopType::ParallelLoops),
|
||||
LinalgFusionOptions().setIndicesToFuse({1}),
|
||||
LinalgMarker(Identifier::get("rhs_fusion", context),
|
||||
Identifier::get("after_rhs_fusion", context)),
|
||||
LinalgMarker(ArrayRef<Identifier>(),
|
||||
Identifier::get("after_rhs_fusion_producer", context)),
|
||||
LinalgMarker(ArrayRef<Identifier>(),
|
||||
Identifier::get("after_rhs_fusion_original", context)));
|
||||
LinalgTransformationFilter(Identifier::get("rhs_fusion", context),
|
||||
Identifier::get("after_rhs_fusion", context)),
|
||||
LinalgTransformationFilter(
|
||||
ArrayRef<Identifier>(),
|
||||
Identifier::get("after_rhs_fusion_producer", context)),
|
||||
LinalgTransformationFilter(
|
||||
ArrayRef<Identifier>(),
|
||||
Identifier::get("after_rhs_fusion_original", context)));
|
||||
|
||||
patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
|
||||
context, dependenceGraph,
|
||||
|
@ -84,12 +91,13 @@ static void fillFusionPatterns(MLIRContext *context,
|
|||
.setTileSizes({32, 64, 16})
|
||||
.setLoopType(LinalgTilingLoopType::ParallelLoops),
|
||||
LinalgFusionOptions().setIndicesToFuse({0, 2}),
|
||||
LinalgMarker(Identifier::get("two_operand_fusion", context),
|
||||
Identifier::get("after_two_operand_fusion", context)),
|
||||
LinalgMarker(
|
||||
LinalgTransformationFilter(
|
||||
Identifier::get("two_operand_fusion", context),
|
||||
Identifier::get("after_two_operand_fusion", context)),
|
||||
LinalgTransformationFilter(
|
||||
ArrayRef<Identifier>(),
|
||||
Identifier::get("after_two_operand_fusion_producer", context)),
|
||||
LinalgMarker(
|
||||
LinalgTransformationFilter(
|
||||
ArrayRef<Identifier>(),
|
||||
Identifier::get("after_two_operand_fusion_original", context)));
|
||||
|
||||
|
@ -98,11 +106,13 @@ static void fillFusionPatterns(MLIRContext *context,
|
|||
LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(
|
||||
LinalgTilingLoopType::ParallelLoops),
|
||||
LinalgFusionOptions().setIndicesToFuse({0, 1}),
|
||||
LinalgMarker(Identifier::get("transpose_fusion", context),
|
||||
Identifier::get("after_transpose_fusion", context)),
|
||||
LinalgMarker(ArrayRef<Identifier>(),
|
||||
Identifier::get("after_transpose_fusion_producer", context)),
|
||||
LinalgMarker(
|
||||
LinalgTransformationFilter(
|
||||
Identifier::get("transpose_fusion", context),
|
||||
Identifier::get("after_transpose_fusion", context)),
|
||||
LinalgTransformationFilter(
|
||||
ArrayRef<Identifier>(),
|
||||
Identifier::get("after_transpose_fusion_producer", context)),
|
||||
LinalgTransformationFilter(
|
||||
ArrayRef<Identifier>(),
|
||||
Identifier::get("after_transpose_fusion_original", context)));
|
||||
}
|
||||
|
|
|
@ -98,29 +98,35 @@ static void applyPatterns(FuncOp funcOp) {
|
|||
//===--------------------------------------------------------------------===//
|
||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
|
||||
LinalgMarker(Identifier::get("MEM", ctx), Identifier::get("L3", ctx)));
|
||||
LinalgTransformationFilter(Identifier::get("MEM", ctx),
|
||||
Identifier::get("L3", ctx)));
|
||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
|
||||
LinalgMarker(Identifier::get("L3", ctx), Identifier::get("L2", ctx)));
|
||||
LinalgTransformationFilter(Identifier::get("L3", ctx),
|
||||
Identifier::get("L2", ctx)));
|
||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
|
||||
LinalgMarker(Identifier::get("L2", ctx), Identifier::get("L1", ctx)));
|
||||
LinalgTransformationFilter(Identifier::get("L2", ctx),
|
||||
Identifier::get("L1", ctx)));
|
||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
|
||||
LinalgMarker(Identifier::get("L1", ctx), Identifier::get("REG", ctx)));
|
||||
LinalgTransformationFilter(Identifier::get("L1", ctx),
|
||||
Identifier::get("REG", ctx)));
|
||||
|
||||
patterns.insert<LinalgTilingPattern<MatvecOp>>(
|
||||
ctx,
|
||||
LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
|
||||
LinalgTilingLoopType::ParallelLoops),
|
||||
LinalgMarker({}, Identifier::get("L1", ctx)));
|
||||
LinalgTransformationFilter(ArrayRef<Identifier>{},
|
||||
Identifier::get("L1", ctx)));
|
||||
|
||||
patterns.insert<LinalgTilingPattern<DotOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes(8000),
|
||||
LinalgMarker(ArrayRef<Identifier>{Identifier::get("MEM", ctx),
|
||||
Identifier::get("L3", ctx),
|
||||
Identifier::get("L2", ctx)},
|
||||
Identifier::get("REG", ctx)));
|
||||
LinalgTransformationFilter(
|
||||
ArrayRef<Identifier>{Identifier::get("MEM", ctx),
|
||||
Identifier::get("L3", ctx),
|
||||
Identifier::get("L2", ctx)},
|
||||
Identifier::get("REG", ctx)));
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Linalg tiling and permutation patterns.
|
||||
|
@ -130,24 +136,24 @@ static void applyPatterns(FuncOp funcOp) {
|
|||
LinalgTilingOptions()
|
||||
.setTileSizes({2000, 3000, 4000})
|
||||
.setInterchange({1, 2, 0}),
|
||||
LinalgMarker(Identifier::get("__with_perm__", ctx),
|
||||
Identifier::get("L2__with_perm__", ctx)));
|
||||
LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
|
||||
Identifier::get("L2__with_perm__", ctx)));
|
||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx,
|
||||
LinalgTilingOptions()
|
||||
.setTileSizes({200, 300, 400})
|
||||
.setInterchange({1, 0, 2}),
|
||||
LinalgMarker(Identifier::get("L2__with_perm__", ctx),
|
||||
Identifier::get("L1__with_perm__", ctx)));
|
||||
LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx),
|
||||
Identifier::get("L1__with_perm__", ctx)));
|
||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
|
||||
LinalgMarker(Identifier::get("L1__with_perm__", ctx),
|
||||
Identifier::get("REG__with_perm__", ctx)));
|
||||
LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx),
|
||||
Identifier::get("REG__with_perm__", ctx)));
|
||||
|
||||
patterns.insert<LinalgTilingPattern<MatvecOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
|
||||
LinalgMarker(Identifier::get("__with_perm__", ctx),
|
||||
Identifier::get("L1__with_perm__", ctx)));
|
||||
LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
|
||||
Identifier::get("L1__with_perm__", ctx)));
|
||||
|
||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx,
|
||||
|
@ -155,8 +161,9 @@ static void applyPatterns(FuncOp funcOp) {
|
|||
.setTileSizes({16, 8, 4})
|
||||
.setInterchange({1, 2, 0})
|
||||
.setLoopType(LinalgTilingLoopType::ParallelLoops),
|
||||
LinalgMarker(Identifier::get("par__with_perm__", ctx),
|
||||
Identifier::get("after_par__with_perm__", ctx)));
|
||||
LinalgTransformationFilter(
|
||||
Identifier::get("par__with_perm__", ctx),
|
||||
Identifier::get("after_par__with_perm__", ctx)));
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Linalg to loops patterns.
|
||||
|
@ -164,7 +171,7 @@ static void applyPatterns(FuncOp funcOp) {
|
|||
patterns.insert<LinalgLoweringPattern<DotOp>>(
|
||||
ctx,
|
||||
/*loweringType=*/LinalgLoweringType::Loops,
|
||||
LinalgMarker(Identifier::get("REG", ctx)));
|
||||
LinalgTransformationFilter(Identifier::get("REG", ctx)));
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Linalg distribution patterns.
|
||||
|
@ -178,7 +185,8 @@ static void applyPatterns(FuncOp funcOp) {
|
|||
LinalgVectorizationPattern<FillOp>,
|
||||
LinalgVectorizationPattern<CopyOp>,
|
||||
LinalgVectorizationPattern<GenericOp>>(
|
||||
ctx, LinalgMarker(Identifier::get("VECTORIZE", ctx)));
|
||||
ctx, LinalgVectorizationOptions(),
|
||||
LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx)));
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Linalg generic permutation patterns.
|
||||
|
@ -186,34 +194,38 @@ static void applyPatterns(FuncOp funcOp) {
|
|||
patterns.insert<LinalgInterchangePattern<GenericOp>>(
|
||||
ctx,
|
||||
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
|
||||
LinalgMarker({}, Identifier::get("PERMUTED", ctx)));
|
||||
LinalgTransformationFilter(ArrayRef<Identifier>{},
|
||||
Identifier::get("PERMUTED", ctx)));
|
||||
patterns.insert<LinalgInterchangePattern<IndexedGenericOp>>(
|
||||
ctx,
|
||||
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
|
||||
LinalgMarker({}, Identifier::get("PERMUTED", ctx)));
|
||||
LinalgTransformationFilter(ArrayRef<Identifier>{},
|
||||
Identifier::get("PERMUTED", ctx)));
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Linalg subview operands promotion.
|
||||
//===--------------------------------------------------------------------===//
|
||||
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
|
||||
ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
|
||||
LinalgMarker(Identifier::get("_promote_views_", ctx),
|
||||
Identifier::get("_views_promoted_", ctx)));
|
||||
LinalgTransformationFilter(Identifier::get("_promote_views_", ctx),
|
||||
Identifier::get("_views_promoted_", ctx)));
|
||||
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
|
||||
ctx,
|
||||
LinalgPromotionOptions()
|
||||
.setOperandsToPromote({0})
|
||||
.setUseFullTileBuffersByDefault(true),
|
||||
LinalgMarker(Identifier::get("_promote_first_view_", ctx),
|
||||
Identifier::get("_first_view_promoted_", ctx)));
|
||||
LinalgTransformationFilter(
|
||||
Identifier::get("_promote_first_view_", ctx),
|
||||
Identifier::get("_first_view_promoted_", ctx)));
|
||||
patterns.insert<LinalgPromotionPattern<FillOp>>(
|
||||
ctx,
|
||||
LinalgPromotionOptions()
|
||||
.setOperandsToPromote({0})
|
||||
.setUseFullTileBuffers({true})
|
||||
.setAlignment(32),
|
||||
LinalgMarker(Identifier::get("_promote_views_aligned_", ctx),
|
||||
Identifier::get("_views_aligned_promoted_", ctx)));
|
||||
LinalgTransformationFilter(
|
||||
Identifier::get("_promote_views_aligned_", ctx),
|
||||
Identifier::get("_views_aligned_promoted_", ctx)));
|
||||
|
||||
applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
|
||||
|
@ -230,18 +242,19 @@ static void fillL1TilingAndMatmulToVectorPatterns(
|
|||
patternsVector.emplace_back(std::make_unique<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx,
|
||||
LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}),
|
||||
LinalgMarker(Identifier::get(startMarker, ctx),
|
||||
Identifier::get("L1", ctx))));
|
||||
LinalgTransformationFilter(Identifier::get(startMarker, ctx),
|
||||
Identifier::get("L1", ctx))));
|
||||
|
||||
patternsVector.emplace_back(
|
||||
std::make_unique<LinalgPromotionPattern<MatmulOp>>(
|
||||
ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
|
||||
LinalgMarker(Identifier::get("L1", ctx),
|
||||
Identifier::get("VEC", ctx))));
|
||||
LinalgTransformationFilter(Identifier::get("L1", ctx),
|
||||
Identifier::get("VEC", ctx))));
|
||||
|
||||
patternsVector.emplace_back(
|
||||
std::make_unique<LinalgVectorizationPattern<MatmulOp>>(
|
||||
ctx, LinalgMarker(Identifier::get("VEC", ctx))));
|
||||
ctx, LinalgVectorizationOptions(),
|
||||
LinalgTransformationFilter(Identifier::get("VEC", ctx))));
|
||||
patternsVector.back()
|
||||
.insert<LinalgVectorizationPattern<FillOp>,
|
||||
LinalgVectorizationPattern<CopyOp>>(ctx);
|
||||
|
@ -289,8 +302,8 @@ static void fillPromotionCallBackPatterns(MLIRContext *ctx,
|
|||
OwningRewritePatternList &patterns) {
|
||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
|
||||
LinalgMarker(Identifier::get("START", ctx),
|
||||
Identifier::get("PROMOTE", ctx)));
|
||||
LinalgTransformationFilter(Identifier::get("START", ctx),
|
||||
Identifier::get("PROMOTE", ctx)));
|
||||
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
|
||||
ctx,
|
||||
LinalgPromotionOptions()
|
||||
|
@ -306,7 +319,7 @@ static void fillPromotionCallBackPatterns(MLIRContext *ctx,
|
|||
copyCallBackFn(b, src, dst, true);
|
||||
return success();
|
||||
}),
|
||||
LinalgMarker(Identifier::get("PROMOTE", ctx)));
|
||||
LinalgTransformationFilter(Identifier::get("PROMOTE", ctx)));
|
||||
}
|
||||
|
||||
template <typename IdOp, typename NProcsOp>
|
||||
|
@ -335,8 +348,9 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
|
|||
.setTileSizes({8, 8, 4})
|
||||
.setLoopType(LinalgTilingLoopType::ParallelLoops)
|
||||
.setDistributionOptions(cyclicNprocsEqNiters),
|
||||
LinalgMarker(Identifier::get("distribute1", context),
|
||||
Identifier::get("after_distribute1", context)));
|
||||
LinalgTransformationFilter(
|
||||
Identifier::get("distribute1", context),
|
||||
Identifier::get("after_distribute1", context)));
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -351,8 +365,9 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
|
|||
.setTileSizes({8, 8, 4})
|
||||
.setLoopType(LinalgTilingLoopType::ParallelLoops)
|
||||
.setDistributionOptions(cyclicNprocsGeNiters),
|
||||
LinalgMarker(Identifier::get("distribute2", context),
|
||||
Identifier::get("after_distribute2", context)));
|
||||
LinalgTransformationFilter(
|
||||
Identifier::get("distribute2", context),
|
||||
Identifier::get("after_distribute2", context)));
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -367,8 +382,9 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
|
|||
.setTileSizes({8, 8, 4})
|
||||
.setLoopType(LinalgTilingLoopType::ParallelLoops)
|
||||
.setDistributionOptions(cyclicNprocsDefault),
|
||||
LinalgMarker(Identifier::get("distribute3", context),
|
||||
Identifier::get("after_distribute3", context)));
|
||||
LinalgTransformationFilter(
|
||||
Identifier::get("distribute3", context),
|
||||
Identifier::get("after_distribute3", context)));
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -383,8 +399,9 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
|
|||
.setTileSizes({8, 8, 4})
|
||||
.setLoopType(LinalgTilingLoopType::ParallelLoops)
|
||||
.setDistributionOptions(cyclicNprocsMixed1),
|
||||
LinalgMarker(Identifier::get("distribute4", context),
|
||||
Identifier::get("after_distribute4", context)));
|
||||
LinalgTransformationFilter(
|
||||
Identifier::get("distribute4", context),
|
||||
Identifier::get("after_distribute4", context)));
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -399,8 +416,9 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
|
|||
.setTileSizes({8, 8, 4})
|
||||
.setLoopType(LinalgTilingLoopType::ParallelLoops)
|
||||
.setDistributionOptions(cyclicNprocsMixed2),
|
||||
LinalgMarker(Identifier::get("distribute5", context),
|
||||
Identifier::get("after_distribute5", context)));
|
||||
LinalgTransformationFilter(
|
||||
Identifier::get("distribute5", context),
|
||||
Identifier::get("after_distribute5", context)));
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -416,8 +434,9 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
|
|||
.setTileSizes({8, 8, 4})
|
||||
.setLoopType(LinalgTilingLoopType::ParallelLoops)
|
||||
.setDistributionOptions(cyclicNprocsMixed3),
|
||||
LinalgMarker(Identifier::get("distribute6", context),
|
||||
Identifier::get("after_distribute6", context)));
|
||||
LinalgTransformationFilter(
|
||||
Identifier::get("distribute6", context),
|
||||
Identifier::get("after_distribute6", context)));
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -432,8 +451,9 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
|
|||
.setTileSizes({8, 8, 4})
|
||||
.setLoopType(LinalgTilingLoopType::Loops)
|
||||
.setDistributionOptions(cyclicNprocsEqNiters),
|
||||
LinalgMarker(Identifier::get("tensors_distribute1", context),
|
||||
Identifier::get("tensors_after_distribute1", context)));
|
||||
LinalgTransformationFilter(
|
||||
Identifier::get("tensors_distribute1", context),
|
||||
Identifier::get("tensors_after_distribute1", context)));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -452,8 +472,8 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
|
|||
LinalgTilingOptions()
|
||||
.setTileSizes({768, 264, 768})
|
||||
.setInterchange({1, 2, 0}),
|
||||
LinalgMarker(Identifier::get("START", ctx),
|
||||
Identifier::get("L2", ctx))));
|
||||
LinalgTransformationFilter(Identifier::get("START", ctx),
|
||||
Identifier::get("L2", ctx))));
|
||||
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
|
||||
stage1Patterns);
|
||||
}
|
||||
|
@ -511,7 +531,8 @@ static void applyTileAndPadPattern(FuncOp funcOp) {
|
|||
.setPaddingValueComputationFunction(getNeutralOfLinalgOp);
|
||||
tilingPattern.insert<linalg::LinalgTilingPattern<linalg::MatmulOp>>(
|
||||
context, linalgTilingOptions,
|
||||
linalg::LinalgMarker(Identifier::get("tile-and-pad", context)));
|
||||
linalg::LinalgTransformationFilter(
|
||||
Identifier::get("tile-and-pad", context)));
|
||||
applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue