forked from OSchip/llvm-project
Utility to normalize memrefs with non-identity layout maps
- introduce utility to convert memrefs with non-identity layout maps to ones with identity layout maps: convert the type and rewrite/remap all its uses - add this utility to -simplify-affine-structures pass for testing purposes Signed-off-by: Uday Bondhugula <uday@polymagelabs.com> Closes tensorflow/mlir#104 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/104 from bondhugula:memref-normalize f2c914aa1890e8860326c9e33f9aa160b3d65e6d PiperOrigin-RevId: 266985317
This commit is contained in:
parent
5593e005c6
commit
54d674f51e
|
@ -477,7 +477,13 @@ public:
|
|||
/// symbolic operands of vMap should match 1:1 (in the same order) with those
|
||||
/// of this constraint system, but the latter could have additional trailing
|
||||
/// operands.
|
||||
LogicalResult composeMap(AffineValueMap *vMap);
|
||||
LogicalResult composeMap(const AffineValueMap *vMap);
|
||||
|
||||
/// Composes an affine map whose dimensions match one to one to the
|
||||
/// dimensions of this FlatAffineConstraints. The results of the map 'other'
|
||||
/// are added as the leading dimensions of this constraint system. Returns
|
||||
/// failure if 'other' is a semi-affine map.
|
||||
LogicalResult composeMatchingMap(AffineMap other);
|
||||
|
||||
/// Projects out (aka eliminates) 'num' identifiers starting at position
|
||||
/// 'pos'. The resulting constraint system is the shadow along the dimensions
|
||||
|
|
|
@ -77,7 +77,9 @@ std::unique_ptr<FunctionPassBase> createLoopUnrollPass(
|
|||
std::unique_ptr<FunctionPassBase>
|
||||
createLoopUnrollAndJamPass(int unrollJamFactor = -1);
|
||||
|
||||
/// Creates an simplification pass for affine structures.
|
||||
/// Creates a simplification pass for affine structures (maps and sets). In
|
||||
/// addition, this pass also normalizes memrefs to have the trivial (identity)
|
||||
/// layout map.
|
||||
std::unique_ptr<FunctionPassBase> createSimplifyAffineStructuresPass();
|
||||
|
||||
/// Creates a loop fusion pass which fuses loops. Buffers of size less than or
|
||||
|
|
|
@ -81,6 +81,11 @@ LogicalResult replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
|
|||
AffineMap indexRemap = AffineMap(),
|
||||
ArrayRef<Value *> extraOperands = {});
|
||||
|
||||
/// Rewrites the memref defined by this alloc op to have an identity layout map
|
||||
/// and updates all its indexing uses. Returns failure if any of its uses
|
||||
/// escape (while leaving the IR in a valid state).
|
||||
LogicalResult normalizeMemRef(AllocOp op);
|
||||
|
||||
/// Creates and inserts into 'builder' a new AffineApplyOp, with the number of
|
||||
/// its results equal to the number of operands, as a composition
|
||||
/// of all other AffineApplyOps reachable from input parameter 'operands'. If
|
||||
|
|
|
@ -614,7 +614,7 @@ void FlatAffineConstraints::mergeAndAlignIdsWithOther(
|
|||
// This routine may add additional local variables if the flattened expression
|
||||
// corresponding to the map has such variables due to mod's, ceildiv's, and
|
||||
// floordiv's in it.
|
||||
LogicalResult FlatAffineConstraints::composeMap(AffineValueMap *vMap) {
|
||||
LogicalResult FlatAffineConstraints::composeMap(const AffineValueMap *vMap) {
|
||||
std::vector<SmallVector<int64_t, 8>> flatExprs;
|
||||
FlatAffineConstraints localCst;
|
||||
if (failed(getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs,
|
||||
|
@ -683,6 +683,75 @@ LogicalResult FlatAffineConstraints::composeMap(AffineValueMap *vMap) {
|
|||
return success();
|
||||
}
|
||||
|
||||
// Similar to composeMap except that no Value's need be associated with the
|
||||
// constraint system nor are they looked at -- since the dimensions and
|
||||
// symbols of 'other' are expected to correspond 1:1 to 'this' system. It
|
||||
// is thus not convenient to share code with composeMap.
|
||||
LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) {
|
||||
assert(other.getNumDims() == getNumDimIds() && "dim mismatch");
|
||||
assert(other.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
|
||||
|
||||
std::vector<SmallVector<int64_t, 8>> flatExprs;
|
||||
FlatAffineConstraints localCst;
|
||||
if (failed(getFlattenedAffineExprs(other, &flatExprs, &localCst))) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "composition unimplemented for semi-affine maps\n");
|
||||
return failure();
|
||||
}
|
||||
assert(flatExprs.size() == other.getNumResults());
|
||||
|
||||
// Add localCst information.
|
||||
if (localCst.getNumLocalIds() > 0) {
|
||||
// Place local id's of A after local id's of B.
|
||||
for (unsigned l = 0, e = localCst.getNumLocalIds(); l < e; l++) {
|
||||
addLocalId(0);
|
||||
}
|
||||
// Finally, append localCst to this constraint set.
|
||||
append(localCst);
|
||||
}
|
||||
|
||||
// Add dimensions corresponding to the map's results.
|
||||
for (unsigned t = 0, e = other.getNumResults(); t < e; t++) {
|
||||
addDimId(0);
|
||||
}
|
||||
|
||||
// We add one equality for each result connecting the result dim of the map to
|
||||
// the other identifiers.
|
||||
// For eg: if the expression is 16*i0 + i1, and this is the r^th
|
||||
// iteration/result of the value map, we are adding the equality:
|
||||
// d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we
|
||||
// add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
|
||||
for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
|
||||
const auto &flatExpr = flatExprs[r];
|
||||
assert(flatExpr.size() >= other.getNumInputs() + 1);
|
||||
|
||||
// eqToAdd is the equality corresponding to the flattened affine expression.
|
||||
SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
|
||||
// Set the coefficient for this result to one.
|
||||
eqToAdd[r] = 1;
|
||||
|
||||
// Dims and symbols.
|
||||
for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) {
|
||||
// Negate 'eq[r]' since the newly added dimension will be set to this one.
|
||||
eqToAdd[e + i] = -flatExpr[i];
|
||||
}
|
||||
// Local vars common to eq and localCst are at the beginning.
|
||||
unsigned j = getNumDimIds() + getNumSymbolIds();
|
||||
unsigned end = flatExpr.size() - 1;
|
||||
for (unsigned i = other.getNumInputs(); i < end; i++, j++) {
|
||||
eqToAdd[j] = -flatExpr[i];
|
||||
}
|
||||
|
||||
// Constant term.
|
||||
eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
|
||||
|
||||
// Add the equality connecting the result of the map to this constraint set.
|
||||
addEquality(eqToAdd);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// Turn a dimension into a symbol.
|
||||
static void turnDimIntoSymbol(FlatAffineConstraints *cst, Value &id) {
|
||||
unsigned pos;
|
||||
|
|
|
@ -20,12 +20,10 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Analysis/AffineStructures.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/Utils.h"
|
||||
|
||||
#define DEBUG_TYPE "simplify-affine-structure"
|
||||
|
||||
|
@ -33,10 +31,10 @@ using namespace mlir;
|
|||
|
||||
namespace {
|
||||
|
||||
/// Simplifies all affine expressions appearing in the operations of
|
||||
/// the Function. This is mainly to test the simplifyAffineExpr method.
|
||||
/// TODO(someone): This should just be defined as a canonicalization pattern
|
||||
/// on AffineMap and driven from the existing canonicalization pass.
|
||||
/// Simplifies affine maps and sets appearing in the operations of the Function.
|
||||
/// This part is mainly to test the simplifyAffineExpr method. In addition,
|
||||
/// all memrefs with non-trivial layout maps are converted to ones with trivial
|
||||
/// identity layout ones.
|
||||
struct SimplifyAffineStructures
|
||||
: public FunctionPass<SimplifyAffineStructures> {
|
||||
void runOnFunction() override;
|
||||
|
@ -93,8 +91,9 @@ std::unique_ptr<FunctionPassBase> mlir::createSimplifyAffineStructuresPass() {
|
|||
}
|
||||
|
||||
void SimplifyAffineStructures::runOnFunction() {
|
||||
auto func = getFunction();
|
||||
simplifiedAttributes.clear();
|
||||
getFunction().walk([&](Operation *opInst) {
|
||||
func.walk([&](Operation *opInst) {
|
||||
for (auto attr : opInst->getAttrs()) {
|
||||
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>())
|
||||
simplifyAndUpdateAttribute(opInst, attr.first, mapAttr);
|
||||
|
@ -102,6 +101,9 @@ void SimplifyAffineStructures::runOnFunction() {
|
|||
simplifyAndUpdateAttribute(opInst, attr.first, setAttr);
|
||||
}
|
||||
});
|
||||
|
||||
// Turn memrefs' non-identity layouts maps into ones with identity.
|
||||
func.walk([](AllocOp op) { normalizeMemRef(op); });
|
||||
}
|
||||
|
||||
static PassRegistration<SimplifyAffineStructures>
|
||||
|
|
|
@ -388,3 +388,82 @@ void mlir::createAffineComputationSlice(
|
|||
opInst->setOperand(idx, newOperands[idx]);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Currently works for static memrefs with single non-identity layout map.
|
||||
LogicalResult mlir::normalizeMemRef(AllocOp allocOp) {
|
||||
MemRefType memrefType = allocOp.getType();
|
||||
unsigned rank = memrefType.getRank();
|
||||
if (rank == 0)
|
||||
return success();
|
||||
|
||||
auto layoutMaps = memrefType.getAffineMaps();
|
||||
OpBuilder b(allocOp);
|
||||
if (layoutMaps.size() != 1)
|
||||
return failure();
|
||||
|
||||
AffineMap layoutMap = layoutMaps.front();
|
||||
|
||||
if (layoutMap == b.getMultiDimIdentityMap(rank))
|
||||
return success();
|
||||
|
||||
if (layoutMap.getNumResults() < rank)
|
||||
// This is a sufficient condition for not being one-to-one; the map is thus
|
||||
// invalid. Leave it alone. (Undefined behavior?)
|
||||
return failure();
|
||||
|
||||
// We don't do any more non-trivial checks for one-to-one'ness; we
|
||||
// assume that it is one-to-one.
|
||||
|
||||
// TODO: Only for static memref's for now.
|
||||
if (memrefType.getNumDynamicDims() > 0)
|
||||
return failure();
|
||||
|
||||
// We have a single map that is not an identity map. Create a new memref with
|
||||
// the right shape and an identity layout map.
|
||||
auto shape = memrefType.getShape();
|
||||
FlatAffineConstraints fac(rank, 0);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
fac.addConstantLowerBound(d, 0);
|
||||
fac.addConstantUpperBound(d, shape[d] - 1);
|
||||
}
|
||||
|
||||
// We compose this map with the original index (logical) space to derive the
|
||||
// upper bounds for the new index space.
|
||||
unsigned newRank = layoutMap.getNumResults();
|
||||
fac.composeMatchingMap(layoutMap);
|
||||
// Project out the old data dimensions.
|
||||
fac.projectOut(newRank, fac.getNumIds() - newRank - fac.getNumLocalIds());
|
||||
SmallVector<int64_t, 4> newShape(newRank);
|
||||
for (unsigned d = 0; d < newRank; ++d) {
|
||||
// The lower bound for the shape is always zero.
|
||||
auto ubConst = fac.getConstantUpperBound(d);
|
||||
// For a static memref and an affine map with no symbols, this is always
|
||||
// bounded.
|
||||
assert(ubConst.hasValue() && "should always have an upper bound");
|
||||
if (ubConst.getValue() < 0)
|
||||
// This is due to an invalid map that maps to a negative space.
|
||||
return failure();
|
||||
newShape[d] = ubConst.getValue() + 1;
|
||||
}
|
||||
|
||||
auto *oldMemRef = allocOp.getResult();
|
||||
auto newMemRefType = b.getMemRefType(newShape, memrefType.getElementType(),
|
||||
b.getMultiDimIdentityMap(newRank));
|
||||
auto newAlloc = b.create<AllocOp>(allocOp.getLoc(), newMemRefType);
|
||||
|
||||
// Replace all uses of the old memref.
|
||||
if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
|
||||
/*extraIndices=*/{},
|
||||
/*indexRemap=*/layoutMap))) {
|
||||
// If it failed (due to escapes for example), bail out.
|
||||
newAlloc.erase();
|
||||
return failure();
|
||||
}
|
||||
// Replace any uses of the original alloc op and erase it. All remaining uses
|
||||
// have to be dealloc's; RAMUW above would've failed otherwise.
|
||||
assert(std::all_of(oldMemRef->user_begin(), oldMemRef->user_end(),
|
||||
[](Operation *op) { return isa<DeallocOp>(op); }));
|
||||
oldMemRef->replaceAllUsesWith(newAlloc);
|
||||
allocOp.erase();
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
// RUN: mlir-opt -simplify-affine-structures %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @permute()
|
||||
func @permute() {
|
||||
%A = alloc() : memref<64x256xf32, (d0, d1) -> (d1, d0)>
|
||||
affine.for %i = 0 to 64 {
|
||||
affine.for %j = 0 to 256 {
|
||||
affine.load %A[%i, %j] : memref<64x256xf32, (d0, d1) -> (d1, d0)>
|
||||
}
|
||||
}
|
||||
dealloc %A : memref<64x256xf32, (d0, d1) -> (d1, d0)>
|
||||
return
|
||||
}
|
||||
// The old memref alloc should disappear.
|
||||
// CHECK-NOT: memref<64x256xf32>
|
||||
// CHECK: [[MEM:%[0-9]+]] = alloc() : memref<256x64xf32>
|
||||
// CHECK-NEXT: affine.for %[[I:arg[0-9]+]] = 0 to 64 {
|
||||
// CHECK-NEXT: affine.for %[[J:arg[0-9]+]] = 0 to 256 {
|
||||
// CHECK-NEXT: affine.load [[MEM]][%[[J]], %[[I]]] : memref<256x64xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: dealloc [[MEM]]
|
||||
// CHECK-NEXT: return
|
||||
|
||||
// CHECK-LABEL: func @shift()
|
||||
func @shift() {
|
||||
// CHECK-NOT: memref<64xf32, (d0) -> (d0 + 1)>
|
||||
%A = alloc() : memref<64xf32, (d0) -> (d0 + 1)>
|
||||
affine.for %i = 0 to 64 {
|
||||
affine.load %A[%i] : memref<64xf32, (d0) -> (d0 + 1)>
|
||||
// CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @high_dim_permute()
|
||||
func @high_dim_permute() {
|
||||
// CHECK-NOT: memref<64x128x256xf32,
|
||||
%A = alloc() : memref<64x128x256xf32, (d0, d1, d2) -> (d2, d0, d1)>
|
||||
// CHECK: %[[I:arg[0-9]+]]
|
||||
affine.for %i = 0 to 64 {
|
||||
// CHECK: %[[J:arg[0-9]+]]
|
||||
affine.for %j = 0 to 128 {
|
||||
// CHECK: %[[K:arg[0-9]+]]
|
||||
affine.for %k = 0 to 256 {
|
||||
affine.load %A[%i, %j, %k] : memref<64x128x256xf32, (d0, d1, d2) -> (d2, d0, d1)>
|
||||
// CHECK: %{{.*}} = affine.load %{{.*}}[%[[K]], %[[I]], %[[J]]] : memref<256x64x128xf32>
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @invalid_map
|
||||
func @invalid_map() {
|
||||
%A = alloc() : memref<64x128xf32, (d0, d1) -> (d0, -d1 - 10)>
|
||||
// CHECK: %{{.*}} = alloc() : memref<64x128xf32,
|
||||
return
|
||||
}
|
||||
|
||||
// A tiled layout.
|
||||
// CHECK-LABEL: func @data_tiling()
|
||||
func @data_tiling() {
|
||||
%A = alloc() : memref<64x512xf32, (d0, d1) -> (d0 floordiv 8, d1 floordiv 16, d0 mod 8, d1 mod 16)>
|
||||
// CHECK: %{{.*}} = alloc() : memref<8x32x8x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// Memref escapes; no normalization.
|
||||
// CHECK-LABEL: func @escaping() -> memref<64xf32, #map{{[0-9]+}}>
|
||||
func @escaping() -> memref<64xf32, (d0) -> (d0 + 2)> {
|
||||
// CHECK: %{{.*}} = alloc() : memref<64xf32, #map{{[0-9]+}}>
|
||||
%A = alloc() : memref<64xf32, (d0) -> (d0 + 2)>
|
||||
return %A : memref<64xf32, (d0) -> (d0 + 2)>
|
||||
}
|
Loading…
Reference in New Issue