forked from OSchip/llvm-project
[mlir][SCF] Fold tensor.cast feeding into scf.foreach_thread.parallel_insert_slice
Differential Revision: https://reviews.llvm.org/D128247
This commit is contained in:
parent
858be16670
commit
98dbaed1e6
|
@ -564,6 +564,7 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
|
|||
];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1276,6 +1276,40 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
/// Fold a parallel_insert_slice source coming from a tensor.cast op.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
|
||||
/// %1 = compute_some_tensor() : tensor<64xf32>
|
||||
/// %2 = tensor.cast %1 : tensor<64xf32> to tensor<?xf32>
|
||||
/// scf.foreach_thread.perform_concurrently {
|
||||
/// scf.foreach_thread.parallel_insert_slice %2 into %out[...] [64] [1] :
|
||||
/// tensor<?xf32> into tensor<128xf32>
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// is folded into:
|
||||
/// ```
|
||||
/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
|
||||
/// %1 = compute_some_tensor() : tensor<64xf32>
|
||||
/// scf.foreach_thread.perform_concurrently {
|
||||
/// scf.foreach_thread.parallel_insert_slice %1 into %out[...] [64] [1] :
|
||||
/// tensor<64xf32> into tensor<128xf32>
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
LogicalResult
|
||||
ParallelInsertSliceOp::fold(ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
auto sourceCast = getSource().getDefiningOp<tensor::CastOp>();
|
||||
if (!sourceCast)
|
||||
return failure();
|
||||
getSourceMutable().assign(sourceCast.getSource());
|
||||
return success();
|
||||
}
|
||||
|
||||
void ParallelInsertSliceOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &results, MLIRContext *context) {
|
||||
results.add<ParallelInsertSliceOpConstantArgumentFolder>(context);
|
||||
|
|
|
@ -26,9 +26,8 @@ func.func @reduce() -> tensor<128xf32> {
|
|||
linalg.yield %14 : f32
|
||||
} -> tensor<?xf32>
|
||||
|
||||
// TODO: canonicalize this cast away.
|
||||
// CHECK: %[[dyn_casted:.*]] = tensor.cast %{{.*}} : tensor<64xf32> to tensor<?xf32>
|
||||
// CHECK: scf.foreach_thread.parallel_insert_slice %[[dyn_casted:.*]] into %{{.*}}[%{{.*}}] [64] [1] : tensor<?xf32> into tensor<128xf32>
|
||||
// CHECK-NOT: tensor.cast
|
||||
// CHECK: scf.foreach_thread.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [64] [1] : tensor<64xf32> into tensor<128xf32>
|
||||
scf.foreach_thread.perform_concurrently {
|
||||
scf.foreach_thread.parallel_insert_slice %13 into %3[%9] [%10] [1] : tensor<?xf32> into tensor<128xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue