forked from OSchip/llvm-project
[mlir][SCF][bufferize][NFC] Implement resolveConflicts for ParallelInsertSliceOp
This was previous implemented as part of the BufferizableOpInterface of ForEachThreadOp. Moving the implementation to ParallelInsertSliceOp to be consistent with the remaining ops and to have a nice example op that can serve as a blueprint for other ops. Differential Revision: https://reviews.llvm.org/D128666
This commit is contained in:
parent
f6f53e990d
commit
04dac2ca7c
|
@ -578,6 +578,10 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
|
|||
/// Return the number of leading operands before `offsets`, `sizes` and
|
||||
/// `strides` operands.
|
||||
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
|
||||
|
||||
/// Return the OpResult of the enclosing ForeachThreadOp that is
|
||||
/// corresponding to this ParallelInsertSliceOp.
|
||||
OpResult getTiedOpResult();
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
|
|
|
@ -1215,6 +1215,18 @@ ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) {
|
|||
// ParallelInsertSliceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpResult ParallelInsertSliceOp::getTiedOpResult() {
|
||||
auto foreachThreadOp = getOperation()->getParentOfType<ForeachThreadOp>();
|
||||
assert(foreachThreadOp && "unlinked ParallelInsertSliceOp");
|
||||
PerformConcurrentlyOp performConcurrentlyOp = foreachThreadOp.getTerminator();
|
||||
for (const auto &it : llvm::enumerate(performConcurrentlyOp.yieldingOps())) {
|
||||
Operation &nextOp = it.value();
|
||||
if (&nextOp == getOperation())
|
||||
return foreachThreadOp->getResult(it.index());
|
||||
}
|
||||
llvm_unreachable("ParallelInsertSliceOp not found");
|
||||
}
|
||||
|
||||
// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
|
||||
void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
|
||||
Value source, Value dest,
|
||||
|
|
|
@ -961,42 +961,6 @@ struct ForeachThreadOpInterface
|
|||
return BufferRelation::Equivalent;
|
||||
}
|
||||
|
||||
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
|
||||
const AnalysisState &state) const {
|
||||
auto bufferizableOp = cast<BufferizableOpInterface>(op);
|
||||
if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
|
||||
return failure();
|
||||
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
auto foreachThreadOp = cast<ForeachThreadOp>(op);
|
||||
for (OpResult opResult : foreachThreadOp->getOpResults()) {
|
||||
SmallVector<OpOperand *> destOperands =
|
||||
state.getAliasingOpOperand(opResult);
|
||||
assert(destOperands.size() == 1 &&
|
||||
"expected exactly one aliasing OpOperand");
|
||||
assert(isa<ParallelInsertSliceOp>(destOperands.front()->getOwner()) &&
|
||||
"expected ParallelInsertSliceOp");
|
||||
|
||||
// Nothing to do if there is no conflict.
|
||||
if (state.isInPlace(*destOperands.front()))
|
||||
continue;
|
||||
|
||||
// Insert tensor allocation.
|
||||
bool isYielded = state.isTensorYielded(opResult);
|
||||
FailureOr<Value> alloc = allocateTensorForShapedValue(
|
||||
rewriter, op->getLoc(), destOperands.front()->get(),
|
||||
/*escape=*/isYielded, state.getOptions());
|
||||
if (failed(alloc))
|
||||
return failure();
|
||||
|
||||
// Update terminator operand.
|
||||
rewriter.updateRootInPlace(destOperands.front()->getOwner(),
|
||||
[&]() { destOperands.front()->set(*alloc); });
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
const BufferizationOptions &options) const {
|
||||
auto foreachThreadOp = cast<ForeachThreadOp>(op);
|
||||
|
@ -1118,7 +1082,55 @@ struct ParallelInsertSliceOpInterface
|
|||
|
||||
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
|
||||
const AnalysisState &state) const {
|
||||
// RaW conflicts are resolved as part of ForeachThreadOp.
|
||||
// This interface method is overridden because we want to set a custom
|
||||
// insertion point for tensor copies. They should be inserted right before
|
||||
// the ForeachThreadOp. E.g.:
|
||||
//
|
||||
// %r0, %r1 = foreach_thead ... {
|
||||
// ...
|
||||
// perform_concurrently {
|
||||
// parallel_insert_slice %a into %b ... {inplace = ["true", "true"]}
|
||||
// parallel_insert_slice %c into %d ... {inplace = ["true", "false"]}
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// After TensorCopyInsertion:
|
||||
//
|
||||
// %copy = bufferization.alloc_tensor() copy(%d)
|
||||
// %r0, %r1 = foreach_thead ... {
|
||||
// ...
|
||||
// perform_concurrently {
|
||||
// parallel_insert_slice %a into %b ...
|
||||
// parallel_insert_slice %c into %copy ...
|
||||
// }
|
||||
// }
|
||||
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
auto insertOp = cast<ParallelInsertSliceOp>(op);
|
||||
auto foreachThreadOp = insertOp->getParentOfType<ForeachThreadOp>();
|
||||
|
||||
// Nothing to do if the destination tensor is inplace.
|
||||
assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
|
||||
"source is always in-place");
|
||||
if (state.isInPlace(op->getOpOperand(1) /*dest*/))
|
||||
return success();
|
||||
|
||||
// Find corresponding OpResult.
|
||||
OpResult opResult = insertOp.getTiedOpResult();
|
||||
|
||||
// Insert tensor allocation right before the ForeachThreadOp.
|
||||
rewriter.setInsertionPoint(foreachThreadOp);
|
||||
bool isYielded = state.isTensorYielded(opResult);
|
||||
FailureOr<Value> alloc =
|
||||
allocateTensorForShapedValue(rewriter, op->getLoc(), insertOp.getDest(),
|
||||
/*escape=*/isYielded, state.getOptions());
|
||||
if (failed(alloc))
|
||||
return failure();
|
||||
|
||||
// Update destination operand.
|
||||
rewriter.updateRootInPlace(
|
||||
insertOp, [&]() { insertOp.getDestMutable().assign(*alloc); });
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1149,29 +1161,20 @@ struct ParallelInsertSliceOpInterface
|
|||
if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), *srcBuffer,
|
||||
subview)))
|
||||
return failure();
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
// Replace all uses of ForeachThreadOp (just the corresponding result).
|
||||
rewriter.setInsertionPointAfter(foreachThreadOp);
|
||||
Value toTensorOp =
|
||||
rewriter.create<ToTensorOp>(foreachThreadOp.getLoc(), *destBuffer);
|
||||
// PerformConcurrentlyOp can have multiple ParallelInserSliceOps. Find the
|
||||
// index of `op` within yielding ops.
|
||||
unsigned resultNum = 0;
|
||||
for (Operation &nextOp : performConcurrentlyOp.yieldingOps()) {
|
||||
if (&nextOp == op)
|
||||
break;
|
||||
resultNum++;
|
||||
}
|
||||
assert(resultNum < foreachThreadOp->getNumResults() &&
|
||||
"ParallelInsertSliceOp not found in PerformConcurrentlyOp");
|
||||
SmallVector<OpOperand *> resultUses = llvm::to_vector(
|
||||
llvm::map_range(foreachThreadOp->getResult(resultNum).getUses(),
|
||||
[](OpOperand &use) { return &use; }));
|
||||
// PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
|
||||
SmallVector<OpOperand *> resultUses =
|
||||
llvm::to_vector(llvm::map_range(insertOp.getTiedOpResult().getUses(),
|
||||
[](OpOperand &use) { return &use; }));
|
||||
for (OpOperand *use : resultUses) {
|
||||
rewriter.updateRootInPlace(use->getOwner(),
|
||||
[&]() { use->set(toTensorOp); });
|
||||
}
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue