[mlir][Linalg] NFC - Cleanup explicitly instantiated paterns 2/n - Loops.cpp

This revision belongs to a series of patches that reduce reliance of Linalg transformations on templated rewrite and conversion patterns.
Instead, this uses a MatchAnyTag pattern for the vast majority of cases and dispatches internally.

Differential revision: https://reviews.llvm.org/D89133
This commit is contained in:
Nicolas Vasilache 2020-10-09 19:15:16 +00:00
parent e0dc3dba3b
commit c303d9b394
1 changed files with 63 additions and 119 deletions

View File

@ -23,6 +23,8 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/FoldUtils.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
@ -65,7 +67,7 @@ static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues,
assert(op.getOperation()->getNumRegions() == 1 &&
"Expected single region op");
auto &b = ScopedContext::getBuilderRef();
auto &block = op.region().front();
auto &block = op.getOperation()->getRegion(0).front();
BlockAndValueMapping map;
map.map(block.getArguments(), indexedValues);
for (auto &op : block.without_terminator()) {
@ -102,8 +104,6 @@ static InputAndOutputIndices getInputAndOutputIndices(ArrayRef<Value> allIvs,
makeCanonicalAffineApplies(b, loc, maps[2], allIvs)};
}
namespace {
/// Emits the MLIR for the scalar part of the generic op by:
/// 1. Emitting load ops for each input and output view in order. This is
/// achieved by applying the appropriate input or output map to the
@ -134,10 +134,9 @@ namespace {
/// }
/// }
/// ```
// TODO: need a LinalgStructuredOpInterface.
template <typename IndexedValueType, typename LinalgStructuredOpType>
void emitScalarImplementation(ArrayRef<Value> allIvs,
LinalgStructuredOpType linalgOp) {
template <typename IndexedValueType>
static void emitScalarImplementation(ArrayRef<Value> allIvs,
LinalgOp linalgOp) {
assert(linalgOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
auto &b = ScopedContext::getBuilderRef();
@ -150,7 +149,7 @@ void emitScalarImplementation(ArrayRef<Value> allIvs,
auto attr = linalgOp.template getAttrOfType<IntegerAttr>("symbol_source");
auto allIvsPlusDims = SmallVector<Value, 4>(allIvs.begin(), allIvs.end());
if (attr) {
auto operand = linalgOp.getOperand(attr.getInt());
auto operand = linalgOp.getOperation()->getOperand(attr.getInt());
auto shapedType = operand.getType().template cast<ShapedType>();
allIvsPlusDims.reserve(allIvs.size() + shapedType.getRank());
for (unsigned idx = 0, e = shapedType.getRank(); idx < e; ++idx)
@ -190,7 +189,7 @@ void emitScalarImplementation(ArrayRef<Value> allIvs,
}
template <typename IndexedValueType>
void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
static void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
assert(copyOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
auto nPar = copyOp.getNumParallelLoops();
@ -211,7 +210,7 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
}
template <typename IndexedValueType>
void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
assert(fillOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
auto nPar = fillOp.getNumParallelLoops();
@ -224,8 +223,8 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
}
template <typename IndexedValueType>
Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
MutableArrayRef<Value> imIdx) {
static Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
MutableArrayRef<Value> imIdx) {
// TODO: add a level of indirection to linalg.generic.
if (!convOp.padding())
return im(imIdx);
@ -311,8 +310,9 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) {
}
}
template <typename IndexedValueType>
void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) {
template <typename IndexedValueType, typename OpType>
static void emitPoolingMinMaxScalarImplementation(ArrayRef<Value> allIvs,
OpType op) {
InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op);
// Emit scalar form.
IndexedValueType output(op.output());
@ -320,30 +320,34 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) {
Value lhs = output(indices.outputs);
Value rhs = input(indices.inputs);
using edsc::op::sgt;
Value maxValue = std_select(sgt(lhs, rhs), lhs, rhs);
output(indices.outputs) = maxValue;
using edsc::op::slt;
Value value = std::is_same<OpType, PoolingMinOp>()
? std_select(slt(lhs, rhs), lhs, rhs)
: std_select(sgt(lhs, rhs), lhs, rhs);
output(indices.outputs) = value;
}
template <typename IndexedValueType>
void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) {
InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op);
// Emit scalar form.
IndexedValueType output(op.output());
IndexedValueType input(op.input());
Value lhs = output(indices.outputs);
Value rhs = input(indices.inputs);
using edsc::op::slt;
Value minValue = std_select(slt(lhs, rhs), lhs, rhs);
output(indices.outputs) = minValue;
static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) {
emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMaxOp>(allIvs,
op);
}
template <typename IndexedValueType>
void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) {
static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) {
emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMinOp>(allIvs,
op);
}
template <typename IndexedValueType>
static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) {
auto indices = getInputAndOutputIndices(allIvs, op);
IndexedValueType input(op.input()), output(op.output());
// Emit scalar form.
output(indices.outputs) += input(indices.inputs);
}
/// Emits the MLIR for the scalar part of the indexed generic op by:
/// 1. Emitting load ops for each input and output view in order. This is
/// achieved by applying the appropriate input or output map to the
@ -422,15 +426,16 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs,
indexing, outputBuffers);
}
template <typename LoopTy, typename ConcreteOpTy>
Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
template <typename LoopTy>
static Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op,
OpBuilder &builder) {
using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
ScopedContext scope(builder, op->getLoc());
// The flattened loopToOperandRangesMaps is expected to be an invertible
// permutation map (which is asserted in the inverse calculation).
auto linalgOp = cast<ConcreteOpTy>(op);
auto linalgOp = cast<LinalgOp>(op);
assert(linalgOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
auto mapsRange =
@ -447,7 +452,12 @@ Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
[&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector {
assert(iterArgs.empty() && "unexpected iterArgs");
allIvs.append(ivs.begin(), ivs.end());
emitScalarImplementation<IndexedValueTy>(allIvs, linalgOp);
llvm::TypeSwitch<Operation *>(op)
.Case<CopyOp, FillOp, ConvOp, PoolingMaxOp, PoolingMinOp,
PoolingSumOp, IndexedGenericOp, LinalgOp>([&](auto op) {
emitScalarImplementation<IndexedValueTy>(allIvs, op);
})
.Default([&](Operation *op) { assert(false && "unexpected op"); });
return scf::ValueVector{};
});
// Number of loop ops might be different from the number of ivs since some
@ -467,32 +477,38 @@ Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
return loops;
}
template <typename LoopType, typename ConcreteOp>
namespace {
template <typename LoopType>
class LinalgRewritePattern : public RewritePattern {
public:
explicit LinalgRewritePattern(MLIRContext *context)
: RewritePattern(ConcreteOp::getOperationName(), 1, context) {}
LinalgRewritePattern() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (!linalgOpToLoopsImpl<LoopType, ConcreteOp>(op, rewriter))
if (!isa<LinalgOp>(op))
return failure();
if (!linalgOpToLoopsImpl<LoopType>(op, rewriter))
return failure();
rewriter.eraseOp(op);
return success();
}
};
template <typename LoopType, typename ConcreteOp>
void insertOnePattern(OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<LinalgRewritePattern<LoopType, ConcreteOp>>(ctx);
}
template <typename LoopType, typename... Args>
void insertPatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) {
(void)std::initializer_list<int>{
0, (insertOnePattern<LoopType, Args>(patterns, ctx), 0)...};
struct FoldAffineOp;
} // namespace
template <typename LoopType>
static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) {
OwningRewritePatternList patterns;
patterns.insert<LinalgRewritePattern<LoopType>>();
DimOp::getCanonicalizationPatterns(patterns, context);
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
patterns.insert<FoldAffineOp>(context);
// Just apply the patterns greedily.
applyPatternsAndFoldGreedily(funcOp, patterns);
}
namespace {
/// Local folding pattern for AffineApplyOp that we can apply greedily.
/// This replaces AffineApplyOp by the proper value in cases where the
/// associated map is trivial.
@ -529,38 +545,20 @@ struct FoldAffineOp : public RewritePattern {
return failure();
}
};
} // namespace
template <typename LoopType>
static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) {
OwningRewritePatternList patterns;
// Canonicalization and folding patterns applied greedily allow cleaning up
// the emitted IR on the fly.
// TODO: fold view and subview ops?
insertPatterns<LoopType,
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>(patterns, context);
DimOp::getCanonicalizationPatterns(patterns, context);
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
patterns.insert<FoldAffineOp>(context);
// Just apply the patterns greedily.
applyPatternsAndFoldGreedily(funcOp, patterns);
}
namespace {
struct LowerToAffineLoops
: public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> {
void runOnFunction() override {
lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), &getContext());
}
};
struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> {
void runOnFunction() override {
lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), &getContext());
}
};
struct LowerToParallelLoops
: public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> {
void runOnFunction() override {
@ -583,60 +581,6 @@ mlir::createConvertLinalgToAffineLoopsPass() {
return std::make_unique<LowerToAffineLoops>();
}
// TODO: gradually remove this layer as more ops become "named".
template <typename LoopTy>
static Optional<LinalgLoops> linalgOpToLoopsImplSwitch(Operation *op,
OpBuilder &builder) {
assert(isa<LinalgOp>(op) && "LinalgOp expected");
if (isa<CopyOp>(op))
return linalgOpToLoopsImpl<LoopTy, CopyOp>(op, builder);
if (isa<FillOp>(op))
return linalgOpToLoopsImpl<LoopTy, FillOp>(op, builder);
if (isa<ConvOp>(op))
return linalgOpToLoopsImpl<LoopTy, ConvOp>(op, builder);
if (isa<PoolingMaxOp>(op))
return linalgOpToLoopsImpl<LoopTy, PoolingMaxOp>(op, builder);
if (isa<PoolingMinOp>(op))
return linalgOpToLoopsImpl<LoopTy, PoolingMinOp>(op, builder);
if (isa<PoolingSumOp>(op))
return linalgOpToLoopsImpl<LoopTy, PoolingSumOp>(op, builder);
if (isa<IndexedGenericOp>(op))
return linalgOpToLoopsImpl<LoopTy, IndexedGenericOp>(op, builder);
// TODO: Cases below are generic and need a LinalgStructuredOpInterface.
if (isa<GenericOp>(op))
return linalgOpToLoopsImpl<LoopTy, GenericOp>(op, builder);
if (isa<MatmulOp>(op))
return linalgOpToLoopsImpl<LoopTy, MatmulOp>(op, builder);
if (isa<MatvecOp>(op))
return linalgOpToLoopsImpl<LoopTy, MatvecOp>(op, builder);
if (isa<VecmatOp>(op))
return linalgOpToLoopsImpl<LoopTy, VecmatOp>(op, builder);
if (isa<DotOp>(op))
return linalgOpToLoopsImpl<LoopTy, DotOp>(op, builder);
if (isa<BatchMatmulOp>(op))
return linalgOpToLoopsImpl<LoopTy, BatchMatmulOp>(op, builder);
if (isa<ConvWOp>(op))
return linalgOpToLoopsImpl<LoopTy, ConvWOp>(op, builder);
if (isa<ConvNWCOp>(op))
return linalgOpToLoopsImpl<LoopTy, ConvNWCOp>(op, builder);
if (isa<ConvNCWOp>(op))
return linalgOpToLoopsImpl<LoopTy, ConvNCWOp>(op, builder);
if (isa<ConvHWOp>(op))
return linalgOpToLoopsImpl<LoopTy, ConvHWOp>(op, builder);
if (isa<ConvNHWCOp>(op))
return linalgOpToLoopsImpl<LoopTy, ConvNHWCOp>(op, builder);
if (isa<ConvNCHWOp>(op))
return linalgOpToLoopsImpl<LoopTy, ConvNCHWOp>(op, builder);
if (isa<ConvDHWOp>(op))
return linalgOpToLoopsImpl<LoopTy, ConvDHWOp>(op, builder);
if (isa<ConvNDHWCOp>(op))
return linalgOpToLoopsImpl<LoopTy, ConvNDHWCOp>(op, builder);
if (isa<ConvNCDHWOp>(op))
return linalgOpToLoopsImpl<LoopTy, ConvNCDHWOp>(op, builder);
llvm_unreachable("Unexpected op in linalgOpToLoopsImpl");
}
SmallVector<Range, 4> mlir::linalg::emitLoopRanges(OpBuilder &b, Location loc,
AffineMap map,
ValueRange viewSizes) {
@ -705,7 +649,7 @@ SmallVector<Range, 4> mlir::linalg::emitLoopRanges(OpBuilder &b, Location loc,
template <typename LoopTy>
Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder,
Operation *op) {
return linalgOpToLoopsImplSwitch<LoopTy>(op, builder);
return linalgOpToLoopsImpl<LoopTy>(op, builder);
}
template Optional<LinalgLoops>