forked from OSchip/llvm-project
Implement OperationStmt. Refactor function printing to use FunctionState class for operation printing. FunctionState class is a base class for CFGFunctionState and MLFunctionState classes. No parsing yet - will add once cl/203785893 is in.
PiperOrigin-RevId: 203862427
This commit is contained in:
parent
178fd24813
commit
6d93615678
|
@ -24,6 +24,9 @@
|
|||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/PointerUnion.h"
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace mlir {
|
||||
|
@ -64,6 +67,19 @@ private:
|
|||
ParentType parent;
|
||||
};
|
||||
|
||||
/// Operation statements represent operations inside ML functions.
|
||||
class OperationStmt : public Operation, public Statement {
|
||||
public:
|
||||
explicit OperationStmt(ParentType parent, Identifier name,
|
||||
ArrayRef<NamedAttribute> attrs, MLIRContext *context)
|
||||
: Operation(name, attrs, context), Statement(Kind::Operation, parent) {}
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Statement *stmt) {
|
||||
return stmt->getKind() == Kind::Operation;
|
||||
}
|
||||
};
|
||||
|
||||
/// Node statement represents a statement that may contain other statements.
|
||||
class NodeStmt : public Statement {
|
||||
public:
|
||||
|
|
|
@ -76,12 +76,56 @@ void ExtFunction::print(raw_ostream &os) const {
|
|||
os << "\n";
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// FunctionState contains common functionality for printing
|
||||
// CFG and ML functions.
|
||||
class FunctionState {
|
||||
public:
|
||||
FunctionState(MLIRContext *context, raw_ostream &os);
|
||||
|
||||
void printOperation(const Operation *op);
|
||||
|
||||
protected:
|
||||
raw_ostream &os;
|
||||
const OperationSet &operationSet;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
FunctionState::FunctionState(MLIRContext *context, raw_ostream &os)
|
||||
: os(os), operationSet(OperationSet::get(context)) {}
|
||||
|
||||
void FunctionState::printOperation(const Operation *op) {
|
||||
// Check to see if this is a known operation. If so, use the registered
|
||||
// custom printer hook.
|
||||
if (auto opInfo = operationSet.lookup(op->getName().str())) {
|
||||
os << " ";
|
||||
opInfo->printAssembly(op, os);
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: escape name if necessary.
|
||||
os << " \"" << op->getName().str() << "\"()";
|
||||
|
||||
auto attrs = op->getAttrs();
|
||||
if (!attrs.empty()) {
|
||||
os << '{';
|
||||
interleave(
|
||||
attrs,
|
||||
[&](NamedAttribute attr) { os << attr.first << ": " << *attr.second; },
|
||||
[&]() { os << ", "; });
|
||||
os << '}';
|
||||
}
|
||||
|
||||
os << '\n';
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CFG Function printing
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class CFGFunctionState {
|
||||
class CFGFunctionState : public FunctionState {
|
||||
public:
|
||||
CFGFunctionState(const CFGFunction *function, raw_ostream &os);
|
||||
|
||||
|
@ -103,16 +147,12 @@ public:
|
|||
|
||||
private:
|
||||
const CFGFunction *function;
|
||||
raw_ostream &os;
|
||||
const OperationSet &operationSet;
|
||||
DenseMap<const BasicBlock*, unsigned> basicBlockIDs;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
CFGFunctionState::CFGFunctionState(const CFGFunction *function, raw_ostream &os)
|
||||
: function(function), os(os),
|
||||
operationSet(OperationSet::get(function->getContext())) {
|
||||
|
||||
: FunctionState(function->getContext(), os), function(function) {
|
||||
// Each basic block gets a unique ID per function.
|
||||
unsigned blockID = 0;
|
||||
for (auto &block : *function)
|
||||
|
@ -151,27 +191,7 @@ void CFGFunctionState::print(const Instruction *inst) {
|
|||
}
|
||||
|
||||
void CFGFunctionState::print(const OperationInst *inst) {
|
||||
// Check to see if this is a known operation. If so, use the registered
|
||||
// custom printer hook.
|
||||
if (auto opInfo = operationSet.lookup(inst->getName().str())) {
|
||||
os << " ";
|
||||
opInfo->printAssembly(inst, os);
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: escape name if necessary.
|
||||
os << " \"" << inst->getName().str() << "\"()";
|
||||
|
||||
auto attrs = inst->getAttrs();
|
||||
if (!attrs.empty()) {
|
||||
os << '{';
|
||||
interleave(attrs, [&](NamedAttribute attr) {
|
||||
os << attr.first << ": " << *attr.second;
|
||||
}, [&]() { os << ", "; });
|
||||
os << '}';
|
||||
}
|
||||
|
||||
os << '\n';
|
||||
printOperation(inst);
|
||||
}
|
||||
|
||||
void CFGFunctionState::print(const BranchInst *inst) {
|
||||
|
@ -186,7 +206,7 @@ void CFGFunctionState::print(const ReturnInst *inst) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class MLFunctionState {
|
||||
class MLFunctionState : public FunctionState {
|
||||
public:
|
||||
MLFunctionState(const MLFunction *function, raw_ostream &os);
|
||||
|
||||
|
@ -195,6 +215,7 @@ public:
|
|||
void print();
|
||||
|
||||
void print(const Statement *stmt);
|
||||
void print(const OperationStmt *stmt);
|
||||
void print(const ForStmt *stmt);
|
||||
void print(const IfStmt *stmt);
|
||||
void print(const ElseClause *stmt, bool isLast);
|
||||
|
@ -204,13 +225,13 @@ private:
|
|||
void printNestedStatements(const NodeStmt *stmt);
|
||||
|
||||
const MLFunction *function;
|
||||
raw_ostream &os;
|
||||
int numSpaces;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
MLFunctionState::MLFunctionState(const MLFunction *function, raw_ostream &os)
|
||||
: function(function), os(os), numSpaces(2) {}
|
||||
: FunctionState(function->getContext(), os), function(function),
|
||||
numSpaces(2) {}
|
||||
|
||||
void MLFunctionState::print() {
|
||||
os << "mlfunc ";
|
||||
|
@ -246,6 +267,8 @@ void MLFunctionState::printNestedStatements(const NodeStmt *stmt) {
|
|||
os.indent(numSpaces) << "}";
|
||||
}
|
||||
|
||||
void MLFunctionState::print(const OperationStmt *stmt) { printOperation(stmt); }
|
||||
|
||||
void MLFunctionState::print(const ForStmt *stmt) {
|
||||
os << "for ";
|
||||
printNestedStatements(stmt);
|
||||
|
|
Loading…
Reference in New Issue