forked from OSchip/llvm-project
[Linalg] Expose subview promotion as a declarative pattern
This PR targest issue tensorflow/mlir#295. It exposes the already existing subiew promotion pass as a declarative pattern Change-Id: If901ebef9fb53fcd0b12ecc536f6b174ce320b92 Closes tensorflow/mlir#315 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/315 from tetuante:issue295 8e5f268b6d85f31015c33505329dbd7a4db97ac5 PiperOrigin-RevId: 285801463
This commit is contained in:
parent
c290e993b2
commit
3ae56c4135
|
@ -43,6 +43,13 @@ class AffineMapDomainHasDim<int n> : CPred<[{
|
||||||
$0.getAttrOfType<ArrayAttr>(getIndexingMapsAttrName()).getValue()[0].
|
$0.getAttrOfType<ArrayAttr>(getIndexingMapsAttrName()).getValue()[0].
|
||||||
cast<AffineMapAttr>().getValue().getNumDims() ==}] # n # [{}]>;
|
cast<AffineMapAttr>().getValue().getNumDims() ==}] # n # [{}]>;
|
||||||
|
|
||||||
|
class HasOperandsOfType<string type>: CPred<[{
|
||||||
|
llvm::any_of($0.getOperands(),
|
||||||
|
[](Value* v) {
|
||||||
|
return dyn_cast_or_null<}] # type # [{>(v->getDefiningOp());
|
||||||
|
})
|
||||||
|
}]>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Linalg fusion patterns.
|
// Linalg fusion patterns.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -101,4 +108,10 @@ class PermuteGenericLinalgOp<list<int> permutation, string value> :
|
||||||
StrJoinInt<permutation>.result # "}, \"" # value # "\"))) " #
|
StrJoinInt<permutation>.result # "}, \"" # value # "\"))) " #
|
||||||
" return matchFailure();">;
|
" return matchFailure();">;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Linalg promote subview operands.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
class LinalgOpPromoteSubviews<string OpType> : NativeCodeCall<
|
||||||
|
"if (failed(linalgOpPromoteSubviews($_builder, $0))) " #
|
||||||
|
" return matchFailure();">;
|
||||||
#endif // LINALG_TRANSFORMS
|
#endif // LINALG_TRANSFORMS
|
||||||
|
|
|
@ -95,6 +95,10 @@ LogicalResult vectorizeGenericOp(PatternRewriter &rewriter, Operation *op);
|
||||||
LogicalResult permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
|
LogicalResult permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
|
||||||
ArrayRef<unsigned> permutation,
|
ArrayRef<unsigned> permutation,
|
||||||
StringRef linalgMarker);
|
StringRef linalgMarker);
|
||||||
|
|
||||||
|
/// Promote std.subviews feeding linalg operations
|
||||||
|
LogicalResult linalgOpPromoteSubviews(PatternRewriter &rewriter, Operation *op);
|
||||||
|
|
||||||
} // namespace linalg
|
} // namespace linalg
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,8 @@
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
#include "mlir/EDSC/Helpers.h"
|
#include "mlir/EDSC/Helpers.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SetVector.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class AffineExpr;
|
class AffineExpr;
|
||||||
class AffineMap;
|
class AffineMap;
|
||||||
|
@ -217,6 +219,17 @@ void applyPermutationToVector(SmallVector<T, N> &inVec,
|
||||||
auxVec[i] = inVec[permutation[i]];
|
auxVec[i] = inVec[permutation[i]];
|
||||||
inVec = auxVec;
|
inVec = auxVec;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Prepares the SubView promotion later performed by `promoteSubViews`
|
||||||
|
/// (where most of the transformation happens). It arranges the new
|
||||||
|
/// operands for `LinalgOp op` and deallocates the new buffer(s)
|
||||||
|
/// It is the entry point for declarative transformation
|
||||||
|
/// Returns the cloned `LinalgOp` with the new operands
|
||||||
|
LinalgOp promoteSubViewOperands(OpBuilder &b, LinalgOp op,
|
||||||
|
llvm::SetVector<Value *> subViews,
|
||||||
|
bool dynamicBuffers = false,
|
||||||
|
OperationFolder *folder = nullptr);
|
||||||
|
|
||||||
} // namespace linalg
|
} // namespace linalg
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -43,6 +43,7 @@ using namespace mlir::linalg;
|
||||||
using namespace mlir::linalg::intrinsics;
|
using namespace mlir::linalg::intrinsics;
|
||||||
|
|
||||||
using llvm::dbgs;
|
using llvm::dbgs;
|
||||||
|
using llvm::SetVector;
|
||||||
|
|
||||||
// Marker used as attribute name in generated Linalg rewriting transformations.
|
// Marker used as attribute name in generated Linalg rewriting transformations.
|
||||||
const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
|
const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
|
||||||
|
@ -230,3 +231,17 @@ mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
|
||||||
linOp.clone(rewriter, linOp.getLoc(), op->getOperands());
|
linOp.clone(rewriter, linOp.getLoc(), op->getOperands());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult mlir::linalg::linalgOpPromoteSubviews(PatternRewriter &rewriter,
|
||||||
|
Operation *op) {
|
||||||
|
LinalgOp linOp = dyn_cast<LinalgOp>(op);
|
||||||
|
SetVector<Value *> subViews;
|
||||||
|
for (auto it : linOp.getInputsAndOutputs())
|
||||||
|
if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp()))
|
||||||
|
subViews.insert(sv);
|
||||||
|
if (!subViews.empty()) {
|
||||||
|
auto resOp = promoteSubViewOperands(rewriter, linOp, subViews);
|
||||||
|
return success(resOp);
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
|
@ -160,11 +160,11 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void promoteSubViewOperands(LinalgOp op, SetVector<Value *> subViews,
|
LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op,
|
||||||
|
SetVector<Value *> subViews,
|
||||||
bool dynamicBuffers,
|
bool dynamicBuffers,
|
||||||
OperationFolder *folder) {
|
OperationFolder *folder) {
|
||||||
// 1. Promote the specified views and use them in the new op.
|
// 1. Promote the specified views and use them in the new op.
|
||||||
OpBuilder b(op);
|
|
||||||
ScopedContext scope(b, op.getLoc());
|
ScopedContext scope(b, op.getLoc());
|
||||||
auto promotedBufferAndViews = promoteSubViews(
|
auto promotedBufferAndViews = promoteSubViews(
|
||||||
b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers, folder);
|
b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers, folder);
|
||||||
|
@ -189,11 +189,12 @@ static void promoteSubViewOperands(LinalgOp op, SetVector<Value *> subViews,
|
||||||
// extra scalars etc.
|
// extra scalars etc.
|
||||||
auto operands = getAssumedNonViewOperands(op);
|
auto operands = getAssumedNonViewOperands(op);
|
||||||
opViews.append(operands.begin(), operands.end());
|
opViews.append(operands.begin(), operands.end());
|
||||||
op.clone(b, op.getLoc(), opViews);
|
LinalgOp res = op.clone(b, op.getLoc(), opViews);
|
||||||
|
|
||||||
// 3. Emit write-back for the promoted output views: copy the partial view.
|
// 3. Emit write-back for the promoted output views: copy the partial view.
|
||||||
for (auto viewAndPartialLocalView : writebackViews) {
|
for (auto viewAndPartialLocalView : writebackViews) {
|
||||||
// Note: use the old op to determine whether the operand view is an output.
|
// WARNING: MUST use the old op to determine whether the operand view is an
|
||||||
|
// output.
|
||||||
bool isOutput =
|
bool isOutput =
|
||||||
op.getIndexOfOutput(viewAndPartialLocalView.first).hasValue();
|
op.getIndexOfOutput(viewAndPartialLocalView.first).hasValue();
|
||||||
if (isOutput)
|
if (isOutput)
|
||||||
|
@ -203,6 +204,8 @@ static void promoteSubViewOperands(LinalgOp op, SetVector<Value *> subViews,
|
||||||
// 4. Dealloc local buffers.
|
// 4. Dealloc local buffers.
|
||||||
for (const auto &pi : promotedBufferAndViews)
|
for (const auto &pi : promotedBufferAndViews)
|
||||||
dealloc(pi.buffer);
|
dealloc(pi.buffer);
|
||||||
|
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
|
static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
|
||||||
|
@ -212,11 +215,12 @@ static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
|
||||||
// TODO(ntv) some heuristic here to decide what to promote. Atm it is all or
|
// TODO(ntv) some heuristic here to decide what to promote. Atm it is all or
|
||||||
// nothing.
|
// nothing.
|
||||||
SetVector<Value *> subViews;
|
SetVector<Value *> subViews;
|
||||||
|
OpBuilder b(op);
|
||||||
for (auto it : op.getInputsAndOutputs())
|
for (auto it : op.getInputsAndOutputs())
|
||||||
if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp()))
|
if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp()))
|
||||||
subViews.insert(sv);
|
subViews.insert(sv);
|
||||||
if (!subViews.empty()) {
|
if (!subViews.empty()) {
|
||||||
promoteSubViewOperands(op, subViews, dynamicBuffers, &folder);
|
promoteSubViewOperands(b, op, subViews, dynamicBuffers, &folder);
|
||||||
toErase.push_back(op);
|
toErase.push_back(op);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
|
@ -315,3 +315,53 @@ func @matmul_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||||
// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] {
|
// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] {
|
||||||
// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] {
|
// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] {
|
||||||
// CHECK : linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
|
// CHECK : linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
|
||||||
|
|
||||||
|
func @promote_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||||
|
%arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||||
|
%arg2: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
|
||||||
|
%c2000 = constant 2000 : index
|
||||||
|
%c3000 = constant 3000 : index
|
||||||
|
%c4000 = constant 4000 : index
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%c1 = constant 1 : index
|
||||||
|
%0 = dim %arg0, 0 : memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||||
|
%1 = dim %arg0, 1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||||
|
%2 = dim %arg1, 1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||||
|
loop.for %arg3 = %c0 to %0 step %c2000 {
|
||||||
|
loop.for %arg4 = %c0 to %2 step %c3000 {
|
||||||
|
loop.for %arg5 = %c0 to %1 step %c4000 {
|
||||||
|
%3 = std.subview %arg0[%arg3, %arg5][%c2000, %c4000][%c1, %c1] :
|
||||||
|
memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
|
||||||
|
%4 = std.subview %arg1[%arg5, %arg4][%c4000, %c3000][%c1, %c1] :
|
||||||
|
memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
|
||||||
|
%5 = std.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] :
|
||||||
|
memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
|
||||||
|
linalg.matmul(%3, %4, %5) {__internal_linalg_transform__ = "_promote_views_"} :
|
||||||
|
memref<?x?xf32, offset: ?, strides: [?, ?]>,
|
||||||
|
memref<?x?xf32, offset: ?, strides: [?, ?]>,
|
||||||
|
memref<?x?xf32, offset: ?, strides: [?, ?]>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @promote_subview_matmul
|
||||||
|
// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] {
|
||||||
|
// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] {
|
||||||
|
// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] {
|
||||||
|
// CHECK : %[[s0:.*]] = std.subview {{%.*}}[{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
|
||||||
|
// CHECK : %[[s1:.*]] = std.subview {{%.*}}[{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
|
||||||
|
// CHECK : %[[s2:.*]] = std.subview {{%.*}}[{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
|
||||||
|
// CHECK : %[[a0:.*]] = alloc({{%.*}}) : memref<?xi8>
|
||||||
|
// CHECK : %[[v0:.*]] = std.view %[[a0]][][{{%.*}}, {{%.*}}]: memref<?xi8> to memref<?x?xf32>
|
||||||
|
// CHECK : %[[l0:.*]] = linalg.slice %[[v0]][{{%.*}}, {{%.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #map{{.*}}>
|
||||||
|
// CHECK : %[[a1:.*]] = alloc({{%.*}}) : memref<?xi8>
|
||||||
|
// CHECK : %[[v1:.*]] = std.view %[[a1]][][{{%.*}}, {{%.*}}]: memref<?xi8> to memref<?x?xf32>
|
||||||
|
// CHECK : %[[l1:.*]] = linalg.slice %[[v1]][{{%.*}}, {{%.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #map{{.*}}>
|
||||||
|
// CHECK : %[[a2:.*]] = alloc({{%.*}}) : memref<?xi8>
|
||||||
|
// CHECK : %[[v2:.*]] = std.view %[[a2]][][{{%.*}}, {{%.*}}]: memref<?xi8> to memref<?x?xf32>
|
||||||
|
// CHECK : %[[l2:.*]] = linalg.slice %[[v2]][{{%.*}}, {{%.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #map{{.*}}>
|
||||||
|
// CHECK : linalg.copy(%[[s0]], %[[l0]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
|
||||||
|
// CHECK : linalg.copy(%[[s1]], %[[l1]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
|
||||||
|
// CHECK : linalg.copy(%[[s2]], %[[l2]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
|
||||||
|
// CHECK : linalg.matmul(%[[v0]], %[[v1]], %[[v2]]) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
|
||||||
|
|
|
@ -115,7 +115,6 @@ def : Pattern<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Linalg generic permutation patterns.
|
// Linalg generic permutation patterns.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def : Pat<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
|
def : Pat<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
|
||||||
(PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op),
|
(PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op),
|
||||||
[(Constraint<And<[HasNoLinalgTransformMarker,
|
[(Constraint<And<[HasNoLinalgTransformMarker,
|
||||||
|
@ -126,4 +125,11 @@ def : Pat<(IndexedGenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
|
||||||
[(Constraint<And<[HasNoLinalgTransformMarker,
|
[(Constraint<And<[HasNoLinalgTransformMarker,
|
||||||
AffineMapDomainHasDim<3>]>> $op)]>;
|
AffineMapDomainHasDim<3>]>> $op)]>;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Linalg subview operands promotion.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
def : Pat<(MatmulOp:$op $A, $B, $C),
|
||||||
|
(LinalgOpPromoteSubviews<"MatmulOp"> $op),
|
||||||
|
[(Constraint<HasOperandsOfType<"SubViewOp">> $op),
|
||||||
|
(Constraint<HasLinalgTransformMarker<"_promote_views_">> $op)]>;
|
||||||
#endif // TEST_LINALG_TRANSFORMS_PATTERNS
|
#endif // TEST_LINALG_TRANSFORMS_PATTERNS
|
||||||
|
|
Loading…
Reference in New Issue