forked from OSchip/llvm-project
[mlir] Add basic block arguments
This patch adds support for basic block arguments including parsing and printing. In doing so noticed that `ssa-id-and-type` is undefined in the MLIR spec; suggested an implementation in the spec doc. PiperOrigin-RevId: 205593369
This commit is contained in:
parent
e402dcc47f
commit
4144c302db
|
@ -22,6 +22,7 @@
|
|||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
class BBArgument;
|
||||
|
||||
/// Each basic block in a CFG function contains a list of basic block arguments,
|
||||
/// normal instructions, and a terminator instruction.
|
||||
|
@ -39,11 +40,33 @@ public:
|
|||
return function;
|
||||
}
|
||||
|
||||
// TODO: bb arguments
|
||||
|
||||
/// Unlink this BasicBlock from its CFGFunction and delete it.
|
||||
void eraseFromFunction();
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Block arguments management
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// This is the list of arguments to the block.
|
||||
typedef ArrayRef<BBArgument *> BBArgListType;
|
||||
BBArgListType getArguments() const { return arguments; }
|
||||
|
||||
using args_iterator = BBArgListType::iterator;
|
||||
using reverse_args_iterator = BBArgListType::reverse_iterator;
|
||||
args_iterator args_begin() const { return getArguments().begin(); }
|
||||
args_iterator args_end() const { return getArguments().end(); }
|
||||
reverse_args_iterator args_rbegin() const { return getArguments().rbegin(); }
|
||||
reverse_args_iterator args_rend() const { return getArguments().rend(); }
|
||||
|
||||
bool args_empty() const { return arguments.empty(); }
|
||||
BBArgument *addArgument(Type *type);
|
||||
llvm::iterator_range<BBArgListType::iterator>
|
||||
addArguments(ArrayRef<Type *> types);
|
||||
|
||||
unsigned getNumArguments() const { return arguments.size(); }
|
||||
BBArgument *getArgument(unsigned i) { return arguments[i]; }
|
||||
const BBArgument *getArgument(unsigned i) const { return arguments[i]; }
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Operation list management
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -105,6 +128,9 @@ private:
|
|||
/// This is the list of operations in the block.
|
||||
OperationListType operations;
|
||||
|
||||
/// This is the list of arguments to the block.
|
||||
std::vector<BBArgument *> arguments;
|
||||
|
||||
/// This is the owning reference to the terminator of the block.
|
||||
TerminatorInst *terminator = nullptr;
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "mlir/IR/SSAValue.h"
|
||||
|
||||
namespace mlir {
|
||||
class BasicBlock;
|
||||
class CFGValue;
|
||||
class Instruction;
|
||||
|
||||
|
@ -33,7 +34,7 @@ class Instruction;
|
|||
/// function. This should be kept as a proper subtype of SSAValueKind,
|
||||
/// including having all of the values of the enumerators align.
|
||||
enum class CFGValueKind {
|
||||
// TODO: BBArg,
|
||||
BBArgument = (int)SSAValueKind::BBArgument,
|
||||
InstResult = (int)SSAValueKind::InstResult,
|
||||
};
|
||||
|
||||
|
@ -45,6 +46,7 @@ class CFGValue : public SSAValueImpl<InstOperand, CFGValueKind> {
|
|||
public:
|
||||
static bool classof(const SSAValue *value) {
|
||||
switch (value->getKind()) {
|
||||
case SSAValueKind::BBArgument:
|
||||
case SSAValueKind::InstResult:
|
||||
return true;
|
||||
}
|
||||
|
@ -54,6 +56,27 @@ protected:
|
|||
CFGValue(CFGValueKind kind, Type *type) : SSAValueImpl(kind, type) {}
|
||||
};
|
||||
|
||||
/// Basic block arguments are CFG Values.
|
||||
class BBArgument : public CFGValue {
|
||||
public:
|
||||
static bool classof(const SSAValue *value) {
|
||||
return value->getKind() == SSAValueKind::BBArgument;
|
||||
}
|
||||
|
||||
BasicBlock *getOwner() { return owner; }
|
||||
const BasicBlock *getOwner() const { return owner; }
|
||||
|
||||
private:
|
||||
friend class BasicBlock; // For access to private constructor.
|
||||
BBArgument(Type *type, BasicBlock *owner)
|
||||
: CFGValue(CFGValueKind::BBArgument, type), owner(owner) {}
|
||||
|
||||
/// The owner of this operand.
|
||||
/// TODO: can encode this more efficiently to avoid the space hit of this
|
||||
/// through bitpacking shenanigans.
|
||||
BasicBlock *const owner;
|
||||
};
|
||||
|
||||
/// Instruction results are CFG Values.
|
||||
class InstResult : public CFGValue {
|
||||
public:
|
||||
|
|
|
@ -34,7 +34,7 @@ template <typename OperandType, typename OwnerType> class SSAValueUseIterator;
|
|||
|
||||
/// This enumerates all of the SSA value kinds in the MLIR system.
|
||||
enum class SSAValueKind {
|
||||
// TODO: BBArg,
|
||||
BBArgument,
|
||||
InstResult,
|
||||
|
||||
// FnArg
|
||||
|
|
|
@ -665,7 +665,9 @@ CFGFunctionPrinter::CFGFunctionPrinter(const CFGFunction *function,
|
|||
|
||||
/// Number all of the SSA values in the specified basic block.
|
||||
void CFGFunctionPrinter::numberValuesInBlock(const BasicBlock *block) {
|
||||
// TODO: basic block arguments.
|
||||
for (auto *arg : block->getArguments()) {
|
||||
numberValueID(arg);
|
||||
}
|
||||
for (auto &op : *block) {
|
||||
// We number instruction that have results, and we only number the first
|
||||
// result.
|
||||
|
@ -686,16 +688,26 @@ void CFGFunctionPrinter::print() {
|
|||
}
|
||||
|
||||
void CFGFunctionPrinter::print(const BasicBlock *block) {
|
||||
os << "bb" << getBBID(block) << ":\n";
|
||||
os << "bb" << getBBID(block);
|
||||
|
||||
if (!block->args_empty()) {
|
||||
os << '(';
|
||||
interleaveComma(block->getArguments(), [&](const BBArgument *arg) {
|
||||
printValueID(arg);
|
||||
os << ": ";
|
||||
ModulePrinter::print(arg->getType());
|
||||
});
|
||||
os << ')';
|
||||
}
|
||||
os << ":\n";
|
||||
|
||||
// TODO Print arguments.
|
||||
for (auto &inst : block->getOperations()) {
|
||||
print(&inst);
|
||||
os << "\n";
|
||||
os << '\n';
|
||||
}
|
||||
|
||||
print(block->getTerminator());
|
||||
os << "\n";
|
||||
os << '\n';
|
||||
}
|
||||
|
||||
void CFGFunctionPrinter::print(const Instruction *inst) {
|
||||
|
|
|
@ -19,12 +19,14 @@
|
|||
#include "mlir/IR/CFGFunction.h"
|
||||
using namespace mlir;
|
||||
|
||||
BasicBlock::BasicBlock() {
|
||||
}
|
||||
BasicBlock::BasicBlock() {}
|
||||
|
||||
BasicBlock::~BasicBlock() {
|
||||
if (terminator)
|
||||
terminator->eraseFromBlock();
|
||||
for (BBArgument *arg : arguments)
|
||||
delete arg;
|
||||
arguments.clear();
|
||||
}
|
||||
|
||||
/// Unlink this BasicBlock from its CFGFunction and delete it.
|
||||
|
@ -84,3 +86,17 @@ transferNodesFromList(ilist_traits<BasicBlock> &otherList,
|
|||
for (; first != last; ++first)
|
||||
first->function = curParent;
|
||||
}
|
||||
|
||||
BBArgument *BasicBlock::addArgument(Type *type) {
|
||||
arguments.push_back(new BBArgument(type, this));
|
||||
return arguments.back();
|
||||
}
|
||||
|
||||
llvm::iterator_range<BasicBlock::BBArgListType::iterator>
|
||||
BasicBlock::addArguments(ArrayRef<Type *> types) {
|
||||
auto initial_size = arguments.size();
|
||||
for (auto *type : types) {
|
||||
addArgument(type);
|
||||
}
|
||||
return {arguments.data() + initial_size, arguments.data() + arguments.size()};
|
||||
}
|
||||
|
|
|
@ -106,6 +106,26 @@ public:
|
|||
bool CFGFuncVerifier::verify() {
|
||||
// TODO: Lots to be done here, including verifying dominance information when
|
||||
// we have uses and defs.
|
||||
// TODO: Verify the first block has no predecessors.
|
||||
|
||||
if (fn.empty())
|
||||
return failure("cfgfunc must have at least one basic block", fn);
|
||||
|
||||
// Verify that the argument list of the function and the arg list of the first
|
||||
// block line up.
|
||||
auto *firstBB = &fn.front();
|
||||
auto fnInputTypes = fn.getType()->getInputs();
|
||||
if (fnInputTypes.size() != firstBB->getNumArguments())
|
||||
return failure("first block of cfgfunc must have " +
|
||||
Twine(fnInputTypes.size()) +
|
||||
" arguments to match function signature",
|
||||
fn);
|
||||
for (unsigned i = 0, e = firstBB->getNumArguments(); i != e; ++i)
|
||||
if (fnInputTypes[i] != firstBB->getArgument(i)->getType())
|
||||
return failure(
|
||||
"type of argument #" + Twine(i) +
|
||||
" must match corresponding argument in function signature",
|
||||
fn);
|
||||
|
||||
for (auto &block : fn) {
|
||||
if (verifyBlock(block))
|
||||
|
@ -121,6 +141,11 @@ bool CFGFuncVerifier::verifyBlock(const BasicBlock &block) {
|
|||
if (verifyTerminator(*block.getTerminator()))
|
||||
return true;
|
||||
|
||||
for (auto *arg : block.getArguments()) {
|
||||
if (arg->getOwner() != &block)
|
||||
return failure("basic block argument not owned by block", block);
|
||||
}
|
||||
|
||||
for (auto &inst : block) {
|
||||
if (verifyOperation(inst))
|
||||
return true;
|
||||
|
|
|
@ -1219,7 +1219,17 @@ public:
|
|||
// SSA parsing productions.
|
||||
ParseResult parseSSAUse(SSAUseInfo &result);
|
||||
ParseResult parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results);
|
||||
SSAValue *parseSSAUseAndType();
|
||||
|
||||
template <typename ResultType>
|
||||
ResultType parseSSADefOrUseAndType(
|
||||
const std::function<ResultType(SSAUseInfo, Type *)> &action);
|
||||
|
||||
SSAValue *parseSSAUseAndType() {
|
||||
return parseSSADefOrUseAndType<SSAValue *>(
|
||||
[&](SSAUseInfo useInfo, Type *type) -> SSAValue * {
|
||||
return resolveSSAUse(useInfo, type);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename ValueTy>
|
||||
ParseResult
|
||||
|
@ -1355,8 +1365,7 @@ ParseResult FunctionParser::finalizeFunction(Function *func, SMLoc loc) {
|
|||
|
||||
/// Parse a SSA operand for an instruction or statement.
|
||||
///
|
||||
/// ssa-use ::= ssa-id | ssa-constant
|
||||
/// TODO: SSA Constants.
|
||||
/// ssa-use ::= ssa-id
|
||||
///
|
||||
ParseResult FunctionParser::parseSSAUse(SSAUseInfo &result) {
|
||||
result.name = getTokenSpelling();
|
||||
|
@ -1398,7 +1407,9 @@ FunctionParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) {
|
|||
/// Parse an SSA use with an associated type.
|
||||
///
|
||||
/// ssa-use-and-type ::= ssa-use `:` type
|
||||
SSAValue *FunctionParser::parseSSAUseAndType() {
|
||||
template <typename ResultType>
|
||||
ResultType FunctionParser::parseSSADefOrUseAndType(
|
||||
const std::function<ResultType(SSAUseInfo, Type *)> &action) {
|
||||
SSAUseInfo useInfo;
|
||||
if (parseSSAUse(useInfo))
|
||||
return nullptr;
|
||||
|
@ -1410,7 +1421,7 @@ SSAValue *FunctionParser::parseSSAUseAndType() {
|
|||
if (!type)
|
||||
return nullptr;
|
||||
|
||||
return resolveSSAUse(useInfo, type);
|
||||
return action(useInfo, type);
|
||||
}
|
||||
|
||||
/// Parse a (possibly empty) list of SSA operands with types.
|
||||
|
@ -1570,12 +1581,39 @@ private:
|
|||
return blockAndLoc.first;
|
||||
}
|
||||
|
||||
ParseResult
|
||||
parseOptionalBasicBlockArgList(SmallVectorImpl<BBArgument *> &results,
|
||||
BasicBlock *owner);
|
||||
|
||||
ParseResult parseBasicBlock();
|
||||
OperationInst *parseCFGOperation();
|
||||
TerminatorInst *parseTerminator();
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Parse a (possibly empty) list of SSA operands with types as basic block
|
||||
/// arguments. Unlike parseOptionalSsaUseAndTypeList the SSA IDs are treated as
|
||||
/// defs, not uses.
|
||||
///
|
||||
/// ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)*
|
||||
///
|
||||
ParseResult CFGFunctionParser::parseOptionalBasicBlockArgList(
|
||||
SmallVectorImpl<BBArgument *> &results, BasicBlock *owner) {
|
||||
if (getToken().is(Token::r_brace))
|
||||
return ParseSuccess;
|
||||
|
||||
return parseCommaSeparatedList([&]() -> ParseResult {
|
||||
auto type = parseSSADefOrUseAndType<Type *>(
|
||||
[&](SSAUseInfo useInfo, Type *type) -> Type * {
|
||||
BBArgument *arg = owner->addArgument(type);
|
||||
if (addDefinition(useInfo, arg) == ParseFailure)
|
||||
return nullptr;
|
||||
return type;
|
||||
});
|
||||
return type ? ParseSuccess : ParseFailure;
|
||||
});
|
||||
}
|
||||
|
||||
ParseResult CFGFunctionParser::parseFunctionBody() {
|
||||
auto braceLoc = getToken().getLoc();
|
||||
if (!consumeIf(Token::l_brace))
|
||||
|
@ -1625,20 +1663,18 @@ ParseResult CFGFunctionParser::parseBasicBlock() {
|
|||
if (block->getFunction())
|
||||
return emitError(nameLoc, "redefinition of block '" + name.str() + "'");
|
||||
|
||||
// Add the block to the function.
|
||||
function->push_back(block);
|
||||
|
||||
// If an argument list is present, parse it.
|
||||
if (consumeIf(Token::l_paren)) {
|
||||
SmallVector<SSAUseInfo, 8> bbArgs;
|
||||
if (parseOptionalSSAUseList(bbArgs))
|
||||
SmallVector<BBArgument *, 8> bbArgs;
|
||||
if (parseOptionalBasicBlockArgList(bbArgs, block))
|
||||
return ParseFailure;
|
||||
if (!consumeIf(Token::r_paren))
|
||||
return emitError("expected ')' to end argument list");
|
||||
|
||||
// TODO: attach it.
|
||||
}
|
||||
|
||||
// Add the block to the function.
|
||||
function->push_back(block);
|
||||
|
||||
if (!consumeIf(Token::colon))
|
||||
return emitError("expected ':' after basic block name");
|
||||
|
||||
|
|
|
@ -91,6 +91,27 @@ bb42: // expected-error {{expected operation name}}
|
|||
|
||||
// -----
|
||||
|
||||
cfgfunc @block_no_rparen() {
|
||||
bb42 (%bb42 : i32: // expected-error {{expected ')' to end argument list}}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
cfgfunc @block_arg_no_ssaid() {
|
||||
bb42 (i32): // expected-error {{expected SSA operand}}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
cfgfunc @block_arg_no_type() {
|
||||
bb42 (%0): // expected-error {{expected ':' and type for SSA operand}}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
mlfunc @foo()
|
||||
mlfunc @bar() // expected-error {{expected '{' in ML function}}
|
||||
|
||||
|
@ -208,3 +229,17 @@ bb42:
|
|||
}
|
||||
|
||||
// -----
|
||||
|
||||
cfgfunc @argError() {
|
||||
bb1(%a: i64): // expected-error {{previously defined here}}
|
||||
br bb2
|
||||
bb2(%a: i64): // expected-error{{redefinition of SSA value '%a'}}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
cfgfunc @bbargMismatch(i32, f32) { // expected-error {{first block of cfgfunc must have 2 arguments to match function signature}}
|
||||
bb42(%0: f32):
|
||||
return
|
||||
}
|
||||
|
|
|
@ -68,17 +68,28 @@ extfunc @functions((memref<1x?x4x?x?xaffineint, #map0, 0>, memref<i8, #map1, 0>)
|
|||
|
||||
// CHECK-LABEL: cfgfunc @simpleCFG(i32, f32) {
|
||||
cfgfunc @simpleCFG(i32, f32) {
|
||||
// CHECK: bb0:
|
||||
bb42: // (%0: i32, %f: f32): TODO(clattner): implement bbargs.
|
||||
// CHECK: %0 = "foo"() : () -> i64
|
||||
// CHECK: bb0(%0: i32, %1: f32):
|
||||
bb42 (%0: i32, %f: f32):
|
||||
// CHECK: %2 = "foo"() : () -> i64
|
||||
%1 = "foo"() : ()->i64
|
||||
// CHECK: "bar"(%0) : (i64) -> (i1, i1, i1)
|
||||
// CHECK: "bar"(%2) : (i64) -> (i1, i1, i1)
|
||||
%2 = "bar"(%1) : (i64) -> (i1,i1,i1)
|
||||
// CHECK: return
|
||||
return
|
||||
// CHECK: }
|
||||
}
|
||||
|
||||
// CHECK-LABEL: cfgfunc @simpleCFGUsingBBArgs(i32, i64) {
|
||||
cfgfunc @simpleCFGUsingBBArgs(i32, i64) {
|
||||
// CHECK: bb0(%0: i32, %1: i64):
|
||||
bb42 (%0: i32, %f: i64):
|
||||
// CHECK: "bar"(%1) : (i64) -> (i1, i1, i1)
|
||||
%2 = "bar"(%f) : (i64) -> (i1,i1,i1)
|
||||
// CHECK: return
|
||||
return
|
||||
// CHECK: }
|
||||
}
|
||||
|
||||
// CHECK-LABEL: cfgfunc @multiblock() {
|
||||
cfgfunc @multiblock() {
|
||||
bb0: // CHECK: bb0:
|
||||
|
|
|
@ -118,7 +118,7 @@ splitMemoryBufferForErrorChecking(std::unique_ptr<MemoryBuffer> buffer) {
|
|||
SMLoc());
|
||||
|
||||
// Extracing the expected errors.
|
||||
llvm::Regex expected("expected-error(@[+-][0-9]+)? {{(.*)}}");
|
||||
llvm::Regex expected("expected-error(@[+-][0-9]+)? *{{(.*)}}");
|
||||
SmallVector<ExpectedError, 2> expectedErrors;
|
||||
SmallVector<StringRef, 100> lines;
|
||||
subbuffer.split(lines, '\n');
|
||||
|
|
Loading…
Reference in New Issue