[mlir][vector] Add folding for ExtractOp with ShapeCastOp source

Differential Revision: https://reviews.llvm.org/D89853
This commit is contained in:
Thomas Raoux 2020-10-23 11:53:38 -07:00
parent 4b90a253c2
commit 8c72eea9a0
2 changed files with 90 additions and 0 deletions

View File

@ -843,6 +843,61 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
return Value(); return Value();
} }
// Fold extractOp with source coming from ShapeCast op.
static Value foldExtractFromShapeCast(ExtractOp extractOp) {
auto shapeCastOp = extractOp.vector().getDefiningOp<vector::ShapeCastOp>();
if (!shapeCastOp)
return Value();
// Get the nth dimension size starting from lowest dimension.
auto getDimReverse = [](VectorType type, int64_t n) {
return type.getDimSize(type.getRank() - n - 1);
};
int64_t destinationRank =
extractOp.getVectorType().getRank() - extractOp.position().size();
if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
return Value();
if (destinationRank > 0) {
auto destinationType = extractOp.getResult().getType().cast<VectorType>();
for (int64_t i = 0; i < destinationRank; i++) {
// The lowest dimension of of the destination must match the lowest
// dimension of the shapecast op source.
if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
getDimReverse(destinationType, i))
return Value();
}
}
// Extract the strides associated with the extract op vector source. Then use
// this to calculate a linearized position for the extract.
auto extractedPos = extractVector<int64_t>(extractOp.position());
std::reverse(extractedPos.begin(), extractedPos.end());
SmallVector<int64_t, 4> strides;
int64_t stride = 1;
for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
strides.push_back(stride);
stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank);
}
int64_t position = linearize(extractedPos, strides);
// Then extract the strides assoociated to the shapeCast op vector source and
// delinearize the position using those strides.
SmallVector<int64_t, 4> newStrides;
int64_t numDimension =
shapeCastOp.getSourceVectorType().getRank() - destinationRank;
stride = 1;
for (int64_t i = 0; i < numDimension; i++) {
newStrides.push_back(stride);
stride *=
getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
}
std::reverse(newStrides.begin(), newStrides.end());
SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
OpBuilder b(extractOp.getContext());
extractOp.setAttr(ExtractOp::getPositionAttrName(),
b.getI64ArrayAttr(newPosition));
extractOp.setOperand(shapeCastOp.source());
return extractOp.getResult();
}
OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) { OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
if (succeeded(foldExtractOpFromExtractChain(*this))) if (succeeded(foldExtractOpFromExtractChain(*this)))
return getResult(); return getResult();
@ -852,6 +907,8 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
return val; return val;
if (auto val = foldExtractFromBroadcast(*this)) if (auto val = foldExtractFromBroadcast(*this))
return val; return val;
if (auto val = foldExtractFromShapeCast(*this))
return val;
return OpFoldResult(); return OpFoldResult();
} }

View File

@ -394,6 +394,39 @@ func @fold_extract_broadcast_negative(%a : f32) -> vector<4xf32> {
return %r : vector<4xf32> return %r : vector<4xf32>
} }
// -----
// CHECK-LABEL: func @fold_extract_shapecast
// CHECK-SAME: (%[[A0:.*]]: vector<5x1x3x2xf32>, %[[A1:.*]]: vector<8x4x2xf32>
// CHECK: %[[R0:.*]] = vector.extract %[[A0]][1, 0, 1, 1] : vector<5x1x3x2xf32>
// CHECK: %[[R1:.*]] = vector.extract %[[A0]][1, 0, 2] : vector<5x1x3x2xf32>
// CHECK: %[[R2:.*]] = vector.extract %[[A1]][7] : vector<8x4x2xf32>
// CHECK: return %[[R0]], %[[R1]], %[[R2]] : f32, vector<2xf32>, vector<4x2xf32>
func @fold_extract_shapecast(%arg0 : vector<5x1x3x2xf32>,
%arg1 : vector<8x4x2xf32>)
-> (f32, vector<2xf32>, vector<4x2xf32>) {
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xf32>
%1 = vector.shape_cast %arg1 : vector<8x4x2xf32> to vector<4x2x4x2xf32>
%r1 = vector.extract %0[4, 1] : vector<15x2xf32>
%r2 = vector.extract %0[5] : vector<15x2xf32>
%r3 = vector.extract %1[3, 1] : vector<4x2x4x2xf32>
return %r1, %r2, %r3 : f32, vector<2xf32>, vector<4x2xf32>
}
// -----
// CHECK-LABEL: fold_extract_shapecast_negative
// CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32>
// CHECK: %[[R:.*]] = vector.extract %[[V]][1] : vector<2x4x2xf32>
// CHECK: return %[[R]] : vector<4x2xf32>
func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>,
%arg1 : vector<8x4x2xf32>) -> vector<4x2xf32> {
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32>
%r = vector.extract %0[1] : vector<2x4x2xf32>
return %r : vector<4x2xf32>
}
// ----- // -----
// CHECK-LABEL: fold_vector_transfers // CHECK-LABEL: fold_vector_transfers