forked from OSchip/llvm-project
Support symbolic operands for memref replacement; fix memrefNormalize
- allow symbols in index remapping provided for memref replacement - fix memref normalize crash on cases with layout maps with symbols Signed-off-by: Uday Bondhugula <uday@polymagelabs.com> Reported by: Alex Zinenko Closes tensorflow/mlir#139 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/139 from bondhugula:memref-rep-symbols 2f48c1fdb5d4c58915bbddbd9f07b18541819233 PiperOrigin-RevId: 269851182
This commit is contained in:
parent
1c73be76d8
commit
727a50ae2d
|
@ -162,6 +162,12 @@ def AllocOp : Std_Op<"alloc"> {
|
|||
unsigned getNumSymbolicOperands() {
|
||||
return getNumOperands() - getType().getNumDynamicDims();
|
||||
}
|
||||
|
||||
/// Returns the symbolic operands (the ones in square brackets), which bind
|
||||
/// to the symbols of the memref's layout map.
|
||||
operand_range getSymbolicOperands() {
|
||||
return {operand_begin() + getType().getNumDynamicDims(), operand_end()};
|
||||
}
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
|
|
@ -40,15 +40,15 @@ class OpBuilder;
|
|||
/// Replaces all "dereferencing" uses of `oldMemRef` with `newMemRef` while
|
||||
/// optionally remapping the old memref's indices using the supplied affine map,
|
||||
/// `indexRemap`. The new memref could be of a different shape or rank.
|
||||
/// `extraIndices` provides additional access indices to be added to the start.
|
||||
/// `extraIndices` provides any additional access indices to be added to the
|
||||
/// start.
|
||||
///
|
||||
/// `indexRemap` remaps indices of the old memref access to a new set of indices
|
||||
/// that are used to index the memref. Additional input operands to indexRemap
|
||||
/// can be optionally provided, and they are added at the start of its input
|
||||
/// list. `indexRemap` is expected to have only dimensional inputs, and the
|
||||
/// number of its inputs equal to extraOperands.size() plus rank of the memref.
|
||||
/// 'extraOperands' is an optional argument that corresponds to additional
|
||||
/// operands (inputs) for indexRemap at the beginning of its input list.
|
||||
/// can be optionally provided in `extraOperands`, and they occupy the start
|
||||
/// of its input list. `indexRemap`'s dimensional inputs are expected to
|
||||
/// correspond to memref's indices, and its symbolic inputs if any should be
|
||||
/// provided in `symbolOperands`.
|
||||
///
|
||||
/// `domInstFilter`, if non-null, restricts the replacement to only those
|
||||
/// operations that are dominated by the former; similarly, `postDomInstFilter`
|
||||
|
@ -70,6 +70,7 @@ LogicalResult replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
|
|||
ArrayRef<Value *> extraIndices = {},
|
||||
AffineMap indexRemap = AffineMap(),
|
||||
ArrayRef<Value *> extraOperands = {},
|
||||
ArrayRef<Value *> symbolOperands = {},
|
||||
Operation *domInstFilter = nullptr,
|
||||
Operation *postDomInstFilter = nullptr);
|
||||
|
||||
|
@ -79,7 +80,8 @@ LogicalResult replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
|
|||
Operation *op,
|
||||
ArrayRef<Value *> extraIndices = {},
|
||||
AffineMap indexRemap = AffineMap(),
|
||||
ArrayRef<Value *> extraOperands = {});
|
||||
ArrayRef<Value *> extraOperands = {},
|
||||
ArrayRef<Value *> symbolOperands = {});
|
||||
|
||||
/// Rewrites the memref defined by this alloc op to have an identity layout map
|
||||
/// and updates all its indexing uses. Returns failure if any of its uses
|
||||
|
|
|
@ -955,6 +955,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
|
|||
LogicalResult res =
|
||||
replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
|
||||
/*extraOperands=*/outerIVs,
|
||||
/*symbolOperands=*/{},
|
||||
/*domInstFilter=*/&*forOp.getBody()->begin());
|
||||
assert(succeeded(res) &&
|
||||
"replaceAllMemrefUsesWith should always succeed here");
|
||||
|
|
|
@ -122,6 +122,7 @@ static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) {
|
|||
/*extraIndices=*/{ivModTwoOp},
|
||||
/*indexRemap=*/AffineMap(),
|
||||
/*extraOperands=*/{},
|
||||
/*symbolOperands=*/{},
|
||||
/*domInstFilter=*/&*forOp.getBody()->begin()))) {
|
||||
LLVM_DEBUG(
|
||||
forOp.emitError("memref replacement for double buffering failed"));
|
||||
|
|
|
@ -1548,6 +1548,7 @@ static LogicalResult generateCopy(
|
|||
replaceAllMemRefUsesWith(memref, fastMemRef,
|
||||
/*extraIndices=*/{}, indexRemap,
|
||||
/*extraOperands=*/regionSymbols,
|
||||
/*symbolOperands=*/{},
|
||||
/*domInstFilter=*/&*begin,
|
||||
/*postDomInstFilter=*/&*postDomFilter);
|
||||
|
||||
|
|
|
@ -62,14 +62,17 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
|
|||
Operation *op,
|
||||
ArrayRef<Value *> extraIndices,
|
||||
AffineMap indexRemap,
|
||||
ArrayRef<Value *> extraOperands) {
|
||||
ArrayRef<Value *> extraOperands,
|
||||
ArrayRef<Value *> symbolOperands) {
|
||||
unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
|
||||
(void)newMemRefRank; // unused in opt mode
|
||||
unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
|
||||
(void)oldMemRefRank;
|
||||
(void)oldMemRefRank; // unused in opt mode
|
||||
if (indexRemap) {
|
||||
assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected");
|
||||
assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank);
|
||||
assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
|
||||
"symbolic operand count mistmatch");
|
||||
assert(indexRemap.getNumInputs() ==
|
||||
extraOperands.size() + oldMemRefRank + symbolOperands.size());
|
||||
assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
|
||||
} else {
|
||||
assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
|
||||
|
@ -131,9 +134,11 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
|
|||
// provided. The indices of a memref come right after it, i.e.,
|
||||
// at position memRefOperandPos + 1.
|
||||
SmallVector<Value *, 4> remapOperands;
|
||||
remapOperands.reserve(extraOperands.size() + oldMemRefRank);
|
||||
remapOperands.reserve(extraOperands.size() + oldMemRefRank +
|
||||
symbolOperands.size());
|
||||
remapOperands.append(extraOperands.begin(), extraOperands.end());
|
||||
remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
|
||||
remapOperands.append(symbolOperands.begin(), symbolOperands.end());
|
||||
|
||||
SmallVector<Value *, 4> remapOutputs;
|
||||
remapOutputs.reserve(oldMemRefRank);
|
||||
|
@ -226,6 +231,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
|
|||
ArrayRef<Value *> extraIndices,
|
||||
AffineMap indexRemap,
|
||||
ArrayRef<Value *> extraOperands,
|
||||
ArrayRef<Value *> symbolOperands,
|
||||
Operation *domInstFilter,
|
||||
Operation *postDomInstFilter) {
|
||||
unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
|
||||
|
@ -233,8 +239,10 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
|
|||
unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
|
||||
(void)oldMemRefRank;
|
||||
if (indexRemap) {
|
||||
assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected");
|
||||
assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank);
|
||||
assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
|
||||
"symbol operand count mismatch");
|
||||
assert(indexRemap.getNumInputs() ==
|
||||
extraOperands.size() + oldMemRefRank + symbolOperands.size());
|
||||
assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
|
||||
} else {
|
||||
assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
|
||||
|
@ -287,7 +295,8 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
|
|||
|
||||
for (auto *op : opsToReplace) {
|
||||
if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef, op, extraIndices,
|
||||
indexRemap, extraOperands)))
|
||||
indexRemap, extraOperands,
|
||||
symbolOperands)))
|
||||
llvm_unreachable("memref replacement guaranteed to succeed here");
|
||||
}
|
||||
|
||||
|
@ -446,6 +455,8 @@ LogicalResult mlir::normalizeMemRef(AllocOp allocOp) {
|
|||
}
|
||||
|
||||
auto *oldMemRef = allocOp.getResult();
|
||||
SmallVector<Value *, 4> symbolOperands(allocOp.getSymbolicOperands());
|
||||
|
||||
auto newMemRefType = b.getMemRefType(newShape, memrefType.getElementType(),
|
||||
b.getMultiDimIdentityMap(newRank));
|
||||
auto newAlloc = b.create<AllocOp>(allocOp.getLoc(), newMemRefType);
|
||||
|
@ -453,7 +464,9 @@ LogicalResult mlir::normalizeMemRef(AllocOp allocOp) {
|
|||
// Replace all uses of the old memref.
|
||||
if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
|
||||
/*extraIndices=*/{},
|
||||
/*indexRemap=*/layoutMap))) {
|
||||
/*indexRemap=*/layoutMap,
|
||||
/*extraOperands=*/{},
|
||||
/*symbolOperands=*/symbolOperands))) {
|
||||
// If it failed (due to escapes for example), bail out.
|
||||
newAlloc.erase();
|
||||
return failure();
|
||||
|
|
|
@ -96,6 +96,21 @@ func @strided_cumulative() {
|
|||
return
|
||||
}
|
||||
|
||||
// Symbolic operand for alloc, although unused. Tests replaceAllMemRefUsesWith
|
||||
// when the index remap has symbols.
|
||||
// CHECK-LABEL: func @symbolic_operands
|
||||
func @symbolic_operands(%s : index) {
|
||||
// CHECK: alloc() : memref<100xf32>
|
||||
%A = alloc()[%s] : memref<10x10xf32, (d0,d1)[s0] -> (10*d0 + d1)>
|
||||
affine.for %i = 0 to 10 {
|
||||
affine.for %j = 0 to 10 {
|
||||
// CHECK: affine.load %{{.*}}[%{{.*}} * 10 + %{{.*}}] : memref<100xf32>
|
||||
affine.load %A[%i, %j] : memref<10x10xf32, (d0,d1)[s0] -> (10*d0 + d1)>
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Memref escapes; no normalization.
|
||||
// CHECK-LABEL: func @escaping() -> memref<64xf32, #map{{[0-9]+}}>
|
||||
func @escaping() -> memref<64xf32, (d0) -> (d0 + 2)> {
|
||||
|
|
Loading…
Reference in New Issue