[MLIR] Updates around MemRef Normalization

The documentation for the NormalizeMemRefs pass and the associated MemRefsNormalizable
traits was confusing and not on the website.  This update clarifies the language
around the difference between a MemRef Type, an operation that accesses the value of
MemRef Type, and better documents the limitations of the current implementation.
This patch also includes some basic debugging information for the pass so people
might have a chance of figuring out why it doesn't work on their code.

Differential Revision: https://reviews.llvm.org/D88532
This commit is contained in:
Stephen Neuendorffer 2020-09-29 17:14:42 -07:00
parent b8ac19cf1c
commit 47df8c57e4
4 changed files with 130 additions and 42 deletions

View File

@ -251,13 +251,15 @@ to have [passes](PassManagement.md) scheduled under them.
* `OpTrait::MemRefsNormalizable` -- `MemRefsNormalizable`
This trait is used to flag operations that can accommodate `MemRefs` with
non-identity memory-layout specifications. This trait indicates that the
normalization of memory layout can be performed for such operations.
`MemRefs` normalization consists of replacing an original memory reference
with layout specifications to an equivalent memory reference where
the specified memory layout is applied by rewritting accesses and types
associated with that memory reference.
This trait is used to flag operations that consume or produce
values of `MemRef` type where those references can be 'normalized'.
In cases where an associated `MemRef` has a
non-identity memory-layout specification, such normalizable operations can be
modified so that the `MemRef` has an identity layout specification.
This can be implemented by associating the operation with its own
index expression that can express the equivalent of the memory-layout
specification of the MemRef type. See [the -normalize-memrefs pass].
(https://mlir.llvm.org/docs/Passes/#-normalize-memrefs-normalize-memrefs)
### Single Block with Implicit Terminator

View File

@ -1212,13 +1212,8 @@ struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> {
}
};
/// This trait is used to flag operations that can accommodate MemRefs with
/// non-identity memory-layout specifications. This trait indicates that the
/// normalization of memory layout can be performed for such operations.
/// MemRefs normalization consists of replacing an original memory reference
/// with layout specifications to an equivalent memory reference where the
/// specified memory layout is applied by rewritting accesses and types
/// associated with that memory reference.
// This trait is used to flag operations that consume or produce
// values of `MemRef` type where those references can be 'normalized'.
// TODO: Right now, the operands of an operation are either all normalizable,
// or not. In the future, we may want to allow some of the operands to be
// normalizable.

View File

