forked from OSchip/llvm-project
[mlir][SCF] Fold tensor.cast feeding into scf.foreach_thread.parallel_insert_slice
Differential Revision: https://reviews.llvm.org/D128247
This commit is contained in:
parent
858be16670
commit
98dbaed1e6
|
@ -564,6 +564,7 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
|
||||||
];
|
];
|
||||||
|
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -1276,6 +1276,40 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
/// Fold a parallel_insert_slice source coming from a tensor.cast op.
|
||||||
|
///
|
||||||
|
/// Example:
|
||||||
|
/// ```
|
||||||
|
/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
|
||||||
|
/// %1 = compute_some_tensor() : tensor<64xf32>
|
||||||
|
/// %2 = tensor.cast %1 : tensor<64xf32> to tensor<?xf32>
|
||||||
|
/// scf.foreach_thread.perform_concurrently {
|
||||||
|
/// scf.foreach_thread.parallel_insert_slice %2 into %out[...] [64] [1] :
|
||||||
|
/// tensor<?xf32> into tensor<128xf32>
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// is folded into:
|
||||||
|
/// ```
|
||||||
|
/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
|
||||||
|
/// %1 = compute_some_tensor() : tensor<64xf32>
|
||||||
|
/// scf.foreach_thread.perform_concurrently {
|
||||||
|
/// scf.foreach_thread.parallel_insert_slice %1 into %out[...] [64] [1] :
|
||||||
|
/// tensor<64xf32> into tensor<128xf32>
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
LogicalResult
|
||||||
|
ParallelInsertSliceOp::fold(ArrayRef<Attribute> operands,
|
||||||
|
SmallVectorImpl<OpFoldResult> &results) {
|
||||||
|
auto sourceCast = getSource().getDefiningOp<tensor::CastOp>();
|
||||||
|
if (!sourceCast)
|
||||||
|
return failure();
|
||||||
|
getSourceMutable().assign(sourceCast.getSource());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
void ParallelInsertSliceOp::getCanonicalizationPatterns(
|
void ParallelInsertSliceOp::getCanonicalizationPatterns(
|
||||||
RewritePatternSet &results, MLIRContext *context) {
|
RewritePatternSet &results, MLIRContext *context) {
|
||||||
results.add<ParallelInsertSliceOpConstantArgumentFolder>(context);
|
results.add<ParallelInsertSliceOpConstantArgumentFolder>(context);
|
||||||
|
|
|
@ -26,9 +26,8 @@ func.func @reduce() -> tensor<128xf32> {
|
||||||
linalg.yield %14 : f32
|
linalg.yield %14 : f32
|
||||||
} -> tensor<?xf32>
|
} -> tensor<?xf32>
|
||||||
|
|
||||||
// TODO: canonicalize this cast away.
|
// CHECK-NOT: tensor.cast
|
||||||
// CHECK: %[[dyn_casted:.*]] = tensor.cast %{{.*}} : tensor<64xf32> to tensor<?xf32>
|
// CHECK: scf.foreach_thread.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [64] [1] : tensor<64xf32> into tensor<128xf32>
|
||||||
// CHECK: scf.foreach_thread.parallel_insert_slice %[[dyn_casted:.*]] into %{{.*}}[%{{.*}}] [64] [1] : tensor<?xf32> into tensor<128xf32>
|
|
||||||
scf.foreach_thread.perform_concurrently {
|
scf.foreach_thread.perform_concurrently {
|
||||||
scf.foreach_thread.parallel_insert_slice %13 into %3[%9] [%10] [1] : tensor<?xf32> into tensor<128xf32>
|
scf.foreach_thread.parallel_insert_slice %13 into %3[%9] [%10] [1] : tensor<?xf32> into tensor<128xf32>
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue