[mlir][Linalg] Mostly NFC - Refactor Linalg patterns and transformations.

Linalg transformations are currently exposed as DRRs.
Unfortunately RewriterGen does not play well with the line of work on named linalg ops which require variadic operands and results.
Additionally, DRR is arguably not the right abstraction to expose compositions of such patterns that don't rely on SSA use-def semantics.

This revision abandons DRRs and exposes manually written C++ patterns.

Refactorings and cleanups are performed to uniformize APIs.
This refactoring will allow replacing the currently manually specified Linalg named ops.

A collateral victim of this refactoring is the `tileAndFuse` DRR, and the one associated test, which will be revived at a later time.

Lastly, the following 2 tests do not add value and are altered:
- a dot_perm tile + interchange test does not test anything new and is removed
- a dot tile + lower to loops does not need 2-D tiling and is trimmed.
This commit is contained in:
Nicolas Vasilache 2020-05-02 01:03:37 -04:00
parent c49f83b6e9
commit 307cfdf533
19 changed files with 1002 additions and 1077 deletions

View File

@ -1,5 +1,4 @@
add_subdirectory(IR) add_subdirectory(IR)
add_subdirectory(Transforms)
set(LLVM_TARGET_DEFINITIONS Passes.td) set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls) mlir_tablegen(Passes.h.inc -gen-pass-decls)

View File

@ -1,7 +0,0 @@
set(LLVM_TARGET_DEFINITIONS LinalgTransformPatterns.td)
mlir_tablegen(LinalgTransformPatterns.h.inc -gen-rewriters)
add_public_tablegen_target(MLIRLinalgTransformPatternsIncGen)
# Including Linalg in TableGen requires to depends on generated files
add_dependencies(MLIRLinalgTransformPatternsIncGen LinalgOdsGen)

View File

@ -1,123 +0,0 @@
//===- LinalgPatterns.td - Linalg transformation patterns --*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the pattern definition file for declarative Linalg transformation.
//
//===----------------------------------------------------------------------===//
#ifndef LINALG_TRANSFORMS
#define LINALG_TRANSFORMS
include "mlir/Dialect/Linalg/IR/LinalgOps.td"
include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td"
include "mlir/Dialect/Affine/IR/AffineOps.td"
def HasNoLinalgTransformMarker : CPred<[{
!op.getAttrOfType<StringAttr>(LinalgTransforms::kLinalgTransformMarker)
}]>;
class HasLinalgTransformMarker<string str> : CPred<[{
op.getAttrOfType<StringAttr>(
LinalgTransforms::kLinalgTransformMarker) &&
op.getAttrOfType<StringAttr>(
LinalgTransforms::kLinalgTransformMarker).getValue() == "}] # str # [{"}]>;
class IsProducedByOpOfType<string str> :
CPred<"isProducedByOpOfType<" # str # ">(op, $0)">;
class AffineMapDomainHasDim<int n> : CPred<[{
op.getAttrOfType<ArrayAttr>(getIndexingMapsAttrName()).getValue()[0].
cast<AffineMapAttr>().getValue().getNumDims() ==}] # n # [{}]>;
class HasOperandsOfType<string type>: CPred<[{
llvm::any_of(op.getOperands(),
[](Value v) {
return dyn_cast_or_null<}] # type # [{>(v.getDefiningOp());
})
}]>;
//===----------------------------------------------------------------------===//
// Linalg fusion patterns.
//===----------------------------------------------------------------------===//
//
// In the future, tile sizes should be derived from op properties + machine
// description but we do not need to wait on this to start having useful
// patterns.
class TileAndFuseLinalgOp<
list<int> sizes, list<int> operandIndices, string value> : NativeCodeCall<
"if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, op, {" #
StrJoinInt<sizes>.result # "}, {" # StrJoinInt<operandIndices>.result # "}," #
" \"" # value # "\")))" #
" return failure();">;
//===----------------------------------------------------------------------===//
// Linalg tiling patterns.
//===----------------------------------------------------------------------===//
//
// In the future, tile sizes should be derived from op properties + machine
// description but we do not need to wait on this to start having useful
// patterns.
// `permutation` is an optional parameter to specify the ordering of the
// tiled loops. If provided, it must be a list of integers with the same number
// of elements as `sizes`.
class TileLinalgOp<list<int> sizes, string value, list<int> permutation=[]> :
NativeCodeCall<
"if (failed(tileLinalgOpAndSetMarker($_builder, op, {" #
StrJoinInt<sizes>.result # "}, \"" # value # "\", {" #
StrJoinInt<permutation>.result # "})))" #
" return failure();">;
//===----------------------------------------------------------------------===//
// Linalg to loop patterns.
//===----------------------------------------------------------------------===//
class LinalgOpToLoops<string OpType> : NativeCodeCall<
"if (failed(linalgOpToLoops<" # OpType # ">($_builder, op))) " #
" return failure();">;
class LinalgOpToParallelLoops<string OpType> : NativeCodeCall<
"if (failed(linalgOpToParallelLoops<" # OpType # ">($_builder, op))) " #
" return failure();">;
class LinalgOpToAffineLoops<string OpType> : NativeCodeCall<
"if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, op))) " #
" return failure();">;
//===----------------------------------------------------------------------===//
// Linalg to vector patterns precondition and DRR.
//===----------------------------------------------------------------------===//
def PreconditionVectorizeLinalgOp : CPred<
"succeeded(vectorizeLinalgOpPrecondition(op))">;
def VectorizeLinalgOp : NativeCodeCall<
"vectorizeLinalgOp($_builder, op)">;
//===----------------------------------------------------------------------===//
// Linalg generic permutation patterns precondition and DRR.
//===----------------------------------------------------------------------===//
class PreconditionPermuteGenericLinalgOp<list<int> permutation> : CPred<
"succeeded(permuteGenericLinalgOpPrecondition(op, {" #
StrJoinInt<permutation>.result # "}))">;
class PermuteGenericLinalgOp<list<int> permutation, string value> :
NativeCodeCall<
"permuteGenericLinalgOp($_builder, op, {" # StrJoinInt<permutation>.result #
"}, \"" # value # "\")">;
//===----------------------------------------------------------------------===//
// Linalg promote subview operands precondition and DRR.
//===----------------------------------------------------------------------===//
def PreconditionPromoteSubviewsLinalgOp : CPred<
"succeeded(promoteSubviewsLinalgOpPrecondition(op))">;
def PromoteSubviewsLinalgOp : NativeCodeCall<
"promoteSubviewsLinalgOp($_builder, op)">;
class PromoteSelectedSubviewsLinalgOp<list<int> operands, string marker="",
int alignment=0> :
NativeCodeCall<"promoteSelectedSubviewsLinalgOpAndSetMarker($_builder, op, {" #
StrJoinInt<operands>.result # "}, \"" # marker # "\", " # alignment # ")">;
#endif // LINALG_TRANSFORMS

View File

@ -1,137 +0,0 @@
//===- LinalgTransforms.h - Linalg transformations as patterns --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef DIALECT_LINALG_TRANSFORMS_LINALGTRANSFORMS_H_
#define DIALECT_LINALG_TRANSFORMS_LINALGTRANSFORMS_H_
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/STLExtras.h"
namespace mlir {
namespace linalg {
// Marker used as attribute name in generated Linalg rewriting transformations.
struct LinalgTransforms {
static const StringLiteral kLinalgTransformMarker;
};
namespace detail {
// Implementation detail of isProducedByOpOfType avoids the need for explicit
// template instantiations.
bool isProducedByOpOfTypeImpl(Operation *consumerOp, Value consumedView,
function_ref<bool(Operation *)> isaOpType);
} // namespace detail
// Returns true if the `consumedView` value use in `consumerOp` is produced by
// an op of type `OpTy`. This is used to implement use-def type information on
// buffers.
template <typename OpTy>
bool isProducedByOpOfType(Operation *consumerOp, Value consumedView) {
return detail::isProducedByOpOfTypeImpl(
consumerOp, consumedView, [](Operation *op) { return isa<OpTy>(op); });
}
////////////////////////////////////////////////////////////////////////////////
// The following Declarative Rewrite Rule (DRR) helpers are used in rewrite
// patterns. As such, they must not call into `rewriter.erase/replace` APIs and
// it is the responsibility of the enclosing PatternRewriter to erase on
// success.
////////////////////////////////////////////////////////////////////////////////
/// Tiles `op` by `sizes` permuting the loops according to `permutation` and
/// sets the attribute `kLinalgTransformMarker` to `linalgMarker`. The
/// permutation is expressed as a list of integers that specify the new ordering
/// of the loop nest (using loop.for operations). The length of `permutation`
/// must be equal to the length of `tileSizes`.
/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with
/// `permutation = [1,2,0]`. All values in `permutation` must be
/// integers, in the range 0..`tileSizes.size()` without duplications
/// (i.e. `[1,1,2]` is an invalid permutation). An empty list
/// states for the identity permutation.
LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, Operation *op,
ArrayRef<int64_t> sizes,
StringRef linalgMarker,
ArrayRef<unsigned> permutation);
/// Tiles ops similar to `tileLinalgOpAndSetMarker` but generates loop.parallel
/// operations instead.
LogicalResult tileLinalgOpToParallelLoopsAndSetMarker(
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
StringRef linalgMarker, ArrayRef<unsigned> permutation);
/// Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and
/// sets the attribute `kLinalgTransformMarker` to `linalgMarker`.
LogicalResult tileAndFuseLinalgOpAndSetMarker(
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker);
/// Tiles ops similar to `tileAndFuseLinalgOpAndSetMarker` but generates
/// loop.parallel operations instead.
LogicalResult tileAndFuseLinalgOpToParallelLoopsAndSetMarker(
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker);
using LinalgLoops = SmallVector<Operation *, 4>;
/// Emits a loop nest of with the proper body for `op`.
template <typename LoopTy, typename ConcreteOp>
Optional<LinalgLoops> linalgLowerOpToLoops(PatternRewriter &rewriter,
Operation *op);
/// Emits a loop nest of `loop.for` with the proper body for `op`.
template <typename ConcreteOp>
LogicalResult linalgOpToLoops(PatternRewriter &rewriter, Operation *op);
/// Emits a loop nest of `loop.parallel` with the proper body for `op`.
template <typename ConcreteOp>
LogicalResult linalgOpToParallelLoops(PatternRewriter &rewriter, Operation *op);
/// Emits a loop nest of `affine.for` with the proper body for `op`.
template <typename ConcreteOp>
LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op);
/// Rewrite a linalg.generic into a suitable vector.contraction op.
LogicalResult vectorizeLinalgOpPrecondition(Operation *op);
SmallVector<Value, 0> vectorizeLinalgOp(PatternRewriter &rewriter,
Operation *op);
/// Emits a `generic` or `indexed_generic` operation with the `indexing_maps`
/// and `iterator_types` permutated according to `permutation`.
LogicalResult
permuteGenericLinalgOpPrecondition(Operation *op,
ArrayRef<unsigned> permutation);
SmallVector<Value, 0> permuteGenericLinalgOp(PatternRewriter &rewriter,
Operation *op,
ArrayRef<unsigned> permutation,
StringRef linalgMarker);
/// Promote std.subviews feeding linalg operations.
LogicalResult promoteSubviewsLinalgOpPrecondition(Operation *op);
SmallVector<Value, 0> promoteSubviewsLinalgOp(PatternRewriter &rewriter,
Operation *op);
/// Similar to `promoteSubviewsLinalgOp` but only tries to promote
/// the views corresponding to the operands specified in
/// `operandIndicesToPromote`. Generated allocations are memory-aligned
/// according to the `alignment` parameter.
/// If linalgMarker is specified and the transformation is successfull
/// sets the attribute `kLinalgTransformMarker` to `linalgMarker`.
SmallVector<Value, 0> promoteSelectedSubviewsLinalgOpAndSetMarker(
PatternRewriter &rewriter, Operation *op,
ArrayRef<int64_t> operandIndicesToPromote, StringRef linalgMarker = "",
int64_t alignment = 0);
} // namespace linalg
} // namespace mlir
#endif // DIALECT_LINALG_TRANSFORMS_LINALGTRANSFORMS_H_

