NFC: Various code cleanups for Ch3.

This change refactors the toyc driver to be much cleaner and easier to extend. It also cleans up a few comments in the combiner.

PiperOrigin-RevId: 274973808
This commit is contained in:
River Riddle 2019-10-16 00:33:43 -07:00 committed by A. Unique TensorFlower
parent 950979745a
commit a08482c1ad
3 changed files with 56 additions and 50 deletions

View File

@ -15,8 +15,8 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
// //
// This file implements a simple combiner for optimizing pattern in the Toy // This file implements a set of simple combiners for optimizing operations in
// dialect. // the Toy dialect.
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -32,7 +32,8 @@ namespace {
#include "ToyCombine.inc" #include "ToyCombine.inc"
} // end anonymous namespace } // end anonymous namespace
/// Fold transpose(transpose(x) -> transpose(x) /// This is an example of a c++ rewrite pattern for the TransposeOp. It
/// optimizes the following scenario: transpose(transpose(x)) -> transpose(x)
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> { struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// We register this pattern to match every toy.transpose in the IR. /// We register this pattern to match every toy.transpose in the IR.
/// The "benefit" is used by the framework to order the patterns and process /// The "benefit" is used by the framework to order the patterns and process
@ -41,8 +42,8 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
: OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {} : OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
/// This method attempts to match a pattern and rewrite it. The rewriter /// This method attempts to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. It is expected /// argument is the orchestrator of the sequence of rewrites. The pattern is
/// to interact with it to perform any changes to the IR from here. /// expected to interact with it to perform any changes to the IR from here.
mlir::PatternMatchResult mlir::PatternMatchResult
matchAndRewrite(TransposeOp op, matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override { mlir::PatternRewriter &rewriter) const override {
@ -55,19 +56,21 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
if (!transposeInputOp) if (!transposeInputOp)
return matchFailure(); return matchFailure();
// Use the rewriter to perform the replacement // Use the rewriter to perform the replacement.
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
return matchSuccess(); return matchSuccess();
} }
}; };
/// Register our patterns for rewrite by the Canonicalization framework. /// Register our patterns as "canonicalization" patterns on the TransposeOp so
/// that they can be picked up by the Canonicalization framework.
void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
results.insert<SimplifyRedundantTranspose>(context); results.insert<SimplifyRedundantTranspose>(context);
} }
/// Register our patterns for rewrite by the Canonicalization framework. /// Register our patterns as "canonicalization" patterns on the ReshapeOp so
/// that they can be picked up by the Canonicalization framework.
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern, results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,

View File

@ -27,12 +27,13 @@
include "toy/Ops.td" include "toy/Ops.td"
#endif // OP_BASE #endif // OP_BASE
/* Pattern-Match and Rewrite using DRR: /// Note: The DRR definition used for defining patterns is shown below:
class Pattern< ///
dag sourcePattern, list<dag> resultPatterns, /// class Pattern<
list<dag> additionalConstraints = [], /// dag sourcePattern, list<dag> resultPatterns,
dag benefitsAdded = (addBenefit 0)>; /// list<dag> additionalConstraints = [],
*/ /// dag benefitsAdded = (addBenefit 0)
/// >;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Basic Pattern-Match and Rewrite // Basic Pattern-Match and Rewrite

View File

@ -79,29 +79,24 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
return parser.ParseModule(); return parser.ParseModule();
} }
mlir::LogicalResult optimize(mlir::ModuleOp module) { int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
mlir::PassManager pm(module.getContext()); // Handle '.toy' input to the compiler.
pm.addPass(mlir::createCanonicalizerPass()); if (inputType != InputType::MLIR &&
!llvm::StringRef(inputFilename).endswith(".mlir")) {
// Apply any generic pass manager command line options and run the pipeline. auto moduleAST = parseInputFile(inputFilename);
applyPassManagerCLOptions(pm); module = mlirGen(context, *moduleAST);
return pm.run(module); return !module ? 1 : 0;
} }
int dumpMLIR() { // Otherwise, the input is '.mlir'.
// Register our Dialect with MLIR
mlir::registerDialect<mlir::toy::ToyDialect>();
mlir::MLIRContext context;
mlir::OwningModuleRef module;
if (inputType == InputType::MLIR ||
llvm::StringRef(inputFilename).endswith(".mlir")) {
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
if (std::error_code EC = fileOrErr.getError()) { if (std::error_code EC = fileOrErr.getError()) {
llvm::errs() << "Could not open input file: " << EC.message() << "\n"; llvm::errs() << "Could not open input file: " << EC.message() << "\n";
return -1; return -1;
} }
// Parse the input mlir.
llvm::SourceMgr sourceMgr; llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
module = mlir::parseSourceFile(sourceMgr, &context); module = mlir::parseSourceFile(sourceMgr, &context);
@ -109,22 +104,29 @@ int dumpMLIR() {
llvm::errs() << "Error can't load file " << inputFilename << "\n"; llvm::errs() << "Error can't load file " << inputFilename << "\n";
return 3; return 3;
} }
if (failed(mlir::verify(*module))) { return 0;
llvm::errs() << "Error verifying MLIR module\n"; }
int dumpMLIR() {
// Register our Dialect with MLIR.
mlir::registerDialect<mlir::toy::ToyDialect>();
mlir::MLIRContext context;
mlir::OwningModuleRef module;
if (int error = loadMLIR(context, module))
return error;
if (EnableOpt) {
mlir::PassManager pm(&context);
// Apply any generic pass manager command line options and run the pipeline.
applyPassManagerCLOptions(pm);
// Add a run of the canonicalizer to optimize the mlir module.
pm.addPass(mlir::createCanonicalizerPass());
if (mlir::failed(pm.run(*module)))
return 4; return 4;
} }
} else {
auto moduleAST = parseInputFile(inputFilename);
module = mlirGen(context, *moduleAST);
}
if (!module)
return 1;
if (EnableOpt) {
if (failed(optimize(*module))) {
llvm::errs() << "Module optimization failed\n";
return 7;
}
}
module->dump(); module->dump();
return 0; return 0;
} }