Make vectorization aware of loop semantics

Now that we have a dependence analysis, we can check that loops are indeed parallel and make vectorization correct.

PiperOrigin-RevId: 240682727
This commit is contained in:
Nicolas Vasilache 2019-03-27 17:50:34 -07:00 committed by jpienaar
parent 21547ace87
commit 4dc7af9da8
5 changed files with 94 additions and 74 deletions

View File

@ -85,20 +85,20 @@ bool isAccessInvariant(Value &iv, Value &index);
llvm::DenseSet<Value *, llvm::DenseMapInfo<Value *>>
getInvariantAccesses(Value &iv, llvm::ArrayRef<Value *> indices);
using VectorizableLoopFun = std::function<bool(AffineForOp)>;
/// Checks whether the loop is structurally vectorizable; i.e.:
/// 1. the loop has proper dependence semantics (parallel, reduction, etc);
/// 2. no conditionals are nested under the loop;
/// 3. all nested load/stores are to scalar MemRefs.
/// TODO(ntv): implement dependence semantics
/// 1. no conditionals are nested under the loop;
/// 2. all nested load/stores are to scalar MemRefs.
/// TODO(ntv): relax the no-conditionals restriction
bool isVectorizableLoop(AffineForOp loop);
bool isVectorizableLoopBody(AffineForOp loop);
/// Checks whether the loop is structurally vectorizable and that all the LoadOp
/// and StoreOp matched have access indexing functions that are are either:
/// 1. invariant along the loop induction variable created by 'loop';
/// 2. varying along the 'fastestVaryingDim' memory dimension.
bool isVectorizableLoopAlongFastestVaryingMemRefDim(AffineForOp loop,
unsigned fastestVaryingDim);
bool isVectorizableLoopBodyAlongFastestVaryingMemRefDim(
AffineForOp loop, unsigned fastestVaryingDim);
/// Checks where SSA dominance would be violated if a for inst's body
/// operations are shifted by the specified shifts. This method checks if a

View File

