[mlir:Transforms] Move NormalizeMemRefs to MemRef/Transforms/

Transforms/  should only contain transformations that are dialect-independent and
this pass interacts with MemRef operations (making it a better fit for living in that
dialect).

Differential Revision: https://reviews.llvm.org/D117841
This commit is contained in:
River Riddle 2022-01-20 15:16:17 -08:00
parent 0e9a4a3b65
commit 2e2c0738e8
9 changed files with 168 additions and 126 deletions

View File

@ -55,6 +55,10 @@ void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
/// load/store ops into `patterns`.
std::unique_ptr<Pass> createFoldSubViewOpsPass();
/// Creates an interprocedural pass to normalize memrefs to have a trivial
/// (identity) layout map.
std::unique_ptr<OperationPass<ModuleOp>> createNormalizeMemRefsPass();
/// Creates an operation pass to resolve `memref.dim` operations with values
/// that are defined by operations that implement the
/// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input

View File

@ -23,6 +23,122 @@ def FoldSubViewOps : Pass<"fold-memref-subview-ops"> {
];
}
def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> {
let summary = "Normalize memrefs";
let description = [{
This pass transforms memref types with a non-trivial
[layout map](https://mlir.llvm.org/docs/LangRef/#layout-map) into
memref types with an identity layout map, e.g. (i, j) -> (i, j). This
pass is inter-procedural, in the sense that it can modify function
interfaces and call sites that pass memref types. In order to modify
memref types while preserving the original behavior, users of those
memref types are also modified to incorporate the resulting layout map.
For instance, an [AffineLoadOp]
(https://mlir.llvm.org/docs/Dialects/Affine/#affineload-affineloadop)
will be updated to compose the layout map with with the affine expression
contained in the op. Operations marked with the [MemRefsNormalizable]
(https://mlir.llvm.org/docs/Traits/#memrefsnormalizable) trait are
expected to be normalizable. Supported operations include affine
operations, memref.alloc, memref.dealloc, and std.return.
Given an appropriate layout map specified in the code, this transformation
can express tiled or linearized access to multi-dimensional data
structures, but will not modify memref types without an explicit layout
map.
Currently this pass is limited to only modify
functions where all memref types can be normalized. If a function
contains any operations that are not MemRefNormalizable, then the function
and any functions that call or call it will not be modified.
Input
```mlir
#tile = affine_map<(i) -> (i floordiv 4, i mod 4)>
func @matmul(%A: memref<16xf64, #tile>,
%B: index, %C: memref<16xf64>) -> (memref<16xf64, #tile>) {
affine.for %arg3 = 0 to 16 {
%a = affine.load %A[%arg3] : memref<16xf64, #tile>
%p = arith.mulf %a, %a : f64
affine.store %p, %A[%arg3] : memref<16xf64, #tile>
}
%c = memref.alloc() : memref<16xf64, #tile>
%d = affine.load %c[0] : memref<16xf64, #tile>
return %A: memref<16xf64, #tile>
}
```
Output
```mlir
func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>)
-> memref<4x4xf64> {
affine.for %arg3 = 0 to 16 {
%3 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64>
%4 = arith.mulf %3, %3 : f64
affine.store %4, %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64>
}
%0 = memref.alloc() : memref<4x4xf64>
%1 = affine.apply #map1()
%2 = affine.load %0[0, 0] : memref<4x4xf64>
return %arg0 : memref<4x4xf64>
}
```
Input
```
#linear8 = affine_map<(i, j) -> (i * 8 + j)>
func @linearize(%arg0: memref<8x8xi32, #linear8>,
%arg1: memref<8x8xi32, #linear8>,
%arg2: memref<8x8xi32, #linear8>) {
%c8 = arith.constant 8 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
affine.for %arg3 = %c0 to %c8 {
affine.for %arg4 = %c0 to %c8 {
affine.for %arg5 = %c0 to %c8 {
%0 = affine.load %arg0[%arg3, %arg5] : memref<8x8xi32, #linear8>
%1 = affine.load %arg1[%arg5, %arg4] : memref<8x8xi32, #linear8>
%2 = affine.load %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8>
%3 = arith.muli %0, %1 : i32
%4 = arith.addi %2, %3 : i32
affine.store %4, %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8>
}
}
}
return
}
```
Output
```mlir
func @linearize(%arg0: memref<64xi32>,
%arg1: memref<64xi32>,
%arg2: memref<64xi32>) {
%c8 = arith.constant 8 : index
%c0 = arith.constant 0 : index
affine.for %arg3 = %c0 to %c8 {
affine.for %arg4 = %c0 to %c8 {
affine.for %arg5 = %c0 to %c8 {
%0 = affine.load %arg0[%arg3 * 8 + %arg5] : memref<64xi32>
%1 = affine.load %arg1[%arg5 * 8 + %arg4] : memref<64xi32>
%2 = affine.load %arg2[%arg3 * 8 + %arg4] : memref<64xi32>
%3 = arith.muli %0, %1 : i32
%4 = arith.addi %2, %3 : i32
affine.store %4, %arg2[%arg3 * 8 + %arg4] : memref<64xi32>
}
}
}
return
}
```
}];
let constructor = "mlir::memref::createNormalizeMemRefsPass()";
let dependentDialects = ["AffineDialect"];
}
def ResolveRankedShapeTypeResultDims :
Pass<"resolve-ranked-shaped-type-result-dims"> {
let summary = "Resolve memref.dim of result values of ranked shape type";

View File

@ -113,10 +113,6 @@ std::unique_ptr<Pass> createSCCPPass();
/// pass may *only* be scheduled on an operation that defines a SymbolTable.
std::unique_ptr<Pass> createSymbolDCEPass();
/// Creates an interprocedural pass to normalize memrefs to have a trivial
/// (identity) layout map.
std::unique_ptr<OperationPass<ModuleOp>> createNormalizeMemRefsPass();
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//

View File

@ -351,122 +351,6 @@ def LoopInvariantCodeMotion : Pass<"loop-invariant-code-motion"> {
let constructor = "mlir::createLoopInvariantCodeMotionPass()";
}
def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> {
let summary = "Normalize memrefs";
let description = [{
This pass transforms memref types with a non-trivial
[layout map](https://mlir.llvm.org/docs/LangRef/#layout-map) into
memref types with an identity layout map, e.g. (i, j) -> (i, j). This
pass is inter-procedural, in the sense that it can modify function
interfaces and call sites that pass memref types. In order to modify
memref types while preserving the original behavior, users of those
memref types are also modified to incorporate the resulting layout map.
For instance, an [AffineLoadOp]
(https://mlir.llvm.org/docs/Dialects/Affine/#affineload-affineloadop)
will be updated to compose the layout map with with the affine expression
contained in the op. Operations marked with the [MemRefsNormalizable]
(https://mlir.llvm.org/docs/Traits/#memrefsnormalizable) trait are
expected to be normalizable. Supported operations include affine
operations, memref.alloc, memref.dealloc, and std.return.
Given an appropriate layout map specified in the code, this transformation
can express tiled or linearized access to multi-dimensional data
structures, but will not modify memref types without an explicit layout
map.
Currently this pass is limited to only modify
functions where all memref types can be normalized. If a function
contains any operations that are not MemRefNormalizable, then the function
and any functions that call or call it will not be modified.
Input
```mlir
#tile = affine_map<(i) -> (i floordiv 4, i mod 4)>
func @matmul(%A: memref<16xf64, #tile>,
%B: index, %C: memref<16xf64>) -> (memref<16xf64, #tile>) {
affine.for %arg3 = 0 to 16 {
%a = affine.load %A[%arg3] : memref<16xf64, #tile>
%p = arith.mulf %a, %a : f64
affine.store %p, %A[%arg3] : memref<16xf64, #tile>
}
%c = memref.alloc() : memref<16xf64, #tile>
%d = affine.load %c[0] : memref<16xf64, #tile>
return %A: memref<16xf64, #tile>
}
```
Output
```mlir
func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>)
-> memref<4x4xf64> {
affine.for %arg3 = 0 to 16 {
%3 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64>
%4 = arith.mulf %3, %3 : f64
affine.store %4, %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64>
}
%0 = memref.alloc() : memref<4x4xf64>
%1 = affine.apply #map1()
%2 = affine.load %0[0, 0] : memref<4x4xf64>
return %arg0 : memref<4x4xf64>
}
```
Input
```
#linear8 = affine_map<(i, j) -> (i * 8 + j)>
func @linearize(%arg0: memref<8x8xi32, #linear8>,
%arg1: memref<8x8xi32, #linear8>,
%arg2: memref<8x8xi32, #linear8>) {
%c8 = arith.constant 8 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
affine.for %arg3 = %c0 to %c8 {
affine.for %arg4 = %c0 to %c8 {
affine.for %arg5 = %c0 to %c8 {
%0 = affine.load %arg0[%arg3, %arg5] : memref<8x8xi32, #linear8>
%1 = affine.load %arg1[%arg5, %arg4] : memref<8x8xi32, #linear8>
%2 = affine.load %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8>
%3 = arith.muli %0, %1 : i32
%4 = arith.addi %2, %3 : i32
affine.store %4, %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8>
}
}
}
return
}
```
Output
```mlir
func @linearize(%arg0: memref<64xi32>,
%arg1: memref<64xi32>,
%arg2: memref<64xi32>) {
%c8 = arith.constant 8 : index
%c0 = arith.constant 0 : index
affine.for %arg3 = %c0 to %c8 {
affine.for %arg4 = %c0 to %c8 {
affine.for %arg5 = %c0 to %c8 {
%0 = affine.load %arg0[%arg3 * 8 + %arg5] : memref<64xi32>
%1 = affine.load %arg1[%arg5 * 8 + %arg4] : memref<64xi32>
%2 = affine.load %arg2[%arg3 * 8 + %arg4] : memref<64xi32>
%3 = arith.muli %0, %1 : i32
%4 = arith.addi %2, %3 : i32
affine.store %4, %arg2[%arg3 * 8 + %arg4] : memref<64xi32>
}
}
}
return
}
```
}];
let constructor = "mlir::createNormalizeMemRefsPass()";
let dependentDialects = ["AffineDialect"];
}
def ParallelLoopCollapsing : Pass<"parallel-loop-collapsing"> {
let summary = "Collapse parallel loops to use less induction variables";
let constructor = "mlir::createParallelLoopCollapsingPass()";

View File

@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRMemRefTransforms
FoldSubViewOps.cpp
NormalizeMemRefs.cpp
ResolveShapedTypeResultDims.cpp
ADDITIONAL_HEADER_DIRS

View File

@ -14,7 +14,7 @@
#include "PassDetail.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Transforms/Utils.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Debug.h"
@ -43,7 +43,8 @@ struct NormalizeMemRefs : public NormalizeMemRefsBase<NormalizeMemRefs> {
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> mlir::createNormalizeMemRefsPass() {
std::unique_ptr<OperationPass<ModuleOp>>
mlir::memref::createNormalizeMemRefsPass() {
return std::make_unique<NormalizeMemRefs>();
}

View File

@ -0,0 +1,43 @@
//===- PassDetail.h - MemRef Pass class details -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef DIALECT_MEMREF_TRANSFORMS_PASSDETAIL_H_
#define DIALECT_MEMREF_TRANSFORMS_PASSDETAIL_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
class AffineDialect;
// Forward declaration from Dialect.h
template <typename ConcreteDialect>
void registerDialect(DialectRegistry &registry);
namespace arith {
class ArithmeticDialect;
} // namespace arith
namespace memref {
class MemRefDialect;
} // namespace memref
namespace tensor {
class TensorDialect;
} // namespace tensor
namespace vector {
class VectorDialect;
} // namespace vector
#define GEN_PASS_CLASSES
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace mlir
#endif // DIALECT_MEMREF_TRANSFORMS_PASSDETAIL_H_

View File

@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@ -107,9 +108,6 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
//===----------------------------------------------------------------------===//
namespace {
#define GEN_PASS_CLASSES
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
struct ResolveRankedShapeTypeResultDimsPass final
: public ResolveRankedShapeTypeResultDimsBase<
ResolveRankedShapeTypeResultDimsPass> {

View File

@ -9,7 +9,6 @@ add_mlir_library(MLIRTransforms
LoopCoalescing.cpp
LoopFusion.cpp
LoopInvariantCodeMotion.cpp
NormalizeMemRefs.cpp
OpStats.cpp
ParallelLoopCollapsing.cpp
PipelineDataTransfer.cpp