[mlir][linalg] Fix crash in tileAndFuseLinalgOpToParallelLoopsAndSetMarker

Instead of using llvm_unreachable to guard against fusing linalg.conv,
reject fusing linalg.conv in isFusableInto.

tileLinalgOpImpl is a templated function now and it can operate on
loop.parellel. So we should avoid calling into getForInductionVarOwner
which always assumes loop.for.

Differential Revision: https://reviews.llvm.org/D78936
This commit is contained in:
Lei Zhang 2020-04-27 11:47:39 -04:00
parent 0852babc30
commit a5bfd32c07
2 changed files with 14 additions and 13 deletions

View File

@ -161,17 +161,6 @@ static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
assert(consumer.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
// TODO(ntv): add a level of indirection to linalg.generic.
if (convOp.padding())
llvm_unreachable("Unexpected conv with padding");
}
if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
// TODO(ntv): add a level of indirection to linalg.generic.
if (convOp.padding())
llvm_unreachable("Unexpected conv with padding");
}
auto subView = dyn_cast_or_null<SubViewOp>(
consumer.getBuffer(consumerIdx).getDefiningOp());
auto slice = dyn_cast_or_null<SliceOp>(
@ -287,6 +276,16 @@ bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
<< *producer.getOperation());
return false;
}
if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
// TODO(ntv): add a level of indirection to linalg.generic.
if (convOp.padding())
return false;
}
if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
// TODO(ntv): add a level of indirection to linalg.generic.
if (convOp.padding())
return false;
}
return true;
}

View File

@ -409,8 +409,10 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(OpBuilder &b, LinalgOp op,
// 5. Gather the newly created loops and return them with the new op.
SmallVector<Operation *, 8> loops;
loops.reserve(ivs.size());
for (auto iv : ivs)
loops.push_back(loop::getForInductionVarOwner(iv));
for (auto iv : ivs) {
loops.push_back(iv.cast<BlockArgument>().getOwner()->getParentOp());
assert(loops.back() && "no owner found for induction variable!");
}
return TiledLinalgOp{res, loops};
}