diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp index 25b9f56fb743..d3ffebbb50f8 100644 --- a/mlir/bindings/python/pybind.cpp +++ b/mlir/bindings/python/pybind.cpp @@ -391,6 +391,15 @@ PYBIND11_MODULE(pybind, m) { SmallVector owning; return PythonStmt(::For(iv, lb, ub, step, makeCStmts(owning, stmts))); }); + m.def("MaxMinFor", [](PythonExpr iv, const py::list &lbs, const py::list &ubs, + PythonExpr step, const py::list &stmts) { + SmallVector owningLBs; + SmallVector owningUBs; + SmallVector owningStmts; + return PythonStmt(::MaxMinFor(iv, makeCExprs(owningLBs, lbs), + makeCExprs(owningUBs, ubs), step, + makeCStmts(owningStmts, stmts))); + }); m.def("Select", [](PythonExpr cond, PythonExpr e1, PythonExpr e2) { return PythonExpr(::Select(cond, e1, e2)); }); diff --git a/mlir/bindings/python/test/test_py2and3.py b/mlir/bindings/python/test/test_py2and3.py index ad3868ad9090..564fb716c591 100644 --- a/mlir/bindings/python/test/test_py2and3.py +++ b/mlir/bindings/python/test/test_py2and3.py @@ -69,6 +69,17 @@ class EdscTest(unittest.TestCase): self.assertIn("for($1 = $2 to $3 step 42) {", str) self.assertIn("= (($3 * 42) + $2 * -1);", str) + def testMaxMinLoop(self): + with E.ContextManager(): + i = E.Expr(E.Bindable(self.indexType)) + step = E.Expr(E.Bindable(self.indexType)) + lbs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)])) + ubs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)])) + loop = E.MaxMinFor(i, lbs, ubs, step, []) + s = str(loop) + self.assertIn("for($1 = max($3, $4, $5, $6) to min($7, $8, $9) step $2)", + s) + def testIndexed(self): with E.ContextManager(): i, j, k = list( diff --git a/mlir/include/mlir-c/Core.h b/mlir/include/mlir-c/Core.h index c8967860f47c..8505cdc0ae6b 100644 --- a/mlir/include/mlir-c/Core.h +++ b/mlir/include/mlir-c/Core.h @@ -227,17 +227,24 @@ edsc_stmt_t Return(edsc_expr_list_t values); /// given list of statements. Block arguments are not currently supported. edsc_block_t Block(edsc_stmt_list_t enclosedStmts); -/// Returns an opaque statement for an mlir::ForInst with `enclosedStmts` nested -/// below it. +/// Returns an opaque statement for an mlir::AffineForOp with `enclosedStmts` +/// nested below it. edsc_stmt_t For(edsc_expr_t iv, edsc_expr_t lb, edsc_expr_t ub, edsc_expr_t step, edsc_stmt_list_t enclosedStmts); -/// Returns an opaque statement for a perfectly nested set of mlir::ForInst with -/// `enclosedStmts` nested below it. +/// Returns an opaque statement for a perfectly nested set of mlir::AffineForOp +/// with `enclosedStmts` nested below it. edsc_stmt_t ForNest(edsc_expr_list_t iv, edsc_expr_list_t lb, edsc_expr_list_t ub, edsc_expr_list_t step, edsc_stmt_list_t enclosedStmts); +/// Returns an opaque statement for an mlir::AffineForOp with the lower bound +/// `max(lbs)` and the upper bound `min(ubs)`, and with `enclosedStmts` nested +/// below it. +edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_expr_list_t lbs, + edsc_expr_list_t ubs, edsc_expr_t step, + edsc_stmt_list_t enclosedStmts); + /// Returns an opaque expression for the corresponding Binary operation. edsc_expr_t Add(edsc_expr_t e1, edsc_expr_t e2); edsc_expr_t Sub(edsc_expr_t e1, edsc_expr_t e2); diff --git a/mlir/include/mlir/EDSC/Types.h b/mlir/include/mlir/EDSC/Types.h index bd57aa415f3e..599d6e3b1bdb 100644 --- a/mlir/include/mlir/EDSC/Types.h +++ b/mlir/include/mlir/EDSC/Types.h @@ -183,6 +183,9 @@ public: /// For debugging purposes. const void *getStoragePtr() const { return storage; } + /// Explicit conversion to bool. Useful in conjunction with dyn_cast. + explicit operator bool() const { return storage != nullptr; } + friend ::llvm::hash_code hash_value(Expr arg); protected: @@ -285,8 +288,19 @@ struct StmtBlockLikeExpr : public Expr { friend class Expr; StmtBlockLikeExpr(ExprKind kind, llvm::ArrayRef exprs, llvm::ArrayRef types = {}); + + /// Get the list of subexpressions. + /// StmtBlockLikeExprs can contain multiple groups of subexpressions separated + /// by null expressions and the result of this call will include them. llvm::ArrayRef getExprs() const; + /// Get the list of subexpression groups. + /// StmtBlockLikeExprs can contain multiple groups of subexpressions separated + /// by null expressions. This will identify those groups and return a list + /// of lists of subexpressions split around null expressions. Two null + /// expressions in a row identify an empty group. + SmallVector, 4> getExprGroups() const; + protected: StmtBlockLikeExpr(Expr::ImplType *ptr) : Expr(ptr) { assert(!ptr || isa() && "expected StmtBlockLikeExpr"); @@ -605,6 +619,13 @@ Stmt For(llvm::ArrayRef indices, llvm::ArrayRef lbs, llvm::ArrayRef ubs, llvm::ArrayRef steps, llvm::ArrayRef enclosedStmts); +/// Define a 'for' loop from with multi-valued bounds. +/// +/// for max(lbs...) to min(ubs...) {} +/// +Stmt MaxMinFor(const Bindable &idx, ArrayRef lbs, ArrayRef ubs, + Expr step, ArrayRef enclosedStmts); + StmtBlock block(llvm::ArrayRef args, llvm::ArrayRef argTypes, llvm::ArrayRef stmts); inline StmtBlock block(llvm::ArrayRef stmts) { diff --git a/mlir/lib/EDSC/LowerEDSCTestPass.cpp b/mlir/lib/EDSC/LowerEDSCTestPass.cpp index 987ad5a39ad2..41cb1734c92e 100644 --- a/mlir/lib/EDSC/LowerEDSCTestPass.cpp +++ b/mlir/lib/EDSC/LowerEDSCTestPass.cpp @@ -118,6 +118,33 @@ PassResult LowerEDSCTestPass::runOnFunction(Function *f) { return success(); } + if (f->getName().strref() == "max_min_for") { + assert(!f->getBlocks().empty() && "max_min_for should not be empty"); + FuncBuilder builder(&f->getBlocks().front(), + f->getBlocks().front().begin()); + assert(f->getNumArguments() == 4 && "max_min_for expected 4 arguments"); + for (const auto *arg : f->getArguments()) + assert(arg->getType().isIndex() && + "max_min_for expected index arguments"); + + edsc::ScopedEDSCContext context; + edsc::Expr lb1(f->getArgument(0)->getType()); + edsc::Expr lb2(f->getArgument(1)->getType()); + edsc::Expr ub1(f->getArgument(2)->getType()); + edsc::Expr ub2(f->getArgument(3)->getType()); + edsc::Expr iv(builder.getIndexType()); + edsc::Expr step = edsc::constantInteger(builder.getIndexType(), 1); + auto loop = + edsc::MaxMinFor(edsc::Bindable(iv), {lb1, lb2}, {ub1, ub2}, step, {}); + edsc::MLIREmitter(&builder, f->getLoc()) + .bind(edsc::Bindable(lb1), f->getArgument(0)) + .bind(edsc::Bindable(lb2), f->getArgument(1)) + .bind(edsc::Bindable(ub1), f->getArgument(2)) + .bind(edsc::Bindable(ub2), f->getArgument(3)) + .emitStmt(loop); + + return success(); + } // Inject an EDSC-constructed computation that assigns Stmt and uses the LHS. if (f->getName().strref().contains("assignments")) { diff --git a/mlir/lib/EDSC/MLIREmitter.cpp b/mlir/lib/EDSC/MLIREmitter.cpp index cbaca8201640..dfd7ce59f46c 100644 --- a/mlir/lib/EDSC/MLIREmitter.cpp +++ b/mlir/lib/EDSC/MLIREmitter.cpp @@ -82,6 +82,38 @@ MLIREmitter &mlir::edsc::MLIREmitter::bind(Bindable e, Value *v) { return *this; } +static void checkAffineProvenance(ArrayRef values) { + for (Value *v : values) { + auto *def = v->getDefiningInst(); + // There may be no defining instruction if the value is a function + // argument. We accept such values. + assert((!def || def->isa() || def->isa() || + def->isa() || def->isa()) && + "loop bound expression must have affine provenance"); + } +} + +static OpPointer emitStaticFor(FuncBuilder &builder, Location loc, + ArrayRef lbs, + ArrayRef ubs, + uint64_t step) { + if (lbs.size() != 1 || ubs.size() != 1) + return OpPointer(); + + auto *lbDef = lbs.front()->getDefiningInst(); + auto *ubDef = ubs.front()->getDefiningInst(); + if (!lbDef || !ubDef) + return OpPointer(); + + auto lbConst = lbDef->dyn_cast(); + auto ubConst = ubDef->dyn_cast(); + if (!lbConst || !ubConst) + return OpPointer(); + + return builder.create(loc, lbConst->getValue(), + ubConst->getValue(), step); +} + Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) { // It is still necessary in case we try to emit a bindable directly // FIXME: make sure isa works and use it below to delegate emission @@ -104,48 +136,37 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) { if (auto expr = e.dyn_cast()) { if (expr.getKind() == ExprKind::For) { - auto exprs = emitExprs(expr.getExprs()); - if (llvm::any_of(exprs, [](Value *v) { return !v; })) { - return nullptr; - } - assert(exprs.size() == 3 && "Expected 3 exprs"); - auto *lb = exprs[0]; - auto *ub = exprs[1]; + auto exprGroups = expr.getExprGroups(); + assert(exprGroups.size() == 3 && "expected 3 expr groups in `for`"); + assert(!exprGroups[0].empty() && "expected at least one lower bound"); + assert(!exprGroups[1].empty() && "expected at least one upper bound"); + assert(exprGroups[2].size() == 1 && + "the third group (step) must have one element"); - // There may be no defining instruction if the value is a function - // argument. We accept such values. - auto *lbDef = lb->getDefiningInst(); - (void)lbDef; - assert((!lbDef || lbDef->isa() || - lbDef->isa() || lbDef->isa() || - lbDef->isa()) && - "lower bound expression does not have affine provenance"); - auto *ubDef = ub->getDefiningInst(); - (void)ubDef; - assert((!ubDef || ubDef->isa() || - ubDef->isa() || ubDef->isa() || - ubDef->isa()) && - "upper bound expression does not have affine provenance"); + auto lbs = emitExprs(exprGroups[0]); + auto ubs = emitExprs(exprGroups[1]); + auto stepExpr = emitExpr(exprGroups[2][0]); + + if (llvm::any_of(lbs, [](Value *v) { return !v; }) || + llvm::any_of(ubs, [](Value *v) { return !v; }) || !stepExpr) + return nullptr; + + checkAffineProvenance(lbs); + checkAffineProvenance(ubs); // Step must be a static constant. auto step = - exprs[2]->getDefiningInst()->cast()->getValue(); + stepExpr->getDefiningInst()->cast()->getValue(); // Special case with more concise emitted code for static bounds. - OpPointer forOp; - if (lbDef && ubDef) - if (auto lbConst = lbDef->dyn_cast()) - if (auto ubConst = ubDef->dyn_cast()) - forOp = builder->create(location, lbConst->getValue(), - ubConst->getValue(), step); + OpPointer forOp = + emitStaticFor(*builder, location, lbs, ubs, step); // General case. - if (!forOp) { - auto map = builder->getDimIdentityMap(); - forOp = - builder->create(location, llvm::makeArrayRef(lb), map, - llvm::makeArrayRef(ub), map, step); - } + if (!forOp) + forOp = builder->create( + location, lbs, builder->getMultiDimIdentityMap(lbs.size()), ubs, + builder->getMultiDimIdentityMap(ubs.size()), step); forOp->createBody(); res = forOp->getInductionVar(); } diff --git a/mlir/lib/EDSC/Types.cpp b/mlir/lib/EDSC/Types.cpp index 5a580d556169..451b0917fb22 100644 --- a/mlir/lib/EDSC/Types.cpp +++ b/mlir/lib/EDSC/Types.cpp @@ -359,7 +359,14 @@ Stmt mlir::edsc::For(Expr lb, Expr ub, Expr step, ArrayRef stmts) { Stmt mlir::edsc::For(const Bindable &idx, Expr lb, Expr ub, Expr step, ArrayRef stmts) { - return Stmt(idx, StmtBlockLikeExpr(ExprKind::For, {lb, ub, step}), stmts); + assert(lb); + assert(ub); + assert(step); + // Use a null expression as a sentinel between lower and upper bound + // expressions in the list of children. + return Stmt( + idx, StmtBlockLikeExpr(ExprKind::For, {lb, nullptr, ub, nullptr, step}), + stmts); } Stmt mlir::edsc::For(ArrayRef indices, ArrayRef lbs, @@ -380,6 +387,24 @@ Stmt mlir::edsc::For(ArrayRef indices, ArrayRef lbs, return curStmt; } +Stmt mlir::edsc::MaxMinFor(const Bindable &idx, ArrayRef lbs, + ArrayRef ubs, Expr step, + ArrayRef enclosedStmts) { + assert(!lbs.empty() && "'for' loop must have lower bounds"); + assert(!ubs.empty() && "'for' loop must have upper bounds"); + + // Use a null expression as a sentinel between lower and upper bound + // expressions in the list of children. + SmallVector exprs; + exprs.insert(exprs.end(), lbs.begin(), lbs.end()); + exprs.push_back(nullptr); + exprs.insert(exprs.end(), ubs.begin(), ubs.end()); + exprs.push_back(nullptr); + exprs.push_back(step); + + return Stmt(idx, StmtBlockLikeExpr(ExprKind::For, exprs), enclosedStmts); +} + edsc_stmt_t For(edsc_expr_t iv, edsc_expr_t lb, edsc_expr_t ub, edsc_expr_t step, edsc_stmt_list_t enclosedStmts) { llvm::SmallVector stmts; @@ -397,6 +422,15 @@ edsc_stmt_t ForNest(edsc_expr_list_t ivs, edsc_expr_list_t lbs, makeExprs(steps), stmts)); } +edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_expr_list_t lbs, + edsc_expr_list_t ubs, edsc_expr_t step, + edsc_stmt_list_t enclosedStmts) { + llvm::SmallVector stmts; + fillStmts(enclosedStmts, &stmts); + return Stmt(MaxMinFor(Expr(iv).cast(), makeExprs(lbs), + makeExprs(ubs), Expr(step), stmts)); +} + StmtBlock mlir::edsc::block(ArrayRef args, ArrayRef argTypes, ArrayRef stmts) { assert(args.size() == argTypes.size() && @@ -669,14 +703,26 @@ void mlir::edsc::Expr::print(raw_ostream &os) const { os << ')'; return; } else if (auto stmtLikeExpr = this->dyn_cast()) { - auto exprs = stmtLikeExpr.getExprs(); switch (stmtLikeExpr.getKind()) { // We only print the lb, ub and step here, which are the StmtBlockLike // part of the `for` StmtBlockLikeExpr. - case ExprKind::For: - assert(exprs.size() == 3 && "For StmtBlockLikeExpr expected 3 exprs"); - os << exprs[0] << " to " << exprs[1] << " step " << exprs[2]; + case ExprKind::For: { + auto exprGroups = stmtLikeExpr.getExprGroups(); + assert(exprGroups.size() == 3 && + "For StmtBlockLikeExpr expected 3 groups"); + assert(exprGroups[2].size() == 1 && "expected 1 expr for loop step"); + if (exprGroups[0].size() == 1 && exprGroups[1].size() == 1) { + os << exprGroups[0][0] << " to " << exprGroups[1][0] << " step " + << exprGroups[2][0]; + } else { + os << "max("; + interleaveComma(exprGroups[0], os); + os << ") to min("; + interleaveComma(exprGroups[1], os); + os << ") step " << exprGroups[2][0]; + } return; + } default: { os << "unknown_stmt"; } @@ -772,6 +818,20 @@ mlir::edsc::StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind, ArrayRef mlir::edsc::StmtBlockLikeExpr::getExprs() const { return static_cast(storage)->operands; } +SmallVector, 4> +mlir::edsc::StmtBlockLikeExpr::getExprGroups() const { + SmallVector, 4> groups; + ArrayRef exprs = getExprs(); + int start = 0; + for (int i = 0, e = exprs.size(); i < e; ++i) { + if (!exprs[i]) { + groups.push_back(exprs.slice(start, i - start)); + start = i + 1; + } + } + groups.push_back(exprs.slice(start, exprs.size() - start)); + return groups; +} mlir::edsc::Stmt::Stmt(const Bindable &lhs, const Expr &rhs, llvm::ArrayRef enclosedStmts) { diff --git a/mlir/test/EDSC/for-loops.mlir b/mlir/test/EDSC/for-loops.mlir index 12b31b2f52f3..9444a56cdd77 100644 --- a/mlir/test/EDSC/for-loops.mlir +++ b/mlir/test/EDSC/for-loops.mlir @@ -6,6 +6,7 @@ // CHECK-DAG: #[[addmap:.*]] = (d0, d1) -> (d0 + d1) // CHECK-DAG: #[[prodconstmap:.*]] = (d0, d1) -> (d0 * 3) // CHECK-DAG: #[[addconstmap:.*]] = (d0, d1) -> (d1 + 3) +// CHECK-DAG: #[[id2dmap:.*]] = (d0, d1) -> (d0, d1) // This function will be detected by the test pass that will insert // EDSC-constructed blocks with arguments. @@ -70,3 +71,14 @@ func @assignments_1(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4x func @assignments_2(%arg0: memref, %arg1: memref, %arg2: memref) { return } + +// This function will be detected by the test pass that will insert an +// EDSC-constructed empty `for` loop with max/min bounds that correspond to +// for max(%arg0, %arg1) to (%arg2, %arg3) step 1 +// before the `return` instruction. +// +// CHECK-LABEL: func @max_min_for(%arg0: index, %arg1: index, %arg2: index, %arg3: index) { +// CHECK: for %i0 = max #[[id2dmap]](%arg0, %arg1) to min #[[id2dmap]](%arg2, %arg3) { +func @max_min_for(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index) { + return +}