NFC: Various code cleanups for Ch3.

This change refactors the toyc driver to be much cleaner and easier to extend. It also cleans up a few comments in the combiner.

PiperOrigin-RevId: 274973808
This commit is contained in:
River Riddle 2019-10-16 00:33:43 -07:00 committed by A. Unique TensorFlower
parent 950979745a
commit a08482c1ad
3 changed files with 56 additions and 50 deletions

View File

@ -15,8 +15,8 @@
// limitations under the License.
// =============================================================================
//
// This file implements a simple combiner for optimizing pattern in the Toy
// dialect.
// This file implements a set of simple combiners for optimizing operations in
// the Toy dialect.
//
//===----------------------------------------------------------------------===//
@ -32,7 +32,8 @@ namespace {
#include "ToyCombine.inc"
} // end anonymous namespace
/// Fold transpose(transpose(x) -> transpose(x)
/// This is an example of a c++ rewrite pattern for the TransposeOp. It
/// optimizes the following scenario: transpose(transpose(x)) -> transpose(x)
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// We register this pattern to match every toy.transpose in the IR.
/// The "benefit" is used by the framework to order the patterns and process
@ -41,8 +42,8 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
: OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
/// This method attempts to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. It is expected
/// to interact with it to perform any changes to the IR from here.
/// argument is the orchestrator of the sequence of rewrites. The pattern is
/// expected to interact with it to perform any changes to the IR from here.
mlir::PatternMatchResult
matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override {
@ -55,19 +56,21 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
if (!transposeInputOp)
return matchFailure();
// Use the rewriter to perform the replacement
// Use the rewriter to perform the replacement.
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
return matchSuccess();
}
};
/// Register our patterns for rewrite by the Canonicalization framework.
/// Register our patterns as "canonicalization" patterns on the TransposeOp so
/// that they can be picked up by the Canonicalization framework.
void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<SimplifyRedundantTranspose>(context);
}
/// Register our patterns for rewrite by the Canonicalization framework.
/// Register our patterns as "canonicalization" patterns on the ReshapeOp so
/// that they can be picked up by the Canonicalization framework.
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,

View File

@ -27,12 +27,13 @@
include "toy/Ops.td"
#endif // OP_BASE
/* Pattern-Match and Rewrite using DRR:
class Pattern<
dag sourcePattern, list<dag> resultPatterns,
list<dag> additionalConstraints = [],
dag benefitsAdded = (addBenefit 0)>;
*/
/// Note: The DRR definition used for defining patterns is shown below:
///
/// class Pattern<
/// dag sourcePattern, list<dag> resultPatterns,
/// list<dag> additionalConstraints = [],
/// dag benefitsAdded = (addBenefit 0)
/// >;
//===----------------------------------------------------------------------===//
// Basic Pattern-Match and Rewrite

View File

@ -79,29 +79,24 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
return parser.ParseModule();
}
mlir::LogicalResult optimize(mlir::ModuleOp module) {
mlir::PassManager pm(module.getContext());
pm.addPass(mlir::createCanonicalizerPass());
// Apply any generic pass manager command line options and run the pipeline.
applyPassManagerCLOptions(pm);
return pm.run(module);
int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
// Handle '.toy' input to the compiler.
if (inputType != InputType::MLIR &&
!llvm::StringRef(inputFilename).endswith(".mlir")) {
auto moduleAST = parseInputFile(inputFilename);
module = mlirGen(context, *moduleAST);
return !module ? 1 : 0;
}
int dumpMLIR() {
// Register our Dialect with MLIR
mlir::registerDialect<mlir::toy::ToyDialect>();
mlir::MLIRContext context;
mlir::OwningModuleRef module;
if (inputType == InputType::MLIR ||
llvm::StringRef(inputFilename).endswith(".mlir")) {
// Otherwise, the input is '.mlir'.
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
if (std::error_code EC = fileOrErr.getError()) {
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
return -1;
}
// Parse the input mlir.
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
module = mlir::parseSourceFile(sourceMgr, &context);
@ -109,22 +104,29 @@ int dumpMLIR() {
llvm::errs() << "Error can't load file " << inputFilename << "\n";
return 3;
}
if (failed(mlir::verify(*module))) {
llvm::errs() << "Error verifying MLIR module\n";
return 0;
}
int dumpMLIR() {
// Register our Dialect with MLIR.
mlir::registerDialect<mlir::toy::ToyDialect>();
mlir::MLIRContext context;
mlir::OwningModuleRef module;
if (int error = loadMLIR(context, module))
return error;
if (EnableOpt) {
mlir::PassManager pm(&context);
// Apply any generic pass manager command line options and run the pipeline.
applyPassManagerCLOptions(pm);
// Add a run of the canonicalizer to optimize the mlir module.
pm.addPass(mlir::createCanonicalizerPass());
if (mlir::failed(pm.run(*module)))
return 4;
}
} else {
auto moduleAST = parseInputFile(inputFilename);
module = mlirGen(context, *moduleAST);
}
if (!module)
return 1;
if (EnableOpt) {
if (failed(optimize(*module))) {
llvm::errs() << "Module optimization failed\n";
return 7;
}
}
module->dump();
return 0;
}