forked from OSchip/llvm-project
[MLIR] Expose makeCanonicalStridedLayoutExpr in StandardTypes.h.
Differential Revision: https://reviews.llvm.org/D75575
This commit is contained in:
parent
bdad0a1b79
commit
e0ce852277
|
@ -658,6 +658,20 @@ AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset,
|
|||
/// `t` with simplified layout.
|
||||
MemRefType canonicalizeStridedLayout(MemRefType t);
|
||||
|
||||
/// Given MemRef `sizes` that are either static or dynamic, returns the
|
||||
/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
|
||||
/// once a dynamic dimension is encountered, all canonical strides become
|
||||
/// dynamic and need to be encoded with a different symbol.
|
||||
/// For canonical strides expressions, the offset is always 0 and and fastest
|
||||
/// varying stride is always `1`.
|
||||
///
|
||||
/// Examples:
|
||||
/// - memref<3x4x5xf32> has canonical stride expression `20*d0 + 5*d1 + d2`.
|
||||
/// - memref<3x?x5xf32> has canonical stride expression `s0*d0 + 5*d1 + d2`.
|
||||
/// - memref<3x4x?xf32> has canonical stride expression `s1*d0 + s0*d1 + d2`.
|
||||
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
|
||||
MLIRContext *context);
|
||||
|
||||
/// Return true if the layout for `t` is compatible with strided semantics.
|
||||
bool isStrided(MemRefType t);
|
||||
|
||||
|
|
|
@ -456,49 +456,6 @@ UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
|
|||
return success();
|
||||
}
|
||||
|
||||
/// Given MemRef `sizes` that are either static or dynamic, returns the
|
||||
/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
|
||||
/// once a dynamic dimension is encountered, all canonical strides become
|
||||
/// dynamic and need to be encoded with a different symbol.
|
||||
/// For canonical strides expressions, the offset is always 0 and and fastest
|
||||
/// varying stride is always `1`.
|
||||
///
|
||||
/// Examples:
|
||||
/// - memref<3x4x5xf32> has canonical stride expression `20*d0 + 5*d1 + d2`.
|
||||
/// - memref<3x?x5xf32> has canonical stride expression `s0*d0 + 5*d1 + d2`.
|
||||
/// - memref<3x4x?xf32> has canonical stride expression `s1*d0 + s0*d1 + d2`.
|
||||
static AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
|
||||
MLIRContext *context) {
|
||||
AffineExpr expr;
|
||||
bool dynamicPoisonBit = false;
|
||||
unsigned nSymbols = 0;
|
||||
int64_t runningSize = 1;
|
||||
unsigned rank = sizes.size();
|
||||
for (auto en : llvm::enumerate(llvm::reverse(sizes))) {
|
||||
auto size = en.value();
|
||||
auto position = rank - 1 - en.index();
|
||||
// Degenerate case, no size =-> no stride
|
||||
if (size == 0)
|
||||
continue;
|
||||
auto d = getAffineDimExpr(position, context);
|
||||
// Static case: stride = runningSize and runningSize *= size.
|
||||
if (!dynamicPoisonBit) {
|
||||
auto cst = getAffineConstantExpr(runningSize, context);
|
||||
expr = expr ? expr + cst * d : cst * d;
|
||||
if (size > 0)
|
||||
runningSize *= size;
|
||||
else
|
||||
// From now on bail into dynamic mode.
|
||||
dynamicPoisonBit = true;
|
||||
continue;
|
||||
}
|
||||
// Dynamic case, new symbol for each new stride.
|
||||
auto sym = getAffineSymbolExpr(nSymbols++, context);
|
||||
expr = expr ? expr + d * sym : d * sym;
|
||||
}
|
||||
return simplifyAffineExpr(expr, rank, nSymbols);
|
||||
}
|
||||
|
||||
// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
|
||||
// i.e. single term). Accumulate the AffineExpr into the existing one.
|
||||
static void extractStridesFromTerm(AffineExpr e,
|
||||
|
@ -766,6 +723,38 @@ MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
|
|||
return MemRefType::Builder(t).setAffineMaps({});
|
||||
}
|
||||
|
||||
AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
|
||||
MLIRContext *context) {
|
||||
AffineExpr expr;
|
||||
bool dynamicPoisonBit = false;
|
||||
unsigned nSymbols = 0;
|
||||
int64_t runningSize = 1;
|
||||
unsigned rank = sizes.size();
|
||||
for (auto en : llvm::enumerate(llvm::reverse(sizes))) {
|
||||
auto size = en.value();
|
||||
auto position = rank - 1 - en.index();
|
||||
// Degenerate case, no size =-> no stride
|
||||
if (size == 0)
|
||||
continue;
|
||||
auto d = getAffineDimExpr(position, context);
|
||||
// Static case: stride = runningSize and runningSize *= size.
|
||||
if (!dynamicPoisonBit) {
|
||||
auto cst = getAffineConstantExpr(runningSize, context);
|
||||
expr = expr ? expr + cst * d : cst * d;
|
||||
if (size > 0)
|
||||
runningSize *= size;
|
||||
else
|
||||
// From now on bail into dynamic mode.
|
||||
dynamicPoisonBit = true;
|
||||
continue;
|
||||
}
|
||||
// Dynamic case, new symbol for each new stride.
|
||||
auto sym = getAffineSymbolExpr(nSymbols++, context);
|
||||
expr = expr ? expr + d * sym : d * sym;
|
||||
}
|
||||
return simplifyAffineExpr(expr, rank, nSymbols);
|
||||
}
|
||||
|
||||
/// Return true if the layout for `t` is compatible with strided semantics.
|
||||
bool mlir::isStrided(MemRefType t) {
|
||||
int64_t offset;
|
||||
|
|
Loading…
Reference in New Issue