View File

@ -0,0 +1,375 @@
//===- Transforms.h - Linalg transformations as patterns --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_
#define DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
namespace linalg {
//===----------------------------------------------------------------------===//
// Transformations exposed as function calls.
//===----------------------------------------------------------------------===//
using LinalgLoops = SmallVector<Operation *, 4>;
struct TiledLinalgOp {
LinalgOp op;
SmallVector<Operation *, 8> loops;
};
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
/// and permute the loop nest according to `interchangeVector`
/// The permutation is expressed as a list of integers that specify
/// the new ordering of the loop nest. The length of `interchangeVector`
/// must be equal to the length of `tileSizes`.
/// An empty vector is interpreted as the identity permutation and the
/// transformation returns early.
///
/// When non-null, the optional pointer `folder` is used to call into the
/// `createAndFold` builder method. If `folder` is null, the regular `create`
/// method is called.
///
/// Returns a struct containing the tiled loops in the specified order
/// and the cloned op if successful, llvm::None otherwise.
///
/// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed by
/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
/// integers, in the range 0..`tileSizes.size()` without duplications
/// (i.e. `[1,1,2]` is an invalid permutation).
Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
ArrayRef<Value> tileSizes,
ArrayRef<unsigned> interchangeVector = {},
OperationFolder *folder = nullptr);
Optional<TiledLinalgOp>
tileLinalgOpToParallelLoops(OpBuilder &b, LinalgOp op,
ArrayRef<Value> tileSizes,
ArrayRef<unsigned> interchangeVector = {},
OperationFolder *folder = nullptr);
/// Performs standalone tiling of a single LinalgOp by constant `tileSizes`.
/// See `tileLinalgOp(... ArrayRef<Value> tileSizes,)` for more details
Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
ArrayRef<int64_t> tileSizes,
ArrayRef<unsigned> interchangeVector = {},
OperationFolder *folder = nullptr);
Optional<TiledLinalgOp>
tileLinalgOpToParallelLoops(OpBuilder &b, LinalgOp op,
ArrayRef<int64_t> tileSizes,
ArrayRef<unsigned> interchangeVector = {},
OperationFolder *folder = nullptr);
/// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`.
/// This is an in-place transformation controlled by `interchangeVector`.
/// An empty vector is interpreted as the identity permutation and the
/// transformation returns early.
///
/// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed with
/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
/// integers, in the range 0..`op.rank` without duplications
/// (i.e. `[1,1,2]` is an invalid permutation).
LinalgOp interchange(LinalgOp op, ArrayRef<unsigned> interchangeVector);
/// Promotes the `subViews` into a new buffer allocated at the insertion point
/// `b`. Promotion occurs in 3 steps:
/// 1. Create a new buffer for a full tile (i.e. not clipped at the boundary).
/// 2. Take a full view on the buffer and `linalg.fill` it with zeros (use
/// float zero for now).
/// 3. Take a partial slice of the full view in step 2. and copy into it.
/// Infers statically sized buffers from subViews unless `dynamicBuffers` is
/// true.
///
/// Returns a list of PromotionInfo which hold the promoted buffer and the
/// full and partial views indexing into the buffer.
// TODO: revisit dynamicBuffers option.
LinalgOp promoteSubViewOperands(OpBuilder &b, LinalgOp op,
llvm::SetVector<Value> subViews,
bool dynamicBuffers = false,
int64_t alignment = 0,
OperationFolder *folder = nullptr);
/// Emit a suitable vector form for a Linalg op with fully static shape.
void vectorizeLinalgOp(OpBuilder &builder, Operation *op);
/// Emits a loop nest of `LoopTy` with the proper body for `op`.
template <typename LoopTy, typename ConcreteOp>
Optional<LinalgLoops> linalgLowerOpToLoops(OpBuilder &builder, Operation *op);
/// Emits a loop nest of `loop.for` with the proper body for `op`.
template <typename ConcreteOp>
LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op);
/// Emits a loop nest of `loop.parallel` with the proper body for `op`.
template <typename ConcreteOp>
LogicalResult linalgOpToParallelLoops(OpBuilder &builder, Operation *op);
/// Emits a loop nest of `affine.for` with the proper body for `op`.
template <typename ConcreteOp>
LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op);
//===----------------------------------------------------------------------===//
// Preconditions that ensure the corresponding transformation suceeds and can be
// applied as a rewrite pattern.
//===----------------------------------------------------------------------===//
/// Emits a `generic` or `indexed_generic` operation with the `indexing_maps`
/// and `iterator_types` permutated according to `permutation`.
LogicalResult
interchangeGenericLinalgOpPrecondition(Operation *op,
ArrayRef<unsigned> interchangeVector);
/// Promote std.subviews feeding linalg operations.
LogicalResult promoteSubviewsLinalgOpPrecondition(
Operation *op, Optional<DenseSet<unsigned>> operandIndicesToPromote = None);
/// Rewrite a linalg.generic into a suitable vector.contraction op.
LogicalResult vectorizeLinalgOpPrecondition(Operation *op);
//===----------------------------------------------------------------------===//
// Transformations exposed as rewrite patterns.
//===----------------------------------------------------------------------===//
// Marker used as attribute name in generated Linalg rewriting transformations.
struct LinalgTransforms {
static const StringLiteral kLinalgTransformMarker;
};
/// Helper class to control common attribute matching and setting behavior.
struct LinalgMarker {
LinalgMarker(ArrayRef<StringRef> matchDisjunction = {},
Optional<StringRef> replacement = None);
LinalgMarker(ArrayRef<StringRef> matchDisjunction, StringRef replacement);
LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
void replaceLinalgMarker(PatternRewriter &rewriter, Operation *op) const;
private:
SmallVector<StringRef, 4> matchDisjunction;
Optional<StringRef> replacement;
};
///
/// Linalg tiling patterns.
///
/// Apply the `tileLinalgOp` transformation as a pattern.
/// `marker` controls LinalgTransformMarker matching and update when specified.
/// See `tileLinalgOp` for more details.
enum class LinalgTilingLoopType {
Loops = 0,
AffineLoops = 1,
ParallelLoops = 2
};
struct LinalgTilingOptions {
/// The tile sizes by which to tile.
SmallVector<int64_t, 4> tileSizes{};
LinalgTilingOptions &setTileSizes(ArrayRef<int64_t> ts) {
tileSizes.assign(ts.begin(), ts.end());
return *this;
}
/// The interchange vector to reorder the tiled loops.
SmallVector<unsigned, 4> interchangeVector{};
LinalgTilingOptions &setInterchange(ArrayRef<unsigned> interchange) {
interchangeVector.assign(interchange.begin(), interchange.end());
return *this;
}
/// The type of tile loops to generate.
LinalgTilingLoopType loopType{LinalgTilingLoopType::Loops};
LinalgTilingOptions &setLoopType(LinalgTilingLoopType lt) {
loopType = lt;
return *this;
}
};
struct LinalgBaseTilingPattern : public RewritePattern {
LinalgBaseTilingPattern(StringRef opName, MLIRContext *context,
LinalgTilingOptions options,
LinalgMarker marker = LinalgMarker(),
PatternBenefit benefit = 1);
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgMarker marker;
/// Options to control tiling;
LinalgTilingOptions options;
};
template <typename OpTy>
struct LinalgTilingPattern : public LinalgBaseTilingPattern {
LinalgTilingPattern(MLIRContext *context, LinalgTilingOptions options,
LinalgMarker marker = LinalgMarker(),
PatternBenefit benefit = 1)
: LinalgBaseTilingPattern(OpTy::getOperationName(), context, options,
marker, benefit) {}
};
///
/// Linalg interchange patterns.
///
/// Apply the `interchange` transformation as a pattern.
/// `marker` controls LinalgTransformMarker matching and update when specified.
/// See `interchange` for more details.
struct LinalgBaseInterchangePattern : public RewritePattern {
LinalgBaseInterchangePattern(StringRef opName, MLIRContext *context,
ArrayRef<unsigned> interchangeVector,
LinalgMarker marker = LinalgMarker(),
PatternBenefit benefit = 1);
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgMarker marker;
/// The interchange vector to reorder the iterators and indexing_maps dims.
SmallVector<unsigned, 8> interchangeVector;
};
template <typename OpTy>
struct LinalgInterchangePattern : public LinalgBaseInterchangePattern {
LinalgInterchangePattern(MLIRContext *context,
ArrayRef<unsigned> interchangeVector,
LinalgMarker marker = LinalgMarker(),
PatternBenefit benefit = 1)
: LinalgBaseInterchangePattern(OpTy::getOperationName(), context,
interchangeVector, marker, benefit) {}
};
///
/// Linalg promotion patterns.
///
/// Apply the `promoteSubViewOperands` transformation as a pattern.
/// `marker` controls LinalgTransformMarker matching and update when specified.
/// See `promoteSubViewOperands` for more details.
struct LinalgBasePromotionPattern : public RewritePattern {
LinalgBasePromotionPattern(StringRef opName, MLIRContext *context,
ArrayRef<unsigned> operandsToPromote = {},
unsigned alignment = 0,
LinalgMarker marker = LinalgMarker(),
PatternBenefit benefit = 1);
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgMarker marker;
/// Indices of subViews to promote.
SmallVector<unsigned, 4> operandsToPromote;
/// Alignment of promoted buffer.
unsigned alignment;
};
template <typename OpTy>
struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
LinalgPromotionPattern(MLIRContext *context,
ArrayRef<unsigned> operandsToPromote = {},
unsigned alignment = 0,
LinalgMarker marker = LinalgMarker(),
PatternBenefit benefit = 1)
: LinalgBasePromotionPattern(OpTy::getOperationName(), context,
operandsToPromote, alignment, marker,
benefit) {}
LinalgPromotionPattern(MLIRContext *context,
ArrayRef<unsigned> operandsToPromote,
LinalgMarker marker = LinalgMarker(),
PatternBenefit benefit = 1)
: LinalgPromotionPattern(context, operandsToPromote, 0, marker, benefit) {
}
LinalgPromotionPattern(MLIRContext *context, unsigned alignment,
LinalgMarker marker = LinalgMarker(),
PatternBenefit benefit = 1)
: LinalgPromotionPattern(context, {}, alignment, marker, benefit) {}
LinalgPromotionPattern(MLIRContext *context, LinalgMarker marker,
PatternBenefit benefit = 1)
: LinalgPromotionPattern(context, {}, 0, marker, benefit) {}
};
///
/// Linalg vectorization patterns.
///
/// Apply the `vectorizeLinalgOp` transformation as a pattern.
/// `marker` controls LinalgTransformMarker matching and update when specified.
/// See `vectorizeLinalgOp` for more details.
struct LinalgBaseVectorizationPattern : public RewritePattern {
LinalgBaseVectorizationPattern(StringRef opName, MLIRContext *context,
LinalgMarker marker = LinalgMarker(),
PatternBenefit benefit = 1);
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgMarker marker;
};
template <typename OpTy>
struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern {
LinalgVectorizationPattern(MLIRContext *context,
LinalgMarker marker = LinalgMarker(),
PatternBenefit benefit = 1)
: LinalgBaseVectorizationPattern(OpTy::getOperationName(), context,
marker, benefit) {}
};
///
/// Linalg lowering patterns.
///
/// Apply the `linalgLowerOpToLoops` transformation as a pattern.
/// `marker` controls LinalgTransformMarker matching and update when specified.
/// See `linalgLowerOpToLoops` for more details.
enum class LinalgLoweringType {
LibraryCall = 0,
Loops = 1,
AffineLoops = 2,
ParallelLoops = 3
};
template <typename OpTy>
struct LinalgLoweringPattern : public RewritePattern {
LinalgLoweringPattern(MLIRContext *context, LinalgLoweringType loweringType,
LinalgMarker marker = LinalgMarker(),
PatternBenefit benefit = 1)
: RewritePattern(OpTy::getOperationName(), {}, benefit, context),
marker(marker), loweringType(loweringType) {}
// TODO: Move implementation to .cpp once named ops are auto-generated.
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
return failure();
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
return failure();
if (failed(promoteSubviewsLinalgOpPrecondition(op)))
return failure();
if (loweringType == LinalgLoweringType::LibraryCall) {
// TODO: Move lowering to library calls here.
return failure();
} else if (loweringType == LinalgLoweringType::Loops) {
if (failed(linalgOpToLoops<OpTy>(rewriter, op)))
return failure();
} else if (loweringType == LinalgLoweringType::AffineLoops) {
if (failed(linalgOpToAffineLoops<OpTy>(rewriter, op)))
return failure();
} else if (failed(linalgOpToParallelLoops<OpTy>(rewriter, op))) {
return failure();
}
rewriter.eraseOp(op);
return success();
}
private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgMarker marker;
/// Controls whether the pattern lowers to library calls, loop.for, affine.for
/// or loop.parallel.
LinalgLoweringType loweringType;
};
} // namespace linalg
} // namespace mlir
#endif // DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_

