[Linalg] Expose subview promotion as a declarative pattern

This PR targest issue tensorflow/mlir#295. It exposes the already existing
subiew promotion pass as a declarative pattern

Change-Id: If901ebef9fb53fcd0b12ecc536f6b174ce320b92

Closes tensorflow/mlir#315

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/315 from tetuante:issue295 8e5f268b6d85f31015c33505329dbd7a4db97ac5
PiperOrigin-RevId: 285801463
This commit is contained in:
Jose Ignacio Gomez 2019-12-16 10:36:06 -08:00 committed by A. Unique TensorFlower
parent c290e993b2
commit 3ae56c4135
7 changed files with 113 additions and 8 deletions

View File

@ -43,6 +43,13 @@ class AffineMapDomainHasDim<int n> : CPred<[{
$0.getAttrOfType<ArrayAttr>(getIndexingMapsAttrName()).getValue()[0]. $0.getAttrOfType<ArrayAttr>(getIndexingMapsAttrName()).getValue()[0].
cast<AffineMapAttr>().getValue().getNumDims() ==}] # n # [{}]>; cast<AffineMapAttr>().getValue().getNumDims() ==}] # n # [{}]>;
class HasOperandsOfType<string type>: CPred<[{
llvm::any_of($0.getOperands(),
[](Value* v) {
return dyn_cast_or_null<}] # type # [{>(v->getDefiningOp());
})
}]>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Linalg fusion patterns. // Linalg fusion patterns.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -101,4 +108,10 @@ class PermuteGenericLinalgOp<list<int> permutation, string value> :
StrJoinInt<permutation>.result # "}, \"" # value # "\"))) " # StrJoinInt<permutation>.result # "}, \"" # value # "\"))) " #
" return matchFailure();">; " return matchFailure();">;
//===----------------------------------------------------------------------===//
// Linalg promote subview operands.
//===----------------------------------------------------------------------===//
class LinalgOpPromoteSubviews<string OpType> : NativeCodeCall<
"if (failed(linalgOpPromoteSubviews($_builder, $0))) " #
" return matchFailure();">;
#endif // LINALG_TRANSFORMS #endif // LINALG_TRANSFORMS

View File

@ -95,6 +95,10 @@ LogicalResult vectorizeGenericOp(PatternRewriter &rewriter, Operation *op);
LogicalResult permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op, LogicalResult permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
ArrayRef<unsigned> permutation, ArrayRef<unsigned> permutation,
StringRef linalgMarker); StringRef linalgMarker);
/// Promote std.subviews feeding linalg operations
LogicalResult linalgOpPromoteSubviews(PatternRewriter &rewriter, Operation *op);
} // namespace linalg } // namespace linalg
} // namespace mlir } // namespace mlir

View File

@ -23,6 +23,8 @@
#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/EDSC/Helpers.h" #include "mlir/EDSC/Helpers.h"
#include "llvm/ADT/SetVector.h"
namespace mlir { namespace mlir {
class AffineExpr; class AffineExpr;
class AffineMap; class AffineMap;
@ -217,6 +219,17 @@ void applyPermutationToVector(SmallVector<T, N> &inVec,
auxVec[i] = inVec[permutation[i]]; auxVec[i] = inVec[permutation[i]];
inVec = auxVec; inVec = auxVec;
} }
/// Prepares the SubView promotion later performed by `promoteSubViews`
/// (where most of the transformation happens). It arranges the new
/// operands for `LinalgOp op` and deallocates the new buffer(s)
/// It is the entry point for declarative transformation
/// Returns the cloned `LinalgOp` with the new operands
LinalgOp promoteSubViewOperands(OpBuilder &b, LinalgOp op,
llvm::SetVector<Value *> subViews,
bool dynamicBuffers = false,
OperationFolder *folder = nullptr);
} // namespace linalg } // namespace linalg
} // namespace mlir } // namespace mlir

View File

@ -43,6 +43,7 @@ using namespace mlir::linalg;
using namespace mlir::linalg::intrinsics; using namespace mlir::linalg::intrinsics;
using llvm::dbgs; using llvm::dbgs;
using llvm::SetVector;
// Marker used as attribute name in generated Linalg rewriting transformations. // Marker used as attribute name in generated Linalg rewriting transformations.
const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
@ -230,3 +231,17 @@ mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
linOp.clone(rewriter, linOp.getLoc(), op->getOperands()); linOp.clone(rewriter, linOp.getLoc(), op->getOperands());
return success(); return success();
} }
LogicalResult mlir::linalg::linalgOpPromoteSubviews(PatternRewriter &rewriter,
Operation *op) {
LinalgOp linOp = dyn_cast<LinalgOp>(op);
SetVector<Value *> subViews;
for (auto it : linOp.getInputsAndOutputs())
if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp()))
subViews.insert(sv);
if (!subViews.empty()) {
auto resOp = promoteSubViewOperands(rewriter, linOp, subViews);
return success(resOp);
}
return failure();
}

View File

@ -160,11 +160,11 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
return res; return res;
} }
static void promoteSubViewOperands(LinalgOp op, SetVector<Value *> subViews, LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op,
SetVector<Value *> subViews,
bool dynamicBuffers, bool dynamicBuffers,
OperationFolder *folder) { OperationFolder *folder) {
// 1. Promote the specified views and use them in the new op. // 1. Promote the specified views and use them in the new op.
OpBuilder b(op);
ScopedContext scope(b, op.getLoc()); ScopedContext scope(b, op.getLoc());
auto promotedBufferAndViews = promoteSubViews( auto promotedBufferAndViews = promoteSubViews(
b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers, folder); b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers, folder);
@ -189,11 +189,12 @@ static void promoteSubViewOperands(LinalgOp op, SetVector<Value *> subViews,
// extra scalars etc. // extra scalars etc.
auto operands = getAssumedNonViewOperands(op); auto operands = getAssumedNonViewOperands(op);
opViews.append(operands.begin(), operands.end()); opViews.append(operands.begin(), operands.end());
op.clone(b, op.getLoc(), opViews); LinalgOp res = op.clone(b, op.getLoc(), opViews);
// 3. Emit write-back for the promoted output views: copy the partial view. // 3. Emit write-back for the promoted output views: copy the partial view.
for (auto viewAndPartialLocalView : writebackViews) { for (auto viewAndPartialLocalView : writebackViews) {
// Note: use the old op to determine whether the operand view is an output. // WARNING: MUST use the old op to determine whether the operand view is an
// output.
bool isOutput = bool isOutput =
op.getIndexOfOutput(viewAndPartialLocalView.first).hasValue(); op.getIndexOfOutput(viewAndPartialLocalView.first).hasValue();
if (isOutput) if (isOutput)
@ -203,6 +204,8 @@ static void promoteSubViewOperands(LinalgOp op, SetVector<Value *> subViews,
// 4. Dealloc local buffers. // 4. Dealloc local buffers.
for (const auto &pi : promotedBufferAndViews) for (const auto &pi : promotedBufferAndViews)
dealloc(pi.buffer); dealloc(pi.buffer);
return res;
} }
static void promoteSubViews(FuncOp f, bool dynamicBuffers) { static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
@ -212,11 +215,12 @@ static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
// TODO(ntv) some heuristic here to decide what to promote. Atm it is all or // TODO(ntv) some heuristic here to decide what to promote. Atm it is all or
// nothing. // nothing.
SetVector<Value *> subViews; SetVector<Value *> subViews;
OpBuilder b(op);
for (auto it : op.getInputsAndOutputs()) for (auto it : op.getInputsAndOutputs())
if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp())) if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp()))
subViews.insert(sv); subViews.insert(sv);
if (!subViews.empty()) { if (!subViews.empty()) {
promoteSubViewOperands(op, subViews, dynamicBuffers, &folder); promoteSubViewOperands(b, op, subViews, dynamicBuffers, &folder);
toErase.push_back(op); toErase.push_back(op);
} }
}); });

View File

