forked from OSchip/llvm-project
EDSC: support multi-expression loop bounds
MLIR supports 'for' loops with lower(upper) bound defined by taking a maximum(minimum) of a list of expressions, but does not have first-class affine constructs for the maximum(minimum). All these expressions must have affine provenance, similarly to a single-expression bound. Add support for constructing such loops using EDSC. The expression factory function is called `edsc::MaxMinFor` to (1) highlight that the maximum(minimum) operation is applied to the lower(upper) bound expressions and (2) differentiate it from a `edsc::For` that creates multiple perfectly nested loops (and should arguably be called `edsc::ForNest`). PiperOrigin-RevId: 234785996
This commit is contained in:
parent
a2a433652d
commit
d055a4e100
|
@ -391,6 +391,15 @@ PYBIND11_MODULE(pybind, m) {
|
|||
SmallVector<edsc_stmt_t, 8> owning;
|
||||
return PythonStmt(::For(iv, lb, ub, step, makeCStmts(owning, stmts)));
|
||||
});
|
||||
m.def("MaxMinFor", [](PythonExpr iv, const py::list &lbs, const py::list &ubs,
|
||||
PythonExpr step, const py::list &stmts) {
|
||||
SmallVector<edsc_expr_t, 8> owningLBs;
|
||||
SmallVector<edsc_expr_t, 8> owningUBs;
|
||||
SmallVector<edsc_stmt_t, 8> owningStmts;
|
||||
return PythonStmt(::MaxMinFor(iv, makeCExprs(owningLBs, lbs),
|
||||
makeCExprs(owningUBs, ubs), step,
|
||||
makeCStmts(owningStmts, stmts)));
|
||||
});
|
||||
m.def("Select", [](PythonExpr cond, PythonExpr e1, PythonExpr e2) {
|
||||
return PythonExpr(::Select(cond, e1, e2));
|
||||
});
|
||||
|
|
|
@ -69,6 +69,17 @@ class EdscTest(unittest.TestCase):
|
|||
self.assertIn("for($1 = $2 to $3 step 42) {", str)
|
||||
self.assertIn("= (($3 * 42) + $2 * -1);", str)
|
||||
|
||||
def testMaxMinLoop(self):
|
||||
with E.ContextManager():
|
||||
i = E.Expr(E.Bindable(self.indexType))
|
||||
step = E.Expr(E.Bindable(self.indexType))
|
||||
lbs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
|
||||
ubs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
|
||||
loop = E.MaxMinFor(i, lbs, ubs, step, [])
|
||||
s = str(loop)
|
||||
self.assertIn("for($1 = max($3, $4, $5, $6) to min($7, $8, $9) step $2)",
|
||||
s)
|
||||
|
||||
def testIndexed(self):
|
||||
with E.ContextManager():
|
||||
i, j, k = list(
|
||||
|
|
|
@ -227,17 +227,24 @@ edsc_stmt_t Return(edsc_expr_list_t values);
|
|||
/// given list of statements. Block arguments are not currently supported.
|
||||
edsc_block_t Block(edsc_stmt_list_t enclosedStmts);
|
||||
|
||||
/// Returns an opaque statement for an mlir::ForInst with `enclosedStmts` nested
|
||||
/// below it.
|
||||
/// Returns an opaque statement for an mlir::AffineForOp with `enclosedStmts`
|
||||
/// nested below it.
|
||||
edsc_stmt_t For(edsc_expr_t iv, edsc_expr_t lb, edsc_expr_t ub,
|
||||
edsc_expr_t step, edsc_stmt_list_t enclosedStmts);
|
||||
|
||||
/// Returns an opaque statement for a perfectly nested set of mlir::ForInst with
|
||||
/// `enclosedStmts` nested below it.
|
||||
/// Returns an opaque statement for a perfectly nested set of mlir::AffineForOp
|
||||
/// with `enclosedStmts` nested below it.
|
||||
edsc_stmt_t ForNest(edsc_expr_list_t iv, edsc_expr_list_t lb,
|
||||
edsc_expr_list_t ub, edsc_expr_list_t step,
|
||||
edsc_stmt_list_t enclosedStmts);
|
||||
|
||||
/// Returns an opaque statement for an mlir::AffineForOp with the lower bound
|
||||
/// `max(lbs)` and the upper bound `min(ubs)`, and with `enclosedStmts` nested
|
||||
/// below it.
|
||||
edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_expr_list_t lbs,
|
||||
edsc_expr_list_t ubs, edsc_expr_t step,
|
||||
edsc_stmt_list_t enclosedStmts);
|
||||
|
||||
/// Returns an opaque expression for the corresponding Binary operation.
|
||||
edsc_expr_t Add(edsc_expr_t e1, edsc_expr_t e2);
|
||||
edsc_expr_t Sub(edsc_expr_t e1, edsc_expr_t e2);
|
||||
|
|
|
@ -183,6 +183,9 @@ public:
|
|||
/// For debugging purposes.
|
||||
const void *getStoragePtr() const { return storage; }
|
||||
|
||||
/// Explicit conversion to bool. Useful in conjunction with dyn_cast.
|
||||
explicit operator bool() const { return storage != nullptr; }
|
||||
|
||||
friend ::llvm::hash_code hash_value(Expr arg);
|
||||
|
||||
protected:
|
||||
|
@ -285,8 +288,19 @@ struct StmtBlockLikeExpr : public Expr {
|
|||
friend class Expr;
|
||||
StmtBlockLikeExpr(ExprKind kind, llvm::ArrayRef<Expr> exprs,
|
||||
llvm::ArrayRef<Type> types = {});
|
||||
|
||||
/// Get the list of subexpressions.
|
||||
/// StmtBlockLikeExprs can contain multiple groups of subexpressions separated
|
||||
/// by null expressions and the result of this call will include them.
|
||||
llvm::ArrayRef<Expr> getExprs() const;
|
||||
|
||||
/// Get the list of subexpression groups.
|
||||
/// StmtBlockLikeExprs can contain multiple groups of subexpressions separated
|
||||
/// by null expressions. This will identify those groups and return a list
|
||||
/// of lists of subexpressions split around null expressions. Two null
|
||||
/// expressions in a row identify an empty group.
|
||||
SmallVector<llvm::ArrayRef<Expr>, 4> getExprGroups() const;
|
||||
|
||||
protected:
|
||||
StmtBlockLikeExpr(Expr::ImplType *ptr) : Expr(ptr) {
|
||||
assert(!ptr || isa<StmtBlockLikeExpr>() && "expected StmtBlockLikeExpr");
|
||||
|
@ -605,6 +619,13 @@ Stmt For(llvm::ArrayRef<Expr> indices, llvm::ArrayRef<Expr> lbs,
|
|||
llvm::ArrayRef<Expr> ubs, llvm::ArrayRef<Expr> steps,
|
||||
llvm::ArrayRef<Stmt> enclosedStmts);
|
||||
|
||||
/// Define a 'for' loop from with multi-valued bounds.
|
||||
///
|
||||
/// for max(lbs...) to min(ubs...) {}
|
||||
///
|
||||
Stmt MaxMinFor(const Bindable &idx, ArrayRef<Expr> lbs, ArrayRef<Expr> ubs,
|
||||
Expr step, ArrayRef<Stmt> enclosedStmts);
|
||||
|
||||
StmtBlock block(llvm::ArrayRef<Bindable> args, llvm::ArrayRef<Type> argTypes,
|
||||
llvm::ArrayRef<Stmt> stmts);
|
||||
inline StmtBlock block(llvm::ArrayRef<Stmt> stmts) {
|
||||
|
|
|
@ -118,6 +118,33 @@ PassResult LowerEDSCTestPass::runOnFunction(Function *f) {
|
|||
|
||||
return success();
|
||||
}
|
||||
if (f->getName().strref() == "max_min_for") {
|
||||
assert(!f->getBlocks().empty() && "max_min_for should not be empty");
|
||||
FuncBuilder builder(&f->getBlocks().front(),
|
||||
f->getBlocks().front().begin());
|
||||
assert(f->getNumArguments() == 4 && "max_min_for expected 4 arguments");
|
||||
for (const auto *arg : f->getArguments())
|
||||
assert(arg->getType().isIndex() &&
|
||||
"max_min_for expected index arguments");
|
||||
|
||||
edsc::ScopedEDSCContext context;
|
||||
edsc::Expr lb1(f->getArgument(0)->getType());
|
||||
edsc::Expr lb2(f->getArgument(1)->getType());
|
||||
edsc::Expr ub1(f->getArgument(2)->getType());
|
||||
edsc::Expr ub2(f->getArgument(3)->getType());
|
||||
edsc::Expr iv(builder.getIndexType());
|
||||
edsc::Expr step = edsc::constantInteger(builder.getIndexType(), 1);
|
||||
auto loop =
|
||||
edsc::MaxMinFor(edsc::Bindable(iv), {lb1, lb2}, {ub1, ub2}, step, {});
|
||||
edsc::MLIREmitter(&builder, f->getLoc())
|
||||
.bind(edsc::Bindable(lb1), f->getArgument(0))
|
||||
.bind(edsc::Bindable(lb2), f->getArgument(1))
|
||||
.bind(edsc::Bindable(ub1), f->getArgument(2))
|
||||
.bind(edsc::Bindable(ub2), f->getArgument(3))
|
||||
.emitStmt(loop);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// Inject an EDSC-constructed computation that assigns Stmt and uses the LHS.
|
||||
if (f->getName().strref().contains("assignments")) {
|
||||
|
|
|
@ -82,6 +82,38 @@ MLIREmitter &mlir::edsc::MLIREmitter::bind(Bindable e, Value *v) {
|
|||
return *this;
|
||||
}
|
||||
|
||||
static void checkAffineProvenance(ArrayRef<Value *> values) {
|
||||
for (Value *v : values) {
|
||||
auto *def = v->getDefiningInst();
|
||||
// There may be no defining instruction if the value is a function
|
||||
// argument. We accept such values.
|
||||
assert((!def || def->isa<ConstantIndexOp>() || def->isa<AffineApplyOp>() ||
|
||||
def->isa<AffineForOp>() || def->isa<DimOp>()) &&
|
||||
"loop bound expression must have affine provenance");
|
||||
}
|
||||
}
|
||||
|
||||
static OpPointer<AffineForOp> emitStaticFor(FuncBuilder &builder, Location loc,
|
||||
ArrayRef<Value *> lbs,
|
||||
ArrayRef<Value *> ubs,
|
||||
uint64_t step) {
|
||||
if (lbs.size() != 1 || ubs.size() != 1)
|
||||
return OpPointer<AffineForOp>();
|
||||
|
||||
auto *lbDef = lbs.front()->getDefiningInst();
|
||||
auto *ubDef = ubs.front()->getDefiningInst();
|
||||
if (!lbDef || !ubDef)
|
||||
return OpPointer<AffineForOp>();
|
||||
|
||||
auto lbConst = lbDef->dyn_cast<ConstantIndexOp>();
|
||||
auto ubConst = ubDef->dyn_cast<ConstantIndexOp>();
|
||||
if (!lbConst || !ubConst)
|
||||
return OpPointer<AffineForOp>();
|
||||
|
||||
return builder.create<AffineForOp>(loc, lbConst->getValue(),
|
||||
ubConst->getValue(), step);
|
||||
}
|
||||
|
||||
Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) {
|
||||
// It is still necessary in case we try to emit a bindable directly
|
||||
// FIXME: make sure isa<Bindable> works and use it below to delegate emission
|
||||
|
@ -104,48 +136,37 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) {
|
|||
|
||||
if (auto expr = e.dyn_cast<StmtBlockLikeExpr>()) {
|
||||
if (expr.getKind() == ExprKind::For) {
|
||||
auto exprs = emitExprs(expr.getExprs());
|
||||
if (llvm::any_of(exprs, [](Value *v) { return !v; })) {
|
||||
return nullptr;
|
||||
}
|
||||
assert(exprs.size() == 3 && "Expected 3 exprs");
|
||||
auto *lb = exprs[0];
|
||||
auto *ub = exprs[1];
|
||||
auto exprGroups = expr.getExprGroups();
|
||||
assert(exprGroups.size() == 3 && "expected 3 expr groups in `for`");
|
||||
assert(!exprGroups[0].empty() && "expected at least one lower bound");
|
||||
assert(!exprGroups[1].empty() && "expected at least one upper bound");
|
||||
assert(exprGroups[2].size() == 1 &&
|
||||
"the third group (step) must have one element");
|
||||
|
||||
// There may be no defining instruction if the value is a function
|
||||
// argument. We accept such values.
|
||||
auto *lbDef = lb->getDefiningInst();
|
||||
(void)lbDef;
|
||||
assert((!lbDef || lbDef->isa<ConstantIndexOp>() ||
|
||||
lbDef->isa<AffineApplyOp>() || lbDef->isa<AffineForOp>() ||
|
||||
lbDef->isa<DimOp>()) &&
|
||||
"lower bound expression does not have affine provenance");
|
||||
auto *ubDef = ub->getDefiningInst();
|
||||
(void)ubDef;
|
||||
assert((!ubDef || ubDef->isa<ConstantIndexOp>() ||
|
||||
ubDef->isa<AffineApplyOp>() || ubDef->isa<AffineForOp>() ||
|
||||
ubDef->isa<DimOp>()) &&
|
||||
"upper bound expression does not have affine provenance");
|
||||
auto lbs = emitExprs(exprGroups[0]);
|
||||
auto ubs = emitExprs(exprGroups[1]);
|
||||
auto stepExpr = emitExpr(exprGroups[2][0]);
|
||||
|
||||
if (llvm::any_of(lbs, [](Value *v) { return !v; }) ||
|
||||
llvm::any_of(ubs, [](Value *v) { return !v; }) || !stepExpr)
|
||||
return nullptr;
|
||||
|
||||
checkAffineProvenance(lbs);
|
||||
checkAffineProvenance(ubs);
|
||||
|
||||
// Step must be a static constant.
|
||||
auto step =
|
||||
exprs[2]->getDefiningInst()->cast<ConstantIndexOp>()->getValue();
|
||||
stepExpr->getDefiningInst()->cast<ConstantIndexOp>()->getValue();
|
||||
|
||||
// Special case with more concise emitted code for static bounds.
|
||||
OpPointer<AffineForOp> forOp;
|
||||
if (lbDef && ubDef)
|
||||
if (auto lbConst = lbDef->dyn_cast<ConstantIndexOp>())
|
||||
if (auto ubConst = ubDef->dyn_cast<ConstantIndexOp>())
|
||||
forOp = builder->create<AffineForOp>(location, lbConst->getValue(),
|
||||
ubConst->getValue(), step);
|
||||
OpPointer<AffineForOp> forOp =
|
||||
emitStaticFor(*builder, location, lbs, ubs, step);
|
||||
|
||||
// General case.
|
||||
if (!forOp) {
|
||||
auto map = builder->getDimIdentityMap();
|
||||
forOp =
|
||||
builder->create<AffineForOp>(location, llvm::makeArrayRef(lb), map,
|
||||
llvm::makeArrayRef(ub), map, step);
|
||||
}
|
||||
if (!forOp)
|
||||
forOp = builder->create<AffineForOp>(
|
||||
location, lbs, builder->getMultiDimIdentityMap(lbs.size()), ubs,
|
||||
builder->getMultiDimIdentityMap(ubs.size()), step);
|
||||
forOp->createBody();
|
||||
res = forOp->getInductionVar();
|
||||
}
|
||||
|
|
|
@ -359,7 +359,14 @@ Stmt mlir::edsc::For(Expr lb, Expr ub, Expr step, ArrayRef<Stmt> stmts) {
|
|||
|
||||
Stmt mlir::edsc::For(const Bindable &idx, Expr lb, Expr ub, Expr step,
|
||||
ArrayRef<Stmt> stmts) {
|
||||
return Stmt(idx, StmtBlockLikeExpr(ExprKind::For, {lb, ub, step}), stmts);
|
||||
assert(lb);
|
||||
assert(ub);
|
||||
assert(step);
|
||||
// Use a null expression as a sentinel between lower and upper bound
|
||||
// expressions in the list of children.
|
||||
return Stmt(
|
||||
idx, StmtBlockLikeExpr(ExprKind::For, {lb, nullptr, ub, nullptr, step}),
|
||||
stmts);
|
||||
}
|
||||
|
||||
Stmt mlir::edsc::For(ArrayRef<Expr> indices, ArrayRef<Expr> lbs,
|
||||
|
@ -380,6 +387,24 @@ Stmt mlir::edsc::For(ArrayRef<Expr> indices, ArrayRef<Expr> lbs,
|
|||
return curStmt;
|
||||
}
|
||||
|
||||
Stmt mlir::edsc::MaxMinFor(const Bindable &idx, ArrayRef<Expr> lbs,
|
||||
ArrayRef<Expr> ubs, Expr step,
|
||||
ArrayRef<Stmt> enclosedStmts) {
|
||||
assert(!lbs.empty() && "'for' loop must have lower bounds");
|
||||
assert(!ubs.empty() && "'for' loop must have upper bounds");
|
||||
|
||||
// Use a null expression as a sentinel between lower and upper bound
|
||||
// expressions in the list of children.
|
||||
SmallVector<Expr, 8> exprs;
|
||||
exprs.insert(exprs.end(), lbs.begin(), lbs.end());
|
||||
exprs.push_back(nullptr);
|
||||
exprs.insert(exprs.end(), ubs.begin(), ubs.end());
|
||||
exprs.push_back(nullptr);
|
||||
exprs.push_back(step);
|
||||
|
||||
return Stmt(idx, StmtBlockLikeExpr(ExprKind::For, exprs), enclosedStmts);
|
||||
}
|
||||
|
||||
edsc_stmt_t For(edsc_expr_t iv, edsc_expr_t lb, edsc_expr_t ub,
|
||||
edsc_expr_t step, edsc_stmt_list_t enclosedStmts) {
|
||||
llvm::SmallVector<Stmt, 8> stmts;
|
||||
|
@ -397,6 +422,15 @@ edsc_stmt_t ForNest(edsc_expr_list_t ivs, edsc_expr_list_t lbs,
|
|||
makeExprs(steps), stmts));
|
||||
}
|
||||
|
||||
edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_expr_list_t lbs,
|
||||
edsc_expr_list_t ubs, edsc_expr_t step,
|
||||
edsc_stmt_list_t enclosedStmts) {
|
||||
llvm::SmallVector<Stmt, 8> stmts;
|
||||
fillStmts(enclosedStmts, &stmts);
|
||||
return Stmt(MaxMinFor(Expr(iv).cast<Bindable>(), makeExprs(lbs),
|
||||
makeExprs(ubs), Expr(step), stmts));
|
||||
}
|
||||
|
||||
StmtBlock mlir::edsc::block(ArrayRef<Bindable> args, ArrayRef<Type> argTypes,
|
||||
ArrayRef<Stmt> stmts) {
|
||||
assert(args.size() == argTypes.size() &&
|
||||
|
@ -669,14 +703,26 @@ void mlir::edsc::Expr::print(raw_ostream &os) const {
|
|||
os << ')';
|
||||
return;
|
||||
} else if (auto stmtLikeExpr = this->dyn_cast<StmtBlockLikeExpr>()) {
|
||||
auto exprs = stmtLikeExpr.getExprs();
|
||||
switch (stmtLikeExpr.getKind()) {
|
||||
// We only print the lb, ub and step here, which are the StmtBlockLike
|
||||
// part of the `for` StmtBlockLikeExpr.
|
||||
case ExprKind::For:
|
||||
assert(exprs.size() == 3 && "For StmtBlockLikeExpr expected 3 exprs");
|
||||
os << exprs[0] << " to " << exprs[1] << " step " << exprs[2];
|
||||
case ExprKind::For: {
|
||||
auto exprGroups = stmtLikeExpr.getExprGroups();
|
||||
assert(exprGroups.size() == 3 &&
|
||||
"For StmtBlockLikeExpr expected 3 groups");
|
||||
assert(exprGroups[2].size() == 1 && "expected 1 expr for loop step");
|
||||
if (exprGroups[0].size() == 1 && exprGroups[1].size() == 1) {
|
||||
os << exprGroups[0][0] << " to " << exprGroups[1][0] << " step "
|
||||
<< exprGroups[2][0];
|
||||
} else {
|
||||
os << "max(";
|
||||
interleaveComma(exprGroups[0], os);
|
||||
os << ") to min(";
|
||||
interleaveComma(exprGroups[1], os);
|
||||
os << ") step " << exprGroups[2][0];
|
||||
}
|
||||
return;
|
||||
}
|
||||
default: {
|
||||
os << "unknown_stmt";
|
||||
}
|
||||
|
@ -772,6 +818,20 @@ mlir::edsc::StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind,
|
|||
ArrayRef<Expr> mlir::edsc::StmtBlockLikeExpr::getExprs() const {
|
||||
return static_cast<ImplType *>(storage)->operands;
|
||||
}
|
||||
SmallVector<ArrayRef<Expr>, 4>
|
||||
mlir::edsc::StmtBlockLikeExpr::getExprGroups() const {
|
||||
SmallVector<ArrayRef<Expr>, 4> groups;
|
||||
ArrayRef<Expr> exprs = getExprs();
|
||||
int start = 0;
|
||||
for (int i = 0, e = exprs.size(); i < e; ++i) {
|
||||
if (!exprs[i]) {
|
||||
groups.push_back(exprs.slice(start, i - start));
|
||||
start = i + 1;
|
||||
}
|
||||
}
|
||||
groups.push_back(exprs.slice(start, exprs.size() - start));
|
||||
return groups;
|
||||
}
|
||||
|
||||
mlir::edsc::Stmt::Stmt(const Bindable &lhs, const Expr &rhs,
|
||||
llvm::ArrayRef<Stmt> enclosedStmts) {
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
// CHECK-DAG: #[[addmap:.*]] = (d0, d1) -> (d0 + d1)
|
||||
// CHECK-DAG: #[[prodconstmap:.*]] = (d0, d1) -> (d0 * 3)
|
||||
// CHECK-DAG: #[[addconstmap:.*]] = (d0, d1) -> (d1 + 3)
|
||||
// CHECK-DAG: #[[id2dmap:.*]] = (d0, d1) -> (d0, d1)
|
||||
|
||||
// This function will be detected by the test pass that will insert
|
||||
// EDSC-constructed blocks with arguments.
|
||||
|
@ -70,3 +71,14 @@ func @assignments_1(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4x
|
|||
func @assignments_2(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
|
||||
return
|
||||
}
|
||||
|
||||
// This function will be detected by the test pass that will insert an
|
||||
// EDSC-constructed empty `for` loop with max/min bounds that correspond to
|
||||
// for max(%arg0, %arg1) to (%arg2, %arg3) step 1
|
||||
// before the `return` instruction.
|
||||
//
|
||||
// CHECK-LABEL: func @max_min_for(%arg0: index, %arg1: index, %arg2: index, %arg3: index) {
|
||||
// CHECK: for %i0 = max #[[id2dmap]](%arg0, %arg1) to min #[[id2dmap]](%arg2, %arg3) {
|
||||
func @max_min_for(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index) {
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue