[mlir][Tensor] Add rewrites to extract slices through `tensor.collape_shape`

This change adds a set of utilities to replace the result of a
`tensor.collapse_shape -> tensor.extract_slice` chain with the
equivalent result formed by aggregating slices of the
`tensor.collapse_shape` source. In general, it is not possible to
commute `extract_slice` and `collapse_shape` if linearized dimensions
are sliced. The i-th dimension of the `tensor.collapse_shape`
result is a "linearized sliced dimension" if:

1) Reassociation indices of tensor.collapse_shape in the i'th position
   is greater than size 1 (multiple dimensions of the input are collapsed)
2) The i-th dimension is sliced by `tensor.extract_slice`.

We can work around this by stitching together the result of
`tensor.extract_slice` by iterating over any linearized sliced dimensions.
This is equivalent to "tiling" the linearized-and-sliced dimensions of
the `tensor.collapse_shape` operation in order to manifest the result
tile (the result of the `tensor.extract_slice`). The user of the
utilities must provide the mechanism to create the tiling (e.g. a loop).
In the tests, it is demonstrated how to apply the utilities using either
`scf.for` or `scf.foreach_thread`.

The below example illustrates the pattern using `scf.for`:

```
%0 = linalg.generic ... -> tensor<3x7x11x10xf32>
%1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : ... to tensor<341x10xf32>
%2 = tensor.extract_slice %1 [13, 0] [10, 10] [2, 1] : .... tensor<10x10xf32>
```

We can construct %2 by generating the following IR:

```
%dest = linalg.init_tensor() : tensor<10x10xf32>
%2 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0) -> tensor<10x10xf32> {
   // Step 1: Map this output idx (%iv) to a multi-index for the input (%3):
   %linear_index = affine.apply affine_map<(d0)[]->(d0*2 + 11)>(%iv)
   %3:3 = arith.delinearize_index %iv into (3, 7, 11)
   // Step 2: Extract the slice from the input
   %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
         tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
   %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
         tensor<1x1x1x10xf32> into tensor<1x10xf32>
   // Step 3: Insert the slice into the destination
   %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
         tensor<1x10xf32> into tensor<10x10xf32>
   scf.yield %6 : tensor<10x10xf32>
}
```

The pattern was discussed in the RFC here: https://discourse.llvm.org/t/rfc-tensor-extracting-slices-from-tensor-collapse-shape/64034

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D129699
This commit is contained in:
Christopher Bate 2022-09-08 15:21:57 -06:00
parent 7fa1d743d0
commit f4a478cd01
14 changed files with 940 additions and 3 deletions

View File

@ -283,7 +283,11 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
// Build an ExtractSliceOp with dynamic entries and inferred result type.
OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source,
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build an ExtractSliceOp with mixed static and dynamic entries packed in
// a Range vector.
OpBuilder<(ins "Value":$source, "ArrayRef<Range>":$ranges,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
];
let extraClassDeclaration = extraBaseClassDeclaration # [{
@ -739,6 +743,11 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
// Build a InsertSliceOp with dynamic entries.
OpBuilder<(ins "Value":$source, "Value":$dest,
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build an InsertSliceOp with mixed static and dynamic entries packed in
// a Range vector.
OpBuilder<(ins "Value":$source, "Value":$dest,
"ArrayRef<Range>":$ranges,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
];
@ -1337,7 +1346,11 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
"ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
"ArrayRef<OpFoldResult>":$strides,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build a ParallelInsertSliceOp with mixed static and dynamic entries
// packed into a Range vector.
OpBuilder<(ins "Value":$source, "Value":$dest,
"ArrayRef<Range>":$ranges,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build a ParallelInsertSliceOp with dynamic entries.
OpBuilder<(ins "Value":$source, "Value":$dest,
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,

View File

@ -0,0 +1,210 @@
//===- TransformsUtils.h - Tensor Transformation Utilities-------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMUTILS_H
#define MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMUTILS_H
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
namespace tensor {
//===----------------------------------------------------------------------===//
// Extract slice from `tensor.collapse_shape`
//===----------------------------------------------------------------------===//
/// This class assists with generating IR required to materialize an
/// arbitrary-sized slice from the result of a CollapseShapeOp. In order to
/// accomplish this, a loop nest or similar operation must be created by the
/// caller. The purpose of the loop nest is to generate a "tiling by 1" of all
/// sliced dimensions. The "tiling by 1" assembles all elements of the result
/// tile over dimensions that would have been impossible to directly slice.
///
/// The class provides three methods:
/// 1. `ExtractSliceFromCollapseHelper::create`: emits IR that should
/// appear before the loop nest and populates the internal state.
/// 2. `ExtractSliceFromCollapseHelper::getIterationSpaceSizes`: returns
/// parameters used by the caller to construct the loop nest.
/// 3. `ExtractSliceFromCollapseHelper::emitLoopNestBody`:
/// emits IR to construct a "size-1 tile" of the desired result and returns a
/// set of ranges where the tile should be inserted into the destination
/// tensor.
///
/// ### Intended usage:
///
/// The caller should first call `ExtractSliceFromCollapseHelper::create` and
/// then create a destination tensor that is the same size as the desired slice.
/// The caller then creates a loop nest that iterates over the multi-dimensional
/// iteration space defined by `[0, ub[0]) x [0, ub[1]] x ... x [0, ub[N-1]]`
/// where `ub` is the upper bound given by
/// `ExtractSliceFromCollapseHelper::getIterationSpaceSizes`. Inside the body of
/// the loop nest, the caller should call
/// `ExtractSliceFromCollapseHelper::emitLoopNestBody` and provide the induction
/// variables. This returns a sub-tile and a set of ranges that describe where
/// this tile should be inserted into the result by the caller. For a complete
/// example of usage, see the examples in the TestTensorTransforms pass.
///
/// ### Example:
/// Consider the following IR:
/// ```
/// %0 = linalg.generic ... -> tensor<3x?x?x11x?xf32>
/// %1 = tensor.collapse_shape %0 [[0, 1, 2], [3, 4]]
/// : tensor<3x?x?x11x?xf32> into tensor<?x?xf32>
/// %2 = tensor.extract_slice %1 [%offt0, %offt1][%size0, %size1][1, 1]
/// : tensor<?x?xf32> to tensor<?x?xf32>
/// ```
///
/// We can construct %2 by generating the following, which only uses `%0`:
///
/// ```
/// %dest = linalg.init_tensor [%size0, %size1] : tensor<?x?xf32>
/// %1 = tensor.dim %0, %c1 : tensor<3x?x?x11x?xf32>
/// %2 = tensor.dim %0, %c2 : tensor<3x?x?x11x?xf32>
/// %3 = tensor.dim %0, %c4 : tensor<3x?x?x11x?xf32>
///
/// %result = scf.for %iv0 = %c0 to %arg2 step %c1 iter_args(%arg6 = %dest) ->
/// (tensor<?x?xf32>) {
/// %5 = scf.for %iv1 = %c0 to %arg4 step %c1 iter_args(%arg8 = %arg6)
/// -> (tensor<?x?xf32>) {
/// %lin0 = (affine.apply) %iv0 + %offt0
/// %lin1 = (affine.apply) %iv1 + %offt1
///
/// %mi0:3 = affine.delinearize_index %lin0 into (%c3, %1, %2)
/// %mi1:2 = affine.delinearize_index %lin1 into (%c11, %3)
///
/// %sub_tile = tensor.extract_slice %0
/// [%mi0#0, %mi0#1, %mi0#2, %mi1#0, %mi1#1]
/// [1, 1, 1, 1, 1]
/// [1, 1, 1, 1, 1]
/// : tensor<3x?x?x11x?xf32> to tensor<1x1x1x1x1xf32>
/// %sub_tile_collapsed = tensor.collapse_shape %sub_tile
/// [[0, 1, 2], [3, 4]]
/// : tensor<1x1x1x1x1xf32> into tensor<1x1xf3
///
/// %12 = tensor.insert_slice %sub_tile_collapsed into
/// %arg8[%iv0, %iv1] [1, 1] [1, 1]
/// : tensor<1x1xf32> into tensor<?x?xf32>
/// scf.yield %12 : tensor<?x?xf32>
/// }
/// scf.yield %5 : tensor<?x?xf32>
/// }
/// ```
///
/// ### Explanation of example:
///
/// Each step above is explained below.
///
/// #### Step 0: Create %dest and materialization of shapes.
/// This step is self-explanatory and performed by the caller. It can be
/// done before or after calling `ExtractSliceFromCollapseHelper::create`,
/// which materializes the source shape (`%0, %1, %2`).
///
/// #### Step 1: Create loop nest.
///
/// The caller creates the loop nest (depicted here is `scf.for`, but any other
/// similar op can be used). The iteration should start at zero and proceed with
/// step size 1 to the upper bounds given by
/// `ExtractSliceFromCollapseHelper::getIterationSpaceSizes`. This forms the
/// basis for the "tiling by 1".
///
/// #### Step 2: Transform (%iv0, %iv1) from the index space of %3 to the index
/// space of %0.
///
/// This step is performed by
/// `ExtractSliceFromCollapseHelper::emitLoopNestBody`.
///
/// The induction variables `%iv0` and `%iv1` live in the
/// index space of %2 (for dimensions 0 and 1, respectively). `%lin0` and
/// `%lin1` are the result of inverting or resolve the index space
/// transformation represented by the slice operation, accounting for offset and
/// stride. Subsequently, `%mi0` and `%mi1` are the result of applying the
/// inverse index space transformation represented by `tensor.collapse_shape`.
/// This is accomplished using `affine.delinearize_index`. Note that %iv0
/// and %iv1 now correspond to multi-indices `%mi0:3` and `%mi1:2`.
///
/// #### Step 3: Extract a sub-tile slice from the source.
///
/// This step is also performed by
/// `ExtractSliceFromCollapseHelper::emitLoopNestBody`.
///
/// The indices `%mi0` and `%mi1` are used to extract a slice from %0. This
/// slice is then collapsed down to match the result rank.
///
/// #### Step 4: Insert sub-tile into the destination
///
/// This step is performed by the caller using the results of
/// `ExtractSliceFromCollapseHelper::emitLoopNestBody`.
///
/// In the above example, the slice insertion parameters are straightforward,
/// but in other possible situations, the slice parameters are more complicated,
/// which is why this helper calculates them for the caller. These other
/// situations correspond to:
/// 1. The presence of linearized dimensions that are not sliced
/// 2. The presence of non-linearized dimensions that are sliced.
class ExtractSliceFromCollapseHelper {
public:
/// Given a CollapseShapeOp and a set of ranges describing the desired slice
/// of its result, emits IR to materialize the shapes of the input and output
/// tensors, and returns an instance of the initialized class. Returns failure
/// if the slice is rank-reducing.
static FailureOr<ExtractSliceFromCollapseHelper>
create(OpBuilder &b, tensor::CollapseShapeOp op, ArrayRef<Range> sliceParams);
/// Given a CollapseShapeOp and an ExtractSliceOp acting on its result, emits
/// IR to materialize the shapes of the input and output tensors of the
/// CollapseShapeOp, and returns an instance of the initialized class. Returns
/// failure if the slice is rank-reducing.
static FailureOr<ExtractSliceFromCollapseHelper>
create(OpBuilder &b, tensor::CollapseShapeOp collapseOp,
tensor::ExtractSliceOp extractOp);
ExtractSliceFromCollapseHelper(
tensor::CollapseShapeOp collapseShapeOp,
ArrayRef<OpFoldResult> collapseShapeInputShape,
ArrayRef<OpFoldResult> collapseShapeOutputShape,
ArrayRef<Range> extractSliceParams,
const llvm::SmallBitVector &linearizedDimensions,
const llvm::SmallBitVector &slicedDimensions, ArrayRef<Value> tiledSizes)
: collapseShapeOp(collapseShapeOp),
collapseShapeInputShape(collapseShapeInputShape),
collapseShapeOutputShape(collapseShapeOutputShape),
sliceParams(extractSliceParams),
linearizedDimensions(linearizedDimensions),
slicedDimensions(slicedDimensions), tiledSizes(tiledSizes) {}
/// Return the upper bounds of the iteration space (with 0 offset and stride
/// 1) required to create the desired slice. Note that this is not the same
/// as the `sizes` parameters of the ExtractSliceOp because not all dimensions
/// of the slice are required to be tiled to form the result.
const SmallVector<Value> &getIterationSpaceSizes() { return tiledSizes; }
/// Generates the IR inside of the caller's loop nest for 1) inverting the
/// index mappings of the ExtractSliceOp->CollapseShapeOp chain and 2)
/// extracting the CollapseShapeOp source tensor tile for this specified
/// iteration space point `tileInductionVars` and 3) calculating where to
/// insert the extracted tile. The returned pair consists of the results of
/// (2) and (3) and should be used by the caller to insert into the
/// destination tensor.
std::pair<Value, SmallVector<Range>>
emitLoopNestBody(OpBuilder &builder, Location loc,
ValueRange tileInductionVars);
private:
tensor::CollapseShapeOp collapseShapeOp;
SmallVector<OpFoldResult> collapseShapeInputShape;
SmallVector<OpFoldResult> collapseShapeOutputShape;
SmallVector<Range> sliceParams;
llvm::SmallBitVector linearizedDimensions;
llvm::SmallBitVector slicedDimensions;
SmallVector<Value> tiledSizes;
};
} // namespace tensor
} // namespace mlir
#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMUTILS_H

View File

@ -14,6 +14,7 @@
#ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
#define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
@ -373,6 +374,90 @@ private:
}
};
/// The input parameters `offsets`, `sizes`, `strides` specify a rectangular
/// non rank-reducing slice of the collapse_shape output. Try to find which
/// dimensions have been sliced and which dimensions are not sliced (offset = 0,
/// size = dim, size = 1). Note that this conservative as it cannot detect if a
/// dynamic size corresponds to the full tensor dimension or not.
llvm::SmallBitVector getSlicedDimensions(ArrayRef<OpFoldResult> sliceInputShape,
ArrayRef<Range> sliceParams);
/// Determine which dimensions are linearized by a `tensor.collapse_shape` op by
/// inspecting its reassociation indices.
llvm::SmallBitVector
getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
/// Given the parameters for both operations in a `CollapseShape->ExtractSlice`
/// chain and reified source and result shapes of the CollapseShapeOp, this
/// class provides two functions that assist with directly forming the result
/// of the extract slice by "tiling the CollapseShapeOp by 1".
//// Example:
// clang-format off
/// ```
/// %0 = linalg.generic ... -> tensor<3x7x11x10xf32>
/// %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : ... to tensor<341x10xf32>
/// %2 = tensor.extract_slice %1 [13, 0] [10, 10] [2, 1] : .... tensor<10x10xf32>
/// ```
/// This class helps build the below IR to replace %2:
/// ```
/// %dest = linalg.init_tensor() : tensor<10x10xf32>
/// %2 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0) -> tensor<10x10xf32> {
/// %linear_index = affine.apply affine_map<(d0)[]->(d0*2 + 11)>(%iv)
/// %3:3 = arith.delinearize_index %iv into (3, 7, 11)
///
/// // This function takes %3 (multiIndices) and the parameters for the slice below.
/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
/// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
///
/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
/// tensor<1x1x1x10xf32> into tensor<1x10xf32>
/// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
/// tensor<1x10xf32> into tensor<10x10xf32>
/// scf.yield %6 : tensor<10x10xf32>
/// }
/// ```
// clang-format on
class SliceFromCollapseHelper {
public:
SliceFromCollapseHelper(ArrayRef<ReassociationIndices> reassociationIndices,
ArrayRef<OpFoldResult> collapseShapeInputShape,
ArrayRef<OpFoldResult> collapseShapeOutputShape,
ArrayRef<Range> extractSliceParams)
: reassociationIndices(reassociationIndices),
collapseShapeInputShape(collapseShapeInputShape),
collapseShapeOutputShape(collapseShapeOutputShape),
sliceParams(extractSliceParams),
linearizedDimensions(getLinearizedDimensions(reassociationIndices)),
slicedDimensions(getSlicedDimensions(collapseShapeOutputShape,
extractSliceParams)) {}
/// This function takes multi-indices and maps them to ExtractSlice parameters
/// in the index space of the CollapseShape's source tensor. This function's
/// signature can be described by `(D_0, D_1,.. D_{n-1}) -> (offsets, sizes,
/// strides)` where `n` the number of "tiled dimensions", which are the
/// dimensions of the output that are linearized by the collapse shape op and
/// are also sliced. Each `D_i` is a tuple that must represent a valid
/// multi-index for the `i-th` tiled dimension. In the example above, there is
/// only one tiled dimension (D_0) and `arith.delinearize_index` produces the
/// multi-index (%3) that would be passed to this function to generate the
/// parameters for the `tensor.extract_slice` op (%4).
SmallVector<Range> getExtractSliceParams(ArrayRef<ValueRange> multiIndices);
/// This function takes indices in the index space of the "tiled dimensions"
/// described above and returns a set of Range variables that describe how the
/// slice should be inserted into the destination. In the example above, `%iv`
/// would be passed to this function to generate the parameters for the
/// `tensor.insert_slice` op producing %6.
SmallVector<Range> getInsertSliceParams(ValueRange tileIndices);
private:
SmallVector<ReassociationIndices> reassociationIndices;
SmallVector<OpFoldResult> collapseShapeInputShape;
SmallVector<OpFoldResult> collapseShapeOutputShape;
SmallVector<Range> sliceParams;
llvm::SmallBitVector linearizedDimensions;
llvm::SmallBitVector slicedDimensions;
};
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H

View File

@ -29,6 +29,13 @@ struct Range {
OpFoldResult stride;
};
/// Given an array of Range values, return a tuple of (offset vector, sizes
/// vector, and strides vector) formed by separating out the individual elements
/// of each range.
std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
SmallVector<OpFoldResult>>
getOffsetsSizesAndStrides(ArrayRef<Range> ranges);
/// Helper function to dispatch an OpFoldResult into `staticVec` if:
/// a) it is an IntegerAttr
/// In other cases, the OpFoldResult is dispached to the `dynamicVec`.

View File

@ -1210,6 +1210,15 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
}
/// Build an ExtractSliceOp with mixed static and dynamic entries packed into a
/// Range vector.
void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
ArrayRef<Range> ranges,
ArrayRef<NamedAttribute> attrs) {
auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
}
/// Build an ExtractSliceOp with dynamic entries and custom result type. If the
/// type passed is nullptr, it is inferred.
void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
@ -1597,6 +1606,15 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
result.addAttributes(attrs);
}
/// Build an InsertSliceOp with mixed static and dynamic entries packed into a
/// Range vector.
void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
Value dest, ArrayRef<Range> ranges,
ArrayRef<NamedAttribute> attrs) {
auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
build(b, result, source, dest, offsets, sizes, strides, attrs);
}
// Build a InsertSliceOp with dynamic entries.
void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
Value dest, ValueRange offsets, ValueRange sizes,
@ -2359,6 +2377,16 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
result.addAttributes(attrs);
}
/// Build an ParallelInsertSliceOp with mixed static and dynamic entries packed
/// into a Range vector.
void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
Value source, Value dest,
ArrayRef<Range> ranges,
ArrayRef<NamedAttribute> attrs) {
auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
build(b, result, source, dest, offsets, sizes, strides, attrs);
}
// Build a ParallelInsertSliceOp with dynamic entries.
void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
Value source, Value dest, ValueRange offsets,

View File

@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRTensorTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
ExtractSliceFromReshape.cpp
SplitPadding.cpp
SwapExtractSliceWithProducer.cpp

View File

@ -0,0 +1,179 @@
//===- ExtractSliceFromReshape.cpp - Slice reshape rewrites-------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements rewrites that replace slices of reshape results with
// aggregated slices of the reshape source.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "llvm/ADT/STLExtras.h"
using namespace mlir;
using namespace mlir::tensor;
/// Get the dimension size of a value of RankedTensor type at the
OpFoldResult getShapeDimSize(OpBuilder &b, Location loc, Value rankedTensor,
int64_t dimIdx) {
RankedTensorType tensorType = rankedTensor.getType().cast<RankedTensorType>();
if (!tensorType.isDynamicDim(dimIdx)) {
return b.getIndexAttr(tensorType.getDimSize(dimIdx));
}
Value idxValue = b.create<arith::ConstantIndexOp>(loc, dimIdx);
return b.createOrFold<tensor::DimOp>(loc, rankedTensor, idxValue);
}
/// Get all the dimension sizes of a value of RankedTensor type.
static SmallVector<OpFoldResult> getShapeDimSizes(OpBuilder &b, Location loc,
Value rankedTensor) {
SmallVector<OpFoldResult> dimSizes;
RankedTensorType tensorType = rankedTensor.getType().cast<RankedTensorType>();
for (unsigned i = 0; i < tensorType.getRank(); i++)
dimSizes.push_back(getShapeDimSize(b, loc, rankedTensor, i));
return dimSizes;
}
/// A tuple that represents (dimension number, dimension value).
using DimAndIndex = std::tuple<unsigned, Value>;
/// Transform `dimAndIndex` from the output index space of a (non-rank-reducing)
/// slice described by `sliceParams` into the input index space.
static DimAndIndex invertSliceIndexing(OpBuilder &b, Location loc,
ArrayRef<Range> sliceParams,
const DimAndIndex &dimAndIndex) {
AffineExpr d0, s0, s1;
bindDims(b.getContext(), d0);
bindSymbols(b.getContext(), s0, s1);
auto [dim, indexValue] = dimAndIndex;
assert(dim < sliceParams.size() && "slice should be non rank-reducing");
return std::make_pair(
dim,
makeComposedAffineApply(
b, loc, s0 + d0 * s1,
{indexValue,
getValueOrCreateConstantIndexOp(b, loc, sliceParams[dim].offset),
getValueOrCreateConstantIndexOp(b, loc, sliceParams[dim].stride)}));
}
/// Transform `dimAndIndex` from the result tensor index space of a
/// CollapseShapeOp to the source tensor index space.
static ValueRange invertCollapseShapeIndexing(
OpBuilder &b, Location loc, ArrayRef<ReassociationIndices> reassociation,
ArrayRef<OpFoldResult> reshapeSourceShape, const DimAndIndex &dimAndIndex) {
const auto &[dim, indexValue] = dimAndIndex;
SmallVector<OpFoldResult> basis;
for (int64_t i : reassociation[dim])
basis.push_back(reshapeSourceShape[i]);
auto delinearized =
b.create<AffineDelinearizeIndexOp>(loc, indexValue, basis);
return delinearized->getResults();
}
FailureOr<ExtractSliceFromCollapseHelper>
tensor::ExtractSliceFromCollapseHelper::create(
OpBuilder &b, tensor::CollapseShapeOp collapseOp,
tensor::ExtractSliceOp extractOp) {
if (extractOp.getSource().getDefiningOp<tensor::CollapseShapeOp>() !=
collapseOp)
return failure();
SmallVector<Range> ranges;
ranges.reserve(extractOp.getSourceType().getRank());
for (const auto &[o, s, st] :
llvm::zip(extractOp.getMixedOffsets(), extractOp.getMixedSizes(),
extractOp.getMixedStrides())) {
ranges.push_back({o, s, st});
}
return ExtractSliceFromCollapseHelper::create(b, collapseOp, ranges);
}
FailureOr<ExtractSliceFromCollapseHelper>
tensor::ExtractSliceFromCollapseHelper::create(OpBuilder &b,
tensor::CollapseShapeOp op,
ArrayRef<Range> sliceParams) {
// Materialize the output shape of the collapse_shape operation. This will
// create IR describing the output shape in terms of the input shape.
ReifiedRankedShapedTypeDims reifiedShapes;
ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
dyn_cast<ReifyRankedShapedTypeOpInterface>(op.getOperation());
if (failed(reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes)))
return failure();
SmallVector<OpFoldResult> collapseShapeOutputShape =
getAsOpFoldResult(reifiedShapes[0]);
SmallVector<ReassociationIndices> reassociationIndices =
op.getReassociationIndices();
// Determine which of the CollapseShapeOp's result dimensions are sliced
// and/or linearized.
llvm::SmallBitVector linearizedDimensions =
getLinearizedDimensions(reassociationIndices);
llvm::SmallBitVector slicedDimensions =
getSlicedDimensions(collapseShapeOutputShape, sliceParams);
auto collapseShapeInputShape = getShapeDimSizes(b, op.getLoc(), op.getSrc());
SmallVector<OpFoldResult> srcShape =
getShapeDimSizes(b, op->getLoc(), op.getSrc());
SmallVector<Value> tileSizes;
for (unsigned i = 0; i < sliceParams.size(); i++) {
if (slicedDimensions[i] && linearizedDimensions[i])
tileSizes.push_back(
getValueOrCreateConstantIndexOp(b, op.getLoc(), sliceParams[i].size));
}
return ExtractSliceFromCollapseHelper(
op, collapseShapeInputShape, collapseShapeOutputShape, sliceParams,
linearizedDimensions, slicedDimensions, tileSizes);
}
std::pair<Value, SmallVector<Range>>
tensor::ExtractSliceFromCollapseHelper::emitLoopNestBody(
OpBuilder &builder, Location loc, ValueRange tileInductionVars) {
// Create the helper class for forming the slice parameters.
const SmallVector<ReassociationIndices> reassociationIndices =
collapseShapeOp.getReassociationIndices();
SliceFromCollapseHelper helper(reassociationIndices, collapseShapeInputShape,
collapseShapeOutputShape, sliceParams);
// Get the indices of the tiled dims (linearized by the collapse_shape
// and sliced by the extract_slice) invert the index spaces
// transformations.
SmallVector<ValueRange> multiIndices;
unsigned loopIdx = 0;
for (unsigned i = 0, e = linearizedDimensions.size(); i < e; i++) {
if (linearizedDimensions[i] && slicedDimensions[i]) {
DimAndIndex tb =
invertSliceIndexing(builder, loc, sliceParams,
std::make_tuple(i, tileInductionVars[loopIdx++]));
multiIndices.push_back(invertCollapseShapeIndexing(
builder, loc, reassociationIndices, collapseShapeInputShape, tb));
}
}
auto extractParams = helper.getExtractSliceParams(multiIndices);
Value subTileResult = builder.create<tensor::ExtractSliceOp>(
loc, collapseShapeOp.getSrc(), extractParams);
SmallVector<Range> insertParams =
helper.getInsertSliceParams(tileInductionVars);
// Collapse the dimensions of the source slice back down.
Value collapsedResult = builder.create<tensor::CollapseShapeOp>(
loc, subTileResult, reassociationIndices);
return std::make_pair(collapsedResult, insertParams);
}

