forked from OSchip/llvm-project
[Linalg] Add a Linalg iterator permutation transformation
This patch closes issue tensorflow/mlir#272 We add a standalone iterator permutation transformation to Linalg. This transformation composes a permutation map with the maps in the "indexing_maps" attribute. It also permutes "iterator_types" accordingly. Change-Id: I7c1e693b8203aeecc595a7c012e738ca1100c857 Closes tensorflow/mlir#307 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/307 from tetuante:issue272 f7908d58792f4111119721885e247045104f1131 PiperOrigin-RevId: 284824102
This commit is contained in:
parent
ad38e49806
commit
b19fed5415
|
@ -24,6 +24,7 @@
|
|||
|
||||
include "mlir/Dialect/Linalg/IR/LinalgOps.td"
|
||||
include "mlir/Dialect/Linalg/IR/LinalgLibraryOps.td"
|
||||
include "mlir/Dialect/AffineOps/AffineOps.td"
|
||||
|
||||
def HasNoLinalgTransformMarker : CPred<[{
|
||||
!$0.getAttrOfType<StringAttr>(LinalgTransforms::kLinalgTransformMarker)
|
||||
|
@ -38,6 +39,10 @@ class HasLinalgTransformMarker<string str> : CPred<[{
|
|||
class IsProducedByOpOfType<string str> :
|
||||
CPred<"isProducedByOpOfType<" # str # ">($0, $1)">;
|
||||
|
||||
class AffineMapDomainHasDim<int n> : CPred<[{
|
||||
$0.getAttrOfType<ArrayAttr>("indexing_maps").getValue()[0].
|
||||
cast<AffineMapAttr>().getValue().getNumDims() ==}] # n # [{}]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg fusion patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -86,4 +91,12 @@ class LinalgOpToVectorContraction<string OpType> : NativeCodeCall<
|
|||
"if (failed(vectorizeGenericOp($_builder, $0))) " #
|
||||
" return matchFailure();">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg generic permutation patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
class PermuteGenericLinalgOp<list<int> permutation, string value> :
|
||||
NativeCodeCall<
|
||||
"if (failed(permuteGenericLinalgOp($_builder, $0, {" #
|
||||
StrJoinInt<permutation>.result # "}, \"" # value # "\"))) " #
|
||||
" return matchFailure();">;
|
||||
#endif // LINALG_TRANSFORMS
|
||||
|
|
|
@ -73,23 +73,28 @@ LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, Operation *op,
|
|||
StringRef linalgMarker,
|
||||
ArrayRef<unsigned> permutation);
|
||||
|
||||
// Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and sets
|
||||
// the attribute `kLinalgTransformMarker` to `linalgMarker`.
|
||||
/// Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and
|
||||
/// sets the attribute `kLinalgTransformMarker` to `linalgMarker`.
|
||||
LogicalResult tileAndFuseLinalgOpAndSetMarker(
|
||||
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
|
||||
ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker);
|
||||
|
||||
// Emits a loop nest of `loop.for` with the proper body for `op`.
|
||||
/// Emits a loop nest of `loop.for` with the proper body for `op`.
|
||||
template <typename ConcreteOp>
|
||||
LogicalResult linalgOpToLoops(PatternRewriter &rewriter, Operation *op);
|
||||
|
||||
// Emits a loop nest of `affine.for` with the proper body for `op`.
|
||||
/// Emits a loop nest of `affine.for` with the proper body for `op`.
|
||||
template <typename ConcreteOp>
|
||||
LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op);
|
||||
|
||||
// Rewrite a linalg.generic into a suitable vector.contraction op.
|
||||
/// Rewrite a linalg.generic into a suitable vector.contraction op.
|
||||
LogicalResult vectorizeGenericOp(PatternRewriter &rewriter, Operation *op);
|
||||
|
||||
/// Emits a `generic` or `indexed_generic` operation with the `indexing_maps`
|
||||
/// and `iterator_types` permutated according to `permutation`.
|
||||
LogicalResult permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<unsigned> permutation,
|
||||
StringRef linalgMarker);
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -205,6 +205,18 @@ promoteSubViews(OpBuilder &b, Location loc, ArrayRef<Value *> subViews,
|
|||
/// tiling to just use the values when cloning `linalgOp`.
|
||||
llvm::SmallVector<Value *, 4> getAssumedNonViewOperands(LinalgOp linalgOp);
|
||||
|
||||
/// Apply the permutation defined by `permutation` to `inVec`.
|
||||
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
|
||||
/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
|
||||
/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`.
|
||||
template <typename T, unsigned N>
|
||||
void applyPermutationToVector(SmallVector<T, N> &inVec,
|
||||
ArrayRef<unsigned> permutation) {
|
||||
SmallVector<T, N> auxVec(inVec.size());
|
||||
for (unsigned i = 0; i < permutation.size(); ++i)
|
||||
auxVec[i] = inVec[permutation[i]];
|
||||
inVec = auxVec;
|
||||
}
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -22,8 +22,10 @@
|
|||
#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h"
|
||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/VectorOps/VectorOps.h"
|
||||
#include "mlir/EDSC/Helpers.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
@ -35,7 +37,10 @@
|
|||
#define DEBUG_TYPE "linalg-transforms"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
using namespace mlir::linalg;
|
||||
using namespace mlir::linalg::intrinsics;
|
||||
|
||||
using llvm::dbgs;
|
||||
|
||||
|
@ -193,3 +198,35 @@ LogicalResult mlir::linalg::vectorizeGenericOp(PatternRewriter &rewriter,
|
|||
std_store(vRes, vectorMemRefC);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<unsigned> permutation,
|
||||
StringRef linalgMarker) {
|
||||
// If permutation is empty, there is nothing to be done.
|
||||
if (permutation.empty())
|
||||
return failure();
|
||||
|
||||
auto linOp = cast<LinalgOp>(op);
|
||||
auto permutationMap = inversePermutation(
|
||||
AffineMap::getPermutationMap(permutation, rewriter.getContext()));
|
||||
SmallVector<AffineMap, 4> newIndexingMap;
|
||||
auto indexingMaps =
|
||||
linOp.getAttrOfType<ArrayAttr>("indexing_maps").getValue();
|
||||
for (unsigned i = 0, e = linOp.getNumInputsAndOutputs(); i != e; ++i) {
|
||||
AffineMap m = indexingMaps[i].cast<AffineMapAttr>().getValue().compose(
|
||||
permutationMap);
|
||||
newIndexingMap.push_back(m);
|
||||
}
|
||||
auto itTypes = linOp.getAttrOfType<ArrayAttr>("iterator_types").getValue();
|
||||
SmallVector<StringRef, 4> itTypesVector;
|
||||
for (unsigned i = 0, e = itTypes.size(); i != e; ++i)
|
||||
itTypesVector.push_back(itTypes[i].cast<StringAttr>().getValue());
|
||||
applyPermutationToVector(itTypesVector, permutation);
|
||||
op->setAttr("indexing_maps", rewriter.getAffineMapArrayAttr(newIndexingMap));
|
||||
op->setAttr("iterator_types", rewriter.getStrArrayAttr(itTypesVector));
|
||||
op->setAttr(LinalgTransforms::kLinalgTransformMarker,
|
||||
rewriter.getStringAttr(linalgMarker));
|
||||
linOp.clone(rewriter, linOp.getLoc(), op->getOperands());
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -215,14 +215,6 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
|
|||
return res;
|
||||
}
|
||||
|
||||
void applyPermutationToLoopRanges(SmallVector<SubViewOp::Range, 4> &loopRanges,
|
||||
ArrayRef<unsigned> permutation) {
|
||||
SmallVector<SubViewOp::Range, 4> auxVec(loopRanges.size());
|
||||
for (unsigned i = 0; i < permutation.size(); ++i)
|
||||
auxVec[i] = loopRanges[permutation[i]];
|
||||
loopRanges = auxVec;
|
||||
}
|
||||
|
||||
llvm::Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
|
||||
OpBuilder &b, LinalgOp op, ArrayRef<Value *> tileSizes,
|
||||
ArrayRef<unsigned> permutation, OperationFolder *folder) {
|
||||
|
@ -256,7 +248,7 @@ llvm::Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
|
|||
makeTiledLoopRanges(b, scope.getLocation(), viewSizesToLoopsMap,
|
||||
viewSizes, tileSizes, folder);
|
||||
if (!permutation.empty())
|
||||
applyPermutationToLoopRanges(loopRanges, permutation);
|
||||
applyPermutationToVector(loopRanges, permutation);
|
||||
|
||||
// 3. Create the tiled loops.
|
||||
LinalgOp res = op;
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
// CHECK-DAG: #[[mk:.*]] = (d0, d1, d2) -> (d0, d2)
|
||||
// CHECK-DAG: #[[kn:.*]] = (d0, d1, d2) -> (d2, d1)
|
||||
// CHECK-DAG: #[[mn:.*]] = (d0, d1, d2) -> (d0, d1)
|
||||
// CHECK-DAG: #[[nm:.*]] = (d0, d1, d2) -> (d1, d0)
|
||||
// CHECK-DAG: #[[km:.*]] = (d0, d1, d2) -> (d2, d0)
|
||||
|
||||
func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
|
||||
%y: memref<?xf32, offset: ?, strides: [1]>,
|
||||
|
@ -191,3 +193,53 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
|
|||
// CHECK: load %{{.*}}[] : memref<vector<8x32xf32>>
|
||||
// CHECK: vector.contract {indexing_maps = [#[[mk]], #[[kn]], #[[mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
|
||||
// CHECK: store %{{.*}}, %{{.*}}[] : memref<vector<8x32xf32>>
|
||||
func @fma(%a: f32, %b: f32, %c: f32) -> f32 {
|
||||
%d = mulf %a, %b: f32
|
||||
%e = addf %c, %d: f32
|
||||
return %e: f32
|
||||
}
|
||||
#matmul_accesses = [
|
||||
(m, n, k) -> (m, k),
|
||||
(m, n, k) -> (k, n),
|
||||
(m, n, k) -> (m, n)
|
||||
]
|
||||
#generic_matmul_trait = {
|
||||
fun = @fma,
|
||||
indexing_maps = #matmul_accesses,
|
||||
library_call = "linalg_matmul",
|
||||
n_views = [2, 1],
|
||||
iterator_types = ["parallel", "parallel", "reduction"]
|
||||
}
|
||||
|
||||
func @permute_generic(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
|
||||
linalg.generic #generic_matmul_trait %A, %B, %C : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL : func @fma
|
||||
// CHECK-LABEL : func @permute_generic
|
||||
// CHECK : linalg.generic {fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul", n_views = [2, 1]} %{{.*}}, %{{.*}}, %{{.*}} : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
|
||||
|
||||
func @fma_indexed(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32) -> f32 {
|
||||
%d = mulf %a, %b: f32
|
||||
%e = addf %c, %d: f32
|
||||
return %e: f32
|
||||
}
|
||||
#indexed_matmul_trait = {
|
||||
fun = @fma_indexed,
|
||||
indexing_maps = #matmul_accesses,
|
||||
library_call = "linalg_matmul_indexed",
|
||||
n_views = [2, 1],
|
||||
iterator_types = ["parallel", "parallel", "reduction"]
|
||||
}
|
||||
func @permute_generic_indexed(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
|
||||
linalg.indexed_generic #indexed_matmul_trait %A, %B, %C : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL : func @fma_indexed
|
||||
// CHECK-LABEL : func @permute_generic_indexed
|
||||
// CHECK : linalg.indexed_generic {fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul_indexed", n_views = [2, 1]} %{{.*}}, %{{.*}}, %{{.*}} : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
|
||||
|
|
|
@ -87,4 +87,17 @@ def : Pattern<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7),
|
|||
[(LinalgOpToVectorContraction<"GenericOp"> $op)],
|
||||
[(Constraint<HasLinalgTransformMarker<"_marked_matmul_">> $op)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg generic permutation patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def : Pat<(GenericOp:$op $input, $imap, $itypes, $nviews, $doc, $fun, $libcall),
|
||||
(PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op),
|
||||
[(Constraint<And<[HasNoLinalgTransformMarker,
|
||||
AffineMapDomainHasDim<3>]>> $op)]>;
|
||||
|
||||
def : Pat<(IndexedGenericOp:$op $input, $imap, $itypes, $nviews, $doc, $fun, $libcall),
|
||||
(PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op),
|
||||
[(Constraint<And<[HasNoLinalgTransformMarker,
|
||||
AffineMapDomainHasDim<3>]>> $op)]>;
|
||||
#endif // TEST_LINALG_TRANSFORMS_PATTERNS
|
||||
|
|
Loading…
Reference in New Issue