DynamicMemRefType: iteration and access by indices

The methods to perform such operations have been implemented for the DynamicMemRefType in a way that is similar to the implementation for StridedMemRefType. Up until here one could pass an unranked memref to the library, and thus obtain a “dynamic” memref descriptor, but then there would have been no possibility to operate on its content.

Differential Revision: https://reviews.llvm.org/D131359
This commit is contained in:
Michele Scuttari 2022-08-15 17:17:54 +02:00
parent 7a73ab5818
commit b8ecf32f81
No known key found for this signature in database
GPG Key ID: E79E7BDFEE4B62D4
3 changed files with 242 additions and 16 deletions

View File

@ -36,6 +36,7 @@
#include <cassert>
#include <cstdint>
#include <initializer_list>
#include <vector>
//===----------------------------------------------------------------------===//
// Codegen-compatible structures for Vector type.
@ -209,13 +210,19 @@ struct StridedMemRefType<T, 0> {
template <typename T, int Rank>
class StridedMemrefIterator {
public:
using iterator_category = std::forward_iterator_tag;
using value_type = T;
using difference_type = std::ptrdiff_t;
using pointer = T *;
using reference = T &;
StridedMemrefIterator(StridedMemRefType<T, Rank> &descriptor,
int64_t offset = 0)
: offset(offset), descriptor(descriptor) {}
: offset(offset), descriptor(&descriptor) {}
StridedMemrefIterator<T, Rank> &operator++() {
int dim = Rank - 1;
while (dim >= 0 && indices[dim] == (descriptor.sizes[dim] - 1)) {
offset -= indices[dim] * descriptor.strides[dim];
while (dim >= 0 && indices[dim] == (descriptor->sizes[dim] - 1)) {
offset -= indices[dim] * descriptor->strides[dim];
indices[dim] = 0;
--dim;
}
@ -224,17 +231,17 @@ public:
return *this;
}
++indices[dim];
offset += descriptor.strides[dim];
offset += descriptor->strides[dim];
return *this;
}
T &operator*() { return descriptor.data[offset]; }
T *operator->() { return &descriptor.data[offset]; }
reference operator*() { return descriptor->data[offset]; }
pointer operator->() { return &descriptor->data[offset]; }
const std::array<int64_t, Rank> &getIndices() { return indices; }
bool operator==(const StridedMemrefIterator &other) const {
return other.offset == offset && &other.descriptor == &descriptor;
return other.offset == offset && other.descriptor == descriptor;
}
bool operator!=(const StridedMemrefIterator &other) const {
@ -245,16 +252,24 @@ private:
/// Offset in the buffer. This can be derived from the indices and the
/// descriptor.
int64_t offset = 0;
/// Array of indices in the multi-dimensional memref.
std::array<int64_t, Rank> indices = {};
/// Descriptor for the strided memref.
StridedMemRefType<T, Rank> &descriptor;
StridedMemRefType<T, Rank> *descriptor;
};
/// Iterate over all elements in a 0-ranked strided memref.
template <typename T>
class StridedMemrefIterator<T, 0> {
public:
using iterator_category = std::forward_iterator_tag;
using value_type = T;
using difference_type = std::ptrdiff_t;
using pointer = T *;
using reference = T &;
StridedMemrefIterator(StridedMemRefType<T, 0> &descriptor, int64_t offset = 0)
: elt(descriptor.data + offset) {}
@ -263,8 +278,8 @@ public:
return *this;
}
T &operator*() { return *elt; }
T *operator->() { return elt; }
reference operator*() { return *elt; }
pointer operator->() { return elt; }
// There are no indices for a 0-ranked memref, but this API is provided for
// consistency with the general case.
@ -301,10 +316,20 @@ struct UnrankedMemRefType {
//===----------------------------------------------------------------------===//
// DynamicMemRefType type.
//===----------------------------------------------------------------------===//
template <typename T>
class DynamicMemRefIterator;
// A reference to one of the StridedMemRef types.
template <typename T>
class DynamicMemRefType {
public:
int64_t rank;
T *basePtr;
T *data;
int64_t offset;
const int64_t *sizes;
const int64_t *strides;
explicit DynamicMemRefType(const StridedMemRefType<T, 0> &memRef)
: rank(0), basePtr(memRef.basePtr), data(memRef.data),
offset(memRef.offset), sizes(nullptr), strides(nullptr) {}
@ -322,12 +347,113 @@ public:
strides = sizes + rank;
}
int64_t rank;
T *basePtr;
T *data;
int64_t offset;
const int64_t *sizes;
const int64_t *strides;
template <typename Range,
typename sfinae = decltype(std::declval<Range>().begin())>
T &operator[](Range &&indices) {
assert(indices.size() == rank &&
"indices should match rank in memref subscript");
if (rank == 0)
return data[offset];
int64_t curOffset = offset;
for (int dim = rank - 1; dim >= 0; --dim) {
int64_t currentIndex = *(indices.begin() + dim);
assert(currentIndex < sizes[dim] && "Index overflow");
curOffset += currentIndex * strides[dim];
}
return data[curOffset];
}
DynamicMemRefIterator<T> begin() { return {*this}; }
DynamicMemRefIterator<T> end() { return {*this, -1}; }
// This operator[] is extremely slow and only for sugaring purposes.
DynamicMemRefType<T> operator[](int64_t idx) {
assert(rank > 0 && "can't make a subscript of a zero ranked array");
DynamicMemRefType<T> res(*this);
--res.rank;
res.offset += idx * res.strides[0];
++res.sizes;
++res.strides;
return res;
}
// This operator* can be used in conjunction with the previous operator[] in
// order to access the underlying value in case of zero-ranked memref.
T &operator*() {
assert(rank == 0 && "not a zero-ranked memRef");
return data[offset];
}
private:
DynamicMemRefType(const DynamicMemRefType<T> &other)
: rank(other.rank), basePtr(other.basePtr), data(other.data),
offset(other.offset), strides(other.strides) {}
};
/// Iterate over all elements in a dynamic memref.
template <typename T>
class DynamicMemRefIterator {
public:
using iterator_category = std::forward_iterator_tag;
using value_type = T;
using difference_type = std::ptrdiff_t;
using pointer = T *;
using reference = T &;
DynamicMemRefIterator(DynamicMemRefType<T> &descriptor, int64_t offset = 0)
: offset(offset), descriptor(&descriptor) {
indices.resize(descriptor.rank, 0);
}
DynamicMemRefIterator<T> &operator++() {
if (descriptor->rank == 0) {
offset = -1;
return *this;
}
int dim = descriptor->rank - 1;
while (dim >= 0 && indices[dim] == (descriptor->sizes[dim] - 1)) {
offset -= indices[dim] * descriptor->strides[dim];
indices[dim] = 0;
--dim;
}
if (dim < 0) {
offset = -1;
return *this;
}
++indices[dim];
offset += descriptor->strides[dim];
return *this;
}
reference operator*() { return descriptor->data[offset]; }
pointer operator->() { return &descriptor->data[offset]; }
const std::vector<int64_t> &getIndices() { return indices; }
bool operator==(const DynamicMemRefIterator &other) const {
return other.offset == offset && other.descriptor == descriptor;
}
bool operator!=(const DynamicMemRefIterator &other) const {
return !(*this == other);
}
private:
/// Offset in the buffer. This can be derived from the indices and the
/// descriptor.
int64_t offset = 0;
/// Array of indices in the multi-dimensional memref.
std::vector<int64_t> indices = {};
/// Descriptor for the dynamic memref.
DynamicMemRefType<T> *descriptor;
};
//===----------------------------------------------------------------------===//

View File

@ -1,4 +1,5 @@
add_mlir_unittest(MLIRExecutionEngineTests
DynamicMemRef.cpp
Invoke.cpp
)
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)

View File

@ -0,0 +1,99 @@
//===- DynamicMemRef.cpp ----------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/ExecutionEngine/CRunnerUtils.h"
#include "llvm/ADT/SmallVector.h"
#include "gmock/gmock.h"
using namespace ::mlir;
using namespace ::testing;
TEST(DynamicMemRef, rankZero) {
int data = 57;
StridedMemRefType<int, 0> memRef;
memRef.basePtr = &data;
memRef.data = &data;
memRef.offset = 0;
DynamicMemRefType<int> dynamicMemRef(memRef);
llvm::SmallVector<int, 1> values(dynamicMemRef.begin(), dynamicMemRef.end());
EXPECT_THAT(values, ElementsAre(57));
}
TEST(DynamicMemRef, rankOne) {
std::array<int, 3> data;
for (size_t i = 0; i < data.size(); ++i) {
data[i] = i;
}
StridedMemRefType<int, 1> memRef;
memRef.basePtr = data.data();
memRef.data = data.data();
memRef.offset = 0;
memRef.sizes[0] = 3;
memRef.strides[0] = 1;
DynamicMemRefType<int> dynamicMemRef(memRef);
llvm::SmallVector<int, 3> values(dynamicMemRef.begin(), dynamicMemRef.end());
EXPECT_THAT(values, ElementsAreArray(data));
for (int64_t i = 0; i < 3; ++i) {
EXPECT_EQ(*dynamicMemRef[i], data[i]);
}
}
TEST(DynamicMemRef, rankTwo) {
std::array<int, 6> data;
for (size_t i = 0; i < data.size(); ++i) {
data[i] = i;
}
StridedMemRefType<int, 2> memRef;
memRef.basePtr = data.data();
memRef.data = data.data();
memRef.offset = 0;
memRef.sizes[0] = 2;
memRef.sizes[1] = 3;
memRef.strides[0] = 3;
memRef.strides[1] = 1;
DynamicMemRefType<int> dynamicMemRef(memRef);
llvm::SmallVector<int, 6> values(dynamicMemRef.begin(), dynamicMemRef.end());
EXPECT_THAT(values, ElementsAreArray(data));
}
TEST(DynamicMemRef, rankThree) {
std::array<int, 24> data;
for (size_t i = 0; i < data.size(); ++i) {
data[i] = i;
}
StridedMemRefType<int, 3> memRef;
memRef.basePtr = data.data();
memRef.data = data.data();
memRef.offset = 0;
memRef.sizes[0] = 2;
memRef.sizes[1] = 3;
memRef.sizes[2] = 4;
memRef.strides[0] = 12;
memRef.strides[1] = 4;
memRef.strides[2] = 1;
DynamicMemRefType<int> dynamicMemRef(memRef);
llvm::SmallVector<int, 24> values(dynamicMemRef.begin(), dynamicMemRef.end());
EXPECT_THAT(values, ElementsAreArray(data));
}