View File

@ -270,3 +270,88 @@ bool mlir::hasNonIdentityLayout(Type type) {
return !memrefType.getLayout().isIdentity();
return false;
}
llvm::SmallBitVector
mlir::getSlicedDimensions(ArrayRef<OpFoldResult> sliceInputShape,
ArrayRef<Range> sliceParams) {
assert(sliceParams.size() == sliceInputShape.size() &&
"only supports non rank-reducing case");
llvm::SmallBitVector mask(sliceInputShape.size());
unsigned idx = 0;
for (const auto &[offset, size, stride] : sliceParams) {
Optional<int64_t> offsetConst = getConstantIntValue(offset);
Optional<int64_t> strideConst = getConstantIntValue(stride);
mask[idx] = !isEqualConstantIntOrValue(size, sliceInputShape[idx]) ||
(!strideConst || *strideConst != 1) ||
(!offsetConst || *offsetConst != 0);
idx++;
}
return mask;
}
llvm::SmallBitVector mlir::getLinearizedDimensions(
ArrayRef<ReassociationIndices> reassociationIndices) {
llvm::SmallBitVector result(reassociationIndices.size());
for (const auto &it : llvm::enumerate(reassociationIndices))
result[it.index()] = it.value().size() > 1;
return result;
}
SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
ArrayRef<ValueRange> multiIndices) {
assert(!multiIndices.empty() && !multiIndices[0].empty() &&
"multiIndices should not be empty");
unsigned loopIdx = 0;
MLIRContext *ctx = multiIndices[0][0].getContext();
auto oneAttr = IntegerAttr::get(IndexType::get(ctx), 1);
auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0);
SmallVector<Range> offsetsSizesAndStrides;
offsetsSizesAndStrides.reserve(collapseShapeInputShape.size());
for (const auto &it : llvm::enumerate(reassociationIndices)) {
// Case 1: Linearized dimensions that have also been sliced. These
// are size of 1 because we are iterating over these dimensions. The
// offsets are exactly the de-linearized multi-indices.
if (slicedDimensions[it.index()] && linearizedDimensions[it.index()]) {
llvm::append_range(
offsetsSizesAndStrides,
llvm::map_range(multiIndices[loopIdx++], [&](Value v) -> Range {
return Range{getAsOpFoldResult(v), oneAttr, oneAttr};
}));
continue;
}
// Case 2: One or possibly multiple combined input dimensions, but we
// have proven that these are not sliced. In this case we just take
// the full extent of each dimension in the reassociation list.
if (linearizedDimensions[it.index()]) {
llvm::append_range(
offsetsSizesAndStrides,
llvm::map_range(it.value(), [&](int64_t idx) -> Range {
return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
}));
continue;
}
// Case 3: A single index, but it may be sliced.
offsetsSizesAndStrides.push_back(sliceParams[it.index()]);
}
return offsetsSizesAndStrides;
}
SmallVector<Range>
SliceFromCollapseHelper::getInsertSliceParams(ValueRange tileIndices) {
MLIRContext *ctx = tileIndices[0].getContext();
auto one = IntegerAttr::get(IndexType::get(ctx), 1);
auto zero = IntegerAttr::get(IndexType::get(ctx), 0);
SmallVector<Range> insertParams;
insertParams.reserve(linearizedDimensions.size());
unsigned loopIdx = 0;
for (unsigned i = 0; i < linearizedDimensions.size(); i++) {
if (linearizedDimensions[i] && slicedDimensions[i]) {
insertParams.push_back(Range{tileIndices[loopIdx++], one, one});
continue;
}
insertParams.push_back(Range{zero, sliceParams[i].size, one});
}
return insertParams;
}

View File

