[mlir][SCF] foreach_thread: Capture shared output tensors explicitly

This change refines the semantics of scf.foreach_thread. Tensors that are inserted into in the terminator must now be passed to the region explicitly via `shared_outs`. Inside of the body of the op, those tensors are then accessed via block arguments.

The body of a scf.foreach_thread is now treated as a repetitive region. I.e., op dominance can no longer be used in conflict detection when using a value that is defined outside of the body. Such uses may now be considered as conflicts (if there is at least one read and one write in the body), effectively privatizing the tensor. Shared outputs are not privatized when they are used via their corresponding block arguments.

As part of this change, it was also necessary to update the "tiling to scf.foreach_thread", such that the generated tensor.extract_slice ops use the scf.foreach_thread's block arguments. This is implemented by cloning the TilingInterface op inside the scf.foreach_thread, rewriting all of its outputs with block arguments and then calling the tiling implementation. Afterwards, the cloned op is deleted again.

Differential Revision: https://reviews.llvm.org/D133114
This commit is contained in:
Matthias Springer 2022-09-02 14:48:35 +02:00
parent f7f0c7f7e3
commit 4cd7362083
15 changed files with 421 additions and 291 deletions

View File

@ -324,6 +324,7 @@ def ForOp : SCF_Op<"for",
//===----------------------------------------------------------------------===//
def ForeachThreadOp : SCF_Op<"foreach_thread", [
AttrSizedOperandSegments,
SingleBlockImplicitTerminator<"scf::PerformConcurrentlyOp">,
RecursiveSideEffects,
AutomaticAllocationScope,
@ -335,6 +336,17 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
parallel body and it takes index operands that indicate how many parallel
instances of that function are created.
The op also takes a variadic number of tensor operands (`shared_outs`).
The future buffers corresponding to these tensors are shared among all
threads. Shared tensors should be accessed via their corresponding block
arguments. If multiple threads write to a shared buffer in a racy
fashion, these writes will execute in some unspecified order. Tensors that
are not shared can be used inside the body (i.e., the op is not isolated
from above); however, if a use of such a tensor bufferizes to a memory
write, the tensor is privatized, i.e., a thread-local copy of the tensor is
used. This ensures that memory side effects of a thread are not visible to
other threads (or in the parent body), apart from explicitly shared tensors.
The name "thread" conveys the fact that the parallel execution is mapped
(i.e. distributed) to a set of virtual threads of execution, one function
application per thread. Further lowerings are responsible for specifying
@ -349,26 +361,20 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
context of the concrete target the op is lowered to, or to ignore it when
the specification is ill-formed or unsupported for a particular target.
The only allowed terminator is `scf.foreach_thread.perform_concurrently`,
which dictates how the partial results of all parallel invocations should be
reconciled into a full value.
The only allowed terminator is `scf.foreach_thread.perform_concurrently`.
`scf.foreach_thread` returns one value per `shared_out` operand. The
actions of the `perform_concurrently` terminators specify how to combine the
partial results of all parallel invocations into a full value, in some
unspecified order. The "destination" of each such op must be a `shared_out`
block argument of the `scf.foreach_thread` op.
`scf.foreach_thread` returns values that are formed by aggregating the
actions of all the `perform_concurrently` terminator of all the virtual
threads, in some unspecified order.
In other words, `scf.foreach_thread` performs all actions specified in the
`perform_concurrently` terminator, after it receives the control back from
its body along each virtual thread of execution.
The actions involved in constructing the return values are further described
by [parallel_insert_slice](#parallelinsertslice-parallelinsertsliceop).
by `tensor.parallel_insert_slice`.
`scf.foreach_thread` acts as an implicit synchronization point.
Multi-value returns are encoded by including multiple operations inside the
`perform_concurrently` block.
When the parallel function body has side effects, the order of reads and
writes to memory is unspecified across threads.
When the parallel function body has side effects, their order is unspecified
across threads.
Example:
@ -377,7 +383,8 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
// Sequential context.
//
%matmul_and_pointwise:2 = scf.foreach_thread (%thread_id_1, %thread_id_2) in
(%num_threads_1, %numthread_id_2) -> (tensor<?x?xT>, tensor<?xT>) {
(%num_threads_1, %numthread_id_2) shared_outs(%o1 = %C, %o2 = %pointwise)
-> (tensor<?x?xT>, tensor<?xT>) {
//
// Parallel context, each thread with id = (%thread_id_1, %thread_id_2)
// runs its version of the code.
@ -386,21 +393,19 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
tensor<?x?xT> to tensor<?x?xT>
%sB = tensor.extract_slice %B[g((%thread_id_1, %thread_id_2))]:
tensor<?x?xT> to tensor<?x?xT>
%sC = tensor.extract_slice %C[h((%thread_id_1, %thread_id_2))]:
%sC = tensor.extract_slice %o1[h((%thread_id_1, %thread_id_2))]:
tensor<?x?xT> to tensor<?x?xT>
%sD = matmul ins(%sA, %sB) outs(%sC)
%spointwise = subtensor %pointwise[i((%thread_id_1, %thread_id_2))]:
%spointwise = subtensor %o2[i((%thread_id_1, %thread_id_2))]:
tensor<?xT> to tensor<?xT>
%sE = add ins(%spointwise) outs(%sD)
scf.foreach_thread.perform_concurrently {
// First op within the parallel terminator contributes to producing %matmul_and_pointwise#0.
scf.foreach_thread.parallel_insert_slice %sD into %C[h((%thread_id_1, %thread_id_2))]:
scf.foreach_thread.parallel_insert_slice %sD into %o1[h((%thread_id_1, %thread_id_2))]:
tensor<?x?xT> into tensor<?x?xT>
// Second op within the parallel terminator contributes to producing %matmul_and_pointwise#1.
scf.foreach_thread.parallel_insert_slice %spointwise into %pointwise[i((%thread_id_1, %thread_id_2))]:
scf.foreach_thread.parallel_insert_slice %spointwise into %o2[i((%thread_id_1, %thread_id_2))]:
tensor<?xT> into tensor<?xT>
}
}
@ -414,7 +419,8 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
// Sequential context.
//
%matmul_and_pointwise:2 = scf.foreach_thread (%thread_id_1, %thread_id_2) in
(%num_threads_1, %numthread_id_2) -> (tensor<?x?xT>, tensor<?xT>) {
(%num_threads_1, %numthread_id_2) shared_outs(...)
-> (tensor<?x?xT>, tensor<?xT>) {
//
// Parallel context, each thread with id = **(%thread_id_2, %thread_id_1)**
// runs its version of the code.
@ -426,9 +432,23 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
// Implicit synchronization point.
// Sequential context.
//
Example with privatized tensors:
%t0 = ...
%t1 = ...
%r = scf.foreach_thread ... shared_outs(%o = t0) -> tensor<?xf32> {
// %t0 and %t1 are privatized. %t0 is definitely copied for each thread
// because the scf.foreach_thread op's %t0 use bufferizes to a memory
// write. In the absence of other conflicts, %t1 is copied only if there
// are uses of %t1 in the body that bufferize to a memory read and to a
// memory write.
"some_use"(%t0)
"some_use"(%t1)
}
}];
let arguments = (ins Variadic<Index>:$num_threads,
DefaultValuedAttr<I64ArrayAttr, "{}">:$thread_dim_mapping);
Variadic<AnyRankedTensor>:$outputs,
DefaultValuedAttr<I64ArrayAttr, "{}">:$thread_dim_mapping);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);
@ -439,19 +459,48 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
// The default builder does not add the proper body BBargs, roll our own.
let skipDefaultBuilders = 1;
let builders = [
// Bodyless builder, result types must be specified.
OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$num_threads,
// Bodyless builder, outputs must be specified.
OpBuilder<(ins "ValueRange":$outputs, "ValueRange":$num_threads,
CArg<"ArrayRef<int64_t>", "{}">:$thread_dim_mapping)>,
// Builder that takes a bodyBuilder lambda, result types are inferred from
// the terminator.
OpBuilder<(ins "ValueRange":$num_threads,
// Builder that takes a bodyBuilder lambda.
OpBuilder<(ins "ValueRange":$outputs, "ValueRange":$num_threads,
"ArrayRef<int64_t>":$thread_dim_mapping,
"function_ref<void(OpBuilder &, Location, ValueRange)>":$bodyBuilder)>
];
let extraClassDeclaration = [{
int64_t getRank() { return getNumThreads().size(); }
::mlir::ValueRange getThreadIndices() { return getBody()->getArguments(); }
::mlir::Value getThreadIndex(int64_t idx) { return getBody()->getArgument(idx); }
OpResult getTiedOpResult(OpOperand *opOperand) {
assert(opOperand->getOperandNumber() >= getRank() && "invalid operand");
return getOperation()->getOpResult(
opOperand->getOperandNumber() - getRank());
}
OpOperand *getTiedOpOperand(BlockArgument bbArg) {
assert(bbArg.getArgNumber() >= getRank() && "invalid bbArg");
return &getOperation()->getOpOperand(bbArg.getArgNumber());
}
BlockArgument getTiedBlockArgument(OpOperand *opOperand) {
assert(opOperand->getOperandNumber() >= getRank() && "invalid operand");
return getBody()->getArgument(opOperand->getOperandNumber());
}
ArrayRef<BlockArgument> getOutputBlockArguments() {
return getBody()->getArguments().drop_front(getRank());
}
::mlir::ValueRange getThreadIndices() {
return getBody()->getArguments().take_front(getRank());
}
::mlir::Value getThreadIndex(int64_t idx) {
return getThreadIndices()[idx];
}
::mlir::Block::BlockArgListType getRegionOutArgs() {
return getBody()->getArguments().drop_front(getRank());
}
// The ensureTerminator method generated by SingleBlockImplicitTerminator is
// unaware of the fact that our terminator also needs a region to be
@ -497,7 +546,7 @@ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
// TODO: Add a `PerformConcurrentlyOpInterface` interface for ops that can
// appear inside perform_concurrently.
let extraClassDeclaration = [{
::llvm::SmallVector<::mlir::Type> getYieldedTypes();
::llvm::SmallVector<::mlir::BlockArgument> getDests();
::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps();
::mlir::OpResult getParentResult(int64_t idx);
}];

View File

@ -17,11 +17,7 @@ include "mlir/IR/OpBase.td"
def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
let description = [{
A parallel combining op is an op with a region, that is not isolated from
above and yields values to its parent op without itself returning an SSA
value. The yielded values are determined by subvalues produced by the ops
contained in the region (the `yieldingOps`) and combined in any unspecified
order to produce the values yielded to the parent op.
A parallel combining op is an op with a region.
This is useful as a terminator to parallel operations that iterate over
some set and return tensors while avoiding tight coupling between the
@ -53,18 +49,6 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
return $_op.getYieldingOps();
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the contained ops that yield subvalues that this op combines to
yield to its parent.
}],
/*retTy=*/"::llvm::SmallVector<::mlir::Type>",
/*methodName=*/"getYieldedTypes",
/*args=*/(ins),
/*methodBody=*/[{
return $_op.getYieldedTypes();
}]
>,
];
// TODO: Single region single block interface on interfaces ?
let verify = [{

View File

@ -235,8 +235,8 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
if (llvm::any_of(loopRanges, hasStrideOne))
return op->emitOpError("only stride-1 supported atm");
// TODO: support `getTiledImplementation` with >1 produced tiled ops.
auto destOperands = op.getDestinationOperands(b);
if (destOperands.size() != 1)
auto dest = op.getDestinationOperands(b);
if (dest.size() != 1)
return op->emitOpError("only single dest operand supported atm");
SmallVector<OpFoldResult> nonZeroNumThreads =
@ -255,8 +255,7 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
// version because we require the use of RewriterBase in the body, so we
// manually move the insertion point to the body below.
scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
loc, op->getResultTypes(), ValueRange(materializedNonZeroNumThreads),
threadDimMapping);
loc, dest, ValueRange(materializedNonZeroNumThreads), threadDimMapping);
// Fill out the ForeachThreadOp body.
b.setInsertionPointToStart(foreachThreadOp.getBody(0));
@ -317,17 +316,34 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
++threadIdIdx;
}
// Clone the tileable op and update its destination operands to use the output
// bbArgs of the ForeachThreadOp.
ArrayRef<BlockArgument> destBbArgs =
foreachThreadOp.getOutputBlockArguments();
Operation *clonedOp = b.clone(*op.getOperation());
auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
if (destinationStyleOp) {
for (OpOperand *outOperand : destinationStyleOp.getOutputOperands()) {
auto it = llvm::find(dest, outOperand->get());
assert(it != dest.end() && "dest operand not found in dest");
unsigned destNum = std::distance(dest.begin(), it);
outOperand->set(destBbArgs[destNum]);
}
}
// Tile the cloned op and delete the clone.
SmallVector<Operation *> tiledOps =
op.getTiledImplementation(b, tiledOffsets, tiledSizes);
cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
tiledSizes);
b.eraseOp(clonedOp);
assert(tiledOps.size() == 1 && "expected a single produced tiled op");
tiledOp = tiledOps.front();
auto tilingInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
assert(tilingInterfaceOp && "Tiled op does not implement TilingInterface");
OpBuilder::InsertPoint insertPt = b.saveInsertionPoint();
for (auto it :
llvm::zip(llvm::seq(unsigned(0), unsigned(destOperands.size())),
tilingInterfaceOp->getResults(), destOperands)) {
for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())),
tilingInterfaceOp->getResults(), destBbArgs)) {
b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint());
SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(op.getResultTilePosition(b, std::get<0>(it), tiledOffsets,

View File

@ -1055,26 +1055,25 @@ LogicalResult ForeachThreadOp::verify() {
if (failed(getTerminator().verify()))
return failure();
// Check that the body defines as single block argument for the thread index.
auto *body = getBody();
if (body->getNumArguments() != getRank())
return emitOpError("region expects ") << getRank() << " arguments";
// Verify consistency between the result types and the terminator.
auto terminatorTypes = getTerminator().getYieldedTypes();
auto opResults = getResults();
if (opResults.size() != terminatorTypes.size())
// Check number of outputs.
if (getNumResults() != getOutputs().size())
return emitOpError("produces ")
<< opResults.size() << " results, but its terminator yields "
<< terminatorTypes.size() << " value(s)";
unsigned i = 0;
for (auto e : llvm::zip(terminatorTypes, opResults)) {
if (std::get<0>(e) != std::get<1>(e).getType())
return emitOpError() << "type mismatch between result " << i << " ("
<< std::get<1>(e).getType() << ") and terminator ("
<< std::get<0>(e) << ")";
i++;
}
<< getNumResults() << " results, but has only "
<< getOutputs().size() << " outputs";
// Check that the body defines block arguments for thread indices and outputs.
auto *body = getBody();
if (body->getNumArguments() != getRank() + getOutputs().size())
return emitOpError("region expects ") << getRank() << " arguments";
for (int64_t i = 0; i < getRank(); ++i)
if (!body->getArgument(i).getType().isIndex())
return emitOpError("expects ")
<< i << "-th block argument to be an index";
for (unsigned i = 0; i < getOutputs().size(); ++i)
if (body->getArgument(i + getRank()).getType() != getOutputs()[i].getType())
return emitOpError("type mismatch between ")
<< i << "-th output and corresponding block argument";
return success();
}
@ -1083,11 +1082,16 @@ void ForeachThreadOp::print(OpAsmPrinter &p) {
llvm::interleaveComma(getThreadIndices(), p);
p << ") in (";
llvm::interleaveComma(getNumThreads(), p);
p << ") -> (" << getResultTypes() << ") ";
p << ")";
printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
p << " ";
if (!getRegionOutArgs().empty())
p << "-> (" << getResultTypes() << ") ";
p.printRegion(getRegion(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/getNumResults() > 0);
p.printOptionalAttrDict(getOperation()->getAttrs());
p.printOptionalAttrDict(getOperation()->getAttrs(),
{"operand_segment_sizes"});
}
ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
@ -1109,15 +1113,34 @@ ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
result.operands))
return failure();
// Parse optional results.
if (parser.parseOptionalArrowTypeList(result.types))
return failure();
// Parse out operands and results.
SmallVector<OpAsmParser::Argument, 4> regionOutArgs;
SmallVector<OpAsmParser::UnresolvedOperand, 4> outOperands;
SMLoc outOperandsLoc = parser.getCurrentLocation();
if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
if (outOperands.size() != result.types.size())
return parser.emitError(outOperandsLoc,
"mismatch between out operands and types");
if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
parser.parseOptionalArrowTypeList(result.types) ||
parser.resolveOperands(outOperands, result.types, outOperandsLoc,
result.operands))
return failure();
}
// Parse region.
SmallVector<OpAsmParser::Argument, 4> regionArgs;
std::unique_ptr<Region> region = std::make_unique<Region>();
for (auto &idx : threadIndices)
for (auto &idx : threadIndices) {
idx.type = builder.getIndexType();
if (parser.parseRegion(*region, threadIndices))
regionArgs.push_back(idx);
}
for (const auto &it : llvm::enumerate(regionOutArgs)) {
auto &out = it.value();
out.type = result.types[it.index()];
regionArgs.push_back(out);
}
if (parser.parseRegion(*region, regionArgs))
return failure();
// Ensure terminator and move region.
@ -1128,19 +1151,27 @@ ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
// Parse the optional attribute list.
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
result.addAttribute("operand_segment_sizes",
parser.getBuilder().getDenseI32ArrayAttr(
{static_cast<int32_t>(threadNums.size()),
static_cast<int32_t>(outOperands.size())}));
return success();
}
// Bodyless builder, result types must be specified.
// Bodyless builder, outputs must be specified.
void ForeachThreadOp::build(mlir::OpBuilder &builder,
mlir::OperationState &result, TypeRange resultTypes,
mlir::OperationState &result, ValueRange outputs,
ValueRange numThreads,
ArrayRef<int64_t> threadDimMapping) {
result.addOperands(numThreads);
result.addOperands(outputs);
result.addAttribute(ForeachThreadOp::getThreadDimMappingAttrName(result.name),
builder.getI64ArrayAttr(threadDimMapping));
result.addAttribute(
// TODO: getThreadDimMappingAttrName() but it is not a static member.
"thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping));
"operand_segment_sizes",
builder.getDenseI32ArrayAttr({static_cast<int32_t>(numThreads.size()),
static_cast<int32_t>(outputs.size())}));
result.addTypes(TypeRange(outputs));
Region *bodyRegion = result.addRegion();
OpBuilder::InsertionGuard g(builder);
@ -1149,40 +1180,51 @@ void ForeachThreadOp::build(mlir::OpBuilder &builder,
// expects it ..
builder.createBlock(bodyRegion);
Block &bodyBlock = bodyRegion->front();
// Add block arguments for indices and outputs.
bodyBlock.addArguments(
SmallVector<Type>(numThreads.size(), builder.getIndexType()),
SmallVector<Location>(numThreads.size(), result.location));
bodyBlock.addArguments(
TypeRange(outputs),
SmallVector<Location>(outputs.size(), result.location));
ForeachThreadOp::ensureTerminator(*bodyRegion, builder, result.location);
result.addTypes(resultTypes);
}
// Builder that takes a bodyBuilder lambda, result types are inferred from
// the terminator.
// Builder that takes a bodyBuilder lambda.
void ForeachThreadOp::build(
mlir::OpBuilder &builder, mlir::OperationState &result,
mlir::OpBuilder &builder, mlir::OperationState &result, ValueRange outputs,
ValueRange numThreads, ArrayRef<int64_t> threadDimMapping,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
result.addOperands(numThreads);
result.addOperands(outputs);
result.addAttribute(ForeachThreadOp::getThreadDimMappingAttrName(result.name),
builder.getI64ArrayAttr(threadDimMapping));
result.addAttribute(
// TODO: getThreadDimMappingAttrName() but it is not a static member.
"thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping));
"operand_segment_sizes",
builder.getDenseI32ArrayAttr({static_cast<int32_t>(numThreads.size()),
static_cast<int32_t>(outputs.size())}));
result.addTypes(TypeRange(outputs));
OpBuilder::InsertionGuard g(builder);
Region *bodyRegion = result.addRegion();
OpBuilder::InsertionGuard g(builder);
builder.createBlock(bodyRegion);
Block &bodyBlock = bodyRegion->front();
// Add block arguments for indices and outputs.
bodyBlock.addArguments(
SmallVector<Type>(numThreads.size(), builder.getIndexType()),
SmallVector<Location>(numThreads.size(), result.location));
bodyBlock.addArguments(
TypeRange(outputs),
SmallVector<Location>(outputs.size(), result.location));
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(&bodyBlock);
bodyBuilder(builder, result.location, bodyBlock.getArguments());
#ifndef NDEBUG
auto terminator =
llvm::dyn_cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
assert(terminator &&
"expected bodyBuilder to create PerformConcurrentlyOp terminator");
result.addTypes(terminator.getYieldedTypes());
#endif // NDEBUG
}
// The ensureTerminator method generated by SingleBlockImplicitTerminator is
@ -1223,12 +1265,23 @@ void PerformConcurrentlyOp::build(OpBuilder &b, OperationState &result) {
}
LogicalResult PerformConcurrentlyOp::verify() {
scf::ForeachThreadOp foreachThreadOp =
dyn_cast<scf::ForeachThreadOp>(getOperation()->getParentOp());
if (!foreachThreadOp)
return this->emitOpError("expected foreach_thread op parent");
// TODO: PerformConcurrentlyOpInterface.
for (const Operation &op : getRegion().front().getOperations()) {
for (Operation &op : getRegion().front().getOperations()) {
if (!isa<tensor::ParallelInsertSliceOp>(op)) {
return this->emitOpError("expected only ")
<< tensor::ParallelInsertSliceOp::getOperationName() << " ops";
}
// Verify that inserts are into out block arguments.
Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
ArrayRef<BlockArgument> regionOutArgs = foreachThreadOp.getRegionOutArgs();
if (llvm::find(regionOutArgs, dest) == regionOutArgs.end())
return op.emitOpError("may only insert into an output block argument");
}
return success();
}
@ -1264,11 +1317,12 @@ OpResult PerformConcurrentlyOp::getParentResult(int64_t idx) {
return getOperation()->getParentOp()->getResult(idx);
}
SmallVector<Type> PerformConcurrentlyOp::getYieldedTypes() {
SmallVector<BlockArgument> PerformConcurrentlyOp::getDests() {
return llvm::to_vector<4>(
llvm::map_range(getYieldingOps(), [](Operation &op) {
auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
return insertSliceOp ? insertSliceOp.yieldedType() : Type();
// Add new ops here as needed.
auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
return insertSliceOp.getDest().cast<BlockArgument>();
}));
}

View File

@ -1054,18 +1054,6 @@ struct YieldOpInterface
}
};
/// Return the destinations that an ForeachThreadOp is inserting into. One per
/// ParallelInsertSliceOp.
static SmallVector<OpOperand *>
getInsertionDest(ForeachThreadOp foreachThreadOp) {
PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator();
SmallVector<OpOperand *> result;
terminator.walk([&](tensor::ParallelInsertSliceOp insertOp) {
result.push_back(&insertOp->getOpOperand(1) /*dest*/);
});
return result;
}
/// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the
/// region. There are op interfaces for the terminators (PerformConcurrentlyOp
/// and ParallelInsertSliceOp), but these are only used during analysis. Not
@ -1073,57 +1061,114 @@ getInsertionDest(ForeachThreadOp foreachThreadOp) {
struct ForeachThreadOpInterface
: public BufferizableOpInterface::ExternalModel<ForeachThreadOpInterface,
ForeachThreadOp> {
SmallVector<OpOperand *>
getAliasingOpOperand(Operation *op, OpResult opResult,
const AnalysisState &state) const {
// Get OpOperand (dest) from corresponding ParallelInsertSliceOp.
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
// scf::ForeachThreadOp alone doesn't bufferize to a memory read, one of the
// uses of its matching bbArg may.
auto foreachThreadOp = cast<ForeachThreadOp>(op);
return {getInsertionDest(foreachThreadOp)[opResult.getResultNumber()]};
return state.isValueRead(foreachThreadOp.getTiedBlockArgument(&opOperand));
}
bool isMemoryWrite(Operation *op, OpResult opResult,
const AnalysisState &state) const {
// This op is a memory write. Stop lookup here to avoid finding false
// conflicts involving this op and one of the ops in the region. This is
// similar to how scf.if ops are analyzed.
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
// Outputs of scf::ForeachThreadOps are always considered as a write.
return true;
}
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
auto foreachThreadOp = cast<ForeachThreadOp>(op);
return {foreachThreadOp.getTiedOpResult(&opOperand)};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return BufferRelation::Equivalent;
}
bool isWritable(Operation *op, Value value,
const AnalysisState &state) const {
return true;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
OpBuilder::InsertionGuard guard(rewriter);
auto foreachThreadOp = cast<ForeachThreadOp>(op);
int64_t rank = foreachThreadOp.getRank();
#ifndef NDEBUG
// ParallelInsertSliceOpInterface replaces all uses.
for (OpResult opResult : foreachThreadOp->getOpResults())
assert(opResult.getUses().empty() &&
"expected that all uses were already replaced");
#endif // NDEBUG
// Get buffers for all output operands.
SmallVector<Value> buffers;
for (Value out : foreachThreadOp.getOutputs()) {
FailureOr<Value> buffer = getBuffer(rewriter, out, options);
if (failed(buffer))
return failure();
buffers.push_back(*buffer);
}
// Use buffers instead of block arguments.
rewriter.setInsertionPointToStart(foreachThreadOp.getBody());
for (const auto &it :
llvm::zip(foreachThreadOp.getBody()->getArguments().drop_front(rank),
buffers)) {
BlockArgument bbArg = std::get<0>(it);
Value buffer = std::get<1>(it);
Value bufferAsTensor =
rewriter.create<ToTensorOp>(foreachThreadOp.getLoc(), buffer);
bbArg.replaceAllUsesWith(bufferAsTensor);
}
// Create new ForeachThreadOp without any results and drop the automatically
// introduced terminator.
TypeRange newResultTypes;
rewriter.setInsertionPoint(foreachThreadOp);
auto newForeachThreadOp = rewriter.create<ForeachThreadOp>(
foreachThreadOp.getLoc(), newResultTypes,
foreachThreadOp.getLoc(), /*outputs=*/ValueRange(),
foreachThreadOp.getNumThreads(),
extractFromI64ArrayAttr(foreachThreadOp.getThreadDimMapping()));
newForeachThreadOp.getBody()->getTerminator()->erase();
// Move over block contents of the old op.
SmallVector<Value> replacementBbArgs;
replacementBbArgs.append(
newForeachThreadOp.getBody()->getArguments().begin(),
newForeachThreadOp.getBody()->getArguments().end());
replacementBbArgs.append(foreachThreadOp.getOutputs().size(), Value());
rewriter.mergeBlocks(foreachThreadOp.getBody(),
newForeachThreadOp.getBody(),
{newForeachThreadOp.getBody()->getArguments()});
newForeachThreadOp.getBody(), replacementBbArgs);
// Remove the old op.
rewriter.eraseOp(op);
// Remove the old op and replace all of its uses.
replaceOpWithBufferizedValues(rewriter, op, buffers);
return success();
}
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
auto foreachThreadOp = cast<ForeachThreadOp>(op);
if (auto bbArg = value.dyn_cast<BlockArgument>())
// A tensor block argument has the same bufferized type as the
// corresponding output operand.
return bufferization::getBufferType(
foreachThreadOp.getTiedOpOperand(bbArg)->get(), options, fixedTypes);
// The bufferized result type is the same as the bufferized type of the
// corresponding output operand.
return bufferization::getBufferType(
foreachThreadOp.getOutputs()[value.cast<OpResult>().getResultNumber()],
options, fixedTypes);
}
bool isRepetitiveRegion(Operation *op, unsigned index) const {
auto foreachThreadOp = cast<ForeachThreadOp>(op);
// This op is not repetitive if it has just a single thread.
if (llvm::all_of(foreachThreadOp.getNumThreads(), [](Value v) {
return getConstantIntValue(v) == static_cast<int64_t>(1);
}))
return false;
return true;
}
};
/// Nothing to do for PerformConcurrentlyOp.

View File

@ -922,12 +922,7 @@ struct ParallelInsertSliceOpInterface
ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
if (&opOperand != &op->getOpOperand(1) /*dest*/)
return {};
// ParallelInsertSliceOp itself has no results, query its tied op results.
auto insertOp = cast<ParallelInsertSliceOp>(op);
return {insertOp.getTiedOpResult()};
}
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
@ -940,84 +935,21 @@ struct ParallelInsertSliceOpInterface
return &opOperand == &op->getOpOperand(1) /*dest*/;
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return BufferRelation::Equivalent;
}
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
const AnalysisState &state) const {
// This interface method is overridden because we want to set a custom
// insertion point for tensor copies. They should be inserted right before
// the ForeachThreadOp. E.g.:
//
// %r0, %r1 = foreach_thead ... {
// ...
// perform_concurrently {
// parallel_insert_slice %a into %b ... {inplace = ["true", "true"]}
// parallel_insert_slice %c into %d ... {inplace = ["true", "false"]}
// }
// }
//
// After TensorCopyInsertion:
//
// %copy = bufferization.alloc_tensor() copy(%d)
// %r0, %r1 = foreach_thead ... {
// ...
// perform_concurrently {
// parallel_insert_slice %a into %b ...
// parallel_insert_slice %c into %copy ...
// }
// }
OpBuilder::InsertionGuard g(rewriter);
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
ParallelCombiningOpInterface parallelCombiningParent =
parallelInsertSliceOp.getParallelCombiningParent();
Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
// Nothing to do if the destination tensor is inplace.
assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
"source is always in-place");
if (state.isInPlace(op->getOpOperand(1) /*dest*/))
return success();
// Find corresponding OpResult.
OpResult opResult = parallelInsertSliceOp.getTiedOpResult();
// Insert tensor allocation right before the ForeachThreadOp.
rewriter.setInsertionPoint(parallelIteratingOp);
bool isYielded = state.isTensorYielded(opResult);
FailureOr<Value> alloc = allocateTensorForShapedValue(
rewriter, op->getLoc(), parallelInsertSliceOp.getDest(),
/*escape=*/isYielded, state.getOptions());
if (failed(alloc))
return failure();
// Update destination operand.
rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() {
parallelInsertSliceOp.getDestMutable().assign(*alloc);
});
return success();
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
OpBuilder::InsertionGuard g(rewriter);
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
ParallelCombiningOpInterface parallelCombiningParent =
parallelInsertSliceOp.getParallelCombiningParent();
Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
// Get destination buffer.
// Bufferize the op outside of the parallel combining terminator.
rewriter.setInsertionPoint(parallelCombiningParent);
// Get source and destination buffers.
FailureOr<Value> destBuffer =
getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
if (failed(destBuffer))
return failure();
// Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`.
rewriter.setInsertionPoint(parallelCombiningParent);
FailureOr<Value> srcBuffer =
getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
if (failed(srcBuffer))
@ -1043,18 +975,7 @@ struct ParallelInsertSliceOpInterface
*srcBuffer, subview)))
return failure();
// Replace all uses of parallelIteratingOp (just the corresponding result).
rewriter.setInsertionPointAfter(parallelIteratingOp);
Value toTensorOp =
rewriter.create<ToTensorOp>(parallelIteratingOp->getLoc(), *destBuffer);
// PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
SmallVector<OpOperand *> resultUses = llvm::to_vector(
llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(),
[](OpOperand &use) { return &use; }));
for (OpOperand *use : resultUses) {
rewriter.updateRootInPlace(use->getOwner(),
[&]() { use->set(toTensorOp); });
}
// Delete the op.
rewriter.eraseOp(op);
return success();
}

