[mlir] Add basic block arguments

This patch adds support for basic block arguments including parsing and printing.

In doing so noticed that `ssa-id-and-type` is undefined in the MLIR spec; suggested an implementation in the spec doc.

PiperOrigin-RevId: 205593369
This commit is contained in:
James Molloy 2018-07-22 15:45:24 -07:00 committed by jpienaar
parent e402dcc47f
commit 4144c302db
10 changed files with 212 additions and 28 deletions

View File

@ -22,6 +22,7 @@
#include <memory>
namespace mlir {
class BBArgument;
/// Each basic block in a CFG function contains a list of basic block arguments,
/// normal instructions, and a terminator instruction.
@ -39,11 +40,33 @@ public:
return function;
}
// TODO: bb arguments
/// Unlink this BasicBlock from its CFGFunction and delete it.
void eraseFromFunction();
//===--------------------------------------------------------------------===//
// Block arguments management
//===--------------------------------------------------------------------===//
// This is the list of arguments to the block.
typedef ArrayRef<BBArgument *> BBArgListType;
BBArgListType getArguments() const { return arguments; }
using args_iterator = BBArgListType::iterator;
using reverse_args_iterator = BBArgListType::reverse_iterator;
args_iterator args_begin() const { return getArguments().begin(); }
args_iterator args_end() const { return getArguments().end(); }
reverse_args_iterator args_rbegin() const { return getArguments().rbegin(); }
reverse_args_iterator args_rend() const { return getArguments().rend(); }
bool args_empty() const { return arguments.empty(); }
BBArgument *addArgument(Type *type);
llvm::iterator_range<BBArgListType::iterator>
addArguments(ArrayRef<Type *> types);
unsigned getNumArguments() const { return arguments.size(); }
BBArgument *getArgument(unsigned i) { return arguments[i]; }
const BBArgument *getArgument(unsigned i) const { return arguments[i]; }
//===--------------------------------------------------------------------===//
// Operation list management
//===--------------------------------------------------------------------===//
@ -105,6 +128,9 @@ private:
/// This is the list of operations in the block.
OperationListType operations;
/// This is the list of arguments to the block.
std::vector<BBArgument *> arguments;
/// This is the owning reference to the terminator of the block.
TerminatorInst *terminator = nullptr;

View File

@ -26,6 +26,7 @@
#include "mlir/IR/SSAValue.h"
namespace mlir {
class BasicBlock;
class CFGValue;
class Instruction;
@ -33,7 +34,7 @@ class Instruction;
/// function. This should be kept as a proper subtype of SSAValueKind,
/// including having all of the values of the enumerators align.
enum class CFGValueKind {
// TODO: BBArg,
BBArgument = (int)SSAValueKind::BBArgument,
InstResult = (int)SSAValueKind::InstResult,
};
@ -45,6 +46,7 @@ class CFGValue : public SSAValueImpl<InstOperand, CFGValueKind> {
public:
static bool classof(const SSAValue *value) {
switch (value->getKind()) {
case SSAValueKind::BBArgument:
case SSAValueKind::InstResult:
return true;
}
@ -54,6 +56,27 @@ protected:
CFGValue(CFGValueKind kind, Type *type) : SSAValueImpl(kind, type) {}
};
/// Basic block arguments are CFG Values.
class BBArgument : public CFGValue {
public:
static bool classof(const SSAValue *value) {
return value->getKind() == SSAValueKind::BBArgument;
}
BasicBlock *getOwner() { return owner; }
const BasicBlock *getOwner() const { return owner; }
private:
friend class BasicBlock; // For access to private constructor.
BBArgument(Type *type, BasicBlock *owner)
: CFGValue(CFGValueKind::BBArgument, type), owner(owner) {}
/// The owner of this operand.
/// TODO: can encode this more efficiently to avoid the space hit of this
/// through bitpacking shenanigans.
BasicBlock *const owner;
};
/// Instruction results are CFG Values.
class InstResult : public CFGValue {
public:

View File

@ -34,7 +34,7 @@ template <typename OperandType, typename OwnerType> class SSAValueUseIterator;
/// This enumerates all of the SSA value kinds in the MLIR system.
enum class SSAValueKind {
// TODO: BBArg,
BBArgument,
InstResult,
// FnArg

View File

@ -665,7 +665,9 @@ CFGFunctionPrinter::CFGFunctionPrinter(const CFGFunction *function,
/// Number all of the SSA values in the specified basic block.
void CFGFunctionPrinter::numberValuesInBlock(const BasicBlock *block) {
// TODO: basic block arguments.
for (auto *arg : block->getArguments()) {
numberValueID(arg);
}
for (auto &op : *block) {
// We number instruction that have results, and we only number the first
// result.
@ -686,16 +688,26 @@ void CFGFunctionPrinter::print() {
}
void CFGFunctionPrinter::print(const BasicBlock *block) {
os << "bb" << getBBID(block) << ":\n";
os << "bb" << getBBID(block);
if (!block->args_empty()) {
os << '(';
interleaveComma(block->getArguments(), [&](const BBArgument *arg) {
printValueID(arg);
os << ": ";
ModulePrinter::print(arg->getType());
});
os << ')';
}
os << ":\n";
// TODO Print arguments.
for (auto &inst : block->getOperations()) {
print(&inst);
os << "\n";
os << '\n';
}
print(block->getTerminator());
os << "\n";
os << '\n';
}
void CFGFunctionPrinter::print(const Instruction *inst) {

View File

@ -19,12 +19,14 @@
#include "mlir/IR/CFGFunction.h"
using namespace mlir;
BasicBlock::BasicBlock() {
}
BasicBlock::BasicBlock() {}
BasicBlock::~BasicBlock() {
if (terminator)
terminator->eraseFromBlock();
for (BBArgument *arg : arguments)
delete arg;
arguments.clear();
}
/// Unlink this BasicBlock from its CFGFunction and delete it.
@ -84,3 +86,17 @@ transferNodesFromList(ilist_traits<BasicBlock> &otherList,
for (; first != last; ++first)
first->function = curParent;
}
BBArgument *BasicBlock::addArgument(Type *type) {
arguments.push_back(new BBArgument(type, this));
return arguments.back();
}
llvm::iterator_range<BasicBlock::BBArgListType::iterator>
BasicBlock::addArguments(ArrayRef<Type *> types) {
auto initial_size = arguments.size();
for (auto *type : types) {
addArgument(type);
}
return {arguments.data() + initial_size, arguments.data() + arguments.size()};
}

View File

@ -106,6 +106,26 @@ public:
bool CFGFuncVerifier::verify() {
// TODO: Lots to be done here, including verifying dominance information when
// we have uses and defs.
// TODO: Verify the first block has no predecessors.
if (fn.empty())
return failure("cfgfunc must have at least one basic block", fn);
// Verify that the argument list of the function and the arg list of the first
// block line up.
auto *firstBB = &fn.front();
auto fnInputTypes = fn.getType()->getInputs();
if (fnInputTypes.size() != firstBB->getNumArguments())
return failure("first block of cfgfunc must have " +
Twine(fnInputTypes.size()) +
" arguments to match function signature",
fn);
for (unsigned i = 0, e = firstBB->getNumArguments(); i != e; ++i)
if (fnInputTypes[i] != firstBB->getArgument(i)->getType())
return failure(
"type of argument #" + Twine(i) +
" must match corresponding argument in function signature",
fn);
for (auto &block : fn) {
if (verifyBlock(block))
@ -121,6 +141,11 @@ bool CFGFuncVerifier::verifyBlock(const BasicBlock &block) {
if (verifyTerminator(*block.getTerminator()))
return true;
for (auto *arg : block.getArguments()) {
if (arg->getOwner() != &block)
return failure("basic block argument not owned by block", block);
}
for (auto &inst : block) {
if (verifyOperation(inst))
return true;

View File

@ -1219,7 +1219,17 @@ public:
// SSA parsing productions.
ParseResult parseSSAUse(SSAUseInfo &result);
ParseResult parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results);
SSAValue *parseSSAUseAndType();
template <typename ResultType>
ResultType parseSSADefOrUseAndType(
const std::function<ResultType(SSAUseInfo, Type *)> &action);
SSAValue *parseSSAUseAndType() {
return parseSSADefOrUseAndType<SSAValue *>(
[&](SSAUseInfo useInfo, Type *type) -> SSAValue * {
return resolveSSAUse(useInfo, type);
});
}
template <typename ValueTy>
ParseResult
@ -1355,8 +1365,7 @@ ParseResult FunctionParser::finalizeFunction(Function *func, SMLoc loc) {
/// Parse a SSA operand for an instruction or statement.
///
/// ssa-use ::= ssa-id | ssa-constant
/// TODO: SSA Constants.
/// ssa-use ::= ssa-id
///
ParseResult FunctionParser::parseSSAUse(SSAUseInfo &result) {
result.name = getTokenSpelling();
@ -1398,7 +1407,9 @@ FunctionParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) {
/// Parse an SSA use with an associated type.
///
/// ssa-use-and-type ::= ssa-use `:` type
SSAValue *FunctionParser::parseSSAUseAndType() {
template <typename ResultType>
ResultType FunctionParser::parseSSADefOrUseAndType(
const std::function<ResultType(SSAUseInfo, Type *)> &action) {
SSAUseInfo useInfo;
if (parseSSAUse(useInfo))
return nullptr;
@ -1410,7 +1421,7 @@ SSAValue *FunctionParser::parseSSAUseAndType() {
if (!type)
return nullptr;
return resolveSSAUse(useInfo, type);
return action(useInfo, type);
}
/// Parse a (possibly empty) list of SSA operands with types.
@ -1570,12 +1581,39 @@ private:
return blockAndLoc.first;
}
ParseResult
parseOptionalBasicBlockArgList(SmallVectorImpl<BBArgument *> &results,
BasicBlock *owner);
ParseResult parseBasicBlock();
OperationInst *parseCFGOperation();
TerminatorInst *parseTerminator();
};
} // end anonymous namespace
/// Parse a (possibly empty) list of SSA operands with types as basic block
/// arguments. Unlike parseOptionalSsaUseAndTypeList the SSA IDs are treated as
/// defs, not uses.
///
/// ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)*
///
ParseResult CFGFunctionParser::parseOptionalBasicBlockArgList(
SmallVectorImpl<BBArgument *> &results, BasicBlock *owner) {
if (getToken().is(Token::r_brace))
return ParseSuccess;
return parseCommaSeparatedList([&]() -> ParseResult {
auto type = parseSSADefOrUseAndType<Type *>(
[&](SSAUseInfo useInfo, Type *type) -> Type * {
BBArgument *arg = owner->addArgument(type);
if (addDefinition(useInfo, arg) == ParseFailure)
return nullptr;
return type;
});
return type ? ParseSuccess : ParseFailure;
});
}
ParseResult CFGFunctionParser::parseFunctionBody() {
auto braceLoc = getToken().getLoc();
if (!consumeIf(Token::l_brace))
@ -1625,20 +1663,18 @@ ParseResult CFGFunctionParser::parseBasicBlock() {
if (block->getFunction())
return emitError(nameLoc, "redefinition of block '" + name.str() + "'");
// Add the block to the function.
function->push_back(block);
// If an argument list is present, parse it.
if (consumeIf(Token::l_paren)) {
SmallVector<SSAUseInfo, 8> bbArgs;
if (parseOptionalSSAUseList(bbArgs))
SmallVector<BBArgument *, 8> bbArgs;
if (parseOptionalBasicBlockArgList(bbArgs, block))
return ParseFailure;
if (!consumeIf(Token::r_paren))
return emitError("expected ')' to end argument list");
// TODO: attach it.
}
// Add the block to the function.
function->push_back(block);
if (!consumeIf(Token::colon))
return emitError("expected ':' after basic block name");

View File

@ -91,6 +91,27 @@ bb42: // expected-error {{expected operation name}}
// -----
cfgfunc @block_no_rparen() {
bb42 (%bb42 : i32: // expected-error {{expected ')' to end argument list}}
return
}
// -----
cfgfunc @block_arg_no_ssaid() {
bb42 (i32): // expected-error {{expected SSA operand}}
return
}
// -----
cfgfunc @block_arg_no_type() {
bb42 (%0): // expected-error {{expected ':' and type for SSA operand}}
return
}
// -----
mlfunc @foo()
mlfunc @bar() // expected-error {{expected '{' in ML function}}
@ -208,3 +229,17 @@ bb42:
}
// -----
cfgfunc @argError() {
bb1(%a: i64): // expected-error {{previously defined here}}
br bb2
bb2(%a: i64): // expected-error{{redefinition of SSA value '%a'}}
return
}
// -----
cfgfunc @bbargMismatch(i32, f32) { // expected-error {{first block of cfgfunc must have 2 arguments to match function signature}}
bb42(%0: f32):
return
}

View File

@ -68,17 +68,28 @@ extfunc @functions((memref<1x?x4x?x?xaffineint, #map0, 0>, memref<i8, #map1, 0>)
// CHECK-LABEL: cfgfunc @simpleCFG(i32, f32) {
cfgfunc @simpleCFG(i32, f32) {
// CHECK: bb0:
bb42: // (%0: i32, %f: f32): TODO(clattner): implement bbargs.
// CHECK: %0 = "foo"() : () -> i64
// CHECK: bb0(%0: i32, %1: f32):
bb42 (%0: i32, %f: f32):
// CHECK: %2 = "foo"() : () -> i64
%1 = "foo"() : ()->i64
// CHECK: "bar"(%0) : (i64) -> (i1, i1, i1)
// CHECK: "bar"(%2) : (i64) -> (i1, i1, i1)
%2 = "bar"(%1) : (i64) -> (i1,i1,i1)
// CHECK: return
return
// CHECK: }
}
// CHECK-LABEL: cfgfunc @simpleCFGUsingBBArgs(i32, i64) {
cfgfunc @simpleCFGUsingBBArgs(i32, i64) {
// CHECK: bb0(%0: i32, %1: i64):
bb42 (%0: i32, %f: i64):
// CHECK: "bar"(%1) : (i64) -> (i1, i1, i1)
%2 = "bar"(%f) : (i64) -> (i1,i1,i1)
// CHECK: return
return
// CHECK: }
}
// CHECK-LABEL: cfgfunc @multiblock() {
cfgfunc @multiblock() {
bb0: // CHECK: bb0:

View File

@ -118,7 +118,7 @@ splitMemoryBufferForErrorChecking(std::unique_ptr<MemoryBuffer> buffer) {
SMLoc());
// Extracing the expected errors.
llvm::Regex expected("expected-error(@[+-][0-9]+)? {{(.*)}}");
llvm::Regex expected("expected-error(@[+-][0-9]+)? *{{(.*)}}");
SmallVector<ExpectedError, 2> expectedErrors;
SmallVector<StringRef, 100> lines;
subbuffer.split(lines, '\n');