Enable ReassociatingReshapeOpConversion with "non-identity" layouts.

Enable ReassociatingReshapeOpConversion with "non-identity" layouts.

This removes an early-return in this function, which seems unnecessary and is
preventing some memref.collapse_shape from converting to LLVM (see included lit test).

It seems unnecessary because the return message says "only empty layout map is supported"
but there actually is code in this function to deal with non-empty layout maps. Maybe
it refers to an earlier state of implementation and is just out of date?

Though, there is another concern about this early return: the condition that it actually
checks, `{src,dst}MemrefType.getLayout().isIdentity()`, is not quite the same as what the
return message says, "only empty layout map is supported". Stepping through this
`getLayout().isIdentity()` code in GDB, I found that it evaluates to `.getAffineMap().isIdentity()`
which does (AffineMap.cpp:271):

```
  if (getNumDims() != getNumResults())
    return false;
```

This seems that it would always return false for memrefs of rank greater than 1 ?

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D114808
This commit is contained in:
Benoit Jacob 2022-01-13 17:39:06 +00:00 committed by Lei Zhang
parent fced2744d3
commit 499703e9c0
2 changed files with 17 additions and 4 deletions

View File

@ -1168,10 +1168,14 @@ public:
ConversionPatternRewriter &rewriter) const override {
MemRefType dstType = reshapeOp.getResultType();
MemRefType srcType = reshapeOp.getSrcType();
if (!srcType.getLayout().isIdentity() ||
!dstType.getLayout().isIdentity()) {
return rewriter.notifyMatchFailure(reshapeOp,
"only empty layout map is supported");
// The condition on the layouts can be ignored when all shapes are static.
if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) {
if (!srcType.getLayout().isIdentity() ||
!dstType.getLayout().isIdentity()) {
return rewriter.notifyMatchFailure(
reshapeOp, "only empty layout map is supported");
}
}
int64_t offset;

View File

@ -883,3 +883,12 @@ func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fval :
// CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel
return
}
// -----
// CHECK-LABEL: func @collapse_static_shape_with_non_identity_layout
func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 64 + s0 + d1 * 64 + d2 * 8 + d3)>>) -> memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> {
// CHECK-NOT: memref.collapse_shape
%1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 64 + s0 + d1 * 64 + d2 * 8 + d3)>> into memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
return %1 : memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
}