forked from OSchip/llvm-project
1743 lines
52 KiB
C++
1743 lines
52 KiB
C++
//===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
|
|
//
|
|
// Copyright 2019 The MLIR Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
// =============================================================================
|
|
//
|
|
// This file implements the MLIR AsmPrinter class, which is used to implement
|
|
// the various print() methods on the core IR objects.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/Dialect.h"
|
|
#include "mlir/IR/Function.h"
|
|
#include "mlir/IR/IntegerSet.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/Module.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/IR/StandardTypes.h"
|
|
#include "mlir/Support/STLExtras.h"
|
|
#include "llvm/ADT/APFloat.h"
|
|
#include "llvm/ADT/DenseMap.h"
|
|
#include "llvm/ADT/MapVector.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SetVector.h"
|
|
#include "llvm/ADT/SmallString.h"
|
|
#include "llvm/ADT/StringExtras.h"
|
|
#include "llvm/ADT/StringSet.h"
|
|
#include "llvm/Support/CommandLine.h"
|
|
#include "llvm/Support/Regex.h"
|
|
using namespace mlir;
|
|
|
|
void Identifier::print(raw_ostream &os) const { os << str(); }
|
|
|
|
void Identifier::dump() const { print(llvm::errs()); }
|
|
|
|
void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
|
|
|
|
void OperationName::dump() const { print(llvm::errs()); }
|
|
|
|
OpAsmPrinter::~OpAsmPrinter() {}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ModuleState
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// TODO(riverriddle) Rethink this flag when we have a pass that can remove debug
|
|
// info or when we have a system for printer flags.
|
|
static llvm::cl::opt<bool>
|
|
shouldPrintDebugInfoOpt("mlir-print-debuginfo",
|
|
llvm::cl::desc("Print debug info in MLIR output"),
|
|
llvm::cl::init(false));
|
|
|
|
static llvm::cl::opt<bool> printPrettyDebugInfo(
|
|
"mlir-pretty-debuginfo",
|
|
llvm::cl::desc("Print pretty debug info in MLIR output"),
|
|
llvm::cl::init(false));
|
|
|
|
// Use the generic op output form in the function printer even if the custom
|
|
// form is defined.
|
|
static llvm::cl::opt<bool>
|
|
printGenericOpForm("mlir-print-op-generic",
|
|
llvm::cl::desc("Print the generic op form"),
|
|
llvm::cl::init(false), llvm::cl::Hidden);
|
|
|
|
namespace {
|
|
/// A special index constant used for non-kind attribute aliases.
|
|
static constexpr int kNonAttrKindAlias = -1;
|
|
|
|
class ModuleState {
|
|
|
|
public:
|
|
/// This is the current context if it is knowable, otherwise this is null.
|
|
MLIRContext *const context;
|
|
|
|
explicit ModuleState(MLIRContext *context) : context(context) {}
|
|
|
|
// Initializes module state, populating affine map state.
|
|
void initialize(Module *module);
|
|
|
|
Twine getAttributeAlias(Attribute attr) const {
|
|
auto alias = attrToAlias.find(attr);
|
|
if (alias == attrToAlias.end())
|
|
return Twine();
|
|
|
|
// Return the alias for this attribute, along with the index if this was
|
|
// generated by a kind alias.
|
|
int kindIndex = alias->second.second;
|
|
return alias->second.first +
|
|
(kindIndex == kNonAttrKindAlias ? Twine() : Twine(kindIndex));
|
|
}
|
|
|
|
void printAttributeAliases(raw_ostream &os) const {
|
|
auto printAlias = [&](StringRef alias, Attribute attr, int index) {
|
|
os << '#' << alias;
|
|
if (index != kNonAttrKindAlias)
|
|
os << index;
|
|
os << " = " << attr << '\n';
|
|
};
|
|
|
|
// Print all of the attribute kind aliases.
|
|
for (auto &kindAlias : attrKindToAlias) {
|
|
for (unsigned i = 0, e = kindAlias.second.second.size(); i != e; ++i)
|
|
printAlias(kindAlias.second.first, kindAlias.second.second[i], i);
|
|
os << "\n";
|
|
}
|
|
|
|
// In a second pass print all of the remaining attribute aliases that aren't
|
|
// kind aliases.
|
|
for (Attribute attr : usedAttributes) {
|
|
auto alias = attrToAlias.find(attr);
|
|
if (alias != attrToAlias.end() &&
|
|
alias->second.second == kNonAttrKindAlias)
|
|
printAlias(alias->second.first, attr, alias->second.second);
|
|
}
|
|
}
|
|
|
|
StringRef getTypeAlias(Type ty) const { return typeToAlias.lookup(ty); }
|
|
|
|
void printTypeAliases(raw_ostream &os) const {
|
|
for (Type type : usedTypes) {
|
|
auto alias = typeToAlias.find(type);
|
|
if (alias != typeToAlias.end())
|
|
os << '!' << alias->second << " = type " << type << '\n';
|
|
}
|
|
}
|
|
|
|
private:
|
|
void recordAttributeReference(Attribute attr) {
|
|
// Don't recheck attributes that have already been seen or those that
|
|
// already have an alias.
|
|
if (!usedAttributes.insert(attr) || attrToAlias.count(attr))
|
|
return;
|
|
|
|
// If this attribute kind has an alias, then record one for this attribute.
|
|
auto alias = attrKindToAlias.find(static_cast<unsigned>(attr.getKind()));
|
|
if (alias == attrKindToAlias.end())
|
|
return;
|
|
std::pair<StringRef, int> attrAlias(alias->second.first,
|
|
alias->second.second.size());
|
|
attrToAlias.insert({attr, attrAlias});
|
|
alias->second.second.push_back(attr);
|
|
}
|
|
|
|
void recordTypeReference(Type ty) { usedTypes.insert(ty); }
|
|
|
|
// Visit functions.
|
|
void visitOperation(Operation *op);
|
|
void visitType(Type type);
|
|
void visitAttribute(Attribute attr);
|
|
|
|
// Initialize symbol aliases.
|
|
void initializeSymbolAliases();
|
|
|
|
/// Set of attributes known to be used within the module.
|
|
llvm::SetVector<Attribute> usedAttributes;
|
|
|
|
/// Mapping between attribute and a pair comprised of a base alias name and a
|
|
/// count suffix. If the suffix is set to -1, it is not displayed.
|
|
llvm::MapVector<Attribute, std::pair<StringRef, int>> attrToAlias;
|
|
|
|
/// Mapping between attribute kind and a pair comprised of a base alias name
|
|
/// and a unique list of attributes belonging to this kind sorted by location
|
|
/// seen in the module.
|
|
llvm::MapVector<unsigned, std::pair<StringRef, std::vector<Attribute>>>
|
|
attrKindToAlias;
|
|
|
|
/// Set of types known to be used within the module.
|
|
llvm::SetVector<Type> usedTypes;
|
|
|
|
/// A mapping between a type and a given alias.
|
|
DenseMap<Type, StringRef> typeToAlias;
|
|
};
|
|
} // end anonymous namespace
|
|
|
|
// TODO Support visiting other types/operations when implemented.
|
|
void ModuleState::visitType(Type type) {
|
|
recordTypeReference(type);
|
|
if (auto funcType = type.dyn_cast<FunctionType>()) {
|
|
// Visit input and result types for functions.
|
|
for (auto input : funcType.getInputs())
|
|
visitType(input);
|
|
for (auto result : funcType.getResults())
|
|
visitType(result);
|
|
} else if (auto memref = type.dyn_cast<MemRefType>()) {
|
|
// Visit affine maps in memref type.
|
|
for (auto map : memref.getAffineMaps())
|
|
recordAttributeReference(AffineMapAttr::get(map));
|
|
} else if (auto vecOrTensor = type.dyn_cast<VectorOrTensorType>()) {
|
|
visitType(vecOrTensor.getElementType());
|
|
}
|
|
}
|
|
|
|
void ModuleState::visitAttribute(Attribute attr) {
|
|
recordAttributeReference(attr);
|
|
if (auto arrayAttr = attr.dyn_cast<ArrayAttr>())
|
|
for (auto elt : arrayAttr.getValue())
|
|
visitAttribute(elt);
|
|
}
|
|
|
|
void ModuleState::visitOperation(Operation *op) {
|
|
// Visit all the types used in the operation.
|
|
for (auto *operand : op->getOperands())
|
|
visitType(operand->getType());
|
|
for (auto *result : op->getResults())
|
|
visitType(result->getType());
|
|
|
|
// Visit each of the attributes.
|
|
for (auto elt : op->getAttrs())
|
|
visitAttribute(elt.second);
|
|
}
|
|
|
|
// Utility to generate a function to register a symbol alias.
|
|
static bool canRegisterAlias(StringRef name, llvm::StringSet<> &usedAliases) {
|
|
assert(!name.empty() && "expected alias name to be non-empty");
|
|
// TODO(riverriddle) Assert that the provided alias name can be lexed as
|
|
// an identifier.
|
|
|
|
// Check that the alias doesn't contain a '.' character and the name is not
|
|
// already in use.
|
|
return !name.contains('.') && usedAliases.insert(name).second;
|
|
}
|
|
|
|
void ModuleState::initializeSymbolAliases() {
|
|
// Track the identifiers in use for each symbol so that the same identifier
|
|
// isn't used twice.
|
|
llvm::StringSet<> usedAliases;
|
|
|
|
// Get the currently registered dialects.
|
|
auto dialects = context->getRegisteredDialects();
|
|
|
|
// Collect the set of aliases from each dialect.
|
|
SmallVector<std::pair<unsigned, StringRef>, 8> attributeKindAliases;
|
|
SmallVector<std::pair<Attribute, StringRef>, 8> attributeAliases;
|
|
SmallVector<std::pair<Type, StringRef>, 16> typeAliases;
|
|
|
|
// AffineMap/Integer set have specific kind aliases.
|
|
attributeKindAliases.emplace_back(StandardAttributes::AffineMap, "map");
|
|
attributeKindAliases.emplace_back(StandardAttributes::IntegerSet, "set");
|
|
|
|
for (auto *dialect : dialects) {
|
|
dialect->getAttributeKindAliases(attributeKindAliases);
|
|
dialect->getAttributeAliases(attributeAliases);
|
|
dialect->getTypeAliases(typeAliases);
|
|
}
|
|
|
|
// Setup the attribute kind aliases.
|
|
StringRef alias;
|
|
unsigned attrKind;
|
|
for (auto &attrAliasPair : attributeKindAliases) {
|
|
std::tie(attrKind, alias) = attrAliasPair;
|
|
assert(!alias.empty() && "expected non-empty alias string");
|
|
if (!usedAliases.count(alias) && !alias.contains('.'))
|
|
attrKindToAlias.insert({attrKind, {alias, {}}});
|
|
}
|
|
|
|
// Clear the set of used identifiers so that the attribute kind aliases are
|
|
// just a prefix and not the full alias, i.e. there may be some overlap.
|
|
usedAliases.clear();
|
|
|
|
// Register the attribute aliases.
|
|
// Create a regex for the attribute kind alias names, these have a prefix with
|
|
// a counter appended to the end. We prevent normal aliases from having these
|
|
// names to avoid collisions.
|
|
llvm::Regex reservedAttrNames("[0-9]+$");
|
|
|
|
// Attribute value aliases.
|
|
Attribute attr;
|
|
for (auto &attrAliasPair : attributeAliases) {
|
|
std::tie(attr, alias) = attrAliasPair;
|
|
if (!reservedAttrNames.match(alias) && canRegisterAlias(alias, usedAliases))
|
|
attrToAlias.insert({attr, {alias, kNonAttrKindAlias}});
|
|
}
|
|
|
|
// Clear the set of used identifiers as types can have the same identifiers as
|
|
// affine structures.
|
|
usedAliases.clear();
|
|
|
|
// Type aliases.
|
|
for (auto &typeAliasPair : typeAliases)
|
|
if (canRegisterAlias(typeAliasPair.second, usedAliases))
|
|
typeToAlias.insert(typeAliasPair);
|
|
}
|
|
|
|
// Initializes module state, populating affine map and integer set state.
|
|
void ModuleState::initialize(Module *module) {
|
|
// Initialize the symbol aliases.
|
|
initializeSymbolAliases();
|
|
|
|
// Walk the module and visit each operation.
|
|
for (auto &fn : *module) {
|
|
visitType(fn.getType());
|
|
for (auto attr : fn.getAttrs())
|
|
ModuleState::visitAttribute(attr.second);
|
|
for (auto attrList : fn.getAllArgAttrs())
|
|
for (auto attr : attrList.getAttrs())
|
|
ModuleState::visitAttribute(attr.second);
|
|
|
|
fn.walk([&](Operation *op) { ModuleState::visitOperation(op); });
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ModulePrinter
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
class ModulePrinter {
|
|
public:
|
|
ModulePrinter(raw_ostream &os, ModuleState &state) : os(os), state(state) {}
|
|
explicit ModulePrinter(ModulePrinter &printer)
|
|
: os(printer.os), state(printer.state) {}
|
|
|
|
template <typename Container, typename UnaryFunctor>
|
|
inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
|
|
interleave(c.begin(), c.end(), each_fn, [&]() { os << ", "; });
|
|
}
|
|
|
|
void print(Module *module);
|
|
void printFunctionReference(Function *func);
|
|
void printAttributeAndType(Attribute attr) {
|
|
printAttributeOptionalType(attr, /*includeType=*/true);
|
|
}
|
|
void printAttribute(Attribute attr) {
|
|
printAttributeOptionalType(attr, /*includeType=*/false);
|
|
}
|
|
|
|
void printType(Type type);
|
|
void print(Function *fn);
|
|
void printLocation(Location loc);
|
|
|
|
void printAffineMap(AffineMap map);
|
|
void printAffineExpr(AffineExpr expr);
|
|
void printAffineConstraint(AffineExpr expr, bool isEq);
|
|
void printIntegerSet(IntegerSet set);
|
|
|
|
protected:
|
|
raw_ostream &os;
|
|
ModuleState &state;
|
|
|
|
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
|
|
ArrayRef<StringRef> elidedAttrs = {});
|
|
void printAttributeOptionalType(Attribute attr, bool includeType);
|
|
void printTrailingLocation(Location loc);
|
|
void printLocationInternal(Location loc, bool pretty = false);
|
|
void printDenseElementsAttr(DenseElementsAttr attr);
|
|
|
|
/// This enum is used to represent the binding stength of the enclosing
|
|
/// context that an AffineExprStorage is being printed in, so we can
|
|
/// intelligently produce parens.
|
|
enum class BindingStrength {
|
|
Weak, // + and -
|
|
Strong, // All other binary operators.
|
|
};
|
|
void printAffineExprInternal(AffineExpr expr,
|
|
BindingStrength enclosingTightness);
|
|
};
|
|
} // end anonymous namespace
|
|
|
|
void ModulePrinter::printTrailingLocation(Location loc) {
|
|
// Check to see if we are printing debug information.
|
|
if (!shouldPrintDebugInfoOpt)
|
|
return;
|
|
|
|
os << " ";
|
|
printLocation(loc);
|
|
}
|
|
|
|
void ModulePrinter::printLocationInternal(Location loc, bool pretty) {
|
|
switch (loc.getKind()) {
|
|
case Location::Kind::Unknown:
|
|
if (pretty)
|
|
os << "[unknown]";
|
|
else
|
|
os << "unknown";
|
|
break;
|
|
case Location::Kind::FileLineCol: {
|
|
auto fileLoc = loc.cast<FileLineColLoc>();
|
|
auto mayQuote = pretty ? "" : "\"";
|
|
os << mayQuote << fileLoc.getFilename() << mayQuote << ':'
|
|
<< fileLoc.getLine() << ':' << fileLoc.getColumn();
|
|
break;
|
|
}
|
|
case Location::Kind::Name: {
|
|
os << '\"' << loc.cast<NameLoc>().getName() << '\"';
|
|
break;
|
|
}
|
|
case Location::Kind::CallSite: {
|
|
auto callLocation = loc.cast<CallSiteLoc>();
|
|
auto caller = callLocation.getCaller();
|
|
auto callee = callLocation.getCallee();
|
|
if (!pretty)
|
|
os << "callsite(";
|
|
printLocationInternal(callee, pretty);
|
|
if (pretty) {
|
|
if (callee.isa<NameLoc>()) {
|
|
if (caller.isa<FileLineColLoc>()) {
|
|
os << " at ";
|
|
} else {
|
|
os << "\n at ";
|
|
}
|
|
} else {
|
|
os << "\n at ";
|
|
}
|
|
} else {
|
|
os << " at ";
|
|
}
|
|
printLocationInternal(caller, pretty);
|
|
if (!pretty)
|
|
os << ")";
|
|
break;
|
|
}
|
|
case Location::Kind::FusedLocation: {
|
|
auto fusedLoc = loc.cast<FusedLoc>();
|
|
if (!pretty)
|
|
os << "fused";
|
|
if (auto metadata = fusedLoc.getMetadata())
|
|
os << '<' << metadata << '>';
|
|
os << '[';
|
|
interleave(
|
|
fusedLoc.getLocations(),
|
|
[&](Location loc) { printLocationInternal(loc, pretty); },
|
|
[&]() { os << ", "; });
|
|
os << ']';
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
void ModulePrinter::print(Module *module) {
|
|
// Output the aliases at the top level.
|
|
state.printAttributeAliases(os);
|
|
state.printTypeAliases(os);
|
|
|
|
// Print the module.
|
|
for (auto &fn : *module)
|
|
print(&fn);
|
|
}
|
|
|
|
/// Print a floating point value in a way that the parser will be able to
|
|
/// round-trip losslessly.
|
|
static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
|
|
// We would like to output the FP constant value in exponential notation,
|
|
// but we cannot do this if doing so will lose precision. Check here to
|
|
// make sure that we only output it in exponential format if we can parse
|
|
// the value back and get the same value.
|
|
bool isInf = apValue.isInfinity();
|
|
bool isNaN = apValue.isNaN();
|
|
if (!isInf && !isNaN) {
|
|
SmallString<128> strValue;
|
|
apValue.toString(strValue, 6, 0, false);
|
|
|
|
// Check to make sure that the stringized number is not some string like
|
|
// "Inf" or NaN, that atof will accept, but the lexer will not. Check
|
|
// that the string matches the "[-+]?[0-9]" regex.
|
|
assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
|
|
((strValue[0] == '-' || strValue[0] == '+') &&
|
|
(strValue[1] >= '0' && strValue[1] <= '9'))) &&
|
|
"[-+]?[0-9] regex does not match!");
|
|
// Reparse stringized version!
|
|
if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
|
|
os << strValue;
|
|
return;
|
|
}
|
|
}
|
|
|
|
SmallVector<char, 16> str;
|
|
apValue.toString(str);
|
|
os << str;
|
|
}
|
|
|
|
void ModulePrinter::printFunctionReference(Function *func) {
|
|
os << '@' << func->getName();
|
|
}
|
|
|
|
void ModulePrinter::printLocation(Location loc) {
|
|
if (printPrettyDebugInfo) {
|
|
printLocationInternal(loc, /*pretty=*/true);
|
|
} else {
|
|
os << "loc(";
|
|
printLocationInternal(loc);
|
|
os << ')';
|
|
}
|
|
}
|
|
|
|
void ModulePrinter::printAttributeOptionalType(Attribute attr,
|
|
bool includeType) {
|
|
if (!attr) {
|
|
os << "<<NULL ATTRIBUTE>>";
|
|
return;
|
|
}
|
|
|
|
// Check for an alias for this attribute.
|
|
Twine alias = state.getAttributeAlias(attr);
|
|
if (!alias.isTriviallyEmpty()) {
|
|
os << '#' << alias;
|
|
return;
|
|
}
|
|
|
|
switch (attr.getKind()) {
|
|
default:
|
|
// TODO(riverriddle) Support parsing/printing dialect attributes.
|
|
llvm_unreachable("unhandled attribute kind");
|
|
|
|
case StandardAttributes::Unit:
|
|
os << "unit";
|
|
break;
|
|
case StandardAttributes::Bool:
|
|
os << (attr.cast<BoolAttr>().getValue() ? "true" : "false");
|
|
break;
|
|
case StandardAttributes::Integer: {
|
|
auto intAttr = attr.cast<IntegerAttr>();
|
|
// Print all integer attributes as signed unless i1.
|
|
bool isSigned = intAttr.getType().isIndex() ||
|
|
intAttr.getType().getIntOrFloatBitWidth() != 1;
|
|
intAttr.getValue().print(os, isSigned);
|
|
// Print the type.
|
|
if (includeType) {
|
|
os << " : ";
|
|
printType(intAttr.getType());
|
|
}
|
|
break;
|
|
}
|
|
case StandardAttributes::Float: {
|
|
auto floatAttr = attr.cast<FloatAttr>();
|
|
printFloatValue(floatAttr.getValue(), os);
|
|
// Print the type.
|
|
if (includeType) {
|
|
os << " : ";
|
|
printType(floatAttr.getType());
|
|
}
|
|
break;
|
|
}
|
|
case StandardAttributes::String:
|
|
os << '"';
|
|
printEscapedString(attr.cast<StringAttr>().getValue(), os);
|
|
os << '"';
|
|
break;
|
|
case StandardAttributes::Array:
|
|
os << '[';
|
|
interleaveComma(attr.cast<ArrayAttr>().getValue(),
|
|
[&](Attribute attr) { printAttribute(attr); });
|
|
os << ']';
|
|
break;
|
|
case StandardAttributes::AffineMap:
|
|
attr.cast<AffineMapAttr>().getValue().print(os);
|
|
break;
|
|
case StandardAttributes::IntegerSet:
|
|
attr.cast<IntegerSetAttr>().getValue().print(os);
|
|
break;
|
|
case StandardAttributes::Type:
|
|
printType(attr.cast<TypeAttr>().getValue());
|
|
break;
|
|
case StandardAttributes::Function: {
|
|
auto *function = attr.cast<FunctionAttr>().getValue();
|
|
if (!function) {
|
|
os << "<<FUNCTION ATTR FOR DELETED FUNCTION>>";
|
|
} else {
|
|
printFunctionReference(function);
|
|
os << " : ";
|
|
printType(function->getType());
|
|
}
|
|
break;
|
|
}
|
|
case StandardAttributes::OpaqueElements: {
|
|
auto eltsAttr = attr.cast<OpaqueElementsAttr>();
|
|
os << "opaque<";
|
|
os << '"' << eltsAttr.getDialect()->getNamespace() << "\", ";
|
|
printType(eltsAttr.getType());
|
|
os << ", " << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << '"' << '>';
|
|
break;
|
|
}
|
|
case StandardAttributes::DenseIntElements:
|
|
case StandardAttributes::DenseFPElements: {
|
|
auto eltsAttr = attr.cast<DenseElementsAttr>();
|
|
os << "dense<";
|
|
printType(eltsAttr.getType());
|
|
os << ", ";
|
|
printDenseElementsAttr(eltsAttr);
|
|
os << '>';
|
|
break;
|
|
}
|
|
case StandardAttributes::SplatElements: {
|
|
auto elementsAttr = attr.cast<SplatElementsAttr>();
|
|
os << "splat<";
|
|
printType(elementsAttr.getType());
|
|
os << ", ";
|
|
printAttribute(elementsAttr.getValue());
|
|
os << '>';
|
|
break;
|
|
}
|
|
case StandardAttributes::SparseElements: {
|
|
auto elementsAttr = attr.cast<SparseElementsAttr>();
|
|
os << "sparse<";
|
|
printType(elementsAttr.getType());
|
|
os << ", ";
|
|
printDenseElementsAttr(elementsAttr.getIndices());
|
|
os << ", ";
|
|
printDenseElementsAttr(elementsAttr.getValues());
|
|
os << '>';
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
|
|
auto type = attr.getType();
|
|
auto shape = type.getShape();
|
|
auto rank = type.getRank();
|
|
|
|
SmallVector<Attribute, 16> elements;
|
|
attr.getValues(elements);
|
|
|
|
// Special case for 0-d tensors;
|
|
if (rank == 0) {
|
|
printAttribute(elements[0]);
|
|
return;
|
|
}
|
|
|
|
// Special case for degenerate tensors.
|
|
if (elements.empty()) {
|
|
for (int i = 0; i < rank; ++i)
|
|
os << '[';
|
|
for (int i = 0; i < rank; ++i)
|
|
os << ']';
|
|
return;
|
|
}
|
|
|
|
// We use a mixed-radix counter to iterate through the shape. When we bump a
|
|
// non-least-significant digit, we emit a close bracket. When we next emit an
|
|
// element we re-open all closed brackets.
|
|
|
|
// The mixed-radix counter, with radices in 'shape'.
|
|
SmallVector<unsigned, 4> counter(rank, 0);
|
|
// The number of brackets that have been opened and not closed.
|
|
unsigned openBrackets = 0;
|
|
|
|
auto bumpCounter = [&]() {
|
|
// Bump the least significant digit.
|
|
++counter[rank - 1];
|
|
// Iterate backwards bubbling back the increment.
|
|
for (unsigned i = rank - 1; i > 0; --i)
|
|
if (counter[i] >= shape[i]) {
|
|
// Index 'i' is rolled over. Bump (i-1) and close a bracket.
|
|
counter[i] = 0;
|
|
++counter[i - 1];
|
|
--openBrackets;
|
|
os << ']';
|
|
}
|
|
};
|
|
|
|
for (unsigned idx = 0, e = elements.size(); idx != e; ++idx) {
|
|
if (idx != 0)
|
|
os << ", ";
|
|
while (openBrackets++ < rank)
|
|
os << '[';
|
|
openBrackets = rank;
|
|
printAttribute(elements[idx]);
|
|
bumpCounter();
|
|
}
|
|
while (openBrackets-- > 0)
|
|
os << ']';
|
|
}
|
|
|
|
static bool isDialectTypeSimpleEnoughForPrettyForm(StringRef typeName) {
|
|
// The type name must start with an identifier.
|
|
if (typeName.empty() || !isalpha(typeName.front()))
|
|
return false;
|
|
|
|
// Ignore all the characters that are valid in an identifier in the type
|
|
// name.
|
|
typeName =
|
|
typeName.drop_while([](char c) { return llvm::isAlnum(c) || c == '.'; });
|
|
if (typeName.empty())
|
|
return true;
|
|
|
|
// If we got to an unexpected character, then it must be a <>. Check those
|
|
// recursively.
|
|
if (typeName.front() != '<' || typeName.back() != '>')
|
|
return false;
|
|
|
|
SmallVector<char, 8> nestedPunctuation;
|
|
do {
|
|
// If we ran out of characters, then we had a punctuation mismatch.
|
|
if (typeName.empty())
|
|
return false;
|
|
|
|
auto c = typeName.front();
|
|
typeName = typeName.drop_front();
|
|
|
|
switch (c) {
|
|
// We never allow nul characters. This is an EOF indicator for the lexer
|
|
// which we could handle, but isn't important for any known dialect.
|
|
case '\0':
|
|
return false;
|
|
case '<':
|
|
case '[':
|
|
case '(':
|
|
case '{':
|
|
nestedPunctuation.push_back(c);
|
|
continue;
|
|
// Reject types with mismatched brackets.
|
|
case '>':
|
|
if (nestedPunctuation.pop_back_val() != '<')
|
|
return false;
|
|
break;
|
|
case ']':
|
|
if (nestedPunctuation.pop_back_val() != '[')
|
|
return false;
|
|
break;
|
|
case ')':
|
|
if (nestedPunctuation.pop_back_val() != '(')
|
|
return false;
|
|
break;
|
|
case '}':
|
|
if (nestedPunctuation.pop_back_val() != '{')
|
|
return false;
|
|
break;
|
|
default:
|
|
continue;
|
|
}
|
|
|
|
// We're done when the punctuation is fully matched.
|
|
} while (!nestedPunctuation.empty());
|
|
|
|
// If there were extra characters, then we failed.
|
|
return typeName.empty();
|
|
}
|
|
|
|
void ModulePrinter::printType(Type type) {
|
|
// Check for an alias for this type.
|
|
StringRef alias = state.getTypeAlias(type);
|
|
if (!alias.empty()) {
|
|
os << '!' << alias;
|
|
return;
|
|
}
|
|
|
|
auto printDialectType = [&](StringRef dialectName, StringRef typeString) {
|
|
os << '!' << dialectName;
|
|
|
|
// If this type name is simple enough, print it directly in pretty form,
|
|
// otherwise, we print it as an escaped string.
|
|
if (isDialectTypeSimpleEnoughForPrettyForm(typeString)) {
|
|
os << '.' << typeString;
|
|
return;
|
|
}
|
|
|
|
// TODO: escape the type name, it could contain " characters.
|
|
os << "<\"" << typeString << "\">";
|
|
};
|
|
|
|
switch (type.getKind()) {
|
|
default: {
|
|
auto &dialect = type.getDialect();
|
|
|
|
// Ask the dialect to serialize the type to a string.
|
|
std::string typeName;
|
|
{
|
|
llvm::raw_string_ostream typeNameStr(typeName);
|
|
dialect.printType(type, typeNameStr);
|
|
}
|
|
|
|
printDialectType(dialect.getNamespace(), typeName);
|
|
return;
|
|
}
|
|
case Type::Kind::Opaque: {
|
|
auto opaqueTy = type.cast<OpaqueType>();
|
|
printDialectType(opaqueTy.getDialectNamespace(), opaqueTy.getTypeData());
|
|
return;
|
|
}
|
|
case StandardTypes::Index:
|
|
os << "index";
|
|
return;
|
|
case StandardTypes::BF16:
|
|
os << "bf16";
|
|
return;
|
|
case StandardTypes::F16:
|
|
os << "f16";
|
|
return;
|
|
case StandardTypes::F32:
|
|
os << "f32";
|
|
return;
|
|
case StandardTypes::F64:
|
|
os << "f64";
|
|
return;
|
|
|
|
case StandardTypes::Integer: {
|
|
auto integer = type.cast<IntegerType>();
|
|
os << 'i' << integer.getWidth();
|
|
return;
|
|
}
|
|
case Type::Kind::Function: {
|
|
auto func = type.cast<FunctionType>();
|
|
os << '(';
|
|
interleaveComma(func.getInputs(), [&](Type type) { printType(type); });
|
|
os << ") -> ";
|
|
auto results = func.getResults();
|
|
if (results.size() == 1 && !results[0].isa<FunctionType>())
|
|
os << results[0];
|
|
else {
|
|
os << '(';
|
|
interleaveComma(results, [&](Type type) { printType(type); });
|
|
os << ')';
|
|
}
|
|
return;
|
|
}
|
|
case StandardTypes::Vector: {
|
|
auto v = type.cast<VectorType>();
|
|
os << "vector<";
|
|
for (auto dim : v.getShape())
|
|
os << dim << 'x';
|
|
os << v.getElementType() << '>';
|
|
return;
|
|
}
|
|
case StandardTypes::RankedTensor: {
|
|
auto v = type.cast<RankedTensorType>();
|
|
os << "tensor<";
|
|
for (auto dim : v.getShape()) {
|
|
if (dim < 0)
|
|
os << '?';
|
|
else
|
|
os << dim;
|
|
os << 'x';
|
|
}
|
|
os << v.getElementType() << '>';
|
|
return;
|
|
}
|
|
case StandardTypes::UnrankedTensor: {
|
|
auto v = type.cast<UnrankedTensorType>();
|
|
os << "tensor<*x";
|
|
printType(v.getElementType());
|
|
os << '>';
|
|
return;
|
|
}
|
|
case StandardTypes::MemRef: {
|
|
auto v = type.cast<MemRefType>();
|
|
os << "memref<";
|
|
for (auto dim : v.getShape()) {
|
|
if (dim < 0)
|
|
os << '?';
|
|
else
|
|
os << dim;
|
|
os << 'x';
|
|
}
|
|
printType(v.getElementType());
|
|
for (auto map : v.getAffineMaps()) {
|
|
os << ", ";
|
|
printAttribute(AffineMapAttr::get(map));
|
|
}
|
|
// Only print the memory space if it is the non-default one.
|
|
if (v.getMemorySpace())
|
|
os << ", " << v.getMemorySpace();
|
|
os << '>';
|
|
return;
|
|
}
|
|
case StandardTypes::Complex:
|
|
os << "complex<";
|
|
printType(type.cast<ComplexType>().getElementType());
|
|
os << '>';
|
|
return;
|
|
case StandardTypes::Tuple: {
|
|
auto tuple = type.cast<TupleType>();
|
|
os << "tuple<";
|
|
interleaveComma(tuple.getTypes(), [&](Type type) { printType(type); });
|
|
os << '>';
|
|
return;
|
|
}
|
|
case StandardTypes::None:
|
|
os << "none";
|
|
return;
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Affine expressions and maps
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void ModulePrinter::printAffineExpr(AffineExpr expr) {
|
|
printAffineExprInternal(expr, BindingStrength::Weak);
|
|
}
|
|
|
|
void ModulePrinter::printAffineExprInternal(
|
|
AffineExpr expr, BindingStrength enclosingTightness) {
|
|
const char *binopSpelling = nullptr;
|
|
switch (expr.getKind()) {
|
|
case AffineExprKind::SymbolId:
|
|
os << 's' << expr.cast<AffineSymbolExpr>().getPosition();
|
|
return;
|
|
case AffineExprKind::DimId:
|
|
os << 'd' << expr.cast<AffineDimExpr>().getPosition();
|
|
return;
|
|
case AffineExprKind::Constant:
|
|
os << expr.cast<AffineConstantExpr>().getValue();
|
|
return;
|
|
case AffineExprKind::Add:
|
|
binopSpelling = " + ";
|
|
break;
|
|
case AffineExprKind::Mul:
|
|
binopSpelling = " * ";
|
|
break;
|
|
case AffineExprKind::FloorDiv:
|
|
binopSpelling = " floordiv ";
|
|
break;
|
|
case AffineExprKind::CeilDiv:
|
|
binopSpelling = " ceildiv ";
|
|
break;
|
|
case AffineExprKind::Mod:
|
|
binopSpelling = " mod ";
|
|
break;
|
|
}
|
|
|
|
auto binOp = expr.cast<AffineBinaryOpExpr>();
|
|
AffineExpr lhsExpr = binOp.getLHS();
|
|
AffineExpr rhsExpr = binOp.getRHS();
|
|
|
|
// Handle tightly binding binary operators.
|
|
if (binOp.getKind() != AffineExprKind::Add) {
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
os << '(';
|
|
|
|
// Pretty print multiplication with -1.
|
|
auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
|
|
if (rhsConst && rhsConst.getValue() == -1) {
|
|
os << "-";
|
|
printAffineExprInternal(lhsExpr, BindingStrength::Strong);
|
|
return;
|
|
}
|
|
|
|
printAffineExprInternal(lhsExpr, BindingStrength::Strong);
|
|
os << binopSpelling;
|
|
printAffineExprInternal(rhsExpr, BindingStrength::Strong);
|
|
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
os << ')';
|
|
return;
|
|
}
|
|
|
|
// Print out special "pretty" forms for add.
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
os << '(';
|
|
|
|
// Pretty print addition to a product that has a negative operand as a
|
|
// subtraction.
|
|
if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) {
|
|
if (rhs.getKind() == AffineExprKind::Mul) {
|
|
AffineExpr rrhsExpr = rhs.getRHS();
|
|
if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
|
|
if (rrhs.getValue() == -1) {
|
|
printAffineExprInternal(lhsExpr, BindingStrength::Weak);
|
|
os << " - ";
|
|
if (rhs.getLHS().getKind() == AffineExprKind::Add) {
|
|
printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong);
|
|
} else {
|
|
printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak);
|
|
}
|
|
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
os << ')';
|
|
return;
|
|
}
|
|
|
|
if (rrhs.getValue() < -1) {
|
|
printAffineExprInternal(lhsExpr, BindingStrength::Weak);
|
|
os << " - ";
|
|
printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong);
|
|
os << " * " << -rrhs.getValue();
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
os << ')';
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Pretty print addition to a negative number as a subtraction.
|
|
if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
|
|
if (rhsConst.getValue() < 0) {
|
|
printAffineExprInternal(lhsExpr, BindingStrength::Weak);
|
|
os << " - " << -rhsConst.getValue();
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
os << ')';
|
|
return;
|
|
}
|
|
}
|
|
|
|
printAffineExprInternal(lhsExpr, BindingStrength::Weak);
|
|
os << " + ";
|
|
printAffineExprInternal(rhsExpr, BindingStrength::Weak);
|
|
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
os << ')';
|
|
}
|
|
|
|
void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) {
|
|
printAffineExprInternal(expr, BindingStrength::Weak);
|
|
isEq ? os << " == 0" : os << " >= 0";
|
|
}
|
|
|
|
void ModulePrinter::printAffineMap(AffineMap map) {
|
|
// Dimension identifiers.
|
|
os << '(';
|
|
for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
|
|
os << 'd' << i << ", ";
|
|
if (map.getNumDims() >= 1)
|
|
os << 'd' << map.getNumDims() - 1;
|
|
os << ')';
|
|
|
|
// Symbolic identifiers.
|
|
if (map.getNumSymbols() != 0) {
|
|
os << '[';
|
|
for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
|
|
os << 's' << i << ", ";
|
|
if (map.getNumSymbols() >= 1)
|
|
os << 's' << map.getNumSymbols() - 1;
|
|
os << ']';
|
|
}
|
|
|
|
// AffineMap should have at least one result.
|
|
assert(!map.getResults().empty());
|
|
// Result affine expressions.
|
|
os << " -> (";
|
|
interleaveComma(map.getResults(),
|
|
[&](AffineExpr expr) { printAffineExpr(expr); });
|
|
os << ')';
|
|
|
|
if (!map.isBounded()) {
|
|
return;
|
|
}
|
|
|
|
// Print range sizes for bounded affine maps.
|
|
os << " size (";
|
|
interleaveComma(map.getRangeSizes(),
|
|
[&](AffineExpr expr) { printAffineExpr(expr); });
|
|
os << ')';
|
|
}
|
|
|
|
void ModulePrinter::printIntegerSet(IntegerSet set) {
|
|
// Dimension identifiers.
|
|
os << '(';
|
|
for (unsigned i = 1; i < set.getNumDims(); ++i)
|
|
os << 'd' << i - 1 << ", ";
|
|
if (set.getNumDims() >= 1)
|
|
os << 'd' << set.getNumDims() - 1;
|
|
os << ')';
|
|
|
|
// Symbolic identifiers.
|
|
if (set.getNumSymbols() != 0) {
|
|
os << '[';
|
|
for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
|
|
os << 's' << i << ", ";
|
|
if (set.getNumSymbols() >= 1)
|
|
os << 's' << set.getNumSymbols() - 1;
|
|
os << ']';
|
|
}
|
|
|
|
// Print constraints.
|
|
os << " : (";
|
|
int numConstraints = set.getNumConstraints();
|
|
for (int i = 1; i < numConstraints; ++i) {
|
|
printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
|
|
os << ", ";
|
|
}
|
|
if (numConstraints >= 1)
|
|
printAffineConstraint(set.getConstraint(numConstraints - 1),
|
|
set.isEq(numConstraints - 1));
|
|
os << ')';
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Function printing
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
|
|
ArrayRef<StringRef> elidedAttrs) {
|
|
// If there are no attributes, then there is nothing to be done.
|
|
if (attrs.empty())
|
|
return;
|
|
|
|
// Filter out any attributes that shouldn't be included.
|
|
SmallVector<NamedAttribute, 8> filteredAttrs;
|
|
for (auto attr : attrs) {
|
|
// If the caller has requested that this attribute be ignored, then drop it.
|
|
if (llvm::any_of(elidedAttrs,
|
|
[&](StringRef elided) { return attr.first.is(elided); }))
|
|
continue;
|
|
|
|
// Otherwise add it to our filteredAttrs list.
|
|
filteredAttrs.push_back(attr);
|
|
}
|
|
|
|
// If there are no attributes left to print after filtering, then we're done.
|
|
if (filteredAttrs.empty())
|
|
return;
|
|
|
|
// Otherwise, print them all out in braces.
|
|
os << " {";
|
|
interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
|
|
os << attr.first;
|
|
|
|
// Pretty printing elides the attribute value for unit attributes.
|
|
if (attr.second.isa<UnitAttr>())
|
|
return;
|
|
|
|
os << ": ";
|
|
printAttributeAndType(attr.second);
|
|
});
|
|
os << '}';
|
|
}
|
|
|
|
namespace {
|
|
|
|
// FunctionPrinter contains common functionality for printing
|
|
// CFG and ML functions.
|
|
class FunctionPrinter : public ModulePrinter, private OpAsmPrinter {
|
|
public:
|
|
FunctionPrinter(Function *function, ModulePrinter &other);
|
|
|
|
// Prints the function as a whole.
|
|
void print();
|
|
|
|
// Print the function signature.
|
|
void printFunctionSignature();
|
|
|
|
// Methods to print operations.
|
|
void print(Operation *op);
|
|
void print(Block *block, bool printBlockArgs = true,
|
|
bool printBlockTerminator = true);
|
|
|
|
void printOperation(Operation *op);
|
|
void printGenericOp(Operation *op) override;
|
|
|
|
// Implement OpAsmPrinter.
|
|
raw_ostream &getStream() const override { return os; }
|
|
void printType(Type type) override { ModulePrinter::printType(type); }
|
|
void printAttribute(Attribute attr) override {
|
|
ModulePrinter::printAttribute(attr);
|
|
}
|
|
void printAttributeAndType(Attribute attr) override {
|
|
ModulePrinter::printAttributeAndType(attr);
|
|
}
|
|
void printFunctionReference(Function *func) override {
|
|
return ModulePrinter::printFunctionReference(func);
|
|
}
|
|
void printOperand(Value *value) override { printValueID(value); }
|
|
|
|
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
|
|
ArrayRef<StringRef> elidedAttrs = {}) override {
|
|
return ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs);
|
|
};
|
|
|
|
enum { nameSentinel = ~0U };
|
|
|
|
void printBlockName(Block *block) {
|
|
auto id = getBlockID(block);
|
|
if (id != ~0U)
|
|
os << "^bb" << id;
|
|
else
|
|
os << "^INVALIDBLOCK";
|
|
}
|
|
|
|
unsigned getBlockID(Block *block) {
|
|
auto it = blockIDs.find(block);
|
|
return it != blockIDs.end() ? it->second : ~0U;
|
|
}
|
|
|
|
void printSuccessorAndUseList(Operation *term, unsigned index) override;
|
|
|
|
/// Print a region.
|
|
void printRegion(Region &blocks, bool printEntryBlockArgs,
|
|
bool printBlockTerminators) override {
|
|
os << " {\n";
|
|
if (!blocks.empty()) {
|
|
auto *entryBlock = &blocks.front();
|
|
print(entryBlock,
|
|
printEntryBlockArgs && entryBlock->getNumArguments() != 0,
|
|
printBlockTerminators);
|
|
for (auto &b : llvm::drop_begin(blocks.getBlocks(), 1))
|
|
print(&b);
|
|
}
|
|
os.indent(currentIndent) << "}";
|
|
}
|
|
|
|
// Number of spaces used for indenting nested operations.
|
|
const static unsigned indentWidth = 2;
|
|
|
|
protected:
|
|
void numberValueID(Value *value);
|
|
void numberValuesInBlock(Block &block);
|
|
void printValueID(Value *value, bool printResultNo = true) const;
|
|
|
|
private:
|
|
Function *function;
|
|
|
|
/// This is the value ID for each SSA value in the current function. If this
|
|
/// returns ~0, then the valueID has an entry in valueNames.
|
|
DenseMap<Value *, unsigned> valueIDs;
|
|
DenseMap<Value *, StringRef> valueNames;
|
|
|
|
/// This is the block ID for each block in the current function.
|
|
DenseMap<Block *, unsigned> blockIDs;
|
|
|
|
/// This keeps track of all of the non-numeric names that are in flight,
|
|
/// allowing us to check for duplicates.
|
|
llvm::StringSet<> usedNames;
|
|
|
|
// This is the current indentation level for nested structures.
|
|
unsigned currentIndent = 0;
|
|
|
|
/// This is the next value ID to assign in numbering.
|
|
unsigned nextValueID = 0;
|
|
/// This is the ID to assign to the next region entry block argument.
|
|
unsigned nextRegionArgumentID = 0;
|
|
/// This is the next ID to assign to a Function argument.
|
|
unsigned nextArgumentID = 0;
|
|
/// This is the next ID to assign when a name conflict is detected.
|
|
unsigned nextConflictID = 0;
|
|
/// This is the next block ID to assign in numbering.
|
|
unsigned nextBlockID = 0;
|
|
};
|
|
} // end anonymous namespace
|
|
|
|
FunctionPrinter::FunctionPrinter(Function *function, ModulePrinter &other)
|
|
: ModulePrinter(other), function(function) {
|
|
|
|
for (auto &block : *function)
|
|
numberValuesInBlock(block);
|
|
}
|
|
|
|
/// Number all of the SSA values in the specified block. Values get numbered
|
|
/// continuously throughout regions. In particular, we traverse the regions
|
|
/// held by operations and number values in depth-first pre-order.
|
|
void FunctionPrinter::numberValuesInBlock(Block &block) {
|
|
// Each block gets a unique ID, and all of the operations within it get
|
|
// numbered as well.
|
|
blockIDs[&block] = nextBlockID++;
|
|
|
|
for (auto *arg : block.getArguments())
|
|
numberValueID(arg);
|
|
|
|
for (auto &op : block) {
|
|
// We number operation that have results, and we only number the first
|
|
// result.
|
|
if (op.getNumResults() != 0)
|
|
numberValueID(op.getResult(0));
|
|
for (auto ®ion : op.getRegions())
|
|
for (auto &block : region)
|
|
numberValuesInBlock(block);
|
|
}
|
|
}
|
|
|
|
void FunctionPrinter::numberValueID(Value *value) {
|
|
assert(!valueIDs.count(value) && "Value numbered multiple times");
|
|
|
|
SmallString<32> specialNameBuffer;
|
|
llvm::raw_svector_ostream specialName(specialNameBuffer);
|
|
|
|
// Give constant integers special names.
|
|
if (auto *op = value->getDefiningOp()) {
|
|
Attribute cst;
|
|
if (m_Constant(&cst).match(op)) {
|
|
Type type = op->getResult(0)->getType();
|
|
if (auto intCst = cst.dyn_cast<IntegerAttr>()) {
|
|
if (type.isIndex()) {
|
|
specialName << 'c' << intCst.getInt();
|
|
} else if (type.cast<IntegerType>().isInteger(1)) {
|
|
// i1 constants get special names.
|
|
specialName << (intCst.getInt() ? "true" : "false");
|
|
} else {
|
|
specialName << 'c' << intCst.getInt() << '_' << type;
|
|
}
|
|
} else if (cst.isa<FunctionAttr>()) {
|
|
specialName << 'f';
|
|
} else {
|
|
specialName << "cst";
|
|
}
|
|
}
|
|
}
|
|
|
|
if (specialNameBuffer.empty()) {
|
|
switch (value->getKind()) {
|
|
case Value::Kind::BlockArgument:
|
|
// If this is an argument to the function, give it an 'arg' name. If the
|
|
// argument is to an entry block of an operation region, give it an 'i'
|
|
// name.
|
|
if (auto *block = cast<BlockArgument>(value)->getOwner()) {
|
|
auto *parentRegion = block->getParent();
|
|
if (parentRegion && block == &parentRegion->front()) {
|
|
if (parentRegion->getContainingFunction())
|
|
specialName << "arg" << nextArgumentID++;
|
|
else
|
|
specialName << "i" << nextRegionArgumentID++;
|
|
break;
|
|
}
|
|
}
|
|
// Otherwise number it normally.
|
|
valueIDs[value] = nextValueID++;
|
|
return;
|
|
case Value::Kind::OpResult:
|
|
// This is an uninteresting result, give it a boring number and be
|
|
// done with it.
|
|
valueIDs[value] = nextValueID++;
|
|
return;
|
|
}
|
|
}
|
|
|
|
// Ok, this value had an interesting name. Remember it with a sentinel.
|
|
valueIDs[value] = nameSentinel;
|
|
|
|
// Remember that we've used this name, checking to see if we had a conflict.
|
|
auto insertRes = usedNames.insert(specialName.str());
|
|
if (insertRes.second) {
|
|
// If this is the first use of the name, then we're successful!
|
|
valueNames[value] = insertRes.first->first();
|
|
return;
|
|
}
|
|
|
|
// Otherwise, we had a conflict - probe until we find a unique name. This
|
|
// is guaranteed to terminate (and usually in a single iteration) because it
|
|
// generates new names by incrementing nextConflictID.
|
|
while (1) {
|
|
std::string probeName =
|
|
specialName.str().str() + "_" + llvm::utostr(nextConflictID++);
|
|
insertRes = usedNames.insert(probeName);
|
|
if (insertRes.second) {
|
|
// If this is the first use of the name, then we're successful!
|
|
valueNames[value] = insertRes.first->first();
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
void FunctionPrinter::print() {
|
|
printFunctionSignature();
|
|
|
|
// Print out function attributes, if present.
|
|
auto attrs = function->getAttrs();
|
|
if (!attrs.empty()) {
|
|
os << "\n attributes ";
|
|
printOptionalAttrDict(attrs);
|
|
}
|
|
|
|
// Print the trailing location.
|
|
printTrailingLocation(function->getLoc());
|
|
|
|
if (!function->empty()) {
|
|
printRegion(function->getBody(), /*printEntryBlockArgs=*/false,
|
|
/*printBlockTerminators=*/true);
|
|
os << "\n";
|
|
}
|
|
os << '\n';
|
|
}
|
|
|
|
void FunctionPrinter::printFunctionSignature() {
|
|
os << "func @" << function->getName() << '(';
|
|
|
|
auto fnType = function->getType();
|
|
bool isExternal = function->isExternal();
|
|
for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) {
|
|
if (i > 0)
|
|
os << ", ";
|
|
|
|
// If this is an external function, don't print argument labels.
|
|
if (!isExternal) {
|
|
printOperand(function->getArgument(i));
|
|
os << ": ";
|
|
}
|
|
|
|
printType(fnType.getInput(i));
|
|
|
|
// Print the attributes for this argument.
|
|
printOptionalAttrDict(function->getArgAttrs(i));
|
|
}
|
|
os << ')';
|
|
|
|
switch (fnType.getResults().size()) {
|
|
case 0:
|
|
break;
|
|
case 1: {
|
|
os << " -> ";
|
|
auto resultType = fnType.getResults()[0];
|
|
bool resultIsFunc = resultType.isa<FunctionType>();
|
|
if (resultIsFunc)
|
|
os << '(';
|
|
printType(resultType);
|
|
if (resultIsFunc)
|
|
os << ')';
|
|
break;
|
|
}
|
|
default:
|
|
os << " -> (";
|
|
interleaveComma(fnType.getResults(),
|
|
[&](Type eltType) { printType(eltType); });
|
|
os << ')';
|
|
break;
|
|
}
|
|
}
|
|
|
|
void FunctionPrinter::print(Block *block, bool printBlockArgs,
|
|
bool printBlockTerminator) {
|
|
// Print the block label and argument list if requested.
|
|
if (printBlockArgs) {
|
|
os.indent(currentIndent);
|
|
printBlockName(block);
|
|
|
|
// Print the argument list if non-empty.
|
|
if (!block->args_empty()) {
|
|
os << '(';
|
|
interleaveComma(block->getArguments(), [&](BlockArgument *arg) {
|
|
printValueID(arg);
|
|
os << ": ";
|
|
printType(arg->getType());
|
|
});
|
|
os << ')';
|
|
}
|
|
os << ':';
|
|
|
|
// Print out some context information about the predecessors of this block.
|
|
if (!block->getFunction()) {
|
|
os << "\t// block is not in a function!";
|
|
} else if (block->hasNoPredecessors()) {
|
|
os << "\t// no predecessors";
|
|
} else if (auto *pred = block->getSinglePredecessor()) {
|
|
os << "\t// pred: ";
|
|
printBlockName(pred);
|
|
} else {
|
|
// We want to print the predecessors in increasing numeric order, not in
|
|
// whatever order the use-list is in, so gather and sort them.
|
|
SmallVector<std::pair<unsigned, Block *>, 4> predIDs;
|
|
for (auto *pred : block->getPredecessors())
|
|
predIDs.push_back({getBlockID(pred), pred});
|
|
llvm::array_pod_sort(predIDs.begin(), predIDs.end());
|
|
|
|
os << "\t// " << predIDs.size() << " preds: ";
|
|
|
|
interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) {
|
|
printBlockName(pred.second);
|
|
});
|
|
}
|
|
os << '\n';
|
|
}
|
|
|
|
currentIndent += indentWidth;
|
|
auto range = llvm::make_range(
|
|
block->getOperations().begin(),
|
|
std::prev(block->getOperations().end(), printBlockTerminator ? 0 : 1));
|
|
for (auto &op : range) {
|
|
print(&op);
|
|
os << '\n';
|
|
}
|
|
currentIndent -= indentWidth;
|
|
}
|
|
|
|
void FunctionPrinter::print(Operation *op) {
|
|
os.indent(currentIndent);
|
|
printOperation(op);
|
|
printTrailingLocation(op->getLoc());
|
|
}
|
|
|
|
void FunctionPrinter::printValueID(Value *value, bool printResultNo) const {
|
|
int resultNo = -1;
|
|
auto lookupValue = value;
|
|
|
|
// If this is a reference to the result of a multi-result operation or
|
|
// operation, print out the # identifier and make sure to map our lookup
|
|
// to the first result of the operation.
|
|
if (auto *result = dyn_cast<OpResult>(value)) {
|
|
if (result->getOwner()->getNumResults() != 1) {
|
|
resultNo = result->getResultNumber();
|
|
lookupValue = result->getOwner()->getResult(0);
|
|
}
|
|
}
|
|
|
|
auto it = valueIDs.find(lookupValue);
|
|
if (it == valueIDs.end()) {
|
|
os << "<<INVALID SSA VALUE>>";
|
|
return;
|
|
}
|
|
|
|
os << '%';
|
|
if (it->second != nameSentinel) {
|
|
os << it->second;
|
|
} else {
|
|
auto nameIt = valueNames.find(lookupValue);
|
|
assert(nameIt != valueNames.end() && "Didn't have a name entry?");
|
|
os << nameIt->second;
|
|
}
|
|
|
|
if (resultNo != -1 && printResultNo)
|
|
os << '#' << resultNo;
|
|
}
|
|
|
|
void FunctionPrinter::printOperation(Operation *op) {
|
|
if (size_t numResults = op->getNumResults()) {
|
|
printValueID(op->getResult(0), /*printResultNo=*/false);
|
|
if (numResults > 1)
|
|
os << ':' << numResults;
|
|
os << " = ";
|
|
}
|
|
|
|
if (printGenericOpForm)
|
|
return printGenericOp(op);
|
|
|
|
// Check to see if this is a known operation. If so, use the registered
|
|
// custom printer hook.
|
|
if (auto *opInfo = op->getAbstractOperation()) {
|
|
opInfo->printAssembly(op, this);
|
|
return;
|
|
}
|
|
|
|
// Otherwise print with the generic assembly form.
|
|
printGenericOp(op);
|
|
}
|
|
|
|
void FunctionPrinter::printGenericOp(Operation *op) {
|
|
os << '"';
|
|
printEscapedString(op->getName().getStringRef(), os);
|
|
os << "\"(";
|
|
|
|
// Get the list of operands that are not successor operands.
|
|
unsigned totalNumSuccessorOperands = 0;
|
|
unsigned numSuccessors = op->getNumSuccessors();
|
|
for (unsigned i = 0; i < numSuccessors; ++i)
|
|
totalNumSuccessorOperands += op->getNumSuccessorOperands(i);
|
|
unsigned numProperOperands = op->getNumOperands() - totalNumSuccessorOperands;
|
|
SmallVector<Value *, 8> properOperands(
|
|
op->operand_begin(), std::next(op->operand_begin(), numProperOperands));
|
|
|
|
interleaveComma(properOperands, [&](Value *value) { printValueID(value); });
|
|
|
|
os << ')';
|
|
|
|
// For terminators, print the list of successors and their operands.
|
|
if (numSuccessors != 0) {
|
|
os << '[';
|
|
for (unsigned i = 0; i < numSuccessors; ++i) {
|
|
if (i != 0)
|
|
os << ", ";
|
|
printSuccessorAndUseList(op, i);
|
|
}
|
|
os << ']';
|
|
}
|
|
|
|
// Print regions.
|
|
if (op->getNumRegions() != 0) {
|
|
os << " (";
|
|
interleaveComma(op->getRegions(), [&](Region ®ion) {
|
|
printRegion(region, /*printEntryBlockArgs=*/true,
|
|
/*printBlockTerminators=*/true);
|
|
});
|
|
os << ')';
|
|
}
|
|
|
|
auto attrs = op->getAttrs();
|
|
printOptionalAttrDict(attrs);
|
|
|
|
// Print the type signature of the operation.
|
|
os << " : (";
|
|
interleaveComma(properOperands,
|
|
[&](Value *value) { printType(value->getType()); });
|
|
os << ") -> ";
|
|
|
|
if (op->getNumResults() == 1 &&
|
|
!op->getResult(0)->getType().isa<FunctionType>()) {
|
|
printType(op->getResult(0)->getType());
|
|
} else {
|
|
os << '(';
|
|
interleaveComma(op->getResults(),
|
|
[&](Value *result) { printType(result->getType()); });
|
|
os << ')';
|
|
}
|
|
}
|
|
|
|
void FunctionPrinter::printSuccessorAndUseList(Operation *term,
|
|
unsigned index) {
|
|
printBlockName(term->getSuccessor(index));
|
|
|
|
auto succOperands = term->getSuccessorOperands(index);
|
|
if (succOperands.begin() == succOperands.end())
|
|
return;
|
|
|
|
os << '(';
|
|
interleaveComma(succOperands,
|
|
[this](Value *operand) { printValueID(operand); });
|
|
os << " : ";
|
|
interleaveComma(succOperands,
|
|
[this](Value *operand) { printType(operand->getType()); });
|
|
os << ')';
|
|
}
|
|
|
|
// Prints function with initialized module state.
|
|
void ModulePrinter::print(Function *fn) { FunctionPrinter(fn, *this).print(); }
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// print and dump methods
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void Attribute::print(raw_ostream &os) const {
|
|
ModuleState state(/*no context is known*/ nullptr);
|
|
ModulePrinter(os, state).printAttributeAndType(*this);
|
|
}
|
|
|
|
void Attribute::dump() const {
|
|
print(llvm::errs());
|
|
llvm::errs() << "\n";
|
|
}
|
|
|
|
void Type::print(raw_ostream &os) const {
|
|
ModuleState state(getContext());
|
|
ModulePrinter(os, state).printType(*this);
|
|
}
|
|
|
|
void Type::dump() const { print(llvm::errs()); }
|
|
|
|
void AffineMap::dump() const {
|
|
print(llvm::errs());
|
|
llvm::errs() << "\n";
|
|
}
|
|
|
|
void IntegerSet::dump() const {
|
|
print(llvm::errs());
|
|
llvm::errs() << "\n";
|
|
}
|
|
|
|
void AffineExpr::print(raw_ostream &os) const {
|
|
if (expr == nullptr) {
|
|
os << "null affine expr";
|
|
return;
|
|
}
|
|
ModuleState state(getContext());
|
|
ModulePrinter(os, state).printAffineExpr(*this);
|
|
}
|
|
|
|
void AffineExpr::dump() const {
|
|
print(llvm::errs());
|
|
llvm::errs() << "\n";
|
|
}
|
|
|
|
void AffineMap::print(raw_ostream &os) const {
|
|
if (map == nullptr) {
|
|
os << "null affine map";
|
|
return;
|
|
}
|
|
ModuleState state(getContext());
|
|
ModulePrinter(os, state).printAffineMap(*this);
|
|
}
|
|
|
|
void IntegerSet::print(raw_ostream &os) const {
|
|
ModuleState state(/*no context is known*/ nullptr);
|
|
ModulePrinter(os, state).printIntegerSet(*this);
|
|
}
|
|
|
|
void Value::print(raw_ostream &os) {
|
|
switch (getKind()) {
|
|
case Value::Kind::BlockArgument:
|
|
// TODO: Improve this.
|
|
os << "<block argument>\n";
|
|
return;
|
|
case Value::Kind::OpResult:
|
|
return getDefiningOp()->print(os);
|
|
}
|
|
}
|
|
|
|
void Value::dump() { print(llvm::errs()); }
|
|
|
|
void Operation::print(raw_ostream &os) {
|
|
auto *function = getFunction();
|
|
if (!function) {
|
|
os << "<<UNLINKED INSTRUCTION>>\n";
|
|
return;
|
|
}
|
|
|
|
ModuleState state(function->getContext());
|
|
ModulePrinter modulePrinter(os, state);
|
|
FunctionPrinter(function, modulePrinter).print(this);
|
|
}
|
|
|
|
void Operation::dump() {
|
|
print(llvm::errs());
|
|
llvm::errs() << "\n";
|
|
}
|
|
|
|
void Block::print(raw_ostream &os) {
|
|
auto *function = getFunction();
|
|
if (!function) {
|
|
os << "<<UNLINKED BLOCK>>\n";
|
|
return;
|
|
}
|
|
|
|
ModuleState state(function->getContext());
|
|
ModulePrinter modulePrinter(os, state);
|
|
FunctionPrinter(function, modulePrinter).print(this);
|
|
}
|
|
|
|
void Block::dump() { print(llvm::errs()); }
|
|
|
|
/// Print out the name of the block without printing its body.
|
|
void Block::printAsOperand(raw_ostream &os, bool printType) {
|
|
if (!getFunction()) {
|
|
os << "<<UNLINKED BLOCK>>\n";
|
|
return;
|
|
}
|
|
ModuleState state(getFunction()->getContext());
|
|
ModulePrinter modulePrinter(os, state);
|
|
FunctionPrinter(getFunction(), modulePrinter).printBlockName(this);
|
|
}
|
|
|
|
void Function::print(raw_ostream &os) {
|
|
ModuleState state(getContext());
|
|
ModulePrinter(os, state).print(this);
|
|
}
|
|
|
|
void Function::dump() { print(llvm::errs()); }
|
|
|
|
void Module::print(raw_ostream &os) {
|
|
ModuleState state(getContext());
|
|
state.initialize(this);
|
|
ModulePrinter(os, state).print(this);
|
|
}
|
|
|
|
void Module::dump() { print(llvm::errs()); }
|
|
|
|
void Location::print(raw_ostream &os) const {
|
|
ModuleState state(nullptr);
|
|
ModulePrinter(os, state).printLocation(*this);
|
|
}
|
|
|
|
void Location::dump() const { print(llvm::errs()); }
|