forked from OSchip/llvm-project
[mlir][Tensor] Move ParallelInsertSlice to the tensor dialect
This is moslty NFC and will allow tensor.parallel_insert_slice to gain rank-reducing semantics by reusing the vast majority of the tensor.insert_slice impl. Depends on D128857 Differential Revision: https://reviews.llvm.org/D128920
This commit is contained in:
parent
f0089fae1d
commit
7fbf55c927
|
@ -503,115 +503,6 @@ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
|
|||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ParallelInsertSliceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO: Implement PerformConcurrentlyOpInterface.
|
||||
def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
|
||||
AttrSizedOperandSegments,
|
||||
OffsetSizeAndStrideOpInterface,
|
||||
// TODO: Cannot use an interface here atm, verify this manually for now.
|
||||
// HasParent<"ParallelCombiningOpInterface">
|
||||
]> {
|
||||
let summary = [{
|
||||
Specify the tensor slice update of a single thread within the terminator of
|
||||
an `scf.foreach_thread`.
|
||||
}];
|
||||
let description = [{
|
||||
The parent `scf.foreach_thread` returns values that are formed by aggregating
|
||||
the actions of all the ops contained within the `perform_concurrently`
|
||||
terminator of all the threads, in some unspecified order.
|
||||
The `scf.foreach_thread.parallel_insert_slice` is one such op allowed in
|
||||
the `scf.foreach_thread.perform_concurrently` terminator.
|
||||
|
||||
Conflicting writes result in undefined semantics, in that the indices written
|
||||
to by multiple parallel updates might contain data from any of the updates, or
|
||||
even a malformed bit pattern.
|
||||
|
||||
If an index is updated exactly once, the value contained at that index
|
||||
in the resulting tensor will be equal to the value at a corresponding index of a
|
||||
slice that was used for the updated. If an index is not updated at all, its value
|
||||
will be equal to the one in the original tensor.
|
||||
|
||||
This op does not create a new value, which allows maintaining a clean
|
||||
separation between the subset and full tensor.
|
||||
Note that we cannot mark this operation as pure (NoSideEffects), even
|
||||
though it has no side effects, because it will get DCEd during
|
||||
canonicalization.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyRankedTensor:$source,
|
||||
AnyRankedTensor:$dest,
|
||||
Variadic<Index>:$offsets,
|
||||
Variadic<Index>:$sizes,
|
||||
Variadic<Index>:$strides,
|
||||
I64ArrayAttr:$static_offsets,
|
||||
I64ArrayAttr:$static_sizes,
|
||||
I64ArrayAttr:$static_strides
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$source `into` $dest ``
|
||||
custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
|
||||
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
|
||||
custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
|
||||
attr-dict `:` type($source) `into` type($dest)
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::Operation::operand_range offsets() { return getOffsets(); }
|
||||
::mlir::Operation::operand_range sizes() { return getSizes(); }
|
||||
::mlir::Operation::operand_range strides() { return getStrides(); }
|
||||
::mlir::ArrayAttr static_offsets() { return getStaticOffsets(); }
|
||||
::mlir::ArrayAttr static_sizes() { return getStaticSizes(); }
|
||||
::mlir::ArrayAttr static_strides() { return getStaticStrides(); }
|
||||
|
||||
Type yieldedType() { return getDest().getType(); }
|
||||
|
||||
RankedTensorType getSourceType() {
|
||||
return getSource().getType().cast<RankedTensorType>();
|
||||
}
|
||||
|
||||
ParallelCombiningOpInterface getParallelCombiningParent() {
|
||||
return dyn_cast<ParallelCombiningOpInterface>(
|
||||
getOperation()->getParentOp());
|
||||
}
|
||||
|
||||
/// Return the expected rank of each of the `static_offsets`, `static_sizes`
|
||||
/// and `static_strides` attributes.
|
||||
std::array<unsigned, 3> getArrayAttrMaxRanks() {
|
||||
unsigned rank = getSourceType().getRank();
|
||||
return {rank, rank, rank};
|
||||
}
|
||||
|
||||
/// Return the number of leading operands before `offsets`, `sizes` and
|
||||
/// `strides` operands.
|
||||
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
|
||||
|
||||
/// Return the OpResult of the enclosing ForeachThreadOp that is
|
||||
/// corresponding to this ParallelInsertSliceOp.
|
||||
OpResult getTiedOpResult();
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
|
||||
OpBuilder<(ins "Value":$source, "Value":$dest,
|
||||
"ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
|
||||
"ArrayRef<OpFoldResult>":$strides,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
|
||||
|
||||
// Build a ParallelInsertSliceOp with dynamic entries.
|
||||
OpBuilder<(ins "Value":$source, "Value":$dest,
|
||||
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
|
||||
];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IfOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "mlir/Interfaces/CastInterfaces.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Interfaces/TilingInterface.h"
|
||||
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||
|
|
|
@ -13,6 +13,7 @@ include "mlir/Dialect/Tensor/IR/TensorBase.td"
|
|||
include "mlir/Interfaces/CastInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/TilingInterface.td"
|
||||
include "mlir/Interfaces/ViewLikeInterface.td"
|
||||
|
@ -1051,6 +1052,110 @@ def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect,
|
|||
let hasRegionVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ParallelInsertSliceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO: Implement PerformConcurrentlyOpInterface.
|
||||
def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
|
||||
AttrSizedOperandSegments,
|
||||
OffsetSizeAndStrideOpInterface,
|
||||
// TODO: Cannot use an interface here atm, verify this manually for now.
|
||||
// HasParent<"ParallelCombiningOpInterface">
|
||||
]> {
|
||||
let summary = [{
|
||||
Specify the tensor slice update of a single thread of a parent
|
||||
ParallelCombiningOpInterface op.
|
||||
}];
|
||||
let description = [{
|
||||
The `parallel_insert_slice` yields a subset tensor value to its parent
|
||||
ParallelCombiningOpInterface. These subset tensor values are aggregated to
|
||||
in some unspecified order into a full tensor value returned by the parent
|
||||
parallel iterating op.
|
||||
The `parallel_insert_slice` is one such op allowed in the
|
||||
ParallelCombiningOpInterface op.
|
||||
|
||||
Conflicting writes result in undefined semantics, in that the indices written
|
||||
to by multiple parallel updates might contain data from any of the updates,
|
||||
or even a malformed bit pattern.
|
||||
|
||||
If an index is updated exactly once, the value contained at that index
|
||||
in the resulting tensor will be equal to the value at a corresponding index
|
||||
of a slice that was used for the updated. If an index is not updated at all,
|
||||
its value will be equal to the one in the original tensor.
|
||||
|
||||
This op does not create a new value, which allows maintaining a clean
|
||||
separation between the subset and full tensor.
|
||||
|
||||
Note that we cannot mark this operation as pure (NoSideEffects), even
|
||||
though it has no side effects, because it will get DCEd during
|
||||
canonicalization.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyRankedTensor:$source,
|
||||
AnyRankedTensor:$dest,
|
||||
Variadic<Index>:$offsets,
|
||||
Variadic<Index>:$sizes,
|
||||
Variadic<Index>:$strides,
|
||||
I64ArrayAttr:$static_offsets,
|
||||
I64ArrayAttr:$static_sizes,
|
||||
I64ArrayAttr:$static_strides
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$source `into` $dest ``
|
||||
custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
|
||||
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
|
||||
custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
|
||||
attr-dict `:` type($source) `into` type($dest)
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
Type yieldedType() { return getDest().getType(); }
|
||||
|
||||
RankedTensorType getSourceType() {
|
||||
return getSource().getType().cast<RankedTensorType>();
|
||||
}
|
||||
|
||||
ParallelCombiningOpInterface getParallelCombiningParent() {
|
||||
return dyn_cast<ParallelCombiningOpInterface>(
|
||||
getOperation()->getParentOp());
|
||||
}
|
||||
|
||||
/// Return the expected rank of each of the `static_offsets`, `static_sizes`
|
||||
/// and `static_strides` attributes.
|
||||
std::array<unsigned, 3> getArrayAttrMaxRanks() {
|
||||
unsigned rank = getSourceType().getRank();
|
||||
return {rank, rank, rank};
|
||||
}
|
||||
|
||||
/// Return the number of leading operands before `offsets`, `sizes` and
|
||||
/// `strides` operands.
|
||||
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
|
||||
|
||||
/// Return the OpResult of the enclosing ForeachThreadOp that is
|
||||
/// corresponding to this ParallelInsertSliceOp.
|
||||
OpResult getTiedOpResult();
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
|
||||
OpBuilder<(ins "Value":$source, "Value":$dest,
|
||||
"ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
|
||||
"ArrayRef<OpFoldResult>":$strides,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
|
||||
|
||||
// Build a ParallelInsertSliceOp with dynamic entries.
|
||||
OpBuilder<(ins "Value":$source, "Value":$dest,
|
||||
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
|
||||
];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SplatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -13,7 +13,6 @@ add_mlir_dialect_library(MLIRSCFDialect
|
|||
MLIRControlFlowDialect
|
||||
MLIRIR
|
||||
MLIRLoopLikeInterface
|
||||
MLIRParallelCombiningOpInterface
|
||||
MLIRSideEffectInterfaces
|
||||
)
|
||||
|
||||
|
|
|
@ -1211,137 +1211,6 @@ ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) {
|
|||
return dyn_cast<ForeachThreadOp>(containingOp);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ParallelInsertSliceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpResult ParallelInsertSliceOp::getTiedOpResult() {
|
||||
ParallelCombiningOpInterface parallelCombiningParent =
|
||||
getParallelCombiningParent();
|
||||
for (const auto &it :
|
||||
llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
|
||||
Operation &nextOp = it.value();
|
||||
if (&nextOp == getOperation())
|
||||
return parallelCombiningParent.getParentResult(it.index());
|
||||
}
|
||||
llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
|
||||
}
|
||||
|
||||
// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
|
||||
void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
|
||||
Value source, Value dest,
|
||||
ArrayRef<OpFoldResult> offsets,
|
||||
ArrayRef<OpFoldResult> sizes,
|
||||
ArrayRef<OpFoldResult> strides,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
|
||||
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
|
||||
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
|
||||
ShapedType::kDynamicStrideOrOffset);
|
||||
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
|
||||
ShapedType::kDynamicSize);
|
||||
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
|
||||
ShapedType::kDynamicStrideOrOffset);
|
||||
build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
|
||||
dynamicStrides, b.getI64ArrayAttr(staticOffsets),
|
||||
b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
|
||||
result.addAttributes(attrs);
|
||||
}
|
||||
|
||||
// Build a ParallelInsertSliceOp with dynamic entries.
|
||||
void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
|
||||
Value source, Value dest, ValueRange offsets,
|
||||
ValueRange sizes, ValueRange strides,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
|
||||
llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
|
||||
SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
|
||||
llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
|
||||
SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
|
||||
llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
|
||||
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
|
||||
}
|
||||
|
||||
LogicalResult ParallelInsertSliceOp::verify() {
|
||||
if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
|
||||
return this->emitError("expected ParallelCombiningOpInterface parent, got:")
|
||||
<< *(getOperation()->getParentOp());
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Pattern to rewrite a parallel_insert_slice op with constant arguments.
|
||||
class ParallelInsertSliceOpConstantArgumentFolder final
|
||||
: public OpRewritePattern<ParallelInsertSliceOp> {
|
||||
public:
|
||||
using OpRewritePattern<ParallelInsertSliceOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// No constant operand, just return.
|
||||
if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
|
||||
return matchPattern(operand, matchConstantIndex());
|
||||
}))
|
||||
return failure();
|
||||
|
||||
// At least one of offsets/sizes/strides is a new constant.
|
||||
// Form the new list of operands and constant attributes from the
|
||||
// existing.
|
||||
SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
|
||||
SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
|
||||
SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
|
||||
canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
|
||||
canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
|
||||
canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
|
||||
|
||||
// Create the new op in canonical form.
|
||||
rewriter.replaceOpWithNewOp<ParallelInsertSliceOp>(
|
||||
insertSliceOp, insertSliceOp.getSource(), insertSliceOp.getDest(),
|
||||
mixedOffsets, mixedSizes, mixedStrides);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
/// Fold a parallel_insert_slice source coming from a tensor.cast op.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
|
||||
/// %1 = compute_some_tensor() : tensor<64xf32>
|
||||
/// %2 = tensor.cast %1 : tensor<64xf32> to tensor<?xf32>
|
||||
/// scf.foreach_thread.perform_concurrently {
|
||||
/// scf.foreach_thread.parallel_insert_slice %2 into %out[...] [64] [1] :
|
||||
/// tensor<?xf32> into tensor<128xf32>
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// is folded into:
|
||||
/// ```
|
||||
/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
|
||||
/// %1 = compute_some_tensor() : tensor<64xf32>
|
||||
/// scf.foreach_thread.perform_concurrently {
|
||||
/// scf.foreach_thread.parallel_insert_slice %1 into %out[...] [64] [1] :
|
||||
/// tensor<64xf32> into tensor<128xf32>
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
LogicalResult
|
||||
ParallelInsertSliceOp::fold(ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
auto sourceCast = getSource().getDefiningOp<tensor::CastOp>();
|
||||
if (!sourceCast)
|
||||
return failure();
|
||||
getSourceMutable().assign(sourceCast.getSource());
|
||||
return success();
|
||||
}
|
||||
|
||||
void ParallelInsertSliceOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &results, MLIRContext *context) {
|
||||
results.add<ParallelInsertSliceOpConstantArgumentFolder>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PerformConcurrentlyOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1355,10 +1224,12 @@ void PerformConcurrentlyOp::build(OpBuilder &b, OperationState &result) {
|
|||
|
||||
LogicalResult PerformConcurrentlyOp::verify() {
|
||||
// TODO: PerformConcurrentlyOpInterface.
|
||||
for (const Operation &op : getRegion().front().getOperations())
|
||||
if (!isa<ParallelInsertSliceOp>(op))
|
||||
return emitOpError(
|
||||
"expected only scf.foreach_thread.parallel_insert_slice ops");
|
||||
for (const Operation &op : getRegion().front().getOperations()) {
|
||||
if (!isa<tensor::ParallelInsertSliceOp>(op)) {
|
||||
return this->emitOpError("expected only ")
|
||||
<< tensor::ParallelInsertSliceOp::getOperationName() << " ops";
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1396,7 +1267,7 @@ OpResult PerformConcurrentlyOp::getParentResult(int64_t idx) {
|
|||
SmallVector<Type> PerformConcurrentlyOp::getYieldedTypes() {
|
||||
return llvm::to_vector<4>(
|
||||
llvm::map_range(getYieldingOps(), [](Operation &op) {
|
||||
auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(&op);
|
||||
auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
|
||||
return insertSliceOp ? insertSliceOp.yieldedType() : Type();
|
||||
}));
|
||||
}
|
||||
|
|
|
@ -927,7 +927,7 @@ static SmallVector<OpOperand *>
|
|||
getInsertionDest(ForeachThreadOp foreachThreadOp) {
|
||||
PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator();
|
||||
SmallVector<OpOperand *> result;
|
||||
terminator.walk([&](ParallelInsertSliceOp insertOp) {
|
||||
terminator.walk([&](tensor::ParallelInsertSliceOp insertOp) {
|
||||
result.push_back(&insertOp->getOpOperand(1) /*dest*/);
|
||||
});
|
||||
return result;
|
||||
|
@ -1004,248 +1004,6 @@ struct PerformConcurrentlyOpInterface
|
|||
}
|
||||
};
|
||||
|
||||
/// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
|
||||
/// equivalent operand / result and same offset/sizes/strides specification).
|
||||
static bool areEquivalentExtractSliceOps(const AnalysisState &state,
|
||||
ExtractSliceOp st,
|
||||
ParallelInsertSliceOp sti) {
|
||||
if (!st || !sti)
|
||||
return false;
|
||||
if (st != sti &&
|
||||
!state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
|
||||
return false;
|
||||
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Return true if `value` is originating from an ExtractSliceOp that matches
|
||||
/// the given InsertSliceOp.
|
||||
static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
|
||||
ParallelInsertSliceOp insertOp) {
|
||||
auto condition = [&](Value val) {
|
||||
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
|
||||
if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
|
||||
condition);
|
||||
}
|
||||
|
||||
/// Analysis of ParallelInsertSliceOp.
|
||||
struct ParallelInsertSliceOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<
|
||||
ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
|
||||
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
if (&opOperand != &op->getOpOperand(1) /*dest*/)
|
||||
return {};
|
||||
|
||||
// ParallelInsertSliceOp itself has no results, query its tied op results.
|
||||
auto insertOp = cast<ParallelInsertSliceOp>(op);
|
||||
return {insertOp.getTiedOpResult()};
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return &opOperand == &op->getOpOperand(1) /*dest*/;
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const AnalysisState &state) const {
|
||||
return BufferRelation::Equivalent;
|
||||
}
|
||||
|
||||
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
|
||||
const AnalysisState &state) const {
|
||||
// This interface method is overridden because we want to set a custom
|
||||
// insertion point for tensor copies. They should be inserted right before
|
||||
// the ForeachThreadOp. E.g.:
|
||||
//
|
||||
// %r0, %r1 = foreach_thead ... {
|
||||
// ...
|
||||
// perform_concurrently {
|
||||
// parallel_insert_slice %a into %b ... {inplace = ["true", "true"]}
|
||||
// parallel_insert_slice %c into %d ... {inplace = ["true", "false"]}
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// After TensorCopyInsertion:
|
||||
//
|
||||
// %copy = bufferization.alloc_tensor() copy(%d)
|
||||
// %r0, %r1 = foreach_thead ... {
|
||||
// ...
|
||||
// perform_concurrently {
|
||||
// parallel_insert_slice %a into %b ...
|
||||
// parallel_insert_slice %c into %copy ...
|
||||
// }
|
||||
// }
|
||||
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
|
||||
ParallelCombiningOpInterface parallelCombiningParent =
|
||||
parallelInsertSliceOp.getParallelCombiningParent();
|
||||
Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
|
||||
|
||||
// Nothing to do if the destination tensor is inplace.
|
||||
assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
|
||||
"source is always in-place");
|
||||
if (state.isInPlace(op->getOpOperand(1) /*dest*/))
|
||||
return success();
|
||||
|
||||
// Find corresponding OpResult.
|
||||
OpResult opResult = parallelInsertSliceOp.getTiedOpResult();
|
||||
|
||||
// Insert tensor allocation right before the ForeachThreadOp.
|
||||
rewriter.setInsertionPoint(parallelIteratingOp);
|
||||
bool isYielded = state.isTensorYielded(opResult);
|
||||
FailureOr<Value> alloc = allocateTensorForShapedValue(
|
||||
rewriter, op->getLoc(), parallelInsertSliceOp.getDest(),
|
||||
/*escape=*/isYielded, state.getOptions());
|
||||
if (failed(alloc))
|
||||
return failure();
|
||||
|
||||
// Update destination operand.
|
||||
rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() {
|
||||
parallelInsertSliceOp.getDestMutable().assign(*alloc);
|
||||
});
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
const BufferizationOptions &options) const {
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
|
||||
ParallelCombiningOpInterface parallelCombiningParent =
|
||||
parallelInsertSliceOp.getParallelCombiningParent();
|
||||
Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
|
||||
|
||||
// Get destination buffer.
|
||||
FailureOr<Value> destBuffer =
|
||||
getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
|
||||
if (failed(destBuffer))
|
||||
return failure();
|
||||
|
||||
// Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`.
|
||||
rewriter.setInsertionPoint(parallelCombiningParent);
|
||||
FailureOr<Value> srcBuffer =
|
||||
getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
|
||||
if (failed(srcBuffer))
|
||||
return failure();
|
||||
Value subview = rewriter.create<memref::SubViewOp>(
|
||||
parallelInsertSliceOp.getLoc(), *destBuffer,
|
||||
parallelInsertSliceOp.getMixedOffsets(),
|
||||
parallelInsertSliceOp.getMixedSizes(),
|
||||
parallelInsertSliceOp.getMixedStrides());
|
||||
// This memcpy will fold away if everything bufferizes in-place.
|
||||
if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
|
||||
*srcBuffer, subview)))
|
||||
return failure();
|
||||
|
||||
// Replace all uses of parallelIteratingOp (just the corresponding result).
|
||||
rewriter.setInsertionPointAfter(parallelIteratingOp);
|
||||
Value toTensorOp =
|
||||
rewriter.create<ToTensorOp>(parallelIteratingOp->getLoc(), *destBuffer);
|
||||
// PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
|
||||
SmallVector<OpOperand *> resultUses = llvm::to_vector(
|
||||
llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(),
|
||||
[](OpOperand &use) { return &use; }));
|
||||
for (OpOperand *use : resultUses) {
|
||||
rewriter.updateRootInPlace(use->getOwner(),
|
||||
[&]() { use->set(toTensorOp); });
|
||||
}
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
// TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
|
||||
// the code.
|
||||
bool isNotConflicting(Operation *op, OpOperand *uRead,
|
||||
OpOperand *uConflictingWrite,
|
||||
const AnalysisState &state) const {
|
||||
Operation *readingOp = uRead->getOwner();
|
||||
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
|
||||
|
||||
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
|
||||
// uRead is an InsertSliceOp...
|
||||
if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) {
|
||||
// As an example, consider the following IR.
|
||||
//
|
||||
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
||||
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
||||
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
||||
// {inplace= [true] }
|
||||
|
||||
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
|
||||
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
|
||||
insertSliceOp))
|
||||
// Case 1: The main insight is that InsertSliceOp reads only part of
|
||||
// the destination tensor. The overwritten area is not read. If
|
||||
// uConflictingWrite writes into exactly the memory location that is
|
||||
// being read by uRead, this is not a conflict.
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 1 (%t) of tensor.insert_slice
|
||||
// uConflictingWrite = OpOperand 1 (%0) of linalg.fill
|
||||
//
|
||||
// The read of %t does not conflict with the write of the FillOp
|
||||
// (same aliases!) because the area that the FillOp operates on is
|
||||
// exactly the one that is *not* read via %t.
|
||||
return true;
|
||||
|
||||
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
|
||||
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
|
||||
// Case 2: The read of the source tensor and the write to the dest
|
||||
// tensor via an InsertSliceOp is not a conflict if the read is
|
||||
// reading exactly that part of an equivalent tensor that the
|
||||
// InsertSliceOp is writing.
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 0 (%1) of tensor.insert_slice
|
||||
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
||||
return true;
|
||||
}
|
||||
|
||||
// If uConflictingWrite is an InsertSliceOp...
|
||||
if (auto insertSliceOp =
|
||||
dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp))
|
||||
// As an example, consider the following IR.
|
||||
//
|
||||
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
||||
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
||||
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
||||
// {inplace= [true] }
|
||||
// %3 = vector.transfer_read %1, %cst
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 0 (%1) of vector.transfer_read
|
||||
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
||||
// lastWrite = %1
|
||||
//
|
||||
// This is not a conflict because the InsertSliceOp overwrites the
|
||||
// memory segment of %1 with the exact same data. (Effectively, there
|
||||
// is no memory write here.)
|
||||
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
state.areEquivalentBufferizedValues(uRead->get(),
|
||||
insertSliceOp.getSource()) &&
|
||||
hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
|
||||
insertSliceOp))
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace scf
|
||||
} // namespace mlir
|
||||
|
@ -1257,8 +1015,6 @@ void mlir::scf::registerBufferizableOpInterfaceExternalModels(
|
|||
ForOp::attachInterface<ForOpInterface>(*ctx);
|
||||
IfOp::attachInterface<IfOpInterface>(*ctx);
|
||||
ForeachThreadOp::attachInterface<ForeachThreadOpInterface>(*ctx);
|
||||
ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
|
||||
*ctx);
|
||||
PerformConcurrentlyOp::attachInterface<PerformConcurrentlyOpInterface>(
|
||||
*ctx);
|
||||
WhileOp::attachInterface<WhileOpInterface>(*ctx);
|
||||
|
|
|
@ -26,6 +26,7 @@ add_mlir_dialect_library(MLIRTensorDialect
|
|||
MLIRDialectUtils
|
||||
MLIRIR
|
||||
MLIRInferTypeOpInterface
|
||||
MLIRParallelCombiningOpInterface
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRSupport
|
||||
MLIRViewLikeInterface
|
||||
|
|
|
@ -2179,6 +2179,137 @@ OpFoldResult PadOp::fold(ArrayRef<Attribute>) {
|
|||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ParallelInsertSliceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpResult ParallelInsertSliceOp::getTiedOpResult() {
|
||||
ParallelCombiningOpInterface parallelCombiningParent =
|
||||
getParallelCombiningParent();
|
||||
for (const auto &it :
|
||||
llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
|
||||
Operation &nextOp = it.value();
|
||||
if (&nextOp == getOperation())
|
||||
return parallelCombiningParent.getParentResult(it.index());
|
||||
}
|
||||
llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
|
||||
}
|
||||
|
||||
// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
|
||||
void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
|
||||
Value source, Value dest,
|
||||
ArrayRef<OpFoldResult> offsets,
|
||||
ArrayRef<OpFoldResult> sizes,
|
||||
ArrayRef<OpFoldResult> strides,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
|
||||
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
|
||||
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
|
||||
ShapedType::kDynamicStrideOrOffset);
|
||||
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
|
||||
ShapedType::kDynamicSize);
|
||||
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
|
||||
ShapedType::kDynamicStrideOrOffset);
|
||||
build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
|
||||
dynamicStrides, b.getI64ArrayAttr(staticOffsets),
|
||||
b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
|
||||
result.addAttributes(attrs);
|
||||
}
|
||||
|
||||
// Build a ParallelInsertSliceOp with dynamic entries.
|
||||
void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
|
||||
Value source, Value dest, ValueRange offsets,
|
||||
ValueRange sizes, ValueRange strides,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
|
||||
llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
|
||||
SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
|
||||
llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
|
||||
SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
|
||||
llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
|
||||
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
|
||||
}
|
||||
|
||||
LogicalResult ParallelInsertSliceOp::verify() {
|
||||
if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
|
||||
return this->emitError("expected ParallelCombiningOpInterface parent, got:")
|
||||
<< *(getOperation()->getParentOp());
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Pattern to rewrite a parallel_insert_slice op with constant arguments.
|
||||
class ParallelInsertSliceOpConstantArgumentFolder final
|
||||
: public OpRewritePattern<ParallelInsertSliceOp> {
|
||||
public:
|
||||
using OpRewritePattern<ParallelInsertSliceOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// No constant operand, just return.
|
||||
if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
|
||||
return matchPattern(operand, matchConstantIndex());
|
||||
}))
|
||||
return failure();
|
||||
|
||||
// At least one of offsets/sizes/strides is a new constant.
|
||||
// Form the new list of operands and constant attributes from the
|
||||
// existing.
|
||||
SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
|
||||
SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
|
||||
SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
|
||||
canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
|
||||
canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
|
||||
canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
|
||||
|
||||
// Create the new op in canonical form.
|
||||
rewriter.replaceOpWithNewOp<ParallelInsertSliceOp>(
|
||||
insertSliceOp, insertSliceOp.getSource(), insertSliceOp.getDest(),
|
||||
mixedOffsets, mixedSizes, mixedStrides);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
/// Fold a parallel_insert_slice source coming from a tensor.cast op.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
|
||||
/// %1 = compute_some_tensor() : tensor<64xf32>
|
||||
/// %2 = tensor.cast %1 : tensor<64xf32> to tensor<?xf32>
|
||||
/// scf.foreach_thread.perform_concurrently {
|
||||
/// scf.foreach_thread.parallel_insert_slice %2 into %out[...] [64] [1] :
|
||||
/// tensor<?xf32> into tensor<128xf32>
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// is folded into:
|
||||
/// ```
|
||||
/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
|
||||
/// %1 = compute_some_tensor() : tensor<64xf32>
|
||||
/// scf.foreach_thread.perform_concurrently {
|
||||
/// scf.foreach_thread.parallel_insert_slice %1 into %out[...] [64] [1] :
|
||||
/// tensor<64xf32> into tensor<128xf32>
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
LogicalResult
|
||||
ParallelInsertSliceOp::fold(ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
auto sourceCast = getSource().getDefiningOp<tensor::CastOp>();
|
||||
if (!sourceCast)
|
||||
return failure();
|
||||
getSourceMutable().assign(sourceCast.getSource());
|
||||
return success();
|
||||
}
|
||||
|
||||
void ParallelInsertSliceOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &results, MLIRContext *context) {
|
||||
results.add<ParallelInsertSliceOpConstantArgumentFolder>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SplatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -810,6 +810,248 @@ struct ReshapeOpInterface
|
|||
}
|
||||
};
|
||||
|
||||
/// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
|
||||
/// equivalent operand / result and same offset/sizes/strides specification).
|
||||
static bool areEquivalentExtractSliceOps(const AnalysisState &state,
|
||||
ExtractSliceOp st,
|
||||
ParallelInsertSliceOp sti) {
|
||||
if (!st || !sti)
|
||||
return false;
|
||||
if (st != sti &&
|
||||
!state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
|
||||
return false;
|
||||
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Return true if `value` is originating from an ExtractSliceOp that matches
|
||||
/// the given InsertSliceOp.
|
||||
static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
|
||||
ParallelInsertSliceOp insertOp) {
|
||||
auto condition = [&](Value val) {
|
||||
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
|
||||
if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
|
||||
condition);
|
||||
}
|
||||
|
||||
/// Analysis of ParallelInsertSliceOp.
|
||||
struct ParallelInsertSliceOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<
|
||||
ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
|
||||
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
if (&opOperand != &op->getOpOperand(1) /*dest*/)
|
||||
return {};
|
||||
|
||||
// ParallelInsertSliceOp itself has no results, query its tied op results.
|
||||
auto insertOp = cast<ParallelInsertSliceOp>(op);
|
||||
return {insertOp.getTiedOpResult()};
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return &opOperand == &op->getOpOperand(1) /*dest*/;
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const AnalysisState &state) const {
|
||||
return BufferRelation::Equivalent;
|
||||
}
|
||||
|
||||
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
|
||||
const AnalysisState &state) const {
|
||||
// This interface method is overridden because we want to set a custom
|
||||
// insertion point for tensor copies. They should be inserted right before
|
||||
// the ForeachThreadOp. E.g.:
|
||||
//
|
||||
// %r0, %r1 = foreach_thead ... {
|
||||
// ...
|
||||
// perform_concurrently {
|
||||
// parallel_insert_slice %a into %b ... {inplace = ["true", "true"]}
|
||||
// parallel_insert_slice %c into %d ... {inplace = ["true", "false"]}
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// After TensorCopyInsertion:
|
||||
//
|
||||
// %copy = bufferization.alloc_tensor() copy(%d)
|
||||
// %r0, %r1 = foreach_thead ... {
|
||||
// ...
|
||||
// perform_concurrently {
|
||||
// parallel_insert_slice %a into %b ...
|
||||
// parallel_insert_slice %c into %copy ...
|
||||
// }
|
||||
// }
|
||||
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
|
||||
ParallelCombiningOpInterface parallelCombiningParent =
|
||||
parallelInsertSliceOp.getParallelCombiningParent();
|
||||
Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
|
||||
|
||||
// Nothing to do if the destination tensor is inplace.
|
||||
assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
|
||||
"source is always in-place");
|
||||
if (state.isInPlace(op->getOpOperand(1) /*dest*/))
|
||||
return success();
|
||||
|
||||
// Find corresponding OpResult.
|
||||
OpResult opResult = parallelInsertSliceOp.getTiedOpResult();
|
||||
|
||||
// Insert tensor allocation right before the ForeachThreadOp.
|
||||
rewriter.setInsertionPoint(parallelIteratingOp);
|
||||
bool isYielded = state.isTensorYielded(opResult);
|
||||
FailureOr<Value> alloc = allocateTensorForShapedValue(
|
||||
rewriter, op->getLoc(), parallelInsertSliceOp.getDest(),
|
||||
/*escape=*/isYielded, state.getOptions());
|
||||
if (failed(alloc))
|
||||
return failure();
|
||||
|
||||
// Update destination operand.
|
||||
rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() {
|
||||
parallelInsertSliceOp.getDestMutable().assign(*alloc);
|
||||
});
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
const BufferizationOptions &options) const {
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
|
||||
ParallelCombiningOpInterface parallelCombiningParent =
|
||||
parallelInsertSliceOp.getParallelCombiningParent();
|
||||
Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
|
||||
|
||||
// Get destination buffer.
|
||||
FailureOr<Value> destBuffer =
|
||||
getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
|
||||
if (failed(destBuffer))
|
||||
return failure();
|
||||
|
||||
// Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`.
|
||||
rewriter.setInsertionPoint(parallelCombiningParent);
|
||||
FailureOr<Value> srcBuffer =
|
||||
getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
|
||||
if (failed(srcBuffer))
|
||||
return failure();
|
||||
Value subview = rewriter.create<memref::SubViewOp>(
|
||||
parallelInsertSliceOp.getLoc(), *destBuffer,
|
||||
parallelInsertSliceOp.getMixedOffsets(),
|
||||
parallelInsertSliceOp.getMixedSizes(),
|
||||
parallelInsertSliceOp.getMixedStrides());
|
||||
// This memcpy will fold away if everything bufferizes in-place.
|
||||
if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
|
||||
*srcBuffer, subview)))
|
||||
return failure();
|
||||
|
||||
// Replace all uses of parallelIteratingOp (just the corresponding result).
|
||||
rewriter.setInsertionPointAfter(parallelIteratingOp);
|
||||
Value toTensorOp =
|
||||
rewriter.create<ToTensorOp>(parallelIteratingOp->getLoc(), *destBuffer);
|
||||
// PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
|
||||
SmallVector<OpOperand *> resultUses = llvm::to_vector(
|
||||
llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(),
|
||||
[](OpOperand &use) { return &use; }));
|
||||
for (OpOperand *use : resultUses) {
|
||||
rewriter.updateRootInPlace(use->getOwner(),
|
||||
[&]() { use->set(toTensorOp); });
|
||||
}
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
// TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
|
||||
// the code.
|
||||
bool isNotConflicting(Operation *op, OpOperand *uRead,
|
||||
OpOperand *uConflictingWrite,
|
||||
const AnalysisState &state) const {
|
||||
Operation *readingOp = uRead->getOwner();
|
||||
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
|
||||
|
||||
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
|
||||
// uRead is an InsertSliceOp...
|
||||
if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) {
|
||||
// As an example, consider the following IR.
|
||||
//
|
||||
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
||||
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
||||
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
||||
// {inplace= [true] }
|
||||
|
||||
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
|
||||
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
|
||||
insertSliceOp))
|
||||
// Case 1: The main insight is that InsertSliceOp reads only part of
|
||||
// the destination tensor. The overwritten area is not read. If
|
||||
// uConflictingWrite writes into exactly the memory location that is
|
||||
// being read by uRead, this is not a conflict.
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 1 (%t) of tensor.insert_slice
|
||||
// uConflictingWrite = OpOperand 1 (%0) of linalg.fill
|
||||
//
|
||||
// The read of %t does not conflict with the write of the FillOp
|
||||
// (same aliases!) because the area that the FillOp operates on is
|
||||
// exactly the one that is *not* read via %t.
|
||||
return true;
|
||||
|
||||
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
|
||||
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
|
||||
// Case 2: The read of the source tensor and the write to the dest
|
||||
// tensor via an InsertSliceOp is not a conflict if the read is
|
||||
// reading exactly that part of an equivalent tensor that the
|
||||
// InsertSliceOp is writing.
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 0 (%1) of tensor.insert_slice
|
||||
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
||||
return true;
|
||||
}
|
||||
|
||||
// If uConflictingWrite is an InsertSliceOp...
|
||||
if (auto insertSliceOp =
|
||||
dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp))
|
||||
// As an example, consider the following IR.
|
||||
//
|
||||
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
||||
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
||||
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
||||
// {inplace= [true] }
|
||||
// %3 = vector.transfer_read %1, %cst
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 0 (%1) of vector.transfer_read
|
||||
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
||||
// lastWrite = %1
|
||||
//
|
||||
// This is not a conflict because the InsertSliceOp overwrites the
|
||||
// memory segment of %1 with the exact same data. (Effectively, there
|
||||
// is no memory write here.)
|
||||
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
state.areEquivalentBufferizedValues(uRead->get(),
|
||||
insertSliceOp.getSource()) &&
|
||||
hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
|
||||
insertSliceOp))
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace tensor
|
||||
} // namespace mlir
|
||||
|
@ -827,6 +1069,8 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
|
|||
GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
|
||||
InsertOp::attachInterface<InsertOpInterface>(*ctx);
|
||||
InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
|
||||
ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
|
||||
*ctx);
|
||||
RankOp::attachInterface<RankOpInterface>(*ctx);
|
||||
ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
|
||||
});
|
||||
|
|
|
@ -1457,28 +1457,3 @@ func.func @func_execute_region_elim_multi_yield() {
|
|||
// CHECK: ^[[bb3]](%[[z:.+]]: i64):
|
||||
// CHECK: "test.bar"(%[[z]])
|
||||
// CHECK: return
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @canonicalize_parallel_insert_slice_indices(
|
||||
// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<?x?xf32>,
|
||||
// CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor<?x?xf32>,
|
||||
// CHECK-SAME: %[[num_threads:[0-9a-z]*]]: index
|
||||
func.func @canonicalize_parallel_insert_slice_indices(
|
||||
%arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
|
||||
%num_threads : index) -> tensor<?x?xf32>
|
||||
{
|
||||
%cst = arith.constant 4.200000e+01 : f32
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
|
||||
// CHECK: scf.foreach_thread (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) -> (tensor<?x?xf32>) {
|
||||
// CHECK-NEXT: scf.foreach_thread.perform_concurrently {
|
||||
// CHECK-NEXT: scf.foreach_thread.parallel_insert_slice %[[arg0]] into %[[arg1]][%[[tidx]], 0] [1, 5] [1, 1]
|
||||
%2 = scf.foreach_thread (%tidx) in (%num_threads) -> (tensor<?x?xf32>) {
|
||||
scf.foreach_thread.perform_concurrently {
|
||||
scf.foreach_thread.parallel_insert_slice %arg0 into %arg1[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor<?x?xf32> into tensor<?x?xf32>
|
||||
}
|
||||
}
|
||||
return %2 : tensor<?x?xf32>
|
||||
}
|
||||
|
|
|
@ -1,36 +1,36 @@
|
|||
// RUN: mlir-opt %s -scf-for-loop-canonicalization -canonicalize | FileCheck %s
|
||||
// RUN: mlir-opt %s -scf-for-loop-canonicalization | FileCheck %s
|
||||
|
||||
func.func @reduce() -> tensor<128xf32> {
|
||||
func.func @reduce() {
|
||||
// CHECK: %[[C64:.*]] = arith.constant 64 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
%cst = arith.constant dense<1.000000e+00> : tensor<1x128x384xf32>
|
||||
%cst_0 = arith.constant -0.000000e+00 : f32
|
||||
%0 = linalg.init_tensor [128, 384] : tensor<128x384xf32>
|
||||
%1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<128x384xf32>) -> tensor<128x384xf32>
|
||||
%2 = linalg.init_tensor [128] : tensor<128xf32>
|
||||
%3 = linalg.fill ins(%cst_0 : f32) outs(%2 : tensor<128xf32>) -> tensor<128xf32>
|
||||
%4 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
|
||||
%0 = memref.alloc() : memref<128x384xf32>
|
||||
linalg.fill ins(%cst_0 : f32) outs(%0 : memref<128x384xf32>)
|
||||
%2 = memref.alloc() : memref<128xf32>
|
||||
linalg.fill ins(%cst_0 : f32) outs(%2 : memref<128xf32>)
|
||||
scf.foreach_thread (%arg0) in (%c2) {
|
||||
%7 = affine.min affine_map<(d0) -> (d0 * -64 + 128, 64)>(%arg0)
|
||||
%8 = affine.max affine_map<(d0) -> (0, d0)>(%7)
|
||||
%9 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg0)
|
||||
%10 = affine.min affine_map<(d0, d1) -> (d1 * -64 + 128, d0)>(%8, %arg0)
|
||||
|
||||
// CHECK: tensor.extract_slice %{{.*}}[%{{.*}}, 0] [64, 384] [1, 1] : tensor<128x384xf32> to tensor<64x384xf32>
|
||||
// CHECK: tensor.extract_slice %{{.*}}[%{{.*}}] [64] [1] : tensor<128xf32> to tensor<64xf32>
|
||||
%11 = tensor.extract_slice %1[%9, 0] [%10, 384] [1, 1] : tensor<128x384xf32> to tensor<?x384xf32>
|
||||
%12 = tensor.extract_slice %3[%9] [%10] [1] : tensor<128xf32> to tensor<?xf32>
|
||||
// CHECK: memref.subview %{{.*}}[%{{.*}}, 0] [%[[C64]], 384] [1, 1] : memref<128x384xf32> to memref<?x384xf32, {{.*}}>
|
||||
// CHECK: memref.subview %{{.*}}[%{{.*}}] [%[[C64]]] [1] : memref<128xf32> to memref<?xf32, {{.*}}>
|
||||
%11 = memref.subview %0[%9, 0] [%10, 384] [1, 1] :
|
||||
memref<128x384xf32> to memref<?x384xf32, affine_map<(d0, d1)[s0] -> (d0 * 384 + s0 + d1)>>
|
||||
%12 = memref.subview %2[%9] [%10] [1] :
|
||||
memref<128xf32> to memref<?xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
|
||||
|
||||
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<64x384xf32>) outs(%{{.*}} : tensor<64xf32>) {
|
||||
%13 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%11 : tensor<?x384xf32>) outs(%12 : tensor<?xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
%14 = arith.addf %arg1, %arg2 : f32
|
||||
linalg.yield %14 : f32
|
||||
} -> tensor<?xf32>
|
||||
|
||||
// CHECK-NOT: tensor.cast
|
||||
// CHECK: scf.foreach_thread.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [64] [1] : tensor<64xf32> into tensor<128xf32>
|
||||
scf.foreach_thread.perform_concurrently {
|
||||
scf.foreach_thread.parallel_insert_slice %13 into %3[%9] [%10] [1] : tensor<?xf32> into tensor<128xf32>
|
||||
}
|
||||
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : memref<?x384xf32, {{.*}}>) outs(%{{.*}} : memref<?xf32, {{.*}}>)
|
||||
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d0)>],
|
||||
iterator_types = ["parallel", "reduction"]}
|
||||
ins(%11 : memref<?x384xf32, affine_map<(d0, d1)[s0] -> (d0 * 384 + s0 + d1)>>)
|
||||
outs(%12 : memref<?xf32, affine_map<(d0)[s0] -> (d0 + s0)>>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
%14 = arith.addf %arg1, %arg2 : f32
|
||||
linalg.yield %14 : f32
|
||||
}
|
||||
}
|
||||
return %4 : tensor<128xf32>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -531,7 +531,7 @@ func.func @wrong_num_results(%in: tensor<100xf32>, %out: tensor<100xf32>) {
|
|||
%result:2 = scf.foreach_thread (%thread_idx) in (%num_threads) -> (tensor<100xf32>, tensor<100xf32>) {
|
||||
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
|
||||
scf.foreach_thread.perform_concurrently {
|
||||
scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
|
||||
tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
|
||||
tensor<1xf32> into tensor<100xf32>
|
||||
}
|
||||
}
|
||||
|
@ -548,7 +548,7 @@ func.func @wrong_type_result(%in: tensor<100xf32>, %out: tensor<100xf32>) {
|
|||
%result = scf.foreach_thread (%thread_idx) in (%num_threads) -> (tensor<?xf32>) {
|
||||
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
|
||||
scf.foreach_thread.perform_concurrently {
|
||||
scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
|
||||
tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
|
||||
tensor<1xf32> into tensor<100xf32>
|
||||
}
|
||||
}
|
||||
|
@ -563,9 +563,9 @@ func.func @wrong_terminator_op(%in: tensor<100xf32>, %out: tensor<100xf32>) {
|
|||
|
||||
%result = scf.foreach_thread (%thread_idx) in (%num_threads) -> (tensor<100xf32>) {
|
||||
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
|
||||
// expected-error @+1 {{expected only scf.foreach_thread.parallel_insert_slice ops}}
|
||||
// expected-error @+1 {{expected only tensor.parallel_insert_slice ops}}
|
||||
scf.foreach_thread.perform_concurrently {
|
||||
scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
|
||||
tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
|
||||
tensor<1xf32> into tensor<100xf32>
|
||||
%0 = arith.constant 1: index
|
||||
}
|
||||
|
|
|
@ -124,10 +124,10 @@ func.func @scf_foreach_thread_out_of_place(%in: tensor<100xf32>,
|
|||
%result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<100xf32> {
|
||||
// CHECK: tensor.extract_slice
|
||||
// CHECK: scf.foreach_thread.perform_concurrently
|
||||
// CHECK: scf.foreach_thread.parallel_insert_slice %{{.*}} into %[[alloc]]
|
||||
// CHECK: tensor.parallel_insert_slice %{{.*}} into %[[alloc]]
|
||||
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
|
||||
scf.foreach_thread.perform_concurrently {
|
||||
scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
|
||||
tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
|
||||
tensor<1xf32> into tensor<100xf32>
|
||||
}
|
||||
// CHECK: } {thread_dim_mapping = [5]}
|
||||
|
|
|
@ -537,7 +537,7 @@ func.func @parallel_insert_slice_no_conflict(
|
|||
// CHECK-NOT: scf.foreach_thread.perform_concurrently
|
||||
// CHECK-NOT: parallel_insert_slice
|
||||
scf.foreach_thread.perform_concurrently {
|
||||
scf.foreach_thread.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] :
|
||||
tensor.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] :
|
||||
tensor<?xf32> into tensor<?xf32>
|
||||
}
|
||||
}
|
||||
|
@ -589,7 +589,7 @@ func.func @parallel_insert_slice_with_conflict(
|
|||
// CHECK-NOT: scf.foreach_thread.perform_concurrently
|
||||
// CHECK-NOT: parallel_insert_slice
|
||||
scf.foreach_thread.perform_concurrently {
|
||||
scf.foreach_thread.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] :
|
||||
tensor.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] :
|
||||
tensor<?xf32> into tensor<?xf32>
|
||||
}
|
||||
}
|
||||
|
@ -627,7 +627,7 @@ func.func @matmul(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>, %arg2: tensor<
|
|||
// CHECK: linalg.matmul ins({{.*}}memref<4x8xf32, #[[$DYN_LAYOUT_MAP]]>, memref<8x4xf32, #[[$DYN_LAYOUT_MAP]]>) outs({{.*}} : memref<4x4xf32, #[[$DYN_LAYOUT_MAP]]>)
|
||||
%8 = linalg.matmul ins(%3, %6 : tensor<4x8xf32>, tensor<8x4xf32>) outs(%7 : tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
scf.foreach_thread.perform_concurrently {
|
||||
scf.foreach_thread.parallel_insert_slice %8 into %arg2[%1, %4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x8xf32>
|
||||
tensor.parallel_insert_slice %8 into %arg2[%1, %4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x8xf32>
|
||||
}
|
||||
}
|
||||
return %0 : tensor<8x8xf32>
|
||||
|
|
|
@ -319,14 +319,14 @@ func.func @simple_example(%in: tensor<100xf32>, %out: tensor<100xf32>) {
|
|||
// CHECK: scf.foreach_thread
|
||||
// CHECK-NEXT: tensor.extract_slice
|
||||
// CHECK-NEXT: scf.foreach_thread.perform_concurrently
|
||||
// CHECK-NEXT: scf.foreach_thread.parallel_insert_slice
|
||||
// CHECK-NEXT: tensor.parallel_insert_slice
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
%result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<100xf32> {
|
||||
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
|
||||
scf.foreach_thread.perform_concurrently {
|
||||
scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
|
||||
tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
|
||||
tensor<1xf32> into tensor<100xf32>
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1425,3 +1425,28 @@ func.func @cast_extract_slice_rank_reduce(%arg0 : tensor<128x512xf32>, %s : inde
|
|||
// CHECK: return %[[E]] : tensor<16xf32>
|
||||
return %1 : tensor<16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @canonicalize_parallel_insert_slice_indices(
|
||||
// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<?x?xf32>,
|
||||
// CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor<?x?xf32>,
|
||||
// CHECK-SAME: %[[num_threads:[0-9a-z]*]]: index
|
||||
func.func @canonicalize_parallel_insert_slice_indices(
|
||||
%arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
|
||||
%num_threads : index) -> tensor<?x?xf32>
|
||||
{
|
||||
%cst = arith.constant 4.200000e+01 : f32
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
|
||||
// CHECK: scf.foreach_thread (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) -> (tensor<?x?xf32>) {
|
||||
// CHECK-NEXT: scf.foreach_thread.perform_concurrently {
|
||||
// CHECK-NEXT: tensor.parallel_insert_slice %[[arg0]] into %[[arg1]][%[[tidx]], 0] [1, 5] [1, 1]
|
||||
%2 = scf.foreach_thread (%tidx) in (%num_threads) -> (tensor<?x?xf32>) {
|
||||
scf.foreach_thread.perform_concurrently {
|
||||
tensor.parallel_insert_slice %arg0 into %arg1[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor<?x?xf32> into tensor<?x?xf32>
|
||||
}
|
||||
}
|
||||
return %2 : tensor<?x?xf32>
|
||||
}
|
||||
|
|
|
@ -4909,6 +4909,7 @@ td_library(
|
|||
":ControlFlowInterfacesTdFiles",
|
||||
":InferTypeOpInterfaceTdFiles",
|
||||
":OpBaseTdFiles",
|
||||
":ParallelCombiningOpInterfaceTdFiles",
|
||||
":SideEffectInterfacesTdFiles",
|
||||
":TilingInterfaceTdFiles",
|
||||
":ViewLikeInterfaceTdFiles",
|
||||
|
@ -4965,6 +4966,7 @@ cc_library(
|
|||
":DialectUtils",
|
||||
":IR",
|
||||
":InferTypeOpInterface",
|
||||
":ParallelCombiningOpInterface",
|
||||
":SideEffectInterfaces",
|
||||
":TensorOpsIncGen",
|
||||
":TilingInterface",
|
||||
|
|
Loading…
Reference in New Issue