@ -313,6 +313,116 @@ def MemRefDataFlowOpt : FunctionPass<"memref-dataflow-opt"> {
def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> {
let summary = "Normalize memrefs";
let description = [{
This pass transforms memref types with a non-trivial
[layout map](https://mlir.llvm.org/docs/LangRef/#layout-map) into
memref types with an identity layout map, e.g. (i, j) -> (i, j). This
pass is inter-procedural, in the sense that it can modify function
interfaces and call sites that pass memref types. In order to modify
memref types while preserving the original behavior, users of those
memref types are also modified to incorporate the resulting layout map.
For instance, an [AffineLoadOp]
(https://mlir.llvm.org/docs/Dialects/Affine/#affineload-affineloadop)
will be updated to compose the layout map with with the affine expression
contained in the op. Operations marked with the [MemRefsNormalizable]
(https://mlir.llvm.org/docs/Traits/#memrefsnormalizable) trait are
expected to be normalizable. Supported operations include affine
operations, std.alloc, std.dealloc, and std.return.
Given an appropriate layout map specified in the code, this transformation
can express tiled or linearized access to multi-dimensional data
structures, but will not modify memref types without an explicit layout
map.
Currently this pass is limited to only modify
functions where all memref types can be normalized. If a function
contains any operations that are not MemRefNormalizable, then the function
and any functions that call or call it will not be modified.
Input
```mlir
#tile = affine_map<(i) -> (i floordiv 4, i mod 4)>
func @matmul(%A: memref<16xf64, #tile>,
%B: index, %C: memref<16xf64>) -> (memref<16xf64, #tile>) {
affine.for %arg3 = 0 to 16 {
%a = affine.load %A[%arg3] : memref<16xf64, #tile>
%p = mulf %a, %a : f64
affine.store %p, %A[%arg3] : memref<16xf64, #tile>
}
%c = alloc() : memref<16xf64, #tile>
%d = affine.load %c[0] : memref<16xf64, #tile>
return %A: memref<16xf64, #tile>
}
```
Output
```mlir
func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>)
-> memref<4x4xf64> {
affine.for %arg3 = 0 to 16 {
%3 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64>
%4 = mulf %3, %3 : f64
affine.store %4, %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64>
}
%0 = alloc() : memref<4x4xf64>
%1 = affine.apply #map1()
%2 = affine.load %0[0, 0] : memref<4x4xf64>
return %arg0 : memref<4x4xf64>
}
```
Input
```
#linear8 = affine_map<(i, j) -> (i * 8 + j)>
func @linearize(%arg0: memref<8x8xi32, #linear8>,
%arg1: memref<8x8xi32, #linear8>,
%arg2: memref<8x8xi32, #linear8>) {
%c8 = constant 8 : index
%c0 = constant 0 : index
%c1 = constant 1 : index
affine.for %arg3 = %c0 to %c8 {
affine.for %arg4 = %c0 to %c8 {
affine.for %arg5 = %c0 to %c8 {
%0 = affine.load %arg0[%arg3, %arg5] : memref<8x8xi32, #linear8>
%1 = affine.load %arg1[%arg5, %arg4] : memref<8x8xi32, #linear8>
%2 = affine.load %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8>
%3 = muli %0, %1 : i32
%4 = addi %2, %3 : i32
affine.store %4, %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8>
}
}
}
return
}
```
Output
```mlir
func @linearize(%arg0: memref<64xi32>,
%arg1: memref<64xi32>,
%arg2: memref<64xi32>) {
%c8 = constant 8 : index
%c0 = constant 0 : index
affine.for %arg3 = %c0 to %c8 {
affine.for %arg4 = %c0 to %c8 {
affine.for %arg5 = %c0 to %c8 {
%0 = affine.load %arg0[%arg3 * 8 + %arg5] : memref<64xi32>
%1 = affine.load %arg1[%arg5 * 8 + %arg4] : memref<64xi32>
%2 = affine.load %arg2[%arg3 * 8 + %arg4] : memref<64xi32>
%3 = muli %0, %1 : i32
%4 = addi %2, %3 : i32
affine.store %4, %arg2[%arg3 * 8 + %arg4] : memref<64xi32>
}
}
}
return
}
```
}];
let constructor = "mlir::createNormalizeMemRefsPass()";
}

View File

@ -29,34 +29,6 @@ namespace {
/// such functions as normalizable. Also, if a normalizable function is known
/// to call a non-normalizable function, we treat that function as
/// non-normalizable as well. We assume external functions to be normalizable.
///
/// Input :-
/// #tile = affine_map<(i) -> (i floordiv 4, i mod 4)>
/// func @matmul(%A: memref<16xf64, #tile>, %B: index, %C: memref<16xf64>) ->
/// (memref<16xf64, #tile>) {
/// affine.for %arg3 = 0 to 16 {
/// %a = affine.load %A[%arg3] : memref<16xf64, #tile>
/// %p = mulf %a, %a : f64
/// affine.store %p, %A[%arg3] : memref<16xf64, #tile>
/// }
/// %c = alloc() : memref<16xf64, #tile>
/// %d = affine.load %c[0] : memref<16xf64, #tile>
/// return %A: memref<16xf64, #tile>
/// }
///
/// Output :-
/// func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>)
/// -> memref<4x4xf64> {
/// affine.for %arg3 = 0 to 16 {
/// %2 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4] :
/// memref<4x4xf64> %3 = mulf %2, %2 : f64 affine.store %3, %arg0[%arg3
/// floordiv 4, %arg3 mod 4] : memref<4x4xf64>
/// }
/// %0 = alloc() : memref<16xf64, #map0>
/// %1 = affine.load %0[0] : memref<16xf64, #map0>
/// return %arg0 : memref<4x4xf64>
/// }
///
struct NormalizeMemRefs : public NormalizeMemRefsBase<NormalizeMemRefs> {
void runOnOperation() override;
void normalizeFuncOpMemRefs(FuncOp funcOp, ModuleOp moduleOp);
@ -73,6 +45,7 @@ std::unique_ptr<OperationPass<ModuleOp>> mlir::createNormalizeMemRefsPass() {
}
void NormalizeMemRefs::runOnOperation() {
LLVM_DEBUG(llvm::dbgs() << "Normalizing Memrefs...\n");
ModuleOp moduleOp = getOperation();
// We maintain all normalizable FuncOps in a DenseSet. It is initialized
// with all the functions within a module and then functions which are not
@ -92,6 +65,9 @@ void NormalizeMemRefs::runOnOperation() {
moduleOp.walk([&](FuncOp funcOp) {
if (normalizableFuncs.contains(funcOp)) {
if (!areMemRefsNormalizable(funcOp)) {
LLVM_DEBUG(llvm::dbgs()
<< "@" << funcOp.getName()
<< " contains ops that cannot normalize MemRefs\n");
// Since this function is not normalizable, we set all the caller
// functions and the callees of this function as not normalizable.
// TODO: Drop this conservative assumption in the future.
@ -101,6 +77,8 @@ void NormalizeMemRefs::runOnOperation() {
}
});
LLVM_DEBUG(llvm::dbgs() << "Normalizing " << normalizableFuncs.size()
<< " functions\n");
// Those functions which can be normalized are subjected to normalization.
for (FuncOp &funcOp : normalizableFuncs)
normalizeFuncOpMemRefs(funcOp, moduleOp);
@ -127,6 +105,9 @@ void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
if (!normalizableFuncs.contains(funcOp))
return;
LLVM_DEBUG(
llvm::dbgs() << "@" << funcOp.getName()
<< " calls or is called by non-normalizable function\n");
normalizableFuncs.erase(funcOp);
// Caller of the function.
Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);