llvm-project/mlir/tools/mlir-opt/mlir-opt.cpp

230 lines
7.9 KiB
C++
Raw Normal View History

//===- mlir-opt.cpp - MLIR Optimizer Driver -------------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This is a command line utility that parses an MLIR file, runs an optimization
// pass, then prints the result back out. It is designed to support unit
// testing.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/Parser.h"
#include "mlir/Pass.h"
#include "mlir/Transforms/ConvertToCFG.h"
#include "mlir/Transforms/Loop.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
using namespace mlir;
using namespace llvm;
static cl::opt<std::string>
inputFilename(cl::Positional, cl::desc("<input file>"), cl::init("-"));
static cl::opt<std::string>
outputFilename("o", cl::desc("Output filename"), cl::value_desc("filename"),
cl::init("-"));
static cl::opt<bool>
checkParserErrors("check-parser-errors", cl::desc("Check for parser errors"),
cl::init(false));
static cl::opt<bool> convertToCFGOpt(
"convert-to-cfg",
cl::desc("Convert all ML functions in the module to CFG ones"));
static cl::opt<bool> unrollInnermostLoops("unroll-innermost-loops",
cl::desc("Unroll innermost loops"),
cl::init(false));
enum OptResult { OptSuccess, OptFailure };
/// Open the specified output file and return it, exiting if there is any I/O or
/// other errors.
static std::unique_ptr<ToolOutputFile> getOutputStream() {
std::error_code error;
auto result = make_unique<ToolOutputFile>(outputFilename, error,
sys::fs::F_None);
if (error) {
llvm::errs() << error.message() << '\n';
exit(1);
}
return result;
}
/// Parses the memory buffer and, if successfully parsed, prints the parsed
/// output. Optionally, convert ML functions into CFG functions.
/// TODO: pull parsing and printing into separate functions.
OptResult parseAndPrintMemoryBuffer(std::unique_ptr<MemoryBuffer> buffer) {
// Tell sourceMgr about this buffer, which is what the parser will pick up.
SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(buffer), SMLoc());
// Parse the input file.
MLIRContext context;
std::unique_ptr<Module> module(parseSourceFile(sourceMgr, &context));
if (!module)
return OptFailure;
// Convert ML functions into CFG functions
if (convertToCFGOpt)
convertToCFG(module.get());
if (unrollInnermostLoops) {
MLFunctionPass *loopUnroll = createLoopUnrollPass();
loopUnroll->runOnModule(module.get());
}
// Print the output.
auto output = getOutputStream();
module->print(output->os());
output->keep();
return OptSuccess;
}
/// Split the memory buffer into multiple buffers using the marker -----.
OptResult
splitMemoryBufferForErrorChecking(std::unique_ptr<MemoryBuffer> buffer) {
const char marker[] = "-----";
SmallVector<StringRef, 2> sourceBuffers;
buffer->getBuffer().split(sourceBuffers, marker);
// Error reporter that verifies error reports matches expected error
// substring.
// TODO: Only checking for error cases below. Could be expanded to other kinds
// of diagnostics.
// TODO: Enable specifying errors on different lines (@-1).
// TODO: Currently only checking if substring matches, enable regex checking.
OptResult opt_result = OptSuccess;
SourceMgr fileSourceMgr;
fileSourceMgr.AddNewSourceBuffer(std::move(buffer), SMLoc());
// Record the expected errors's position, substring and whether it was seen.
struct ExpectedError {
int lineNo;
StringRef substring;
SMLoc fileLoc;
bool matched;
};
// Tracks offset of subbuffer into original buffer.
const char *fileOffset =
fileSourceMgr.getMemoryBuffer(fileSourceMgr.getMainFileID())
->getBufferStart();
for (auto &subbuffer : sourceBuffers) {
SourceMgr sourceMgr;
// Tell sourceMgr about this buffer, which is what the parser will pick up.
sourceMgr.AddNewSourceBuffer(MemoryBuffer::getMemBufferCopy(subbuffer),
SMLoc());
// Extracing the expected errors.
llvm::Regex expected("expected-error(@[+-][0-9]+)? *{{(.*)}}");
SmallVector<ExpectedError, 2> expectedErrors;
SmallVector<StringRef, 100> lines;
subbuffer.split(lines, '\n');
size_t bufOffset = 0;
for (int lineNo = 0; lineNo < lines.size(); ++lineNo) {
SmallVector<StringRef, 3> matches;
if (expected.match(lines[lineNo], &matches)) {
// Point to the start of expected-error.
SMLoc errorStart =
SMLoc::getFromPointer(fileOffset + bufOffset +
lines[lineNo].size() - matches[2].size() - 2);
ExpectedError expErr{lineNo + 1, matches[2], errorStart, false};
int offset;
if (!matches[1].empty() &&
!matches[1].drop_front().getAsInteger(0, offset)) {
expErr.lineNo += offset;
}
expectedErrors.push_back(expErr);
}
bufOffset += lines[lineNo].size() + 1;
}
// Error checker that verifies reported error was expected.
auto checker = [&](const SMDiagnostic &err) {
for (auto &e : expectedErrors) {
if (err.getLineNo() == e.lineNo &&
err.getMessage().contains(e.substring)) {
e.matched = true;
return;
}
}
// Report error if no match found.
const auto &sourceMgr = *err.getSourceMgr();
const char *bufferStart =
sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID())
->getBufferStart();
size_t offset = err.getLoc().getPointer() - bufferStart;
SMLoc loc = SMLoc::getFromPointer(fileOffset + offset);
fileSourceMgr.PrintMessage(loc, SourceMgr::DK_Error,
"unexpected error: " + err.getMessage());
opt_result = OptFailure;
};
// Parse the input file.
MLIRContext context;
std::unique_ptr<Module> module(
parseSourceFile(sourceMgr, &context, checker));
// Verify that all expected errors were seen.
for (auto err : expectedErrors) {
if (!err.matched) {
SMRange range(err.fileLoc,
SMLoc::getFromPointer(err.fileLoc.getPointer() +
err.substring.size()));
fileSourceMgr.PrintMessage(
err.fileLoc, SourceMgr::DK_Error,
"expected error \"" + err.substring + "\" was not produced", range);
opt_result = OptFailure;
}
}
fileOffset += subbuffer.size() + strlen(marker);
}
return opt_result;
}
int main(int argc, char **argv) {
InitLLVM x(argc, argv);
cl::ParseCommandLineOptions(argc, argv, "MLIR modular optimizer driver\n");
// Set up the input file.
auto fileOrErr = MemoryBuffer::getFileOrSTDIN(inputFilename);
if (std::error_code error = fileOrErr.getError()) {
llvm::errs() << argv[0] << ": could not open input file '" << inputFilename
<< "': " << error.message() << "\n";
return 1;
}
if (checkParserErrors)
return splitMemoryBufferForErrorChecking(std::move(*fileOrErr));
return parseAndPrintMemoryBuffer(std::move(*fileOrErr));
}