[mlir][vector] Address post-commit review comments on vector ops folding patterns

Differential Revision: https://reviews.llvm.org/D90183
This commit is contained in:
Thomas Raoux 2020-11-02 10:18:38 -08:00
parent b85f2f5c5f
commit 9081e7594d
2 changed files with 15 additions and 11 deletions

View File

@ -850,10 +850,12 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
return Value();
// Get the nth dimension size starting from lowest dimension.
auto getDimReverse = [](VectorType type, int64_t n) {
return type.getDimSize(type.getRank() - n - 1);
return type.getShape().take_back(n+1).front();
};
int64_t destinationRank =
extractOp.getVectorType().getRank() - extractOp.position().size();
extractOp.getType().isa<VectorType>()
? extractOp.getType().cast<VectorType>().getRank()
: 0;
if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
return Value();
if (destinationRank > 0) {
@ -861,6 +863,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
for (int64_t i = 0; i < destinationRank; i++) {
// The lowest dimension of of the destination must match the lowest
// dimension of the shapecast op source.
// TODO: This case could be support in a canonicalization pattern.
if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
getDimReverse(destinationType, i))
return Value();
@ -891,6 +894,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
}
std::reverse(newStrides.begin(), newStrides.end());
SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
extractOp.setAttr(ExtractOp::getPositionAttrName(),
b.getI64ArrayAttr(newPosition));
@ -1632,8 +1636,8 @@ static LogicalResult verify(ExtractStridedSliceOp op) {
}
// When the source of ExtractStrided comes from a chain of InsertStrided ops try
// to use the source o the InsertStrided ops if we can detect that the extracted
// vector is a subset of one of the vector inserted.
// to use the source of the InsertStrided ops if we can detect that the
// extracted vector is a subset of one of the vector inserted.
static LogicalResult
foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
// Helper to extract integer out of ArrayAttr.

View File

@ -160,20 +160,20 @@ func @extract_strided_fold_negative(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
// Case where we need to go through 2 level of insert element.
// CHECK-LABEL: extract_strided_fold_insert
// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>,
// CHECK-SAME: (%[[ARG0:.*]]: vector<2x8xf32>, %[[ARG1:.*]]: vector<1x4xf32>,
// CHECK-NEXT: %[[EXT:.*]] = vector.extract_strided_slice %[[ARG1]]
// CHECK-SAME: {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]}
// CHECK-SAME: {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]}
// CHECK-SAME: : vector<1x4xf32> to vector<1x1xf32>
// CHECK-NEXT: return %[[EXT]] : vector<1x1xf32>
func @extract_strided_fold_insert(%a: vector<2x4xf32>, %b: vector<1x4xf32>,
func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>,
%c : vector<1x4xf32>) -> (vector<1x1xf32>) {
%0 = vector.insert_strided_slice %b, %a {offsets = [0, 0], strides = [1, 1]}
: vector<1x4xf32> into vector<2x4xf32>
%0 = vector.insert_strided_slice %b, %a {offsets = [0, 1], strides = [1, 1]}
: vector<1x4xf32> into vector<2x8xf32>
%1 = vector.insert_strided_slice %c, %0 {offsets = [1, 0], strides = [1, 1]}
: vector<1x4xf32> into vector<2x4xf32>
: vector<1x4xf32> into vector<2x8xf32>
%2 = vector.extract_strided_slice %1
{offsets = [0, 1], sizes = [1, 1], strides = [1, 1]}
: vector<2x4xf32> to vector<1x1xf32>
: vector<2x8xf32> to vector<1x1xf32>
return %2 : vector<1x1xf32>
}