forked from OSchip/llvm-project
[mlir][Linalg] Mostly NFC - Refactor Linalg patterns and transformations.
Linalg transformations are currently exposed as DRRs. Unfortunately RewriterGen does not play well with the line of work on named linalg ops which require variadic operands and results. Additionally, DRR is arguably not the right abstraction to expose compositions of such patterns that don't rely on SSA use-def semantics. This revision abandons DRRs and exposes manually written C++ patterns. Refactorings and cleanups are performed to uniformize APIs. This refactoring will allow replacing the currently manually specified Linalg named ops. A collateral victim of this refactoring is the `tileAndFuse` DRR, and the one associated test, which will be revived at a later time. Lastly, the following 2 tests do not add value and are altered: - a dot_perm tile + interchange test does not test anything new and is removed - a dot tile + lower to loops does not need 2-D tiling and is trimmed.
This commit is contained in:
parent
c49f83b6e9
commit
307cfdf533
|
@ -1,5 +1,4 @@
|
|||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
|
|
|
@ -1,7 +0,0 @@
|
|||
set(LLVM_TARGET_DEFINITIONS LinalgTransformPatterns.td)
|
||||
mlir_tablegen(LinalgTransformPatterns.h.inc -gen-rewriters)
|
||||
add_public_tablegen_target(MLIRLinalgTransformPatternsIncGen)
|
||||
|
||||
# Including Linalg in TableGen requires to depends on generated files
|
||||
add_dependencies(MLIRLinalgTransformPatternsIncGen LinalgOdsGen)
|
||||
|
|
@ -1,123 +0,0 @@
|
|||
//===- LinalgPatterns.td - Linalg transformation patterns --*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This is the pattern definition file for declarative Linalg transformation.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef LINALG_TRANSFORMS
|
||||
#define LINALG_TRANSFORMS
|
||||
|
||||
include "mlir/Dialect/Linalg/IR/LinalgOps.td"
|
||||
include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td"
|
||||
include "mlir/Dialect/Affine/IR/AffineOps.td"
|
||||
|
||||
def HasNoLinalgTransformMarker : CPred<[{
|
||||
!op.getAttrOfType<StringAttr>(LinalgTransforms::kLinalgTransformMarker)
|
||||
}]>;
|
||||
|
||||
class HasLinalgTransformMarker<string str> : CPred<[{
|
||||
op.getAttrOfType<StringAttr>(
|
||||
LinalgTransforms::kLinalgTransformMarker) &&
|
||||
op.getAttrOfType<StringAttr>(
|
||||
LinalgTransforms::kLinalgTransformMarker).getValue() == "}] # str # [{"}]>;
|
||||
|
||||
class IsProducedByOpOfType<string str> :
|
||||
CPred<"isProducedByOpOfType<" # str # ">(op, $0)">;
|
||||
|
||||
class AffineMapDomainHasDim<int n> : CPred<[{
|
||||
op.getAttrOfType<ArrayAttr>(getIndexingMapsAttrName()).getValue()[0].
|
||||
cast<AffineMapAttr>().getValue().getNumDims() ==}] # n # [{}]>;
|
||||
|
||||
class HasOperandsOfType<string type>: CPred<[{
|
||||
llvm::any_of(op.getOperands(),
|
||||
[](Value v) {
|
||||
return dyn_cast_or_null<}] # type # [{>(v.getDefiningOp());
|
||||
})
|
||||
}]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg fusion patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// In the future, tile sizes should be derived from op properties + machine
|
||||
// description but we do not need to wait on this to start having useful
|
||||
// patterns.
|
||||
class TileAndFuseLinalgOp<
|
||||
list<int> sizes, list<int> operandIndices, string value> : NativeCodeCall<
|
||||
"if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, op, {" #
|
||||
StrJoinInt<sizes>.result # "}, {" # StrJoinInt<operandIndices>.result # "}," #
|
||||
" \"" # value # "\")))" #
|
||||
" return failure();">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg tiling patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// In the future, tile sizes should be derived from op properties + machine
|
||||
// description but we do not need to wait on this to start having useful
|
||||
// patterns.
|
||||
// `permutation` is an optional parameter to specify the ordering of the
|
||||
// tiled loops. If provided, it must be a list of integers with the same number
|
||||
// of elements as `sizes`.
|
||||
class TileLinalgOp<list<int> sizes, string value, list<int> permutation=[]> :
|
||||
NativeCodeCall<
|
||||
"if (failed(tileLinalgOpAndSetMarker($_builder, op, {" #
|
||||
StrJoinInt<sizes>.result # "}, \"" # value # "\", {" #
|
||||
StrJoinInt<permutation>.result # "})))" #
|
||||
" return failure();">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg to loop patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
class LinalgOpToLoops<string OpType> : NativeCodeCall<
|
||||
"if (failed(linalgOpToLoops<" # OpType # ">($_builder, op))) " #
|
||||
" return failure();">;
|
||||
|
||||
class LinalgOpToParallelLoops<string OpType> : NativeCodeCall<
|
||||
"if (failed(linalgOpToParallelLoops<" # OpType # ">($_builder, op))) " #
|
||||
" return failure();">;
|
||||
|
||||
class LinalgOpToAffineLoops<string OpType> : NativeCodeCall<
|
||||
"if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, op))) " #
|
||||
" return failure();">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg to vector patterns precondition and DRR.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def PreconditionVectorizeLinalgOp : CPred<
|
||||
"succeeded(vectorizeLinalgOpPrecondition(op))">;
|
||||
def VectorizeLinalgOp : NativeCodeCall<
|
||||
"vectorizeLinalgOp($_builder, op)">;
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg generic permutation patterns precondition and DRR.
|
||||
//===----------------------------------------------------------------------===//
|
||||
class PreconditionPermuteGenericLinalgOp<list<int> permutation> : CPred<
|
||||
"succeeded(permuteGenericLinalgOpPrecondition(op, {" #
|
||||
StrJoinInt<permutation>.result # "}))">;
|
||||
class PermuteGenericLinalgOp<list<int> permutation, string value> :
|
||||
NativeCodeCall<
|
||||
"permuteGenericLinalgOp($_builder, op, {" # StrJoinInt<permutation>.result #
|
||||
"}, \"" # value # "\")">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg promote subview operands precondition and DRR.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def PreconditionPromoteSubviewsLinalgOp : CPred<
|
||||
"succeeded(promoteSubviewsLinalgOpPrecondition(op))">;
|
||||
def PromoteSubviewsLinalgOp : NativeCodeCall<
|
||||
"promoteSubviewsLinalgOp($_builder, op)">;
|
||||
|
||||
class PromoteSelectedSubviewsLinalgOp<list<int> operands, string marker="",
|
||||
int alignment=0> :
|
||||
NativeCodeCall<"promoteSelectedSubviewsLinalgOpAndSetMarker($_builder, op, {" #
|
||||
StrJoinInt<operands>.result # "}, \"" # marker # "\", " # alignment # ")">;
|
||||
|
||||
#endif // LINALG_TRANSFORMS
|
|
@ -1,137 +0,0 @@
|
|||
//===- LinalgTransforms.h - Linalg transformations as patterns --*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef DIALECT_LINALG_TRANSFORMS_LINALGTRANSFORMS_H_
|
||||
#define DIALECT_LINALG_TRANSFORMS_LINALGTRANSFORMS_H_
|
||||
|
||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace linalg {
|
||||
|
||||
// Marker used as attribute name in generated Linalg rewriting transformations.
|
||||
struct LinalgTransforms {
|
||||
static const StringLiteral kLinalgTransformMarker;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
// Implementation detail of isProducedByOpOfType avoids the need for explicit
|
||||
// template instantiations.
|
||||
bool isProducedByOpOfTypeImpl(Operation *consumerOp, Value consumedView,
|
||||
function_ref<bool(Operation *)> isaOpType);
|
||||
} // namespace detail
|
||||
|
||||
// Returns true if the `consumedView` value use in `consumerOp` is produced by
|
||||
// an op of type `OpTy`. This is used to implement use-def type information on
|
||||
// buffers.
|
||||
template <typename OpTy>
|
||||
bool isProducedByOpOfType(Operation *consumerOp, Value consumedView) {
|
||||
return detail::isProducedByOpOfTypeImpl(
|
||||
consumerOp, consumedView, [](Operation *op) { return isa<OpTy>(op); });
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// The following Declarative Rewrite Rule (DRR) helpers are used in rewrite
|
||||
// patterns. As such, they must not call into `rewriter.erase/replace` APIs and
|
||||
// it is the responsibility of the enclosing PatternRewriter to erase on
|
||||
// success.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tiles `op` by `sizes` permuting the loops according to `permutation` and
|
||||
/// sets the attribute `kLinalgTransformMarker` to `linalgMarker`. The
|
||||
/// permutation is expressed as a list of integers that specify the new ordering
|
||||
/// of the loop nest (using loop.for operations). The length of `permutation`
|
||||
/// must be equal to the length of `tileSizes`.
|
||||
/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with
|
||||
/// `permutation = [1,2,0]`. All values in `permutation` must be
|
||||
/// integers, in the range 0..`tileSizes.size()` without duplications
|
||||
/// (i.e. `[1,1,2]` is an invalid permutation). An empty list
|
||||
/// states for the identity permutation.
|
||||
LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<int64_t> sizes,
|
||||
StringRef linalgMarker,
|
||||
ArrayRef<unsigned> permutation);
|
||||
|
||||
/// Tiles ops similar to `tileLinalgOpAndSetMarker` but generates loop.parallel
|
||||
/// operations instead.
|
||||
LogicalResult tileLinalgOpToParallelLoopsAndSetMarker(
|
||||
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
|
||||
StringRef linalgMarker, ArrayRef<unsigned> permutation);
|
||||
|
||||
/// Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and
|
||||
/// sets the attribute `kLinalgTransformMarker` to `linalgMarker`.
|
||||
LogicalResult tileAndFuseLinalgOpAndSetMarker(
|
||||
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
|
||||
ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker);
|
||||
|
||||
/// Tiles ops similar to `tileAndFuseLinalgOpAndSetMarker` but generates
|
||||
/// loop.parallel operations instead.
|
||||
LogicalResult tileAndFuseLinalgOpToParallelLoopsAndSetMarker(
|
||||
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
|
||||
ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker);
|
||||
|
||||
using LinalgLoops = SmallVector<Operation *, 4>;
|
||||
|
||||
/// Emits a loop nest of with the proper body for `op`.
|
||||
template <typename LoopTy, typename ConcreteOp>
|
||||
Optional<LinalgLoops> linalgLowerOpToLoops(PatternRewriter &rewriter,
|
||||
Operation *op);
|
||||
|
||||
/// Emits a loop nest of `loop.for` with the proper body for `op`.
|
||||
template <typename ConcreteOp>
|
||||
LogicalResult linalgOpToLoops(PatternRewriter &rewriter, Operation *op);
|
||||
|
||||
/// Emits a loop nest of `loop.parallel` with the proper body for `op`.
|
||||
template <typename ConcreteOp>
|
||||
LogicalResult linalgOpToParallelLoops(PatternRewriter &rewriter, Operation *op);
|
||||
|
||||
/// Emits a loop nest of `affine.for` with the proper body for `op`.
|
||||
template <typename ConcreteOp>
|
||||
LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op);
|
||||
|
||||
/// Rewrite a linalg.generic into a suitable vector.contraction op.
|
||||
LogicalResult vectorizeLinalgOpPrecondition(Operation *op);
|
||||
SmallVector<Value, 0> vectorizeLinalgOp(PatternRewriter &rewriter,
|
||||
Operation *op);
|
||||
|
||||
/// Emits a `generic` or `indexed_generic` operation with the `indexing_maps`
|
||||
/// and `iterator_types` permutated according to `permutation`.
|
||||
LogicalResult
|
||||
permuteGenericLinalgOpPrecondition(Operation *op,
|
||||
ArrayRef<unsigned> permutation);
|
||||
SmallVector<Value, 0> permuteGenericLinalgOp(PatternRewriter &rewriter,
|
||||
Operation *op,
|
||||
ArrayRef<unsigned> permutation,
|
||||
StringRef linalgMarker);
|
||||
|
||||
/// Promote std.subviews feeding linalg operations.
|
||||
LogicalResult promoteSubviewsLinalgOpPrecondition(Operation *op);
|
||||
SmallVector<Value, 0> promoteSubviewsLinalgOp(PatternRewriter &rewriter,
|
||||
Operation *op);
|
||||
|
||||
/// Similar to `promoteSubviewsLinalgOp` but only tries to promote
|
||||
/// the views corresponding to the operands specified in
|
||||
/// `operandIndicesToPromote`. Generated allocations are memory-aligned
|
||||
/// according to the `alignment` parameter.
|
||||
/// If linalgMarker is specified and the transformation is successfull
|
||||
/// sets the attribute `kLinalgTransformMarker` to `linalgMarker`.
|
||||
SmallVector<Value, 0> promoteSelectedSubviewsLinalgOpAndSetMarker(
|
||||
PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<int64_t> operandIndicesToPromote, StringRef linalgMarker = "",
|
||||
int64_t alignment = 0);
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
#endif // DIALECT_LINALG_TRANSFORMS_LINALGTRANSFORMS_H_
|
|
@ -0,0 +1,375 @@
|
|||
//===- Transforms.h - Linalg transformations as patterns --------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_
|
||||
#define DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_
|
||||
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace linalg {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Transformations exposed as function calls.
|
||||
//===----------------------------------------------------------------------===//
|
||||
using LinalgLoops = SmallVector<Operation *, 4>;
|
||||
|
||||
struct TiledLinalgOp {
|
||||
LinalgOp op;
|
||||
SmallVector<Operation *, 8> loops;
|
||||
};
|
||||
|
||||
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
|
||||
/// and permute the loop nest according to `interchangeVector`
|
||||
/// The permutation is expressed as a list of integers that specify
|
||||
/// the new ordering of the loop nest. The length of `interchangeVector`
|
||||
/// must be equal to the length of `tileSizes`.
|
||||
/// An empty vector is interpreted as the identity permutation and the
|
||||
/// transformation returns early.
|
||||
///
|
||||
/// When non-null, the optional pointer `folder` is used to call into the
|
||||
/// `createAndFold` builder method. If `folder` is null, the regular `create`
|
||||
/// method is called.
|
||||
///
|
||||
/// Returns a struct containing the tiled loops in the specified order
|
||||
/// and the cloned op if successful, llvm::None otherwise.
|
||||
///
|
||||
/// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed by
|
||||
/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
|
||||
/// integers, in the range 0..`tileSizes.size()` without duplications
|
||||
/// (i.e. `[1,1,2]` is an invalid permutation).
|
||||
Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
|
||||
ArrayRef<Value> tileSizes,
|
||||
ArrayRef<unsigned> interchangeVector = {},
|
||||
OperationFolder *folder = nullptr);
|
||||
Optional<TiledLinalgOp>
|
||||
tileLinalgOpToParallelLoops(OpBuilder &b, LinalgOp op,
|
||||
ArrayRef<Value> tileSizes,
|
||||
ArrayRef<unsigned> interchangeVector = {},
|
||||
OperationFolder *folder = nullptr);
|
||||
|
||||
/// Performs standalone tiling of a single LinalgOp by constant `tileSizes`.
|
||||
/// See `tileLinalgOp(... ArrayRef<Value> tileSizes,)` for more details
|
||||
Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
|
||||
ArrayRef<int64_t> tileSizes,
|
||||
ArrayRef<unsigned> interchangeVector = {},
|
||||
OperationFolder *folder = nullptr);
|
||||
Optional<TiledLinalgOp>
|
||||
tileLinalgOpToParallelLoops(OpBuilder &b, LinalgOp op,
|
||||
ArrayRef<int64_t> tileSizes,
|
||||
ArrayRef<unsigned> interchangeVector = {},
|
||||
OperationFolder *folder = nullptr);
|
||||
|
||||
/// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`.
|
||||
/// This is an in-place transformation controlled by `interchangeVector`.
|
||||
/// An empty vector is interpreted as the identity permutation and the
|
||||
/// transformation returns early.
|
||||
///
|
||||
/// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed with
|
||||
/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
|
||||
/// integers, in the range 0..`op.rank` without duplications
|
||||
/// (i.e. `[1,1,2]` is an invalid permutation).
|
||||
LinalgOp interchange(LinalgOp op, ArrayRef<unsigned> interchangeVector);
|
||||
|
||||
/// Promotes the `subViews` into a new buffer allocated at the insertion point
|
||||
/// `b`. Promotion occurs in 3 steps:
|
||||
/// 1. Create a new buffer for a full tile (i.e. not clipped at the boundary).
|
||||
/// 2. Take a full view on the buffer and `linalg.fill` it with zeros (use
|
||||
/// float zero for now).
|
||||
/// 3. Take a partial slice of the full view in step 2. and copy into it.
|
||||
/// Infers statically sized buffers from subViews unless `dynamicBuffers` is
|
||||
/// true.
|
||||
///
|
||||
/// Returns a list of PromotionInfo which hold the promoted buffer and the
|
||||
/// full and partial views indexing into the buffer.
|
||||
// TODO: revisit dynamicBuffers option.
|
||||
LinalgOp promoteSubViewOperands(OpBuilder &b, LinalgOp op,
|
||||
llvm::SetVector<Value> subViews,
|
||||
bool dynamicBuffers = false,
|
||||
int64_t alignment = 0,
|
||||
OperationFolder *folder = nullptr);
|
||||
|
||||
/// Emit a suitable vector form for a Linalg op with fully static shape.
|
||||
void vectorizeLinalgOp(OpBuilder &builder, Operation *op);
|
||||
|
||||
/// Emits a loop nest of `LoopTy` with the proper body for `op`.
|
||||
template <typename LoopTy, typename ConcreteOp>
|
||||
Optional<LinalgLoops> linalgLowerOpToLoops(OpBuilder &builder, Operation *op);
|
||||
|
||||
/// Emits a loop nest of `loop.for` with the proper body for `op`.
|
||||
template <typename ConcreteOp>
|
||||
LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op);
|
||||
|
||||
/// Emits a loop nest of `loop.parallel` with the proper body for `op`.
|
||||
template <typename ConcreteOp>
|
||||
LogicalResult linalgOpToParallelLoops(OpBuilder &builder, Operation *op);
|
||||
|
||||
/// Emits a loop nest of `affine.for` with the proper body for `op`.
|
||||
template <typename ConcreteOp>
|
||||
LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Preconditions that ensure the corresponding transformation suceeds and can be
|
||||
// applied as a rewrite pattern.
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Emits a `generic` or `indexed_generic` operation with the `indexing_maps`
|
||||
/// and `iterator_types` permutated according to `permutation`.
|
||||
LogicalResult
|
||||
interchangeGenericLinalgOpPrecondition(Operation *op,
|
||||
ArrayRef<unsigned> interchangeVector);
|
||||
|
||||
/// Promote std.subviews feeding linalg operations.
|
||||
LogicalResult promoteSubviewsLinalgOpPrecondition(
|
||||
Operation *op, Optional<DenseSet<unsigned>> operandIndicesToPromote = None);
|
||||
|
||||
/// Rewrite a linalg.generic into a suitable vector.contraction op.
|
||||
LogicalResult vectorizeLinalgOpPrecondition(Operation *op);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Transformations exposed as rewrite patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Marker used as attribute name in generated Linalg rewriting transformations.
|
||||
struct LinalgTransforms {
|
||||
static const StringLiteral kLinalgTransformMarker;
|
||||
};
|
||||
|
||||
/// Helper class to control common attribute matching and setting behavior.
|
||||
struct LinalgMarker {
|
||||
LinalgMarker(ArrayRef<StringRef> matchDisjunction = {},
|
||||
Optional<StringRef> replacement = None);
|
||||
LinalgMarker(ArrayRef<StringRef> matchDisjunction, StringRef replacement);
|
||||
LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
|
||||
void replaceLinalgMarker(PatternRewriter &rewriter, Operation *op) const;
|
||||
|
||||
private:
|
||||
SmallVector<StringRef, 4> matchDisjunction;
|
||||
Optional<StringRef> replacement;
|
||||
};
|
||||
|
||||
///
|
||||
/// Linalg tiling patterns.
|
||||
///
|
||||
/// Apply the `tileLinalgOp` transformation as a pattern.
|
||||
/// `marker` controls LinalgTransformMarker matching and update when specified.
|
||||
/// See `tileLinalgOp` for more details.
|
||||
enum class LinalgTilingLoopType {
|
||||
Loops = 0,
|
||||
AffineLoops = 1,
|
||||
ParallelLoops = 2
|
||||
};
|
||||
struct LinalgTilingOptions {
|
||||
/// The tile sizes by which to tile.
|
||||
SmallVector<int64_t, 4> tileSizes{};
|
||||
LinalgTilingOptions &setTileSizes(ArrayRef<int64_t> ts) {
|
||||
tileSizes.assign(ts.begin(), ts.end());
|
||||
return *this;
|
||||
}
|
||||
/// The interchange vector to reorder the tiled loops.
|
||||
SmallVector<unsigned, 4> interchangeVector{};
|
||||
LinalgTilingOptions &setInterchange(ArrayRef<unsigned> interchange) {
|
||||
interchangeVector.assign(interchange.begin(), interchange.end());
|
||||
return *this;
|
||||
}
|
||||
/// The type of tile loops to generate.
|
||||
LinalgTilingLoopType loopType{LinalgTilingLoopType::Loops};
|
||||
LinalgTilingOptions &setLoopType(LinalgTilingLoopType lt) {
|
||||
loopType = lt;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
struct LinalgBaseTilingPattern : public RewritePattern {
|
||||
LinalgBaseTilingPattern(StringRef opName, MLIRContext *context,
|
||||
LinalgTilingOptions options,
|
||||
LinalgMarker marker = LinalgMarker(),
|
||||
PatternBenefit benefit = 1);
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
/// LinalgTransformMarker handles special attribute manipulations.
|
||||
LinalgMarker marker;
|
||||
/// Options to control tiling;
|
||||
LinalgTilingOptions options;
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct LinalgTilingPattern : public LinalgBaseTilingPattern {
|
||||
LinalgTilingPattern(MLIRContext *context, LinalgTilingOptions options,
|
||||
LinalgMarker marker = LinalgMarker(),
|
||||
PatternBenefit benefit = 1)
|
||||
: LinalgBaseTilingPattern(OpTy::getOperationName(), context, options,
|
||||
marker, benefit) {}
|
||||
};
|
||||
|
||||
///
|
||||
/// Linalg interchange patterns.
|
||||
///
|
||||
/// Apply the `interchange` transformation as a pattern.
|
||||
/// `marker` controls LinalgTransformMarker matching and update when specified.
|
||||
/// See `interchange` for more details.
|
||||
struct LinalgBaseInterchangePattern : public RewritePattern {
|
||||
LinalgBaseInterchangePattern(StringRef opName, MLIRContext *context,
|
||||
ArrayRef<unsigned> interchangeVector,
|
||||
LinalgMarker marker = LinalgMarker(),
|
||||
PatternBenefit benefit = 1);
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
/// LinalgTransformMarker handles special attribute manipulations.
|
||||
LinalgMarker marker;
|
||||
/// The interchange vector to reorder the iterators and indexing_maps dims.
|
||||
SmallVector<unsigned, 8> interchangeVector;
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct LinalgInterchangePattern : public LinalgBaseInterchangePattern {
|
||||
LinalgInterchangePattern(MLIRContext *context,
|
||||
ArrayRef<unsigned> interchangeVector,
|
||||
LinalgMarker marker = LinalgMarker(),
|
||||
PatternBenefit benefit = 1)
|
||||
: LinalgBaseInterchangePattern(OpTy::getOperationName(), context,
|
||||
interchangeVector, marker, benefit) {}
|
||||
};
|
||||
|
||||
///
|
||||
/// Linalg promotion patterns.
|
||||
///
|
||||
/// Apply the `promoteSubViewOperands` transformation as a pattern.
|
||||
/// `marker` controls LinalgTransformMarker matching and update when specified.
|
||||
/// See `promoteSubViewOperands` for more details.
|
||||
struct LinalgBasePromotionPattern : public RewritePattern {
|
||||
LinalgBasePromotionPattern(StringRef opName, MLIRContext *context,
|
||||
ArrayRef<unsigned> operandsToPromote = {},
|
||||
unsigned alignment = 0,
|
||||
LinalgMarker marker = LinalgMarker(),
|
||||
PatternBenefit benefit = 1);
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
/// LinalgTransformMarker handles special attribute manipulations.
|
||||
LinalgMarker marker;
|
||||
/// Indices of subViews to promote.
|
||||
SmallVector<unsigned, 4> operandsToPromote;
|
||||
/// Alignment of promoted buffer.
|
||||
unsigned alignment;
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
|
||||
LinalgPromotionPattern(MLIRContext *context,
|
||||
ArrayRef<unsigned> operandsToPromote = {},
|
||||
unsigned alignment = 0,
|
||||
LinalgMarker marker = LinalgMarker(),
|
||||
PatternBenefit benefit = 1)
|
||||
: LinalgBasePromotionPattern(OpTy::getOperationName(), context,
|
||||
operandsToPromote, alignment, marker,
|
||||
benefit) {}
|
||||
LinalgPromotionPattern(MLIRContext *context,
|
||||
ArrayRef<unsigned> operandsToPromote,
|
||||
LinalgMarker marker = LinalgMarker(),
|
||||
PatternBenefit benefit = 1)
|
||||
: LinalgPromotionPattern(context, operandsToPromote, 0, marker, benefit) {
|
||||
}
|
||||
LinalgPromotionPattern(MLIRContext *context, unsigned alignment,
|
||||
LinalgMarker marker = LinalgMarker(),
|
||||
PatternBenefit benefit = 1)
|
||||
: LinalgPromotionPattern(context, {}, alignment, marker, benefit) {}
|
||||
LinalgPromotionPattern(MLIRContext *context, LinalgMarker marker,
|
||||
PatternBenefit benefit = 1)
|
||||
: LinalgPromotionPattern(context, {}, 0, marker, benefit) {}
|
||||
};
|
||||
|
||||
///
|
||||
/// Linalg vectorization patterns.
|
||||
///
|
||||
/// Apply the `vectorizeLinalgOp` transformation as a pattern.
|
||||
/// `marker` controls LinalgTransformMarker matching and update when specified.
|
||||
/// See `vectorizeLinalgOp` for more details.
|
||||
struct LinalgBaseVectorizationPattern : public RewritePattern {
|
||||
LinalgBaseVectorizationPattern(StringRef opName, MLIRContext *context,
|
||||
LinalgMarker marker = LinalgMarker(),
|
||||
PatternBenefit benefit = 1);
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
/// LinalgTransformMarker handles special attribute manipulations.
|
||||
LinalgMarker marker;
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern {
|
||||
LinalgVectorizationPattern(MLIRContext *context,
|
||||
LinalgMarker marker = LinalgMarker(),
|
||||
PatternBenefit benefit = 1)
|
||||
: LinalgBaseVectorizationPattern(OpTy::getOperationName(), context,
|
||||
marker, benefit) {}
|
||||
};
|
||||
|
||||
///
|
||||
/// Linalg lowering patterns.
|
||||
///
|
||||
/// Apply the `linalgLowerOpToLoops` transformation as a pattern.
|
||||
/// `marker` controls LinalgTransformMarker matching and update when specified.
|
||||
/// See `linalgLowerOpToLoops` for more details.
|
||||
enum class LinalgLoweringType {
|
||||
LibraryCall = 0,
|
||||
Loops = 1,
|
||||
AffineLoops = 2,
|
||||
ParallelLoops = 3
|
||||
};
|
||||
template <typename OpTy>
|
||||
struct LinalgLoweringPattern : public RewritePattern {
|
||||
LinalgLoweringPattern(MLIRContext *context, LinalgLoweringType loweringType,
|
||||
LinalgMarker marker = LinalgMarker(),
|
||||
PatternBenefit benefit = 1)
|
||||
: RewritePattern(OpTy::getOperationName(), {}, benefit, context),
|
||||
marker(marker), loweringType(loweringType) {}
|
||||
// TODO: Move implementation to .cpp once named ops are auto-generated.
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
|
||||
if (!linalgOp)
|
||||
return failure();
|
||||
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
|
||||
return failure();
|
||||
if (failed(promoteSubviewsLinalgOpPrecondition(op)))
|
||||
return failure();
|
||||
|
||||
if (loweringType == LinalgLoweringType::LibraryCall) {
|
||||
// TODO: Move lowering to library calls here.
|
||||
return failure();
|
||||
} else if (loweringType == LinalgLoweringType::Loops) {
|
||||
if (failed(linalgOpToLoops<OpTy>(rewriter, op)))
|
||||
return failure();
|
||||
} else if (loweringType == LinalgLoweringType::AffineLoops) {
|
||||
if (failed(linalgOpToAffineLoops<OpTy>(rewriter, op)))
|
||||
return failure();
|
||||
} else if (failed(linalgOpToParallelLoops<OpTy>(rewriter, op))) {
|
||||
return failure();
|
||||
}
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
/// LinalgTransformMarker handles special attribute manipulations.
|
||||
LinalgMarker marker;
|
||||
/// Controls whether the pattern lowers to library calls, loop.for, affine.for
|
||||
/// or loop.parallel.
|
||||
LinalgLoweringType loweringType;
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
#endif // DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_
|
|
@ -101,63 +101,6 @@ SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
|
|||
AffineMap map, ArrayRef<Value> values,
|
||||
OperationFolder *folder = nullptr);
|
||||
|
||||
struct TiledLinalgOp {
|
||||
LinalgOp op;
|
||||
SmallVector<Operation *, 8> loops;
|
||||
};
|
||||
|
||||
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
|
||||
/// and permute the loop nest according to `permutation`
|
||||
/// The permutation is expressed as a list of integers that specify
|
||||
/// the new ordering of the loop nest. The length of `permutation`
|
||||
/// must be equal to the length of `tileSizes`.
|
||||
/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with
|
||||
/// `permutation = [1,2,0]`. All values in `permutation` must be
|
||||
/// integers, in the range 0..`tileSizes.size()` without duplications
|
||||
/// (i.e. `[1,1,2]` is an invalid permutation). An empty list
|
||||
/// states for the identity permutation.
|
||||
/// Returns a struct containing the tiled loops in the specified order
|
||||
/// and the cloned op if successful, llvm::None otherwise.
|
||||
/// When non-null, the optional pointer `folder` is used to call into the
|
||||
/// `createAndFold` builder method. If `folder` is null, the regular `create`
|
||||
/// method is called.
|
||||
Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
|
||||
ArrayRef<Value> tileSizes,
|
||||
ArrayRef<unsigned> permutation = {},
|
||||
OperationFolder *folder = nullptr);
|
||||
Optional<TiledLinalgOp> tileLinalgOpToParallelLoops(
|
||||
OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
|
||||
ArrayRef<unsigned> permutation = {}, OperationFolder *folder = nullptr);
|
||||
|
||||
/// Performs standalone tiling of a single LinalgOp by constant `tileSizes`.
|
||||
/// and permute the loop nest according to `permutation`
|
||||
/// The permutation is expressed as a list of integers that specify
|
||||
/// the new ordering of the loop nest. The length of `permutation`
|
||||
/// must be equal to the length of `tileSizes`.
|
||||
/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with
|
||||
/// `permutation = [1,2,0]`. All values in `permutation` must be
|
||||
/// integers, in the range 0..`tileSizes.size()` without duplications
|
||||
/// (i.e. `[1,1,2]` is an invalid permutation). An empty list
|
||||
/// states for the identity permutation.
|
||||
/// Returns a struct containing the tiled loops in the specified order
|
||||
/// and the cloned op if successful, llvm::None otherwise.
|
||||
/// When non-null, the optional pointer `folder` is used to call into the
|
||||
/// `createAndFold` builder method. If `folder` is null, the regular `create`
|
||||
/// method is called.
|
||||
Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
|
||||
ArrayRef<int64_t> tileSizes,
|
||||
ArrayRef<unsigned> permutation = {},
|
||||
OperationFolder *folder = nullptr);
|
||||
Optional<TiledLinalgOp> tileLinalgOpToParallelLoops(
|
||||
OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
|
||||
ArrayRef<unsigned> permutation = {}, OperationFolder *folder = nullptr);
|
||||
|
||||
template <typename... Args>
|
||||
Optional<TiledLinalgOp> tileLinalgOperation(OpBuilder &b, Operation *op,
|
||||
Args... args) {
|
||||
return tileLinalgOp(b, cast<LinalgOp>(op), args...);
|
||||
}
|
||||
|
||||
struct PromotionInfo {
|
||||
Value buffer;
|
||||
Value fullLocalView;
|
||||
|
@ -198,17 +141,6 @@ void applyPermutationToVector(SmallVector<T, N> &inVec,
|
|||
inVec = auxVec;
|
||||
}
|
||||
|
||||
/// Prepares the SubView promotion later performed by `promoteSubViews`
|
||||
/// (where most of the transformation happens). It arranges the new
|
||||
/// operands for `LinalgOp op` and deallocates the new buffer(s)
|
||||
/// It is the entry point for declarative transformation
|
||||
/// Returns the cloned `LinalgOp` with the new operands
|
||||
LinalgOp promoteSubViewOperands(OpBuilder &b, LinalgOp op,
|
||||
llvm::SetVector<Value> subViews,
|
||||
bool dynamicBuffers = false,
|
||||
int64_t alignment = 0,
|
||||
OperationFolder *folder = nullptr);
|
||||
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
add_mlir_dialect_library(MLIRLinalgTransforms
|
||||
Fusion.cpp
|
||||
LinalgTransforms.cpp
|
||||
LinalgToLoops.cpp
|
||||
Interchange.cpp
|
||||
Loops.cpp
|
||||
Promotion.cpp
|
||||
Tiling.cpp
|
||||
Transforms.cpp
|
||||
Vectorization.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
|
||||
|
@ -11,7 +13,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
|
|||
DEPENDS
|
||||
intrinsics_gen
|
||||
MLIRLinalgPassIncGen
|
||||
MLIRLinalgTransformPatternsIncGen
|
||||
)
|
||||
target_link_libraries(MLIRLinalgTransforms
|
||||
PUBLIC
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
//===- Interchange.cpp - Linalg interchange transformation ----------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements the linalg interchange transformation.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <type_traits>
|
||||
|
||||
#define DEBUG_TYPE "linalg-interchange"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition(
|
||||
Operation *op, ArrayRef<unsigned> interchangeVector) {
|
||||
if (interchangeVector.empty())
|
||||
return failure();
|
||||
// Transformation applies to generic ops only.
|
||||
if (!isa<GenericOp>(op) && !isa<IndexedGenericOp>(op))
|
||||
return failure();
|
||||
LinalgOp linOp = cast<LinalgOp>(op);
|
||||
// Transformation applies to buffers only.
|
||||
if (!linOp.hasBufferSemantics())
|
||||
return failure();
|
||||
// Permutation must be applicable.
|
||||
if (linOp.getIndexingMap(0).getNumInputs() != interchangeVector.size())
|
||||
return failure();
|
||||
// Permutation map must be invertible.
|
||||
if (!inversePermutation(
|
||||
AffineMap::getPermutationMap(interchangeVector, op->getContext())))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
LinalgOp mlir::linalg::interchange(LinalgOp op,
|
||||
ArrayRef<unsigned> interchangeVector) {
|
||||
if (interchangeVector.empty())
|
||||
return op;
|
||||
|
||||
MLIRContext *context = op.getContext();
|
||||
auto permutationMap = inversePermutation(
|
||||
AffineMap::getPermutationMap(interchangeVector, context));
|
||||
assert(permutationMap && "expected permutation to be invertible");
|
||||
SmallVector<Attribute, 4> newIndexingMaps;
|
||||
auto indexingMaps = op.indexing_maps().getValue();
|
||||
for (unsigned i = 0, e = op.getNumInputsAndOutputs(); i != e; ++i) {
|
||||
AffineMap m = indexingMaps[i].cast<AffineMapAttr>().getValue();
|
||||
if (!permutationMap.isEmpty())
|
||||
m = m.compose(permutationMap);
|
||||
newIndexingMaps.push_back(AffineMapAttr::get(m));
|
||||
}
|
||||
auto itTypes = op.iterator_types().getValue();
|
||||
SmallVector<Attribute, 4> itTypesVector;
|
||||
for (unsigned i = 0, e = itTypes.size(); i != e; ++i)
|
||||
itTypesVector.push_back(itTypes[i]);
|
||||
applyPermutationToVector(itTypesVector, interchangeVector);
|
||||
|
||||
op.setAttr(getIndexingMapsAttrName(),
|
||||
ArrayAttr::get(newIndexingMaps, context));
|
||||
op.setAttr(getIteratorTypesAttrName(),
|
||||
ArrayAttr::get(itTypesVector, context));
|
||||
|
||||
return op;
|
||||
}
|
|
@ -1,381 +0,0 @@
|
|||
//===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements logic for transforming Linalg operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h"
|
||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <type_traits>
|
||||
|
||||
#define DEBUG_TYPE "linalg-transforms"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
using llvm::dbgs;
|
||||
using llvm::SetVector;
|
||||
|
||||
// Marker used as attribute name in generated Linalg rewriting transformations.
|
||||
const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
|
||||
"__internal_linalg_transform__";
|
||||
|
||||
using TileFn = Optional<TiledLinalgOp>(OpBuilder &, LinalgOp, ArrayRef<int64_t>,
|
||||
ArrayRef<unsigned>, OperationFolder *);
|
||||
|
||||
static LogicalResult
|
||||
tileLinalgOpAndSetMarkerImpl(TileFn tileFn, PatternRewriter &rewriter,
|
||||
Operation *op, ArrayRef<int64_t> sizes,
|
||||
StringRef linalgMarker,
|
||||
ArrayRef<unsigned> permutation) {
|
||||
assert(permutation.empty() || permutation.size() == sizes.size());
|
||||
auto tileRes = tileFn(rewriter, op, sizes, permutation, /*folder=*/nullptr);
|
||||
if (!tileRes)
|
||||
return failure();
|
||||
tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker,
|
||||
rewriter.getStringAttr(linalgMarker));
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult mlir::linalg::tileLinalgOpAndSetMarker(
|
||||
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
|
||||
StringRef linalgMarker, ArrayRef<unsigned> permutation) {
|
||||
return tileLinalgOpAndSetMarkerImpl(tileLinalgOp, rewriter, op, sizes,
|
||||
linalgMarker, permutation);
|
||||
}
|
||||
LogicalResult mlir::linalg::tileLinalgOpToParallelLoopsAndSetMarker(
|
||||
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
|
||||
StringRef linalgMarker, ArrayRef<unsigned> permutation) {
|
||||
return tileLinalgOpAndSetMarkerImpl(tileLinalgOpToParallelLoops, rewriter, op,
|
||||
sizes, linalgMarker, permutation);
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
tileAndFuseLinalgOpAndSetMarkerImpl(TileFn tileFn, PatternRewriter &rewriter,
|
||||
Operation *op, ArrayRef<int64_t> sizes,
|
||||
ArrayRef<int64_t> operandIndicesToFuse,
|
||||
StringRef linalgMarker) {
|
||||
auto tileRes =
|
||||
tileFn(rewriter, op, sizes, /*permutation=*/{}, /*folder=*/nullptr);
|
||||
if (!tileRes)
|
||||
return failure();
|
||||
tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker,
|
||||
rewriter.getStringAttr(linalgMarker));
|
||||
Aliases aliases;
|
||||
auto G = LinalgDependenceGraph::buildDependenceGraph(
|
||||
aliases, op->getParentOfType<FuncOp>());
|
||||
SmallVector<Operation *, 4> originalProducers;
|
||||
for (auto operandIdx : operandIndicesToFuse) {
|
||||
auto fusionRes = fuseProducerOf(rewriter, tileRes->op, operandIdx, G);
|
||||
if (!fusionRes) {
|
||||
// Linalg fusion requires tiled loops to even determine whether it is
|
||||
// possible to fuse. As a consequence, the pattern may fail even though a
|
||||
// tiled version of op has already been introduced.
|
||||
// So we need to remove the tiled version ourselves in case of failure.
|
||||
// Another possibility is to ensure the constraints on the pattern
|
||||
// guarantee that fusion will occur and just assert here. As we develop
|
||||
// more complex patterns we can choose what is best.
|
||||
rewriter.eraseOp(tileRes->loops[0]);
|
||||
return failure();
|
||||
}
|
||||
fusionRes->fusedProducer.setAttr(LinalgTransforms::kLinalgTransformMarker,
|
||||
rewriter.getStringAttr(linalgMarker));
|
||||
originalProducers.push_back(fusionRes->originalProducer);
|
||||
}
|
||||
|
||||
// The originalProducers can now be safely erased. This is similar to
|
||||
// SSA-value use-def but in the world of buffer + structured ops.
|
||||
for (auto *originalProducer : originalProducers)
|
||||
rewriter.eraseOp(originalProducer);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker(
|
||||
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
|
||||
ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker) {
|
||||
return tileAndFuseLinalgOpAndSetMarkerImpl(
|
||||
tileLinalgOp, rewriter, op, sizes, operandIndicesToFuse, linalgMarker);
|
||||
}
|
||||
LogicalResult mlir::linalg::tileAndFuseLinalgOpToParallelLoopsAndSetMarker(
|
||||
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
|
||||
ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker) {
|
||||
return tileAndFuseLinalgOpAndSetMarkerImpl(
|
||||
tileLinalgOpToParallelLoops, rewriter, op, sizes, operandIndicesToFuse,
|
||||
linalgMarker);
|
||||
}
|
||||
|
||||
bool mlir::linalg::detail::isProducedByOpOfTypeImpl(
|
||||
Operation *consumerOp, Value consumedView,
|
||||
function_ref<bool(Operation *)> isaOpType) {
|
||||
LinalgOp consumer = dyn_cast<LinalgOp>(consumerOp);
|
||||
assert(consumer.hasBufferSemantics() &&
|
||||
"expected linalg op with buffer semantics");
|
||||
if (!consumer)
|
||||
return false;
|
||||
|
||||
auto maybeConsumerIndex = consumer.getIndexOfInput(consumedView);
|
||||
if (!maybeConsumerIndex)
|
||||
return false;
|
||||
|
||||
Aliases aliases;
|
||||
auto G = LinalgDependenceGraph::buildDependenceGraph(
|
||||
aliases, consumer.getParentOfType<FuncOp>());
|
||||
for (auto dependence : G.getDependencesInto(
|
||||
consumer, LinalgDependenceGraph::DependenceType::RAW)) {
|
||||
auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
|
||||
if (!isProducerLastWriteOfView(G, consumer, consumedView, producer))
|
||||
continue;
|
||||
if (isaOpType(dependence.dependentOpView.op))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
//============================================================================//
|
||||
// Precondition and transformation for vectorization of Linalg generic ops.
|
||||
//============================================================================//
|
||||
static bool hasMultiplyAddBody(linalg::GenericOp op) {
|
||||
auto &r = op.region();
|
||||
if (r.empty())
|
||||
return false;
|
||||
if (r.getBlocks().size() != 1)
|
||||
return false;
|
||||
auto &ops = r.front().getOperations();
|
||||
if (ops.size() != 3)
|
||||
return false;
|
||||
|
||||
using mlir::matchers::m_Val;
|
||||
auto a = m_Val(r.front().getArgument(0));
|
||||
auto b = m_Val(r.front().getArgument(1));
|
||||
auto c = m_Val(r.front().getArgument(2));
|
||||
// TODO(ntv) Update this detection once we have matcher support for
|
||||
// specifying that any permutation of operands matches.
|
||||
auto pattern1 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c));
|
||||
auto pattern2 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b)));
|
||||
auto pattern3 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c));
|
||||
auto pattern4 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a)));
|
||||
return pattern1.match(&ops.back()) || pattern2.match(&ops.back()) ||
|
||||
pattern3.match(&ops.back()) || pattern4.match(&ops.back());
|
||||
}
|
||||
|
||||
// TODO(ntv) should be Tablegen'd from a single source that generates the op
|
||||
// itself.
|
||||
static bool isRowMajorMatmul(linalg::GenericOp genericOp) {
|
||||
return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
|
||||
isRowMajorMatmul(genericOp.indexing_maps()) &&
|
||||
hasMultiplyAddBody(genericOp);
|
||||
}
|
||||
|
||||
// TODO(ntv, ataei): This is in fact much more general than just vectorization
|
||||
// for matmul and fill ops.
|
||||
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
|
||||
auto linalgOp = cast<linalg::LinalgOp>(op);
|
||||
// All types must be static shape to go to vector.
|
||||
for (Value operand : linalgOp.getInputsAndOutputBuffers())
|
||||
if (!operand.getType().cast<ShapedType>().hasStaticShape())
|
||||
return failure();
|
||||
for (Type outputTensorType : linalgOp.getOutputTensorTypes())
|
||||
if (!outputTensorType.cast<ShapedType>().hasStaticShape())
|
||||
return failure();
|
||||
if (isa<linalg::MatmulOp>(op) || isa<linalg::FillOp>(op))
|
||||
return success();
|
||||
|
||||
auto genericOp = dyn_cast<linalg::GenericOp>(op);
|
||||
if (!genericOp || !::isRowMajorMatmul(genericOp))
|
||||
return failure();
|
||||
|
||||
// TODO(ntv): non-identity layout.
|
||||
auto isStaticMemRefWithIdentityLayout = [](Value v) {
|
||||
auto m = v.getType().dyn_cast<MemRefType>();
|
||||
if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty())
|
||||
return false;
|
||||
return true;
|
||||
};
|
||||
if (!llvm::all_of(genericOp.getInputsAndOutputBuffers(),
|
||||
isStaticMemRefWithIdentityLayout))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<Value, 0> mlir::linalg::vectorizeLinalgOp(PatternRewriter &rewriter,
|
||||
Operation *op) {
|
||||
assert(succeeded(vectorizeLinalgOpPrecondition(op)) &&
|
||||
"DRR failure case must be a precondition");
|
||||
auto linalgOp = cast<linalg::LinalgOp>(op);
|
||||
assert(linalgOp.hasBufferSemantics() &&
|
||||
"expected linalg op with buffer semantics");
|
||||
if (auto convOp = dyn_cast<linalg::ConvOp>(op)) {
|
||||
// TODO(ntv): add a level of indirection to linalg.generic.
|
||||
if (convOp.padding())
|
||||
llvm_unreachable("Unexpected conv with padding");
|
||||
}
|
||||
|
||||
edsc::ScopedContext scope(rewriter, op->getLoc());
|
||||
|
||||
if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
|
||||
// Vectorize fill as a vector.broadcast.
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
|
||||
"]: Rewrite linalg.fill as vector.broadcast: "
|
||||
<< *op << ":\n");
|
||||
auto dstMemrefVec = vector_type_cast(fillOp.getOutputBuffer(0));
|
||||
Value dstVec = std_load(dstMemrefVec);
|
||||
auto resVec = vector_broadcast(dstVec.getType(), fillOp.value());
|
||||
std_store(resVec, dstMemrefVec);
|
||||
} else {
|
||||
// Vectorize other ops as vector contraction (currently only matmul).
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
|
||||
"]: Rewrite linalg op as vector.contract: "
|
||||
<< *op << ":\n");
|
||||
auto vA = std_load(vector_type_cast(linalgOp.getInput(0)));
|
||||
auto vB = std_load(vector_type_cast(linalgOp.getInput(1)));
|
||||
auto vectorMemRefC = vector_type_cast(linalgOp.getOutputBuffer(0));
|
||||
auto vC = std_load(vectorMemRefC);
|
||||
auto vRes = vector_contract(vA, vB, vC, linalgOp.indexing_maps(),
|
||||
linalgOp.iterator_types());
|
||||
std_store(vRes, vectorMemRefC);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
//============================================================================//
|
||||
// Precondition and transformation for permutation of Linalg generic ops.
|
||||
//============================================================================//
|
||||
LogicalResult mlir::linalg::permuteGenericLinalgOpPrecondition(
|
||||
Operation *op, ArrayRef<unsigned> permutation) {
|
||||
if (permutation.empty())
|
||||
return failure();
|
||||
// Transformation applies to generic ops only.
|
||||
if (!isa<GenericOp>(op) && !isa<IndexedGenericOp>(op))
|
||||
return failure();
|
||||
LinalgOp linOp = cast<LinalgOp>(op);
|
||||
// Transformation applies to buffers only.
|
||||
if (!linOp.hasBufferSemantics())
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<Value, 0>
|
||||
mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<unsigned> permutation,
|
||||
StringRef linalgMarker) {
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Permute dims for linalg op: " << *op
|
||||
<< ":\n");
|
||||
|
||||
assert(succeeded(permuteGenericLinalgOpPrecondition(op, permutation)) &&
|
||||
"DRR failure case must be a precondition");
|
||||
|
||||
auto linOp = cast<LinalgOp>(op);
|
||||
auto permutationMap = inversePermutation(
|
||||
AffineMap::getPermutationMap(permutation, rewriter.getContext()));
|
||||
assert(permutationMap && "expected permutation to be invertible");
|
||||
SmallVector<AffineMap, 4> newIndexingMap;
|
||||
auto indexingMaps = linOp.indexing_maps().getValue();
|
||||
for (unsigned i = 0, e = linOp.getNumInputsAndOutputs(); i != e; ++i) {
|
||||
AffineMap m = indexingMaps[i].cast<AffineMapAttr>().getValue();
|
||||
if (!permutationMap.isEmpty())
|
||||
m = m.compose(permutationMap);
|
||||
newIndexingMap.push_back(m);
|
||||
}
|
||||
auto itTypes = linOp.iterator_types().getValue();
|
||||
SmallVector<Attribute, 4> itTypesVector;
|
||||
for (unsigned i = 0, e = itTypes.size(); i != e; ++i)
|
||||
itTypesVector.push_back(itTypes[i]);
|
||||
applyPermutationToVector(itTypesVector, permutation);
|
||||
op->setAttr(getIndexingMapsAttrName(),
|
||||
rewriter.getAffineMapArrayAttr(newIndexingMap));
|
||||
op->setAttr(getIteratorTypesAttrName(), rewriter.getArrayAttr(itTypesVector));
|
||||
op->setAttr(LinalgTransforms::kLinalgTransformMarker,
|
||||
rewriter.getStringAttr(linalgMarker));
|
||||
linOp.clone(rewriter, linOp.getLoc(), op->getOperands());
|
||||
return {};
|
||||
}
|
||||
|
||||
//============================================================================//
|
||||
// Precondition and transformation for Linalg subview promotion.
|
||||
//============================================================================//
|
||||
LogicalResult mlir::linalg::promoteSubviewsLinalgOpPrecondition(Operation *op) {
|
||||
LinalgOp linOp = dyn_cast<LinalgOp>(op);
|
||||
// Transformation applies to buffers only.
|
||||
if (!linOp || !linOp.hasBufferSemantics())
|
||||
return failure();
|
||||
if (llvm::none_of(linOp.getInputsAndOutputBuffers(), [](Value v) {
|
||||
return isa_and_nonnull<SubViewOp>(v.getDefiningOp());
|
||||
}))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<Value, 0>
|
||||
mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter,
|
||||
Operation *op) {
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Promote subviews for linalg op: "
|
||||
<< *op << ":\n");
|
||||
|
||||
assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) &&
|
||||
"DRR failure case must be a precondition");
|
||||
|
||||
LinalgOp linOp = cast<LinalgOp>(op);
|
||||
SmallVector<int64_t, 4> toPromote;
|
||||
int64_t nBuffers = linOp.getNumInputsAndOutputBuffers();
|
||||
toPromote.reserve(nBuffers);
|
||||
for (int64_t i = 0; i < nBuffers; ++i)
|
||||
toPromote.push_back(i);
|
||||
return promoteSelectedSubviewsLinalgOpAndSetMarker(rewriter, op, toPromote);
|
||||
}
|
||||
|
||||
SmallVector<Value, 0> mlir::linalg::promoteSelectedSubviewsLinalgOpAndSetMarker(
|
||||
PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<int64_t> operandIndicesToPromote, StringRef linalgMarker,
|
||||
int64_t alignment) {
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Promote subviews for linalg op: "
|
||||
<< *op << ":\n");
|
||||
|
||||
assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) &&
|
||||
"DRR failure case must be a precondition");
|
||||
|
||||
if (auto convOp = dyn_cast<linalg::ConvOp>(op)) {
|
||||
// TODO(ntv): add a level of indirection to linalg.generic.
|
||||
if (convOp.padding())
|
||||
llvm_unreachable("Unexpected conv with padding");
|
||||
}
|
||||
|
||||
LinalgOp linOp = cast<LinalgOp>(op);
|
||||
assert(linOp.hasBufferSemantics() &&
|
||||
"expected linalg op with buffer semantics");
|
||||
SetVector<Value> subViews;
|
||||
for (int64_t index : operandIndicesToPromote)
|
||||
if (auto sv =
|
||||
dyn_cast_or_null<SubViewOp>(linOp.getBuffer(index).getDefiningOp()))
|
||||
subViews.insert(sv);
|
||||
|
||||
if (!subViews.empty()) {
|
||||
auto newOp =
|
||||
promoteSubViewOperands(rewriter, linOp, subViews, false, alignment);
|
||||
if (!linalgMarker.empty())
|
||||
newOp.setAttr(LinalgTransforms::kLinalgTransformMarker,
|
||||
rewriter.getStringAttr(linalgMarker));
|
||||
return {};
|
||||
}
|
||||
llvm_unreachable("DRR failure case must be a precondition");
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
//===- LinalgToLoops.cpp - conversion from Linalg library ops to loops-----===//
|
||||
//===- Loops.cpp - conversion from Linalg named and generic ops to loops --===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -12,7 +12,7 @@
|
|||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/LoopOps/EDSC/Builders.h"
|
||||
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||
|
@ -489,20 +489,6 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// This struct is for factoring out the implementation and support template
|
||||
/// instantiations in the following 2 cases:
|
||||
/// 1. Appending to a list of patterns via RewritePatternList.
|
||||
/// 2. Direct invocation via `linalgOpToLoops` and `linalgOpToAffineLoops`.
|
||||
/// The implementation must work both in DRR and inside a RewritePattern. As a
|
||||
/// consequence, (1) it is only allowed to emit new ops if the match is
|
||||
/// guaranteed to be a success, (2) it is not allowed erase/replace, and (3) an
|
||||
/// encompassing pattern must take care of the erasure logic.
|
||||
template <typename LoopTy, typename ConcreteOpTy>
|
||||
class LinalgOpToLoopsImpl {
|
||||
public:
|
||||
static Optional<LinalgLoops> doit(Operation *op, PatternRewriter &rewriter);
|
||||
};
|
||||
|
||||
namespace {
|
||||
/// Helper struct to generate the loop nest for the op. This factored out here
|
||||
/// to be able to partially specialize this for different LoopTy.
|
||||
|
@ -573,14 +559,12 @@ public:
|
|||
} // namespace
|
||||
|
||||
template <typename LoopTy, typename ConcreteOpTy>
|
||||
Optional<LinalgLoops>
|
||||
LinalgOpToLoopsImpl<LoopTy, ConcreteOpTy>::doit(Operation *op,
|
||||
PatternRewriter &rewriter) {
|
||||
Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
|
||||
using Impl = GenerateLoopNest<LoopTy, ConcreteOpTy>;
|
||||
using IndexedValueTy =
|
||||
typename GenerateLoopNest<LoopTy, ConcreteOpTy>::IndexedValueTy;
|
||||
|
||||
ScopedContext scope(rewriter, op->getLoc());
|
||||
ScopedContext scope(builder, op->getLoc());
|
||||
|
||||
// The flattened loopToOperandRangesMaps is expected to be an invertible
|
||||
// permutation map (which is asserted in the inverse calculation).
|
||||
|
@ -607,7 +591,7 @@ LinalgOpToLoopsImpl<LoopTy, ConcreteOpTy>::doit(Operation *op,
|
|||
SmallVector<Value, 4> allIvs(nLoops);
|
||||
auto loopRanges =
|
||||
emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), invertedMap,
|
||||
getViewSizes(rewriter, linalgOp));
|
||||
getViewSizes(builder, linalgOp));
|
||||
assert(loopRanges.size() == allIvs.size());
|
||||
Impl::doit(linalgOp, loopRanges, allIvs);
|
||||
// Number of loop ops might be different from the number of ivs since some
|
||||
|
@ -635,8 +619,7 @@ public:
|
|||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
using Impl = LinalgOpToLoopsImpl<LoopType, ConcreteOp>;
|
||||
if (!Impl::doit(op, rewriter))
|
||||
if (!linalgOpToLoopsImpl<LoopType, ConcreteOp>(op, rewriter))
|
||||
return failure();
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
|
@ -662,7 +645,7 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// Populate the given list with patterns that convert from Linalg to LLVM.
|
||||
/// Populate the given list with patterns that convert from Linalg to loops.
|
||||
template <typename LoopType>
|
||||
void FillRewritePatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||
RewritePatternList<LoopType,
|
||||
|
@ -759,50 +742,49 @@ mlir::createConvertLinalgToAffineLoopsPass() {
|
|||
|
||||
/// Emits a loop nest with the proper body for `op`.
|
||||
template <typename LoopTy, typename ConcreteOp>
|
||||
Optional<LinalgLoops>
|
||||
mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter, Operation *op) {
|
||||
return LinalgOpToLoopsImpl<LoopTy, ConcreteOp>::doit(op, rewriter);
|
||||
Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder,
|
||||
Operation *op) {
|
||||
return linalgOpToLoopsImpl<LoopTy, ConcreteOp>(op, builder);
|
||||
}
|
||||
|
||||
/// Emits a loop nest of `loop.for` with the proper body for `op`.
|
||||
template <typename ConcreteOp>
|
||||
LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter,
|
||||
Operation *op) {
|
||||
LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) {
|
||||
Optional<LinalgLoops> loops =
|
||||
linalgLowerOpToLoops<loop::ForOp, ConcreteOp>(rewriter, op);
|
||||
linalgLowerOpToLoops<loop::ForOp, ConcreteOp>(builder, op);
|
||||
return loops ? success() : failure();
|
||||
}
|
||||
|
||||
/// Emits a loop nest of `affine.for` with the proper body for `op`.
|
||||
template <typename ConcreteOp>
|
||||
LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter,
|
||||
LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder,
|
||||
Operation *op) {
|
||||
Optional<LinalgLoops> loops =
|
||||
linalgLowerOpToLoops<AffineForOp, ConcreteOp>(rewriter, op);
|
||||
linalgLowerOpToLoops<AffineForOp, ConcreteOp>(builder, op);
|
||||
return loops ? success() : failure();
|
||||
}
|
||||
|
||||
/// Emits a loop nest of `loop.parallel` with the proper body for `op`.
|
||||
template <typename ConcreteOp>
|
||||
LogicalResult mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter,
|
||||
LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder,
|
||||
Operation *op) {
|
||||
Optional<LinalgLoops> loops =
|
||||
linalgLowerOpToLoops<loop::ParallelOp, ConcreteOp>(rewriter, op);
|
||||
linalgLowerOpToLoops<loop::ParallelOp, ConcreteOp>(builder, op);
|
||||
return loops ? success() : failure();
|
||||
}
|
||||
|
||||
// TODO(ntv) Need to make these instantiations more future-proof to avoid the
|
||||
// need to update as soon as we add new ops.
|
||||
// TODO Need to make these instantiations more future-proof to avoid the need to
|
||||
// update as soon as we add new ops.
|
||||
#define INSTANTIATE_LINALG_OP_TO_LOOPS(OP_TYPE) \
|
||||
template LogicalResult mlir::linalg::linalgOpToLoops<OP_TYPE>( \
|
||||
PatternRewriter & rewriter, Operation * op); \
|
||||
OpBuilder & builder, Operation * op); \
|
||||
template LogicalResult mlir::linalg::linalgOpToAffineLoops<OP_TYPE>( \
|
||||
PatternRewriter & rewriter, Operation * op); \
|
||||
OpBuilder & builder, Operation * op); \
|
||||
template LogicalResult mlir::linalg::linalgOpToParallelLoops<OP_TYPE>( \
|
||||
PatternRewriter & rewriter, Operation * op); \
|
||||
OpBuilder & builder, Operation * op); \
|
||||
template Optional<LinalgLoops> \
|
||||
mlir::linalg::linalgLowerOpToLoops<loop::ParallelOp, OP_TYPE>( \
|
||||
PatternRewriter & rewriter, Operation * op);
|
||||
OpBuilder & builder, Operation * op);
|
||||
|
||||
INSTANTIATE_LINALG_OP_TO_LOOPS(CopyOp)
|
||||
INSTANTIATE_LINALG_OP_TO_LOOPS(FillOp)
|
|
@ -16,6 +16,7 @@
|
|||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/LoopOps/LoopOps.h"
|
||||
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||
|
@ -264,6 +265,21 @@ static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
|
|||
op.erase();
|
||||
}
|
||||
|
||||
LogicalResult mlir::linalg::promoteSubviewsLinalgOpPrecondition(
|
||||
Operation *op, llvm::Optional<DenseSet<unsigned>> operandIndicesToPromote) {
|
||||
LinalgOp linOp = dyn_cast<LinalgOp>(op);
|
||||
// Transformation applies to buffers only.
|
||||
if (!linOp || !linOp.hasBufferSemantics())
|
||||
return failure();
|
||||
for (auto en : llvm::enumerate(linOp.getInputsAndOutputBuffers())) {
|
||||
auto sv = isa_and_nonnull<SubViewOp>(en.value().getDefiningOp());
|
||||
if (sv && (!operandIndicesToPromote.hasValue() ||
|
||||
operandIndicesToPromote->count(en.index())))
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct LinalgPromotionPass : public LinalgPromotionBase<LinalgPromotionPass> {
|
||||
LinalgPromotionPass() = default;
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/LoopOps/EDSC/Builders.h"
|
||||
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||
|
@ -320,10 +321,9 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
|
|||
}
|
||||
|
||||
template <typename LoopTy>
|
||||
Optional<TiledLinalgOp> static tileLinalgOpImpl(OpBuilder &b, LinalgOp op,
|
||||
ArrayRef<Value> tileSizes,
|
||||
ArrayRef<unsigned> permutation,
|
||||
OperationFolder *folder) {
|
||||
Optional<TiledLinalgOp> static tileLinalgOpImpl(
|
||||
OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
|
||||
ArrayRef<unsigned> interchangeVector, OperationFolder *folder) {
|
||||
assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
|
||||
// 1. Enforce the convention that "tiling by zero" skips tiling a particular
|
||||
// dimension. This convention is significantly simpler to handle instead of
|
||||
|
@ -342,13 +342,13 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(OpBuilder &b, LinalgOp op,
|
|||
return llvm::None;
|
||||
}
|
||||
|
||||
// If permutation is empty, use the identity. Build the permutation map
|
||||
// If interchangeVector is empty, use the identity. Build the permutation map
|
||||
// otherwise.
|
||||
auto invPermutationMap = AffineMap::getMultiDimIdentityMap(
|
||||
tileSizes.size(), ScopedContext::getContext());
|
||||
if (!permutation.empty())
|
||||
invPermutationMap = inversePermutation(
|
||||
AffineMap::getPermutationMap(permutation, ScopedContext::getContext()));
|
||||
if (!interchangeVector.empty())
|
||||
invPermutationMap = inversePermutation(AffineMap::getPermutationMap(
|
||||
interchangeVector, ScopedContext::getContext()));
|
||||
if (!invPermutationMap)
|
||||
return llvm::None;
|
||||
|
||||
|
@ -371,8 +371,8 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(OpBuilder &b, LinalgOp op,
|
|||
std::tie(loopRanges, loopIndexToRangeIndex) =
|
||||
makeTiledLoopRanges(b, scope.getLocation(), viewSizesToLoopsMap,
|
||||
viewSizes, tileSizes, folder);
|
||||
if (!permutation.empty())
|
||||
applyPermutationToVector(loopRanges, permutation);
|
||||
if (!interchangeVector.empty())
|
||||
applyPermutationToVector(loopRanges, interchangeVector);
|
||||
|
||||
// 3. Create the tiled loops.
|
||||
LinalgOp res = op;
|
||||
|
@ -393,7 +393,7 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(OpBuilder &b, LinalgOp op,
|
|||
// assuming that loopRanges have previously been permuted by
|
||||
// (i,j,k)->(k,i,j) So this permutation should be the inversePermutation of
|
||||
// that one: (d0,d1,d2)->(d2,d0,d1)
|
||||
if (!permutation.empty())
|
||||
if (!interchangeVector.empty())
|
||||
ivValues = applyMapToValues(b, loc, invPermutationMap, ivValues, folder);
|
||||
|
||||
auto views =
|
||||
|
@ -420,7 +420,8 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(OpBuilder &b, LinalgOp op,
|
|||
template <typename LoopTy>
|
||||
static Optional<TiledLinalgOp>
|
||||
tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
|
||||
ArrayRef<unsigned> permutation, OperationFolder *folder) {
|
||||
ArrayRef<unsigned> interchangeVector,
|
||||
OperationFolder *folder) {
|
||||
assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
|
||||
if (tileSizes.empty())
|
||||
return llvm::None;
|
||||
|
@ -459,33 +460,36 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
|
|||
tileSizeValues.push_back(folded_std_constant_index(folder, 0));
|
||||
}
|
||||
|
||||
return tileLinalgOpImpl<LoopTy>(b, op, tileSizeValues, permutation, folder);
|
||||
return tileLinalgOpImpl<LoopTy>(b, op, tileSizeValues, interchangeVector,
|
||||
folder);
|
||||
}
|
||||
|
||||
Optional<TiledLinalgOp>
|
||||
mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
|
||||
ArrayRef<unsigned> permutation,
|
||||
ArrayRef<unsigned> interchangeVector,
|
||||
OperationFolder *folder) {
|
||||
return tileLinalgOpImpl<loop::ForOp>(b, op, tileSizes, permutation, folder);
|
||||
return tileLinalgOpImpl<loop::ForOp>(b, op, tileSizes, interchangeVector,
|
||||
folder);
|
||||
}
|
||||
|
||||
Optional<TiledLinalgOp> mlir::linalg::tileLinalgOpToParallelLoops(
|
||||
OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
|
||||
ArrayRef<unsigned> permutation, OperationFolder *folder) {
|
||||
return tileLinalgOpImpl<loop::ParallelOp>(b, op, tileSizes, permutation,
|
||||
ArrayRef<unsigned> interchangeVector, OperationFolder *folder) {
|
||||
return tileLinalgOpImpl<loop::ParallelOp>(b, op, tileSizes, interchangeVector,
|
||||
folder);
|
||||
}
|
||||
|
||||
Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
|
||||
OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
|
||||
ArrayRef<unsigned> permutation, OperationFolder *folder) {
|
||||
return tileLinalgOpImpl<loop::ForOp>(b, op, tileSizes, permutation, folder);
|
||||
ArrayRef<unsigned> interchangeVector, OperationFolder *folder) {
|
||||
return tileLinalgOpImpl<loop::ForOp>(b, op, tileSizes, interchangeVector,
|
||||
folder);
|
||||
}
|
||||
|
||||
Optional<TiledLinalgOp> mlir::linalg::tileLinalgOpToParallelLoops(
|
||||
OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
|
||||
ArrayRef<unsigned> permutation, OperationFolder *folder) {
|
||||
return tileLinalgOpImpl<loop::ParallelOp>(b, op, tileSizes, permutation,
|
||||
ArrayRef<unsigned> interchangeVector, OperationFolder *folder) {
|
||||
return tileLinalgOpImpl<loop::ParallelOp>(b, op, tileSizes, interchangeVector,
|
||||
folder);
|
||||
}
|
||||
|
||||
|
@ -496,8 +500,8 @@ static void tileLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) {
|
|||
f.walk([tileSizes, &b, &folder](LinalgOp op) {
|
||||
if (!op.hasBufferSemantics())
|
||||
return;
|
||||
auto opLoopsPair =
|
||||
tileLinalgOpImpl<LoopTy>(b, op, tileSizes, /*permutation=*/{}, &folder);
|
||||
auto opLoopsPair = tileLinalgOpImpl<LoopTy>(
|
||||
b, op, tileSizes, /*interchangeVector=*/{}, &folder);
|
||||
// If tiling occurred successfully, erase old op.
|
||||
if (opLoopsPair)
|
||||
op.erase();
|
||||
|
|
|
@ -0,0 +1,228 @@
|
|||
//===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements logic and helpers to expose Linalg transforms as rewrite
|
||||
// patterns.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <type_traits>
|
||||
|
||||
#define DEBUG_TYPE "linalg-transforms"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
using llvm::dbgs;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Transformations exposed as rewrite patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Marker used as attribute name in generated Linalg rewriting transformations.
|
||||
const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
|
||||
"__internal_linalg_transform__";
|
||||
|
||||
mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
|
||||
llvm::Optional<StringRef> replacement)
|
||||
: matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
|
||||
replacement(replacement) {}
|
||||
|
||||
mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
|
||||
StringRef replacement)
|
||||
: LinalgMarker(matchDisjunction, llvm::Optional<StringRef>{replacement}) {}
|
||||
|
||||
LogicalResult
|
||||
mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
|
||||
Operation *op) const {
|
||||
auto attr = op->template getAttrOfType<StringAttr>(
|
||||
LinalgTransforms::kLinalgTransformMarker);
|
||||
|
||||
if (!attr) {
|
||||
// 1. Has no marker case and matchDisjunction is empty.
|
||||
if (matchDisjunction.empty())
|
||||
return success();
|
||||
|
||||
// 2. Has no marker and matchDisjuntion matches the no-moarker case.
|
||||
for (auto marker : matchDisjunction)
|
||||
if (marker.empty())
|
||||
return success();
|
||||
|
||||
// 3. Has no marker but was expecting a marker.
|
||||
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
|
||||
diag << " does not have any marker from list: ";
|
||||
llvm::interleaveComma(matchDisjunction, diag);
|
||||
});
|
||||
}
|
||||
|
||||
// 4. Match explicit marker.
|
||||
for (auto marker : matchDisjunction)
|
||||
if (attr.getValue() == marker)
|
||||
return success();
|
||||
|
||||
// 5. Fail to match.
|
||||
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
|
||||
diag << " does not have any marker from list: ";
|
||||
llvm::interleaveComma(matchDisjunction, diag);
|
||||
});
|
||||
}
|
||||
|
||||
void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter,
|
||||
Operation *op) const {
|
||||
if (replacement.hasValue())
|
||||
op->setAttr(LinalgTransforms::kLinalgTransformMarker,
|
||||
rewriter.getStringAttr(replacement.getValue()));
|
||||
else
|
||||
op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker,
|
||||
rewriter.getContext()));
|
||||
}
|
||||
|
||||
/// Linalg base tiling pattern.
|
||||
mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
|
||||
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
|
||||
LinalgMarker marker, PatternBenefit benefit)
|
||||
: RewritePattern(opName, {}, benefit, context), marker(marker),
|
||||
options(options) {}
|
||||
|
||||
LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewrite(
|
||||
Operation *op, PatternRewriter &rewriter) const {
|
||||
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
|
||||
if (!linalgOp)
|
||||
return failure();
|
||||
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
|
||||
return failure();
|
||||
Optional<TiledLinalgOp> res;
|
||||
if (options.loopType == LinalgTilingLoopType::Loops)
|
||||
res = tileLinalgOp(rewriter, linalgOp, options.tileSizes,
|
||||
options.interchangeVector);
|
||||
else if (options.loopType == LinalgTilingLoopType::ParallelLoops)
|
||||
res = tileLinalgOpToParallelLoops(rewriter, linalgOp, options.tileSizes,
|
||||
options.interchangeVector);
|
||||
// TODO: Impl tiling to affine loops when it makes sense.
|
||||
|
||||
if (!res)
|
||||
return failure();
|
||||
|
||||
// New marker if specified.
|
||||
marker.replaceLinalgMarker(rewriter, res->op.getOperation());
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Linalg base interchange pattern.
|
||||
mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
|
||||
StringRef opName, MLIRContext *context,
|
||||
ArrayRef<unsigned> interchangeVector, LinalgMarker marker,
|
||||
PatternBenefit benefit)
|
||||
: RewritePattern(opName, {}, benefit, context), marker(marker),
|
||||
interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
|
||||
|
||||
LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
|
||||
Operation *op, PatternRewriter &rewriter) const {
|
||||
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
|
||||
if (!linalgOp)
|
||||
return failure();
|
||||
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
|
||||
return failure();
|
||||
if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector)))
|
||||
return failure();
|
||||
|
||||
// TODO: figure out how this interplays with named ops. In particular this
|
||||
// should break the named op property.
|
||||
rewriter.updateRootInPlace(op, [&]() {
|
||||
interchange(linalgOp, interchangeVector);
|
||||
// New marker if specified.
|
||||
marker.replaceLinalgMarker(rewriter, op);
|
||||
});
|
||||
return success();
|
||||
}
|
||||
|
||||
mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
|
||||
StringRef opName, MLIRContext *context,
|
||||
ArrayRef<unsigned> operandsToPromote, unsigned alignment,
|
||||
LinalgMarker marker, PatternBenefit benefit)
|
||||
: RewritePattern(opName, {}, benefit, context), marker(marker),
|
||||
operandsToPromote(operandsToPromote.begin(), operandsToPromote.end()),
|
||||
alignment(alignment) {}
|
||||
|
||||
LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
|
||||
Operation *op, PatternRewriter &rewriter) const {
|
||||
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
|
||||
if (!linalgOp)
|
||||
return failure();
|
||||
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
|
||||
return failure();
|
||||
if (operandsToPromote.empty()) {
|
||||
if (failed(promoteSubviewsLinalgOpPrecondition(op, llvm::None)))
|
||||
return failure();
|
||||
} else {
|
||||
DenseSet<unsigned> set;
|
||||
set.insert(operandsToPromote.begin(), operandsToPromote.end());
|
||||
if (failed(promoteSubviewsLinalgOpPrecondition(op, set)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
llvm::SetVector<Value> subViews;
|
||||
if (!operandsToPromote.empty()) {
|
||||
for (unsigned idx : operandsToPromote) {
|
||||
auto *op = linalgOp.getBuffer(idx).getDefiningOp();
|
||||
if (auto sv = dyn_cast_or_null<SubViewOp>(op))
|
||||
subViews.insert(sv);
|
||||
}
|
||||
} else {
|
||||
unsigned nBuffers = linalgOp.getNumInputsAndOutputBuffers();
|
||||
for (unsigned idx = 0; idx < nBuffers; ++idx) {
|
||||
auto *op = linalgOp.getBuffer(idx).getDefiningOp();
|
||||
if (auto sv = dyn_cast_or_null<SubViewOp>(op))
|
||||
subViews.insert(sv);
|
||||
}
|
||||
}
|
||||
|
||||
auto promotedOp =
|
||||
promoteSubViewOperands(rewriter, op, subViews, /*dynamicBuffers=*/false,
|
||||
/*alignment=*/alignment);
|
||||
marker.replaceLinalgMarker(rewriter, promotedOp.getOperation());
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
|
||||
StringRef opName, MLIRContext *context, LinalgMarker marker,
|
||||
PatternBenefit benefit)
|
||||
: RewritePattern(opName, {}, benefit, context), marker(marker) {}
|
||||
|
||||
LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
|
||||
Operation *op, PatternRewriter &rewriter) const {
|
||||
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
|
||||
if (!linalgOp)
|
||||
return failure();
|
||||
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
|
||||
return failure();
|
||||
if (failed(vectorizeLinalgOpPrecondition(op)))
|
||||
return failure();
|
||||
vectorizeLinalgOp(rewriter, op);
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
|
@ -0,0 +1,131 @@
|
|||
//===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements the linalg dialect Vectorization transformations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <type_traits>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
using llvm::dbgs;
|
||||
|
||||
#define DEBUG_TYPE "linalg-vectorization"
|
||||
|
||||
static bool hasMultiplyAddBody(linalg::GenericOp op) {
|
||||
auto &r = op.region();
|
||||
if (!llvm::hasSingleElement(r))
|
||||
return false;
|
||||
if (!llvm::hasNItems(r.front().begin(), r.front().end(), 3))
|
||||
return false;
|
||||
|
||||
using mlir::matchers::m_Val;
|
||||
auto a = m_Val(r.front().getArgument(0));
|
||||
auto b = m_Val(r.front().getArgument(1));
|
||||
auto c = m_Val(r.front().getArgument(2));
|
||||
// TODO: Update this detection once we have matcher support for specifying
|
||||
// that any permutation of operands matches.
|
||||
auto pattern1 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c));
|
||||
auto pattern2 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b)));
|
||||
auto pattern3 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c));
|
||||
auto pattern4 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a)));
|
||||
return pattern1.match(&r.front().back()) ||
|
||||
pattern2.match(&r.front().back()) ||
|
||||
pattern3.match(&r.front().back()) || pattern4.match(&r.front().back());
|
||||
}
|
||||
|
||||
// TODO: Should be Tablegen'd from a single source that generates the op itself.
|
||||
static bool isRowMajorMatmul(linalg::GenericOp genericOp) {
|
||||
return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
|
||||
isRowMajorMatmul(genericOp.indexing_maps()) &&
|
||||
hasMultiplyAddBody(genericOp);
|
||||
}
|
||||
|
||||
// TODO: This is in fact much more general than just vectorization for matmul
|
||||
// and fill ops.
|
||||
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
|
||||
auto linalgOp = cast<linalg::LinalgOp>(op);
|
||||
// All types must be static shape to go to vector.
|
||||
for (Value operand : linalgOp.getInputsAndOutputBuffers())
|
||||
if (!operand.getType().cast<ShapedType>().hasStaticShape())
|
||||
return failure();
|
||||
for (Type outputTensorType : linalgOp.getOutputTensorTypes())
|
||||
if (!outputTensorType.cast<ShapedType>().hasStaticShape())
|
||||
return failure();
|
||||
if (isa<linalg::MatmulOp>(op) || isa<linalg::FillOp>(op))
|
||||
return success();
|
||||
|
||||
auto genericOp = dyn_cast<linalg::GenericOp>(op);
|
||||
if (!genericOp || !::isRowMajorMatmul(genericOp))
|
||||
return failure();
|
||||
|
||||
// TODO(ntv): non-identity layout.
|
||||
auto isStaticMemRefWithIdentityLayout = [](Value v) {
|
||||
auto m = v.getType().dyn_cast<MemRefType>();
|
||||
if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty())
|
||||
return false;
|
||||
return true;
|
||||
};
|
||||
return success(llvm::all_of(genericOp.getInputsAndOutputBuffers(),
|
||||
isStaticMemRefWithIdentityLayout));
|
||||
}
|
||||
|
||||
void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
|
||||
assert(succeeded(vectorizeLinalgOpPrecondition(op)));
|
||||
|
||||
if (auto convOp = dyn_cast<linalg::ConvOp>(op)) {
|
||||
// TODO: add a level of indirection to linalg.generic.
|
||||
if (convOp.padding())
|
||||
llvm_unreachable("Unexpected conv with padding");
|
||||
}
|
||||
|
||||
edsc::ScopedContext scope(builder, op->getLoc());
|
||||
if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
|
||||
// Vectorize fill as a vector.broadcast.
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
|
||||
"]: Rewrite linalg.fill as vector.broadcast: "
|
||||
<< *op << ":\n");
|
||||
Value memref = vector_type_cast(fillOp.getOutputBuffer(0));
|
||||
Value dst = std_load(memref);
|
||||
Value res = vector_broadcast(dst.getType(), fillOp.value());
|
||||
std_store(res, memref);
|
||||
return;
|
||||
}
|
||||
|
||||
// Vectorize other ops as vector contraction (currently only matmul).
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
|
||||
"]: Rewrite linalg op as vector.contract: "
|
||||
<< *op << ":\n");
|
||||
auto linalgOp = cast<linalg::LinalgOp>(op);
|
||||
Value a = std_load(vector_type_cast(linalgOp.getInput(0)));
|
||||
Value b = std_load(vector_type_cast(linalgOp.getInput(1)));
|
||||
Value memref = vector_type_cast(linalgOp.getOutputBuffer(0));
|
||||
Value c = std_load(memref);
|
||||
Value res = vector_contract(a, b, c, linalgOp.indexing_maps(),
|
||||
linalgOp.iterator_types());
|
||||
std_store(res, memref);
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt %s -test-linalg-transform-patterns | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-patterns | FileCheck %s
|
||||
|
||||
// CHECK-DAG: #[[STRIDED_1D:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
|
||||
// Map corresponding to a 2D memory access where the stride along the last dim is known to be 1.
|
||||
|
@ -22,10 +22,8 @@ func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
|
|||
// CHECK-LABEL: func @dot
|
||||
// CHECK-DAG: %[[c0:.*]] = constant 0 : index
|
||||
// CHECK-DAG: %[[c1:.*]] = constant 1 : index
|
||||
// CHECK-DAG: %[[c8:.*]] = constant 8 : index
|
||||
// CHECK-DAG: %[[c8000:.*]] = constant 8000 : index
|
||||
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8000]] {
|
||||
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8]] {
|
||||
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c1]] {
|
||||
// CHECK: load
|
||||
// CHECK: load
|
||||
|
@ -46,8 +44,7 @@ func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
|||
// CHECK-DAG: %[[c0:.*]] = constant 0 : index
|
||||
// CHECK-DAG: %[[c5:.*]] = constant 5 : index
|
||||
// CHECK-DAG: %[[c6:.*]] = constant 6 : index
|
||||
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c5]]
|
||||
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c6]]
|
||||
// CHECK: loop.parallel {{.*}} step (%[[c5]], %[[c6]])
|
||||
// CHECK: linalg.matvec({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?xf32, #[[STRIDED_1D]]>, memref<?xf32, #[[STRIDED_1D]]>
|
||||
|
||||
func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
|
@ -86,88 +83,6 @@ func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
|||
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4]] {
|
||||
// CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
|
||||
|
||||
#some_generic_trait = {
|
||||
args_in = 1,
|
||||
args_out = 1,
|
||||
indexing_maps = [
|
||||
affine_map<(i, j) -> (i, j)>,
|
||||
affine_map<(i, j) -> (i, j)>
|
||||
],
|
||||
iterator_types = ["parallel", "parallel"]
|
||||
}
|
||||
func @fusion_test(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
%D: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
%E: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
|
||||
// This should not be fused as it would violate dependencies. It will get
|
||||
// tiled for all levels of the memory hierarchy.
|
||||
linalg.matmul(%A, %A, %C) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
|
||||
// This should be fused.
|
||||
linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
|
||||
// This should not be fused or transformed at all since there are no patterns
|
||||
// on it. However it will be reordered because there are no dependencies.
|
||||
linalg.generic #some_generic_trait %A, %D {
|
||||
^bb(%a: f32, %b: f32) :
|
||||
linalg.yield %a : f32
|
||||
} : memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
|
||||
linalg.matmul(%C, %D, %E) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @fusion_test
|
||||
// CHECK-DAG: %[[c0:.*]] = constant 0 : index
|
||||
// CHECK-DAG: %[[c2:.*]] = constant 2 : index
|
||||
// CHECK-DAG: %[[c3:.*]] = constant 3 : index
|
||||
// CHECK-DAG: %[[c4:.*]] = constant 4 : index
|
||||
// CHECK-DAG: %[[c20:.*]] = constant 20 : index
|
||||
// CHECK-DAG: %[[c30:.*]] = constant 30 : index
|
||||
// CHECK-DAG: %[[c40:.*]] = constant 40 : index
|
||||
// CHECK-DAG: %[[c100:.*]] = constant 100 : index
|
||||
// CHECK-DAG: %[[c150:.*]] = constant 150 : index
|
||||
// CHECK-DAG: %[[c200:.*]] = constant 200 : index
|
||||
// CHECK-DAG: %[[c300:.*]] = constant 300 : index
|
||||
// CHECK-DAG: %[[c400:.*]] = constant 400 : index
|
||||
// CHECK-DAG: %[[c2000:.*]] = constant 2000 : index
|
||||
// CHECK-DAG: %[[c3000:.*]] = constant 3000 : index
|
||||
// CHECK-DAG: %[[c4000:.*]] = constant 4000 : index
|
||||
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] {
|
||||
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] {
|
||||
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] {
|
||||
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c200]] {
|
||||
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c300]] {
|
||||
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c400]] {
|
||||
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c20]] {
|
||||
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] {
|
||||
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] {
|
||||
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2]] {
|
||||
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3]] {
|
||||
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4]] {
|
||||
// CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
|
||||
//
|
||||
// CHECK: linalg.generic
|
||||
//
|
||||
// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c100]] {
|
||||
// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c150]] {
|
||||
// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c2]] {
|
||||
// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c3]] {
|
||||
// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c4]] {
|
||||
// CHECK: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
|
||||
// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c2]] {
|
||||
// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c3]] {
|
||||
// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c4]] {
|
||||
// CHECK: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
|
||||
|
||||
#matmul_trait = {
|
||||
args_in = 2,
|
||||
args_out = 1,
|
||||
|
@ -280,23 +195,6 @@ func @permute_generic_indexed(
|
|||
// CHECK-SAME: memref<?x?xf32, #[[STRIDED_2D_u_1]]>,
|
||||
// CHECK-SAME: memref<?x?xf32, #[[STRIDED_2D_u_1]]>
|
||||
|
||||
func @dot_perm(%x: memref<?xf32, offset: ?, strides: [1]>,
|
||||
%y: memref<?xf32, offset: ?, strides: [1]>,
|
||||
%v: memref<f32>) {
|
||||
linalg.dot(%x, %y, %v) {__internal_linalg_transform__ = "__with_perm__"} :
|
||||
memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<f32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @dot_perm
|
||||
// CHECK-DAG: %[[c0:.*]] = constant 0 : index
|
||||
// CHECK-DAG: %[[c8:.*]] = constant 8 : index
|
||||
// CHECK-DAG: %[[c8000:.*]] = constant 8000 : index
|
||||
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8000]] {
|
||||
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8]] {
|
||||
// CHECK: linalg.dot({{.*}}, {{.*}}, {{.*}}) : memref<?xf32, #[[STRIDED_1D]]>, memref<?xf32, #[[STRIDED_1D]]>, memref<f32>
|
||||
|
||||
func @matvec_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
%x: memref<?xf32, offset: ?, strides: [1]>,
|
||||
%y: memref<?xf32, offset: ?, strides: [1]>) {
|
||||
|
|
|
@ -1,9 +1,3 @@
|
|||
set(LLVM_TARGET_DEFINITIONS TestLinalgTransformPatterns.td)
|
||||
mlir_tablegen(TestLinalgTransformPatterns.h.inc -gen-rewriters)
|
||||
add_public_tablegen_target(MLIRTestLinalgTransformPatternsIncGen)
|
||||
# Including Linalg in TableGen requires to depends on generated files
|
||||
add_dependencies(MLIRTestLinalgTransformPatternsIncGen LinalgOdsGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TestVectorTransformPatterns.td)
|
||||
mlir_tablegen(TestVectorTransformPatterns.h.inc -gen-rewriters)
|
||||
add_public_tablegen_target(MLIRTestVectorTransformPatternsIncGen)
|
||||
|
|
|
@ -1,168 +0,0 @@
|
|||
//===- TestLinalgTransformPatterns.td - Test patterns --*- tablegen ----*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This is the pattern definition file for declarative Linalg transformations
|
||||
// tests.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TEST_LINALG_TRANSFORMS_PATTERNS
|
||||
#define TEST_LINALG_TRANSFORMS_PATTERNS
|
||||
|
||||
include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Linalg fusion patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pat<(MatmulOp:$op $A, $_, $_),
|
||||
(TileAndFuseLinalgOp<[100, 150], [0], "L1">),
|
||||
[
|
||||
(Constraint<HasNoLinalgTransformMarker>),
|
||||
(Constraint<IsProducedByOpOfType<"MatmulOp">> $A),
|
||||
],
|
||||
// In the buffer world there is no use-def chains or dags so benefits
|
||||
// cannot be computed automatically from the length of the matched
|
||||
// pattern. Instead we specify the benefit ourselves for now.
|
||||
// This is not expected to be a big challenge long-term because
|
||||
// pattern benefits are akin to feature engineering: features should
|
||||
// be learned.
|
||||
(addBenefit 1)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg tiling patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pat<(MatmulOp:$op $_, $_, $_),
|
||||
(TileLinalgOp<[2000, 3000, 4000], "L3">),
|
||||
[(Constraint<Or<[HasNoLinalgTransformMarker,
|
||||
HasLinalgTransformMarker<"MEM">]>>)]>;
|
||||
def : Pat<(MatmulOp:$op $_, $_, $_),
|
||||
(TileLinalgOp<[200, 300, 400], "L2">),
|
||||
[(Constraint<HasLinalgTransformMarker<"L3">>)]>;
|
||||
def : Pat<(MatmulOp:$op $_, $_, $_),
|
||||
(TileLinalgOp<[20, 30, 40], "L1">),
|
||||
[(Constraint<HasLinalgTransformMarker<"L2">>)]>;
|
||||
def : Pat<(MatmulOp:$op $_, $_, $_),
|
||||
(TileLinalgOp<[2, 3, 4], "REG">),
|
||||
[(Constraint<HasLinalgTransformMarker<"L1">>)]>;
|
||||
|
||||
def : Pattern<(MatvecOp:$op $_, $_, $_),
|
||||
[(TileLinalgOp<[5, 6], "L1">)],
|
||||
[(Constraint<HasNoLinalgTransformMarker>)]>;
|
||||
|
||||
def : Pattern<(DotOp:$op $_, $_, $_),
|
||||
[(TileLinalgOp<[8000], "L1">)],
|
||||
[(Constraint<Or<[HasNoLinalgTransformMarker,
|
||||
HasLinalgTransformMarker<"MEM">,
|
||||
HasLinalgTransformMarker<"L3">,
|
||||
HasLinalgTransformMarker<"L2">]>>)]>;
|
||||
def : Pattern<(DotOp:$op $_, $_, $_),
|
||||
[(TileLinalgOp<[8], "REG">)],
|
||||
[(Constraint<HasLinalgTransformMarker<"L1">>)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg tiling and permutation patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pat<(MatmulOp:$op $_, $_, $_),
|
||||
(TileLinalgOp<[2000, 3000, 4000], "L2__with_perm__", [1,2,0]>),
|
||||
[(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>;
|
||||
def : Pat<(MatmulOp:$op $_, $_, $_),
|
||||
(TileLinalgOp<[200, 300, 400], "L1__with_perm__", [1,0,2]>),
|
||||
[(Constraint<HasLinalgTransformMarker<"L2__with_perm__">>)]>;
|
||||
def : Pat<(MatmulOp:$op $_, $_, $_),
|
||||
(TileLinalgOp<[20, 30, 40], "REG__with_perm__">),
|
||||
[(Constraint<HasLinalgTransformMarker<"L1__with_perm__">>)]>;
|
||||
|
||||
|
||||
def : Pattern<(MatvecOp:$op $_, $_, $_),
|
||||
[(TileLinalgOp<[5, 6], "L1__with_perm__", [1,0]>)],
|
||||
[(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>;
|
||||
|
||||
def : Pattern<(DotOp:$op $_, $_, $_),
|
||||
[(TileLinalgOp<[8000], "L1__with_perm__">)],
|
||||
[(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>;
|
||||
def : Pattern<(DotOp:$op $_, $_, $_),
|
||||
[(TileLinalgOp<[8], "REG__with_perm__">)],
|
||||
[(Constraint<HasLinalgTransformMarker<"L1__with_perm__">>)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg to loops patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pattern<(DotOp:$op $_, $_, $_),
|
||||
[(LinalgOpToLoops<"DotOp">)],
|
||||
[(Constraint<HasLinalgTransformMarker<"REG">>)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg to vector contraction patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pattern<(MatmulOp:$op $_, $_, $_),
|
||||
[(VectorizeLinalgOp)],
|
||||
[(Constraint<And<[
|
||||
HasLinalgTransformMarker<"VECTORIZE">,
|
||||
PreconditionVectorizeLinalgOp
|
||||
]>>)]>;
|
||||
def : Pattern<(FillOp:$op $_, $_),
|
||||
[(VectorizeLinalgOp)],
|
||||
[(Constraint<And<[
|
||||
HasLinalgTransformMarker<"VECTORIZE">,
|
||||
PreconditionVectorizeLinalgOp
|
||||
]>>)]>;
|
||||
def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_),
|
||||
[(VectorizeLinalgOp)],
|
||||
[(Constraint<And<[
|
||||
HasLinalgTransformMarker<"VECTORIZE">,
|
||||
PreconditionVectorizeLinalgOp
|
||||
]>>)]>;
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg generic permutation patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pat<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_),
|
||||
(PermuteGenericLinalgOp<[1, 2, 0], "PERMUTE"> $op),
|
||||
[(Constraint<And<[
|
||||
HasNoLinalgTransformMarker,
|
||||
AffineMapDomainHasDim<3>,
|
||||
PreconditionPermuteGenericLinalgOp<[1, 2, 0]>
|
||||
]>>)]>;
|
||||
|
||||
def : Pat<(IndexedGenericOp:$op $_, $_, $_, $_, $_, $_, $_),
|
||||
(PermuteGenericLinalgOp<[1, 2, 0], "PERMUTE"> $op),
|
||||
[(Constraint<And<[
|
||||
HasNoLinalgTransformMarker,
|
||||
AffineMapDomainHasDim<3>,
|
||||
PreconditionPermuteGenericLinalgOp<[1, 2, 0]>
|
||||
]>>)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg subview operands promotion.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pat<(MatmulOp:$op $_, $_, $_),
|
||||
(PromoteSubviewsLinalgOp),
|
||||
[(Constraint<And<[
|
||||
PreconditionPromoteSubviewsLinalgOp,
|
||||
HasOperandsOfType<"SubViewOp">,
|
||||
HasLinalgTransformMarker<"_promote_views_">]>>
|
||||
)]>;
|
||||
|
||||
def : Pat<(MatmulOp:$op $_, $_, $_),
|
||||
(PromoteSelectedSubviewsLinalgOp<[0], "first_view_promotion">),
|
||||
[(Constraint<And<[
|
||||
PreconditionPromoteSubviewsLinalgOp,
|
||||
HasOperandsOfType<"SubViewOp">,
|
||||
HasLinalgTransformMarker<"_promote_first_view_">]>>
|
||||
)]>;
|
||||
|
||||
def : Pat<(FillOp:$op $_, $_),
|
||||
(PromoteSelectedSubviewsLinalgOp<[0], "aligned_promotion", 32>),
|
||||
[(Constraint<And<[
|
||||
PreconditionPromoteSubviewsLinalgOp,
|
||||
HasOperandsOfType<"SubViewOp">,
|
||||
HasLinalgTransformMarker<"_promote_views_aligned_">]>>
|
||||
)]>;
|
||||
|
||||
#endif // TEST_LINALG_TRANSFORMS_PATTERNS
|
|
@ -25,7 +25,6 @@ add_llvm_library(MLIRTestTransforms
|
|||
|
||||
DEPENDS
|
||||
MLIRStandardOpsIncGen
|
||||
MLIRTestLinalgTransformPatternsIncGen
|
||||
MLIRTestVectorTransformPatternsIncGen
|
||||
)
|
||||
|
||||
|
|
|
@ -10,36 +10,127 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
namespace mlir {
|
||||
namespace linalg {
|
||||
namespace {
|
||||
#include "TestLinalgTransformPatterns.h.inc"
|
||||
} // end namespace
|
||||
} // end namespace linalg
|
||||
} // end namespace mlir
|
||||
|
||||
namespace {
|
||||
struct TestLinalgTransforms
|
||||
: public PassWrapper<TestLinalgTransforms, FunctionPass> {
|
||||
TestLinalgTransforms() = default;
|
||||
TestLinalgTransforms(const TestLinalgTransforms &pass) {}
|
||||
|
||||
void runOnFunction() override;
|
||||
|
||||
Option<bool> testPatterns{*this, "test-patterns",
|
||||
llvm::cl::desc("Test a mixed set of patterns"),
|
||||
llvm::cl::init(false)};
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Apply transformations specified as patterns.
|
||||
void TestLinalgTransforms::runOnFunction() {
|
||||
static void applyPatterns(FuncOp funcOp) {
|
||||
MLIRContext *ctx = funcOp.getContext();
|
||||
OwningRewritePatternList patterns;
|
||||
auto funcOp = getFunction();
|
||||
|
||||
// Add the generated patterns to the list.
|
||||
linalg::populateWithGenerated(&getContext(), &patterns);
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Linalg tiling patterns.
|
||||
//===--------------------------------------------------------------------===//
|
||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
|
||||
LinalgMarker({"MEM", {}}, "L3"));
|
||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
|
||||
LinalgMarker({"L3"}, "L2"));
|
||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
|
||||
LinalgMarker({"L2"}, "L1"));
|
||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
|
||||
LinalgMarker({"L1"}, "REG"));
|
||||
|
||||
patterns.insert<LinalgTilingPattern<MatvecOp>>(
|
||||
ctx,
|
||||
LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
|
||||
LinalgTilingLoopType::ParallelLoops),
|
||||
LinalgMarker({}, "L1"));
|
||||
|
||||
patterns.insert<LinalgTilingPattern<DotOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes(8000),
|
||||
LinalgMarker({"MEM", "L3", "L2", {}}, "REG"));
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Linalg tiling and permutation patterns.
|
||||
//===--------------------------------------------------------------------===//
|
||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx,
|
||||
LinalgTilingOptions()
|
||||
.setTileSizes({2000, 3000, 4000})
|
||||
.setInterchange({1, 2, 0}),
|
||||
LinalgMarker({"__with_perm__"}, "L2__with_perm__"));
|
||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx,
|
||||
LinalgTilingOptions()
|
||||
.setTileSizes({200, 300, 400})
|
||||
.setInterchange({1, 0, 2}),
|
||||
LinalgMarker({"L2__with_perm__"}, "L1__with_perm__"));
|
||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
|
||||
LinalgMarker({"L1__with_perm__"}, "REG__with_perm__"));
|
||||
|
||||
patterns.insert<LinalgTilingPattern<MatvecOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
|
||||
LinalgMarker({"__with_perm__"}, "L1__with_perm__"));
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Linalg to loops patterns.
|
||||
//===--------------------------------------------------------------------===//
|
||||
patterns.insert<LinalgLoweringPattern<DotOp>>(
|
||||
ctx,
|
||||
/*loweringType=*/LinalgLoweringType::Loops, LinalgMarker({"REG"}));
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Linalg to vector contraction patterns.
|
||||
//===--------------------------------------------------------------------===//
|
||||
patterns.insert<LinalgVectorizationPattern<MatmulOp>,
|
||||
LinalgVectorizationPattern<FillOp>,
|
||||
LinalgVectorizationPattern<GenericOp>>(
|
||||
ctx, LinalgMarker({"VECTORIZE"}));
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Linalg generic permutation patterns.
|
||||
//===--------------------------------------------------------------------===//
|
||||
patterns.insert<LinalgInterchangePattern<GenericOp>>(
|
||||
ctx,
|
||||
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
|
||||
LinalgMarker({}, "PERMUTED"));
|
||||
patterns.insert<LinalgInterchangePattern<IndexedGenericOp>>(
|
||||
ctx,
|
||||
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
|
||||
LinalgMarker({}, "PERMUTED"));
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Linalg subview operands promotion.
|
||||
//===--------------------------------------------------------------------===//
|
||||
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
|
||||
ctx, LinalgMarker({"_promote_views_"}, "_views_promoted_"));
|
||||
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
|
||||
ctx,
|
||||
/*operandsToPromote=*/ArrayRef<unsigned>{0},
|
||||
LinalgMarker({"_promote_first_view_"}, "_first_view_promoted_"));
|
||||
patterns.insert<LinalgPromotionPattern<FillOp>>(
|
||||
ctx,
|
||||
/*operandsToPromote=*/ArrayRef<unsigned>{0},
|
||||
/*alignment=*/32,
|
||||
LinalgMarker({"_promote_views_aligned_"}, "_views_aligned_promoted_"));
|
||||
|
||||
applyPatternsAndFoldGreedily(funcOp, patterns);
|
||||
|
||||
// Drop the marker.
|
||||
|
@ -48,9 +139,15 @@ void TestLinalgTransforms::runOnFunction() {
|
|||
});
|
||||
}
|
||||
|
||||
/// Apply transformations specified as patterns.
|
||||
void TestLinalgTransforms::runOnFunction() {
|
||||
if (testPatterns)
|
||||
return applyPatterns(getFunction());
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
void registerTestLinalgTransforms() {
|
||||
PassRegistration<TestLinalgTransforms>(
|
||||
PassRegistration<TestLinalgTransforms> testTransformPatternsPass(
|
||||
"test-linalg-transform-patterns",
|
||||
"Test Linalg transformation patterns by applying them greedily.");
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue