forked from OSchip/llvm-project
[mlir][VectorOps] Implement insert_strided_slice conversion
Summary: This diff implements the progressive lowering of insert_strided_slice. Two cases appear: 1. when the source and dest vectors have different ranks, extract the dest subvector at the proper offset and reduce to case 2. 2. when they have the same rank N: a. if the source and dest type are the same, the insertion is trivial: just forward the source b. otherwise, iterate over all N-1 D subvectors and create an extract/insert_strided_slice/insert replacement, reducing the problem to vecotrs of the same N-1 rank. This combines properly with the other conversion patterns to lower all the way to LLVM. Reviewers: ftynse, rriddle, AlexEichenberger, andydavis1, tetuante, nicolasvasilache Reviewed By: andydavis1 Subscribers: merge_guards_bot, mehdi_amini, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72317
This commit is contained in:
parent
65678d9384
commit
2d515e49d8
|
@ -70,6 +70,17 @@ static Value insertOne(ConversionPatternRewriter &rewriter,
|
|||
rewriter.getI64ArrayAttr(pos));
|
||||
}
|
||||
|
||||
// Helper that picks the proper sequence for inserting.
|
||||
static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
|
||||
Value into, int64_t offset) {
|
||||
auto vectorType = into.getType().cast<VectorType>();
|
||||
if (vectorType.getRank() > 1)
|
||||
return rewriter.create<InsertOp>(loc, from, into, offset);
|
||||
return rewriter.create<vector::InsertElementOp>(
|
||||
loc, vectorType, from, into,
|
||||
rewriter.create<ConstantIndexOp>(loc, offset));
|
||||
}
|
||||
|
||||
// Helper that picks the proper sequence for extracting.
|
||||
static Value extractOne(ConversionPatternRewriter &rewriter,
|
||||
LLVMTypeConverter &lowering, Location loc, Value val,
|
||||
|
@ -86,6 +97,32 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
|
|||
rewriter.getI64ArrayAttr(pos));
|
||||
}
|
||||
|
||||
// Helper that picks the proper sequence for extracting.
|
||||
static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
|
||||
int64_t offset) {
|
||||
auto vectorType = vector.getType().cast<VectorType>();
|
||||
if (vectorType.getRank() > 1)
|
||||
return rewriter.create<ExtractOp>(loc, vector, offset);
|
||||
return rewriter.create<vector::ExtractElementOp>(
|
||||
loc, vectorType.getElementType(), vector,
|
||||
rewriter.create<ConstantIndexOp>(loc, offset));
|
||||
}
|
||||
|
||||
// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
|
||||
// TODO(rriddle): Better support for attribute subtype forwarding + slicing.
|
||||
static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
|
||||
unsigned dropFront = 0,
|
||||
unsigned dropBack = 0) {
|
||||
assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
|
||||
auto range = arrayAttr.getAsRange<IntegerAttr>();
|
||||
SmallVector<int64_t, 4> res;
|
||||
res.reserve(arrayAttr.size() - dropFront - dropBack);
|
||||
for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
|
||||
it != eit; ++it)
|
||||
res.push_back((*it).getValue().getSExtValue());
|
||||
return res;
|
||||
}
|
||||
|
||||
class VectorBroadcastOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
explicit VectorBroadcastOpConversion(MLIRContext *context,
|
||||
|
@ -464,6 +501,139 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// When ranks are different, InsertStridedSlice needs to extract a properly
|
||||
// ranked vector from the destination vector into which to insert. This pattern
|
||||
// only takes care of this part and forwards the rest of the conversion to
|
||||
// another pattern that converts InsertStridedSlice for operands of the same
|
||||
// rank.
|
||||
//
|
||||
// RewritePattern for InsertStridedSliceOp where source and destination vectors
|
||||
// have different ranks. In this case:
|
||||
// 1. the proper subvector is extracted from the destination vector
|
||||
// 2. a new InsertStridedSlice op is created to insert the source in the
|
||||
// destination subvector
|
||||
// 3. the destination subvector is inserted back in the proper place
|
||||
// 4. the op is replaced by the result of step 3.
|
||||
// The new InsertStridedSlice from step 2. will be picked up by a
|
||||
// `VectorInsertStridedSliceOpSameRankRewritePattern`.
|
||||
class VectorInsertStridedSliceOpDifferentRankRewritePattern
|
||||
: public OpRewritePattern<InsertStridedSliceOp> {
|
||||
public:
|
||||
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(InsertStridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto srcType = op.getSourceVectorType();
|
||||
auto dstType = op.getDestVectorType();
|
||||
|
||||
if (op.offsets().getValue().empty())
|
||||
return matchFailure();
|
||||
|
||||
auto loc = op.getLoc();
|
||||
int64_t rankDiff = dstType.getRank() - srcType.getRank();
|
||||
assert(rankDiff >= 0);
|
||||
if (rankDiff == 0)
|
||||
return matchFailure();
|
||||
|
||||
int64_t rankRest = dstType.getRank() - rankDiff;
|
||||
// Extract / insert the subvector of matching rank and InsertStridedSlice
|
||||
// on it.
|
||||
Value extracted =
|
||||
rewriter.create<ExtractOp>(loc, op.dest(),
|
||||
getI64SubArray(op.offsets(), /*dropFront=*/0,
|
||||
/*dropFront=*/rankRest));
|
||||
// A different pattern will kick in for InsertStridedSlice with matching
|
||||
// ranks.
|
||||
auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
|
||||
loc, op.source(), extracted,
|
||||
getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
|
||||
getI64SubArray(op.strides(), /*dropFront=*/rankDiff));
|
||||
rewriter.replaceOpWithNewOp<InsertOp>(
|
||||
op, stridedSliceInnerOp.getResult(), op.dest(),
|
||||
getI64SubArray(op.offsets(), /*dropFront=*/0,
|
||||
/*dropFront=*/rankRest));
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// RewritePattern for InsertStridedSliceOp where source and destination vectors
|
||||
// have the same rank. In this case, we reduce
|
||||
// 1. the proper subvector is extracted from the destination vector
|
||||
// 2. a new InsertStridedSlice op is created to insert the source in the
|
||||
// destination subvector
|
||||
// 3. the destination subvector is inserted back in the proper place
|
||||
// 4. the op is replaced by the result of step 3.
|
||||
// The new InsertStridedSlice from step 2. will be picked up by a
|
||||
// `VectorInsertStridedSliceOpSameRankRewritePattern`.
|
||||
class VectorInsertStridedSliceOpSameRankRewritePattern
|
||||
: public OpRewritePattern<InsertStridedSliceOp> {
|
||||
public:
|
||||
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(InsertStridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto srcType = op.getSourceVectorType();
|
||||
auto dstType = op.getDestVectorType();
|
||||
|
||||
if (op.offsets().getValue().empty())
|
||||
return matchFailure();
|
||||
|
||||
int64_t rankDiff = dstType.getRank() - srcType.getRank();
|
||||
assert(rankDiff >= 0);
|
||||
if (rankDiff != 0)
|
||||
return matchFailure();
|
||||
|
||||
if (srcType == dstType) {
|
||||
rewriter.replaceOp(op, op.source());
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
int64_t offset =
|
||||
op.offsets().getValue().front().cast<IntegerAttr>().getInt();
|
||||
int64_t size = srcType.getShape().front();
|
||||
int64_t stride =
|
||||
op.strides().getValue().front().cast<IntegerAttr>().getInt();
|
||||
|
||||
auto loc = op.getLoc();
|
||||
Value res = op.dest();
|
||||
// For each slice of the source vector along the most major dimension.
|
||||
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
|
||||
off += stride, ++idx) {
|
||||
// 1. extract the proper subvector (or element) from source
|
||||
Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
|
||||
if (extractedSource.getType().isa<VectorType>()) {
|
||||
// 2. If we have a vector, extract the proper subvector from destination
|
||||
// Otherwise we are at the element level and no need to recurse.
|
||||
Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
|
||||
// 3. Reduce the problem to lowering a new InsertStridedSlice op with
|
||||
// smaller rank.
|
||||
InsertStridedSliceOp insertStridedSliceOp =
|
||||
rewriter.create<InsertStridedSliceOp>(
|
||||
loc, extractedSource, extractedDest,
|
||||
getI64SubArray(op.offsets(), /* dropFront=*/1),
|
||||
getI64SubArray(op.strides(), /* dropFront=*/1));
|
||||
// Call matchAndRewrite recursively from within the pattern. This
|
||||
// circumvents the current limitation that a given pattern cannot
|
||||
// be called multiple times by the PatternRewrite infrastructure (to
|
||||
// avoid infinite recursion, but in this case, infinite recursion
|
||||
// cannot happen because the rank is strictly decreasing).
|
||||
// TODO(rriddle, nicolasvasilache) Implement something like a hook for
|
||||
// a potential function that must decrease and allow the same pattern
|
||||
// multiple times.
|
||||
auto success = matchAndRewrite(insertStridedSliceOp, rewriter);
|
||||
(void)success;
|
||||
assert(success && "Unexpected failure");
|
||||
extractedSource = insertStridedSliceOp;
|
||||
}
|
||||
// 4. Insert the extractedSource into the res vector.
|
||||
res = insertOne(rewriter, loc, extractedSource, res, off);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, res);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
class VectorOuterProductOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
explicit VectorOuterProductOpConversion(MLIRContext *context,
|
||||
|
@ -725,49 +895,10 @@ private:
|
|||
}
|
||||
};
|
||||
|
||||
// TODO(rriddle): Better support for attribute subtype forwarding + slicing.
|
||||
static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
|
||||
unsigned dropFront = 0,
|
||||
unsigned dropBack = 0) {
|
||||
assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
|
||||
auto range = arrayAttr.getAsRange<IntegerAttr>();
|
||||
SmallVector<int64_t, 4> res;
|
||||
res.reserve(arrayAttr.size() - dropFront - dropBack);
|
||||
for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
|
||||
it != eit; ++it)
|
||||
res.push_back((*it).getValue().getSExtValue());
|
||||
return res;
|
||||
}
|
||||
|
||||
/// Emit the proper `ExtractOp` or `ExtractElementOp` depending on the rank
|
||||
/// of `vector`.
|
||||
static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
|
||||
int64_t offset) {
|
||||
auto vectorType = vector.getType().cast<VectorType>();
|
||||
if (vectorType.getRank() > 1)
|
||||
return rewriter.create<ExtractOp>(loc, vector, offset);
|
||||
return rewriter.create<vector::ExtractElementOp>(
|
||||
loc, vectorType.getElementType(), vector,
|
||||
rewriter.create<ConstantIndexOp>(loc, offset));
|
||||
}
|
||||
|
||||
/// Emit the proper `InsertOp` or `InsertElementOp` depending on the rank
|
||||
/// of `vector`.
|
||||
static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
|
||||
Value into, int64_t offset) {
|
||||
auto vectorType = into.getType().cast<VectorType>();
|
||||
if (vectorType.getRank() > 1)
|
||||
return rewriter.create<InsertOp>(loc, from, into, offset);
|
||||
return rewriter.create<vector::InsertElementOp>(
|
||||
loc, vectorType, from, into,
|
||||
rewriter.create<ConstantIndexOp>(loc, offset));
|
||||
}
|
||||
|
||||
/// Progressive lowering of StridedSliceOp to either:
|
||||
/// 1. extractelement + insertelement for the 1-D case
|
||||
/// 2. extract + optional strided_slice + insert for the n-D case.
|
||||
class VectorStridedSliceOpRewritePattern
|
||||
: public OpRewritePattern<StridedSliceOp> {
|
||||
class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> {
|
||||
public:
|
||||
using OpRewritePattern<StridedSliceOp>::OpRewritePattern;
|
||||
|
||||
|
@ -821,7 +952,9 @@ public:
|
|||
void mlir::populateVectorToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||
MLIRContext *ctx = converter.getDialect()->getContext();
|
||||
patterns.insert<VectorStridedSliceOpRewritePattern>(ctx);
|
||||
patterns.insert<VectorInsertStridedSliceOpDifferentRankRewritePattern,
|
||||
VectorInsertStridedSliceOpSameRankRewritePattern,
|
||||
VectorStridedSliceOpConversion>(ctx);
|
||||
patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion,
|
||||
VectorExtractElementOpConversion, VectorExtractOpConversion,
|
||||
VectorInsertElementOpConversion, VectorInsertOpConversion,
|
||||
|
|
|
@ -427,7 +427,6 @@ func @vector_print_vector(%arg0: vector<2x2xf32>) {
|
|||
|
||||
func @strided_slice(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>, %arg2: vector<4x8x16xf32>) {
|
||||
// CHECK-LABEL: llvm.func @strided_slice(
|
||||
|
||||
%0 = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
|
||||
// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
|
||||
// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm<"<2 x float>">
|
||||
|
@ -483,4 +482,45 @@ func @strided_slice(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>, %arg2: vector<
|
|||
return
|
||||
}
|
||||
|
||||
func @insert_strided_slice(%a: vector<2x2xf32>, %b: vector<4x4xf32>, %c: vector<4x4x4xf32>) {
|
||||
// CHECK-LABEL: @insert_strided_slice
|
||||
|
||||
%0 = vector.insert_strided_slice %b, %c {offsets = [2, 0, 0], strides = [1, 1]} : vector<4x4xf32> into vector<4x4x4xf32>
|
||||
// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x [4 x <4 x float>]]">
|
||||
// CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [4 x <4 x float>]]">
|
||||
|
||||
%1 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
|
||||
//
|
||||
// Subvector vector<2xf32> @0 into vector<4xf32> @2
|
||||
// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[2 x <2 x float>]">
|
||||
// CHECK-NEXT: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x <4 x float>]">
|
||||
// Element @0 -> element @2
|
||||
// CHECK-NEXT: llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK-NEXT: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
|
||||
// CHECK-NEXT: llvm.mlir.constant(2 : index) : !llvm.i64
|
||||
// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>">
|
||||
// Element @1 -> element @3
|
||||
// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK-NEXT: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
|
||||
// CHECK-NEXT: llvm.mlir.constant(3 : index) : !llvm.i64
|
||||
// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>">
|
||||
// CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x <4 x float>]">
|
||||
//
|
||||
// Subvector vector<2xf32> @1 into vector<4xf32> @3
|
||||
// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"[2 x <2 x float>]">
|
||||
// CHECK-NEXT: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <4 x float>]">
|
||||
// Element @0 -> element @2
|
||||
// CHECK-NEXT: llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK-NEXT: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
|
||||
// CHECK-NEXT: llvm.mlir.constant(2 : index) : !llvm.i64
|
||||
// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>">
|
||||
// Element @1 -> element @3
|
||||
// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK-NEXT: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
|
||||
// CHECK-NEXT: llvm.mlir.constant(3 : index) : !llvm.i64
|
||||
// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>">
|
||||
// CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <4 x float>]">
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue