diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 33de538919d6..4238bb8b8b14 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -94,47 +94,93 @@ bool splitMemoryBufferForErrorChecking(std::unique_ptr buffer) { // TODO: Enable specifying errors on different lines (@-1). // TODO: Currently only checking if substring matches, enable regex checking. bool failed = false; - SMDiagnosticHandlerTy errorChecker = [&failed](llvm::SMDiagnostic err) { + SourceMgr fileSourceMgr; + fileSourceMgr.AddNewSourceBuffer(std::move(buffer), SMLoc()); + + // Tracks offset of subbuffer into original buffer. + const char *fileOffset = + fileSourceMgr.getMemoryBuffer(fileSourceMgr.getMainFileID()) + ->getBufferStart(); + + // Create error checker that uses the helper function to relate the reported + // error to the file being parsed. + SMDiagnosticHandlerTy checker = [&](SMDiagnostic err) { + const auto &sourceMgr = *err.getSourceMgr(); + const char *bufferStart = + sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID())->getBufferStart(); + StringRef line = err.getLineContents(); + size_t offset = err.getLoc().getPointer() - bufferStart; + SMLoc loc = SMLoc::getFromPointer(fileOffset + offset); + // Extract expected substring using regex and check simple containment in // error message. llvm::Regex expected("expected-error {{(.*)}}"); SmallVector matches; bool matched = expected.match(line, &matches); if (matches.size() != 2) { - const auto& sourceMgr = *err.getSourceMgr(); - sourceMgr.PrintMessage(err.getLoc(), SourceMgr::DK_Error, - "unexpected error: " + err.getMessage()); + fileSourceMgr.PrintMessage( + loc, SourceMgr::DK_Error, + "unexpected error: " + err.getMessage()); failed = true; return; } matched = err.getMessage().contains(matches[1]); if (!matched) { - llvm::errs() << "Expected error substring (" << matches[1] - << ") not found in error `" << err.getMessage() << "`.\n"; + const char checkPrefix[] = "expected-error {{"; + loc = SMLoc::getFromPointer(fileOffset + offset + line.find(checkPrefix) - + err.getColumnNo() + strlen(checkPrefix)); + fileSourceMgr.PrintMessage(loc, SourceMgr::DK_Error, + "\"" + err.getMessage() + + "\" did not contain expected substring \"" + + matches[1] + "\""); failed = true; + return; } }; for (auto& subbuffer : sourceBuffers) { + SourceMgr sourceMgr; + // Tell sourceMgr about this buffer, which is what the parser will pick up. + sourceMgr.AddNewSourceBuffer(MemoryBuffer::getMemBufferCopy(subbuffer), + SMLoc()); + int expectedCount = subbuffer.count("expected-error"); if (expectedCount > 1) { - llvm::errs() << "Unable to verify more than 1 error per group.\n"; + size_t expectedOffset = subbuffer.find("expected-error"); + expectedOffset = subbuffer.find("expected-error", expectedOffset); + SMLoc loc = SMLoc::getFromPointer(fileOffset + expectedOffset); + fileSourceMgr.PrintMessage(loc, SourceMgr::DK_Error, + "too many errors expected: unable to verify " + "more than one error per group"); + fileOffset += subbuffer.size() + strlen(marker); failed = true; continue; } - bool parsed = parseAndPrintMemoryBuffer( - MemoryBuffer::getMemBufferCopy(subbuffer), errorChecker); + // Parse the input file. + MLIRContext context; + std::unique_ptr module( + parseSourceFile(sourceMgr, &context, checker)); + bool parsed = module != nullptr; if (parsed && expectedCount != 0) { - llvm::Regex expected("expected-error {{(.*)}}"); + llvm::Regex expected(".*expected-error {{(.*)}}"); SmallVector matches; expected.match(subbuffer, &matches); - llvm::errs() << "Expected an error (" << matches[1] - << ") but no error reported.\n"; + + // Highlight expected-error clause of unexpectedly passing test case. + size_t expectedOffset = subbuffer.find("expected-error"); + size_t endOffset = matches[0].size(); + SMLoc loc = SMLoc::getFromPointer(fileOffset + expectedOffset); + SMRange range(loc, SMLoc::getFromPointer(fileOffset + endOffset)); + fileSourceMgr.PrintMessage( + loc, SourceMgr::DK_Error, + "expected error \"" + matches[1] + "\" was not produced", range); failed = true; } + + fileOffset += subbuffer.size() + strlen(marker); } return !failed;