[Linalg] Add a Linalg iterator permutation transformation

This patch closes issue tensorflow/mlir#272
We add a standalone iterator permutation transformation to Linalg.
This transformation composes a permutation map with the maps in the
"indexing_maps" attribute. It also permutes "iterator_types"
accordingly.

Change-Id: I7c1e693b8203aeecc595a7c012e738ca1100c857

Closes tensorflow/mlir#307

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/307 from tetuante:issue272 f7908d58792f4111119721885e247045104f1131
PiperOrigin-RevId: 284824102
This commit is contained in:
Jose Ignacio Gomez 2019-12-10 12:25:10 -08:00 committed by A. Unique TensorFlower
parent ad38e49806
commit b19fed5415
7 changed files with 138 additions and 14 deletions

View File

@ -24,6 +24,7 @@
include "mlir/Dialect/Linalg/IR/LinalgOps.td"
include "mlir/Dialect/Linalg/IR/LinalgLibraryOps.td"
include "mlir/Dialect/AffineOps/AffineOps.td"
def HasNoLinalgTransformMarker : CPred<[{
!$0.getAttrOfType<StringAttr>(LinalgTransforms::kLinalgTransformMarker)
@ -38,6 +39,10 @@ class HasLinalgTransformMarker<string str> : CPred<[{
class IsProducedByOpOfType<string str> :
CPred<"isProducedByOpOfType<" # str # ">($0, $1)">;
class AffineMapDomainHasDim<int n> : CPred<[{
$0.getAttrOfType<ArrayAttr>("indexing_maps").getValue()[0].
cast<AffineMapAttr>().getValue().getNumDims() ==}] # n # [{}]>;
//===----------------------------------------------------------------------===//
// Linalg fusion patterns.
//===----------------------------------------------------------------------===//
@ -86,4 +91,12 @@ class LinalgOpToVectorContraction<string OpType> : NativeCodeCall<
"if (failed(vectorizeGenericOp($_builder, $0))) " #
" return matchFailure();">;
//===----------------------------------------------------------------------===//
// Linalg generic permutation patterns.
//===----------------------------------------------------------------------===//
class PermuteGenericLinalgOp<list<int> permutation, string value> :
NativeCodeCall<
"if (failed(permuteGenericLinalgOp($_builder, $0, {" #
StrJoinInt<permutation>.result # "}, \"" # value # "\"))) " #
" return matchFailure();">;
#endif // LINALG_TRANSFORMS

View File

@ -73,23 +73,28 @@ LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, Operation *op,
StringRef linalgMarker,
ArrayRef<unsigned> permutation);
// Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and sets
// the attribute `kLinalgTransformMarker` to `linalgMarker`.
/// Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and
/// sets the attribute `kLinalgTransformMarker` to `linalgMarker`.
LogicalResult tileAndFuseLinalgOpAndSetMarker(
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker);
// Emits a loop nest of `loop.for` with the proper body for `op`.
/// Emits a loop nest of `loop.for` with the proper body for `op`.
template <typename ConcreteOp>
LogicalResult linalgOpToLoops(PatternRewriter &rewriter, Operation *op);
// Emits a loop nest of `affine.for` with the proper body for `op`.
/// Emits a loop nest of `affine.for` with the proper body for `op`.
template <typename ConcreteOp>
LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op);
// Rewrite a linalg.generic into a suitable vector.contraction op.
/// Rewrite a linalg.generic into a suitable vector.contraction op.
LogicalResult vectorizeGenericOp(PatternRewriter &rewriter, Operation *op);
/// Emits a `generic` or `indexed_generic` operation with the `indexing_maps`
/// and `iterator_types` permutated according to `permutation`.
LogicalResult permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
ArrayRef<unsigned> permutation,
StringRef linalgMarker);
} // namespace linalg
} // namespace mlir

View File

