[StructuredOps][Linalg] Add a primitive pattern to rewrite the linalg.generic form of matmul to vector form.

This CL uses the newly expanded matcher support to easily detect when a linalg.generic has a multiply-accumulate body. A linalg.generic with such a body is rewritten as a vector contraction.
This CL additionally limits the rewrite to the case of matrix multiplication on contiguous and statically shaped memrefs for now.

Before expanding further, we should harden the infrastructure for expressing custom ops with the structured ops abstraction.

PiperOrigin-RevId: 284566659
This commit is contained in:
Nicolas Vasilache 2019-12-09 09:14:05 -08:00 committed by A. Unique TensorFlower
parent 70aeb4566e
commit 91c0074624
9 changed files with 168 additions and 6 deletions

View File

@ -484,9 +484,10 @@ def GenericOp : GenericOpBase<"generic"> {
The external library is assumed to be dynamically linked and no strong
compile-time guarantees are provided. In the absence of such a library
call, linalg.generic will always lower to loops.
- iterator_types: an ArrayAttr they type of the enclosing loops; Each element of
the list represents and iterator of one of the following types:
parallel, reduction, window
- iterator_types: an ArrayAttr specifying the type of the enclosing loops.
Each element of the list represents and iterator of one of the following
types:
parallel, reduction, window
- n_views: a pair of I64Attr representing the number of input (readonly)
and output (readwrite) views.

View File

@ -30,6 +30,8 @@ def HasNoLinalgTransformMarker : CPred<[{
}]>;
class HasLinalgTransformMarker<string str> : CPred<[{
$0.getAttrOfType<StringAttr>(
LinalgTransforms::kLinalgTransformMarker) &&
$0.getAttrOfType<StringAttr>(
LinalgTransforms::kLinalgTransformMarker).getValue() == "}] # str # [{"}]>;
@ -77,4 +79,11 @@ class LinalgOpToAffineLoops<string OpType> : NativeCodeCall<
"if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, $0))) " #
" return matchFailure();">;
//===----------------------------------------------------------------------===//
// Linalg to vector contraction patterns.
//===----------------------------------------------------------------------===//
class LinalgOpToVectorContraction<string OpType> : NativeCodeCall<
"if (failed(vectorizeGenericOp($_builder, $0))) " #
" return matchFailure();">;
#endif // LINALG_TRANSFORMS

View File

@ -87,6 +87,9 @@ LogicalResult linalgOpToLoops(PatternRewriter &rewriter, Operation *op);
template <typename ConcreteOp>
LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op);
// Rewrite a linalg.generic into a suitable vector.contraction op.
LogicalResult vectorizeGenericOp(PatternRewriter &rewriter, Operation *op);
} // namespace linalg
} // namespace mlir

View File

@ -127,7 +127,16 @@ def Vector_ContractionOp :
%5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
}];
let builders = [OpBuilder<
"Builder *builder, OperationState &result, Value *lhs, Value *rhs, "
"Value *acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes">];
let extraClassDeclaration = [{
static constexpr StringLiteral getIndexingMapsAttrName() {
return "indexing_maps";
}
static constexpr StringLiteral getIteratorTypesAttrName() {
return "iterator_types";
}
VectorType getLhsType() {
return lhs()->getType().cast<VectorType>();
}
@ -148,7 +157,7 @@ def Vector_ContractionOp :
VectorType getResultType() {
return getResult()->getType().cast<VectorType>();
}
SmallVector<StringRef, 2> getTraitAttrNames();
ArrayRef<StringRef> getTraitAttrNames();
SmallVector<AffineMap, 4> getIndexingMaps();
static StringRef getReductionIteratorTypeName() {
return "reduction";

View File

@ -25,4 +25,5 @@ add_dependencies(MLIRLinalg
MLIRLinalgTransformPatternsIncGen
MLIRStandardOps
MLIRStandardToLLVM
MLIRVectorOps
)

View File

@ -23,12 +23,22 @@
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/VectorOps/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <type_traits>
#define DEBUG_TYPE "linalg-transforms"
using namespace mlir;
using namespace mlir::linalg;
using llvm::dbgs;
// Marker used as attribute name in generated Linalg rewriting transformations.
const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
"__internal_linalg_transform__";
@ -106,3 +116,80 @@ bool mlir::linalg::detail::isProducedByOpOfTypeImpl(
}
return false;
}
static bool hasMultiplyAddBody(linalg::GenericOp op) {
auto &r = op.region();
if (r.empty())
return false;
if (r.getBlocks().size() != 1)
return false;
auto &ops = r.front().getOperations();
if (ops.size() != 3)
return false;
using mlir::matchers::m_Val;
auto a = m_Val(r.front().getArgument(0));
auto b = m_Val(r.front().getArgument(1));
auto c = m_Val(r.front().getArgument(2));
// TODO(ntv) Update this detection once we have matcher support for
// specifying that any permutation of operands matches.
auto pattern1 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c));
auto pattern2 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b)));
auto pattern3 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c));
auto pattern4 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a)));
return pattern1.match(&ops.back()) || pattern2.match(&ops.back()) ||
pattern3.match(&ops.back()) || pattern4.match(&ops.back());
}
// TODO(ntv) should be Tablegen'd from a single source that generates the op
// itself.
static bool isMatmul(linalg::GenericOp genericOp) {
auto *ctx = genericOp.getContext();
auto m = getAffineDimExpr(0, ctx);
auto n = getAffineDimExpr(1, ctx);
auto k = getAffineDimExpr(2, ctx);
auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}));
auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}));
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}));
auto maps = ArrayAttr::get({mapA, mapB, mapC}, ctx);
return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp);
}
LogicalResult mlir::linalg::vectorizeGenericOp(PatternRewriter &rewriter,
Operation *op) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
"]: Rewrite linalg op as vector.contract: "
<< *op << ":\n");
// TODO(ntv): This is in fact much more general than just vectorization for
// matmul ops.
auto genericOp = dyn_cast<linalg::GenericOp>(op);
if (!genericOp || !isMatmul(genericOp))
return failure();
// TODO(ntv): non-identity layout.
auto isStaticMemRefWithIdentityLayout = [](Value *v) {
auto m = v->getType().dyn_cast<MemRefType>();
if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty())
return false;
return true;
};
if (!llvm::all_of(genericOp.getInputsAndOutputs(),
isStaticMemRefWithIdentityLayout))
return failure();
edsc::ScopedContext scope(rewriter, op->getLoc());
using edsc::intrinsics::std_load;
using edsc::intrinsics::std_store;
using vector_contract = edsc::intrinsics::ValueBuilder<vector::ContractionOp>;
using vector_type_cast = edsc::intrinsics::ValueBuilder<vector::TypeCastOp>;
auto vA = std_load(vector_type_cast(genericOp.getInput(0)));
auto vB = std_load(vector_type_cast(genericOp.getInput(1)));
auto vectorMemRefC = vector_type_cast(genericOp.getOutput(0));
auto vC = std_load(vectorMemRefC);
auto vRes = vector_contract(vA, vB, vC, genericOp.indexing_maps(),
genericOp.iterator_types());
std_store(vRes, vectorMemRefC);
return success();
}