@ -13,6 +13,21 @@
namespace mlir {
std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
SmallVector<OpFoldResult>>
getOffsetsSizesAndStrides(ArrayRef<Range> ranges) {
SmallVector<OpFoldResult> offsets, sizes, strides;
offsets.reserve(ranges.size());
sizes.reserve(ranges.size());
strides.reserve(ranges.size());
for (const auto &[offset, size, stride] : ranges) {
offsets.push_back(offset);
sizes.push_back(size);
strides.push_back(stride);
}
return std::make_tuple(offsets, sizes, strides);
}
/// Helper function to dispatch an OpFoldResult into `staticVec` if:
/// a) it is an IntegerAttr
/// In other cases, the OpFoldResult is dispached to the `dynamicVec`.

View File

@ -0,0 +1,164 @@
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-rewrite-extract-slice-from-collapse-shape %s | FileCheck %s
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-rewrite-extract-slice-from-collapse-shape use-foreach" %s | FileCheck %s --check-prefix=FOREACH
func.func @extract_slice_static(%input: tensor<3x5x7x11xf32>) -> tensor<20x11xf32> {
%collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x5x7x11xf32> into tensor<105x11xf32>
%slice = tensor.extract_slice %collapsed [0, 0] [20, 11] [1, 1] : tensor<105x11xf32> to tensor<20x11xf32>
return %slice : tensor<20x11xf32>
}
// CHECK: func.func @extract_slice_static(%[[arg0:.+]]:
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[c20:.+]] = arith.constant 20 : index
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index
// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index
// CHECK-DAG: %[[init:.+]] = linalg.init_tensor [20, 11] :
// CHECK-DAG: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c20]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]])
// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]]
// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] :
// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 11] [1, 1] :
// CHECK: scf.yield %[[update]] :
// CHECK: return %[[tile]]
// FOREACH: func.func @extract_slice_static(%[[arg0:.+]]:
// FOREACH-DAG: %[[c20:.+]] = arith.constant 20 : index
// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index
// FOREACH-DAG: %[[c5:.+]] = arith.constant 5 : index
// FOREACH-DAG: %[[c7:.+]] = arith.constant 7 : index
// FOREACH-DAG: %[[init:.+]] = linalg.init_tensor [20, 11] :
// FOREACH: %[[tile:.+]] = scf.foreach_thread (%[[iv:.+]]) in (%[[c20]]) shared_outs(%[[dest:.+]] = %[[init]])
// FOREACH: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]]
// FOREACH: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] :
// FOREACH: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
// FOREACH: perform_concurrently
// FOREACH-NEXT: tensor.parallel_insert_slice %[[sliceFlat]] into %[[dest]][%[[iv]], 0] [1, 11] [1, 1] :
// FOREACH: return %[[tile]]
// -----
func.func @extract_slice_static_strided(%input: tensor<3x5x7x11xf32>) -> tensor<10x5xf32> {
%collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x5x7x11xf32> into tensor<105x11xf32>
%slice = tensor.extract_slice %collapsed [13, 0] [10, 5] [2, 2] : tensor<105x11xf32> to tensor<10x5xf32>
return %slice : tensor<10x5xf32>
}
// CHECK: #[[$map0:.+]] = affine_map<(d0) -> (d0 * 2 + 13)>
// CHECK: func.func @extract_slice_static_strided(%[[arg0:.+]]:
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index
// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index
// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index
// CHECK: %[[init:.+]] = linalg.init_tensor [10, 5] :
// CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]])
// CHECK: %[[inputIv:.+]] = affine.apply #[[$map0]](%[[iv]])
// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (%[[c3]], %[[c5]], %[[c7]]
// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] :
// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] :
// CHECK: scf.yield %[[update]] :
// CHECK: return %[[tile]]
// -----
func.func @extract_slice_dynamic(%input: tensor<3x?x?x11xf32>, %offt: index, %size: index) -> tensor<?x5xf32> {
%collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x?x?x11xf32> into tensor<?x11xf32>
%slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 5] [2, 2] : tensor<?x11xf32> to tensor<?x5xf32>
return %slice : tensor<?x5xf32>
}
// CHECK: #[[map0:.+]] = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
// CHECK: func.func @extract_slice_dynamic(%[[arg0:.+]]: tensor<{{.*}}>, %[[lb:.+]]: index, %[[sz:.+]]: index)
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
// CHECK: %[[init:.+]] = linalg.init_tensor [%[[sz]], 5] : tensor<?x5xf32>
// CHECK-DAG: %[[d1:.+]] = tensor.dim %arg0, %[[c1]] : tensor<3x?x?x11xf32>
// CHECK-DAG: %[[d2:.+]] = tensor.dim %arg0, %[[c2]] : tensor<3x?x?x11xf32>
// CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[sz]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]])
// CHECK: %[[inputIv:.+]] = affine.apply #[[map0]](%[[iv]])[%[[lb]]]
// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (%[[c3]], %[[d1]], %[[d2]]) :
// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] :
// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] :
// CHECK: scf.yield %[[update]] :
// CHECK: return %[[tile]] :
// -----
func.func @extract_slice_dynamic_multidim(%input: tensor<3x?x?x11x?xf32>, %offt0: index, %size0: index, %offt1: index, %size1: index) -> tensor<?x?xf32> {
%collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<3x?x?x11x?xf32> into tensor<?x?xf32>
%slice = tensor.extract_slice %collapsed [%offt0, %offt1] [%size0, %size1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
return %slice : tensor<?x?xf32>
}
// CHECK: #[[map0:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
// CHECK: func.func @extract_slice_dynamic_multidim(%[[arg0:.+]]: tensor<3x?x?x11x?xf32>, %[[lb1:.+]]: index, %[[sz1:.+]]: index, %[[lb2:.+]]: index, %[[sz2:.+]]: index)
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[c4:.+]] = arith.constant 4 : index
// CHECK-DAG: %[[c11:.+]] = arith.constant 11 : index
// CHECK: %[[init:.+]] = linalg.init_tensor [%[[sz1]], %[[sz2]]] : tensor<?x?xf32>
// CHECK-DAG: %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] :
// CHECK-DAG: %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] :
// CHECK-DAG: %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] :
// CHECK: %[[tile1:.+]] = scf.for %[[iv1:.+]] = %[[c0]] to %[[sz1]] step %[[c1]] iter_args(%[[iterArg1:.+]] = %[[init]])
// CHECK: %[[tile2:.+]] = scf.for %[[iv2:.+]] = %[[c0]] to %[[sz2]] step %[[c1]] iter_args(%[[iterArg2:.+]] = %[[iterArg1]])
// CHECK: %[[inputIv1:.+]] = affine.apply #[[map0:.+]](%[[iv1]])[%[[lb1]]]
// CHECK: %[[multiIndex1:.+]]:3 = affine.delinearize_index %[[inputIv1]] into (%[[c3]], %[[d1]], %[[d2]]) :
// CHECK: %[[inputIv2:.+]] = affine.apply #[[map0:.+]](%[[iv2]])[%[[lb2]]]
// CHECK: %[[multiIndex2:.+]]:2 = affine.delinearize_index %[[inputIv2]] into (%[[c11]], %[[d4]]) :
// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] :
// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} :
// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg2]][%[[iv1]], %[[iv2]]] [1, 1] [1, 1] :
// CHECK: scf.yield %[[update]] :
// CHECK: scf.yield %[[tile2]] :
// CHECK: return %[[tile1]] :
// FOREACH: #[[map1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
// FOREACH: func.func @extract_slice_dynamic_multidim(%[[arg0:.+]]: tensor<3x?x?x11x?xf32>, %[[lb1:.+]]: index, %[[sz1:.+]]: index, %[[lb2:.+]]: index, %[[sz2:.+]]: index)
// FOREACH-DAG: %[[c1:.+]] = arith.constant 1 : index
// FOREACH-DAG: %[[c2:.+]] = arith.constant 2 : index
// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index
// FOREACH-DAG: %[[c4:.+]] = arith.constant 4 : index
// FOREACH-DAG: %[[c11:.+]] = arith.constant 11 : index
// FOREACH: %[[init:.+]] = linalg.init_tensor [%[[sz1]], %[[sz2]]] : tensor<?x?xf32>
// FOREACH-DAG: %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] :
// FOREACH-DAG: %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] :
// FOREACH-DAG: %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] :
// FOREACH: %[[tile1:.+]] = scf.foreach_thread (%[[tid1:.+]], %[[tid2:.+]]) in (%[[sz1]], %[[sz2]]) shared_outs(%[[dest:.+]] = %[[init]])
// FOREACH-DAG: %[[iv1:.+]] = affine.apply #[[map1]](%[[tid1]])[%[[lb1]]]
// FOREACH: %[[multiIndex1:.+]]:3 = affine.delinearize_index %[[iv1]] into (%[[c3]], %[[d1]], %[[d2]]) :
// FOREACH-DAG: %[[iv2:.+]] = affine.apply #[[map1]](%[[tid2]])[%[[lb2]]]
// FOREACH: %[[multiIndex2:.+]]:2 = affine.delinearize_index %[[iv2]] into (%[[c11]], %[[d4]]) :
// FOREACH: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] :
// FOREACH: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} :
// FOREACH: perform_concurrently
//FOREACH-NEXT: tensor.parallel_insert_slice %[[sliceFlat]] into %[[dest]][%[[tid1]], %[[tid2]]] [1, 1] [1, 1] :
// -----
// Verifies that a linearized dimension that is not sliced does not generate a loop. Note that this
// only works for static shapes.
// CHECK: @extract_slice_non_sliced_linearized_dim(%[[arg0:.+]]: tensor<{{.*}}>,
func.func @extract_slice_non_sliced_linearized_dim(%input: tensor<3x?x?x11x2xf32>, %offt: index, %size: index) -> tensor<?x22xf32> {
%collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<3x?x?x11x2xf32> into tensor<?x22xf32>
%slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 22] [1, 1] : tensor<?x22xf32> to tensor<?x22xf32>
// CHECK: scf.for
// CHECK-NOT: scf.for
// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index
// CHECK: tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0, 0] [1, 1, 1, 11, 2] [1, 1, 1, 1, 1]
return %slice : tensor<?x22xf32>
}

