forked from OSchip/llvm-project
NFC: Various code cleanups for Ch3.
This change refactors the toyc driver to be much cleaner and easier to extend. It also cleans up a few comments in the combiner. PiperOrigin-RevId: 274973808
This commit is contained in:
parent
950979745a
commit
a08482c1ad
|
@ -15,8 +15,8 @@
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
//
|
//
|
||||||
// This file implements a simple combiner for optimizing pattern in the Toy
|
// This file implements a set of simple combiners for optimizing operations in
|
||||||
// dialect.
|
// the Toy dialect.
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
@ -32,7 +32,8 @@ namespace {
|
||||||
#include "ToyCombine.inc"
|
#include "ToyCombine.inc"
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
/// Fold transpose(transpose(x) -> transpose(x)
|
/// This is an example of a c++ rewrite pattern for the TransposeOp. It
|
||||||
|
/// optimizes the following scenario: transpose(transpose(x)) -> transpose(x)
|
||||||
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
||||||
/// We register this pattern to match every toy.transpose in the IR.
|
/// We register this pattern to match every toy.transpose in the IR.
|
||||||
/// The "benefit" is used by the framework to order the patterns and process
|
/// The "benefit" is used by the framework to order the patterns and process
|
||||||
|
@ -41,8 +42,8 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
||||||
: OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
|
: OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
|
||||||
|
|
||||||
/// This method attempts to match a pattern and rewrite it. The rewriter
|
/// This method attempts to match a pattern and rewrite it. The rewriter
|
||||||
/// argument is the orchestrator of the sequence of rewrites. It is expected
|
/// argument is the orchestrator of the sequence of rewrites. The pattern is
|
||||||
/// to interact with it to perform any changes to the IR from here.
|
/// expected to interact with it to perform any changes to the IR from here.
|
||||||
mlir::PatternMatchResult
|
mlir::PatternMatchResult
|
||||||
matchAndRewrite(TransposeOp op,
|
matchAndRewrite(TransposeOp op,
|
||||||
mlir::PatternRewriter &rewriter) const override {
|
mlir::PatternRewriter &rewriter) const override {
|
||||||
|
@ -55,19 +56,21 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
||||||
if (!transposeInputOp)
|
if (!transposeInputOp)
|
||||||
return matchFailure();
|
return matchFailure();
|
||||||
|
|
||||||
// Use the rewriter to perform the replacement
|
// Use the rewriter to perform the replacement.
|
||||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
|
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Register our patterns for rewrite by the Canonicalization framework.
|
/// Register our patterns as "canonicalization" patterns on the TransposeOp so
|
||||||
|
/// that they can be picked up by the Canonicalization framework.
|
||||||
void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
results.insert<SimplifyRedundantTranspose>(context);
|
results.insert<SimplifyRedundantTranspose>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Register our patterns for rewrite by the Canonicalization framework.
|
/// Register our patterns as "canonicalization" patterns on the ReshapeOp so
|
||||||
|
/// that they can be picked up by the Canonicalization framework.
|
||||||
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
|
results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
|
||||||
|
|
|
@ -27,12 +27,13 @@
|
||||||
include "toy/Ops.td"
|
include "toy/Ops.td"
|
||||||
#endif // OP_BASE
|
#endif // OP_BASE
|
||||||
|
|
||||||
/* Pattern-Match and Rewrite using DRR:
|
/// Note: The DRR definition used for defining patterns is shown below:
|
||||||
class Pattern<
|
///
|
||||||
dag sourcePattern, list<dag> resultPatterns,
|
/// class Pattern<
|
||||||
list<dag> additionalConstraints = [],
|
/// dag sourcePattern, list<dag> resultPatterns,
|
||||||
dag benefitsAdded = (addBenefit 0)>;
|
/// list<dag> additionalConstraints = [],
|
||||||
*/
|
/// dag benefitsAdded = (addBenefit 0)
|
||||||
|
/// >;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Basic Pattern-Match and Rewrite
|
// Basic Pattern-Match and Rewrite
|
||||||
|
|
|
@ -79,29 +79,24 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
|
||||||
return parser.ParseModule();
|
return parser.ParseModule();
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::LogicalResult optimize(mlir::ModuleOp module) {
|
int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
|
||||||
mlir::PassManager pm(module.getContext());
|
// Handle '.toy' input to the compiler.
|
||||||
pm.addPass(mlir::createCanonicalizerPass());
|
if (inputType != InputType::MLIR &&
|
||||||
|
!llvm::StringRef(inputFilename).endswith(".mlir")) {
|
||||||
// Apply any generic pass manager command line options and run the pipeline.
|
auto moduleAST = parseInputFile(inputFilename);
|
||||||
applyPassManagerCLOptions(pm);
|
module = mlirGen(context, *moduleAST);
|
||||||
return pm.run(module);
|
return !module ? 1 : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int dumpMLIR() {
|
// Otherwise, the input is '.mlir'.
|
||||||
// Register our Dialect with MLIR
|
|
||||||
mlir::registerDialect<mlir::toy::ToyDialect>();
|
|
||||||
|
|
||||||
mlir::MLIRContext context;
|
|
||||||
mlir::OwningModuleRef module;
|
|
||||||
if (inputType == InputType::MLIR ||
|
|
||||||
llvm::StringRef(inputFilename).endswith(".mlir")) {
|
|
||||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||||
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
||||||
if (std::error_code EC = fileOrErr.getError()) {
|
if (std::error_code EC = fileOrErr.getError()) {
|
||||||
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Parse the input mlir.
|
||||||
llvm::SourceMgr sourceMgr;
|
llvm::SourceMgr sourceMgr;
|
||||||
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
|
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
|
||||||
module = mlir::parseSourceFile(sourceMgr, &context);
|
module = mlir::parseSourceFile(sourceMgr, &context);
|
||||||
|
@ -109,22 +104,29 @@ int dumpMLIR() {
|
||||||
llvm::errs() << "Error can't load file " << inputFilename << "\n";
|
llvm::errs() << "Error can't load file " << inputFilename << "\n";
|
||||||
return 3;
|
return 3;
|
||||||
}
|
}
|
||||||
if (failed(mlir::verify(*module))) {
|
return 0;
|
||||||
llvm::errs() << "Error verifying MLIR module\n";
|
}
|
||||||
|
|
||||||
|
int dumpMLIR() {
|
||||||
|
// Register our Dialect with MLIR.
|
||||||
|
mlir::registerDialect<mlir::toy::ToyDialect>();
|
||||||
|
|
||||||
|
mlir::MLIRContext context;
|
||||||
|
mlir::OwningModuleRef module;
|
||||||
|
if (int error = loadMLIR(context, module))
|
||||||
|
return error;
|
||||||
|
|
||||||
|
if (EnableOpt) {
|
||||||
|
mlir::PassManager pm(&context);
|
||||||
|
// Apply any generic pass manager command line options and run the pipeline.
|
||||||
|
applyPassManagerCLOptions(pm);
|
||||||
|
|
||||||
|
// Add a run of the canonicalizer to optimize the mlir module.
|
||||||
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
|
if (mlir::failed(pm.run(*module)))
|
||||||
return 4;
|
return 4;
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
auto moduleAST = parseInputFile(inputFilename);
|
|
||||||
module = mlirGen(context, *moduleAST);
|
|
||||||
}
|
|
||||||
if (!module)
|
|
||||||
return 1;
|
|
||||||
if (EnableOpt) {
|
|
||||||
if (failed(optimize(*module))) {
|
|
||||||
llvm::errs() << "Module optimization failed\n";
|
|
||||||
return 7;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
module->dump();
|
module->dump();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue