forked from OSchip/llvm-project
[mlir][vector] Add folding for ExtractOp with ShapeCastOp source
Differential Revision: https://reviews.llvm.org/D89853
This commit is contained in:
parent
4b90a253c2
commit
8c72eea9a0
|
@ -843,6 +843,61 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
|
|||
return Value();
|
||||
}
|
||||
|
||||
// Fold extractOp with source coming from ShapeCast op.
|
||||
static Value foldExtractFromShapeCast(ExtractOp extractOp) {
|
||||
auto shapeCastOp = extractOp.vector().getDefiningOp<vector::ShapeCastOp>();
|
||||
if (!shapeCastOp)
|
||||
return Value();
|
||||
// Get the nth dimension size starting from lowest dimension.
|
||||
auto getDimReverse = [](VectorType type, int64_t n) {
|
||||
return type.getDimSize(type.getRank() - n - 1);
|
||||
};
|
||||
int64_t destinationRank =
|
||||
extractOp.getVectorType().getRank() - extractOp.position().size();
|
||||
if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
|
||||
return Value();
|
||||
if (destinationRank > 0) {
|
||||
auto destinationType = extractOp.getResult().getType().cast<VectorType>();
|
||||
for (int64_t i = 0; i < destinationRank; i++) {
|
||||
// The lowest dimension of of the destination must match the lowest
|
||||
// dimension of the shapecast op source.
|
||||
if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
|
||||
getDimReverse(destinationType, i))
|
||||
return Value();
|
||||
}
|
||||
}
|
||||
// Extract the strides associated with the extract op vector source. Then use
|
||||
// this to calculate a linearized position for the extract.
|
||||
auto extractedPos = extractVector<int64_t>(extractOp.position());
|
||||
std::reverse(extractedPos.begin(), extractedPos.end());
|
||||
SmallVector<int64_t, 4> strides;
|
||||
int64_t stride = 1;
|
||||
for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
|
||||
strides.push_back(stride);
|
||||
stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank);
|
||||
}
|
||||
|
||||
int64_t position = linearize(extractedPos, strides);
|
||||
// Then extract the strides assoociated to the shapeCast op vector source and
|
||||
// delinearize the position using those strides.
|
||||
SmallVector<int64_t, 4> newStrides;
|
||||
int64_t numDimension =
|
||||
shapeCastOp.getSourceVectorType().getRank() - destinationRank;
|
||||
stride = 1;
|
||||
for (int64_t i = 0; i < numDimension; i++) {
|
||||
newStrides.push_back(stride);
|
||||
stride *=
|
||||
getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
|
||||
}
|
||||
std::reverse(newStrides.begin(), newStrides.end());
|
||||
SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
|
||||
OpBuilder b(extractOp.getContext());
|
||||
extractOp.setAttr(ExtractOp::getPositionAttrName(),
|
||||
b.getI64ArrayAttr(newPosition));
|
||||
extractOp.setOperand(shapeCastOp.source());
|
||||
return extractOp.getResult();
|
||||
}
|
||||
|
||||
OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
|
||||
if (succeeded(foldExtractOpFromExtractChain(*this)))
|
||||
return getResult();
|
||||
|
@ -852,6 +907,8 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
|
|||
return val;
|
||||
if (auto val = foldExtractFromBroadcast(*this))
|
||||
return val;
|
||||
if (auto val = foldExtractFromShapeCast(*this))
|
||||
return val;
|
||||
return OpFoldResult();
|
||||
}
|
||||
|
||||
|
|
|
@ -394,6 +394,39 @@ func @fold_extract_broadcast_negative(%a : f32) -> vector<4xf32> {
|
|||
return %r : vector<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fold_extract_shapecast
|
||||
// CHECK-SAME: (%[[A0:.*]]: vector<5x1x3x2xf32>, %[[A1:.*]]: vector<8x4x2xf32>
|
||||
// CHECK: %[[R0:.*]] = vector.extract %[[A0]][1, 0, 1, 1] : vector<5x1x3x2xf32>
|
||||
// CHECK: %[[R1:.*]] = vector.extract %[[A0]][1, 0, 2] : vector<5x1x3x2xf32>
|
||||
// CHECK: %[[R2:.*]] = vector.extract %[[A1]][7] : vector<8x4x2xf32>
|
||||
// CHECK: return %[[R0]], %[[R1]], %[[R2]] : f32, vector<2xf32>, vector<4x2xf32>
|
||||
func @fold_extract_shapecast(%arg0 : vector<5x1x3x2xf32>,
|
||||
%arg1 : vector<8x4x2xf32>)
|
||||
-> (f32, vector<2xf32>, vector<4x2xf32>) {
|
||||
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xf32>
|
||||
%1 = vector.shape_cast %arg1 : vector<8x4x2xf32> to vector<4x2x4x2xf32>
|
||||
%r1 = vector.extract %0[4, 1] : vector<15x2xf32>
|
||||
%r2 = vector.extract %0[5] : vector<15x2xf32>
|
||||
%r3 = vector.extract %1[3, 1] : vector<4x2x4x2xf32>
|
||||
return %r1, %r2, %r3 : f32, vector<2xf32>, vector<4x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: fold_extract_shapecast_negative
|
||||
// CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32>
|
||||
// CHECK: %[[R:.*]] = vector.extract %[[V]][1] : vector<2x4x2xf32>
|
||||
// CHECK: return %[[R]] : vector<4x2xf32>
|
||||
func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>,
|
||||
%arg1 : vector<8x4x2xf32>) -> vector<4x2xf32> {
|
||||
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32>
|
||||
%r = vector.extract %0[1] : vector<2x4x2xf32>
|
||||
return %r : vector<4x2xf32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: fold_vector_transfers
|
||||
|
|
Loading…
Reference in New Issue