EDSC: support multi-expression loop bounds

MLIR supports 'for' loops with lower(upper) bound defined by taking a
maximum(minimum) of a list of expressions, but does not have first-class affine
constructs for the maximum(minimum).  All these expressions must have affine
provenance, similarly to a single-expression bound.  Add support for
constructing such loops using EDSC.  The expression factory function is called
`edsc::MaxMinFor` to (1) highlight that the maximum(minimum) operation is
applied to the lower(upper) bound expressions and (2) differentiate it from a
`edsc::For` that creates multiple perfectly nested loops (and should arguably
be called `edsc::ForNest`).

PiperOrigin-RevId: 234785996
This commit is contained in:
Alex Zinenko 2019-02-20 06:54:36 -08:00 committed by jpienaar
parent a2a433652d
commit d055a4e100
8 changed files with 211 additions and 43 deletions

View File

@ -391,6 +391,15 @@ PYBIND11_MODULE(pybind, m) {
SmallVector<edsc_stmt_t, 8> owning;
return PythonStmt(::For(iv, lb, ub, step, makeCStmts(owning, stmts)));
});
m.def("MaxMinFor", [](PythonExpr iv, const py::list &lbs, const py::list &ubs,
PythonExpr step, const py::list &stmts) {
SmallVector<edsc_expr_t, 8> owningLBs;
SmallVector<edsc_expr_t, 8> owningUBs;
SmallVector<edsc_stmt_t, 8> owningStmts;
return PythonStmt(::MaxMinFor(iv, makeCExprs(owningLBs, lbs),
makeCExprs(owningUBs, ubs), step,
makeCStmts(owningStmts, stmts)));
});
m.def("Select", [](PythonExpr cond, PythonExpr e1, PythonExpr e2) {
return PythonExpr(::Select(cond, e1, e2));
});

View File

