[mlir][SCF] Add a ParallelCombiningOpInterface to decouple scf::PerformConcurrently from its contained operations

This allows purging references of scf.ForeachThreadOp and scf.PerformConcurrentlyOp from
ParallelInsertSliceOp.
This will allowmoving the op closer to tensor::InsertSliceOp with which it should share much more
code.

In the future, the decoupling will also allow extending the type of ops that can be used in the
parallel combinator as well as semantics related to multiple concurrent inserts to the same
result.

Differential Revision: https://reviews.llvm.org/D128857
This commit is contained in:
Nicolas Vasilache 2022-06-30 03:37:21 -07:00
parent 6a57d8fba5
commit b994d388ae
10 changed files with 205 additions and 43 deletions

View File

@ -18,6 +18,7 @@
#include "mlir/IR/RegionKindInterface.h" #include "mlir/IR/RegionKindInterface.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Interfaces/ViewLikeInterface.h"

View File

@ -16,6 +16,7 @@
include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/IR/RegionKindInterface.td" include "mlir/IR/RegionKindInterface.td"
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td" include "mlir/Interfaces/ViewLikeInterface.td"
@ -468,6 +469,7 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
NoSideEffect, NoSideEffect,
Terminator, Terminator,
DeclareOpInterfaceMethods<ParallelCombiningOpInterface>,
HasParent<"ForeachThreadOp">, HasParent<"ForeachThreadOp">,
] # GraphRegionNoTerminator.traits> { ] # GraphRegionNoTerminator.traits> {
let summary = "terminates a `foreach_thread` block"; let summary = "terminates a `foreach_thread` block";
@ -495,8 +497,9 @@ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
// TODO: Add a `PerformConcurrentlyOpInterface` interface for ops that can // TODO: Add a `PerformConcurrentlyOpInterface` interface for ops that can
// appear inside perform_concurrently. // appear inside perform_concurrently.
let extraClassDeclaration = [{ let extraClassDeclaration = [{
SmallVector<Type> yieldedTypes(); ::llvm::SmallVector<::mlir::Type> getYieldedTypes();
::llvm::iterator_range<Block::iterator> yieldingOps(); ::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps();
::mlir::OpResult getParentResult(int64_t idx);
}]; }];
} }
@ -508,7 +511,9 @@ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
AttrSizedOperandSegments, AttrSizedOperandSegments,
OffsetSizeAndStrideOpInterface, OffsetSizeAndStrideOpInterface,
HasParent<"PerformConcurrentlyOp">]> { // TODO: Cannot use an interface here atm, verify this manually for now.
// HasParent<"ParallelCombiningOpInterface">
]> {
let summary = [{ let summary = [{
Specify the tensor slice update of a single thread within the terminator of Specify the tensor slice update of a single thread within the terminator of
an `scf.foreach_thread`. an `scf.foreach_thread`.
@ -568,6 +573,11 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
return getSource().getType().cast<RankedTensorType>(); return getSource().getType().cast<RankedTensorType>();
} }
ParallelCombiningOpInterface getParallelCombiningParent() {
return dyn_cast<ParallelCombiningOpInterface>(
getOperation()->getParentOp());
}
/// Return the expected rank of each of the `static_offsets`, `static_sizes` /// Return the expected rank of each of the `static_offsets`, `static_sizes`
/// and `static_strides` attributes. /// and `static_strides` attributes.
std::array<unsigned, 3> getArrayAttrMaxRanks() { std::array<unsigned, 3> getArrayAttrMaxRanks() {
@ -599,6 +609,7 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let hasFolder = 1; let hasFolder = 1;
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -6,6 +6,7 @@ add_mlir_interface(DerivedAttributeOpInterface)
add_mlir_interface(InferIntRangeInterface) add_mlir_interface(InferIntRangeInterface)
add_mlir_interface(InferTypeOpInterface) add_mlir_interface(InferTypeOpInterface)
add_mlir_interface(LoopLikeInterface) add_mlir_interface(LoopLikeInterface)
add_mlir_interface(ParallelCombiningOpInterface)
add_mlir_interface(SideEffectInterfaces) add_mlir_interface(SideEffectInterfaces)
add_mlir_interface(TilingInterface) add_mlir_interface(TilingInterface)
add_mlir_interface(VectorInterfaces) add_mlir_interface(VectorInterfaces)

View File

@ -0,0 +1,29 @@
//===- ParallelCombiningOpInterface.h - Parallel combining op interface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the operation interface for ops that parallel combining
// operations.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_
#define MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_
#include "mlir/IR/OpDefinition.h"
namespace mlir {
namespace detail {
// TODO: Single region single block interface on interfaces ?
LogicalResult verifyParallelCombiningOpInterface(Operation *op);
} // namespace detail
} // namespace mlir
/// Include the generated interface declarations.
#include "mlir/Interfaces/ParallelCombiningOpInterface.h.inc"
#endif // MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_

View File

@ -0,0 +1,75 @@
//===- ParallelCombiningOpInterface.td - Parallel iface ----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the interface for ops that perform parallel combining operations.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE
#define MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE
include "mlir/IR/OpBase.td"
def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
let description = [{
A parallel combining op is an op with a region, that is not isolated from
above and yields values to its parent op without itself returning an SSA
value. The yielded values are determined by subvalues produced by the ops
contained in the region (the `yieldingOps`) and combined in any unspecified
order to produce the values yielded to the parent op.
This is useful as a terminator to parallel operations that iterate over
some set and return tensors while avoiding tight coupling between the
iterating op, the combining op and the individual subtensor producing ops.
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<
/*desc=*/[{
Return `idx`^th result of the parent operation.
}],
/*retTy=*/"::mlir::OpResult",
/*methodName=*/"getParentResult",
/*args=*/(ins "int64_t":$idx),
/*methodBody=*/[{
return $_op.getParentResult(idx);
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the contained ops that yield subvalues that this op combines to
yield to its parent.
}],
/*retTy=*/"::llvm::iterator_range<Block::iterator>",
/*methodName=*/"getYieldingOps",
/*args=*/(ins),
/*methodBody=*/[{
return $_op.getYieldingOps();
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the contained ops that yield subvalues that this op combines to
yield to its parent.
}],
/*retTy=*/"::llvm::SmallVector<::mlir::Type>",
/*methodName=*/"getYieldedTypes",
/*args=*/(ins),
/*methodBody=*/[{
return $_op.getYieldedTypes();
}]
>,
];
// TODO: Single region single block interface on interfaces ?
let verify = [{
return verifyParallelCombiningOpInterface($_op);
}];
}
#endif // MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE

View File

@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFDialect
MLIRControlFlowDialect MLIRControlFlowDialect
MLIRIR MLIRIR
MLIRLoopLikeInterface MLIRLoopLikeInterface
MLIRParallelCombiningOpInterface
MLIRSideEffectInterfaces MLIRSideEffectInterfaces
) )

View File

@ -1061,7 +1061,7 @@ LogicalResult ForeachThreadOp::verify() {
return emitOpError("region expects ") << getRank() << " arguments"; return emitOpError("region expects ") << getRank() << " arguments";
// Verify consistency between the result types and the terminator. // Verify consistency between the result types and the terminator.
auto terminatorTypes = getTerminator().yieldedTypes(); auto terminatorTypes = getTerminator().getYieldedTypes();
auto opResults = getResults(); auto opResults = getResults();
if (opResults.size() != terminatorTypes.size()) if (opResults.size() != terminatorTypes.size())
return emitOpError("produces ") return emitOpError("produces ")
@ -1182,7 +1182,7 @@ void ForeachThreadOp::build(
llvm::dyn_cast<PerformConcurrentlyOp>(bodyBlock.getTerminator()); llvm::dyn_cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
assert(terminator && assert(terminator &&
"expected bodyBuilder to create PerformConcurrentlyOp terminator"); "expected bodyBuilder to create PerformConcurrentlyOp terminator");
result.addTypes(terminator.yieldedTypes()); result.addTypes(terminator.getYieldedTypes());
} }
// The ensureTerminator method generated by SingleBlockImplicitTerminator is // The ensureTerminator method generated by SingleBlockImplicitTerminator is
@ -1216,15 +1216,15 @@ ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpResult ParallelInsertSliceOp::getTiedOpResult() { OpResult ParallelInsertSliceOp::getTiedOpResult() {
auto foreachThreadOp = getOperation()->getParentOfType<ForeachThreadOp>(); ParallelCombiningOpInterface parallelCombiningParent =
assert(foreachThreadOp && "unlinked ParallelInsertSliceOp"); getParallelCombiningParent();
PerformConcurrentlyOp performConcurrentlyOp = foreachThreadOp.getTerminator(); for (const auto &it :
for (const auto &it : llvm::enumerate(performConcurrentlyOp.yieldingOps())) { llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
Operation &nextOp = it.value(); Operation &nextOp = it.value();
if (&nextOp == getOperation()) if (&nextOp == getOperation())
return foreachThreadOp->getResult(it.index()); return parallelCombiningParent.getParentResult(it.index());
} }
llvm_unreachable("ParallelInsertSliceOp not found"); llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
} }
// Build a ParallelInsertSliceOp with mixed static and dynamic entries. // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
@ -1262,6 +1262,13 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
build(b, result, source, dest, offsetValues, sizeValues, strideValues); build(b, result, source, dest, offsetValues, sizeValues, strideValues);
} }
LogicalResult ParallelInsertSliceOp::verify() {
if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
return this->emitError("expected ParallelCombiningOpInterface parent, got:")
<< *(getOperation()->getParentOp());
return success();
}
namespace { namespace {
/// Pattern to rewrite a parallel_insert_slice op with constant arguments. /// Pattern to rewrite a parallel_insert_slice op with constant arguments.
class ParallelInsertSliceOpConstantArgumentFolder final class ParallelInsertSliceOpConstantArgumentFolder final
@ -1382,15 +1389,19 @@ ParseResult PerformConcurrentlyOp::parse(OpAsmParser &parser,
return success(); return success();
} }
SmallVector<Type> PerformConcurrentlyOp::yieldedTypes() { OpResult PerformConcurrentlyOp::getParentResult(int64_t idx) {
return getOperation()->getParentOp()->getResult(idx);
}
SmallVector<Type> PerformConcurrentlyOp::getYieldedTypes() {
return llvm::to_vector<4>( return llvm::to_vector<4>(
llvm::map_range(this->yieldingOps(), [](Operation &op) { llvm::map_range(getYieldingOps(), [](Operation &op) {
auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(&op); auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(&op);
return insertSliceOp ? insertSliceOp.yieldedType() : Type(); return insertSliceOp ? insertSliceOp.yieldedType() : Type();
})); }));
} }
llvm::iterator_range<Block::iterator> PerformConcurrentlyOp::yieldingOps() { llvm::iterator_range<Block::iterator> PerformConcurrentlyOp::getYieldingOps() {
return getRegion().front().getOperations(); return getRegion().front().getOperations();
} }

View File

@ -1043,8 +1043,7 @@ struct ParallelInsertSliceOpInterface
if (&opOperand != &op->getOpOperand(1) /*dest*/) if (&opOperand != &op->getOpOperand(1) /*dest*/)
return {}; return {};
// ParallelInsertSliceOp itself has no results. Tensors are returned via // ParallelInsertSliceOp itself has no results, query its tied op results.
// the parent op.
auto insertOp = cast<ParallelInsertSliceOp>(op); auto insertOp = cast<ParallelInsertSliceOp>(op);
return {insertOp.getTiedOpResult()}; return {insertOp.getTiedOpResult()};
} }
@ -1090,8 +1089,10 @@ struct ParallelInsertSliceOpInterface
// } // }
OpBuilder::InsertionGuard g(rewriter); OpBuilder::InsertionGuard g(rewriter);
auto insertOp = cast<ParallelInsertSliceOp>(op); auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
auto foreachThreadOp = insertOp->getParentOfType<ForeachThreadOp>(); ParallelCombiningOpInterface parallelCombiningParent =
parallelInsertSliceOp.getParallelCombiningParent();
Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
// Nothing to do if the destination tensor is inplace. // Nothing to do if the destination tensor is inplace.
assert(state.isInPlace(op->getOpOperand(0) /*src*/) && assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
@ -1100,20 +1101,21 @@ struct ParallelInsertSliceOpInterface
return success(); return success();
// Find corresponding OpResult. // Find corresponding OpResult.
OpResult opResult = insertOp.getTiedOpResult(); OpResult opResult = parallelInsertSliceOp.getTiedOpResult();
// Insert tensor allocation right before the ForeachThreadOp. // Insert tensor allocation right before the ForeachThreadOp.
rewriter.setInsertionPoint(foreachThreadOp); rewriter.setInsertionPoint(parallelIteratingOp);
bool isYielded = state.isTensorYielded(opResult); bool isYielded = state.isTensorYielded(opResult);
FailureOr<Value> alloc = FailureOr<Value> alloc = allocateTensorForShapedValue(
allocateTensorForShapedValue(rewriter, op->getLoc(), insertOp.getDest(), rewriter, op->getLoc(), parallelInsertSliceOp.getDest(),
/*escape=*/isYielded, state.getOptions()); /*escape=*/isYielded, state.getOptions());
if (failed(alloc)) if (failed(alloc))
return failure(); return failure();
// Update destination operand. // Update destination operand.
rewriter.updateRootInPlace( rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() {
insertOp, [&]() { insertOp.getDestMutable().assign(*alloc); }); parallelInsertSliceOp.getDestMutable().assign(*alloc);
});
return success(); return success();
} }
@ -1121,39 +1123,41 @@ struct ParallelInsertSliceOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const { const BufferizationOptions &options) const {
OpBuilder::InsertionGuard g(rewriter); OpBuilder::InsertionGuard g(rewriter);
auto insertOp = cast<ParallelInsertSliceOp>(op); auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
auto performConcurrentlyOp = cast<PerformConcurrentlyOp>(op->getParentOp()); ParallelCombiningOpInterface parallelCombiningParent =
auto foreachThreadOp = parallelInsertSliceOp.getParallelCombiningParent();
cast<ForeachThreadOp>(performConcurrentlyOp->getParentOp()); Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
// Get destination buffer. // Get destination buffer.
FailureOr<Value> destBuffer = FailureOr<Value> destBuffer =
getBuffer(rewriter, insertOp.getDest(), options); getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
if (failed(destBuffer)) if (failed(destBuffer))
return failure(); return failure();
// Bufferize the ParallelInsertSliceOp outside of the PerformConcurrentlyOp. // Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`.
rewriter.setInsertionPoint(performConcurrentlyOp); rewriter.setInsertionPoint(parallelCombiningParent);
FailureOr<Value> srcBuffer = FailureOr<Value> srcBuffer =
getBuffer(rewriter, insertOp.getSource(), options); getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
if (failed(srcBuffer)) if (failed(srcBuffer))
return failure(); return failure();
Value subview = rewriter.create<memref::SubViewOp>( Value subview = rewriter.create<memref::SubViewOp>(
insertOp.getLoc(), *destBuffer, insertOp.getMixedOffsets(), parallelInsertSliceOp.getLoc(), *destBuffer,
insertOp.getMixedSizes(), insertOp.getMixedStrides()); parallelInsertSliceOp.getMixedOffsets(),
parallelInsertSliceOp.getMixedSizes(),
parallelInsertSliceOp.getMixedStrides());
// This memcpy will fold away if everything bufferizes in-place. // This memcpy will fold away if everything bufferizes in-place.
if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), *srcBuffer, if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
subview))) *srcBuffer, subview)))
return failure(); return failure();
// Replace all uses of ForeachThreadOp (just the corresponding result). // Replace all uses of parallelIteratingOp (just the corresponding result).
rewriter.setInsertionPointAfter(foreachThreadOp); rewriter.setInsertionPointAfter(parallelIteratingOp);
Value toTensorOp = Value toTensorOp =
rewriter.create<ToTensorOp>(foreachThreadOp.getLoc(), *destBuffer); rewriter.create<ToTensorOp>(parallelIteratingOp->getLoc(), *destBuffer);
// PerformConcurrentlyOp can have multiple ParallelInsertSliceOps. // PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
SmallVector<OpOperand *> resultUses = SmallVector<OpOperand *> resultUses = llvm::to_vector(
llvm::to_vector(llvm::map_range(insertOp.getTiedOpResult().getUses(), llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(),
[](OpOperand &use) { return &use; })); [](OpOperand &use) { return &use; }));
for (OpOperand *use : resultUses) { for (OpOperand *use : resultUses) {
rewriter.updateRootInPlace(use->getOwner(), rewriter.updateRootInPlace(use->getOwner(),
[&]() { use->set(toTensorOp); }); [&]() { use->set(toTensorOp); });

View File

@ -8,6 +8,7 @@ set(LLVM_OPTIONAL_SOURCES
InferIntRangeInterface.cpp InferIntRangeInterface.cpp
InferTypeOpInterface.cpp InferTypeOpInterface.cpp
LoopLikeInterface.cpp LoopLikeInterface.cpp
ParallelCombiningOpInterface.cpp
SideEffectInterfaces.cpp SideEffectInterfaces.cpp
TilingInterface.cpp TilingInterface.cpp
VectorInterfaces.cpp VectorInterfaces.cpp
@ -38,6 +39,7 @@ add_mlir_interface_library(DataLayoutInterfaces)
add_mlir_interface_library(DerivedAttributeOpInterface) add_mlir_interface_library(DerivedAttributeOpInterface)
add_mlir_interface_library(InferIntRangeInterface) add_mlir_interface_library(InferIntRangeInterface)
add_mlir_interface_library(InferTypeOpInterface) add_mlir_interface_library(InferTypeOpInterface)
add_mlir_interface_library(ParallelCombiningOpInterface)
add_mlir_interface_library(SideEffectInterfaces) add_mlir_interface_library(SideEffectInterfaces)
add_mlir_interface_library(TilingInterface) add_mlir_interface_library(TilingInterface)
add_mlir_interface_library(VectorInterfaces) add_mlir_interface_library(VectorInterfaces)

View File

@ -0,0 +1,27 @@
//===- ParallelCombiningOpInterface.cpp - Parallel combining op interface -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// ParallelCombiningOpInterface
//===----------------------------------------------------------------------===//
// TODO: Single region single block interface on interfaces ?
LogicalResult mlir::detail::verifyParallelCombiningOpInterface(Operation *op) {
if (op->getNumRegions() != 1)
return op->emitError("expected single region op");
if (!op->getRegion(0).hasOneBlock())
return op->emitError("expected single block op region");
return success();
}
/// Include the definitions of the interface.
#include "mlir/Interfaces/ParallelCombiningOpInterface.cpp.inc"