forked from OSchip/llvm-project
[MLIR][normalize-memrefs] Non-normalizable operations with identity map layouts do not block normalization of the entire function
The current approach is convervative in which whenever there is a non-normalizable operations in a function will the function be labelled as non-normalizable. It means it requires that all operations must have MemRefsNormalizable trait. This patch relaxes the requirement that if the memref map layouts of a non-normalizable operation are identity, this operation does not block the normalization of the other operations in the same function. Reviewed By: bondhugula Differential Revision: https://reviews.llvm.org/D125854
This commit is contained in:
parent
e941b031d3
commit
183c4a391e
|
@ -145,10 +145,10 @@ void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
|
|||
/// Check whether all the uses of AllocOps, CallOps and function arguments of a
|
||||
/// function are either of dereferencing type or are uses in: DeallocOp, CallOp
|
||||
/// or ReturnOp. Only if these constraints are satisfied will the function
|
||||
/// become a candidate for normalization. We follow a conservative approach here
|
||||
/// wherein even if the non-normalizable memref is not a part of the function's
|
||||
/// argument or return type, we still label the entire function as
|
||||
/// non-normalizable. We assume external functions to be normalizable.
|
||||
/// become a candidate for normalization. When the uses of a memref are
|
||||
/// non-normalizable and the memref map layout is trivial (identity), we can
|
||||
/// still label the entire function as normalizable. We assume external
|
||||
/// functions to be normalizable.
|
||||
bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
|
||||
// We assume external functions to be normalizable.
|
||||
if (funcOp.isExternal())
|
||||
|
@ -157,7 +157,11 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
|
|||
if (funcOp
|
||||
.walk([&](memref::AllocOp allocOp) -> WalkResult {
|
||||
Value oldMemRef = allocOp.getResult();
|
||||
if (!isMemRefNormalizable(oldMemRef.getUsers()))
|
||||
if (!oldMemRef.getType()
|
||||
.cast<MemRefType>()
|
||||
.getLayout()
|
||||
.isIdentity() &&
|
||||
!isMemRefNormalizable(oldMemRef.getUsers()))
|
||||
return WalkResult::interrupt();
|
||||
return WalkResult::advance();
|
||||
})
|
||||
|
@ -170,7 +174,11 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
|
|||
llvm::seq<unsigned>(0, callOp.getNumResults())) {
|
||||
Value oldMemRef = callOp.getResult(resIndex);
|
||||
if (oldMemRef.getType().isa<MemRefType>())
|
||||
if (!isMemRefNormalizable(oldMemRef.getUsers()))
|
||||
if (!oldMemRef.getType()
|
||||
.cast<MemRefType>()
|
||||
.getLayout()
|
||||
.isIdentity() &&
|
||||
!isMemRefNormalizable(oldMemRef.getUsers()))
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
return WalkResult::advance();
|
||||
|
@ -181,7 +189,8 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
|
|||
for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
|
||||
BlockArgument oldMemRef = funcOp.getArgument(argIndex);
|
||||
if (oldMemRef.getType().isa<MemRefType>())
|
||||
if (!isMemRefNormalizable(oldMemRef.getUsers()))
|
||||
if (!oldMemRef.getType().cast<MemRefType>().getLayout().isIdentity() &&
|
||||
!isMemRefNormalizable(oldMemRef.getUsers()))
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -41,6 +41,24 @@ func.func @test_nonnorm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
|
|||
return
|
||||
}
|
||||
|
||||
// Test with op_nonnorm whose memref map layouts are identity. This op_nonnorm
|
||||
// does not block the normalization of other operations.
|
||||
|
||||
// CHECK-LABEL: test_nonnorm_identity_layout
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x64xf32>)
|
||||
func.func @test_nonnorm_identity_layout(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
|
||||
%0 = memref.alloc() : memref<1x16x14x14xf32>
|
||||
"test.op_nonnorm"(%0, %0) : (memref<1x16x14x14xf32>, memref<1x16x14x14xf32>) -> ()
|
||||
"test.op_norm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32>) -> ()
|
||||
memref.dealloc %0 : memref<1x16x14x14xf32>
|
||||
|
||||
// CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x14x14xf32>
|
||||
// CHECK: "test.op_nonnorm"(%[[v0]], %[[v0]]) : (memref<1x16x14x14xf32>, memref<1x16x14x14xf32>) -> ()
|
||||
// CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x14x14xf32>) -> ()
|
||||
// CHECK: memref.dealloc %[[v0]] : memref<1x16x14x14xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// Test with op_norm, with maps in the operations in the function.
|
||||
|
||||
// CHECK-LABEL: test_norm_mix
|
||||
|
|
Loading…
Reference in New Issue