forked from OSchip/llvm-project
[mlir][Linalg] Fix Linalg EDSC builders
Summary: This diff fixes the fact that the method `mlir::edsc::makeGenericLinalgOp` incorrectly adds 2 blocks to Linalg ops. Tests are updated accordingly. Reviewers: ftynse, hanchung, herhut, pifon2a, asaadaldien Reviewed By: asaadaldien Subscribers: merge_guards_bot, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72780
This commit is contained in:
parent
4f5c65a5c8
commit
2b81d3c6c6
|
@ -251,6 +251,16 @@ public:
|
||||||
/// not yet bound to mlir::Value.
|
/// not yet bound to mlir::Value.
|
||||||
BlockBuilder(BlockHandle *bh, ArrayRef<ValueHandle *> args);
|
BlockBuilder(BlockHandle *bh, ArrayRef<ValueHandle *> args);
|
||||||
|
|
||||||
|
/// Constructs a new mlir::Block with argument types derived from `args` and
|
||||||
|
/// appends it as the last block in the region.
|
||||||
|
/// Captures the new block in `bh` and its arguments into `args`.
|
||||||
|
/// Enters the new mlir::Block* and sets the insertion point to its end.
|
||||||
|
///
|
||||||
|
/// Prerequisites:
|
||||||
|
/// The ValueHandle `args` are typed delayed ValueHandles; i.e. they are
|
||||||
|
/// not yet bound to mlir::Value.
|
||||||
|
BlockBuilder(BlockHandle *bh, Region ®ion, ArrayRef<ValueHandle *> args);
|
||||||
|
|
||||||
/// The only purpose of this operator is to serve as a sequence point so that
|
/// The only purpose of this operator is to serve as a sequence point so that
|
||||||
/// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
|
/// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
|
||||||
/// scoped within a BlockBuilder.
|
/// scoped within a BlockBuilder.
|
||||||
|
@ -450,6 +460,9 @@ public:
|
||||||
/// Delegates block creation to MLIR and wrap the resulting mlir::Block.
|
/// Delegates block creation to MLIR and wrap the resulting mlir::Block.
|
||||||
static BlockHandle create(ArrayRef<Type> argTypes);
|
static BlockHandle create(ArrayRef<Type> argTypes);
|
||||||
|
|
||||||
|
/// Delegates block creation to MLIR and wrap the resulting mlir::Block.
|
||||||
|
static BlockHandle createInRegion(Region ®ion, ArrayRef<Type> argTypes);
|
||||||
|
|
||||||
operator bool() { return block != nullptr; }
|
operator bool() { return block != nullptr; }
|
||||||
operator mlir::Block *() { return block; }
|
operator mlir::Block *() { return block; }
|
||||||
mlir::Block *getBlock() { return block; }
|
mlir::Block *getBlock() { return block; }
|
||||||
|
|
|
@ -184,14 +184,16 @@ Operation *mlir::edsc::makeGenericLinalgOp(
|
||||||
? getElementTypeOrSelf(it.value())
|
? getElementTypeOrSelf(it.value())
|
||||||
: it.value().getType());
|
: it.value().getType());
|
||||||
|
|
||||||
assert(op->getRegions().front().empty());
|
assert(op->getNumRegions() == 1);
|
||||||
op->getRegions().front().push_front(new Block);
|
assert(op->getRegion(0).empty());
|
||||||
OpBuilder bb(op->getRegions().front());
|
OpBuilder opBuilder(op);
|
||||||
ScopedContext scope(bb, op->getLoc());
|
ScopedContext scope(opBuilder, op->getLoc());
|
||||||
BlockHandle b;
|
BlockHandle b;
|
||||||
auto handles = makeValueHandles(blockTypes);
|
auto handles = makeValueHandles(blockTypes);
|
||||||
BlockBuilder(&b, makeHandlePointers(MutableArrayRef<ValueHandle>(handles)))(
|
BlockBuilder(&b, op->getRegion(0),
|
||||||
|
makeHandlePointers(MutableArrayRef<ValueHandle>(handles)))(
|
||||||
[&] { regionBuilder(b.getBlock()->getArguments()); });
|
[&] { regionBuilder(b.getBlock()->getArguments()); });
|
||||||
|
assert(op->getRegion(0).getBlocks().size() == 1);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -133,6 +133,22 @@ BlockHandle mlir::edsc::BlockHandle::create(ArrayRef<Type> argTypes) {
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BlockHandle mlir::edsc::BlockHandle::createInRegion(Region ®ion,
|
||||||
|
ArrayRef<Type> argTypes) {
|
||||||
|
auto ¤tB = ScopedContext::getBuilder();
|
||||||
|
BlockHandle res;
|
||||||
|
region.push_back(new Block);
|
||||||
|
res.block = ®ion.back();
|
||||||
|
// createBlock sets the insertion point inside the block.
|
||||||
|
// We do not want this behavior when using declarative builders with nesting.
|
||||||
|
OpBuilder::InsertionGuard g(currentB);
|
||||||
|
currentB.setInsertionPoint(res.block, res.block->begin());
|
||||||
|
for (auto t : argTypes) {
|
||||||
|
res.block->addArgument(t);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
static Optional<ValueHandle> emitStaticFor(ArrayRef<ValueHandle> lbs,
|
static Optional<ValueHandle> emitStaticFor(ArrayRef<ValueHandle> lbs,
|
||||||
ArrayRef<ValueHandle> ubs,
|
ArrayRef<ValueHandle> ubs,
|
||||||
int64_t step) {
|
int64_t step) {
|
||||||
|
@ -285,6 +301,23 @@ mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh,
|
||||||
enter(bh->getBlock());
|
enter(bh->getBlock());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh, Region ®ion,
|
||||||
|
ArrayRef<ValueHandle *> args) {
|
||||||
|
assert(!*bh && "BlockHandle already captures a block, use "
|
||||||
|
"the explicit BockBuilder(bh, Append())({}) syntax instead.");
|
||||||
|
SmallVector<Type, 8> types;
|
||||||
|
for (auto *a : args) {
|
||||||
|
assert(!a->hasValue() &&
|
||||||
|
"Expected delayed ValueHandle that has not yet captured.");
|
||||||
|
types.push_back(a->getType());
|
||||||
|
}
|
||||||
|
*bh = BlockHandle::createInRegion(region, types);
|
||||||
|
for (auto it : llvm::zip(args, bh->getBlock()->getArguments())) {
|
||||||
|
*(std::get<0>(it)) = ValueHandle(std::get<1>(it));
|
||||||
|
}
|
||||||
|
enter(bh->getBlock());
|
||||||
|
}
|
||||||
|
|
||||||
/// Only serves as an ordering point between entering nested block and creating
|
/// Only serves as an ordering point between entering nested block and creating
|
||||||
/// stmts.
|
/// stmts.
|
||||||
void mlir::edsc::BlockBuilder::operator()(function_ref<void(void)> fun) {
|
void mlir::edsc::BlockBuilder::operator()(function_ref<void(void)> fun) {
|
||||||
|
|
|
@ -876,7 +876,7 @@ TEST_FUNC(linalg_pointwise_test) {
|
||||||
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
|
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
|
||||||
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
|
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
|
||||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
|
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
|
||||||
/// CHECK: ^bb1(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32):
|
/// CHECK: ^bb0(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32):
|
||||||
// CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32
|
// CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32
|
||||||
// CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32
|
// CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32
|
||||||
// CHECK: linalg.yield %[[a4]] : f32
|
// CHECK: linalg.yield %[[a4]] : f32
|
||||||
|
@ -906,7 +906,7 @@ TEST_FUNC(linalg_matmul_test) {
|
||||||
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d1)>,
|
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d1)>,
|
||||||
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d1)>],
|
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d1)>],
|
||||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]}
|
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]}
|
||||||
/// CHECK: ^bb1(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32):
|
/// CHECK: ^bb0(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32):
|
||||||
// CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32
|
// CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32
|
||||||
// CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32
|
// CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32
|
||||||
// CHECK: linalg.yield %[[a4]] : f32
|
// CHECK: linalg.yield %[[a4]] : f32
|
||||||
|
@ -937,7 +937,7 @@ TEST_FUNC(linalg_conv_nhwc) {
|
||||||
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d2, d1)>,
|
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d2, d1)>,
|
||||||
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d3, d4, d1 + d2 * 7)>],
|
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d3, d4, d1 + d2 * 7)>],
|
||||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
|
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
|
||||||
// CHECK: ^bb1(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32):
|
// CHECK: ^bb0(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32):
|
||||||
// CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32
|
// CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32
|
||||||
// CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32
|
// CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32
|
||||||
// CHECK: linalg.yield %[[a4]] : f32
|
// CHECK: linalg.yield %[[a4]] : f32
|
||||||
|
|
Loading…
Reference in New Issue