View File

@ -101,63 +101,6 @@ SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
AffineMap map, ArrayRef<Value> values, AffineMap map, ArrayRef<Value> values,
OperationFolder *folder = nullptr); OperationFolder *folder = nullptr);
struct TiledLinalgOp {
LinalgOp op;
SmallVector<Operation *, 8> loops;
};
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
/// and permute the loop nest according to `permutation`
/// The permutation is expressed as a list of integers that specify
/// the new ordering of the loop nest. The length of `permutation`
/// must be equal to the length of `tileSizes`.
/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with
/// `permutation = [1,2,0]`. All values in `permutation` must be
/// integers, in the range 0..`tileSizes.size()` without duplications
/// (i.e. `[1,1,2]` is an invalid permutation). An empty list
/// states for the identity permutation.
/// Returns a struct containing the tiled loops in the specified order
/// and the cloned op if successful, llvm::None otherwise.
/// When non-null, the optional pointer `folder` is used to call into the
/// `createAndFold` builder method. If `folder` is null, the regular `create`
/// method is called.
Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
ArrayRef<Value> tileSizes,
ArrayRef<unsigned> permutation = {},
OperationFolder *folder = nullptr);
Optional<TiledLinalgOp> tileLinalgOpToParallelLoops(
OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
ArrayRef<unsigned> permutation = {}, OperationFolder *folder = nullptr);
/// Performs standalone tiling of a single LinalgOp by constant `tileSizes`.
/// and permute the loop nest according to `permutation`
/// The permutation is expressed as a list of integers that specify
/// the new ordering of the loop nest. The length of `permutation`
/// must be equal to the length of `tileSizes`.
/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with
/// `permutation = [1,2,0]`. All values in `permutation` must be
/// integers, in the range 0..`tileSizes.size()` without duplications
/// (i.e. `[1,1,2]` is an invalid permutation). An empty list
/// states for the identity permutation.
/// Returns a struct containing the tiled loops in the specified order
/// and the cloned op if successful, llvm::None otherwise.
/// When non-null, the optional pointer `folder` is used to call into the
/// `createAndFold` builder method. If `folder` is null, the regular `create`
/// method is called.
Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
ArrayRef<int64_t> tileSizes,
ArrayRef<unsigned> permutation = {},
OperationFolder *folder = nullptr);
Optional<TiledLinalgOp> tileLinalgOpToParallelLoops(
OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
ArrayRef<unsigned> permutation = {}, OperationFolder *folder = nullptr);
template <typename... Args>
Optional<TiledLinalgOp> tileLinalgOperation(OpBuilder &b, Operation *op,
Args... args) {
return tileLinalgOp(b, cast<LinalgOp>(op), args...);
}
struct PromotionInfo { struct PromotionInfo {
Value buffer; Value buffer;
Value fullLocalView; Value fullLocalView;
@ -198,17 +141,6 @@ void applyPermutationToVector(SmallVector<T, N> &inVec,
inVec = auxVec; inVec = auxVec;
} }
/// Prepares the SubView promotion later performed by `promoteSubViews`
/// (where most of the transformation happens). It arranges the new
/// operands for `LinalgOp op` and deallocates the new buffer(s)
/// It is the entry point for declarative transformation
/// Returns the cloned `LinalgOp` with the new operands
LinalgOp promoteSubViewOperands(OpBuilder &b, LinalgOp op,
llvm::SetVector<Value> subViews,
bool dynamicBuffers = false,
int64_t alignment = 0,
OperationFolder *folder = nullptr);
} // namespace linalg } // namespace linalg
} // namespace mlir } // namespace mlir

View File

