From 307cfdf5338641e3a895857ef02dc9da35cd0eb6 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Sat, 2 May 2020 01:03:37 -0400 Subject: [PATCH] [mlir][Linalg] Mostly NFC - Refactor Linalg patterns and transformations. Linalg transformations are currently exposed as DRRs. Unfortunately RewriterGen does not play well with the line of work on named linalg ops which require variadic operands and results. Additionally, DRR is arguably not the right abstraction to expose compositions of such patterns that don't rely on SSA use-def semantics. This revision abandons DRRs and exposes manually written C++ patterns. Refactorings and cleanups are performed to uniformize APIs. This refactoring will allow replacing the currently manually specified Linalg named ops. A collateral victim of this refactoring is the `tileAndFuse` DRR, and the one associated test, which will be revived at a later time. Lastly, the following 2 tests do not add value and are altered: - a dot_perm tile + interchange test does not test anything new and is removed - a dot tile + lower to loops does not need 2-D tiling and is trimmed. --- .../mlir/Dialect/Linalg/CMakeLists.txt | 1 - .../Dialect/Linalg/Transforms/CMakeLists.txt | 7 - .../Transforms/LinalgTransformPatterns.td | 123 ------ .../Linalg/Transforms/LinalgTransforms.h | 137 ------- .../Dialect/Linalg/Transforms/Transforms.h | 375 +++++++++++++++++ .../include/mlir/Dialect/Linalg/Utils/Utils.h | 68 ---- .../Dialect/Linalg/Transforms/CMakeLists.txt | 7 +- .../Dialect/Linalg/Transforms/Interchange.cpp | 85 ++++ .../Linalg/Transforms/LinalgTransforms.cpp | 381 ------------------ .../{LinalgToLoops.cpp => Loops.cpp} | 62 +-- .../Dialect/Linalg/Transforms/Promotion.cpp | 16 + mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 50 +-- .../Dialect/Linalg/Transforms/Transforms.cpp | 228 +++++++++++ .../Linalg/Transforms/Vectorization.cpp | 131 ++++++ .../Dialect/Linalg/transform-patterns.mlir | 106 +---- .../lib/DeclarativeTransforms/CMakeLists.txt | 6 - .../TestLinalgTransformPatterns.td | 168 -------- mlir/test/lib/Transforms/CMakeLists.txt | 1 - .../lib/Transforms/TestLinalgTransforms.cpp | 127 +++++- 19 files changed, 1002 insertions(+), 1077 deletions(-) delete mode 100644 mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt delete mode 100644 mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td delete mode 100644 mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h create mode 100644 mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h create mode 100644 mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp delete mode 100644 mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp rename mlir/lib/Dialect/Linalg/Transforms/{LinalgToLoops.cpp => Loops.cpp} (93%) create mode 100644 mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp create mode 100644 mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp delete mode 100644 mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td diff --git a/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt index 076c2dfbccb5..66ac74515ddd 100644 --- a/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt @@ -1,5 +1,4 @@ add_subdirectory(IR) -add_subdirectory(Transforms) set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt deleted file mode 100644 index e90626134897..000000000000 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS LinalgTransformPatterns.td) -mlir_tablegen(LinalgTransformPatterns.h.inc -gen-rewriters) -add_public_tablegen_target(MLIRLinalgTransformPatternsIncGen) - -# Including Linalg in TableGen requires to depends on generated files -add_dependencies(MLIRLinalgTransformPatternsIncGen LinalgOdsGen) - diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td deleted file mode 100644 index a51352cd4d0e..000000000000 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td +++ /dev/null @@ -1,123 +0,0 @@ -//===- LinalgPatterns.td - Linalg transformation patterns --*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This is the pattern definition file for declarative Linalg transformation. -// -//===----------------------------------------------------------------------===// - -#ifndef LINALG_TRANSFORMS -#define LINALG_TRANSFORMS - -include "mlir/Dialect/Linalg/IR/LinalgOps.td" -include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td" -include "mlir/Dialect/Affine/IR/AffineOps.td" - -def HasNoLinalgTransformMarker : CPred<[{ - !op.getAttrOfType(LinalgTransforms::kLinalgTransformMarker) -}]>; - -class HasLinalgTransformMarker : CPred<[{ - op.getAttrOfType( - LinalgTransforms::kLinalgTransformMarker) && - op.getAttrOfType( - LinalgTransforms::kLinalgTransformMarker).getValue() == "}] # str # [{"}]>; - -class IsProducedByOpOfType : - CPred<"isProducedByOpOfType<" # str # ">(op, $0)">; - -class AffineMapDomainHasDim : CPred<[{ - op.getAttrOfType(getIndexingMapsAttrName()).getValue()[0]. - cast().getValue().getNumDims() ==}] # n # [{}]>; - -class HasOperandsOfType: CPred<[{ - llvm::any_of(op.getOperands(), - [](Value v) { - return dyn_cast_or_null<}] # type # [{>(v.getDefiningOp()); - }) -}]>; - -//===----------------------------------------------------------------------===// -// Linalg fusion patterns. -//===----------------------------------------------------------------------===// -// -// In the future, tile sizes should be derived from op properties + machine -// description but we do not need to wait on this to start having useful -// patterns. -class TileAndFuseLinalgOp< - list sizes, list operandIndices, string value> : NativeCodeCall< - "if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, op, {" # - StrJoinInt.result # "}, {" # StrJoinInt.result # "}," # - " \"" # value # "\")))" # - " return failure();">; - -//===----------------------------------------------------------------------===// -// Linalg tiling patterns. -//===----------------------------------------------------------------------===// -// -// In the future, tile sizes should be derived from op properties + machine -// description but we do not need to wait on this to start having useful -// patterns. -// `permutation` is an optional parameter to specify the ordering of the -// tiled loops. If provided, it must be a list of integers with the same number -// of elements as `sizes`. -class TileLinalgOp sizes, string value, list permutation=[]> : - NativeCodeCall< - "if (failed(tileLinalgOpAndSetMarker($_builder, op, {" # - StrJoinInt.result # "}, \"" # value # "\", {" # - StrJoinInt.result # "})))" # - " return failure();">; - -//===----------------------------------------------------------------------===// -// Linalg to loop patterns. -//===----------------------------------------------------------------------===// -class LinalgOpToLoops : NativeCodeCall< - "if (failed(linalgOpToLoops<" # OpType # ">($_builder, op))) " # - " return failure();">; - -class LinalgOpToParallelLoops : NativeCodeCall< - "if (failed(linalgOpToParallelLoops<" # OpType # ">($_builder, op))) " # - " return failure();">; - -class LinalgOpToAffineLoops : NativeCodeCall< - "if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, op))) " # - " return failure();">; - -//===----------------------------------------------------------------------===// -// Linalg to vector patterns precondition and DRR. -//===----------------------------------------------------------------------===// -def PreconditionVectorizeLinalgOp : CPred< - "succeeded(vectorizeLinalgOpPrecondition(op))">; -def VectorizeLinalgOp : NativeCodeCall< - "vectorizeLinalgOp($_builder, op)">; - - -//===----------------------------------------------------------------------===// -// Linalg generic permutation patterns precondition and DRR. -//===----------------------------------------------------------------------===// -class PreconditionPermuteGenericLinalgOp permutation> : CPred< - "succeeded(permuteGenericLinalgOpPrecondition(op, {" # - StrJoinInt.result # "}))">; -class PermuteGenericLinalgOp permutation, string value> : - NativeCodeCall< - "permuteGenericLinalgOp($_builder, op, {" # StrJoinInt.result # - "}, \"" # value # "\")">; - -//===----------------------------------------------------------------------===// -// Linalg promote subview operands precondition and DRR. -//===----------------------------------------------------------------------===// -def PreconditionPromoteSubviewsLinalgOp : CPred< - "succeeded(promoteSubviewsLinalgOpPrecondition(op))">; -def PromoteSubviewsLinalgOp : NativeCodeCall< - "promoteSubviewsLinalgOp($_builder, op)">; - -class PromoteSelectedSubviewsLinalgOp operands, string marker="", - int alignment=0> : - NativeCodeCall<"promoteSelectedSubviewsLinalgOpAndSetMarker($_builder, op, {" # - StrJoinInt.result # "}, \"" # marker # "\", " # alignment # ")">; - -#endif // LINALG_TRANSFORMS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h deleted file mode 100644 index 78d588aaf00b..000000000000 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h +++ /dev/null @@ -1,137 +0,0 @@ -//===- LinalgTransforms.h - Linalg transformations as patterns --*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef DIALECT_LINALG_TRANSFORMS_LINALGTRANSFORMS_H_ -#define DIALECT_LINALG_TRANSFORMS_LINALGTRANSFORMS_H_ - -#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" - -#include "llvm/ADT/STLExtras.h" - -namespace mlir { -namespace linalg { - -// Marker used as attribute name in generated Linalg rewriting transformations. -struct LinalgTransforms { - static const StringLiteral kLinalgTransformMarker; -}; - -namespace detail { -// Implementation detail of isProducedByOpOfType avoids the need for explicit -// template instantiations. -bool isProducedByOpOfTypeImpl(Operation *consumerOp, Value consumedView, - function_ref isaOpType); -} // namespace detail - -// Returns true if the `consumedView` value use in `consumerOp` is produced by -// an op of type `OpTy`. This is used to implement use-def type information on -// buffers. -template -bool isProducedByOpOfType(Operation *consumerOp, Value consumedView) { - return detail::isProducedByOpOfTypeImpl( - consumerOp, consumedView, [](Operation *op) { return isa(op); }); -} - -//////////////////////////////////////////////////////////////////////////////// -// The following Declarative Rewrite Rule (DRR) helpers are used in rewrite -// patterns. As such, they must not call into `rewriter.erase/replace` APIs and -// it is the responsibility of the enclosing PatternRewriter to erase on -// success. -//////////////////////////////////////////////////////////////////////////////// - -/// Tiles `op` by `sizes` permuting the loops according to `permutation` and -/// sets the attribute `kLinalgTransformMarker` to `linalgMarker`. The -/// permutation is expressed as a list of integers that specify the new ordering -/// of the loop nest (using loop.for operations). The length of `permutation` -/// must be equal to the length of `tileSizes`. -/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with -/// `permutation = [1,2,0]`. All values in `permutation` must be -/// integers, in the range 0..`tileSizes.size()` without duplications -/// (i.e. `[1,1,2]` is an invalid permutation). An empty list -/// states for the identity permutation. -LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, Operation *op, - ArrayRef sizes, - StringRef linalgMarker, - ArrayRef permutation); - -/// Tiles ops similar to `tileLinalgOpAndSetMarker` but generates loop.parallel -/// operations instead. -LogicalResult tileLinalgOpToParallelLoopsAndSetMarker( - PatternRewriter &rewriter, Operation *op, ArrayRef sizes, - StringRef linalgMarker, ArrayRef permutation); - -/// Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and -/// sets the attribute `kLinalgTransformMarker` to `linalgMarker`. -LogicalResult tileAndFuseLinalgOpAndSetMarker( - PatternRewriter &rewriter, Operation *op, ArrayRef sizes, - ArrayRef operandIndicesToFuse, StringRef linalgMarker); - -/// Tiles ops similar to `tileAndFuseLinalgOpAndSetMarker` but generates -/// loop.parallel operations instead. -LogicalResult tileAndFuseLinalgOpToParallelLoopsAndSetMarker( - PatternRewriter &rewriter, Operation *op, ArrayRef sizes, - ArrayRef operandIndicesToFuse, StringRef linalgMarker); - -using LinalgLoops = SmallVector; - -/// Emits a loop nest of with the proper body for `op`. -template -Optional linalgLowerOpToLoops(PatternRewriter &rewriter, - Operation *op); - -/// Emits a loop nest of `loop.for` with the proper body for `op`. -template -LogicalResult linalgOpToLoops(PatternRewriter &rewriter, Operation *op); - -/// Emits a loop nest of `loop.parallel` with the proper body for `op`. -template -LogicalResult linalgOpToParallelLoops(PatternRewriter &rewriter, Operation *op); - -/// Emits a loop nest of `affine.for` with the proper body for `op`. -template -LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op); - -/// Rewrite a linalg.generic into a suitable vector.contraction op. -LogicalResult vectorizeLinalgOpPrecondition(Operation *op); -SmallVector vectorizeLinalgOp(PatternRewriter &rewriter, - Operation *op); - -/// Emits a `generic` or `indexed_generic` operation with the `indexing_maps` -/// and `iterator_types` permutated according to `permutation`. -LogicalResult -permuteGenericLinalgOpPrecondition(Operation *op, - ArrayRef permutation); -SmallVector permuteGenericLinalgOp(PatternRewriter &rewriter, - Operation *op, - ArrayRef permutation, - StringRef linalgMarker); - -/// Promote std.subviews feeding linalg operations. -LogicalResult promoteSubviewsLinalgOpPrecondition(Operation *op); -SmallVector promoteSubviewsLinalgOp(PatternRewriter &rewriter, - Operation *op); - -/// Similar to `promoteSubviewsLinalgOp` but only tries to promote -/// the views corresponding to the operands specified in -/// `operandIndicesToPromote`. Generated allocations are memory-aligned -/// according to the `alignment` parameter. -/// If linalgMarker is specified and the transformation is successfull -/// sets the attribute `kLinalgTransformMarker` to `linalgMarker`. -SmallVector promoteSelectedSubviewsLinalgOpAndSetMarker( - PatternRewriter &rewriter, Operation *op, - ArrayRef operandIndicesToPromote, StringRef linalgMarker = "", - int64_t alignment = 0); -} // namespace linalg -} // namespace mlir - -#endif // DIALECT_LINALG_TRANSFORMS_LINALGTRANSFORMS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h new file mode 100644 index 000000000000..b67ff776ea4a --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -0,0 +1,375 @@ +//===- Transforms.h - Linalg transformations as patterns --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_ +#define DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_ + +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace linalg { + +//===----------------------------------------------------------------------===// +// Transformations exposed as function calls. +//===----------------------------------------------------------------------===// +using LinalgLoops = SmallVector; + +struct TiledLinalgOp { + LinalgOp op; + SmallVector loops; +}; + +/// Performs standalone tiling of a single LinalgOp by `tileSizes`. +/// and permute the loop nest according to `interchangeVector` +/// The permutation is expressed as a list of integers that specify +/// the new ordering of the loop nest. The length of `interchangeVector` +/// must be equal to the length of `tileSizes`. +/// An empty vector is interpreted as the identity permutation and the +/// transformation returns early. +/// +/// When non-null, the optional pointer `folder` is used to call into the +/// `createAndFold` builder method. If `folder` is null, the regular `create` +/// method is called. +/// +/// Returns a struct containing the tiled loops in the specified order +/// and the cloned op if successful, llvm::None otherwise. +/// +/// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed by +/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be +/// integers, in the range 0..`tileSizes.size()` without duplications +/// (i.e. `[1,1,2]` is an invalid permutation). +Optional tileLinalgOp(OpBuilder &b, LinalgOp op, + ArrayRef tileSizes, + ArrayRef interchangeVector = {}, + OperationFolder *folder = nullptr); +Optional +tileLinalgOpToParallelLoops(OpBuilder &b, LinalgOp op, + ArrayRef tileSizes, + ArrayRef interchangeVector = {}, + OperationFolder *folder = nullptr); + +/// Performs standalone tiling of a single LinalgOp by constant `tileSizes`. +/// See `tileLinalgOp(... ArrayRef tileSizes,)` for more details +Optional tileLinalgOp(OpBuilder &b, LinalgOp op, + ArrayRef tileSizes, + ArrayRef interchangeVector = {}, + OperationFolder *folder = nullptr); +Optional +tileLinalgOpToParallelLoops(OpBuilder &b, LinalgOp op, + ArrayRef tileSizes, + ArrayRef interchangeVector = {}, + OperationFolder *folder = nullptr); + +/// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`. +/// This is an in-place transformation controlled by `interchangeVector`. +/// An empty vector is interpreted as the identity permutation and the +/// transformation returns early. +/// +/// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed with +/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be +/// integers, in the range 0..`op.rank` without duplications +/// (i.e. `[1,1,2]` is an invalid permutation). +LinalgOp interchange(LinalgOp op, ArrayRef interchangeVector); + +/// Promotes the `subViews` into a new buffer allocated at the insertion point +/// `b`. Promotion occurs in 3 steps: +/// 1. Create a new buffer for a full tile (i.e. not clipped at the boundary). +/// 2. Take a full view on the buffer and `linalg.fill` it with zeros (use +/// float zero for now). +/// 3. Take a partial slice of the full view in step 2. and copy into it. +/// Infers statically sized buffers from subViews unless `dynamicBuffers` is +/// true. +/// +/// Returns a list of PromotionInfo which hold the promoted buffer and the +/// full and partial views indexing into the buffer. +// TODO: revisit dynamicBuffers option. +LinalgOp promoteSubViewOperands(OpBuilder &b, LinalgOp op, + llvm::SetVector subViews, + bool dynamicBuffers = false, + int64_t alignment = 0, + OperationFolder *folder = nullptr); + +/// Emit a suitable vector form for a Linalg op with fully static shape. +void vectorizeLinalgOp(OpBuilder &builder, Operation *op); + +/// Emits a loop nest of `LoopTy` with the proper body for `op`. +template +Optional linalgLowerOpToLoops(OpBuilder &builder, Operation *op); + +/// Emits a loop nest of `loop.for` with the proper body for `op`. +template +LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op); + +/// Emits a loop nest of `loop.parallel` with the proper body for `op`. +template +LogicalResult linalgOpToParallelLoops(OpBuilder &builder, Operation *op); + +/// Emits a loop nest of `affine.for` with the proper body for `op`. +template +LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op); + +//===----------------------------------------------------------------------===// +// Preconditions that ensure the corresponding transformation suceeds and can be +// applied as a rewrite pattern. +//===----------------------------------------------------------------------===// +/// Emits a `generic` or `indexed_generic` operation with the `indexing_maps` +/// and `iterator_types` permutated according to `permutation`. +LogicalResult +interchangeGenericLinalgOpPrecondition(Operation *op, + ArrayRef interchangeVector); + +/// Promote std.subviews feeding linalg operations. +LogicalResult promoteSubviewsLinalgOpPrecondition( + Operation *op, Optional> operandIndicesToPromote = None); + +/// Rewrite a linalg.generic into a suitable vector.contraction op. +LogicalResult vectorizeLinalgOpPrecondition(Operation *op); + +//===----------------------------------------------------------------------===// +// Transformations exposed as rewrite patterns. +//===----------------------------------------------------------------------===// +// Marker used as attribute name in generated Linalg rewriting transformations. +struct LinalgTransforms { + static const StringLiteral kLinalgTransformMarker; +}; + +/// Helper class to control common attribute matching and setting behavior. +struct LinalgMarker { + LinalgMarker(ArrayRef matchDisjunction = {}, + Optional replacement = None); + LinalgMarker(ArrayRef matchDisjunction, StringRef replacement); + LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const; + void replaceLinalgMarker(PatternRewriter &rewriter, Operation *op) const; + +private: + SmallVector matchDisjunction; + Optional replacement; +}; + +/// +/// Linalg tiling patterns. +/// +/// Apply the `tileLinalgOp` transformation as a pattern. +/// `marker` controls LinalgTransformMarker matching and update when specified. +/// See `tileLinalgOp` for more details. +enum class LinalgTilingLoopType { + Loops = 0, + AffineLoops = 1, + ParallelLoops = 2 +}; +struct LinalgTilingOptions { + /// The tile sizes by which to tile. + SmallVector tileSizes{}; + LinalgTilingOptions &setTileSizes(ArrayRef ts) { + tileSizes.assign(ts.begin(), ts.end()); + return *this; + } + /// The interchange vector to reorder the tiled loops. + SmallVector interchangeVector{}; + LinalgTilingOptions &setInterchange(ArrayRef interchange) { + interchangeVector.assign(interchange.begin(), interchange.end()); + return *this; + } + /// The type of tile loops to generate. + LinalgTilingLoopType loopType{LinalgTilingLoopType::Loops}; + LinalgTilingOptions &setLoopType(LinalgTilingLoopType lt) { + loopType = lt; + return *this; + } +}; + +struct LinalgBaseTilingPattern : public RewritePattern { + LinalgBaseTilingPattern(StringRef opName, MLIRContext *context, + LinalgTilingOptions options, + LinalgMarker marker = LinalgMarker(), + PatternBenefit benefit = 1); + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; + +private: + /// LinalgTransformMarker handles special attribute manipulations. + LinalgMarker marker; + /// Options to control tiling; + LinalgTilingOptions options; +}; + +template +struct LinalgTilingPattern : public LinalgBaseTilingPattern { + LinalgTilingPattern(MLIRContext *context, LinalgTilingOptions options, + LinalgMarker marker = LinalgMarker(), + PatternBenefit benefit = 1) + : LinalgBaseTilingPattern(OpTy::getOperationName(), context, options, + marker, benefit) {} +}; + +/// +/// Linalg interchange patterns. +/// +/// Apply the `interchange` transformation as a pattern. +/// `marker` controls LinalgTransformMarker matching and update when specified. +/// See `interchange` for more details. +struct LinalgBaseInterchangePattern : public RewritePattern { + LinalgBaseInterchangePattern(StringRef opName, MLIRContext *context, + ArrayRef interchangeVector, + LinalgMarker marker = LinalgMarker(), + PatternBenefit benefit = 1); + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; + +private: + /// LinalgTransformMarker handles special attribute manipulations. + LinalgMarker marker; + /// The interchange vector to reorder the iterators and indexing_maps dims. + SmallVector interchangeVector; +}; + +template +struct LinalgInterchangePattern : public LinalgBaseInterchangePattern { + LinalgInterchangePattern(MLIRContext *context, + ArrayRef interchangeVector, + LinalgMarker marker = LinalgMarker(), + PatternBenefit benefit = 1) + : LinalgBaseInterchangePattern(OpTy::getOperationName(), context, + interchangeVector, marker, benefit) {} +}; + +/// +/// Linalg promotion patterns. +/// +/// Apply the `promoteSubViewOperands` transformation as a pattern. +/// `marker` controls LinalgTransformMarker matching and update when specified. +/// See `promoteSubViewOperands` for more details. +struct LinalgBasePromotionPattern : public RewritePattern { + LinalgBasePromotionPattern(StringRef opName, MLIRContext *context, + ArrayRef operandsToPromote = {}, + unsigned alignment = 0, + LinalgMarker marker = LinalgMarker(), + PatternBenefit benefit = 1); + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; + +private: + /// LinalgTransformMarker handles special attribute manipulations. + LinalgMarker marker; + /// Indices of subViews to promote. + SmallVector operandsToPromote; + /// Alignment of promoted buffer. + unsigned alignment; +}; + +template +struct LinalgPromotionPattern : public LinalgBasePromotionPattern { + LinalgPromotionPattern(MLIRContext *context, + ArrayRef operandsToPromote = {}, + unsigned alignment = 0, + LinalgMarker marker = LinalgMarker(), + PatternBenefit benefit = 1) + : LinalgBasePromotionPattern(OpTy::getOperationName(), context, + operandsToPromote, alignment, marker, + benefit) {} + LinalgPromotionPattern(MLIRContext *context, + ArrayRef operandsToPromote, + LinalgMarker marker = LinalgMarker(), + PatternBenefit benefit = 1) + : LinalgPromotionPattern(context, operandsToPromote, 0, marker, benefit) { + } + LinalgPromotionPattern(MLIRContext *context, unsigned alignment, + LinalgMarker marker = LinalgMarker(), + PatternBenefit benefit = 1) + : LinalgPromotionPattern(context, {}, alignment, marker, benefit) {} + LinalgPromotionPattern(MLIRContext *context, LinalgMarker marker, + PatternBenefit benefit = 1) + : LinalgPromotionPattern(context, {}, 0, marker, benefit) {} +}; + +/// +/// Linalg vectorization patterns. +/// +/// Apply the `vectorizeLinalgOp` transformation as a pattern. +/// `marker` controls LinalgTransformMarker matching and update when specified. +/// See `vectorizeLinalgOp` for more details. +struct LinalgBaseVectorizationPattern : public RewritePattern { + LinalgBaseVectorizationPattern(StringRef opName, MLIRContext *context, + LinalgMarker marker = LinalgMarker(), + PatternBenefit benefit = 1); + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; + +private: + /// LinalgTransformMarker handles special attribute manipulations. + LinalgMarker marker; +}; + +template +struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern { + LinalgVectorizationPattern(MLIRContext *context, + LinalgMarker marker = LinalgMarker(), + PatternBenefit benefit = 1) + : LinalgBaseVectorizationPattern(OpTy::getOperationName(), context, + marker, benefit) {} +}; + +/// +/// Linalg lowering patterns. +/// +/// Apply the `linalgLowerOpToLoops` transformation as a pattern. +/// `marker` controls LinalgTransformMarker matching and update when specified. +/// See `linalgLowerOpToLoops` for more details. +enum class LinalgLoweringType { + LibraryCall = 0, + Loops = 1, + AffineLoops = 2, + ParallelLoops = 3 +}; +template +struct LinalgLoweringPattern : public RewritePattern { + LinalgLoweringPattern(MLIRContext *context, LinalgLoweringType loweringType, + LinalgMarker marker = LinalgMarker(), + PatternBenefit benefit = 1) + : RewritePattern(OpTy::getOperationName(), {}, benefit, context), + marker(marker), loweringType(loweringType) {} + // TODO: Move implementation to .cpp once named ops are auto-generated. + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + LinalgOp linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + if (failed(marker.checkAndNotify(rewriter, linalgOp))) + return failure(); + if (failed(promoteSubviewsLinalgOpPrecondition(op))) + return failure(); + + if (loweringType == LinalgLoweringType::LibraryCall) { + // TODO: Move lowering to library calls here. + return failure(); + } else if (loweringType == LinalgLoweringType::Loops) { + if (failed(linalgOpToLoops(rewriter, op))) + return failure(); + } else if (loweringType == LinalgLoweringType::AffineLoops) { + if (failed(linalgOpToAffineLoops(rewriter, op))) + return failure(); + } else if (failed(linalgOpToParallelLoops(rewriter, op))) { + return failure(); + } + rewriter.eraseOp(op); + return success(); + } + +private: + /// LinalgTransformMarker handles special attribute manipulations. + LinalgMarker marker; + /// Controls whether the pattern lowers to library calls, loop.for, affine.for + /// or loop.parallel. + LinalgLoweringType loweringType; +}; + +} // namespace linalg +} // namespace mlir + +#endif // DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 7dea577f0a49..1a5b6d888c0c 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -101,63 +101,6 @@ SmallVector applyMapToValues(OpBuilder &b, Location loc, AffineMap map, ArrayRef values, OperationFolder *folder = nullptr); -struct TiledLinalgOp { - LinalgOp op; - SmallVector loops; -}; - -/// Performs standalone tiling of a single LinalgOp by `tileSizes`. -/// and permute the loop nest according to `permutation` -/// The permutation is expressed as a list of integers that specify -/// the new ordering of the loop nest. The length of `permutation` -/// must be equal to the length of `tileSizes`. -/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with -/// `permutation = [1,2,0]`. All values in `permutation` must be -/// integers, in the range 0..`tileSizes.size()` without duplications -/// (i.e. `[1,1,2]` is an invalid permutation). An empty list -/// states for the identity permutation. -/// Returns a struct containing the tiled loops in the specified order -/// and the cloned op if successful, llvm::None otherwise. -/// When non-null, the optional pointer `folder` is used to call into the -/// `createAndFold` builder method. If `folder` is null, the regular `create` -/// method is called. -Optional tileLinalgOp(OpBuilder &b, LinalgOp op, - ArrayRef tileSizes, - ArrayRef permutation = {}, - OperationFolder *folder = nullptr); -Optional tileLinalgOpToParallelLoops( - OpBuilder &b, LinalgOp op, ArrayRef tileSizes, - ArrayRef permutation = {}, OperationFolder *folder = nullptr); - -/// Performs standalone tiling of a single LinalgOp by constant `tileSizes`. -/// and permute the loop nest according to `permutation` -/// The permutation is expressed as a list of integers that specify -/// the new ordering of the loop nest. The length of `permutation` -/// must be equal to the length of `tileSizes`. -/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with -/// `permutation = [1,2,0]`. All values in `permutation` must be -/// integers, in the range 0..`tileSizes.size()` without duplications -/// (i.e. `[1,1,2]` is an invalid permutation). An empty list -/// states for the identity permutation. -/// Returns a struct containing the tiled loops in the specified order -/// and the cloned op if successful, llvm::None otherwise. -/// When non-null, the optional pointer `folder` is used to call into the -/// `createAndFold` builder method. If `folder` is null, the regular `create` -/// method is called. -Optional tileLinalgOp(OpBuilder &b, LinalgOp op, - ArrayRef tileSizes, - ArrayRef permutation = {}, - OperationFolder *folder = nullptr); -Optional tileLinalgOpToParallelLoops( - OpBuilder &b, LinalgOp op, ArrayRef tileSizes, - ArrayRef permutation = {}, OperationFolder *folder = nullptr); - -template -Optional tileLinalgOperation(OpBuilder &b, Operation *op, - Args... args) { - return tileLinalgOp(b, cast(op), args...); -} - struct PromotionInfo { Value buffer; Value fullLocalView; @@ -198,17 +141,6 @@ void applyPermutationToVector(SmallVector &inVec, inVec = auxVec; } -/// Prepares the SubView promotion later performed by `promoteSubViews` -/// (where most of the transformation happens). It arranges the new -/// operands for `LinalgOp op` and deallocates the new buffer(s) -/// It is the entry point for declarative transformation -/// Returns the cloned `LinalgOp` with the new operands -LinalgOp promoteSubViewOperands(OpBuilder &b, LinalgOp op, - llvm::SetVector subViews, - bool dynamicBuffers = false, - int64_t alignment = 0, - OperationFolder *folder = nullptr); - } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index c8e74ea30e8d..c8464e277cbb 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -1,9 +1,11 @@ add_mlir_dialect_library(MLIRLinalgTransforms Fusion.cpp - LinalgTransforms.cpp - LinalgToLoops.cpp + Interchange.cpp + Loops.cpp Promotion.cpp Tiling.cpp + Transforms.cpp + Vectorization.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg @@ -11,7 +13,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms DEPENDS intrinsics_gen MLIRLinalgPassIncGen - MLIRLinalgTransformPatternsIncGen ) target_link_libraries(MLIRLinalgTransforms PUBLIC diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp new file mode 100644 index 000000000000..71e4969d4657 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp @@ -0,0 +1,85 @@ +//===- Interchange.cpp - Linalg interchange transformation ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the linalg interchange transformation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/EDSC/Intrinsics.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include + +#define DEBUG_TYPE "linalg-interchange" + +using namespace mlir; +using namespace mlir::linalg; + +LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition( + Operation *op, ArrayRef interchangeVector) { + if (interchangeVector.empty()) + return failure(); + // Transformation applies to generic ops only. + if (!isa(op) && !isa(op)) + return failure(); + LinalgOp linOp = cast(op); + // Transformation applies to buffers only. + if (!linOp.hasBufferSemantics()) + return failure(); + // Permutation must be applicable. + if (linOp.getIndexingMap(0).getNumInputs() != interchangeVector.size()) + return failure(); + // Permutation map must be invertible. + if (!inversePermutation( + AffineMap::getPermutationMap(interchangeVector, op->getContext()))) + return failure(); + return success(); +} + +LinalgOp mlir::linalg::interchange(LinalgOp op, + ArrayRef interchangeVector) { + if (interchangeVector.empty()) + return op; + + MLIRContext *context = op.getContext(); + auto permutationMap = inversePermutation( + AffineMap::getPermutationMap(interchangeVector, context)); + assert(permutationMap && "expected permutation to be invertible"); + SmallVector newIndexingMaps; + auto indexingMaps = op.indexing_maps().getValue(); + for (unsigned i = 0, e = op.getNumInputsAndOutputs(); i != e; ++i) { + AffineMap m = indexingMaps[i].cast().getValue(); + if (!permutationMap.isEmpty()) + m = m.compose(permutationMap); + newIndexingMaps.push_back(AffineMapAttr::get(m)); + } + auto itTypes = op.iterator_types().getValue(); + SmallVector itTypesVector; + for (unsigned i = 0, e = itTypes.size(); i != e; ++i) + itTypesVector.push_back(itTypes[i]); + applyPermutationToVector(itTypesVector, interchangeVector); + + op.setAttr(getIndexingMapsAttrName(), + ArrayAttr::get(newIndexingMaps, context)); + op.setAttr(getIteratorTypesAttrName(), + ArrayAttr::get(itTypesVector, context)); + + return op; +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp deleted file mode 100644 index 423c1c10596a..000000000000 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ /dev/null @@ -1,381 +0,0 @@ -//===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements logic for transforming Linalg operations. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h" -#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" -#include "mlir/Dialect/Vector/EDSC/Intrinsics.h" -#include "mlir/Dialect/Vector/VectorOps.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" -#include - -#define DEBUG_TYPE "linalg-transforms" - -using namespace mlir; -using namespace mlir::edsc; -using namespace mlir::edsc::intrinsics; -using namespace mlir::linalg; - -using llvm::dbgs; -using llvm::SetVector; - -// Marker used as attribute name in generated Linalg rewriting transformations. -const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = - "__internal_linalg_transform__"; - -using TileFn = Optional(OpBuilder &, LinalgOp, ArrayRef, - ArrayRef, OperationFolder *); - -static LogicalResult -tileLinalgOpAndSetMarkerImpl(TileFn tileFn, PatternRewriter &rewriter, - Operation *op, ArrayRef sizes, - StringRef linalgMarker, - ArrayRef permutation) { - assert(permutation.empty() || permutation.size() == sizes.size()); - auto tileRes = tileFn(rewriter, op, sizes, permutation, /*folder=*/nullptr); - if (!tileRes) - return failure(); - tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker, - rewriter.getStringAttr(linalgMarker)); - return success(); -} - -LogicalResult mlir::linalg::tileLinalgOpAndSetMarker( - PatternRewriter &rewriter, Operation *op, ArrayRef sizes, - StringRef linalgMarker, ArrayRef permutation) { - return tileLinalgOpAndSetMarkerImpl(tileLinalgOp, rewriter, op, sizes, - linalgMarker, permutation); -} -LogicalResult mlir::linalg::tileLinalgOpToParallelLoopsAndSetMarker( - PatternRewriter &rewriter, Operation *op, ArrayRef sizes, - StringRef linalgMarker, ArrayRef permutation) { - return tileLinalgOpAndSetMarkerImpl(tileLinalgOpToParallelLoops, rewriter, op, - sizes, linalgMarker, permutation); -} - -static LogicalResult -tileAndFuseLinalgOpAndSetMarkerImpl(TileFn tileFn, PatternRewriter &rewriter, - Operation *op, ArrayRef sizes, - ArrayRef operandIndicesToFuse, - StringRef linalgMarker) { - auto tileRes = - tileFn(rewriter, op, sizes, /*permutation=*/{}, /*folder=*/nullptr); - if (!tileRes) - return failure(); - tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker, - rewriter.getStringAttr(linalgMarker)); - Aliases aliases; - auto G = LinalgDependenceGraph::buildDependenceGraph( - aliases, op->getParentOfType()); - SmallVector originalProducers; - for (auto operandIdx : operandIndicesToFuse) { - auto fusionRes = fuseProducerOf(rewriter, tileRes->op, operandIdx, G); - if (!fusionRes) { - // Linalg fusion requires tiled loops to even determine whether it is - // possible to fuse. As a consequence, the pattern may fail even though a - // tiled version of op has already been introduced. - // So we need to remove the tiled version ourselves in case of failure. - // Another possibility is to ensure the constraints on the pattern - // guarantee that fusion will occur and just assert here. As we develop - // more complex patterns we can choose what is best. - rewriter.eraseOp(tileRes->loops[0]); - return failure(); - } - fusionRes->fusedProducer.setAttr(LinalgTransforms::kLinalgTransformMarker, - rewriter.getStringAttr(linalgMarker)); - originalProducers.push_back(fusionRes->originalProducer); - } - - // The originalProducers can now be safely erased. This is similar to - // SSA-value use-def but in the world of buffer + structured ops. - for (auto *originalProducer : originalProducers) - rewriter.eraseOp(originalProducer); - return success(); -} - -LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker( - PatternRewriter &rewriter, Operation *op, ArrayRef sizes, - ArrayRef operandIndicesToFuse, StringRef linalgMarker) { - return tileAndFuseLinalgOpAndSetMarkerImpl( - tileLinalgOp, rewriter, op, sizes, operandIndicesToFuse, linalgMarker); -} -LogicalResult mlir::linalg::tileAndFuseLinalgOpToParallelLoopsAndSetMarker( - PatternRewriter &rewriter, Operation *op, ArrayRef sizes, - ArrayRef operandIndicesToFuse, StringRef linalgMarker) { - return tileAndFuseLinalgOpAndSetMarkerImpl( - tileLinalgOpToParallelLoops, rewriter, op, sizes, operandIndicesToFuse, - linalgMarker); -} - -bool mlir::linalg::detail::isProducedByOpOfTypeImpl( - Operation *consumerOp, Value consumedView, - function_ref isaOpType) { - LinalgOp consumer = dyn_cast(consumerOp); - assert(consumer.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - if (!consumer) - return false; - - auto maybeConsumerIndex = consumer.getIndexOfInput(consumedView); - if (!maybeConsumerIndex) - return false; - - Aliases aliases; - auto G = LinalgDependenceGraph::buildDependenceGraph( - aliases, consumer.getParentOfType()); - for (auto dependence : G.getDependencesInto( - consumer, LinalgDependenceGraph::DependenceType::RAW)) { - auto producer = cast(dependence.dependentOpView.op); - if (!isProducerLastWriteOfView(G, consumer, consumedView, producer)) - continue; - if (isaOpType(dependence.dependentOpView.op)) - return true; - } - return false; -} - -//============================================================================// -// Precondition and transformation for vectorization of Linalg generic ops. -//============================================================================// -static bool hasMultiplyAddBody(linalg::GenericOp op) { - auto &r = op.region(); - if (r.empty()) - return false; - if (r.getBlocks().size() != 1) - return false; - auto &ops = r.front().getOperations(); - if (ops.size() != 3) - return false; - - using mlir::matchers::m_Val; - auto a = m_Val(r.front().getArgument(0)); - auto b = m_Val(r.front().getArgument(1)); - auto c = m_Val(r.front().getArgument(2)); - // TODO(ntv) Update this detection once we have matcher support for - // specifying that any permutation of operands matches. - auto pattern1 = m_Op(m_Op(m_Op(a, b), c)); - auto pattern2 = m_Op(m_Op(c, m_Op(a, b))); - auto pattern3 = m_Op(m_Op(m_Op(b, a), c)); - auto pattern4 = m_Op(m_Op(c, m_Op(b, a))); - return pattern1.match(&ops.back()) || pattern2.match(&ops.back()) || - pattern3.match(&ops.back()) || pattern4.match(&ops.back()); -} - -// TODO(ntv) should be Tablegen'd from a single source that generates the op -// itself. -static bool isRowMajorMatmul(linalg::GenericOp genericOp) { - return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 && - isRowMajorMatmul(genericOp.indexing_maps()) && - hasMultiplyAddBody(genericOp); -} - -// TODO(ntv, ataei): This is in fact much more general than just vectorization -// for matmul and fill ops. -LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { - auto linalgOp = cast(op); - // All types must be static shape to go to vector. - for (Value operand : linalgOp.getInputsAndOutputBuffers()) - if (!operand.getType().cast().hasStaticShape()) - return failure(); - for (Type outputTensorType : linalgOp.getOutputTensorTypes()) - if (!outputTensorType.cast().hasStaticShape()) - return failure(); - if (isa(op) || isa(op)) - return success(); - - auto genericOp = dyn_cast(op); - if (!genericOp || !::isRowMajorMatmul(genericOp)) - return failure(); - - // TODO(ntv): non-identity layout. - auto isStaticMemRefWithIdentityLayout = [](Value v) { - auto m = v.getType().dyn_cast(); - if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty()) - return false; - return true; - }; - if (!llvm::all_of(genericOp.getInputsAndOutputBuffers(), - isStaticMemRefWithIdentityLayout)) - return failure(); - return success(); -} - -SmallVector mlir::linalg::vectorizeLinalgOp(PatternRewriter &rewriter, - Operation *op) { - assert(succeeded(vectorizeLinalgOpPrecondition(op)) && - "DRR failure case must be a precondition"); - auto linalgOp = cast(op); - assert(linalgOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - if (auto convOp = dyn_cast(op)) { - // TODO(ntv): add a level of indirection to linalg.generic. - if (convOp.padding()) - llvm_unreachable("Unexpected conv with padding"); - } - - edsc::ScopedContext scope(rewriter, op->getLoc()); - - if (auto fillOp = dyn_cast(op)) { - // Vectorize fill as a vector.broadcast. - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE - "]: Rewrite linalg.fill as vector.broadcast: " - << *op << ":\n"); - auto dstMemrefVec = vector_type_cast(fillOp.getOutputBuffer(0)); - Value dstVec = std_load(dstMemrefVec); - auto resVec = vector_broadcast(dstVec.getType(), fillOp.value()); - std_store(resVec, dstMemrefVec); - } else { - // Vectorize other ops as vector contraction (currently only matmul). - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE - "]: Rewrite linalg op as vector.contract: " - << *op << ":\n"); - auto vA = std_load(vector_type_cast(linalgOp.getInput(0))); - auto vB = std_load(vector_type_cast(linalgOp.getInput(1))); - auto vectorMemRefC = vector_type_cast(linalgOp.getOutputBuffer(0)); - auto vC = std_load(vectorMemRefC); - auto vRes = vector_contract(vA, vB, vC, linalgOp.indexing_maps(), - linalgOp.iterator_types()); - std_store(vRes, vectorMemRefC); - } - return {}; -} - -//============================================================================// -// Precondition and transformation for permutation of Linalg generic ops. -//============================================================================// -LogicalResult mlir::linalg::permuteGenericLinalgOpPrecondition( - Operation *op, ArrayRef permutation) { - if (permutation.empty()) - return failure(); - // Transformation applies to generic ops only. - if (!isa(op) && !isa(op)) - return failure(); - LinalgOp linOp = cast(op); - // Transformation applies to buffers only. - if (!linOp.hasBufferSemantics()) - return failure(); - return success(); -} - -SmallVector -mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op, - ArrayRef permutation, - StringRef linalgMarker) { - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Permute dims for linalg op: " << *op - << ":\n"); - - assert(succeeded(permuteGenericLinalgOpPrecondition(op, permutation)) && - "DRR failure case must be a precondition"); - - auto linOp = cast(op); - auto permutationMap = inversePermutation( - AffineMap::getPermutationMap(permutation, rewriter.getContext())); - assert(permutationMap && "expected permutation to be invertible"); - SmallVector newIndexingMap; - auto indexingMaps = linOp.indexing_maps().getValue(); - for (unsigned i = 0, e = linOp.getNumInputsAndOutputs(); i != e; ++i) { - AffineMap m = indexingMaps[i].cast().getValue(); - if (!permutationMap.isEmpty()) - m = m.compose(permutationMap); - newIndexingMap.push_back(m); - } - auto itTypes = linOp.iterator_types().getValue(); - SmallVector itTypesVector; - for (unsigned i = 0, e = itTypes.size(); i != e; ++i) - itTypesVector.push_back(itTypes[i]); - applyPermutationToVector(itTypesVector, permutation); - op->setAttr(getIndexingMapsAttrName(), - rewriter.getAffineMapArrayAttr(newIndexingMap)); - op->setAttr(getIteratorTypesAttrName(), rewriter.getArrayAttr(itTypesVector)); - op->setAttr(LinalgTransforms::kLinalgTransformMarker, - rewriter.getStringAttr(linalgMarker)); - linOp.clone(rewriter, linOp.getLoc(), op->getOperands()); - return {}; -} - -//============================================================================// -// Precondition and transformation for Linalg subview promotion. -//============================================================================// -LogicalResult mlir::linalg::promoteSubviewsLinalgOpPrecondition(Operation *op) { - LinalgOp linOp = dyn_cast(op); - // Transformation applies to buffers only. - if (!linOp || !linOp.hasBufferSemantics()) - return failure(); - if (llvm::none_of(linOp.getInputsAndOutputBuffers(), [](Value v) { - return isa_and_nonnull(v.getDefiningOp()); - })) - return failure(); - return success(); -} - -SmallVector -mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter, - Operation *op) { - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Promote subviews for linalg op: " - << *op << ":\n"); - - assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) && - "DRR failure case must be a precondition"); - - LinalgOp linOp = cast(op); - SmallVector toPromote; - int64_t nBuffers = linOp.getNumInputsAndOutputBuffers(); - toPromote.reserve(nBuffers); - for (int64_t i = 0; i < nBuffers; ++i) - toPromote.push_back(i); - return promoteSelectedSubviewsLinalgOpAndSetMarker(rewriter, op, toPromote); -} - -SmallVector mlir::linalg::promoteSelectedSubviewsLinalgOpAndSetMarker( - PatternRewriter &rewriter, Operation *op, - ArrayRef operandIndicesToPromote, StringRef linalgMarker, - int64_t alignment) { - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Promote subviews for linalg op: " - << *op << ":\n"); - - assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) && - "DRR failure case must be a precondition"); - - if (auto convOp = dyn_cast(op)) { - // TODO(ntv): add a level of indirection to linalg.generic. - if (convOp.padding()) - llvm_unreachable("Unexpected conv with padding"); - } - - LinalgOp linOp = cast(op); - assert(linOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - SetVector subViews; - for (int64_t index : operandIndicesToPromote) - if (auto sv = - dyn_cast_or_null(linOp.getBuffer(index).getDefiningOp())) - subViews.insert(sv); - - if (!subViews.empty()) { - auto newOp = - promoteSubViewOperands(rewriter, linOp, subViews, false, alignment); - if (!linalgMarker.empty()) - newOp.setAttr(LinalgTransforms::kLinalgTransformMarker, - rewriter.getStringAttr(linalgMarker)); - return {}; - } - llvm_unreachable("DRR failure case must be a precondition"); -} diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp similarity index 93% rename from mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp rename to mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index de4b6f18e42d..c5e7958b84a1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -1,4 +1,4 @@ -//===- LinalgToLoops.cpp - conversion from Linalg library ops to loops-----===// +//===- Loops.cpp - conversion from Linalg named and generic ops to loops --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -12,7 +12,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/LoopOps/EDSC/Builders.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" @@ -489,20 +489,6 @@ public: } }; -/// This struct is for factoring out the implementation and support template -/// instantiations in the following 2 cases: -/// 1. Appending to a list of patterns via RewritePatternList. -/// 2. Direct invocation via `linalgOpToLoops` and `linalgOpToAffineLoops`. -/// The implementation must work both in DRR and inside a RewritePattern. As a -/// consequence, (1) it is only allowed to emit new ops if the match is -/// guaranteed to be a success, (2) it is not allowed erase/replace, and (3) an -/// encompassing pattern must take care of the erasure logic. -template -class LinalgOpToLoopsImpl { -public: - static Optional doit(Operation *op, PatternRewriter &rewriter); -}; - namespace { /// Helper struct to generate the loop nest for the op. This factored out here /// to be able to partially specialize this for different LoopTy. @@ -573,14 +559,12 @@ public: } // namespace template -Optional -LinalgOpToLoopsImpl::doit(Operation *op, - PatternRewriter &rewriter) { +Optional linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) { using Impl = GenerateLoopNest; using IndexedValueTy = typename GenerateLoopNest::IndexedValueTy; - ScopedContext scope(rewriter, op->getLoc()); + ScopedContext scope(builder, op->getLoc()); // The flattened loopToOperandRangesMaps is expected to be an invertible // permutation map (which is asserted in the inverse calculation). @@ -607,7 +591,7 @@ LinalgOpToLoopsImpl::doit(Operation *op, SmallVector allIvs(nLoops); auto loopRanges = emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), invertedMap, - getViewSizes(rewriter, linalgOp)); + getViewSizes(builder, linalgOp)); assert(loopRanges.size() == allIvs.size()); Impl::doit(linalgOp, loopRanges, allIvs); // Number of loop ops might be different from the number of ivs since some @@ -635,8 +619,7 @@ public: LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - using Impl = LinalgOpToLoopsImpl; - if (!Impl::doit(op, rewriter)) + if (!linalgOpToLoopsImpl(op, rewriter)) return failure(); rewriter.eraseOp(op); return success(); @@ -662,7 +645,7 @@ public: } }; -/// Populate the given list with patterns that convert from Linalg to LLVM. +/// Populate the given list with patterns that convert from Linalg to loops. template void FillRewritePatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) { RewritePatternList -Optional -mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter, Operation *op) { - return LinalgOpToLoopsImpl::doit(op, rewriter); +Optional mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, + Operation *op) { + return linalgOpToLoopsImpl(op, builder); } /// Emits a loop nest of `loop.for` with the proper body for `op`. template -LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter, - Operation *op) { +LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) { Optional loops = - linalgLowerOpToLoops(rewriter, op); + linalgLowerOpToLoops(builder, op); return loops ? success() : failure(); } /// Emits a loop nest of `affine.for` with the proper body for `op`. template -LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter, +LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, Operation *op) { Optional loops = - linalgLowerOpToLoops(rewriter, op); + linalgLowerOpToLoops(builder, op); return loops ? success() : failure(); } /// Emits a loop nest of `loop.parallel` with the proper body for `op`. template -LogicalResult mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter, +LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, Operation *op) { Optional loops = - linalgLowerOpToLoops(rewriter, op); + linalgLowerOpToLoops(builder, op); return loops ? success() : failure(); } -// TODO(ntv) Need to make these instantiations more future-proof to avoid the -// need to update as soon as we add new ops. +// TODO Need to make these instantiations more future-proof to avoid the need to +// update as soon as we add new ops. #define INSTANTIATE_LINALG_OP_TO_LOOPS(OP_TYPE) \ template LogicalResult mlir::linalg::linalgOpToLoops( \ - PatternRewriter & rewriter, Operation * op); \ + OpBuilder & builder, Operation * op); \ template LogicalResult mlir::linalg::linalgOpToAffineLoops( \ - PatternRewriter & rewriter, Operation * op); \ + OpBuilder & builder, Operation * op); \ template LogicalResult mlir::linalg::linalgOpToParallelLoops( \ - PatternRewriter & rewriter, Operation * op); \ + OpBuilder & builder, Operation * op); \ template Optional \ mlir::linalg::linalgLowerOpToLoops( \ - PatternRewriter & rewriter, Operation * op); + OpBuilder & builder, Operation * op); INSTANTIATE_LINALG_OP_TO_LOOPS(CopyOp) INSTANTIATE_LINALG_OP_TO_LOOPS(FillOp) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index ca905116d71e..5e277b187624 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" @@ -264,6 +265,21 @@ static void promoteSubViews(FuncOp f, bool dynamicBuffers) { op.erase(); } +LogicalResult mlir::linalg::promoteSubviewsLinalgOpPrecondition( + Operation *op, llvm::Optional> operandIndicesToPromote) { + LinalgOp linOp = dyn_cast(op); + // Transformation applies to buffers only. + if (!linOp || !linOp.hasBufferSemantics()) + return failure(); + for (auto en : llvm::enumerate(linOp.getInputsAndOutputBuffers())) { + auto sv = isa_and_nonnull(en.value().getDefiningOp()); + if (sv && (!operandIndicesToPromote.hasValue() || + operandIndicesToPromote->count(en.index()))) + return success(); + } + return failure(); +} + namespace { struct LinalgPromotionPass : public LinalgPromotionBase { LinalgPromotionPass() = default; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 4b1fed8d3427..b6977e01266f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/LoopOps/EDSC/Builders.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" @@ -320,10 +321,9 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, } template -Optional static tileLinalgOpImpl(OpBuilder &b, LinalgOp op, - ArrayRef tileSizes, - ArrayRef permutation, - OperationFolder *folder) { +Optional static tileLinalgOpImpl( + OpBuilder &b, LinalgOp op, ArrayRef tileSizes, + ArrayRef interchangeVector, OperationFolder *folder) { assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); // 1. Enforce the convention that "tiling by zero" skips tiling a particular // dimension. This convention is significantly simpler to handle instead of @@ -342,13 +342,13 @@ Optional static tileLinalgOpImpl(OpBuilder &b, LinalgOp op, return llvm::None; } - // If permutation is empty, use the identity. Build the permutation map + // If interchangeVector is empty, use the identity. Build the permutation map // otherwise. auto invPermutationMap = AffineMap::getMultiDimIdentityMap( tileSizes.size(), ScopedContext::getContext()); - if (!permutation.empty()) - invPermutationMap = inversePermutation( - AffineMap::getPermutationMap(permutation, ScopedContext::getContext())); + if (!interchangeVector.empty()) + invPermutationMap = inversePermutation(AffineMap::getPermutationMap( + interchangeVector, ScopedContext::getContext())); if (!invPermutationMap) return llvm::None; @@ -371,8 +371,8 @@ Optional static tileLinalgOpImpl(OpBuilder &b, LinalgOp op, std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges(b, scope.getLocation(), viewSizesToLoopsMap, viewSizes, tileSizes, folder); - if (!permutation.empty()) - applyPermutationToVector(loopRanges, permutation); + if (!interchangeVector.empty()) + applyPermutationToVector(loopRanges, interchangeVector); // 3. Create the tiled loops. LinalgOp res = op; @@ -393,7 +393,7 @@ Optional static tileLinalgOpImpl(OpBuilder &b, LinalgOp op, // assuming that loopRanges have previously been permuted by // (i,j,k)->(k,i,j) So this permutation should be the inversePermutation of // that one: (d0,d1,d2)->(d2,d0,d1) - if (!permutation.empty()) + if (!interchangeVector.empty()) ivValues = applyMapToValues(b, loc, invPermutationMap, ivValues, folder); auto views = @@ -420,7 +420,8 @@ Optional static tileLinalgOpImpl(OpBuilder &b, LinalgOp op, template static Optional tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef tileSizes, - ArrayRef permutation, OperationFolder *folder) { + ArrayRef interchangeVector, + OperationFolder *folder) { assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); if (tileSizes.empty()) return llvm::None; @@ -459,33 +460,36 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef tileSizes, tileSizeValues.push_back(folded_std_constant_index(folder, 0)); } - return tileLinalgOpImpl(b, op, tileSizeValues, permutation, folder); + return tileLinalgOpImpl(b, op, tileSizeValues, interchangeVector, + folder); } Optional mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef tileSizes, - ArrayRef permutation, + ArrayRef interchangeVector, OperationFolder *folder) { - return tileLinalgOpImpl(b, op, tileSizes, permutation, folder); + return tileLinalgOpImpl(b, op, tileSizes, interchangeVector, + folder); } Optional mlir::linalg::tileLinalgOpToParallelLoops( OpBuilder &b, LinalgOp op, ArrayRef tileSizes, - ArrayRef permutation, OperationFolder *folder) { - return tileLinalgOpImpl(b, op, tileSizes, permutation, + ArrayRef interchangeVector, OperationFolder *folder) { + return tileLinalgOpImpl(b, op, tileSizes, interchangeVector, folder); } Optional mlir::linalg::tileLinalgOp( OpBuilder &b, LinalgOp op, ArrayRef tileSizes, - ArrayRef permutation, OperationFolder *folder) { - return tileLinalgOpImpl(b, op, tileSizes, permutation, folder); + ArrayRef interchangeVector, OperationFolder *folder) { + return tileLinalgOpImpl(b, op, tileSizes, interchangeVector, + folder); } Optional mlir::linalg::tileLinalgOpToParallelLoops( OpBuilder &b, LinalgOp op, ArrayRef tileSizes, - ArrayRef permutation, OperationFolder *folder) { - return tileLinalgOpImpl(b, op, tileSizes, permutation, + ArrayRef interchangeVector, OperationFolder *folder) { + return tileLinalgOpImpl(b, op, tileSizes, interchangeVector, folder); } @@ -496,8 +500,8 @@ static void tileLinalgOps(FuncOp f, ArrayRef tileSizes) { f.walk([tileSizes, &b, &folder](LinalgOp op) { if (!op.hasBufferSemantics()) return; - auto opLoopsPair = - tileLinalgOpImpl(b, op, tileSizes, /*permutation=*/{}, &folder); + auto opLoopsPair = tileLinalgOpImpl( + b, op, tileSizes, /*interchangeVector=*/{}, &folder); // If tiling occurred successfully, erase old op. if (opLoopsPair) op.erase(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp new file mode 100644 index 000000000000..e229b10072f0 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -0,0 +1,228 @@ +//===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements logic and helpers to expose Linalg transforms as rewrite +// patterns. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/EDSC/Intrinsics.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include + +#define DEBUG_TYPE "linalg-transforms" + +using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; +using namespace mlir::linalg; + +using llvm::dbgs; + +//===----------------------------------------------------------------------===// +// Transformations exposed as rewrite patterns. +//===----------------------------------------------------------------------===// +// Marker used as attribute name in generated Linalg rewriting transformations. +const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = + "__internal_linalg_transform__"; + +mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef matchDisjunction, + llvm::Optional replacement) + : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), + replacement(replacement) {} + +mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef matchDisjunction, + StringRef replacement) + : LinalgMarker(matchDisjunction, llvm::Optional{replacement}) {} + +LogicalResult +mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter, + Operation *op) const { + auto attr = op->template getAttrOfType( + LinalgTransforms::kLinalgTransformMarker); + + if (!attr) { + // 1. Has no marker case and matchDisjunction is empty. + if (matchDisjunction.empty()) + return success(); + + // 2. Has no marker and matchDisjuntion matches the no-moarker case. + for (auto marker : matchDisjunction) + if (marker.empty()) + return success(); + + // 3. Has no marker but was expecting a marker. + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << " does not have any marker from list: "; + llvm::interleaveComma(matchDisjunction, diag); + }); + } + + // 4. Match explicit marker. + for (auto marker : matchDisjunction) + if (attr.getValue() == marker) + return success(); + + // 5. Fail to match. + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << " does not have any marker from list: "; + llvm::interleaveComma(matchDisjunction, diag); + }); +} + +void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter, + Operation *op) const { + if (replacement.hasValue()) + op->setAttr(LinalgTransforms::kLinalgTransformMarker, + rewriter.getStringAttr(replacement.getValue())); + else + op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, + rewriter.getContext())); +} + +/// Linalg base tiling pattern. +mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( + StringRef opName, MLIRContext *context, LinalgTilingOptions options, + LinalgMarker marker, PatternBenefit benefit) + : RewritePattern(opName, {}, benefit, context), marker(marker), + options(options) {} + +LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { + LinalgOp linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + if (failed(marker.checkAndNotify(rewriter, linalgOp))) + return failure(); + Optional res; + if (options.loopType == LinalgTilingLoopType::Loops) + res = tileLinalgOp(rewriter, linalgOp, options.tileSizes, + options.interchangeVector); + else if (options.loopType == LinalgTilingLoopType::ParallelLoops) + res = tileLinalgOpToParallelLoops(rewriter, linalgOp, options.tileSizes, + options.interchangeVector); + // TODO: Impl tiling to affine loops when it makes sense. + + if (!res) + return failure(); + + // New marker if specified. + marker.replaceLinalgMarker(rewriter, res->op.getOperation()); + + rewriter.eraseOp(op); + return success(); +} + +/// Linalg base interchange pattern. +mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern( + StringRef opName, MLIRContext *context, + ArrayRef interchangeVector, LinalgMarker marker, + PatternBenefit benefit) + : RewritePattern(opName, {}, benefit, context), marker(marker), + interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} + +LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { + LinalgOp linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + if (failed(marker.checkAndNotify(rewriter, linalgOp))) + return failure(); + if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector))) + return failure(); + + // TODO: figure out how this interplays with named ops. In particular this + // should break the named op property. + rewriter.updateRootInPlace(op, [&]() { + interchange(linalgOp, interchangeVector); + // New marker if specified. + marker.replaceLinalgMarker(rewriter, op); + }); + return success(); +} + +mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( + StringRef opName, MLIRContext *context, + ArrayRef operandsToPromote, unsigned alignment, + LinalgMarker marker, PatternBenefit benefit) + : RewritePattern(opName, {}, benefit, context), marker(marker), + operandsToPromote(operandsToPromote.begin(), operandsToPromote.end()), + alignment(alignment) {} + +LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { + LinalgOp linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + if (failed(marker.checkAndNotify(rewriter, linalgOp))) + return failure(); + if (operandsToPromote.empty()) { + if (failed(promoteSubviewsLinalgOpPrecondition(op, llvm::None))) + return failure(); + } else { + DenseSet set; + set.insert(operandsToPromote.begin(), operandsToPromote.end()); + if (failed(promoteSubviewsLinalgOpPrecondition(op, set))) + return failure(); + } + + llvm::SetVector subViews; + if (!operandsToPromote.empty()) { + for (unsigned idx : operandsToPromote) { + auto *op = linalgOp.getBuffer(idx).getDefiningOp(); + if (auto sv = dyn_cast_or_null(op)) + subViews.insert(sv); + } + } else { + unsigned nBuffers = linalgOp.getNumInputsAndOutputBuffers(); + for (unsigned idx = 0; idx < nBuffers; ++idx) { + auto *op = linalgOp.getBuffer(idx).getDefiningOp(); + if (auto sv = dyn_cast_or_null(op)) + subViews.insert(sv); + } + } + + auto promotedOp = + promoteSubViewOperands(rewriter, op, subViews, /*dynamicBuffers=*/false, + /*alignment=*/alignment); + marker.replaceLinalgMarker(rewriter, promotedOp.getOperation()); + rewriter.eraseOp(op); + return success(); +} + +mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( + StringRef opName, MLIRContext *context, LinalgMarker marker, + PatternBenefit benefit) + : RewritePattern(opName, {}, benefit, context), marker(marker) {} + +LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { + LinalgOp linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + if (failed(marker.checkAndNotify(rewriter, linalgOp))) + return failure(); + if (failed(vectorizeLinalgOpPrecondition(op))) + return failure(); + vectorizeLinalgOp(rewriter, op); + rewriter.eraseOp(op); + return success(); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp new file mode 100644 index 000000000000..f27baa3c662a --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -0,0 +1,131 @@ +//===- Vectorization.cpp - Implementation of linalg Vectorization ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the linalg dialect Vectorization transformations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/EDSC/Intrinsics.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; +using namespace mlir::linalg; + +using llvm::dbgs; + +#define DEBUG_TYPE "linalg-vectorization" + +static bool hasMultiplyAddBody(linalg::GenericOp op) { + auto &r = op.region(); + if (!llvm::hasSingleElement(r)) + return false; + if (!llvm::hasNItems(r.front().begin(), r.front().end(), 3)) + return false; + + using mlir::matchers::m_Val; + auto a = m_Val(r.front().getArgument(0)); + auto b = m_Val(r.front().getArgument(1)); + auto c = m_Val(r.front().getArgument(2)); + // TODO: Update this detection once we have matcher support for specifying + // that any permutation of operands matches. + auto pattern1 = m_Op(m_Op(m_Op(a, b), c)); + auto pattern2 = m_Op(m_Op(c, m_Op(a, b))); + auto pattern3 = m_Op(m_Op(m_Op(b, a), c)); + auto pattern4 = m_Op(m_Op(c, m_Op(b, a))); + return pattern1.match(&r.front().back()) || + pattern2.match(&r.front().back()) || + pattern3.match(&r.front().back()) || pattern4.match(&r.front().back()); +} + +// TODO: Should be Tablegen'd from a single source that generates the op itself. +static bool isRowMajorMatmul(linalg::GenericOp genericOp) { + return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 && + isRowMajorMatmul(genericOp.indexing_maps()) && + hasMultiplyAddBody(genericOp); +} + +// TODO: This is in fact much more general than just vectorization for matmul +// and fill ops. +LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { + auto linalgOp = cast(op); + // All types must be static shape to go to vector. + for (Value operand : linalgOp.getInputsAndOutputBuffers()) + if (!operand.getType().cast().hasStaticShape()) + return failure(); + for (Type outputTensorType : linalgOp.getOutputTensorTypes()) + if (!outputTensorType.cast().hasStaticShape()) + return failure(); + if (isa(op) || isa(op)) + return success(); + + auto genericOp = dyn_cast(op); + if (!genericOp || !::isRowMajorMatmul(genericOp)) + return failure(); + + // TODO(ntv): non-identity layout. + auto isStaticMemRefWithIdentityLayout = [](Value v) { + auto m = v.getType().dyn_cast(); + if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty()) + return false; + return true; + }; + return success(llvm::all_of(genericOp.getInputsAndOutputBuffers(), + isStaticMemRefWithIdentityLayout)); +} + +void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) { + assert(succeeded(vectorizeLinalgOpPrecondition(op))); + + if (auto convOp = dyn_cast(op)) { + // TODO: add a level of indirection to linalg.generic. + if (convOp.padding()) + llvm_unreachable("Unexpected conv with padding"); + } + + edsc::ScopedContext scope(builder, op->getLoc()); + if (auto fillOp = dyn_cast(op)) { + // Vectorize fill as a vector.broadcast. + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE + "]: Rewrite linalg.fill as vector.broadcast: " + << *op << ":\n"); + Value memref = vector_type_cast(fillOp.getOutputBuffer(0)); + Value dst = std_load(memref); + Value res = vector_broadcast(dst.getType(), fillOp.value()); + std_store(res, memref); + return; + } + + // Vectorize other ops as vector contraction (currently only matmul). + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE + "]: Rewrite linalg op as vector.contract: " + << *op << ":\n"); + auto linalgOp = cast(op); + Value a = std_load(vector_type_cast(linalgOp.getInput(0))); + Value b = std_load(vector_type_cast(linalgOp.getInput(1))); + Value memref = vector_type_cast(linalgOp.getOutputBuffer(0)); + Value c = std_load(memref); + Value res = vector_contract(a, b, c, linalgOp.indexing_maps(), + linalgOp.iterator_types()); + std_store(res, memref); +} diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir index d230aa993611..f5ef4fff8165 100644 --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-patterns | FileCheck %s // CHECK-DAG: #[[STRIDED_1D:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> // Map corresponding to a 2D memory access where the stride along the last dim is known to be 1. @@ -22,10 +22,8 @@ func @dot(%x: memref, // CHECK-LABEL: func @dot // CHECK-DAG: %[[c0:.*]] = constant 0 : index // CHECK-DAG: %[[c1:.*]] = constant 1 : index -// CHECK-DAG: %[[c8:.*]] = constant 8 : index // CHECK-DAG: %[[c8000:.*]] = constant 8000 : index // CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8000]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8]] { // CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c1]] { // CHECK: load // CHECK: load @@ -46,8 +44,7 @@ func @matvec(%A: memref, // CHECK-DAG: %[[c0:.*]] = constant 0 : index // CHECK-DAG: %[[c5:.*]] = constant 5 : index // CHECK-DAG: %[[c6:.*]] = constant 6 : index -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c5]] -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c6]] +// CHECK: loop.parallel {{.*}} step (%[[c5]], %[[c6]]) // CHECK: linalg.matvec({{.*}}, {{.*}}, {{.*}}) : memref, memref, memref func @matmul(%A: memref, @@ -86,88 +83,6 @@ func @matmul(%A: memref, // CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4]] { // CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref, memref, memref -#some_generic_trait = { - args_in = 1, - args_out = 1, - indexing_maps = [ - affine_map<(i, j) -> (i, j)>, - affine_map<(i, j) -> (i, j)> - ], - iterator_types = ["parallel", "parallel"] -} -func @fusion_test(%A: memref, - %B: memref, - %C: memref, - %D: memref, - %E: memref) { - // This should not be fused as it would violate dependencies. It will get - // tiled for all levels of the memory hierarchy. - linalg.matmul(%A, %A, %C) : memref, - memref, - memref - - // This should be fused. - linalg.matmul(%A, %B, %C) : memref, - memref, - memref - - // This should not be fused or transformed at all since there are no patterns - // on it. However it will be reordered because there are no dependencies. - linalg.generic #some_generic_trait %A, %D { - ^bb(%a: f32, %b: f32) : - linalg.yield %a : f32 - } : memref, - memref - - linalg.matmul(%C, %D, %E) : memref, - memref, - memref - - return -} -// CHECK-LABEL: func @fusion_test -// CHECK-DAG: %[[c0:.*]] = constant 0 : index -// CHECK-DAG: %[[c2:.*]] = constant 2 : index -// CHECK-DAG: %[[c3:.*]] = constant 3 : index -// CHECK-DAG: %[[c4:.*]] = constant 4 : index -// CHECK-DAG: %[[c20:.*]] = constant 20 : index -// CHECK-DAG: %[[c30:.*]] = constant 30 : index -// CHECK-DAG: %[[c40:.*]] = constant 40 : index -// CHECK-DAG: %[[c100:.*]] = constant 100 : index -// CHECK-DAG: %[[c150:.*]] = constant 150 : index -// CHECK-DAG: %[[c200:.*]] = constant 200 : index -// CHECK-DAG: %[[c300:.*]] = constant 300 : index -// CHECK-DAG: %[[c400:.*]] = constant 400 : index -// CHECK-DAG: %[[c2000:.*]] = constant 2000 : index -// CHECK-DAG: %[[c3000:.*]] = constant 3000 : index -// CHECK-DAG: %[[c4000:.*]] = constant 4000 : index -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c200]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c300]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c400]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c20]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4]] { -// CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref, memref, memref -// -// CHECK: linalg.generic -// -// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c100]] { -// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c150]] { -// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c2]] { -// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c3]] { -// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c4]] { -// CHECK: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) : memref, memref, memref -// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c2]] { -// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c3]] { -// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c4]] { -// CHECK: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) : memref, memref, memref - #matmul_trait = { args_in = 2, args_out = 1, @@ -280,23 +195,6 @@ func @permute_generic_indexed( // CHECK-SAME: memref, // CHECK-SAME: memref -func @dot_perm(%x: memref, - %y: memref, - %v: memref) { - linalg.dot(%x, %y, %v) {__internal_linalg_transform__ = "__with_perm__"} : - memref, - memref, - memref - return -} -// CHECK-LABEL: func @dot_perm -// CHECK-DAG: %[[c0:.*]] = constant 0 : index -// CHECK-DAG: %[[c8:.*]] = constant 8 : index -// CHECK-DAG: %[[c8000:.*]] = constant 8000 : index -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8000]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8]] { -// CHECK: linalg.dot({{.*}}, {{.*}}, {{.*}}) : memref, memref, memref - func @matvec_perm(%A: memref, %x: memref, %y: memref) { diff --git a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt index 4ca78772b68a..67d194ff868a 100644 --- a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt +++ b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt @@ -1,9 +1,3 @@ -set(LLVM_TARGET_DEFINITIONS TestLinalgTransformPatterns.td) -mlir_tablegen(TestLinalgTransformPatterns.h.inc -gen-rewriters) -add_public_tablegen_target(MLIRTestLinalgTransformPatternsIncGen) -# Including Linalg in TableGen requires to depends on generated files -add_dependencies(MLIRTestLinalgTransformPatternsIncGen LinalgOdsGen) - set(LLVM_TARGET_DEFINITIONS TestVectorTransformPatterns.td) mlir_tablegen(TestVectorTransformPatterns.h.inc -gen-rewriters) add_public_tablegen_target(MLIRTestVectorTransformPatternsIncGen) diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td deleted file mode 100644 index 313e2f8171a8..000000000000 --- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td +++ /dev/null @@ -1,168 +0,0 @@ -//===- TestLinalgTransformPatterns.td - Test patterns --*- tablegen ----*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This is the pattern definition file for declarative Linalg transformations -// tests. -// -//===----------------------------------------------------------------------===// - -#ifndef TEST_LINALG_TRANSFORMS_PATTERNS -#define TEST_LINALG_TRANSFORMS_PATTERNS - -include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td" - -//===----------------------------------------------------------------------===// -// Test Linalg fusion patterns. -//===----------------------------------------------------------------------===// -def : Pat<(MatmulOp:$op $A, $_, $_), - (TileAndFuseLinalgOp<[100, 150], [0], "L1">), - [ - (Constraint), - (Constraint> $A), - ], - // In the buffer world there is no use-def chains or dags so benefits - // cannot be computed automatically from the length of the matched - // pattern. Instead we specify the benefit ourselves for now. - // This is not expected to be a big challenge long-term because - // pattern benefits are akin to feature engineering: features should - // be learned. - (addBenefit 1)>; - -//===----------------------------------------------------------------------===// -// Linalg tiling patterns. -//===----------------------------------------------------------------------===// -def : Pat<(MatmulOp:$op $_, $_, $_), - (TileLinalgOp<[2000, 3000, 4000], "L3">), - [(Constraint]>>)]>; -def : Pat<(MatmulOp:$op $_, $_, $_), - (TileLinalgOp<[200, 300, 400], "L2">), - [(Constraint>)]>; -def : Pat<(MatmulOp:$op $_, $_, $_), - (TileLinalgOp<[20, 30, 40], "L1">), - [(Constraint>)]>; -def : Pat<(MatmulOp:$op $_, $_, $_), - (TileLinalgOp<[2, 3, 4], "REG">), - [(Constraint>)]>; - -def : Pattern<(MatvecOp:$op $_, $_, $_), - [(TileLinalgOp<[5, 6], "L1">)], - [(Constraint)]>; - -def : Pattern<(DotOp:$op $_, $_, $_), - [(TileLinalgOp<[8000], "L1">)], - [(Constraint, - HasLinalgTransformMarker<"L3">, - HasLinalgTransformMarker<"L2">]>>)]>; -def : Pattern<(DotOp:$op $_, $_, $_), - [(TileLinalgOp<[8], "REG">)], - [(Constraint>)]>; - -//===----------------------------------------------------------------------===// -// Linalg tiling and permutation patterns. -//===----------------------------------------------------------------------===// -def : Pat<(MatmulOp:$op $_, $_, $_), - (TileLinalgOp<[2000, 3000, 4000], "L2__with_perm__", [1,2,0]>), - [(Constraint>)]>; -def : Pat<(MatmulOp:$op $_, $_, $_), - (TileLinalgOp<[200, 300, 400], "L1__with_perm__", [1,0,2]>), - [(Constraint>)]>; -def : Pat<(MatmulOp:$op $_, $_, $_), - (TileLinalgOp<[20, 30, 40], "REG__with_perm__">), - [(Constraint>)]>; - - -def : Pattern<(MatvecOp:$op $_, $_, $_), - [(TileLinalgOp<[5, 6], "L1__with_perm__", [1,0]>)], - [(Constraint>)]>; - -def : Pattern<(DotOp:$op $_, $_, $_), - [(TileLinalgOp<[8000], "L1__with_perm__">)], - [(Constraint>)]>; -def : Pattern<(DotOp:$op $_, $_, $_), - [(TileLinalgOp<[8], "REG__with_perm__">)], - [(Constraint>)]>; - -//===----------------------------------------------------------------------===// -// Linalg to loops patterns. -//===----------------------------------------------------------------------===// -def : Pattern<(DotOp:$op $_, $_, $_), - [(LinalgOpToLoops<"DotOp">)], - [(Constraint>)]>; - -//===----------------------------------------------------------------------===// -// Linalg to vector contraction patterns. -//===----------------------------------------------------------------------===// -def : Pattern<(MatmulOp:$op $_, $_, $_), - [(VectorizeLinalgOp)], - [(Constraint, - PreconditionVectorizeLinalgOp - ]>>)]>; -def : Pattern<(FillOp:$op $_, $_), - [(VectorizeLinalgOp)], - [(Constraint, - PreconditionVectorizeLinalgOp - ]>>)]>; -def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_), - [(VectorizeLinalgOp)], - [(Constraint, - PreconditionVectorizeLinalgOp - ]>>)]>; - - -//===----------------------------------------------------------------------===// -// Linalg generic permutation patterns. -//===----------------------------------------------------------------------===// -def : Pat<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_), - (PermuteGenericLinalgOp<[1, 2, 0], "PERMUTE"> $op), - [(Constraint, - PreconditionPermuteGenericLinalgOp<[1, 2, 0]> - ]>>)]>; - -def : Pat<(IndexedGenericOp:$op $_, $_, $_, $_, $_, $_, $_), - (PermuteGenericLinalgOp<[1, 2, 0], "PERMUTE"> $op), - [(Constraint, - PreconditionPermuteGenericLinalgOp<[1, 2, 0]> - ]>>)]>; - -//===----------------------------------------------------------------------===// -// Linalg subview operands promotion. -//===----------------------------------------------------------------------===// -def : Pat<(MatmulOp:$op $_, $_, $_), - (PromoteSubviewsLinalgOp), - [(Constraint, - HasLinalgTransformMarker<"_promote_views_">]>> - )]>; - -def : Pat<(MatmulOp:$op $_, $_, $_), - (PromoteSelectedSubviewsLinalgOp<[0], "first_view_promotion">), - [(Constraint, - HasLinalgTransformMarker<"_promote_first_view_">]>> - )]>; - -def : Pat<(FillOp:$op $_, $_), - (PromoteSelectedSubviewsLinalgOp<[0], "aligned_promotion", 32>), - [(Constraint, - HasLinalgTransformMarker<"_promote_views_aligned_">]>> - )]>; - -#endif // TEST_LINALG_TRANSFORMS_PATTERNS diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt index 0e0e15fb2a93..33129a9a9e0b 100644 --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -25,7 +25,6 @@ add_llvm_library(MLIRTestTransforms DEPENDS MLIRStandardOpsIncGen - MLIRTestLinalgTransformPatternsIncGen MLIRTestVectorTransformPatternsIncGen ) diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp index 7fc1138ff8d4..f3861c38fa60 100644 --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -10,36 +10,127 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" -#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "llvm/ADT/SetVector.h" + using namespace mlir; using namespace mlir::linalg; -namespace mlir { -namespace linalg { -namespace { -#include "TestLinalgTransformPatterns.h.inc" -} // end namespace -} // end namespace linalg -} // end namespace mlir - namespace { struct TestLinalgTransforms : public PassWrapper { + TestLinalgTransforms() = default; + TestLinalgTransforms(const TestLinalgTransforms &pass) {} + void runOnFunction() override; + + Option testPatterns{*this, "test-patterns", + llvm::cl::desc("Test a mixed set of patterns"), + llvm::cl::init(false)}; }; } // end anonymous namespace -/// Apply transformations specified as patterns. -void TestLinalgTransforms::runOnFunction() { +static void applyPatterns(FuncOp funcOp) { + MLIRContext *ctx = funcOp.getContext(); OwningRewritePatternList patterns; - auto funcOp = getFunction(); - // Add the generated patterns to the list. - linalg::populateWithGenerated(&getContext(), &patterns); + //===--------------------------------------------------------------------===// + // Linalg tiling patterns. + //===--------------------------------------------------------------------===// + patterns.insert>( + ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), + LinalgMarker({"MEM", {}}, "L3")); + patterns.insert>( + ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}), + LinalgMarker({"L3"}, "L2")); + patterns.insert>( + ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), + LinalgMarker({"L2"}, "L1")); + patterns.insert>( + ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}), + LinalgMarker({"L1"}, "REG")); + + patterns.insert>( + ctx, + LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( + LinalgTilingLoopType::ParallelLoops), + LinalgMarker({}, "L1")); + + patterns.insert>( + ctx, LinalgTilingOptions().setTileSizes(8000), + LinalgMarker({"MEM", "L3", "L2", {}}, "REG")); + + //===--------------------------------------------------------------------===// + // Linalg tiling and permutation patterns. + //===--------------------------------------------------------------------===// + patterns.insert>( + ctx, + LinalgTilingOptions() + .setTileSizes({2000, 3000, 4000}) + .setInterchange({1, 2, 0}), + LinalgMarker({"__with_perm__"}, "L2__with_perm__")); + patterns.insert>( + ctx, + LinalgTilingOptions() + .setTileSizes({200, 300, 400}) + .setInterchange({1, 0, 2}), + LinalgMarker({"L2__with_perm__"}, "L1__with_perm__")); + patterns.insert>( + ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), + LinalgMarker({"L1__with_perm__"}, "REG__with_perm__")); + + patterns.insert>( + ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), + LinalgMarker({"__with_perm__"}, "L1__with_perm__")); + + //===--------------------------------------------------------------------===// + // Linalg to loops patterns. + //===--------------------------------------------------------------------===// + patterns.insert>( + ctx, + /*loweringType=*/LinalgLoweringType::Loops, LinalgMarker({"REG"})); + + //===--------------------------------------------------------------------===// + // Linalg to vector contraction patterns. + //===--------------------------------------------------------------------===// + patterns.insert, + LinalgVectorizationPattern, + LinalgVectorizationPattern>( + ctx, LinalgMarker({"VECTORIZE"})); + + //===--------------------------------------------------------------------===// + // Linalg generic permutation patterns. + //===--------------------------------------------------------------------===// + patterns.insert>( + ctx, + /*interchangeVector=*/ArrayRef{1, 2, 0}, + LinalgMarker({}, "PERMUTED")); + patterns.insert>( + ctx, + /*interchangeVector=*/ArrayRef{1, 2, 0}, + LinalgMarker({}, "PERMUTED")); + + //===--------------------------------------------------------------------===// + // Linalg subview operands promotion. + //===--------------------------------------------------------------------===// + patterns.insert>( + ctx, LinalgMarker({"_promote_views_"}, "_views_promoted_")); + patterns.insert>( + ctx, + /*operandsToPromote=*/ArrayRef{0}, + LinalgMarker({"_promote_first_view_"}, "_first_view_promoted_")); + patterns.insert>( + ctx, + /*operandsToPromote=*/ArrayRef{0}, + /*alignment=*/32, + LinalgMarker({"_promote_views_aligned_"}, "_views_aligned_promoted_")); + applyPatternsAndFoldGreedily(funcOp, patterns); // Drop the marker. @@ -48,9 +139,15 @@ void TestLinalgTransforms::runOnFunction() { }); } +/// Apply transformations specified as patterns. +void TestLinalgTransforms::runOnFunction() { + if (testPatterns) + return applyPatterns(getFunction()); +} + namespace mlir { void registerTestLinalgTransforms() { - PassRegistration( + PassRegistration testTransformPatternsPass( "test-linalg-transform-patterns", "Test Linalg transformation patterns by applying them greedily."); }