forked from OSchip/llvm-project
Refactor the AsmParser to follow the pattern established in the parser:
there is now an explicit state class - which only has one instance per top level FooThing::print call. The FunctionPrinter's now subclass ModulePrinter so they can just call print on their types and other global stuff. This also makes the contract strict that the global FooThing::print calls are the public entrypoints and that the printer implementation is otherwise self contained. No Functionality Change. PiperOrigin-RevId: 205409317
This commit is contained in:
parent
a798b021f9
commit
3b7b3302c7
|
@ -105,8 +105,6 @@ public:
|
|||
return expr->getKind() <= Kind::LAST_AFFINE_BINARY_OP;
|
||||
}
|
||||
|
||||
void print(raw_ostream &os) const;
|
||||
|
||||
protected:
|
||||
explicit AffineBinaryOpExpr(Kind kind, AffineExpr *lhs, AffineExpr *rhs);
|
||||
|
||||
|
@ -143,7 +141,6 @@ public:
|
|||
static bool classof(const AffineExpr *expr) {
|
||||
return expr->getKind() == Kind::DimId;
|
||||
}
|
||||
void print(raw_ostream &os) const;
|
||||
|
||||
private:
|
||||
explicit AffineDimExpr(unsigned position)
|
||||
|
@ -168,7 +165,6 @@ public:
|
|||
static bool classof(const AffineExpr *expr) {
|
||||
return expr->getKind() == Kind::SymbolId;
|
||||
}
|
||||
void print(raw_ostream &os) const;
|
||||
|
||||
private:
|
||||
explicit AffineSymbolExpr(unsigned position)
|
||||
|
@ -189,7 +185,6 @@ public:
|
|||
static bool classof(const AffineExpr *expr) {
|
||||
return expr->getKind() == Kind::Constant;
|
||||
}
|
||||
void print(raw_ostream &os) const;
|
||||
|
||||
private:
|
||||
explicit AffineConstantExpr(int64_t constant)
|
||||
|
|
|
@ -74,8 +74,6 @@ public:
|
|||
static bool classof(const Function *func) {
|
||||
return func->getKind() == Kind::ExtFunc;
|
||||
}
|
||||
|
||||
void print(raw_ostream &os) const;
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -44,8 +44,6 @@ public:
|
|||
static bool classof(const StmtBlock *block) {
|
||||
return block->getStmtBlockKind() == StmtBlockKind::MLFunc;
|
||||
}
|
||||
|
||||
void print(raw_ostream &os) const;
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -38,37 +38,23 @@ void Identifier::print(raw_ostream &os) const { os << str(); }
|
|||
|
||||
void Identifier::dump() const { print(llvm::errs()); }
|
||||
|
||||
template <typename Container, typename UnaryFunctor>
|
||||
inline void interleaveComma(raw_ostream &os, const Container &c,
|
||||
UnaryFunctor each_fn) {
|
||||
interleave(c.begin(), c.end(), each_fn, [&]() { os << ", "; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Module printing
|
||||
// ModuleState
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class ModuleState {
|
||||
public:
|
||||
ModuleState(raw_ostream &os);
|
||||
/// This is the operation set for the current context if it is knowable (a
|
||||
/// context could be determined), otherwise this is null.
|
||||
OperationSet *const operationSet;
|
||||
|
||||
explicit ModuleState(MLIRContext *context)
|
||||
: operationSet(context ? &OperationSet::get(context) : nullptr) {}
|
||||
|
||||
// Initializes module state, populating affine map state.
|
||||
void initialize(const Module *module);
|
||||
|
||||
void print(const Module *module);
|
||||
void print(const Attribute *attr) const;
|
||||
void print(const Type *type) const;
|
||||
void print(const Function *fn);
|
||||
void print(const ExtFunction *fn);
|
||||
void print(const CFGFunction *fn);
|
||||
void print(const MLFunction *fn);
|
||||
|
||||
void recordAffineMapReference(const AffineMap *affineMap) {
|
||||
if (affineMapIds.count(affineMap) == 0) {
|
||||
affineMapIds[affineMap] = nextAffineMapId++;
|
||||
}
|
||||
}
|
||||
|
||||
int getAffineMapId(const AffineMap *affineMap) const {
|
||||
auto it = affineMapIds.find(affineMap);
|
||||
if (it == affineMapIds.end()) {
|
||||
|
@ -77,7 +63,17 @@ public:
|
|||
return it->second;
|
||||
}
|
||||
|
||||
const DenseMap<const AffineMap *, int> &getAffineMapIds() const {
|
||||
return affineMapIds;
|
||||
}
|
||||
|
||||
private:
|
||||
void recordAffineMapReference(const AffineMap *affineMap) {
|
||||
if (affineMapIds.count(affineMap) == 0) {
|
||||
affineMapIds[affineMap] = nextAffineMapId++;
|
||||
}
|
||||
}
|
||||
|
||||
// Visit functions.
|
||||
void visitFunction(const Function *fn);
|
||||
void visitExtFunction(const ExtFunction *fn);
|
||||
|
@ -87,23 +83,11 @@ private:
|
|||
void visitAttribute(const Attribute *attr);
|
||||
void visitOperation(const Operation *op);
|
||||
|
||||
void printAffineMapId(int affineMapId) const;
|
||||
void printAffineMapReference(const AffineMap* affineMap) const;
|
||||
|
||||
raw_ostream &os;
|
||||
DenseMap<const AffineMap *, int> affineMapIds;
|
||||
int nextAffineMapId = 0;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
ModuleState::ModuleState(raw_ostream &os) : os(os) {}
|
||||
|
||||
// Initializes module state, populating affine map state.
|
||||
void ModuleState::initialize(const Module *module) {
|
||||
for (auto fn : module->functionList) {
|
||||
visitFunction(fn);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO Support visiting other types/instructions when implemented.
|
||||
void ModuleState::visitType(const Type *type) {
|
||||
|
@ -171,8 +155,54 @@ void ModuleState::visitFunction(const Function *fn) {
|
|||
}
|
||||
}
|
||||
|
||||
// Initializes module state, populating affine map state.
|
||||
void ModuleState::initialize(const Module *module) {
|
||||
for (auto fn : module->functionList) {
|
||||
visitFunction(fn);
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ModulePrinter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class ModulePrinter {
|
||||
public:
|
||||
ModulePrinter(raw_ostream &os, ModuleState &state) : os(os), state(state) {}
|
||||
explicit ModulePrinter(const ModulePrinter &printer)
|
||||
: os(printer.os), state(printer.state) {}
|
||||
|
||||
template <typename Container, typename UnaryFunctor>
|
||||
inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
|
||||
interleave(c.begin(), c.end(), each_fn, [&]() { os << ", "; });
|
||||
}
|
||||
|
||||
void print(const Module *module);
|
||||
void print(const Attribute *attr) const;
|
||||
void print(const Type *type) const;
|
||||
void print(const Function *fn);
|
||||
void print(const ExtFunction *fn);
|
||||
void print(const CFGFunction *fn);
|
||||
void print(const MLFunction *fn);
|
||||
|
||||
void print(const AffineMap *map);
|
||||
void print(const AffineExpr *expr) const;
|
||||
|
||||
protected:
|
||||
raw_ostream &os;
|
||||
ModuleState &state;
|
||||
|
||||
void printFunctionSignature(const Function *fn);
|
||||
void printAffineMapId(int affineMapId) const;
|
||||
void printAffineMapReference(const AffineMap *affineMap) const;
|
||||
|
||||
void print(const AffineBinaryOpExpr *expr) const;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
// Prints function with initialized module state.
|
||||
void ModuleState::print(const Function *fn) {
|
||||
void ModulePrinter::print(const Function *fn) {
|
||||
switch (fn->getKind()) {
|
||||
case Function::Kind::ExtFunc:
|
||||
return print(cast<ExtFunction>(fn));
|
||||
|
@ -184,12 +214,12 @@ void ModuleState::print(const Function *fn) {
|
|||
}
|
||||
|
||||
// Prints affine map identifier.
|
||||
void ModuleState::printAffineMapId(int affineMapId) const {
|
||||
void ModulePrinter::printAffineMapId(int affineMapId) const {
|
||||
os << "#map" << affineMapId;
|
||||
}
|
||||
|
||||
void ModuleState::printAffineMapReference(const AffineMap* affineMap) const {
|
||||
const int mapId = getAffineMapId(affineMap);
|
||||
void ModulePrinter::printAffineMapReference(const AffineMap *affineMap) const {
|
||||
int mapId = state.getAffineMapId(affineMap);
|
||||
if (mapId >= 0) {
|
||||
// Map will be printed at top of module so print reference to its id.
|
||||
printAffineMapId(mapId);
|
||||
|
@ -199,8 +229,8 @@ void ModuleState::printAffineMapReference(const AffineMap* affineMap) const {
|
|||
}
|
||||
}
|
||||
|
||||
void ModuleState::print(const Module *module) {
|
||||
for (const auto &mapAndId : affineMapIds) {
|
||||
void ModulePrinter::print(const Module *module) {
|
||||
for (const auto &mapAndId : state.getAffineMapIds()) {
|
||||
printAffineMapId(mapAndId.second);
|
||||
os << " = ";
|
||||
mapAndId.first->print(os);
|
||||
|
@ -209,7 +239,7 @@ void ModuleState::print(const Module *module) {
|
|||
for (auto *fn : module->functionList) print(fn);
|
||||
}
|
||||
|
||||
void ModuleState::print(const Attribute *attr) const {
|
||||
void ModulePrinter::print(const Attribute *attr) const {
|
||||
switch (attr->getKind()) {
|
||||
case Attribute::Kind::Bool:
|
||||
os << (cast<BoolAttr>(attr)->getValue() ? "true" : "false");
|
||||
|
@ -228,7 +258,7 @@ void ModuleState::print(const Attribute *attr) const {
|
|||
case Attribute::Kind::Array: {
|
||||
auto elts = cast<ArrayAttr>(attr)->getValue();
|
||||
os << '[';
|
||||
interleaveComma(os, elts, [&](Attribute *attr) { print(attr); });
|
||||
interleaveComma(elts, [&](Attribute *attr) { print(attr); });
|
||||
os << ']';
|
||||
break;
|
||||
}
|
||||
|
@ -238,7 +268,7 @@ void ModuleState::print(const Attribute *attr) const {
|
|||
}
|
||||
}
|
||||
|
||||
void ModuleState::print(const Type *type) const {
|
||||
void ModulePrinter::print(const Type *type) const {
|
||||
switch (type->getKind()) {
|
||||
case Type::Kind::AffineInt:
|
||||
os << "affineint";
|
||||
|
@ -264,14 +294,14 @@ void ModuleState::print(const Type *type) const {
|
|||
case Type::Kind::Function: {
|
||||
auto *func = cast<FunctionType>(type);
|
||||
os << '(';
|
||||
interleaveComma(os, func->getInputs(), [&](Type *type) { os << *type; });
|
||||
interleaveComma(func->getInputs(), [&](Type *type) { os << *type; });
|
||||
os << ") -> ";
|
||||
auto results = func->getResults();
|
||||
if (results.size() == 1)
|
||||
os << *results[0];
|
||||
else {
|
||||
os << '(';
|
||||
interleaveComma(os, results, [&](Type *type) { os << *type; });
|
||||
interleaveComma(results, [&](Type *type) { os << *type; });
|
||||
os << ')';
|
||||
}
|
||||
return;
|
||||
|
@ -323,18 +353,133 @@ void ModuleState::print(const Type *type) const {
|
|||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Affine expressions and maps
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ModulePrinter::print(const AffineExpr *expr) const {
|
||||
switch (expr->getKind()) {
|
||||
case AffineExpr::Kind::SymbolId:
|
||||
os << 's' << cast<AffineSymbolExpr>(expr)->getPosition();
|
||||
return;
|
||||
case AffineExpr::Kind::DimId:
|
||||
os << 'd' << cast<AffineDimExpr>(expr)->getPosition();
|
||||
return;
|
||||
case AffineExpr::Kind::Constant:
|
||||
os << cast<AffineConstantExpr>(expr)->getValue();
|
||||
return;
|
||||
case AffineExpr::Kind::Add:
|
||||
case AffineExpr::Kind::Mul:
|
||||
case AffineExpr::Kind::FloorDiv:
|
||||
case AffineExpr::Kind::CeilDiv:
|
||||
case AffineExpr::Kind::Mod:
|
||||
return print(cast<AffineBinaryOpExpr>(expr));
|
||||
}
|
||||
}
|
||||
|
||||
void ModulePrinter::print(const AffineBinaryOpExpr *expr) const {
|
||||
if (expr->getKind() != AffineExpr::Kind::Add) {
|
||||
os << '(';
|
||||
print(expr->getLHS());
|
||||
switch (expr->getKind()) {
|
||||
case AffineExpr::Kind::Mul:
|
||||
os << " * ";
|
||||
break;
|
||||
case AffineExpr::Kind::FloorDiv:
|
||||
os << " floordiv ";
|
||||
break;
|
||||
case AffineExpr::Kind::CeilDiv:
|
||||
os << " ceildiv ";
|
||||
break;
|
||||
case AffineExpr::Kind::Mod:
|
||||
os << " mod ";
|
||||
break;
|
||||
default:
|
||||
llvm_unreachable("unexpected affine binary op expression");
|
||||
}
|
||||
|
||||
print(expr->getRHS());
|
||||
os << ')';
|
||||
return;
|
||||
}
|
||||
|
||||
// Print out special "pretty" forms for add.
|
||||
os << '(';
|
||||
print(expr->getLHS());
|
||||
|
||||
// Pretty print addition to a product that has a negative operand as a
|
||||
// subtraction.
|
||||
if (auto *rhs = dyn_cast<AffineBinaryOpExpr>(expr->getRHS())) {
|
||||
if (rhs->getKind() == AffineExpr::Kind::Mul) {
|
||||
if (auto *rrhs = dyn_cast<AffineConstantExpr>(rhs->getRHS())) {
|
||||
if (rrhs->getValue() < 0) {
|
||||
os << " - (";
|
||||
print(rhs->getLHS());
|
||||
os << " * " << -rrhs->getValue() << "))";
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pretty print addition to a negative number as a subtraction.
|
||||
if (auto *rhs = dyn_cast<AffineConstantExpr>(expr->getRHS())) {
|
||||
if (rhs->getValue() < 0) {
|
||||
os << " - " << -rhs->getValue() << ")";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
os << " + ";
|
||||
print(expr->getRHS());
|
||||
os << ')';
|
||||
}
|
||||
|
||||
void ModulePrinter::print(const AffineMap *map) {
|
||||
// Dimension identifiers.
|
||||
os << '(';
|
||||
for (int i = 0; i < (int)map->getNumDims() - 1; i++)
|
||||
os << "d" << i << ", ";
|
||||
if (map->getNumDims() >= 1)
|
||||
os << "d" << map->getNumDims() - 1;
|
||||
os << ")";
|
||||
|
||||
// Symbolic identifiers.
|
||||
if (map->getNumSymbols() >= 1) {
|
||||
os << " [";
|
||||
for (int i = 0; i < (int)map->getNumSymbols() - 1; i++)
|
||||
os << "s" << i << ", ";
|
||||
if (map->getNumSymbols() >= 1)
|
||||
os << "s" << map->getNumSymbols() - 1;
|
||||
os << "]";
|
||||
}
|
||||
|
||||
// AffineMap should have at least one result.
|
||||
assert(!map->getResults().empty());
|
||||
// Result affine expressions.
|
||||
os << " -> (";
|
||||
interleaveComma(map->getResults(), [&](AffineExpr *expr) { print(expr); });
|
||||
os << ")";
|
||||
|
||||
if (!map->isBounded()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Print range sizes for bounded affine maps.
|
||||
os << " size (";
|
||||
interleaveComma(map->getRangeSizes(), [&](AffineExpr *expr) { print(expr); });
|
||||
os << ")";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Function printing
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printFunctionSignature(const Function *fn,
|
||||
const ModuleState *moduleState,
|
||||
raw_ostream &os) {
|
||||
void ModulePrinter::printFunctionSignature(const Function *fn) {
|
||||
auto type = fn->getType();
|
||||
|
||||
os << "@" << fn->getName() << '(';
|
||||
interleaveComma(os, type->getInputs(),
|
||||
[&](Type *eltType) { moduleState->print(eltType); });
|
||||
interleaveComma(type->getInputs(), [&](Type *eltType) { print(eltType); });
|
||||
os << ')';
|
||||
|
||||
switch (type->getResults().size()) {
|
||||
|
@ -342,20 +487,19 @@ static void printFunctionSignature(const Function *fn,
|
|||
break;
|
||||
case 1:
|
||||
os << " -> ";
|
||||
moduleState->print(type->getResults()[0]);
|
||||
print(type->getResults()[0]);
|
||||
break;
|
||||
default:
|
||||
os << " -> (";
|
||||
interleaveComma(os, type->getResults(),
|
||||
[&](Type *eltType) { moduleState->print(eltType); });
|
||||
interleaveComma(type->getResults(), [&](Type *eltType) { print(eltType); });
|
||||
os << ')';
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void ModuleState::print(const ExtFunction *fn) {
|
||||
void ModulePrinter::print(const ExtFunction *fn) {
|
||||
os << "extfunc ";
|
||||
printFunctionSignature(fn, this, os);
|
||||
printFunctionSignature(fn);
|
||||
os << '\n';
|
||||
}
|
||||
|
||||
|
@ -363,18 +507,13 @@ namespace {
|
|||
|
||||
// FunctionState contains common functionality for printing
|
||||
// CFG and ML functions.
|
||||
class FunctionState {
|
||||
class FunctionState : public ModulePrinter {
|
||||
public:
|
||||
FunctionState(MLIRContext *context, const ModuleState *moduleState,
|
||||
raw_ostream &os);
|
||||
FunctionState(const ModulePrinter &other) : ModulePrinter(other) {}
|
||||
|
||||
void printOperation(const Operation *op);
|
||||
|
||||
protected:
|
||||
raw_ostream &os;
|
||||
const ModuleState *moduleState;
|
||||
const OperationSet &operationSet;
|
||||
|
||||
void numberValueID(const SSAValue *value) {
|
||||
assert(!valueIDs.count(value) && "Value numbered multiple times");
|
||||
valueIDs[value] = nextValueID++;
|
||||
|
@ -397,12 +536,6 @@ private:
|
|||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
FunctionState::FunctionState(MLIRContext *context,
|
||||
const ModuleState *moduleState, raw_ostream &os)
|
||||
: os(os),
|
||||
moduleState(moduleState),
|
||||
operationSet(OperationSet::get(context)) {}
|
||||
|
||||
void FunctionState::printOperation(const Operation *op) {
|
||||
os << " ";
|
||||
|
||||
|
@ -417,7 +550,7 @@ void FunctionState::printOperation(const Operation *op) {
|
|||
|
||||
// Check to see if this is a known operation. If so, use the registered
|
||||
// custom printer hook.
|
||||
if (auto opInfo = operationSet.lookup(op->getName().str())) {
|
||||
if (auto opInfo = state.operationSet->lookup(op->getName().str())) {
|
||||
opInfo->printAssembly(op, os);
|
||||
return;
|
||||
}
|
||||
|
@ -431,18 +564,18 @@ void FunctionState::printOperation(const Operation *op) {
|
|||
// Operation this check can go away.
|
||||
if (auto *inst = dyn_cast<OperationInst>(op)) {
|
||||
// TODO: Use getOperands() when we have it.
|
||||
interleaveComma(
|
||||
os, inst->getInstOperands(),
|
||||
[&](const InstOperand &operand) { printValueID(operand.get()); });
|
||||
interleaveComma(inst->getInstOperands(), [&](const InstOperand &operand) {
|
||||
printValueID(operand.get());
|
||||
});
|
||||
}
|
||||
|
||||
os << ')';
|
||||
auto attrs = op->getAttrs();
|
||||
if (!attrs.empty()) {
|
||||
os << '{';
|
||||
interleaveComma(os, attrs, [&](NamedAttribute attr) {
|
||||
interleaveComma(attrs, [&](NamedAttribute attr) {
|
||||
os << attr.first << ": ";
|
||||
moduleState->print(attr.second);
|
||||
print(attr.second);
|
||||
});
|
||||
os << '}';
|
||||
}
|
||||
|
@ -453,20 +586,18 @@ void FunctionState::printOperation(const Operation *op) {
|
|||
// Print the type signature of the operation.
|
||||
os << " : (";
|
||||
// TODO: Switch to getOperands() when we have it.
|
||||
interleaveComma(os, inst->getInstOperands(), [&](const InstOperand &op) {
|
||||
moduleState->print(op.get()->getType());
|
||||
});
|
||||
interleaveComma(inst->getInstOperands(),
|
||||
[&](const InstOperand &op) { print(op.get()->getType()); });
|
||||
os << ") -> ";
|
||||
|
||||
// TODO: Switch to getResults() when we have it.
|
||||
if (inst->getNumResults() == 1) {
|
||||
moduleState->print(inst->getInstResult(0).getType());
|
||||
print(inst->getInstResult(0).getType());
|
||||
} else {
|
||||
os << '(';
|
||||
interleaveComma(os, inst->getInstResults(),
|
||||
[&](const InstResult &result) {
|
||||
moduleState->print(result.getType());
|
||||
});
|
||||
interleaveComma(inst->getInstResults(), [&](const InstResult &result) {
|
||||
print(result.getType());
|
||||
});
|
||||
os << ')';
|
||||
}
|
||||
}
|
||||
|
@ -477,10 +608,9 @@ void FunctionState::printOperation(const Operation *op) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class CFGFunctionState : public FunctionState {
|
||||
class CFGFunctionPrinter : public FunctionState {
|
||||
public:
|
||||
CFGFunctionState(const CFGFunction *function, const ModuleState *moduleState,
|
||||
raw_ostream &os);
|
||||
CFGFunctionPrinter(const CFGFunction *function, const ModulePrinter &other);
|
||||
|
||||
const CFGFunction *getFunction() const { return function; }
|
||||
|
||||
|
@ -502,25 +632,23 @@ private:
|
|||
const CFGFunction *function;
|
||||
DenseMap<const BasicBlock *, unsigned> basicBlockIDs;
|
||||
|
||||
void numberBlock(const BasicBlock *block);
|
||||
void numberValuesInBlock(const BasicBlock *block);
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
CFGFunctionState::CFGFunctionState(const CFGFunction *function,
|
||||
const ModuleState *moduleState,
|
||||
raw_ostream &os)
|
||||
: FunctionState(function->getContext(), moduleState, os),
|
||||
function(function) {
|
||||
CFGFunctionPrinter::CFGFunctionPrinter(const CFGFunction *function,
|
||||
const ModulePrinter &other)
|
||||
: FunctionState(other), function(function) {
|
||||
// Each basic block gets a unique ID per function.
|
||||
unsigned blockID = 0;
|
||||
for (auto &block : *function) {
|
||||
basicBlockIDs[&block] = blockID++;
|
||||
numberBlock(&block);
|
||||
numberValuesInBlock(&block);
|
||||
}
|
||||
}
|
||||
|
||||
/// Number all of the SSA values in the specified basic block.
|
||||
void CFGFunctionState::numberBlock(const BasicBlock *block) {
|
||||
void CFGFunctionPrinter::numberValuesInBlock(const BasicBlock *block) {
|
||||
// TODO: basic block arguments.
|
||||
for (auto &op : *block) {
|
||||
// We number instruction that have results, and we only number the first
|
||||
|
@ -532,16 +660,16 @@ void CFGFunctionState::numberBlock(const BasicBlock *block) {
|
|||
// Terminators do not define values.
|
||||
}
|
||||
|
||||
void CFGFunctionState::print() {
|
||||
void CFGFunctionPrinter::print() {
|
||||
os << "cfgfunc ";
|
||||
printFunctionSignature(this->getFunction(), moduleState, os);
|
||||
printFunctionSignature(getFunction());
|
||||
os << " {\n";
|
||||
|
||||
for (auto &block : *function) print(&block);
|
||||
os << "}\n\n";
|
||||
}
|
||||
|
||||
void CFGFunctionState::print(const BasicBlock *block) {
|
||||
void CFGFunctionPrinter::print(const BasicBlock *block) {
|
||||
os << "bb" << getBBID(block) << ":\n";
|
||||
|
||||
// TODO Print arguments.
|
||||
|
@ -554,7 +682,7 @@ void CFGFunctionState::print(const BasicBlock *block) {
|
|||
os << "\n";
|
||||
}
|
||||
|
||||
void CFGFunctionState::print(const Instruction *inst) {
|
||||
void CFGFunctionPrinter::print(const Instruction *inst) {
|
||||
switch (inst->getKind()) {
|
||||
case Instruction::Kind::Operation:
|
||||
return print(cast<OperationInst>(inst));
|
||||
|
@ -565,17 +693,16 @@ void CFGFunctionState::print(const Instruction *inst) {
|
|||
}
|
||||
}
|
||||
|
||||
void CFGFunctionState::print(const OperationInst *inst) {
|
||||
void CFGFunctionPrinter::print(const OperationInst *inst) {
|
||||
printOperation(inst);
|
||||
}
|
||||
void CFGFunctionState::print(const BranchInst *inst) {
|
||||
void CFGFunctionPrinter::print(const BranchInst *inst) {
|
||||
os << " br bb" << getBBID(inst->getDest());
|
||||
}
|
||||
void CFGFunctionState::print(const ReturnInst *inst) { os << " return"; }
|
||||
void CFGFunctionPrinter::print(const ReturnInst *inst) { os << " return"; }
|
||||
|
||||
void ModuleState::print(const CFGFunction *fn) {
|
||||
CFGFunctionState state(fn, this, os);
|
||||
state.print();
|
||||
void ModulePrinter::print(const CFGFunction *fn) {
|
||||
CFGFunctionPrinter(fn, *this).print();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -583,10 +710,9 @@ void ModuleState::print(const CFGFunction *fn) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class MLFunctionState : public FunctionState {
|
||||
class MLFunctionPrinter : public FunctionState {
|
||||
public:
|
||||
MLFunctionState(const MLFunction *function, const ModuleState *moduleState,
|
||||
raw_ostream &os);
|
||||
MLFunctionPrinter(const MLFunction *function, const ModulePrinter &other);
|
||||
|
||||
const MLFunction *getFunction() const { return function; }
|
||||
|
||||
|
@ -609,24 +735,21 @@ private:
|
|||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
MLFunctionState::MLFunctionState(const MLFunction *function,
|
||||
const ModuleState *moduleState,
|
||||
raw_ostream &os)
|
||||
: FunctionState(function->getContext(), moduleState, os),
|
||||
function(function),
|
||||
numSpaces(0) {}
|
||||
MLFunctionPrinter::MLFunctionPrinter(const MLFunction *function,
|
||||
const ModulePrinter &other)
|
||||
: FunctionState(other), function(function), numSpaces(0) {}
|
||||
|
||||
void MLFunctionState::print() {
|
||||
void MLFunctionPrinter::print() {
|
||||
os << "mlfunc ";
|
||||
// FIXME: should print argument names rather than just signature
|
||||
printFunctionSignature(function, moduleState, os);
|
||||
printFunctionSignature(function);
|
||||
os << " {\n";
|
||||
print(function);
|
||||
os << " return\n";
|
||||
os << "}\n\n";
|
||||
}
|
||||
|
||||
void MLFunctionState::print(const StmtBlock *block) {
|
||||
void MLFunctionPrinter::print(const StmtBlock *block) {
|
||||
numSpaces += indentWidth;
|
||||
for (auto &stmt : block->getStatements()) {
|
||||
print(&stmt);
|
||||
|
@ -635,7 +758,7 @@ void MLFunctionState::print(const StmtBlock *block) {
|
|||
numSpaces -= indentWidth;
|
||||
}
|
||||
|
||||
void MLFunctionState::print(const Statement *stmt) {
|
||||
void MLFunctionPrinter::print(const Statement *stmt) {
|
||||
switch (stmt->getKind()) {
|
||||
case Statement::Kind::Operation:
|
||||
return print(cast<OperationStmt>(stmt));
|
||||
|
@ -646,9 +769,11 @@ void MLFunctionState::print(const Statement *stmt) {
|
|||
}
|
||||
}
|
||||
|
||||
void MLFunctionState::print(const OperationStmt *stmt) { printOperation(stmt); }
|
||||
void MLFunctionPrinter::print(const OperationStmt *stmt) {
|
||||
printOperation(stmt);
|
||||
}
|
||||
|
||||
void MLFunctionState::print(const ForStmt *stmt) {
|
||||
void MLFunctionPrinter::print(const ForStmt *stmt) {
|
||||
os.indent(numSpaces) << "for x = " << *stmt->getLowerBound();
|
||||
os << " to " << *stmt->getUpperBound();
|
||||
if (stmt->getStep()->getValue() != 1)
|
||||
|
@ -659,7 +784,7 @@ void MLFunctionState::print(const ForStmt *stmt) {
|
|||
os.indent(numSpaces) << "}";
|
||||
}
|
||||
|
||||
void MLFunctionState::print(const IfStmt *stmt) {
|
||||
void MLFunctionPrinter::print(const IfStmt *stmt) {
|
||||
os.indent(numSpaces) << "if () {\n";
|
||||
print(stmt->getThenClause());
|
||||
os.indent(numSpaces) << "}";
|
||||
|
@ -670,9 +795,8 @@ void MLFunctionState::print(const IfStmt *stmt) {
|
|||
}
|
||||
}
|
||||
|
||||
void ModuleState::print(const MLFunction *fn) {
|
||||
MLFunctionState state(fn, this, os);
|
||||
state.print();
|
||||
void ModulePrinter::print(const MLFunction *fn) {
|
||||
MLFunctionPrinter(fn, *this).print();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -680,8 +804,8 @@ void ModuleState::print(const MLFunction *fn) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void Attribute::print(raw_ostream &os) const {
|
||||
ModuleState moduleState(os);
|
||||
moduleState.print(this);
|
||||
ModuleState state(/*no context is known*/ nullptr);
|
||||
ModulePrinter(os, state).print(this);
|
||||
}
|
||||
|
||||
void Attribute::dump() const {
|
||||
|
@ -689,23 +813,12 @@ void Attribute::dump() const {
|
|||
}
|
||||
|
||||
void Type::print(raw_ostream &os) const {
|
||||
ModuleState moduleState(os);
|
||||
moduleState.print(this);
|
||||
ModuleState state(getContext());
|
||||
ModulePrinter(os, state).print(this);
|
||||
}
|
||||
|
||||
void Type::dump() const { print(llvm::errs()); }
|
||||
|
||||
void Instruction::print(raw_ostream &os) const {
|
||||
ModuleState moduleState(os);
|
||||
CFGFunctionState state(getFunction(), &moduleState, os);
|
||||
state.print(this);
|
||||
}
|
||||
|
||||
void Instruction::dump() const {
|
||||
print(llvm::errs());
|
||||
llvm::errs() << "\n";
|
||||
}
|
||||
|
||||
void AffineMap::dump() const {
|
||||
print(llvm::errs());
|
||||
llvm::errs() << "\n";
|
||||
|
@ -716,163 +829,54 @@ void AffineExpr::dump() const {
|
|||
llvm::errs() << "\n";
|
||||
}
|
||||
|
||||
void AffineSymbolExpr::print(raw_ostream &os) const {
|
||||
os << 's' << getPosition();
|
||||
}
|
||||
|
||||
void AffineDimExpr::print(raw_ostream &os) const { os << 'd' << getPosition(); }
|
||||
|
||||
void AffineConstantExpr::print(raw_ostream &os) const { os << getValue(); }
|
||||
|
||||
static void printAdd(const AffineBinaryOpExpr *addExpr, raw_ostream &os) {
|
||||
os << '(' << *addExpr->getLHS();
|
||||
|
||||
// Pretty print addition to a product that has a negative operand as a
|
||||
// subtraction.
|
||||
if (auto *rhs = dyn_cast<AffineBinaryOpExpr>(addExpr->getRHS())) {
|
||||
if (rhs->getKind() == AffineExpr::Kind::Mul) {
|
||||
if (auto *rrhs = dyn_cast<AffineConstantExpr>(rhs->getRHS())) {
|
||||
if (rrhs->getValue() < 0) {
|
||||
os << " - (" << *rhs->getLHS() << " * " << -rrhs->getValue() << "))";
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pretty print addition to a negative number as a subtraction.
|
||||
if (auto *rhs = dyn_cast<AffineConstantExpr>(addExpr->getRHS())) {
|
||||
if (rhs->getValue() < 0) {
|
||||
os << " - " << -rhs->getValue() << ")";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
os << " + " << *addExpr->getRHS() << ")";
|
||||
}
|
||||
|
||||
void AffineBinaryOpExpr::print(raw_ostream &os) const {
|
||||
switch (getKind()) {
|
||||
case Kind::Add:
|
||||
return printAdd(this, os);
|
||||
case Kind::Mul:
|
||||
os << "(" << *getLHS() << " * " << *getRHS() << ")";
|
||||
return;
|
||||
case Kind::FloorDiv:
|
||||
os << "(" << *getLHS() << " floordiv " << *getRHS() << ")";
|
||||
return;
|
||||
case Kind::CeilDiv:
|
||||
os << "(" << *getLHS() << " ceildiv " << *getRHS() << ")";
|
||||
return;
|
||||
case Kind::Mod:
|
||||
os << "(" << *getLHS() << " mod " << *getRHS() << ")";
|
||||
return;
|
||||
default:
|
||||
llvm_unreachable("unexpected affine binary op expression");
|
||||
}
|
||||
}
|
||||
|
||||
void AffineExpr::print(raw_ostream &os) const {
|
||||
switch (getKind()) {
|
||||
case Kind::SymbolId:
|
||||
return cast<AffineSymbolExpr>(this)->print(os);
|
||||
case Kind::DimId:
|
||||
return cast<AffineDimExpr>(this)->print(os);
|
||||
case Kind::Constant:
|
||||
return cast<AffineConstantExpr>(this)->print(os);
|
||||
case Kind::Add:
|
||||
case Kind::Mul:
|
||||
case Kind::FloorDiv:
|
||||
case Kind::CeilDiv:
|
||||
case Kind::Mod:
|
||||
return cast<AffineBinaryOpExpr>(this)->print(os);
|
||||
}
|
||||
ModuleState state(/*no context is known*/ nullptr);
|
||||
ModulePrinter(os, state).print(this);
|
||||
}
|
||||
|
||||
void AffineMap::print(raw_ostream &os) const {
|
||||
// Dimension identifiers.
|
||||
os << "(";
|
||||
for (int i = 0; i < (int)getNumDims() - 1; i++) os << "d" << i << ", ";
|
||||
if (getNumDims() >= 1) os << "d" << getNumDims() - 1;
|
||||
os << ")";
|
||||
ModuleState state(/*no context is known*/ nullptr);
|
||||
ModulePrinter(os, state).print(this);
|
||||
}
|
||||
|
||||
// Symbolic identifiers.
|
||||
if (getNumSymbols() >= 1) {
|
||||
os << " [";
|
||||
for (int i = 0; i < (int)getNumSymbols() - 1; i++) os << "s" << i << ", ";
|
||||
if (getNumSymbols() >= 1) os << "s" << getNumSymbols() - 1;
|
||||
os << "]";
|
||||
}
|
||||
void Instruction::print(raw_ostream &os) const {
|
||||
ModuleState state(getFunction()->getContext());
|
||||
ModulePrinter modulePrinter(os, state);
|
||||
CFGFunctionPrinter(getFunction(), modulePrinter).print(this);
|
||||
}
|
||||
|
||||
// AffineMap should have at least one result.
|
||||
assert(!getResults().empty());
|
||||
// Result affine expressions.
|
||||
os << " -> (";
|
||||
interleaveComma(os, getResults(), [&](AffineExpr *expr) { os << *expr; });
|
||||
os << ")";
|
||||
|
||||
if (!isBounded()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Print range sizes for bounded affine maps.
|
||||
os << " size (";
|
||||
interleaveComma(os, getRangeSizes(), [&](AffineExpr *expr) { os << *expr; });
|
||||
os << ")";
|
||||
void Instruction::dump() const {
|
||||
print(llvm::errs());
|
||||
llvm::errs() << "\n";
|
||||
}
|
||||
|
||||
void BasicBlock::print(raw_ostream &os) const {
|
||||
ModuleState moduleState(os);
|
||||
CFGFunctionState state(getFunction(), &moduleState, os);
|
||||
state.print();
|
||||
ModuleState state(getFunction()->getContext());
|
||||
ModulePrinter modulePrinter(os, state);
|
||||
CFGFunctionPrinter(getFunction(), modulePrinter).print(this);
|
||||
}
|
||||
|
||||
void BasicBlock::dump() const { print(llvm::errs()); }
|
||||
|
||||
void Statement::print(raw_ostream &os) const {
|
||||
ModuleState moduleState(os);
|
||||
MLFunctionState state(getFunction(), &moduleState, os);
|
||||
state.print(this);
|
||||
ModuleState state(getFunction()->getContext());
|
||||
ModulePrinter modulePrinter(os, state);
|
||||
MLFunctionPrinter(getFunction(), modulePrinter).print(this);
|
||||
}
|
||||
|
||||
void Statement::dump() const { print(llvm::errs()); }
|
||||
|
||||
void Function::print(raw_ostream &os) const {
|
||||
switch (getKind()) {
|
||||
case Kind::ExtFunc:
|
||||
return cast<ExtFunction>(this)->print(os);
|
||||
case Kind::CFGFunc:
|
||||
return cast<CFGFunction>(this)->print(os);
|
||||
case Kind::MLFunc:
|
||||
return cast<MLFunction>(this)->print(os);
|
||||
}
|
||||
ModuleState state(getContext());
|
||||
ModulePrinter(os, state).print(this);
|
||||
}
|
||||
|
||||
void Function::dump() const { print(llvm::errs()); }
|
||||
|
||||
void ExtFunction::print(raw_ostream &os) const {
|
||||
ModuleState moduleState(os);
|
||||
os << "extfunc ";
|
||||
printFunctionSignature(this, &moduleState, os);
|
||||
os << "\n";
|
||||
}
|
||||
|
||||
void CFGFunction::print(raw_ostream &os) const {
|
||||
ModuleState moduleState(os);
|
||||
CFGFunctionState state(this, &moduleState, os);
|
||||
state.print();
|
||||
}
|
||||
|
||||
void MLFunction::print(raw_ostream &os) const {
|
||||
ModuleState moduleState(os);
|
||||
MLFunctionState state(this, &moduleState, os);
|
||||
state.print();
|
||||
}
|
||||
|
||||
void Module::print(raw_ostream &os) const {
|
||||
ModuleState moduleState(os);
|
||||
moduleState.initialize(this);
|
||||
moduleState.print(this);
|
||||
ModuleState state(getContext());
|
||||
state.initialize(this);
|
||||
ModulePrinter(os, state).print(this);
|
||||
}
|
||||
|
||||
void Module::dump() const { print(llvm::errs()); }
|
||||
|
|
Loading…
Reference in New Issue