Replace TerminatorInst with builtin terminator operations.

Note: Terminators will be merged into the operations list in a follow up patch.
PiperOrigin-RevId: 221670037
This commit is contained in:
River Riddle 2018-11-15 12:32:21 -08:00 committed by jpienaar
parent de828dd259
commit 503caf0722
15 changed files with 77 additions and 742 deletions

View File

@ -116,9 +116,9 @@ public:
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
/// Change the terminator of this block to the specified instruction. /// Change the terminator of this block to the specified instruction.
void setTerminator(TerminatorInst *inst); void setTerminator(OperationInst *inst);
TerminatorInst *getTerminator() const { return terminator; } OperationInst *getTerminator() const { return terminator; }
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
// Predecessors and successors. // Predecessors and successors.
@ -217,7 +217,7 @@ private:
std::vector<BBArgument *> arguments; std::vector<BBArgument *> arguments;
/// This is the owning reference to the terminator of the block. /// This is the owning reference to the terminator of the block.
TerminatorInst *terminator = nullptr; OperationInst *terminator = nullptr;
BasicBlock(const BasicBlock&) = delete; BasicBlock(const BasicBlock&) = delete;
void operator=(const BasicBlock&) = delete; void operator=(const BasicBlock&) = delete;

View File

@ -267,32 +267,7 @@ public:
return op; return op;
} }
// Terminators.
ReturnInst *createReturn(Location location, ArrayRef<CFGValue *> operands) {
return insertTerminator(ReturnInst::create(location, operands));
}
BranchInst *createBranch(Location location, BasicBlock *dest,
ArrayRef<CFGValue *> operands = {}) {
return insertTerminator(BranchInst::create(location, dest, operands));
}
CondBranchInst *createCondBranch(Location location, CFGValue *condition,
BasicBlock *trueDest,
BasicBlock *falseDest) {
return insertTerminator(
CondBranchInst::create(location, condition, trueDest, falseDest));
}
private: private:
template <typename T> T *insertTerminator(T *term) {
// FIXME: b/118738403
assert(!block->getTerminator() && "cannot insert the second terminator");
block->setTerminator(term);
return term;
}
CFGFunction *function; CFGFunction *function;
BasicBlock *block = nullptr; BasicBlock *block = nullptr;
BasicBlock::iterator insertPoint; BasicBlock::iterator insertPoint;

View File