@ -69,6 +69,17 @@ class EdscTest(unittest.TestCase):
self.assertIn("for($1 = $2 to $3 step 42) {", str)
self.assertIn("= (($3 * 42) + $2 * -1);", str)
def testMaxMinLoop(self):
with E.ContextManager():
i = E.Expr(E.Bindable(self.indexType))
step = E.Expr(E.Bindable(self.indexType))
lbs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
ubs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
loop = E.MaxMinFor(i, lbs, ubs, step, [])
s = str(loop)
self.assertIn("for($1 = max($3, $4, $5, $6) to min($7, $8, $9) step $2)",
s)
def testIndexed(self):
with E.ContextManager():
i, j, k = list(

View File

@ -227,17 +227,24 @@ edsc_stmt_t Return(edsc_expr_list_t values);
/// given list of statements. Block arguments are not currently supported.
edsc_block_t Block(edsc_stmt_list_t enclosedStmts);
/// Returns an opaque statement for an mlir::ForInst with `enclosedStmts` nested
/// below it.
/// Returns an opaque statement for an mlir::AffineForOp with `enclosedStmts`
/// nested below it.
edsc_stmt_t For(edsc_expr_t iv, edsc_expr_t lb, edsc_expr_t ub,
edsc_expr_t step, edsc_stmt_list_t enclosedStmts);
/// Returns an opaque statement for a perfectly nested set of mlir::ForInst with
/// `enclosedStmts` nested below it.
/// Returns an opaque statement for a perfectly nested set of mlir::AffineForOp
/// with `enclosedStmts` nested below it.
edsc_stmt_t ForNest(edsc_expr_list_t iv, edsc_expr_list_t lb,
edsc_expr_list_t ub, edsc_expr_list_t step,
edsc_stmt_list_t enclosedStmts);
/// Returns an opaque statement for an mlir::AffineForOp with the lower bound
/// `max(lbs)` and the upper bound `min(ubs)`, and with `enclosedStmts` nested
/// below it.
edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_expr_list_t lbs,
edsc_expr_list_t ubs, edsc_expr_t step,
edsc_stmt_list_t enclosedStmts);
/// Returns an opaque expression for the corresponding Binary operation.
edsc_expr_t Add(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t Sub(edsc_expr_t e1, edsc_expr_t e2);

View File

@ -183,6 +183,9 @@ public:
/// For debugging purposes.
const void *getStoragePtr() const { return storage; }
/// Explicit conversion to bool. Useful in conjunction with dyn_cast.
explicit operator bool() const { return storage != nullptr; }
friend ::llvm::hash_code hash_value(Expr arg);
protected:
@ -285,8 +288,19 @@ struct StmtBlockLikeExpr : public Expr {
friend class Expr;
StmtBlockLikeExpr(ExprKind kind, llvm::ArrayRef<Expr> exprs,
llvm::ArrayRef<Type> types = {});
/// Get the list of subexpressions.
/// StmtBlockLikeExprs can contain multiple groups of subexpressions separated
/// by null expressions and the result of this call will include them.
llvm::ArrayRef<Expr> getExprs() const;
/// Get the list of subexpression groups.
/// StmtBlockLikeExprs can contain multiple groups of subexpressions separated
/// by null expressions. This will identify those groups and return a list
/// of lists of subexpressions split around null expressions. Two null
/// expressions in a row identify an empty group.
SmallVector<llvm::ArrayRef<Expr>, 4> getExprGroups() const;
protected:
StmtBlockLikeExpr(Expr::ImplType *ptr) : Expr(ptr) {
assert(!ptr || isa<StmtBlockLikeExpr>() && "expected StmtBlockLikeExpr");
@ -605,6 +619,13 @@ Stmt For(llvm::ArrayRef<Expr> indices, llvm::ArrayRef<Expr> lbs,
llvm::ArrayRef<Expr> ubs, llvm::ArrayRef<Expr> steps,
llvm::ArrayRef<Stmt> enclosedStmts);
/// Define a 'for' loop from with multi-valued bounds.
///
/// for max(lbs...) to min(ubs...) {}
///
Stmt MaxMinFor(const Bindable &idx, ArrayRef<Expr> lbs, ArrayRef<Expr> ubs,
Expr step, ArrayRef<Stmt> enclosedStmts);
StmtBlock block(llvm::ArrayRef<Bindable> args, llvm::ArrayRef<Type> argTypes,
llvm::ArrayRef<Stmt> stmts);
inline StmtBlock block(llvm::ArrayRef<Stmt> stmts) {

View File

@ -118,6 +118,33 @@ PassResult LowerEDSCTestPass::runOnFunction(Function *f) {
return success();
}
if (f->getName().strref() == "max_min_for") {
assert(!f->getBlocks().empty() && "max_min_for should not be empty");
FuncBuilder builder(&f->getBlocks().front(),
f->getBlocks().front().begin());
assert(f->getNumArguments() == 4 && "max_min_for expected 4 arguments");
for (const auto *arg : f->getArguments())
assert(arg->getType().isIndex() &&
"max_min_for expected index arguments");
edsc::ScopedEDSCContext context;
edsc::Expr lb1(f->getArgument(0)->getType());
edsc::Expr lb2(f->getArgument(1)->getType());
edsc::Expr ub1(f->getArgument(2)->getType());
edsc::Expr ub2(f->getArgument(3)->getType());
edsc::Expr iv(builder.getIndexType());
edsc::Expr step = edsc::constantInteger(builder.getIndexType(), 1);
auto loop =
edsc::MaxMinFor(edsc::Bindable(iv), {lb1, lb2}, {ub1, ub2}, step, {});
edsc::MLIREmitter(&builder, f->getLoc())
.bind(edsc::Bindable(lb1), f->getArgument(0))
.bind(edsc::Bindable(lb2), f->getArgument(1))
.bind(edsc::Bindable(ub1), f->getArgument(2))
.bind(edsc::Bindable(ub2), f->getArgument(3))
.emitStmt(loop);
return success();
}
// Inject an EDSC-constructed computation that assigns Stmt and uses the LHS.
if (f->getName().strref().contains("assignments")) {

View File

@ -82,6 +82,38 @@ MLIREmitter &mlir::edsc::MLIREmitter::bind(Bindable e, Value *v) {
return *this;
}
static void checkAffineProvenance(ArrayRef<Value *> values) {
for (Value *v : values) {
auto *def = v->getDefiningInst();
// There may be no defining instruction if the value is a function
// argument. We accept such values.
assert((!def || def->isa<ConstantIndexOp>() || def->isa<AffineApplyOp>() ||
def->isa<AffineForOp>() || def->isa<DimOp>()) &&
"loop bound expression must have affine provenance");
}
}
static OpPointer<AffineForOp> emitStaticFor(FuncBuilder &builder, Location loc,
ArrayRef<Value *> lbs,
ArrayRef<Value *> ubs,
uint64_t step) {
if (lbs.size() != 1 || ubs.size() != 1)
return OpPointer<AffineForOp>();
auto *lbDef = lbs.front()->getDefiningInst();
auto *ubDef = ubs.front()->getDefiningInst();
if (!lbDef || !ubDef)
return OpPointer<AffineForOp>();
auto lbConst = lbDef->dyn_cast<ConstantIndexOp>();
auto ubConst = ubDef->dyn_cast<ConstantIndexOp>();
if (!lbConst || !ubConst)
return OpPointer<AffineForOp>();
return builder.create<AffineForOp>(loc, lbConst->getValue(),
ubConst->getValue(), step);
}
Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) {
// It is still necessary in case we try to emit a bindable directly
// FIXME: make sure isa<Bindable> works and use it below to delegate emission
@ -104,48 +136,37 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) {
if (auto expr = e.dyn_cast<StmtBlockLikeExpr>()) {
if (expr.getKind() == ExprKind::For) {
auto exprs = emitExprs(expr.getExprs());
if (llvm::any_of(exprs, [](Value *v) { return !v; })) {
return nullptr;
}
assert(exprs.size() == 3 && "Expected 3 exprs");
auto *lb = exprs[0];
auto *ub = exprs[1];
auto exprGroups = expr.getExprGroups();
assert(exprGroups.size() == 3 && "expected 3 expr groups in `for`");
assert(!exprGroups[0].empty() && "expected at least one lower bound");
assert(!exprGroups[1].empty() && "expected at least one upper bound");
assert(exprGroups[2].size() == 1 &&
"the third group (step) must have one element");
// There may be no defining instruction if the value is a function
// argument. We accept such values.
auto *lbDef = lb->getDefiningInst();
(void)lbDef;
assert((!lbDef || lbDef->isa<ConstantIndexOp>() ||
lbDef->isa<AffineApplyOp>() || lbDef->isa<AffineForOp>() ||
lbDef->isa<DimOp>()) &&
"lower bound expression does not have affine provenance");
auto *ubDef = ub->getDefiningInst();
(void)ubDef;
assert((!ubDef || ubDef->isa<ConstantIndexOp>() ||
ubDef->isa<AffineApplyOp>() || ubDef->isa<AffineForOp>() ||
ubDef->isa<DimOp>()) &&
"upper bound expression does not have affine provenance");
auto lbs = emitExprs(exprGroups[0]);
auto ubs = emitExprs(exprGroups[1]);
auto stepExpr = emitExpr(exprGroups[2][0]);
if (llvm::any_of(lbs, [](Value *v) { return !v; }) ||
llvm::any_of(ubs, [](Value *v) { return !v; }) || !stepExpr)
return nullptr;
checkAffineProvenance(lbs);
checkAffineProvenance(ubs);
// Step must be a static constant.
auto step =
exprs[2]->getDefiningInst()->cast<ConstantIndexOp>()->getValue();
stepExpr->getDefiningInst()->cast<ConstantIndexOp>()->getValue();
// Special case with more concise emitted code for static bounds.
OpPointer<AffineForOp> forOp;
if (lbDef && ubDef)
if (auto lbConst = lbDef->dyn_cast<ConstantIndexOp>())
if (auto ubConst = ubDef->dyn_cast<ConstantIndexOp>())
forOp = builder->create<AffineForOp>(location, lbConst->getValue(),
ubConst->getValue(), step);
OpPointer<AffineForOp> forOp =
emitStaticFor(*builder, location, lbs, ubs, step);
// General case.
if (!forOp) {
auto map = builder->getDimIdentityMap();
forOp =
builder->create<AffineForOp>(location, llvm::makeArrayRef(lb), map,
llvm::makeArrayRef(ub), map, step);
}
if (!forOp)
forOp = builder->create<AffineForOp>(
location, lbs, builder->getMultiDimIdentityMap(lbs.size()), ubs,
builder->getMultiDimIdentityMap(ubs.size()), step);
forOp->createBody();
res = forOp->getInductionVar();
}

View File

@ -359,7 +359,14 @@ Stmt mlir::edsc::For(Expr lb, Expr ub, Expr step, ArrayRef<Stmt> stmts) {
Stmt mlir::edsc::For(const Bindable &idx, Expr lb, Expr ub, Expr step,
ArrayRef<Stmt> stmts) {
return Stmt(idx, StmtBlockLikeExpr(ExprKind::For, {lb, ub, step}), stmts);
assert(lb);
assert(ub);
assert(step);
// Use a null expression as a sentinel between lower and upper bound
// expressions in the list of children.
return Stmt(
idx, StmtBlockLikeExpr(ExprKind::For, {lb, nullptr, ub, nullptr, step}),
stmts);
}
Stmt mlir::edsc::For(ArrayRef<Expr> indices, ArrayRef<Expr> lbs,
@ -380,6 +387,24 @@ Stmt mlir::edsc::For(ArrayRef<Expr> indices, ArrayRef<Expr> lbs,
return curStmt;
}
Stmt mlir::edsc::MaxMinFor(const Bindable &idx, ArrayRef<Expr> lbs,
ArrayRef<Expr> ubs, Expr step,
ArrayRef<Stmt> enclosedStmts) {
assert(!lbs.empty() && "'for' loop must have lower bounds");
assert(!ubs.empty() && "'for' loop must have upper bounds");
// Use a null expression as a sentinel between lower and upper bound
// expressions in the list of children.
SmallVector<Expr, 8> exprs;
exprs.insert(exprs.end(), lbs.begin(), lbs.end());
exprs.push_back(nullptr);
exprs.insert(exprs.end(), ubs.begin(), ubs.end());
exprs.push_back(nullptr);
exprs.push_back(step);
return Stmt(idx, StmtBlockLikeExpr(ExprKind::For, exprs), enclosedStmts);
}
edsc_stmt_t For(edsc_expr_t iv, edsc_expr_t lb, edsc_expr_t ub,
edsc_expr_t step, edsc_stmt_list_t enclosedStmts) {
llvm::SmallVector<Stmt, 8> stmts;
@ -397,6 +422,15 @@ edsc_stmt_t ForNest(edsc_expr_list_t ivs, edsc_expr_list_t lbs,
makeExprs(steps), stmts));
}
edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_expr_list_t lbs,
edsc_expr_list_t ubs, edsc_expr_t step,
edsc_stmt_list_t enclosedStmts) {
llvm::SmallVector<Stmt, 8> stmts;
fillStmts(enclosedStmts, &stmts);
return Stmt(MaxMinFor(Expr(iv).cast<Bindable>(), makeExprs(lbs),
makeExprs(ubs), Expr(step), stmts));
}
StmtBlock mlir::edsc::block(ArrayRef<Bindable> args, ArrayRef<Type> argTypes,
ArrayRef<Stmt> stmts) {
assert(args.size() == argTypes.size() &&
@ -669,14 +703,26 @@ void mlir::edsc::Expr::print(raw_ostream &os) const {
os << ')';
return;
} else if (auto stmtLikeExpr = this->dyn_cast<StmtBlockLikeExpr>()) {
auto exprs = stmtLikeExpr.getExprs();
switch (stmtLikeExpr.getKind()) {
// We only print the lb, ub and step here, which are the StmtBlockLike
// part of the `for` StmtBlockLikeExpr.
case ExprKind::For:
assert(exprs.size() == 3 && "For StmtBlockLikeExpr expected 3 exprs");
os << exprs[0] << " to " << exprs[1] << " step " << exprs[2];
case ExprKind::For: {
auto exprGroups = stmtLikeExpr.getExprGroups();
assert(exprGroups.size() == 3 &&
"For StmtBlockLikeExpr expected 3 groups");
assert(exprGroups[2].size() == 1 && "expected 1 expr for loop step");
if (exprGroups[0].size() == 1 && exprGroups[1].size() == 1) {
os << exprGroups[0][0] << " to " << exprGroups[1][0] << " step "
<< exprGroups[2][0];
} else {
os << "max(";
interleaveComma(exprGroups[0], os);
os << ") to min(";
interleaveComma(exprGroups[1], os);
os << ") step " << exprGroups[2][0];
}
return;
}
default: {
os << "unknown_stmt";
}
@ -772,6 +818,20 @@ mlir::edsc::StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind,
ArrayRef<Expr> mlir::edsc::StmtBlockLikeExpr::getExprs() const {
return static_cast<ImplType *>(storage)->operands;
}
SmallVector<ArrayRef<Expr>, 4>
mlir::edsc::StmtBlockLikeExpr::getExprGroups() const {
SmallVector<ArrayRef<Expr>, 4> groups;
ArrayRef<Expr> exprs = getExprs();
int start = 0;
for (int i = 0, e = exprs.size(); i < e; ++i) {
if (!exprs[i]) {
groups.push_back(exprs.slice(start, i - start));
start = i + 1;
}
}
groups.push_back(exprs.slice(start, exprs.size() - start));
return groups;
}
mlir::edsc::Stmt::Stmt(const Bindable &lhs, const Expr &rhs,
llvm::ArrayRef<Stmt> enclosedStmts) {

View File

@ -6,6 +6,7 @@
// CHECK-DAG: #[[addmap:.*]] = (d0, d1) -> (d0 + d1)
// CHECK-DAG: #[[prodconstmap:.*]] = (d0, d1) -> (d0 * 3)
// CHECK-DAG: #[[addconstmap:.*]] = (d0, d1) -> (d1 + 3)
// CHECK-DAG: #[[id2dmap:.*]] = (d0, d1) -> (d0, d1)
// This function will be detected by the test pass that will insert
// EDSC-constructed blocks with arguments.
@ -70,3 +71,14 @@ func @assignments_1(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4x
func @assignments_2(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
return
}
// This function will be detected by the test pass that will insert an
// EDSC-constructed empty `for` loop with max/min bounds that correspond to
// for max(%arg0, %arg1) to (%arg2, %arg3) step 1
// before the `return` instruction.
//
// CHECK-LABEL: func @max_min_for(%arg0: index, %arg1: index, %arg2: index, %arg3: index) {
// CHECK: for %i0 = max #[[id2dmap]](%arg0, %arg1) to min #[[id2dmap]](%arg2, %arg3) {
func @max_min_for(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index) {
return
}