Fix segfaults when printing unlinked statements, instructions and blocks. Fancy printing requires a pointer to the function since SSA values get function-specific names. This CL adds checks to ensure that we don't dereference null pointers in unliked objects. Unlinked statements, instructions and blocks are printed as <<UNLINKED STATEMENT>> etc.

PiperOrigin-RevId: 207293992
This commit is contained in:
Tatiana Shpeisman 2018-08-03 11:12:34 -07:00 committed by jpienaar
parent b4dea892f2
commit 2dcdec8910
5 changed files with 28 additions and 7 deletions

View File

@ -45,7 +45,6 @@ public:
};
Kind getKind() const { return kind; }
/// Remove this statement from its block and delete it.
void eraseFromBlock();
@ -58,6 +57,7 @@ public:
/// Returns the function that this statement is part of.
/// The function is determined by traversing the chain of parent statements.
/// Returns nullptr if the statement is unlinked.
MLFunction *findFunction() const;
/// Returns true if there are no more loops nested under this stmt.

View File

@ -582,6 +582,7 @@ public:
}
void printOperand(const SSAValue *value) { printValueID(value); }
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<const char *> elidedAttrs = {}) override;
@ -1182,6 +1183,10 @@ void SSAValue::print(raw_ostream &os) const {
void SSAValue::dump() const { print(llvm::errs()); }
void Instruction::print(raw_ostream &os) const {
if (!getFunction()) {
os << "<<UNLINKED INSTRUCTION>>\n";
return;
}
ModuleState state(getFunction()->getContext());
ModulePrinter modulePrinter(os, state);
CFGFunctionPrinter(getFunction(), modulePrinter).print(this);
@ -1193,6 +1198,10 @@ void Instruction::dump() const {
}
void BasicBlock::print(raw_ostream &os) const {
if (!getFunction()) {
os << "<<UNLINKED BLOCK>>\n";
return;
}
ModuleState state(getFunction()->getContext());
ModulePrinter modulePrinter(os, state);
CFGFunctionPrinter(getFunction(), modulePrinter).print(this);
@ -1202,6 +1211,11 @@ void BasicBlock::dump() const { print(llvm::errs()); }
void Statement::print(raw_ostream &os) const {
MLFunction *function = findFunction();
if (!function) {
os << "<<UNLINKED STATEMENT>>\n";
return;
}
ModuleState state(function->getContext());
ModulePrinter modulePrinter(os, state);
MLFunctionPrinter(function, modulePrinter).print(this);

View File

@ -71,11 +71,13 @@ void OperationInst::destroy() {
/// Return the context this operation is associated with.
MLIRContext *Instruction::getContext() const {
return getFunction()->getContext();
auto *fn = getFunction();
return fn ? fn->getContext() : nullptr;
}
CFGFunction *Instruction::getFunction() const {
return getBlock()->getFunction();
auto *block = getBlock();
return block ? block->getFunction() : nullptr;
}
unsigned Instruction::getNumOperands() const {

View File

@ -57,10 +57,12 @@ void Statement::destroy() {
}
}
Statement *Statement::getParentStmt() const { return block->getParentStmt(); }
Statement *Statement::getParentStmt() const {
return block ? block->getParentStmt() : nullptr;
}
MLFunction *Statement::findFunction() const {
return this->getBlock()->findFunction();
return block ? block->findFunction() : nullptr;
}
bool Statement::isInnermost() const {

View File

@ -38,7 +38,10 @@ Statement *StmtBlock::getParentStmt() const {
MLFunction *StmtBlock::findFunction() const {
StmtBlock *block = const_cast<StmtBlock *>(this);
while (block->getParentStmt() != nullptr)
while (block->getParentStmt()) {
block = block->getParentStmt()->getBlock();
return static_cast<MLFunction *>(block);
if (!block)
return nullptr;
}
return dyn_cast<MLFunction>(block);
}