[mlir][SCF][bufferize][NFC] Implement resolveConflicts for ParallelInsertSliceOp

This was previous implemented as part of the BufferizableOpInterface of ForEachThreadOp. Moving the implementation to ParallelInsertSliceOp to be consistent with the remaining ops and to have a nice example op that can serve as a blueprint for other ops.

Differential Revision: https://reviews.llvm.org/D128666
This commit is contained in:
Matthias Springer 2022-06-28 12:02:28 +02:00
parent f6f53e990d
commit 04dac2ca7c
3 changed files with 70 additions and 51 deletions

View File

@ -578,6 +578,10 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
/// Return the number of leading operands before `offsets`, `sizes` and
/// `strides` operands.
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
/// Return the OpResult of the enclosing ForeachThreadOp that is
/// corresponding to this ParallelInsertSliceOp.
OpResult getTiedOpResult();
}];
let builders = [

View File

@ -1215,6 +1215,18 @@ ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) {
// ParallelInsertSliceOp
//===----------------------------------------------------------------------===//
OpResult ParallelInsertSliceOp::getTiedOpResult() {
auto foreachThreadOp = getOperation()->getParentOfType<ForeachThreadOp>();
assert(foreachThreadOp && "unlinked ParallelInsertSliceOp");
PerformConcurrentlyOp performConcurrentlyOp = foreachThreadOp.getTerminator();
for (const auto &it : llvm::enumerate(performConcurrentlyOp.yieldingOps())) {
Operation &nextOp = it.value();
if (&nextOp == getOperation())
return foreachThreadOp->getResult(it.index());
}
llvm_unreachable("ParallelInsertSliceOp not found");
}
// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
Value source, Value dest,

View File

@ -961,42 +961,6 @@ struct ForeachThreadOpInterface
return BufferRelation::Equivalent;
}
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
const AnalysisState &state) const {
auto bufferizableOp = cast<BufferizableOpInterface>(op);
if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
return failure();
OpBuilder::InsertionGuard g(rewriter);
auto foreachThreadOp = cast<ForeachThreadOp>(op);
for (OpResult opResult : foreachThreadOp->getOpResults()) {
SmallVector<OpOperand *> destOperands =
state.getAliasingOpOperand(opResult);
assert(destOperands.size() == 1 &&
"expected exactly one aliasing OpOperand");
assert(isa<ParallelInsertSliceOp>(destOperands.front()->getOwner()) &&
"expected ParallelInsertSliceOp");
// Nothing to do if there is no conflict.
if (state.isInPlace(*destOperands.front()))
continue;
// Insert tensor allocation.
bool isYielded = state.isTensorYielded(opResult);
FailureOr<Value> alloc = allocateTensorForShapedValue(
rewriter, op->getLoc(), destOperands.front()->get(),
/*escape=*/isYielded, state.getOptions());
if (failed(alloc))
return failure();
// Update terminator operand.
rewriter.updateRootInPlace(destOperands.front()->getOwner(),
[&]() { destOperands.front()->set(*alloc); });
}
return success();
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto foreachThreadOp = cast<ForeachThreadOp>(op);
@ -1118,7 +1082,55 @@ struct ParallelInsertSliceOpInterface
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
const AnalysisState &state) const {
// RaW conflicts are resolved as part of ForeachThreadOp.
// This interface method is overridden because we want to set a custom
// insertion point for tensor copies. They should be inserted right before
// the ForeachThreadOp. E.g.:
//
// %r0, %r1 = foreach_thead ... {
// ...
// perform_concurrently {
// parallel_insert_slice %a into %b ... {inplace = ["true", "true"]}
// parallel_insert_slice %c into %d ... {inplace = ["true", "false"]}
// }
// }
//
// After TensorCopyInsertion:
//
// %copy = bufferization.alloc_tensor() copy(%d)
// %r0, %r1 = foreach_thead ... {
// ...
// perform_concurrently {
// parallel_insert_slice %a into %b ...
// parallel_insert_slice %c into %copy ...
// }
// }
OpBuilder::InsertionGuard g(rewriter);
auto insertOp = cast<ParallelInsertSliceOp>(op);
auto foreachThreadOp = insertOp->getParentOfType<ForeachThreadOp>();
// Nothing to do if the destination tensor is inplace.
assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
"source is always in-place");
if (state.isInPlace(op->getOpOperand(1) /*dest*/))
return success();
// Find corresponding OpResult.
OpResult opResult = insertOp.getTiedOpResult();
// Insert tensor allocation right before the ForeachThreadOp.
rewriter.setInsertionPoint(foreachThreadOp);
bool isYielded = state.isTensorYielded(opResult);
FailureOr<Value> alloc =
allocateTensorForShapedValue(rewriter, op->getLoc(), insertOp.getDest(),
/*escape=*/isYielded, state.getOptions());
if (failed(alloc))
return failure();
// Update destination operand.
rewriter.updateRootInPlace(
insertOp, [&]() { insertOp.getDestMutable().assign(*alloc); });
return success();
}
@ -1149,29 +1161,20 @@ struct ParallelInsertSliceOpInterface
if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), *srcBuffer,
subview)))
return failure();
rewriter.eraseOp(op);
// Replace all uses of ForeachThreadOp (just the corresponding result).
rewriter.setInsertionPointAfter(foreachThreadOp);
Value toTensorOp =
rewriter.create<ToTensorOp>(foreachThreadOp.getLoc(), *destBuffer);
// PerformConcurrentlyOp can have multiple ParallelInserSliceOps. Find the
// index of `op` within yielding ops.
unsigned resultNum = 0;
for (Operation &nextOp : performConcurrentlyOp.yieldingOps()) {
if (&nextOp == op)
break;
resultNum++;
}
assert(resultNum < foreachThreadOp->getNumResults() &&
"ParallelInsertSliceOp not found in PerformConcurrentlyOp");
SmallVector<OpOperand *> resultUses = llvm::to_vector(
llvm::map_range(foreachThreadOp->getResult(resultNum).getUses(),
[](OpOperand &use) { return &use; }));
// PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
SmallVector<OpOperand *> resultUses =
llvm::to_vector(llvm::map_range(insertOp.getTiedOpResult().getUses(),
[](OpOperand &use) { return &use; }));
for (OpOperand *use : resultUses) {
rewriter.updateRootInPlace(use->getOwner(),
[&]() { use->set(toTensorOp); });
}
rewriter.eraseOp(op);
return success();
}