@ -205,6 +205,18 @@ promoteSubViews(OpBuilder &b, Location loc, ArrayRef<Value *> subViews,
/// tiling to just use the values when cloning `linalgOp`.
llvm::SmallVector<Value *, 4> getAssumedNonViewOperands(LinalgOp linalgOp);
/// Apply the permutation defined by `permutation` to `inVec`.
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`.
template <typename T, unsigned N>
void applyPermutationToVector(SmallVector<T, N> &inVec,
ArrayRef<unsigned> permutation) {
SmallVector<T, N> auxVec(inVec.size());
for (unsigned i = 0; i < permutation.size(); ++i)
auxVec[i] = inVec[permutation[i]];
inVec = auxVec;
}
} // namespace linalg
} // namespace mlir

View File

@ -22,8 +22,10 @@
#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h"
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/VectorOps/VectorOps.h"
#include "mlir/EDSC/Helpers.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
@ -35,7 +37,10 @@
#define DEBUG_TYPE "linalg-transforms"
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
using namespace mlir::linalg;
using namespace mlir::linalg::intrinsics;
using llvm::dbgs;
@ -193,3 +198,35 @@ LogicalResult mlir::linalg::vectorizeGenericOp(PatternRewriter &rewriter,
std_store(vRes, vectorMemRefC);
return success();
}
LogicalResult
mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
ArrayRef<unsigned> permutation,
StringRef linalgMarker) {
// If permutation is empty, there is nothing to be done.
if (permutation.empty())
return failure();
auto linOp = cast<LinalgOp>(op);
auto permutationMap = inversePermutation(
AffineMap::getPermutationMap(permutation, rewriter.getContext()));
SmallVector<AffineMap, 4> newIndexingMap;
auto indexingMaps =
linOp.getAttrOfType<ArrayAttr>("indexing_maps").getValue();
for (unsigned i = 0, e = linOp.getNumInputsAndOutputs(); i != e; ++i) {
AffineMap m = indexingMaps[i].cast<AffineMapAttr>().getValue().compose(
permutationMap);
newIndexingMap.push_back(m);
}
auto itTypes = linOp.getAttrOfType<ArrayAttr>("iterator_types").getValue();
SmallVector<StringRef, 4> itTypesVector;
for (unsigned i = 0, e = itTypes.size(); i != e; ++i)
itTypesVector.push_back(itTypes[i].cast<StringAttr>().getValue());
applyPermutationToVector(itTypesVector, permutation);
op->setAttr("indexing_maps", rewriter.getAffineMapArrayAttr(newIndexingMap));
op->setAttr("iterator_types", rewriter.getStrArrayAttr(itTypesVector));
op->setAttr(LinalgTransforms::kLinalgTransformMarker,
rewriter.getStringAttr(linalgMarker));
linOp.clone(rewriter, linOp.getLoc(), op->getOperands());
return success();
}

View File

@ -215,14 +215,6 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
return res;
}
void applyPermutationToLoopRanges(SmallVector<SubViewOp::Range, 4> &loopRanges,
ArrayRef<unsigned> permutation) {
SmallVector<SubViewOp::Range, 4> auxVec(loopRanges.size());
for (unsigned i = 0; i < permutation.size(); ++i)
auxVec[i] = loopRanges[permutation[i]];
loopRanges = auxVec;
}
llvm::Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
OpBuilder &b, LinalgOp op, ArrayRef<Value *> tileSizes,
ArrayRef<unsigned> permutation, OperationFolder *folder) {
@ -256,7 +248,7 @@ llvm::Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
makeTiledLoopRanges(b, scope.getLocation(), viewSizesToLoopsMap,
viewSizes, tileSizes, folder);
if (!permutation.empty())
applyPermutationToLoopRanges(loopRanges, permutation);
applyPermutationToVector(loopRanges, permutation);
// 3. Create the tiled loops.
LinalgOp res = op;

View File

@ -5,6 +5,8 @@
// CHECK-DAG: #[[mk:.*]] = (d0, d1, d2) -> (d0, d2)
// CHECK-DAG: #[[kn:.*]] = (d0, d1, d2) -> (d2, d1)
// CHECK-DAG: #[[mn:.*]] = (d0, d1, d2) -> (d0, d1)
// CHECK-DAG: #[[nm:.*]] = (d0, d1, d2) -> (d1, d0)
// CHECK-DAG: #[[km:.*]] = (d0, d1, d2) -> (d2, d0)
func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
%y: memref<?xf32, offset: ?, strides: [1]>,
@ -191,3 +193,53 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
// CHECK: load %{{.*}}[] : memref<vector<8x32xf32>>
// CHECK: vector.contract {indexing_maps = [#[[mk]], #[[kn]], #[[mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
// CHECK: store %{{.*}}, %{{.*}}[] : memref<vector<8x32xf32>>
func @fma(%a: f32, %b: f32, %c: f32) -> f32 {
%d = mulf %a, %b: f32
%e = addf %c, %d: f32
return %e: f32
}
#matmul_accesses = [
(m, n, k) -> (m, k),
(m, n, k) -> (k, n),
(m, n, k) -> (m, n)
]
#generic_matmul_trait = {
fun = @fma,
indexing_maps = #matmul_accesses,
library_call = "linalg_matmul",
n_views = [2, 1],
iterator_types = ["parallel", "parallel", "reduction"]
}
func @permute_generic(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
linalg.generic #generic_matmul_trait %A, %B, %C : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
return
}
// CHECK-LABEL : func @fma
// CHECK-LABEL : func @permute_generic
// CHECK : linalg.generic {fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul", n_views = [2, 1]} %{{.*}}, %{{.*}}, %{{.*}} : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
func @fma_indexed(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32) -> f32 {
%d = mulf %a, %b: f32
%e = addf %c, %d: f32
return %e: f32
}
#indexed_matmul_trait = {
fun = @fma_indexed,
indexing_maps = #matmul_accesses,
library_call = "linalg_matmul_indexed",
n_views = [2, 1],
iterator_types = ["parallel", "parallel", "reduction"]
}
func @permute_generic_indexed(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
linalg.indexed_generic #indexed_matmul_trait %A, %B, %C : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
return
}
// CHECK-LABEL : func @fma_indexed
// CHECK-LABEL : func @permute_generic_indexed
// CHECK : linalg.indexed_generic {fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul_indexed", n_views = [2, 1]} %{{.*}}, %{{.*}}, %{{.*}} : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>

View File

@ -87,4 +87,17 @@ def : Pattern<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7),
[(LinalgOpToVectorContraction<"GenericOp"> $op)],
[(Constraint<HasLinalgTransformMarker<"_marked_matmul_">> $op)]>;
//===----------------------------------------------------------------------===//
// Linalg generic permutation patterns.
//===----------------------------------------------------------------------===//
def : Pat<(GenericOp:$op $input, $imap, $itypes, $nviews, $doc, $fun, $libcall),
(PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op),
[(Constraint<And<[HasNoLinalgTransformMarker,
AffineMapDomainHasDim<3>]>> $op)]>;
def : Pat<(IndexedGenericOp:$op $input, $imap, $itypes, $nviews, $doc, $fun, $libcall),
(PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op),
[(Constraint<And<[HasNoLinalgTransformMarker,
AffineMapDomainHasDim<3>]>> $op)]>;
#endif // TEST_LINALG_TRANSFORMS_PATTERNS