[mlir][Linalg] Add loop.parallel lowering for all Linalg Ops.

The outer parallel loops of a linalg operation is lowered to
loop.parallel, with the other loops lowered to loop.for. This gets the
lowering to loop.parallel on par with the loop.for lowering. In future
the reduction loop could also be lowered to loop.parallel.
Also add a utility function that returns the loops that are
created.

Differential Revision: https://reviews.llvm.org/D77678
This commit is contained in:
MaheshRavishankar 2020-04-13 09:33:34 -07:00
parent dffbeffa39
commit 03391df90e
4 changed files with 780 additions and 369 deletions

View File

@ -70,6 +70,13 @@ LogicalResult tileAndFuseLinalgOpAndSetMarker(
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker);
using LinalgLoops = SmallVector<Operation *, 4>;
/// Emits a loop nest of with the proper body for `op`.
template <typename LoopTy, typename ConcreteOp>
Optional<LinalgLoops> linalgLowerOpToLoops(PatternRewriter &rewriter,
Operation *op);
/// Emits a loop nest of `loop.for` with the proper body for `op`.
template <typename ConcreteOp>
LogicalResult linalgOpToLoops(PatternRewriter &rewriter, Operation *op);

View File

@ -533,26 +533,111 @@ public:
// consequence, (1) it is only allowed to emit new ops if the match is
// guaranteed to be a success, (2) it is not allowed erase/replace, and (3) an
// encompassing pattern must take care of the erasure logic.
template <typename LoopTy, typename IndexedValueTy, typename ConcreteOpTy>
template <typename LoopTy, typename ConcreteOpTy>
class LinalgOpToLoopsImpl {
public:
static LogicalResult doit(Operation *op, PatternRewriter &rewriter);
static Optional<LinalgLoops> doit(Operation *op, PatternRewriter &rewriter);
};
template <typename LoopTy>
bool loweringIsAllowed(int numParallelLoops, int numLoops) {
return true;
}
template <>
bool loweringIsAllowed<loop::ParallelOp>(int numParallelLoops, int numLoops) {
return numParallelLoops == numLoops;
}
namespace {
/// Helper struct to generate the loop nest for the op. This factored out here
/// to be able to partially specialize this for different LoopTy.
template <typename LoopTy, typename ConcreteOpTy>
class GenerateLoopNest {
public:
using IndexedValueTy =
typename std::conditional<std::is_same<LoopTy, AffineForOp>::value,
AffineIndexedValue, StdIndexedValue>::type;
static void doit(ConcreteOpTy linalgOp, ArrayRef<Value> loopRanges,
MutableArrayRef<ValueHandle> allIvs) {
SmallVector<ValueHandle *, 4> allPIvs =
makeHandlePointers(MutableArrayRef<ValueHandle>(allIvs));
template <typename LoopTy, typename IndexedValueTy, typename ConcreteOpTy>
LogicalResult LinalgOpToLoopsImpl<LoopTy, IndexedValueTy, ConcreteOpTy>::doit(
Operation *op, PatternRewriter &rewriter) {
OpBuilder b(op);
ScopedContext scope(b, op->getLoc());
GenericLoopNestRangeBuilder<LoopTy>(allPIvs, loopRanges)([&] {
SmallVector<Value, 4> allIvValues(allIvs.begin(), allIvs.end());
LinalgScopedEmitter<IndexedValueTy,
ConcreteOpTy>::emitScalarImplementation(allIvValues,
linalgOp);
});
}
};
/// Generates loops nest using loop.parallel. loop.parallel is only used for the
/// outer parallel loops. All other loops are generated using loop.for
/// operation.
template <typename ConcreteOpTy>
class GenerateLoopNest<loop::ParallelOp, ConcreteOpTy> {
public:
using IndexedValueTy = StdIndexedValue;
static void doit(ConcreteOpTy linalgOp, ArrayRef<Value> loopRanges,
MutableArrayRef<ValueHandle> allIvs) {
// Only generate loop.parallel for outer consecutive "parallel"
// iterator_types.
// TODO(ravishankarm): Generate loop.parallel for all "parallel" iterator
// types, not just the outer most ones. Also handle "reduction" iterator
// types.
auto nPar = linalgOp.getNumParallelLoops();
auto nRed = linalgOp.getNumReductionLoops();
auto nWin = linalgOp.getNumWindowLoops();
auto nLoops = nPar + nRed + nWin;
auto nOuterPar = linalgOp.iterator_types()
.getValue()
.take_while([](Attribute attr) {
return attr.cast<StringAttr>().getValue() ==
getParallelIteratorTypeName();
})
.size();
// If there are no outer parallel loops, then number of loop ops is same as
// the number of loops, and they are all loop.for ops.
auto nLoopOps = (nOuterPar ? nLoops - nOuterPar + 1 : nLoops);
SmallVector<ValueHandle *, 4> allPIvs =
makeHandlePointers(MutableArrayRef<ValueHandle>(allIvs));
SmallVector<OperationHandle, 4> allLoops(nLoopOps, OperationHandle());
SmallVector<OperationHandle *, 4> allPLoops;
allPLoops.reserve(allLoops.size());
for (OperationHandle &loop : allLoops)
allPLoops.push_back(&loop);
ArrayRef<ValueHandle *> allPIvsRef(allPIvs);
ArrayRef<OperationHandle *> allPLoopsRef(allPLoops);
if (nOuterPar) {
GenericLoopNestRangeBuilder<loop::ParallelOp>(
allPIvsRef.take_front(nOuterPar),
loopRanges.take_front(nOuterPar))([&] {
GenericLoopNestRangeBuilder<loop::ForOp>(
allPIvsRef.drop_front(nOuterPar),
loopRanges.drop_front(nOuterPar))([&] {
SmallVector<Value, 4> allIvValues(allIvs.begin(), allIvs.end());
LinalgScopedEmitter<StdIndexedValue, ConcreteOpTy>::
emitScalarImplementation(allIvValues, linalgOp);
});
});
} else {
// If there are no parallel loops then fallback to generating all loop.for
// operations.
GenericLoopNestRangeBuilder<loop::ForOp>(allPIvsRef, loopRanges)([&] {
SmallVector<Value, 4> allIvValues(allIvs.begin(), allIvs.end());
LinalgScopedEmitter<StdIndexedValue,
ConcreteOpTy>::emitScalarImplementation(allIvValues,
linalgOp);
});
}
}
};
} // namespace
template <typename LoopTy, typename ConcreteOpTy>
Optional<LinalgLoops>
LinalgOpToLoopsImpl<LoopTy, ConcreteOpTy>::doit(Operation *op,
PatternRewriter &rewriter) {
using Impl = GenerateLoopNest<LoopTy, ConcreteOpTy>;
using IndexedValueTy =
typename GenerateLoopNest<LoopTy, ConcreteOpTy>::IndexedValueTy;
ScopedContext scope(rewriter, op->getLoc());
// The flattened loopToOperandRangesMaps is expected to be an invertible
// permutation map (which is asserted in the inverse calculation).
@ -563,8 +648,6 @@ LogicalResult LinalgOpToLoopsImpl<LoopTy, IndexedValueTy, ConcreteOpTy>::doit(
auto nRed = linalgOp.getNumReductionLoops();
auto nWin = linalgOp.getNumWindowLoops();
auto nLoops = nPar + nRed + nWin;
if (!loweringIsAllowed<LoopTy>(nPar, nLoops))
return failure();
auto mapsRange =
linalgOp.indexing_maps().template getAsRange<AffineMapAttr>();
auto maps =
@ -573,25 +656,34 @@ LogicalResult LinalgOpToLoopsImpl<LoopTy, IndexedValueTy, ConcreteOpTy>::doit(
if (!invertedMap) {
LinalgScopedEmitter<IndexedValueTy, ConcreteOpTy>::emitScalarImplementation(
{}, linalgOp);
return success();
return LinalgLoops();
}
SmallVector<ValueHandle, 4> allIvs(nLoops, ValueHandle(b.getIndexType()));
SmallVector<ValueHandle *, 4> allPIvs =
makeHandlePointers(MutableArrayRef<ValueHandle>(allIvs));
auto loopRanges = emitLoopRanges(scope.getBuilder(), scope.getLocation(),
invertedMap, getViewSizes(b, linalgOp));
SmallVector<ValueHandle, 4> allIvs(nLoops,
ValueHandle(rewriter.getIndexType()));
auto loopRanges =
emitLoopRanges(scope.getBuilder(), scope.getLocation(), invertedMap,
getViewSizes(rewriter, linalgOp));
assert(loopRanges.size() == allIvs.size());
GenericLoopNestRangeBuilder<LoopTy>(allPIvs, loopRanges)([&] {
SmallVector<Value, 4> allIvValues(allIvs.begin(), allIvs.end());
LinalgScopedEmitter<IndexedValueTy, ConcreteOpTy>::emitScalarImplementation(
allIvValues, linalgOp);
});
return success();
Impl::doit(linalgOp, loopRanges, allIvs);
// Number of loop ops might be different from the number of ivs since some
// loops like affine.parallel and loop.parallel have multiple ivs.
llvm::SetVector<Operation *> loopSet;
for (ValueHandle &iv : allIvs) {
if (!iv.hasValue())
return {};
// The induction variable is a block argument of the entry block of the
// loop operation.
BlockArgument ivVal = iv.getValue().dyn_cast<BlockArgument>();
if (!ivVal)
return {};
loopSet.insert(ivVal.getOwner()->getParentOp());
}
LinalgLoops loops(loopSet.begin(), loopSet.end());
return loops;
}
template <typename LoopType, typename IndexedValueType, typename ConcreteOp>
template <typename LoopType, typename ConcreteOp>
class LinalgRewritePattern : public RewritePattern {
public:
explicit LinalgRewritePattern(MLIRContext *context)
@ -599,8 +691,8 @@ public:
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
using Impl = LinalgOpToLoopsImpl<LoopType, IndexedValueType, ConcreteOp>;
if (failed(Impl::doit(op, rewriter)))
using Impl = LinalgOpToLoopsImpl<LoopType, ConcreteOp>;
if (!Impl::doit(op, rewriter))
return failure();
rewriter.eraseOp(op);
return success();
@ -608,32 +700,28 @@ public:
};
// Helper classes for type list expansion.
template <typename LoopType, typename IndexedValueType, typename... LinalgOps>
template <typename LoopType, typename... LinalgOps>
class RewritePatternList;
template <typename LoopType, typename IndexedValueType>
class RewritePatternList<LoopType, IndexedValueType> {
template <typename LoopType>
class RewritePatternList<LoopType> {
public:
static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {}
};
template <typename LoopType, typename IndexedValueType, typename ConcreteOp,
typename... LinalgOps>
class RewritePatternList<LoopType, IndexedValueType, ConcreteOp, LinalgOps...> {
template <typename LoopType, typename ConcreteOp, typename... LinalgOps>
class RewritePatternList<LoopType, ConcreteOp, LinalgOps...> {
public:
static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns
.insert<LinalgRewritePattern<LoopType, IndexedValueType, ConcreteOp>>(
ctx);
RewritePatternList<LoopType, IndexedValueType, LinalgOps...>::build(
patterns, ctx);
patterns.insert<LinalgRewritePattern<LoopType, ConcreteOp>>(ctx);
RewritePatternList<LoopType, LinalgOps...>::build(patterns, ctx);
}
};
/// Populate the given list with patterns that convert from Linalg to LLVM.
template <typename LoopType, typename IndexedValueType>
template <typename LoopType>
void FillRewritePatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) {
RewritePatternList<LoopType, IndexedValueType,
RewritePatternList<LoopType,
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>::build(patterns, ctx);
@ -677,13 +765,13 @@ struct FoldAffineOp : public RewritePattern {
};
} // namespace
template <typename LoopType, typename IndexedValueType>
template <typename LoopType>
static void lowerLinalgToLoopsImpl(Operation *op, MLIRContext *context) {
OwningRewritePatternList patterns;
// Canonicalization and folding patterns applied greedily allow cleaning up
// the emitted IR on the fly.
// TODO(ntv) fold view and subview ops?
FillRewritePatterns<LoopType, IndexedValueType>(patterns, context);
FillRewritePatterns<LoopType>(patterns, context);
DimOp::getCanonicalizationPatterns(patterns, context);
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
patterns.insert<FoldAffineOp>(context);
@ -695,21 +783,18 @@ namespace {
struct LowerToAffineLoops
: public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> {
void runOnFunction() override {
lowerLinalgToLoopsImpl<AffineForOp, AffineIndexedValue>(getFunction(),
&getContext());
lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), &getContext());
}
};
struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> {
void runOnFunction() override {
lowerLinalgToLoopsImpl<loop::ForOp, StdIndexedValue>(getFunction(),
&getContext());
lowerLinalgToLoopsImpl<loop::ForOp>(getFunction(), &getContext());
}
};
struct LowerToParallelLoops
: public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> {
void runOnFunction() override {
lowerLinalgToLoopsImpl<loop::ParallelOp, StdIndexedValue>(getFunction(),
&getContext());
lowerLinalgToLoopsImpl<loop::ParallelOp>(getFunction(), &getContext());
}
};
} // namespace
@ -728,28 +813,38 @@ mlir::createConvertLinalgToAffineLoopsPass() {
return std::make_unique<LowerToAffineLoops>();
}
/// Emits a loop nest with the proper body for `op`.
template <typename LoopTy, typename ConcreteOp>
Optional<LinalgLoops>
mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter, Operation *op) {
return LinalgOpToLoopsImpl<LoopTy, ConcreteOp>::doit(op, rewriter);
}
/// Emits a loop nest of `loop.for` with the proper body for `op`.
template <typename ConcreteOp>
LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter,
Operation *op) {
return LinalgOpToLoopsImpl<loop::ForOp, StdIndexedValue, ConcreteOp>::doit(
op, rewriter);
Optional<LinalgLoops> loops =
linalgLowerOpToLoops<loop::ForOp, ConcreteOp>(rewriter, op);
return loops ? success() : failure();
}
/// Emits a loop nest of `affine.for` with the proper body for `op`.
template <typename ConcreteOp>
LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter,
Operation *op) {
return LinalgOpToLoopsImpl<AffineForOp, AffineIndexedValue, ConcreteOp>::doit(
op, rewriter);
Optional<LinalgLoops> loops =
linalgLowerOpToLoops<AffineForOp, ConcreteOp>(rewriter, op);
return loops ? success() : failure();
}
/// Emits a loop nest of `loop.parallel` with the proper body for `op`.
template <typename ConcreteOp>
LogicalResult mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter,
Operation *op) {
return LinalgOpToLoopsImpl<loop::ParallelOp, StdIndexedValue,
ConcreteOp>::doit(op, rewriter);
Optional<LinalgLoops> loops =
linalgLowerOpToLoops<loop::ParallelOp, ConcreteOp>(rewriter, op);
return loops ? success() : failure();
}
// TODO(ntv) Need to make these instantiations more future-proof to avoid the
@ -758,7 +853,12 @@ LogicalResult mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter,
template LogicalResult mlir::linalg::linalgOpToLoops<OP_TYPE>( \
PatternRewriter & rewriter, Operation * op); \
template LogicalResult mlir::linalg::linalgOpToAffineLoops<OP_TYPE>( \
PatternRewriter & rewriter, Operation * op);
PatternRewriter & rewriter, Operation * op); \
template LogicalResult mlir::linalg::linalgOpToParallelLoops<OP_TYPE>( \
PatternRewriter & rewriter, Operation * op); \
template Optional<LinalgLoops> \
mlir::linalg::linalgLowerOpToLoops<loop::ParallelOp, OP_TYPE>( \
PatternRewriter & rewriter, Operation * op);
INSTANTIATE_LINALG_OP_TO_LOOPS(CopyOp)
INSTANTIATE_LINALG_OP_TO_LOOPS(FillOp)
@ -771,9 +871,3 @@ INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingMinOp)
INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingSumOp)
INSTANTIATE_LINALG_OP_TO_LOOPS(GenericOp)
INSTANTIATE_LINALG_OP_TO_LOOPS(IndexedGenericOp)
// TODO(pifon): Enable lowering to parallel loops for ops other than
// linalg.generic for now to be on the safe side.
template LogicalResult
mlir::linalg::linalgOpToParallelLoops<GenericOp>(PatternRewriter &rewriter,
Operation *op);

