[mlir][Linalg] NFC - Automate the printing of canonicalizers and folders for nameds Linalg ops.

This revision reduces the number of places that specific information needs to be modified when adding new named Linalg ops.

Differential Revision: https://reviews.llvm.org/D89223
This commit is contained in:
Nicolas Vasilache 2020-10-12 11:21:43 +00:00
parent 422aaf31da
commit 69d3247f35
2 changed files with 27 additions and 15 deletions

View File

@ -1150,6 +1150,11 @@ static LogicalResult verify(PoolingSumOp op) {
return verifySingleInputPoolingOp(op);
}
namespace {
struct EraseDeadLinalgOp;
struct FoldTensorCastOp;
} // namespace
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.cpp.inc"
#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc"
@ -1578,18 +1583,4 @@ CANONICALIZERS_AND_FOLDERS(FillOp)
CANONICALIZERS_AND_FOLDERS(GenericOp)
CANONICALIZERS_AND_FOLDERS(IndexedGenericOp)
// TODO: Determine whether we can generate the folders and verifiers.
CANONICALIZERS_AND_FOLDERS(BatchMatmulOp)
CANONICALIZERS_AND_FOLDERS(DotOp)
CANONICALIZERS_AND_FOLDERS(MatmulOp)
CANONICALIZERS_AND_FOLDERS(MatvecOp)
CANONICALIZERS_AND_FOLDERS(VecmatOp)
CANONICALIZERS_AND_FOLDERS(ConvWOp)
CANONICALIZERS_AND_FOLDERS(ConvNWCOp)
CANONICALIZERS_AND_FOLDERS(ConvNCWOp)
CANONICALIZERS_AND_FOLDERS(ConvHWOp)
CANONICALIZERS_AND_FOLDERS(ConvNHWCOp)
CANONICALIZERS_AND_FOLDERS(ConvNCHWOp)
CANONICALIZERS_AND_FOLDERS(ConvDHWOp)
CANONICALIZERS_AND_FOLDERS(ConvNDHWCOp)
CANONICALIZERS_AND_FOLDERS(ConvNCDHWOp)
// All named ops canonicalizers and folders are auto-generated in the .cpp.inc.

View File

@ -994,6 +994,10 @@ public:
void printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
ComprehensionParsingState &state);
/// Print the C++ impl for named ops canonicalizers and fodlers.
void printCanonicalizersAndFolders(llvm::raw_ostream &os,
StringRef cppOpName);
private:
//===--------------------------------------------------------------------===//
// Internal bookkeeping of tensors.
@ -1430,6 +1434,7 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
printReferenceIterators(ss, cppOpName, state);
printReferenceIndexingMaps(ss, cppOpName, state);
printRegionBuilder(ss, cppOpName, state);
printCanonicalizersAndFolders(ss, cppOpName);
ss.flush();
os << extraMethods << "\n";
}
@ -1571,6 +1576,22 @@ void TCParser::printReferenceIterators(llvm::raw_ostream &os,
os << llvm::formatv(referenceReferenceIteratorsFmt, cppOpName, iteratorsStr);
}
void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os,
StringRef cppOpName) {
const char *canonicalizersAndFoldersFmt = R"FMT(
void {0}::getCanonicalizationPatterns(
OwningRewritePatternList &results,
MLIRContext *context) {{
results.insert<EraseDeadLinalgOp>();
results.insert<FoldTensorCastOp>();
}
LogicalResult {0}::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {{
return foldMemRefCast(*this);
})FMT";
os << llvm::formatv(canonicalizersAndFoldersFmt, cppOpName);
}
/// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`.
void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
StringRef cppOpName,