@ -1,9 +1,11 @@
add_mlir_dialect_library(MLIRLinalgTransforms add_mlir_dialect_library(MLIRLinalgTransforms
Fusion.cpp Fusion.cpp
LinalgTransforms.cpp Interchange.cpp
LinalgToLoops.cpp Loops.cpp
Promotion.cpp Promotion.cpp
Tiling.cpp Tiling.cpp
Transforms.cpp
Vectorization.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
@ -11,7 +13,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
DEPENDS DEPENDS
intrinsics_gen intrinsics_gen
MLIRLinalgPassIncGen MLIRLinalgPassIncGen
MLIRLinalgTransformPatternsIncGen
) )
target_link_libraries(MLIRLinalgTransforms target_link_libraries(MLIRLinalgTransforms
PUBLIC PUBLIC

View File

@ -0,0 +1,85 @@
//===- Interchange.cpp - Linalg interchange transformation ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the linalg interchange transformation.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <type_traits>
#define DEBUG_TYPE "linalg-interchange"
using namespace mlir;
using namespace mlir::linalg;
LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition(
Operation *op, ArrayRef<unsigned> interchangeVector) {
if (interchangeVector.empty())
return failure();
// Transformation applies to generic ops only.
if (!isa<GenericOp>(op) && !isa<IndexedGenericOp>(op))
return failure();
LinalgOp linOp = cast<LinalgOp>(op);
// Transformation applies to buffers only.
if (!linOp.hasBufferSemantics())
return failure();
// Permutation must be applicable.
if (linOp.getIndexingMap(0).getNumInputs() != interchangeVector.size())
return failure();
// Permutation map must be invertible.
if (!inversePermutation(
AffineMap::getPermutationMap(interchangeVector, op->getContext())))
return failure();
return success();
}
LinalgOp mlir::linalg::interchange(LinalgOp op,
ArrayRef<unsigned> interchangeVector) {
if (interchangeVector.empty())
return op;
MLIRContext *context = op.getContext();
auto permutationMap = inversePermutation(
AffineMap::getPermutationMap(interchangeVector, context));
assert(permutationMap && "expected permutation to be invertible");
SmallVector<Attribute, 4> newIndexingMaps;
auto indexingMaps = op.indexing_maps().getValue();
for (unsigned i = 0, e = op.getNumInputsAndOutputs(); i != e; ++i) {
AffineMap m = indexingMaps[i].cast<AffineMapAttr>().getValue();
if (!permutationMap.isEmpty())
m = m.compose(permutationMap);
newIndexingMaps.push_back(AffineMapAttr::get(m));
}
auto itTypes = op.iterator_types().getValue();
SmallVector<Attribute, 4> itTypesVector;
for (unsigned i = 0, e = itTypes.size(); i != e; ++i)
itTypesVector.push_back(itTypes[i]);
applyPermutationToVector(itTypesVector, interchangeVector);
op.setAttr(getIndexingMapsAttrName(),
ArrayAttr::get(newIndexingMaps, context));
op.setAttr(getIteratorTypesAttrName(),
ArrayAttr::get(itTypesVector, context));
return op;
}

View File

@ -1,381 +0,0 @@
//===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements logic for transforming Linalg operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h"
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <type_traits>
#define DEBUG_TYPE "linalg-transforms"
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
using namespace mlir::linalg;
using llvm::dbgs;
using llvm::SetVector;
// Marker used as attribute name in generated Linalg rewriting transformations.
const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
"__internal_linalg_transform__";
using TileFn = Optional<TiledLinalgOp>(OpBuilder &, LinalgOp, ArrayRef<int64_t>,
ArrayRef<unsigned>, OperationFolder *);
static LogicalResult
tileLinalgOpAndSetMarkerImpl(TileFn tileFn, PatternRewriter &rewriter,
Operation *op, ArrayRef<int64_t> sizes,
StringRef linalgMarker,
ArrayRef<unsigned> permutation) {
assert(permutation.empty() || permutation.size() == sizes.size());
auto tileRes = tileFn(rewriter, op, sizes, permutation, /*folder=*/nullptr);
if (!tileRes)
return failure();
tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker,
rewriter.getStringAttr(linalgMarker));
return success();
}
LogicalResult mlir::linalg::tileLinalgOpAndSetMarker(
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
StringRef linalgMarker, ArrayRef<unsigned> permutation) {
return tileLinalgOpAndSetMarkerImpl(tileLinalgOp, rewriter, op, sizes,
linalgMarker, permutation);
}
LogicalResult mlir::linalg::tileLinalgOpToParallelLoopsAndSetMarker(
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
StringRef linalgMarker, ArrayRef<unsigned> permutation) {
return tileLinalgOpAndSetMarkerImpl(tileLinalgOpToParallelLoops, rewriter, op,
sizes, linalgMarker, permutation);
}
static LogicalResult
tileAndFuseLinalgOpAndSetMarkerImpl(TileFn tileFn, PatternRewriter &rewriter,
Operation *op, ArrayRef<int64_t> sizes,
ArrayRef<int64_t> operandIndicesToFuse,
StringRef linalgMarker) {
auto tileRes =
tileFn(rewriter, op, sizes, /*permutation=*/{}, /*folder=*/nullptr);
if (!tileRes)
return failure();
tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker,
rewriter.getStringAttr(linalgMarker));
Aliases aliases;
auto G = LinalgDependenceGraph::buildDependenceGraph(
aliases, op->getParentOfType<FuncOp>());
SmallVector<Operation *, 4> originalProducers;
for (auto operandIdx : operandIndicesToFuse) {
auto fusionRes = fuseProducerOf(rewriter, tileRes->op, operandIdx, G);
if (!fusionRes) {
// Linalg fusion requires tiled loops to even determine whether it is
// possible to fuse. As a consequence, the pattern may fail even though a
// tiled version of op has already been introduced.
// So we need to remove the tiled version ourselves in case of failure.
// Another possibility is to ensure the constraints on the pattern
// guarantee that fusion will occur and just assert here. As we develop
// more complex patterns we can choose what is best.
rewriter.eraseOp(tileRes->loops[0]);
return failure();
}
fusionRes->fusedProducer.setAttr(LinalgTransforms::kLinalgTransformMarker,
rewriter.getStringAttr(linalgMarker));
originalProducers.push_back(fusionRes->originalProducer);
}
// The originalProducers can now be safely erased. This is similar to
// SSA-value use-def but in the world of buffer + structured ops.
for (auto *originalProducer : originalProducers)
rewriter.eraseOp(originalProducer);
return success();
}
LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker(
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker) {
return tileAndFuseLinalgOpAndSetMarkerImpl(
tileLinalgOp, rewriter, op, sizes, operandIndicesToFuse, linalgMarker);
}
LogicalResult mlir::linalg::tileAndFuseLinalgOpToParallelLoopsAndSetMarker(
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker) {
return tileAndFuseLinalgOpAndSetMarkerImpl(
tileLinalgOpToParallelLoops, rewriter, op, sizes, operandIndicesToFuse,
linalgMarker);
}
bool mlir::linalg::detail::isProducedByOpOfTypeImpl(
Operation *consumerOp, Value consumedView,
function_ref<bool(Operation *)> isaOpType) {
LinalgOp consumer = dyn_cast<LinalgOp>(consumerOp);
assert(consumer.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
if (!consumer)
return false;
auto maybeConsumerIndex = consumer.getIndexOfInput(consumedView);
if (!maybeConsumerIndex)
return false;
Aliases aliases;
auto G = LinalgDependenceGraph::buildDependenceGraph(
aliases, consumer.getParentOfType<FuncOp>());
for (auto dependence : G.getDependencesInto(
consumer, LinalgDependenceGraph::DependenceType::RAW)) {
auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
if (!isProducerLastWriteOfView(G, consumer, consumedView, producer))
continue;
if (isaOpType(dependence.dependentOpView.op))
return true;
}
return false;
}
//============================================================================//
// Precondition and transformation for vectorization of Linalg generic ops.
//============================================================================//
static bool hasMultiplyAddBody(linalg::GenericOp op) {
auto &r = op.region();
if (r.empty())
return false;
if (r.getBlocks().size() != 1)
return false;
auto &ops = r.front().getOperations();
if (ops.size() != 3)
return false;
using mlir::matchers::m_Val;
auto a = m_Val(r.front().getArgument(0));
auto b = m_Val(r.front().getArgument(1));
auto c = m_Val(r.front().getArgument(2));
// TODO(ntv) Update this detection once we have matcher support for
// specifying that any permutation of operands matches.
auto pattern1 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c));
auto pattern2 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b)));
auto pattern3 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c));
auto pattern4 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a)));
return pattern1.match(&ops.back()) || pattern2.match(&ops.back()) ||
pattern3.match(&ops.back()) || pattern4.match(&ops.back());
}
// TODO(ntv) should be Tablegen'd from a single source that generates the op
// itself.
static bool isRowMajorMatmul(linalg::GenericOp genericOp) {
return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
isRowMajorMatmul(genericOp.indexing_maps()) &&
hasMultiplyAddBody(genericOp);
}
// TODO(ntv, ataei): This is in fact much more general than just vectorization
// for matmul and fill ops.
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
auto linalgOp = cast<linalg::LinalgOp>(op);
// All types must be static shape to go to vector.
for (Value operand : linalgOp.getInputsAndOutputBuffers())
if (!operand.getType().cast<ShapedType>().hasStaticShape())
return failure();
for (Type outputTensorType : linalgOp.getOutputTensorTypes())
if (!outputTensorType.cast<ShapedType>().hasStaticShape())
return failure();
if (isa<linalg::MatmulOp>(op) || isa<linalg::FillOp>(op))
return success();
auto genericOp = dyn_cast<linalg::GenericOp>(op);
if (!genericOp || !::isRowMajorMatmul(genericOp))
return failure();
// TODO(ntv): non-identity layout.
auto isStaticMemRefWithIdentityLayout = [](Value v) {
auto m = v.getType().dyn_cast<MemRefType>();
if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty())
return false;
return true;
};
if (!llvm::all_of(genericOp.getInputsAndOutputBuffers(),
isStaticMemRefWithIdentityLayout))
return failure();
return success();
}
SmallVector<Value, 0> mlir::linalg::vectorizeLinalgOp(PatternRewriter &rewriter,
Operation *op) {
assert(succeeded(vectorizeLinalgOpPrecondition(op)) &&
"DRR failure case must be a precondition");
auto linalgOp = cast<linalg::LinalgOp>(op);
assert(linalgOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
if (auto convOp = dyn_cast<linalg::ConvOp>(op)) {
// TODO(ntv): add a level of indirection to linalg.generic.
if (convOp.padding())
llvm_unreachable("Unexpected conv with padding");
}
edsc::ScopedContext scope(rewriter, op->getLoc());
if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
// Vectorize fill as a vector.broadcast.
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
"]: Rewrite linalg.fill as vector.broadcast: "
<< *op << ":\n");
auto dstMemrefVec = vector_type_cast(fillOp.getOutputBuffer(0));
Value dstVec = std_load(dstMemrefVec);
auto resVec = vector_broadcast(dstVec.getType(), fillOp.value());
std_store(resVec, dstMemrefVec);
} else {
// Vectorize other ops as vector contraction (currently only matmul).
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
"]: Rewrite linalg op as vector.contract: "
<< *op << ":\n");
auto vA = std_load(vector_type_cast(linalgOp.getInput(0)));
auto vB = std_load(vector_type_cast(linalgOp.getInput(1)));
auto vectorMemRefC = vector_type_cast(linalgOp.getOutputBuffer(0));
auto vC = std_load(vectorMemRefC);
auto vRes = vector_contract(vA, vB, vC, linalgOp.indexing_maps(),
linalgOp.iterator_types());
std_store(vRes, vectorMemRefC);
}
return {};
}
//============================================================================//
// Precondition and transformation for permutation of Linalg generic ops.
//============================================================================//
LogicalResult mlir::linalg::permuteGenericLinalgOpPrecondition(
Operation *op, ArrayRef<unsigned> permutation) {
if (permutation.empty())
return failure();
// Transformation applies to generic ops only.
if (!isa<GenericOp>(op) && !isa<IndexedGenericOp>(op))
return failure();
LinalgOp linOp = cast<LinalgOp>(op);
// Transformation applies to buffers only.
if (!linOp.hasBufferSemantics())
return failure();
return success();
}
SmallVector<Value, 0>
mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
ArrayRef<unsigned> permutation,
StringRef linalgMarker) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Permute dims for linalg op: " << *op
<< ":\n");
assert(succeeded(permuteGenericLinalgOpPrecondition(op, permutation)) &&
"DRR failure case must be a precondition");
auto linOp = cast<LinalgOp>(op);
auto permutationMap = inversePermutation(
AffineMap::getPermutationMap(permutation, rewriter.getContext()));
assert(permutationMap && "expected permutation to be invertible");
SmallVector<AffineMap, 4> newIndexingMap;
auto indexingMaps = linOp.indexing_maps().getValue();
for (unsigned i = 0, e = linOp.getNumInputsAndOutputs(); i != e; ++i) {
AffineMap m = indexingMaps[i].cast<AffineMapAttr>().getValue();
if (!permutationMap.isEmpty())
m = m.compose(permutationMap);
newIndexingMap.push_back(m);
}
auto itTypes = linOp.iterator_types().getValue();
SmallVector<Attribute, 4> itTypesVector;
for (unsigned i = 0, e = itTypes.size(); i != e; ++i)
itTypesVector.push_back(itTypes[i]);
applyPermutationToVector(itTypesVector, permutation);
op->setAttr(getIndexingMapsAttrName(),
rewriter.getAffineMapArrayAttr(newIndexingMap));
op->setAttr(getIteratorTypesAttrName(), rewriter.getArrayAttr(itTypesVector));
op->setAttr(LinalgTransforms::kLinalgTransformMarker,
rewriter.getStringAttr(linalgMarker));
linOp.clone(rewriter, linOp.getLoc(), op->getOperands());
return {};
}
//============================================================================//
// Precondition and transformation for Linalg subview promotion.
//============================================================================//
LogicalResult mlir::linalg::promoteSubviewsLinalgOpPrecondition(Operation *op) {
LinalgOp linOp = dyn_cast<LinalgOp>(op);
// Transformation applies to buffers only.
if (!linOp || !linOp.hasBufferSemantics())
return failure();
if (llvm::none_of(linOp.getInputsAndOutputBuffers(), [](Value v) {
return isa_and_nonnull<SubViewOp>(v.getDefiningOp());
}))
return failure();
return success();
}
SmallVector<Value, 0>
mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter,
Operation *op) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Promote subviews for linalg op: "
<< *op << ":\n");
assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) &&
"DRR failure case must be a precondition");
LinalgOp linOp = cast<LinalgOp>(op);
SmallVector<int64_t, 4> toPromote;
int64_t nBuffers = linOp.getNumInputsAndOutputBuffers();
toPromote.reserve(nBuffers);
for (int64_t i = 0; i < nBuffers; ++i)
toPromote.push_back(i);
return promoteSelectedSubviewsLinalgOpAndSetMarker(rewriter, op, toPromote);
}
SmallVector<Value, 0> mlir::linalg::promoteSelectedSubviewsLinalgOpAndSetMarker(
PatternRewriter &rewriter, Operation *op,
ArrayRef<int64_t> operandIndicesToPromote, StringRef linalgMarker,
int64_t alignment) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Promote subviews for linalg op: "
<< *op << ":\n");
assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) &&
"DRR failure case must be a precondition");
if (auto convOp = dyn_cast<linalg::ConvOp>(op)) {
// TODO(ntv): add a level of indirection to linalg.generic.
if (convOp.padding())
llvm_unreachable("Unexpected conv with padding");
}
LinalgOp linOp = cast<LinalgOp>(op);
assert(linOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
SetVector<Value> subViews;
for (int64_t index : operandIndicesToPromote)
if (auto sv =
dyn_cast_or_null<SubViewOp>(linOp.getBuffer(index).getDefiningOp()))
subViews.insert(sv);
if (!subViews.empty()) {
auto newOp =
promoteSubViewOperands(rewriter, linOp, subViews, false, alignment);
if (!linalgMarker.empty())
newOp.setAttr(LinalgTransforms::kLinalgTransformMarker,
rewriter.getStringAttr(linalgMarker));
return {};
}
llvm_unreachable("DRR failure case must be a precondition");
}

View File

@ -1,4 +1,4 @@
//===- LinalgToLoops.cpp - conversion from Linalg library ops to loops-----===// //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===//
// //
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information. // See https://llvm.org/LICENSE.txt for license information.
@ -12,7 +12,7 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/LoopOps/EDSC/Builders.h" #include "mlir/Dialect/LoopOps/EDSC/Builders.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
@ -489,20 +489,6 @@ public:
} }
}; };
/// This struct is for factoring out the implementation and support template
/// instantiations in the following 2 cases:
/// 1. Appending to a list of patterns via RewritePatternList.
/// 2. Direct invocation via `linalgOpToLoops` and `linalgOpToAffineLoops`.
/// The implementation must work both in DRR and inside a RewritePattern. As a
/// consequence, (1) it is only allowed to emit new ops if the match is
/// guaranteed to be a success, (2) it is not allowed erase/replace, and (3) an
/// encompassing pattern must take care of the erasure logic.
template <typename LoopTy, typename ConcreteOpTy>
class LinalgOpToLoopsImpl {
public:
static Optional<LinalgLoops> doit(Operation *op, PatternRewriter &rewriter);
};
namespace { namespace {
/// Helper struct to generate the loop nest for the op. This factored out here /// Helper struct to generate the loop nest for the op. This factored out here
/// to be able to partially specialize this for different LoopTy. /// to be able to partially specialize this for different LoopTy.
@ -573,14 +559,12 @@ public:
} // namespace } // namespace
template <typename LoopTy, typename ConcreteOpTy> template <typename LoopTy, typename ConcreteOpTy>
Optional<LinalgLoops> Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
LinalgOpToLoopsImpl<LoopTy, ConcreteOpTy>::doit(Operation *op,
PatternRewriter &rewriter) {
using Impl = GenerateLoopNest<LoopTy, ConcreteOpTy>; using Impl = GenerateLoopNest<LoopTy, ConcreteOpTy>;
using IndexedValueTy = using IndexedValueTy =
typename GenerateLoopNest<LoopTy, ConcreteOpTy>::IndexedValueTy; typename GenerateLoopNest<LoopTy, ConcreteOpTy>::IndexedValueTy;
ScopedContext scope(rewriter, op->getLoc()); ScopedContext scope(builder, op->getLoc());
// The flattened loopToOperandRangesMaps is expected to be an invertible // The flattened loopToOperandRangesMaps is expected to be an invertible
// permutation map (which is asserted in the inverse calculation). // permutation map (which is asserted in the inverse calculation).
@ -607,7 +591,7 @@ LinalgOpToLoopsImpl<LoopTy, ConcreteOpTy>::doit(Operation *op,
SmallVector<Value, 4> allIvs(nLoops); SmallVector<Value, 4> allIvs(nLoops);
auto loopRanges = auto loopRanges =
emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), invertedMap, emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), invertedMap,
getViewSizes(rewriter, linalgOp)); getViewSizes(builder, linalgOp));
assert(loopRanges.size() == allIvs.size()); assert(loopRanges.size() == allIvs.size());
Impl::doit(linalgOp, loopRanges, allIvs); Impl::doit(linalgOp, loopRanges, allIvs);
// Number of loop ops might be different from the number of ivs since some // Number of loop ops might be different from the number of ivs since some
@ -635,8 +619,7 @@ public:
LogicalResult matchAndRewrite(Operation *op, LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
using Impl = LinalgOpToLoopsImpl<LoopType, ConcreteOp>; if (!linalgOpToLoopsImpl<LoopType, ConcreteOp>(op, rewriter))
if (!Impl::doit(op, rewriter))
return failure(); return failure();
rewriter.eraseOp(op); rewriter.eraseOp(op);
return success(); return success();
@ -662,7 +645,7 @@ public:
} }
}; };
/// Populate the given list with patterns that convert from Linalg to LLVM. /// Populate the given list with patterns that convert from Linalg to loops.
template <typename LoopType> template <typename LoopType>
void FillRewritePatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) { void FillRewritePatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) {
RewritePatternList<LoopType, RewritePatternList<LoopType,
@ -759,50 +742,49 @@ mlir::createConvertLinalgToAffineLoopsPass() {
/// Emits a loop nest with the proper body for `op`. /// Emits a loop nest with the proper body for `op`.
template <typename LoopTy, typename ConcreteOp> template <typename LoopTy, typename ConcreteOp>
Optional<LinalgLoops> Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder,
mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter, Operation *op) { Operation *op) {
return LinalgOpToLoopsImpl<LoopTy, ConcreteOp>::doit(op, rewriter); return linalgOpToLoopsImpl<LoopTy, ConcreteOp>(op, builder);
} }
/// Emits a loop nest of `loop.for` with the proper body for `op`. /// Emits a loop nest of `loop.for` with the proper body for `op`.
template <typename ConcreteOp> template <typename ConcreteOp>
LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter, LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) {
Operation *op) {
Optional<LinalgLoops> loops = Optional<LinalgLoops> loops =
linalgLowerOpToLoops<loop::ForOp, ConcreteOp>(rewriter, op); linalgLowerOpToLoops<loop::ForOp, ConcreteOp>(builder, op);
return loops ? success() : failure(); return loops ? success() : failure();
} }
/// Emits a loop nest of `affine.for` with the proper body for `op`. /// Emits a loop nest of `affine.for` with the proper body for `op`.
template <typename ConcreteOp> template <typename ConcreteOp>
LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter, LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder,
Operation *op) { Operation *op) {
Optional<LinalgLoops> loops = Optional<LinalgLoops> loops =
linalgLowerOpToLoops<AffineForOp, ConcreteOp>(rewriter, op); linalgLowerOpToLoops<AffineForOp, ConcreteOp>(builder, op);
return loops ? success() : failure(); return loops ? success() : failure();
} }
/// Emits a loop nest of `loop.parallel` with the proper body for `op`. /// Emits a loop nest of `loop.parallel` with the proper body for `op`.
template <typename ConcreteOp> template <typename ConcreteOp>
LogicalResult mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter, LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder,
Operation *op) { Operation *op) {
Optional<LinalgLoops> loops = Optional<LinalgLoops> loops =
linalgLowerOpToLoops<loop::ParallelOp, ConcreteOp>(rewriter, op); linalgLowerOpToLoops<loop::ParallelOp, ConcreteOp>(builder, op);
return loops ? success() : failure(); return loops ? success() : failure();
} }
// TODO(ntv) Need to make these instantiations more future-proof to avoid the // TODO Need to make these instantiations more future-proof to avoid the need to
// need to update as soon as we add new ops. // update as soon as we add new ops.
#define INSTANTIATE_LINALG_OP_TO_LOOPS(OP_TYPE) \ #define INSTANTIATE_LINALG_OP_TO_LOOPS(OP_TYPE) \
template LogicalResult mlir::linalg::linalgOpToLoops<OP_TYPE>( \ template LogicalResult mlir::linalg::linalgOpToLoops<OP_TYPE>( \
PatternRewriter & rewriter, Operation * op); \ OpBuilder & builder, Operation * op); \
template LogicalResult mlir::linalg::linalgOpToAffineLoops<OP_TYPE>( \ template LogicalResult mlir::linalg::linalgOpToAffineLoops<OP_TYPE>( \
PatternRewriter & rewriter, Operation * op); \ OpBuilder & builder, Operation * op); \
template LogicalResult mlir::linalg::linalgOpToParallelLoops<OP_TYPE>( \ template LogicalResult mlir::linalg::linalgOpToParallelLoops<OP_TYPE>( \
PatternRewriter & rewriter, Operation * op); \ OpBuilder & builder, Operation * op); \
template Optional<LinalgLoops> \ template Optional<LinalgLoops> \
mlir::linalg::linalgLowerOpToLoops<loop::ParallelOp, OP_TYPE>( \ mlir::linalg::linalgLowerOpToLoops<loop::ParallelOp, OP_TYPE>( \
PatternRewriter & rewriter, Operation * op); OpBuilder & builder, Operation * op);
INSTANTIATE_LINALG_OP_TO_LOOPS(CopyOp) INSTANTIATE_LINALG_OP_TO_LOOPS(CopyOp)
INSTANTIATE_LINALG_OP_TO_LOOPS(FillOp) INSTANTIATE_LINALG_OP_TO_LOOPS(FillOp)

View File

