forked from OSchip/llvm-project
[mlir:Linalg] Populate LinalgOp patterns on LinalgDialect as opposed to each op
Interface patterns are unique in that they get added to every operation that also implements that interface, given that they aren't tied to individual operations. When the same interface pattern gets added to multiple operations (such as the current behavior with Linalg), an reference to each of these patterns is added to every op (meaning that an operation will now have N references to effectively the same pattern). This revision fixes this problematic behavior in Linalg, and can bring upwards of a 25% reduction in compile time in Linalg based workloads. Differential Revision: https://reviews.llvm.org/D104160
This commit is contained in:
parent
75d3b46ad2
commit
66e2708205
|
@ -35,6 +35,7 @@ def Linalg_Dialect : Dialect {
|
|||
let dependentDialects = [
|
||||
"AffineDialect", "StandardOpsDialect", "tensor::TensorDialect"
|
||||
];
|
||||
let hasCanonicalizer = 1;
|
||||
let hasOperationAttrVerify = 1;
|
||||
let extraClassDeclaration = [{
|
||||
/// Attribute name used to to memoize indexing maps for named ops.
|
||||
|
|
|
@ -178,7 +178,6 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
|
|||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
let skipDefaultBuilders = 1;
|
||||
}
|
||||
|
||||
|
@ -230,7 +229,6 @@ def FillOp : LinalgStructured_Op<"fill", []> {
|
|||
let verifier = [{ return ::verify(*this); }];
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
/// A base class for pooling operation such as conv. The arguments must contain
|
||||
|
@ -427,7 +425,6 @@ def ConvOp : PoolingBase_Op<"conv", []> {
|
|||
let verifier = [{ return ::verify(*this); }];
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
// Only support buffer semantics.
|
||||
|
@ -490,7 +487,6 @@ class SingleInputPoolingBase_Op<string mnemonic>
|
|||
let verifier = [{ return ::verify(*this); }];
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def PoolingMaxOp: SingleInputPoolingBase_Op<"pooling_max"> {
|
||||
|
@ -673,7 +669,6 @@ def GenericOp : GenericOpBase<"generic"> {
|
|||
let verifier = [{ return ::verify(*this); }];
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
/// GenericOp with Indexing (i.e. multi-for style in which the region is passed
|
||||
|
|
|
@ -2787,11 +2787,6 @@ DEFINE_POOLING_OP_GET_EFFECTS(PoolingMaxOp)
|
|||
DEFINE_POOLING_OP_GET_EFFECTS(PoolingMinOp)
|
||||
DEFINE_POOLING_OP_GET_EFFECTS(PoolingSumOp)
|
||||
|
||||
namespace {
|
||||
struct EraseDeadLinalgOp;
|
||||
struct FoldTensorCastOp;
|
||||
} // namespace
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.tcgen.cpp.inc"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
|
||||
|
||||
|
@ -3374,25 +3369,29 @@ struct RemoveIdentityLinalgOps : public OpInterfaceRewritePattern<LinalgOp> {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
#define CANONICALIZERS_AND_FOLDERS(XXX) \
|
||||
void XXX::getCanonicalizationPatterns(RewritePatternSet &results, \
|
||||
MLIRContext *context) { \
|
||||
results.add<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp, \
|
||||
RemoveIdentityLinalgOps>(context); \
|
||||
} \
|
||||
\
|
||||
#define LINALGOP_FOLDERS(XXX) \
|
||||
LogicalResult XXX::fold(ArrayRef<Attribute>, \
|
||||
SmallVectorImpl<OpFoldResult> &) { \
|
||||
return foldMemRefCast(*this); \
|
||||
}
|
||||
|
||||
CANONICALIZERS_AND_FOLDERS(ConvOp)
|
||||
CANONICALIZERS_AND_FOLDERS(PoolingMaxOp)
|
||||
CANONICALIZERS_AND_FOLDERS(PoolingMinOp)
|
||||
CANONICALIZERS_AND_FOLDERS(PoolingSumOp)
|
||||
CANONICALIZERS_AND_FOLDERS(CopyOp)
|
||||
CANONICALIZERS_AND_FOLDERS(FillOp)
|
||||
CANONICALIZERS_AND_FOLDERS(GenericOp)
|
||||
LINALGOP_FOLDERS(ConvOp)
|
||||
LINALGOP_FOLDERS(PoolingMaxOp)
|
||||
LINALGOP_FOLDERS(PoolingMinOp)
|
||||
LINALGOP_FOLDERS(PoolingSumOp)
|
||||
LINALGOP_FOLDERS(CopyOp)
|
||||
LINALGOP_FOLDERS(FillOp)
|
||||
LINALGOP_FOLDERS(GenericOp)
|
||||
|
||||
// All named ops canonicalizers and folders are auto-generated in the
|
||||
// .cpp.inc.
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LinalgDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void LinalgDialect::getCanonicalizationPatterns(
|
||||
RewritePatternSet &results) const {
|
||||
results.add<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp,
|
||||
RemoveIdentityLinalgOps>(getContext());
|
||||
}
|
||||
|
|
|
@ -1405,6 +1405,8 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
|
|||
IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
|
||||
TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context);
|
||||
TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context);
|
||||
context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
|
||||
patterns);
|
||||
}
|
||||
|
||||
void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
|
||||
|
|
|
@ -414,6 +414,7 @@ void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
|
|||
memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
|
||||
tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
|
||||
memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
|
||||
ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns);
|
||||
CanonicalizationPatternList<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
|
||||
|
|
|
@ -1959,7 +1959,6 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
|
|||
return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/);
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let extraClassDeclaration = structuredOpsBaseDecls # [{{
|
||||
// Auto-generated.
|
||||
|
@ -2094,13 +2093,7 @@ void TCParser::printReferenceIterators(llvm::raw_ostream &os,
|
|||
|
||||
void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os,
|
||||
StringRef cppOpName) {
|
||||
const char *canonicalizersAndFoldersFmt = R"FMT(
|
||||
void {0}::getCanonicalizationPatterns(
|
||||
RewritePatternSet &results,
|
||||
MLIRContext *context) {{
|
||||
results.add<EraseDeadLinalgOp>(context);
|
||||
results.add<FoldTensorCastOp>(context);
|
||||
}
|
||||
const char *foldersFmt = R"FMT(
|
||||
LogicalResult {0}::fold(ArrayRef<Attribute>,
|
||||
SmallVectorImpl<OpFoldResult> &) {{
|
||||
return foldMemRefCast(*this);
|
||||
|
@ -2112,7 +2105,7 @@ void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os,
|
|||
getGenericEffectsImpl(effects,
|
||||
getOperation()->getResults(), inputBuffers, outputBuffers);
|
||||
})FMT";
|
||||
os << llvm::formatv(canonicalizersAndFoldersFmt, cppOpName);
|
||||
os << llvm::formatv(foldersFmt, cppOpName);
|
||||
}
|
||||
|
||||
// Prints methods for querying whether the current named op has attributes that
|
||||
|
|
|
@ -503,7 +503,6 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
|
|||
return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/);
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let extraClassDeclaration = structuredOpsBaseDecls # [{{
|
||||
// Auto-generated.
|
||||
|
@ -535,16 +534,10 @@ ArrayAttr {0}::iterator_types() {
|
|||
}
|
||||
)FMT";
|
||||
|
||||
// Implementations of getCanonicalizationPatterns, fold and getEffects.
|
||||
// Implementations of fold and getEffects.
|
||||
// Parameters:
|
||||
// {0}: Class name
|
||||
const char structuredOpCanonicalizersAndFoldersFormat[] = R"FMT(
|
||||
void {0}::getCanonicalizationPatterns(
|
||||
RewritePatternSet &results,
|
||||
MLIRContext *context) {{
|
||||
results.add<EraseDeadLinalgOp>(context);
|
||||
results.add<FoldTensorCastOp>(context);
|
||||
}
|
||||
const char structuredOpFoldersFormat[] = R"FMT(
|
||||
LogicalResult {0}::fold(ArrayRef<Attribute>,
|
||||
SmallVectorImpl<OpFoldResult> &) {{
|
||||
return foldMemRefCast(*this);
|
||||
|
@ -880,7 +873,7 @@ void {0}::regionBuilder(
|
|||
}
|
||||
|
||||
// Canonicalizers and folders.
|
||||
os << llvm::formatv(structuredOpCanonicalizersAndFoldersFormat, className);
|
||||
os << llvm::formatv(structuredOpFoldersFormat, className);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue