Utility to normalize memrefs with non-identity layout maps

- introduce utility to convert memrefs with non-identity layout maps to
  ones with identity layout maps: convert the type and rewrite/remap all
  its uses

- add this utility to -simplify-affine-structures pass for testing
  purposes

Signed-off-by: Uday Bondhugula <uday@polymagelabs.com>

Closes tensorflow/mlir#104

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/104 from bondhugula:memref-normalize f2c914aa1890e8860326c9e33f9aa160b3d65e6d
PiperOrigin-RevId: 266985317
This commit is contained in:
Uday Bondhugula 2019-09-03 12:13:59 -07:00 committed by A. Unique TensorFlower
parent 5593e005c6
commit 54d674f51e
7 changed files with 250 additions and 11 deletions

View File

@ -477,7 +477,13 @@ public:
/// symbolic operands of vMap should match 1:1 (in the same order) with those
/// of this constraint system, but the latter could have additional trailing
/// operands.
LogicalResult composeMap(AffineValueMap *vMap);
LogicalResult composeMap(const AffineValueMap *vMap);
/// Composes an affine map whose dimensions match one to one to the
/// dimensions of this FlatAffineConstraints. The results of the map 'other'
/// are added as the leading dimensions of this constraint system. Returns
/// failure if 'other' is a semi-affine map.
LogicalResult composeMatchingMap(AffineMap other);
/// Projects out (aka eliminates) 'num' identifiers starting at position
/// 'pos'. The resulting constraint system is the shadow along the dimensions

View File

@ -77,7 +77,9 @@ std::unique_ptr<FunctionPassBase> createLoopUnrollPass(
std::unique_ptr<FunctionPassBase>
createLoopUnrollAndJamPass(int unrollJamFactor = -1);
/// Creates an simplification pass for affine structures.
/// Creates a simplification pass for affine structures (maps and sets). In
/// addition, this pass also normalizes memrefs to have the trivial (identity)
/// layout map.
std::unique_ptr<FunctionPassBase> createSimplifyAffineStructuresPass();
/// Creates a loop fusion pass which fuses loops. Buffers of size less than or

View File

@ -81,6 +81,11 @@ LogicalResult replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
AffineMap indexRemap = AffineMap(),
ArrayRef<Value *> extraOperands = {});
/// Rewrites the memref defined by this alloc op to have an identity layout map
/// and updates all its indexing uses. Returns failure if any of its uses
/// escape (while leaving the IR in a valid state).
LogicalResult normalizeMemRef(AllocOp op);
/// Creates and inserts into 'builder' a new AffineApplyOp, with the number of
/// its results equal to the number of operands, as a composition
/// of all other AffineApplyOps reachable from input parameter 'operands'. If

View File

@ -614,7 +614,7 @@ void FlatAffineConstraints::mergeAndAlignIdsWithOther(
// This routine may add additional local variables if the flattened expression
// corresponding to the map has such variables due to mod's, ceildiv's, and
// floordiv's in it.
LogicalResult FlatAffineConstraints::composeMap(AffineValueMap *vMap) {
LogicalResult FlatAffineConstraints::composeMap(const AffineValueMap *vMap) {
std::vector<SmallVector<int64_t, 8>> flatExprs;
FlatAffineConstraints localCst;
if (failed(getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs,
@ -683,6 +683,75 @@ LogicalResult FlatAffineConstraints::composeMap(AffineValueMap *vMap) {
return success();
}
// Similar to composeMap except that no Value's need be associated with the
// constraint system nor are they looked at -- since the dimensions and
// symbols of 'other' are expected to correspond 1:1 to 'this' system. It
// is thus not convenient to share code with composeMap.
LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) {
assert(other.getNumDims() == getNumDimIds() && "dim mismatch");
assert(other.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
std::vector<SmallVector<int64_t, 8>> flatExprs;
FlatAffineConstraints localCst;
if (failed(getFlattenedAffineExprs(other, &flatExprs, &localCst))) {
LLVM_DEBUG(llvm::dbgs()
<< "composition unimplemented for semi-affine maps\n");
return failure();
}
assert(flatExprs.size() == other.getNumResults());
// Add localCst information.
if (localCst.getNumLocalIds() > 0) {
// Place local id's of A after local id's of B.
for (unsigned l = 0, e = localCst.getNumLocalIds(); l < e; l++) {
addLocalId(0);
}
// Finally, append localCst to this constraint set.
append(localCst);
}
// Add dimensions corresponding to the map's results.
for (unsigned t = 0, e = other.getNumResults(); t < e; t++) {
addDimId(0);
}
// We add one equality for each result connecting the result dim of the map to
// the other identifiers.
// For eg: if the expression is 16*i0 + i1, and this is the r^th
// iteration/result of the value map, we are adding the equality:
// d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we
// add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
const auto &flatExpr = flatExprs[r];
assert(flatExpr.size() >= other.getNumInputs() + 1);
// eqToAdd is the equality corresponding to the flattened affine expression.
SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
// Set the coefficient for this result to one.
eqToAdd[r] = 1;
// Dims and symbols.
for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) {
// Negate 'eq[r]' since the newly added dimension will be set to this one.
eqToAdd[e + i] = -flatExpr[i];
}
// Local vars common to eq and localCst are at the beginning.
unsigned j = getNumDimIds() + getNumSymbolIds();
unsigned end = flatExpr.size() - 1;
for (unsigned i = other.getNumInputs(); i < end; i++, j++) {
eqToAdd[j] = -flatExpr[i];
}
// Constant term.
eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
// Add the equality connecting the result of the map to this constraint set.
addEquality(eqToAdd);
}
return success();
}
// Turn a dimension into a symbol.
static void turnDimIntoSymbol(FlatAffineConstraints *cst, Value &id) {
unsigned pos;

View File

@ -20,12 +20,10 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/Utils.h"
#define DEBUG_TYPE "simplify-affine-structure"
@ -33,10 +31,10 @@ using namespace mlir;
namespace {
/// Simplifies all affine expressions appearing in the operations of
/// the Function. This is mainly to test the simplifyAffineExpr method.
/// TODO(someone): This should just be defined as a canonicalization pattern
/// on AffineMap and driven from the existing canonicalization pass.
/// Simplifies affine maps and sets appearing in the operations of the Function.
/// This part is mainly to test the simplifyAffineExpr method. In addition,
/// all memrefs with non-trivial layout maps are converted to ones with trivial
/// identity layout ones.
struct SimplifyAffineStructures
: public FunctionPass<SimplifyAffineStructures> {
void runOnFunction() override;
@ -93,8 +91,9 @@ std::unique_ptr<FunctionPassBase> mlir::createSimplifyAffineStructuresPass() {
}
void SimplifyAffineStructures::runOnFunction() {
auto func = getFunction();
simplifiedAttributes.clear();
getFunction().walk([&](Operation *opInst) {
func.walk([&](Operation *opInst) {
for (auto attr : opInst->getAttrs()) {
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>())
simplifyAndUpdateAttribute(opInst, attr.first, mapAttr);
@ -102,6 +101,9 @@ void SimplifyAffineStructures::runOnFunction() {
simplifyAndUpdateAttribute(opInst, attr.first, setAttr);
}
});
// Turn memrefs' non-identity layouts maps into ones with identity.
func.walk([](AllocOp op) { normalizeMemRef(op); });
}
static PassRegistration<SimplifyAffineStructures>

View File

@ -388,3 +388,82 @@ void mlir::createAffineComputationSlice(
opInst->setOperand(idx, newOperands[idx]);
}
}
// TODO: Currently works for static memrefs with single non-identity layout map.
LogicalResult mlir::normalizeMemRef(AllocOp allocOp) {
MemRefType memrefType = allocOp.getType();
unsigned rank = memrefType.getRank();
if (rank == 0)
return success();
auto layoutMaps = memrefType.getAffineMaps();
OpBuilder b(allocOp);
if (layoutMaps.size() != 1)
return failure();
AffineMap layoutMap = layoutMaps.front();
if (layoutMap == b.getMultiDimIdentityMap(rank))
return success();
if (layoutMap.getNumResults() < rank)
// This is a sufficient condition for not being one-to-one; the map is thus
// invalid. Leave it alone. (Undefined behavior?)
return failure();
// We don't do any more non-trivial checks for one-to-one'ness; we
// assume that it is one-to-one.
// TODO: Only for static memref's for now.
if (memrefType.getNumDynamicDims() > 0)
return failure();
// We have a single map that is not an identity map. Create a new memref with
// the right shape and an identity layout map.
auto shape = memrefType.getShape();
FlatAffineConstraints fac(rank, 0);
for (unsigned d = 0; d < rank; ++d) {
fac.addConstantLowerBound(d, 0);
fac.addConstantUpperBound(d, shape[d] - 1);
}
// We compose this map with the original index (logical) space to derive the
// upper bounds for the new index space.
unsigned newRank = layoutMap.getNumResults();
fac.composeMatchingMap(layoutMap);
// Project out the old data dimensions.
fac.projectOut(newRank, fac.getNumIds() - newRank - fac.getNumLocalIds());
SmallVector<int64_t, 4> newShape(newRank);
for (unsigned d = 0; d < newRank; ++d) {
// The lower bound for the shape is always zero.
auto ubConst = fac.getConstantUpperBound(d);
// For a static memref and an affine map with no symbols, this is always
// bounded.
assert(ubConst.hasValue() && "should always have an upper bound");
if (ubConst.getValue() < 0)
// This is due to an invalid map that maps to a negative space.
return failure();
newShape[d] = ubConst.getValue() + 1;
}
auto *oldMemRef = allocOp.getResult();
auto newMemRefType = b.getMemRefType(newShape, memrefType.getElementType(),
b.getMultiDimIdentityMap(newRank));
auto newAlloc = b.create<AllocOp>(allocOp.getLoc(), newMemRefType);
// Replace all uses of the old memref.
if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
/*extraIndices=*/{},
/*indexRemap=*/layoutMap))) {
// If it failed (due to escapes for example), bail out.
newAlloc.erase();
return failure();
}
// Replace any uses of the original alloc op and erase it. All remaining uses
// have to be dealloc's; RAMUW above would've failed otherwise.
assert(std::all_of(oldMemRef->user_begin(), oldMemRef->user_end(),
[](Operation *op) { return isa<DeallocOp>(op); }));
oldMemRef->replaceAllUsesWith(newAlloc);
allocOp.erase();
return success();
}

View File

@ -0,0 +1,76 @@
// RUN: mlir-opt -simplify-affine-structures %s | FileCheck %s
// CHECK-LABEL: func @permute()
func @permute() {
%A = alloc() : memref<64x256xf32, (d0, d1) -> (d1, d0)>
affine.for %i = 0 to 64 {
affine.for %j = 0 to 256 {
affine.load %A[%i, %j] : memref<64x256xf32, (d0, d1) -> (d1, d0)>
}
}
dealloc %A : memref<64x256xf32, (d0, d1) -> (d1, d0)>
return
}
// The old memref alloc should disappear.
// CHECK-NOT: memref<64x256xf32>
// CHECK: [[MEM:%[0-9]+]] = alloc() : memref<256x64xf32>
// CHECK-NEXT: affine.for %[[I:arg[0-9]+]] = 0 to 64 {
// CHECK-NEXT: affine.for %[[J:arg[0-9]+]] = 0 to 256 {
// CHECK-NEXT: affine.load [[MEM]][%[[J]], %[[I]]] : memref<256x64xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: dealloc [[MEM]]
// CHECK-NEXT: return
// CHECK-LABEL: func @shift()
func @shift() {
// CHECK-NOT: memref<64xf32, (d0) -> (d0 + 1)>
%A = alloc() : memref<64xf32, (d0) -> (d0 + 1)>
affine.for %i = 0 to 64 {
affine.load %A[%i] : memref<64xf32, (d0) -> (d0 + 1)>
// CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32>
}
return
}
// CHECK-LABEL: func @high_dim_permute()
func @high_dim_permute() {
// CHECK-NOT: memref<64x128x256xf32,
%A = alloc() : memref<64x128x256xf32, (d0, d1, d2) -> (d2, d0, d1)>
// CHECK: %[[I:arg[0-9]+]]
affine.for %i = 0 to 64 {
// CHECK: %[[J:arg[0-9]+]]
affine.for %j = 0 to 128 {
// CHECK: %[[K:arg[0-9]+]]
affine.for %k = 0 to 256 {
affine.load %A[%i, %j, %k] : memref<64x128x256xf32, (d0, d1, d2) -> (d2, d0, d1)>
// CHECK: %{{.*}} = affine.load %{{.*}}[%[[K]], %[[I]], %[[J]]] : memref<256x64x128xf32>
}
}
}
return
}
// CHECK-LABEL: func @invalid_map
func @invalid_map() {
%A = alloc() : memref<64x128xf32, (d0, d1) -> (d0, -d1 - 10)>
// CHECK: %{{.*}} = alloc() : memref<64x128xf32,
return
}
// A tiled layout.
// CHECK-LABEL: func @data_tiling()
func @data_tiling() {
%A = alloc() : memref<64x512xf32, (d0, d1) -> (d0 floordiv 8, d1 floordiv 16, d0 mod 8, d1 mod 16)>
// CHECK: %{{.*}} = alloc() : memref<8x32x8x16xf32>
return
}
// Memref escapes; no normalization.
// CHECK-LABEL: func @escaping() -> memref<64xf32, #map{{[0-9]+}}>
func @escaping() -> memref<64xf32, (d0) -> (d0 + 2)> {
// CHECK: %{{.*}} = alloc() : memref<64xf32, #map{{[0-9]+}}>
%A = alloc() : memref<64xf32, (d0) -> (d0 + 2)>
return %A : memref<64xf32, (d0) -> (d0 + 2)>
}