@ -69,9 +69,6 @@ class Instruction : public IROperandOwner {
public: public:
enum class Kind { enum class Kind {
Operation = (int)IROperandOwner::Kind::OperationInst, Operation = (int)IROperandOwner::Kind::OperationInst,
Branch = (int)IROperandOwner::Kind::BranchInst,
CondBranch = (int)IROperandOwner::Kind::CondBranchInst,
Return = (int)IROperandOwner::Kind::ReturnInst
}; };
Kind getKind() const { return (Kind)IROperandOwner::getKind(); } Kind getKind() const { return (Kind)IROperandOwner::getKind(); }
@ -444,282 +441,6 @@ private:
size_t numTrailingObjects(OverloadToken<unsigned>) const { return numSuccs; } size_t numTrailingObjects(OverloadToken<unsigned>) const { return numSuccs; }
}; };
/// Terminator instructions are the last part of a basic block, used to
/// represent control flow and returns.
class TerminatorInst : public Instruction {
public:
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Instruction *inst) {
return inst->getKind() != Kind::Operation;
}
/// Remove this terminator from its BasicBlock and delete it.
void erase();
/// Return the list of BasicBlockOperand operands of this terminator that
/// this terminator holds.
MutableArrayRef<BasicBlockOperand> getBasicBlockOperands();
ArrayRef<BasicBlockOperand> getBasicBlockOperands() const {
return const_cast<TerminatorInst *>(this)->getBasicBlockOperands();
}
unsigned getNumSuccessors() const { return getBasicBlockOperands().size(); }
const BasicBlock *getSuccessor(unsigned i) const {
return getBasicBlockOperands()[i].get();
}
BasicBlock *getSuccessor(unsigned i) {
return getBasicBlockOperands()[i].get();
}
protected:
TerminatorInst(Kind kind, Location location) : Instruction(kind, location) {}
~TerminatorInst() {}
};
/// The 'br' instruction is an unconditional from one basic block to another,
/// and may pass basic block arguments to the successor.
class BranchInst : public TerminatorInst {
public:
static BranchInst *create(Location location, BasicBlock *dest,
ArrayRef<CFGValue *> operands = {}) {
return new BranchInst(location, dest, operands);
}
~BranchInst() {}
/// Return the block this branch jumps to.
BasicBlock *getDest() const { return dest.get(); }
void setDest(BasicBlock *block);
unsigned getNumOperands() const { return operands.size(); }
ArrayRef<InstOperand> getInstOperands() const { return operands; }
MutableArrayRef<InstOperand> getInstOperands() { return operands; }
/// Add one value to the operand list.
void addOperand(CFGValue *value);
/// Add a list of values to the operand list.
void addOperands(ArrayRef<CFGValue *> values);
/// Erase a specific argument from the arg list.
// TODO: void eraseArgument(int Index);
MutableArrayRef<BasicBlockOperand> getBasicBlockOperands() { return dest; }
ArrayRef<BasicBlockOperand> getBasicBlockOperands() const { return dest; }
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const IROperandOwner *ptr) {
return ptr->getKind() == IROperandOwner::Kind::BranchInst;
}
private:
explicit BranchInst(Location location, BasicBlock *dest,
ArrayRef<CFGValue *> operands);
BasicBlockOperand dest;
std::vector<InstOperand> operands;
};
/// The 'cond_br' instruction is a conditional branch based on a boolean
/// condition to one of two possible successors. It may pass arguments to each
/// successor.
class CondBranchInst : public TerminatorInst {
// These are the indices into the dests list.
enum { trueIndex = 0, falseIndex = 1 };
public:
static CondBranchInst *create(Location location, CFGValue *condition,
BasicBlock *trueDest, BasicBlock *falseDest) {
return new CondBranchInst(location, condition, trueDest, falseDest);
}
~CondBranchInst() {}
/// Return the i1 condition.
CFGValue *getCondition() { return condition; }
const CFGValue *getCondition() const { return condition; }
/// Return the destination if the condition is true.
BasicBlock *getTrueDest() const { return dests[trueIndex].get(); }
/// Return the destination if the condition is false.
BasicBlock *getFalseDest() const { return dests[falseIndex].get(); }
// Support non-const operand iteration.
using operand_iterator = OperandIterator<CondBranchInst, CFGValue>;
// Support const operand iteration.
using const_operand_iterator =
OperandIterator<const CondBranchInst, const CFGValue>;
ArrayRef<InstOperand> getInstOperands() const { return operands; }
MutableArrayRef<InstOperand> getInstOperands() { return operands; }
unsigned getNumOperands() const { return operands.size(); }
// Accessors for operands to the 'true' destination.
CFGValue *getTrueOperand(unsigned idx) {
return getTrueInstOperand(idx).get();
}
const CFGValue *getTrueOperand(unsigned idx) const {
return getTrueInstOperand(idx).get();
}
void setTrueOperand(unsigned idx, CFGValue *value) {
return getTrueInstOperand(idx).set(value);
}
operand_iterator true_operand_begin() { return operand_iterator(this, 0); }
operand_iterator true_operand_end() {
return operand_iterator(this, getNumTrueOperands());
}
llvm::iterator_range<operand_iterator> getTrueOperands() {
return {true_operand_begin(), true_operand_end()};
}
const_operand_iterator true_operand_begin() const {
return const_operand_iterator(this, 0);
}
const_operand_iterator true_operand_end() const {
return const_operand_iterator(this, getNumTrueOperands());
}
llvm::iterator_range<const_operand_iterator> getTrueOperands() const {
return {true_operand_begin(), true_operand_end()};
}
ArrayRef<InstOperand> getTrueInstOperands() const {
return const_cast<CondBranchInst *>(this)->getTrueInstOperands();
}
MutableArrayRef<InstOperand> getTrueInstOperands() {
return {operands.data(), operands.data() + getNumTrueOperands()};
}
InstOperand &getTrueInstOperand(unsigned idx) { return operands[idx]; }
const InstOperand &getTrueInstOperand(unsigned idx) const {
return operands[idx];
}
unsigned getNumTrueOperands() const { return numTrueOperands; }
/// Add one value to the true operand list.
void addTrueOperand(CFGValue *value);
/// Add a list of values to the operand list.
void addTrueOperands(ArrayRef<CFGValue *> values);
// Accessors for operands to the 'false' destination.
CFGValue *getFalseOperand(unsigned idx) {
return getFalseInstOperand(idx).get();
}
const CFGValue *getFalseOperand(unsigned idx) const {
return getFalseInstOperand(idx).get();
}
void setFalseOperand(unsigned idx, CFGValue *value) {
return getFalseInstOperand(idx).set(value);
}
operand_iterator false_operand_begin() {
return operand_iterator(this, getNumTrueOperands());
}
operand_iterator false_operand_end() {
return operand_iterator(this, getNumOperands());
}
llvm::iterator_range<operand_iterator> getFalseOperands() {
return {false_operand_begin(), false_operand_end()};
}
const_operand_iterator false_operand_begin() const {
return const_operand_iterator(this, getNumTrueOperands());
}
const_operand_iterator false_operand_end() const {
return const_operand_iterator(this, getNumOperands());
}
llvm::iterator_range<const_operand_iterator> getFalseOperands() const {
return {false_operand_begin(), false_operand_end()};
}
ArrayRef<InstOperand> getFalseInstOperands() const {
return const_cast<CondBranchInst *>(this)->getFalseInstOperands();
}
MutableArrayRef<InstOperand> getFalseInstOperands() {
return {operands.data() + getNumTrueOperands(),
operands.data() + getNumOperands()};
}
InstOperand &getFalseInstOperand(unsigned idx) {
return operands[idx + getNumTrueOperands()];
}
const InstOperand &getFalseInstOperand(unsigned idx) const {
return operands[idx + getNumTrueOperands()];
}
unsigned getNumFalseOperands() const {
return operands.size() - numTrueOperands;
}
/// Add one value to the false operand list.
void addFalseOperand(CFGValue *value);
/// Add a list of values to the operand list.
void addFalseOperands(ArrayRef<CFGValue *> values);
MutableArrayRef<BasicBlockOperand> getBasicBlockOperands() { return dests; }
ArrayRef<BasicBlockOperand> getBasicBlockOperands() const { return dests; }
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const IROperandOwner *ptr) {
return ptr->getKind() == IROperandOwner::Kind::CondBranchInst;
}
private:
CondBranchInst(Location location, CFGValue *condition, BasicBlock *trueDest,
BasicBlock *falseDest);
CFGValue *condition;
BasicBlockOperand dests[2]; // 0 is the true dest, 1 is the false dest.
// Operand list. The true operands are stored first, followed by the false
// operands.
std::vector<InstOperand> operands;
unsigned numTrueOperands;
};
/// The 'return' instruction represents the end of control flow within the
/// current function, and can return zero or more results. The result list is
/// required to align with the result list of the containing function's type.
class ReturnInst final
: public TerminatorInst,
private llvm::TrailingObjects<ReturnInst, InstOperand> {
public:
/// Create a new ReturnInst with the specific fields.
static ReturnInst *create(Location location, ArrayRef<CFGValue *> operands);
unsigned getNumOperands() const { return numOperands; }
ArrayRef<InstOperand> getInstOperands() const {
return {getTrailingObjects<InstOperand>(), numOperands};
}
MutableArrayRef<InstOperand> getInstOperands() {
return {getTrailingObjects<InstOperand>(), numOperands};
}
void destroy();
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const IROperandOwner *ptr) {
return ptr->getKind() == IROperandOwner::Kind::ReturnInst;
}
private:
// This stuff is used by the TrailingObjects template.
friend llvm::TrailingObjects<ReturnInst, InstOperand>;
size_t numTrailingObjects(OverloadToken<InstOperand>) const {
return numOperands;
}
ReturnInst(Location location, unsigned numOperands);
~ReturnInst();
unsigned numOperands;
};
} // end namespace mlir } // end namespace mlir
#endif // MLIR_IR_INSTRUCTIONS_H #endif // MLIR_IR_INSTRUCTIONS_H

View File

@ -80,9 +80,6 @@ public:
ForStmt, ForStmt,
IfStmt, IfStmt,
OperationInst, OperationInst,
BranchInst,
CondBranchInst,
ReturnInst,
/// These enums define ranges used for classof implementations. /// These enums define ranges used for classof implementations.
STMT_LAST = IfStmt, STMT_LAST = IfStmt,

View File

@ -49,12 +49,12 @@ bool DominanceInfo::properlyDominates(const Instruction *a,
return false; return false;
// If one is a terminator, then the other dominates it. // If one is a terminator, then the other dominates it.
auto *aOp = dyn_cast<OperationInst>(a); auto *aOp = cast<OperationInst>(a);
if (!aOp) if (aOp->isTerminator())
return false; return false;
auto *bOp = dyn_cast<OperationInst>(b); auto *bOp = cast<OperationInst>(b);
if (!bOp) if (bOp->isTerminator())
return true; return true;
// Otherwise, do a linear scan to determine whether B comes after A. // Otherwise, do a linear scan to determine whether B comes after A.

View File

@ -75,7 +75,7 @@ public:
// If the code is properly formed, there will be a terminator. Use its // If the code is properly formed, there will be a terminator. Use its
// location. // location.
if (auto *termInst = bb.getTerminator()) if (auto *termInst = bb.getTerminator())
return failure(message, *termInst); return (termInst->emitError(message), true);
// Worst case, fall back to using the function's location. // Worst case, fall back to using the function's location.
return failure(message, fn); return failure(message, fn);
@ -166,14 +166,7 @@ struct CFGFuncVerifier : public Verifier {
bool verify(); bool verify();
bool verifyBlock(const BasicBlock &block); bool verifyBlock(const BasicBlock &block);
bool verifyTerminator(const TerminatorInst &term);
bool verifyInstOperands(const Instruction &inst); bool verifyInstOperands(const Instruction &inst);
bool verifyBBArguments(ArrayRef<InstOperand> operands,
const BasicBlock *destBB, const TerminatorInst &term);
bool verifyReturn(const ReturnInst &inst);
bool verifyBranch(const BranchInst &inst);
bool verifyCondBranch(const CondBranchInst &inst);
}; };
} // end anonymous namespace } // end anonymous namespace
@ -237,7 +230,10 @@ bool CFGFuncVerifier::verifyBlock(const BasicBlock &block) {
if (!block.getTerminator()) if (!block.getTerminator())
return failure("basic block with no terminator", block); return failure("basic block with no terminator", block);
if (verifyTerminator(*block.getTerminator())) // TODO(riverriddle) Remove this when terminators are inside of the block
// operation list.
auto &term = *block.getTerminator();
if (verifyOperation(term) || verifyInstOperands(term))
return true; return true;
for (auto *arg : block.getArguments()) { for (auto *arg : block.getArguments()) {
@ -252,101 +248,6 @@ bool CFGFuncVerifier::verifyBlock(const BasicBlock &block) {
return false; return false;
} }
bool CFGFuncVerifier::verifyTerminator(const TerminatorInst &term) {
if (term.getFunction() != &fn)
return failure("terminator in the wrong function", term);
// Check that operands are non-nil and structurally ok.
for (const auto *operand : term.getOperands()) {
if (!operand)
return failure("null operand found", term);
if (operand->getFunction() != &fn)
return failure("reference to operand defined in another function", term);
}
// Verify dominance of values.
verifyInstOperands(term);
// Check that successors are in the right function.
for (auto *succ : term.getBlock()->getSuccessors()) {
if (succ->getFunction() != &fn)
return failure("reference to block defined in another function", term);
}
if (auto *ret = dyn_cast<ReturnInst>(&term))
return verifyReturn(*ret);
if (auto *br = dyn_cast<BranchInst>(&term))
return verifyBranch(*br);
if (auto *br = dyn_cast<CondBranchInst>(&term))
return verifyCondBranch(*br);
return false;
}
/// Check a set of basic block arguments against the expected list in in the
/// destination basic block.
bool CFGFuncVerifier::verifyBBArguments(ArrayRef<InstOperand> operands,
const BasicBlock *destBB,
const TerminatorInst &term) {
if (operands.size() != destBB->getNumArguments())
return failure("branch has " + Twine(operands.size()) +
" operands, but target block has " +
Twine(destBB->getNumArguments()),
term);
for (unsigned i = 0, e = operands.size(); i != e; ++i)
if (operands[i].get()->getType() != destBB->getArgument(i)->getType())
return failure("type mismatch in bb argument #" + Twine(i), term);
return false;
}
bool CFGFuncVerifier::verifyReturn(const ReturnInst &inst) {
// Verify that the return operands match the results of the function.
auto results = fn.getType().getResults();
if (inst.getNumOperands() != results.size())
return failure("return has " + Twine(inst.getNumOperands()) +
" operands, but enclosing function returns " +
Twine(results.size()),
inst);
for (unsigned i = 0, e = results.size(); i != e; ++i)
if (inst.getOperand(i)->getType() != results[i])
return failure("type of return operand " + Twine(i) +
" doesn't match function result type",
inst);
return false;
}
bool CFGFuncVerifier::verifyBranch(const BranchInst &inst) {
// Verify that the number of operands lines up with the number of BB arguments
// in the successor.
if (verifyBBArguments(inst.getInstOperands(), inst.getDest(), inst))
return true;
return false;
}
bool CFGFuncVerifier::verifyCondBranch(const CondBranchInst &inst) {
// Verify that the number of operands lines up with the number of BB arguments
// in the true successor.
if (verifyBBArguments(inst.getTrueInstOperands(), inst.getTrueDest(), inst))
return true;
// And the false successor.
if (verifyBBArguments(inst.getFalseInstOperands(), inst.getFalseDest(), inst))
return true;
if (inst.getCondition()->getType() != Type::getInteger(1, fn.getContext()))
return failure("type of condition is not boolean (i1)", inst);
return false;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ML Functions // ML Functions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -1164,9 +1164,6 @@ public:
void print(const Instruction *inst); void print(const Instruction *inst);
void print(const OperationInst *inst); void print(const OperationInst *inst);
void print(const ReturnInst *inst);
void print(const BranchInst *inst);
void print(const CondBranchInst *inst);
void printSuccessorAndUseList(const Operation *term, unsigned index); void printSuccessorAndUseList(const Operation *term, unsigned index);
@ -1282,12 +1279,6 @@ void CFGFunctionPrinter::print(const Instruction *inst) {
switch (inst->getKind()) { switch (inst->getKind()) {
case Instruction::Kind::Operation: case Instruction::Kind::Operation:
return print(cast<OperationInst>(inst)); return print(cast<OperationInst>(inst));
case TerminatorInst::Kind::Branch:
return print(cast<BranchInst>(inst));
case TerminatorInst::Kind::CondBranch:
return print(cast<CondBranchInst>(inst));
case TerminatorInst::Kind::Return:
return print(cast<ReturnInst>(inst));
} }
} }
@ -1319,40 +1310,6 @@ void CFGFunctionPrinter::printSuccessorAndUseList(const Operation *term,
printBranchOperands(term->getSuccessorOperands(index)); printBranchOperands(term->getSuccessorOperands(index));
} }
void CFGFunctionPrinter::print(const BranchInst *inst) {
os << "br ";
printBBName(inst->getDest());
printBranchOperands(inst->getOperands());
}
void CFGFunctionPrinter::print(const CondBranchInst *inst) {
os << "cond_br ";
printValueID(inst->getCondition());
os << ", ";
printBBName(inst->getTrueDest());
printBranchOperands(inst->getTrueOperands());
os << ", ";
printBBName(inst->getFalseDest());
printBranchOperands(inst->getFalseOperands());
}
void CFGFunctionPrinter::print(const ReturnInst *inst) {
os << "return";
if (inst->getNumOperands() == 0)
return;
os << ' ';
interleaveComma(inst->getOperands(),
[&](const CFGValue *operand) { printValueID(operand); });
os << " : ";
interleaveComma(inst->getOperands(), [&](const CFGValue *operand) {
printType(operand->getType());
});
}
void ModulePrinter::print(const CFGFunction *fn) { void ModulePrinter::print(const CFGFunction *fn) {
CFGFunctionPrinter(fn, *this).print(); CFGFunctionPrinter(fn, *this).print();
} }

View File

@ -16,6 +16,8 @@
// ============================================================================= // =============================================================================
#include "mlir/IR/BasicBlock.h" #include "mlir/IR/BasicBlock.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/CFGFunction.h" #include "mlir/IR/CFGFunction.h"
using namespace mlir; using namespace mlir;
@ -54,7 +56,7 @@ auto BasicBlock::addArguments(ArrayRef<Type> types)
// Terminator management // Terminator management
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void BasicBlock::setTerminator(TerminatorInst *inst) { void BasicBlock::setTerminator(OperationInst *inst) {
assert((!inst || !inst->block) && "terminator already inserted into a block"); assert((!inst || !inst->block) && "terminator already inserted into a block");
// If we already had a terminator, abandon it. // If we already had a terminator, abandon it.
if (terminator) if (terminator)
@ -161,9 +163,12 @@ BasicBlock *BasicBlock::splitBasicBlock(iterator splitBefore) {
// to the new block. // to the new block.
auto branchLoc = auto branchLoc =
splitBefore == end() ? getTerminator()->getLoc() : splitBefore->getLoc(); splitBefore == end() ? getTerminator()->getLoc() : splitBefore->getLoc();
// TODO(riverriddle) Remove this when terminators are a part of the operations
// list.
auto oldTerm = getTerminator(); auto oldTerm = getTerminator();
setTerminator(BranchInst::create(branchLoc, newBB)); setTerminator(nullptr);
newBB->setTerminator(oldTerm); newBB->setTerminator(oldTerm);
CFGFuncBuilder(this).create<BranchOp>(branchLoc, newBB);
// Move all of the operations from the split point to the end of the function // Move all of the operations from the split point to the end of the function
// into the new block. // into the new block.

View File

@ -19,6 +19,7 @@
#include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h" #include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IntegerSet.h" #include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Location.h" #include "mlir/IR/Location.h"
#include "mlir/IR/Module.h" #include "mlir/IR/Module.h"
@ -277,7 +278,15 @@ OperationInst *CFGFuncBuilder::createOperation(const OperationState &state) {
auto *op = auto *op =
OperationInst::create(state.location, state.name, operands, state.types, OperationInst::create(state.location, state.name, operands, state.types,
state.attributes, state.successors, context); state.attributes, state.successors, context);
block->getOperations().insert(insertPoint, op); // TODO(riverriddle) Remove this when the terminators are in basic block
// operation lists.
if (op->isTerminator()) {
// FIXME(b/118738403)
assert(!block->getTerminator() && "cannot insert the second terminator");
block->setTerminator(op);
} else {
block->getOperations().insert(insertPoint, op);
}
return op; return op;
} }

View File

@ -53,15 +53,6 @@ void Instruction::destroy() {
case Kind::Operation: case Kind::Operation:
cast<OperationInst>(this)->destroy(); cast<OperationInst>(this)->destroy();
break; break;
case Kind::Branch:
delete cast<BranchInst>(this);
break;
case Kind::CondBranch:
delete cast<CondBranchInst>(this);
break;
case Kind::Return:
cast<ReturnInst>(this)->destroy();
break;
} }
} }
@ -79,12 +70,6 @@ unsigned Instruction::getNumOperands() const {
switch (getKind()) { switch (getKind()) {
case Kind::Operation: case Kind::Operation:
return cast<OperationInst>(this)->getNumOperands(); return cast<OperationInst>(this)->getNumOperands();
case Kind::Branch:
return cast<BranchInst>(this)->getNumOperands();
case Kind::CondBranch:
return cast<CondBranchInst>(this)->getNumOperands();
case Kind::Return:
return cast<ReturnInst>(this)->getNumOperands();
} }
} }
@ -92,12 +77,6 @@ MutableArrayRef<InstOperand> Instruction::getInstOperands() {
switch (getKind()) { switch (getKind()) {
case Kind::Operation: case Kind::Operation:
return cast<OperationInst>(this)->getInstOperands(); return cast<OperationInst>(this)->getInstOperands();
case Kind::Branch:
return cast<BranchInst>(this)->getInstOperands();
case Kind::CondBranch:
return cast<CondBranchInst>(this)->getInstOperands();
case Kind::Return:
return cast<ReturnInst>(this)->getInstOperands();
} }
} }
@ -108,10 +87,6 @@ void Instruction::dropAllReferences() {
for (auto &op : getInstOperands()) for (auto &op : getInstOperands())
op.drop(); op.drop();
if (auto *term = dyn_cast<TerminatorInst>(this))
for (auto &dest : term->getBasicBlockOperands())
dest.drop();
if (OperationInst *opInst = dyn_cast<OperationInst>(this)) { if (OperationInst *opInst = dyn_cast<OperationInst>(this)) {
if (opInst->isTerminator()) if (opInst->isTerminator())
for (auto &dest : opInst->getBasicBlockOperands()) for (auto &dest : opInst->getBasicBlockOperands())
@ -348,7 +323,14 @@ void llvm::ilist_traits<::mlir::OperationInst>::transferNodesFromList(
/// Unlink this instruction from its BasicBlock and delete it. /// Unlink this instruction from its BasicBlock and delete it.
void OperationInst::erase() { void OperationInst::erase() {
assert(getBlock() && "Instruction has no parent"); assert(getBlock() && "Instruction has no parent");
getBlock()->getOperations().erase(this); // TODO(riverriddle) Remove this when terminators are a part of the operations
// list.
if (isTerminator()) {
getBlock()->setTerminator(nullptr);
destroy();
} else {
getBlock()->getOperations().erase(this);
}
} }
/// Unlink this operation instruction from its current basic block and insert /// Unlink this operation instruction from its current basic block and insert
@ -356,6 +338,12 @@ void OperationInst::erase() {
/// in the same function. /// in the same function.
void OperationInst::moveBefore(OperationInst *existingInst) { void OperationInst::moveBefore(OperationInst *existingInst) {
assert(existingInst && "Cannot move before a null instruction"); assert(existingInst && "Cannot move before a null instruction");
// TODO(riverriddle) Remove this when terminators are a part of the operations
// list.
if (existingInst->isTerminator()) {
return moveBefore(existingInst->getBlock(),
existingInst->getBlock()->end());
}
return moveBefore(existingInst->getBlock(), existingInst->getIterator()); return moveBefore(existingInst->getBlock(), existingInst->getIterator());
} }
@ -366,126 +354,3 @@ void OperationInst::moveBefore(BasicBlock *block,
block->getOperations().splice(iterator, getBlock()->getOperations(), block->getOperations().splice(iterator, getBlock()->getOperations(),
getIterator()); getIterator());
} }
//===----------------------------------------------------------------------===//
// TerminatorInst
//===----------------------------------------------------------------------===//
/// Remove this terminator from its BasicBlock and delete it.
void TerminatorInst::erase() {
assert(getBlock() && "Instruction has no parent");
getBlock()->setTerminator(nullptr);
destroy();
}
/// Return the list of destination entries that this terminator branches to.
MutableArrayRef<BasicBlockOperand> TerminatorInst::getBasicBlockOperands() {
switch (getKind()) {
case Kind::Operation:
llvm_unreachable("not a terminator");
case Kind::Branch:
return cast<BranchInst>(this)->getBasicBlockOperands();
case Kind::CondBranch:
return cast<CondBranchInst>(this)->getBasicBlockOperands();
case Kind::Return:
// Return has no basic block successors.
return {};
}
}
//===----------------------------------------------------------------------===//
// ReturnInst
//===----------------------------------------------------------------------===//
/// Create a new OperationInst with the specific fields.
ReturnInst *ReturnInst::create(Location location,
ArrayRef<CFGValue *> operands) {
auto byteSize = totalSizeToAlloc<InstOperand>(operands.size());
void *rawMem = malloc(byteSize);
// Initialize the ReturnInst part of the instruction.
auto inst = ::new (rawMem) ReturnInst(location, operands.size());
// Initialize the operands and results.
auto instOperands = inst->getInstOperands();
for (unsigned i = 0, e = operands.size(); i != e; ++i)
new (&instOperands[i]) InstOperand(inst, operands[i]);
return inst;
}
ReturnInst::ReturnInst(Location location, unsigned numOperands)
: TerminatorInst(Kind::Return, location), numOperands(numOperands) {}
void ReturnInst::destroy() {
this->~ReturnInst();
free(this);
}
ReturnInst::~ReturnInst() {
// Explicitly run the destructors for the operands.
for (auto &operand : getInstOperands())
operand.~InstOperand();
}
//===----------------------------------------------------------------------===//
// BranchInst
//===----------------------------------------------------------------------===//
BranchInst::BranchInst(Location location, BasicBlock *dest,
ArrayRef<CFGValue *> operands)
: TerminatorInst(Kind::Branch, location), dest(this, dest) {
addOperands(operands);
}
void BranchInst::setDest(BasicBlock *block) { dest.set(block); }
/// Add one value to the operand list.
void BranchInst::addOperand(CFGValue *value) {
operands.emplace_back(InstOperand(this, value));
}
/// Add a list of values to the operand list.
void BranchInst::addOperands(ArrayRef<CFGValue *> values) {
operands.reserve(operands.size() + values.size());
for (auto *value : values)
addOperand(value);
}
//===----------------------------------------------------------------------===//
// CondBranchInst
//===----------------------------------------------------------------------===//
CondBranchInst::CondBranchInst(Location location, CFGValue *condition,
BasicBlock *trueDest, BasicBlock *falseDest)
: TerminatorInst(Kind::CondBranch, location),
condition(condition), dests{{this}, {this}}, numTrueOperands(0) {
dests[falseIndex].set(falseDest);
dests[trueIndex].set(trueDest);
}
/// Add one value to the true operand list.
void CondBranchInst::addTrueOperand(CFGValue *value) {
assert(getNumFalseOperands() == 0 &&
"Must insert all true operands before false operands!");
operands.emplace_back(InstOperand(this, value));
++numTrueOperands;
}
/// Add a list of values to the true operand list.
void CondBranchInst::addTrueOperands(ArrayRef<CFGValue *> values) {
operands.reserve(operands.size() + values.size());
for (auto *value : values)
addTrueOperand(value);
}
/// Add one value to the false operand list.
void CondBranchInst::addFalseOperand(CFGValue *value) {
operands.emplace_back(InstOperand(this, value));
}
/// Add a list of values to the false operand list.
void CondBranchInst::addFalseOperands(ArrayRef<CFGValue *> values) {
operands.reserve(operands.size() + values.size());
for (auto *value : values)
addFalseOperand(value);
}

View File

@ -542,7 +542,10 @@ bool OpTrait::impl::verifyIsTerminator(const Operation *op) {
} else { } else {
const OperationInst *inst = cast<OperationInst>(op); const OperationInst *inst = cast<OperationInst>(op);
const BasicBlock *block = inst->getBlock(); const BasicBlock *block = inst->getBlock();
if (!block || &block->back() != inst) // TODO(riverriddle) Check the instruction at the back of the block when
// terminators are in the operations list.
// if (!block || &block->back() != inst)
if (!block || block->getTerminator() != inst)
return op->emitOpError( return op->emitOpError(
"must be the last instruction in the parent basic block."); "must be the last instruction in the parent basic block.");
} }

View File

@ -78,9 +78,6 @@ MLIRContext *IROperandOwner::getContext() const {
return cast<IfStmt>(this)->getContext(); return cast<IfStmt>(this)->getContext();
case Kind::OperationInst: case Kind::OperationInst:
case Kind::BranchInst:
case Kind::CondBranchInst:
case Kind::ReturnInst:
// If we have an instruction, we can efficiently get this from the function // If we have an instruction, we can efficiently get this from the function
// the instruction is in. // the instruction is in.
auto *fn = cast<Instruction>(this)->getFunction(); auto *fn = cast<Instruction>(this)->getFunction();

View File

@ -2133,7 +2133,10 @@ FunctionParser::parseOperation(const CreateOperationFunction &createOpFunc) {
// is structurally as we expect. If not, produce an error with a reasonable // is structurally as we expect. If not, produce an error with a reasonable
// source location. // source location.
if (auto *opInfo = op->getAbstractOperation()) { if (auto *opInfo = op->getAbstractOperation()) {
if (opInfo->verifyInvariants(op)) // We don't wan't to verify branching terminators at this time because
// the successors may not have been fully parsed yet.
if (!(op->isTerminator() && op->getNumSuccessors() != 0) &&
opInfo->verifyInvariants(op))
return ParseFailure; return ParseFailure;
} }
@ -2528,11 +2531,8 @@ private:
ParseResult ParseResult
parseOptionalBasicBlockArgList(SmallVectorImpl<BBArgument *> &results, parseOptionalBasicBlockArgList(SmallVectorImpl<BBArgument *> &results,
BasicBlock *owner); BasicBlock *owner);
ParseResult parseBranchBlockAndUseList(BasicBlock *&block,
SmallVectorImpl<CFGValue *> &values);
ParseResult parseBasicBlock(); ParseResult parseBasicBlock();
TerminatorInst *parseTerminator();
}; };
} // end anonymous namespace } // end anonymous namespace
@ -2610,6 +2610,15 @@ ParseResult CFGFunctionParser::parseFunctionBody() {
return ParseFailure; return ParseFailure;
} }
// Now that the function body has been fully parsed we check the invariants
// of any branching terminators.
for (auto &block : *function) {
auto *term = block.getTerminator();
auto *abstractOp = term->getAbstractOperation();
if (term->getNumSuccessors() != 0 && abstractOp)
abstractOp->verifyInvariants(term);
}
return finalizeFunction(function, braceLoc); return finalizeFunction(function, braceLoc);
} }
@ -2657,101 +2666,13 @@ ParseResult CFGFunctionParser::parseBasicBlock() {
return ParseFailure; return ParseFailure;
} }
if (!parseTerminator()) // Parse the terminator operation.
if (parseOperation(createOpFunc))
return ParseFailure; return ParseFailure;
return ParseSuccess; return ParseSuccess;
} }
ParseResult CFGFunctionParser::parseBranchBlockAndUseList(
BasicBlock *&block, SmallVectorImpl<CFGValue *> &values) {
// Verify branch is identifier and get the matching block.
if (!getToken().is(Token::bare_identifier))
return emitError("expected basic block name");
block = getBlockNamed(getTokenSpelling(), getToken().getLoc());
consumeToken();
// Handle optional arguments.
if (consumeIf(Token::l_paren) &&
(parseOptionalSSAUseAndTypeList(values) ||
parseToken(Token::r_paren, "expected ')' to close argument list"))) {
return ParseFailure;
}
return ParseSuccess;
}
/// Parse the terminator instruction for a basic block.
///
/// terminator-stmt ::= `br` bb-id branch-use-list?
/// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)`
/// terminator-stmt ::=
/// `cond_br` ssa-use `,` bb-id branch-use-list? `,` bb-id
/// branch-use-list?
/// terminator-stmt ::= `return` ssa-use-and-type-list?
///
TerminatorInst *CFGFunctionParser::parseTerminator() {
auto loc = getToken().getLoc();
switch (getToken().getKind()) {
default:
return (emitError("expected terminator at end of basic block"), nullptr);
case Token::kw_return: {
consumeToken(Token::kw_return);
// Parse any operands.
SmallVector<CFGValue *, 8> operands;
if (parseOptionalSSAUseAndTypeList(operands))
return nullptr;
return builder.createReturn(getEncodedSourceLocation(loc), operands);
}
case Token::kw_br: {
consumeToken(Token::kw_br);
BasicBlock *destBB;
SmallVector<CFGValue *, 4> values;
if (parseBranchBlockAndUseList(destBB, values))
return nullptr;
auto branch = builder.createBranch(getEncodedSourceLocation(loc), destBB);
branch->addOperands(values);
return branch;
}
case Token::kw_cond_br: {
consumeToken(Token::kw_cond_br);
SSAUseInfo ssaUse;
if (parseSSAUse(ssaUse))
return nullptr;
auto *cond = resolveSSAUse(ssaUse, builder.getIntegerType(1));
if (!cond)
return (emitError("expected type was boolean (i1)"), nullptr);
if (parseToken(Token::comma, "expected ',' in conditional branch"))
return nullptr;
BasicBlock *trueBlock;
SmallVector<CFGValue *, 4> trueOperands;
if (parseBranchBlockAndUseList(trueBlock, trueOperands))
return nullptr;
if (parseToken(Token::comma, "expected ',' in conditional branch"))
return nullptr;
BasicBlock *falseBlock;
SmallVector<CFGValue *, 4> falseOperands;
if (parseBranchBlockAndUseList(falseBlock, falseOperands))
return nullptr;
auto branch =
builder.createCondBranch(getEncodedSourceLocation(loc),
cast<CFGValue>(cond), trueBlock, falseBlock);
branch->addTrueOperands(trueOperands);
branch->addFalseOperands(falseOperands);
return branch;
}
}
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ML Functions // ML Functions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -62,19 +62,13 @@ private:
}; };
} // end anonymous namespace } // end anonymous namespace
// Return a vector of OperationStmt's arguments as the CFGValues or SSAValues // Return a vector of OperationStmt's arguments as SSAValues. For each
// depending on the template argument. For each statement operands, represented // statement operands, represented as MLValue, lookup its CFGValue conterpart in
// as MLValue, lookup its CFGValue conterpart in the valueRemapping table. // the valueRemapping table.
// The return type parameterization is necessary because some instructions static llvm::SmallVector<SSAValue *, 4>
// accept vectors of SSAValues while others accept vectors of CFGValues.
template <typename SSAValueTy>
static llvm::SmallVector<SSAValueTy *, 4>
operandsAs(OperationStmt *opStmt, operandsAs(OperationStmt *opStmt,
const llvm::DenseMap<const MLValue *, CFGValue *> &valueRemapping) { const llvm::DenseMap<const MLValue *, CFGValue *> &valueRemapping) {
static_assert(std::is_same<SSAValueTy, SSAValue>::value || llvm::SmallVector<SSAValue *, 4> operands;
std::is_same<SSAValueTy, CFGValue>::value,
"can only cast statement operands to CFGValue or SSAValue");
llvm::SmallVector<SSAValueTy *, 4> operands;
for (const MLValue *operand : opStmt->getOperands()) { for (const MLValue *operand : opStmt->getOperands()) {
assert(valueRemapping.count(operand) != 0 && "operand is not defined"); assert(valueRemapping.count(operand) != 0 && "operand is not defined");
operands.push_back(valueRemapping.lookup(operand)); operands.push_back(valueRemapping.lookup(operand));
@ -89,20 +83,10 @@ operandsAs(OperationStmt *opStmt,
// mapping MLValue->CFGValue as the conversion is performed. The operation // mapping MLValue->CFGValue as the conversion is performed. The operation
// instruction is appended to current block (end of SESE region). // instruction is appended to current block (end of SESE region).
void FunctionConverter::visitOperationStmt(OperationStmt *opStmt) { void FunctionConverter::visitOperationStmt(OperationStmt *opStmt) {
// Handle returns separately, they are transformed into a specially-typed
// return instruction.
// TODO(zinenko): after terminators and operations are merged, remove this
// special case and de-template operandsAs.
if (opStmt->getName().getStringRef() == ReturnOp::getOperationName()) {
builder.createReturn(opStmt->getLoc(),
operandsAs<CFGValue>(opStmt, valueRemapping));
return;
}
// Set up basic operation state (context, name, operands). // Set up basic operation state (context, name, operands).
OperationState state(cfgFunc->getContext(), opStmt->getLoc(), OperationState state(cfgFunc->getContext(), opStmt->getLoc(),
opStmt->getName()); opStmt->getName());
state.addOperands(operandsAs<SSAValue>(opStmt, valueRemapping)); state.addOperands(operandsAs(opStmt, valueRemapping));
// Set up operation return types. The corresponding SSAValues will become // Set up operation return types. The corresponding SSAValues will become
// available after the operation is created. // available after the operation is created.
@ -206,7 +190,7 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
// At the loop insertion location, branch immediately to the loop init block. // At the loop insertion location, branch immediately to the loop init block.
builder.setInsertionPoint(loopInsertionPoint); builder.setInsertionPoint(loopInsertionPoint);
builder.createBranch(builder.getUnknownLoc(), loopInitBlock); builder.create<BranchOp>(builder.getUnknownLoc(), loopInitBlock);
// The loop condition block has an argument for loop induction variable. // The loop condition block has an argument for loop induction variable.
// Create it upfront and make the loop induction variable -> basic block // Create it upfront and make the loop induction variable -> basic block
@ -230,8 +214,8 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
CFGValue *step = getConstantIndexValue(forStmt->getStep()); CFGValue *step = getConstantIndexValue(forStmt->getStep());
auto stepOp = builder.create<AddIOp>(forStmt->getLoc(), iv, step); auto stepOp = builder.create<AddIOp>(forStmt->getLoc(), iv, step);
CFGValue *nextIvValue = cast<CFGValue>(stepOp->getResult()); CFGValue *nextIvValue = cast<CFGValue>(stepOp->getResult());
builder.createBranch(builder.getUnknownLoc(), loopConditionBlock, builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock,
{nextIvValue}); nextIvValue);
// Create post-loop block here so that it appears after all loop body blocks. // Create post-loop block here so that it appears after all loop body blocks.
BasicBlock *postLoopBlock = builder.createBlock(); BasicBlock *postLoopBlock = builder.createBlock();
@ -243,15 +227,15 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
getConstantIndexValue(forStmt->getConstantLowerBound()); getConstantIndexValue(forStmt->getConstantLowerBound());
CFGValue *upperBound = CFGValue *upperBound =
getConstantIndexValue(forStmt->getConstantUpperBound()); getConstantIndexValue(forStmt->getConstantUpperBound());
builder.createBranch(builder.getUnknownLoc(), loopConditionBlock, builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock,
{lowerBound}); lowerBound);
builder.setInsertionPoint(loopConditionBlock); builder.setInsertionPoint(loopConditionBlock);
auto comparisonOp = builder.create<CmpIOp>( auto comparisonOp = builder.create<CmpIOp>(
forStmt->getLoc(), CmpIPredicate::SLT, iv, upperBound); forStmt->getLoc(), CmpIPredicate::SLT, iv, upperBound);
auto comparisonResult = cast<CFGValue>(comparisonOp->getResult()); auto comparisonResult = cast<CFGValue>(comparisonOp->getResult());
builder.createCondBranch(builder.getUnknownLoc(), comparisonResult, builder.create<CondBranchOp>(builder.getUnknownLoc(), comparisonResult,
loopBodyFirstBlock, postLoopBlock); loopBodyFirstBlock, postLoopBlock);
// Finally, make sure building can continue by setting the post-loop block // Finally, make sure building can continue by setting the post-loop block
// (end of loop SESE region) as the insertion point. // (end of loop SESE region) as the insertion point.

View File

@ -338,7 +338,7 @@ mlfunc @malformed_type(%a : intt) { // expected-error {{expected type}}
cfgfunc @resulterror() -> i32 { cfgfunc @resulterror() -> i32 {
bb42: bb42:
return // expected-error {{return has 0 operands, but enclosing function returns 1}} return // expected-error {{'return' op has 0 operands, but enclosing function returns 1}}
} }
// ----- // -----
@ -387,7 +387,7 @@ cfgfunc @condbr_notbool() {
bb0: bb0:
%a = "foo"() : () -> i32 // expected-error {{prior use here}} %a = "foo"() : () -> i32 // expected-error {{prior use here}}
cond_br %a, bb0, bb0 // expected-error {{use of value '%a' expects different type than prior uses}} cond_br %a, bb0, bb0 // expected-error {{use of value '%a' expects different type than prior uses}}
// expected-error@-1 {{expected type was boolean (i1)}} // expected-error@-1 {{expected condition type was boolean (i1)}}
} }
// ----- // -----