Implement OperationStmt. Refactor function printing to use FunctionState class for operation printing. FunctionState class is a base class for CFGFunctionState and MLFunctionState classes. No parsing yet - will add once cl/203785893 is in.

PiperOrigin-RevId: 203862427
This commit is contained in:
Tatiana Shpeisman 2018-07-09 17:42:46 -07:00 committed by jpienaar
parent 178fd24813
commit 6d93615678
2 changed files with 69 additions and 30 deletions

View File

@ -24,6 +24,9 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/PointerUnion.h"
#include "mlir/IR/Operation.h"
#include <vector>
namespace mlir {
@ -64,6 +67,19 @@ private:
ParentType parent;
};
/// Operation statements represent operations inside ML functions.
class OperationStmt : public Operation, public Statement {
public:
explicit OperationStmt(ParentType parent, Identifier name,
ArrayRef<NamedAttribute> attrs, MLIRContext *context)
: Operation(name, attrs, context), Statement(Kind::Operation, parent) {}
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Statement *stmt) {
return stmt->getKind() == Kind::Operation;
}
};
/// Node statement represents a statement that may contain other statements.
class NodeStmt : public Statement {
public:

View File

@ -76,12 +76,56 @@ void ExtFunction::print(raw_ostream &os) const {
os << "\n";
}
namespace {
// FunctionState contains common functionality for printing
// CFG and ML functions.
class FunctionState {
public:
FunctionState(MLIRContext *context, raw_ostream &os);
void printOperation(const Operation *op);
protected:
raw_ostream &os;
const OperationSet &operationSet;
};
} // end anonymous namespace
FunctionState::FunctionState(MLIRContext *context, raw_ostream &os)
: os(os), operationSet(OperationSet::get(context)) {}
void FunctionState::printOperation(const Operation *op) {
// Check to see if this is a known operation. If so, use the registered
// custom printer hook.
if (auto opInfo = operationSet.lookup(op->getName().str())) {
os << " ";
opInfo->printAssembly(op, os);
return;
}
// TODO: escape name if necessary.
os << " \"" << op->getName().str() << "\"()";
auto attrs = op->getAttrs();
if (!attrs.empty()) {
os << '{';
interleave(
attrs,
[&](NamedAttribute attr) { os << attr.first << ": " << *attr.second; },
[&]() { os << ", "; });
os << '}';
}
os << '\n';
}
//===----------------------------------------------------------------------===//
// CFG Function printing
//===----------------------------------------------------------------------===//
namespace {
class CFGFunctionState {
class CFGFunctionState : public FunctionState {
public:
CFGFunctionState(const CFGFunction *function, raw_ostream &os);
@ -103,16 +147,12 @@ public:
private:
const CFGFunction *function;
raw_ostream &os;
const OperationSet &operationSet;
DenseMap<const BasicBlock*, unsigned> basicBlockIDs;
};
} // end anonymous namespace
CFGFunctionState::CFGFunctionState(const CFGFunction *function, raw_ostream &os)
: function(function), os(os),
operationSet(OperationSet::get(function->getContext())) {
: FunctionState(function->getContext(), os), function(function) {
// Each basic block gets a unique ID per function.
unsigned blockID = 0;
for (auto &block : *function)
@ -151,27 +191,7 @@ void CFGFunctionState::print(const Instruction *inst) {
}
void CFGFunctionState::print(const OperationInst *inst) {
// Check to see if this is a known operation. If so, use the registered
// custom printer hook.
if (auto opInfo = operationSet.lookup(inst->getName().str())) {
os << " ";
opInfo->printAssembly(inst, os);
return;
}
// TODO: escape name if necessary.
os << " \"" << inst->getName().str() << "\"()";
auto attrs = inst->getAttrs();
if (!attrs.empty()) {
os << '{';
interleave(attrs, [&](NamedAttribute attr) {
os << attr.first << ": " << *attr.second;
}, [&]() { os << ", "; });
os << '}';
}
os << '\n';
printOperation(inst);
}
void CFGFunctionState::print(const BranchInst *inst) {
@ -186,7 +206,7 @@ void CFGFunctionState::print(const ReturnInst *inst) {
//===----------------------------------------------------------------------===//
namespace {
class MLFunctionState {
class MLFunctionState : public FunctionState {
public:
MLFunctionState(const MLFunction *function, raw_ostream &os);
@ -195,6 +215,7 @@ public:
void print();
void print(const Statement *stmt);
void print(const OperationStmt *stmt);
void print(const ForStmt *stmt);
void print(const IfStmt *stmt);
void print(const ElseClause *stmt, bool isLast);
@ -204,13 +225,13 @@ private:
void printNestedStatements(const NodeStmt *stmt);
const MLFunction *function;
raw_ostream &os;
int numSpaces;
};
} // end anonymous namespace
MLFunctionState::MLFunctionState(const MLFunction *function, raw_ostream &os)
: function(function), os(os), numSpaces(2) {}
: FunctionState(function->getContext(), os), function(function),
numSpaces(2) {}
void MLFunctionState::print() {
os << "mlfunc ";
@ -246,6 +267,8 @@ void MLFunctionState::printNestedStatements(const NodeStmt *stmt) {
os.indent(numSpaces) << "}";
}
void MLFunctionState::print(const OperationStmt *stmt) { printOperation(stmt); }
void MLFunctionState::print(const ForStmt *stmt) {
os << "for ";
printNestedStatements(stmt);