llvm-project/mlir/lib/IR/AsmPrinter.cpp

1625 lines
48 KiB
C++

//===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements the MLIR AsmPrinter class, which is used to implement
// the various print() methods on the core IR objects.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSet.h"
#include "mlir/IR/StandardOps.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
using namespace mlir;
void Identifier::print(raw_ostream &os) const { os << str(); }
void Identifier::dump() const { print(llvm::errs()); }
OpAsmPrinter::~OpAsmPrinter() {}
//===----------------------------------------------------------------------===//
// ModuleState
//===----------------------------------------------------------------------===//
namespace {
class ModuleState {
public:
/// This is the operation set for the current context if it is knowable (a
/// context could be determined), otherwise this is null.
OperationSet *const operationSet;
explicit ModuleState(MLIRContext *context)
: operationSet(context ? &OperationSet::get(context) : nullptr) {}
// Initializes module state, populating affine map state.
void initialize(const Module *module);
int getAffineMapId(AffineMap *affineMap) const {
auto it = affineMapIds.find(affineMap);
if (it == affineMapIds.end()) {
return -1;
}
return it->second;
}
ArrayRef<AffineMap *> getAffineMapIds() const { return affineMapsById; }
int getIntegerSetId(IntegerSet *integerSet) const {
auto it = integerSetIds.find(integerSet);
if (it == integerSetIds.end()) {
return -1;
}
return it->second;
}
ArrayRef<IntegerSet *> getIntegerSetIds() const { return integerSetsById; }
private:
void recordAffineMapReference(AffineMap *affineMap) {
if (affineMapIds.count(affineMap) == 0) {
affineMapIds[affineMap] = affineMapsById.size();
affineMapsById.push_back(affineMap);
}
}
void recordIntegerSetReference(IntegerSet *integerSet) {
if (integerSetIds.count(integerSet) == 0) {
integerSetIds[integerSet] = integerSetsById.size();
integerSetsById.push_back(integerSet);
}
}
// Return true if this map could be printed using the shorthand form.
static bool hasShorthandForm(AffineMap *boundMap) {
if (boundMap->isSingleConstant())
return true;
// Check if the affine map is single dim id or single symbol identity -
// (i)->(i) or ()[s]->(i)
return boundMap->getNumInputs() == 1 && boundMap->getNumResults() == 1 &&
(boundMap->getResult(0).isa<AffineDimExprRef>() ||
boundMap->getResult(0).isa<AffineSymbolExprRef>());
}
// Visit functions.
void visitFunction(const Function *fn);
void visitExtFunction(const ExtFunction *fn);
void visitCFGFunction(const CFGFunction *fn);
void visitMLFunction(const MLFunction *fn);
void visitStatement(const Statement *stmt);
void visitForStmt(const ForStmt *forStmt);
void visitIfStmt(const IfStmt *ifStmt);
void visitOperationStmt(const OperationStmt *opStmt);
void visitType(const Type *type);
void visitAttribute(const Attribute *attr);
void visitOperation(const Operation *op);
DenseMap<AffineMap *, int> affineMapIds;
std::vector<AffineMap *> affineMapsById;
DenseMap<IntegerSet *, int> integerSetIds;
std::vector<IntegerSet *> integerSetsById;
};
} // end anonymous namespace
// TODO Support visiting other types/instructions when implemented.
void ModuleState::visitType(const Type *type) {
if (auto *funcType = dyn_cast<FunctionType>(type)) {
// Visit input and result types for functions.
for (auto *input : funcType->getInputs())
visitType(input);
for (auto *result : funcType->getResults())
visitType(result);
} else if (auto *memref = dyn_cast<MemRefType>(type)) {
// Visit affine maps in memref type.
for (auto *map : memref->getAffineMaps()) {
recordAffineMapReference(map);
}
}
}
void ModuleState::visitAttribute(const Attribute *attr) {
if (auto *mapAttr = dyn_cast<AffineMapAttr>(attr)) {
recordAffineMapReference(mapAttr->getValue());
} else if (auto *arrayAttr = dyn_cast<ArrayAttr>(attr)) {
for (auto elt : arrayAttr->getValue()) {
visitAttribute(elt);
}
}
}
void ModuleState::visitOperation(const Operation *op) {
// Visit all the types used in the operation.
for (auto *operand : op->getOperands())
visitType(operand->getType());
for (auto *result : op->getResults())
visitType(result->getType());
// Visit each of the attributes.
for (auto elt : op->getAttrs())
visitAttribute(elt.second);
}
void ModuleState::visitExtFunction(const ExtFunction *fn) {
visitType(fn->getType());
}
void ModuleState::visitCFGFunction(const CFGFunction *fn) {
visitType(fn->getType());
for (auto &block : *fn) {
for (auto &op : block.getOperations()) {
visitOperation(&op);
}
}
}
void ModuleState::visitIfStmt(const IfStmt *ifStmt) {
recordIntegerSetReference(ifStmt->getIntegerSet());
for (auto &childStmt : *ifStmt->getThen())
visitStatement(&childStmt);
if (ifStmt->hasElse())
for (auto &childStmt : *ifStmt->getElse())
visitStatement(&childStmt);
}
void ModuleState::visitForStmt(const ForStmt *forStmt) {
AffineMap *lbMap = forStmt->getLowerBoundMap();
if (!hasShorthandForm(lbMap))
recordAffineMapReference(lbMap);
AffineMap *ubMap = forStmt->getUpperBoundMap();
if (!hasShorthandForm(ubMap))
recordAffineMapReference(ubMap);
for (auto &childStmt : *forStmt)
visitStatement(&childStmt);
}
void ModuleState::visitOperationStmt(const OperationStmt *opStmt) {
for (auto attr : opStmt->getAttrs())
visitAttribute(attr.second);
}
void ModuleState::visitStatement(const Statement *stmt) {
switch (stmt->getKind()) {
case Statement::Kind::If:
return visitIfStmt(cast<IfStmt>(stmt));
case Statement::Kind::For:
return visitForStmt(cast<ForStmt>(stmt));
case Statement::Kind::Operation:
return visitOperationStmt(cast<OperationStmt>(stmt));
default:
return;
}
}
void ModuleState::visitMLFunction(const MLFunction *fn) {
visitType(fn->getType());
for (auto &stmt : *fn) {
ModuleState::visitStatement(&stmt);
}
}
void ModuleState::visitFunction(const Function *fn) {
switch (fn->getKind()) {
case Function::Kind::ExtFunc:
return visitExtFunction(cast<ExtFunction>(fn));
case Function::Kind::CFGFunc:
return visitCFGFunction(cast<CFGFunction>(fn));
case Function::Kind::MLFunc:
return visitMLFunction(cast<MLFunction>(fn));
}
}
// Initializes module state, populating affine map and integer set state.
void ModuleState::initialize(const Module *module) {
for (auto &fn : *module) {
visitFunction(&fn);
}
}
//===----------------------------------------------------------------------===//
// ModulePrinter
//===----------------------------------------------------------------------===//
namespace {
class ModulePrinter {
public:
ModulePrinter(raw_ostream &os, ModuleState &state) : os(os), state(state) {}
explicit ModulePrinter(const ModulePrinter &printer)
: os(printer.os), state(printer.state) {}
template <typename Container, typename UnaryFunctor>
inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
interleave(c.begin(), c.end(), each_fn, [&]() { os << ", "; });
}
void print(const Module *module);
void printFunctionReference(const Function *func);
void printAttribute(const Attribute *attr);
void printType(const Type *type);
void print(const Function *fn);
void print(const ExtFunction *fn);
void print(const CFGFunction *fn);
void print(const MLFunction *fn);
void printAffineMap(AffineMap *map);
void printAffineExpr(AffineExprRef expr);
void printAffineConstraint(AffineExprRef expr, bool isEq);
void printIntegerSet(IntegerSet *set);
protected:
raw_ostream &os;
ModuleState &state;
void printFunctionSignature(const Function *fn);
void printFunctionAttributes(const Function *fn);
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<const char *> elidedAttrs = {});
void printFunctionResultType(const FunctionType *type);
void printAffineMapId(int affineMapId) const;
void printAffineMapReference(AffineMap *affineMap);
void printIntegerSetId(int integerSetId) const;
void printIntegerSetReference(IntegerSet *integerSet);
/// This enum is used to represent the binding stength of the enclosing
/// context that an AffineExpr is being printed in, so we can intelligently
/// produce parens.
enum class BindingStrength {
Weak, // + and -
Strong, // All other binary operators.
};
void printAffineExprInternal(AffineExprRef expr,
BindingStrength enclosingTightness);
};
} // end anonymous namespace
// Prints function with initialized module state.
void ModulePrinter::print(const Function *fn) {
switch (fn->getKind()) {
case Function::Kind::ExtFunc:
return print(cast<ExtFunction>(fn));
case Function::Kind::CFGFunc:
return print(cast<CFGFunction>(fn));
case Function::Kind::MLFunc:
return print(cast<MLFunction>(fn));
}
}
// Prints affine map identifier.
void ModulePrinter::printAffineMapId(int affineMapId) const {
os << "#map" << affineMapId;
}
void ModulePrinter::printAffineMapReference(AffineMap *affineMap) {
int mapId = state.getAffineMapId(affineMap);
if (mapId >= 0) {
// Map will be printed at top of module so print reference to its id.
printAffineMapId(mapId);
} else {
// Map not in module state so print inline.
affineMap->print(os);
}
}
// Prints integer set identifier.
void ModulePrinter::printIntegerSetId(int integerSetId) const {
os << "@@set" << integerSetId;
}
void ModulePrinter::printIntegerSetReference(IntegerSet *integerSet) {
int setId;
if ((setId = state.getIntegerSetId(integerSet)) >= 0) {
// The set will be printed at top of module; so print reference to its id.
printIntegerSetId(setId);
} else {
// Set not in module state so print inline.
integerSet->print(os);
}
}
void ModulePrinter::print(const Module *module) {
for (const auto &map : state.getAffineMapIds()) {
printAffineMapId(state.getAffineMapId(map));
os << " = ";
map->print(os);
os << '\n';
}
for (const auto &set : state.getIntegerSetIds()) {
printIntegerSetId(state.getIntegerSetId(set));
os << " = ";
set->print(os);
os << '\n';
}
for (auto const &fn : *module)
print(&fn);
}
/// Print a floating point value in a way that the parser will be able to
/// round-trip losslessly.
static void printFloatValue(double value, raw_ostream &os) {
APFloat apValue(value);
// We would like to output the FP constant value in exponential notation,
// but we cannot do this if doing so will lose precision. Check here to
// make sure that we only output it in exponential format if we can parse
// the value back and get the same value.
bool isInf = apValue.isInfinity();
bool isNaN = apValue.isNaN();
if (!isInf && !isNaN) {
SmallString<128> strValue;
apValue.toString(strValue, 6, 0, false);
// Check to make sure that the stringized number is not some string like
// "Inf" or NaN, that atof will accept, but the lexer will not. Check
// that the string matches the "[-+]?[0-9]" regex.
assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
((strValue[0] == '-' || strValue[0] == '+') &&
(strValue[1] >= '0' && strValue[1] <= '9'))) &&
"[-+]?[0-9] regex does not match!");
// Reparse stringized version!
if (APFloat(APFloat::IEEEdouble(), strValue).convertToDouble() == value) {
os << strValue;
return;
}
}
// Otherwise, print it in a hexadecimal form. Convert it to an integer so we
// can print it out using integer math.
union {
double doubleValue;
uint64_t integerValue;
};
doubleValue = value;
os << "0x";
// Print out 16 nibbles worth of hex digit.
for (unsigned i = 0; i != 16; ++i) {
os << llvm::hexdigit(integerValue >> 60);
integerValue <<= 4;
}
}
void ModulePrinter::printFunctionReference(const Function *func) {
os << '@' << func->getName();
}
void ModulePrinter::printAttribute(const Attribute *attr) {
switch (attr->getKind()) {
case Attribute::Kind::Bool:
os << (cast<BoolAttr>(attr)->getValue() ? "true" : "false");
break;
case Attribute::Kind::Integer:
os << cast<IntegerAttr>(attr)->getValue();
break;
case Attribute::Kind::Float:
printFloatValue(cast<FloatAttr>(attr)->getValue(), os);
break;
case Attribute::Kind::String:
os << '"';
printEscapedString(cast<StringAttr>(attr)->getValue(), os);
os << '"';
break;
case Attribute::Kind::Array:
os << '[';
interleaveComma(cast<ArrayAttr>(attr)->getValue(),
[&](Attribute *attr) { printAttribute(attr); });
os << ']';
break;
case Attribute::Kind::AffineMap:
printAffineMapReference(cast<AffineMapAttr>(attr)->getValue());
break;
case Attribute::Kind::Type:
printType(cast<TypeAttr>(attr)->getValue());
break;
case Attribute::Kind::Function: {
auto *function = cast<FunctionAttr>(attr)->getValue();
if (!function) {
os << "<<FUNCTION ATTR FOR DELETED FUNCTION>>";
} else {
printFunctionReference(function);
os << " : ";
printType(function->getType());
}
break;
}
}
}
void ModulePrinter::printType(const Type *type) {
switch (type->getKind()) {
case Type::Kind::Index:
os << "index";
return;
case Type::Kind::BF16:
os << "bf16";
return;
case Type::Kind::F16:
os << "f16";
return;
case Type::Kind::F32:
os << "f32";
return;
case Type::Kind::F64:
os << "f64";
return;
case Type::Kind::TFControl:
os << "tf_control";
return;
case Type::Kind::TFResource:
os << "tf_resource";
return;
case Type::Kind::TFVariant:
os << "tf_variant";
return;
case Type::Kind::TFComplex64:
os << "tf_complex64";
return;
case Type::Kind::TFComplex128:
os << "tf_complex128";
return;
case Type::Kind::TFF32REF:
os << "tf_f32ref";
return;
case Type::Kind::TFString:
os << "tf_string";
return;
case Type::Kind::Integer: {
auto *integer = cast<IntegerType>(type);
os << 'i' << integer->getWidth();
return;
}
case Type::Kind::Function: {
auto *func = cast<FunctionType>(type);
os << '(';
interleaveComma(func->getInputs(), [&](Type *type) { printType(type); });
os << ") -> ";
auto results = func->getResults();
if (results.size() == 1)
os << *results[0];
else {
os << '(';
interleaveComma(results, [&](Type *type) { printType(type); });
os << ')';
}
return;
}
case Type::Kind::Vector: {
auto *v = cast<VectorType>(type);
os << "vector<";
for (auto dim : v->getShape())
os << dim << 'x';
os << *v->getElementType() << '>';
return;
}
case Type::Kind::RankedTensor: {
auto *v = cast<RankedTensorType>(type);
os << "tensor<";
for (auto dim : v->getShape()) {
if (dim < 0)
os << '?';
else
os << dim;
os << 'x';
}
os << *v->getElementType() << '>';
return;
}
case Type::Kind::UnrankedTensor: {
auto *v = cast<UnrankedTensorType>(type);
os << "tensor<*x";
printType(v->getElementType());
os << '>';
return;
}
case Type::Kind::MemRef: {
auto *v = cast<MemRefType>(type);
os << "memref<";
for (auto dim : v->getShape()) {
if (dim < 0)
os << '?';
else
os << dim;
os << 'x';
}
printType(v->getElementType());
for (auto map : v->getAffineMaps()) {
os << ", ";
printAffineMapReference(map);
}
// Only print the memory space if it is the non-default one.
if (v->getMemorySpace())
os << ", " << v->getMemorySpace();
os << '>';
return;
}
}
}
//===----------------------------------------------------------------------===//
// Affine expressions and maps
//===----------------------------------------------------------------------===//
void ModulePrinter::printAffineExpr(AffineExprRef expr) {
printAffineExprInternal(expr, BindingStrength::Weak);
}
void ModulePrinter::printAffineExprInternal(
AffineExprRef expr, BindingStrength enclosingTightness) {
const char *binopSpelling = nullptr;
switch (expr->getKind()) {
case AffineExpr::Kind::SymbolId:
os << 's' << expr.cast<AffineSymbolExprRef>()->getPosition();
return;
case AffineExpr::Kind::DimId:
os << 'd' << expr.cast<AffineDimExprRef>()->getPosition();
return;
case AffineExpr::Kind::Constant:
os << expr.cast<AffineConstantExprRef>()->getValue();
return;
case AffineExpr::Kind::Add:
binopSpelling = " + ";
break;
case AffineExpr::Kind::Mul:
binopSpelling = " * ";
break;
case AffineExpr::Kind::FloorDiv:
binopSpelling = " floordiv ";
break;
case AffineExpr::Kind::CeilDiv:
binopSpelling = " ceildiv ";
break;
case AffineExpr::Kind::Mod:
binopSpelling = " mod ";
break;
}
auto binOp = expr.cast<AffineBinaryOpExprRef>();
// Handle tightly binding binary operators.
if (binOp->getKind() != AffineExpr::Kind::Add) {
if (enclosingTightness == BindingStrength::Strong)
os << '(';
printAffineExprInternal(binOp->getLHS(), BindingStrength::Strong);
os << binopSpelling;
printAffineExprInternal(binOp->getRHS(), BindingStrength::Strong);
if (enclosingTightness == BindingStrength::Strong)
os << ')';
return;
}
// Print out special "pretty" forms for add.
if (enclosingTightness == BindingStrength::Strong)
os << '(';
// Pretty print addition to a product that has a negative operand as a
// subtraction.
AffineExprRef rhsExpr = binOp->getRHS();
if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExprRef>()) {
if (rhs->getKind() == AffineExpr::Kind::Mul) {
AffineExprRef rrhsExpr = rhs->getRHS();
if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExprRef>()) {
if (rrhs->getValue() == -1) {
printAffineExprInternal(binOp->getLHS(), BindingStrength::Weak);
os << " - ";
printAffineExprInternal(rhs->getLHS(), BindingStrength::Weak);
if (enclosingTightness == BindingStrength::Strong)
os << ')';
return;
}
if (rrhs->getValue() < -1) {
printAffineExprInternal(binOp->getLHS(), BindingStrength::Weak);
os << " - ";
printAffineExprInternal(rhs->getLHS(), BindingStrength::Strong);
os << " * " << -rrhs->getValue();
if (enclosingTightness == BindingStrength::Strong)
os << ')';
return;
}
}
}
}
// Pretty print addition to a negative number as a subtraction.
if (auto rhs = rhsExpr.dyn_cast<AffineConstantExprRef>()) {
if (rhs->getValue() < 0) {
printAffineExprInternal(binOp->getLHS(), BindingStrength::Weak);
os << " - " << -rhs->getValue();
if (enclosingTightness == BindingStrength::Strong)
os << ')';
return;
}
}
printAffineExprInternal(binOp->getLHS(), BindingStrength::Weak);
os << " + ";
printAffineExprInternal(binOp->getRHS(), BindingStrength::Weak);
if (enclosingTightness == BindingStrength::Strong)
os << ')';
}
void ModulePrinter::printAffineConstraint(AffineExprRef expr, bool isEq) {
printAffineExprInternal(expr, BindingStrength::Weak);
isEq ? os << " == 0" : os << " >= 0";
}
void ModulePrinter::printAffineMap(AffineMap *map) {
// Dimension identifiers.
os << '(';
for (int i = 0; i < (int)map->getNumDims() - 1; ++i)
os << 'd' << i << ", ";
if (map->getNumDims() >= 1)
os << 'd' << map->getNumDims() - 1;
os << ')';
// Symbolic identifiers.
if (map->getNumSymbols() != 0) {
os << '[';
for (unsigned i = 0; i < map->getNumSymbols() - 1; ++i)
os << 's' << i << ", ";
if (map->getNumSymbols() >= 1)
os << 's' << map->getNumSymbols() - 1;
os << ']';
}
// AffineMap should have at least one result.
assert(!map->getResults().empty());
// Result affine expressions.
os << " -> (";
interleaveComma(map->getResults(),
[&](AffineExprRef expr) { printAffineExpr(expr); });
os << ')';
if (!map->isBounded()) {
return;
}
// Print range sizes for bounded affine maps.
os << " size (";
interleaveComma(map->getRangeSizes(),
[&](AffineExprRef expr) { printAffineExpr(expr); });
os << ')';
}
void ModulePrinter::printIntegerSet(IntegerSet *set) {
// Dimension identifiers.
os << '(';
for (unsigned i = 1; i < set->getNumDims(); ++i)
os << 'd' << i - 1 << ", ";
if (set->getNumDims() >= 1)
os << 'd' << set->getNumDims() - 1;
os << ')';
// Symbolic identifiers.
if (set->getNumSymbols() != 0) {
os << '[';
for (unsigned i = 0; i < set->getNumSymbols() - 1; ++i)
os << 's' << i << ", ";
if (set->getNumSymbols() >= 1)
os << 's' << set->getNumSymbols() - 1;
os << ']';
}
// Print constraints.
os << " : (";
auto numConstraints = set->getNumConstraints();
for (int i = 1; i < numConstraints; ++i) {
printAffineConstraint(set->getConstraint(i - 1), set->isEq(i - 1));
os << ", ";
}
if (numConstraints >= 1)
printAffineConstraint(set->getConstraint(numConstraints - 1),
set->isEq(numConstraints - 1));
os << ')';
}
//===----------------------------------------------------------------------===//
// Function printing
//===----------------------------------------------------------------------===//
void ModulePrinter::printFunctionResultType(const FunctionType *type) {
switch (type->getResults().size()) {
case 0:
break;
case 1:
os << " -> ";
printType(type->getResults()[0]);
break;
default:
os << " -> (";
interleaveComma(type->getResults(),
[&](Type *eltType) { printType(eltType); });
os << ')';
break;
}
}
void ModulePrinter::printFunctionAttributes(const Function *fn) {
auto attrs = fn->getAttrs();
if (attrs.empty())
return;
os << "\n attributes ";
printOptionalAttrDict(attrs);
}
void ModulePrinter::printFunctionSignature(const Function *fn) {
auto type = fn->getType();
os << "@" << fn->getName() << '(';
interleaveComma(type->getInputs(),
[&](Type *eltType) { printType(eltType); });
os << ')';
printFunctionResultType(type);
}
void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<const char *> elidedAttrs) {
// If there are no attributes, then there is nothing to be done.
if (attrs.empty())
return;
// Filter out any attributes that shouldn't be included.
SmallVector<NamedAttribute, 8> filteredAttrs;
for (auto attr : attrs) {
auto attrName = attr.first.strref();
// Never print attributes that start with a colon. These are internal
// attributes that represent location or other internal metadata.
if (attrName.startswith(":"))
return;
// If the caller has requested that this attribute be ignored, then drop it.
bool ignore = false;
for (const char *elide : elidedAttrs)
ignore |= attrName == StringRef(elide);
// Otherwise add it to our filteredAttrs list.
if (!ignore) {
filteredAttrs.push_back(attr);
}
}
// If there are no attributes left to print after filtering, then we're done.
if (filteredAttrs.empty())
return;
// Otherwise, print them all out in braces.
os << " {";
interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
os << attr.first << ": ";
printAttribute(attr.second);
});
os << '}';
}
void ModulePrinter::print(const ExtFunction *fn) {
os << "extfunc ";
printFunctionSignature(fn);
printFunctionAttributes(fn);
os << '\n';
}
namespace {
// FunctionPrinter contains common functionality for printing
// CFG and ML functions.
class FunctionPrinter : public ModulePrinter, private OpAsmPrinter {
public:
FunctionPrinter(const ModulePrinter &other) : ModulePrinter(other) {}
void printOperation(const Operation *op);
void printDefaultOp(const Operation *op);
// Implement OpAsmPrinter.
raw_ostream &getStream() const { return os; }
void printType(const Type *type) { ModulePrinter::printType(type); }
void printAttribute(const Attribute *attr) {
ModulePrinter::printAttribute(attr);
}
void printAffineMap(AffineMap *map) {
return ModulePrinter::printAffineMapReference(map);
}
void printIntegerSet(IntegerSet *set) {
return ModulePrinter::printIntegerSetReference(set);
}
void printAffineExpr(AffineExprRef expr) {
return ModulePrinter::printAffineExpr(expr);
}
void printFunctionReference(const Function *func) {
return ModulePrinter::printFunctionReference(func);
}
void printFunctionAttributes(const Function *func) {
return ModulePrinter::printFunctionAttributes(func);
}
void printOperand(const SSAValue *value) { printValueID(value); }
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<const char *> elidedAttrs = {}) {
return ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs);
};
enum { nameSentinel = ~0U };
protected:
void numberValueID(const SSAValue *value) {
assert(!valueIDs.count(value) && "Value numbered multiple times");
SmallString<32> specialNameBuffer;
llvm::raw_svector_ostream specialName(specialNameBuffer);
// Give constant integers special names.
if (auto *op = value->getDefiningOperation()) {
if (auto intOp = op->getAs<ConstantIntOp>()) {
// i1 constants get special names.
if (intOp->getType()->isInteger(1)) {
specialName << (intOp->getValue() ? "true" : "false");
} else {
specialName << 'c' << intOp->getValue() << '_' << *intOp->getType();
}
} else if (auto intOp = op->getAs<ConstantIndexOp>()) {
specialName << 'c' << intOp->getValue();
} else if (auto constant = op->getAs<ConstantOp>()) {
if (isa<FunctionAttr>(constant->getValue()))
specialName << 'f';
else
specialName << "cst";
}
}
if (specialNameBuffer.empty()) {
switch (value->getKind()) {
case SSAValueKind::BBArgument:
// If this is an argument to the function, give it an 'arg' name.
if (auto *bb = cast<BBArgument>(value)->getOwner())
if (auto *fn = bb->getFunction())
if (&fn->front() == bb) {
specialName << "arg" << nextArgumentID++;
break;
}
// Otherwise number it normally.
LLVM_FALLTHROUGH;
case SSAValueKind::InstResult:
case SSAValueKind::StmtResult:
// This is an uninteresting result, give it a boring number and be
// done with it.
valueIDs[value] = nextValueID++;
return;
case SSAValueKind::MLFuncArgument:
specialName << "arg" << nextArgumentID++;
break;
case SSAValueKind::ForStmt:
specialName << 'i' << nextLoopID++;
break;
}
}
// Ok, this value had an interesting name. Remember it with a sentinel.
valueIDs[value] = nameSentinel;
// Remember that we've used this name, checking to see if we had a conflict.
auto insertRes = usedNames.insert(specialName.str());
if (insertRes.second) {
// If this is the first use of the name, then we're successful!
valueNames[value] = insertRes.first->first();
return;
}
// Otherwise, we had a conflict - probe until we find a unique name. This
// is guaranteed to terminate (and usually in a single iteration) because it
// generates new names by incrementing nextConflictID.
while (1) {
std::string probeName =
specialName.str().str() + "_" + llvm::utostr(nextConflictID++);
insertRes = usedNames.insert(probeName);
if (insertRes.second) {
// If this is the first use of the name, then we're successful!
valueNames[value] = insertRes.first->first();
return;
}
}
}
void printValueID(const SSAValue *value, bool printResultNo = true) const {
int resultNo = -1;
auto lookupValue = value;
// If this is a reference to the result of a multi-result instruction or
// statement, print out the # identifier and make sure to map our lookup
// to the first result of the instruction.
if (auto *result = dyn_cast<InstResult>(value)) {
if (result->getOwner()->getNumResults() != 1) {
resultNo = result->getResultNumber();
lookupValue = result->getOwner()->getResult(0);
}
} else if (auto *result = dyn_cast<StmtResult>(value)) {
if (result->getOwner()->getNumResults() != 1) {
resultNo = result->getResultNumber();
lookupValue = result->getOwner()->getResult(0);
}
}
auto it = valueIDs.find(lookupValue);
if (it == valueIDs.end()) {
os << "<<INVALID SSA VALUE>>";
return;
}
os << '%';
if (it->second != nameSentinel) {
os << it->second;
} else {
auto nameIt = valueNames.find(lookupValue);
assert(nameIt != valueNames.end() && "Didn't have a name entry?");
os << nameIt->second;
}
if (resultNo != -1 && printResultNo)
os << '#' << resultNo;
}
private:
/// This is the value ID for each SSA value in the current function. If this
/// returns ~0, then the valueID has an entry in valueNames.
DenseMap<const SSAValue *, unsigned> valueIDs;
DenseMap<const SSAValue *, StringRef> valueNames;
/// This keeps track of all of the non-numeric names that are in flight,
/// allowing us to check for duplicates.
llvm::StringSet<> usedNames;
/// This is the next value ID to assign in numbering.
unsigned nextValueID = 0;
/// This is the ID to assign to the next induction variable.
unsigned nextLoopID = 0;
/// This is the next ID to assign to an MLFunction argument.
unsigned nextArgumentID = 0;
/// This is the next ID to assign when a name conflict is detected.
unsigned nextConflictID = 0;
};
} // end anonymous namespace
void FunctionPrinter::printOperation(const Operation *op) {
if (op->getNumResults()) {
printValueID(op->getResult(0), /*printResultNo=*/false);
os << " = ";
}
// Check to see if this is a known operation. If so, use the registered
// custom printer hook.
if (auto *opInfo = state.operationSet->lookup(op->getName())) {
opInfo->printAssembly(op, this);
return;
}
// Otherwise use the standard verbose printing approach.
printDefaultOp(op);
}
void FunctionPrinter::printDefaultOp(const Operation *op) {
os << '"';
printEscapedString(op->getName(), os);
os << "\"(";
interleaveComma(op->getOperands(),
[&](const SSAValue *value) { printValueID(value); });
os << ')';
auto attrs = op->getAttrs();
printOptionalAttrDict(attrs);
// Print the type signature of the operation.
os << " : (";
interleaveComma(op->getOperands(),
[&](const SSAValue *value) { printType(value->getType()); });
os << ") -> ";
if (op->getNumResults() == 1) {
printType(op->getResult(0)->getType());
} else {
os << '(';
interleaveComma(op->getResults(), [&](const SSAValue *result) {
printType(result->getType());
});
os << ')';
}
}
//===----------------------------------------------------------------------===//
// CFG Function printing
//===----------------------------------------------------------------------===//
namespace {
class CFGFunctionPrinter : public FunctionPrinter {
public:
CFGFunctionPrinter(const CFGFunction *function, const ModulePrinter &other);
const CFGFunction *getFunction() const { return function; }
void print();
void print(const BasicBlock *block);
void print(const Instruction *inst);
void print(const OperationInst *inst);
void print(const ReturnInst *inst);
void print(const BranchInst *inst);
void print(const CondBranchInst *inst);
void printBBName(const BasicBlock *block) { os << "bb" << getBBID(block); }
unsigned getBBID(const BasicBlock *block) {
auto it = basicBlockIDs.find(block);
assert(it != basicBlockIDs.end() && "Block not in this function?");
return it->second;
}
private:
const CFGFunction *function;
DenseMap<const BasicBlock *, unsigned> basicBlockIDs;
void numberValuesInBlock(const BasicBlock *block);
};
} // end anonymous namespace
CFGFunctionPrinter::CFGFunctionPrinter(const CFGFunction *function,
const ModulePrinter &other)
: FunctionPrinter(other), function(function) {
// Each basic block gets a unique ID per function.
unsigned blockID = 0;
for (auto &block : *function) {
basicBlockIDs[&block] = blockID++;
numberValuesInBlock(&block);
}
}
/// Number all of the SSA values in the specified basic block.
void CFGFunctionPrinter::numberValuesInBlock(const BasicBlock *block) {
for (auto *arg : block->getArguments()) {
numberValueID(arg);
}
for (auto &op : *block) {
// We number instruction that have results, and we only number the first
// result.
if (op.getNumResults() != 0)
numberValueID(op.getResult(0));
}
// Terminators do not define values.
}
void CFGFunctionPrinter::print() {
os << "cfgfunc ";
printFunctionSignature(getFunction());
printFunctionAttributes(getFunction());
os << " {\n";
for (auto &block : *function)
print(&block);
os << "}\n\n";
}
void CFGFunctionPrinter::print(const BasicBlock *block) {
printBBName(block);
if (!block->args_empty()) {
os << '(';
interleaveComma(block->getArguments(), [&](const BBArgument *arg) {
printValueID(arg);
os << ": ";
printType(arg->getType());
});
os << ')';
}
os << ':';
// Print out some context information about the predecessors of this block.
if (!block->getFunction()) {
os << "\t// block is not in a function!";
} else if (block->hasNoPredecessors()) {
// Don't print "no predecessors" for the entry block.
if (block != &block->getFunction()->front())
os << "\t// no predecessors";
} else if (auto *pred = block->getSinglePredecessor()) {
os << "\t// pred: ";
printBBName(pred);
} else {
// We want to print the predecessors in increasing numeric order, not in
// whatever order the use-list is in, so gather and sort them.
SmallVector<unsigned, 4> predIDs;
for (auto *pred : block->getPredecessors())
predIDs.push_back(getBBID(pred));
llvm::array_pod_sort(predIDs.begin(), predIDs.end());
os << "\t// " << predIDs.size() << " preds: ";
interleaveComma(predIDs, [&](unsigned predID) { os << "bb" << predID; });
}
os << '\n';
for (auto &inst : block->getOperations()) {
os << " ";
print(&inst);
os << '\n';
}
os << " ";
print(block->getTerminator());
os << '\n';
}
void CFGFunctionPrinter::print(const Instruction *inst) {
if (!inst) {
os << "<<null instruction>>\n";
return;
}
switch (inst->getKind()) {
case Instruction::Kind::Operation:
return print(cast<OperationInst>(inst));
case TerminatorInst::Kind::Branch:
return print(cast<BranchInst>(inst));
case TerminatorInst::Kind::CondBranch:
return print(cast<CondBranchInst>(inst));
case TerminatorInst::Kind::Return:
return print(cast<ReturnInst>(inst));
}
}
void CFGFunctionPrinter::print(const OperationInst *inst) {
printOperation(inst);
}
void CFGFunctionPrinter::print(const BranchInst *inst) {
os << "br ";
printBBName(inst->getDest());
if (inst->getNumOperands() != 0) {
os << '(';
interleaveComma(inst->getOperands(),
[&](const CFGValue *operand) { printValueID(operand); });
os << ") : ";
interleaveComma(inst->getOperands(), [&](const CFGValue *operand) {
printType(operand->getType());
});
}
}
void CFGFunctionPrinter::print(const CondBranchInst *inst) {
os << "cond_br ";
printValueID(inst->getCondition());
os << ", ";
printBBName(inst->getTrueDest());
if (inst->getNumTrueOperands() != 0) {
os << '(';
interleaveComma(inst->getTrueOperands(),
[&](const CFGValue *operand) { printValueID(operand); });
os << " : ";
interleaveComma(inst->getTrueOperands(), [&](const CFGValue *operand) {
printType(operand->getType());
});
os << ")";
}
os << ", ";
printBBName(inst->getFalseDest());
if (inst->getNumFalseOperands() != 0) {
os << '(';
interleaveComma(inst->getFalseOperands(),
[&](const CFGValue *operand) { printValueID(operand); });
os << " : ";
interleaveComma(inst->getFalseOperands(), [&](const CFGValue *operand) {
printType(operand->getType());
});
os << ")";
}
}
void CFGFunctionPrinter::print(const ReturnInst *inst) {
os << "return";
if (inst->getNumOperands() == 0)
return;
os << ' ';
interleaveComma(inst->getOperands(),
[&](const CFGValue *operand) { printValueID(operand); });
os << " : ";
interleaveComma(inst->getOperands(), [&](const CFGValue *operand) {
printType(operand->getType());
});
}
void ModulePrinter::print(const CFGFunction *fn) {
CFGFunctionPrinter(fn, *this).print();
}
//===----------------------------------------------------------------------===//
// ML Function printing
//===----------------------------------------------------------------------===//
namespace {
class MLFunctionPrinter : public FunctionPrinter {
public:
MLFunctionPrinter(const MLFunction *function, const ModulePrinter &other);
const MLFunction *getFunction() const { return function; }
// Prints ML function.
void print();
// Prints ML function signature.
void printFunctionSignature();
// Methods to print ML function statements.
void print(const Statement *stmt);
void print(const OperationStmt *stmt);
void print(const ForStmt *stmt);
void print(const IfStmt *stmt);
void print(const StmtBlock *block);
// Print loop bounds.
void printDimAndSymbolList(ArrayRef<StmtOperand> ops, unsigned numDims);
void printBound(AffineBound bound, const char *prefix);
// Number of spaces used for indenting nested statements.
const static unsigned indentWidth = 2;
private:
void numberValues();
const MLFunction *function;
int numSpaces;
};
} // end anonymous namespace
MLFunctionPrinter::MLFunctionPrinter(const MLFunction *function,
const ModulePrinter &other)
: FunctionPrinter(other), function(function), numSpaces(0) {
assert(function && "Cannot print nullptr function");
numberValues();
}
/// Number all of the SSA values in this ML function.
void MLFunctionPrinter::numberValues() {
// Numbers ML function arguments.
for (auto *arg : function->getArguments())
numberValueID(arg);
// Walks ML function statements and numbers for statements and
// the first result of the operation statements.
struct NumberValuesPass : public StmtWalker<NumberValuesPass> {
NumberValuesPass(MLFunctionPrinter *printer) : printer(printer) {}
void visitOperationStmt(OperationStmt *stmt) {
if (stmt->getNumResults() != 0)
printer->numberValueID(stmt->getResult(0));
}
void visitForStmt(ForStmt *stmt) { printer->numberValueID(stmt); }
MLFunctionPrinter *printer;
};
NumberValuesPass pass(this);
// TODO: it'd be cleaner to have constant visitor instead of using const_cast.
pass.walk(const_cast<MLFunction *>(function));
}
void MLFunctionPrinter::print() {
os << "mlfunc ";
printFunctionSignature();
printFunctionAttributes(getFunction());
os << " {\n";
print(function);
os << "}\n\n";
}
void MLFunctionPrinter::printFunctionSignature() {
auto type = function->getType();
os << "@" << function->getName() << '(';
for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) {
if (i > 0)
os << ", ";
auto *arg = function->getArgument(i);
printOperand(arg);
os << " : ";
printType(arg->getType());
}
os << ")";
printFunctionResultType(type);
}
void MLFunctionPrinter::print(const StmtBlock *block) {
numSpaces += indentWidth;
for (auto &stmt : block->getStatements()) {
print(&stmt);
os << "\n";
}
numSpaces -= indentWidth;
}
void MLFunctionPrinter::print(const Statement *stmt) {
switch (stmt->getKind()) {
case Statement::Kind::Operation:
return print(cast<OperationStmt>(stmt));
case Statement::Kind::For:
return print(cast<ForStmt>(stmt));
case Statement::Kind::If:
return print(cast<IfStmt>(stmt));
}
}
void MLFunctionPrinter::print(const OperationStmt *stmt) {
os.indent(numSpaces);
printOperation(stmt);
}
void MLFunctionPrinter::print(const ForStmt *stmt) {
os.indent(numSpaces) << "for ";
printOperand(stmt);
os << " = ";
printBound(stmt->getLowerBound(), "max");
os << " to ";
printBound(stmt->getUpperBound(), "min");
if (stmt->getStep() != 1)
os << " step " << stmt->getStep();
os << " {\n";
print(static_cast<const StmtBlock *>(stmt));
os.indent(numSpaces) << "}";
}
void MLFunctionPrinter::printDimAndSymbolList(ArrayRef<StmtOperand> ops,
unsigned numDims) {
auto printComma = [&]() { os << ", "; };
os << '(';
interleave(ops.begin(), ops.begin() + numDims,
[&](const StmtOperand &v) { printOperand(v.get()); }, printComma);
os << ')';
if (numDims < ops.size()) {
os << '[';
interleave(ops.begin() + numDims, ops.end(),
[&](const StmtOperand &v) { printOperand(v.get()); },
printComma);
os << ']';
}
}
void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) {
AffineMap *map = bound.getMap();
// Check if this bound should be printed using short-hand notation.
// The decision to restrict printing short-hand notation to trivial cases
// comes from the will to roundtrip MLIR binary -> text -> binary in a
// lossless way.
// Therefore, short-hand parsing and printing is only supported for
// zero-operand constant maps and single symbol operand identity maps.
if (map->getNumResults() == 1) {
AffineExprRef expr = map->getResult(0);
// Print constant bound.
if (map->getNumDims() == 0 && map->getNumSymbols() == 0) {
if (auto constExpr = expr.dyn_cast<AffineConstantExprRef>()) {
os << constExpr->getValue();
return;
}
}
// Print bound that consists of a single SSA symbol if the map is over a
// single symbol.
if (map->getNumDims() == 0 && map->getNumSymbols() == 1) {
if (auto symExpr = expr.dyn_cast<AffineSymbolExprRef>()) {
printOperand(bound.getOperand(0));
return;
}
}
} else {
// Map has multiple results. Print 'min' or 'max' prefix.
os << prefix << ' ';
}
// Print the map and its operands.
printAffineMapReference(map);
printDimAndSymbolList(bound.getStmtOperands(), map->getNumDims());
}
void MLFunctionPrinter::print(const IfStmt *stmt) {
os.indent(numSpaces) << "if ";
IntegerSet *set = stmt->getIntegerSet();
printIntegerSetReference(set);
printDimAndSymbolList(stmt->getStmtOperands(), set->getNumDims());
os << " {\n";
print(stmt->getThen());
os.indent(numSpaces) << "}";
if (stmt->hasElse()) {
os << " else {\n";
print(stmt->getElse());
os.indent(numSpaces) << "}";
}
}
void ModulePrinter::print(const MLFunction *fn) {
MLFunctionPrinter(fn, *this).print();
}
//===----------------------------------------------------------------------===//
// print and dump methods
//===----------------------------------------------------------------------===//
void Attribute::print(raw_ostream &os) const {
ModuleState state(/*no context is known*/ nullptr);
ModulePrinter(os, state).printAttribute(this);
}
void Attribute::dump() const { print(llvm::errs()); }
void Type::print(raw_ostream &os) const {
ModuleState state(getContext());
ModulePrinter(os, state).printType(this);
}
void Type::dump() const { print(llvm::errs()); }
void AffineMap::dump() {
print(llvm::errs());
llvm::errs() << "\n";
}
void AffineExpr::dump() {
print(llvm::errs());
llvm::errs() << "\n";
}
void IntegerSet::dump() {
print(llvm::errs());
llvm::errs() << "\n";
}
void AffineExpr::print(raw_ostream &os) {
ModuleState state(/*no context is known*/ nullptr);
ModulePrinter(os, state).printAffineExpr(this);
}
void AffineMap::print(raw_ostream &os) {
ModuleState state(/*no context is known*/ nullptr);
ModulePrinter(os, state).printAffineMap(this);
}
void IntegerSet::print(raw_ostream &os) {
ModuleState state(/*no context is known*/ nullptr);
ModulePrinter(os, state).printIntegerSet(this);
}
void SSAValue::print(raw_ostream &os) const {
switch (getKind()) {
case SSAValueKind::BBArgument:
// TODO: Improve this.
os << "<bb argument>\n";
return;
case SSAValueKind::InstResult:
return getDefiningInst()->print(os);
case SSAValueKind::MLFuncArgument:
// TODO: Improve this.
os << "<function argument>\n";
return;
case SSAValueKind::StmtResult:
return getDefiningStmt()->print(os);
case SSAValueKind::ForStmt:
return cast<ForStmt>(this)->print(os);
}
}
void SSAValue::dump() const { print(llvm::errs()); }
void Instruction::print(raw_ostream &os) const {
if (!getFunction()) {
os << "<<UNLINKED INSTRUCTION>>\n";
return;
}
ModuleState state(getFunction()->getContext());
ModulePrinter modulePrinter(os, state);
CFGFunctionPrinter(getFunction(), modulePrinter).print(this);
}
void Instruction::dump() const {
print(llvm::errs());
llvm::errs() << "\n";
}
void BasicBlock::print(raw_ostream &os) const {
if (!getFunction()) {
os << "<<UNLINKED BLOCK>>\n";
return;
}
ModuleState state(getFunction()->getContext());
ModulePrinter modulePrinter(os, state);
CFGFunctionPrinter(getFunction(), modulePrinter).print(this);
}
void BasicBlock::dump() const { print(llvm::errs()); }
/// Print out the name of the basic block without printing its body.
void BasicBlock::printAsOperand(raw_ostream &os, bool printType) {
if (!getFunction()) {
os << "<<UNLINKED BLOCK>>\n";
return;
}
ModuleState state(getFunction()->getContext());
ModulePrinter modulePrinter(os, state);
CFGFunctionPrinter(getFunction(), modulePrinter).printBBName(this);
}
void Statement::print(raw_ostream &os) const {
MLFunction *function = findFunction();
if (!function) {
os << "<<UNLINKED STATEMENT>>\n";
return;
}
ModuleState state(function->getContext());
ModulePrinter modulePrinter(os, state);
MLFunctionPrinter(function, modulePrinter).print(this);
}
void Statement::dump() const { print(llvm::errs()); }
void StmtBlock::printBlock(raw_ostream &os) const {
MLFunction *function = findFunction();
ModuleState state(function->getContext());
ModulePrinter modulePrinter(os, state);
MLFunctionPrinter(function, modulePrinter).print(this);
}
void StmtBlock::dumpBlock() const { printBlock(llvm::errs()); }
void Function::print(raw_ostream &os) const {
ModuleState state(getContext());
ModulePrinter(os, state).print(this);
}
void Function::dump() const { print(llvm::errs()); }
void Module::print(raw_ostream &os) const {
ModuleState state(getContext());
state.initialize(this);
ModulePrinter(os, state).print(this);
}
void Module::dump() const { print(llvm::errs()); }