forked from OSchip/llvm-project
[mlir][Tensor] Move ParallelInsertSlice to the tensor dialect
This is moslty NFC and will allow tensor.parallel_insert_slice to gain rank-reducing semantics by reusing the vast majority of the tensor.insert_slice impl. Depends on D128857 Differential Revision: https://reviews.llvm.org/D128920
This commit is contained in:
parent
f0089fae1d
commit
7fbf55c927
|
@ -503,115 +503,6 @@ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// ParallelInsertSliceOp
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
// TODO: Implement PerformConcurrentlyOpInterface.
|
|
||||||
def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
|
|
||||||
AttrSizedOperandSegments,
|
|
||||||
OffsetSizeAndStrideOpInterface,
|
|
||||||
// TODO: Cannot use an interface here atm, verify this manually for now.
|
|
||||||
// HasParent<"ParallelCombiningOpInterface">
|
|
||||||
]> {
|
|
||||||
let summary = [{
|
|
||||||
Specify the tensor slice update of a single thread within the terminator of
|
|
||||||
an `scf.foreach_thread`.
|
|
||||||
}];
|
|
||||||
let description = [{
|
|
||||||
The parent `scf.foreach_thread` returns values that are formed by aggregating
|
|
||||||
the actions of all the ops contained within the `perform_concurrently`
|
|
||||||
terminator of all the threads, in some unspecified order.
|
|
||||||
The `scf.foreach_thread.parallel_insert_slice` is one such op allowed in
|
|
||||||
the `scf.foreach_thread.perform_concurrently` terminator.
|
|
||||||
|
|
||||||
Conflicting writes result in undefined semantics, in that the indices written
|
|
||||||
to by multiple parallel updates might contain data from any of the updates, or
|
|
||||||
even a malformed bit pattern.
|
|
||||||
|
|
||||||
If an index is updated exactly once, the value contained at that index
|
|
||||||
in the resulting tensor will be equal to the value at a corresponding index of a
|
|
||||||
slice that was used for the updated. If an index is not updated at all, its value
|
|
||||||
will be equal to the one in the original tensor.
|
|
||||||
|
|
||||||
This op does not create a new value, which allows maintaining a clean
|
|
||||||
separation between the subset and full tensor.
|
|
||||||
Note that we cannot mark this operation as pure (NoSideEffects), even
|
|
||||||
though it has no side effects, because it will get DCEd during
|
|
||||||
canonicalization.
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
AnyRankedTensor:$source,
|
|
||||||
AnyRankedTensor:$dest,
|
|
||||||
Variadic<Index>:$offsets,
|
|
||||||
Variadic<Index>:$sizes,
|
|
||||||
Variadic<Index>:$strides,
|
|
||||||
I64ArrayAttr:$static_offsets,
|
|
||||||
I64ArrayAttr:$static_sizes,
|
|
||||||
I64ArrayAttr:$static_strides
|
|
||||||
);
|
|
||||||
let assemblyFormat = [{
|
|
||||||
$source `into` $dest ``
|
|
||||||
custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
|
|
||||||
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
|
|
||||||
custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
|
|
||||||
attr-dict `:` type($source) `into` type($dest)
|
|
||||||
}];
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
::mlir::Operation::operand_range offsets() { return getOffsets(); }
|
|
||||||
::mlir::Operation::operand_range sizes() { return getSizes(); }
|
|
||||||
::mlir::Operation::operand_range strides() { return getStrides(); }
|
|
||||||
::mlir::ArrayAttr static_offsets() { return getStaticOffsets(); }
|
|
||||||
::mlir::ArrayAttr static_sizes() { return getStaticSizes(); }
|
|
||||||
::mlir::ArrayAttr static_strides() { return getStaticStrides(); }
|
|
||||||
|
|
||||||
Type yieldedType() { return getDest().getType(); }
|
|
||||||
|
|
||||||
RankedTensorType getSourceType() {
|
|
||||||
return getSource().getType().cast<RankedTensorType>();
|
|
||||||
}
|
|
||||||
|
|
||||||
ParallelCombiningOpInterface getParallelCombiningParent() {
|
|
||||||
return dyn_cast<ParallelCombiningOpInterface>(
|
|
||||||
getOperation()->getParentOp());
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return the expected rank of each of the `static_offsets`, `static_sizes`
|
|
||||||
/// and `static_strides` attributes.
|
|
||||||
std::array<unsigned, 3> getArrayAttrMaxRanks() {
|
|
||||||
unsigned rank = getSourceType().getRank();
|
|
||||||
return {rank, rank, rank};
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return the number of leading operands before `offsets`, `sizes` and
|
|
||||||
/// `strides` operands.
|
|
||||||
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
|
|
||||||
|
|
||||||
/// Return the OpResult of the enclosing ForeachThreadOp that is
|
|
||||||
/// corresponding to this ParallelInsertSliceOp.
|
|
||||||
OpResult getTiedOpResult();
|
|
||||||
}];
|
|
||||||
|
|
||||||
let builders = [
|
|
||||||
// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
|
|
||||||
OpBuilder<(ins "Value":$source, "Value":$dest,
|
|
||||||
"ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
|
|
||||||
"ArrayRef<OpFoldResult>":$strides,
|
|
||||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
|
|
||||||
|
|
||||||
// Build a ParallelInsertSliceOp with dynamic entries.
|
|
||||||
OpBuilder<(ins "Value":$source, "Value":$dest,
|
|
||||||
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
|
|
||||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
|
|
||||||
];
|
|
||||||
|
|
||||||
let hasCanonicalizer = 1;
|
|
||||||
let hasFolder = 1;
|
|
||||||
let hasVerifier = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// IfOp
|
// IfOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#include "mlir/Interfaces/CastInterfaces.h"
|
#include "mlir/Interfaces/CastInterfaces.h"
|
||||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||||
|
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
|
||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
#include "mlir/Interfaces/TilingInterface.h"
|
#include "mlir/Interfaces/TilingInterface.h"
|
||||||
#include "mlir/Interfaces/ViewLikeInterface.h"
|
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||||
|
|
|
@ -13,6 +13,7 @@ include "mlir/Dialect/Tensor/IR/TensorBase.td"
|
||||||
include "mlir/Interfaces/CastInterfaces.td"
|
include "mlir/Interfaces/CastInterfaces.td"
|
||||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||||
|
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
include "mlir/Interfaces/TilingInterface.td"
|
include "mlir/Interfaces/TilingInterface.td"
|
||||||
include "mlir/Interfaces/ViewLikeInterface.td"
|
include "mlir/Interfaces/ViewLikeInterface.td"
|
||||||
|
@ -1051,6 +1052,110 @@ def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect,
|
||||||
let hasRegionVerifier = 1;
|
let hasRegionVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ParallelInsertSliceOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// TODO: Implement PerformConcurrentlyOpInterface.
|
||||||
|
def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
|
||||||
|
AttrSizedOperandSegments,
|
||||||
|
OffsetSizeAndStrideOpInterface,
|
||||||
|
// TODO: Cannot use an interface here atm, verify this manually for now.
|
||||||
|
// HasParent<"ParallelCombiningOpInterface">
|
||||||
|
]> {
|
||||||
|
let summary = [{
|
||||||
|
Specify the tensor slice update of a single thread of a parent
|
||||||
|
ParallelCombiningOpInterface op.
|
||||||
|
}];
|
||||||
|
let description = [{
|
||||||
|
The `parallel_insert_slice` yields a subset tensor value to its parent
|
||||||
|
ParallelCombiningOpInterface. These subset tensor values are aggregated to
|
||||||
|
in some unspecified order into a full tensor value returned by the parent
|
||||||
|
parallel iterating op.
|
||||||
|
The `parallel_insert_slice` is one such op allowed in the
|
||||||
|
ParallelCombiningOpInterface op.
|
||||||
|
|
||||||
|
Conflicting writes result in undefined semantics, in that the indices written
|
||||||
|
to by multiple parallel updates might contain data from any of the updates,
|
||||||
|
or even a malformed bit pattern.
|
||||||
|
|
||||||
|
If an index is updated exactly once, the value contained at that index
|
||||||
|
in the resulting tensor will be equal to the value at a corresponding index
|
||||||
|
of a slice that was used for the updated. If an index is not updated at all,
|
||||||
|
its value will be equal to the one in the original tensor.
|
||||||
|
|
||||||
|
This op does not create a new value, which allows maintaining a clean
|
||||||
|
separation between the subset and full tensor.
|
||||||
|
|
||||||
|
Note that we cannot mark this operation as pure (NoSideEffects), even
|
||||||
|
though it has no side effects, because it will get DCEd during
|
||||||
|
canonicalization.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
AnyRankedTensor:$source,
|
||||||
|
AnyRankedTensor:$dest,
|
||||||
|
Variadic<Index>:$offsets,
|
||||||
|
Variadic<Index>:$sizes,
|
||||||
|
Variadic<Index>:$strides,
|
||||||
|
I64ArrayAttr:$static_offsets,
|
||||||
|
I64ArrayAttr:$static_sizes,
|
||||||
|
I64ArrayAttr:$static_strides
|
||||||
|
);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$source `into` $dest ``
|
||||||
|
custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
|
||||||
|
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
|
||||||
|
custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
|
||||||
|
attr-dict `:` type($source) `into` type($dest)
|
||||||
|
}];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
Type yieldedType() { return getDest().getType(); }
|
||||||
|
|
||||||
|
RankedTensorType getSourceType() {
|
||||||
|
return getSource().getType().cast<RankedTensorType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
ParallelCombiningOpInterface getParallelCombiningParent() {
|
||||||
|
return dyn_cast<ParallelCombiningOpInterface>(
|
||||||
|
getOperation()->getParentOp());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the expected rank of each of the `static_offsets`, `static_sizes`
|
||||||
|
/// and `static_strides` attributes.
|
||||||
|
std::array<unsigned, 3> getArrayAttrMaxRanks() {
|
||||||
|
unsigned rank = getSourceType().getRank();
|
||||||
|
return {rank, rank, rank};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the number of leading operands before `offsets`, `sizes` and
|
||||||
|
/// `strides` operands.
|
||||||
|
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
|
||||||
|
|
||||||
|
/// Return the OpResult of the enclosing ForeachThreadOp that is
|
||||||
|
/// corresponding to this ParallelInsertSliceOp.
|
||||||
|
OpResult getTiedOpResult();
|
||||||
|
}];
|
||||||
|
|
||||||
|
let builders = [
|
||||||
|
// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
|
||||||
|
OpBuilder<(ins "Value":$source, "Value":$dest,
|
||||||
|
"ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
|
||||||
|
"ArrayRef<OpFoldResult>":$strides,
|
||||||
|
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
|
||||||
|
|
||||||
|
// Build a ParallelInsertSliceOp with dynamic entries.
|
||||||
|
OpBuilder<(ins "Value":$source, "Value":$dest,
|
||||||
|
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
|
||||||
|
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
|
||||||
|
];
|
||||||
|
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
|
let hasFolder = 1;
|
||||||
|
let hasVerifier = 1;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// SplatOp
|
// SplatOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -13,7 +13,6 @@ add_mlir_dialect_library(MLIRSCFDialect
|
||||||
MLIRControlFlowDialect
|
MLIRControlFlowDialect
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRLoopLikeInterface
|
MLIRLoopLikeInterface
|
||||||
MLIRParallelCombiningOpInterface
|
|
||||||
MLIRSideEffectInterfaces
|
MLIRSideEffectInterfaces
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1211,137 +1211,6 @@ ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) {
|
||||||
return dyn_cast<ForeachThreadOp>(containingOp);
|
return dyn_cast<ForeachThreadOp>(containingOp);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// ParallelInsertSliceOp
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
OpResult ParallelInsertSliceOp::getTiedOpResult() {
|
|
||||||
ParallelCombiningOpInterface parallelCombiningParent =
|
|
||||||
getParallelCombiningParent();
|
|
||||||
for (const auto &it :
|
|
||||||
llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
|
|
||||||
Operation &nextOp = it.value();
|
|
||||||
if (&nextOp == getOperation())
|
|
||||||
return parallelCombiningParent.getParentResult(it.index());
|
|
||||||
}
|
|
||||||
llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
|
|
||||||
void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
|
|
||||||
Value source, Value dest,
|
|
||||||
ArrayRef<OpFoldResult> offsets,
|
|
||||||
ArrayRef<OpFoldResult> sizes,
|
|
||||||
ArrayRef<OpFoldResult> strides,
|
|
||||||
ArrayRef<NamedAttribute> attrs) {
|
|
||||||
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
|
|
||||||
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
|
|
||||||
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
|
|
||||||
ShapedType::kDynamicStrideOrOffset);
|
|
||||||
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
|
|
||||||
ShapedType::kDynamicSize);
|
|
||||||
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
|
|
||||||
ShapedType::kDynamicStrideOrOffset);
|
|
||||||
build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
|
|
||||||
dynamicStrides, b.getI64ArrayAttr(staticOffsets),
|
|
||||||
b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
|
|
||||||
result.addAttributes(attrs);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build a ParallelInsertSliceOp with dynamic entries.
|
|
||||||
void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
|
|
||||||
Value source, Value dest, ValueRange offsets,
|
|
||||||
ValueRange sizes, ValueRange strides,
|
|
||||||
ArrayRef<NamedAttribute> attrs) {
|
|
||||||
SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
|
|
||||||
llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
|
|
||||||
SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
|
|
||||||
llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
|
|
||||||
SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
|
|
||||||
llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
|
|
||||||
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult ParallelInsertSliceOp::verify() {
|
|
||||||
if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
|
|
||||||
return this->emitError("expected ParallelCombiningOpInterface parent, got:")
|
|
||||||
<< *(getOperation()->getParentOp());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
/// Pattern to rewrite a parallel_insert_slice op with constant arguments.
|
|
||||||
class ParallelInsertSliceOpConstantArgumentFolder final
|
|
||||||
: public OpRewritePattern<ParallelInsertSliceOp> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern<ParallelInsertSliceOp>::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
// No constant operand, just return.
|
|
||||||
if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
|
|
||||||
return matchPattern(operand, matchConstantIndex());
|
|
||||||
}))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// At least one of offsets/sizes/strides is a new constant.
|
|
||||||
// Form the new list of operands and constant attributes from the
|
|
||||||
// existing.
|
|
||||||
SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
|
|
||||||
SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
|
|
||||||
SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
|
|
||||||
canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
|
|
||||||
canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
|
|
||||||
canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
|
|
||||||
|
|
||||||
// Create the new op in canonical form.
|
|
||||||
rewriter.replaceOpWithNewOp<ParallelInsertSliceOp>(
|
|
||||||
insertSliceOp, insertSliceOp.getSource(), insertSliceOp.getDest(),
|
|
||||||
mixedOffsets, mixedSizes, mixedStrides);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
/// Fold a parallel_insert_slice source coming from a tensor.cast op.
|
|
||||||
///
|
|
||||||
/// Example:
|
|
||||||
/// ```
|
|
||||||
/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
|
|
||||||
/// %1 = compute_some_tensor() : tensor<64xf32>
|
|
||||||
/// %2 = tensor.cast %1 : tensor<64xf32> to tensor<?xf32>
|
|
||||||
/// scf.foreach_thread.perform_concurrently {
|
|
||||||
/// scf.foreach_thread.parallel_insert_slice %2 into %out[...] [64] [1] :
|
|
||||||
/// tensor<?xf32> into tensor<128xf32>
|
|
||||||
/// }
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
/// is folded into:
|
|
||||||
/// ```
|
|
||||||
/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
|
|
||||||
/// %1 = compute_some_tensor() : tensor<64xf32>
|
|
||||||
/// scf.foreach_thread.perform_concurrently {
|
|
||||||
/// scf.foreach_thread.parallel_insert_slice %1 into %out[...] [64] [1] :
|
|
||||||
/// tensor<64xf32> into tensor<128xf32>
|
|
||||||
/// }
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
LogicalResult
|
|
||||||
ParallelInsertSliceOp::fold(ArrayRef<Attribute> operands,
|
|
||||||
SmallVectorImpl<OpFoldResult> &results) {
|
|
||||||
auto sourceCast = getSource().getDefiningOp<tensor::CastOp>();
|
|
||||||
if (!sourceCast)
|
|
||||||
return failure();
|
|
||||||
getSourceMutable().assign(sourceCast.getSource());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
void ParallelInsertSliceOp::getCanonicalizationPatterns(
|
|
||||||
RewritePatternSet &results, MLIRContext *context) {
|
|
||||||
results.add<ParallelInsertSliceOpConstantArgumentFolder>(context);
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// PerformConcurrentlyOp
|
// PerformConcurrentlyOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1355,10 +1224,12 @@ void PerformConcurrentlyOp::build(OpBuilder &b, OperationState &result) {
|
||||||
|
|
||||||
LogicalResult PerformConcurrentlyOp::verify() {
|
LogicalResult PerformConcurrentlyOp::verify() {
|
||||||
// TODO: PerformConcurrentlyOpInterface.
|
// TODO: PerformConcurrentlyOpInterface.
|
||||||
for (const Operation &op : getRegion().front().getOperations())
|
for (const Operation &op : getRegion().front().getOperations()) {
|
||||||
if (!isa<ParallelInsertSliceOp>(op))
|
if (!isa<tensor::ParallelInsertSliceOp>(op)) {
|
||||||
return emitOpError(
|
return this->emitOpError("expected only ")
|
||||||
"expected only scf.foreach_thread.parallel_insert_slice ops");
|
<< tensor::ParallelInsertSliceOp::getOperationName() << " ops";
|
||||||
|
}
|
||||||
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1396,7 +1267,7 @@ OpResult PerformConcurrentlyOp::getParentResult(int64_t idx) {
|
||||||
SmallVector<Type> PerformConcurrentlyOp::getYieldedTypes() {
|
SmallVector<Type> PerformConcurrentlyOp::getYieldedTypes() {
|
||||||
return llvm::to_vector<4>(
|
return llvm::to_vector<4>(
|
||||||
llvm::map_range(getYieldingOps(), [](Operation &op) {
|
llvm::map_range(getYieldingOps(), [](Operation &op) {
|
||||||
auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(&op);
|
auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
|
||||||
return insertSliceOp ? insertSliceOp.yieldedType() : Type();
|
return insertSliceOp ? insertSliceOp.yieldedType() : Type();
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
|
@ -927,7 +927,7 @@ static SmallVector<OpOperand *>
|
||||||
getInsertionDest(ForeachThreadOp foreachThreadOp) {
|
getInsertionDest(ForeachThreadOp foreachThreadOp) {
|
||||||
PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator();
|
PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator();
|
||||||
SmallVector<OpOperand *> result;
|
SmallVector<OpOperand *> result;
|
||||||
terminator.walk([&](ParallelInsertSliceOp insertOp) {
|
terminator.walk([&](tensor::ParallelInsertSliceOp insertOp) {
|
||||||
result.push_back(&insertOp->getOpOperand(1) /*dest*/);
|
result.push_back(&insertOp->getOpOperand(1) /*dest*/);
|
||||||
});
|
});
|
||||||
return result;
|
return result;
|
||||||
|
@ -1004,248 +1004,6 @@ struct PerformConcurrentlyOpInterface
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
|
|
||||||
/// equivalent operand / result and same offset/sizes/strides specification).
|
|
||||||
static bool areEquivalentExtractSliceOps(const AnalysisState &state,
|
|
||||||
ExtractSliceOp st,
|
|
||||||
ParallelInsertSliceOp sti) {
|
|
||||||
if (!st || !sti)
|
|
||||||
return false;
|
|
||||||
if (st != sti &&
|
|
||||||
!state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
|
|
||||||
return false;
|
|
||||||
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
|
|
||||||
return false;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return true if `value` is originating from an ExtractSliceOp that matches
|
|
||||||
/// the given InsertSliceOp.
|
|
||||||
static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
|
|
||||||
ParallelInsertSliceOp insertOp) {
|
|
||||||
auto condition = [&](Value val) {
|
|
||||||
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
|
|
||||||
if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
|
|
||||||
return true;
|
|
||||||
return false;
|
|
||||||
};
|
|
||||||
|
|
||||||
return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
|
|
||||||
condition);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Analysis of ParallelInsertSliceOp.
|
|
||||||
struct ParallelInsertSliceOpInterface
|
|
||||||
: public BufferizableOpInterface::ExternalModel<
|
|
||||||
ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
|
|
||||||
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
|
||||||
const AnalysisState &state) const {
|
|
||||||
if (&opOperand != &op->getOpOperand(1) /*dest*/)
|
|
||||||
return {};
|
|
||||||
|
|
||||||
// ParallelInsertSliceOp itself has no results, query its tied op results.
|
|
||||||
auto insertOp = cast<ParallelInsertSliceOp>(op);
|
|
||||||
return {insertOp.getTiedOpResult()};
|
|
||||||
}
|
|
||||||
|
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
|
||||||
const AnalysisState &state) const {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
|
||||||
const AnalysisState &state) const {
|
|
||||||
return &opOperand == &op->getOpOperand(1) /*dest*/;
|
|
||||||
}
|
|
||||||
|
|
||||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
|
||||||
const AnalysisState &state) const {
|
|
||||||
return BufferRelation::Equivalent;
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
|
|
||||||
const AnalysisState &state) const {
|
|
||||||
// This interface method is overridden because we want to set a custom
|
|
||||||
// insertion point for tensor copies. They should be inserted right before
|
|
||||||
// the ForeachThreadOp. E.g.:
|
|
||||||
//
|
|
||||||
// %r0, %r1 = foreach_thead ... {
|
|
||||||
// ...
|
|
||||||
// perform_concurrently {
|
|
||||||
// parallel_insert_slice %a into %b ... {inplace = ["true", "true"]}
|
|
||||||
// parallel_insert_slice %c into %d ... {inplace = ["true", "false"]}
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// After TensorCopyInsertion:
|
|
||||||
//
|
|
||||||
// %copy = bufferization.alloc_tensor() copy(%d)
|
|
||||||
// %r0, %r1 = foreach_thead ... {
|
|
||||||
// ...
|
|
||||||
// perform_concurrently {
|
|
||||||
// parallel_insert_slice %a into %b ...
|
|
||||||
// parallel_insert_slice %c into %copy ...
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
OpBuilder::InsertionGuard g(rewriter);
|
|
||||||
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
|
|
||||||
ParallelCombiningOpInterface parallelCombiningParent =
|
|
||||||
parallelInsertSliceOp.getParallelCombiningParent();
|
|
||||||
Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
|
|
||||||
|
|
||||||
// Nothing to do if the destination tensor is inplace.
|
|
||||||
assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
|
|
||||||
"source is always in-place");
|
|
||||||
if (state.isInPlace(op->getOpOperand(1) /*dest*/))
|
|
||||||
return success();
|
|
||||||
|
|
||||||
// Find corresponding OpResult.
|
|
||||||
OpResult opResult = parallelInsertSliceOp.getTiedOpResult();
|
|
||||||
|
|
||||||
// Insert tensor allocation right before the ForeachThreadOp.
|
|
||||||
rewriter.setInsertionPoint(parallelIteratingOp);
|
|
||||||
bool isYielded = state.isTensorYielded(opResult);
|
|
||||||
FailureOr<Value> alloc = allocateTensorForShapedValue(
|
|
||||||
rewriter, op->getLoc(), parallelInsertSliceOp.getDest(),
|
|
||||||
/*escape=*/isYielded, state.getOptions());
|
|
||||||
if (failed(alloc))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Update destination operand.
|
|
||||||
rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() {
|
|
||||||
parallelInsertSliceOp.getDestMutable().assign(*alloc);
|
|
||||||
});
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
|
||||||
const BufferizationOptions &options) const {
|
|
||||||
OpBuilder::InsertionGuard g(rewriter);
|
|
||||||
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
|
|
||||||
ParallelCombiningOpInterface parallelCombiningParent =
|
|
||||||
parallelInsertSliceOp.getParallelCombiningParent();
|
|
||||||
Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
|
|
||||||
|
|
||||||
// Get destination buffer.
|
|
||||||
FailureOr<Value> destBuffer =
|
|
||||||
getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
|
|
||||||
if (failed(destBuffer))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`.
|
|
||||||
rewriter.setInsertionPoint(parallelCombiningParent);
|
|
||||||
FailureOr<Value> srcBuffer =
|
|
||||||
getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
|
|
||||||
if (failed(srcBuffer))
|
|
||||||
return failure();
|
|
||||||
Value subview = rewriter.create<memref::SubViewOp>(
|
|
||||||
parallelInsertSliceOp.getLoc(), *destBuffer,
|
|
||||||
parallelInsertSliceOp.getMixedOffsets(),
|
|
||||||
parallelInsertSliceOp.getMixedSizes(),
|
|
||||||
parallelInsertSliceOp.getMixedStrides());
|
|
||||||
// This memcpy will fold away if everything bufferizes in-place.
|
|
||||||
if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
|
|
||||||
*srcBuffer, subview)))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Replace all uses of parallelIteratingOp (just the corresponding result).
|
|
||||||
rewriter.setInsertionPointAfter(parallelIteratingOp);
|
|
||||||
Value toTensorOp =
|
|
||||||
rewriter.create<ToTensorOp>(parallelIteratingOp->getLoc(), *destBuffer);
|
|
||||||
// PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
|
|
||||||
SmallVector<OpOperand *> resultUses = llvm::to_vector(
|
|
||||||
llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(),
|
|
||||||
[](OpOperand &use) { return &use; }));
|
|
||||||
for (OpOperand *use : resultUses) {
|
|
||||||
rewriter.updateRootInPlace(use->getOwner(),
|
|
||||||
[&]() { use->set(toTensorOp); });
|
|
||||||
}
|
|
||||||
rewriter.eraseOp(op);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
|
|
||||||
// the code.
|
|
||||||
bool isNotConflicting(Operation *op, OpOperand *uRead,
|
|
||||||
OpOperand *uConflictingWrite,
|
|
||||||
const AnalysisState &state) const {
|
|
||||||
Operation *readingOp = uRead->getOwner();
|
|
||||||
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
|
|
||||||
|
|
||||||
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
|
|
||||||
// uRead is an InsertSliceOp...
|
|
||||||
if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) {
|
|
||||||
// As an example, consider the following IR.
|
|
||||||
//
|
|
||||||
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
|
||||||
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
|
||||||
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
|
||||||
// {inplace= [true] }
|
|
||||||
|
|
||||||
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
|
|
||||||
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
|
||||||
hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
|
|
||||||
insertSliceOp))
|
|
||||||
// Case 1: The main insight is that InsertSliceOp reads only part of
|
|
||||||
// the destination tensor. The overwritten area is not read. If
|
|
||||||
// uConflictingWrite writes into exactly the memory location that is
|
|
||||||
// being read by uRead, this is not a conflict.
|
|
||||||
//
|
|
||||||
// In the above example:
|
|
||||||
// uRead = OpOperand 1 (%t) of tensor.insert_slice
|
|
||||||
// uConflictingWrite = OpOperand 1 (%0) of linalg.fill
|
|
||||||
//
|
|
||||||
// The read of %t does not conflict with the write of the FillOp
|
|
||||||
// (same aliases!) because the area that the FillOp operates on is
|
|
||||||
// exactly the one that is *not* read via %t.
|
|
||||||
return true;
|
|
||||||
|
|
||||||
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
|
|
||||||
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
|
||||||
hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
|
|
||||||
// Case 2: The read of the source tensor and the write to the dest
|
|
||||||
// tensor via an InsertSliceOp is not a conflict if the read is
|
|
||||||
// reading exactly that part of an equivalent tensor that the
|
|
||||||
// InsertSliceOp is writing.
|
|
||||||
//
|
|
||||||
// In the above example:
|
|
||||||
// uRead = OpOperand 0 (%1) of tensor.insert_slice
|
|
||||||
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If uConflictingWrite is an InsertSliceOp...
|
|
||||||
if (auto insertSliceOp =
|
|
||||||
dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp))
|
|
||||||
// As an example, consider the following IR.
|
|
||||||
//
|
|
||||||
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
|
||||||
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
|
||||||
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
|
||||||
// {inplace= [true] }
|
|
||||||
// %3 = vector.transfer_read %1, %cst
|
|
||||||
//
|
|
||||||
// In the above example:
|
|
||||||
// uRead = OpOperand 0 (%1) of vector.transfer_read
|
|
||||||
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
|
||||||
// lastWrite = %1
|
|
||||||
//
|
|
||||||
// This is not a conflict because the InsertSliceOp overwrites the
|
|
||||||
// memory segment of %1 with the exact same data. (Effectively, there
|
|
||||||
// is no memory write here.)
|
|
||||||
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
|
||||||
state.areEquivalentBufferizedValues(uRead->get(),
|
|
||||||
insertSliceOp.getSource()) &&
|
|
||||||
hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
|
|
||||||
insertSliceOp))
|
|
||||||
return true;
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace scf
|
} // namespace scf
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@ -1257,8 +1015,6 @@ void mlir::scf::registerBufferizableOpInterfaceExternalModels(
|
||||||
ForOp::attachInterface<ForOpInterface>(*ctx);
|
ForOp::attachInterface<ForOpInterface>(*ctx);
|
||||||
IfOp::attachInterface<IfOpInterface>(*ctx);
|
IfOp::attachInterface<IfOpInterface>(*ctx);
|
||||||
ForeachThreadOp::attachInterface<ForeachThreadOpInterface>(*ctx);
|
ForeachThreadOp::attachInterface<ForeachThreadOpInterface>(*ctx);
|
||||||
ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
|
|
||||||
*ctx);
|
|
||||||
PerformConcurrentlyOp::attachInterface<PerformConcurrentlyOpInterface>(
|
PerformConcurrentlyOp::attachInterface<PerformConcurrentlyOpInterface>(
|
||||||
*ctx);
|
*ctx);
|
||||||
WhileOp::attachInterface<WhileOpInterface>(*ctx);
|
WhileOp::attachInterface<WhileOpInterface>(*ctx);
|
||||||
|
|
|
@ -26,6 +26,7 @@ add_mlir_dialect_library(MLIRTensorDialect
|
||||||
MLIRDialectUtils
|
MLIRDialectUtils
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRInferTypeOpInterface
|
MLIRInferTypeOpInterface
|
||||||
|
MLIRParallelCombiningOpInterface
|
||||||
MLIRSideEffectInterfaces
|
MLIRSideEffectInterfaces
|
||||||
MLIRSupport
|
MLIRSupport
|
||||||
MLIRViewLikeInterface
|
MLIRViewLikeInterface
|
||||||
|
|
|
@ -2179,6 +2179,137 @@ OpFoldResult PadOp::fold(ArrayRef<Attribute>) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ParallelInsertSliceOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpResult ParallelInsertSliceOp::getTiedOpResult() {
|
||||||
|
ParallelCombiningOpInterface parallelCombiningParent =
|
||||||
|
getParallelCombiningParent();
|
||||||
|
for (const auto &it :
|
||||||
|
llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
|
||||||
|
Operation &nextOp = it.value();
|
||||||
|
if (&nextOp == getOperation())
|
||||||
|
return parallelCombiningParent.getParentResult(it.index());
|
||||||
|
}
|
||||||
|
llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
|
||||||
|
void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
|
||||||
|
Value source, Value dest,
|
||||||
|
ArrayRef<OpFoldResult> offsets,
|
||||||
|
ArrayRef<OpFoldResult> sizes,
|
||||||
|
ArrayRef<OpFoldResult> strides,
|
||||||
|
ArrayRef<NamedAttribute> attrs) {
|
||||||
|
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
|
||||||
|
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
|
||||||
|
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
|
||||||
|
ShapedType::kDynamicStrideOrOffset);
|
||||||
|
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
|
||||||
|
ShapedType::kDynamicSize);
|
||||||
|
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
|
||||||
|
ShapedType::kDynamicStrideOrOffset);
|
||||||
|
build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
|
||||||
|
dynamicStrides, b.getI64ArrayAttr(staticOffsets),
|
||||||
|
b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
|
||||||
|
result.addAttributes(attrs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a ParallelInsertSliceOp with dynamic entries.
|
||||||
|
void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
|
||||||
|
Value source, Value dest, ValueRange offsets,
|
||||||
|
ValueRange sizes, ValueRange strides,
|
||||||
|
ArrayRef<NamedAttribute> attrs) {
|
||||||
|
SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
|
||||||
|
llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
|
||||||
|
SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
|
||||||
|
llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
|
||||||
|
SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
|
||||||
|
llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
|
||||||
|
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult ParallelInsertSliceOp::verify() {
|
||||||
|
if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
|
||||||
|
return this->emitError("expected ParallelCombiningOpInterface parent, got:")
|
||||||
|
<< *(getOperation()->getParentOp());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
/// Pattern to rewrite a parallel_insert_slice op with constant arguments.
|
||||||
|
class ParallelInsertSliceOpConstantArgumentFolder final
|
||||||
|
: public OpRewritePattern<ParallelInsertSliceOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<ParallelInsertSliceOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// No constant operand, just return.
|
||||||
|
if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
|
||||||
|
return matchPattern(operand, matchConstantIndex());
|
||||||
|
}))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// At least one of offsets/sizes/strides is a new constant.
|
||||||
|
// Form the new list of operands and constant attributes from the
|
||||||
|
// existing.
|
||||||
|
SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
|
||||||
|
SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
|
||||||
|
SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
|
||||||
|
canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
|
||||||
|
canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
|
||||||
|
canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
|
||||||
|
|
||||||
|
// Create the new op in canonical form.
|
||||||
|
rewriter.replaceOpWithNewOp<ParallelInsertSliceOp>(
|
||||||
|
insertSliceOp, insertSliceOp.getSource(), insertSliceOp.getDest(),
|
||||||
|
mixedOffsets, mixedSizes, mixedStrides);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
/// Fold a parallel_insert_slice source coming from a tensor.cast op.
|
||||||
|
///
|
||||||
|
/// Example:
|
||||||
|
/// ```
|
||||||
|
/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
|
||||||
|
/// %1 = compute_some_tensor() : tensor<64xf32>
|
||||||
|
/// %2 = tensor.cast %1 : tensor<64xf32> to tensor<?xf32>
|
||||||
|
/// scf.foreach_thread.perform_concurrently {
|
||||||
|
/// scf.foreach_thread.parallel_insert_slice %2 into %out[...] [64] [1] :
|
||||||
|
/// tensor<?xf32> into tensor<128xf32>
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// is folded into:
|
||||||
|
/// ```
|
||||||
|
/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
|
||||||
|
/// %1 = compute_some_tensor() : tensor<64xf32>
|
||||||
|
/// scf.foreach_thread.perform_concurrently {
|
||||||
|
/// scf.foreach_thread.parallel_insert_slice %1 into %out[...] [64] [1] :
|
||||||
|
/// tensor<64xf32> into tensor<128xf32>
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
LogicalResult
|
||||||
|
ParallelInsertSliceOp::fold(ArrayRef<Attribute> operands,
|
||||||
|
SmallVectorImpl<OpFoldResult> &results) {
|
||||||
|
auto sourceCast = getSource().getDefiningOp<tensor::CastOp>();
|
||||||
|
if (!sourceCast)
|
||||||
|
return failure();
|
||||||
|
getSourceMutable().assign(sourceCast.getSource());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void ParallelInsertSliceOp::getCanonicalizationPatterns(
|
||||||
|
RewritePatternSet &results, MLIRContext *context) {
|
||||||
|
results.add<ParallelInsertSliceOpConstantArgumentFolder>(context);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// SplatOp
|
// SplatOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -810,6 +810,248 @@ struct ReshapeOpInterface
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
|
||||||
|
/// equivalent operand / result and same offset/sizes/strides specification).
|
||||||
|
static bool areEquivalentExtractSliceOps(const AnalysisState &state,
|
||||||
|
ExtractSliceOp st,
|
||||||
|
ParallelInsertSliceOp sti) {
|
||||||
|
if (!st || !sti)
|
||||||
|
return false;
|
||||||
|
if (st != sti &&
|
||||||
|
!state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
|
||||||
|
return false;
|
||||||
|
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return true if `value` is originating from an ExtractSliceOp that matches
|
||||||
|
/// the given InsertSliceOp.
|
||||||
|
static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
|
||||||
|
ParallelInsertSliceOp insertOp) {
|
||||||
|
auto condition = [&](Value val) {
|
||||||
|
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
|
||||||
|
if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
|
||||||
|
condition);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Analysis of ParallelInsertSliceOp.
|
||||||
|
struct ParallelInsertSliceOpInterface
|
||||||
|
: public BufferizableOpInterface::ExternalModel<
|
||||||
|
ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
|
||||||
|
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
|
const AnalysisState &state) const {
|
||||||
|
if (&opOperand != &op->getOpOperand(1) /*dest*/)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
// ParallelInsertSliceOp itself has no results, query its tied op results.
|
||||||
|
auto insertOp = cast<ParallelInsertSliceOp>(op);
|
||||||
|
return {insertOp.getTiedOpResult()};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||||
|
const AnalysisState &state) const {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||||
|
const AnalysisState &state) const {
|
||||||
|
return &opOperand == &op->getOpOperand(1) /*dest*/;
|
||||||
|
}
|
||||||
|
|
||||||
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||||
|
const AnalysisState &state) const {
|
||||||
|
return BufferRelation::Equivalent;
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
|
||||||
|
const AnalysisState &state) const {
|
||||||
|
// This interface method is overridden because we want to set a custom
|
||||||
|
// insertion point for tensor copies. They should be inserted right before
|
||||||
|
// the ForeachThreadOp. E.g.:
|
||||||
|
//
|
||||||
|
// %r0, %r1 = foreach_thead ... {
|
||||||
|
// ...
|
||||||
|
// perform_concurrently {
|
||||||
|
// parallel_insert_slice %a into %b ... {inplace = ["true", "true"]}
|
||||||
|
// parallel_insert_slice %c into %d ... {inplace = ["true", "false"]}
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// After TensorCopyInsertion:
|
||||||
|
//
|
||||||
|
// %copy = bufferization.alloc_tensor() copy(%d)
|
||||||
|
// %r0, %r1 = foreach_thead ... {
|
||||||
|
// ...
|
||||||
|
// perform_concurrently {
|
||||||
|
// parallel_insert_slice %a into %b ...
|
||||||
|
// parallel_insert_slice %c into %copy ...
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
OpBuilder::InsertionGuard g(rewriter);
|
||||||
|
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
|
||||||
|
ParallelCombiningOpInterface parallelCombiningParent =
|
||||||
|
parallelInsertSliceOp.getParallelCombiningParent();
|
||||||
|
Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
|
||||||
|
|
||||||
|
// Nothing to do if the destination tensor is inplace.
|
||||||
|
assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
|
||||||
|
"source is always in-place");
|
||||||
|
if (state.isInPlace(op->getOpOperand(1) /*dest*/))
|
||||||
|
return success();
|
||||||
|
|
||||||
|
// Find corresponding OpResult.
|
||||||
|
OpResult opResult = parallelInsertSliceOp.getTiedOpResult();
|
||||||
|
|
||||||
|
// Insert tensor allocation right before the ForeachThreadOp.
|
||||||
|
rewriter.setInsertionPoint(parallelIteratingOp);
|
||||||
|
bool isYielded = state.isTensorYielded(opResult);
|
||||||
|
FailureOr<Value> alloc = allocateTensorForShapedValue(
|
||||||
|
rewriter, op->getLoc(), parallelInsertSliceOp.getDest(),
|
||||||
|
/*escape=*/isYielded, state.getOptions());
|
||||||
|
if (failed(alloc))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Update destination operand.
|
||||||
|
rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() {
|
||||||
|
parallelInsertSliceOp.getDestMutable().assign(*alloc);
|
||||||
|
});
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
|
const BufferizationOptions &options) const {
|
||||||
|
OpBuilder::InsertionGuard g(rewriter);
|
||||||
|
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
|
||||||
|
ParallelCombiningOpInterface parallelCombiningParent =
|
||||||
|
parallelInsertSliceOp.getParallelCombiningParent();
|
||||||
|
Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
|
||||||
|
|
||||||
|
// Get destination buffer.
|
||||||
|
FailureOr<Value> destBuffer =
|
||||||
|
getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
|
||||||
|
if (failed(destBuffer))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`.
|
||||||
|
rewriter.setInsertionPoint(parallelCombiningParent);
|
||||||
|
FailureOr<Value> srcBuffer =
|
||||||
|
getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
|
||||||
|
if (failed(srcBuffer))
|
||||||
|
return failure();
|
||||||
|
Value subview = rewriter.create<memref::SubViewOp>(
|
||||||
|
parallelInsertSliceOp.getLoc(), *destBuffer,
|
||||||
|
parallelInsertSliceOp.getMixedOffsets(),
|
||||||
|
parallelInsertSliceOp.getMixedSizes(),
|
||||||
|
parallelInsertSliceOp.getMixedStrides());
|
||||||
|
// This memcpy will fold away if everything bufferizes in-place.
|
||||||
|
if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
|
||||||
|
*srcBuffer, subview)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Replace all uses of parallelIteratingOp (just the corresponding result).
|
||||||
|
rewriter.setInsertionPointAfter(parallelIteratingOp);
|
||||||
|
Value toTensorOp =
|
||||||
|
rewriter.create<ToTensorOp>(parallelIteratingOp->getLoc(), *destBuffer);
|
||||||
|
// PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
|
||||||
|
SmallVector<OpOperand *> resultUses = llvm::to_vector(
|
||||||
|
llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(),
|
||||||
|
[](OpOperand &use) { return &use; }));
|
||||||
|
for (OpOperand *use : resultUses) {
|
||||||
|
rewriter.updateRootInPlace(use->getOwner(),
|
||||||
|
[&]() { use->set(toTensorOp); });
|
||||||
|
}
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
|
||||||
|
// the code.
|
||||||
|
bool isNotConflicting(Operation *op, OpOperand *uRead,
|
||||||
|
OpOperand *uConflictingWrite,
|
||||||
|
const AnalysisState &state) const {
|
||||||
|
Operation *readingOp = uRead->getOwner();
|
||||||
|
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
|
||||||
|
|
||||||
|
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
|
||||||
|
// uRead is an InsertSliceOp...
|
||||||
|
if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) {
|
||||||
|
// As an example, consider the following IR.
|
||||||
|
//
|
||||||
|
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
||||||
|
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
||||||
|
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
||||||
|
// {inplace= [true] }
|
||||||
|
|
||||||
|
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
|
||||||
|
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||||
|
hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
|
||||||
|
insertSliceOp))
|
||||||
|
// Case 1: The main insight is that InsertSliceOp reads only part of
|
||||||
|
// the destination tensor. The overwritten area is not read. If
|
||||||
|
// uConflictingWrite writes into exactly the memory location that is
|
||||||
|
// being read by uRead, this is not a conflict.
|
||||||
|
//
|
||||||
|
// In the above example:
|
||||||
|
// uRead = OpOperand 1 (%t) of tensor.insert_slice
|
||||||
|
// uConflictingWrite = OpOperand 1 (%0) of linalg.fill
|
||||||
|
//
|
||||||
|
// The read of %t does not conflict with the write of the FillOp
|
||||||
|
// (same aliases!) because the area that the FillOp operates on is
|
||||||
|
// exactly the one that is *not* read via %t.
|
||||||
|
return true;
|
||||||
|
|
||||||
|
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
|
||||||
|
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||||
|
hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
|
||||||
|
// Case 2: The read of the source tensor and the write to the dest
|
||||||
|
// tensor via an InsertSliceOp is not a conflict if the read is
|
||||||
|
// reading exactly that part of an equivalent tensor that the
|
||||||
|
// InsertSliceOp is writing.
|
||||||
|
//
|
||||||
|
// In the above example:
|
||||||
|
// uRead = OpOperand 0 (%1) of tensor.insert_slice
|
||||||
|
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If uConflictingWrite is an InsertSliceOp...
|
||||||
|
if (auto insertSliceOp =
|
||||||
|
dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp))
|
||||||
|
// As an example, consider the following IR.
|
||||||
|
//
|
||||||
|
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
||||||
|
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
||||||
|
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
||||||
|
// {inplace= [true] }
|
||||||
|
// %3 = vector.transfer_read %1, %cst
|
||||||
|
//
|
||||||
|
// In the above example:
|
||||||
|
// uRead = OpOperand 0 (%1) of vector.transfer_read
|
||||||
|
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
||||||
|
// lastWrite = %1
|
||||||
|
//
|
||||||
|
// This is not a conflict because the InsertSliceOp overwrites the
|
||||||
|
// memory segment of %1 with the exact same data. (Effectively, there
|
||||||
|
// is no memory write here.)
|
||||||
|
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||||
|
state.areEquivalentBufferizedValues(uRead->get(),
|
||||||
|
insertSliceOp.getSource()) &&
|
||||||
|
hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
|
||||||
|
insertSliceOp))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensor
|
} // namespace tensor
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@ -827,6 +1069,8 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
|
||||||
GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
|
GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
|
||||||
InsertOp::attachInterface<InsertOpInterface>(*ctx);
|
InsertOp::attachInterface<InsertOpInterface>(*ctx);
|
||||||
InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
|
InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
|
||||||
|
ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
|
||||||
|
*ctx);
|
||||||
RankOp::attachInterface<RankOpInterface>(*ctx);
|
RankOp::attachInterface<RankOpInterface>(*ctx);
|
||||||
ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
|
ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
|
||||||
});
|
});
|
||||||
|
|
|
@ -1457,28 +1457,3 @@ func.func @func_execute_region_elim_multi_yield() {
|
||||||
// CHECK: ^[[bb3]](%[[z:.+]]: i64):
|
// CHECK: ^[[bb3]](%[[z:.+]]: i64):
|
||||||
// CHECK: "test.bar"(%[[z]])
|
// CHECK: "test.bar"(%[[z]])
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @canonicalize_parallel_insert_slice_indices(
|
|
||||||
// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<?x?xf32>,
|
|
||||||
// CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor<?x?xf32>,
|
|
||||||
// CHECK-SAME: %[[num_threads:[0-9a-z]*]]: index
|
|
||||||
func.func @canonicalize_parallel_insert_slice_indices(
|
|
||||||
%arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
|
|
||||||
%num_threads : index) -> tensor<?x?xf32>
|
|
||||||
{
|
|
||||||
%cst = arith.constant 4.200000e+01 : f32
|
|
||||||
%c0 = arith.constant 0 : index
|
|
||||||
%c1 = arith.constant 1 : index
|
|
||||||
|
|
||||||
// CHECK: scf.foreach_thread (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) -> (tensor<?x?xf32>) {
|
|
||||||
// CHECK-NEXT: scf.foreach_thread.perform_concurrently {
|
|
||||||
// CHECK-NEXT: scf.foreach_thread.parallel_insert_slice %[[arg0]] into %[[arg1]][%[[tidx]], 0] [1, 5] [1, 1]
|
|
||||||
%2 = scf.foreach_thread (%tidx) in (%num_threads) -> (tensor<?x?xf32>) {
|
|
||||||
scf.foreach_thread.perform_concurrently {
|
|
||||||
scf.foreach_thread.parallel_insert_slice %arg0 into %arg1[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor<?x?xf32> into tensor<?x?xf32>
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return %2 : tensor<?x?xf32>
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,36 +1,36 @@
|
||||||
// RUN: mlir-opt %s -scf-for-loop-canonicalization -canonicalize | FileCheck %s
|
// RUN: mlir-opt %s -scf-for-loop-canonicalization | FileCheck %s
|
||||||
|
|
||||||
func.func @reduce() -> tensor<128xf32> {
|
func.func @reduce() {
|
||||||
|
// CHECK: %[[C64:.*]] = arith.constant 64 : index
|
||||||
%c2 = arith.constant 2 : index
|
%c2 = arith.constant 2 : index
|
||||||
%cst = arith.constant dense<1.000000e+00> : tensor<1x128x384xf32>
|
|
||||||
%cst_0 = arith.constant -0.000000e+00 : f32
|
%cst_0 = arith.constant -0.000000e+00 : f32
|
||||||
%0 = linalg.init_tensor [128, 384] : tensor<128x384xf32>
|
%0 = memref.alloc() : memref<128x384xf32>
|
||||||
%1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<128x384xf32>) -> tensor<128x384xf32>
|
linalg.fill ins(%cst_0 : f32) outs(%0 : memref<128x384xf32>)
|
||||||
%2 = linalg.init_tensor [128] : tensor<128xf32>
|
%2 = memref.alloc() : memref<128xf32>
|
||||||
%3 = linalg.fill ins(%cst_0 : f32) outs(%2 : tensor<128xf32>) -> tensor<128xf32>
|
linalg.fill ins(%cst_0 : f32) outs(%2 : memref<128xf32>)
|
||||||
%4 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
|
scf.foreach_thread (%arg0) in (%c2) {
|
||||||
%7 = affine.min affine_map<(d0) -> (d0 * -64 + 128, 64)>(%arg0)
|
%7 = affine.min affine_map<(d0) -> (d0 * -64 + 128, 64)>(%arg0)
|
||||||
%8 = affine.max affine_map<(d0) -> (0, d0)>(%7)
|
%8 = affine.max affine_map<(d0) -> (0, d0)>(%7)
|
||||||
%9 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg0)
|
%9 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg0)
|
||||||
%10 = affine.min affine_map<(d0, d1) -> (d1 * -64 + 128, d0)>(%8, %arg0)
|
%10 = affine.min affine_map<(d0, d1) -> (d1 * -64 + 128, d0)>(%8, %arg0)
|
||||||
|
|
||||||
// CHECK: tensor.extract_slice %{{.*}}[%{{.*}}, 0] [64, 384] [1, 1] : tensor<128x384xf32> to tensor<64x384xf32>
|
// CHECK: memref.subview %{{.*}}[%{{.*}}, 0] [%[[C64]], 384] [1, 1] : memref<128x384xf32> to memref<?x384xf32, {{.*}}>
|
||||||
// CHECK: tensor.extract_slice %{{.*}}[%{{.*}}] [64] [1] : tensor<128xf32> to tensor<64xf32>
|
// CHECK: memref.subview %{{.*}}[%{{.*}}] [%[[C64]]] [1] : memref<128xf32> to memref<?xf32, {{.*}}>
|
||||||
%11 = tensor.extract_slice %1[%9, 0] [%10, 384] [1, 1] : tensor<128x384xf32> to tensor<?x384xf32>
|
%11 = memref.subview %0[%9, 0] [%10, 384] [1, 1] :
|
||||||
%12 = tensor.extract_slice %3[%9] [%10] [1] : tensor<128xf32> to tensor<?xf32>
|
memref<128x384xf32> to memref<?x384xf32, affine_map<(d0, d1)[s0] -> (d0 * 384 + s0 + d1)>>
|
||||||
|
%12 = memref.subview %2[%9] [%10] [1] :
|
||||||
|
memref<128xf32> to memref<?xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
|
||||||
|
|
||||||
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<64x384xf32>) outs(%{{.*}} : tensor<64xf32>) {
|
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : memref<?x384xf32, {{.*}}>) outs(%{{.*}} : memref<?xf32, {{.*}}>)
|
||||||
%13 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%11 : tensor<?x384xf32>) outs(%12 : tensor<?xf32>) {
|
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
||||||
^bb0(%arg1: f32, %arg2: f32):
|
affine_map<(d0, d1) -> (d0)>],
|
||||||
%14 = arith.addf %arg1, %arg2 : f32
|
iterator_types = ["parallel", "reduction"]}
|
||||||
linalg.yield %14 : f32
|
ins(%11 : memref<?x384xf32, affine_map<(d0, d1)[s0] -> (d0 * 384 + s0 + d1)>>)
|
||||||
} -> tensor<?xf32>
|
outs(%12 : memref<?xf32, affine_map<(d0)[s0] -> (d0 + s0)>>) {
|
||||||
|
^bb0(%arg1: f32, %arg2: f32):
|
||||||
// CHECK-NOT: tensor.cast
|
%14 = arith.addf %arg1, %arg2 : f32
|
||||||
// CHECK: scf.foreach_thread.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [64] [1] : tensor<64xf32> into tensor<128xf32>
|
linalg.yield %14 : f32
|
||||||
scf.foreach_thread.perform_concurrently {
|
}
|
||||||
scf.foreach_thread.parallel_insert_slice %13 into %3[%9] [%10] [1] : tensor<?xf32> into tensor<128xf32>
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return %4 : tensor<128xf32>
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -531,7 +531,7 @@ func.func @wrong_num_results(%in: tensor<100xf32>, %out: tensor<100xf32>) {
|
||||||
%result:2 = scf.foreach_thread (%thread_idx) in (%num_threads) -> (tensor<100xf32>, tensor<100xf32>) {
|
%result:2 = scf.foreach_thread (%thread_idx) in (%num_threads) -> (tensor<100xf32>, tensor<100xf32>) {
|
||||||
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
|
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
|
||||||
scf.foreach_thread.perform_concurrently {
|
scf.foreach_thread.perform_concurrently {
|
||||||
scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
|
tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
|
||||||
tensor<1xf32> into tensor<100xf32>
|
tensor<1xf32> into tensor<100xf32>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -548,7 +548,7 @@ func.func @wrong_type_result(%in: tensor<100xf32>, %out: tensor<100xf32>) {
|
||||||
%result = scf.foreach_thread (%thread_idx) in (%num_threads) -> (tensor<?xf32>) {
|
%result = scf.foreach_thread (%thread_idx) in (%num_threads) -> (tensor<?xf32>) {
|
||||||
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
|
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
|
||||||
scf.foreach_thread.perform_concurrently {
|
scf.foreach_thread.perform_concurrently {
|
||||||
scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
|
tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
|
||||||
tensor<1xf32> into tensor<100xf32>
|
tensor<1xf32> into tensor<100xf32>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -563,9 +563,9 @@ func.func @wrong_terminator_op(%in: tensor<100xf32>, %out: tensor<100xf32>) {
|
||||||
|
|
||||||
%result = scf.foreach_thread (%thread_idx) in (%num_threads) -> (tensor<100xf32>) {
|
%result = scf.foreach_thread (%thread_idx) in (%num_threads) -> (tensor<100xf32>) {
|
||||||
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
|
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
|
||||||
// expected-error @+1 {{expected only scf.foreach_thread.parallel_insert_slice ops}}
|
// expected-error @+1 {{expected only tensor.parallel_insert_slice ops}}
|
||||||
scf.foreach_thread.perform_concurrently {
|
scf.foreach_thread.perform_concurrently {
|
||||||
scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
|
tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
|
||||||
tensor<1xf32> into tensor<100xf32>
|
tensor<1xf32> into tensor<100xf32>
|
||||||
%0 = arith.constant 1: index
|
%0 = arith.constant 1: index
|
||||||
}
|
}
|
||||||
|
|
|
@ -124,10 +124,10 @@ func.func @scf_foreach_thread_out_of_place(%in: tensor<100xf32>,
|
||||||
%result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<100xf32> {
|
%result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<100xf32> {
|
||||||
// CHECK: tensor.extract_slice
|
// CHECK: tensor.extract_slice
|
||||||
// CHECK: scf.foreach_thread.perform_concurrently
|
// CHECK: scf.foreach_thread.perform_concurrently
|
||||||
// CHECK: scf.foreach_thread.parallel_insert_slice %{{.*}} into %[[alloc]]
|
// CHECK: tensor.parallel_insert_slice %{{.*}} into %[[alloc]]
|
||||||
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
|
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
|
||||||
scf.foreach_thread.perform_concurrently {
|
scf.foreach_thread.perform_concurrently {
|
||||||
scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
|
tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
|
||||||
tensor<1xf32> into tensor<100xf32>
|
tensor<1xf32> into tensor<100xf32>
|
||||||
}
|
}
|
||||||
// CHECK: } {thread_dim_mapping = [5]}
|
// CHECK: } {thread_dim_mapping = [5]}
|
||||||
|
|
|
@ -537,7 +537,7 @@ func.func @parallel_insert_slice_no_conflict(
|
||||||
// CHECK-NOT: scf.foreach_thread.perform_concurrently
|
// CHECK-NOT: scf.foreach_thread.perform_concurrently
|
||||||
// CHECK-NOT: parallel_insert_slice
|
// CHECK-NOT: parallel_insert_slice
|
||||||
scf.foreach_thread.perform_concurrently {
|
scf.foreach_thread.perform_concurrently {
|
||||||
scf.foreach_thread.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] :
|
tensor.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] :
|
||||||
tensor<?xf32> into tensor<?xf32>
|
tensor<?xf32> into tensor<?xf32>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -589,7 +589,7 @@ func.func @parallel_insert_slice_with_conflict(
|
||||||
// CHECK-NOT: scf.foreach_thread.perform_concurrently
|
// CHECK-NOT: scf.foreach_thread.perform_concurrently
|
||||||
// CHECK-NOT: parallel_insert_slice
|
// CHECK-NOT: parallel_insert_slice
|
||||||
scf.foreach_thread.perform_concurrently {
|
scf.foreach_thread.perform_concurrently {
|
||||||
scf.foreach_thread.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] :
|
tensor.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] :
|
||||||
tensor<?xf32> into tensor<?xf32>
|
tensor<?xf32> into tensor<?xf32>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -627,7 +627,7 @@ func.func @matmul(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>, %arg2: tensor<
|
||||||
// CHECK: linalg.matmul ins({{.*}}memref<4x8xf32, #[[$DYN_LAYOUT_MAP]]>, memref<8x4xf32, #[[$DYN_LAYOUT_MAP]]>) outs({{.*}} : memref<4x4xf32, #[[$DYN_LAYOUT_MAP]]>)
|
// CHECK: linalg.matmul ins({{.*}}memref<4x8xf32, #[[$DYN_LAYOUT_MAP]]>, memref<8x4xf32, #[[$DYN_LAYOUT_MAP]]>) outs({{.*}} : memref<4x4xf32, #[[$DYN_LAYOUT_MAP]]>)
|
||||||
%8 = linalg.matmul ins(%3, %6 : tensor<4x8xf32>, tensor<8x4xf32>) outs(%7 : tensor<4x4xf32>) -> tensor<4x4xf32>
|
%8 = linalg.matmul ins(%3, %6 : tensor<4x8xf32>, tensor<8x4xf32>) outs(%7 : tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||||
scf.foreach_thread.perform_concurrently {
|
scf.foreach_thread.perform_concurrently {
|
||||||
scf.foreach_thread.parallel_insert_slice %8 into %arg2[%1, %4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x8xf32>
|
tensor.parallel_insert_slice %8 into %arg2[%1, %4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x8xf32>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return %0 : tensor<8x8xf32>
|
return %0 : tensor<8x8xf32>
|
||||||
|
|
|
@ -319,14 +319,14 @@ func.func @simple_example(%in: tensor<100xf32>, %out: tensor<100xf32>) {
|
||||||
// CHECK: scf.foreach_thread
|
// CHECK: scf.foreach_thread
|
||||||
// CHECK-NEXT: tensor.extract_slice
|
// CHECK-NEXT: tensor.extract_slice
|
||||||
// CHECK-NEXT: scf.foreach_thread.perform_concurrently
|
// CHECK-NEXT: scf.foreach_thread.perform_concurrently
|
||||||
// CHECK-NEXT: scf.foreach_thread.parallel_insert_slice
|
// CHECK-NEXT: tensor.parallel_insert_slice
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: return
|
// CHECK-NEXT: return
|
||||||
%result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<100xf32> {
|
%result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<100xf32> {
|
||||||
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
|
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
|
||||||
scf.foreach_thread.perform_concurrently {
|
scf.foreach_thread.perform_concurrently {
|
||||||
scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
|
tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
|
||||||
tensor<1xf32> into tensor<100xf32>
|
tensor<1xf32> into tensor<100xf32>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1425,3 +1425,28 @@ func.func @cast_extract_slice_rank_reduce(%arg0 : tensor<128x512xf32>, %s : inde
|
||||||
// CHECK: return %[[E]] : tensor<16xf32>
|
// CHECK: return %[[E]] : tensor<16xf32>
|
||||||
return %1 : tensor<16xf32>
|
return %1 : tensor<16xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @canonicalize_parallel_insert_slice_indices(
|
||||||
|
// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<?x?xf32>,
|
||||||
|
// CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor<?x?xf32>,
|
||||||
|
// CHECK-SAME: %[[num_threads:[0-9a-z]*]]: index
|
||||||
|
func.func @canonicalize_parallel_insert_slice_indices(
|
||||||
|
%arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
|
||||||
|
%num_threads : index) -> tensor<?x?xf32>
|
||||||
|
{
|
||||||
|
%cst = arith.constant 4.200000e+01 : f32
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%c1 = arith.constant 1 : index
|
||||||
|
|
||||||
|
// CHECK: scf.foreach_thread (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) -> (tensor<?x?xf32>) {
|
||||||
|
// CHECK-NEXT: scf.foreach_thread.perform_concurrently {
|
||||||
|
// CHECK-NEXT: tensor.parallel_insert_slice %[[arg0]] into %[[arg1]][%[[tidx]], 0] [1, 5] [1, 1]
|
||||||
|
%2 = scf.foreach_thread (%tidx) in (%num_threads) -> (tensor<?x?xf32>) {
|
||||||
|
scf.foreach_thread.perform_concurrently {
|
||||||
|
tensor.parallel_insert_slice %arg0 into %arg1[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor<?x?xf32> into tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return %2 : tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
|
|
@ -4909,6 +4909,7 @@ td_library(
|
||||||
":ControlFlowInterfacesTdFiles",
|
":ControlFlowInterfacesTdFiles",
|
||||||
":InferTypeOpInterfaceTdFiles",
|
":InferTypeOpInterfaceTdFiles",
|
||||||
":OpBaseTdFiles",
|
":OpBaseTdFiles",
|
||||||
|
":ParallelCombiningOpInterfaceTdFiles",
|
||||||
":SideEffectInterfacesTdFiles",
|
":SideEffectInterfacesTdFiles",
|
||||||
":TilingInterfaceTdFiles",
|
":TilingInterfaceTdFiles",
|
||||||
":ViewLikeInterfaceTdFiles",
|
":ViewLikeInterfaceTdFiles",
|
||||||
|
@ -4965,6 +4966,7 @@ cc_library(
|
||||||
":DialectUtils",
|
":DialectUtils",
|
||||||
":IR",
|
":IR",
|
||||||
":InferTypeOpInterface",
|
":InferTypeOpInterface",
|
||||||
|
":ParallelCombiningOpInterface",
|
||||||
":SideEffectInterfaces",
|
":SideEffectInterfaces",
|
||||||
":TensorOpsIncGen",
|
":TensorOpsIncGen",
|
||||||
":TilingInterface",
|
":TilingInterface",
|
||||||
|
|
Loading…
Reference in New Issue