forked from OSchip/llvm-project
[mlir][Linalg] Generalize Vectorization of Linalg contractions
This revision adds support for vectorizing named and generic contraction ops to vector.contract. Cases in which the memref is 0-D are special cased to emit std.load/std.store instead of vector.transfer. Relevant tests are added. Differential revision: https://reviews.llvm.org/D83307
This commit is contained in:
parent
015a0faa5e
commit
56c638b5c1
|
@ -286,6 +286,12 @@ public:
|
|||
return llvm::make_range(attr_value_iterator<AttrTy>(begin()),
|
||||
attr_value_iterator<AttrTy>(end()));
|
||||
}
|
||||
template <typename AttrTy, typename UnderlyingTy>
|
||||
auto getAsRange() {
|
||||
return llvm::map_range(getAsRange<AttrTy>(), [](AttrTy attr) {
|
||||
return static_cast<UnderlyingTy>(attr.getValue());
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -36,8 +36,7 @@ using llvm::dbgs;
|
|||
|
||||
#define DEBUG_TYPE "linalg-vectorization"
|
||||
|
||||
static bool hasMultiplyAddBody(linalg::GenericOp op) {
|
||||
auto &r = op.region();
|
||||
static bool hasMultiplyAddBody(Region &r) {
|
||||
if (!llvm::hasSingleElement(r))
|
||||
return false;
|
||||
if (!llvm::hasNItems(r.front().begin(), r.front().end(), 3))
|
||||
|
@ -59,14 +58,26 @@ static bool hasMultiplyAddBody(linalg::GenericOp op) {
|
|||
}
|
||||
|
||||
// TODO: Should be Tablegen'd from a single source that generates the op itself.
|
||||
static bool isRowMajorMatmul(linalg::GenericOp genericOp) {
|
||||
return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
|
||||
isRowMajorMatmul(genericOp.indexing_maps()) &&
|
||||
hasMultiplyAddBody(genericOp);
|
||||
static LogicalResult isContraction(Operation *op) {
|
||||
// TODO: interface for named ops.
|
||||
if (isa<linalg::BatchMatmulOp, linalg::MatmulOp, linalg::MatvecOp,
|
||||
linalg::DotOp>(op))
|
||||
return success();
|
||||
|
||||
auto genericOp = dyn_cast<linalg::GenericOp>(op);
|
||||
if (!genericOp)
|
||||
return failure();
|
||||
|
||||
auto mapRange =
|
||||
genericOp.indexing_maps().getAsRange<AffineMapAttr, AffineMap>();
|
||||
|
||||
return success(
|
||||
genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
|
||||
llvm::all_of(mapRange,
|
||||
[](AffineMap m) { return m.isProjectedPermutation(); }) &&
|
||||
hasMultiplyAddBody(genericOp.region()));
|
||||
}
|
||||
|
||||
// TODO: This is in fact much more general than just vectorization for matmul
|
||||
// and fill ops.
|
||||
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
|
||||
auto linalgOp = cast<linalg::LinalgOp>(op);
|
||||
// All types must be static shape to go to vector.
|
||||
|
@ -76,33 +87,16 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
|
|||
for (Type outputTensorType : linalgOp.getOutputTensorTypes())
|
||||
if (!outputTensorType.cast<ShapedType>().hasStaticShape())
|
||||
return failure();
|
||||
if (isa<linalg::MatmulOp, linalg::FillOp>(op))
|
||||
|
||||
if (isa<linalg::FillOp>(op))
|
||||
return success();
|
||||
|
||||
auto genericOp = dyn_cast<linalg::GenericOp>(op);
|
||||
if (!genericOp || !::isRowMajorMatmul(genericOp))
|
||||
return failure();
|
||||
|
||||
// TODO: non-identity layout.
|
||||
auto isStaticMemRefWithIdentityLayout = [](Value v) {
|
||||
auto m = v.getType().dyn_cast<MemRefType>();
|
||||
if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty())
|
||||
return false;
|
||||
return true;
|
||||
};
|
||||
return success(llvm::all_of(genericOp.getInputsAndOutputBuffers(),
|
||||
isStaticMemRefWithIdentityLayout));
|
||||
return isContraction(op);
|
||||
}
|
||||
|
||||
void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
|
||||
assert(succeeded(vectorizeLinalgOpPrecondition(op)));
|
||||
|
||||
if (auto convOp = dyn_cast<linalg::ConvOp>(op)) {
|
||||
// TODO: add a level of indirection to linalg.generic.
|
||||
if (convOp.padding())
|
||||
llvm_unreachable("Unexpected conv with padding");
|
||||
}
|
||||
|
||||
StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
|
||||
(void)dbgPref;
|
||||
edsc::ScopedContext scope(builder, op->getLoc());
|
||||
|
@ -117,33 +111,47 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
|
|||
return;
|
||||
}
|
||||
|
||||
// Vectorize other ops as vector contraction (currently only matmul).
|
||||
assert(succeeded(isContraction(op)) && "Expected contraction");
|
||||
|
||||
// Vectorize other ops as vector contraction.
|
||||
// TODO: interface.
|
||||
LLVM_DEBUG(dbgs() << dbgPref
|
||||
<< "Rewrite linalg op as vector.contract: " << *op);
|
||||
// In the case of 0-D memrefs, return null and special case to scalar load or
|
||||
// store later.
|
||||
auto extractVectorTypeFromScalarView = [](Value v) {
|
||||
MemRefType mt = v.getType().cast<MemRefType>();
|
||||
return VectorType::get(mt.getShape(), mt.getElementType());
|
||||
return mt.getShape().empty()
|
||||
? VectorType()
|
||||
: VectorType::get(mt.getShape(), mt.getElementType());
|
||||
};
|
||||
auto linalgOp = cast<linalg::LinalgOp>(op);
|
||||
Value viewA = linalgOp.getInput(0);
|
||||
Value viewB = linalgOp.getInput(1);
|
||||
Value viewC = linalgOp.getOutputBuffer(0);
|
||||
VectorType vtA = extractVectorTypeFromScalarView(viewA);
|
||||
VectorType vtB = extractVectorTypeFromScalarView(viewB);
|
||||
VectorType vtC = extractVectorTypeFromScalarView(viewC);
|
||||
Value zero = std_constant_index(0);
|
||||
SmallVector<Value, 4> indicesA(linalgOp.getInputShapedType(0).getRank(),
|
||||
zero);
|
||||
SmallVector<Value, 4> indicesB(linalgOp.getInputShapedType(1).getRank(),
|
||||
zero);
|
||||
SmallVector<Value, 4> indicesC(linalgOp.getOutputShapedType(0).getRank(),
|
||||
zero);
|
||||
Value a = vector_transfer_read(extractVectorTypeFromScalarView(viewA), viewA,
|
||||
indicesA);
|
||||
Value b = vector_transfer_read(extractVectorTypeFromScalarView(viewB), viewB,
|
||||
indicesB);
|
||||
Value c = vector_transfer_read(extractVectorTypeFromScalarView(viewC), viewC,
|
||||
indicesC);
|
||||
SmallVector<Value, 4> indicesA, indicesB, indicesC;
|
||||
if (vtA)
|
||||
indicesA = SmallVector<Value, 4>(vtA.getRank(), zero);
|
||||
if (vtB)
|
||||
indicesB = SmallVector<Value, 4>(vtB.getRank(), zero);
|
||||
if (vtC)
|
||||
indicesC = SmallVector<Value, 4>(vtC.getRank(), zero);
|
||||
Value a = vtA ? vector_transfer_read(vtA, viewA, indicesA).value
|
||||
: std_load(viewA, indicesA).value;
|
||||
Value b = vtB ? vector_transfer_read(vtB, viewB, indicesB).value
|
||||
: std_load(viewB, indicesB).value;
|
||||
Value c = vtC ? vector_transfer_read(vtC, viewC, indicesC).value
|
||||
: std_load(viewC, indicesC).value;
|
||||
Value res = vector_contract(a, b, c, linalgOp.indexing_maps(),
|
||||
linalgOp.iterator_types());
|
||||
vector_transfer_write(res, viewC, indicesC);
|
||||
if (vtC)
|
||||
vector_transfer_write(res, viewC, indicesC);
|
||||
else
|
||||
std_store(res, viewC, indicesC);
|
||||
}
|
||||
|
||||
/// Check whether there is any interleaved use of any `values` between `firstOp`
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-contraction-to-vector-patterns | FileCheck %s --check-prefix=VECTOR-CONTRACTION
|
||||
|
||||
func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
|
||||
%B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
|
||||
|
@ -30,3 +31,38 @@ func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
|
|||
// CHECK-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>
|
||||
//
|
||||
// CHECK: linalg.copy
|
||||
|
||||
// VECTOR-CONTRACTION-LABEL: contraction_dot
|
||||
func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
|
||||
// VECTOR-CONTRACTION: vector.contract
|
||||
// VECTOR-CONTRACTION-SAME: vector<1584xf32>, vector<1584xf32> into f32
|
||||
linalg.dot(%A, %B, %C) : memref<1584xf32>, memref<1584xf32>, memref<f32>
|
||||
return
|
||||
}
|
||||
|
||||
// VECTOR-CONTRACTION-LABEL: contraction_matvec
|
||||
func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
|
||||
// VECTOR-CONTRACTION: vector.contract
|
||||
// VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584xf32> into vector<1584xf32>
|
||||
linalg.matvec %A, %B, %C :
|
||||
(memref<1584x1584xf32>, memref<1584xf32>, memref<1584xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// VECTOR-CONTRACTION-LABEL: contraction_matmul
|
||||
func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
|
||||
// VECTOR-CONTRACTION: vector.contract
|
||||
// VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584x1584xf32> into vector<1584x1584xf32>
|
||||
linalg.matmul %A, %B, %C :
|
||||
(memref<1584x1584xf32>, memref<1584x1584xf32>, memref<1584x1584xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// VECTOR-CONTRACTION-LABEL: contraction_batch_matmul
|
||||
func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
|
||||
// VECTOR-CONTRACTION: vector.contract
|
||||
// VECTOR-CONTRACTION-SAME: vector<1584x1584x1584xf32>, vector<1584x1584x1584xf32> into vector<1584x1584x1584xf32>
|
||||
linalg.batch_matmul %A, %B, %C :
|
||||
(memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -54,6 +54,11 @@ struct TestLinalgTransforms
|
|||
llvm::cl::desc(
|
||||
"Test a fused pass that forwards linalg.copy to vector.transfer"),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> testGenericToVectorPattern{
|
||||
*this, "test-contraction-to-vector-patterns",
|
||||
llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
|
||||
"in vector.contract form"),
|
||||
llvm::cl::init(false)};
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
@ -300,6 +305,16 @@ static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
|
|||
applyPatternsAndFoldGreedily(funcOp, forwardPattern);
|
||||
}
|
||||
|
||||
static void applyContractionToVectorPatterns(FuncOp funcOp) {
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<LinalgVectorizationPattern<BatchMatmulOp>,
|
||||
LinalgVectorizationPattern<MatmulOp>,
|
||||
LinalgVectorizationPattern<MatvecOp>,
|
||||
LinalgVectorizationPattern<DotOp>,
|
||||
LinalgVectorizationPattern<GenericOp>>(funcOp.getContext());
|
||||
applyPatternsAndFoldGreedily(funcOp, patterns);
|
||||
}
|
||||
|
||||
/// Apply transformations specified as patterns.
|
||||
void TestLinalgTransforms::runOnFunction() {
|
||||
auto lambda = [&](void *) {
|
||||
|
@ -323,6 +338,8 @@ void TestLinalgTransforms::runOnFunction() {
|
|||
testMatmulToVectorPatterns2dTiling);
|
||||
if (testVectorTransferForwardingPatterns)
|
||||
return applyVectorTransferForwardingPatterns(getFunction());
|
||||
if (testGenericToVectorPattern)
|
||||
return applyContractionToVectorPatterns(getFunction());
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
|
|
Loading…
Reference in New Issue