[MLIR] Expose makeCanonicalStridedLayoutExpr in StandardTypes.h.

Differential Revision: https://reviews.llvm.org/D75575
This commit is contained in:
Alexander Belyaev 2020-03-04 00:37:50 +01:00
parent bdad0a1b79
commit e0ce852277
2 changed files with 46 additions and 43 deletions

View File

@ -658,6 +658,20 @@ AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset,
/// `t` with simplified layout.
MemRefType canonicalizeStridedLayout(MemRefType t);
/// Given MemRef `sizes` that are either static or dynamic, returns the
/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
/// once a dynamic dimension is encountered, all canonical strides become
/// dynamic and need to be encoded with a different symbol.
/// For canonical strides expressions, the offset is always 0 and and fastest
/// varying stride is always `1`.
///
/// Examples:
/// - memref<3x4x5xf32> has canonical stride expression `20*d0 + 5*d1 + d2`.
/// - memref<3x?x5xf32> has canonical stride expression `s0*d0 + 5*d1 + d2`.
/// - memref<3x4x?xf32> has canonical stride expression `s1*d0 + s0*d1 + d2`.
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
MLIRContext *context);
/// Return true if the layout for `t` is compatible with strided semantics.
bool isStrided(MemRefType t);

View File

@ -456,49 +456,6 @@ UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
return success();
}
/// Given MemRef `sizes` that are either static or dynamic, returns the
/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
/// once a dynamic dimension is encountered, all canonical strides become
/// dynamic and need to be encoded with a different symbol.
/// For canonical strides expressions, the offset is always 0 and and fastest
/// varying stride is always `1`.
///
/// Examples:
/// - memref<3x4x5xf32> has canonical stride expression `20*d0 + 5*d1 + d2`.
/// - memref<3x?x5xf32> has canonical stride expression `s0*d0 + 5*d1 + d2`.
/// - memref<3x4x?xf32> has canonical stride expression `s1*d0 + s0*d1 + d2`.
static AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
MLIRContext *context) {
AffineExpr expr;
bool dynamicPoisonBit = false;
unsigned nSymbols = 0;
int64_t runningSize = 1;
unsigned rank = sizes.size();
for (auto en : llvm::enumerate(llvm::reverse(sizes))) {
auto size = en.value();
auto position = rank - 1 - en.index();
// Degenerate case, no size =-> no stride
if (size == 0)
continue;
auto d = getAffineDimExpr(position, context);
// Static case: stride = runningSize and runningSize *= size.
if (!dynamicPoisonBit) {
auto cst = getAffineConstantExpr(runningSize, context);
expr = expr ? expr + cst * d : cst * d;
if (size > 0)
runningSize *= size;
else
// From now on bail into dynamic mode.
dynamicPoisonBit = true;
continue;
}
// Dynamic case, new symbol for each new stride.
auto sym = getAffineSymbolExpr(nSymbols++, context);
expr = expr ? expr + d * sym : d * sym;
}
return simplifyAffineExpr(expr, rank, nSymbols);
}
// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
// i.e. single term). Accumulate the AffineExpr into the existing one.
static void extractStridesFromTerm(AffineExpr e,
@ -766,6 +723,38 @@ MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
return MemRefType::Builder(t).setAffineMaps({});
}
AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
MLIRContext *context) {
AffineExpr expr;
bool dynamicPoisonBit = false;
unsigned nSymbols = 0;
int64_t runningSize = 1;
unsigned rank = sizes.size();
for (auto en : llvm::enumerate(llvm::reverse(sizes))) {
auto size = en.value();
auto position = rank - 1 - en.index();
// Degenerate case, no size =-> no stride
if (size == 0)
continue;
auto d = getAffineDimExpr(position, context);
// Static case: stride = runningSize and runningSize *= size.
if (!dynamicPoisonBit) {
auto cst = getAffineConstantExpr(runningSize, context);
expr = expr ? expr + cst * d : cst * d;
if (size > 0)
runningSize *= size;
else
// From now on bail into dynamic mode.
dynamicPoisonBit = true;
continue;
}
// Dynamic case, new symbol for each new stride.
auto sym = getAffineSymbolExpr(nSymbols++, context);
expr = expr ? expr + d * sym : d * sym;
}
return simplifyAffineExpr(expr, rank, nSymbols);
}
/// Return true if the layout for `t` is compatible with strided semantics.
bool mlir::isStrided(MemRefType t) {
int64_t offset;