forked from OSchip/llvm-project
Revert "Add indented raw_ostream class"
This reverts commit 78530ce653
.
Fails on shared_lib build.
This commit is contained in:
parent
b82a7486d1
commit
be185b6a73
|
@ -1,102 +0,0 @@
|
|||
//===- IndentedOstream.h - raw ostream wrapper to indent --------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// raw_ostream subclass that keeps track of indentation for textual output
|
||||
// where indentation helps readability.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_SUPPORT_INDENTEDOSTREAM_H_
|
||||
#define MLIR_SUPPORT_INDENTEDOSTREAM_H_
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// raw_ostream subclass that simplifies indention a sequence of code.
|
||||
class raw_indented_ostream : public raw_ostream {
|
||||
public:
|
||||
explicit raw_indented_ostream(llvm::raw_ostream &os) : os(os) {
|
||||
SetUnbuffered();
|
||||
}
|
||||
|
||||
/// Simple RAII struct to use to indentation around entering/exiting region.
|
||||
struct DelimitedScope {
|
||||
explicit DelimitedScope(raw_indented_ostream &os, StringRef open = "",
|
||||
StringRef close = "")
|
||||
: os(os), open(open), close(close) {
|
||||
os << open;
|
||||
os.indent();
|
||||
}
|
||||
~DelimitedScope() {
|
||||
os.unindent();
|
||||
os << close;
|
||||
}
|
||||
|
||||
raw_indented_ostream &os;
|
||||
|
||||
private:
|
||||
llvm::StringRef open, close;
|
||||
};
|
||||
|
||||
/// Returns DelimitedScope.
|
||||
DelimitedScope scope(StringRef open = "", StringRef close = "") {
|
||||
return DelimitedScope(*this, open, close);
|
||||
}
|
||||
|
||||
/// Re-indents by removing the leading whitespace from the first non-empty
|
||||
/// line from every line of the the string, skipping over empty lines at the
|
||||
/// start.
|
||||
raw_indented_ostream &reindent(StringRef str);
|
||||
|
||||
/// Increases the indent and returning this raw_indented_ostream.
|
||||
raw_indented_ostream &indent() {
|
||||
currentIndent += indentSize;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Decreases the indent and returning this raw_indented_ostream.
|
||||
raw_indented_ostream &unindent() {
|
||||
currentIndent = std::max(0, currentIndent - indentSize);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Emits whitespace and sets the indendation for the stream.
|
||||
raw_indented_ostream &indent(int with) {
|
||||
os.indent(with);
|
||||
atStartOfLine = false;
|
||||
currentIndent = with;
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
void write_impl(const char *ptr, size_t size) override;
|
||||
|
||||
/// Return the current position within the stream, not counting the bytes
|
||||
/// currently in the buffer.
|
||||
uint64_t current_pos() const override { return os.tell(); }
|
||||
|
||||
/// Constant indent added/removed.
|
||||
static constexpr int indentSize = 2;
|
||||
|
||||
// Tracker for current indentation.
|
||||
int currentIndent = 0;
|
||||
|
||||
// The leading whitespace of the string being printed, if reindent is used.
|
||||
int leadingWs = 0;
|
||||
|
||||
// Tracks whether at start of line and so indent is required or not.
|
||||
bool atStartOfLine = true;
|
||||
|
||||
// The underlying raw_ostream.
|
||||
raw_ostream &os;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
#endif // MLIR_SUPPORT_INDENTEDOSTREAM_H_
|
|
@ -1,6 +1,5 @@
|
|||
set(LLVM_OPTIONAL_SOURCES
|
||||
FileUtilities.cpp
|
||||
IndentedOstream.cpp
|
||||
MlirOptMain.cpp
|
||||
StorageUniquer.cpp
|
||||
ToolUtilities.cpp
|
||||
|
@ -28,10 +27,3 @@ add_mlir_library(MLIROptLib
|
|||
MLIRParser
|
||||
MLIRSupport
|
||||
)
|
||||
|
||||
# This doesn't use add_mlir_library as it is used in mlir-tblgen and else
|
||||
# mlir-tblgen ends up depending on mlir-generic-headers.
|
||||
add_llvm_library(MLIRSupportIdentedOstream
|
||||
IndentedOstream.cpp
|
||||
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Support)
|
||||
|
|
|
@ -1,65 +0,0 @@
|
|||
//===- IndentedOstream.cpp - raw ostream wrapper to indent ----------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// raw_ostream subclass that keeps track of indentation for textual output
|
||||
// where indentation helps readability.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Support/IndentedOstream.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
raw_indented_ostream &mlir::raw_indented_ostream::reindent(StringRef str) {
|
||||
StringRef remaining = str;
|
||||
// Find leading whitespace indent.
|
||||
while (!remaining.empty()) {
|
||||
auto split = remaining.split('\n');
|
||||
size_t indent = split.first.find_first_not_of(" \t");
|
||||
if (indent != StringRef::npos) {
|
||||
leadingWs = indent;
|
||||
break;
|
||||
}
|
||||
remaining = split.second;
|
||||
}
|
||||
// Print, skipping the empty lines.
|
||||
*this << remaining;
|
||||
leadingWs = 0;
|
||||
return *this;
|
||||
}
|
||||
|
||||
void mlir::raw_indented_ostream::write_impl(const char *ptr, size_t size) {
|
||||
StringRef str(ptr, size);
|
||||
// Print out indented.
|
||||
auto print = [this](StringRef str) {
|
||||
if (atStartOfLine)
|
||||
os.indent(currentIndent) << str.substr(leadingWs);
|
||||
else
|
||||
os << str.substr(leadingWs);
|
||||
};
|
||||
|
||||
while (!str.empty()) {
|
||||
size_t idx = str.find('\n');
|
||||
if (idx == StringRef::npos) {
|
||||
if (!str.substr(leadingWs).empty()) {
|
||||
print(str);
|
||||
atStartOfLine = false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
auto split =
|
||||
std::make_pair(str.slice(0, idx), str.slice(idx + 1, StringRef::npos));
|
||||
// Print empty new line without spaces if line only has spaces.
|
||||
if (!split.first.ltrim().empty())
|
||||
print(split.first);
|
||||
os << '\n';
|
||||
atStartOfLine = true;
|
||||
str = split.second;
|
||||
}
|
||||
}
|
|
@ -25,7 +25,6 @@ add_tablegen(mlir-tblgen MLIR
|
|||
set_target_properties(mlir-tblgen PROPERTIES FOLDER "Tablegenning")
|
||||
target_link_libraries(mlir-tblgen
|
||||
PRIVATE
|
||||
MLIRSupportIdentedOstream
|
||||
MLIRTableGen)
|
||||
|
||||
mlir_check_all_link_libraries(mlir-tblgen)
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "DocGenUtilities.h"
|
||||
#include "mlir/Support/IndentedOstream.h"
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
#include "mlir/TableGen/Operator.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
@ -36,8 +35,39 @@ using mlir::tblgen::Operator;
|
|||
// in a way the user wanted but has some additional indenting due to being
|
||||
// nested in the op definition.
|
||||
void mlir::tblgen::emitDescription(StringRef description, raw_ostream &os) {
|
||||
raw_indented_ostream ros(os);
|
||||
ros.reindent(description.rtrim(" \t"));
|
||||
// Determine the minimum number of spaces in a line.
|
||||
size_t min_indent = -1;
|
||||
StringRef remaining = description;
|
||||
while (!remaining.empty()) {
|
||||
auto split = remaining.split('\n');
|
||||
size_t indent = split.first.find_first_not_of(" \t");
|
||||
if (indent != StringRef::npos)
|
||||
min_indent = std::min(indent, min_indent);
|
||||
remaining = split.second;
|
||||
}
|
||||
|
||||
// Print out the description indented.
|
||||
os << "\n";
|
||||
remaining = description;
|
||||
bool printed = false;
|
||||
while (!remaining.empty()) {
|
||||
auto split = remaining.split('\n');
|
||||
if (split.second.empty()) {
|
||||
// Skip last line with just spaces.
|
||||
if (split.first.ltrim().empty())
|
||||
break;
|
||||
}
|
||||
// Print empty new line without spaces if line only has spaces, unless no
|
||||
// text has been emitted before.
|
||||
if (split.first.ltrim().empty()) {
|
||||
if (printed)
|
||||
os << "\n";
|
||||
} else {
|
||||
os << split.first.substr(min_indent) << "\n";
|
||||
printed = true;
|
||||
}
|
||||
remaining = split.second;
|
||||
}
|
||||
}
|
||||
|
||||
// Emits `str` with trailing newline if not empty.
|
||||
|
@ -86,7 +116,7 @@ static void emitOpDoc(Operator op, raw_ostream &os) {
|
|||
|
||||
// Emit the summary, syntax, and description if present.
|
||||
if (op.hasSummary())
|
||||
os << "\n" << op.getSummary() << "\n\n";
|
||||
os << "\n" << op.getSummary() << "\n";
|
||||
if (op.hasAssemblyFormat())
|
||||
emitAssemblyFormat(op.getOperationName(), op.getAssemblyFormat().trim(),
|
||||
os);
|
||||
|
@ -198,7 +228,7 @@ static void emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
|||
}
|
||||
|
||||
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
|
||||
for (const auto &dialectWithOps : dialectOps)
|
||||
for (auto dialectWithOps : dialectOps)
|
||||
emitDialectDoc(dialectWithOps.first, dialectWithOps.second,
|
||||
dialectTypes[dialectWithOps.first], os);
|
||||
}
|
||||
|
|
|
@ -10,7 +10,6 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Support/IndentedOstream.h"
|
||||
#include "mlir/TableGen/Attribute.h"
|
||||
#include "mlir/TableGen/Format.h"
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
|
@ -78,11 +77,11 @@ private:
|
|||
|
||||
// Emits C++ statements for matching the `argIndex`-th argument of the given
|
||||
// DAG `tree` as an operand.
|
||||
void emitOperandMatch(DagNode tree, int argIndex, int depth);
|
||||
void emitOperandMatch(DagNode tree, int argIndex, int depth, int indent);
|
||||
|
||||
// Emits C++ statements for matching the `argIndex`-th argument of the given
|
||||
// DAG `tree` as an attribute.
|
||||
void emitAttributeMatch(DagNode tree, int argIndex, int depth);
|
||||
void emitAttributeMatch(DagNode tree, int argIndex, int depth, int indent);
|
||||
|
||||
// Emits C++ for checking a match with a corresponding match failure
|
||||
// diagnostic.
|
||||
|
@ -185,7 +184,7 @@ private:
|
|||
// The next unused ID for newly created values.
|
||||
unsigned nextValueId;
|
||||
|
||||
raw_indented_ostream os;
|
||||
raw_ostream &os;
|
||||
|
||||
// Format contexts containing placeholder substitutions.
|
||||
FmtContext fmtCtx;
|
||||
|
@ -226,7 +225,8 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
|
|||
// Skip the operand matching at depth 0 as the pattern rewriter already does.
|
||||
if (depth != 0) {
|
||||
// Skip if there is no defining operation (e.g., arguments to function).
|
||||
os << formatv("if (!castedOp{0})\n return failure();\n", depth);
|
||||
os.indent(indent) << formatv("if (!castedOp{0}) return failure();\n",
|
||||
depth);
|
||||
}
|
||||
if (tree.getNumArgs() != op.getNumArgs()) {
|
||||
PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
|
||||
|
@ -238,7 +238,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
|
|||
// If the operand's name is set, set to that variable.
|
||||
auto name = tree.getSymbol();
|
||||
if (!name.empty())
|
||||
os << formatv("{0} = castedOp{1};\n", name, depth);
|
||||
os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth);
|
||||
|
||||
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
|
||||
auto opArg = op.getArg(i);
|
||||
|
@ -253,23 +253,24 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
|
|||
PrintFatalError(loc, error);
|
||||
}
|
||||
}
|
||||
os << "{\n";
|
||||
os.indent(indent) << "{\n";
|
||||
|
||||
os.indent() << formatv(
|
||||
os.indent(indent + 2) << formatv(
|
||||
"auto *op{0} = "
|
||||
"(*castedOp{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
|
||||
depth + 1, depth, i);
|
||||
emitOpMatch(argTree, depth + 1);
|
||||
os << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
|
||||
os.unindent() << "}\n";
|
||||
os.indent(indent + 2)
|
||||
<< formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
|
||||
os.indent(indent) << "}\n";
|
||||
continue;
|
||||
}
|
||||
|
||||
// Next handle DAG leaf: operand or attribute
|
||||
if (opArg.is<NamedTypeConstraint *>()) {
|
||||
emitOperandMatch(tree, i, depth);
|
||||
emitOperandMatch(tree, i, depth, indent);
|
||||
} else if (opArg.is<NamedAttribute *>()) {
|
||||
emitAttributeMatch(tree, i, depth);
|
||||
emitAttributeMatch(tree, i, depth, indent);
|
||||
} else {
|
||||
PrintFatalError(loc, "unhandled case when matching op");
|
||||
}
|
||||
|
@ -279,7 +280,8 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
|
|||
<< '\n');
|
||||
}
|
||||
|
||||
void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
|
||||
void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
|
||||
int indent) {
|
||||
Operator &op = tree.getDialectOp(opMap);
|
||||
auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
|
||||
auto matcher = tree.getArgAsLeaf(argIndex);
|
||||
|
@ -326,28 +328,30 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
|
|||
op.arg_begin(), op.arg_begin() + argIndex,
|
||||
[](const Argument &arg) { return arg.is<NamedAttribute *>(); });
|
||||
|
||||
os << formatv("{0} = castedOp{1}.getODSOperands({2});\n", name, depth,
|
||||
argIndex - numPrevAttrs);
|
||||
os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
|
||||
name, depth, argIndex - numPrevAttrs);
|
||||
}
|
||||
}
|
||||
|
||||
void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
|
||||
void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
|
||||
int indent) {
|
||||
Operator &op = tree.getDialectOp(opMap);
|
||||
auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
|
||||
const auto &attr = namedAttr->attr;
|
||||
|
||||
os << "{\n";
|
||||
os.indent() << formatv(
|
||||
"auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\"); "
|
||||
os.indent(indent) << "{\n";
|
||||
indent += 2;
|
||||
os.indent(indent) << formatv(
|
||||
"auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");"
|
||||
"(void)tblgen_attr;\n",
|
||||
depth, attr.getStorageType(), namedAttr->name);
|
||||
|
||||
// TODO: This should use getter method to avoid duplication.
|
||||
if (attr.hasDefaultValue()) {
|
||||
os << "if (!tblgen_attr) tblgen_attr = "
|
||||
<< std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
|
||||
attr.getDefaultValue()))
|
||||
<< ";\n";
|
||||
os.indent(indent) << "if (!tblgen_attr) tblgen_attr = "
|
||||
<< std::string(tgfmt(attr.getConstBuilderTemplate(),
|
||||
&fmtCtx, attr.getDefaultValue()))
|
||||
<< ";\n";
|
||||
} else if (attr.isOptional()) {
|
||||
// For a missing attribute that is optional according to definition, we
|
||||
// should just capture a mlir::Attribute() to signal the missing state.
|
||||
|
@ -383,20 +387,27 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
|
|||
auto name = tree.getArgName(argIndex);
|
||||
// `$_` is a special symbol to ignore op argument matching.
|
||||
if (!name.empty() && name != "_") {
|
||||
os << formatv("{0} = tblgen_attr;\n", name);
|
||||
os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
|
||||
}
|
||||
|
||||
os.unindent() << "}\n";
|
||||
indent -= 2;
|
||||
os.indent(indent) << "}\n";
|
||||
}
|
||||
|
||||
void PatternEmitter::emitMatchCheck(
|
||||
int depth, const FmtObjectBase &matchFmt,
|
||||
const llvm::formatv_object_base &failureFmt) {
|
||||
os << "if (!(" << matchFmt.str() << "))";
|
||||
os.scope("{\n", "\n}\n").os
|
||||
<< "return rewriter.notifyMatchFailure(op" << depth
|
||||
<< ", [&](::mlir::Diagnostic &diag) {\n diag << " << failureFmt.str()
|
||||
<< ";\n});";
|
||||
// {0} The match depth (used to get the operation that failed to match).
|
||||
// {1} The format for the match string.
|
||||
// {2} The format for the failure string.
|
||||
const char *matchStr = R"(
|
||||
if (!({1})) {
|
||||
return rewriter.notifyMatchFailure(op{0}, [&](::mlir::Diagnostic &diag) {
|
||||
diag << {2};
|
||||
});
|
||||
})";
|
||||
os << llvm::formatv(matchStr, depth, matchFmt.str(), failureFmt.str())
|
||||
<< "\n";
|
||||
}
|
||||
|
||||
void PatternEmitter::emitMatchLogic(DagNode tree) {
|
||||
|
@ -480,7 +491,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
|
|||
|
||||
// Emit RewritePattern for Pattern.
|
||||
auto locs = pattern.getLocation();
|
||||
os << formatv("/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n",
|
||||
os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
|
||||
make_range(locs.rbegin(), locs.rend()));
|
||||
os << formatv(R"(struct {0} : public ::mlir::RewritePattern {
|
||||
{0}(::mlir::MLIRContext *context)
|
||||
|
@ -498,48 +509,44 @@ void PatternEmitter::emit(StringRef rewriteName) {
|
|||
os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
|
||||
|
||||
// Emit matchAndRewrite() function.
|
||||
{
|
||||
auto classScope = os.scope();
|
||||
os.reindent(R"(
|
||||
::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0,
|
||||
::mlir::PatternRewriter &rewriter) const override {)")
|
||||
<< '\n';
|
||||
{
|
||||
auto functionScope = os.scope();
|
||||
os << R"(
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(::mlir::Operation *op0,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
)";
|
||||
|
||||
// Register all symbols bound in the source pattern.
|
||||
pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
|
||||
// Register all symbols bound in the source pattern.
|
||||
pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "start creating local variables for capturing matches\n");
|
||||
os << "// Variables for capturing values and attributes used while "
|
||||
"creating ops\n";
|
||||
// Create local variables for storing the arguments and results bound
|
||||
// to symbols.
|
||||
for (const auto &symbolInfoPair : symbolInfoMap) {
|
||||
StringRef symbol = symbolInfoPair.getKey();
|
||||
auto &info = symbolInfoPair.getValue();
|
||||
os << info.getVarDecl(symbol);
|
||||
}
|
||||
// TODO: capture ops with consistent numbering so that it can be
|
||||
// reused for fused loc.
|
||||
os << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n",
|
||||
pattern.getSourcePattern().getNumOps());
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "done creating local variables for capturing matches\n");
|
||||
|
||||
os << "// Match\n";
|
||||
os << "tblgen_ops[0] = op0;\n";
|
||||
emitMatchLogic(sourceTree);
|
||||
|
||||
os << "\n// Rewrite\n";
|
||||
emitRewriteLogic();
|
||||
|
||||
os << "return success();\n";
|
||||
}
|
||||
os << "};\n";
|
||||
LLVM_DEBUG(
|
||||
llvm::dbgs() << "start creating local variables for capturing matches\n");
|
||||
os.indent(4) << "// Variables for capturing values and attributes used for "
|
||||
"creating ops\n";
|
||||
// Create local variables for storing the arguments and results bound
|
||||
// to symbols.
|
||||
for (const auto &symbolInfoPair : symbolInfoMap) {
|
||||
StringRef symbol = symbolInfoPair.getKey();
|
||||
auto &info = symbolInfoPair.getValue();
|
||||
os.indent(4) << info.getVarDecl(symbol);
|
||||
}
|
||||
os << "};\n\n";
|
||||
// TODO: capture ops with consistent numbering so that it can be
|
||||
// reused for fused loc.
|
||||
os.indent(4) << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n",
|
||||
pattern.getSourcePattern().getNumOps());
|
||||
LLVM_DEBUG(
|
||||
llvm::dbgs() << "done creating local variables for capturing matches\n");
|
||||
|
||||
os.indent(4) << "// Match\n";
|
||||
os.indent(4) << "tblgen_ops[0] = op0;\n";
|
||||
emitMatchLogic(sourceTree);
|
||||
os << "\n";
|
||||
|
||||
os.indent(4) << "// Rewrite\n";
|
||||
emitRewriteLogic();
|
||||
|
||||
os.indent(4) << "return success();\n";
|
||||
os << " };\n";
|
||||
os << "};\n";
|
||||
}
|
||||
|
||||
void PatternEmitter::emitRewriteLogic() {
|
||||
|
@ -579,7 +586,7 @@ void PatternEmitter::emitRewriteLogic() {
|
|||
PrintFatalError(loc, error);
|
||||
}
|
||||
|
||||
os << "auto odsLoc = rewriter.getFusedLoc({";
|
||||
os.indent(4) << "auto odsLoc = rewriter.getFusedLoc({";
|
||||
for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
|
||||
os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
|
||||
}
|
||||
|
@ -594,21 +601,22 @@ void PatternEmitter::emitRewriteLogic() {
|
|||
// we are handling auxiliary patterns so we want the side effect even if
|
||||
// NativeCodeCall is not replacing matched root op's results.
|
||||
if (resultTree.isNativeCodeCall())
|
||||
os << val << ";\n";
|
||||
os.indent(4) << val << ";\n";
|
||||
}
|
||||
|
||||
if (numExpectedResults == 0) {
|
||||
assert(replStartIndex >= numResultPatterns &&
|
||||
"invalid auxiliary vs. replacement pattern division!");
|
||||
// No result to replace. Just erase the op.
|
||||
os << "rewriter.eraseOp(op0);\n";
|
||||
os.indent(4) << "rewriter.eraseOp(op0);\n";
|
||||
} else {
|
||||
// Process replacement result patterns.
|
||||
os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
|
||||
os.indent(4)
|
||||
<< "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
|
||||
for (int i = replStartIndex; i < numResultPatterns; ++i) {
|
||||
DagNode resultTree = pattern.getResultPattern(i);
|
||||
auto val = handleResultPattern(resultTree, offsets[i], 0);
|
||||
os << "\n";
|
||||
os.indent(4) << "\n";
|
||||
// Resolve each symbol for all range use so that we can loop over them.
|
||||
// We need an explicit cast to `SmallVector` to capture the cases where
|
||||
// `{0}` resolves to an `Operation::result_range` as well as cases that
|
||||
|
@ -617,11 +625,12 @@ void PatternEmitter::emitRewriteLogic() {
|
|||
// TODO: Revisit the need for materializing a vector.
|
||||
os << symbolInfoMap.getAllRangeUse(
|
||||
val,
|
||||
"for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
|
||||
" tblgen_repl_values.push_back(v);\n}\n",
|
||||
" for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{ "
|
||||
"tblgen_repl_values.push_back(v); }",
|
||||
"\n");
|
||||
}
|
||||
os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
|
||||
os.indent(4) << "\n";
|
||||
os.indent(4) << "rewriter.replaceOp(op0, tblgen_repl_values);\n";
|
||||
}
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
|
||||
|
@ -870,8 +879,9 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
|
|||
}
|
||||
|
||||
// Create the local variable for this op.
|
||||
os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(),
|
||||
valuePackName);
|
||||
os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(),
|
||||
valuePackName);
|
||||
os.indent(4) << "{\n";
|
||||
|
||||
// Right now ODS don't have general type inference support. Except a few
|
||||
// special cases listed below, DRR needs to supply types for all results
|
||||
|
@ -890,9 +900,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
|
|||
createAggregateLocalVarsForOpArgs(tree, childNodeNames);
|
||||
|
||||
// Then create the op.
|
||||
os.scope("", "\n}\n").os << formatv(
|
||||
"{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
|
||||
os.indent(6) << formatv(
|
||||
"{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);\n",
|
||||
valuePackName, resultOp.getQualCppClassName(), locToUse);
|
||||
os.indent(4) << "}\n";
|
||||
return resultValue;
|
||||
}
|
||||
|
||||
|
@ -909,10 +920,11 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
|
|||
// aggregate-parameter builders.
|
||||
createSeparateLocalVarsForOpArgs(tree, childNodeNames);
|
||||
|
||||
os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
|
||||
resultOp.getQualCppClassName(), locToUse);
|
||||
os.indent(6) << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
|
||||
resultOp.getQualCppClassName(), locToUse);
|
||||
supplyValuesForOpArgs(tree, childNodeNames);
|
||||
os << "\n );\n}\n";
|
||||
os << "\n );\n";
|
||||
os.indent(4) << "}\n";
|
||||
return resultValue;
|
||||
}
|
||||
|
||||
|
@ -926,19 +938,20 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
|
|||
|
||||
// Then prepare the result types. We need to specify the types for all
|
||||
// results.
|
||||
os.indent() << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; "
|
||||
"(void)tblgen_types;\n");
|
||||
os.indent(6) << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; "
|
||||
"(void)tblgen_types;\n");
|
||||
int numResults = resultOp.getNumResults();
|
||||
if (numResults != 0) {
|
||||
for (int i = 0; i < numResults; ++i)
|
||||
os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n"
|
||||
" tblgen_types.push_back(v.getType());\n}\n",
|
||||
resultIndex + i);
|
||||
os.indent(6) << formatv("for (auto v : castedOp0.getODSResults({0})) {{"
|
||||
"tblgen_types.push_back(v.getType()); }\n",
|
||||
resultIndex + i);
|
||||
}
|
||||
os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
|
||||
"tblgen_values, tblgen_attrs);\n",
|
||||
valuePackName, resultOp.getQualCppClassName(), locToUse);
|
||||
os.unindent() << "}\n";
|
||||
os.indent(6) << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
|
||||
"tblgen_values, tblgen_attrs);\n",
|
||||
valuePackName, resultOp.getQualCppClassName(),
|
||||
locToUse);
|
||||
os.indent(4) << "}\n";
|
||||
return resultValue;
|
||||
}
|
||||
|
||||
|
@ -955,15 +968,16 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs(
|
|||
for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
|
||||
const auto *operand =
|
||||
resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
|
||||
// We do not need special handling for attributes.
|
||||
if (!operand)
|
||||
if (!operand) {
|
||||
// We do not need special handling for attributes.
|
||||
continue;
|
||||
}
|
||||
|
||||
raw_indented_ostream::DelimitedScope scope(os);
|
||||
std::string varName;
|
||||
if (operand->isVariadic()) {
|
||||
varName = std::string(formatv("tblgen_values_{0}", valueIndex++));
|
||||
os << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n", varName);
|
||||
os.indent(6) << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n",
|
||||
varName);
|
||||
std::string range;
|
||||
if (node.isNestedDagArg(argIndex)) {
|
||||
range = childNodeNames[argIndex];
|
||||
|
@ -973,11 +987,11 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs(
|
|||
// Resolve the symbol for all range use so that we have a uniform way of
|
||||
// capturing the values.
|
||||
range = symbolInfoMap.getValueAndRangeUse(range);
|
||||
os << formatv("for (auto v: {0}) {{\n {1}.push_back(v);\n}\n", range,
|
||||
varName);
|
||||
os.indent(6) << formatv("for (auto v : {0}) {1}.push_back(v);\n", range,
|
||||
varName);
|
||||
} else {
|
||||
varName = std::string(formatv("tblgen_value_{0}", valueIndex++));
|
||||
os << formatv("::mlir::Value {0} = ", varName);
|
||||
os.indent(6) << formatv("::mlir::Value {0} = ", varName);
|
||||
if (node.isNestedDagArg(argIndex)) {
|
||||
os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
|
||||
} else {
|
||||
|
@ -1005,7 +1019,7 @@ void PatternEmitter::supplyValuesForOpArgs(
|
|||
for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
|
||||
argIndex != numOpArgs; ++argIndex) {
|
||||
// Start each argument on its own line.
|
||||
os << ",\n ";
|
||||
(os << ",\n").indent(8);
|
||||
|
||||
Argument opArg = resultOp.getArg(argIndex);
|
||||
// Handle the case of operand first.
|
||||
|
@ -1046,16 +1060,14 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
|
|||
DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
|
||||
Operator &resultOp = node.getDialectOp(opMap);
|
||||
|
||||
auto scope = os.scope();
|
||||
os << formatv("::mlir::SmallVector<::mlir::Value, 4> "
|
||||
"tblgen_values; (void)tblgen_values;\n");
|
||||
os << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> "
|
||||
"tblgen_attrs; (void)tblgen_attrs;\n");
|
||||
os.indent(6) << formatv("::mlir::SmallVector<::mlir::Value, 4> "
|
||||
"tblgen_values; (void)tblgen_values;\n");
|
||||
os.indent(6) << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> "
|
||||
"tblgen_attrs; (void)tblgen_attrs;\n");
|
||||
|
||||
const char *addAttrCmd =
|
||||
"if (auto tmpAttr = {1}) {\n"
|
||||
" tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), "
|
||||
"tmpAttr);\n}\n";
|
||||
"if (auto tmpAttr = {1}) "
|
||||
"tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), tmpAttr);\n";
|
||||
for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
|
||||
if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
|
||||
// The argument in the op definition.
|
||||
|
@ -1064,14 +1076,14 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
|
|||
if (!subTree.isNativeCodeCall())
|
||||
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
|
||||
"for creating attribute");
|
||||
os << formatv(addAttrCmd, opArgName,
|
||||
handleReplaceWithNativeCodeCall(subTree));
|
||||
os.indent(6) << formatv(addAttrCmd, opArgName,
|
||||
handleReplaceWithNativeCodeCall(subTree));
|
||||
} else {
|
||||
auto leaf = node.getArgAsLeaf(argIndex);
|
||||
// The argument in the result DAG pattern.
|
||||
auto patArgName = node.getArgName(argIndex);
|
||||
os << formatv(addAttrCmd, opArgName,
|
||||
handleOpArgument(leaf, patArgName));
|
||||
os.indent(6) << formatv(addAttrCmd, opArgName,
|
||||
handleOpArgument(leaf, patArgName));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
@ -1089,10 +1101,10 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
|
|||
// Resolve the symbol for all range use so that we have a uniform way of
|
||||
// capturing the values.
|
||||
range = symbolInfoMap.getValueAndRangeUse(range);
|
||||
os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n",
|
||||
range);
|
||||
os.indent(6) << formatv(
|
||||
"for (auto v : {0}) tblgen_values.push_back(v);\n", range);
|
||||
} else {
|
||||
os << formatv("tblgen_values.push_back(", varName);
|
||||
os.indent(6) << formatv("tblgen_values.push_back(", varName);
|
||||
if (node.isNestedDagArg(argIndex)) {
|
||||
os << symbolInfoMap.getValueAndRangeUse(
|
||||
childNodeNames.lookup(argIndex));
|
||||
|
|
|
@ -1,6 +0,0 @@
|
|||
add_mlir_unittest(MLIRSupportTests
|
||||
IndentedOstreamTest.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(MLIRSupportTests
|
||||
PRIVATE MLIRSupportIdentedOstream MLIRSupport)
|
|
@ -1,110 +0,0 @@
|
|||
//===- IndentedOstreamTest.cpp - Indented raw ostream Tests ---------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Support/IndentedOstream.h"
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
using namespace mlir;
|
||||
using ::testing::StrEq;
|
||||
|
||||
TEST(FormatTest, SingleLine) {
|
||||
std::string str;
|
||||
llvm::raw_string_ostream os(str);
|
||||
raw_indented_ostream ros(os);
|
||||
ros << 10;
|
||||
ros.flush();
|
||||
EXPECT_THAT(os.str(), StrEq("10"));
|
||||
}
|
||||
|
||||
TEST(FormatTest, SimpleMultiLine) {
|
||||
std::string str;
|
||||
llvm::raw_string_ostream os(str);
|
||||
raw_indented_ostream ros(os);
|
||||
ros << "a";
|
||||
ros << "b";
|
||||
ros << "\n";
|
||||
ros << "c";
|
||||
ros << "\n";
|
||||
ros.flush();
|
||||
EXPECT_THAT(os.str(), StrEq("ab\nc\n"));
|
||||
}
|
||||
|
||||
TEST(FormatTest, SimpleMultiLineIndent) {
|
||||
std::string str;
|
||||
llvm::raw_string_ostream os(str);
|
||||
raw_indented_ostream ros(os);
|
||||
ros.indent(2) << "a";
|
||||
ros.indent(4) << "b";
|
||||
ros << "\n";
|
||||
ros << "c";
|
||||
ros << "\n";
|
||||
ros.flush();
|
||||
EXPECT_THAT(os.str(), StrEq(" a b\n c\n"));
|
||||
}
|
||||
|
||||
TEST(FormatTest, SingleRegion) {
|
||||
std::string str;
|
||||
llvm::raw_string_ostream os(str);
|
||||
raw_indented_ostream ros(os);
|
||||
ros << "before\n";
|
||||
{
|
||||
raw_indented_ostream::DelimitedScope scope(ros);
|
||||
ros << "inside " << 10;
|
||||
ros << "\n two\n";
|
||||
{
|
||||
raw_indented_ostream::DelimitedScope scope(ros, "{\n", "\n}\n");
|
||||
ros << "inner inner";
|
||||
}
|
||||
}
|
||||
ros << "after";
|
||||
ros.flush();
|
||||
const auto *expected =
|
||||
R"(before
|
||||
inside 10
|
||||
two
|
||||
{
|
||||
inner inner
|
||||
}
|
||||
after)";
|
||||
EXPECT_THAT(os.str(), StrEq(expected));
|
||||
|
||||
// Repeat the above with inline form.
|
||||
str.clear();
|
||||
ros << "before\n";
|
||||
ros.scope().os << "inside " << 10 << "\n two\n";
|
||||
ros.scope().os.scope("{\n", "\n}\n").os << "inner inner";
|
||||
ros << "after";
|
||||
ros.flush();
|
||||
EXPECT_THAT(os.str(), StrEq(expected));
|
||||
}
|
||||
|
||||
TEST(FormatTest, Reindent) {
|
||||
std::string str;
|
||||
llvm::raw_string_ostream os(str);
|
||||
raw_indented_ostream ros(os);
|
||||
|
||||
// String to print with some additional empty lines at the start and lines
|
||||
// with just spaces.
|
||||
const auto *desc = R"(
|
||||
|
||||
|
||||
First line
|
||||
second line
|
||||
|
||||
|
||||
)";
|
||||
ros.reindent(desc);
|
||||
ros.flush();
|
||||
const auto *expected =
|
||||
R"(First line
|
||||
second line
|
||||
|
||||
|
||||
)";
|
||||
EXPECT_THAT(os.str(), StrEq(expected));
|
||||
}
|
Loading…
Reference in New Issue