forked from OSchip/llvm-project
[mlir][Linalg] NFC - Cleanup explicitly instantiated paterns 2/n - Loops.cpp
This revision belongs to a series of patches that reduce reliance of Linalg transformations on templated rewrite and conversion patterns. Instead, this uses a MatchAnyTag pattern for the vast majority of cases and dispatches internally. Differential revision: https://reviews.llvm.org/D89133
This commit is contained in:
parent
e0dc3dba3b
commit
c303d9b394
|
@ -23,6 +23,8 @@
|
|||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
|
@ -65,7 +67,7 @@ static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues,
|
|||
assert(op.getOperation()->getNumRegions() == 1 &&
|
||||
"Expected single region op");
|
||||
auto &b = ScopedContext::getBuilderRef();
|
||||
auto &block = op.region().front();
|
||||
auto &block = op.getOperation()->getRegion(0).front();
|
||||
BlockAndValueMapping map;
|
||||
map.map(block.getArguments(), indexedValues);
|
||||
for (auto &op : block.without_terminator()) {
|
||||
|
@ -102,8 +104,6 @@ static InputAndOutputIndices getInputAndOutputIndices(ArrayRef<Value> allIvs,
|
|||
makeCanonicalAffineApplies(b, loc, maps[2], allIvs)};
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// Emits the MLIR for the scalar part of the generic op by:
|
||||
/// 1. Emitting load ops for each input and output view in order. This is
|
||||
/// achieved by applying the appropriate input or output map to the
|
||||
|
@ -134,10 +134,9 @@ namespace {
|
|||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
// TODO: need a LinalgStructuredOpInterface.
|
||||
template <typename IndexedValueType, typename LinalgStructuredOpType>
|
||||
void emitScalarImplementation(ArrayRef<Value> allIvs,
|
||||
LinalgStructuredOpType linalgOp) {
|
||||
template <typename IndexedValueType>
|
||||
static void emitScalarImplementation(ArrayRef<Value> allIvs,
|
||||
LinalgOp linalgOp) {
|
||||
assert(linalgOp.hasBufferSemantics() &&
|
||||
"expected linalg op with buffer semantics");
|
||||
auto &b = ScopedContext::getBuilderRef();
|
||||
|
@ -150,7 +149,7 @@ void emitScalarImplementation(ArrayRef<Value> allIvs,
|
|||
auto attr = linalgOp.template getAttrOfType<IntegerAttr>("symbol_source");
|
||||
auto allIvsPlusDims = SmallVector<Value, 4>(allIvs.begin(), allIvs.end());
|
||||
if (attr) {
|
||||
auto operand = linalgOp.getOperand(attr.getInt());
|
||||
auto operand = linalgOp.getOperation()->getOperand(attr.getInt());
|
||||
auto shapedType = operand.getType().template cast<ShapedType>();
|
||||
allIvsPlusDims.reserve(allIvs.size() + shapedType.getRank());
|
||||
for (unsigned idx = 0, e = shapedType.getRank(); idx < e; ++idx)
|
||||
|
@ -190,7 +189,7 @@ void emitScalarImplementation(ArrayRef<Value> allIvs,
|
|||
}
|
||||
|
||||
template <typename IndexedValueType>
|
||||
void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
|
||||
static void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
|
||||
assert(copyOp.hasBufferSemantics() &&
|
||||
"expected linalg op with buffer semantics");
|
||||
auto nPar = copyOp.getNumParallelLoops();
|
||||
|
@ -211,7 +210,7 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
|
|||
}
|
||||
|
||||
template <typename IndexedValueType>
|
||||
void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
|
||||
static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
|
||||
assert(fillOp.hasBufferSemantics() &&
|
||||
"expected linalg op with buffer semantics");
|
||||
auto nPar = fillOp.getNumParallelLoops();
|
||||
|
@ -224,8 +223,8 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
|
|||
}
|
||||
|
||||
template <typename IndexedValueType>
|
||||
Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
|
||||
MutableArrayRef<Value> imIdx) {
|
||||
static Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
|
||||
MutableArrayRef<Value> imIdx) {
|
||||
// TODO: add a level of indirection to linalg.generic.
|
||||
if (!convOp.padding())
|
||||
return im(imIdx);
|
||||
|
@ -311,8 +310,9 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename IndexedValueType>
|
||||
void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) {
|
||||
template <typename IndexedValueType, typename OpType>
|
||||
static void emitPoolingMinMaxScalarImplementation(ArrayRef<Value> allIvs,
|
||||
OpType op) {
|
||||
InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op);
|
||||
// Emit scalar form.
|
||||
IndexedValueType output(op.output());
|
||||
|
@ -320,30 +320,34 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) {
|
|||
Value lhs = output(indices.outputs);
|
||||
Value rhs = input(indices.inputs);
|
||||
using edsc::op::sgt;
|
||||
Value maxValue = std_select(sgt(lhs, rhs), lhs, rhs);
|
||||
output(indices.outputs) = maxValue;
|
||||
using edsc::op::slt;
|
||||
Value value = std::is_same<OpType, PoolingMinOp>()
|
||||
? std_select(slt(lhs, rhs), lhs, rhs)
|
||||
: std_select(sgt(lhs, rhs), lhs, rhs);
|
||||
output(indices.outputs) = value;
|
||||
}
|
||||
|
||||
template <typename IndexedValueType>
|
||||
void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) {
|
||||
InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op);
|
||||
// Emit scalar form.
|
||||
IndexedValueType output(op.output());
|
||||
IndexedValueType input(op.input());
|
||||
Value lhs = output(indices.outputs);
|
||||
Value rhs = input(indices.inputs);
|
||||
using edsc::op::slt;
|
||||
Value minValue = std_select(slt(lhs, rhs), lhs, rhs);
|
||||
output(indices.outputs) = minValue;
|
||||
static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) {
|
||||
emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMaxOp>(allIvs,
|
||||
op);
|
||||
}
|
||||
|
||||
template <typename IndexedValueType>
|
||||
void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) {
|
||||
static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) {
|
||||
emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMinOp>(allIvs,
|
||||
op);
|
||||
}
|
||||
|
||||
template <typename IndexedValueType>
|
||||
static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) {
|
||||
auto indices = getInputAndOutputIndices(allIvs, op);
|
||||
IndexedValueType input(op.input()), output(op.output());
|
||||
|
||||
// Emit scalar form.
|
||||
output(indices.outputs) += input(indices.inputs);
|
||||
}
|
||||
|
||||
/// Emits the MLIR for the scalar part of the indexed generic op by:
|
||||
/// 1. Emitting load ops for each input and output view in order. This is
|
||||
/// achieved by applying the appropriate input or output map to the
|
||||
|
@ -422,15 +426,16 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs,
|
|||
indexing, outputBuffers);
|
||||
}
|
||||
|
||||
template <typename LoopTy, typename ConcreteOpTy>
|
||||
Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
|
||||
template <typename LoopTy>
|
||||
static Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op,
|
||||
OpBuilder &builder) {
|
||||
using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
|
||||
|
||||
ScopedContext scope(builder, op->getLoc());
|
||||
|
||||
// The flattened loopToOperandRangesMaps is expected to be an invertible
|
||||
// permutation map (which is asserted in the inverse calculation).
|
||||
auto linalgOp = cast<ConcreteOpTy>(op);
|
||||
auto linalgOp = cast<LinalgOp>(op);
|
||||
assert(linalgOp.hasBufferSemantics() &&
|
||||
"expected linalg op with buffer semantics");
|
||||
auto mapsRange =
|
||||
|
@ -447,7 +452,12 @@ Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
|
|||
[&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector {
|
||||
assert(iterArgs.empty() && "unexpected iterArgs");
|
||||
allIvs.append(ivs.begin(), ivs.end());
|
||||
emitScalarImplementation<IndexedValueTy>(allIvs, linalgOp);
|
||||
llvm::TypeSwitch<Operation *>(op)
|
||||
.Case<CopyOp, FillOp, ConvOp, PoolingMaxOp, PoolingMinOp,
|
||||
PoolingSumOp, IndexedGenericOp, LinalgOp>([&](auto op) {
|
||||
emitScalarImplementation<IndexedValueTy>(allIvs, op);
|
||||
})
|
||||
.Default([&](Operation *op) { assert(false && "unexpected op"); });
|
||||
return scf::ValueVector{};
|
||||
});
|
||||
// Number of loop ops might be different from the number of ivs since some
|
||||
|
@ -467,32 +477,38 @@ Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
|
|||
return loops;
|
||||
}
|
||||
|
||||
template <typename LoopType, typename ConcreteOp>
|
||||
namespace {
|
||||
template <typename LoopType>
|
||||
class LinalgRewritePattern : public RewritePattern {
|
||||
public:
|
||||
explicit LinalgRewritePattern(MLIRContext *context)
|
||||
: RewritePattern(ConcreteOp::getOperationName(), 1, context) {}
|
||||
LinalgRewritePattern() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!linalgOpToLoopsImpl<LoopType, ConcreteOp>(op, rewriter))
|
||||
if (!isa<LinalgOp>(op))
|
||||
return failure();
|
||||
if (!linalgOpToLoopsImpl<LoopType>(op, rewriter))
|
||||
return failure();
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LoopType, typename ConcreteOp>
|
||||
void insertOnePattern(OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||
patterns.insert<LinalgRewritePattern<LoopType, ConcreteOp>>(ctx);
|
||||
}
|
||||
|
||||
template <typename LoopType, typename... Args>
|
||||
void insertPatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||
(void)std::initializer_list<int>{
|
||||
0, (insertOnePattern<LoopType, Args>(patterns, ctx), 0)...};
|
||||
struct FoldAffineOp;
|
||||
} // namespace
|
||||
|
||||
template <typename LoopType>
|
||||
static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) {
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<LinalgRewritePattern<LoopType>>();
|
||||
DimOp::getCanonicalizationPatterns(patterns, context);
|
||||
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
|
||||
patterns.insert<FoldAffineOp>(context);
|
||||
// Just apply the patterns greedily.
|
||||
applyPatternsAndFoldGreedily(funcOp, patterns);
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Local folding pattern for AffineApplyOp that we can apply greedily.
|
||||
/// This replaces AffineApplyOp by the proper value in cases where the
|
||||
/// associated map is trivial.
|
||||
|
@ -529,38 +545,20 @@ struct FoldAffineOp : public RewritePattern {
|
|||
return failure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
template <typename LoopType>
|
||||
static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) {
|
||||
OwningRewritePatternList patterns;
|
||||
// Canonicalization and folding patterns applied greedily allow cleaning up
|
||||
// the emitted IR on the fly.
|
||||
// TODO: fold view and subview ops?
|
||||
insertPatterns<LoopType,
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
|
||||
>(patterns, context);
|
||||
|
||||
DimOp::getCanonicalizationPatterns(patterns, context);
|
||||
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
|
||||
patterns.insert<FoldAffineOp>(context);
|
||||
// Just apply the patterns greedily.
|
||||
applyPatternsAndFoldGreedily(funcOp, patterns);
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct LowerToAffineLoops
|
||||
: public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> {
|
||||
void runOnFunction() override {
|
||||
lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), &getContext());
|
||||
}
|
||||
};
|
||||
|
||||
struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> {
|
||||
void runOnFunction() override {
|
||||
lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), &getContext());
|
||||
}
|
||||
};
|
||||
|
||||
struct LowerToParallelLoops
|
||||
: public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> {
|
||||
void runOnFunction() override {
|
||||
|
@ -583,60 +581,6 @@ mlir::createConvertLinalgToAffineLoopsPass() {
|
|||
return std::make_unique<LowerToAffineLoops>();
|
||||
}
|
||||
|
||||
// TODO: gradually remove this layer as more ops become "named".
|
||||
template <typename LoopTy>
|
||||
static Optional<LinalgLoops> linalgOpToLoopsImplSwitch(Operation *op,
|
||||
OpBuilder &builder) {
|
||||
assert(isa<LinalgOp>(op) && "LinalgOp expected");
|
||||
if (isa<CopyOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, CopyOp>(op, builder);
|
||||
if (isa<FillOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, FillOp>(op, builder);
|
||||
if (isa<ConvOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, ConvOp>(op, builder);
|
||||
if (isa<PoolingMaxOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, PoolingMaxOp>(op, builder);
|
||||
if (isa<PoolingMinOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, PoolingMinOp>(op, builder);
|
||||
if (isa<PoolingSumOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, PoolingSumOp>(op, builder);
|
||||
if (isa<IndexedGenericOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, IndexedGenericOp>(op, builder);
|
||||
|
||||
// TODO: Cases below are generic and need a LinalgStructuredOpInterface.
|
||||
if (isa<GenericOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, GenericOp>(op, builder);
|
||||
if (isa<MatmulOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, MatmulOp>(op, builder);
|
||||
if (isa<MatvecOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, MatvecOp>(op, builder);
|
||||
if (isa<VecmatOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, VecmatOp>(op, builder);
|
||||
if (isa<DotOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, DotOp>(op, builder);
|
||||
if (isa<BatchMatmulOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, BatchMatmulOp>(op, builder);
|
||||
if (isa<ConvWOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, ConvWOp>(op, builder);
|
||||
if (isa<ConvNWCOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, ConvNWCOp>(op, builder);
|
||||
if (isa<ConvNCWOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, ConvNCWOp>(op, builder);
|
||||
if (isa<ConvHWOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, ConvHWOp>(op, builder);
|
||||
if (isa<ConvNHWCOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, ConvNHWCOp>(op, builder);
|
||||
if (isa<ConvNCHWOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, ConvNCHWOp>(op, builder);
|
||||
if (isa<ConvDHWOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, ConvDHWOp>(op, builder);
|
||||
if (isa<ConvNDHWCOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, ConvNDHWCOp>(op, builder);
|
||||
if (isa<ConvNCDHWOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, ConvNCDHWOp>(op, builder);
|
||||
llvm_unreachable("Unexpected op in linalgOpToLoopsImpl");
|
||||
}
|
||||
|
||||
SmallVector<Range, 4> mlir::linalg::emitLoopRanges(OpBuilder &b, Location loc,
|
||||
AffineMap map,
|
||||
ValueRange viewSizes) {
|
||||
|
@ -705,7 +649,7 @@ SmallVector<Range, 4> mlir::linalg::emitLoopRanges(OpBuilder &b, Location loc,
|
|||
template <typename LoopTy>
|
||||
Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder,
|
||||
Operation *op) {
|
||||
return linalgOpToLoopsImplSwitch<LoopTy>(op, builder);
|
||||
return linalgOpToLoopsImpl<LoopTy>(op, builder);
|
||||
}
|
||||
|
||||
template Optional<LinalgLoops>
|
||||
|
|
Loading…
Reference in New Issue