[mlir][Linalg] Generalize Vectorization of Linalg contractions

This revision adds support for vectorizing named and generic contraction ops to vector.contract. Cases in which the memref is 0-D are special cased to emit std.load/std.store instead of vector.transfer. Relevant tests are added.

Differential revision: https://reviews.llvm.org/D83307
This commit is contained in:
Nicolas Vasilache 2020-07-10 10:21:45 -04:00
parent 015a0faa5e
commit 56c638b5c1
4 changed files with 110 additions and 43 deletions

View File

@ -286,6 +286,12 @@ public:
return llvm::make_range(attr_value_iterator<AttrTy>(begin()),
attr_value_iterator<AttrTy>(end()));
}
template <typename AttrTy, typename UnderlyingTy>
auto getAsRange() {
return llvm::map_range(getAsRange<AttrTy>(), [](AttrTy attr) {
return static_cast<UnderlyingTy>(attr.getValue());
});
}
};
//===----------------------------------------------------------------------===//

View File

@ -36,8 +36,7 @@ using llvm::dbgs;
#define DEBUG_TYPE "linalg-vectorization"
static bool hasMultiplyAddBody(linalg::GenericOp op) {
auto &r = op.region();
static bool hasMultiplyAddBody(Region &r) {
if (!llvm::hasSingleElement(r))
return false;
if (!llvm::hasNItems(r.front().begin(), r.front().end(), 3))
@ -59,14 +58,26 @@ static bool hasMultiplyAddBody(linalg::GenericOp op) {
}
// TODO: Should be Tablegen'd from a single source that generates the op itself.
static bool isRowMajorMatmul(linalg::GenericOp genericOp) {
return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
isRowMajorMatmul(genericOp.indexing_maps()) &&
hasMultiplyAddBody(genericOp);
static LogicalResult isContraction(Operation *op) {
// TODO: interface for named ops.
if (isa<linalg::BatchMatmulOp, linalg::MatmulOp, linalg::MatvecOp,
linalg::DotOp>(op))
return success();
auto genericOp = dyn_cast<linalg::GenericOp>(op);
if (!genericOp)
return failure();
auto mapRange =
genericOp.indexing_maps().getAsRange<AffineMapAttr, AffineMap>();
return success(
genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
llvm::all_of(mapRange,
[](AffineMap m) { return m.isProjectedPermutation(); }) &&
hasMultiplyAddBody(genericOp.region()));
}
// TODO: This is in fact much more general than just vectorization for matmul
// and fill ops.
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
auto linalgOp = cast<linalg::LinalgOp>(op);
// All types must be static shape to go to vector.
@ -76,33 +87,16 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
for (Type outputTensorType : linalgOp.getOutputTensorTypes())
if (!outputTensorType.cast<ShapedType>().hasStaticShape())
return failure();
if (isa<linalg::MatmulOp, linalg::FillOp>(op))
if (isa<linalg::FillOp>(op))
return success();
auto genericOp = dyn_cast<linalg::GenericOp>(op);
if (!genericOp || !::isRowMajorMatmul(genericOp))
return failure();
// TODO: non-identity layout.
auto isStaticMemRefWithIdentityLayout = [](Value v) {
auto m = v.getType().dyn_cast<MemRefType>();
if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty())
return false;
return true;
};
return success(llvm::all_of(genericOp.getInputsAndOutputBuffers(),
isStaticMemRefWithIdentityLayout));
return isContraction(op);
}
void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
assert(succeeded(vectorizeLinalgOpPrecondition(op)));
if (auto convOp = dyn_cast<linalg::ConvOp>(op)) {
// TODO: add a level of indirection to linalg.generic.
if (convOp.padding())
llvm_unreachable("Unexpected conv with padding");
}
StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
(void)dbgPref;
edsc::ScopedContext scope(builder, op->getLoc());
@ -117,33 +111,47 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
return;
}
// Vectorize other ops as vector contraction (currently only matmul).
assert(succeeded(isContraction(op)) && "Expected contraction");
// Vectorize other ops as vector contraction.
// TODO: interface.
LLVM_DEBUG(dbgs() << dbgPref
<< "Rewrite linalg op as vector.contract: " << *op);
// In the case of 0-D memrefs, return null and special case to scalar load or
// store later.
auto extractVectorTypeFromScalarView = [](Value v) {
MemRefType mt = v.getType().cast<MemRefType>();
return VectorType::get(mt.getShape(), mt.getElementType());
return mt.getShape().empty()
? VectorType()
: VectorType::get(mt.getShape(), mt.getElementType());
};
auto linalgOp = cast<linalg::LinalgOp>(op);
Value viewA = linalgOp.getInput(0);
Value viewB = linalgOp.getInput(1);
Value viewC = linalgOp.getOutputBuffer(0);
VectorType vtA = extractVectorTypeFromScalarView(viewA);
VectorType vtB = extractVectorTypeFromScalarView(viewB);
VectorType vtC = extractVectorTypeFromScalarView(viewC);
Value zero = std_constant_index(0);
SmallVector<Value, 4> indicesA(linalgOp.getInputShapedType(0).getRank(),
zero);
SmallVector<Value, 4> indicesB(linalgOp.getInputShapedType(1).getRank(),
zero);
SmallVector<Value, 4> indicesC(linalgOp.getOutputShapedType(0).getRank(),
zero);
Value a = vector_transfer_read(extractVectorTypeFromScalarView(viewA), viewA,
indicesA);
Value b = vector_transfer_read(extractVectorTypeFromScalarView(viewB), viewB,
indicesB);
Value c = vector_transfer_read(extractVectorTypeFromScalarView(viewC), viewC,
indicesC);
SmallVector<Value, 4> indicesA, indicesB, indicesC;
if (vtA)
indicesA = SmallVector<Value, 4>(vtA.getRank(), zero);
if (vtB)
indicesB = SmallVector<Value, 4>(vtB.getRank(), zero);
if (vtC)
indicesC = SmallVector<Value, 4>(vtC.getRank(), zero);
Value a = vtA ? vector_transfer_read(vtA, viewA, indicesA).value
: std_load(viewA, indicesA).value;
Value b = vtB ? vector_transfer_read(vtB, viewB, indicesB).value
: std_load(viewB, indicesB).value;
Value c = vtC ? vector_transfer_read(vtC, viewC, indicesC).value
: std_load(viewC, indicesC).value;
Value res = vector_contract(a, b, c, linalgOp.indexing_maps(),
linalgOp.iterator_types());
if (vtC)
vector_transfer_write(res, viewC, indicesC);
else
std_store(res, viewC, indicesC);
}
/// Check whether there is any interleaved use of any `values` between `firstOp`

