forked from OSchip/llvm-project
[mlir][vector] Add folding for ExtractOp with ShapeCastOp source
Differential Revision: https://reviews.llvm.org/D89853
This commit is contained in:
parent
4b90a253c2
commit
8c72eea9a0
|
@ -843,6 +843,61 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
|
||||||
return Value();
|
return Value();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fold extractOp with source coming from ShapeCast op.
|
||||||
|
static Value foldExtractFromShapeCast(ExtractOp extractOp) {
|
||||||
|
auto shapeCastOp = extractOp.vector().getDefiningOp<vector::ShapeCastOp>();
|
||||||
|
if (!shapeCastOp)
|
||||||
|
return Value();
|
||||||
|
// Get the nth dimension size starting from lowest dimension.
|
||||||
|
auto getDimReverse = [](VectorType type, int64_t n) {
|
||||||
|
return type.getDimSize(type.getRank() - n - 1);
|
||||||
|
};
|
||||||
|
int64_t destinationRank =
|
||||||
|
extractOp.getVectorType().getRank() - extractOp.position().size();
|
||||||
|
if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
|
||||||
|
return Value();
|
||||||
|
if (destinationRank > 0) {
|
||||||
|
auto destinationType = extractOp.getResult().getType().cast<VectorType>();
|
||||||
|
for (int64_t i = 0; i < destinationRank; i++) {
|
||||||
|
// The lowest dimension of of the destination must match the lowest
|
||||||
|
// dimension of the shapecast op source.
|
||||||
|
if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
|
||||||
|
getDimReverse(destinationType, i))
|
||||||
|
return Value();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Extract the strides associated with the extract op vector source. Then use
|
||||||
|
// this to calculate a linearized position for the extract.
|
||||||
|
auto extractedPos = extractVector<int64_t>(extractOp.position());
|
||||||
|
std::reverse(extractedPos.begin(), extractedPos.end());
|
||||||
|
SmallVector<int64_t, 4> strides;
|
||||||
|
int64_t stride = 1;
|
||||||
|
for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
|
||||||
|
strides.push_back(stride);
|
||||||
|
stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t position = linearize(extractedPos, strides);
|
||||||
|
// Then extract the strides assoociated to the shapeCast op vector source and
|
||||||
|
// delinearize the position using those strides.
|
||||||
|
SmallVector<int64_t, 4> newStrides;
|
||||||
|
int64_t numDimension =
|
||||||
|
shapeCastOp.getSourceVectorType().getRank() - destinationRank;
|
||||||
|
stride = 1;
|
||||||
|
for (int64_t i = 0; i < numDimension; i++) {
|
||||||
|
newStrides.push_back(stride);
|
||||||
|
stride *=
|
||||||
|
getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
|
||||||
|
}
|
||||||
|
std::reverse(newStrides.begin(), newStrides.end());
|
||||||
|
SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
|
||||||
|
OpBuilder b(extractOp.getContext());
|
||||||
|
extractOp.setAttr(ExtractOp::getPositionAttrName(),
|
||||||
|
b.getI64ArrayAttr(newPosition));
|
||||||
|
extractOp.setOperand(shapeCastOp.source());
|
||||||
|
return extractOp.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
|
OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
|
||||||
if (succeeded(foldExtractOpFromExtractChain(*this)))
|
if (succeeded(foldExtractOpFromExtractChain(*this)))
|
||||||
return getResult();
|
return getResult();
|
||||||
|
@ -852,6 +907,8 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
|
||||||
return val;
|
return val;
|
||||||
if (auto val = foldExtractFromBroadcast(*this))
|
if (auto val = foldExtractFromBroadcast(*this))
|
||||||
return val;
|
return val;
|
||||||
|
if (auto val = foldExtractFromShapeCast(*this))
|
||||||
|
return val;
|
||||||
return OpFoldResult();
|
return OpFoldResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -394,6 +394,39 @@ func @fold_extract_broadcast_negative(%a : f32) -> vector<4xf32> {
|
||||||
return %r : vector<4xf32>
|
return %r : vector<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_extract_shapecast
|
||||||
|
// CHECK-SAME: (%[[A0:.*]]: vector<5x1x3x2xf32>, %[[A1:.*]]: vector<8x4x2xf32>
|
||||||
|
// CHECK: %[[R0:.*]] = vector.extract %[[A0]][1, 0, 1, 1] : vector<5x1x3x2xf32>
|
||||||
|
// CHECK: %[[R1:.*]] = vector.extract %[[A0]][1, 0, 2] : vector<5x1x3x2xf32>
|
||||||
|
// CHECK: %[[R2:.*]] = vector.extract %[[A1]][7] : vector<8x4x2xf32>
|
||||||
|
// CHECK: return %[[R0]], %[[R1]], %[[R2]] : f32, vector<2xf32>, vector<4x2xf32>
|
||||||
|
func @fold_extract_shapecast(%arg0 : vector<5x1x3x2xf32>,
|
||||||
|
%arg1 : vector<8x4x2xf32>)
|
||||||
|
-> (f32, vector<2xf32>, vector<4x2xf32>) {
|
||||||
|
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xf32>
|
||||||
|
%1 = vector.shape_cast %arg1 : vector<8x4x2xf32> to vector<4x2x4x2xf32>
|
||||||
|
%r1 = vector.extract %0[4, 1] : vector<15x2xf32>
|
||||||
|
%r2 = vector.extract %0[5] : vector<15x2xf32>
|
||||||
|
%r3 = vector.extract %1[3, 1] : vector<4x2x4x2xf32>
|
||||||
|
return %r1, %r2, %r3 : f32, vector<2xf32>, vector<4x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_extract_shapecast_negative
|
||||||
|
// CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32>
|
||||||
|
// CHECK: %[[R:.*]] = vector.extract %[[V]][1] : vector<2x4x2xf32>
|
||||||
|
// CHECK: return %[[R]] : vector<4x2xf32>
|
||||||
|
func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>,
|
||||||
|
%arg1 : vector<8x4x2xf32>) -> vector<4x2xf32> {
|
||||||
|
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32>
|
||||||
|
%r = vector.extract %0[1] : vector<2x4x2xf32>
|
||||||
|
return %r : vector<4x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: fold_vector_transfers
|
// CHECK-LABEL: fold_vector_transfers
|
||||||
|
|
Loading…
Reference in New Issue