forked from OSchip/llvm-project
[mlir][SCF] Add a ParallelCombiningOpInterface to decouple scf::PerformConcurrently from its contained operations
This allows purging references of scf.ForeachThreadOp and scf.PerformConcurrentlyOp from ParallelInsertSliceOp. This will allowmoving the op closer to tensor::InsertSliceOp with which it should share much more code. In the future, the decoupling will also allow extending the type of ops that can be used in the parallel combinator as well as semantics related to multiple concurrent inserts to the same result. Differential Revision: https://reviews.llvm.org/D128857
This commit is contained in:
parent
6a57d8fba5
commit
b994d388ae
|
@ -18,6 +18,7 @@
|
||||||
#include "mlir/IR/RegionKindInterface.h"
|
#include "mlir/IR/RegionKindInterface.h"
|
||||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||||
#include "mlir/Interfaces/LoopLikeInterface.h"
|
#include "mlir/Interfaces/LoopLikeInterface.h"
|
||||||
|
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
|
||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
#include "mlir/Interfaces/ViewLikeInterface.h"
|
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||||
include "mlir/Interfaces/LoopLikeInterface.td"
|
include "mlir/Interfaces/LoopLikeInterface.td"
|
||||||
include "mlir/IR/RegionKindInterface.td"
|
include "mlir/IR/RegionKindInterface.td"
|
||||||
|
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
include "mlir/Interfaces/ViewLikeInterface.td"
|
include "mlir/Interfaces/ViewLikeInterface.td"
|
||||||
|
|
||||||
|
@ -468,6 +469,7 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
|
||||||
def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
|
def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
|
||||||
NoSideEffect,
|
NoSideEffect,
|
||||||
Terminator,
|
Terminator,
|
||||||
|
DeclareOpInterfaceMethods<ParallelCombiningOpInterface>,
|
||||||
HasParent<"ForeachThreadOp">,
|
HasParent<"ForeachThreadOp">,
|
||||||
] # GraphRegionNoTerminator.traits> {
|
] # GraphRegionNoTerminator.traits> {
|
||||||
let summary = "terminates a `foreach_thread` block";
|
let summary = "terminates a `foreach_thread` block";
|
||||||
|
@ -495,8 +497,9 @@ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
|
||||||
// TODO: Add a `PerformConcurrentlyOpInterface` interface for ops that can
|
// TODO: Add a `PerformConcurrentlyOpInterface` interface for ops that can
|
||||||
// appear inside perform_concurrently.
|
// appear inside perform_concurrently.
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
SmallVector<Type> yieldedTypes();
|
::llvm::SmallVector<::mlir::Type> getYieldedTypes();
|
||||||
::llvm::iterator_range<Block::iterator> yieldingOps();
|
::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps();
|
||||||
|
::mlir::OpResult getParentResult(int64_t idx);
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -508,7 +511,9 @@ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
|
||||||
def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
|
def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
|
||||||
AttrSizedOperandSegments,
|
AttrSizedOperandSegments,
|
||||||
OffsetSizeAndStrideOpInterface,
|
OffsetSizeAndStrideOpInterface,
|
||||||
HasParent<"PerformConcurrentlyOp">]> {
|
// TODO: Cannot use an interface here atm, verify this manually for now.
|
||||||
|
// HasParent<"ParallelCombiningOpInterface">
|
||||||
|
]> {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
Specify the tensor slice update of a single thread within the terminator of
|
Specify the tensor slice update of a single thread within the terminator of
|
||||||
an `scf.foreach_thread`.
|
an `scf.foreach_thread`.
|
||||||
|
@ -568,6 +573,11 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
|
||||||
return getSource().getType().cast<RankedTensorType>();
|
return getSource().getType().cast<RankedTensorType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ParallelCombiningOpInterface getParallelCombiningParent() {
|
||||||
|
return dyn_cast<ParallelCombiningOpInterface>(
|
||||||
|
getOperation()->getParentOp());
|
||||||
|
}
|
||||||
|
|
||||||
/// Return the expected rank of each of the `static_offsets`, `static_sizes`
|
/// Return the expected rank of each of the `static_offsets`, `static_sizes`
|
||||||
/// and `static_strides` attributes.
|
/// and `static_strides` attributes.
|
||||||
std::array<unsigned, 3> getArrayAttrMaxRanks() {
|
std::array<unsigned, 3> getArrayAttrMaxRanks() {
|
||||||
|
@ -599,6 +609,7 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
|
||||||
|
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -6,6 +6,7 @@ add_mlir_interface(DerivedAttributeOpInterface)
|
||||||
add_mlir_interface(InferIntRangeInterface)
|
add_mlir_interface(InferIntRangeInterface)
|
||||||
add_mlir_interface(InferTypeOpInterface)
|
add_mlir_interface(InferTypeOpInterface)
|
||||||
add_mlir_interface(LoopLikeInterface)
|
add_mlir_interface(LoopLikeInterface)
|
||||||
|
add_mlir_interface(ParallelCombiningOpInterface)
|
||||||
add_mlir_interface(SideEffectInterfaces)
|
add_mlir_interface(SideEffectInterfaces)
|
||||||
add_mlir_interface(TilingInterface)
|
add_mlir_interface(TilingInterface)
|
||||||
add_mlir_interface(VectorInterfaces)
|
add_mlir_interface(VectorInterfaces)
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
//===- ParallelCombiningOpInterface.h - Parallel combining op interface ---===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file implements the operation interface for ops that parallel combining
|
||||||
|
// operations.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_
|
||||||
|
#define MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_
|
||||||
|
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace detail {
|
||||||
|
// TODO: Single region single block interface on interfaces ?
|
||||||
|
LogicalResult verifyParallelCombiningOpInterface(Operation *op);
|
||||||
|
} // namespace detail
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
/// Include the generated interface declarations.
|
||||||
|
#include "mlir/Interfaces/ParallelCombiningOpInterface.h.inc"
|
||||||
|
|
||||||
|
#endif // MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_
|
|
@ -0,0 +1,75 @@
|
||||||
|
//===- ParallelCombiningOpInterface.td - Parallel iface ----*- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Defines the interface for ops that perform parallel combining operations.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE
|
||||||
|
#define MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE
|
||||||
|
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
|
||||||
|
def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
|
||||||
|
let description = [{
|
||||||
|
A parallel combining op is an op with a region, that is not isolated from
|
||||||
|
above and yields values to its parent op without itself returning an SSA
|
||||||
|
value. The yielded values are determined by subvalues produced by the ops
|
||||||
|
contained in the region (the `yieldingOps`) and combined in any unspecified
|
||||||
|
order to produce the values yielded to the parent op.
|
||||||
|
|
||||||
|
This is useful as a terminator to parallel operations that iterate over
|
||||||
|
some set and return tensors while avoiding tight coupling between the
|
||||||
|
iterating op, the combining op and the individual subtensor producing ops.
|
||||||
|
}];
|
||||||
|
let cppNamespace = "::mlir";
|
||||||
|
|
||||||
|
let methods = [
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return `idx`^th result of the parent operation.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"::mlir::OpResult",
|
||||||
|
/*methodName=*/"getParentResult",
|
||||||
|
/*args=*/(ins "int64_t":$idx),
|
||||||
|
/*methodBody=*/[{
|
||||||
|
return $_op.getParentResult(idx);
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return the contained ops that yield subvalues that this op combines to
|
||||||
|
yield to its parent.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"::llvm::iterator_range<Block::iterator>",
|
||||||
|
/*methodName=*/"getYieldingOps",
|
||||||
|
/*args=*/(ins),
|
||||||
|
/*methodBody=*/[{
|
||||||
|
return $_op.getYieldingOps();
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return the contained ops that yield subvalues that this op combines to
|
||||||
|
yield to its parent.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"::llvm::SmallVector<::mlir::Type>",
|
||||||
|
/*methodName=*/"getYieldedTypes",
|
||||||
|
/*args=*/(ins),
|
||||||
|
/*methodBody=*/[{
|
||||||
|
return $_op.getYieldedTypes();
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
];
|
||||||
|
// TODO: Single region single block interface on interfaces ?
|
||||||
|
let verify = [{
|
||||||
|
return verifyParallelCombiningOpInterface($_op);
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE
|
|
@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFDialect
|
||||||
MLIRControlFlowDialect
|
MLIRControlFlowDialect
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRLoopLikeInterface
|
MLIRLoopLikeInterface
|
||||||
|
MLIRParallelCombiningOpInterface
|
||||||
MLIRSideEffectInterfaces
|
MLIRSideEffectInterfaces
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1061,7 +1061,7 @@ LogicalResult ForeachThreadOp::verify() {
|
||||||
return emitOpError("region expects ") << getRank() << " arguments";
|
return emitOpError("region expects ") << getRank() << " arguments";
|
||||||
|
|
||||||
// Verify consistency between the result types and the terminator.
|
// Verify consistency between the result types and the terminator.
|
||||||
auto terminatorTypes = getTerminator().yieldedTypes();
|
auto terminatorTypes = getTerminator().getYieldedTypes();
|
||||||
auto opResults = getResults();
|
auto opResults = getResults();
|
||||||
if (opResults.size() != terminatorTypes.size())
|
if (opResults.size() != terminatorTypes.size())
|
||||||
return emitOpError("produces ")
|
return emitOpError("produces ")
|
||||||
|
@ -1182,7 +1182,7 @@ void ForeachThreadOp::build(
|
||||||
llvm::dyn_cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
|
llvm::dyn_cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
|
||||||
assert(terminator &&
|
assert(terminator &&
|
||||||
"expected bodyBuilder to create PerformConcurrentlyOp terminator");
|
"expected bodyBuilder to create PerformConcurrentlyOp terminator");
|
||||||
result.addTypes(terminator.yieldedTypes());
|
result.addTypes(terminator.getYieldedTypes());
|
||||||
}
|
}
|
||||||
|
|
||||||
// The ensureTerminator method generated by SingleBlockImplicitTerminator is
|
// The ensureTerminator method generated by SingleBlockImplicitTerminator is
|
||||||
|
@ -1216,15 +1216,15 @@ ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpResult ParallelInsertSliceOp::getTiedOpResult() {
|
OpResult ParallelInsertSliceOp::getTiedOpResult() {
|
||||||
auto foreachThreadOp = getOperation()->getParentOfType<ForeachThreadOp>();
|
ParallelCombiningOpInterface parallelCombiningParent =
|
||||||
assert(foreachThreadOp && "unlinked ParallelInsertSliceOp");
|
getParallelCombiningParent();
|
||||||
PerformConcurrentlyOp performConcurrentlyOp = foreachThreadOp.getTerminator();
|
for (const auto &it :
|
||||||
for (const auto &it : llvm::enumerate(performConcurrentlyOp.yieldingOps())) {
|
llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
|
||||||
Operation &nextOp = it.value();
|
Operation &nextOp = it.value();
|
||||||
if (&nextOp == getOperation())
|
if (&nextOp == getOperation())
|
||||||
return foreachThreadOp->getResult(it.index());
|
return parallelCombiningParent.getParentResult(it.index());
|
||||||
}
|
}
|
||||||
llvm_unreachable("ParallelInsertSliceOp not found");
|
llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
|
// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
|
||||||
|
@ -1262,6 +1262,13 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
|
||||||
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
|
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult ParallelInsertSliceOp::verify() {
|
||||||
|
if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
|
||||||
|
return this->emitError("expected ParallelCombiningOpInterface parent, got:")
|
||||||
|
<< *(getOperation()->getParentOp());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
/// Pattern to rewrite a parallel_insert_slice op with constant arguments.
|
/// Pattern to rewrite a parallel_insert_slice op with constant arguments.
|
||||||
class ParallelInsertSliceOpConstantArgumentFolder final
|
class ParallelInsertSliceOpConstantArgumentFolder final
|
||||||
|
@ -1382,15 +1389,19 @@ ParseResult PerformConcurrentlyOp::parse(OpAsmParser &parser,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Type> PerformConcurrentlyOp::yieldedTypes() {
|
OpResult PerformConcurrentlyOp::getParentResult(int64_t idx) {
|
||||||
|
return getOperation()->getParentOp()->getResult(idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Type> PerformConcurrentlyOp::getYieldedTypes() {
|
||||||
return llvm::to_vector<4>(
|
return llvm::to_vector<4>(
|
||||||
llvm::map_range(this->yieldingOps(), [](Operation &op) {
|
llvm::map_range(getYieldingOps(), [](Operation &op) {
|
||||||
auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(&op);
|
auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(&op);
|
||||||
return insertSliceOp ? insertSliceOp.yieldedType() : Type();
|
return insertSliceOp ? insertSliceOp.yieldedType() : Type();
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::iterator_range<Block::iterator> PerformConcurrentlyOp::yieldingOps() {
|
llvm::iterator_range<Block::iterator> PerformConcurrentlyOp::getYieldingOps() {
|
||||||
return getRegion().front().getOperations();
|
return getRegion().front().getOperations();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1043,8 +1043,7 @@ struct ParallelInsertSliceOpInterface
|
||||||
if (&opOperand != &op->getOpOperand(1) /*dest*/)
|
if (&opOperand != &op->getOpOperand(1) /*dest*/)
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
// ParallelInsertSliceOp itself has no results. Tensors are returned via
|
// ParallelInsertSliceOp itself has no results, query its tied op results.
|
||||||
// the parent op.
|
|
||||||
auto insertOp = cast<ParallelInsertSliceOp>(op);
|
auto insertOp = cast<ParallelInsertSliceOp>(op);
|
||||||
return {insertOp.getTiedOpResult()};
|
return {insertOp.getTiedOpResult()};
|
||||||
}
|
}
|
||||||
|
@ -1090,8 +1089,10 @@ struct ParallelInsertSliceOpInterface
|
||||||
// }
|
// }
|
||||||
|
|
||||||
OpBuilder::InsertionGuard g(rewriter);
|
OpBuilder::InsertionGuard g(rewriter);
|
||||||
auto insertOp = cast<ParallelInsertSliceOp>(op);
|
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
|
||||||
auto foreachThreadOp = insertOp->getParentOfType<ForeachThreadOp>();
|
ParallelCombiningOpInterface parallelCombiningParent =
|
||||||
|
parallelInsertSliceOp.getParallelCombiningParent();
|
||||||
|
Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
|
||||||
|
|
||||||
// Nothing to do if the destination tensor is inplace.
|
// Nothing to do if the destination tensor is inplace.
|
||||||
assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
|
assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
|
||||||
|
@ -1100,20 +1101,21 @@ struct ParallelInsertSliceOpInterface
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
// Find corresponding OpResult.
|
// Find corresponding OpResult.
|
||||||
OpResult opResult = insertOp.getTiedOpResult();
|
OpResult opResult = parallelInsertSliceOp.getTiedOpResult();
|
||||||
|
|
||||||
// Insert tensor allocation right before the ForeachThreadOp.
|
// Insert tensor allocation right before the ForeachThreadOp.
|
||||||
rewriter.setInsertionPoint(foreachThreadOp);
|
rewriter.setInsertionPoint(parallelIteratingOp);
|
||||||
bool isYielded = state.isTensorYielded(opResult);
|
bool isYielded = state.isTensorYielded(opResult);
|
||||||
FailureOr<Value> alloc =
|
FailureOr<Value> alloc = allocateTensorForShapedValue(
|
||||||
allocateTensorForShapedValue(rewriter, op->getLoc(), insertOp.getDest(),
|
rewriter, op->getLoc(), parallelInsertSliceOp.getDest(),
|
||||||
/*escape=*/isYielded, state.getOptions());
|
/*escape=*/isYielded, state.getOptions());
|
||||||
if (failed(alloc))
|
if (failed(alloc))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Update destination operand.
|
// Update destination operand.
|
||||||
rewriter.updateRootInPlace(
|
rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() {
|
||||||
insertOp, [&]() { insertOp.getDestMutable().assign(*alloc); });
|
parallelInsertSliceOp.getDestMutable().assign(*alloc);
|
||||||
|
});
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -1121,39 +1123,41 @@ struct ParallelInsertSliceOpInterface
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
const BufferizationOptions &options) const {
|
const BufferizationOptions &options) const {
|
||||||
OpBuilder::InsertionGuard g(rewriter);
|
OpBuilder::InsertionGuard g(rewriter);
|
||||||
auto insertOp = cast<ParallelInsertSliceOp>(op);
|
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
|
||||||
auto performConcurrentlyOp = cast<PerformConcurrentlyOp>(op->getParentOp());
|
ParallelCombiningOpInterface parallelCombiningParent =
|
||||||
auto foreachThreadOp =
|
parallelInsertSliceOp.getParallelCombiningParent();
|
||||||
cast<ForeachThreadOp>(performConcurrentlyOp->getParentOp());
|
Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
|
||||||
|
|
||||||
// Get destination buffer.
|
// Get destination buffer.
|
||||||
FailureOr<Value> destBuffer =
|
FailureOr<Value> destBuffer =
|
||||||
getBuffer(rewriter, insertOp.getDest(), options);
|
getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
|
||||||
if (failed(destBuffer))
|
if (failed(destBuffer))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Bufferize the ParallelInsertSliceOp outside of the PerformConcurrentlyOp.
|
// Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`.
|
||||||
rewriter.setInsertionPoint(performConcurrentlyOp);
|
rewriter.setInsertionPoint(parallelCombiningParent);
|
||||||
FailureOr<Value> srcBuffer =
|
FailureOr<Value> srcBuffer =
|
||||||
getBuffer(rewriter, insertOp.getSource(), options);
|
getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
|
||||||
if (failed(srcBuffer))
|
if (failed(srcBuffer))
|
||||||
return failure();
|
return failure();
|
||||||
Value subview = rewriter.create<memref::SubViewOp>(
|
Value subview = rewriter.create<memref::SubViewOp>(
|
||||||
insertOp.getLoc(), *destBuffer, insertOp.getMixedOffsets(),
|
parallelInsertSliceOp.getLoc(), *destBuffer,
|
||||||
insertOp.getMixedSizes(), insertOp.getMixedStrides());
|
parallelInsertSliceOp.getMixedOffsets(),
|
||||||
|
parallelInsertSliceOp.getMixedSizes(),
|
||||||
|
parallelInsertSliceOp.getMixedStrides());
|
||||||
// This memcpy will fold away if everything bufferizes in-place.
|
// This memcpy will fold away if everything bufferizes in-place.
|
||||||
if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), *srcBuffer,
|
if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
|
||||||
subview)))
|
*srcBuffer, subview)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Replace all uses of ForeachThreadOp (just the corresponding result).
|
// Replace all uses of parallelIteratingOp (just the corresponding result).
|
||||||
rewriter.setInsertionPointAfter(foreachThreadOp);
|
rewriter.setInsertionPointAfter(parallelIteratingOp);
|
||||||
Value toTensorOp =
|
Value toTensorOp =
|
||||||
rewriter.create<ToTensorOp>(foreachThreadOp.getLoc(), *destBuffer);
|
rewriter.create<ToTensorOp>(parallelIteratingOp->getLoc(), *destBuffer);
|
||||||
// PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
|
// PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
|
||||||
SmallVector<OpOperand *> resultUses =
|
SmallVector<OpOperand *> resultUses = llvm::to_vector(
|
||||||
llvm::to_vector(llvm::map_range(insertOp.getTiedOpResult().getUses(),
|
llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(),
|
||||||
[](OpOperand &use) { return &use; }));
|
[](OpOperand &use) { return &use; }));
|
||||||
for (OpOperand *use : resultUses) {
|
for (OpOperand *use : resultUses) {
|
||||||
rewriter.updateRootInPlace(use->getOwner(),
|
rewriter.updateRootInPlace(use->getOwner(),
|
||||||
[&]() { use->set(toTensorOp); });
|
[&]() { use->set(toTensorOp); });
|
||||||
|
|
|
@ -8,6 +8,7 @@ set(LLVM_OPTIONAL_SOURCES
|
||||||
InferIntRangeInterface.cpp
|
InferIntRangeInterface.cpp
|
||||||
InferTypeOpInterface.cpp
|
InferTypeOpInterface.cpp
|
||||||
LoopLikeInterface.cpp
|
LoopLikeInterface.cpp
|
||||||
|
ParallelCombiningOpInterface.cpp
|
||||||
SideEffectInterfaces.cpp
|
SideEffectInterfaces.cpp
|
||||||
TilingInterface.cpp
|
TilingInterface.cpp
|
||||||
VectorInterfaces.cpp
|
VectorInterfaces.cpp
|
||||||
|
@ -38,6 +39,7 @@ add_mlir_interface_library(DataLayoutInterfaces)
|
||||||
add_mlir_interface_library(DerivedAttributeOpInterface)
|
add_mlir_interface_library(DerivedAttributeOpInterface)
|
||||||
add_mlir_interface_library(InferIntRangeInterface)
|
add_mlir_interface_library(InferIntRangeInterface)
|
||||||
add_mlir_interface_library(InferTypeOpInterface)
|
add_mlir_interface_library(InferTypeOpInterface)
|
||||||
|
add_mlir_interface_library(ParallelCombiningOpInterface)
|
||||||
add_mlir_interface_library(SideEffectInterfaces)
|
add_mlir_interface_library(SideEffectInterfaces)
|
||||||
add_mlir_interface_library(TilingInterface)
|
add_mlir_interface_library(TilingInterface)
|
||||||
add_mlir_interface_library(VectorInterfaces)
|
add_mlir_interface_library(VectorInterfaces)
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
//===- ParallelCombiningOpInterface.cpp - Parallel combining op interface -===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ParallelCombiningOpInterface
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// TODO: Single region single block interface on interfaces ?
|
||||||
|
LogicalResult mlir::detail::verifyParallelCombiningOpInterface(Operation *op) {
|
||||||
|
if (op->getNumRegions() != 1)
|
||||||
|
return op->emitError("expected single region op");
|
||||||
|
if (!op->getRegion(0).hasOneBlock())
|
||||||
|
return op->emitError("expected single block op region");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Include the definitions of the interface.
|
||||||
|
#include "mlir/Interfaces/ParallelCombiningOpInterface.cpp.inc"
|
Loading…
Reference in New Issue