[mlir][SCF] Add a ParallelCombiningOpInterface to decouple scf::PerformConcurrently from its contained operations

This allows purging references of scf.ForeachThreadOp and scf.PerformConcurrentlyOp from
ParallelInsertSliceOp.
This will allowmoving the op closer to tensor::InsertSliceOp with which it should share much more
code.

In the future, the decoupling will also allow extending the type of ops that can be used in the
parallel combinator as well as semantics related to multiple concurrent inserts to the same
result.

Differential Revision: https://reviews.llvm.org/D128857
This commit is contained in:
Nicolas Vasilache 2022-06-30 03:37:21 -07:00
parent 6a57d8fba5
commit b994d388ae
10 changed files with 205 additions and 43 deletions

View File

@ -18,6 +18,7 @@
#include "mlir/IR/RegionKindInterface.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"

View File

@ -16,6 +16,7 @@
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
@ -468,6 +469,7 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
NoSideEffect,
Terminator,
DeclareOpInterfaceMethods<ParallelCombiningOpInterface>,
HasParent<"ForeachThreadOp">,
] # GraphRegionNoTerminator.traits> {
let summary = "terminates a `foreach_thread` block";
@ -495,8 +497,9 @@ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
// TODO: Add a `PerformConcurrentlyOpInterface` interface for ops that can
// appear inside perform_concurrently.
let extraClassDeclaration = [{
SmallVector<Type> yieldedTypes();
::llvm::iterator_range<Block::iterator> yieldingOps();
::llvm::SmallVector<::mlir::Type> getYieldedTypes();
::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps();
::mlir::OpResult getParentResult(int64_t idx);
}];
}
@ -508,7 +511,9 @@ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
AttrSizedOperandSegments,
OffsetSizeAndStrideOpInterface,
HasParent<"PerformConcurrentlyOp">]> {
// TODO: Cannot use an interface here atm, verify this manually for now.
// HasParent<"ParallelCombiningOpInterface">
]> {
let summary = [{
Specify the tensor slice update of a single thread within the terminator of
an `scf.foreach_thread`.
@ -568,6 +573,11 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
return getSource().getType().cast<RankedTensorType>();
}
ParallelCombiningOpInterface getParallelCombiningParent() {
return dyn_cast<ParallelCombiningOpInterface>(
getOperation()->getParentOp());
}
/// Return the expected rank of each of the `static_offsets`, `static_sizes`
/// and `static_strides` attributes.
std::array<unsigned, 3> getArrayAttrMaxRanks() {
@ -599,6 +609,7 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//

View File

@ -6,6 +6,7 @@ add_mlir_interface(DerivedAttributeOpInterface)
add_mlir_interface(InferIntRangeInterface)
add_mlir_interface(InferTypeOpInterface)
add_mlir_interface(LoopLikeInterface)
add_mlir_interface(ParallelCombiningOpInterface)
add_mlir_interface(SideEffectInterfaces)
add_mlir_interface(TilingInterface)
add_mlir_interface(VectorInterfaces)

View File

@ -0,0 +1,29 @@
//===- ParallelCombiningOpInterface.h - Parallel combining op interface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the operation interface for ops that parallel combining
// operations.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_
#define MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_
#include "mlir/IR/OpDefinition.h"
namespace mlir {
namespace detail {
// TODO: Single region single block interface on interfaces ?
LogicalResult verifyParallelCombiningOpInterface(Operation *op);
} // namespace detail
} // namespace mlir
/// Include the generated interface declarations.
#include "mlir/Interfaces/ParallelCombiningOpInterface.h.inc"
#endif // MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_

View File

@ -0,0 +1,75 @@
//===- ParallelCombiningOpInterface.td - Parallel iface ----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the interface for ops that perform parallel combining operations.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE
#define MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE
include "mlir/IR/OpBase.td"
def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
let description = [{
A parallel combining op is an op with a region, that is not isolated from
above and yields values to its parent op without itself returning an SSA
value. The yielded values are determined by subvalues produced by the ops
contained in the region (the `yieldingOps`) and combined in any unspecified
order to produce the values yielded to the parent op.
This is useful as a terminator to parallel operations that iterate over
some set and return tensors while avoiding tight coupling between the
iterating op, the combining op and the individual subtensor producing ops.
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<
/*desc=*/[{
Return `idx`^th result of the parent operation.
}],
/*retTy=*/"::mlir::OpResult",
/*methodName=*/"getParentResult",
/*args=*/(ins "int64_t":$idx),
/*methodBody=*/[{
return $_op.getParentResult(idx);
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the contained ops that yield subvalues that this op combines to
yield to its parent.
}],
/*retTy=*/"::llvm::iterator_range<Block::iterator>",
/*methodName=*/"getYieldingOps",
/*args=*/(ins),
/*methodBody=*/[{
return $_op.getYieldingOps();
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the contained ops that yield subvalues that this op combines to
yield to its parent.
}],
/*retTy=*/"::llvm::SmallVector<::mlir::Type>",
/*methodName=*/"getYieldedTypes",
/*args=*/(ins),
/*methodBody=*/[{
return $_op.getYieldedTypes();
}]
>,
];
// TODO: Single region single block interface on interfaces ?
let verify = [{
return verifyParallelCombiningOpInterface($_op);
}];
}
#endif // MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE

View File

@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFDialect
MLIRControlFlowDialect
MLIRIR
MLIRLoopLikeInterface
MLIRParallelCombiningOpInterface
MLIRSideEffectInterfaces
)

View File

@ -1061,7 +1061,7 @@ LogicalResult ForeachThreadOp::verify() {
return emitOpError("region expects ") << getRank() << " arguments";
// Verify consistency between the result types and the terminator.
auto terminatorTypes = getTerminator().yieldedTypes();
auto terminatorTypes = getTerminator().getYieldedTypes();
auto opResults = getResults();
if (opResults.size() != terminatorTypes.size())
return emitOpError("produces ")
@ -1182,7 +1182,7 @@ void ForeachThreadOp::build(
llvm::dyn_cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
assert(terminator &&
"expected bodyBuilder to create PerformConcurrentlyOp terminator");
result.addTypes(terminator.yieldedTypes());
result.addTypes(terminator.getYieldedTypes());
}
// The ensureTerminator method generated by SingleBlockImplicitTerminator is
@ -1216,15 +1216,15 @@ ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) {
//===----------------------------------------------------------------------===//
OpResult ParallelInsertSliceOp::getTiedOpResult() {
auto foreachThreadOp = getOperation()->getParentOfType<ForeachThreadOp>();
assert(foreachThreadOp && "unlinked ParallelInsertSliceOp");
PerformConcurrentlyOp performConcurrentlyOp = foreachThreadOp.getTerminator();
for (const auto &it : llvm::enumerate(performConcurrentlyOp.yieldingOps())) {
ParallelCombiningOpInterface parallelCombiningParent =
getParallelCombiningParent();
for (const auto &it :
llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
Operation &nextOp = it.value();
if (&nextOp == getOperation())
return foreachThreadOp->getResult(it.index());
return parallelCombiningParent.getParentResult(it.index());
}
llvm_unreachable("ParallelInsertSliceOp not found");
llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
}
// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
@ -1262,6 +1262,13 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
}
LogicalResult ParallelInsertSliceOp::verify() {
if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
return this->emitError("expected ParallelCombiningOpInterface parent, got:")
<< *(getOperation()->getParentOp());
return success();
}
namespace {
/// Pattern to rewrite a parallel_insert_slice op with constant arguments.
class ParallelInsertSliceOpConstantArgumentFolder final
@ -1382,15 +1389,19 @@ ParseResult PerformConcurrentlyOp::parse(OpAsmParser &parser,
return success();
}
SmallVector<Type> PerformConcurrentlyOp::yieldedTypes() {
OpResult PerformConcurrentlyOp::getParentResult(int64_t idx) {
return getOperation()->getParentOp()->getResult(idx);
}
SmallVector<Type> PerformConcurrentlyOp::getYieldedTypes() {
return llvm::to_vector<4>(
llvm::map_range(this->yieldingOps(), [](Operation &op) {
llvm::map_range(getYieldingOps(), [](Operation &op) {
auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(&op);
return insertSliceOp ? insertSliceOp.yieldedType() : Type();
}));
}
llvm::iterator_range<Block::iterator> PerformConcurrentlyOp::yieldingOps() {
llvm::iterator_range<Block::iterator> PerformConcurrentlyOp::getYieldingOps() {
return getRegion().front().getOperations();
}

View File

@ -1043,8 +1043,7 @@ struct ParallelInsertSliceOpInterface
if (&opOperand != &op->getOpOperand(1) /*dest*/)
return {};
// ParallelInsertSliceOp itself has no results. Tensors are returned via
// the parent op.
// ParallelInsertSliceOp itself has no results, query its tied op results.
auto insertOp = cast<ParallelInsertSliceOp>(op);
return {insertOp.getTiedOpResult()};
}
@ -1090,8 +1089,10 @@ struct ParallelInsertSliceOpInterface
// }
OpBuilder::InsertionGuard g(rewriter);
auto insertOp = cast<ParallelInsertSliceOp>(op);
auto foreachThreadOp = insertOp->getParentOfType<ForeachThreadOp>();
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
ParallelCombiningOpInterface parallelCombiningParent =
parallelInsertSliceOp.getParallelCombiningParent();
Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
// Nothing to do if the destination tensor is inplace.
assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
@ -1100,20 +1101,21 @@ struct ParallelInsertSliceOpInterface
return success();
// Find corresponding OpResult.
OpResult opResult = insertOp.getTiedOpResult();
OpResult opResult = parallelInsertSliceOp.getTiedOpResult();
// Insert tensor allocation right before the ForeachThreadOp.
rewriter.setInsertionPoint(foreachThreadOp);
rewriter.setInsertionPoint(parallelIteratingOp);
bool isYielded = state.isTensorYielded(opResult);
FailureOr<Value> alloc =
allocateTensorForShapedValue(rewriter, op->getLoc(), insertOp.getDest(),
/*escape=*/isYielded, state.getOptions());
FailureOr<Value> alloc = allocateTensorForShapedValue(
rewriter, op->getLoc(), parallelInsertSliceOp.getDest(),
/*escape=*/isYielded, state.getOptions());
if (failed(alloc))
return failure();
// Update destination operand.
rewriter.updateRootInPlace(
insertOp, [&]() { insertOp.getDestMutable().assign(*alloc); });
rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() {
parallelInsertSliceOp.getDestMutable().assign(*alloc);
});
return success();
}
@ -1121,39 +1123,41 @@ struct ParallelInsertSliceOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
OpBuilder::InsertionGuard g(rewriter);
auto insertOp = cast<ParallelInsertSliceOp>(op);
auto performConcurrentlyOp = cast<PerformConcurrentlyOp>(op->getParentOp());
auto foreachThreadOp =
cast<ForeachThreadOp>(performConcurrentlyOp->getParentOp());
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
ParallelCombiningOpInterface parallelCombiningParent =
parallelInsertSliceOp.getParallelCombiningParent();
Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
// Get destination buffer.
FailureOr<Value> destBuffer =
getBuffer(rewriter, insertOp.getDest(), options);
getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
if (failed(destBuffer))
return failure();
// Bufferize the ParallelInsertSliceOp outside of the PerformConcurrentlyOp.
rewriter.setInsertionPoint(performConcurrentlyOp);
// Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`.
rewriter.setInsertionPoint(parallelCombiningParent);
FailureOr<Value> srcBuffer =
getBuffer(rewriter, insertOp.getSource(), options);
getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
if (failed(srcBuffer))
return failure();
Value subview = rewriter.create<memref::SubViewOp>(
insertOp.getLoc(), *destBuffer, insertOp.getMixedOffsets(),
insertOp.getMixedSizes(), insertOp.getMixedStrides());
parallelInsertSliceOp.getLoc(), *destBuffer,
parallelInsertSliceOp.getMixedOffsets(),
parallelInsertSliceOp.getMixedSizes(),
parallelInsertSliceOp.getMixedStrides());
// This memcpy will fold away if everything bufferizes in-place.
if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), *srcBuffer,
subview)))
if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
*srcBuffer, subview)))
return failure();
// Replace all uses of ForeachThreadOp (just the corresponding result).
rewriter.setInsertionPointAfter(foreachThreadOp);
// Replace all uses of parallelIteratingOp (just the corresponding result).
rewriter.setInsertionPointAfter(parallelIteratingOp);
Value toTensorOp =
rewriter.create<ToTensorOp>(foreachThreadOp.getLoc(), *destBuffer);
rewriter.create<ToTensorOp>(parallelIteratingOp->getLoc(), *destBuffer);
// PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
SmallVector<OpOperand *> resultUses =
llvm::to_vector(llvm::map_range(insertOp.getTiedOpResult().getUses(),
[](OpOperand &use) { return &use; }));
SmallVector<OpOperand *> resultUses = llvm::to_vector(
llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(),
[](OpOperand &use) { return &use; }));
for (OpOperand *use : resultUses) {
rewriter.updateRootInPlace(use->getOwner(),
[&]() { use->set(toTensorOp); });

View File

@ -8,6 +8,7 @@ set(LLVM_OPTIONAL_SOURCES
InferIntRangeInterface.cpp
InferTypeOpInterface.cpp
LoopLikeInterface.cpp
ParallelCombiningOpInterface.cpp
SideEffectInterfaces.cpp
TilingInterface.cpp
VectorInterfaces.cpp
@ -38,6 +39,7 @@ add_mlir_interface_library(DataLayoutInterfaces)
add_mlir_interface_library(DerivedAttributeOpInterface)
add_mlir_interface_library(InferIntRangeInterface)
add_mlir_interface_library(InferTypeOpInterface)
add_mlir_interface_library(ParallelCombiningOpInterface)
add_mlir_interface_library(SideEffectInterfaces)
add_mlir_interface_library(TilingInterface)
add_mlir_interface_library(VectorInterfaces)

View File

@ -0,0 +1,27 @@
//===- ParallelCombiningOpInterface.cpp - Parallel combining op interface -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// ParallelCombiningOpInterface
//===----------------------------------------------------------------------===//
// TODO: Single region single block interface on interfaces ?
LogicalResult mlir::detail::verifyParallelCombiningOpInterface(Operation *op) {
if (op->getNumRegions() != 1)
return op->emitError("expected single region op");
if (!op->getRegion(0).hasOneBlock())
return op->emitError("expected single block op region");
return success();
}
/// Include the definitions of the interface.
#include "mlir/Interfaces/ParallelCombiningOpInterface.cpp.inc"