forked from OSchip/llvm-project
Extract MemRefType::getStridesAndOffset as a free function and fix dynamic offset determination.
This also adds coverage with a missing test, which uncovered a bug in the conditional for testing whether an offset is dynamic or not. PiperOrigin-RevId: 272505798
This commit is contained in:
parent
f294e0e513
commit
9604bb6269
|
@ -367,31 +367,8 @@ public:
|
|||
/// Returns the memory space in which data referred to by this memref resides.
|
||||
unsigned getMemorySpace() const;
|
||||
|
||||
/// Returns the strides of the MemRef if the layout map is in strided form.
|
||||
/// MemRefs with layout maps in strided form include:
|
||||
/// 1. empty or identity layout map, in which case the stride information is
|
||||
/// the canonical form computed from sizes;
|
||||
/// 2. single affine map layout of the form `K + k0 * d0 + ... kn * dn`,
|
||||
/// where K and ki's are constants or symbols.
|
||||
///
|
||||
/// A stride specification is a list of integer values that are either static
|
||||
/// or dynamic (encoded with kDynamicStrideOrOffset). Strides encode the
|
||||
/// distance in the number of elements between successive entries along a
|
||||
/// particular dimension. For example, `memref<42x16xf32, (64 * d0 + d1)>`
|
||||
/// specifies a view into a non-contiguous memory region of `42` by `16` `f32`
|
||||
/// elements in which the distance between two consecutive elements along the
|
||||
/// outer dimension is `1` and the distance between two consecutive elements
|
||||
/// along the inner dimension is `64`.
|
||||
///
|
||||
/// If a simple strided form cannot be extracted from the composition of the
|
||||
/// layout map, returns llvm::None.
|
||||
///
|
||||
/// The convention is that the strides for dimensions d0, .. dn appear in
|
||||
/// order to make indexing intuitive into the result.
|
||||
static constexpr int64_t kDynamicStrideOrOffset =
|
||||
std::numeric_limits<int64_t>::min();
|
||||
LogicalResult getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
|
||||
int64_t &offset) const;
|
||||
|
||||
static bool kindof(unsigned kind) { return kind == StandardTypes::MemRef; }
|
||||
|
||||
|
@ -492,6 +469,31 @@ public:
|
|||
static bool kindof(unsigned kind) { return kind == StandardTypes::None; }
|
||||
};
|
||||
|
||||
/// Returns the strides of the MemRef if the layout map is in strided form.
|
||||
/// MemRefs with layout maps in strided form include:
|
||||
/// 1. empty or identity layout map, in which case the stride information is
|
||||
/// the canonical form computed from sizes;
|
||||
/// 2. single affine map layout of the form `K + k0 * d0 + ... kn * dn`,
|
||||
/// where K and ki's are constants or symbols.
|
||||
///
|
||||
/// A stride specification is a list of integer values that are either static
|
||||
/// or dynamic (encoded with kDynamicStrideOrOffset). Strides encode the
|
||||
/// distance in the number of elements between successive entries along a
|
||||
/// particular dimension. For example, `memref<42x16xf32, (64 * d0 + d1)>`
|
||||
/// specifies a view into a non-contiguous memory region of `42` by `16` `f32`
|
||||
/// elements in which the distance between two consecutive elements along the
|
||||
/// outer dimension is `1` and the distance between two consecutive elements
|
||||
/// along the inner dimension is `64`.
|
||||
///
|
||||
/// If a simple strided form cannot be extracted from the composition of the
|
||||
/// layout map, returns llvm::None.
|
||||
///
|
||||
/// The convention is that the strides for dimensions d0, .. dn appear in
|
||||
/// order to make indexing intuitive into the result.
|
||||
LogicalResult getStridesAndOffset(MemRefType t,
|
||||
SmallVectorImpl<int64_t> &strides,
|
||||
int64_t &offset);
|
||||
|
||||
/// Given a list of strides (in which MemRefType::kDynamicStrideOrOffset
|
||||
/// represents a dynamic value), return the single result AffineMap which
|
||||
/// represents the linearized strided layout map. Dimensions correspond to the
|
||||
|
|
|
@ -152,7 +152,7 @@ static unsigned kStridePosInMemRefDescriptor = 3;
|
|||
Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
bool strideSuccess = succeeded(type.getStridesAndOffset(strides, offset));
|
||||
bool strideSuccess = succeeded(getStridesAndOffset(type, strides, offset));
|
||||
assert(strideSuccess &&
|
||||
"Non-strided layout maps must have been normalized away");
|
||||
(void)strideSuccess;
|
||||
|
@ -571,14 +571,14 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
|||
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
auto successStrides = type.getStridesAndOffset(strides, offset);
|
||||
auto successStrides = getStridesAndOffset(type, strides, offset);
|
||||
if (failed(successStrides))
|
||||
return matchFailure();
|
||||
|
||||
// Dynamic strides are ok if they can be deduced from dynamic sizes (which
|
||||
// is guaranteed when succeeded(successStrides)).
|
||||
// Dynamic offset however can never be alloc'ed.
|
||||
if (offset != MemRefType::kDynamicStrideOrOffset)
|
||||
// is guaranteed when succeeded(successStrides)). Dynamic offset however can
|
||||
// never be alloc'ed.
|
||||
if (offset == MemRefType::kDynamicStrideOrOffset)
|
||||
return matchFailure();
|
||||
|
||||
return matchSuccess();
|
||||
|
@ -652,7 +652,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
|||
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
auto successStrides = type.getStridesAndOffset(strides, offset);
|
||||
auto successStrides = getStridesAndOffset(type, strides, offset);
|
||||
assert(succeeded(successStrides) && "unexpected non-strided memref");
|
||||
(void)successStrides;
|
||||
assert(offset != MemRefType::kDynamicStrideOrOffset &&
|
||||
|
@ -952,7 +952,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
|
|||
auto ptrType = getMemRefElementPtrType(type, this->lowering);
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
auto successStrides = type.getStridesAndOffset(strides, offset);
|
||||
auto successStrides = getStridesAndOffset(type, strides, offset);
|
||||
assert(succeeded(successStrides) && "unexpected non-strided memref");
|
||||
(void)successStrides;
|
||||
return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides,
|
||||
|
|
|
@ -322,7 +322,7 @@ void mlir::linalg::SliceOp::build(Builder *b, OperationState &result,
|
|||
auto memRefType = base->getType().cast<MemRefType>();
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
auto res = memRefType.getStridesAndOffset(strides, offset);
|
||||
auto res = getStridesAndOffset(memRefType, strides, offset);
|
||||
assert(succeeded(res) && strides.size() == indexings.size());
|
||||
(void)res;
|
||||
|
||||
|
@ -466,7 +466,7 @@ void mlir::linalg::TransposeOp::build(Builder *b, OperationState &result,
|
|||
// Compute permuted strides.
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
auto res = memRefType.getStridesAndOffset(strides, offset);
|
||||
auto res = getStridesAndOffset(memRefType, strides, offset);
|
||||
assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
|
||||
(void)res;
|
||||
auto map = makeStridedLinearLayoutMap(strides, offset, b->getContext());
|
||||
|
|
|
@ -23,7 +23,6 @@
|
|||
#include "mlir/Support/STLExtras.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::detail;
|
||||
|
@ -544,9 +543,10 @@ static void extractStridesFromTerm(AffineExpr e,
|
|||
llvm_unreachable("unexpected binary operation");
|
||||
}
|
||||
|
||||
LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
|
||||
int64_t &offset) const {
|
||||
auto affineMaps = getAffineMaps();
|
||||
LogicalResult mlir::getStridesAndOffset(MemRefType t,
|
||||
SmallVectorImpl<int64_t> &strides,
|
||||
int64_t &offset) {
|
||||
auto affineMaps = t.getAffineMaps();
|
||||
// For now strides are only computed on a single affine map with a single
|
||||
// result (i.e. the closed subset of linearization maps that are compatible
|
||||
// with striding semantics).
|
||||
|
@ -555,12 +555,12 @@ LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
|
|||
return failure();
|
||||
AffineExpr stridedExpr;
|
||||
if (affineMaps.empty() || affineMaps[0].isIdentity()) {
|
||||
if (getRank() == 0) {
|
||||
if (t.getRank() == 0) {
|
||||
// Handle 0-D corner case.
|
||||
offset = 0;
|
||||
return success();
|
||||
}
|
||||
stridedExpr = makeCanonicalStridedLayoutExpr(getShape(), getContext());
|
||||
stridedExpr = makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
|
||||
} else if (affineMaps[0].getNumResults() == 1) {
|
||||
stridedExpr = affineMaps[0].getResult(0);
|
||||
}
|
||||
|
@ -568,9 +568,9 @@ LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
|
|||
return failure();
|
||||
|
||||
bool failed = false;
|
||||
strides = SmallVector<int64_t, 4>(getRank(), 0);
|
||||
strides = SmallVector<int64_t, 4>(t.getRank(), 0);
|
||||
bool seenOffset = false;
|
||||
SmallVector<bool, 4> seen(getRank(), false);
|
||||
SmallVector<bool, 4> seen(t.getRank(), false);
|
||||
if (stridedExpr.isa<AffineBinaryOpExpr>()) {
|
||||
stridedExpr.walk([&](AffineExpr e) {
|
||||
if (!failed)
|
||||
|
@ -688,6 +688,6 @@ AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
|
|||
bool mlir::isStrided(MemRefType t) {
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> stridesAndOffset;
|
||||
auto res = t.getStridesAndOffset(stridesAndOffset, offset);
|
||||
auto res = getStridesAndOffset(t, stridesAndOffset, offset);
|
||||
return succeeded(res);
|
||||
}
|
||||
|
|
|
@ -11,3 +11,9 @@ func @address_space(%arg0 : memref<32xf32, (d0) -> (d0), 7>) {
|
|||
std.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @strided_memref(
|
||||
func @strided_memref(%ind: index) {
|
||||
%0 = alloc()[%ind] : memref<32x64xf32, (i, j)[M] -> (32 + M * i + j)>
|
||||
std.return
|
||||
}
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ void TestMemRefStrideCalculation::runOnFunction() {
|
|||
auto memrefType = allocOp.getResult()->getType().cast<MemRefType>();
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
if (failed(memrefType.getStridesAndOffset(strides, offset))) {
|
||||
if (failed(getStridesAndOffset(memrefType, strides, offset))) {
|
||||
llvm::outs() << "MemRefType " << memrefType << " cannot be converted to "
|
||||
<< "strided form\n";
|
||||
return;
|
||||
|
|
Loading…
Reference in New Issue