[mlir:Linalg] Populate LinalgOp patterns on LinalgDialect as opposed to each op

Interface patterns are unique in that they get added to every operation that also implements that interface, given that they aren't tied to individual operations. When the same interface pattern gets added to multiple operations (such as the current behavior with Linalg), an reference to each of these patterns is added to every op (meaning that an operation will now have N references to effectively the same pattern). This revision fixes this problematic behavior in Linalg, and can bring upwards of a 25% reduction in compile time in Linalg based workloads.

Differential Revision: https://reviews.llvm.org/D104160
This commit is contained in:
River Riddle 2021-06-14 11:09:43 -07:00
parent 75d3b46ad2
commit 66e2708205
7 changed files with 27 additions and 43 deletions

View File

@ -35,6 +35,7 @@ def Linalg_Dialect : Dialect {
let dependentDialects = [
"AffineDialect", "StandardOpsDialect", "tensor::TensorDialect"
];
let hasCanonicalizer = 1;
let hasOperationAttrVerify = 1;
let extraClassDeclaration = [{
/// Attribute name used to to memoize indexing maps for named ops.

View File

@ -178,7 +178,6 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
let skipDefaultBuilders = 1;
}
@ -230,7 +229,6 @@ def FillOp : LinalgStructured_Op<"fill", []> {
let verifier = [{ return ::verify(*this); }];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
/// A base class for pooling operation such as conv. The arguments must contain
@ -427,7 +425,6 @@ def ConvOp : PoolingBase_Op<"conv", []> {
let verifier = [{ return ::verify(*this); }];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
// Only support buffer semantics.
@ -490,7 +487,6 @@ class SingleInputPoolingBase_Op<string mnemonic>
let verifier = [{ return ::verify(*this); }];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
def PoolingMaxOp: SingleInputPoolingBase_Op<"pooling_max"> {
@ -673,7 +669,6 @@ def GenericOp : GenericOpBase<"generic"> {
let verifier = [{ return ::verify(*this); }];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
/// GenericOp with Indexing (i.e. multi-for style in which the region is passed

View File

@ -2787,11 +2787,6 @@ DEFINE_POOLING_OP_GET_EFFECTS(PoolingMaxOp)
DEFINE_POOLING_OP_GET_EFFECTS(PoolingMinOp)
DEFINE_POOLING_OP_GET_EFFECTS(PoolingSumOp)
namespace {
struct EraseDeadLinalgOp;
struct FoldTensorCastOp;
} // namespace
#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.tcgen.cpp.inc"
#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
@ -3374,25 +3369,29 @@ struct RemoveIdentityLinalgOps : public OpInterfaceRewritePattern<LinalgOp> {
};
} // namespace
#define CANONICALIZERS_AND_FOLDERS(XXX) \
void XXX::getCanonicalizationPatterns(RewritePatternSet &results, \
MLIRContext *context) { \
results.add<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp, \
RemoveIdentityLinalgOps>(context); \
} \
\
#define LINALGOP_FOLDERS(XXX) \
LogicalResult XXX::fold(ArrayRef<Attribute>, \
SmallVectorImpl<OpFoldResult> &) { \
return foldMemRefCast(*this); \
}
CANONICALIZERS_AND_FOLDERS(ConvOp)
CANONICALIZERS_AND_FOLDERS(PoolingMaxOp)
CANONICALIZERS_AND_FOLDERS(PoolingMinOp)
CANONICALIZERS_AND_FOLDERS(PoolingSumOp)
CANONICALIZERS_AND_FOLDERS(CopyOp)
CANONICALIZERS_AND_FOLDERS(FillOp)
CANONICALIZERS_AND_FOLDERS(GenericOp)
LINALGOP_FOLDERS(ConvOp)
LINALGOP_FOLDERS(PoolingMaxOp)
LINALGOP_FOLDERS(PoolingMinOp)
LINALGOP_FOLDERS(PoolingSumOp)
LINALGOP_FOLDERS(CopyOp)
LINALGOP_FOLDERS(FillOp)
LINALGOP_FOLDERS(GenericOp)
// All named ops canonicalizers and folders are auto-generated in the
// .cpp.inc.
//===----------------------------------------------------------------------===//
// LinalgDialect
//===----------------------------------------------------------------------===//
void LinalgDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp,
RemoveIdentityLinalgOps>(getContext());
}

View File

@ -1405,6 +1405,8 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context);
TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context);
context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
patterns);
}
void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {

View File

@ -414,6 +414,7 @@ void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns);
CanonicalizationPatternList<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"

View File

@ -1959,7 +1959,6 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/);
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
let extraClassDeclaration = structuredOpsBaseDecls # [{{
// Auto-generated.
@ -2094,13 +2093,7 @@ void TCParser::printReferenceIterators(llvm::raw_ostream &os,
void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os,
StringRef cppOpName) {
const char *canonicalizersAndFoldersFmt = R"FMT(
void {0}::getCanonicalizationPatterns(
RewritePatternSet &results,
MLIRContext *context) {{
results.add<EraseDeadLinalgOp>(context);
results.add<FoldTensorCastOp>(context);
}
const char *foldersFmt = R"FMT(
LogicalResult {0}::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {{
return foldMemRefCast(*this);
@ -2112,7 +2105,7 @@ void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os,
getGenericEffectsImpl(effects,
getOperation()->getResults(), inputBuffers, outputBuffers);
})FMT";
os << llvm::formatv(canonicalizersAndFoldersFmt, cppOpName);
os << llvm::formatv(foldersFmt, cppOpName);
}
// Prints methods for querying whether the current named op has attributes that

View File

@ -503,7 +503,6 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/);
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
let extraClassDeclaration = structuredOpsBaseDecls # [{{
// Auto-generated.
@ -535,16 +534,10 @@ ArrayAttr {0}::iterator_types() {
}
)FMT";
// Implementations of getCanonicalizationPatterns, fold and getEffects.
// Implementations of fold and getEffects.
// Parameters:
// {0}: Class name
const char structuredOpCanonicalizersAndFoldersFormat[] = R"FMT(
void {0}::getCanonicalizationPatterns(
RewritePatternSet &results,
MLIRContext *context) {{
results.add<EraseDeadLinalgOp>(context);
results.add<FoldTensorCastOp>(context);
}
const char structuredOpFoldersFormat[] = R"FMT(
LogicalResult {0}::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {{
return foldMemRefCast(*this);
@ -880,7 +873,7 @@ void {0}::regionBuilder(
}
// Canonicalizers and folders.
os << llvm::formatv(structuredOpCanonicalizersAndFoldersFormat, className);
os << llvm::formatv(structuredOpFoldersFormat, className);
return success();
}