From 3ae56c4135025e5186e289be446246cdc719a5b1 Mon Sep 17 00:00:00 2001 From: Jose Ignacio Gomez Date: Mon, 16 Dec 2019 10:36:06 -0800 Subject: [PATCH] [Linalg] Expose subview promotion as a declarative pattern This PR targest issue tensorflow/mlir#295. It exposes the already existing subiew promotion pass as a declarative pattern Change-Id: If901ebef9fb53fcd0b12ecc536f6b174ce320b92 Closes tensorflow/mlir#315 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/315 from tetuante:issue295 8e5f268b6d85f31015c33505329dbd7a4db97ac5 PiperOrigin-RevId: 285801463 --- .../Transforms/LinalgTransformPatterns.td | 13 +++++ .../Linalg/Transforms/LinalgTransforms.h | 4 ++ .../include/mlir/Dialect/Linalg/Utils/Utils.h | 13 +++++ .../Linalg/Transforms/LinalgTransforms.cpp | 15 ++++++ .../Dialect/Linalg/Transforms/Promotion.cpp | 18 ++++--- .../Dialect/Linalg/transform-patterns.mlir | 50 +++++++++++++++++++ .../TestLinalgTransformPatterns.td | 8 ++- 7 files changed, 113 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td index 6e3ec8895035..d92eb77107f1 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td @@ -43,6 +43,13 @@ class AffineMapDomainHasDim : CPred<[{ $0.getAttrOfType(getIndexingMapsAttrName()).getValue()[0]. cast().getValue().getNumDims() ==}] # n # [{}]>; +class HasOperandsOfType: CPred<[{ + llvm::any_of($0.getOperands(), + [](Value* v) { + return dyn_cast_or_null<}] # type # [{>(v->getDefiningOp()); + }) +}]>; + //===----------------------------------------------------------------------===// // Linalg fusion patterns. //===----------------------------------------------------------------------===// @@ -101,4 +108,10 @@ class PermuteGenericLinalgOp permutation, string value> : StrJoinInt.result # "}, \"" # value # "\"))) " # " return matchFailure();">; +//===----------------------------------------------------------------------===// +// Linalg promote subview operands. +//===----------------------------------------------------------------------===// +class LinalgOpPromoteSubviews : NativeCodeCall< + "if (failed(linalgOpPromoteSubviews($_builder, $0))) " # + " return matchFailure();">; #endif // LINALG_TRANSFORMS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h index b103625a8a47..9682948dbd7e 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h @@ -95,6 +95,10 @@ LogicalResult vectorizeGenericOp(PatternRewriter &rewriter, Operation *op); LogicalResult permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op, ArrayRef permutation, StringRef linalgMarker); + +/// Promote std.subviews feeding linalg operations +LogicalResult linalgOpPromoteSubviews(PatternRewriter &rewriter, Operation *op); + } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 994b3c9f185c..9f1a8342252f 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -23,6 +23,8 @@ #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/EDSC/Helpers.h" +#include "llvm/ADT/SetVector.h" + namespace mlir { class AffineExpr; class AffineMap; @@ -217,6 +219,17 @@ void applyPermutationToVector(SmallVector &inVec, auxVec[i] = inVec[permutation[i]]; inVec = auxVec; } + +/// Prepares the SubView promotion later performed by `promoteSubViews` +/// (where most of the transformation happens). It arranges the new +/// operands for `LinalgOp op` and deallocates the new buffer(s) +/// It is the entry point for declarative transformation +/// Returns the cloned `LinalgOp` with the new operands +LinalgOp promoteSubViewOperands(OpBuilder &b, LinalgOp op, + llvm::SetVector subViews, + bool dynamicBuffers = false, + OperationFolder *folder = nullptr); + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp index 605122326416..74000212373d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -43,6 +43,7 @@ using namespace mlir::linalg; using namespace mlir::linalg::intrinsics; using llvm::dbgs; +using llvm::SetVector; // Marker used as attribute name in generated Linalg rewriting transformations. const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = @@ -230,3 +231,17 @@ mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op, linOp.clone(rewriter, linOp.getLoc(), op->getOperands()); return success(); } + +LogicalResult mlir::linalg::linalgOpPromoteSubviews(PatternRewriter &rewriter, + Operation *op) { + LinalgOp linOp = dyn_cast(op); + SetVector subViews; + for (auto it : linOp.getInputsAndOutputs()) + if (auto sv = dyn_cast_or_null(it->getDefiningOp())) + subViews.insert(sv); + if (!subViews.empty()) { + auto resOp = promoteSubViewOperands(rewriter, linOp, subViews); + return success(resOp); + } + return failure(); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index 32b70346b975..c7fbebce3830 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -160,11 +160,11 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, return res; } -static void promoteSubViewOperands(LinalgOp op, SetVector subViews, - bool dynamicBuffers, - OperationFolder *folder) { +LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op, + SetVector subViews, + bool dynamicBuffers, + OperationFolder *folder) { // 1. Promote the specified views and use them in the new op. - OpBuilder b(op); ScopedContext scope(b, op.getLoc()); auto promotedBufferAndViews = promoteSubViews( b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers, folder); @@ -189,11 +189,12 @@ static void promoteSubViewOperands(LinalgOp op, SetVector subViews, // extra scalars etc. auto operands = getAssumedNonViewOperands(op); opViews.append(operands.begin(), operands.end()); - op.clone(b, op.getLoc(), opViews); + LinalgOp res = op.clone(b, op.getLoc(), opViews); // 3. Emit write-back for the promoted output views: copy the partial view. for (auto viewAndPartialLocalView : writebackViews) { - // Note: use the old op to determine whether the operand view is an output. + // WARNING: MUST use the old op to determine whether the operand view is an + // output. bool isOutput = op.getIndexOfOutput(viewAndPartialLocalView.first).hasValue(); if (isOutput) @@ -203,6 +204,8 @@ static void promoteSubViewOperands(LinalgOp op, SetVector subViews, // 4. Dealloc local buffers. for (const auto &pi : promotedBufferAndViews) dealloc(pi.buffer); + + return res; } static void promoteSubViews(FuncOp f, bool dynamicBuffers) { @@ -212,11 +215,12 @@ static void promoteSubViews(FuncOp f, bool dynamicBuffers) { // TODO(ntv) some heuristic here to decide what to promote. Atm it is all or // nothing. SetVector subViews; + OpBuilder b(op); for (auto it : op.getInputsAndOutputs()) if (auto sv = dyn_cast_or_null(it->getDefiningOp())) subViews.insert(sv); if (!subViews.empty()) { - promoteSubViewOperands(op, subViews, dynamicBuffers, &folder); + promoteSubViewOperands(b, op, subViews, dynamicBuffers, &folder); toErase.push_back(op); } }); diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir index 4a9d8bc95649..8a08bf850fff 100644 --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -315,3 +315,53 @@ func @matmul_perm(%A: memref, // CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] { // CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] { // CHECK : linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref, memref, memref + +func @promote_subview_matmul(%arg0: memref, + %arg1: memref, + %arg2: memref) { + %c2000 = constant 2000 : index + %c3000 = constant 3000 : index + %c4000 = constant 4000 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = dim %arg0, 0 : memref + %1 = dim %arg0, 1 : memref + %2 = dim %arg1, 1 : memref + loop.for %arg3 = %c0 to %0 step %c2000 { + loop.for %arg4 = %c0 to %2 step %c3000 { + loop.for %arg5 = %c0 to %1 step %c4000 { + %3 = std.subview %arg0[%arg3, %arg5][%c2000, %c4000][%c1, %c1] : + memref to memref + %4 = std.subview %arg1[%arg5, %arg4][%c4000, %c3000][%c1, %c1] : + memref to memref + %5 = std.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] : + memref to memref + linalg.matmul(%3, %4, %5) {__internal_linalg_transform__ = "_promote_views_"} : + memref, + memref, + memref + } + } + } + return +} +// CHECK-LABEL: func @promote_subview_matmul +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] { +// CHECK : %[[s0:.*]] = std.subview {{%.*}}[{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}] : memref to memref +// CHECK : %[[s1:.*]] = std.subview {{%.*}}[{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}] : memref to memref +// CHECK : %[[s2:.*]] = std.subview {{%.*}}[{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}] : memref to memref +// CHECK : %[[a0:.*]] = alloc({{%.*}}) : memref +// CHECK : %[[v0:.*]] = std.view %[[a0]][][{{%.*}}, {{%.*}}]: memref to memref +// CHECK : %[[l0:.*]] = linalg.slice %[[v0]][{{%.*}}, {{%.*}}] : memref, !linalg.range, !linalg.range, memref +// CHECK : %[[a1:.*]] = alloc({{%.*}}) : memref +// CHECK : %[[v1:.*]] = std.view %[[a1]][][{{%.*}}, {{%.*}}]: memref to memref +// CHECK : %[[l1:.*]] = linalg.slice %[[v1]][{{%.*}}, {{%.*}}] : memref, !linalg.range, !linalg.range, memref +// CHECK : %[[a2:.*]] = alloc({{%.*}}) : memref +// CHECK : %[[v2:.*]] = std.view %[[a2]][][{{%.*}}, {{%.*}}]: memref to memref +// CHECK : %[[l2:.*]] = linalg.slice %[[v2]][{{%.*}}, {{%.*}}] : memref, !linalg.range, !linalg.range, memref +// CHECK : linalg.copy(%[[s0]], %[[l0]]) : memref, memref +// CHECK : linalg.copy(%[[s1]], %[[l1]]) : memref, memref +// CHECK : linalg.copy(%[[s2]], %[[l2]]) : memref, memref +// CHECK : linalg.matmul(%[[v0]], %[[v1]], %[[v2]]) : memref, memref, memref diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td index 4d8c9282f2d9..d23139273988 100644 --- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td +++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td @@ -115,7 +115,6 @@ def : Pattern<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8), //===----------------------------------------------------------------------===// // Linalg generic permutation patterns. //===----------------------------------------------------------------------===// - def : Pat<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8), (PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op), [(Constraint]>> $op)]>; +//===----------------------------------------------------------------------===// +// Linalg subview operands promotion. +//===----------------------------------------------------------------------===// +def : Pat<(MatmulOp:$op $A, $B, $C), + (LinalgOpPromoteSubviews<"MatmulOp"> $op), + [(Constraint> $op), + (Constraint> $op)]>; #endif // TEST_LINALG_TRANSFORMS_PATTERNS