View File

@ -51,6 +51,16 @@ mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
// ContractionOp
//===----------------------------------------------------------------------===//
void vector::ContractionOp::build(Builder *builder, OperationState &result,
Value *lhs, Value *rhs, Value *acc,
ArrayAttr indexingMaps,
ArrayAttr iteratorTypes) {
result.addOperands({lhs, rhs, acc});
result.addTypes(acc->getType());
result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
}
static ParseResult parseContractionOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType lhsInfo;
@ -235,8 +245,10 @@ static LogicalResult verify(ContractionOp op) {
return success();
}
SmallVector<StringRef, 2> ContractionOp::getTraitAttrNames() {
return SmallVector<StringRef, 2>{"indexing_maps", "iterator_types"};
ArrayRef<StringRef> ContractionOp::getTraitAttrNames() {
static constexpr StringRef names[2] = {getIndexingMapsAttrName(),
getIteratorTypesAttrName()};
return ArrayRef<StringRef>(names);
}
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {

View File

@ -2,6 +2,9 @@
// CHECK-DAG: #[[STRIDED_1D:.*]] = (d0)[s0] -> (d0 + s0)
// CHECK-DAG: #[[STRIDED_2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)
// CHECK-DAG: #[[mk:.*]] = (d0, d1, d2) -> (d0, d2)
// CHECK-DAG: #[[kn:.*]] = (d0, d1, d2) -> (d2, d1)
// CHECK-DAG: #[[mn:.*]] = (d0, d1, d2) -> (d0, d1)
func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
%y: memref<?xf32, offset: ?, strides: [1]>,
@ -158,3 +161,33 @@ func @fusion_test(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// CHECK : loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c4]] {
// CHECK : linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
#matmul_trait = {
indexing_maps = [
(m, n, k) -> (m, k),
(m, n, k) -> (k, n),
(m, n, k) -> (m, n)
],
n_views = [2, 1],
iterator_types = ["parallel", "parallel", "reduction"],
__internal_linalg_transform__ = "_marked_matmul_"
}
func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
%C: memref<8x32xf32>) {
linalg.generic #matmul_trait %A, %B, %C {
^bb(%a: f32, %b: f32, %c: f32) :
%d = mulf %a, %b: f32
%e = addf %c, %d: f32
linalg.yield %e : f32
} : memref<8x16xf32>, memref<16x32xf32>, memref<8x32xf32>
return
}
// CHECK-LABEL: func @vectorization_test
// CHECK: vector.type_cast %{{.*}} : memref<8x16xf32> to memref<vector<8x16xf32>>
// CHECK: load %{{.*}}[] : memref<vector<8x16xf32>>
// CHECK: vector.type_cast %{{.*}} : memref<16x32xf32> to memref<vector<16x32xf32>>
// CHECK: load %{{.*}}[] : memref<vector<16x32xf32>>
// CHECK: vector.type_cast %{{.*}} : memref<8x32xf32> to memref<vector<8x32xf32>>
// CHECK: load %{{.*}}[] : memref<vector<8x32xf32>>
// CHECK: vector.contract {indexing_maps = [#[[mk]], #[[kn]], #[[mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
// CHECK: store %{{.*}}, %{{.*}}[] : memref<vector<8x32xf32>>

View File

@ -80,4 +80,11 @@ def : Pattern<(DotOp:$op $a, $b, $c),
[(LinalgOpToLoops<"DotOp"> $op)],
[(Constraint<HasLinalgTransformMarker<"REG">> $op)]>;
//===----------------------------------------------------------------------===//
// Linalg to vector contraction patterns.
//===----------------------------------------------------------------------===//
def : Pattern<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7),
[(LinalgOpToVectorContraction<"GenericOp"> $op)],
[(Constraint<HasLinalgTransformMarker<"_marked_matmul_">> $op)]>;
#endif // TEST_LINALG_TRANSFORMS_PATTERNS