@ -274,20 +274,17 @@ static bool isVectorTransferReadOrWrite(Operation &op) {
return op.isa<VectorTransferReadOp>() || op.isa<VectorTransferWriteOp>();
}
using VectorizableInstFun = std::function<bool(AffineForOp, Operation &)>;
using VectorizableOpFun = std::function<bool(AffineForOp, Operation &)>;
static bool isVectorizableLoopWithCond(AffineForOp loop,
VectorizableInstFun isVectorizableInst) {
auto *forInst = loop.getOperation();
if (!matcher::isParallelLoop(*forInst) &&
!matcher::isReductionLoop(*forInst)) {
return false;
}
static bool
isVectorizableLoopBodyWithOpCond(AffineForOp loop,
VectorizableOpFun isVectorizableOp) {
auto *forOp = loop.getOperation();
// No vectorization across conditionals for now.
auto conditionals = matcher::If();
SmallVector<NestedMatch, 8> conditionalsMatched;
conditionals.match(forInst, &conditionalsMatched);
conditionals.match(forOp, &conditionalsMatched);
if (!conditionalsMatched.empty()) {
return false;
}
@ -298,21 +295,21 @@ static bool isVectorizableLoopWithCond(AffineForOp loop,
!(op.isa<AffineIfOp>() || op.isa<AffineForOp>());
});
SmallVector<NestedMatch, 8> regionsMatched;
regions.match(forInst, &regionsMatched);
regions.match(forOp, &regionsMatched);
if (!regionsMatched.empty()) {
return false;
}
auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite);
SmallVector<NestedMatch, 8> vectorTransfersMatched;
vectorTransfers.match(forInst, &vectorTransfersMatched);
vectorTransfers.match(forOp, &vectorTransfersMatched);
if (!vectorTransfersMatched.empty()) {
return false;
}
auto loadAndStores = matcher::Op(matcher::isLoadOrStore);
SmallVector<NestedMatch, 8> loadAndStoresMatched;
loadAndStores.match(forInst, &loadAndStoresMatched);
loadAndStores.match(forOp, &loadAndStoresMatched);
for (auto ls : loadAndStoresMatched) {
auto *op = ls.getMatchedOperation();
auto load = op->dyn_cast<LoadOp>();
@ -324,16 +321,16 @@ static bool isVectorizableLoopWithCond(AffineForOp loop,
if (vector) {
return false;
}
if (!isVectorizableInst(loop, *op)) {
if (isVectorizableOp && !isVectorizableOp(loop, *op)) {
return false;
}
}
return true;
}
bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim(
bool mlir::isVectorizableLoopBodyAlongFastestVaryingMemRefDim(
AffineForOp loop, unsigned fastestVaryingDim) {
VectorizableInstFun fun([fastestVaryingDim](AffineForOp loop, Operation &op) {
VectorizableOpFun fun([fastestVaryingDim](AffineForOp loop, Operation &op) {
auto load = op.dyn_cast<LoadOp>();
auto store = op.dyn_cast<StoreOp>();
return load ? isContiguousAccess(*loop.getInductionVar(), load,
@ -341,14 +338,11 @@ bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim(
: isContiguousAccess(*loop.getInductionVar(), store,
fastestVaryingDim);
});
return isVectorizableLoopWithCond(loop, fun);
return isVectorizableLoopBodyWithOpCond(loop, fun);
}
bool mlir::isVectorizableLoop(AffineForOp loop) {
VectorizableInstFun fun(
// TODO: implement me
[](AffineForOp loop, Operation &op) { return true; });
return isVectorizableLoopWithCond(loop, fun);
bool mlir::isVectorizableLoopBody(AffineForOp loop) {
return isVectorizableLoopBodyWithOpCond(loop, nullptr);
}
/// Checks whether SSA dominance would be violated if a for op's body

View File

@ -153,18 +153,6 @@ NestedPattern For(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
nested, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
}
// TODO(ntv): parallel annotation on loops.
bool isParallelLoop(Operation &op) {
auto loop = op.cast<AffineForOp>();
return loop || true; // loop->isParallel();
};
// TODO(ntv): reduction annotation on loops.
bool isReductionLoop(Operation &op) {
auto loop = op.cast<AffineForOp>();
return loop || true; // loop->isReduction();
};
bool isLoadOrStore(Operation &op) {
return op.isa<LoadOp>() || op.isa<StoreOp>();
};

View File

@ -24,6 +24,7 @@
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/Analysis/NestedMatcher.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/Analysis/VectorAnalysis.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
@ -565,7 +566,8 @@ static llvm::cl::list<int> clFastestVaryingPattern(
/// Forward declaration.
static FilterFunctionType
isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension);
isVectorizableLoopPtrFactory(const llvm::DenseSet<Operation *> &parallelLoops,
unsigned fastestVaryingMemRefDimension);
// Build a bunch of predetermined patterns that will be traversed in order.
// Due to the recursive nature of NestedPatterns, this captures
@ -573,77 +575,84 @@ isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension);
/// Note that this currently only matches 2 nested loops and will be extended.
// TODO(ntv): support 3-D loop patterns with a common reduction loop that can
// be matched to GEMMs.
static std::vector<NestedPattern> defaultPatterns() {
static std::vector<NestedPattern>
defaultPatterns(const llvm::DenseSet<Operation *> &parallelLoops) {
using matcher::For;
return std::vector<NestedPattern>{
// 3-D patterns
For(isVectorizableLoopPtrFactory(2),
For(isVectorizableLoopPtrFactory(1),
For(isVectorizableLoopPtrFactory(0)))),
For(isVectorizableLoopPtrFactory(parallelLoops, 2),
For(isVectorizableLoopPtrFactory(parallelLoops, 1),
For(isVectorizableLoopPtrFactory(parallelLoops, 0)))),
// for i { for j { A[??f(not i, not j), f(i, not j), f(not i, j)];}}
// test independently with:
// --test-fastest-varying=1 --test-fastest-varying=0
For(isVectorizableLoopPtrFactory(1),
For(isVectorizableLoopPtrFactory(0))),
For(isVectorizableLoopPtrFactory(parallelLoops, 1),
For(isVectorizableLoopPtrFactory(parallelLoops, 0))),
// for i { for j { A[??f(not i, not j), f(i, not j), ?, f(not i, j)];}}
// test independently with:
// --test-fastest-varying=2 --test-fastest-varying=0
For(isVectorizableLoopPtrFactory(2),
For(isVectorizableLoopPtrFactory(0))),
For(isVectorizableLoopPtrFactory(parallelLoops, 2),
For(isVectorizableLoopPtrFactory(parallelLoops, 0))),
// for i { for j { A[??f(not i, not j), f(i, not j), ?, ?, f(not i, j)];}}
// test independently with:
// --test-fastest-varying=3 --test-fastest-varying=0
For(isVectorizableLoopPtrFactory(3),
For(isVectorizableLoopPtrFactory(0))),
For(isVectorizableLoopPtrFactory(parallelLoops, 3),
For(isVectorizableLoopPtrFactory(parallelLoops, 0))),
// for i { for j { A[??f(not i, not j), f(not i, j), f(i, not j)];}}
// test independently with:
// --test-fastest-varying=0 --test-fastest-varying=1
For(isVectorizableLoopPtrFactory(0),
For(isVectorizableLoopPtrFactory(1))),
For(isVectorizableLoopPtrFactory(parallelLoops, 0),
For(isVectorizableLoopPtrFactory(parallelLoops, 1))),
// for i { for j { A[??f(not i, not j), f(not i, j), ?, f(i, not j)];}}
// test independently with:
// --test-fastest-varying=0 --test-fastest-varying=2
For(isVectorizableLoopPtrFactory(0),
For(isVectorizableLoopPtrFactory(2))),
For(isVectorizableLoopPtrFactory(parallelLoops, 0),
For(isVectorizableLoopPtrFactory(parallelLoops, 2))),
// for i { for j { A[??f(not i, not j), f(not i, j), ?, ?, f(i, not j)];}}
// test independently with:
// --test-fastest-varying=0 --test-fastest-varying=3
For(isVectorizableLoopPtrFactory(0),
For(isVectorizableLoopPtrFactory(3))),
For(isVectorizableLoopPtrFactory(parallelLoops, 0),
For(isVectorizableLoopPtrFactory(parallelLoops, 3))),
// for i { A[??f(not i) , f(i)];}
// test independently with: --test-fastest-varying=0
For(isVectorizableLoopPtrFactory(0)),
For(isVectorizableLoopPtrFactory(parallelLoops, 0)),
// for i { A[??f(not i) , f(i), ?];}
// test independently with: --test-fastest-varying=1
For(isVectorizableLoopPtrFactory(1)),
For(isVectorizableLoopPtrFactory(parallelLoops, 1)),
// for i { A[??f(not i) , f(i), ?, ?];}
// test independently with: --test-fastest-varying=2
For(isVectorizableLoopPtrFactory(2)),
For(isVectorizableLoopPtrFactory(parallelLoops, 2)),
// for i { A[??f(not i) , f(i), ?, ?, ?];}
// test independently with: --test-fastest-varying=3
For(isVectorizableLoopPtrFactory(3))};
For(isVectorizableLoopPtrFactory(parallelLoops, 3))};
}
/// Creates a vectorization pattern from the command line arguments.
/// Up to 3-D patterns are supported.
/// If the command line argument requests a pattern of higher order, returns an
/// empty pattern list which will conservatively result in no vectorization.
static std::vector<NestedPattern> makePatterns() {
static std::vector<NestedPattern>
makePatterns(const llvm::DenseSet<Operation *> &parallelLoops) {
using matcher::For;
if (clFastestVaryingPattern.empty()) {
return defaultPatterns();
return defaultPatterns(parallelLoops);
}
switch (clFastestVaryingPattern.size()) {
case 1:
return {For(isVectorizableLoopPtrFactory(clFastestVaryingPattern[0]))};
return {For(isVectorizableLoopPtrFactory(parallelLoops,
clFastestVaryingPattern[0]))};
case 2:
return {For(isVectorizableLoopPtrFactory(clFastestVaryingPattern[0]),
For(isVectorizableLoopPtrFactory(clFastestVaryingPattern[1])))};
return {For(
isVectorizableLoopPtrFactory(parallelLoops, clFastestVaryingPattern[0]),
For(isVectorizableLoopPtrFactory(parallelLoops,
clFastestVaryingPattern[1])))};
case 3:
return {For(
isVectorizableLoopPtrFactory(clFastestVaryingPattern[0]),
For(isVectorizableLoopPtrFactory(clFastestVaryingPattern[1]),
For(isVectorizableLoopPtrFactory(clFastestVaryingPattern[2]))))};
isVectorizableLoopPtrFactory(parallelLoops, clFastestVaryingPattern[0]),
For(isVectorizableLoopPtrFactory(parallelLoops,
clFastestVaryingPattern[1]),
For(isVectorizableLoopPtrFactory(parallelLoops,
clFastestVaryingPattern[2]))))};
default:
return std::vector<NestedPattern>();
}
@ -905,10 +914,14 @@ static LogicalResult vectorizeAffineForOp(AffineForOp loop, int64_t step,
/// once we understand better the performance implications and we are confident
/// we can build a cost model and a search procedure.
static FilterFunctionType
isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) {
return [fastestVaryingMemRefDimension](Operation &forOp) {
isVectorizableLoopPtrFactory(const llvm::DenseSet<Operation *> &parallelLoops,
unsigned fastestVaryingMemRefDimension) {
return [&parallelLoops, fastestVaryingMemRefDimension](Operation &forOp) {
auto loop = forOp.cast<AffineForOp>();
return isVectorizableLoopAlongFastestVaryingMemRefDim(
auto parallelIt = parallelLoops.find(loop);
if (parallelIt == parallelLoops.end())
return false;
return isVectorizableLoopBodyAlongFastestVaryingMemRefDim(
loop, fastestVaryingMemRefDimension);
};
}
@ -1168,7 +1181,7 @@ static LogicalResult vectorizeRootMatch(NestedMatch m,
// vectorizable. If a pattern is not vectorizable anymore, we just skip it.
// TODO(ntv): implement a non-greedy profitability analysis that keeps only
// non-intersecting patterns.
if (!isVectorizableLoop(loop)) {
if (!isVectorizableLoopBody(loop)) {
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable");
return failure();
}
@ -1240,7 +1253,16 @@ void Vectorize::runOnFunction() {
NestedPatternContext mlContext;
Function &f = getFunction();
for (auto &pat : makePatterns()) {
llvm::DenseSet<Operation *> parallelLoops;
f.walkPostOrder([&parallelLoops](Operation *op) {
if (auto loop = op->dyn_cast<AffineForOp>()) {
if (isLoopParallel(loop)) {
parallelLoops.insert(op);
}
}
});
for (auto &pat : makePatterns(parallelLoops)) {
LLVM_DEBUG(dbgs() << "\n******************************************");
LLVM_DEBUG(dbgs() << "\n******************************************");
LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on Function\n");

View File

@ -11,6 +11,7 @@
#set0 = (i) : (i >= 0)
// Maps introduced to vectorize fastest varying memory index.
// CHECK-LABEL: vec1d
func @vec1d(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK-DAG: [[C0:%[a-z0-9_]+]] = constant 0 : index
// CHECK-DAG: [[ARG_M:%[0-9]+]] = dim %arg0, 0 : memref<?x?xf32>
@ -133,6 +134,7 @@ func @vec1d(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
return
}
// CHECK-LABEL: vector_add_2d
func @vector_add_2d(%M : index, %N : index) -> f32 {
%A = alloc (%M, %N) : memref<?x?xf32, 0>
%B = alloc (%M, %N) : memref<?x?xf32, 0>
@ -201,3 +203,17 @@ func @vec_rejected(%A : memref<?x?xf32>, %C : memref<?x?xf32>) {
}
return
}
// This should not vectorize due to the sequential dependence in the loop.
// CHECK-LABEL: @vec_rejected_sequential
func @vec_rejected_sequential(%A : memref<?xf32>) {
%N = dim %A, 0 : memref<?xf32>
affine.for %i = 0 to %N {
// CHECK-NOT: vector
%a = load %A[%i] : memref<?xf32>
// CHECK-NOT: vector
%ip1 = affine.apply (d0)->(d0 + 1) (%i)
store %a, %A[%ip1] : memref<?xf32>
}
return
}