[mlir][Linalg] Fix Linalg EDSC builders

Summary:
This diff fixes the fact that the method `mlir::edsc::makeGenericLinalgOp`
incorrectly adds 2 blocks to Linalg ops.

Tests are updated accordingly.

Reviewers: ftynse, hanchung, herhut, pifon2a, asaadaldien

Reviewed By: asaadaldien

Subscribers: merge_guards_bot, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D72780
This commit is contained in:
Nicolas Vasilache 2020-01-16 09:30:17 -05:00
parent 4f5c65a5c8
commit 2b81d3c6c6
4 changed files with 56 additions and 8 deletions

View File

@ -251,6 +251,16 @@ public:
/// not yet bound to mlir::Value. /// not yet bound to mlir::Value.
BlockBuilder(BlockHandle *bh, ArrayRef<ValueHandle *> args); BlockBuilder(BlockHandle *bh, ArrayRef<ValueHandle *> args);
/// Constructs a new mlir::Block with argument types derived from `args` and
/// appends it as the last block in the region.
/// Captures the new block in `bh` and its arguments into `args`.
/// Enters the new mlir::Block* and sets the insertion point to its end.
///
/// Prerequisites:
/// The ValueHandle `args` are typed delayed ValueHandles; i.e. they are
/// not yet bound to mlir::Value.
BlockBuilder(BlockHandle *bh, Region &region, ArrayRef<ValueHandle *> args);
/// The only purpose of this operator is to serve as a sequence point so that /// The only purpose of this operator is to serve as a sequence point so that
/// the evaluation of `fun` (which build IR snippets in a scoped fashion) is /// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
/// scoped within a BlockBuilder. /// scoped within a BlockBuilder.
@ -450,6 +460,9 @@ public:
/// Delegates block creation to MLIR and wrap the resulting mlir::Block. /// Delegates block creation to MLIR and wrap the resulting mlir::Block.
static BlockHandle create(ArrayRef<Type> argTypes); static BlockHandle create(ArrayRef<Type> argTypes);
/// Delegates block creation to MLIR and wrap the resulting mlir::Block.
static BlockHandle createInRegion(Region &region, ArrayRef<Type> argTypes);
operator bool() { return block != nullptr; } operator bool() { return block != nullptr; }
operator mlir::Block *() { return block; } operator mlir::Block *() { return block; }
mlir::Block *getBlock() { return block; } mlir::Block *getBlock() { return block; }

View File

@ -184,14 +184,16 @@ Operation *mlir::edsc::makeGenericLinalgOp(
? getElementTypeOrSelf(it.value()) ? getElementTypeOrSelf(it.value())
: it.value().getType()); : it.value().getType());
assert(op->getRegions().front().empty()); assert(op->getNumRegions() == 1);
op->getRegions().front().push_front(new Block); assert(op->getRegion(0).empty());
OpBuilder bb(op->getRegions().front()); OpBuilder opBuilder(op);
ScopedContext scope(bb, op->getLoc()); ScopedContext scope(opBuilder, op->getLoc());
BlockHandle b; BlockHandle b;
auto handles = makeValueHandles(blockTypes); auto handles = makeValueHandles(blockTypes);
BlockBuilder(&b, makeHandlePointers(MutableArrayRef<ValueHandle>(handles)))( BlockBuilder(&b, op->getRegion(0),
makeHandlePointers(MutableArrayRef<ValueHandle>(handles)))(
[&] { regionBuilder(b.getBlock()->getArguments()); }); [&] { regionBuilder(b.getBlock()->getArguments()); });
assert(op->getRegion(0).getBlocks().size() == 1);
return op; return op;
} }

View File

@ -133,6 +133,22 @@ BlockHandle mlir::edsc::BlockHandle::create(ArrayRef<Type> argTypes) {
return res; return res;
} }
BlockHandle mlir::edsc::BlockHandle::createInRegion(Region &region,
ArrayRef<Type> argTypes) {
auto &currentB = ScopedContext::getBuilder();
BlockHandle res;
region.push_back(new Block);
res.block = &region.back();
// createBlock sets the insertion point inside the block.
// We do not want this behavior when using declarative builders with nesting.
OpBuilder::InsertionGuard g(currentB);
currentB.setInsertionPoint(res.block, res.block->begin());
for (auto t : argTypes) {
res.block->addArgument(t);
}
return res;
}
static Optional<ValueHandle> emitStaticFor(ArrayRef<ValueHandle> lbs, static Optional<ValueHandle> emitStaticFor(ArrayRef<ValueHandle> lbs,
ArrayRef<ValueHandle> ubs, ArrayRef<ValueHandle> ubs,
int64_t step) { int64_t step) {
@ -285,6 +301,23 @@ mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh,
enter(bh->getBlock()); enter(bh->getBlock());
} }
mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh, Region &region,
ArrayRef<ValueHandle *> args) {
assert(!*bh && "BlockHandle already captures a block, use "
"the explicit BockBuilder(bh, Append())({}) syntax instead.");
SmallVector<Type, 8> types;
for (auto *a : args) {
assert(!a->hasValue() &&
"Expected delayed ValueHandle that has not yet captured.");
types.push_back(a->getType());
}
*bh = BlockHandle::createInRegion(region, types);
for (auto it : llvm::zip(args, bh->getBlock()->getArguments())) {
*(std::get<0>(it)) = ValueHandle(std::get<1>(it));
}
enter(bh->getBlock());
}
/// Only serves as an ordering point between entering nested block and creating /// Only serves as an ordering point between entering nested block and creating
/// stmts. /// stmts.
void mlir::edsc::BlockBuilder::operator()(function_ref<void(void)> fun) { void mlir::edsc::BlockBuilder::operator()(function_ref<void(void)> fun) {

View File

@ -876,7 +876,7 @@ TEST_FUNC(linalg_pointwise_test) {
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, // CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], // CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
/// CHECK: ^bb1(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32): /// CHECK: ^bb0(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32):
// CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32 // CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32
// CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32 // CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32
// CHECK: linalg.yield %[[a4]] : f32 // CHECK: linalg.yield %[[a4]] : f32
@ -906,7 +906,7 @@ TEST_FUNC(linalg_matmul_test) {
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d1)>, // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d1)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d1)>], // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]}
/// CHECK: ^bb1(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32): /// CHECK: ^bb0(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32):
// CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32 // CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32
// CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32 // CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32
// CHECK: linalg.yield %[[a4]] : f32 // CHECK: linalg.yield %[[a4]] : f32
@ -937,7 +937,7 @@ TEST_FUNC(linalg_conv_nhwc) {
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d2, d1)>, // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d2, d1)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d3, d4, d1 + d2 * 7)>], // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d3, d4, d1 + d2 * 7)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
// CHECK: ^bb1(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32): // CHECK: ^bb0(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32):
// CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32 // CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32
// CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32 // CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32
// CHECK: linalg.yield %[[a4]] : f32 // CHECK: linalg.yield %[[a4]] : f32