View File

@ -6,6 +6,7 @@ add_mlir_library(MLIRTensorTestPasses
LINK_LIBS PUBLIC
MLIRArithmeticDialect
MLIRLinalgDialect
MLIRPass
MLIRSCFDialect
MLIRTensorDialect

View File

@ -11,8 +11,10 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@ -28,7 +30,8 @@ struct TestTensorTransforms
TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithmeticDialect, scf::SCFDialect>();
registry.insert<arith::ArithmeticDialect, scf::SCFDialect,
linalg::LinalgDialect>();
}
StringRef getArgument() const final {
@ -49,6 +52,19 @@ struct TestTensorTransforms
*this, "test-fold-constant-extract-slice",
llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"),
llvm::cl::init(false)};
Option<bool> testRewriteExtractSliceWithTiledCollapseShape{
*this, "test-rewrite-extract-slice-from-collapse-shape",
llvm::cl::desc("Test swapping tensor.extract_slice of a collapse_shape "
"with loop nest"),
llvm::cl::init(false)};
Option<bool> useForeach{
*this, "use-foreach",
llvm::cl::desc(
"Use the scf.foreach_thread operation when generating loop nests for "
"the extract_slice of collapse_shape pattern"),
llvm::cl::init(false)};
};
} // namespace
@ -74,12 +90,142 @@ static void applyFoldConstantExtractSlicePatterns(Operation *rootOp) {
(void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
}
namespace {
/// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`.
/// The `tensor.extract_slice` is replaced by a loop or gather operation that
/// stitches together the desired tile from slices of the source of the collapse
/// shape op.
struct RewriteExtractSliceFromCollapseShapeBase
: public OpRewritePattern<tensor::ExtractSliceOp> {
RewriteExtractSliceFromCollapseShapeBase(MLIRContext *context)
: mlir::OpRewritePattern<tensor::ExtractSliceOp>(context) {}
/// Emit a loop or gather operation that uses `helper` to take each point in
/// the parallel iteration space bounds, extract a slice from the source
/// tensor and insert it into `dest`. For examples, see below for `scf.for`
/// and `scf.foreach`.
virtual LogicalResult
emitReplacement(tensor::ExtractSliceOp op, Value dest,
tensor::ExtractSliceFromCollapseHelper &helper,
PatternRewriter &rewriter) const = 0;
LogicalResult matchAndRewrite(tensor::ExtractSliceOp op,
PatternRewriter &rewriter) const override {
auto collapseOp = op.getSource().getDefiningOp<tensor::CollapseShapeOp>();
if (!collapseOp)
return rewriter.notifyMatchFailure(
op, "producer is not a tensor.collapse_shape op");
// Materialize the output shape values of the slice operation.a
ReifiedRankedShapedTypeDims reifiedShapes;
if (failed(op.reifyResultShapes(rewriter, reifiedShapes)))
return rewriter.notifyMatchFailure(op, "failed to reify result shapes");
// Create the destination tensor using the above values.
Type elementType = op.getSourceType().getElementType();
SmallVector<OpFoldResult> outputShape = getAsOpFoldResult(reifiedShapes[0]);
Value dest = rewriter.create<linalg::InitTensorOp>(
op->getLoc(), outputShape, elementType);
// Calculate the parameters for the tile loop nest.
FailureOr<tensor::ExtractSliceFromCollapseHelper> params =
tensor::ExtractSliceFromCollapseHelper::create(rewriter, collapseOp,
op);
if (failed(params))
return rewriter.notifyMatchFailure(
op, "could not calculate tiling parameters");
return emitReplacement(op, dest, *params, rewriter);
}
};
struct RewriteExtractSliceFromCollapseShapeUsingScfFor
: public RewriteExtractSliceFromCollapseShapeBase {
RewriteExtractSliceFromCollapseShapeUsingScfFor(MLIRContext *context)
: RewriteExtractSliceFromCollapseShapeBase(context) {}
LogicalResult emitReplacement(tensor::ExtractSliceOp op, Value dest,
tensor::ExtractSliceFromCollapseHelper &helper,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
const unsigned numTiledDims = helper.getIterationSpaceSizes().size();
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
SmallVector<Value> lbs(numTiledDims, zero);
SmallVector<Value> steps(numTiledDims, one);
scf::LoopNest nest = scf::buildLoopNest(
rewriter, loc, lbs, helper.getIterationSpaceSizes(), steps, dest,
[&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs,
ValueRange iterArgs) -> scf::ValueVector {
auto [tile, insertParams] =
helper.emitLoopNestBody(nestedBuilder, loc, outputIvs);
// Insert the slice into the destination.
Value result = nestedBuilder.create<tensor::InsertSliceOp>(
loc, tile, iterArgs[0], insertParams);
return {result};
});
rewriter.replaceOp(op, nest.getResults()[0]);
return success();
}
};
struct RewriteExtractSliceFromCollapseShapeUsingScfForeach
: public RewriteExtractSliceFromCollapseShapeBase {
RewriteExtractSliceFromCollapseShapeUsingScfForeach(MLIRContext *context)
: RewriteExtractSliceFromCollapseShapeBase(context) {}
LogicalResult emitReplacement(tensor::ExtractSliceOp op, Value dest,
tensor::ExtractSliceFromCollapseHelper &helper,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto foreachOp = rewriter.create<scf::ForeachThreadOp>(
loc, /*outputs=*/dest, /*numThreads=*/helper.getIterationSpaceSizes(),
/*threadDimMapping=*/ArrayRef<int64_t>{},
[&](OpBuilder &nestedBuilder, Location loc, ValueRange regionArgs) {
unsigned numThreadIdRegionArgs =
helper.getIterationSpaceSizes().size();
unsigned numOutputRegionArgs =
regionArgs.size() - numThreadIdRegionArgs;
ValueRange outputIvs = regionArgs.take_front(numThreadIdRegionArgs);
ValueRange outputArgs = regionArgs.take_back(numOutputRegionArgs);
assert(outputArgs.size() == 1 &&
"there should only be one output region argument");
auto [tile, insertParams] =
helper.emitLoopNestBody(nestedBuilder, loc, outputIvs);
// Insert the slice into the destination.
auto term = nestedBuilder.create<scf::PerformConcurrentlyOp>(loc);
nestedBuilder.setInsertionPointToStart(term.getBody());
nestedBuilder.create<tensor::ParallelInsertSliceOp>(
loc, tile, outputArgs[0], insertParams);
});
rewriter.replaceOp(op, foreachOp->getResult(0));
return success();
}
};
} // namespace
static LogicalResult
applyRewriteExtractFromCollapseShapePatterns(Operation *rootOp,
bool useForeach) {
RewritePatternSet patterns(rootOp->getContext());
if (useForeach)
patterns.add<RewriteExtractSliceFromCollapseShapeUsingScfForeach>(
rootOp->getContext());
else
patterns.add<RewriteExtractSliceFromCollapseShapeUsingScfFor>(
rootOp->getContext());
return applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
}
void TestTensorTransforms::runOnOperation() {
Operation *rootOp = getOperation();
if (testSplitPaddingPatterns)
applySplitPaddingPatterns(rootOp);
if (testFoldConstantExtractSlice)
applyFoldConstantExtractSlicePatterns(rootOp);
if (testRewriteExtractSliceWithTiledCollapseShape) {
if (failed(
applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))
return signalPassFailure();
}
}
namespace mlir {

View File

@ -5061,11 +5061,13 @@ cc_library(
"include/mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h",
"include/mlir/Dialect/Tensor/Transforms/Passes.h",
"include/mlir/Dialect/Tensor/Transforms/Transforms.h",
"include/mlir/Dialect/Tensor/Transforms/TransformUtils.h"
],
includes = ["include"],
deps = [
":AffineDialect",
":ArithmeticDialect",
":ArithmeticUtils",
":BufferizationDialect",
":BufferizationTransforms",
":DialectUtils",

View File

@ -620,6 +620,7 @@ cc_library(
includes = ["lib/Dialect/Test"],
deps = [
"//mlir:ArithmeticDialect",
"//mlir:LinalgDialect",
"//mlir:Pass",
"//mlir:SCFDialect",
"//mlir:TensorDialect",