View File

@ -835,16 +835,16 @@ func.func @reduce_dispatch_0() -> tensor<4x2xf32> {
%c4 = arith.constant 4 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.init_tensor [4, 2] : tensor<4x2xf32>
%res = scf.foreach_thread (%arg0, %arg1) in (%c4, %c2) -> (tensor<4x2xf32>) {
%res = scf.foreach_thread (%arg0, %arg1) in (%c4, %c2) shared_outs(%o = %0) -> (tensor<4x2xf32>) {
%1 = linalg.init_tensor [1, 1] : tensor<1x1xf32>
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1xf32>) -> tensor<1x1xf32>
scf.foreach_thread.perform_concurrently {
// CHECK: tensor.parallel_insert_slice %{{[0-9a-z]*}} into %{{[0-9a-z]*}}
// CHECK-SAME: [%{{.*}}, %{{.*}}] [1, 1] [1, 1] : tensor<f32> into tensor<4x2xf32>
tensor.parallel_insert_slice %2 into %0[%arg0, %arg1] [1, 1] [1, 1] :
tensor.parallel_insert_slice %2 into %o[%arg0, %arg1] [1, 1] [1, 1] :
tensor<1x1xf32> into tensor<4x2xf32>
}
}
}
return %res: tensor<4x2xf32>
}

View File

@ -15,15 +15,15 @@ module {
func.func @matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
// CHECK-DAG: %[[C20:.*]] = arith.constant 20 : index
// CHECK: scf.foreach_thread ({{.*}}) in (%[[C10]], %[[C20]]) -> (tensor<?x?xf32>) {
// CHECK: scf.foreach_thread ({{.*}}) in (%[[C10]], %[[C20]]) shared_outs(%[[C_BLK:.*]] = %[[C]]) -> (tensor<?x?xf32>) {
// CHECK: %[[tA:.*]] = tensor.extract_slice %[[A]]{{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[tB:.*]] = tensor.extract_slice %[[B]]{{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[tC:.*]] = tensor.extract_slice %[[C]]{{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[tC:.*]] = tensor.extract_slice %[[C_BLK]]{{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[RES:.*]] = linalg.matmul
// CHECK-SAME: ins(%[[tA]], %[[tB]] : tensor<?x?xf32>, tensor<?x?xf32>)
// CHECK-SAME: outs(%[[tC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: scf.foreach_thread.perform_concurrently {
// CHECK-NEXT: tensor.parallel_insert_slice %[[RES]] into %[[C]]{{.*}} :
// CHECK-NEXT: tensor.parallel_insert_slice %[[RES]] into %[[C_BLK]]{{.*}} :
// CHECK-SAME: tensor<?x?xf32> into tensor<?x?xf32>
// CHECK-NEXT: }
// CHECK-NEXT: } {thread_dim_mapping = [1, 0]}
@ -55,10 +55,10 @@ module {
// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor
// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor
// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor
func.func @matmul_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> {
func.func @matmul_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> {
// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index
// CHECK-DAG: %[[c21:.+]] = arith.constant 21 : index
// CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c21]])
// CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c21]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
// CHECK: %[[TSMIN:.+]] = affine.min #[[$map0]](%[[IV1]])
// CHECK: %[[TS:.+]] = affine.max #[[$map1]](%[[TSMIN]])
// CHECK-NOT: affine.min
@ -67,7 +67,7 @@ func.func @matmul_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: t
// CHECK: %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]])
// CHECK: %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] :
// CHECK: %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] :
// CHECK: %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] :
// CHECK: %[[tC:.+]] = tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] :
// CHECK: linalg.matmul
// CHECK: scf.foreach_thread.perform_concurrently
// CHECK-NEXT: tensor.parallel_insert_slice
@ -104,14 +104,14 @@ func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C
// CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 :
// CHECK: %[[NT0:.+]] = affine.apply #map0()[%[[M]]]
// CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
// CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]])
// CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
// CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
// CHECK: %[[TS1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
// CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
// CHECK tensor.extract_slice %[[A]]
// CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
// CHECK tensor.extract_slice %[[B]]
// CHECK tensor.extract_slice %[[C]]
// CHECK: tensor.extract_slice %[[A]]
// CHECK: tensor.extract_slice %[[B]]
// CHECK: tensor.extract_slice %[[C_BLK]]
// CHECK: linalg.matmul
// CHECK: scf.foreach_thread.perform_concurrently
// CHECK-NEXT: tensor.parallel_insert_slice
@ -144,7 +144,7 @@ transform.with_pdl_patterns {
func.func @matmul_tile_size_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> {
// CHECK-DAG: %[[c10:.+]] = arith.constant 10 :
// CHECK-DAG: %[[c15:.+]] = arith.constant 15 :
// CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c15]])
// CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c15]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
// CHECK: %[[TS:.+]] = affine.min #[[$map0]](%[[IV1]])
// CHECK-NOT: affine.max
// CHECK-NOT: affine.min
@ -152,7 +152,7 @@ func.func @matmul_tile_size_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf
// CHECK: %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]])
// CHECK: %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] :
// CHECK: %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] :
// CHECK: %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] :
// CHECK: %[[tC:.+]] = tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] :
// CHECK: linalg.matmul
// CHECK: scf.foreach_thread.perform_concurrently
// CHECK-NEXT: tensor.parallel_insert_slice
@ -199,7 +199,7 @@ module {
// CHECK-LABEL: extract_source(
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: scf.foreach_thread (%[[ARG:.*]]) in (%[[C2]]) -> (tensor<4xf32>) {
// CHECK: scf.foreach_thread (%[[ARG:.*]]) in (%[[C2]]) shared_outs(%{{.*}} = %{{.*}}) -> (tensor<4xf32>) {
// CHECK: %[[OFF:.*]] = affine.apply #[[$map0]](%[[ARG]])
// CHECK: scf.foreach_thread.perform_concurrently {
// CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%[[OFF]]] [2] [1] : tensor<2xf32> into tensor<4xf32>
@ -227,10 +227,10 @@ func.func @matmul_tile_size_dynamic_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?x
// CHECK-DAG: %[[N:.+]] = tensor.dim %[[B]], %c1 :
// CHECK-DAG: %[[NT0:.+]] = affine.apply #[[$map0]]()[%[[M]], %[[tile_size]]]
// CHECK-DAG: %[[NT1:.+]] = affine.apply #[[$map1]]()[%[[N]]]
// CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]])
// CHECK tensor.extract_slice %[[A]]
// CHECK tensor.extract_slice %[[B]]
// CHECK tensor.extract_slice %[[C]]
// CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
// CHECK: tensor.extract_slice %[[A]]
// CHECK: tensor.extract_slice %[[B]]
// CHECK: tensor.extract_slice %[[C_BLK]]
// CHECK: linalg.matmul
// CHECK: scf.foreach_thread.perform_concurrently
// CHECK-NEXT: tensor.parallel_insert_slice

View File

@ -17,10 +17,10 @@ module {
%1 = affine.apply #map0()[%d0, %arg0]
// CHECK: scf.foreach_thread {{.*}} {
%2 = scf.foreach_thread (%arg3) in (%1) -> (tensor<?xf32>) {
%2 = scf.foreach_thread (%arg3) in (%1) shared_outs(%o = %arg2) -> (tensor<?xf32>) {
%3 = affine.apply #map1(%arg3)[%arg0]
%4 = affine.min #map2(%arg3)[%d0, %arg0]
%5 = tensor.extract_slice %arg2[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
%5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
// CHECK: %[[T0:.*]] = tensor.extract_slice %[[IN]][%{{.*}}] [%{{.*}}] [{{.*}}]
// CHECK: %[[T1:.*]] = linalg.fill {{.*}} outs(%[[T0]]
@ -29,7 +29,7 @@ module {
// CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]]
%7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
scf.foreach_thread.perform_concurrently {
tensor.parallel_insert_slice %7 into %arg2[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
}
}
// CHECK: }
@ -70,16 +70,16 @@ module {
%1 = affine.apply #map0()[%arg0]
// CHECK: scf.foreach_thread {{.*}} {
%2 = scf.foreach_thread (%arg3) in (%1) -> (tensor<64xf32>) {
%2 = scf.foreach_thread (%arg3) in (%1) shared_outs(%o = %arg2) -> (tensor<64xf32>) {
// CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor
%3 = affine.apply #map1(%arg3)[%arg0]
%4 = affine.min #map2(%arg3)[%arg0]
%5 = tensor.extract_slice %arg2[%3] [%4] [1] : tensor<64xf32> to tensor<?xf32>
%5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<64xf32> to tensor<?xf32>
// CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[INIT_TENSOR]]
%7 = linalg.elemwise_unary ins(%0 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
scf.foreach_thread.perform_concurrently {
tensor.parallel_insert_slice %7 into %arg2[%3] [%4] [1] : tensor<?xf32> into tensor<64xf32>
tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<64xf32>
}
}
// CHECK: }

View File

@ -527,11 +527,11 @@ func.func @wrong_num_results(%in: tensor<100xf32>, %out: tensor<100xf32>) {
%c1 = arith.constant 1 : index
%num_threads = arith.constant 100 : index
// expected-error @+1 {{produces 2 results, but its terminator yields 1 value(s)}}
%result:2 = scf.foreach_thread (%thread_idx) in (%num_threads) -> (tensor<100xf32>, tensor<100xf32>) {
// expected-error @+1 {{1 operands present, but expected 2}}
%result:2 = scf.foreach_thread (%thread_idx) in (%num_threads) shared_outs(%o = %out) -> (tensor<100xf32>, tensor<100xf32>) {
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
scf.foreach_thread.perform_concurrently {
tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] :
tensor<1xf32> into tensor<100xf32>
}
}
@ -540,14 +540,14 @@ func.func @wrong_num_results(%in: tensor<100xf32>, %out: tensor<100xf32>) {
// -----
func.func @wrong_type_result(%in: tensor<100xf32>, %out: tensor<100xf32>) {
func.func @invalid_insert_dest(%in: tensor<100xf32>, %out: tensor<100xf32>) {
%c1 = arith.constant 1 : index
%num_threads = arith.constant 100 : index
// expected-error @+1 {{type mismatch between result 0 ('tensor<?xf32>') and terminator ('tensor<100xf32>')}}
%result = scf.foreach_thread (%thread_idx) in (%num_threads) -> (tensor<?xf32>) {
%result = scf.foreach_thread (%thread_idx) in (%num_threads) shared_outs(%o = %out) -> (tensor<100xf32>) {
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
scf.foreach_thread.perform_concurrently {
// expected-error @+1 {{may only insert into an output block argument}}
tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
tensor<1xf32> into tensor<100xf32>
}
@ -561,11 +561,11 @@ func.func @wrong_terminator_op(%in: tensor<100xf32>, %out: tensor<100xf32>) {
%c1 = arith.constant 1 : index
%num_threads = arith.constant 100 : index
%result = scf.foreach_thread (%thread_idx) in (%num_threads) -> (tensor<100xf32>) {
%result = scf.foreach_thread (%thread_idx) in (%num_threads) shared_outs(%o = %out) -> (tensor<100xf32>) {
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
// expected-error @+1 {{expected only tensor.parallel_insert_slice ops}}
scf.foreach_thread.perform_concurrently {
tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] :
tensor<1xf32> into tensor<100xf32>
%0 = arith.constant 1: index
}

View File

@ -120,14 +120,14 @@ func.func @scf_foreach_thread_out_of_place(%in: tensor<100xf32>,
// CHECK-FUNC-NOT: alloc_tensor
// CHECK: %[[alloc:.*]] = bufferization.alloc_tensor() copy(%[[arg1]]) {bufferization.escape = [false]} : tensor<100xf32>
// CHECK: scf.foreach_thread
%result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<100xf32> {
// CHECK: scf.foreach_thread {{.*}} shared_outs(%[[o:.*]] = %[[alloc]])
%result = scf.foreach_thread (%thread_idx) in (%num_threads) shared_outs(%o = %out) -> tensor<100xf32> {
// CHECK: tensor.extract_slice
// CHECK: scf.foreach_thread.perform_concurrently
// CHECK: tensor.parallel_insert_slice %{{.*}} into %[[alloc]]
// CHECK: tensor.parallel_insert_slice %{{.*}} into %[[o]]
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
scf.foreach_thread.perform_concurrently {
tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] :
tensor<1xf32> into tensor<100xf32>
}
// CHECK: } {thread_dim_mapping = [5]}

View File

@ -525,10 +525,10 @@ func.func @parallel_insert_slice_no_conflict(
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
// CHECK: scf.foreach_thread (%[[tidx:.*]]) in (%[[idx2]]) -> ()
%2 = scf.foreach_thread (%arg3) in (%idx2) -> (tensor<?xf32>) {
// CHECK: scf.foreach_thread (%[[tidx:.*]]) in (%[[idx2]])
%2 = scf.foreach_thread (%arg3) in (%idx2) shared_outs(%o = %arg2) -> (tensor<?xf32>) {
// CHECK: %[[subview:.*]] = memref.subview %[[arg2]][5] [%[[idx]]] [1]
%6 = tensor.extract_slice %arg2[5] [%idx] [%c1] : tensor<?xf32> to tensor<?xf32>
%6 = tensor.extract_slice %o[5] [%idx] [%c1] : tensor<?xf32> to tensor<?xf32>
// CHECK: linalg.fill ins(%{{.*}}) outs(%[[subview]] : memref<?xf32
%8 = linalg.fill ins(%cst : f32) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
// Self-copy will DCE away later.
@ -538,7 +538,7 @@ func.func @parallel_insert_slice_no_conflict(
// CHECK-NOT: scf.foreach_thread.perform_concurrently
// CHECK-NOT: parallel_insert_slice
scf.foreach_thread.perform_concurrently {
tensor.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] :
tensor.parallel_insert_slice %8 into %o[5] [%idx] [%c1] :
tensor<?xf32> into tensor<?xf32>
}
}
@ -571,26 +571,22 @@ func.func @parallel_insert_slice_with_conflict(
// CHECK: %[[alloc1:.*]] = memref.alloc
// CHECK: memref.copy %[[arg2]], %[[alloc1]]
// CHECK: scf.foreach_thread (%[[tidx:.*]]) in (%[[idx2]]) -> ()
%2 = scf.foreach_thread (%arg3) in (%idx2) -> (tensor<?xf32>) {
// Another alloc for the extract_slice op.
// CHECK: %[[alloc2:.*]] = memref.alloc
%6 = tensor.extract_slice %arg2[5] [%idx] [%c1] : tensor<?xf32> to tensor<?xf32>
// CHECK: scf.foreach_thread (%[[tidx:.*]]) in (%[[idx2]])
%2 = scf.foreach_thread (%arg3) in (%idx2) shared_outs(%o = %arg2) -> (tensor<?xf32>) {
// CHECK: %[[subview1:.*]] = memref.subview %[[alloc1]][5] [%[[idx]]] [1]
%6 = tensor.extract_slice %o[5] [%idx] [%c1] : tensor<?xf32> to tensor<?xf32>
// CHECK: linalg.fill ins(%{{.*}}) outs(%[[alloc2]] : memref<?xf32
// CHECK: linalg.fill ins(%{{.*}}) outs(%[[subview1]] : memref<?xf32
%8 = linalg.fill ins(%cst : f32) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
// Now the copy of the actual insert_slice.
// CHECK: %[[subview1:.*]] = memref.subview %[[alloc1]][5] [%[[idx]]] [1]
//
// CHECK: memref.copy %[[alloc2]], %[[subview1]]
// CHECK: memref.dealloc %[[alloc2]]
// Now the copy of the actual insert_slice. (It will fold away.)
// CHECK: memref.copy %[[subview1]], %[[subview1]]
// Empty terminator is elided from pretty-printing.
// CHECK-NOT: scf.foreach_thread.perform_concurrently
// CHECK-NOT: parallel_insert_slice
scf.foreach_thread.perform_concurrently {
tensor.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] :
tensor.parallel_insert_slice %8 into %o[5] [%idx] [%c1] :
tensor<?xf32> into tensor<?xf32>
}
}
@ -617,18 +613,18 @@ func.func @matmul(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>, %arg2: tensor<
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
// CHECK: scf.foreach_thread {{.*}} -> ()
%0 = scf.foreach_thread (%arg3, %arg4) in (%c2, %c4) -> (tensor<8x8xf32>) {
// CHECK: scf.foreach_thread {{.*}}
%0 = scf.foreach_thread (%arg3, %arg4) in (%c2, %c4) shared_outs(%o = %arg2) -> (tensor<8x8xf32>) {
%1 = affine.apply #map0(%arg3)
%3 = tensor.extract_slice %arg0[%1, 0] [4, 8] [1, 1] : tensor<8x8xf32> to tensor<4x8xf32>
%4 = affine.apply #map1(%arg4)
%6 = tensor.extract_slice %arg1[0, %4] [8, 4] [1, 1] : tensor<8x8xf32> to tensor<8x4xf32>
%7 = tensor.extract_slice %arg2[%1, %4] [4, 4] [1, 1] : tensor<8x8xf32> to tensor<4x4xf32>
%7 = tensor.extract_slice %o[%1, %4] [4, 4] [1, 1] : tensor<8x8xf32> to tensor<4x4xf32>
// CHECK: linalg.matmul ins({{.*}}memref<4x8xf32, #[[$DYN_LAYOUT_MAP]]>, memref<8x4xf32, #[[$DYN_LAYOUT_MAP]]>) outs({{.*}} : memref<4x4xf32, #[[$DYN_LAYOUT_MAP]]>)
%8 = linalg.matmul ins(%3, %6 : tensor<4x8xf32>, tensor<8x4xf32>) outs(%7 : tensor<4x4xf32>) -> tensor<4x4xf32>
scf.foreach_thread.perform_concurrently {
tensor.parallel_insert_slice %8 into %arg2[%1, %4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x8xf32>
tensor.parallel_insert_slice %8 into %o[%1, %4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x8xf32>
}
}
return %0 : tensor<8x8xf32>
@ -636,6 +632,71 @@ func.func @matmul(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>, %arg2: tensor<
// -----
// CHECK-LABEL: func @scf_foreach_private_var(
// CHECK-SAME: %[[t:.*]]: memref<10xf32
func.func @scf_foreach_private_var(%t: tensor<10xf32>) -> f32 {
%c2 = arith.constant 2 : index
%c5 = arith.constant 5 : index
// A copy is inserted for the uses of %t in the loop.
// CHECK: %[[t_copy:.*]] = memref.alloc() {{.*}} : memref<10xf32>
// CHECK: memref.copy %[[t]], %[[t_copy]]
// CHECK: scf.foreach_thread (%{{.*}}) in (%{{.*}}) {
// Load from the copy and store into the shared output.
// CHECK: %[[subview:.*]] = memref.subview %[[t]]
// CHECK: memref.load %[[t_copy]]
// CHECK: memref.store %{{.*}}, %[[subview]]
%0 = scf.foreach_thread (%tid) in (%c2) shared_outs(%o = %t) -> tensor<10xf32> {
%offset = arith.muli %c5, %tid : index
%slice = tensor.extract_slice %o[%offset] [5] [1]
: tensor<10xf32> to tensor<5xf32>
%r2 = tensor.extract %t[%tid] : tensor<10xf32>
%i = tensor.insert %r2 into %slice[%c2] : tensor<5xf32>
scf.foreach_thread.perform_concurrently {
tensor.parallel_insert_slice %i into %o[%offset] [5] [1]
: tensor<5xf32> into tensor<10xf32>
}
}
%r = tensor.extract %0[%c2] : tensor<10xf32>
return %r : f32
}
// -----
// CHECK-LABEL: func.func @scf_foreach_privatized_but_not_copied(
// CHECK-SAME: %[[t0:.*]]: memref<10xf32, {{.*}}>, %[[t1:.*]]: memref<10xf32
func.func @scf_foreach_privatized_but_not_copied(
%t0: tensor<10xf32>, %t1: tensor<10xf32>) -> f32 {
%c2 = arith.constant 2 : index
%c5 = arith.constant 5 : index
// CHECK-NOT: memref.alloc
// CHECK-NOT: memref.copy
// CHECK: scf.foreach_thread {{.*}} {
%0 = scf.foreach_thread (%tid) in (%c2) shared_outs(%o = %t0) -> tensor<10xf32> {
%offset = arith.muli %c5, %tid : index
%slice = tensor.extract_slice %o[%offset] [5] [1]
: tensor<10xf32> to tensor<5xf32>
// %t1 is never written in here, so no copy is needed
// CHECK: memref.load %[[t1]]
%r2 = tensor.extract %t1[%tid] : tensor<10xf32>
%i = tensor.insert %r2 into %slice[%c2] : tensor<5xf32>
scf.foreach_thread.perform_concurrently {
tensor.parallel_insert_slice %i into %o[%offset] [5] [1]
: tensor<5xf32> into tensor<10xf32>
}
}
%r = tensor.extract %0[%c2] : tensor<10xf32>
return %r : f32
}
// -----
// CHECK-LABEL: func @scf_if_memory_space
func.func @scf_if_memory_space(%c: i1, %f: f32) -> (f32, f32)
{

View File

@ -323,10 +323,10 @@ func.func @simple_example(%in: tensor<100xf32>, %out: tensor<100xf32>) {
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return
%result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<100xf32> {
%result = scf.foreach_thread (%thread_idx) in (%num_threads) shared_outs(%o = %out) -> tensor<100xf32> {
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
scf.foreach_thread.perform_concurrently {
tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] :
tensor<1xf32> into tensor<100xf32>
}
}
@ -340,7 +340,7 @@ func.func @elide_terminator() -> () {
// CHECK: scf.foreach_thread
// CHECK-NEXT: } {thread_dim_mapping = [42]}
// CHECK-NEXT: return
scf.foreach_thread (%thread_idx) in (%num_threads) -> () {
scf.foreach_thread (%thread_idx) in (%num_threads) {
scf.foreach_thread.perform_concurrently {
}
} {thread_dim_mapping = [42]}

View File

@ -1455,13 +1455,13 @@ func.func @canonicalize_parallel_insert_slice_indices(
%c1 = arith.constant 1 : index
// CHECK-NOT: tensor.cast
// CHECK: scf.foreach_thread (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) -> (tensor<?x?xf32>) {
// CHECK: scf.foreach_thread (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) shared_outs(%[[o:.*]] = %[[arg1]]) -> (tensor<?x?xf32>) {
// CHECK-NEXT: scf.foreach_thread.perform_concurrently {
// CHECK-NEXT: tensor.parallel_insert_slice %[[arg0]] into %[[arg1]][%[[tidx]], 0] [1, 5] [1, 1]
%2 = scf.foreach_thread (%tidx) in (%num_threads) -> (tensor<?x?xf32>) {
// CHECK-NEXT: tensor.parallel_insert_slice %[[arg0]] into %[[o]][%[[tidx]], 0] [1, 5] [1, 1]
%2 = scf.foreach_thread (%tidx) in (%num_threads) shared_outs(%o = %arg1) -> (tensor<?x?xf32>) {
%3 = tensor.cast %arg0 : tensor<1x5xf32> to tensor<?x5xf32>
scf.foreach_thread.perform_concurrently {
tensor.parallel_insert_slice %3 into %arg1[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor<?x5xf32> into tensor<?x?xf32>
tensor.parallel_insert_slice %3 into %o[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor<?x5xf32> into tensor<?x?xf32>
}
}
return %2 : tensor<?x?xf32>
@ -1477,12 +1477,12 @@ func.func @dont_fold_parallel_insert_slice(
{
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
// CHECK: scf.foreach_thread () in () -> (tensor<1x5xf32>) {
// CHECK: scf.foreach_thread () in () shared_outs(%[[o:.*]] = %[[arg1]]) -> (tensor<1x5xf32>) {
// CHECK-NEXT: scf.foreach_thread.perform_concurrently {
// CHECK-NEXT: tensor.parallel_insert_slice %[[arg0]] into %[[arg1]][0, 0] [1, 5] [1, 1] : tensor<1x5xf32> into tensor<1x5xf32>
%2 = scf.foreach_thread () in () -> (tensor<1x5xf32>) {
// CHECK-NEXT: tensor.parallel_insert_slice %[[arg0]] into %[[o]][0, 0] [1, 5] [1, 1] : tensor<1x5xf32> into tensor<1x5xf32>
%2 = scf.foreach_thread () in () shared_outs(%o = %arg1) -> (tensor<1x5xf32>) {
scf.foreach_thread.perform_concurrently {
tensor.parallel_insert_slice %arg0 into %arg1[%c0, %c0] [1, 5] [%c1, %c1] : tensor<1x5xf32> into tensor<1x5xf32>
tensor.parallel_insert_slice %arg0 into %o[%c0, %c0] [1, 5] [%c1, %c1] : tensor<1x5xf32> into tensor<1x5xf32>
}
}
return %2 : tensor<1x5xf32>

View File

@ -205,12 +205,12 @@ func.func @rank_reducing_parallel_insert_slice(%in: tensor<100xf32>, %out: tenso
%num_threads = arith.constant 100 : index
// CHECK: scf.foreach_thread {{.*}} {
%result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<200x100xf32> {
%result = scf.foreach_thread (%thread_idx) in (%num_threads) shared_outs (%o = %out) -> tensor<200x100xf32> {
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
scf.foreach_thread.perform_concurrently {
// CHECK: memref.subview %{{.*}}[%{{.*}}] [1] [1] : memref<100xf32, #[[$MAP0]]> to memref<1xf32, #[[$MAP0]]>
// CHECK: memref.subview %{{.*}}[1, %{{.*}}] [1, 1] [1, 1] : memref<200x100xf32, #[[$MAP1]]> to memref<1xf32, #[[$MAP0]]>
tensor.parallel_insert_slice %1 into %out[1, %thread_idx][1, 1][1, 1] :
tensor.parallel_insert_slice %1 into %o[1, %thread_idx][1, 1][1, 1] :
tensor<1xf32> into tensor<200x100xf32>
}
}