View File

@ -1,5 +1,6 @@
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-contraction-to-vector-patterns | FileCheck %s --check-prefix=VECTOR-CONTRACTION
func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
%B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
@ -30,3 +31,38 @@ func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
// CHECK-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>
//
// CHECK: linalg.copy
// VECTOR-CONTRACTION-LABEL: contraction_dot
func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
// VECTOR-CONTRACTION: vector.contract
// VECTOR-CONTRACTION-SAME: vector<1584xf32>, vector<1584xf32> into f32
linalg.dot(%A, %B, %C) : memref<1584xf32>, memref<1584xf32>, memref<f32>
return
}
// VECTOR-CONTRACTION-LABEL: contraction_matvec
func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
// VECTOR-CONTRACTION: vector.contract
// VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584xf32> into vector<1584xf32>
linalg.matvec %A, %B, %C :
(memref<1584x1584xf32>, memref<1584xf32>, memref<1584xf32>)
return
}
// VECTOR-CONTRACTION-LABEL: contraction_matmul
func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
// VECTOR-CONTRACTION: vector.contract
// VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584x1584xf32> into vector<1584x1584xf32>
linalg.matmul %A, %B, %C :
(memref<1584x1584xf32>, memref<1584x1584xf32>, memref<1584x1584xf32>)
return
}
// VECTOR-CONTRACTION-LABEL: contraction_batch_matmul
func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
// VECTOR-CONTRACTION: vector.contract
// VECTOR-CONTRACTION-SAME: vector<1584x1584x1584xf32>, vector<1584x1584x1584xf32> into vector<1584x1584x1584xf32>
linalg.batch_matmul %A, %B, %C :
(memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
return
}

View File

@ -54,6 +54,11 @@ struct TestLinalgTransforms
llvm::cl::desc(
"Test a fused pass that forwards linalg.copy to vector.transfer"),
llvm::cl::init(false)};
Option<bool> testGenericToVectorPattern{
*this, "test-contraction-to-vector-patterns",
llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
"in vector.contract form"),
llvm::cl::init(false)};
};
} // end anonymous namespace
@ -300,6 +305,16 @@ static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
applyPatternsAndFoldGreedily(funcOp, forwardPattern);
}
static void applyContractionToVectorPatterns(FuncOp funcOp) {
OwningRewritePatternList patterns;
patterns.insert<LinalgVectorizationPattern<BatchMatmulOp>,
LinalgVectorizationPattern<MatmulOp>,
LinalgVectorizationPattern<MatvecOp>,
LinalgVectorizationPattern<DotOp>,
LinalgVectorizationPattern<GenericOp>>(funcOp.getContext());
applyPatternsAndFoldGreedily(funcOp, patterns);
}
/// Apply transformations specified as patterns.
void TestLinalgTransforms::runOnFunction() {
auto lambda = [&](void *) {
@ -323,6 +338,8 @@ void TestLinalgTransforms::runOnFunction() {
testMatmulToVectorPatterns2dTiling);
if (testVectorTransferForwardingPatterns)
return applyVectorTransferForwardingPatterns(getFunction());
if (testGenericToVectorPattern)
return applyContractionToVectorPatterns(getFunction());
}
namespace mlir {