diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td index 426ec656b0e7..7de48c07e447 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -162,6 +162,12 @@ def AllocOp : Std_Op<"alloc"> { unsigned getNumSymbolicOperands() { return getNumOperands() - getType().getNumDynamicDims(); } + + /// Returns the symbolic operands (the ones in square brackets), which bind + /// to the symbols of the memref's layout map. + operand_range getSymbolicOperands() { + return {operand_begin() + getType().getNumDynamicDims(), operand_end()}; + } }]; let hasCanonicalizer = 1; diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index 0644bc8064fc..c682b48f331c 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -40,15 +40,15 @@ class OpBuilder; /// Replaces all "dereferencing" uses of `oldMemRef` with `newMemRef` while /// optionally remapping the old memref's indices using the supplied affine map, /// `indexRemap`. The new memref could be of a different shape or rank. -/// `extraIndices` provides additional access indices to be added to the start. +/// `extraIndices` provides any additional access indices to be added to the +/// start. /// /// `indexRemap` remaps indices of the old memref access to a new set of indices /// that are used to index the memref. Additional input operands to indexRemap -/// can be optionally provided, and they are added at the start of its input -/// list. `indexRemap` is expected to have only dimensional inputs, and the -/// number of its inputs equal to extraOperands.size() plus rank of the memref. -/// 'extraOperands' is an optional argument that corresponds to additional -/// operands (inputs) for indexRemap at the beginning of its input list. +/// can be optionally provided in `extraOperands`, and they occupy the start +/// of its input list. `indexRemap`'s dimensional inputs are expected to +/// correspond to memref's indices, and its symbolic inputs if any should be +/// provided in `symbolOperands`. /// /// `domInstFilter`, if non-null, restricts the replacement to only those /// operations that are dominated by the former; similarly, `postDomInstFilter` @@ -70,6 +70,7 @@ LogicalResult replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, ArrayRef extraIndices = {}, AffineMap indexRemap = AffineMap(), ArrayRef extraOperands = {}, + ArrayRef symbolOperands = {}, Operation *domInstFilter = nullptr, Operation *postDomInstFilter = nullptr); @@ -79,7 +80,8 @@ LogicalResult replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, Operation *op, ArrayRef extraIndices = {}, AffineMap indexRemap = AffineMap(), - ArrayRef extraOperands = {}); + ArrayRef extraOperands = {}, + ArrayRef symbolOperands = {}); /// Rewrites the memref defined by this alloc op to have an identity layout map /// and updates all its indexing uses. Returns failure if any of its uses diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 8257bf05f5d4..188165b94e1d 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -955,6 +955,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, LogicalResult res = replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, /*extraOperands=*/outerIVs, + /*symbolOperands=*/{}, /*domInstFilter=*/&*forOp.getBody()->begin()); assert(succeeded(res) && "replaceAllMemrefUsesWith should always succeed here"); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index d8d8dba96207..b4d67262c17d 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -122,6 +122,7 @@ static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) { /*extraIndices=*/{ivModTwoOp}, /*indexRemap=*/AffineMap(), /*extraOperands=*/{}, + /*symbolOperands=*/{}, /*domInstFilter=*/&*forOp.getBody()->begin()))) { LLVM_DEBUG( forOp.emitError("memref replacement for double buffering failed")); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index e038512c0c0a..0c9a666a6ec1 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -1548,6 +1548,7 @@ static LogicalResult generateCopy( replaceAllMemRefUsesWith(memref, fastMemRef, /*extraIndices=*/{}, indexRemap, /*extraOperands=*/regionSymbols, + /*symbolOperands=*/{}, /*domInstFilter=*/&*begin, /*postDomInstFilter=*/&*postDomFilter); diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index e57d40e5a1c1..d6400ac50ed4 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -62,14 +62,17 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, Operation *op, ArrayRef extraIndices, AffineMap indexRemap, - ArrayRef extraOperands) { + ArrayRef extraOperands, + ArrayRef symbolOperands) { unsigned newMemRefRank = newMemRef->getType().cast().getRank(); (void)newMemRefRank; // unused in opt mode unsigned oldMemRefRank = oldMemRef->getType().cast().getRank(); - (void)oldMemRefRank; + (void)oldMemRefRank; // unused in opt mode if (indexRemap) { - assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected"); - assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank); + assert(indexRemap.getNumSymbols() == symbolOperands.size() && + "symbolic operand count mistmatch"); + assert(indexRemap.getNumInputs() == + extraOperands.size() + oldMemRefRank + symbolOperands.size()); assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank); } else { assert(oldMemRefRank + extraIndices.size() == newMemRefRank); @@ -131,9 +134,11 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, // provided. The indices of a memref come right after it, i.e., // at position memRefOperandPos + 1. SmallVector remapOperands; - remapOperands.reserve(extraOperands.size() + oldMemRefRank); + remapOperands.reserve(extraOperands.size() + oldMemRefRank + + symbolOperands.size()); remapOperands.append(extraOperands.begin(), extraOperands.end()); remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end()); + remapOperands.append(symbolOperands.begin(), symbolOperands.end()); SmallVector remapOutputs; remapOutputs.reserve(oldMemRefRank); @@ -226,6 +231,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, ArrayRef extraIndices, AffineMap indexRemap, ArrayRef extraOperands, + ArrayRef symbolOperands, Operation *domInstFilter, Operation *postDomInstFilter) { unsigned newMemRefRank = newMemRef->getType().cast().getRank(); @@ -233,8 +239,10 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, unsigned oldMemRefRank = oldMemRef->getType().cast().getRank(); (void)oldMemRefRank; if (indexRemap) { - assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected"); - assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank); + assert(indexRemap.getNumSymbols() == symbolOperands.size() && + "symbol operand count mismatch"); + assert(indexRemap.getNumInputs() == + extraOperands.size() + oldMemRefRank + symbolOperands.size()); assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank); } else { assert(oldMemRefRank + extraIndices.size() == newMemRefRank); @@ -287,7 +295,8 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, for (auto *op : opsToReplace) { if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef, op, extraIndices, - indexRemap, extraOperands))) + indexRemap, extraOperands, + symbolOperands))) llvm_unreachable("memref replacement guaranteed to succeed here"); } @@ -446,6 +455,8 @@ LogicalResult mlir::normalizeMemRef(AllocOp allocOp) { } auto *oldMemRef = allocOp.getResult(); + SmallVector symbolOperands(allocOp.getSymbolicOperands()); + auto newMemRefType = b.getMemRefType(newShape, memrefType.getElementType(), b.getMultiDimIdentityMap(newRank)); auto newAlloc = b.create(allocOp.getLoc(), newMemRefType); @@ -453,7 +464,9 @@ LogicalResult mlir::normalizeMemRef(AllocOp allocOp) { // Replace all uses of the old memref. if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc, /*extraIndices=*/{}, - /*indexRemap=*/layoutMap))) { + /*indexRemap=*/layoutMap, + /*extraOperands=*/{}, + /*symbolOperands=*/symbolOperands))) { // If it failed (due to escapes for example), bail out. newAlloc.erase(); return failure(); diff --git a/mlir/test/Transforms/memref-normalize.mlir b/mlir/test/Transforms/memref-normalize.mlir index c4973e8eceea..e9b63624120d 100644 --- a/mlir/test/Transforms/memref-normalize.mlir +++ b/mlir/test/Transforms/memref-normalize.mlir @@ -96,6 +96,21 @@ func @strided_cumulative() { return } +// Symbolic operand for alloc, although unused. Tests replaceAllMemRefUsesWith +// when the index remap has symbols. +// CHECK-LABEL: func @symbolic_operands +func @symbolic_operands(%s : index) { + // CHECK: alloc() : memref<100xf32> + %A = alloc()[%s] : memref<10x10xf32, (d0,d1)[s0] -> (10*d0 + d1)> + affine.for %i = 0 to 10 { + affine.for %j = 0 to 10 { + // CHECK: affine.load %{{.*}}[%{{.*}} * 10 + %{{.*}}] : memref<100xf32> + affine.load %A[%i, %j] : memref<10x10xf32, (d0,d1)[s0] -> (10*d0 + d1)> + } + } + return +} + // Memref escapes; no normalization. // CHECK-LABEL: func @escaping() -> memref<64xf32, #map{{[0-9]+}}> func @escaping() -> memref<64xf32, (d0) -> (d0 + 2)> {