[mlir][Tensor] Move ParallelInsertSlice to the tensor dialect

This is moslty NFC and will allow tensor.parallel_insert_slice to gain
rank-reducing semantics by reusing the vast majority of the tensor.insert_slice impl.

Depends on D128857

Differential Revision: https://reviews.llvm.org/D128920
This commit is contained in:
Nicolas Vasilache 2022-06-30 04:27:41 -07:00
parent f0089fae1d
commit 7fbf55c927
17 changed files with 553 additions and 552 deletions

View File

@ -503,115 +503,6 @@ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
}];
}
//===----------------------------------------------------------------------===//
// ParallelInsertSliceOp
//===----------------------------------------------------------------------===//
// TODO: Implement PerformConcurrentlyOpInterface.
def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
AttrSizedOperandSegments,
OffsetSizeAndStrideOpInterface,
// TODO: Cannot use an interface here atm, verify this manually for now.
// HasParent<"ParallelCombiningOpInterface">
]> {
let summary = [{
Specify the tensor slice update of a single thread within the terminator of
an `scf.foreach_thread`.
}];
let description = [{
The parent `scf.foreach_thread` returns values that are formed by aggregating
the actions of all the ops contained within the `perform_concurrently`
terminator of all the threads, in some unspecified order.
The `scf.foreach_thread.parallel_insert_slice` is one such op allowed in
the `scf.foreach_thread.perform_concurrently` terminator.
Conflicting writes result in undefined semantics, in that the indices written
to by multiple parallel updates might contain data from any of the updates, or
even a malformed bit pattern.
If an index is updated exactly once, the value contained at that index
in the resulting tensor will be equal to the value at a corresponding index of a
slice that was used for the updated. If an index is not updated at all, its value
will be equal to the one in the original tensor.
This op does not create a new value, which allows maintaining a clean
separation between the subset and full tensor.
Note that we cannot mark this operation as pure (NoSideEffects), even
though it has no side effects, because it will get DCEd during
canonicalization.
}];
let arguments = (ins
AnyRankedTensor:$source,
AnyRankedTensor:$dest,
Variadic<Index>:$offsets,
Variadic<Index>:$sizes,
Variadic<Index>:$strides,
I64ArrayAttr:$static_offsets,
I64ArrayAttr:$static_sizes,
I64ArrayAttr:$static_strides
);
let assemblyFormat = [{
$source `into` $dest ``
custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
attr-dict `:` type($source) `into` type($dest)
}];
let extraClassDeclaration = [{
::mlir::Operation::operand_range offsets() { return getOffsets(); }
::mlir::Operation::operand_range sizes() { return getSizes(); }
::mlir::Operation::operand_range strides() { return getStrides(); }
::mlir::ArrayAttr static_offsets() { return getStaticOffsets(); }
::mlir::ArrayAttr static_sizes() { return getStaticSizes(); }
::mlir::ArrayAttr static_strides() { return getStaticStrides(); }
Type yieldedType() { return getDest().getType(); }
RankedTensorType getSourceType() {
return getSource().getType().cast<RankedTensorType>();
}
ParallelCombiningOpInterface getParallelCombiningParent() {
return dyn_cast<ParallelCombiningOpInterface>(
getOperation()->getParentOp());
}
/// Return the expected rank of each of the `static_offsets`, `static_sizes`
/// and `static_strides` attributes.
std::array<unsigned, 3> getArrayAttrMaxRanks() {
unsigned rank = getSourceType().getRank();
return {rank, rank, rank};
}
/// Return the number of leading operands before `offsets`, `sizes` and
/// `strides` operands.
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
/// Return the OpResult of the enclosing ForeachThreadOp that is
/// corresponding to this ParallelInsertSliceOp.
OpResult getTiedOpResult();
}];
let builders = [
// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
OpBuilder<(ins "Value":$source, "Value":$dest,
"ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
"ArrayRef<OpFoldResult>":$strides,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build a ParallelInsertSliceOp with dynamic entries.
OpBuilder<(ins "Value":$source, "Value":$dest,
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
];
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// IfOp
//===----------------------------------------------------------------------===//

View File

@ -17,6 +17,7 @@
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"

View File

@ -13,6 +13,7 @@ include "mlir/Dialect/Tensor/IR/TensorBase.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/TilingInterface.td"
include "mlir/Interfaces/ViewLikeInterface.td"
@ -1051,6 +1052,110 @@ def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect,
let hasRegionVerifier = 1;
}
//===----------------------------------------------------------------------===//
// ParallelInsertSliceOp
//===----------------------------------------------------------------------===//
// TODO: Implement PerformConcurrentlyOpInterface.
def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
AttrSizedOperandSegments,
OffsetSizeAndStrideOpInterface,
// TODO: Cannot use an interface here atm, verify this manually for now.
// HasParent<"ParallelCombiningOpInterface">
]> {
let summary = [{
Specify the tensor slice update of a single thread of a parent
ParallelCombiningOpInterface op.
}];
let description = [{
The `parallel_insert_slice` yields a subset tensor value to its parent
ParallelCombiningOpInterface. These subset tensor values are aggregated to
in some unspecified order into a full tensor value returned by the parent
parallel iterating op.
The `parallel_insert_slice` is one such op allowed in the
ParallelCombiningOpInterface op.
Conflicting writes result in undefined semantics, in that the indices written
to by multiple parallel updates might contain data from any of the updates,
or even a malformed bit pattern.
If an index is updated exactly once, the value contained at that index
in the resulting tensor will be equal to the value at a corresponding index
of a slice that was used for the updated. If an index is not updated at all,
its value will be equal to the one in the original tensor.
This op does not create a new value, which allows maintaining a clean
separation between the subset and full tensor.
Note that we cannot mark this operation as pure (NoSideEffects), even
though it has no side effects, because it will get DCEd during
canonicalization.
}];
let arguments = (ins
AnyRankedTensor:$source,
AnyRankedTensor:$dest,
Variadic<Index>:$offsets,
Variadic<Index>:$sizes,
Variadic<Index>:$strides,
I64ArrayAttr:$static_offsets,
I64ArrayAttr:$static_sizes,
I64ArrayAttr:$static_strides
);
let assemblyFormat = [{
$source `into` $dest ``
custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
attr-dict `:` type($source) `into` type($dest)
}];
let extraClassDeclaration = [{
Type yieldedType() { return getDest().getType(); }
RankedTensorType getSourceType() {
return getSource().getType().cast<RankedTensorType>();
}
ParallelCombiningOpInterface getParallelCombiningParent() {
return dyn_cast<ParallelCombiningOpInterface>(
getOperation()->getParentOp());
}
/// Return the expected rank of each of the `static_offsets`, `static_sizes`
/// and `static_strides` attributes.
std::array<unsigned, 3> getArrayAttrMaxRanks() {
unsigned rank = getSourceType().getRank();
return {rank, rank, rank};
}
/// Return the number of leading operands before `offsets`, `sizes` and
/// `strides` operands.
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
/// Return the OpResult of the enclosing ForeachThreadOp that is
/// corresponding to this ParallelInsertSliceOp.
OpResult getTiedOpResult();
}];
let builders = [
// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
OpBuilder<(ins "Value":$source, "Value":$dest,
"ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
"ArrayRef<OpFoldResult>":$strides,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build a ParallelInsertSliceOp with dynamic entries.
OpBuilder<(ins "Value":$source, "Value":$dest,
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
];
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// SplatOp
//===----------------------------------------------------------------------===//

View File

@ -13,7 +13,6 @@ add_mlir_dialect_library(MLIRSCFDialect
MLIRControlFlowDialect
MLIRIR
MLIRLoopLikeInterface
MLIRParallelCombiningOpInterface
MLIRSideEffectInterfaces
)

View File

@ -1211,137 +1211,6 @@ ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) {
return dyn_cast<ForeachThreadOp>(containingOp);
}
//===----------------------------------------------------------------------===//
// ParallelInsertSliceOp
//===----------------------------------------------------------------------===//
OpResult ParallelInsertSliceOp::getTiedOpResult() {
ParallelCombiningOpInterface parallelCombiningParent =
getParallelCombiningParent();
for (const auto &it :
llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
Operation &nextOp = it.value();
if (&nextOp == getOperation())
return parallelCombiningParent.getParentResult(it.index());
}
llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
}
// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
Value source, Value dest,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides,
ArrayRef<NamedAttribute> attrs) {
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
ShapedType::kDynamicStrideOrOffset);
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
ShapedType::kDynamicSize);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
ShapedType::kDynamicStrideOrOffset);
build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
dynamicStrides, b.getI64ArrayAttr(staticOffsets),
b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
result.addAttributes(attrs);
}
// Build a ParallelInsertSliceOp with dynamic entries.
void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
Value source, Value dest, ValueRange offsets,
ValueRange sizes, ValueRange strides,
ArrayRef<NamedAttribute> attrs) {
SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
}
LogicalResult ParallelInsertSliceOp::verify() {
if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
return this->emitError("expected ParallelCombiningOpInterface parent, got:")
<< *(getOperation()->getParentOp());
return success();
}
namespace {
/// Pattern to rewrite a parallel_insert_slice op with constant arguments.
class ParallelInsertSliceOpConstantArgumentFolder final
: public OpRewritePattern<ParallelInsertSliceOp> {
public:
using OpRewritePattern<ParallelInsertSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp,
PatternRewriter &rewriter) const override {
// No constant operand, just return.
if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
return matchPattern(operand, matchConstantIndex());
}))
return failure();
// At least one of offsets/sizes/strides is a new constant.
// Form the new list of operands and constant attributes from the
// existing.
SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
// Create the new op in canonical form.
rewriter.replaceOpWithNewOp<ParallelInsertSliceOp>(
insertSliceOp, insertSliceOp.getSource(), insertSliceOp.getDest(),
mixedOffsets, mixedSizes, mixedStrides);
return success();
}
};
} // namespace
/// Fold a parallel_insert_slice source coming from a tensor.cast op.
///
/// Example:
/// ```
/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
/// %1 = compute_some_tensor() : tensor<64xf32>
/// %2 = tensor.cast %1 : tensor<64xf32> to tensor<?xf32>
/// scf.foreach_thread.perform_concurrently {
/// scf.foreach_thread.parallel_insert_slice %2 into %out[...] [64] [1] :
/// tensor<?xf32> into tensor<128xf32>
/// }
/// }
/// ```
///
/// is folded into:
/// ```
/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
/// %1 = compute_some_tensor() : tensor<64xf32>
/// scf.foreach_thread.perform_concurrently {
/// scf.foreach_thread.parallel_insert_slice %1 into %out[...] [64] [1] :
/// tensor<64xf32> into tensor<128xf32>
/// }
/// }
/// ```
LogicalResult
ParallelInsertSliceOp::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
auto sourceCast = getSource().getDefiningOp<tensor::CastOp>();
if (!sourceCast)
return failure();
getSourceMutable().assign(sourceCast.getSource());
return success();
}
void ParallelInsertSliceOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<ParallelInsertSliceOpConstantArgumentFolder>(context);
}
//===----------------------------------------------------------------------===//
// PerformConcurrentlyOp
//===----------------------------------------------------------------------===//
@ -1355,10 +1224,12 @@ void PerformConcurrentlyOp::build(OpBuilder &b, OperationState &result) {
LogicalResult PerformConcurrentlyOp::verify() {
// TODO: PerformConcurrentlyOpInterface.
for (const Operation &op : getRegion().front().getOperations())
if (!isa<ParallelInsertSliceOp>(op))
return emitOpError(
"expected only scf.foreach_thread.parallel_insert_slice ops");
for (const Operation &op : getRegion().front().getOperations()) {
if (!isa<tensor::ParallelInsertSliceOp>(op)) {
return this->emitOpError("expected only ")
<< tensor::ParallelInsertSliceOp::getOperationName() << " ops";
}
}
return success();
}
@ -1396,7 +1267,7 @@ OpResult PerformConcurrentlyOp::getParentResult(int64_t idx) {
SmallVector<Type> PerformConcurrentlyOp::getYieldedTypes() {
return llvm::to_vector<4>(
llvm::map_range(getYieldingOps(), [](Operation &op) {
auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(&op);
auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
return insertSliceOp ? insertSliceOp.yieldedType() : Type();
}));
}

View File

@ -927,7 +927,7 @@ static SmallVector<OpOperand *>
getInsertionDest(ForeachThreadOp foreachThreadOp) {
PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator();
SmallVector<OpOperand *> result;
terminator.walk([&](ParallelInsertSliceOp insertOp) {
terminator.walk([&](tensor::ParallelInsertSliceOp insertOp) {
result.push_back(&insertOp->getOpOperand(1) /*dest*/);
});
return result;
@ -1004,248 +1004,6 @@ struct PerformConcurrentlyOpInterface
}
};
/// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
/// equivalent operand / result and same offset/sizes/strides specification).
static bool areEquivalentExtractSliceOps(const AnalysisState &state,
ExtractSliceOp st,
ParallelInsertSliceOp sti) {
if (!st || !sti)
return false;
if (st != sti &&
!state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
return false;
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
return false;
return true;
}
/// Return true if `value` is originating from an ExtractSliceOp that matches
/// the given InsertSliceOp.
static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
ParallelInsertSliceOp insertOp) {
auto condition = [&](Value val) {
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
return true;
return false;
};
return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
condition);
}
/// Analysis of ParallelInsertSliceOp.
struct ParallelInsertSliceOpInterface
: public BufferizableOpInterface::ExternalModel<
ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
if (&opOperand != &op->getOpOperand(1) /*dest*/)
return {};
// ParallelInsertSliceOp itself has no results, query its tied op results.
auto insertOp = cast<ParallelInsertSliceOp>(op);
return {insertOp.getTiedOpResult()};
}
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return &opOperand == &op->getOpOperand(1) /*dest*/;
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return BufferRelation::Equivalent;
}
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
const AnalysisState &state) const {
// This interface method is overridden because we want to set a custom
// insertion point for tensor copies. They should be inserted right before
// the ForeachThreadOp. E.g.:
//
// %r0, %r1 = foreach_thead ... {
// ...
// perform_concurrently {
// parallel_insert_slice %a into %b ... {inplace = ["true", "true"]}
// parallel_insert_slice %c into %d ... {inplace = ["true", "false"]}
// }
// }
//
// After TensorCopyInsertion:
//
// %copy = bufferization.alloc_tensor() copy(%d)
// %r0, %r1 = foreach_thead ... {
// ...
// perform_concurrently {
// parallel_insert_slice %a into %b ...
// parallel_insert_slice %c into %copy ...
// }
// }
OpBuilder::InsertionGuard g(rewriter);
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
ParallelCombiningOpInterface parallelCombiningParent =
parallelInsertSliceOp.getParallelCombiningParent();
Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
// Nothing to do if the destination tensor is inplace.
assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
"source is always in-place");
if (state.isInPlace(op->getOpOperand(1) /*dest*/))
return success();
// Find corresponding OpResult.
OpResult opResult = parallelInsertSliceOp.getTiedOpResult();
// Insert tensor allocation right before the ForeachThreadOp.
rewriter.setInsertionPoint(parallelIteratingOp);
bool isYielded = state.isTensorYielded(opResult);
FailureOr<Value> alloc = allocateTensorForShapedValue(
rewriter, op->getLoc(), parallelInsertSliceOp.getDest(),
/*escape=*/isYielded, state.getOptions());
if (failed(alloc))
return failure();
// Update destination operand.
rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() {
parallelInsertSliceOp.getDestMutable().assign(*alloc);
});
return success();
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
OpBuilder::InsertionGuard g(rewriter);
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
ParallelCombiningOpInterface parallelCombiningParent =
parallelInsertSliceOp.getParallelCombiningParent();
Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
// Get destination buffer.
FailureOr<Value> destBuffer =
getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
if (failed(destBuffer))
return failure();
// Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`.
rewriter.setInsertionPoint(parallelCombiningParent);
FailureOr<Value> srcBuffer =
getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
if (failed(srcBuffer))
return failure();
Value subview = rewriter.create<memref::SubViewOp>(
parallelInsertSliceOp.getLoc(), *destBuffer,
parallelInsertSliceOp.getMixedOffsets(),
parallelInsertSliceOp.getMixedSizes(),
parallelInsertSliceOp.getMixedStrides());
// This memcpy will fold away if everything bufferizes in-place.
if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
*srcBuffer, subview)))
return failure();
// Replace all uses of parallelIteratingOp (just the corresponding result).
rewriter.setInsertionPointAfter(parallelIteratingOp);
Value toTensorOp =
rewriter.create<ToTensorOp>(parallelIteratingOp->getLoc(), *destBuffer);
// PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
SmallVector<OpOperand *> resultUses = llvm::to_vector(
llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(),
[](OpOperand &use) { return &use; }));
for (OpOperand *use : resultUses) {
rewriter.updateRootInPlace(use->getOwner(),
[&]() { use->set(toTensorOp); });
}
rewriter.eraseOp(op);
return success();
}
// TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
// the code.
bool isNotConflicting(Operation *op, OpOperand *uRead,
OpOperand *uConflictingWrite,
const AnalysisState &state) const {
Operation *readingOp = uRead->getOwner();
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
// uRead is an InsertSliceOp...
if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) {
// As an example, consider the following IR.
//
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
// %1 = linalg.fill %cst, %0 {inplace= [true] }
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
// {inplace= [true] }
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
insertSliceOp))
// Case 1: The main insight is that InsertSliceOp reads only part of
// the destination tensor. The overwritten area is not read. If
// uConflictingWrite writes into exactly the memory location that is
// being read by uRead, this is not a conflict.
//
// In the above example:
// uRead = OpOperand 1 (%t) of tensor.insert_slice
// uConflictingWrite = OpOperand 1 (%0) of linalg.fill
//
// The read of %t does not conflict with the write of the FillOp
// (same aliases!) because the area that the FillOp operates on is
// exactly the one that is *not* read via %t.
return true;
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
// Case 2: The read of the source tensor and the write to the dest
// tensor via an InsertSliceOp is not a conflict if the read is
// reading exactly that part of an equivalent tensor that the
// InsertSliceOp is writing.
//
// In the above example:
// uRead = OpOperand 0 (%1) of tensor.insert_slice
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
return true;
}
// If uConflictingWrite is an InsertSliceOp...
if (auto insertSliceOp =
dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp))
// As an example, consider the following IR.
//
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
// %1 = linalg.fill %cst, %0 {inplace= [true] }
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
// {inplace= [true] }
// %3 = vector.transfer_read %1, %cst
//
// In the above example:
// uRead = OpOperand 0 (%1) of vector.transfer_read
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
// lastWrite = %1
//
// This is not a conflict because the InsertSliceOp overwrites the
// memory segment of %1 with the exact same data. (Effectively, there
// is no memory write here.)
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
state.areEquivalentBufferizedValues(uRead->get(),
insertSliceOp.getSource()) &&
hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
insertSliceOp))
return true;
return false;
}
};
} // namespace
} // namespace scf
} // namespace mlir
@ -1257,8 +1015,6 @@ void mlir::scf::registerBufferizableOpInterfaceExternalModels(
ForOp::attachInterface<ForOpInterface>(*ctx);
IfOp::attachInterface<IfOpInterface>(*ctx);
ForeachThreadOp::attachInterface<ForeachThreadOpInterface>(*ctx);
ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
*ctx);
PerformConcurrentlyOp::attachInterface<PerformConcurrentlyOpInterface>(
*ctx);
WhileOp::attachInterface<WhileOpInterface>(*ctx);

View File

@ -26,6 +26,7 @@ add_mlir_dialect_library(MLIRTensorDialect
MLIRDialectUtils
MLIRIR
MLIRInferTypeOpInterface
MLIRParallelCombiningOpInterface
MLIRSideEffectInterfaces
MLIRSupport
MLIRViewLikeInterface

View File

@ -2179,6 +2179,137 @@ OpFoldResult PadOp::fold(ArrayRef<Attribute>) {
return {};
}
//===----------------------------------------------------------------------===//
// ParallelInsertSliceOp
//===----------------------------------------------------------------------===//
OpResult ParallelInsertSliceOp::getTiedOpResult() {
ParallelCombiningOpInterface parallelCombiningParent =
getParallelCombiningParent();
for (const auto &it :
llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
Operation &nextOp = it.value();
if (&nextOp == getOperation())
return parallelCombiningParent.getParentResult(it.index());
}
llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
}
// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
Value source, Value dest,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides,
ArrayRef<NamedAttribute> attrs) {
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
ShapedType::kDynamicStrideOrOffset);
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
ShapedType::kDynamicSize);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
ShapedType::kDynamicStrideOrOffset);
build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
dynamicStrides, b.getI64ArrayAttr(staticOffsets),
b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
result.addAttributes(attrs);
}
// Build a ParallelInsertSliceOp with dynamic entries.
void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
Value source, Value dest, ValueRange offsets,
ValueRange sizes, ValueRange strides,
ArrayRef<NamedAttribute> attrs) {
SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
}
LogicalResult ParallelInsertSliceOp::verify() {
if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
return this->emitError("expected ParallelCombiningOpInterface parent, got:")
<< *(getOperation()->getParentOp());
return success();
}
namespace {
/// Pattern to rewrite a parallel_insert_slice op with constant arguments.
class ParallelInsertSliceOpConstantArgumentFolder final
: public OpRewritePattern<ParallelInsertSliceOp> {
public:
using OpRewritePattern<ParallelInsertSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp,
PatternRewriter &rewriter) const override {
// No constant operand, just return.
if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
return matchPattern(operand, matchConstantIndex());
}))
return failure();
// At least one of offsets/sizes/strides is a new constant.
// Form the new list of operands and constant attributes from the
// existing.
SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
// Create the new op in canonical form.
rewriter.replaceOpWithNewOp<ParallelInsertSliceOp>(
insertSliceOp, insertSliceOp.getSource(), insertSliceOp.getDest(),
mixedOffsets, mixedSizes, mixedStrides);
return success();
}
};
} // namespace
/// Fold a parallel_insert_slice source coming from a tensor.cast op.
///
/// Example:
/// ```
/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
/// %1 = compute_some_tensor() : tensor<64xf32>
/// %2 = tensor.cast %1 : tensor<64xf32> to tensor<?xf32>
/// scf.foreach_thread.perform_concurrently {
/// scf.foreach_thread.parallel_insert_slice %2 into %out[...] [64] [1] :
/// tensor<?xf32> into tensor<128xf32>
/// }
/// }
/// ```
///
/// is folded into:
/// ```
/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
/// %1 = compute_some_tensor() : tensor<64xf32>
/// scf.foreach_thread.perform_concurrently {
/// scf.foreach_thread.parallel_insert_slice %1 into %out[...] [64] [1] :
/// tensor<64xf32> into tensor<128xf32>
/// }
/// }
/// ```
LogicalResult
ParallelInsertSliceOp::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
auto sourceCast = getSource().getDefiningOp<tensor::CastOp>();
if (!sourceCast)
return failure();
getSourceMutable().assign(sourceCast.getSource());
return success();
}
void ParallelInsertSliceOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<ParallelInsertSliceOpConstantArgumentFolder>(context);
}
//===----------------------------------------------------------------------===//
// SplatOp
//===----------------------------------------------------------------------===//

View File

@ -810,6 +810,248 @@ struct ReshapeOpInterface
}
};
/// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
/// equivalent operand / result and same offset/sizes/strides specification).
static bool areEquivalentExtractSliceOps(const AnalysisState &state,
ExtractSliceOp st,
ParallelInsertSliceOp sti) {
if (!st || !sti)
return false;
if (st != sti &&
!state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
return false;
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
return false;
return true;
}
/// Return true if `value` is originating from an ExtractSliceOp that matches
/// the given InsertSliceOp.
static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
ParallelInsertSliceOp insertOp) {
auto condition = [&](Value val) {
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
return true;
return false;
};
return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
condition);
}
/// Analysis of ParallelInsertSliceOp.
struct ParallelInsertSliceOpInterface
: public BufferizableOpInterface::ExternalModel<
ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
if (&opOperand != &op->getOpOperand(1) /*dest*/)
return {};
// ParallelInsertSliceOp itself has no results, query its tied op results.
auto insertOp = cast<ParallelInsertSliceOp>(op);
return {insertOp.getTiedOpResult()};
}
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return &opOperand == &op->getOpOperand(1) /*dest*/;
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return BufferRelation::Equivalent;
}
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
const AnalysisState &state) const {
// This interface method is overridden because we want to set a custom
// insertion point for tensor copies. They should be inserted right before
// the ForeachThreadOp. E.g.:
//
// %r0, %r1 = foreach_thead ... {
// ...
// perform_concurrently {
// parallel_insert_slice %a into %b ... {inplace = ["true", "true"]}
// parallel_insert_slice %c into %d ... {inplace = ["true", "false"]}
// }
// }
//
// After TensorCopyInsertion:
//
// %copy = bufferization.alloc_tensor() copy(%d)
// %r0, %r1 = foreach_thead ... {
// ...
// perform_concurrently {
// parallel_insert_slice %a into %b ...
// parallel_insert_slice %c into %copy ...
// }
// }
OpBuilder::InsertionGuard g(rewriter);
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
ParallelCombiningOpInterface parallelCombiningParent =
parallelInsertSliceOp.getParallelCombiningParent();
Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
// Nothing to do if the destination tensor is inplace.
assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
"source is always in-place");
if (state.isInPlace(op->getOpOperand(1) /*dest*/))
return success();
// Find corresponding OpResult.
OpResult opResult = parallelInsertSliceOp.getTiedOpResult();
// Insert tensor allocation right before the ForeachThreadOp.
rewriter.setInsertionPoint(parallelIteratingOp);
bool isYielded = state.isTensorYielded(opResult);
FailureOr<Value> alloc = allocateTensorForShapedValue(
rewriter, op->getLoc(), parallelInsertSliceOp.getDest(),
/*escape=*/isYielded, state.getOptions());
if (failed(alloc))
return failure();
// Update destination operand.
rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() {
parallelInsertSliceOp.getDestMutable().assign(*alloc);
});
return success();
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
OpBuilder::InsertionGuard g(rewriter);
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
ParallelCombiningOpInterface parallelCombiningParent =
parallelInsertSliceOp.getParallelCombiningParent();
Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
// Get destination buffer.
FailureOr<Value> destBuffer =
getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
if (failed(destBuffer))
return failure();
// Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`.
rewriter.setInsertionPoint(parallelCombiningParent);
FailureOr<Value> srcBuffer =
getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
if (failed(srcBuffer))
return failure();
Value subview = rewriter.create<memref::SubViewOp>(
parallelInsertSliceOp.getLoc(), *destBuffer,
parallelInsertSliceOp.getMixedOffsets(),
parallelInsertSliceOp.getMixedSizes(),
parallelInsertSliceOp.getMixedStrides());
// This memcpy will fold away if everything bufferizes in-place.
if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
*srcBuffer, subview)))
return failure();
// Replace all uses of parallelIteratingOp (just the corresponding result).
rewriter.setInsertionPointAfter(parallelIteratingOp);
Value toTensorOp =
rewriter.create<ToTensorOp>(parallelIteratingOp->getLoc(), *destBuffer);
// PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
SmallVector<OpOperand *> resultUses = llvm::to_vector(
llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(),
[](OpOperand &use) { return &use; }));
for (OpOperand *use : resultUses) {
rewriter.updateRootInPlace(use->getOwner(),
[&]() { use->set(toTensorOp); });
}
rewriter.eraseOp(op);
return success();
}
// TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
// the code.
bool isNotConflicting(Operation *op, OpOperand *uRead,
OpOperand *uConflictingWrite,
const AnalysisState &state) const {
Operation *readingOp = uRead->getOwner();
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
// uRead is an InsertSliceOp...
if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) {
// As an example, consider the following IR.
//
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
// %1 = linalg.fill %cst, %0 {inplace= [true] }
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
// {inplace= [true] }
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
insertSliceOp))
// Case 1: The main insight is that InsertSliceOp reads only part of
// the destination tensor. The overwritten area is not read. If
// uConflictingWrite writes into exactly the memory location that is
// being read by uRead, this is not a conflict.
//
// In the above example:
// uRead = OpOperand 1 (%t) of tensor.insert_slice
// uConflictingWrite = OpOperand 1 (%0) of linalg.fill
//
// The read of %t does not conflict with the write of the FillOp
// (same aliases!) because the area that the FillOp operates on is
// exactly the one that is *not* read via %t.
return true;
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
// Case 2: The read of the source tensor and the write to the dest
// tensor via an InsertSliceOp is not a conflict if the read is
// reading exactly that part of an equivalent tensor that the
// InsertSliceOp is writing.
//
// In the above example:
// uRead = OpOperand 0 (%1) of tensor.insert_slice
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
return true;
}
// If uConflictingWrite is an InsertSliceOp...
if (auto insertSliceOp =
dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp))
// As an example, consider the following IR.
//
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
// %1 = linalg.fill %cst, %0 {inplace= [true] }
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
// {inplace= [true] }
// %3 = vector.transfer_read %1, %cst
//
// In the above example:
// uRead = OpOperand 0 (%1) of vector.transfer_read
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
// lastWrite = %1
//
// This is not a conflict because the InsertSliceOp overwrites the
// memory segment of %1 with the exact same data. (Effectively, there
// is no memory write here.)
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
state.areEquivalentBufferizedValues(uRead->get(),
insertSliceOp.getSource()) &&
hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
insertSliceOp))
return true;
return false;
}
};
} // namespace
} // namespace tensor
} // namespace mlir
@ -827,6 +1069,8 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
InsertOp::attachInterface<InsertOpInterface>(*ctx);
InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
*ctx);
RankOp::attachInterface<RankOpInterface>(*ctx);
ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
});

View File

@ -1457,28 +1457,3 @@ func.func @func_execute_region_elim_multi_yield() {
// CHECK: ^[[bb3]](%[[z:.+]]: i64):
// CHECK: "test.bar"(%[[z]])
// CHECK: return
// -----
// CHECK-LABEL: func.func @canonicalize_parallel_insert_slice_indices(
// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[num_threads:[0-9a-z]*]]: index
func.func @canonicalize_parallel_insert_slice_indices(
%arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
%num_threads : index) -> tensor<?x?xf32>
{
%cst = arith.constant 4.200000e+01 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
// CHECK: scf.foreach_thread (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) -> (tensor<?x?xf32>) {
// CHECK-NEXT: scf.foreach_thread.perform_concurrently {
// CHECK-NEXT: scf.foreach_thread.parallel_insert_slice %[[arg0]] into %[[arg1]][%[[tidx]], 0] [1, 5] [1, 1]
%2 = scf.foreach_thread (%tidx) in (%num_threads) -> (tensor<?x?xf32>) {
scf.foreach_thread.perform_concurrently {
scf.foreach_thread.parallel_insert_slice %arg0 into %arg1[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor<?x?xf32> into tensor<?x?xf32>
}
}
return %2 : tensor<?x?xf32>
}

View File

@ -1,36 +1,36 @@
// RUN: mlir-opt %s -scf-for-loop-canonicalization -canonicalize | FileCheck %s
// RUN: mlir-opt %s -scf-for-loop-canonicalization | FileCheck %s
func.func @reduce() -> tensor<128xf32> {
func.func @reduce() {
// CHECK: %[[C64:.*]] = arith.constant 64 : index
%c2 = arith.constant 2 : index
%cst = arith.constant dense<1.000000e+00> : tensor<1x128x384xf32>
%cst_0 = arith.constant -0.000000e+00 : f32
%0 = linalg.init_tensor [128, 384] : tensor<128x384xf32>
%1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<128x384xf32>) -> tensor<128x384xf32>
%2 = linalg.init_tensor [128] : tensor<128xf32>
%3 = linalg.fill ins(%cst_0 : f32) outs(%2 : tensor<128xf32>) -> tensor<128xf32>
%4 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
%0 = memref.alloc() : memref<128x384xf32>
linalg.fill ins(%cst_0 : f32) outs(%0 : memref<128x384xf32>)
%2 = memref.alloc() : memref<128xf32>
linalg.fill ins(%cst_0 : f32) outs(%2 : memref<128xf32>)
scf.foreach_thread (%arg0) in (%c2) {
%7 = affine.min affine_map<(d0) -> (d0 * -64 + 128, 64)>(%arg0)
%8 = affine.max affine_map<(d0) -> (0, d0)>(%7)
%9 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg0)
%10 = affine.min affine_map<(d0, d1) -> (d1 * -64 + 128, d0)>(%8, %arg0)
// CHECK: tensor.extract_slice %{{.*}}[%{{.*}}, 0] [64, 384] [1, 1] : tensor<128x384xf32> to tensor<64x384xf32>
// CHECK: tensor.extract_slice %{{.*}}[%{{.*}}] [64] [1] : tensor<128xf32> to tensor<64xf32>
%11 = tensor.extract_slice %1[%9, 0] [%10, 384] [1, 1] : tensor<128x384xf32> to tensor<?x384xf32>
%12 = tensor.extract_slice %3[%9] [%10] [1] : tensor<128xf32> to tensor<?xf32>
// CHECK: memref.subview %{{.*}}[%{{.*}}, 0] [%[[C64]], 384] [1, 1] : memref<128x384xf32> to memref<?x384xf32, {{.*}}>
// CHECK: memref.subview %{{.*}}[%{{.*}}] [%[[C64]]] [1] : memref<128xf32> to memref<?xf32, {{.*}}>
%11 = memref.subview %0[%9, 0] [%10, 384] [1, 1] :
memref<128x384xf32> to memref<?x384xf32, affine_map<(d0, d1)[s0] -> (d0 * 384 + s0 + d1)>>
%12 = memref.subview %2[%9] [%10] [1] :
memref<128xf32> to memref<?xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<64x384xf32>) outs(%{{.*}} : tensor<64xf32>) {
%13 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%11 : tensor<?x384xf32>) outs(%12 : tensor<?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%14 = arith.addf %arg1, %arg2 : f32
linalg.yield %14 : f32
} -> tensor<?xf32>
// CHECK-NOT: tensor.cast
// CHECK: scf.foreach_thread.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [64] [1] : tensor<64xf32> into tensor<128xf32>
scf.foreach_thread.perform_concurrently {
scf.foreach_thread.parallel_insert_slice %13 into %3[%9] [%10] [1] : tensor<?xf32> into tensor<128xf32>
}
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : memref<?x384xf32, {{.*}}>) outs(%{{.*}} : memref<?xf32, {{.*}}>)
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%11 : memref<?x384xf32, affine_map<(d0, d1)[s0] -> (d0 * 384 + s0 + d1)>>)
outs(%12 : memref<?xf32, affine_map<(d0)[s0] -> (d0 + s0)>>) {
^bb0(%arg1: f32, %arg2: f32):
%14 = arith.addf %arg1, %arg2 : f32
linalg.yield %14 : f32
}
}
return %4 : tensor<128xf32>
return
}

View File

@ -531,7 +531,7 @@ func.func @wrong_num_results(%in: tensor<100xf32>, %out: tensor<100xf32>) {
%result:2 = scf.foreach_thread (%thread_idx) in (%num_threads) -> (tensor<100xf32>, tensor<100xf32>) {
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
scf.foreach_thread.perform_concurrently {
scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
tensor<1xf32> into tensor<100xf32>
}
}
@ -548,7 +548,7 @@ func.func @wrong_type_result(%in: tensor<100xf32>, %out: tensor<100xf32>) {
%result = scf.foreach_thread (%thread_idx) in (%num_threads) -> (tensor<?xf32>) {
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
scf.foreach_thread.perform_concurrently {
scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
tensor<1xf32> into tensor<100xf32>
}
}
@ -563,9 +563,9 @@ func.func @wrong_terminator_op(%in: tensor<100xf32>, %out: tensor<100xf32>) {
%result = scf.foreach_thread (%thread_idx) in (%num_threads) -> (tensor<100xf32>) {
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
// expected-error @+1 {{expected only scf.foreach_thread.parallel_insert_slice ops}}
// expected-error @+1 {{expected only tensor.parallel_insert_slice ops}}
scf.foreach_thread.perform_concurrently {
scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
tensor<1xf32> into tensor<100xf32>
%0 = arith.constant 1: index
}

View File

@ -124,10 +124,10 @@ func.func @scf_foreach_thread_out_of_place(%in: tensor<100xf32>,
%result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<100xf32> {
// CHECK: tensor.extract_slice
// CHECK: scf.foreach_thread.perform_concurrently
// CHECK: scf.foreach_thread.parallel_insert_slice %{{.*}} into %[[alloc]]
// CHECK: tensor.parallel_insert_slice %{{.*}} into %[[alloc]]
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
scf.foreach_thread.perform_concurrently {
scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
tensor<1xf32> into tensor<100xf32>
}
// CHECK: } {thread_dim_mapping = [5]}

View File

@ -537,7 +537,7 @@ func.func @parallel_insert_slice_no_conflict(
// CHECK-NOT: scf.foreach_thread.perform_concurrently
// CHECK-NOT: parallel_insert_slice
scf.foreach_thread.perform_concurrently {
scf.foreach_thread.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] :
tensor.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] :
tensor<?xf32> into tensor<?xf32>
}
}
@ -589,7 +589,7 @@ func.func @parallel_insert_slice_with_conflict(
// CHECK-NOT: scf.foreach_thread.perform_concurrently
// CHECK-NOT: parallel_insert_slice
scf.foreach_thread.perform_concurrently {
scf.foreach_thread.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] :
tensor.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] :
tensor<?xf32> into tensor<?xf32>
}
}
@ -627,7 +627,7 @@ func.func @matmul(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>, %arg2: tensor<
// CHECK: linalg.matmul ins({{.*}}memref<4x8xf32, #[[$DYN_LAYOUT_MAP]]>, memref<8x4xf32, #[[$DYN_LAYOUT_MAP]]>) outs({{.*}} : memref<4x4xf32, #[[$DYN_LAYOUT_MAP]]>)
%8 = linalg.matmul ins(%3, %6 : tensor<4x8xf32>, tensor<8x4xf32>) outs(%7 : tensor<4x4xf32>) -> tensor<4x4xf32>
scf.foreach_thread.perform_concurrently {
scf.foreach_thread.parallel_insert_slice %8 into %arg2[%1, %4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x8xf32>
tensor.parallel_insert_slice %8 into %arg2[%1, %4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x8xf32>
}
}
return %0 : tensor<8x8xf32>

View File

@ -319,14 +319,14 @@ func.func @simple_example(%in: tensor<100xf32>, %out: tensor<100xf32>) {
// CHECK: scf.foreach_thread
// CHECK-NEXT: tensor.extract_slice
// CHECK-NEXT: scf.foreach_thread.perform_concurrently
// CHECK-NEXT: scf.foreach_thread.parallel_insert_slice
// CHECK-NEXT: tensor.parallel_insert_slice
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return
%result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<100xf32> {
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
scf.foreach_thread.perform_concurrently {
scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
tensor<1xf32> into tensor<100xf32>
}
}

View File

@ -1425,3 +1425,28 @@ func.func @cast_extract_slice_rank_reduce(%arg0 : tensor<128x512xf32>, %s : inde
// CHECK: return %[[E]] : tensor<16xf32>
return %1 : tensor<16xf32>
}
// -----
// CHECK-LABEL: func.func @canonicalize_parallel_insert_slice_indices(
// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[num_threads:[0-9a-z]*]]: index
func.func @canonicalize_parallel_insert_slice_indices(
%arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
%num_threads : index) -> tensor<?x?xf32>
{
%cst = arith.constant 4.200000e+01 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
// CHECK: scf.foreach_thread (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) -> (tensor<?x?xf32>) {
// CHECK-NEXT: scf.foreach_thread.perform_concurrently {
// CHECK-NEXT: tensor.parallel_insert_slice %[[arg0]] into %[[arg1]][%[[tidx]], 0] [1, 5] [1, 1]
%2 = scf.foreach_thread (%tidx) in (%num_threads) -> (tensor<?x?xf32>) {
scf.foreach_thread.perform_concurrently {
tensor.parallel_insert_slice %arg0 into %arg1[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor<?x?xf32> into tensor<?x?xf32>
}
}
return %2 : tensor<?x?xf32>
}

View File

@ -4909,6 +4909,7 @@ td_library(
":ControlFlowInterfacesTdFiles",
":InferTypeOpInterfaceTdFiles",
":OpBaseTdFiles",
":ParallelCombiningOpInterfaceTdFiles",
":SideEffectInterfacesTdFiles",
":TilingInterfaceTdFiles",
":ViewLikeInterfaceTdFiles",
@ -4965,6 +4966,7 @@ cc_library(
":DialectUtils",
":IR",
":InferTypeOpInterface",
":ParallelCombiningOpInterface",
":SideEffectInterfaces",
":TensorOpsIncGen",
":TilingInterface",