From 40843347b37cc73fd0208c63a19359df6b2bf67b Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Mon, 12 Jul 2021 18:18:23 +0000 Subject: [PATCH] [mlir][sparse] Add Merger unit tests (with gcc5 build fix) This is a fix of https://reviews.llvm.org/D104956, which broke the gcc5 build. We opt to use unit tests rather than check tests as the lattice/merger code is a small C++ component with a well-defined API. Testing this API via check tests would be far less direct and readable. In addition, as the check tests will only be able to test the API indirectly, the tests may break based on unrelated changes; e.g. changes in linalg. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D105828 --- mlir/unittests/Dialect/CMakeLists.txt | 1 + .../Dialect/SparseTensor/CMakeLists.txt | 7 + .../Dialect/SparseTensor/MergerTest.cpp | 252 ++++++++++++++++++ 3 files changed, 260 insertions(+) create mode 100644 mlir/unittests/Dialect/SparseTensor/CMakeLists.txt create mode 100644 mlir/unittests/Dialect/SparseTensor/MergerTest.cpp diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt index 22f1475a1393..6b441567b548 100644 --- a/mlir/unittests/Dialect/CMakeLists.txt +++ b/mlir/unittests/Dialect/CMakeLists.txt @@ -7,4 +7,5 @@ target_link_libraries(MLIRDialectTests MLIRDialect) add_subdirectory(Quant) +add_subdirectory(SparseTensor) add_subdirectory(SPIRV) diff --git a/mlir/unittests/Dialect/SparseTensor/CMakeLists.txt b/mlir/unittests/Dialect/SparseTensor/CMakeLists.txt new file mode 100644 index 000000000000..f9594aab3bbc --- /dev/null +++ b/mlir/unittests/Dialect/SparseTensor/CMakeLists.txt @@ -0,0 +1,7 @@ +add_mlir_unittest(MLIRSparseTensorTests + MergerTest.cpp +) +target_link_libraries(MLIRSparseTensorTests + PRIVATE + MLIRSparseTensorUtils +) diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp new file mode 100644 index 000000000000..b544ce16469a --- /dev/null +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -0,0 +1,252 @@ +#include "mlir/Dialect/SparseTensor/Utils/Merger.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include + +using namespace mlir::sparse_tensor; + +namespace { + +/// Simple recursive data structure used to match expressions in Mergers. +struct Pattern { + Kind kind; + + /// Expressions representing tensors simply have a tensor number. + unsigned tensorNum; + + /// Tensor operations point to their children. + std::shared_ptr e0; + std::shared_ptr e1; + + /// Constructors. + /// Rather than using these, please use the readable helper constructor + /// functions below to make tests more readable. + Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {} + Pattern(Kind kind, std::shared_ptr e0, std::shared_ptr e1) + : kind(kind), e0(e0), e1(e1) { + assert(kind >= Kind::kMulF); + assert(e0 && e1); + } +}; + +/// +/// Readable Pattern builder functions. +/// These should be preferred over the actual constructors. +/// + +static std::shared_ptr tensorPattern(unsigned tensorNum) { + return std::make_shared(tensorNum); +} + +static std::shared_ptr addfPattern(std::shared_ptr e0, + std::shared_ptr e1) { + return std::make_shared(Kind::kAddF, e0, e1); +} + +static std::shared_ptr mulfPattern(std::shared_ptr e0, + std::shared_ptr e1) { + return std::make_shared(Kind::kMulF, e0, e1); +} + +class MergerTestBase : public ::testing::Test { +protected: + MergerTestBase(unsigned numTensors, unsigned numLoops) + : numTensors(numTensors), numLoops(numLoops), + merger(numTensors, numLoops) {} + + /// + /// Expression construction helpers. + /// + + unsigned tensor(unsigned tensor) { + return merger.addExp(Kind::kTensor, tensor); + } + + unsigned addf(unsigned e0, unsigned e1) { + return merger.addExp(Kind::kAddF, e0, e1); + } + + unsigned mulf(unsigned e0, unsigned e1) { + return merger.addExp(Kind::kMulF, e0, e1); + } + + /// + /// Comparison helpers. + /// + + /// For readability of tests. + unsigned lat(unsigned lat) { return lat; } + + /// Returns true if a lattice point with an expression matching the given + /// pattern and bits matching the given bits is present in lattice points + /// [p, p+n) of lattice set s. This is useful for testing partial ordering + /// constraints between lattice points. We generally know how contiguous + /// groups of lattice points should be ordered with respect to other groups, + /// but there is no required ordering within groups. + bool latPointWithinRange(unsigned s, unsigned p, unsigned n, + std::shared_ptr pattern, + llvm::BitVector bits) { + for (unsigned i = p; i < p + n; ++i) { + if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) && + compareBits(s, i, bits)) + return true; + } + return false; + } + + /// Wrapper over latPointWithinRange for readability of tests. + void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n, + std::shared_ptr pattern, + llvm::BitVector bits) { + EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits)); + } + + /// Wrapper over expectLatPointWithinRange for a single lat point. + void expectLatPoint(unsigned s, unsigned p, std::shared_ptr pattern, + llvm::BitVector bits) { + EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits)); + } + + /// Converts a vector of (loop, tensor) pairs to a bitvector with the + /// corresponding bits set. + llvm::BitVector + loopsToBits(std::vector> loops) { + llvm::BitVector testBits = llvm::BitVector(numTensors + 1, false); + for (auto l : loops) { + auto loop = std::get<0>(l); + auto tensor = std::get<1>(l); + testBits.set(numTensors * loop + tensor); + } + return testBits; + } + + /// Returns true if the bits of lattice point p in set s match the given bits. + bool compareBits(unsigned s, unsigned p, llvm::BitVector bits) { + return merger.lat(merger.set(s)[p]).bits == bits; + } + + /// Check that there are n lattice points in set s. + void expectNumLatPoints(unsigned s, unsigned n) { + EXPECT_THAT(merger.set(s).size(), n); + } + + /// Compares expressions for equality. Equality is defined recursively as: + /// - Two expressions can only be equal if they have the same Kind. + /// - Two binary expressions are equal if they have the same Kind and their + /// children are equal. + /// - Expressions with Kind invariant or tensor are equal if they have the + /// same expression id. + bool compareExpression(unsigned e, std::shared_ptr pattern) { + auto tensorExp = merger.exp(e); + if (tensorExp.kind != pattern->kind) + return false; + assert(tensorExp.kind != Kind::kInvariant && + "Invariant comparison not yet supported"); + switch (tensorExp.kind) { + case Kind::kTensor: + return tensorExp.tensor == pattern->tensorNum; + case Kind::kZero: + return true; + case Kind::kMulF: + case Kind::kMulI: + case Kind::kAddF: + case Kind::kAddI: + case Kind::kSubF: + case Kind::kSubI: + return compareExpression(tensorExp.children.e0, pattern->e0) && + compareExpression(tensorExp.children.e1, pattern->e1); + default: + llvm_unreachable("Unhandled Kind"); + } + } + + unsigned numTensors; + unsigned numLoops; + Merger merger; +}; + +class MergerTest3T1L : public MergerTestBase { +protected: + // Our three tensors (two inputs, one output). + const unsigned t0 = 0, t1 = 1, t2 = 2; + + // Our single loop. + const unsigned l0 = 0; + + MergerTest3T1L() : MergerTestBase(3, 1) { + // Tensor 0: sparse input vector. + merger.addExp(Kind::kTensor, t0, -1u); + merger.setDim(t0, l0, Dim::kSparse); + + // Tensor 1: sparse input vector. + merger.addExp(Kind::kTensor, t1, -1u); + merger.setDim(t1, l0, Dim::kSparse); + + // Tensor 2: dense output vector. + merger.addExp(Kind::kTensor, t2, -1u); + merger.setDim(t2, l0, Dim::kDense); + } +}; + +} // anonymous namespace + +/// Vector addition of 2 vectors, i.e.: +/// a(i) = b(i) + c(i) +/// which should form the 3 lattice points +/// { +/// lat( i_00 i_01 / (tensor_0 + tensor_1) ) +/// lat( i_00 / tensor_0 ) +/// lat( i_01 / tensor_1 ) +/// } +/// and after optimization, will reduce to the 2 lattice points +/// { +/// lat( i_00 i_01 / (tensor_0 + tensor_1) ) +/// lat( i_00 / tensor_0 ) +/// } +TEST_F(MergerTest3T1L, VectorAdd2) { + // Construct expression. + auto e = addf(tensor(t0), tensor(t1)); + + // Build lattices and check. + auto s = merger.buildLattices(e, l0); + expectNumLatPoints(s, 3); + expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)), + loopsToBits({{l0, t0}, {l0, t1}})); + expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0), + loopsToBits({{l0, t0}})); + expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1), + loopsToBits({{l0, t1}})); + + // Optimize lattices and check. + s = merger.optimizeSet(s); + expectNumLatPoints(s, 3); + expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)), + loopsToBits({{l0, t0}, {l0, t1}})); + expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0), + loopsToBits({{l0, t0}})); + expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1), + loopsToBits({{l0, t1}})); +} + +/// Vector multiplication of 2 vectors, i.e.: +/// a(i) = b(i) * c(i) +/// which should form the single lattice point +/// { +/// lat( i_00 i_01 / (tensor_0 * tensor_1) ) +/// } +TEST_F(MergerTest3T1L, VectorMul2) { + // Construct expression. + auto e = mulf(t0, t1); + + // Build lattices and check. + auto s = merger.buildLattices(e, l0); + expectNumLatPoints(s, 1); + expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)), + loopsToBits({{l0, t0}, {l0, t1}})); + + // Optimize lattices and check. + s = merger.optimizeSet(s); + expectNumLatPoints(s, 1); + expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)), + loopsToBits({{l0, t0}, {l0, t1}})); +}