forked from OSchip/llvm-project
NFC: Various code cleanups for Ch3.
This change refactors the toyc driver to be much cleaner and easier to extend. It also cleans up a few comments in the combiner. PiperOrigin-RevId: 274973808
This commit is contained in:
parent
950979745a
commit
a08482c1ad
|
@ -15,8 +15,8 @@
|
|||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a simple combiner for optimizing pattern in the Toy
|
||||
// dialect.
|
||||
// This file implements a set of simple combiners for optimizing operations in
|
||||
// the Toy dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
@ -32,7 +32,8 @@ namespace {
|
|||
#include "ToyCombine.inc"
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Fold transpose(transpose(x) -> transpose(x)
|
||||
/// This is an example of a c++ rewrite pattern for the TransposeOp. It
|
||||
/// optimizes the following scenario: transpose(transpose(x)) -> transpose(x)
|
||||
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
||||
/// We register this pattern to match every toy.transpose in the IR.
|
||||
/// The "benefit" is used by the framework to order the patterns and process
|
||||
|
@ -41,8 +42,8 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
: OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
|
||||
|
||||
/// This method attempts to match a pattern and rewrite it. The rewriter
|
||||
/// argument is the orchestrator of the sequence of rewrites. It is expected
|
||||
/// to interact with it to perform any changes to the IR from here.
|
||||
/// argument is the orchestrator of the sequence of rewrites. The pattern is
|
||||
/// expected to interact with it to perform any changes to the IR from here.
|
||||
mlir::PatternMatchResult
|
||||
matchAndRewrite(TransposeOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
@ -55,19 +56,21 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
if (!transposeInputOp)
|
||||
return matchFailure();
|
||||
|
||||
// Use the rewriter to perform the replacement
|
||||
// Use the rewriter to perform the replacement.
|
||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
/// Register our patterns for rewrite by the Canonicalization framework.
|
||||
/// Register our patterns as "canonicalization" patterns on the TransposeOp so
|
||||
/// that they can be picked up by the Canonicalization framework.
|
||||
void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<SimplifyRedundantTranspose>(context);
|
||||
}
|
||||
|
||||
/// Register our patterns for rewrite by the Canonicalization framework.
|
||||
/// Register our patterns as "canonicalization" patterns on the ReshapeOp so
|
||||
/// that they can be picked up by the Canonicalization framework.
|
||||
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
|
||||
|
|
|
@ -27,12 +27,13 @@
|
|||
include "toy/Ops.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
/* Pattern-Match and Rewrite using DRR:
|
||||
class Pattern<
|
||||
dag sourcePattern, list<dag> resultPatterns,
|
||||
list<dag> additionalConstraints = [],
|
||||
dag benefitsAdded = (addBenefit 0)>;
|
||||
*/
|
||||
/// Note: The DRR definition used for defining patterns is shown below:
|
||||
///
|
||||
/// class Pattern<
|
||||
/// dag sourcePattern, list<dag> resultPatterns,
|
||||
/// list<dag> additionalConstraints = [],
|
||||
/// dag benefitsAdded = (addBenefit 0)
|
||||
/// >;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Basic Pattern-Match and Rewrite
|
||||
|
|
|
@ -79,29 +79,24 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
|
|||
return parser.ParseModule();
|
||||
}
|
||||
|
||||
mlir::LogicalResult optimize(mlir::ModuleOp module) {
|
||||
mlir::PassManager pm(module.getContext());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
|
||||
// Apply any generic pass manager command line options and run the pipeline.
|
||||
applyPassManagerCLOptions(pm);
|
||||
return pm.run(module);
|
||||
int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
|
||||
// Handle '.toy' input to the compiler.
|
||||
if (inputType != InputType::MLIR &&
|
||||
!llvm::StringRef(inputFilename).endswith(".mlir")) {
|
||||
auto moduleAST = parseInputFile(inputFilename);
|
||||
module = mlirGen(context, *moduleAST);
|
||||
return !module ? 1 : 0;
|
||||
}
|
||||
|
||||
int dumpMLIR() {
|
||||
// Register our Dialect with MLIR
|
||||
mlir::registerDialect<mlir::toy::ToyDialect>();
|
||||
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module;
|
||||
if (inputType == InputType::MLIR ||
|
||||
llvm::StringRef(inputFilename).endswith(".mlir")) {
|
||||
// Otherwise, the input is '.mlir'.
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
||||
if (std::error_code EC = fileOrErr.getError()) {
|
||||
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Parse the input mlir.
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
|
||||
module = mlir::parseSourceFile(sourceMgr, &context);
|
||||
|
@ -109,22 +104,29 @@ int dumpMLIR() {
|
|||
llvm::errs() << "Error can't load file " << inputFilename << "\n";
|
||||
return 3;
|
||||
}
|
||||
if (failed(mlir::verify(*module))) {
|
||||
llvm::errs() << "Error verifying MLIR module\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
int dumpMLIR() {
|
||||
// Register our Dialect with MLIR.
|
||||
mlir::registerDialect<mlir::toy::ToyDialect>();
|
||||
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module;
|
||||
if (int error = loadMLIR(context, module))
|
||||
return error;
|
||||
|
||||
if (EnableOpt) {
|
||||
mlir::PassManager pm(&context);
|
||||
// Apply any generic pass manager command line options and run the pipeline.
|
||||
applyPassManagerCLOptions(pm);
|
||||
|
||||
// Add a run of the canonicalizer to optimize the mlir module.
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
if (mlir::failed(pm.run(*module)))
|
||||
return 4;
|
||||
}
|
||||
} else {
|
||||
auto moduleAST = parseInputFile(inputFilename);
|
||||
module = mlirGen(context, *moduleAST);
|
||||
}
|
||||
if (!module)
|
||||
return 1;
|
||||
if (EnableOpt) {
|
||||
if (failed(optimize(*module))) {
|
||||
llvm::errs() << "Module optimization failed\n";
|
||||
return 7;
|
||||
}
|
||||
}
|
||||
|
||||
module->dump();
|
||||
return 0;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue