From 83e8db2193bbbd5e7aacc6e2f318f648bbd93da4 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 25 Feb 2019 09:16:30 -0800 Subject: [PATCH] EDSC: support branch instructions The new implementation of blocks was designed to support blocks with arguments. More specifically, StmtBlock can be constructed with a list of Bindables that will be bound to block aguments upon construction. Leverage this functionality to implement branch instructions with arguments. This additionally requires the statement storage to have a list of successors, similarly to core IR operations. Becauase successor chains can form loops, we need a possibility to decouple block declaration, after which it becomes usable by branch instructions, from block body definition. This is achieved by creating an empty block and by resetting its body with a new list of instructions. Note that assigning a block from another block will not affect any instructions that may have designated this block as their successor (this behavior is necessary to make value-type semantics of EDSC types consistent). Combined, one can now write generators like EDSCContext context; Type indexType = ...; Bindable i(indexType), ii(indexType), zero(indexType), one(indexType); StmtBlock loopBlock({i}, {}); loopBlock.set({ii = i + one, Branch(loopBlock, {ii})}); MLIREmitter(&builder) .bindConstant(zero, 0) .bindConstant(one, 1) .emitStmt(Branch(loopBlock, {zero})); where the emitter will emit the statement and its successors, if present. PiperOrigin-RevId: 235541892 --- mlir/bindings/python/pybind.cpp | 33 ++- mlir/bindings/python/test/test_py2and3.py | 39 +++- mlir/include/mlir-c/Core.h | 11 +- mlir/include/mlir/EDSC/Types.h | 68 +++--- mlir/lib/EDSC/LowerEDSCTestPass.cpp | 36 ++- mlir/lib/EDSC/MLIREmitter.cpp | 18 +- mlir/lib/EDSC/Types.cpp | 261 +++++++++++++++++----- mlir/test/EDSC/for-loops.mlir | 17 +- 8 files changed, 377 insertions(+), 106 deletions(-) diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp index ef2a4d92dcc8..9de78d865644 100644 --- a/mlir/bindings/python/pybind.cpp +++ b/mlir/bindings/python/pybind.cpp @@ -241,6 +241,8 @@ struct PythonBlock { return StmtBlock(blk).str(); } + PythonBlock set(const py::list &stmts); + edsc_block_t blk; }; @@ -317,6 +319,14 @@ static edsc_expr_list_t makeCExprs(llvm::SmallVectorImpl &owning, return edsc_expr_list_t{owning.data(), owning.size()}; } +static mlir_type_list_t makeCTypes(llvm::SmallVectorImpl &owning, + const py::list &types) { + for (auto &inp : types) { + owning.push_back(mlir_type_t{inp.cast()}); + } + return mlir_type_list_t{owning.data(), owning.size()}; +} + PythonExpr::PythonExpr(const PythonBindable &bindable) : expr{bindable.expr} {} PythonExpr MLIRFunctionEmitter::bindConstantBF16(double value) { @@ -410,6 +420,12 @@ void MLIRFunctionEmitter::emitBlockBody(PythonBlock block) { emitter.emitStmts(StmtBlock(block).getBody()); } +PythonBlock PythonBlock::set(const py::list &stmts) { + SmallVector owning; + ::BlockSetBody(blk, makeCStmts(owning, stmts)); + return *this; +} + PythonExpr dispatchCall(py::args args, py::kwargs kwargs) { assert(args.size() != 0); llvm::SmallVector exprs; @@ -438,10 +454,24 @@ PYBIND11_MODULE(pybind, m) { m.def("deleteContext", [](void *ctx) { delete reinterpret_cast(ctx); }); + m.def("Block", [](const py::list &args, const py::list &stmts) { + SmallVector owning; + SmallVector owningArgs; + return PythonBlock( + ::Block(makeCExprs(owningArgs, args), makeCStmts(owning, stmts))); + }); m.def("Block", [](const py::list &stmts) { SmallVector owning; - return PythonBlock(::Block(makeCStmts(owning, stmts))); + edsc_expr_list_t args{nullptr, 0}; + return PythonBlock(::Block(args, makeCStmts(owning, stmts))); }); + m.def( + "Branch", + [](PythonBlock destination, const py::list &operands) { + SmallVector owning; + return PythonStmt(::Branch(destination, makeCExprs(owning, operands))); + }, + py::arg("destination"), py::arg("operands") = py::list()); m.def("For", [](const py::list &ivs, const py::list &lbs, const py::list &ubs, const py::list &steps, const py::list &stmts) { SmallVector owningIVs; @@ -527,6 +557,7 @@ PYBIND11_MODULE(pybind, m) { py::class_(m, "StmtBlock", "Wrapping class for mlir::edsc::StmtBlock") .def(py::init()) + .def("set", &PythonBlock::set) .def("__str__", &PythonBlock::str); py::class_(m, "Type", "Wrapping class for mlir::Type.") diff --git a/mlir/bindings/python/test/test_py2and3.py b/mlir/bindings/python/test/test_py2and3.py index c6ab3ff25714..314e1b55b2e3 100644 --- a/mlir/bindings/python/test/test_py2and3.py +++ b/mlir/bindings/python/test/test_py2and3.py @@ -151,10 +151,47 @@ class EdscTest(unittest.TestCase): i, j = list(map(E.Expr, [E.Bindable(self.f32Type) for _ in range(2)])) stmt = E.Block([E.Stmt(i + j), E.Stmt(i - j)]) str = stmt.__str__() - self.assertIn("^bb:", str) + self.assertIn("^bb", str) self.assertIn(" = ($1 + $2)", str) self.assertIn(" = ($1 - $2)", str) + def testBlockArgs(self): + with E.ContextManager(): + module = E.MLIRModule() + t = module.make_scalar_type("i", 32) + i, j = list(map(E.Expr, [E.Bindable(t) for _ in range(2)])) + stmt = E.Block([i, j], [E.Stmt(i + j)]) + str = stmt.__str__() + self.assertIn("^bb", str) + self.assertIn("($1, $2):", str) + self.assertIn("($1 + $2)", str) + + def testBranch(self): + with E.ContextManager(): + i, j = list(map(E.Expr, [E.Bindable(self.i32Type) for _ in range(2)])) + b1 = E.Block([E.Stmt(i + j)]) + b2 = E.Block([E.Branch(b1)]) + str1 = b1.__str__() + str2 = b2.__str__() + self.assertIn("^bb1:\n" + "$4 = ($1 + $2)", str1) + self.assertIn("^bb2:\n" + "$6 = br ^bb1", str2) + + def testBranchArgs(self): + with E.ContextManager(): + b1arg, b2arg = (E.Expr(E.Bindable(self.i32Type)) for _ in range(2)) + # Declare empty blocks with arguments and bind those arguments. + b1 = E.Block([b1arg], []) + b2 = E.Block([b2arg], []) + one = E.ConstantInteger(self.i32Type, 1) + # Make blocks branch to each other in a sort of infinite loop. + # This checks that the EDSC implementation does not fall into such loop. + b1.set([E.Branch(b2, [b1arg + one])]) + b2.set([E.Branch(b1, [b2arg])]) + str1 = b1.__str__() + str2 = b2.__str__() + self.assertIn("^bb1($1):\n" + "$6 = br ^bb2(($1 + 1))", str1) + self.assertIn("^bb2($2):\n" + "$8 = br ^bb1($2)", str2) + def testMLIRScalarTypes(self): module = E.MLIRModule() t = module.make_scalar_type("bf16") diff --git a/mlir/include/mlir-c/Core.h b/mlir/include/mlir-c/Core.h index 09dc5c850e27..44ab5cb61beb 100644 --- a/mlir/include/mlir-c/Core.h +++ b/mlir/include/mlir-c/Core.h @@ -230,8 +230,15 @@ edsc_expr_t ConstantInteger(mlir_type_t type, int64_t value); edsc_stmt_t Return(edsc_expr_list_t values); /// Returns an opaque expression for an mlir::edsc::StmtBlock containing the -/// given list of statements. Block arguments are not currently supported. -edsc_block_t Block(edsc_stmt_list_t enclosedStmts); +/// given list of statements. +edsc_block_t Block(edsc_expr_list_t arguments, edsc_stmt_list_t enclosedStmts); + +/// Set the body of the block to the given statements and return the block. +edsc_block_t BlockSetBody(edsc_block_t, edsc_stmt_list_t stmts); + +/// Returns an opaque statement branching to `destination` and passing +/// `arguments` as block arguments. +edsc_stmt_t Branch(edsc_block_t destination, edsc_expr_list_t arguments); /// Returns an opaque statement for an mlir::AffineForOp with `enclosedStmts` /// nested below it. diff --git a/mlir/include/mlir/EDSC/Types.h b/mlir/include/mlir/EDSC/Types.h index bb333a93938c..9f29ab95bd0f 100644 --- a/mlir/include/mlir/EDSC/Types.h +++ b/mlir/include/mlir/EDSC/Types.h @@ -48,6 +48,8 @@ struct StmtBlockStorage; } // namespace detail +class StmtBlock; + /// EDSC Types closely mirror the core MLIR and uses an abstraction similar to /// AffineExpr: /// 1. a set of composable structs; @@ -164,7 +166,20 @@ public: ArrayRef getResultTypes() const; /// Returns the list of expressions used as arguments of this expression. - ArrayRef getChildExpressions() const; + ArrayRef getProperArguments() const; + + /// Returns the list of lists of expressions used as arguments of successors + /// of this expression (i.e., arguments passed to destination basic blocks in + /// terminator statements). + SmallVector, 4> getSuccessorArguments() const; + + /// Returns the list of expressions used as arguments of the `index`-th + /// successor of this expression. + ArrayRef getSuccessorArguments(int index) const; + + /// Returns the list of argument groups (includes the proper argument group, + /// followed by successor/block argument groups). + SmallVector, 4> getAllArgumentGroups() const; /// Returns the list of attributes of this expression. ArrayRef getAttributes() const; @@ -172,9 +187,13 @@ public: /// Returns the attribute with the given name, if any. Attribute getAttribute(StringRef name) const; + /// Returns the list of successors (StmtBlocks) of this expression. + ArrayRef getSuccessors() const; + /// Build the IR corresponding to this expression. SmallVector - build(FuncBuilder &b, const llvm::DenseMap &ssaBindings) const; + build(FuncBuilder &b, const llvm::DenseMap &ssaBindings, + const llvm::DenseMap &blockBindings) const; void print(raw_ostream &os) const; void dump() const; @@ -267,15 +286,18 @@ struct VariadicExpr : public Expr { friend class Expr; VariadicExpr(StringRef name, llvm::ArrayRef exprs, llvm::ArrayRef types = {}, - ArrayRef attrs = {}); + llvm::ArrayRef attrs = {}, + llvm::ArrayRef succ = {}); llvm::ArrayRef getExprs() const; llvm::ArrayRef getTypes() const; + llvm::ArrayRef getSuccessors() const; template static VariadicExpr make(llvm::ArrayRef exprs, llvm::ArrayRef types = {}, - llvm::ArrayRef attrs = {}) { - return VariadicExpr(T::getOperationName(), exprs, types, attrs); + llvm::ArrayRef attrs = {}, + llvm::ArrayRef succ = {}) { + return VariadicExpr(T::getOperationName(), exprs, types, attrs, succ); } protected: @@ -289,18 +311,6 @@ struct StmtBlockLikeExpr : public Expr { StmtBlockLikeExpr(ExprKind kind, llvm::ArrayRef exprs, llvm::ArrayRef types = {}); - /// Get the list of subexpressions. - /// StmtBlockLikeExprs can contain multiple groups of subexpressions separated - /// by null expressions and the result of this call will include them. - llvm::ArrayRef getExprs() const; - - /// Get the list of subexpression groups. - /// StmtBlockLikeExprs can contain multiple groups of subexpressions separated - /// by null expressions. This will identify those groups and return a list - /// of lists of subexpressions split around null expressions. Two null - /// expressions in a row identify an empty group. - SmallVector, 4> getExprGroups() const; - protected: StmtBlockLikeExpr(Expr::ImplType *ptr) : Expr(ptr) { assert(!ptr || isa() && "expected StmtBlockLikeExpr"); @@ -399,19 +409,23 @@ public: explicit StmtBlock(edsc_block_t st) : storage(reinterpret_cast(st)) {} StmtBlock(const StmtBlock &other) = default; - StmtBlock(llvm::ArrayRef stmts = {}); - StmtBlock(llvm::ArrayRef args, llvm::ArrayRef argTypes, - llvm::ArrayRef stmts = {}); + StmtBlock(llvm::ArrayRef stmts); + StmtBlock(llvm::ArrayRef args, llvm::ArrayRef stmts = {}); llvm::ArrayRef getArguments() const; llvm::ArrayRef getArgumentTypes() const; llvm::ArrayRef getBody() const; + uint64_t getId() const; void print(llvm::raw_ostream &os, Twine indent) const; std::string str() const; operator edsc_block_t() { return edsc_block_t{storage}; } + /// Reset the body of this block with the given list of statements. + StmtBlock &operator=(llvm::ArrayRef stmts); + void set(llvm::ArrayRef stmts) { *this = stmts; } + ImplType *getStoragePtr() const { return storage; } private: @@ -619,6 +633,8 @@ Expr call(Expr func, Type result, llvm::ArrayRef args); Expr call(Expr func, llvm::ArrayRef args); Stmt Return(ArrayRef values = {}); +Stmt Branch(StmtBlock destination, ArrayRef args = {}); + Stmt For(Expr lb, Expr ub, Expr step, llvm::ArrayRef enclosedStmts); Stmt For(const Bindable &idx, Expr lb, Expr ub, Expr step, llvm::ArrayRef enclosedStmts); @@ -633,11 +649,13 @@ Stmt For(llvm::ArrayRef indices, llvm::ArrayRef lbs, Stmt MaxMinFor(const Bindable &idx, ArrayRef lbs, ArrayRef ubs, Expr step, ArrayRef enclosedStmts); -StmtBlock block(llvm::ArrayRef args, llvm::ArrayRef argTypes, - llvm::ArrayRef stmts); -inline StmtBlock block(llvm::ArrayRef stmts) { - return block({}, {}, stmts); -} +/// Define an MLIR Block and bind its arguments to `args`. The types of block +/// arguments are those of `args`, each of which must have exactly one result +/// type. The body of the block may be empty and can be reset later. +StmtBlock block(llvm::ArrayRef args, llvm::ArrayRef stmts); +/// Define an MLIR Block without arguments. The body of the block can be empty +/// and can be reset later. +inline StmtBlock block(llvm::ArrayRef stmts) { return block({}, stmts); } /// This helper class exists purely for sugaring purposes and allows writing /// expressions such as: diff --git a/mlir/lib/EDSC/LowerEDSCTestPass.cpp b/mlir/lib/EDSC/LowerEDSCTestPass.cpp index 67d4fb380807..9ea9e1c392d1 100644 --- a/mlir/lib/EDSC/LowerEDSCTestPass.cpp +++ b/mlir/lib/EDSC/LowerEDSCTestPass.cpp @@ -42,21 +42,43 @@ struct LowerEDSCTestPass : public FunctionPass { #include "mlir/EDSC/reference-impl.inc" PassResult LowerEDSCTestPass::runOnFunction(Function *f) { - // Inject a EDSC-constructed list of blocks. + // Inject a EDSC-constructed infinite loop implemented by mutual branching + // between two blocks, following the pattern: + // + // br ^bb1 + // ^bb1: + // br ^bb2 + // ^bb2: + // br ^bb1 + // + // Use blocks with arguments. if (f->getName().strref() == "blocks") { using namespace edsc::op; FuncBuilder builder(f); edsc::ScopedEDSCContext context; + // Declare two blocks. Note that we must declare the blocks before creating + // branches to them. auto type = builder.getIntegerType(32); - edsc::Expr arg1(type), arg2(type), arg3(type), arg4(type); + edsc::Expr arg1(type), arg2(type), arg3(type), arg4(type), r(type); + edsc::StmtBlock b1 = edsc::block({arg1, arg2}, {}), + b2 = edsc::block({arg3, arg4}, {}); + auto c1 = edsc::constantInteger(type, 42); + auto c2 = edsc::constantInteger(type, 1234); - auto b1 = - edsc::block({arg1, arg2}, {type, type}, {arg1 + arg2, edsc::Return()}); - auto b2 = - edsc::block({arg3, arg4}, {type, type}, {arg3 - arg4, edsc::Return()}); + // Make an infinite loops by branching between the blocks. Note that copy- + // assigning a block won't work well with branches, update the body instead. + b1.set({r = arg1 + arg2, edsc::Branch(b2, {arg1, r})}); + b2.set({edsc::Branch(b1, {arg3, arg4})}); + auto instr = edsc::Branch(b2, {c1, c2}); - edsc::MLIREmitter(&builder, f->getLoc()).emitBlock(b1).emitBlock(b2); + // Remove the existing 'return' from the function, reset the builder after + // the instruction iterator invalidation and emit a branch to b2. This + // should also emit blocks b2 and b1 that appear as successors to the + // current block after the branch instruction is insterted. + f->begin()->clear(); + builder.setInsertionPoint(&*f->begin(), f->begin()->begin()); + edsc::MLIREmitter(&builder, f->getLoc()).emitStmt(instr); } // Inject a EDSC-constructed `for` loop with bounds coming from function diff --git a/mlir/lib/EDSC/MLIREmitter.cpp b/mlir/lib/EDSC/MLIREmitter.cpp index aa4ca47e6609..d52fdcd57c57 100644 --- a/mlir/lib/EDSC/MLIREmitter.cpp +++ b/mlir/lib/EDSC/MLIREmitter.cpp @@ -129,7 +129,16 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) { bool expectedEmpty = false; if (e.isa() || e.isa() || e.isa() || e.isa()) { - auto results = e.build(*builder, ssaBindings); + // Emit any successors before the instruction with successors. At this + // point, all values defined by the current block must have been bound, the + // current instruction with successors cannot define new values, so the + // successor can use those values. + assert(e.getSuccessors().empty() || e.getResultTypes().empty() && + "an operation with successors must " + "not have results and vice versa"); + for (StmtBlock block : e.getSuccessors()) + emitBlock(block); + auto results = e.build(*builder, ssaBindings, blockBindings); assert(results.size() <= 1 && "2+-result exprs are not supported"); expectedEmpty = results.empty(); if (!results.empty()) @@ -138,7 +147,7 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) { if (auto expr = e.dyn_cast()) { if (expr.getKind() == ExprKind::For) { - auto exprGroups = expr.getExprGroups(); + auto exprGroups = expr.getAllArgumentGroups(); assert(exprGroups.size() == 3 && "expected 3 expr groups in `for`"); assert(!exprGroups[0].empty() && "expected at least one lower bound"); assert(!exprGroups[1].empty() && "expected at least one upper bound"); @@ -213,8 +222,9 @@ mlir::edsc::MLIREmitter &mlir::edsc::MLIREmitter::emitStmt(const Stmt &stmt) { if (!val) { assert((stmt.getRHS().is_op() || stmt.getRHS().is_op() || stmt.getRHS().is_op() || - stmt.getRHS().is_op()) && - "dealloc, store, return or call_indirect expected as the only " + stmt.getRHS().is_op() || + stmt.getRHS().is_op()) && + "dealloc, store, return, br, or call_indirect expected as the only " "0-result ops"); if (stmt.getRHS().is_op()) { assert( diff --git a/mlir/lib/EDSC/Types.cpp b/mlir/lib/EDSC/Types.cpp index 39f09be8c5f6..9e1a9795823d 100644 --- a/mlir/lib/EDSC/Types.cpp +++ b/mlir/lib/EDSC/Types.cpp @@ -64,17 +64,23 @@ struct ExprStorage { unsigned id; StringRef opName; + + // Exprs can contain multiple groups of operands separated by null + // expressions. Two null expressions in a row identify an empty group. ArrayRef operands; + ArrayRef resultTypes; ArrayRef attributes; + ArrayRef successors; ExprStorage(ExprKind kind, StringRef name, ArrayRef results, ArrayRef children, ArrayRef attrs, - StringRef descr = "", unsigned exprId = Expr::newId()) + ArrayRef succ = {}, unsigned exprId = Expr::newId()) : kind(kind), id(exprId) { operands = copyIntoExprAllocator(children); resultTypes = copyIntoExprAllocator(results); attributes = copyIntoExprAllocator(attrs); + successors = copyIntoExprAllocator(succ); if (!name.empty()) { auto nameStorage = Expr::globalAllocator()->Allocate(name.size()); std::uninitialized_copy(name.begin(), name.end(), nameStorage); @@ -94,11 +100,24 @@ struct StmtStorage { struct StmtBlockStorage { StmtBlockStorage(ArrayRef args, ArrayRef argTypes, ArrayRef stmts) { + id = nextId(); arguments = copyIntoExprAllocator(args); argumentTypes = copyIntoExprAllocator(argTypes); statements = copyIntoExprAllocator(stmts); } + void replaceStmts(ArrayRef stmts) { + Expr::globalAllocator()->Deallocate(statements.data(), statements.size()); + statements = copyIntoExprAllocator(stmts); + } + + static uint64_t &nextId() { + static thread_local uint64_t next = 0; + return ++next; + } + static void resetIds() { nextId() = 0; } + + uint64_t id; ArrayRef arguments; ArrayRef argumentTypes; ArrayRef statements; @@ -111,6 +130,7 @@ struct StmtBlockStorage { mlir::edsc::ScopedEDSCContext::ScopedEDSCContext() { Expr::globalAllocator() = &allocator; Bindable::resetIds(); + StmtBlockStorage::resetIds(); } mlir::edsc::ScopedEDSCContext::~ScopedEDSCContext() { @@ -138,10 +158,6 @@ ArrayRef mlir::edsc::Expr::getResultTypes() const { return storage->resultTypes; } -ArrayRef mlir::edsc::Expr::getChildExpressions() const { - return storage->operands; -} - ArrayRef mlir::edsc::Expr::getAttributes() const { return storage->attributes; } @@ -153,26 +169,38 @@ Attribute mlir::edsc::Expr::getAttribute(StringRef name) const { return {}; } +ArrayRef mlir::edsc::Expr::getSuccessors() const { + return storage->successors; +} + StringRef mlir::edsc::Expr::getName() const { return static_cast(storage)->opName; } SmallVector -Expr::build(FuncBuilder &b, - const llvm::DenseMap &ssaBindings) const { +buildExprs(ArrayRef exprs, FuncBuilder &b, + const llvm::DenseMap &ssaBindings, + const llvm::DenseMap &blockBindings) { + SmallVector values; + values.reserve(exprs.size()); + for (auto child : exprs) { + auto subResults = child.build(b, ssaBindings, blockBindings); + assert(subResults.size() == 1 && + "expected single-result expression as operand"); + values.push_back(subResults.front()); + } + return values; +} + +SmallVector +Expr::build(FuncBuilder &b, const llvm::DenseMap &ssaBindings, + const llvm::DenseMap &blockBindings) const { auto it = ssaBindings.find(*this); if (it != ssaBindings.end()) return {it->second}; - auto *impl = static_cast(storage); - SmallVector operandValues; - operandValues.reserve(impl->operands.size()); - for (auto child : impl->operands) { - auto subResults = child.build(b, ssaBindings); - assert(subResults.size() == 1 && - "expected single-result expression as operand"); - operandValues.push_back(subResults.front()); - } + SmallVector operandValues = + buildExprs(getProperArguments(), b, ssaBindings, blockBindings); // Special case for emitting composed affine.applies. // FIXME: this should not be a special case, instead, define composed form as @@ -185,12 +213,24 @@ Expr::build(FuncBuilder &b, return {affInstr->getResult()}; } - auto state = OperationState(b.getContext(), b.getUnknownLoc(), impl->opName); + auto state = OperationState(b.getContext(), b.getUnknownLoc(), getName()); state.addOperands(operandValues); - state.addTypes(impl->resultTypes); - for (const auto &attr : impl->attributes) + state.addTypes(getResultTypes()); + for (const auto &attr : getAttributes()) state.addAttribute(attr.first, attr.second); + auto successors = getSuccessors(); + auto successorArgs = getSuccessorArguments(); + assert(successors.size() == successorArgs.size() && + "expected all successors to have a corresponding operand group"); + for (int i = 0, e = successors.size(); i < e; ++i) { + StmtBlock block = successors[i]; + assert(blockBindings.count(block) != 0 && "successor block does not exist"); + state.addSuccessor( + blockBindings.lookup(block), + buildExprs(successorArgs[i], b, ssaBindings, blockBindings)); + } + Instruction *inst = b.createOperation(state); return llvm::to_vector<4>(inst->getResults()); } @@ -499,17 +539,26 @@ edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_max_expr_t lb, edsc_min_expr_t ub, Expr(step), stmts)); } -StmtBlock mlir::edsc::block(ArrayRef args, ArrayRef argTypes, - ArrayRef stmts) { - assert(args.size() == argTypes.size() && - "mismatching number of arguments and argument types"); - return StmtBlock(args, argTypes, stmts); +StmtBlock mlir::edsc::block(ArrayRef args, ArrayRef stmts) { + return StmtBlock(args, stmts); } -edsc_block_t Block(edsc_stmt_list_t enclosedStmts) { +edsc_block_t Block(edsc_expr_list_t arguments, edsc_stmt_list_t enclosedStmts) { llvm::SmallVector stmts; fillStmts(enclosedStmts, &stmts); - return StmtBlock(stmts); + + llvm::SmallVector args; + for (uint64_t i = 0; i < arguments.n; ++i) + args.emplace_back(Expr(arguments.exprs[i])); + + return StmtBlock(args, stmts); +} + +edsc_block_t BlockSetBody(edsc_block_t block, edsc_stmt_list_t stmts) { + llvm::SmallVector body; + fillStmts(stmts, &body); + StmtBlock(block).set(body); + return block; } Expr mlir::edsc::load(Expr m, ArrayRef indices) { @@ -593,6 +642,13 @@ edsc_stmt_t Return(edsc_expr_list_t values) { return Stmt(Return(makeExprs(values))); } +Stmt mlir::edsc::Branch(StmtBlock destination, ArrayRef args) { + SmallVector arguments; + arguments.push_back(nullptr); + arguments.insert(arguments.end(), args.begin(), args.end()); + return VariadicExpr::make(arguments, {}, {}, {destination}); +} + static raw_ostream &printBinaryExpr(raw_ostream &os, BinaryExpr e, StringRef infix) { os << '(' << e.getLHS() << ' ' << infix << ' ' << e.getRHS() << ')'; @@ -701,7 +757,12 @@ void printAffineApply(raw_ostream &os, mlir::edsc::Expr e) { assert(mapAttr && "expected a map in an affine apply expr"); printAffineMap(os, mapAttr.cast().getValue(), - e.getChildExpressions()); + e.getProperArguments()); +} + +edsc_stmt_t Branch(edsc_block_t destination, edsc_expr_list_t arguments) { + auto args = makeExprs(arguments); + return mlir::edsc::Branch(StmtBlock(destination), args); } void mlir::edsc::Expr::print(raw_ostream &os) const { @@ -737,15 +798,15 @@ void mlir::edsc::Expr::print(raw_ostream &os) const { // Handle known variadic ops with pretty forms. if (auto narExpr = this->dyn_cast()) { if (narExpr.is_op()) { - os << narExpr.getName() << '(' << getChildExpressions().front() << '['; - interleaveComma(getChildExpressions().drop_front(), os); + os << narExpr.getName() << '(' << getProperArguments().front() << '['; + interleaveComma(getProperArguments().drop_front(), os); os << "])"; return; } if (narExpr.is_op()) { - os << narExpr.getName() << '(' << getChildExpressions().front() << ", " - << getChildExpressions()[1] << '['; - interleaveComma(getChildExpressions().drop_front(2), os); + os << narExpr.getName() << '(' << getProperArguments().front() << ", " + << getProperArguments()[1] << '['; + interleaveComma(getProperArguments().drop_front(2), os); os << "])"; return; } @@ -756,11 +817,21 @@ void mlir::edsc::Expr::print(raw_ostream &os) const { return; } if (narExpr.is_op()) { - os << '@' << getChildExpressions().front() << '('; - interleaveComma(getChildExpressions().drop_front(), os); + os << '@' << getProperArguments().front() << '('; + interleaveComma(getProperArguments().drop_front(), os); os << ')'; return; } + if (narExpr.is_op()) { + os << "br ^bb" << narExpr.getSuccessors().front().getId(); + auto blockArgs = getSuccessorArguments(0); + if (!blockArgs.empty()) + os << '('; + interleaveComma(blockArgs, os); + if (!blockArgs.empty()) + os << ")"; + return; + } } // Special case for integer constants that are printed as is. Use @@ -778,7 +849,26 @@ void mlir::edsc::Expr::print(raw_ostream &os) const { if (this->isa() || this->isa() || this->isa() || this->isa()) { os << (getName().empty() ? "##unknown##" : getName()) << '('; - interleaveComma(getChildExpressions(), os); + interleaveComma(getProperArguments(), os); + auto successors = getSuccessors(); + if (!successors.empty()) { + os << '['; + interleave( + llvm::zip(successors, getSuccessorArguments()), + [&os](const std::tuple &> + &pair) { + const auto &block = std::get<0>(pair); + ArrayRef operands = std::get<1>(pair); + os << "^bb" << block.getId(); + if (!operands.empty()) { + os << '('; + interleaveComma(operands, os); + os << ')'; + } + }, + [&os]() { os << ", "; }); + os << ']'; + } auto attrs = getAttributes(); if (!attrs.empty()) { os << '{'; @@ -797,7 +887,7 @@ void mlir::edsc::Expr::print(raw_ostream &os) const { // We only print the lb, ub and step here, which are the StmtBlockLike // part of the `for` StmtBlockLikeExpr. case ExprKind::For: { - auto exprGroups = stmtLikeExpr.getExprGroups(); + auto exprGroups = stmtLikeExpr.getAllArgumentGroups(); assert(exprGroups.size() == 3 && "For StmtBlockLikeExpr expected 3 groups"); assert(exprGroups[2].size() == 1 && "expected 1 expr for loop step"); @@ -885,17 +975,21 @@ Expr mlir::edsc::TernaryExpr::getRHS() const { mlir::edsc::VariadicExpr::VariadicExpr(StringRef name, ArrayRef exprs, ArrayRef types, - ArrayRef attrs) + ArrayRef attrs, + ArrayRef succ) : Expr(Expr::globalAllocator()->Allocate()) { // Initialize with placement new. new (storage) - detail::ExprStorage(ExprKind::Variadic, name, types, exprs, attrs); + detail::ExprStorage(ExprKind::Variadic, name, types, exprs, attrs, succ); } ArrayRef mlir::edsc::VariadicExpr::getExprs() const { - return static_cast(storage)->operands; + return storage->operands; } ArrayRef mlir::edsc::VariadicExpr::getTypes() const { - return static_cast(storage)->resultTypes; + return storage->resultTypes; +} +ArrayRef mlir::edsc::VariadicExpr::getSuccessors() const { + return storage->successors; } mlir::edsc::StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind, @@ -905,24 +999,56 @@ mlir::edsc::StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind, // Initialize with placement new. new (storage) detail::ExprStorage(kind, "", types, exprs, {}); } -ArrayRef mlir::edsc::StmtBlockLikeExpr::getExprs() const { - return static_cast(storage)->operands; -} -SmallVector, 4> -mlir::edsc::StmtBlockLikeExpr::getExprGroups() const { - SmallVector, 4> groups; - ArrayRef exprs = getExprs(); - int start = 0; - for (int i = 0, e = exprs.size(); i < e; ++i) { - if (!exprs[i]) { - groups.push_back(exprs.slice(start, i - start)); - start = i + 1; - } + +static ArrayRef getOneArgumentGroupStartingFrom(int start, + ExprStorage *storage) { + for (int i = start, e = storage->operands.size(); i < e; ++i) { + if (!storage->operands[i]) + return storage->operands.slice(start, i - start); + } + return storage->operands.drop_front(start); +} + +static SmallVector, 4> +getAllArgumentGroupsStartingFrom(int start, ExprStorage *storage) { + SmallVector, 4> groups; + while (start < storage->operands.size()) { + auto group = getOneArgumentGroupStartingFrom(start, storage); + start += group.size() + 1; + groups.push_back(group); } - groups.push_back(exprs.slice(start, exprs.size() - start)); return groups; } +ArrayRef mlir::edsc::Expr::getProperArguments() const { + return getOneArgumentGroupStartingFrom(0, storage); +} + +SmallVector, 4> mlir::edsc::Expr::getSuccessorArguments() const { + // Skip the first group containing proper arguments. + // Note that +1 to size is necessary to step over the nullptrs in the list. + int start = getOneArgumentGroupStartingFrom(0, storage).size() + 1; + return getAllArgumentGroupsStartingFrom(start, storage); +} + +ArrayRef mlir::edsc::Expr::getSuccessorArguments(int index) const { + assert(index >= 0 && "argument group index is out of bounds"); + assert(!storage->operands.empty() && "argument list is empty"); + + // Skip over the first index + 1 groups (also includes proper arguments). + int start = 0; + for (int i = 0, e = index + 1; i < e; ++i) { + assert(start < storage->operands.size() && + "argument group index is out of bounds"); + start += getOneArgumentGroupStartingFrom(start, storage).size() + 1; + } + return getOneArgumentGroupStartingFrom(start, storage); +} + +SmallVector, 4> mlir::edsc::Expr::getAllArgumentGroups() const { + return getAllArgumentGroupsStartingFrom(0, storage); +} + mlir::edsc::Stmt::Stmt(const Bindable &lhs, const Expr &rhs, llvm::ArrayRef enclosedStmts) { storage = Expr::globalAllocator()->Allocate(); @@ -1012,15 +1138,29 @@ llvm::raw_ostream &mlir::edsc::operator<<(llvm::raw_ostream &os, } mlir::edsc::StmtBlock::StmtBlock(llvm::ArrayRef stmts) - : StmtBlock({}, {}, stmts) {} + : StmtBlock({}, stmts) {} mlir::edsc::StmtBlock::StmtBlock(llvm::ArrayRef args, - llvm::ArrayRef argTypes, llvm::ArrayRef stmts) { + // Extract block argument types from bindable types. + // Bindables must have a single type. + llvm::SmallVector argTypes; + argTypes.reserve(args.size()); + for (Bindable arg : args) { + auto argResults = arg.getResultTypes(); + assert(argResults.size() == 1 && + "only single-result expressions are supported"); + argTypes.push_back(argResults.front()); + } storage = Expr::globalAllocator()->Allocate(); new (storage) detail::StmtBlockStorage(args, argTypes, stmts); } +mlir::edsc::StmtBlock &mlir::edsc::StmtBlock::operator=(ArrayRef stmts) { + storage->replaceStmts(stmts); + return *this; +} + ArrayRef mlir::edsc::StmtBlock::getArguments() const { return storage->arguments; } @@ -1033,17 +1173,20 @@ ArrayRef mlir::edsc::StmtBlock::getBody() const { return storage->statements; } +uint64_t mlir::edsc::StmtBlock::getId() const { return storage->id; } + void mlir::edsc::StmtBlock::print(llvm::raw_ostream &os, Twine indent) const { - os << indent << "^bb"; + os << indent << "^bb" << storage->id; if (!getArgumentTypes().empty()) os << '('; interleaveComma(getArguments(), os); if (!getArgumentTypes().empty()) os << ')'; os << ":\n"; - - for (auto stmt : getBody()) + for (auto stmt : getBody()) { stmt.print(os, indent + " "); + os << '\n'; + } } std::string mlir::edsc::StmtBlock::str() const { diff --git a/mlir/test/EDSC/for-loops.mlir b/mlir/test/EDSC/for-loops.mlir index 7032b160316e..ea28ae1e4242 100644 --- a/mlir/test/EDSC/for-loops.mlir +++ b/mlir/test/EDSC/for-loops.mlir @@ -10,16 +10,19 @@ // CHECK-DAG: #[[id2dmap:.*]] = (d0, d1) -> (d0, d1) // This function will be detected by the test pass that will insert -// EDSC-constructed blocks with arguments. +// EDSC-constructed blocks with arguments forming an infinite loop. // CHECK-LABEL: @blocks func @blocks() { return -//CHECK: ^bb1(%0: i32, %1: i32): // no predecessors -//CHECK-NEXT: %2 = addi %0, %1 : i32 -//CHECK-NEXT: return -//CHECK: ^bb2(%3: i32, %4: i32): // no predecessors -//CHECK-NEXT: %5 = subi %3, %4 : i32 -//CHECK-NEXT: return +//CHECK: %c42_i32 = constant 42 : i32 +//CHECK-NEXT: %c1234_i32 = constant 1234 : i32 +//CHECK-NEXT: br ^bb1(%c42_i32, %c1234_i32 : i32, i32) +//CHECK-NEXT: ^bb1(%0: i32, %1: i32): // 2 preds: ^bb0, ^bb2 +//CHECK-NEXT: br ^bb2(%0, %1 : i32, i32) +//CHECK-NEXT: ^bb2(%2: i32, %3: i32): // pred: ^bb1 +//CHECK-NEXT: %4 = addi %2, %3 : i32 +//CHECK-NEXT: br ^bb1(%2, %4 : i32, i32) +//CHECK-NEXT: } } // This function will be detected by the test pass that will insert an