File diff suppressed because it is too large Load Diff

View File

@ -32,22 +32,32 @@ func @linalg_generic_sum(%lhs: memref<2x2xf32>,
// -----
#accesses = [
affine_map<(m, n) -> (m, n)>,
affine_map<(m, n) -> (m)>
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
]
#trait = {
args_in = 1,
args_out = 1,
iterator_types = ["parallel", "reduction"],
iterator_types = ["parallel", "parallel", "reduction", "parallel"],
indexing_maps = #accesses
}
func @do_not_lower_reduce(%A: memref<2x4xf32>, %B: memref<2xf32>) {
func @lower_outer_parallel(%A: memref<?x?x?x?xf32>, %B: memref<?x?x?xf32>) {
linalg.generic #trait %A, %B {
^bb0(%a: f32, %b: f32):
linalg.yield %a: f32
} : memref<2x4xf32>, memref<2xf32>
} : memref<?x?x?x?xf32>, memref<?x?x?xf32>
return
}
// CHECK-LABEL: @do_not_lower_reduce
// CHECK: linalg.generic
// CHECK-LABEL: @lower_outer_parallel
// CHECK-DAG: %[[C0:.*]] = constant 0
// CHECK-DAG: %[[C1:.*]] = constant 1
// CHECK-DAG: %[[D0:.*]] = dim %{{.*}}, 0
// CHECK-DAG: %[[D1:.*]] = dim %{{.*}}, 1
// CHECK-DAG: %[[D2:.*]] = dim %{{.*}}, 2
// CHECK-DAG: %[[D3:.*]] = dim %{{.*}}, 3
// CHECK: loop.parallel (%[[IV0:.*]], %[[IV1:.*]]) = (%[[C0]], %[[C0]]) to (%[[D0]], %[[D1]]) step (%[[C1]], %[[C1]])
// CHECK: loop.for %[[IV2:.*]] = %[[C0]] to %[[D2]] step %[[C1]]
// CHECK: loop.for %[[IV3:.*]] = %[[C0]] to %[[D3]] step %[[C1]]
// CHECK: load %{{.*}}[%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
// CHECK: store %{{.*}}, %{{.*}}[%[[IV0]], %[[IV1]], %[[IV3]]]