forked from OSchip/llvm-project
332 lines
10 KiB
C++
332 lines
10 KiB
C++
//===- NodePrinter.cpp ----------------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Tools/PDLL/AST/Context.h"
|
|
#include "mlir/Tools/PDLL/AST/Nodes.h"
|
|
#include "llvm/ADT/StringExtras.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include "llvm/Support/SaveAndRestore.h"
|
|
#include "llvm/Support/ScopedPrinter.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::pdll::ast;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// NodePrinter
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
class NodePrinter {
|
|
public:
|
|
NodePrinter(raw_ostream &os) : os(os) {}
|
|
|
|
/// Print the given type to the stream.
|
|
void print(Type type);
|
|
|
|
/// Print the given node to the stream.
|
|
void print(const Node *node);
|
|
|
|
private:
|
|
/// Print a range containing children of a node.
|
|
template <typename RangeT,
|
|
std::enable_if_t<!std::is_convertible<RangeT, const Node *>::value>
|
|
* = nullptr>
|
|
void printChildren(RangeT &&range) {
|
|
if (llvm::empty(range))
|
|
return;
|
|
|
|
// Print the first N-1 elements with a prefix of "|-".
|
|
auto it = std::begin(range);
|
|
for (unsigned i = 0, e = llvm::size(range) - 1; i < e; ++i, ++it)
|
|
print(*it);
|
|
|
|
// Print the last element.
|
|
elementIndentStack.back() = true;
|
|
print(*it);
|
|
}
|
|
template <typename RangeT, typename... OthersT,
|
|
std::enable_if_t<std::is_convertible<RangeT, const Node *>::value>
|
|
* = nullptr>
|
|
void printChildren(RangeT &&range, OthersT &&...others) {
|
|
printChildren(ArrayRef<const Node *>({range, others...}));
|
|
}
|
|
/// Print a range containing children of a node, nesting the children under
|
|
/// the given label.
|
|
template <typename RangeT>
|
|
void printChildren(StringRef label, RangeT &&range) {
|
|
if (llvm::empty(range))
|
|
return;
|
|
elementIndentStack.reserve(elementIndentStack.size() + 1);
|
|
llvm::SaveAndRestore<bool> lastElement(elementIndentStack.back(), true);
|
|
|
|
printIndent();
|
|
os << label << "`\n";
|
|
elementIndentStack.push_back(/*isLastElt*/ false);
|
|
printChildren(std::forward<RangeT>(range));
|
|
elementIndentStack.pop_back();
|
|
}
|
|
|
|
/// Print the given derived node to the stream.
|
|
void printImpl(const CompoundStmt *stmt);
|
|
void printImpl(const EraseStmt *stmt);
|
|
void printImpl(const LetStmt *stmt);
|
|
void printImpl(const ReplaceStmt *stmt);
|
|
void printImpl(const RewriteStmt *stmt);
|
|
|
|
void printImpl(const AttributeExpr *expr);
|
|
void printImpl(const DeclRefExpr *expr);
|
|
void printImpl(const MemberAccessExpr *expr);
|
|
void printImpl(const OperationExpr *expr);
|
|
void printImpl(const TupleExpr *expr);
|
|
void printImpl(const TypeExpr *expr);
|
|
|
|
void printImpl(const AttrConstraintDecl *decl);
|
|
void printImpl(const OpConstraintDecl *decl);
|
|
void printImpl(const TypeConstraintDecl *decl);
|
|
void printImpl(const TypeRangeConstraintDecl *decl);
|
|
void printImpl(const ValueConstraintDecl *decl);
|
|
void printImpl(const ValueRangeConstraintDecl *decl);
|
|
void printImpl(const NamedAttributeDecl *decl);
|
|
void printImpl(const OpNameDecl *decl);
|
|
void printImpl(const PatternDecl *decl);
|
|
void printImpl(const VariableDecl *decl);
|
|
void printImpl(const Module *module);
|
|
|
|
/// Print the current indent stack.
|
|
void printIndent() {
|
|
if (elementIndentStack.empty())
|
|
return;
|
|
|
|
for (bool isLastElt : llvm::makeArrayRef(elementIndentStack).drop_back())
|
|
os << (isLastElt ? " " : " |");
|
|
os << (elementIndentStack.back() ? " `" : " |");
|
|
}
|
|
|
|
/// The raw output stream.
|
|
raw_ostream &os;
|
|
|
|
/// A stack of indents and a flag indicating if the current element being
|
|
/// printed at that indent is the last element.
|
|
SmallVector<bool> elementIndentStack;
|
|
};
|
|
} // namespace
|
|
|
|
void NodePrinter::print(Type type) {
|
|
// Protect against invalid inputs.
|
|
if (!type) {
|
|
os << "Type<NULL>";
|
|
return;
|
|
}
|
|
|
|
TypeSwitch<Type>(type)
|
|
.Case([&](AttributeType) { os << "Attr"; })
|
|
.Case([&](ConstraintType) { os << "Constraint"; })
|
|
.Case([&](OperationType type) {
|
|
os << "Op";
|
|
if (Optional<StringRef> name = type.getName())
|
|
os << "<" << *name << ">";
|
|
})
|
|
.Case([&](RangeType type) {
|
|
print(type.getElementType());
|
|
os << "Range";
|
|
})
|
|
.Case([&](TupleType type) {
|
|
os << "Tuple<";
|
|
llvm::interleaveComma(
|
|
llvm::zip(type.getElementNames(), type.getElementTypes()), os,
|
|
[&](auto it) {
|
|
if (!std::get<0>(it).empty())
|
|
os << std::get<0>(it) << ": ";
|
|
this->print(std::get<1>(it));
|
|
});
|
|
os << ">";
|
|
})
|
|
.Case([&](TypeType) { os << "Type"; })
|
|
.Case([&](ValueType) { os << "Value"; })
|
|
.Default([](Type) { llvm_unreachable("unknown AST type"); });
|
|
}
|
|
|
|
void NodePrinter::print(const Node *node) {
|
|
printIndent();
|
|
os << "-";
|
|
|
|
elementIndentStack.push_back(/*isLastElt*/ false);
|
|
TypeSwitch<const Node *>(node)
|
|
.Case<
|
|
// Statements.
|
|
const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt,
|
|
const RewriteStmt,
|
|
|
|
// Expressions.
|
|
const AttributeExpr, const DeclRefExpr, const MemberAccessExpr,
|
|
const OperationExpr, const TupleExpr, const TypeExpr,
|
|
|
|
// Decls.
|
|
const AttrConstraintDecl, const OpConstraintDecl,
|
|
const TypeConstraintDecl, const TypeRangeConstraintDecl,
|
|
const ValueConstraintDecl, const ValueRangeConstraintDecl,
|
|
const NamedAttributeDecl, const OpNameDecl, const PatternDecl,
|
|
const VariableDecl,
|
|
|
|
const Module>([&](auto derivedNode) { this->printImpl(derivedNode); })
|
|
.Default([](const Node *) { llvm_unreachable("unknown AST node"); });
|
|
elementIndentStack.pop_back();
|
|
}
|
|
|
|
void NodePrinter::printImpl(const CompoundStmt *stmt) {
|
|
os << "CompoundStmt " << stmt << "\n";
|
|
printChildren(stmt->getChildren());
|
|
}
|
|
|
|
void NodePrinter::printImpl(const EraseStmt *stmt) {
|
|
os << "EraseStmt " << stmt << "\n";
|
|
printChildren(stmt->getRootOpExpr());
|
|
}
|
|
|
|
void NodePrinter::printImpl(const LetStmt *stmt) {
|
|
os << "LetStmt " << stmt << "\n";
|
|
printChildren(stmt->getVarDecl());
|
|
}
|
|
|
|
void NodePrinter::printImpl(const ReplaceStmt *stmt) {
|
|
os << "ReplaceStmt " << stmt << "\n";
|
|
printChildren(stmt->getRootOpExpr());
|
|
printChildren("ReplValues", stmt->getReplExprs());
|
|
}
|
|
|
|
void NodePrinter::printImpl(const RewriteStmt *stmt) {
|
|
os << "RewriteStmt " << stmt << "\n";
|
|
printChildren(stmt->getRootOpExpr(), stmt->getRewriteBody());
|
|
}
|
|
|
|
void NodePrinter::printImpl(const AttributeExpr *expr) {
|
|
os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
|
|
}
|
|
|
|
void NodePrinter::printImpl(const DeclRefExpr *expr) {
|
|
os << "DeclRefExpr " << expr << " Type<";
|
|
print(expr->getType());
|
|
os << ">\n";
|
|
printChildren(expr->getDecl());
|
|
}
|
|
|
|
void NodePrinter::printImpl(const MemberAccessExpr *expr) {
|
|
os << "MemberAccessExpr " << expr << " Member<" << expr->getMemberName()
|
|
<< "> Type<";
|
|
print(expr->getType());
|
|
os << ">\n";
|
|
printChildren(expr->getParentExpr());
|
|
}
|
|
|
|
void NodePrinter::printImpl(const OperationExpr *expr) {
|
|
os << "OperationExpr " << expr << " Type<";
|
|
print(expr->getType());
|
|
os << ">\n";
|
|
|
|
printChildren(expr->getNameDecl());
|
|
printChildren("Operands", expr->getOperands());
|
|
printChildren("Result Types", expr->getResultTypes());
|
|
printChildren("Attributes", expr->getAttributes());
|
|
}
|
|
|
|
void NodePrinter::printImpl(const TupleExpr *expr) {
|
|
os << "TupleExpr " << expr << " Type<";
|
|
print(expr->getType());
|
|
os << ">\n";
|
|
|
|
printChildren(expr->getElements());
|
|
}
|
|
|
|
void NodePrinter::printImpl(const TypeExpr *expr) {
|
|
os << "TypeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
|
|
}
|
|
|
|
void NodePrinter::printImpl(const AttrConstraintDecl *decl) {
|
|
os << "AttrConstraintDecl " << decl << "\n";
|
|
if (const auto *typeExpr = decl->getTypeExpr())
|
|
printChildren(typeExpr);
|
|
}
|
|
|
|
void NodePrinter::printImpl(const OpConstraintDecl *decl) {
|
|
os << "OpConstraintDecl " << decl << "\n";
|
|
printChildren(decl->getNameDecl());
|
|
}
|
|
|
|
void NodePrinter::printImpl(const TypeConstraintDecl *decl) {
|
|
os << "TypeConstraintDecl " << decl << "\n";
|
|
}
|
|
|
|
void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) {
|
|
os << "TypeRangeConstraintDecl " << decl << "\n";
|
|
}
|
|
|
|
void NodePrinter::printImpl(const ValueConstraintDecl *decl) {
|
|
os << "ValueConstraintDecl " << decl << "\n";
|
|
if (const auto *typeExpr = decl->getTypeExpr())
|
|
printChildren(typeExpr);
|
|
}
|
|
|
|
void NodePrinter::printImpl(const ValueRangeConstraintDecl *decl) {
|
|
os << "ValueRangeConstraintDecl " << decl << "\n";
|
|
if (const auto *typeExpr = decl->getTypeExpr())
|
|
printChildren(typeExpr);
|
|
}
|
|
|
|
void NodePrinter::printImpl(const NamedAttributeDecl *decl) {
|
|
os << "NamedAttributeDecl " << decl << " Name<" << decl->getName().getName()
|
|
<< ">\n";
|
|
printChildren(decl->getValue());
|
|
}
|
|
|
|
void NodePrinter::printImpl(const OpNameDecl *decl) {
|
|
os << "OpNameDecl " << decl;
|
|
if (Optional<StringRef> name = decl->getName())
|
|
os << " Name<" << name << ">";
|
|
os << "\n";
|
|
}
|
|
|
|
void NodePrinter::printImpl(const PatternDecl *decl) {
|
|
os << "PatternDecl " << decl;
|
|
if (const Name *name = decl->getName())
|
|
os << " Name<" << name->getName() << ">";
|
|
if (Optional<uint16_t> benefit = decl->getBenefit())
|
|
os << " Benefit<" << *benefit << ">";
|
|
if (decl->hasBoundedRewriteRecursion())
|
|
os << " Recursion";
|
|
|
|
os << "\n";
|
|
printChildren(decl->getBody());
|
|
}
|
|
|
|
void NodePrinter::printImpl(const VariableDecl *decl) {
|
|
os << "VariableDecl " << decl << " Name<" << decl->getName().getName()
|
|
<< "> Type<";
|
|
print(decl->getType());
|
|
os << ">\n";
|
|
if (Expr *initExpr = decl->getInitExpr())
|
|
printChildren(initExpr);
|
|
|
|
auto constraints =
|
|
llvm::map_range(decl->getConstraints(),
|
|
[](const ConstraintRef &ref) { return ref.constraint; });
|
|
printChildren("Constraints", constraints);
|
|
}
|
|
|
|
void NodePrinter::printImpl(const Module *module) {
|
|
os << "Module " << module << "\n";
|
|
printChildren(module->getChildren());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Entry point
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void Node::print(raw_ostream &os) const { NodePrinter(os).print(this); }
|
|
|
|
void Type::print(raw_ostream &os) const { NodePrinter(os).print(*this); }
|