diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index b71102cde1cf..d1deb5abd541 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -843,6 +843,61 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) { return Value(); } +// Fold extractOp with source coming from ShapeCast op. +static Value foldExtractFromShapeCast(ExtractOp extractOp) { + auto shapeCastOp = extractOp.vector().getDefiningOp(); + if (!shapeCastOp) + return Value(); + // Get the nth dimension size starting from lowest dimension. + auto getDimReverse = [](VectorType type, int64_t n) { + return type.getDimSize(type.getRank() - n - 1); + }; + int64_t destinationRank = + extractOp.getVectorType().getRank() - extractOp.position().size(); + if (destinationRank > shapeCastOp.getSourceVectorType().getRank()) + return Value(); + if (destinationRank > 0) { + auto destinationType = extractOp.getResult().getType().cast(); + for (int64_t i = 0; i < destinationRank; i++) { + // The lowest dimension of of the destination must match the lowest + // dimension of the shapecast op source. + if (getDimReverse(shapeCastOp.getSourceVectorType(), i) != + getDimReverse(destinationType, i)) + return Value(); + } + } + // Extract the strides associated with the extract op vector source. Then use + // this to calculate a linearized position for the extract. + auto extractedPos = extractVector(extractOp.position()); + std::reverse(extractedPos.begin(), extractedPos.end()); + SmallVector strides; + int64_t stride = 1; + for (int64_t i = 0, e = extractedPos.size(); i < e; i++) { + strides.push_back(stride); + stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank); + } + + int64_t position = linearize(extractedPos, strides); + // Then extract the strides assoociated to the shapeCast op vector source and + // delinearize the position using those strides. + SmallVector newStrides; + int64_t numDimension = + shapeCastOp.getSourceVectorType().getRank() - destinationRank; + stride = 1; + for (int64_t i = 0; i < numDimension; i++) { + newStrides.push_back(stride); + stride *= + getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank); + } + std::reverse(newStrides.begin(), newStrides.end()); + SmallVector newPosition = delinearize(newStrides, position); + OpBuilder b(extractOp.getContext()); + extractOp.setAttr(ExtractOp::getPositionAttrName(), + b.getI64ArrayAttr(newPosition)); + extractOp.setOperand(shapeCastOp.source()); + return extractOp.getResult(); +} + OpFoldResult ExtractOp::fold(ArrayRef) { if (succeeded(foldExtractOpFromExtractChain(*this))) return getResult(); @@ -852,6 +907,8 @@ OpFoldResult ExtractOp::fold(ArrayRef) { return val; if (auto val = foldExtractFromBroadcast(*this)) return val; + if (auto val = foldExtractFromShapeCast(*this)) + return val; return OpFoldResult(); } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 2f927a1bbc81..66bad06e6b60 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -394,6 +394,39 @@ func @fold_extract_broadcast_negative(%a : f32) -> vector<4xf32> { return %r : vector<4xf32> } +// ----- + +// CHECK-LABEL: func @fold_extract_shapecast +// CHECK-SAME: (%[[A0:.*]]: vector<5x1x3x2xf32>, %[[A1:.*]]: vector<8x4x2xf32> +// CHECK: %[[R0:.*]] = vector.extract %[[A0]][1, 0, 1, 1] : vector<5x1x3x2xf32> +// CHECK: %[[R1:.*]] = vector.extract %[[A0]][1, 0, 2] : vector<5x1x3x2xf32> +// CHECK: %[[R2:.*]] = vector.extract %[[A1]][7] : vector<8x4x2xf32> +// CHECK: return %[[R0]], %[[R1]], %[[R2]] : f32, vector<2xf32>, vector<4x2xf32> +func @fold_extract_shapecast(%arg0 : vector<5x1x3x2xf32>, + %arg1 : vector<8x4x2xf32>) + -> (f32, vector<2xf32>, vector<4x2xf32>) { + %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xf32> + %1 = vector.shape_cast %arg1 : vector<8x4x2xf32> to vector<4x2x4x2xf32> + %r1 = vector.extract %0[4, 1] : vector<15x2xf32> + %r2 = vector.extract %0[5] : vector<15x2xf32> + %r3 = vector.extract %1[3, 1] : vector<4x2x4x2xf32> + return %r1, %r2, %r3 : f32, vector<2xf32>, vector<4x2xf32> +} + +// ----- + +// CHECK-LABEL: fold_extract_shapecast_negative +// CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32> +// CHECK: %[[R:.*]] = vector.extract %[[V]][1] : vector<2x4x2xf32> +// CHECK: return %[[R]] : vector<4x2xf32> +func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>, + %arg1 : vector<8x4x2xf32>) -> vector<4x2xf32> { + %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32> + %r = vector.extract %0[1] : vector<2x4x2xf32> + return %r : vector<4x2xf32> +} + + // ----- // CHECK-LABEL: fold_vector_transfers