EDSC: support branch instructions

The new implementation of blocks was designed to support blocks with arguments.
More specifically, StmtBlock can be constructed with a list of Bindables that
will be bound to block aguments upon construction.  Leverage this functionality
to implement branch instructions with arguments.

This additionally requires the statement storage to have a list of successors,
similarly to core IR operations.

Becauase successor chains can form loops, we need a possibility to decouple
block declaration, after which it becomes usable by branch instructions, from
block body definition.  This is achieved by creating an empty block and by
resetting its body with a new list of instructions.  Note that assigning a
block from another block will not affect any instructions that may have
designated this block as their successor (this behavior is necessary to make
value-type semantics of EDSC types consistent).  Combined, one can now write
generators like

    EDSCContext context;
    Type indexType = ...;
    Bindable i(indexType), ii(indexType), zero(indexType), one(indexType);
    StmtBlock loopBlock({i}, {});
    loopBlock.set({ii = i + one,
                   Branch(loopBlock, {ii})});
    MLIREmitter(&builder)
        .bindConstant<ConstantIndexOp>(zero, 0)
        .bindConstant<ConstantIndexOp>(one, 1)
	.emitStmt(Branch(loopBlock, {zero}));

where the emitter will emit the statement and its successors, if present.

PiperOrigin-RevId: 235541892
This commit is contained in:
Alex Zinenko 2019-02-25 09:16:30 -08:00 committed by jpienaar
parent 8b99d1bdbf
commit 83e8db2193
8 changed files with 377 additions and 106 deletions

View File

