forked from OSchip/llvm-project
DynamicMemRefType: iteration and access by indices
The methods to perform such operations have been implemented for the DynamicMemRefType in a way that is similar to the implementation for StridedMemRefType. Up until here one could pass an unranked memref to the library, and thus obtain a “dynamic” memref descriptor, but then there would have been no possibility to operate on its content. Differential Revision: https://reviews.llvm.org/D131359
This commit is contained in:
parent
7a73ab5818
commit
b8ecf32f81
|
@ -36,6 +36,7 @@
|
|||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Codegen-compatible structures for Vector type.
|
||||
|
@ -209,13 +210,19 @@ struct StridedMemRefType<T, 0> {
|
|||
template <typename T, int Rank>
|
||||
class StridedMemrefIterator {
|
||||
public:
|
||||
using iterator_category = std::forward_iterator_tag;
|
||||
using value_type = T;
|
||||
using difference_type = std::ptrdiff_t;
|
||||
using pointer = T *;
|
||||
using reference = T &;
|
||||
|
||||
StridedMemrefIterator(StridedMemRefType<T, Rank> &descriptor,
|
||||
int64_t offset = 0)
|
||||
: offset(offset), descriptor(descriptor) {}
|
||||
: offset(offset), descriptor(&descriptor) {}
|
||||
StridedMemrefIterator<T, Rank> &operator++() {
|
||||
int dim = Rank - 1;
|
||||
while (dim >= 0 && indices[dim] == (descriptor.sizes[dim] - 1)) {
|
||||
offset -= indices[dim] * descriptor.strides[dim];
|
||||
while (dim >= 0 && indices[dim] == (descriptor->sizes[dim] - 1)) {
|
||||
offset -= indices[dim] * descriptor->strides[dim];
|
||||
indices[dim] = 0;
|
||||
--dim;
|
||||
}
|
||||
|
@ -224,17 +231,17 @@ public:
|
|||
return *this;
|
||||
}
|
||||
++indices[dim];
|
||||
offset += descriptor.strides[dim];
|
||||
offset += descriptor->strides[dim];
|
||||
return *this;
|
||||
}
|
||||
|
||||
T &operator*() { return descriptor.data[offset]; }
|
||||
T *operator->() { return &descriptor.data[offset]; }
|
||||
reference operator*() { return descriptor->data[offset]; }
|
||||
pointer operator->() { return &descriptor->data[offset]; }
|
||||
|
||||
const std::array<int64_t, Rank> &getIndices() { return indices; }
|
||||
|
||||
bool operator==(const StridedMemrefIterator &other) const {
|
||||
return other.offset == offset && &other.descriptor == &descriptor;
|
||||
return other.offset == offset && other.descriptor == descriptor;
|
||||
}
|
||||
|
||||
bool operator!=(const StridedMemrefIterator &other) const {
|
||||
|
@ -245,16 +252,24 @@ private:
|
|||
/// Offset in the buffer. This can be derived from the indices and the
|
||||
/// descriptor.
|
||||
int64_t offset = 0;
|
||||
|
||||
/// Array of indices in the multi-dimensional memref.
|
||||
std::array<int64_t, Rank> indices = {};
|
||||
|
||||
/// Descriptor for the strided memref.
|
||||
StridedMemRefType<T, Rank> &descriptor;
|
||||
StridedMemRefType<T, Rank> *descriptor;
|
||||
};
|
||||
|
||||
/// Iterate over all elements in a 0-ranked strided memref.
|
||||
template <typename T>
|
||||
class StridedMemrefIterator<T, 0> {
|
||||
public:
|
||||
using iterator_category = std::forward_iterator_tag;
|
||||
using value_type = T;
|
||||
using difference_type = std::ptrdiff_t;
|
||||
using pointer = T *;
|
||||
using reference = T &;
|
||||
|
||||
StridedMemrefIterator(StridedMemRefType<T, 0> &descriptor, int64_t offset = 0)
|
||||
: elt(descriptor.data + offset) {}
|
||||
|
||||
|
@ -263,8 +278,8 @@ public:
|
|||
return *this;
|
||||
}
|
||||
|
||||
T &operator*() { return *elt; }
|
||||
T *operator->() { return elt; }
|
||||
reference operator*() { return *elt; }
|
||||
pointer operator->() { return elt; }
|
||||
|
||||
// There are no indices for a 0-ranked memref, but this API is provided for
|
||||
// consistency with the general case.
|
||||
|
@ -301,10 +316,20 @@ struct UnrankedMemRefType {
|
|||
//===----------------------------------------------------------------------===//
|
||||
// DynamicMemRefType type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <typename T>
|
||||
class DynamicMemRefIterator;
|
||||
|
||||
// A reference to one of the StridedMemRef types.
|
||||
template <typename T>
|
||||
class DynamicMemRefType {
|
||||
public:
|
||||
int64_t rank;
|
||||
T *basePtr;
|
||||
T *data;
|
||||
int64_t offset;
|
||||
const int64_t *sizes;
|
||||
const int64_t *strides;
|
||||
|
||||
explicit DynamicMemRefType(const StridedMemRefType<T, 0> &memRef)
|
||||
: rank(0), basePtr(memRef.basePtr), data(memRef.data),
|
||||
offset(memRef.offset), sizes(nullptr), strides(nullptr) {}
|
||||
|
@ -322,12 +347,113 @@ public:
|
|||
strides = sizes + rank;
|
||||
}
|
||||
|
||||
int64_t rank;
|
||||
T *basePtr;
|
||||
T *data;
|
||||
int64_t offset;
|
||||
const int64_t *sizes;
|
||||
const int64_t *strides;
|
||||
template <typename Range,
|
||||
typename sfinae = decltype(std::declval<Range>().begin())>
|
||||
T &operator[](Range &&indices) {
|
||||
assert(indices.size() == rank &&
|
||||
"indices should match rank in memref subscript");
|
||||
if (rank == 0)
|
||||
return data[offset];
|
||||
|
||||
int64_t curOffset = offset;
|
||||
for (int dim = rank - 1; dim >= 0; --dim) {
|
||||
int64_t currentIndex = *(indices.begin() + dim);
|
||||
assert(currentIndex < sizes[dim] && "Index overflow");
|
||||
curOffset += currentIndex * strides[dim];
|
||||
}
|
||||
return data[curOffset];
|
||||
}
|
||||
|
||||
DynamicMemRefIterator<T> begin() { return {*this}; }
|
||||
DynamicMemRefIterator<T> end() { return {*this, -1}; }
|
||||
|
||||
// This operator[] is extremely slow and only for sugaring purposes.
|
||||
DynamicMemRefType<T> operator[](int64_t idx) {
|
||||
assert(rank > 0 && "can't make a subscript of a zero ranked array");
|
||||
|
||||
DynamicMemRefType<T> res(*this);
|
||||
--res.rank;
|
||||
res.offset += idx * res.strides[0];
|
||||
++res.sizes;
|
||||
++res.strides;
|
||||
return res;
|
||||
}
|
||||
|
||||
// This operator* can be used in conjunction with the previous operator[] in
|
||||
// order to access the underlying value in case of zero-ranked memref.
|
||||
T &operator*() {
|
||||
assert(rank == 0 && "not a zero-ranked memRef");
|
||||
return data[offset];
|
||||
}
|
||||
|
||||
private:
|
||||
DynamicMemRefType(const DynamicMemRefType<T> &other)
|
||||
: rank(other.rank), basePtr(other.basePtr), data(other.data),
|
||||
offset(other.offset), strides(other.strides) {}
|
||||
};
|
||||
|
||||
/// Iterate over all elements in a dynamic memref.
|
||||
template <typename T>
|
||||
class DynamicMemRefIterator {
|
||||
public:
|
||||
using iterator_category = std::forward_iterator_tag;
|
||||
using value_type = T;
|
||||
using difference_type = std::ptrdiff_t;
|
||||
using pointer = T *;
|
||||
using reference = T &;
|
||||
|
||||
DynamicMemRefIterator(DynamicMemRefType<T> &descriptor, int64_t offset = 0)
|
||||
: offset(offset), descriptor(&descriptor) {
|
||||
indices.resize(descriptor.rank, 0);
|
||||
}
|
||||
|
||||
DynamicMemRefIterator<T> &operator++() {
|
||||
if (descriptor->rank == 0) {
|
||||
offset = -1;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int dim = descriptor->rank - 1;
|
||||
|
||||
while (dim >= 0 && indices[dim] == (descriptor->sizes[dim] - 1)) {
|
||||
offset -= indices[dim] * descriptor->strides[dim];
|
||||
indices[dim] = 0;
|
||||
--dim;
|
||||
}
|
||||
|
||||
if (dim < 0) {
|
||||
offset = -1;
|
||||
return *this;
|
||||
}
|
||||
|
||||
++indices[dim];
|
||||
offset += descriptor->strides[dim];
|
||||
return *this;
|
||||
}
|
||||
|
||||
reference operator*() { return descriptor->data[offset]; }
|
||||
pointer operator->() { return &descriptor->data[offset]; }
|
||||
|
||||
const std::vector<int64_t> &getIndices() { return indices; }
|
||||
|
||||
bool operator==(const DynamicMemRefIterator &other) const {
|
||||
return other.offset == offset && other.descriptor == descriptor;
|
||||
}
|
||||
|
||||
bool operator!=(const DynamicMemRefIterator &other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
private:
|
||||
/// Offset in the buffer. This can be derived from the indices and the
|
||||
/// descriptor.
|
||||
int64_t offset = 0;
|
||||
|
||||
/// Array of indices in the multi-dimensional memref.
|
||||
std::vector<int64_t> indices = {};
|
||||
|
||||
/// Descriptor for the dynamic memref.
|
||||
DynamicMemRefType<T> *descriptor;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
add_mlir_unittest(MLIRExecutionEngineTests
|
||||
DynamicMemRef.cpp
|
||||
Invoke.cpp
|
||||
)
|
||||
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
//===- DynamicMemRef.cpp ----------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/ExecutionEngine/CRunnerUtils.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
using namespace ::mlir;
|
||||
using namespace ::testing;
|
||||
|
||||
TEST(DynamicMemRef, rankZero) {
|
||||
int data = 57;
|
||||
|
||||
StridedMemRefType<int, 0> memRef;
|
||||
memRef.basePtr = &data;
|
||||
memRef.data = &data;
|
||||
memRef.offset = 0;
|
||||
|
||||
DynamicMemRefType<int> dynamicMemRef(memRef);
|
||||
|
||||
llvm::SmallVector<int, 1> values(dynamicMemRef.begin(), dynamicMemRef.end());
|
||||
EXPECT_THAT(values, ElementsAre(57));
|
||||
}
|
||||
|
||||
TEST(DynamicMemRef, rankOne) {
|
||||
std::array<int, 3> data;
|
||||
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
data[i] = i;
|
||||
}
|
||||
|
||||
StridedMemRefType<int, 1> memRef;
|
||||
memRef.basePtr = data.data();
|
||||
memRef.data = data.data();
|
||||
memRef.offset = 0;
|
||||
memRef.sizes[0] = 3;
|
||||
memRef.strides[0] = 1;
|
||||
|
||||
DynamicMemRefType<int> dynamicMemRef(memRef);
|
||||
|
||||
llvm::SmallVector<int, 3> values(dynamicMemRef.begin(), dynamicMemRef.end());
|
||||
EXPECT_THAT(values, ElementsAreArray(data));
|
||||
|
||||
for (int64_t i = 0; i < 3; ++i) {
|
||||
EXPECT_EQ(*dynamicMemRef[i], data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DynamicMemRef, rankTwo) {
|
||||
std::array<int, 6> data;
|
||||
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
data[i] = i;
|
||||
}
|
||||
|
||||
StridedMemRefType<int, 2> memRef;
|
||||
memRef.basePtr = data.data();
|
||||
memRef.data = data.data();
|
||||
memRef.offset = 0;
|
||||
memRef.sizes[0] = 2;
|
||||
memRef.sizes[1] = 3;
|
||||
memRef.strides[0] = 3;
|
||||
memRef.strides[1] = 1;
|
||||
|
||||
DynamicMemRefType<int> dynamicMemRef(memRef);
|
||||
|
||||
llvm::SmallVector<int, 6> values(dynamicMemRef.begin(), dynamicMemRef.end());
|
||||
EXPECT_THAT(values, ElementsAreArray(data));
|
||||
}
|
||||
|
||||
TEST(DynamicMemRef, rankThree) {
|
||||
std::array<int, 24> data;
|
||||
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
data[i] = i;
|
||||
}
|
||||
|
||||
StridedMemRefType<int, 3> memRef;
|
||||
memRef.basePtr = data.data();
|
||||
memRef.data = data.data();
|
||||
memRef.offset = 0;
|
||||
memRef.sizes[0] = 2;
|
||||
memRef.sizes[1] = 3;
|
||||
memRef.sizes[2] = 4;
|
||||
memRef.strides[0] = 12;
|
||||
memRef.strides[1] = 4;
|
||||
memRef.strides[2] = 1;
|
||||
|
||||
DynamicMemRefType<int> dynamicMemRef(memRef);
|
||||
|
||||
llvm::SmallVector<int, 24> values(dynamicMemRef.begin(), dynamicMemRef.end());
|
||||
EXPECT_THAT(values, ElementsAreArray(data));
|
||||
}
|
Loading…
Reference in New Issue