forked from OSchip/llvm-project
[mlir][Vector] NFC - Extract rewrites related to insert/extract strided slice in a separate file.
Differential Revision: https://reviews.llvm.org/D112301
This commit is contained in:
parent
cac8808f15
commit
eda2ebd780
|
@ -0,0 +1,58 @@
|
|||
//===- VectorRewritePatterns.h - Vector rewrite patterns --------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_
|
||||
#define DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_
|
||||
|
||||
namespace mlir {
|
||||
class RewritePatternSet;
|
||||
|
||||
namespace vector {
|
||||
|
||||
/// Populate `patterns` with the following patterns.
|
||||
///
|
||||
/// [VectorInsertStridedSliceOpDifferentRankRewritePattern]
|
||||
/// =======================================================
|
||||
/// RewritePattern for InsertStridedSliceOp where source and destination vectors
|
||||
/// have different ranks.
|
||||
///
|
||||
/// When ranks are different, InsertStridedSlice needs to extract a properly
|
||||
/// ranked vector from the destination vector into which to insert. This pattern
|
||||
/// only takes care of this extraction part and forwards the rest to
|
||||
/// [VectorInsertStridedSliceOpSameRankRewritePattern].
|
||||
///
|
||||
/// For a k-D source and n-D destination vector (k < n), we emit:
|
||||
/// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
|
||||
/// insert the k-D source.
|
||||
/// 2. k-D -> (n-1)-D InsertStridedSlice op
|
||||
/// 3. InsertOp that is the reverse of 1.
|
||||
///
|
||||
/// [VectorInsertStridedSliceOpSameRankRewritePattern]
|
||||
/// ==================================================
|
||||
/// RewritePattern for InsertStridedSliceOp where source and destination vectors
|
||||
/// have the same rank. For each outermost index in the slice:
|
||||
/// begin end stride
|
||||
/// [offset : offset+size*stride : stride]
|
||||
/// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
|
||||
/// 2. InsertStridedSlice (k-1)-D into (n-1)-D
|
||||
/// 3. the destination subvector is inserted back in the proper place
|
||||
/// 3. InsertOp that is the reverse of 1.
|
||||
///
|
||||
/// [VectorExtractStridedSliceOpRewritePattern]
|
||||
/// ===========================================
|
||||
/// Progressive lowering of ExtractStridedSliceOp to either:
|
||||
/// 1. single offset extract as a direct vector::ShuffleOp.
|
||||
/// 2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp +
|
||||
/// InsertOp/InsertElementOp for the n-D case.
|
||||
void populateVectorInsertExtractStridedSliceTransforms(
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
} // namespace vector
|
||||
} // namespace mlir
|
||||
|
||||
#endif // DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_
|
|
@ -24,13 +24,6 @@ namespace scf {
|
|||
class IfOp;
|
||||
} // namespace scf
|
||||
|
||||
/// Collect a set of patterns to convert from the Vector dialect to itself.
|
||||
/// Should be merged with populateVectorToSCFLoweringPattern.
|
||||
void populateVectorToVectorConversionPatterns(
|
||||
MLIRContext *context, RewritePatternSet &patterns,
|
||||
ArrayRef<int64_t> coarseVectorShape = {},
|
||||
ArrayRef<int64_t> fineVectorShape = {});
|
||||
|
||||
namespace vector {
|
||||
|
||||
/// Options that control the vector unrolling.
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#ifndef MLIR_DIALECT_VECTOR_VECTORUTILS_H_
|
||||
#define MLIR_DIALECT_VECTOR_VECTORUTILS_H_
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
@ -184,6 +185,11 @@ bool checkSameValueRAW(vector::TransferWriteOp defWrite,
|
|||
bool checkSameValueWAW(vector::TransferWriteOp write,
|
||||
vector::TransferWriteOp priorWrite);
|
||||
|
||||
// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
|
||||
SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
|
||||
unsigned dropFront = 0,
|
||||
unsigned dropBack = 0);
|
||||
|
||||
namespace matcher {
|
||||
|
||||
/// Matches vector.transfer_read, vector.transfer_write and ops that return a
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
|
||||
|
@ -52,17 +53,6 @@ static Value insertOne(ConversionPatternRewriter &rewriter,
|
|||
rewriter.getI64ArrayAttr(pos));
|
||||
}
|
||||
|
||||
// Helper that picks the proper sequence for inserting.
|
||||
static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
|
||||
Value into, int64_t offset) {
|
||||
auto vectorType = into.getType().cast<VectorType>();
|
||||
if (vectorType.getRank() > 1)
|
||||
return rewriter.create<InsertOp>(loc, from, into, offset);
|
||||
return rewriter.create<vector::InsertElementOp>(
|
||||
loc, vectorType, from, into,
|
||||
rewriter.create<arith::ConstantIndexOp>(loc, offset));
|
||||
}
|
||||
|
||||
// Helper that picks the proper sequence for extracting.
|
||||
static Value extractOne(ConversionPatternRewriter &rewriter,
|
||||
LLVMTypeConverter &typeConverter, Location loc,
|
||||
|
@ -79,32 +69,6 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
|
|||
rewriter.getI64ArrayAttr(pos));
|
||||
}
|
||||
|
||||
// Helper that picks the proper sequence for extracting.
|
||||
static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
|
||||
int64_t offset) {
|
||||
auto vectorType = vector.getType().cast<VectorType>();
|
||||
if (vectorType.getRank() > 1)
|
||||
return rewriter.create<ExtractOp>(loc, vector, offset);
|
||||
return rewriter.create<vector::ExtractElementOp>(
|
||||
loc, vectorType.getElementType(), vector,
|
||||
rewriter.create<arith::ConstantIndexOp>(loc, offset));
|
||||
}
|
||||
|
||||
// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
|
||||
// TODO: Better support for attribute subtype forwarding + slicing.
|
||||
static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
|
||||
unsigned dropFront = 0,
|
||||
unsigned dropBack = 0) {
|
||||
assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
|
||||
auto range = arrayAttr.getAsRange<IntegerAttr>();
|
||||
SmallVector<int64_t, 4> res;
|
||||
res.reserve(arrayAttr.size() - dropFront - dropBack);
|
||||
for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
|
||||
it != eit; ++it)
|
||||
res.push_back((*it).getValue().getSExtValue());
|
||||
return res;
|
||||
}
|
||||
|
||||
// Helper that returns data layout alignment of a memref.
|
||||
LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
|
||||
MemRefType memrefType, unsigned &align) {
|
||||
|
@ -813,132 +777,6 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// When ranks are different, InsertStridedSlice needs to extract a properly
|
||||
// ranked vector from the destination vector into which to insert. This pattern
|
||||
// only takes care of this part and forwards the rest of the conversion to
|
||||
// another pattern that converts InsertStridedSlice for operands of the same
|
||||
// rank.
|
||||
//
|
||||
// RewritePattern for InsertStridedSliceOp where source and destination vectors
|
||||
// have different ranks. In this case:
|
||||
// 1. the proper subvector is extracted from the destination vector
|
||||
// 2. a new InsertStridedSlice op is created to insert the source in the
|
||||
// destination subvector
|
||||
// 3. the destination subvector is inserted back in the proper place
|
||||
// 4. the op is replaced by the result of step 3.
|
||||
// The new InsertStridedSlice from step 2. will be picked up by a
|
||||
// `VectorInsertStridedSliceOpSameRankRewritePattern`.
|
||||
class VectorInsertStridedSliceOpDifferentRankRewritePattern
|
||||
: public OpRewritePattern<InsertStridedSliceOp> {
|
||||
public:
|
||||
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto srcType = op.getSourceVectorType();
|
||||
auto dstType = op.getDestVectorType();
|
||||
|
||||
if (op.offsets().getValue().empty())
|
||||
return failure();
|
||||
|
||||
auto loc = op.getLoc();
|
||||
int64_t rankDiff = dstType.getRank() - srcType.getRank();
|
||||
assert(rankDiff >= 0);
|
||||
if (rankDiff == 0)
|
||||
return failure();
|
||||
|
||||
int64_t rankRest = dstType.getRank() - rankDiff;
|
||||
// Extract / insert the subvector of matching rank and InsertStridedSlice
|
||||
// on it.
|
||||
Value extracted =
|
||||
rewriter.create<ExtractOp>(loc, op.dest(),
|
||||
getI64SubArray(op.offsets(), /*dropFront=*/0,
|
||||
/*dropBack=*/rankRest));
|
||||
// A different pattern will kick in for InsertStridedSlice with matching
|
||||
// ranks.
|
||||
auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
|
||||
loc, op.source(), extracted,
|
||||
getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
|
||||
getI64SubArray(op.strides(), /*dropFront=*/0));
|
||||
rewriter.replaceOpWithNewOp<InsertOp>(
|
||||
op, stridedSliceInnerOp.getResult(), op.dest(),
|
||||
getI64SubArray(op.offsets(), /*dropFront=*/0,
|
||||
/*dropBack=*/rankRest));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// RewritePattern for InsertStridedSliceOp where source and destination vectors
|
||||
// have the same rank. In this case, we reduce
|
||||
// 1. the proper subvector is extracted from the destination vector
|
||||
// 2. a new InsertStridedSlice op is created to insert the source in the
|
||||
// destination subvector
|
||||
// 3. the destination subvector is inserted back in the proper place
|
||||
// 4. the op is replaced by the result of step 3.
|
||||
// The new InsertStridedSlice from step 2. will be picked up by a
|
||||
// `VectorInsertStridedSliceOpSameRankRewritePattern`.
|
||||
class VectorInsertStridedSliceOpSameRankRewritePattern
|
||||
: public OpRewritePattern<InsertStridedSliceOp> {
|
||||
public:
|
||||
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
|
||||
|
||||
void initialize() {
|
||||
// This pattern creates recursive InsertStridedSliceOp, but the recursion is
|
||||
// bounded as the rank is strictly decreasing.
|
||||
setHasBoundedRewriteRecursion();
|
||||
}
|
||||
|
||||
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto srcType = op.getSourceVectorType();
|
||||
auto dstType = op.getDestVectorType();
|
||||
|
||||
if (op.offsets().getValue().empty())
|
||||
return failure();
|
||||
|
||||
int64_t rankDiff = dstType.getRank() - srcType.getRank();
|
||||
assert(rankDiff >= 0);
|
||||
if (rankDiff != 0)
|
||||
return failure();
|
||||
|
||||
if (srcType == dstType) {
|
||||
rewriter.replaceOp(op, op.source());
|
||||
return success();
|
||||
}
|
||||
|
||||
int64_t offset =
|
||||
op.offsets().getValue().front().cast<IntegerAttr>().getInt();
|
||||
int64_t size = srcType.getShape().front();
|
||||
int64_t stride =
|
||||
op.strides().getValue().front().cast<IntegerAttr>().getInt();
|
||||
|
||||
auto loc = op.getLoc();
|
||||
Value res = op.dest();
|
||||
// For each slice of the source vector along the most major dimension.
|
||||
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
|
||||
off += stride, ++idx) {
|
||||
// 1. extract the proper subvector (or element) from source
|
||||
Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
|
||||
if (extractedSource.getType().isa<VectorType>()) {
|
||||
// 2. If we have a vector, extract the proper subvector from destination
|
||||
// Otherwise we are at the element level and no need to recurse.
|
||||
Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
|
||||
// 3. Reduce the problem to lowering a new InsertStridedSlice op with
|
||||
// smaller rank.
|
||||
extractedSource = rewriter.create<InsertStridedSliceOp>(
|
||||
loc, extractedSource, extractedDest,
|
||||
getI64SubArray(op.offsets(), /* dropFront=*/1),
|
||||
getI64SubArray(op.strides(), /* dropFront=*/1));
|
||||
}
|
||||
// 4. Insert the extractedSource into the res vector.
|
||||
res = insertOne(rewriter, loc, extractedSource, res, off);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, res);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Returns the strides if the memory underlying `memRefType` has a contiguous
|
||||
/// static layout.
|
||||
static llvm::Optional<SmallVector<int64_t, 4>>
|
||||
|
@ -1189,67 +1027,6 @@ private:
|
|||
}
|
||||
};
|
||||
|
||||
/// Progressive lowering of ExtractStridedSliceOp to either:
|
||||
/// 1. express single offset extract as a direct shuffle.
|
||||
/// 2. extract + lower rank strided_slice + insert for the n-D case.
|
||||
class VectorExtractStridedSliceOpConversion
|
||||
: public OpRewritePattern<ExtractStridedSliceOp> {
|
||||
public:
|
||||
using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
|
||||
|
||||
void initialize() {
|
||||
// This pattern creates recursive ExtractStridedSliceOp, but the recursion
|
||||
// is bounded as the rank is strictly decreasing.
|
||||
setHasBoundedRewriteRecursion();
|
||||
}
|
||||
|
||||
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto dstType = op.getType();
|
||||
|
||||
assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
|
||||
|
||||
int64_t offset =
|
||||
op.offsets().getValue().front().cast<IntegerAttr>().getInt();
|
||||
int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
|
||||
int64_t stride =
|
||||
op.strides().getValue().front().cast<IntegerAttr>().getInt();
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto elemType = dstType.getElementType();
|
||||
assert(elemType.isSignlessIntOrIndexOrFloat());
|
||||
|
||||
// Single offset can be more efficiently shuffled.
|
||||
if (op.offsets().getValue().size() == 1) {
|
||||
SmallVector<int64_t, 4> offsets;
|
||||
offsets.reserve(size);
|
||||
for (int64_t off = offset, e = offset + size * stride; off < e;
|
||||
off += stride)
|
||||
offsets.push_back(off);
|
||||
rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
|
||||
op.vector(),
|
||||
rewriter.getI64ArrayAttr(offsets));
|
||||
return success();
|
||||
}
|
||||
|
||||
// Extract/insert on a lower ranked extract strided slice op.
|
||||
Value zero = rewriter.create<arith::ConstantOp>(
|
||||
loc, elemType, rewriter.getZeroAttr(elemType));
|
||||
Value res = rewriter.create<SplatOp>(loc, dstType, zero);
|
||||
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
|
||||
off += stride, ++idx) {
|
||||
Value one = extractOne(rewriter, loc, op.vector(), off);
|
||||
Value extracted = rewriter.create<ExtractStridedSliceOp>(
|
||||
loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
|
||||
getI64SubArray(op.sizes(), /* dropFront=*/1),
|
||||
getI64SubArray(op.strides(), /* dropFront=*/1));
|
||||
res = insertOne(rewriter, loc, extracted, res, idx);
|
||||
}
|
||||
rewriter.replaceOp(op, res);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
/// Populate the given list with patterns that convert from Vector to LLVM.
|
||||
|
@ -1257,10 +1034,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
|
|||
LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||
bool reassociateFPReductions) {
|
||||
MLIRContext *ctx = converter.getDialect()->getContext();
|
||||
patterns.add<VectorFMAOpNDRewritePattern,
|
||||
VectorInsertStridedSliceOpDifferentRankRewritePattern,
|
||||
VectorInsertStridedSliceOpSameRankRewritePattern,
|
||||
VectorExtractStridedSliceOpConversion>(ctx);
|
||||
patterns.add<VectorFMAOpNDRewritePattern>(ctx);
|
||||
populateVectorInsertExtractStridedSliceTransforms(patterns);
|
||||
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
|
||||
patterns
|
||||
.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
add_mlir_dialect_library(MLIRVector
|
||||
VectorOps.cpp
|
||||
VectorInsertExtractStridedSliceRewritePatterns.cpp
|
||||
VectorMultiDimReductionTransforms.cpp
|
||||
VectorOps.cpp
|
||||
VectorTransferOpTransforms.cpp
|
||||
VectorTransforms.cpp
|
||||
VectorUtils.cpp
|
||||
|
|
|
@ -0,0 +1,236 @@
|
|||
//===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
|
||||
#include "mlir/Dialect/Vector/VectorUtils.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::vector;
|
||||
|
||||
// Helper that picks the proper sequence for inserting.
|
||||
static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
|
||||
Value into, int64_t offset) {
|
||||
auto vectorType = into.getType().cast<VectorType>();
|
||||
if (vectorType.getRank() > 1)
|
||||
return rewriter.create<InsertOp>(loc, from, into, offset);
|
||||
return rewriter.create<vector::InsertElementOp>(
|
||||
loc, vectorType, from, into,
|
||||
rewriter.create<arith::ConstantIndexOp>(loc, offset));
|
||||
}
|
||||
|
||||
// Helper that picks the proper sequence for extracting.
|
||||
static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
|
||||
int64_t offset) {
|
||||
auto vectorType = vector.getType().cast<VectorType>();
|
||||
if (vectorType.getRank() > 1)
|
||||
return rewriter.create<ExtractOp>(loc, vector, offset);
|
||||
return rewriter.create<vector::ExtractElementOp>(
|
||||
loc, vectorType.getElementType(), vector,
|
||||
rewriter.create<arith::ConstantIndexOp>(loc, offset));
|
||||
}
|
||||
|
||||
/// RewritePattern for InsertStridedSliceOp where source and destination vectors
|
||||
/// have different ranks.
|
||||
///
|
||||
/// When ranks are different, InsertStridedSlice needs to extract a properly
|
||||
/// ranked vector from the destination vector into which to insert. This pattern
|
||||
/// only takes care of this extraction part and forwards the rest to
|
||||
/// [VectorInsertStridedSliceOpSameRankRewritePattern].
|
||||
///
|
||||
/// For a k-D source and n-D destination vector (k < n), we emit:
|
||||
/// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
|
||||
/// insert the k-D source.
|
||||
/// 2. k-D -> (n-1)-D InsertStridedSlice op
|
||||
/// 3. InsertOp that is the reverse of 1.
|
||||
class VectorInsertStridedSliceOpDifferentRankRewritePattern
|
||||
: public OpRewritePattern<InsertStridedSliceOp> {
|
||||
public:
|
||||
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto srcType = op.getSourceVectorType();
|
||||
auto dstType = op.getDestVectorType();
|
||||
|
||||
if (op.offsets().getValue().empty())
|
||||
return failure();
|
||||
|
||||
auto loc = op.getLoc();
|
||||
int64_t rankDiff = dstType.getRank() - srcType.getRank();
|
||||
assert(rankDiff >= 0);
|
||||
if (rankDiff == 0)
|
||||
return failure();
|
||||
|
||||
int64_t rankRest = dstType.getRank() - rankDiff;
|
||||
// Extract / insert the subvector of matching rank and InsertStridedSlice
|
||||
// on it.
|
||||
Value extracted =
|
||||
rewriter.create<ExtractOp>(loc, op.dest(),
|
||||
getI64SubArray(op.offsets(), /*dropFront=*/0,
|
||||
/*dropBack=*/rankRest));
|
||||
|
||||
// A different pattern will kick in for InsertStridedSlice with matching
|
||||
// ranks.
|
||||
auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
|
||||
loc, op.source(), extracted,
|
||||
getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
|
||||
getI64SubArray(op.strides(), /*dropFront=*/0));
|
||||
|
||||
rewriter.replaceOpWithNewOp<InsertOp>(
|
||||
op, stridedSliceInnerOp.getResult(), op.dest(),
|
||||
getI64SubArray(op.offsets(), /*dropFront=*/0,
|
||||
/*dropBack=*/rankRest));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// RewritePattern for InsertStridedSliceOp where source and destination vectors
|
||||
/// have the same rank. For each outermost index in the slice:
|
||||
/// begin end stride
|
||||
/// [offset : offset+size*stride : stride]
|
||||
/// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
|
||||
/// 2. InsertStridedSlice (k-1)-D into (n-1)-D
|
||||
/// 3. the destination subvector is inserted back in the proper place
|
||||
/// 3. InsertOp that is the reverse of 1.
|
||||
class VectorInsertStridedSliceOpSameRankRewritePattern
|
||||
: public OpRewritePattern<InsertStridedSliceOp> {
|
||||
public:
|
||||
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
|
||||
|
||||
void initialize() {
|
||||
// This pattern creates recursive InsertStridedSliceOp, but the recursion is
|
||||
// bounded as the rank is strictly decreasing.
|
||||
setHasBoundedRewriteRecursion();
|
||||
}
|
||||
|
||||
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto srcType = op.getSourceVectorType();
|
||||
auto dstType = op.getDestVectorType();
|
||||
|
||||
if (op.offsets().getValue().empty())
|
||||
return failure();
|
||||
|
||||
int64_t rankDiff = dstType.getRank() - srcType.getRank();
|
||||
assert(rankDiff >= 0);
|
||||
if (rankDiff != 0)
|
||||
return failure();
|
||||
|
||||
if (srcType == dstType) {
|
||||
rewriter.replaceOp(op, op.source());
|
||||
return success();
|
||||
}
|
||||
|
||||
int64_t offset =
|
||||
op.offsets().getValue().front().cast<IntegerAttr>().getInt();
|
||||
int64_t size = srcType.getShape().front();
|
||||
int64_t stride =
|
||||
op.strides().getValue().front().cast<IntegerAttr>().getInt();
|
||||
|
||||
auto loc = op.getLoc();
|
||||
Value res = op.dest();
|
||||
// For each slice of the source vector along the most major dimension.
|
||||
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
|
||||
off += stride, ++idx) {
|
||||
// 1. extract the proper subvector (or element) from source
|
||||
Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
|
||||
if (extractedSource.getType().isa<VectorType>()) {
|
||||
// 2. If we have a vector, extract the proper subvector from destination
|
||||
// Otherwise we are at the element level and no need to recurse.
|
||||
Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
|
||||
// 3. Reduce the problem to lowering a new InsertStridedSlice op with
|
||||
// smaller rank.
|
||||
extractedSource = rewriter.create<InsertStridedSliceOp>(
|
||||
loc, extractedSource, extractedDest,
|
||||
getI64SubArray(op.offsets(), /* dropFront=*/1),
|
||||
getI64SubArray(op.strides(), /* dropFront=*/1));
|
||||
}
|
||||
// 4. Insert the extractedSource into the res vector.
|
||||
res = insertOne(rewriter, loc, extractedSource, res, off);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, res);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Progressive lowering of ExtractStridedSliceOp to either:
|
||||
/// 1. single offset extract as a direct vector::ShuffleOp.
|
||||
/// 2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp +
|
||||
/// InsertOp/InsertElementOp for the n-D case.
|
||||
class VectorExtractStridedSliceOpRewritePattern
|
||||
: public OpRewritePattern<ExtractStridedSliceOp> {
|
||||
public:
|
||||
using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
|
||||
|
||||
void initialize() {
|
||||
// This pattern creates recursive ExtractStridedSliceOp, but the recursion
|
||||
// is bounded as the rank is strictly decreasing.
|
||||
setHasBoundedRewriteRecursion();
|
||||
}
|
||||
|
||||
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto dstType = op.getType();
|
||||
|
||||
assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
|
||||
|
||||
int64_t offset =
|
||||
op.offsets().getValue().front().cast<IntegerAttr>().getInt();
|
||||
int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
|
||||
int64_t stride =
|
||||
op.strides().getValue().front().cast<IntegerAttr>().getInt();
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto elemType = dstType.getElementType();
|
||||
assert(elemType.isSignlessIntOrIndexOrFloat());
|
||||
|
||||
// Single offset can be more efficiently shuffled.
|
||||
if (op.offsets().getValue().size() == 1) {
|
||||
SmallVector<int64_t, 4> offsets;
|
||||
offsets.reserve(size);
|
||||
for (int64_t off = offset, e = offset + size * stride; off < e;
|
||||
off += stride)
|
||||
offsets.push_back(off);
|
||||
rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
|
||||
op.vector(),
|
||||
rewriter.getI64ArrayAttr(offsets));
|
||||
return success();
|
||||
}
|
||||
|
||||
// Extract/insert on a lower ranked extract strided slice op.
|
||||
Value zero = rewriter.create<arith::ConstantOp>(
|
||||
loc, elemType, rewriter.getZeroAttr(elemType));
|
||||
Value res = rewriter.create<SplatOp>(loc, dstType, zero);
|
||||
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
|
||||
off += stride, ++idx) {
|
||||
Value one = extractOne(rewriter, loc, op.vector(), off);
|
||||
Value extracted = rewriter.create<ExtractStridedSliceOp>(
|
||||
loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
|
||||
getI64SubArray(op.sizes(), /* dropFront=*/1),
|
||||
getI64SubArray(op.strides(), /* dropFront=*/1));
|
||||
res = insertOne(rewriter, loc, extracted, res, idx);
|
||||
}
|
||||
rewriter.replaceOp(op, res);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Populate the given list with patterns that convert from Vector to LLVM.
|
||||
void mlir::vector::populateVectorInsertExtractStridedSliceTransforms(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<VectorInsertStridedSliceOpDifferentRankRewritePattern,
|
||||
VectorInsertStridedSliceOpSameRankRewritePattern,
|
||||
VectorExtractStridedSliceOpRewritePattern>(
|
||||
patterns.getContext());
|
||||
}
|
|
@ -2204,20 +2204,6 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
|
||||
static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
|
||||
unsigned dropFront = 0,
|
||||
unsigned dropBack = 0) {
|
||||
assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
|
||||
auto range = arrayAttr.getAsRange<IntegerAttr>();
|
||||
SmallVector<int64_t, 4> res;
|
||||
res.reserve(arrayAttr.size() - dropFront - dropBack);
|
||||
for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
|
||||
it != eit; ++it)
|
||||
res.push_back((*it).getValue().getSExtValue());
|
||||
return res;
|
||||
}
|
||||
|
||||
// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
|
||||
// BroadcastOp(ExtractStrideSliceOp).
|
||||
class StridedSliceBroadcast final
|
||||
|
|
|
@ -1034,10 +1034,11 @@ public:
|
|||
};
|
||||
|
||||
/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
|
||||
/// vectors progressively on the way from targeting llvm.matrix intrinsics.
|
||||
/// vectors progressively.
|
||||
/// This iterates over the most major dimension of the 2-D vector and performs
|
||||
/// rewrites into:
|
||||
/// vector.strided_slice from 1-D + vector.insert into 2-D
|
||||
/// vector.extract_strided_slice from 1-D + vector.insert into 2-D
|
||||
/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
|
||||
class ShapeCastOp2DUpCastRewritePattern
|
||||
: public OpRewritePattern<vector::ShapeCastOp> {
|
||||
public:
|
||||
|
|
|
@ -362,3 +362,16 @@ bool mlir::checkSameValueWAW(vector::TransferWriteOp write,
|
|||
priorWrite.getVectorType() == write.getVectorType() &&
|
||||
priorWrite.permutation_map() == write.permutation_map();
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 4> mlir::getI64SubArray(ArrayAttr arrayAttr,
|
||||
unsigned dropFront,
|
||||
unsigned dropBack) {
|
||||
assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
|
||||
auto range = arrayAttr.getAsRange<IntegerAttr>();
|
||||
SmallVector<int64_t, 4> res;
|
||||
res.reserve(arrayAttr.size() - dropFront - dropBack);
|
||||
for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
|
||||
it != eit; ++it)
|
||||
res.push_back((*it).getValue().getSExtValue());
|
||||
return res;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue