forked from OSchip/llvm-project
[StructuredOps][Linalg] Add a primitive pattern to rewrite the linalg.generic form of matmul to vector form.
This CL uses the newly expanded matcher support to easily detect when a linalg.generic has a multiply-accumulate body. A linalg.generic with such a body is rewritten as a vector contraction. This CL additionally limits the rewrite to the case of matrix multiplication on contiguous and statically shaped memrefs for now. Before expanding further, we should harden the infrastructure for expressing custom ops with the structured ops abstraction. PiperOrigin-RevId: 284566659
This commit is contained in:
parent
70aeb4566e
commit
91c0074624
|
@ -484,9 +484,10 @@ def GenericOp : GenericOpBase<"generic"> {
|
|||
The external library is assumed to be dynamically linked and no strong
|
||||
compile-time guarantees are provided. In the absence of such a library
|
||||
call, linalg.generic will always lower to loops.
|
||||
- iterator_types: an ArrayAttr they type of the enclosing loops; Each element of
|
||||
the list represents and iterator of one of the following types:
|
||||
parallel, reduction, window
|
||||
- iterator_types: an ArrayAttr specifying the type of the enclosing loops.
|
||||
Each element of the list represents and iterator of one of the following
|
||||
types:
|
||||
parallel, reduction, window
|
||||
- n_views: a pair of I64Attr representing the number of input (readonly)
|
||||
and output (readwrite) views.
|
||||
|
||||
|
|
|
@ -30,6 +30,8 @@ def HasNoLinalgTransformMarker : CPred<[{
|
|||
}]>;
|
||||
|
||||
class HasLinalgTransformMarker<string str> : CPred<[{
|
||||
$0.getAttrOfType<StringAttr>(
|
||||
LinalgTransforms::kLinalgTransformMarker) &&
|
||||
$0.getAttrOfType<StringAttr>(
|
||||
LinalgTransforms::kLinalgTransformMarker).getValue() == "}] # str # [{"}]>;
|
||||
|
||||
|
@ -77,4 +79,11 @@ class LinalgOpToAffineLoops<string OpType> : NativeCodeCall<
|
|||
"if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, $0))) " #
|
||||
" return matchFailure();">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg to vector contraction patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
class LinalgOpToVectorContraction<string OpType> : NativeCodeCall<
|
||||
"if (failed(vectorizeGenericOp($_builder, $0))) " #
|
||||
" return matchFailure();">;
|
||||
|
||||
#endif // LINALG_TRANSFORMS
|
||||
|
|
|
@ -87,6 +87,9 @@ LogicalResult linalgOpToLoops(PatternRewriter &rewriter, Operation *op);
|
|||
template <typename ConcreteOp>
|
||||
LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op);
|
||||
|
||||
// Rewrite a linalg.generic into a suitable vector.contraction op.
|
||||
LogicalResult vectorizeGenericOp(PatternRewriter &rewriter, Operation *op);
|
||||
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -127,7 +127,16 @@ def Vector_ContractionOp :
|
|||
%5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask
|
||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||
}];
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState &result, Value *lhs, Value *rhs, "
|
||||
"Value *acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes">];
|
||||
let extraClassDeclaration = [{
|
||||
static constexpr StringLiteral getIndexingMapsAttrName() {
|
||||
return "indexing_maps";
|
||||
}
|
||||
static constexpr StringLiteral getIteratorTypesAttrName() {
|
||||
return "iterator_types";
|
||||
}
|
||||
VectorType getLhsType() {
|
||||
return lhs()->getType().cast<VectorType>();
|
||||
}
|
||||
|
@ -148,7 +157,7 @@ def Vector_ContractionOp :
|
|||
VectorType getResultType() {
|
||||
return getResult()->getType().cast<VectorType>();
|
||||
}
|
||||
SmallVector<StringRef, 2> getTraitAttrNames();
|
||||
ArrayRef<StringRef> getTraitAttrNames();
|
||||
SmallVector<AffineMap, 4> getIndexingMaps();
|
||||
static StringRef getReductionIteratorTypeName() {
|
||||
return "reduction";
|
||||
|
|
|
@ -25,4 +25,5 @@ add_dependencies(MLIRLinalg
|
|||
MLIRLinalgTransformPatternsIncGen
|
||||
MLIRStandardOps
|
||||
MLIRStandardToLLVM
|
||||
MLIRVectorOps
|
||||
)
|
||||
|
|
|
@ -23,12 +23,22 @@
|
|||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/VectorOps/VectorOps.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <type_traits>
|
||||
|
||||
#define DEBUG_TYPE "linalg-transforms"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
using llvm::dbgs;
|
||||
|
||||
// Marker used as attribute name in generated Linalg rewriting transformations.
|
||||
const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
|
||||
"__internal_linalg_transform__";
|
||||
|
@ -106,3 +116,80 @@ bool mlir::linalg::detail::isProducedByOpOfTypeImpl(
|
|||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool hasMultiplyAddBody(linalg::GenericOp op) {
|
||||
auto &r = op.region();
|
||||
if (r.empty())
|
||||
return false;
|
||||
if (r.getBlocks().size() != 1)
|
||||
return false;
|
||||
auto &ops = r.front().getOperations();
|
||||
if (ops.size() != 3)
|
||||
return false;
|
||||
|
||||
using mlir::matchers::m_Val;
|
||||
auto a = m_Val(r.front().getArgument(0));
|
||||
auto b = m_Val(r.front().getArgument(1));
|
||||
auto c = m_Val(r.front().getArgument(2));
|
||||
// TODO(ntv) Update this detection once we have matcher support for
|
||||
// specifying that any permutation of operands matches.
|
||||
auto pattern1 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c));
|
||||
auto pattern2 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b)));
|
||||
auto pattern3 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c));
|
||||
auto pattern4 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a)));
|
||||
return pattern1.match(&ops.back()) || pattern2.match(&ops.back()) ||
|
||||
pattern3.match(&ops.back()) || pattern4.match(&ops.back());
|
||||
}
|
||||
|
||||
// TODO(ntv) should be Tablegen'd from a single source that generates the op
|
||||
// itself.
|
||||
static bool isMatmul(linalg::GenericOp genericOp) {
|
||||
auto *ctx = genericOp.getContext();
|
||||
auto m = getAffineDimExpr(0, ctx);
|
||||
auto n = getAffineDimExpr(1, ctx);
|
||||
auto k = getAffineDimExpr(2, ctx);
|
||||
auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}));
|
||||
auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}));
|
||||
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}));
|
||||
auto maps = ArrayAttr::get({mapA, mapB, mapC}, ctx);
|
||||
return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
|
||||
genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp);
|
||||
}
|
||||
|
||||
LogicalResult mlir::linalg::vectorizeGenericOp(PatternRewriter &rewriter,
|
||||
Operation *op) {
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
|
||||
"]: Rewrite linalg op as vector.contract: "
|
||||
<< *op << ":\n");
|
||||
|
||||
// TODO(ntv): This is in fact much more general than just vectorization for
|
||||
// matmul ops.
|
||||
auto genericOp = dyn_cast<linalg::GenericOp>(op);
|
||||
if (!genericOp || !isMatmul(genericOp))
|
||||
return failure();
|
||||
|
||||
// TODO(ntv): non-identity layout.
|
||||
auto isStaticMemRefWithIdentityLayout = [](Value *v) {
|
||||
auto m = v->getType().dyn_cast<MemRefType>();
|
||||
if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty())
|
||||
return false;
|
||||
return true;
|
||||
};
|
||||
if (!llvm::all_of(genericOp.getInputsAndOutputs(),
|
||||
isStaticMemRefWithIdentityLayout))
|
||||
return failure();
|
||||
|
||||
edsc::ScopedContext scope(rewriter, op->getLoc());
|
||||
using edsc::intrinsics::std_load;
|
||||
using edsc::intrinsics::std_store;
|
||||
using vector_contract = edsc::intrinsics::ValueBuilder<vector::ContractionOp>;
|
||||
using vector_type_cast = edsc::intrinsics::ValueBuilder<vector::TypeCastOp>;
|
||||
auto vA = std_load(vector_type_cast(genericOp.getInput(0)));
|
||||
auto vB = std_load(vector_type_cast(genericOp.getInput(1)));
|
||||
auto vectorMemRefC = vector_type_cast(genericOp.getOutput(0));
|
||||
auto vC = std_load(vectorMemRefC);
|
||||
auto vRes = vector_contract(vA, vB, vC, genericOp.indexing_maps(),
|
||||
genericOp.iterator_types());
|
||||
std_store(vRes, vectorMemRefC);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -51,6 +51,16 @@ mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
|
|||
// ContractionOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void vector::ContractionOp::build(Builder *builder, OperationState &result,
|
||||
Value *lhs, Value *rhs, Value *acc,
|
||||
ArrayAttr indexingMaps,
|
||||
ArrayAttr iteratorTypes) {
|
||||
result.addOperands({lhs, rhs, acc});
|
||||
result.addTypes(acc->getType());
|
||||
result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
|
||||
result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
|
||||
}
|
||||
|
||||
static ParseResult parseContractionOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::OperandType lhsInfo;
|
||||
|
@ -235,8 +245,10 @@ static LogicalResult verify(ContractionOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
SmallVector<StringRef, 2> ContractionOp::getTraitAttrNames() {
|
||||
return SmallVector<StringRef, 2>{"indexing_maps", "iterator_types"};
|
||||
ArrayRef<StringRef> ContractionOp::getTraitAttrNames() {
|
||||
static constexpr StringRef names[2] = {getIndexingMapsAttrName(),
|
||||
getIteratorTypesAttrName()};
|
||||
return ArrayRef<StringRef>(names);
|
||||
}
|
||||
|
||||
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
|
||||
|
|
|
@ -2,6 +2,9 @@
|
|||
|
||||
// CHECK-DAG: #[[STRIDED_1D:.*]] = (d0)[s0] -> (d0 + s0)
|
||||
// CHECK-DAG: #[[STRIDED_2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)
|
||||
// CHECK-DAG: #[[mk:.*]] = (d0, d1, d2) -> (d0, d2)
|
||||
// CHECK-DAG: #[[kn:.*]] = (d0, d1, d2) -> (d2, d1)
|
||||
// CHECK-DAG: #[[mn:.*]] = (d0, d1, d2) -> (d0, d1)
|
||||
|
||||
func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
|
||||
%y: memref<?xf32, offset: ?, strides: [1]>,
|
||||
|
@ -158,3 +161,33 @@ func @fusion_test(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
|||
// CHECK : loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c4]] {
|
||||
// CHECK : linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
|
||||
|
||||
#matmul_trait = {
|
||||
indexing_maps = [
|
||||
(m, n, k) -> (m, k),
|
||||
(m, n, k) -> (k, n),
|
||||
(m, n, k) -> (m, n)
|
||||
],
|
||||
n_views = [2, 1],
|
||||
iterator_types = ["parallel", "parallel", "reduction"],
|
||||
__internal_linalg_transform__ = "_marked_matmul_"
|
||||
}
|
||||
func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
|
||||
%C: memref<8x32xf32>) {
|
||||
linalg.generic #matmul_trait %A, %B, %C {
|
||||
^bb(%a: f32, %b: f32, %c: f32) :
|
||||
%d = mulf %a, %b: f32
|
||||
%e = addf %c, %d: f32
|
||||
linalg.yield %e : f32
|
||||
} : memref<8x16xf32>, memref<16x32xf32>, memref<8x32xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vectorization_test
|
||||
// CHECK: vector.type_cast %{{.*}} : memref<8x16xf32> to memref<vector<8x16xf32>>
|
||||
// CHECK: load %{{.*}}[] : memref<vector<8x16xf32>>
|
||||
// CHECK: vector.type_cast %{{.*}} : memref<16x32xf32> to memref<vector<16x32xf32>>
|
||||
// CHECK: load %{{.*}}[] : memref<vector<16x32xf32>>
|
||||
// CHECK: vector.type_cast %{{.*}} : memref<8x32xf32> to memref<vector<8x32xf32>>
|
||||
// CHECK: load %{{.*}}[] : memref<vector<8x32xf32>>
|
||||
// CHECK: vector.contract {indexing_maps = [#[[mk]], #[[kn]], #[[mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
|
||||
// CHECK: store %{{.*}}, %{{.*}}[] : memref<vector<8x32xf32>>
|
||||
|
|
|
@ -80,4 +80,11 @@ def : Pattern<(DotOp:$op $a, $b, $c),
|
|||
[(LinalgOpToLoops<"DotOp"> $op)],
|
||||
[(Constraint<HasLinalgTransformMarker<"REG">> $op)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg to vector contraction patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pattern<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7),
|
||||
[(LinalgOpToVectorContraction<"GenericOp"> $op)],
|
||||
[(Constraint<HasLinalgTransformMarker<"_marked_matmul_">> $op)]>;
|
||||
|
||||
#endif // TEST_LINALG_TRANSFORMS_PATTERNS
|
||||
|
|
Loading…
Reference in New Issue