diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 819b8382432d..955ac113c03c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -650,6 +650,27 @@ mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op, // Misc. vectorization patterns. //----------------------------------------------------------------------------// +/// Helper function that retrieves the value of an IntegerAttr. +static int64_t getIntFromAttr(Attribute attr) { + return attr.cast().getInt(); +} + +/// Given an ArrayRef of OpFoldResults, return a vector of Values. IntegerAttrs +/// are converted to ConstantIndexOps. Other attribute types are not supported. +static SmallVector ofrToIndexValues(OpBuilder &builder, Location loc, + ArrayRef ofrs) { + SmallVector result; + llvm::for_each(ofrs, [&](auto o) { + if (auto val = o.template dyn_cast()) { + result.push_back(val); + } else { + result.push_back(builder.create( + loc, getIntFromAttr(o.template get()))); + } + }); + return result; +} + /// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and /// TransferWriteOp. For now, this only applies when all low and high paddings /// are determined to be zero. @@ -763,12 +784,87 @@ struct PadTensorOpVectorizationWithTransferReadPattern } }; +/// Rewrite use of PadTensorOp result in SubtensorInsertOp. E.g.: +/// ``` +/// %0 = linalg.pad_tensor %src ... : tensor to tensor<17x5xf32> +/// %r = subtensor_insert %0 into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1] +/// : tensor<17x5xf32> into tensor +/// ``` +/// is rewritten to: +/// ``` +/// %0 = vector.transfer_read %src[%c0, %c0], %padding +/// : tensor, vector<17x5xf32> +/// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0] +/// {in_bounds = [true, true]} : vector<17x5xf32>, tensor +/// ``` +/// +/// This rewrite is possible if: +/// - Low padding is static 0. +/// - `padOp` result shape is static. +/// - The entire padded tensor is inserted. +/// (Implies that sizes of `insertOp` are all static.) +/// - Only unit strides in `insertOp`. +/// - Single, scalar padding value. +struct PadTensorOpVectorizationWithSubTensorInsertPattern + : public VectorizePadTensorOpUserPattern { + using VectorizePadTensorOpUserPattern< + SubTensorInsertOp>::VectorizePadTensorOpUserPattern; + + LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp, + SubTensorInsertOp insertOp) const override { + // Low padding must be static 0. + if (!padOp.hasZeroLowPad()) return failure(); + // Only unit stride supported. + if (!insertOp.hasUnitStride()) return failure(); + // Pad value must be a constant. + auto padValue = padOp.getConstantPaddingValue(); + if (!padValue) + return failure(); + // Dynamic shapes not supported. + if (!padOp.result().getType().cast().hasStaticShape()) + return failure(); + + auto vecType = VectorType::get(padOp.getType().getShape(), + padOp.getType().getElementType()); + unsigned vecRank = vecType.getRank(); + unsigned tensorRank = insertOp.getType().getRank(); + + // Check if sizes match: Insert the entire tensor into most minor dims. + // (No permutations allowed.) + SmallVector expectedSizes(tensorRank - vecRank, 1); + expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end()); + if (!llvm::all_of( + llvm::zip(insertOp.getMixedSizes(), expectedSizes), + [](auto it) { return isEqualConstantInt(std::get<0>(it), + std::get<1>(it)); })) + return failure(); + + // Generate TransferReadOp: Read entire source tensor and add high padding. + SmallVector readIndices( + vecRank, rewriter.create(padOp.getLoc(), 0)); + auto read = rewriter.create( + padOp.getLoc(), vecType, padOp.source(), readIndices, padValue); + + // Generate TransferWriteOp: Write to SubTensorInsertOp's dest tensor at + // specified offsets. Write is fully in-bounds because a SubTensorInsertOp's + // source must fit into the destination at the specified offsets. + auto writeIndices = + ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets()); + SmallVector inBounds(vecRank, true); + rewriter.replaceOpWithNewOp( + insertOp, read, insertOp.dest(), writeIndices, inBounds); + + return success(); + } +}; + void mlir::linalg::populatePadTensorOpVectorizationPatterns( RewritePatternSet &patterns, PatternBenefit baseBenefit) { patterns.add( patterns.getContext(), baseBenefit); // Try these specialized patterns first before resorting to the generic one. - patterns.add( + patterns.add( patterns.getContext(), baseBenefit.getBenefit() + 1); } diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir index ab55f879f8ae..04c3d8435884 100644 --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -580,6 +580,28 @@ func @pad_and_transfer_read(%arg0: tensor<5x6xf32>) -> vector<7x9xf32> { // ----- +// CHECK-LABEL: func @pad_and_subtensor_insert +// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>, %[[ARG1:.*]]: tensor<12x13xf32> +// CHECK-NOT: linalg.pad_tensor +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C5:.*]] = constant 5.0 +// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32> +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<7x9xf32>, tensor<12x13xf32> +// CHECK: return %[[WRITE]] +func @pad_and_subtensor_insert( + %arg0: tensor<5x6xf32>, %arg1: tensor<12x13xf32>) -> tensor<12x13xf32> { + %c0 = constant 0 : index + %c5 = constant 5.0 : f32 + %0 = linalg.pad_tensor %arg0 low[0, 0] high[2, 3] { + ^bb0(%arg2: index, %arg3: index): + linalg.yield %c5 : f32 + } : tensor<5x6xf32> to tensor<7x9xf32> + %r = subtensor_insert %0 into %arg1[0, 0][7, 9][1, 1] : tensor<7x9xf32> into tensor<12x13xf32> + return %r : tensor<12x13xf32> +} + +// ----- + // CHECK-DAG: #[[$M0:.*]] = affine_map<(d0, d1) -> (d0, d1, 0)> // CHECK-LABEL: func @sum_exp