[mlir][Linalg] NFC - Cleanup Linalg Declarative Transformations

Summary:
This is part of an ongoing cleanup and uniformization work.

This diff performs 3 types of cleanups:
1. Uniformize transformation names.
2. Replace all pattern operands that need not be captured by `$_`
3. Replace all usage of pattern captured op by the normalized `op` name (instead of positional parameters such as `$0`)

Reviewers: ftynse

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D72081
This commit is contained in:
Nicolas Vasilache 2020-01-02 09:54:47 -05:00
parent 87a004d0f8
commit a9d9aadcdf
4 changed files with 80 additions and 79 deletions

View File

@ -18,24 +18,24 @@ include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td"
include "mlir/Dialect/AffineOps/AffineOps.td"
def HasNoLinalgTransformMarker : CPred<[{
!$0.getAttrOfType<StringAttr>(LinalgTransforms::kLinalgTransformMarker)
!op.getAttrOfType<StringAttr>(LinalgTransforms::kLinalgTransformMarker)
}]>;
class HasLinalgTransformMarker<string str> : CPred<[{
$0.getAttrOfType<StringAttr>(
op.getAttrOfType<StringAttr>(
LinalgTransforms::kLinalgTransformMarker) &&
$0.getAttrOfType<StringAttr>(
op.getAttrOfType<StringAttr>(
LinalgTransforms::kLinalgTransformMarker).getValue() == "}] # str # [{"}]>;
class IsProducedByOpOfType<string str> :
CPred<"isProducedByOpOfType<" # str # ">($0, $1)">;
CPred<"isProducedByOpOfType<" # str # ">(op, $0)">;
class AffineMapDomainHasDim<int n> : CPred<[{
$0.getAttrOfType<ArrayAttr>(getIndexingMapsAttrName()).getValue()[0].
op.getAttrOfType<ArrayAttr>(getIndexingMapsAttrName()).getValue()[0].
cast<AffineMapAttr>().getValue().getNumDims() ==}] # n # [{}]>;
class HasOperandsOfType<string type>: CPred<[{
llvm::any_of($0.getOperands(),
llvm::any_of(op.getOperands(),
[](Value v) {
return dyn_cast_or_null<}] # type # [{>(v->getDefiningOp());
})
@ -50,7 +50,7 @@ class HasOperandsOfType<string type>: CPred<[{
// patterns.
class TileAndFuseLinalgOp<
list<int> sizes, list<int> operandIndices, string value> : NativeCodeCall<
"if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, $0, {" #
"if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, op, {" #
StrJoinInt<sizes>.result # "}, {" # StrJoinInt<operandIndices>.result # "}," #
" \"" # value # "\")))" #
" return matchFailure();">;
@ -67,7 +67,7 @@ class TileAndFuseLinalgOp<
// of elements as `sizes`.
class TileLinalgOp<list<int> sizes, string value, list<int> permutation=[]> :
NativeCodeCall<
"if (failed(tileLinalgOpAndSetMarker($_builder, $0, {" #
"if (failed(tileLinalgOpAndSetMarker($_builder, op, {" #
StrJoinInt<sizes>.result # "}, \"" # value # "\", {" #
StrJoinInt<permutation>.result # "})))" #
" return matchFailure();">;
@ -76,18 +76,18 @@ class TileLinalgOp<list<int> sizes, string value, list<int> permutation=[]> :
// Linalg to loop patterns.
//===----------------------------------------------------------------------===//
class LinalgOpToLoops<string OpType> : NativeCodeCall<
"if (failed(linalgOpToLoops<" # OpType # ">($_builder, $0))) " #
"if (failed(linalgOpToLoops<" # OpType # ">($_builder, op))) " #
" return matchFailure();">;
class LinalgOpToAffineLoops<string OpType> : NativeCodeCall<
"if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, $0))) " #
"if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, op))) " #
" return matchFailure();">;
//===----------------------------------------------------------------------===//
// Linalg to vector contraction patterns.
//===----------------------------------------------------------------------===//
class LinalgOpToVectorContraction<string OpType> : NativeCodeCall<
"if (failed(vectorizeGenericOp($_builder, $0))) " #
class VectorizeGenericLinalgOp<string OpType> : NativeCodeCall<
"if (failed(vectorizeGenericLinalgOp($_builder, op))) " #
" return matchFailure();">;
//===----------------------------------------------------------------------===//
@ -95,14 +95,14 @@ class LinalgOpToVectorContraction<string OpType> : NativeCodeCall<
//===----------------------------------------------------------------------===//
class PermuteGenericLinalgOp<list<int> permutation, string value> :
NativeCodeCall<
"if (failed(permuteGenericLinalgOp($_builder, $0, {" #
"if (failed(permuteGenericLinalgOp($_builder, op, {" #
StrJoinInt<permutation>.result # "}, \"" # value # "\"))) " #
" return matchFailure();">;
//===----------------------------------------------------------------------===//
// Linalg promote subview operands.
//===----------------------------------------------------------------------===//
class LinalgOpPromoteSubviews<string OpType> : NativeCodeCall<
"if (failed(linalgOpPromoteSubviews($_builder, $0))) " #
class PromoteSubviewsLinalgOp<string OpType> : NativeCodeCall<
"if (failed(promoteSubviewsLinalgOp($_builder, op))) " #
" return matchFailure();">;
#endif // LINALG_TRANSFORMS

View File

@ -79,7 +79,8 @@ template <typename ConcreteOp>
LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op);
/// Rewrite a linalg.generic into a suitable vector.contraction op.
LogicalResult vectorizeGenericOp(PatternRewriter &rewriter, Operation *op);
LogicalResult vectorizeGenericLinalgOp(PatternRewriter &rewriter,
Operation *op);
/// Emits a `generic` or `indexed_generic` operation with the `indexing_maps`
/// and `iterator_types` permutated according to `permutation`.
@ -88,7 +89,7 @@ LogicalResult permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
StringRef linalgMarker);
/// Promote std.subviews feeding linalg operations
LogicalResult linalgOpPromoteSubviews(PatternRewriter &rewriter, Operation *op);
LogicalResult promoteSubviewsLinalgOp(PatternRewriter &rewriter, Operation *op);
} // namespace linalg
} // namespace mlir

View File

@ -153,8 +153,8 @@ static bool isMatmul(linalg::GenericOp genericOp) {
genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp);
}
LogicalResult mlir::linalg::vectorizeGenericOp(PatternRewriter &rewriter,
Operation *op) {
LogicalResult mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter,
Operation *op) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
"]: Rewrite linalg op as vector.contract: "
<< *op << ":\n");
@ -223,7 +223,7 @@ mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
return success();
}
LogicalResult mlir::linalg::linalgOpPromoteSubviews(PatternRewriter &rewriter,
LogicalResult mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter,
Operation *op) {
LinalgOp linOp = dyn_cast<LinalgOp>(op);
SetVector<Value> subViews;

View File

@ -19,11 +19,11 @@ include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td"
//===----------------------------------------------------------------------===//
// Test Linalg fusion patterns.
//===----------------------------------------------------------------------===//
def : Pat<(MatmulOp:$consumer $A, $B, $C),
(TileAndFuseLinalgOp<[100, 150], [0], "L1"> $consumer),
def : Pat<(MatmulOp:$op $A, $_, $_),
(TileAndFuseLinalgOp<[100, 150], [0], "L1">),
[
(Constraint<HasNoLinalgTransformMarker> $consumer),
(Constraint<IsProducedByOpOfType<"MatmulOp">> $consumer, $A),
(Constraint<HasNoLinalgTransformMarker>),
(Constraint<IsProducedByOpOfType<"MatmulOp">> $A),
],
// In the buffer world there is no use-def chains or dags so benefits
// cannot be computed automatically from the length of the matched
@ -36,91 +36,91 @@ def : Pat<(MatmulOp:$consumer $A, $B, $C),
//===----------------------------------------------------------------------===//
// Linalg tiling patterns.
//===----------------------------------------------------------------------===//
def : Pat<(MatmulOp:$op $A, $B, $C),
(TileLinalgOp<[2000, 3000, 4000], "L3"> $op),
def : Pat<(MatmulOp:$op $_, $_, $_),
(TileLinalgOp<[2000, 3000, 4000], "L3">),
[(Constraint<Or<[HasNoLinalgTransformMarker,
HasLinalgTransformMarker<"MEM">]>> $op)]>;
def : Pat<(MatmulOp:$op $A, $B, $C),
(TileLinalgOp<[200, 300, 400], "L2"> $op),
[(Constraint<HasLinalgTransformMarker<"L3">> $op)]>;
def : Pat<(MatmulOp:$op $A, $B, $C),
(TileLinalgOp<[20, 30, 40], "L1"> $op),
[(Constraint<HasLinalgTransformMarker<"L2">> $op)]>;
def : Pat<(MatmulOp:$op $A, $B, $C),
(TileLinalgOp<[2, 3, 4], "REG"> $op),
[(Constraint<HasLinalgTransformMarker<"L1">> $op)]>;
HasLinalgTransformMarker<"MEM">]>>)]>;
def : Pat<(MatmulOp:$op $_, $_, $_),
(TileLinalgOp<[200, 300, 400], "L2">),
[(Constraint<HasLinalgTransformMarker<"L3">>)]>;
def : Pat<(MatmulOp:$op $_, $_, $_),
(TileLinalgOp<[20, 30, 40], "L1">),
[(Constraint<HasLinalgTransformMarker<"L2">>)]>;
def : Pat<(MatmulOp:$op $_, $_, $_),
(TileLinalgOp<[2, 3, 4], "REG">),
[(Constraint<HasLinalgTransformMarker<"L1">>)]>;
def : Pattern<(MatvecOp:$op $A, $b, $c),
[(TileLinalgOp<[5, 6], "L1"> $op)],
[(Constraint<HasNoLinalgTransformMarker> $op)]>;
def : Pattern<(MatvecOp:$op $_, $_, $_),
[(TileLinalgOp<[5, 6], "L1">)],
[(Constraint<HasNoLinalgTransformMarker>)]>;
def : Pattern<(DotOp:$op $a, $b, $c),
[(TileLinalgOp<[8000], "L1"> $op)],
def : Pattern<(DotOp:$op $_, $_, $_),
[(TileLinalgOp<[8000], "L1">)],
[(Constraint<Or<[HasNoLinalgTransformMarker,
HasLinalgTransformMarker<"MEM">,
HasLinalgTransformMarker<"L3">,
HasLinalgTransformMarker<"L2">]>> $op)]>;
def : Pattern<(DotOp:$op $a, $b, $c),
[(TileLinalgOp<[8], "REG"> $op)],
[(Constraint<HasLinalgTransformMarker<"L1">> $op)]>;
HasLinalgTransformMarker<"L2">]>>)]>;
def : Pattern<(DotOp:$op $_, $_, $_),
[(TileLinalgOp<[8], "REG">)],
[(Constraint<HasLinalgTransformMarker<"L1">>)]>;
//===----------------------------------------------------------------------===//
// Linalg tiling and permutation patterns.
//===----------------------------------------------------------------------===//
def : Pat<(MatmulOp:$op $A, $B, $C),
(TileLinalgOp<[2000, 3000, 4000], "L2__with_perm__", [1,2,0]> $op),
[(Constraint<HasLinalgTransformMarker<"__with_perm__">> $op)]>;
def : Pat<(MatmulOp:$op $A, $B, $C),
(TileLinalgOp<[200, 300, 400], "L1__with_perm__", [1,0,2]> $op),
[(Constraint<HasLinalgTransformMarker<"L2__with_perm__">> $op)]>;
def : Pat<(MatmulOp:$op $A, $B, $C),
(TileLinalgOp<[20, 30, 40], "REG__with_perm__"> $op),
[(Constraint<HasLinalgTransformMarker<"L1__with_perm__">> $op)]>;
def : Pat<(MatmulOp:$op $_, $_, $_),
(TileLinalgOp<[2000, 3000, 4000], "L2__with_perm__", [1,2,0]>),
[(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>;
def : Pat<(MatmulOp:$op $_, $_, $_),
(TileLinalgOp<[200, 300, 400], "L1__with_perm__", [1,0,2]>),
[(Constraint<HasLinalgTransformMarker<"L2__with_perm__">>)]>;
def : Pat<(MatmulOp:$op $_, $_, $_),
(TileLinalgOp<[20, 30, 40], "REG__with_perm__">),
[(Constraint<HasLinalgTransformMarker<"L1__with_perm__">>)]>;
def : Pattern<(MatvecOp:$op $A, $b, $c),
[(TileLinalgOp<[5, 6], "L1__with_perm__", [1,0]> $op)],
[(Constraint<HasLinalgTransformMarker<"__with_perm__">> $op)]>;
def : Pattern<(MatvecOp:$op $_, $_, $_),
[(TileLinalgOp<[5, 6], "L1__with_perm__", [1,0]>)],
[(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>;
def : Pattern<(DotOp:$op $a, $b, $c),
[(TileLinalgOp<[8000], "L1__with_perm__"> $op)],
[(Constraint<HasLinalgTransformMarker<"__with_perm__">> $op)]>;
def : Pattern<(DotOp:$op $a, $b, $c),
[(TileLinalgOp<[8], "REG__with_perm__"> $op)],
[(Constraint<HasLinalgTransformMarker<"L1__with_perm__">> $op)]>;
def : Pattern<(DotOp:$op $_, $_, $_),
[(TileLinalgOp<[8000], "L1__with_perm__">)],
[(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>;
def : Pattern<(DotOp:$op $_, $_, $_),
[(TileLinalgOp<[8], "REG__with_perm__">)],
[(Constraint<HasLinalgTransformMarker<"L1__with_perm__">>)]>;
//===----------------------------------------------------------------------===//
// Linalg to loops patterns.
//===----------------------------------------------------------------------===//
def : Pattern<(DotOp:$op $a, $b, $c),
[(LinalgOpToLoops<"DotOp"> $op)],
[(Constraint<HasLinalgTransformMarker<"REG">> $op)]>;
def : Pattern<(DotOp:$op $_, $_, $_),
[(LinalgOpToLoops<"DotOp">)],
[(Constraint<HasLinalgTransformMarker<"REG">>)]>;
//===----------------------------------------------------------------------===//
// Linalg to vector contraction patterns.
//===----------------------------------------------------------------------===//
def : Pattern<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
[(LinalgOpToVectorContraction<"GenericOp"> $op)],
[(Constraint<HasLinalgTransformMarker<"_marked_matmul_">> $op)]>;
def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
[(VectorizeGenericLinalgOp<"GenericOp">)],
[(Constraint<HasLinalgTransformMarker<"_marked_matmul_">>)]>;
//===----------------------------------------------------------------------===//
// Linalg generic permutation patterns.
//===----------------------------------------------------------------------===//
def : Pat<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
(PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op),
def : Pat<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
(PermuteGenericLinalgOp<[1,2,0],"PERMUTED">),
[(Constraint<And<[HasNoLinalgTransformMarker,
AffineMapDomainHasDim<3>]>> $op)]>;
AffineMapDomainHasDim<3>]>>)]>;
def : Pat<(IndexedGenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
(PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op),
def : Pat<(IndexedGenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
(PermuteGenericLinalgOp<[1,2,0],"PERMUTED">),
[(Constraint<And<[HasNoLinalgTransformMarker,
AffineMapDomainHasDim<3>]>> $op)]>;
AffineMapDomainHasDim<3>]>>)]>;
//===----------------------------------------------------------------------===//
// Linalg subview operands promotion.
//===----------------------------------------------------------------------===//
def : Pat<(MatmulOp:$op $A, $B, $C),
(LinalgOpPromoteSubviews<"MatmulOp"> $op),
[(Constraint<HasOperandsOfType<"SubViewOp">> $op),
(Constraint<HasLinalgTransformMarker<"_promote_views_">> $op)]>;
def : Pat<(MatmulOp:$op $_, $_, $_),
(PromoteSubviewsLinalgOp<"MatmulOp">),
[(Constraint<HasOperandsOfType<"SubViewOp">>),
(Constraint<HasLinalgTransformMarker<"_promote_views_">>)]>;
#endif // TEST_LINALG_TRANSFORMS_PATTERNS