forked from OSchip/llvm-project
[mlir][Linalg] NFC - Cleanup Linalg Declarative Transformations
Summary: This is part of an ongoing cleanup and uniformization work. This diff performs 3 types of cleanups: 1. Uniformize transformation names. 2. Replace all pattern operands that need not be captured by `$_` 3. Replace all usage of pattern captured op by the normalized `op` name (instead of positional parameters such as `$0`) Reviewers: ftynse Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72081
This commit is contained in:
parent
87a004d0f8
commit
a9d9aadcdf
|
@ -18,24 +18,24 @@ include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td"
|
|||
include "mlir/Dialect/AffineOps/AffineOps.td"
|
||||
|
||||
def HasNoLinalgTransformMarker : CPred<[{
|
||||
!$0.getAttrOfType<StringAttr>(LinalgTransforms::kLinalgTransformMarker)
|
||||
!op.getAttrOfType<StringAttr>(LinalgTransforms::kLinalgTransformMarker)
|
||||
}]>;
|
||||
|
||||
class HasLinalgTransformMarker<string str> : CPred<[{
|
||||
$0.getAttrOfType<StringAttr>(
|
||||
op.getAttrOfType<StringAttr>(
|
||||
LinalgTransforms::kLinalgTransformMarker) &&
|
||||
$0.getAttrOfType<StringAttr>(
|
||||
op.getAttrOfType<StringAttr>(
|
||||
LinalgTransforms::kLinalgTransformMarker).getValue() == "}] # str # [{"}]>;
|
||||
|
||||
class IsProducedByOpOfType<string str> :
|
||||
CPred<"isProducedByOpOfType<" # str # ">($0, $1)">;
|
||||
CPred<"isProducedByOpOfType<" # str # ">(op, $0)">;
|
||||
|
||||
class AffineMapDomainHasDim<int n> : CPred<[{
|
||||
$0.getAttrOfType<ArrayAttr>(getIndexingMapsAttrName()).getValue()[0].
|
||||
op.getAttrOfType<ArrayAttr>(getIndexingMapsAttrName()).getValue()[0].
|
||||
cast<AffineMapAttr>().getValue().getNumDims() ==}] # n # [{}]>;
|
||||
|
||||
class HasOperandsOfType<string type>: CPred<[{
|
||||
llvm::any_of($0.getOperands(),
|
||||
llvm::any_of(op.getOperands(),
|
||||
[](Value v) {
|
||||
return dyn_cast_or_null<}] # type # [{>(v->getDefiningOp());
|
||||
})
|
||||
|
@ -50,7 +50,7 @@ class HasOperandsOfType<string type>: CPred<[{
|
|||
// patterns.
|
||||
class TileAndFuseLinalgOp<
|
||||
list<int> sizes, list<int> operandIndices, string value> : NativeCodeCall<
|
||||
"if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, $0, {" #
|
||||
"if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, op, {" #
|
||||
StrJoinInt<sizes>.result # "}, {" # StrJoinInt<operandIndices>.result # "}," #
|
||||
" \"" # value # "\")))" #
|
||||
" return matchFailure();">;
|
||||
|
@ -67,7 +67,7 @@ class TileAndFuseLinalgOp<
|
|||
// of elements as `sizes`.
|
||||
class TileLinalgOp<list<int> sizes, string value, list<int> permutation=[]> :
|
||||
NativeCodeCall<
|
||||
"if (failed(tileLinalgOpAndSetMarker($_builder, $0, {" #
|
||||
"if (failed(tileLinalgOpAndSetMarker($_builder, op, {" #
|
||||
StrJoinInt<sizes>.result # "}, \"" # value # "\", {" #
|
||||
StrJoinInt<permutation>.result # "})))" #
|
||||
" return matchFailure();">;
|
||||
|
@ -76,18 +76,18 @@ class TileLinalgOp<list<int> sizes, string value, list<int> permutation=[]> :
|
|||
// Linalg to loop patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
class LinalgOpToLoops<string OpType> : NativeCodeCall<
|
||||
"if (failed(linalgOpToLoops<" # OpType # ">($_builder, $0))) " #
|
||||
"if (failed(linalgOpToLoops<" # OpType # ">($_builder, op))) " #
|
||||
" return matchFailure();">;
|
||||
|
||||
class LinalgOpToAffineLoops<string OpType> : NativeCodeCall<
|
||||
"if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, $0))) " #
|
||||
"if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, op))) " #
|
||||
" return matchFailure();">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg to vector contraction patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
class LinalgOpToVectorContraction<string OpType> : NativeCodeCall<
|
||||
"if (failed(vectorizeGenericOp($_builder, $0))) " #
|
||||
class VectorizeGenericLinalgOp<string OpType> : NativeCodeCall<
|
||||
"if (failed(vectorizeGenericLinalgOp($_builder, op))) " #
|
||||
" return matchFailure();">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -95,14 +95,14 @@ class LinalgOpToVectorContraction<string OpType> : NativeCodeCall<
|
|||
//===----------------------------------------------------------------------===//
|
||||
class PermuteGenericLinalgOp<list<int> permutation, string value> :
|
||||
NativeCodeCall<
|
||||
"if (failed(permuteGenericLinalgOp($_builder, $0, {" #
|
||||
"if (failed(permuteGenericLinalgOp($_builder, op, {" #
|
||||
StrJoinInt<permutation>.result # "}, \"" # value # "\"))) " #
|
||||
" return matchFailure();">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg promote subview operands.
|
||||
//===----------------------------------------------------------------------===//
|
||||
class LinalgOpPromoteSubviews<string OpType> : NativeCodeCall<
|
||||
"if (failed(linalgOpPromoteSubviews($_builder, $0))) " #
|
||||
class PromoteSubviewsLinalgOp<string OpType> : NativeCodeCall<
|
||||
"if (failed(promoteSubviewsLinalgOp($_builder, op))) " #
|
||||
" return matchFailure();">;
|
||||
#endif // LINALG_TRANSFORMS
|
||||
|
|
|
@ -79,7 +79,8 @@ template <typename ConcreteOp>
|
|||
LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op);
|
||||
|
||||
/// Rewrite a linalg.generic into a suitable vector.contraction op.
|
||||
LogicalResult vectorizeGenericOp(PatternRewriter &rewriter, Operation *op);
|
||||
LogicalResult vectorizeGenericLinalgOp(PatternRewriter &rewriter,
|
||||
Operation *op);
|
||||
|
||||
/// Emits a `generic` or `indexed_generic` operation with the `indexing_maps`
|
||||
/// and `iterator_types` permutated according to `permutation`.
|
||||
|
@ -88,7 +89,7 @@ LogicalResult permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
|
|||
StringRef linalgMarker);
|
||||
|
||||
/// Promote std.subviews feeding linalg operations
|
||||
LogicalResult linalgOpPromoteSubviews(PatternRewriter &rewriter, Operation *op);
|
||||
LogicalResult promoteSubviewsLinalgOp(PatternRewriter &rewriter, Operation *op);
|
||||
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
|
|
@ -153,8 +153,8 @@ static bool isMatmul(linalg::GenericOp genericOp) {
|
|||
genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp);
|
||||
}
|
||||
|
||||
LogicalResult mlir::linalg::vectorizeGenericOp(PatternRewriter &rewriter,
|
||||
Operation *op) {
|
||||
LogicalResult mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter,
|
||||
Operation *op) {
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
|
||||
"]: Rewrite linalg op as vector.contract: "
|
||||
<< *op << ":\n");
|
||||
|
@ -223,7 +223,7 @@ mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult mlir::linalg::linalgOpPromoteSubviews(PatternRewriter &rewriter,
|
||||
LogicalResult mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter,
|
||||
Operation *op) {
|
||||
LinalgOp linOp = dyn_cast<LinalgOp>(op);
|
||||
SetVector<Value> subViews;
|
||||
|
|
|
@ -19,11 +19,11 @@ include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td"
|
|||
//===----------------------------------------------------------------------===//
|
||||
// Test Linalg fusion patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pat<(MatmulOp:$consumer $A, $B, $C),
|
||||
(TileAndFuseLinalgOp<[100, 150], [0], "L1"> $consumer),
|
||||
def : Pat<(MatmulOp:$op $A, $_, $_),
|
||||
(TileAndFuseLinalgOp<[100, 150], [0], "L1">),
|
||||
[
|
||||
(Constraint<HasNoLinalgTransformMarker> $consumer),
|
||||
(Constraint<IsProducedByOpOfType<"MatmulOp">> $consumer, $A),
|
||||
(Constraint<HasNoLinalgTransformMarker>),
|
||||
(Constraint<IsProducedByOpOfType<"MatmulOp">> $A),
|
||||
],
|
||||
// In the buffer world there is no use-def chains or dags so benefits
|
||||
// cannot be computed automatically from the length of the matched
|
||||
|
@ -36,91 +36,91 @@ def : Pat<(MatmulOp:$consumer $A, $B, $C),
|
|||
//===----------------------------------------------------------------------===//
|
||||
// Linalg tiling patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pat<(MatmulOp:$op $A, $B, $C),
|
||||
(TileLinalgOp<[2000, 3000, 4000], "L3"> $op),
|
||||
def : Pat<(MatmulOp:$op $_, $_, $_),
|
||||
(TileLinalgOp<[2000, 3000, 4000], "L3">),
|
||||
[(Constraint<Or<[HasNoLinalgTransformMarker,
|
||||
HasLinalgTransformMarker<"MEM">]>> $op)]>;
|
||||
def : Pat<(MatmulOp:$op $A, $B, $C),
|
||||
(TileLinalgOp<[200, 300, 400], "L2"> $op),
|
||||
[(Constraint<HasLinalgTransformMarker<"L3">> $op)]>;
|
||||
def : Pat<(MatmulOp:$op $A, $B, $C),
|
||||
(TileLinalgOp<[20, 30, 40], "L1"> $op),
|
||||
[(Constraint<HasLinalgTransformMarker<"L2">> $op)]>;
|
||||
def : Pat<(MatmulOp:$op $A, $B, $C),
|
||||
(TileLinalgOp<[2, 3, 4], "REG"> $op),
|
||||
[(Constraint<HasLinalgTransformMarker<"L1">> $op)]>;
|
||||
HasLinalgTransformMarker<"MEM">]>>)]>;
|
||||
def : Pat<(MatmulOp:$op $_, $_, $_),
|
||||
(TileLinalgOp<[200, 300, 400], "L2">),
|
||||
[(Constraint<HasLinalgTransformMarker<"L3">>)]>;
|
||||
def : Pat<(MatmulOp:$op $_, $_, $_),
|
||||
(TileLinalgOp<[20, 30, 40], "L1">),
|
||||
[(Constraint<HasLinalgTransformMarker<"L2">>)]>;
|
||||
def : Pat<(MatmulOp:$op $_, $_, $_),
|
||||
(TileLinalgOp<[2, 3, 4], "REG">),
|
||||
[(Constraint<HasLinalgTransformMarker<"L1">>)]>;
|
||||
|
||||
def : Pattern<(MatvecOp:$op $A, $b, $c),
|
||||
[(TileLinalgOp<[5, 6], "L1"> $op)],
|
||||
[(Constraint<HasNoLinalgTransformMarker> $op)]>;
|
||||
def : Pattern<(MatvecOp:$op $_, $_, $_),
|
||||
[(TileLinalgOp<[5, 6], "L1">)],
|
||||
[(Constraint<HasNoLinalgTransformMarker>)]>;
|
||||
|
||||
def : Pattern<(DotOp:$op $a, $b, $c),
|
||||
[(TileLinalgOp<[8000], "L1"> $op)],
|
||||
def : Pattern<(DotOp:$op $_, $_, $_),
|
||||
[(TileLinalgOp<[8000], "L1">)],
|
||||
[(Constraint<Or<[HasNoLinalgTransformMarker,
|
||||
HasLinalgTransformMarker<"MEM">,
|
||||
HasLinalgTransformMarker<"L3">,
|
||||
HasLinalgTransformMarker<"L2">]>> $op)]>;
|
||||
def : Pattern<(DotOp:$op $a, $b, $c),
|
||||
[(TileLinalgOp<[8], "REG"> $op)],
|
||||
[(Constraint<HasLinalgTransformMarker<"L1">> $op)]>;
|
||||
HasLinalgTransformMarker<"L2">]>>)]>;
|
||||
def : Pattern<(DotOp:$op $_, $_, $_),
|
||||
[(TileLinalgOp<[8], "REG">)],
|
||||
[(Constraint<HasLinalgTransformMarker<"L1">>)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg tiling and permutation patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pat<(MatmulOp:$op $A, $B, $C),
|
||||
(TileLinalgOp<[2000, 3000, 4000], "L2__with_perm__", [1,2,0]> $op),
|
||||
[(Constraint<HasLinalgTransformMarker<"__with_perm__">> $op)]>;
|
||||
def : Pat<(MatmulOp:$op $A, $B, $C),
|
||||
(TileLinalgOp<[200, 300, 400], "L1__with_perm__", [1,0,2]> $op),
|
||||
[(Constraint<HasLinalgTransformMarker<"L2__with_perm__">> $op)]>;
|
||||
def : Pat<(MatmulOp:$op $A, $B, $C),
|
||||
(TileLinalgOp<[20, 30, 40], "REG__with_perm__"> $op),
|
||||
[(Constraint<HasLinalgTransformMarker<"L1__with_perm__">> $op)]>;
|
||||
def : Pat<(MatmulOp:$op $_, $_, $_),
|
||||
(TileLinalgOp<[2000, 3000, 4000], "L2__with_perm__", [1,2,0]>),
|
||||
[(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>;
|
||||
def : Pat<(MatmulOp:$op $_, $_, $_),
|
||||
(TileLinalgOp<[200, 300, 400], "L1__with_perm__", [1,0,2]>),
|
||||
[(Constraint<HasLinalgTransformMarker<"L2__with_perm__">>)]>;
|
||||
def : Pat<(MatmulOp:$op $_, $_, $_),
|
||||
(TileLinalgOp<[20, 30, 40], "REG__with_perm__">),
|
||||
[(Constraint<HasLinalgTransformMarker<"L1__with_perm__">>)]>;
|
||||
|
||||
|
||||
def : Pattern<(MatvecOp:$op $A, $b, $c),
|
||||
[(TileLinalgOp<[5, 6], "L1__with_perm__", [1,0]> $op)],
|
||||
[(Constraint<HasLinalgTransformMarker<"__with_perm__">> $op)]>;
|
||||
def : Pattern<(MatvecOp:$op $_, $_, $_),
|
||||
[(TileLinalgOp<[5, 6], "L1__with_perm__", [1,0]>)],
|
||||
[(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>;
|
||||
|
||||
def : Pattern<(DotOp:$op $a, $b, $c),
|
||||
[(TileLinalgOp<[8000], "L1__with_perm__"> $op)],
|
||||
[(Constraint<HasLinalgTransformMarker<"__with_perm__">> $op)]>;
|
||||
def : Pattern<(DotOp:$op $a, $b, $c),
|
||||
[(TileLinalgOp<[8], "REG__with_perm__"> $op)],
|
||||
[(Constraint<HasLinalgTransformMarker<"L1__with_perm__">> $op)]>;
|
||||
def : Pattern<(DotOp:$op $_, $_, $_),
|
||||
[(TileLinalgOp<[8000], "L1__with_perm__">)],
|
||||
[(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>;
|
||||
def : Pattern<(DotOp:$op $_, $_, $_),
|
||||
[(TileLinalgOp<[8], "REG__with_perm__">)],
|
||||
[(Constraint<HasLinalgTransformMarker<"L1__with_perm__">>)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg to loops patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pattern<(DotOp:$op $a, $b, $c),
|
||||
[(LinalgOpToLoops<"DotOp"> $op)],
|
||||
[(Constraint<HasLinalgTransformMarker<"REG">> $op)]>;
|
||||
def : Pattern<(DotOp:$op $_, $_, $_),
|
||||
[(LinalgOpToLoops<"DotOp">)],
|
||||
[(Constraint<HasLinalgTransformMarker<"REG">>)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg to vector contraction patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pattern<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
|
||||
[(LinalgOpToVectorContraction<"GenericOp"> $op)],
|
||||
[(Constraint<HasLinalgTransformMarker<"_marked_matmul_">> $op)]>;
|
||||
def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
|
||||
[(VectorizeGenericLinalgOp<"GenericOp">)],
|
||||
[(Constraint<HasLinalgTransformMarker<"_marked_matmul_">>)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg generic permutation patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pat<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
|
||||
(PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op),
|
||||
def : Pat<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
|
||||
(PermuteGenericLinalgOp<[1,2,0],"PERMUTED">),
|
||||
[(Constraint<And<[HasNoLinalgTransformMarker,
|
||||
AffineMapDomainHasDim<3>]>> $op)]>;
|
||||
AffineMapDomainHasDim<3>]>>)]>;
|
||||
|
||||
def : Pat<(IndexedGenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
|
||||
(PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op),
|
||||
def : Pat<(IndexedGenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
|
||||
(PermuteGenericLinalgOp<[1,2,0],"PERMUTED">),
|
||||
[(Constraint<And<[HasNoLinalgTransformMarker,
|
||||
AffineMapDomainHasDim<3>]>> $op)]>;
|
||||
AffineMapDomainHasDim<3>]>>)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg subview operands promotion.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pat<(MatmulOp:$op $A, $B, $C),
|
||||
(LinalgOpPromoteSubviews<"MatmulOp"> $op),
|
||||
[(Constraint<HasOperandsOfType<"SubViewOp">> $op),
|
||||
(Constraint<HasLinalgTransformMarker<"_promote_views_">> $op)]>;
|
||||
def : Pat<(MatmulOp:$op $_, $_, $_),
|
||||
(PromoteSubviewsLinalgOp<"MatmulOp">),
|
||||
[(Constraint<HasOperandsOfType<"SubViewOp">>),
|
||||
(Constraint<HasLinalgTransformMarker<"_promote_views_">>)]>;
|
||||
#endif // TEST_LINALG_TRANSFORMS_PATTERNS
|
||||
|
|
Loading…
Reference in New Issue