From b8ecf32f81bb8073320ad5d4722a1680f615d133 Mon Sep 17 00:00:00 2001 From: Michele Scuttari Date: Mon, 15 Aug 2022 17:17:54 +0200 Subject: [PATCH] DynamicMemRefType: iteration and access by indices MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The methods to perform such operations have been implemented for the DynamicMemRefType in a way that is similar to the implementation for StridedMemRefType. Up until here one could pass an unranked memref to the library, and thus obtain a “dynamic” memref descriptor, but then there would have been no possibility to operate on its content. Differential Revision: https://reviews.llvm.org/D131359 --- .../mlir/ExecutionEngine/CRunnerUtils.h | 158 ++++++++++++++++-- mlir/unittests/ExecutionEngine/CMakeLists.txt | 1 + .../ExecutionEngine/DynamicMemRef.cpp | 99 +++++++++++ 3 files changed, 242 insertions(+), 16 deletions(-) create mode 100644 mlir/unittests/ExecutionEngine/DynamicMemRef.cpp diff --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h index e536ae8fe115..9a24bbf4b8d5 100644 --- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h +++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h @@ -36,6 +36,7 @@ #include #include #include +#include //===----------------------------------------------------------------------===// // Codegen-compatible structures for Vector type. @@ -209,13 +210,19 @@ struct StridedMemRefType { template class StridedMemrefIterator { public: + using iterator_category = std::forward_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = T *; + using reference = T &; + StridedMemrefIterator(StridedMemRefType &descriptor, int64_t offset = 0) - : offset(offset), descriptor(descriptor) {} + : offset(offset), descriptor(&descriptor) {} StridedMemrefIterator &operator++() { int dim = Rank - 1; - while (dim >= 0 && indices[dim] == (descriptor.sizes[dim] - 1)) { - offset -= indices[dim] * descriptor.strides[dim]; + while (dim >= 0 && indices[dim] == (descriptor->sizes[dim] - 1)) { + offset -= indices[dim] * descriptor->strides[dim]; indices[dim] = 0; --dim; } @@ -224,17 +231,17 @@ public: return *this; } ++indices[dim]; - offset += descriptor.strides[dim]; + offset += descriptor->strides[dim]; return *this; } - T &operator*() { return descriptor.data[offset]; } - T *operator->() { return &descriptor.data[offset]; } + reference operator*() { return descriptor->data[offset]; } + pointer operator->() { return &descriptor->data[offset]; } const std::array &getIndices() { return indices; } bool operator==(const StridedMemrefIterator &other) const { - return other.offset == offset && &other.descriptor == &descriptor; + return other.offset == offset && other.descriptor == descriptor; } bool operator!=(const StridedMemrefIterator &other) const { @@ -245,16 +252,24 @@ private: /// Offset in the buffer. This can be derived from the indices and the /// descriptor. int64_t offset = 0; + /// Array of indices in the multi-dimensional memref. std::array indices = {}; + /// Descriptor for the strided memref. - StridedMemRefType &descriptor; + StridedMemRefType *descriptor; }; /// Iterate over all elements in a 0-ranked strided memref. template class StridedMemrefIterator { public: + using iterator_category = std::forward_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = T *; + using reference = T &; + StridedMemrefIterator(StridedMemRefType &descriptor, int64_t offset = 0) : elt(descriptor.data + offset) {} @@ -263,8 +278,8 @@ public: return *this; } - T &operator*() { return *elt; } - T *operator->() { return elt; } + reference operator*() { return *elt; } + pointer operator->() { return elt; } // There are no indices for a 0-ranked memref, but this API is provided for // consistency with the general case. @@ -301,10 +316,20 @@ struct UnrankedMemRefType { //===----------------------------------------------------------------------===// // DynamicMemRefType type. //===----------------------------------------------------------------------===// +template +class DynamicMemRefIterator; + // A reference to one of the StridedMemRef types. template class DynamicMemRefType { public: + int64_t rank; + T *basePtr; + T *data; + int64_t offset; + const int64_t *sizes; + const int64_t *strides; + explicit DynamicMemRefType(const StridedMemRefType &memRef) : rank(0), basePtr(memRef.basePtr), data(memRef.data), offset(memRef.offset), sizes(nullptr), strides(nullptr) {} @@ -322,12 +347,113 @@ public: strides = sizes + rank; } - int64_t rank; - T *basePtr; - T *data; - int64_t offset; - const int64_t *sizes; - const int64_t *strides; + template ().begin())> + T &operator[](Range &&indices) { + assert(indices.size() == rank && + "indices should match rank in memref subscript"); + if (rank == 0) + return data[offset]; + + int64_t curOffset = offset; + for (int dim = rank - 1; dim >= 0; --dim) { + int64_t currentIndex = *(indices.begin() + dim); + assert(currentIndex < sizes[dim] && "Index overflow"); + curOffset += currentIndex * strides[dim]; + } + return data[curOffset]; + } + + DynamicMemRefIterator begin() { return {*this}; } + DynamicMemRefIterator end() { return {*this, -1}; } + + // This operator[] is extremely slow and only for sugaring purposes. + DynamicMemRefType operator[](int64_t idx) { + assert(rank > 0 && "can't make a subscript of a zero ranked array"); + + DynamicMemRefType res(*this); + --res.rank; + res.offset += idx * res.strides[0]; + ++res.sizes; + ++res.strides; + return res; + } + + // This operator* can be used in conjunction with the previous operator[] in + // order to access the underlying value in case of zero-ranked memref. + T &operator*() { + assert(rank == 0 && "not a zero-ranked memRef"); + return data[offset]; + } + +private: + DynamicMemRefType(const DynamicMemRefType &other) + : rank(other.rank), basePtr(other.basePtr), data(other.data), + offset(other.offset), strides(other.strides) {} +}; + +/// Iterate over all elements in a dynamic memref. +template +class DynamicMemRefIterator { +public: + using iterator_category = std::forward_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = T *; + using reference = T &; + + DynamicMemRefIterator(DynamicMemRefType &descriptor, int64_t offset = 0) + : offset(offset), descriptor(&descriptor) { + indices.resize(descriptor.rank, 0); + } + + DynamicMemRefIterator &operator++() { + if (descriptor->rank == 0) { + offset = -1; + return *this; + } + + int dim = descriptor->rank - 1; + + while (dim >= 0 && indices[dim] == (descriptor->sizes[dim] - 1)) { + offset -= indices[dim] * descriptor->strides[dim]; + indices[dim] = 0; + --dim; + } + + if (dim < 0) { + offset = -1; + return *this; + } + + ++indices[dim]; + offset += descriptor->strides[dim]; + return *this; + } + + reference operator*() { return descriptor->data[offset]; } + pointer operator->() { return &descriptor->data[offset]; } + + const std::vector &getIndices() { return indices; } + + bool operator==(const DynamicMemRefIterator &other) const { + return other.offset == offset && other.descriptor == descriptor; + } + + bool operator!=(const DynamicMemRefIterator &other) const { + return !(*this == other); + } + +private: + /// Offset in the buffer. This can be derived from the indices and the + /// descriptor. + int64_t offset = 0; + + /// Array of indices in the multi-dimensional memref. + std::vector indices = {}; + + /// Descriptor for the dynamic memref. + DynamicMemRefType *descriptor; }; //===----------------------------------------------------------------------===// diff --git a/mlir/unittests/ExecutionEngine/CMakeLists.txt b/mlir/unittests/ExecutionEngine/CMakeLists.txt index d17acb6647f8..32722d0dd958 100644 --- a/mlir/unittests/ExecutionEngine/CMakeLists.txt +++ b/mlir/unittests/ExecutionEngine/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_unittest(MLIRExecutionEngineTests + DynamicMemRef.cpp Invoke.cpp ) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) diff --git a/mlir/unittests/ExecutionEngine/DynamicMemRef.cpp b/mlir/unittests/ExecutionEngine/DynamicMemRef.cpp new file mode 100644 index 000000000000..5f4f01270246 --- /dev/null +++ b/mlir/unittests/ExecutionEngine/DynamicMemRef.cpp @@ -0,0 +1,99 @@ +//===- DynamicMemRef.cpp ----------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/ExecutionEngine/CRunnerUtils.h" +#include "llvm/ADT/SmallVector.h" + +#include "gmock/gmock.h" + +using namespace ::mlir; +using namespace ::testing; + +TEST(DynamicMemRef, rankZero) { + int data = 57; + + StridedMemRefType memRef; + memRef.basePtr = &data; + memRef.data = &data; + memRef.offset = 0; + + DynamicMemRefType dynamicMemRef(memRef); + + llvm::SmallVector values(dynamicMemRef.begin(), dynamicMemRef.end()); + EXPECT_THAT(values, ElementsAre(57)); +} + +TEST(DynamicMemRef, rankOne) { + std::array data; + + for (size_t i = 0; i < data.size(); ++i) { + data[i] = i; + } + + StridedMemRefType memRef; + memRef.basePtr = data.data(); + memRef.data = data.data(); + memRef.offset = 0; + memRef.sizes[0] = 3; + memRef.strides[0] = 1; + + DynamicMemRefType dynamicMemRef(memRef); + + llvm::SmallVector values(dynamicMemRef.begin(), dynamicMemRef.end()); + EXPECT_THAT(values, ElementsAreArray(data)); + + for (int64_t i = 0; i < 3; ++i) { + EXPECT_EQ(*dynamicMemRef[i], data[i]); + } +} + +TEST(DynamicMemRef, rankTwo) { + std::array data; + + for (size_t i = 0; i < data.size(); ++i) { + data[i] = i; + } + + StridedMemRefType memRef; + memRef.basePtr = data.data(); + memRef.data = data.data(); + memRef.offset = 0; + memRef.sizes[0] = 2; + memRef.sizes[1] = 3; + memRef.strides[0] = 3; + memRef.strides[1] = 1; + + DynamicMemRefType dynamicMemRef(memRef); + + llvm::SmallVector values(dynamicMemRef.begin(), dynamicMemRef.end()); + EXPECT_THAT(values, ElementsAreArray(data)); +} + +TEST(DynamicMemRef, rankThree) { + std::array data; + + for (size_t i = 0; i < data.size(); ++i) { + data[i] = i; + } + + StridedMemRefType memRef; + memRef.basePtr = data.data(); + memRef.data = data.data(); + memRef.offset = 0; + memRef.sizes[0] = 2; + memRef.sizes[1] = 3; + memRef.sizes[2] = 4; + memRef.strides[0] = 12; + memRef.strides[1] = 4; + memRef.strides[2] = 1; + + DynamicMemRefType dynamicMemRef(memRef); + + llvm::SmallVector values(dynamicMemRef.begin(), dynamicMemRef.end()); + EXPECT_THAT(values, ElementsAreArray(data)); +} \ No newline at end of file