@ -16,6 +16,7 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/LoopOps/LoopOps.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
@ -264,6 +265,21 @@ static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
op.erase(); op.erase();
} }
LogicalResult mlir::linalg::promoteSubviewsLinalgOpPrecondition(
Operation *op, llvm::Optional<DenseSet<unsigned>> operandIndicesToPromote) {
LinalgOp linOp = dyn_cast<LinalgOp>(op);
// Transformation applies to buffers only.
if (!linOp || !linOp.hasBufferSemantics())
return failure();
for (auto en : llvm::enumerate(linOp.getInputsAndOutputBuffers())) {
auto sv = isa_and_nonnull<SubViewOp>(en.value().getDefiningOp());
if (sv && (!operandIndicesToPromote.hasValue() ||
operandIndicesToPromote->count(en.index())))
return success();
}
return failure();
}
namespace { namespace {
struct LinalgPromotionPass : public LinalgPromotionBase<LinalgPromotionPass> { struct LinalgPromotionPass : public LinalgPromotionBase<LinalgPromotionPass> {
LinalgPromotionPass() = default; LinalgPromotionPass() = default;

View File

@ -15,6 +15,7 @@
#include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/LoopOps/EDSC/Builders.h" #include "mlir/Dialect/LoopOps/EDSC/Builders.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
@ -320,10 +321,9 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
} }
template <typename LoopTy> template <typename LoopTy>
Optional<TiledLinalgOp> static tileLinalgOpImpl(OpBuilder &b, LinalgOp op, Optional<TiledLinalgOp> static tileLinalgOpImpl(
ArrayRef<Value> tileSizes, OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
ArrayRef<unsigned> permutation, ArrayRef<unsigned> interchangeVector, OperationFolder *folder) {
OperationFolder *folder) {
assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
// 1. Enforce the convention that "tiling by zero" skips tiling a particular // 1. Enforce the convention that "tiling by zero" skips tiling a particular
// dimension. This convention is significantly simpler to handle instead of // dimension. This convention is significantly simpler to handle instead of
@ -342,13 +342,13 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(OpBuilder &b, LinalgOp op,
return llvm::None; return llvm::None;
} }
// If permutation is empty, use the identity. Build the permutation map // If interchangeVector is empty, use the identity. Build the permutation map
// otherwise. // otherwise.
auto invPermutationMap = AffineMap::getMultiDimIdentityMap( auto invPermutationMap = AffineMap::getMultiDimIdentityMap(
tileSizes.size(), ScopedContext::getContext()); tileSizes.size(), ScopedContext::getContext());
if (!permutation.empty()) if (!interchangeVector.empty())
invPermutationMap = inversePermutation( invPermutationMap = inversePermutation(AffineMap::getPermutationMap(
AffineMap::getPermutationMap(permutation, ScopedContext::getContext())); interchangeVector, ScopedContext::getContext()));
if (!invPermutationMap) if (!invPermutationMap)
return llvm::None; return llvm::None;
@ -371,8 +371,8 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(OpBuilder &b, LinalgOp op,
std::tie(loopRanges, loopIndexToRangeIndex) = std::tie(loopRanges, loopIndexToRangeIndex) =
makeTiledLoopRanges(b, scope.getLocation(), viewSizesToLoopsMap, makeTiledLoopRanges(b, scope.getLocation(), viewSizesToLoopsMap,
viewSizes, tileSizes, folder); viewSizes, tileSizes, folder);
if (!permutation.empty()) if (!interchangeVector.empty())
applyPermutationToVector(loopRanges, permutation); applyPermutationToVector(loopRanges, interchangeVector);
// 3. Create the tiled loops. // 3. Create the tiled loops.
LinalgOp res = op; LinalgOp res = op;
@ -393,7 +393,7 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(OpBuilder &b, LinalgOp op,
// assuming that loopRanges have previously been permuted by // assuming that loopRanges have previously been permuted by
// (i,j,k)->(k,i,j) So this permutation should be the inversePermutation of // (i,j,k)->(k,i,j) So this permutation should be the inversePermutation of
// that one: (d0,d1,d2)->(d2,d0,d1) // that one: (d0,d1,d2)->(d2,d0,d1)
if (!permutation.empty()) if (!interchangeVector.empty())
ivValues = applyMapToValues(b, loc, invPermutationMap, ivValues, folder); ivValues = applyMapToValues(b, loc, invPermutationMap, ivValues, folder);
auto views = auto views =
@ -420,7 +420,8 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(OpBuilder &b, LinalgOp op,
template <typename LoopTy> template <typename LoopTy>
static Optional<TiledLinalgOp> static Optional<TiledLinalgOp>
tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes, tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
ArrayRef<unsigned> permutation, OperationFolder *folder) { ArrayRef<unsigned> interchangeVector,
OperationFolder *folder) {
assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
if (tileSizes.empty()) if (tileSizes.empty())
return llvm::None; return llvm::None;
@ -459,33 +460,36 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
tileSizeValues.push_back(folded_std_constant_index(folder, 0)); tileSizeValues.push_back(folded_std_constant_index(folder, 0));
} }
return tileLinalgOpImpl<LoopTy>(b, op, tileSizeValues, permutation, folder); return tileLinalgOpImpl<LoopTy>(b, op, tileSizeValues, interchangeVector,
folder);
} }
Optional<TiledLinalgOp> Optional<TiledLinalgOp>
mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes, mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
ArrayRef<unsigned> permutation, ArrayRef<unsigned> interchangeVector,
OperationFolder *folder) { OperationFolder *folder) {
return tileLinalgOpImpl<loop::ForOp>(b, op, tileSizes, permutation, folder); return tileLinalgOpImpl<loop::ForOp>(b, op, tileSizes, interchangeVector,
folder);
} }
Optional<TiledLinalgOp> mlir::linalg::tileLinalgOpToParallelLoops( Optional<TiledLinalgOp> mlir::linalg::tileLinalgOpToParallelLoops(
OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes, OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
ArrayRef<unsigned> permutation, OperationFolder *folder) { ArrayRef<unsigned> interchangeVector, OperationFolder *folder) {
return tileLinalgOpImpl<loop::ParallelOp>(b, op, tileSizes, permutation, return tileLinalgOpImpl<loop::ParallelOp>(b, op, tileSizes, interchangeVector,
folder); folder);
} }
Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp( Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes, OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
ArrayRef<unsigned> permutation, OperationFolder *folder) { ArrayRef<unsigned> interchangeVector, OperationFolder *folder) {
return tileLinalgOpImpl<loop::ForOp>(b, op, tileSizes, permutation, folder); return tileLinalgOpImpl<loop::ForOp>(b, op, tileSizes, interchangeVector,
folder);
} }
Optional<TiledLinalgOp> mlir::linalg::tileLinalgOpToParallelLoops( Optional<TiledLinalgOp> mlir::linalg::tileLinalgOpToParallelLoops(
OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes, OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
ArrayRef<unsigned> permutation, OperationFolder *folder) { ArrayRef<unsigned> interchangeVector, OperationFolder *folder) {
return tileLinalgOpImpl<loop::ParallelOp>(b, op, tileSizes, permutation, return tileLinalgOpImpl<loop::ParallelOp>(b, op, tileSizes, interchangeVector,
folder); folder);
} }
@ -496,8 +500,8 @@ static void tileLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) {
f.walk([tileSizes, &b, &folder](LinalgOp op) { f.walk([tileSizes, &b, &folder](LinalgOp op) {
if (!op.hasBufferSemantics()) if (!op.hasBufferSemantics())
return; return;
auto opLoopsPair = auto opLoopsPair = tileLinalgOpImpl<LoopTy>(
tileLinalgOpImpl<LoopTy>(b, op, tileSizes, /*permutation=*/{}, &folder); b, op, tileSizes, /*interchangeVector=*/{}, &folder);
// If tiling occurred successfully, erase old op. // If tiling occurred successfully, erase old op.
if (opLoopsPair) if (opLoopsPair)
op.erase(); op.erase();

View File

@ -0,0 +1,228 @@
//===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements logic and helpers to expose Linalg transforms as rewrite
// patterns.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <type_traits>
#define DEBUG_TYPE "linalg-transforms"
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
using namespace mlir::linalg;
using llvm::dbgs;
//===----------------------------------------------------------------------===//
// Transformations exposed as rewrite patterns.
//===----------------------------------------------------------------------===//
// Marker used as attribute name in generated Linalg rewriting transformations.
const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
"__internal_linalg_transform__";
mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
llvm::Optional<StringRef> replacement)
: matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
replacement(replacement) {}
mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
StringRef replacement)
: LinalgMarker(matchDisjunction, llvm::Optional<StringRef>{replacement}) {}
LogicalResult
mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
Operation *op) const {
auto attr = op->template getAttrOfType<StringAttr>(
LinalgTransforms::kLinalgTransformMarker);
if (!attr) {
// 1. Has no marker case and matchDisjunction is empty.
if (matchDisjunction.empty())
return success();
// 2. Has no marker and matchDisjuntion matches the no-moarker case.
for (auto marker : matchDisjunction)
if (marker.empty())
return success();
// 3. Has no marker but was expecting a marker.
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << " does not have any marker from list: ";
llvm::interleaveComma(matchDisjunction, diag);
});
}
// 4. Match explicit marker.
for (auto marker : matchDisjunction)
if (attr.getValue() == marker)
return success();
// 5. Fail to match.
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << " does not have any marker from list: ";
llvm::interleaveComma(matchDisjunction, diag);
});
}
void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter,
Operation *op) const {
if (replacement.hasValue())
op->setAttr(LinalgTransforms::kLinalgTransformMarker,
rewriter.getStringAttr(replacement.getValue()));
else
op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker,
rewriter.getContext()));
}
/// Linalg base tiling pattern.
mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
LinalgMarker marker, PatternBenefit benefit)
: RewritePattern(opName, {}, benefit, context), marker(marker),
options(options) {}
LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
return failure();
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
return failure();
Optional<TiledLinalgOp> res;
if (options.loopType == LinalgTilingLoopType::Loops)
res = tileLinalgOp(rewriter, linalgOp, options.tileSizes,
options.interchangeVector);
else if (options.loopType == LinalgTilingLoopType::ParallelLoops)
res = tileLinalgOpToParallelLoops(rewriter, linalgOp, options.tileSizes,
options.interchangeVector);
// TODO: Impl tiling to affine loops when it makes sense.
if (!res)
return failure();
// New marker if specified.
marker.replaceLinalgMarker(rewriter, res->op.getOperation());
rewriter.eraseOp(op);
return success();
}
/// Linalg base interchange pattern.
mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
StringRef opName, MLIRContext *context,
ArrayRef<unsigned> interchangeVector, LinalgMarker marker,
PatternBenefit benefit)
: RewritePattern(opName, {}, benefit, context), marker(marker),
interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
return failure();
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
return failure();
if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector)))
return failure();
// TODO: figure out how this interplays with named ops. In particular this
// should break the named op property.
rewriter.updateRootInPlace(op, [&]() {
interchange(linalgOp, interchangeVector);
// New marker if specified.
marker.replaceLinalgMarker(rewriter, op);
});
return success();
}
mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
StringRef opName, MLIRContext *context,
ArrayRef<unsigned> operandsToPromote, unsigned alignment,
LinalgMarker marker, PatternBenefit benefit)
: RewritePattern(opName, {}, benefit, context), marker(marker),
operandsToPromote(operandsToPromote.begin(), operandsToPromote.end()),
alignment(alignment) {}
LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
return failure();
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
return failure();
if (operandsToPromote.empty()) {
if (failed(promoteSubviewsLinalgOpPrecondition(op, llvm::None)))
return failure();
} else {
DenseSet<unsigned> set;
set.insert(operandsToPromote.begin(), operandsToPromote.end());
if (failed(promoteSubviewsLinalgOpPrecondition(op, set)))
return failure();
}
llvm::SetVector<Value> subViews;
if (!operandsToPromote.empty()) {
for (unsigned idx : operandsToPromote) {
auto *op = linalgOp.getBuffer(idx).getDefiningOp();
if (auto sv = dyn_cast_or_null<SubViewOp>(op))
subViews.insert(sv);
}
} else {
unsigned nBuffers = linalgOp.getNumInputsAndOutputBuffers();
for (unsigned idx = 0; idx < nBuffers; ++idx) {
auto *op = linalgOp.getBuffer(idx).getDefiningOp();
if (auto sv = dyn_cast_or_null<SubViewOp>(op))
subViews.insert(sv);
}
}
auto promotedOp =
promoteSubViewOperands(rewriter, op, subViews, /*dynamicBuffers=*/false,
/*alignment=*/alignment);
marker.replaceLinalgMarker(rewriter, promotedOp.getOperation());
rewriter.eraseOp(op);
return success();
}
mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
StringRef opName, MLIRContext *context, LinalgMarker marker,
PatternBenefit benefit)
: RewritePattern(opName, {}, benefit, context), marker(marker) {}
LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
return failure();
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
return failure();
if (failed(vectorizeLinalgOpPrecondition(op)))
return failure();
vectorizeLinalgOp(rewriter, op);
rewriter.eraseOp(op);
return success();
}

View File

