diff --git a/mlir/include/mlir/EDSC/Intrinsics.h b/mlir/include/mlir/EDSC/Intrinsics.h index 00c36001b502..fcaf7a1c5010 100644 --- a/mlir/include/mlir/EDSC/Intrinsics.h +++ b/mlir/include/mlir/EDSC/Intrinsics.h @@ -43,6 +43,22 @@ namespace intrinsics { /// All Handles have already captured previously constructed IR objects. ValueHandle BR(BlockHandle bh, ArrayRef operands); +/// Creates a new mlir::Block* and branches to it from the current block. +/// Argument types are specified by `operands`. +/// Captures the new block in `bh` and the actual `operands` in `captures`. To +/// insert the new mlir::Block*, a local ScopedContext is constructed and +/// released to the current block. The branch instruction is then added to the +/// new block. +/// +/// Prerequisites: +/// `b` has not yet captured an mlir::Block*. +/// No `captures` have captured any mlir::Value*. +/// All `operands` have already captured an mlir::Value* +/// captures.size() == operands.size() +/// captures and operands are pairwise of the same type. +ValueHandle BR(BlockHandle *bh, ArrayRef captures, + ArrayRef operands); + /// Branches into the mlir::Block* captured by BlockHandle `trueBranch` with /// `trueOperands` if `cond` evaluates to `true` (resp. `falseBranch` and /// `falseOperand` if `cond` evaluates to `false`). @@ -53,6 +69,29 @@ ValueHandle COND_BR(ValueHandle cond, BlockHandle trueBranch, ArrayRef trueOperands, BlockHandle falseBranch, ArrayRef falseOperands); +/// Eagerly creates new mlir::Block* with argument types specified by +/// `trueOperands`/`falseOperands`. +/// Captures the new blocks in `trueBranch`/`falseBranch` and the arguments in +/// `trueCaptures/falseCaptures`. +/// To insert the new mlir::Block*, a local ScopedContext is constructed and +/// released. The branch instruction is then added in the original location and +/// targeting the eagerly constructed blocks. +/// +/// Prerequisites: +/// `trueBranch`/`falseBranch` has not yet captured an mlir::Block*. +/// No `trueCaptures`/`falseCaptures` have captured any mlir::Value*. +/// All `trueOperands`/`trueOperands` have already captured an mlir::Value* +/// `trueCaptures`.size() == `trueOperands`.size() +/// `falseCaptures`.size() == `falseOperands`.size() +/// `trueCaptures` and `trueOperands` are pairwise of the same type +/// `falseCaptures` and `falseOperands` are pairwise of the same type. +ValueHandle COND_BR(ValueHandle cond, BlockHandle *trueBranch, + ArrayRef trueCaptures, + ArrayRef trueOperands, + BlockHandle *falseBranch, + ArrayRef falseCaptures, + ArrayRef falseOperands); + //////////////////////////////////////////////////////////////////////////////// // TODO(ntv): Intrinsics below this line should be TableGen'd. //////////////////////////////////////////////////////////////////////////////// diff --git a/mlir/lib/EDSC/Intrinsics.cpp b/mlir/lib/EDSC/Intrinsics.cpp index 431609b144ba..ea3287980f11 100644 --- a/mlir/lib/EDSC/Intrinsics.cpp +++ b/mlir/lib/EDSC/Intrinsics.cpp @@ -31,6 +31,36 @@ ValueHandle mlir::edsc::intrinsics::BR(BlockHandle bh, SmallVector ops(operands.begin(), operands.end()); return ValueHandle::create(bh.getBlock(), ops); } +static void enforceEmptyCapturesMatchOperands(ArrayRef captures, + ArrayRef operands) { + assert(captures.size() == operands.size() && + "Expected same number of captures as operands"); + for (auto it : llvm::zip(captures, operands)) { + (void)it; + assert(!std::get<0>(it)->hasValue() && + "Unexpected already captured ValueHandle"); + assert(std::get<1>(it) && "Expected already captured ValueHandle"); + assert(std::get<0>(it)->getType() == std::get<1>(it).getType() && + "Expected the same type for capture and operand"); + } +} + +ValueHandle mlir::edsc::intrinsics::BR(BlockHandle *bh, + ArrayRef captures, + ArrayRef operands) { + assert(!*bh && "Unexpected already captured BlockHandle"); + enforceEmptyCapturesMatchOperands(captures, operands); + { // Clone the scope explicitly to avoid modifying the insertion point in the + // current scope which result in surprising usage. + auto *currentB = ScopedContext::getBuilder(); + FuncBuilder b(currentB->getInsertionBlock(), currentB->getInsertionPoint()); + Location loc = ScopedContext::getLocation(); + ScopedContext scope(b, loc); + BlockBuilder(bh, captures)({/* no body */}); + } // Release before adding the branch to the eagerly created block. + SmallVector ops(operands.begin(), operands.end()); + return ValueHandle::create(bh->getBlock(), ops); +} ValueHandle mlir::edsc::intrinsics::COND_BR(ValueHandle cond, BlockHandle trueBranch, @@ -43,9 +73,33 @@ mlir::edsc::intrinsics::COND_BR(ValueHandle cond, BlockHandle trueBranch, falseBranch.getBlock(), falseOps); } +ValueHandle mlir::edsc::intrinsics::COND_BR( + ValueHandle cond, BlockHandle *trueBranch, + ArrayRef trueCaptures, ArrayRef trueOperands, + BlockHandle *falseBranch, ArrayRef falseCaptures, + ArrayRef falseOperands) { + assert(!*trueBranch && "Unexpected already captured BlockHandle"); + assert(!*falseBranch && "Unexpected already captured BlockHandle"); + enforceEmptyCapturesMatchOperands(trueCaptures, trueOperands); + enforceEmptyCapturesMatchOperands(falseCaptures, falseOperands); + { // Clone the scope explicitly. + auto *currentB = ScopedContext::getBuilder(); + FuncBuilder b(currentB->getInsertionBlock(), currentB->getInsertionPoint()); + Location loc = ScopedContext::getLocation(); + ScopedContext scope(b, loc); + BlockBuilder(trueBranch, trueCaptures)({/* no body */}); + BlockBuilder(falseBranch, falseCaptures)({/* no body */}); + } // Release before adding the branch to the eagerly created block. + SmallVector trueOps(trueOperands.begin(), trueOperands.end()); + SmallVector falseOps(falseOperands.begin(), falseOperands.end()); + return ValueHandle::create( + cond, trueBranch->getBlock(), trueOps, falseBranch->getBlock(), falseOps); +} + //////////////////////////////////////////////////////////////////////////////// // TODO(ntv): Intrinsics below this line should be TableGen'd. //////////////////////////////////////////////////////////////////////////////// + ValueHandle mlir::edsc::intrinsics::RETURN(ArrayRef operands) { SmallVector ops(operands.begin(), operands.end()); return ValueHandle::create(ops); diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index 30ad22522143..09cb26777750 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -197,6 +197,50 @@ TEST_FUNC(builder_blocks) { f->print(llvm::outs()); } +TEST_FUNC(builder_blocks_eager) { + using namespace edsc; + using namespace edsc::intrinsics; + using namespace edsc::op; + auto f = makeFunction("builder_blocks_eager"); + + ScopedContext scope(f.get()); + ValueHandle c1(ValueHandle::create(42, 32)), + c2(ValueHandle::create(1234, 32)); + ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()), + arg4(c1.getType()), r(c1.getType()); + + // clang-format off + BlockHandle b1, b2; + { // Toplevel function scope. + BR(&b1, {&arg1, &arg2}, {c1, c2}); // eagerly builds a new block for b1 + // We cannot construct b2 eagerly with a `BR(&b2, ...)` call from within b1 + // because it would result in b2 being nested under b1 which is not what we + // want in this test. + BlockBuilder(&b2, {&arg3, &arg4})({ + // Instead, construct explicitly + BR(b1, {arg3, arg4}), + }); + /// And come back to append into b1 once b2 exists. + BlockBuilder(b1, Append())({ + r = arg1 + arg2, + BR(b2, {arg1, r}), + }); + } + + // CHECK-LABEL: @builder_blocks_eager + // CHECK: %c42_i32 = constant 42 : i32 + // CHECK-NEXT: %c1234_i32 = constant 1234 : i32 + // CHECK-NEXT: br ^bb1(%c42_i32, %c1234_i32 : i32, i32) + // CHECK-NEXT: ^bb1(%0: i32, %1: i32): // 2 preds: ^bb0, ^bb2 + // CHECK-NEXT: %2 = addi %0, %1 : i32 + // CHECK-NEXT: br ^bb2(%0, %2 : i32, i32) + // CHECK-NEXT: ^bb2(%3: i32, %4: i32): // pred: ^bb1 + // CHECK-NEXT: br ^bb1(%3, %4 : i32, i32) + // CHECK-NEXT: } + // clang-format on + f->print(llvm::outs()); +} + TEST_FUNC(builder_cond_branch) { using namespace edsc; using namespace edsc::intrinsics; @@ -211,7 +255,6 @@ TEST_FUNC(builder_cond_branch) { ValueHandle arg1(c32.getType()), arg2(c64.getType()), arg3(c32.getType()); BlockHandle b1, b2, functionBlock(&f->front()); - ; BlockBuilder(&b1, {&arg1})({ RETURN({}), }); @@ -237,6 +280,43 @@ TEST_FUNC(builder_cond_branch) { f->print(llvm::outs()); } +TEST_FUNC(builder_cond_branch_eager) { + using namespace edsc; + using namespace edsc::intrinsics; + using namespace edsc::op; + auto f = makeFunction("builder_cond_branch_eager", {}, + {IntegerType::get(1, &globalContext())}); + + ScopedContext scope(f.get()); + ValueHandle funcArg(f->getArgument(0)); + ValueHandle c32(ValueHandle::create(32, 32)), + c64(ValueHandle::create(64, 64)), + c42(ValueHandle::create(42, 32)); + ValueHandle arg1(c32.getType()), arg2(c64.getType()), arg3(c32.getType()); + + // clang-format off + BlockHandle b1, b2; + COND_BR(funcArg, &b1, {&arg1}, {c32}, &b2, {&arg2, &arg3}, {c64, c42}); + BlockBuilder(b1, Append())({ + RETURN({}), + }); + BlockBuilder(b2, Append())({ + RETURN({}), + }); + + // CHECK-LABEL: @builder_cond_branch_eager + // CHECK: %c32_i32 = constant 32 : i32 + // CHECK-NEXT: %c64_i64 = constant 64 : i64 + // CHECK-NEXT: %c42_i32 = constant 42 : i32 + // CHECK-NEXT: cond_br %arg0, ^bb1(%c32_i32 : i32), ^bb2(%c64_i64, %c42_i32 : i64, i32) + // CHECK-NEXT: ^bb1(%0: i32): // pred: ^bb0 + // CHECK-NEXT: return + // CHECK-NEXT: ^bb2(%1: i64, %2: i32): // pred: ^bb0 + // CHECK-NEXT: return + // clang-format on + f->print(llvm::outs()); +} + int main() { RUN_TESTS(); return 0;