@ -315,3 +315,53 @@ func @matmul_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] { // CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] {
// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] { // CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] {
// CHECK : linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]> // CHECK : linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
func @promote_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%arg2: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
%c2000 = constant 2000 : index
%c3000 = constant 3000 : index
%c4000 = constant 4000 : index
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = dim %arg0, 0 : memref<?x?xf32, offset: ?, strides: [?, 1]>
%1 = dim %arg0, 1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
%2 = dim %arg1, 1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
loop.for %arg3 = %c0 to %0 step %c2000 {
loop.for %arg4 = %c0 to %2 step %c3000 {
loop.for %arg5 = %c0 to %1 step %c4000 {
%3 = std.subview %arg0[%arg3, %arg5][%c2000, %c4000][%c1, %c1] :
memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
%4 = std.subview %arg1[%arg5, %arg4][%c4000, %c3000][%c1, %c1] :
memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
%5 = std.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] :
memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
linalg.matmul(%3, %4, %5) {__internal_linalg_transform__ = "_promote_views_"} :
memref<?x?xf32, offset: ?, strides: [?, ?]>,
memref<?x?xf32, offset: ?, strides: [?, ?]>,
memref<?x?xf32, offset: ?, strides: [?, ?]>
}
}
}
return
}
// CHECK-LABEL: func @promote_subview_matmul
// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] {
// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] {
// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] {
// CHECK : %[[s0:.*]] = std.subview {{%.*}}[{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
// CHECK : %[[s1:.*]] = std.subview {{%.*}}[{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
// CHECK : %[[s2:.*]] = std.subview {{%.*}}[{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
// CHECK : %[[a0:.*]] = alloc({{%.*}}) : memref<?xi8>
// CHECK : %[[v0:.*]] = std.view %[[a0]][][{{%.*}}, {{%.*}}]: memref<?xi8> to memref<?x?xf32>
// CHECK : %[[l0:.*]] = linalg.slice %[[v0]][{{%.*}}, {{%.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #map{{.*}}>
// CHECK : %[[a1:.*]] = alloc({{%.*}}) : memref<?xi8>
// CHECK : %[[v1:.*]] = std.view %[[a1]][][{{%.*}}, {{%.*}}]: memref<?xi8> to memref<?x?xf32>
// CHECK : %[[l1:.*]] = linalg.slice %[[v1]][{{%.*}}, {{%.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #map{{.*}}>
// CHECK : %[[a2:.*]] = alloc({{%.*}}) : memref<?xi8>
// CHECK : %[[v2:.*]] = std.view %[[a2]][][{{%.*}}, {{%.*}}]: memref<?xi8> to memref<?x?xf32>
// CHECK : %[[l2:.*]] = linalg.slice %[[v2]][{{%.*}}, {{%.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #map{{.*}}>
// CHECK : linalg.copy(%[[s0]], %[[l0]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
// CHECK : linalg.copy(%[[s1]], %[[l1]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
// CHECK : linalg.copy(%[[s2]], %[[l2]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
// CHECK : linalg.matmul(%[[v0]], %[[v1]], %[[v2]]) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>

View File

@ -115,7 +115,6 @@ def : Pattern<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Linalg generic permutation patterns. // Linalg generic permutation patterns.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def : Pat<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8), def : Pat<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
(PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op), (PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op),
[(Constraint<And<[HasNoLinalgTransformMarker, [(Constraint<And<[HasNoLinalgTransformMarker,
@ -126,4 +125,11 @@ def : Pat<(IndexedGenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
[(Constraint<And<[HasNoLinalgTransformMarker, [(Constraint<And<[HasNoLinalgTransformMarker,
AffineMapDomainHasDim<3>]>> $op)]>; AffineMapDomainHasDim<3>]>> $op)]>;
//===----------------------------------------------------------------------===//
// Linalg subview operands promotion.
//===----------------------------------------------------------------------===//
def : Pat<(MatmulOp:$op $A, $B, $C),
(LinalgOpPromoteSubviews<"MatmulOp"> $op),
[(Constraint<HasOperandsOfType<"SubViewOp">> $op),
(Constraint<HasLinalgTransformMarker<"_promote_views_">> $op)]>;
#endif // TEST_LINALG_TRANSFORMS_PATTERNS #endif // TEST_LINALG_TRANSFORMS_PATTERNS