Add indented raw_ostream class

Class simplifies keeping track of the indentation while emitting. For every new line the current indentation is simply prefixed (if not at start of line, then it just emits as normal). Add a simple Region helper that makes it easy to have the C++ scope match the emitted scope.

Use this in op doc generator and rewrite generator.

Differential Revision: https://reviews.llvm.org/D84107
This commit is contained in:
Jacques Pienaar 2020-10-03 08:53:43 -07:00
parent 7feafa0286
commit 78530ce653
8 changed files with 414 additions and 164 deletions

View File

@ -0,0 +1,102 @@
//===- IndentedOstream.h - raw ostream wrapper to indent --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// raw_ostream subclass that keeps track of indentation for textual output
// where indentation helps readability.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_SUPPORT_INDENTEDOSTREAM_H_
#define MLIR_SUPPORT_INDENTEDOSTREAM_H_
#include "mlir/Support/LLVM.h"
#include "llvm/Support/raw_ostream.h"
namespace mlir {
/// raw_ostream subclass that simplifies indention a sequence of code.
class raw_indented_ostream : public raw_ostream {
public:
explicit raw_indented_ostream(llvm::raw_ostream &os) : os(os) {
SetUnbuffered();
}
/// Simple RAII struct to use to indentation around entering/exiting region.
struct DelimitedScope {
explicit DelimitedScope(raw_indented_ostream &os, StringRef open = "",
StringRef close = "")
: os(os), open(open), close(close) {
os << open;
os.indent();
}
~DelimitedScope() {
os.unindent();
os << close;
}
raw_indented_ostream &os;
private:
llvm::StringRef open, close;
};
/// Returns DelimitedScope.
DelimitedScope scope(StringRef open = "", StringRef close = "") {
return DelimitedScope(*this, open, close);
}
/// Re-indents by removing the leading whitespace from the first non-empty
/// line from every line of the the string, skipping over empty lines at the
/// start.
raw_indented_ostream &reindent(StringRef str);
/// Increases the indent and returning this raw_indented_ostream.
raw_indented_ostream &indent() {
currentIndent += indentSize;
return *this;
}
/// Decreases the indent and returning this raw_indented_ostream.
raw_indented_ostream &unindent() {
currentIndent = std::max(0, currentIndent - indentSize);
return *this;
}
/// Emits whitespace and sets the indendation for the stream.
raw_indented_ostream &indent(int with) {
os.indent(with);
atStartOfLine = false;
currentIndent = with;
return *this;
}
private:
void write_impl(const char *ptr, size_t size) override;
/// Return the current position within the stream, not counting the bytes
/// currently in the buffer.
uint64_t current_pos() const override { return os.tell(); }
/// Constant indent added/removed.
static constexpr int indentSize = 2;
// Tracker for current indentation.
int currentIndent = 0;
// The leading whitespace of the string being printed, if reindent is used.
int leadingWs = 0;
// Tracks whether at start of line and so indent is required or not.
bool atStartOfLine = true;
// The underlying raw_ostream.
raw_ostream &os;
};
} // namespace mlir
#endif // MLIR_SUPPORT_INDENTEDOSTREAM_H_

View File

@ -1,5 +1,6 @@
set(LLVM_OPTIONAL_SOURCES
FileUtilities.cpp
IndentedOstream.cpp
MlirOptMain.cpp
StorageUniquer.cpp
ToolUtilities.cpp
@ -27,3 +28,10 @@ add_mlir_library(MLIROptLib
MLIRParser
MLIRSupport
)
# This doesn't use add_mlir_library as it is used in mlir-tblgen and else
# mlir-tblgen ends up depending on mlir-generic-headers.
add_llvm_library(MLIRSupportIdentedOstream
IndentedOstream.cpp
${MLIR_MAIN_INCLUDE_DIR}/mlir/Support)

View File

@ -0,0 +1,65 @@
//===- IndentedOstream.cpp - raw ostream wrapper to indent ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// raw_ostream subclass that keeps track of indentation for textual output
// where indentation helps readability.
//
//===----------------------------------------------------------------------===//
#include "mlir/Support/IndentedOstream.h"
using namespace mlir;
raw_indented_ostream &mlir::raw_indented_ostream::reindent(StringRef str) {
StringRef remaining = str;
// Find leading whitespace indent.
while (!remaining.empty()) {
auto split = remaining.split('\n');
size_t indent = split.first.find_first_not_of(" \t");
if (indent != StringRef::npos) {
leadingWs = indent;
break;
}
remaining = split.second;
}
// Print, skipping the empty lines.
*this << remaining;
leadingWs = 0;
return *this;
}
void mlir::raw_indented_ostream::write_impl(const char *ptr, size_t size) {
StringRef str(ptr, size);
// Print out indented.
auto print = [this](StringRef str) {
if (atStartOfLine)
os.indent(currentIndent) << str.substr(leadingWs);
else
os << str.substr(leadingWs);
};
while (!str.empty()) {
size_t idx = str.find('\n');
if (idx == StringRef::npos) {
if (!str.substr(leadingWs).empty()) {
print(str);
atStartOfLine = false;
}
break;
}
auto split =
std::make_pair(str.slice(0, idx), str.slice(idx + 1, StringRef::npos));
// Print empty new line without spaces if line only has spaces.
if (!split.first.ltrim().empty())
print(split.first);
os << '\n';
atStartOfLine = true;
str = split.second;
}
}

View File

@ -25,6 +25,7 @@ add_tablegen(mlir-tblgen MLIR
set_target_properties(mlir-tblgen PROPERTIES FOLDER "Tablegenning")
target_link_libraries(mlir-tblgen
PRIVATE
MLIRSupportIdentedOstream
MLIRTableGen)
mlir_check_all_link_libraries(mlir-tblgen)

View File

@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "DocGenUtilities.h"
#include "mlir/Support/IndentedOstream.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/DenseMap.h"
@ -35,39 +36,8 @@ using mlir::tblgen::Operator;
// in a way the user wanted but has some additional indenting due to being
// nested in the op definition.
void mlir::tblgen::emitDescription(StringRef description, raw_ostream &os) {
// Determine the minimum number of spaces in a line.
size_t min_indent = -1;
StringRef remaining = description;
while (!remaining.empty()) {
auto split = remaining.split('\n');
size_t indent = split.first.find_first_not_of(" \t");
if (indent != StringRef::npos)
min_indent = std::min(indent, min_indent);
remaining = split.second;
}
// Print out the description indented.
os << "\n";
remaining = description;
bool printed = false;
while (!remaining.empty()) {
auto split = remaining.split('\n');
if (split.second.empty()) {
// Skip last line with just spaces.
if (split.first.ltrim().empty())
break;
}
// Print empty new line without spaces if line only has spaces, unless no
// text has been emitted before.
if (split.first.ltrim().empty()) {
if (printed)
os << "\n";
} else {
os << split.first.substr(min_indent) << "\n";
printed = true;
}
remaining = split.second;
}
raw_indented_ostream ros(os);
ros.reindent(description.rtrim(" \t"));
}
// Emits `str` with trailing newline if not empty.
@ -116,7 +86,7 @@ static void emitOpDoc(Operator op, raw_ostream &os) {
// Emit the summary, syntax, and description if present.
if (op.hasSummary())
os << "\n" << op.getSummary() << "\n";
os << "\n" << op.getSummary() << "\n\n";
if (op.hasAssemblyFormat())
emitAssemblyFormat(op.getOperationName(), op.getAssemblyFormat().trim(),
os);
@ -228,7 +198,7 @@ static void emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
}
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
for (auto dialectWithOps : dialectOps)
for (const auto &dialectWithOps : dialectOps)
emitDialectDoc(dialectWithOps.first, dialectWithOps.second,
dialectTypes[dialectWithOps.first], os);
}

View File

@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Support/IndentedOstream.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
@ -77,11 +78,11 @@ private:
// Emits C++ statements for matching the `argIndex`-th argument of the given
// DAG `tree` as an operand.
void emitOperandMatch(DagNode tree, int argIndex, int depth, int indent);
void emitOperandMatch(DagNode tree, int argIndex, int depth);
// Emits C++ statements for matching the `argIndex`-th argument of the given
// DAG `tree` as an attribute.
void emitAttributeMatch(DagNode tree, int argIndex, int depth, int indent);
void emitAttributeMatch(DagNode tree, int argIndex, int depth);
// Emits C++ for checking a match with a corresponding match failure
// diagnostic.
@ -184,7 +185,7 @@ private:
// The next unused ID for newly created values.
unsigned nextValueId;
raw_ostream &os;
raw_indented_ostream os;
// Format contexts containing placeholder substitutions.
FmtContext fmtCtx;
@ -225,8 +226,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
// Skip the operand matching at depth 0 as the pattern rewriter already does.
if (depth != 0) {
// Skip if there is no defining operation (e.g., arguments to function).
os.indent(indent) << formatv("if (!castedOp{0}) return failure();\n",
depth);
os << formatv("if (!castedOp{0})\n return failure();\n", depth);
}
if (tree.getNumArgs() != op.getNumArgs()) {
PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
@ -238,7 +238,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
// If the operand's name is set, set to that variable.
auto name = tree.getSymbol();
if (!name.empty())
os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth);
os << formatv("{0} = castedOp{1};\n", name, depth);
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
auto opArg = op.getArg(i);
@ -253,24 +253,23 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
PrintFatalError(loc, error);
}
}
os.indent(indent) << "{\n";
os << "{\n";
os.indent(indent + 2) << formatv(
os.indent() << formatv(
"auto *op{0} = "
"(*castedOp{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
depth + 1, depth, i);
emitOpMatch(argTree, depth + 1);
os.indent(indent + 2)
<< formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
os.indent(indent) << "}\n";
os << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
os.unindent() << "}\n";
continue;
}
// Next handle DAG leaf: operand or attribute
if (opArg.is<NamedTypeConstraint *>()) {
emitOperandMatch(tree, i, depth, indent);
emitOperandMatch(tree, i, depth);
} else if (opArg.is<NamedAttribute *>()) {
emitAttributeMatch(tree, i, depth, indent);
emitAttributeMatch(tree, i, depth);
} else {
PrintFatalError(loc, "unhandled case when matching op");
}
@ -280,8 +279,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
<< '\n');
}
void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
int indent) {
void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
Operator &op = tree.getDialectOp(opMap);
auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
auto matcher = tree.getArgAsLeaf(argIndex);
@ -328,30 +326,28 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
op.arg_begin(), op.arg_begin() + argIndex,
[](const Argument &arg) { return arg.is<NamedAttribute *>(); });
os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
name, depth, argIndex - numPrevAttrs);
os << formatv("{0} = castedOp{1}.getODSOperands({2});\n", name, depth,
argIndex - numPrevAttrs);
}
}
void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
int indent) {
void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
Operator &op = tree.getDialectOp(opMap);
auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
const auto &attr = namedAttr->attr;
os.indent(indent) << "{\n";
indent += 2;
os.indent(indent) << formatv(
"auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");"
os << "{\n";
os.indent() << formatv(
"auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\"); "
"(void)tblgen_attr;\n",
depth, attr.getStorageType(), namedAttr->name);
// TODO: This should use getter method to avoid duplication.
if (attr.hasDefaultValue()) {
os.indent(indent) << "if (!tblgen_attr) tblgen_attr = "
<< std::string(tgfmt(attr.getConstBuilderTemplate(),
&fmtCtx, attr.getDefaultValue()))
<< ";\n";
os << "if (!tblgen_attr) tblgen_attr = "
<< std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
attr.getDefaultValue()))
<< ";\n";
} else if (attr.isOptional()) {
// For a missing attribute that is optional according to definition, we
// should just capture a mlir::Attribute() to signal the missing state.
@ -387,27 +383,20 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
auto name = tree.getArgName(argIndex);
// `$_` is a special symbol to ignore op argument matching.
if (!name.empty() && name != "_") {
os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
os << formatv("{0} = tblgen_attr;\n", name);
}
indent -= 2;
os.indent(indent) << "}\n";
os.unindent() << "}\n";
}
void PatternEmitter::emitMatchCheck(
int depth, const FmtObjectBase &matchFmt,
const llvm::formatv_object_base &failureFmt) {
// {0} The match depth (used to get the operation that failed to match).
// {1} The format for the match string.
// {2} The format for the failure string.
const char *matchStr = R"(
if (!({1})) {
return rewriter.notifyMatchFailure(op{0}, [&](::mlir::Diagnostic &diag) {
diag << {2};
});
})";
os << llvm::formatv(matchStr, depth, matchFmt.str(), failureFmt.str())
<< "\n";
os << "if (!(" << matchFmt.str() << "))";
os.scope("{\n", "\n}\n").os
<< "return rewriter.notifyMatchFailure(op" << depth
<< ", [&](::mlir::Diagnostic &diag) {\n diag << " << failureFmt.str()
<< ";\n});";
}
void PatternEmitter::emitMatchLogic(DagNode tree) {
@ -491,7 +480,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
// Emit RewritePattern for Pattern.
auto locs = pattern.getLocation();
os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
os << formatv("/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n",
make_range(locs.rbegin(), locs.rend()));
os << formatv(R"(struct {0} : public ::mlir::RewritePattern {
{0}(::mlir::MLIRContext *context)
@ -509,44 +498,48 @@ void PatternEmitter::emit(StringRef rewriteName) {
os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
// Emit matchAndRewrite() function.
os << R"(
::mlir::LogicalResult
matchAndRewrite(::mlir::Operation *op0,
::mlir::PatternRewriter &rewriter) const override {
)";
{
auto classScope = os.scope();
os.reindent(R"(
::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0,
::mlir::PatternRewriter &rewriter) const override {)")
<< '\n';
{
auto functionScope = os.scope();
// Register all symbols bound in the source pattern.
pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
// Register all symbols bound in the source pattern.
pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
LLVM_DEBUG(
llvm::dbgs() << "start creating local variables for capturing matches\n");
os.indent(4) << "// Variables for capturing values and attributes used for "
"creating ops\n";
// Create local variables for storing the arguments and results bound
// to symbols.
for (const auto &symbolInfoPair : symbolInfoMap) {
StringRef symbol = symbolInfoPair.getKey();
auto &info = symbolInfoPair.getValue();
os.indent(4) << info.getVarDecl(symbol);
LLVM_DEBUG(llvm::dbgs()
<< "start creating local variables for capturing matches\n");
os << "// Variables for capturing values and attributes used while "
"creating ops\n";
// Create local variables for storing the arguments and results bound
// to symbols.
for (const auto &symbolInfoPair : symbolInfoMap) {
StringRef symbol = symbolInfoPair.getKey();
auto &info = symbolInfoPair.getValue();
os << info.getVarDecl(symbol);
}
// TODO: capture ops with consistent numbering so that it can be
// reused for fused loc.
os << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n",
pattern.getSourcePattern().getNumOps());
LLVM_DEBUG(llvm::dbgs()
<< "done creating local variables for capturing matches\n");
os << "// Match\n";
os << "tblgen_ops[0] = op0;\n";
emitMatchLogic(sourceTree);
os << "\n// Rewrite\n";
emitRewriteLogic();
os << "return success();\n";
}
os << "};\n";
}
// TODO: capture ops with consistent numbering so that it can be
// reused for fused loc.
os.indent(4) << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n",
pattern.getSourcePattern().getNumOps());
LLVM_DEBUG(
llvm::dbgs() << "done creating local variables for capturing matches\n");
os.indent(4) << "// Match\n";
os.indent(4) << "tblgen_ops[0] = op0;\n";
emitMatchLogic(sourceTree);
os << "\n";
os.indent(4) << "// Rewrite\n";
emitRewriteLogic();
os.indent(4) << "return success();\n";
os << " };\n";
os << "};\n";
os << "};\n\n";
}
void PatternEmitter::emitRewriteLogic() {
@ -586,7 +579,7 @@ void PatternEmitter::emitRewriteLogic() {
PrintFatalError(loc, error);
}
os.indent(4) << "auto odsLoc = rewriter.getFusedLoc({";
os << "auto odsLoc = rewriter.getFusedLoc({";
for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
}
@ -601,22 +594,21 @@ void PatternEmitter::emitRewriteLogic() {
// we are handling auxiliary patterns so we want the side effect even if
// NativeCodeCall is not replacing matched root op's results.
if (resultTree.isNativeCodeCall())
os.indent(4) << val << ";\n";
os << val << ";\n";
}
if (numExpectedResults == 0) {
assert(replStartIndex >= numResultPatterns &&
"invalid auxiliary vs. replacement pattern division!");
// No result to replace. Just erase the op.
os.indent(4) << "rewriter.eraseOp(op0);\n";
os << "rewriter.eraseOp(op0);\n";
} else {
// Process replacement result patterns.
os.indent(4)
<< "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
for (int i = replStartIndex; i < numResultPatterns; ++i) {
DagNode resultTree = pattern.getResultPattern(i);
auto val = handleResultPattern(resultTree, offsets[i], 0);
os.indent(4) << "\n";
os << "\n";
// Resolve each symbol for all range use so that we can loop over them.
// We need an explicit cast to `SmallVector` to capture the cases where
// `{0}` resolves to an `Operation::result_range` as well as cases that
@ -625,12 +617,11 @@ void PatternEmitter::emitRewriteLogic() {
// TODO: Revisit the need for materializing a vector.
os << symbolInfoMap.getAllRangeUse(
val,
" for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{ "
"tblgen_repl_values.push_back(v); }",
"for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
" tblgen_repl_values.push_back(v);\n}\n",
"\n");
}
os.indent(4) << "\n";
os.indent(4) << "rewriter.replaceOp(op0, tblgen_repl_values);\n";
os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
}
LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
@ -879,9 +870,8 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
}
// Create the local variable for this op.
os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(),
valuePackName);
os.indent(4) << "{\n";
os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(),
valuePackName);
// Right now ODS don't have general type inference support. Except a few
// special cases listed below, DRR needs to supply types for all results
@ -900,10 +890,9 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
createAggregateLocalVarsForOpArgs(tree, childNodeNames);
// Then create the op.
os.indent(6) << formatv(
"{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);\n",
os.scope("", "\n}\n").os << formatv(
"{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
valuePackName, resultOp.getQualCppClassName(), locToUse);
os.indent(4) << "}\n";
return resultValue;
}
@ -920,11 +909,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
// aggregate-parameter builders.
createSeparateLocalVarsForOpArgs(tree, childNodeNames);
os.indent(6) << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
resultOp.getQualCppClassName(), locToUse);
os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
resultOp.getQualCppClassName(), locToUse);
supplyValuesForOpArgs(tree, childNodeNames);
os << "\n );\n";
os.indent(4) << "}\n";
os << "\n );\n}\n";
return resultValue;
}
@ -938,20 +926,19 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
// Then prepare the result types. We need to specify the types for all
// results.
os.indent(6) << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; "
"(void)tblgen_types;\n");
os.indent() << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; "
"(void)tblgen_types;\n");
int numResults = resultOp.getNumResults();
if (numResults != 0) {
for (int i = 0; i < numResults; ++i)
os.indent(6) << formatv("for (auto v : castedOp0.getODSResults({0})) {{"
"tblgen_types.push_back(v.getType()); }\n",
resultIndex + i);
os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n"
" tblgen_types.push_back(v.getType());\n}\n",
resultIndex + i);
}
os.indent(6) << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
"tblgen_values, tblgen_attrs);\n",
valuePackName, resultOp.getQualCppClassName(),
locToUse);
os.indent(4) << "}\n";
os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
"tblgen_values, tblgen_attrs);\n",
valuePackName, resultOp.getQualCppClassName(), locToUse);
os.unindent() << "}\n";
return resultValue;
}
@ -968,16 +955,15 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs(
for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
const auto *operand =
resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
if (!operand) {
// We do not need special handling for attributes.
// We do not need special handling for attributes.
if (!operand)
continue;
}
raw_indented_ostream::DelimitedScope scope(os);
std::string varName;
if (operand->isVariadic()) {
varName = std::string(formatv("tblgen_values_{0}", valueIndex++));
os.indent(6) << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n",
varName);
os << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n", varName);
std::string range;
if (node.isNestedDagArg(argIndex)) {
range = childNodeNames[argIndex];
@ -987,11 +973,11 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs(
// Resolve the symbol for all range use so that we have a uniform way of
// capturing the values.
range = symbolInfoMap.getValueAndRangeUse(range);
os.indent(6) << formatv("for (auto v : {0}) {1}.push_back(v);\n", range,
varName);
os << formatv("for (auto v: {0}) {{\n {1}.push_back(v);\n}\n", range,
varName);
} else {
varName = std::string(formatv("tblgen_value_{0}", valueIndex++));
os.indent(6) << formatv("::mlir::Value {0} = ", varName);
os << formatv("::mlir::Value {0} = ", varName);
if (node.isNestedDagArg(argIndex)) {
os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
} else {
@ -1019,7 +1005,7 @@ void PatternEmitter::supplyValuesForOpArgs(
for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
argIndex != numOpArgs; ++argIndex) {
// Start each argument on its own line.
(os << ",\n").indent(8);
os << ",\n ";
Argument opArg = resultOp.getArg(argIndex);
// Handle the case of operand first.
@ -1060,14 +1046,16 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
Operator &resultOp = node.getDialectOp(opMap);
os.indent(6) << formatv("::mlir::SmallVector<::mlir::Value, 4> "
"tblgen_values; (void)tblgen_values;\n");
os.indent(6) << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> "
"tblgen_attrs; (void)tblgen_attrs;\n");
auto scope = os.scope();
os << formatv("::mlir::SmallVector<::mlir::Value, 4> "
"tblgen_values; (void)tblgen_values;\n");
os << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> "
"tblgen_attrs; (void)tblgen_attrs;\n");
const char *addAttrCmd =
"if (auto tmpAttr = {1}) "
"tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), tmpAttr);\n";
"if (auto tmpAttr = {1}) {\n"
" tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), "
"tmpAttr);\n}\n";
for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
// The argument in the op definition.
@ -1076,14 +1064,14 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
if (!subTree.isNativeCodeCall())
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
"for creating attribute");
os.indent(6) << formatv(addAttrCmd, opArgName,
handleReplaceWithNativeCodeCall(subTree));
os << formatv(addAttrCmd, opArgName,
handleReplaceWithNativeCodeCall(subTree));
} else {
auto leaf = node.getArgAsLeaf(argIndex);
// The argument in the result DAG pattern.
auto patArgName = node.getArgName(argIndex);
os.indent(6) << formatv(addAttrCmd, opArgName,
handleOpArgument(leaf, patArgName));
os << formatv(addAttrCmd, opArgName,
handleOpArgument(leaf, patArgName));
}
continue;
}
@ -1101,10 +1089,10 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
// Resolve the symbol for all range use so that we have a uniform way of
// capturing the values.
range = symbolInfoMap.getValueAndRangeUse(range);
os.indent(6) << formatv(
"for (auto v : {0}) tblgen_values.push_back(v);\n", range);
os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n",
range);
} else {
os.indent(6) << formatv("tblgen_values.push_back(", varName);
os << formatv("tblgen_values.push_back(", varName);
if (node.isNestedDagArg(argIndex)) {
os << symbolInfoMap.getValueAndRangeUse(
childNodeNames.lookup(argIndex));

View File

@ -0,0 +1,6 @@
add_mlir_unittest(MLIRSupportTests
IndentedOstreamTest.cpp
)
target_link_libraries(MLIRSupportTests
PRIVATE MLIRSupportIdentedOstream MLIRSupport)

View File

@ -0,0 +1,110 @@
//===- IndentedOstreamTest.cpp - Indented raw ostream Tests ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Support/IndentedOstream.h"
#include "gmock/gmock.h"
using namespace mlir;
using ::testing::StrEq;
TEST(FormatTest, SingleLine) {
std::string str;
llvm::raw_string_ostream os(str);
raw_indented_ostream ros(os);
ros << 10;
ros.flush();
EXPECT_THAT(os.str(), StrEq("10"));
}
TEST(FormatTest, SimpleMultiLine) {
std::string str;
llvm::raw_string_ostream os(str);
raw_indented_ostream ros(os);
ros << "a";
ros << "b";
ros << "\n";
ros << "c";
ros << "\n";
ros.flush();
EXPECT_THAT(os.str(), StrEq("ab\nc\n"));
}
TEST(FormatTest, SimpleMultiLineIndent) {
std::string str;
llvm::raw_string_ostream os(str);
raw_indented_ostream ros(os);
ros.indent(2) << "a";
ros.indent(4) << "b";
ros << "\n";
ros << "c";
ros << "\n";
ros.flush();
EXPECT_THAT(os.str(), StrEq(" a b\n c\n"));
}
TEST(FormatTest, SingleRegion) {
std::string str;
llvm::raw_string_ostream os(str);
raw_indented_ostream ros(os);
ros << "before\n";
{
raw_indented_ostream::DelimitedScope scope(ros);
ros << "inside " << 10;
ros << "\n two\n";
{
raw_indented_ostream::DelimitedScope scope(ros, "{\n", "\n}\n");
ros << "inner inner";
}
}
ros << "after";
ros.flush();
const auto *expected =
R"(before
inside 10
two
{
inner inner
}
after)";
EXPECT_THAT(os.str(), StrEq(expected));
// Repeat the above with inline form.
str.clear();
ros << "before\n";
ros.scope().os << "inside " << 10 << "\n two\n";
ros.scope().os.scope("{\n", "\n}\n").os << "inner inner";
ros << "after";
ros.flush();
EXPECT_THAT(os.str(), StrEq(expected));
}
TEST(FormatTest, Reindent) {
std::string str;
llvm::raw_string_ostream os(str);
raw_indented_ostream ros(os);
// String to print with some additional empty lines at the start and lines
// with just spaces.
const auto *desc = R"(
First line
second line
)";
ros.reindent(desc);
ros.flush();
const auto *expected =
R"(First line
second line
)";
EXPECT_THAT(os.str(), StrEq(expected));
}