@ -0,0 +1,131 @@
//===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the linalg dialect Vectorization transformations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <type_traits>
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
using namespace mlir::linalg;
using llvm::dbgs;
#define DEBUG_TYPE "linalg-vectorization"
static bool hasMultiplyAddBody(linalg::GenericOp op) {
auto &r = op.region();
if (!llvm::hasSingleElement(r))
return false;
if (!llvm::hasNItems(r.front().begin(), r.front().end(), 3))
return false;
using mlir::matchers::m_Val;
auto a = m_Val(r.front().getArgument(0));
auto b = m_Val(r.front().getArgument(1));
auto c = m_Val(r.front().getArgument(2));
// TODO: Update this detection once we have matcher support for specifying
// that any permutation of operands matches.
auto pattern1 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c));
auto pattern2 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b)));
auto pattern3 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c));
auto pattern4 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a)));
return pattern1.match(&r.front().back()) ||
pattern2.match(&r.front().back()) ||
pattern3.match(&r.front().back()) || pattern4.match(&r.front().back());
}
// TODO: Should be Tablegen'd from a single source that generates the op itself.
static bool isRowMajorMatmul(linalg::GenericOp genericOp) {
return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
isRowMajorMatmul(genericOp.indexing_maps()) &&
hasMultiplyAddBody(genericOp);
}
// TODO: This is in fact much more general than just vectorization for matmul
// and fill ops.
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
auto linalgOp = cast<linalg::LinalgOp>(op);
// All types must be static shape to go to vector.
for (Value operand : linalgOp.getInputsAndOutputBuffers())
if (!operand.getType().cast<ShapedType>().hasStaticShape())
return failure();
for (Type outputTensorType : linalgOp.getOutputTensorTypes())
if (!outputTensorType.cast<ShapedType>().hasStaticShape())
return failure();
if (isa<linalg::MatmulOp>(op) || isa<linalg::FillOp>(op))
return success();
auto genericOp = dyn_cast<linalg::GenericOp>(op);
if (!genericOp || !::isRowMajorMatmul(genericOp))
return failure();
// TODO(ntv): non-identity layout.
auto isStaticMemRefWithIdentityLayout = [](Value v) {
auto m = v.getType().dyn_cast<MemRefType>();
if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty())
return false;
return true;
};
return success(llvm::all_of(genericOp.getInputsAndOutputBuffers(),
isStaticMemRefWithIdentityLayout));
}
void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
assert(succeeded(vectorizeLinalgOpPrecondition(op)));
if (auto convOp = dyn_cast<linalg::ConvOp>(op)) {
// TODO: add a level of indirection to linalg.generic.
if (convOp.padding())
llvm_unreachable("Unexpected conv with padding");
}
edsc::ScopedContext scope(builder, op->getLoc());
if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
// Vectorize fill as a vector.broadcast.
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
"]: Rewrite linalg.fill as vector.broadcast: "
<< *op << ":\n");
Value memref = vector_type_cast(fillOp.getOutputBuffer(0));
Value dst = std_load(memref);
Value res = vector_broadcast(dst.getType(), fillOp.value());
std_store(res, memref);
return;
}
// Vectorize other ops as vector contraction (currently only matmul).
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
"]: Rewrite linalg op as vector.contract: "
<< *op << ":\n");
auto linalgOp = cast<linalg::LinalgOp>(op);
Value a = std_load(vector_type_cast(linalgOp.getInput(0)));
Value b = std_load(vector_type_cast(linalgOp.getInput(1)));
Value memref = vector_type_cast(linalgOp.getOutputBuffer(0));
Value c = std_load(memref);
Value res = vector_contract(a, b, c, linalgOp.indexing_maps(),
linalgOp.iterator_types());
std_store(res, memref);
}

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-linalg-transform-patterns | FileCheck %s // RUN: mlir-opt %s -test-linalg-transform-patterns=test-patterns | FileCheck %s
// CHECK-DAG: #[[STRIDED_1D:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> // CHECK-DAG: #[[STRIDED_1D:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
// Map corresponding to a 2D memory access where the stride along the last dim is known to be 1. // Map corresponding to a 2D memory access where the stride along the last dim is known to be 1.
@ -22,10 +22,8 @@ func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
// CHECK-LABEL: func @dot // CHECK-LABEL: func @dot
// CHECK-DAG: %[[c0:.*]] = constant 0 : index // CHECK-DAG: %[[c0:.*]] = constant 0 : index
// CHECK-DAG: %[[c1:.*]] = constant 1 : index // CHECK-DAG: %[[c1:.*]] = constant 1 : index
// CHECK-DAG: %[[c8:.*]] = constant 8 : index
// CHECK-DAG: %[[c8000:.*]] = constant 8000 : index // CHECK-DAG: %[[c8000:.*]] = constant 8000 : index
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8000]] { // CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8000]] {
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8]] {
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c1]] { // CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c1]] {
// CHECK: load // CHECK: load
// CHECK: load // CHECK: load
@ -46,8 +44,7 @@ func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// CHECK-DAG: %[[c0:.*]] = constant 0 : index // CHECK-DAG: %[[c0:.*]] = constant 0 : index
// CHECK-DAG: %[[c5:.*]] = constant 5 : index // CHECK-DAG: %[[c5:.*]] = constant 5 : index
// CHECK-DAG: %[[c6:.*]] = constant 6 : index // CHECK-DAG: %[[c6:.*]] = constant 6 : index
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c5]] // CHECK: loop.parallel {{.*}} step (%[[c5]], %[[c6]])
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c6]]
// CHECK: linalg.matvec({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?xf32, #[[STRIDED_1D]]>, memref<?xf32, #[[STRIDED_1D]]> // CHECK: linalg.matvec({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?xf32, #[[STRIDED_1D]]>, memref<?xf32, #[[STRIDED_1D]]>
func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
@ -86,88 +83,6 @@ func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4]] { // CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4]] {
// CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]> // CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
#some_generic_trait = {
args_in = 1,
args_out = 1,
indexing_maps = [
affine_map<(i, j) -> (i, j)>,
affine_map<(i, j) -> (i, j)>
],
iterator_types = ["parallel", "parallel"]
}
func @fusion_test(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%D: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%E: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
// This should not be fused as it would violate dependencies. It will get
// tiled for all levels of the memory hierarchy.
linalg.matmul(%A, %A, %C) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]>
// This should be fused.
linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]>
// This should not be fused or transformed at all since there are no patterns
// on it. However it will be reordered because there are no dependencies.
linalg.generic #some_generic_trait %A, %D {
^bb(%a: f32, %b: f32) :
linalg.yield %a : f32
} : memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]>
linalg.matmul(%C, %D, %E) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]>
return
}
// CHECK-LABEL: func @fusion_test
// CHECK-DAG: %[[c0:.*]] = constant 0 : index
// CHECK-DAG: %[[c2:.*]] = constant 2 : index
// CHECK-DAG: %[[c3:.*]] = constant 3 : index
// CHECK-DAG: %[[c4:.*]] = constant 4 : index
// CHECK-DAG: %[[c20:.*]] = constant 20 : index
// CHECK-DAG: %[[c30:.*]] = constant 30 : index
// CHECK-DAG: %[[c40:.*]] = constant 40 : index
// CHECK-DAG: %[[c100:.*]] = constant 100 : index
// CHECK-DAG: %[[c150:.*]] = constant 150 : index
// CHECK-DAG: %[[c200:.*]] = constant 200 : index
// CHECK-DAG: %[[c300:.*]] = constant 300 : index
// CHECK-DAG: %[[c400:.*]] = constant 400 : index
// CHECK-DAG: %[[c2000:.*]] = constant 2000 : index
// CHECK-DAG: %[[c3000:.*]] = constant 3000 : index
// CHECK-DAG: %[[c4000:.*]] = constant 4000 : index
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] {
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] {
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] {
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c200]] {
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c300]] {
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c400]] {
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c20]] {
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] {
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] {
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2]] {
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3]] {
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4]] {
// CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
//
// CHECK: linalg.generic
//
// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c100]] {
// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c150]] {
// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c2]] {
// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c3]] {
// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c4]] {
// CHECK: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c2]] {
// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c3]] {
// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c4]] {
// CHECK: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
#matmul_trait = { #matmul_trait = {
args_in = 2, args_in = 2,
args_out = 1, args_out = 1,
@ -280,23 +195,6 @@ func @permute_generic_indexed(
// CHECK-SAME: memref<?x?xf32, #[[STRIDED_2D_u_1]]>, // CHECK-SAME: memref<?x?xf32, #[[STRIDED_2D_u_1]]>,
// CHECK-SAME: memref<?x?xf32, #[[STRIDED_2D_u_1]]> // CHECK-SAME: memref<?x?xf32, #[[STRIDED_2D_u_1]]>
func @dot_perm(%x: memref<?xf32, offset: ?, strides: [1]>,
%y: memref<?xf32, offset: ?, strides: [1]>,
%v: memref<f32>) {
linalg.dot(%x, %y, %v) {__internal_linalg_transform__ = "__with_perm__"} :
memref<?xf32, offset: ?, strides: [1]>,
memref<?xf32, offset: ?, strides: [1]>,
memref<f32>
return
}
// CHECK-LABEL: func @dot_perm
// CHECK-DAG: %[[c0:.*]] = constant 0 : index
// CHECK-DAG: %[[c8:.*]] = constant 8 : index
// CHECK-DAG: %[[c8000:.*]] = constant 8000 : index
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8000]] {
// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8]] {
// CHECK: linalg.dot({{.*}}, {{.*}}, {{.*}}) : memref<?xf32, #[[STRIDED_1D]]>, memref<?xf32, #[[STRIDED_1D]]>, memref<f32>
func @matvec_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, func @matvec_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%x: memref<?xf32, offset: ?, strides: [1]>, %x: memref<?xf32, offset: ?, strides: [1]>,
%y: memref<?xf32, offset: ?, strides: [1]>) { %y: memref<?xf32, offset: ?, strides: [1]>) {

View File

@ -1,9 +1,3 @@
set(LLVM_TARGET_DEFINITIONS TestLinalgTransformPatterns.td)
mlir_tablegen(TestLinalgTransformPatterns.h.inc -gen-rewriters)
add_public_tablegen_target(MLIRTestLinalgTransformPatternsIncGen)
# Including Linalg in TableGen requires to depends on generated files
add_dependencies(MLIRTestLinalgTransformPatternsIncGen LinalgOdsGen)
set(LLVM_TARGET_DEFINITIONS TestVectorTransformPatterns.td) set(LLVM_TARGET_DEFINITIONS TestVectorTransformPatterns.td)
mlir_tablegen(TestVectorTransformPatterns.h.inc -gen-rewriters) mlir_tablegen(TestVectorTransformPatterns.h.inc -gen-rewriters)
add_public_tablegen_target(MLIRTestVectorTransformPatternsIncGen) add_public_tablegen_target(MLIRTestVectorTransformPatternsIncGen)

View File

@ -1,168 +0,0 @@
//===- TestLinalgTransformPatterns.td - Test patterns --*- tablegen ----*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the pattern definition file for declarative Linalg transformations
// tests.
//
//===----------------------------------------------------------------------===//
#ifndef TEST_LINALG_TRANSFORMS_PATTERNS
#define TEST_LINALG_TRANSFORMS_PATTERNS
include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td"
//===----------------------------------------------------------------------===//
// Test Linalg fusion patterns.
//===----------------------------------------------------------------------===//
def : Pat<(MatmulOp:$op $A, $_, $_),
(TileAndFuseLinalgOp<[100, 150], [0], "L1">),
[
(Constraint<HasNoLinalgTransformMarker>),
(Constraint<IsProducedByOpOfType<"MatmulOp">> $A),
],
// In the buffer world there is no use-def chains or dags so benefits
// cannot be computed automatically from the length of the matched
// pattern. Instead we specify the benefit ourselves for now.
// This is not expected to be a big challenge long-term because
// pattern benefits are akin to feature engineering: features should
// be learned.
(addBenefit 1)>;
//===----------------------------------------------------------------------===//
// Linalg tiling patterns.
//===----------------------------------------------------------------------===//
def : Pat<(MatmulOp:$op $_, $_, $_),
(TileLinalgOp<[2000, 3000, 4000], "L3">),
[(Constraint<Or<[HasNoLinalgTransformMarker,
HasLinalgTransformMarker<"MEM">]>>)]>;
def : Pat<(MatmulOp:$op $_, $_, $_),
(TileLinalgOp<[200, 300, 400], "L2">),
[(Constraint<HasLinalgTransformMarker<"L3">>)]>;
def : Pat<(MatmulOp:$op $_, $_, $_),
(TileLinalgOp<[20, 30, 40], "L1">),
[(Constraint<HasLinalgTransformMarker<"L2">>)]>;
def : Pat<(MatmulOp:$op $_, $_, $_),
(TileLinalgOp<[2, 3, 4], "REG">),
[(Constraint<HasLinalgTransformMarker<"L1">>)]>;
def : Pattern<(MatvecOp:$op $_, $_, $_),
[(TileLinalgOp<[5, 6], "L1">)],
[(Constraint<HasNoLinalgTransformMarker>)]>;
def : Pattern<(DotOp:$op $_, $_, $_),
[(TileLinalgOp<[8000], "L1">)],
[(Constraint<Or<[HasNoLinalgTransformMarker,
HasLinalgTransformMarker<"MEM">,
HasLinalgTransformMarker<"L3">,
HasLinalgTransformMarker<"L2">]>>)]>;
def : Pattern<(DotOp:$op $_, $_, $_),
[(TileLinalgOp<[8], "REG">)],
[(Constraint<HasLinalgTransformMarker<"L1">>)]>;
//===----------------------------------------------------------------------===//
// Linalg tiling and permutation patterns.
//===----------------------------------------------------------------------===//
def : Pat<(MatmulOp:$op $_, $_, $_),
(TileLinalgOp<[2000, 3000, 4000], "L2__with_perm__", [1,2,0]>),
[(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>;
def : Pat<(MatmulOp:$op $_, $_, $_),
(TileLinalgOp<[200, 300, 400], "L1__with_perm__", [1,0,2]>),
[(Constraint<HasLinalgTransformMarker<"L2__with_perm__">>)]>;
def : Pat<(MatmulOp:$op $_, $_, $_),
(TileLinalgOp<[20, 30, 40], "REG__with_perm__">),
[(Constraint<HasLinalgTransformMarker<"L1__with_perm__">>)]>;
def : Pattern<(MatvecOp:$op $_, $_, $_),
[(TileLinalgOp<[5, 6], "L1__with_perm__", [1,0]>)],
[(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>;
def : Pattern<(DotOp:$op $_, $_, $_),
[(TileLinalgOp<[8000], "L1__with_perm__">)],
[(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>;
def : Pattern<(DotOp:$op $_, $_, $_),
[(TileLinalgOp<[8], "REG__with_perm__">)],
[(Constraint<HasLinalgTransformMarker<"L1__with_perm__">>)]>;
//===----------------------------------------------------------------------===//
// Linalg to loops patterns.
//===----------------------------------------------------------------------===//
def : Pattern<(DotOp:$op $_, $_, $_),
[(LinalgOpToLoops<"DotOp">)],
[(Constraint<HasLinalgTransformMarker<"REG">>)]>;
//===----------------------------------------------------------------------===//
// Linalg to vector contraction patterns.
//===----------------------------------------------------------------------===//
def : Pattern<(MatmulOp:$op $_, $_, $_),
[(VectorizeLinalgOp)],
[(Constraint<And<[
HasLinalgTransformMarker<"VECTORIZE">,
PreconditionVectorizeLinalgOp
]>>)]>;
def : Pattern<(FillOp:$op $_, $_),
[(VectorizeLinalgOp)],
[(Constraint<And<[
HasLinalgTransformMarker<"VECTORIZE">,
PreconditionVectorizeLinalgOp
]>>)]>;
def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_),
[(VectorizeLinalgOp)],
[(Constraint<And<[
HasLinalgTransformMarker<"VECTORIZE">,
PreconditionVectorizeLinalgOp
]>>)]>;
//===----------------------------------------------------------------------===//
// Linalg generic permutation patterns.
//===----------------------------------------------------------------------===//
def : Pat<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_),
(PermuteGenericLinalgOp<[1, 2, 0], "PERMUTE"> $op),
[(Constraint<And<[
HasNoLinalgTransformMarker,
AffineMapDomainHasDim<3>,
PreconditionPermuteGenericLinalgOp<[1, 2, 0]>
]>>)]>;
def : Pat<(IndexedGenericOp:$op $_, $_, $_, $_, $_, $_, $_),
(PermuteGenericLinalgOp<[1, 2, 0], "PERMUTE"> $op),
[(Constraint<And<[
HasNoLinalgTransformMarker,
AffineMapDomainHasDim<3>,
PreconditionPermuteGenericLinalgOp<[1, 2, 0]>
]>>)]>;
//===----------------------------------------------------------------------===//
// Linalg subview operands promotion.
//===----------------------------------------------------------------------===//
def : Pat<(MatmulOp:$op $_, $_, $_),
(PromoteSubviewsLinalgOp),
[(Constraint<And<[
PreconditionPromoteSubviewsLinalgOp,
HasOperandsOfType<"SubViewOp">,
HasLinalgTransformMarker<"_promote_views_">]>>
)]>;
def : Pat<(MatmulOp:$op $_, $_, $_),
(PromoteSelectedSubviewsLinalgOp<[0], "first_view_promotion">),
[(Constraint<And<[
PreconditionPromoteSubviewsLinalgOp,
HasOperandsOfType<"SubViewOp">,
HasLinalgTransformMarker<"_promote_first_view_">]>>
)]>;
def : Pat<(FillOp:$op $_, $_),
(PromoteSelectedSubviewsLinalgOp<[0], "aligned_promotion", 32>),
[(Constraint<And<[
PreconditionPromoteSubviewsLinalgOp,
HasOperandsOfType<"SubViewOp">,
HasLinalgTransformMarker<"_promote_views_aligned_">]>>
)]>;
#endif // TEST_LINALG_TRANSFORMS_PATTERNS

View File

@ -25,7 +25,6 @@ add_llvm_library(MLIRTestTransforms
DEPENDS DEPENDS
MLIRStandardOpsIncGen MLIRStandardOpsIncGen
MLIRTestLinalgTransformPatternsIncGen
MLIRTestVectorTransformPatternsIncGen MLIRTestVectorTransformPatternsIncGen
) )

View File

@ -10,36 +10,127 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "llvm/ADT/SetVector.h"
using namespace mlir; using namespace mlir;
using namespace mlir::linalg; using namespace mlir::linalg;
namespace mlir {
namespace linalg {
namespace {
#include "TestLinalgTransformPatterns.h.inc"
} // end namespace
} // end namespace linalg
} // end namespace mlir
namespace { namespace {
struct TestLinalgTransforms struct TestLinalgTransforms
: public PassWrapper<TestLinalgTransforms, FunctionPass> { : public PassWrapper<TestLinalgTransforms, FunctionPass> {
TestLinalgTransforms() = default;
TestLinalgTransforms(const TestLinalgTransforms &pass) {}
void runOnFunction() override; void runOnFunction() override;
Option<bool> testPatterns{*this, "test-patterns",
llvm::cl::desc("Test a mixed set of patterns"),
llvm::cl::init(false)};
}; };
} // end anonymous namespace } // end anonymous namespace
/// Apply transformations specified as patterns. static void applyPatterns(FuncOp funcOp) {
void TestLinalgTransforms::runOnFunction() { MLIRContext *ctx = funcOp.getContext();
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
auto funcOp = getFunction();
// Add the generated patterns to the list. //===--------------------------------------------------------------------===//
linalg::populateWithGenerated(&getContext(), &patterns); // Linalg tiling patterns.
//===--------------------------------------------------------------------===//
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
LinalgMarker({"MEM", {}}, "L3"));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
LinalgMarker({"L3"}, "L2"));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
LinalgMarker({"L2"}, "L1"));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
LinalgMarker({"L1"}, "REG"));
patterns.insert<LinalgTilingPattern<MatvecOp>>(
ctx,
LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
LinalgTilingLoopType::ParallelLoops),
LinalgMarker({}, "L1"));
patterns.insert<LinalgTilingPattern<DotOp>>(
ctx, LinalgTilingOptions().setTileSizes(8000),
LinalgMarker({"MEM", "L3", "L2", {}}, "REG"));
//===--------------------------------------------------------------------===//
// Linalg tiling and permutation patterns.
//===--------------------------------------------------------------------===//
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx,
LinalgTilingOptions()
.setTileSizes({2000, 3000, 4000})
.setInterchange({1, 2, 0}),
LinalgMarker({"__with_perm__"}, "L2__with_perm__"));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx,
LinalgTilingOptions()
.setTileSizes({200, 300, 400})
.setInterchange({1, 0, 2}),
LinalgMarker({"L2__with_perm__"}, "L1__with_perm__"));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
LinalgMarker({"L1__with_perm__"}, "REG__with_perm__"));
patterns.insert<LinalgTilingPattern<MatvecOp>>(
ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
LinalgMarker({"__with_perm__"}, "L1__with_perm__"));
//===--------------------------------------------------------------------===//
// Linalg to loops patterns.
//===--------------------------------------------------------------------===//
patterns.insert<LinalgLoweringPattern<DotOp>>(
ctx,
/*loweringType=*/LinalgLoweringType::Loops, LinalgMarker({"REG"}));
//===--------------------------------------------------------------------===//
// Linalg to vector contraction patterns.
//===--------------------------------------------------------------------===//
patterns.insert<LinalgVectorizationPattern<MatmulOp>,
LinalgVectorizationPattern<FillOp>,
LinalgVectorizationPattern<GenericOp>>(
ctx, LinalgMarker({"VECTORIZE"}));
//===--------------------------------------------------------------------===//
// Linalg generic permutation patterns.
//===--------------------------------------------------------------------===//
patterns.insert<LinalgInterchangePattern<GenericOp>>(
ctx,
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
LinalgMarker({}, "PERMUTED"));
patterns.insert<LinalgInterchangePattern<IndexedGenericOp>>(
ctx,
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
LinalgMarker({}, "PERMUTED"));
//===--------------------------------------------------------------------===//
// Linalg subview operands promotion.
//===--------------------------------------------------------------------===//
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
ctx, LinalgMarker({"_promote_views_"}, "_views_promoted_"));
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
ctx,
/*operandsToPromote=*/ArrayRef<unsigned>{0},
LinalgMarker({"_promote_first_view_"}, "_first_view_promoted_"));
patterns.insert<LinalgPromotionPattern<FillOp>>(
ctx,
/*operandsToPromote=*/ArrayRef<unsigned>{0},
/*alignment=*/32,
LinalgMarker({"_promote_views_aligned_"}, "_views_aligned_promoted_"));
applyPatternsAndFoldGreedily(funcOp, patterns); applyPatternsAndFoldGreedily(funcOp, patterns);
// Drop the marker. // Drop the marker.
@ -48,9 +139,15 @@ void TestLinalgTransforms::runOnFunction() {
}); });
} }
/// Apply transformations specified as patterns.
void TestLinalgTransforms::runOnFunction() {
if (testPatterns)
return applyPatterns(getFunction());
}
namespace mlir { namespace mlir {
void registerTestLinalgTransforms() { void registerTestLinalgTransforms() {
PassRegistration<TestLinalgTransforms>( PassRegistration<TestLinalgTransforms> testTransformPatternsPass(
"test-linalg-transform-patterns", "test-linalg-transform-patterns",
"Test Linalg transformation patterns by applying them greedily."); "Test Linalg transformation patterns by applying them greedily.");
} }