[mlir][SCF] Fold tensor.cast feeding into scf.foreach_thread.parallel_insert_slice

Differential Revision: https://reviews.llvm.org/D128247
This commit is contained in:
Nicolas Vasilache 2022-06-21 00:41:31 -07:00
parent 858be16670
commit 98dbaed1e6
3 changed files with 37 additions and 3 deletions

View File

@ -564,6 +564,7 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
];
let hasCanonicalizer = 1;
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//

View File

@ -1276,6 +1276,40 @@ public:
};
} // namespace
/// Fold a parallel_insert_slice source coming from a tensor.cast op.
///
/// Example:
/// ```
/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
/// %1 = compute_some_tensor() : tensor<64xf32>
/// %2 = tensor.cast %1 : tensor<64xf32> to tensor<?xf32>
/// scf.foreach_thread.perform_concurrently {
/// scf.foreach_thread.parallel_insert_slice %2 into %out[...] [64] [1] :
/// tensor<?xf32> into tensor<128xf32>
/// }
/// }
/// ```
///
/// is folded into:
/// ```
/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
/// %1 = compute_some_tensor() : tensor<64xf32>
/// scf.foreach_thread.perform_concurrently {
/// scf.foreach_thread.parallel_insert_slice %1 into %out[...] [64] [1] :
/// tensor<64xf32> into tensor<128xf32>
/// }
/// }
/// ```
LogicalResult
ParallelInsertSliceOp::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
auto sourceCast = getSource().getDefiningOp<tensor::CastOp>();
if (!sourceCast)
return failure();
getSourceMutable().assign(sourceCast.getSource());
return success();
}
void ParallelInsertSliceOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<ParallelInsertSliceOpConstantArgumentFolder>(context);

View File

@ -26,9 +26,8 @@ func.func @reduce() -> tensor<128xf32> {
linalg.yield %14 : f32
} -> tensor<?xf32>
// TODO: canonicalize this cast away.
// CHECK: %[[dyn_casted:.*]] = tensor.cast %{{.*}} : tensor<64xf32> to tensor<?xf32>
// CHECK: scf.foreach_thread.parallel_insert_slice %[[dyn_casted:.*]] into %{{.*}}[%{{.*}}] [64] [1] : tensor<?xf32> into tensor<128xf32>
// CHECK-NOT: tensor.cast
// CHECK: scf.foreach_thread.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [64] [1] : tensor<64xf32> into tensor<128xf32>
scf.foreach_thread.perform_concurrently {
scf.foreach_thread.parallel_insert_slice %13 into %3[%9] [%10] [1] : tensor<?xf32> into tensor<128xf32>
}