@ -241,6 +241,8 @@ struct PythonBlock {
return StmtBlock(blk).str();
}
PythonBlock set(const py::list &stmts);
edsc_block_t blk;
};
@ -317,6 +319,14 @@ static edsc_expr_list_t makeCExprs(llvm::SmallVectorImpl<edsc_expr_t> &owning,
return edsc_expr_list_t{owning.data(), owning.size()};
}
static mlir_type_list_t makeCTypes(llvm::SmallVectorImpl<mlir_type_t> &owning,
const py::list &types) {
for (auto &inp : types) {
owning.push_back(mlir_type_t{inp.cast<PythonType>()});
}
return mlir_type_list_t{owning.data(), owning.size()};
}
PythonExpr::PythonExpr(const PythonBindable &bindable) : expr{bindable.expr} {}
PythonExpr MLIRFunctionEmitter::bindConstantBF16(double value) {
@ -410,6 +420,12 @@ void MLIRFunctionEmitter::emitBlockBody(PythonBlock block) {
emitter.emitStmts(StmtBlock(block).getBody());
}
PythonBlock PythonBlock::set(const py::list &stmts) {
SmallVector<edsc_stmt_t, 8> owning;
::BlockSetBody(blk, makeCStmts(owning, stmts));
return *this;
}
PythonExpr dispatchCall(py::args args, py::kwargs kwargs) {
assert(args.size() != 0);
llvm::SmallVector<edsc_expr_t, 8> exprs;
@ -438,10 +454,24 @@ PYBIND11_MODULE(pybind, m) {
m.def("deleteContext",
[](void *ctx) { delete reinterpret_cast<ScopedEDSCContext *>(ctx); });
m.def("Block", [](const py::list &args, const py::list &stmts) {
SmallVector<edsc_stmt_t, 8> owning;
SmallVector<edsc_expr_t, 8> owningArgs;
return PythonBlock(
::Block(makeCExprs(owningArgs, args), makeCStmts(owning, stmts)));
});
m.def("Block", [](const py::list &stmts) {
SmallVector<edsc_stmt_t, 8> owning;
return PythonBlock(::Block(makeCStmts(owning, stmts)));
edsc_expr_list_t args{nullptr, 0};
return PythonBlock(::Block(args, makeCStmts(owning, stmts)));
});
m.def(
"Branch",
[](PythonBlock destination, const py::list &operands) {
SmallVector<edsc_expr_t, 8> owning;
return PythonStmt(::Branch(destination, makeCExprs(owning, operands)));
},
py::arg("destination"), py::arg("operands") = py::list());
m.def("For", [](const py::list &ivs, const py::list &lbs, const py::list &ubs,
const py::list &steps, const py::list &stmts) {
SmallVector<edsc_expr_t, 8> owningIVs;
@ -527,6 +557,7 @@ PYBIND11_MODULE(pybind, m) {
py::class_<PythonBlock>(m, "StmtBlock",
"Wrapping class for mlir::edsc::StmtBlock")
.def(py::init<PythonBlock>())
.def("set", &PythonBlock::set)
.def("__str__", &PythonBlock::str);
py::class_<PythonType>(m, "Type", "Wrapping class for mlir::Type.")

View File

@ -151,10 +151,47 @@ class EdscTest(unittest.TestCase):
i, j = list(map(E.Expr, [E.Bindable(self.f32Type) for _ in range(2)]))
stmt = E.Block([E.Stmt(i + j), E.Stmt(i - j)])
str = stmt.__str__()
self.assertIn("^bb:", str)
self.assertIn("^bb", str)
self.assertIn(" = ($1 + $2)", str)
self.assertIn(" = ($1 - $2)", str)
def testBlockArgs(self):
with E.ContextManager():
module = E.MLIRModule()
t = module.make_scalar_type("i", 32)
i, j = list(map(E.Expr, [E.Bindable(t) for _ in range(2)]))
stmt = E.Block([i, j], [E.Stmt(i + j)])
str = stmt.__str__()
self.assertIn("^bb", str)
self.assertIn("($1, $2):", str)
self.assertIn("($1 + $2)", str)
def testBranch(self):
with E.ContextManager():
i, j = list(map(E.Expr, [E.Bindable(self.i32Type) for _ in range(2)]))
b1 = E.Block([E.Stmt(i + j)])
b2 = E.Block([E.Branch(b1)])
str1 = b1.__str__()
str2 = b2.__str__()
self.assertIn("^bb1:\n" + "$4 = ($1 + $2)", str1)
self.assertIn("^bb2:\n" + "$6 = br ^bb1", str2)
def testBranchArgs(self):
with E.ContextManager():
b1arg, b2arg = (E.Expr(E.Bindable(self.i32Type)) for _ in range(2))
# Declare empty blocks with arguments and bind those arguments.
b1 = E.Block([b1arg], [])
b2 = E.Block([b2arg], [])
one = E.ConstantInteger(self.i32Type, 1)
# Make blocks branch to each other in a sort of infinite loop.
# This checks that the EDSC implementation does not fall into such loop.
b1.set([E.Branch(b2, [b1arg + one])])
b2.set([E.Branch(b1, [b2arg])])
str1 = b1.__str__()
str2 = b2.__str__()
self.assertIn("^bb1($1):\n" + "$6 = br ^bb2(($1 + 1))", str1)
self.assertIn("^bb2($2):\n" + "$8 = br ^bb1($2)", str2)
def testMLIRScalarTypes(self):
module = E.MLIRModule()
t = module.make_scalar_type("bf16")

View File

@ -230,8 +230,15 @@ edsc_expr_t ConstantInteger(mlir_type_t type, int64_t value);
edsc_stmt_t Return(edsc_expr_list_t values);
/// Returns an opaque expression for an mlir::edsc::StmtBlock containing the
/// given list of statements. Block arguments are not currently supported.
edsc_block_t Block(edsc_stmt_list_t enclosedStmts);
/// given list of statements.
edsc_block_t Block(edsc_expr_list_t arguments, edsc_stmt_list_t enclosedStmts);
/// Set the body of the block to the given statements and return the block.
edsc_block_t BlockSetBody(edsc_block_t, edsc_stmt_list_t stmts);
/// Returns an opaque statement branching to `destination` and passing
/// `arguments` as block arguments.
edsc_stmt_t Branch(edsc_block_t destination, edsc_expr_list_t arguments);
/// Returns an opaque statement for an mlir::AffineForOp with `enclosedStmts`
/// nested below it.

View File

@ -48,6 +48,8 @@ struct StmtBlockStorage;
} // namespace detail
class StmtBlock;
/// EDSC Types closely mirror the core MLIR and uses an abstraction similar to
/// AffineExpr:
/// 1. a set of composable structs;
@ -164,7 +166,20 @@ public:
ArrayRef<Type> getResultTypes() const;
/// Returns the list of expressions used as arguments of this expression.
ArrayRef<Expr> getChildExpressions() const;
ArrayRef<Expr> getProperArguments() const;
/// Returns the list of lists of expressions used as arguments of successors
/// of this expression (i.e., arguments passed to destination basic blocks in
/// terminator statements).
SmallVector<ArrayRef<Expr>, 4> getSuccessorArguments() const;
/// Returns the list of expressions used as arguments of the `index`-th
/// successor of this expression.
ArrayRef<Expr> getSuccessorArguments(int index) const;
/// Returns the list of argument groups (includes the proper argument group,
/// followed by successor/block argument groups).
SmallVector<ArrayRef<Expr>, 4> getAllArgumentGroups() const;
/// Returns the list of attributes of this expression.
ArrayRef<NamedAttribute> getAttributes() const;
@ -172,9 +187,13 @@ public:
/// Returns the attribute with the given name, if any.
Attribute getAttribute(StringRef name) const;
/// Returns the list of successors (StmtBlocks) of this expression.
ArrayRef<StmtBlock> getSuccessors() const;
/// Build the IR corresponding to this expression.
SmallVector<Value *, 4>
build(FuncBuilder &b, const llvm::DenseMap<Expr, Value *> &ssaBindings) const;
build(FuncBuilder &b, const llvm::DenseMap<Expr, Value *> &ssaBindings,
const llvm::DenseMap<StmtBlock, Block *> &blockBindings) const;
void print(raw_ostream &os) const;
void dump() const;
@ -267,15 +286,18 @@ struct VariadicExpr : public Expr {
friend class Expr;
VariadicExpr(StringRef name, llvm::ArrayRef<Expr> exprs,
llvm::ArrayRef<Type> types = {},
ArrayRef<NamedAttribute> attrs = {});
llvm::ArrayRef<NamedAttribute> attrs = {},
llvm::ArrayRef<StmtBlock> succ = {});
llvm::ArrayRef<Expr> getExprs() const;
llvm::ArrayRef<Type> getTypes() const;
llvm::ArrayRef<StmtBlock> getSuccessors() const;
template <typename T>
static VariadicExpr make(llvm::ArrayRef<Expr> exprs,
llvm::ArrayRef<Type> types = {},
llvm::ArrayRef<NamedAttribute> attrs = {}) {
return VariadicExpr(T::getOperationName(), exprs, types, attrs);
llvm::ArrayRef<NamedAttribute> attrs = {},
llvm::ArrayRef<StmtBlock> succ = {}) {
return VariadicExpr(T::getOperationName(), exprs, types, attrs, succ);
}
protected:
@ -289,18 +311,6 @@ struct StmtBlockLikeExpr : public Expr {
StmtBlockLikeExpr(ExprKind kind, llvm::ArrayRef<Expr> exprs,
llvm::ArrayRef<Type> types = {});
/// Get the list of subexpressions.
/// StmtBlockLikeExprs can contain multiple groups of subexpressions separated
/// by null expressions and the result of this call will include them.
llvm::ArrayRef<Expr> getExprs() const;
/// Get the list of subexpression groups.
/// StmtBlockLikeExprs can contain multiple groups of subexpressions separated
/// by null expressions. This will identify those groups and return a list
/// of lists of subexpressions split around null expressions. Two null
/// expressions in a row identify an empty group.
SmallVector<llvm::ArrayRef<Expr>, 4> getExprGroups() const;
protected:
StmtBlockLikeExpr(Expr::ImplType *ptr) : Expr(ptr) {
assert(!ptr || isa<StmtBlockLikeExpr>() && "expected StmtBlockLikeExpr");
@ -399,19 +409,23 @@ public:
explicit StmtBlock(edsc_block_t st)
: storage(reinterpret_cast<ImplType *>(st)) {}
StmtBlock(const StmtBlock &other) = default;
StmtBlock(llvm::ArrayRef<Stmt> stmts = {});
StmtBlock(llvm::ArrayRef<Bindable> args, llvm::ArrayRef<Type> argTypes,
llvm::ArrayRef<Stmt> stmts = {});
StmtBlock(llvm::ArrayRef<Stmt> stmts);
StmtBlock(llvm::ArrayRef<Bindable> args, llvm::ArrayRef<Stmt> stmts = {});
llvm::ArrayRef<Bindable> getArguments() const;
llvm::ArrayRef<Type> getArgumentTypes() const;
llvm::ArrayRef<Stmt> getBody() const;
uint64_t getId() const;
void print(llvm::raw_ostream &os, Twine indent) const;
std::string str() const;
operator edsc_block_t() { return edsc_block_t{storage}; }
/// Reset the body of this block with the given list of statements.
StmtBlock &operator=(llvm::ArrayRef<Stmt> stmts);
void set(llvm::ArrayRef<Stmt> stmts) { *this = stmts; }
ImplType *getStoragePtr() const { return storage; }
private:
@ -619,6 +633,8 @@ Expr call(Expr func, Type result, llvm::ArrayRef<Expr> args);
Expr call(Expr func, llvm::ArrayRef<Expr> args);
Stmt Return(ArrayRef<Expr> values = {});
Stmt Branch(StmtBlock destination, ArrayRef<Expr> args = {});
Stmt For(Expr lb, Expr ub, Expr step, llvm::ArrayRef<Stmt> enclosedStmts);
Stmt For(const Bindable &idx, Expr lb, Expr ub, Expr step,
llvm::ArrayRef<Stmt> enclosedStmts);
@ -633,11 +649,13 @@ Stmt For(llvm::ArrayRef<Expr> indices, llvm::ArrayRef<Expr> lbs,
Stmt MaxMinFor(const Bindable &idx, ArrayRef<Expr> lbs, ArrayRef<Expr> ubs,
Expr step, ArrayRef<Stmt> enclosedStmts);
StmtBlock block(llvm::ArrayRef<Bindable> args, llvm::ArrayRef<Type> argTypes,
llvm::ArrayRef<Stmt> stmts);
inline StmtBlock block(llvm::ArrayRef<Stmt> stmts) {
return block({}, {}, stmts);
}
/// Define an MLIR Block and bind its arguments to `args`. The types of block
/// arguments are those of `args`, each of which must have exactly one result
/// type. The body of the block may be empty and can be reset later.
StmtBlock block(llvm::ArrayRef<Bindable> args, llvm::ArrayRef<Stmt> stmts);
/// Define an MLIR Block without arguments. The body of the block can be empty
/// and can be reset later.
inline StmtBlock block(llvm::ArrayRef<Stmt> stmts) { return block({}, stmts); }
/// This helper class exists purely for sugaring purposes and allows writing
/// expressions such as:

View File

@ -42,21 +42,43 @@ struct LowerEDSCTestPass : public FunctionPass {
#include "mlir/EDSC/reference-impl.inc"
PassResult LowerEDSCTestPass::runOnFunction(Function *f) {
// Inject a EDSC-constructed list of blocks.
// Inject a EDSC-constructed infinite loop implemented by mutual branching
// between two blocks, following the pattern:
//
// br ^bb1
// ^bb1:
// br ^bb2
// ^bb2:
// br ^bb1
//
// Use blocks with arguments.
if (f->getName().strref() == "blocks") {
using namespace edsc::op;
FuncBuilder builder(f);
edsc::ScopedEDSCContext context;
// Declare two blocks. Note that we must declare the blocks before creating
// branches to them.
auto type = builder.getIntegerType(32);
edsc::Expr arg1(type), arg2(type), arg3(type), arg4(type);
edsc::Expr arg1(type), arg2(type), arg3(type), arg4(type), r(type);
edsc::StmtBlock b1 = edsc::block({arg1, arg2}, {}),
b2 = edsc::block({arg3, arg4}, {});
auto c1 = edsc::constantInteger(type, 42);
auto c2 = edsc::constantInteger(type, 1234);
auto b1 =
edsc::block({arg1, arg2}, {type, type}, {arg1 + arg2, edsc::Return()});
auto b2 =
edsc::block({arg3, arg4}, {type, type}, {arg3 - arg4, edsc::Return()});
// Make an infinite loops by branching between the blocks. Note that copy-
// assigning a block won't work well with branches, update the body instead.
b1.set({r = arg1 + arg2, edsc::Branch(b2, {arg1, r})});
b2.set({edsc::Branch(b1, {arg3, arg4})});
auto instr = edsc::Branch(b2, {c1, c2});
edsc::MLIREmitter(&builder, f->getLoc()).emitBlock(b1).emitBlock(b2);
// Remove the existing 'return' from the function, reset the builder after
// the instruction iterator invalidation and emit a branch to b2. This
// should also emit blocks b2 and b1 that appear as successors to the
// current block after the branch instruction is insterted.
f->begin()->clear();
builder.setInsertionPoint(&*f->begin(), f->begin()->begin());
edsc::MLIREmitter(&builder, f->getLoc()).emitStmt(instr);
}
// Inject a EDSC-constructed `for` loop with bounds coming from function

View File

@ -129,7 +129,16 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) {
bool expectedEmpty = false;
if (e.isa<UnaryExpr>() || e.isa<BinaryExpr>() || e.isa<TernaryExpr>() ||
e.isa<VariadicExpr>()) {
auto results = e.build(*builder, ssaBindings);
// Emit any successors before the instruction with successors. At this
// point, all values defined by the current block must have been bound, the
// current instruction with successors cannot define new values, so the
// successor can use those values.
assert(e.getSuccessors().empty() || e.getResultTypes().empty() &&
"an operation with successors must "
"not have results and vice versa");
for (StmtBlock block : e.getSuccessors())
emitBlock(block);
auto results = e.build(*builder, ssaBindings, blockBindings);
assert(results.size() <= 1 && "2+-result exprs are not supported");
expectedEmpty = results.empty();
if (!results.empty())
@ -138,7 +147,7 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) {
if (auto expr = e.dyn_cast<StmtBlockLikeExpr>()) {
if (expr.getKind() == ExprKind::For) {
auto exprGroups = expr.getExprGroups();
auto exprGroups = expr.getAllArgumentGroups();
assert(exprGroups.size() == 3 && "expected 3 expr groups in `for`");
assert(!exprGroups[0].empty() && "expected at least one lower bound");
assert(!exprGroups[1].empty() && "expected at least one upper bound");
@ -213,8 +222,9 @@ mlir::edsc::MLIREmitter &mlir::edsc::MLIREmitter::emitStmt(const Stmt &stmt) {
if (!val) {
assert((stmt.getRHS().is_op<DeallocOp>() ||
stmt.getRHS().is_op<StoreOp>() || stmt.getRHS().is_op<ReturnOp>() ||
stmt.getRHS().is_op<CallIndirectOp>()) &&
"dealloc, store, return or call_indirect expected as the only "
stmt.getRHS().is_op<CallIndirectOp>() ||
stmt.getRHS().is_op<BranchOp>()) &&
"dealloc, store, return, br, or call_indirect expected as the only "
"0-result ops");
if (stmt.getRHS().is_op<CallIndirectOp>()) {
assert(

View File

@ -64,17 +64,23 @@ struct ExprStorage {
unsigned id;
StringRef opName;
// Exprs can contain multiple groups of operands separated by null
// expressions. Two null expressions in a row identify an empty group.
ArrayRef<Expr> operands;
ArrayRef<Type> resultTypes;
ArrayRef<NamedAttribute> attributes;
ArrayRef<StmtBlock> successors;
ExprStorage(ExprKind kind, StringRef name, ArrayRef<Type> results,
ArrayRef<Expr> children, ArrayRef<NamedAttribute> attrs,
StringRef descr = "", unsigned exprId = Expr::newId())
ArrayRef<StmtBlock> succ = {}, unsigned exprId = Expr::newId())
: kind(kind), id(exprId) {
operands = copyIntoExprAllocator(children);
resultTypes = copyIntoExprAllocator(results);
attributes = copyIntoExprAllocator(attrs);
successors = copyIntoExprAllocator(succ);
if (!name.empty()) {
auto nameStorage = Expr::globalAllocator()->Allocate<char>(name.size());
std::uninitialized_copy(name.begin(), name.end(), nameStorage);
@ -94,11 +100,24 @@ struct StmtStorage {
struct StmtBlockStorage {
StmtBlockStorage(ArrayRef<Bindable> args, ArrayRef<Type> argTypes,
ArrayRef<Stmt> stmts) {
id = nextId();
arguments = copyIntoExprAllocator(args);
argumentTypes = copyIntoExprAllocator(argTypes);
statements = copyIntoExprAllocator(stmts);
}
void replaceStmts(ArrayRef<Stmt> stmts) {
Expr::globalAllocator()->Deallocate(statements.data(), statements.size());
statements = copyIntoExprAllocator(stmts);
}
static uint64_t &nextId() {
static thread_local uint64_t next = 0;
return ++next;
}
static void resetIds() { nextId() = 0; }
uint64_t id;
ArrayRef<Bindable> arguments;
ArrayRef<Type> argumentTypes;
ArrayRef<Stmt> statements;
@ -111,6 +130,7 @@ struct StmtBlockStorage {
mlir::edsc::ScopedEDSCContext::ScopedEDSCContext() {
Expr::globalAllocator() = &allocator;
Bindable::resetIds();
StmtBlockStorage::resetIds();
}
mlir::edsc::ScopedEDSCContext::~ScopedEDSCContext() {
@ -138,10 +158,6 @@ ArrayRef<Type> mlir::edsc::Expr::getResultTypes() const {
return storage->resultTypes;
}
ArrayRef<Expr> mlir::edsc::Expr::getChildExpressions() const {
return storage->operands;
}
ArrayRef<NamedAttribute> mlir::edsc::Expr::getAttributes() const {
return storage->attributes;
}
@ -153,26 +169,38 @@ Attribute mlir::edsc::Expr::getAttribute(StringRef name) const {
return {};
}
ArrayRef<StmtBlock> mlir::edsc::Expr::getSuccessors() const {
return storage->successors;
}
StringRef mlir::edsc::Expr::getName() const {
return static_cast<ImplType *>(storage)->opName;
}
SmallVector<Value *, 4>
Expr::build(FuncBuilder &b,
const llvm::DenseMap<Expr, Value *> &ssaBindings) const {
buildExprs(ArrayRef<Expr> exprs, FuncBuilder &b,
const llvm::DenseMap<Expr, Value *> &ssaBindings,
const llvm::DenseMap<StmtBlock, mlir::Block *> &blockBindings) {
SmallVector<Value *, 4> values;
values.reserve(exprs.size());
for (auto child : exprs) {
auto subResults = child.build(b, ssaBindings, blockBindings);
assert(subResults.size() == 1 &&
"expected single-result expression as operand");
values.push_back(subResults.front());
}
return values;
}
SmallVector<Value *, 4>
Expr::build(FuncBuilder &b, const llvm::DenseMap<Expr, Value *> &ssaBindings,
const llvm::DenseMap<StmtBlock, Block *> &blockBindings) const {
auto it = ssaBindings.find(*this);
if (it != ssaBindings.end())
return {it->second};
auto *impl = static_cast<ImplType *>(storage);
SmallVector<Value *, 4> operandValues;
operandValues.reserve(impl->operands.size());
for (auto child : impl->operands) {
auto subResults = child.build(b, ssaBindings);
assert(subResults.size() == 1 &&
"expected single-result expression as operand");
operandValues.push_back(subResults.front());
}
SmallVector<Value *, 4> operandValues =
buildExprs(getProperArguments(), b, ssaBindings, blockBindings);
// Special case for emitting composed affine.applies.
// FIXME: this should not be a special case, instead, define composed form as
@ -185,12 +213,24 @@ Expr::build(FuncBuilder &b,
return {affInstr->getResult()};
}
auto state = OperationState(b.getContext(), b.getUnknownLoc(), impl->opName);
auto state = OperationState(b.getContext(), b.getUnknownLoc(), getName());
state.addOperands(operandValues);
state.addTypes(impl->resultTypes);
for (const auto &attr : impl->attributes)
state.addTypes(getResultTypes());
for (const auto &attr : getAttributes())
state.addAttribute(attr.first, attr.second);
auto successors = getSuccessors();
auto successorArgs = getSuccessorArguments();
assert(successors.size() == successorArgs.size() &&
"expected all successors to have a corresponding operand group");
for (int i = 0, e = successors.size(); i < e; ++i) {
StmtBlock block = successors[i];
assert(blockBindings.count(block) != 0 && "successor block does not exist");
state.addSuccessor(
blockBindings.lookup(block),
buildExprs(successorArgs[i], b, ssaBindings, blockBindings));
}
Instruction *inst = b.createOperation(state);
return llvm::to_vector<4>(inst->getResults());
}
@ -499,17 +539,26 @@ edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_max_expr_t lb, edsc_min_expr_t ub,
Expr(step), stmts));
}
StmtBlock mlir::edsc::block(ArrayRef<Bindable> args, ArrayRef<Type> argTypes,
ArrayRef<Stmt> stmts) {
assert(args.size() == argTypes.size() &&
"mismatching number of arguments and argument types");
return StmtBlock(args, argTypes, stmts);
StmtBlock mlir::edsc::block(ArrayRef<Bindable> args, ArrayRef<Stmt> stmts) {
return StmtBlock(args, stmts);
}
edsc_block_t Block(edsc_stmt_list_t enclosedStmts) {
edsc_block_t Block(edsc_expr_list_t arguments, edsc_stmt_list_t enclosedStmts) {
llvm::SmallVector<Stmt, 8> stmts;
fillStmts(enclosedStmts, &stmts);
return StmtBlock(stmts);
llvm::SmallVector<Bindable, 8> args;
for (uint64_t i = 0; i < arguments.n; ++i)
args.emplace_back(Expr(arguments.exprs[i]));
return StmtBlock(args, stmts);
}
edsc_block_t BlockSetBody(edsc_block_t block, edsc_stmt_list_t stmts) {
llvm::SmallVector<Stmt, 8> body;
fillStmts(stmts, &body);
StmtBlock(block).set(body);
return block;
}
Expr mlir::edsc::load(Expr m, ArrayRef<Expr> indices) {
@ -593,6 +642,13 @@ edsc_stmt_t Return(edsc_expr_list_t values) {
return Stmt(Return(makeExprs(values)));
}
Stmt mlir::edsc::Branch(StmtBlock destination, ArrayRef<Expr> args) {
SmallVector<Expr, 4> arguments;
arguments.push_back(nullptr);
arguments.insert(arguments.end(), args.begin(), args.end());
return VariadicExpr::make<BranchOp>(arguments, {}, {}, {destination});
}
static raw_ostream &printBinaryExpr(raw_ostream &os, BinaryExpr e,
StringRef infix) {
os << '(' << e.getLHS() << ' ' << infix << ' ' << e.getRHS() << ')';
@ -701,7 +757,12 @@ void printAffineApply(raw_ostream &os, mlir::edsc::Expr e) {
assert(mapAttr && "expected a map in an affine apply expr");
printAffineMap(os, mapAttr.cast<AffineMapAttr>().getValue(),
e.getChildExpressions());
e.getProperArguments());
}
edsc_stmt_t Branch(edsc_block_t destination, edsc_expr_list_t arguments) {
auto args = makeExprs(arguments);
return mlir::edsc::Branch(StmtBlock(destination), args);
}
void mlir::edsc::Expr::print(raw_ostream &os) const {
@ -737,15 +798,15 @@ void mlir::edsc::Expr::print(raw_ostream &os) const {
// Handle known variadic ops with pretty forms.
if (auto narExpr = this->dyn_cast<VariadicExpr>()) {
if (narExpr.is_op<LoadOp>()) {
os << narExpr.getName() << '(' << getChildExpressions().front() << '[';
interleaveComma(getChildExpressions().drop_front(), os);
os << narExpr.getName() << '(' << getProperArguments().front() << '[';
interleaveComma(getProperArguments().drop_front(), os);
os << "])";
return;
}
if (narExpr.is_op<StoreOp>()) {
os << narExpr.getName() << '(' << getChildExpressions().front() << ", "
<< getChildExpressions()[1] << '[';
interleaveComma(getChildExpressions().drop_front(2), os);
os << narExpr.getName() << '(' << getProperArguments().front() << ", "
<< getProperArguments()[1] << '[';
interleaveComma(getProperArguments().drop_front(2), os);
os << "])";
return;
}
@ -756,11 +817,21 @@ void mlir::edsc::Expr::print(raw_ostream &os) const {
return;
}
if (narExpr.is_op<CallIndirectOp>()) {
os << '@' << getChildExpressions().front() << '(';
interleaveComma(getChildExpressions().drop_front(), os);
os << '@' << getProperArguments().front() << '(';
interleaveComma(getProperArguments().drop_front(), os);
os << ')';
return;
}
if (narExpr.is_op<BranchOp>()) {
os << "br ^bb" << narExpr.getSuccessors().front().getId();
auto blockArgs = getSuccessorArguments(0);
if (!blockArgs.empty())
os << '(';
interleaveComma(blockArgs, os);
if (!blockArgs.empty())
os << ")";
return;
}
}
// Special case for integer constants that are printed as is. Use
@ -778,7 +849,26 @@ void mlir::edsc::Expr::print(raw_ostream &os) const {
if (this->isa<UnaryExpr>() || this->isa<BinaryExpr>() ||
this->isa<TernaryExpr>() || this->isa<VariadicExpr>()) {
os << (getName().empty() ? "##unknown##" : getName()) << '(';
interleaveComma(getChildExpressions(), os);
interleaveComma(getProperArguments(), os);
auto successors = getSuccessors();
if (!successors.empty()) {
os << '[';
interleave(
llvm::zip(successors, getSuccessorArguments()),
[&os](const std::tuple<const StmtBlock &, const ArrayRef<Expr> &>
&pair) {
const auto &block = std::get<0>(pair);
ArrayRef<Expr> operands = std::get<1>(pair);
os << "^bb" << block.getId();
if (!operands.empty()) {
os << '(';
interleaveComma(operands, os);
os << ')';
}
},
[&os]() { os << ", "; });
os << ']';
}
auto attrs = getAttributes();
if (!attrs.empty()) {
os << '{';
@ -797,7 +887,7 @@ void mlir::edsc::Expr::print(raw_ostream &os) const {
// We only print the lb, ub and step here, which are the StmtBlockLike
// part of the `for` StmtBlockLikeExpr.
case ExprKind::For: {
auto exprGroups = stmtLikeExpr.getExprGroups();
auto exprGroups = stmtLikeExpr.getAllArgumentGroups();
assert(exprGroups.size() == 3 &&
"For StmtBlockLikeExpr expected 3 groups");
assert(exprGroups[2].size() == 1 && "expected 1 expr for loop step");
@ -885,17 +975,21 @@ Expr mlir::edsc::TernaryExpr::getRHS() const {
mlir::edsc::VariadicExpr::VariadicExpr(StringRef name, ArrayRef<Expr> exprs,
ArrayRef<Type> types,
ArrayRef<NamedAttribute> attrs)
ArrayRef<NamedAttribute> attrs,
ArrayRef<StmtBlock> succ)
: Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) {
// Initialize with placement new.
new (storage)
detail::ExprStorage(ExprKind::Variadic, name, types, exprs, attrs);
detail::ExprStorage(ExprKind::Variadic, name, types, exprs, attrs, succ);
}
ArrayRef<Expr> mlir::edsc::VariadicExpr::getExprs() const {
return static_cast<ImplType *>(storage)->operands;
return storage->operands;
}
ArrayRef<Type> mlir::edsc::VariadicExpr::getTypes() const {
return static_cast<ImplType *>(storage)->resultTypes;
return storage->resultTypes;
}
ArrayRef<StmtBlock> mlir::edsc::VariadicExpr::getSuccessors() const {
return storage->successors;
}
mlir::edsc::StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind,
@ -905,24 +999,56 @@ mlir::edsc::StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind,
// Initialize with placement new.
new (storage) detail::ExprStorage(kind, "", types, exprs, {});
}
ArrayRef<Expr> mlir::edsc::StmtBlockLikeExpr::getExprs() const {
return static_cast<ImplType *>(storage)->operands;
}
SmallVector<ArrayRef<Expr>, 4>
mlir::edsc::StmtBlockLikeExpr::getExprGroups() const {
SmallVector<ArrayRef<Expr>, 4> groups;
ArrayRef<Expr> exprs = getExprs();
int start = 0;
for (int i = 0, e = exprs.size(); i < e; ++i) {
if (!exprs[i]) {
groups.push_back(exprs.slice(start, i - start));
start = i + 1;
}
static ArrayRef<Expr> getOneArgumentGroupStartingFrom(int start,
ExprStorage *storage) {
for (int i = start, e = storage->operands.size(); i < e; ++i) {
if (!storage->operands[i])
return storage->operands.slice(start, i - start);
}
return storage->operands.drop_front(start);
}
static SmallVector<ArrayRef<Expr>, 4>
getAllArgumentGroupsStartingFrom(int start, ExprStorage *storage) {
SmallVector<ArrayRef<Expr>, 4> groups;
while (start < storage->operands.size()) {
auto group = getOneArgumentGroupStartingFrom(start, storage);
start += group.size() + 1;
groups.push_back(group);
}
groups.push_back(exprs.slice(start, exprs.size() - start));
return groups;
}
ArrayRef<Expr> mlir::edsc::Expr::getProperArguments() const {
return getOneArgumentGroupStartingFrom(0, storage);
}
SmallVector<ArrayRef<Expr>, 4> mlir::edsc::Expr::getSuccessorArguments() const {
// Skip the first group containing proper arguments.
// Note that +1 to size is necessary to step over the nullptrs in the list.
int start = getOneArgumentGroupStartingFrom(0, storage).size() + 1;
return getAllArgumentGroupsStartingFrom(start, storage);
}
ArrayRef<Expr> mlir::edsc::Expr::getSuccessorArguments(int index) const {
assert(index >= 0 && "argument group index is out of bounds");
assert(!storage->operands.empty() && "argument list is empty");
// Skip over the first index + 1 groups (also includes proper arguments).
int start = 0;
for (int i = 0, e = index + 1; i < e; ++i) {
assert(start < storage->operands.size() &&
"argument group index is out of bounds");
start += getOneArgumentGroupStartingFrom(start, storage).size() + 1;
}
return getOneArgumentGroupStartingFrom(start, storage);
}
SmallVector<ArrayRef<Expr>, 4> mlir::edsc::Expr::getAllArgumentGroups() const {
return getAllArgumentGroupsStartingFrom(0, storage);
}
mlir::edsc::Stmt::Stmt(const Bindable &lhs, const Expr &rhs,
llvm::ArrayRef<Stmt> enclosedStmts) {
storage = Expr::globalAllocator()->Allocate<detail::StmtStorage>();
@ -1012,15 +1138,29 @@ llvm::raw_ostream &mlir::edsc::operator<<(llvm::raw_ostream &os,
}
mlir::edsc::StmtBlock::StmtBlock(llvm::ArrayRef<Stmt> stmts)
: StmtBlock({}, {}, stmts) {}
: StmtBlock({}, stmts) {}
mlir::edsc::StmtBlock::StmtBlock(llvm::ArrayRef<Bindable> args,
llvm::ArrayRef<Type> argTypes,
llvm::ArrayRef<Stmt> stmts) {
// Extract block argument types from bindable types.
// Bindables must have a single type.
llvm::SmallVector<Type, 8> argTypes;
argTypes.reserve(args.size());
for (Bindable arg : args) {
auto argResults = arg.getResultTypes();
assert(argResults.size() == 1 &&
"only single-result expressions are supported");
argTypes.push_back(argResults.front());
}
storage = Expr::globalAllocator()->Allocate<detail::StmtBlockStorage>();
new (storage) detail::StmtBlockStorage(args, argTypes, stmts);
}
mlir::edsc::StmtBlock &mlir::edsc::StmtBlock::operator=(ArrayRef<Stmt> stmts) {
storage->replaceStmts(stmts);
return *this;
}
ArrayRef<mlir::edsc::Bindable> mlir::edsc::StmtBlock::getArguments() const {
return storage->arguments;
}
@ -1033,17 +1173,20 @@ ArrayRef<mlir::edsc::Stmt> mlir::edsc::StmtBlock::getBody() const {
return storage->statements;
}
uint64_t mlir::edsc::StmtBlock::getId() const { return storage->id; }
void mlir::edsc::StmtBlock::print(llvm::raw_ostream &os, Twine indent) const {
os << indent << "^bb";
os << indent << "^bb" << storage->id;
if (!getArgumentTypes().empty())
os << '(';
interleaveComma(getArguments(), os);
if (!getArgumentTypes().empty())
os << ')';
os << ":\n";
for (auto stmt : getBody())
for (auto stmt : getBody()) {
stmt.print(os, indent + " ");
os << '\n';
}
}
std::string mlir::edsc::StmtBlock::str() const {

View File

@ -10,16 +10,19 @@
// CHECK-DAG: #[[id2dmap:.*]] = (d0, d1) -> (d0, d1)
// This function will be detected by the test pass that will insert
// EDSC-constructed blocks with arguments.
// EDSC-constructed blocks with arguments forming an infinite loop.
// CHECK-LABEL: @blocks
func @blocks() {
return
//CHECK: ^bb1(%0: i32, %1: i32): // no predecessors
//CHECK-NEXT: %2 = addi %0, %1 : i32
//CHECK-NEXT: return
//CHECK: ^bb2(%3: i32, %4: i32): // no predecessors
//CHECK-NEXT: %5 = subi %3, %4 : i32
//CHECK-NEXT: return
//CHECK: %c42_i32 = constant 42 : i32
//CHECK-NEXT: %c1234_i32 = constant 1234 : i32
//CHECK-NEXT: br ^bb1(%c42_i32, %c1234_i32 : i32, i32)
//CHECK-NEXT: ^bb1(%0: i32, %1: i32): // 2 preds: ^bb0, ^bb2
//CHECK-NEXT: br ^bb2(%0, %1 : i32, i32)
//CHECK-NEXT: ^bb2(%2: i32, %3: i32): // pred: ^bb1
//CHECK-NEXT: %4 = addi %2, %3 : i32
//CHECK-NEXT: br ^bb1(%2, %4 : i32, i32)
//CHECK-NEXT: }
}
// This function will be detected by the test pass that will insert an