forked from OSchip/llvm-project
[mlir][sparse] Add Merger unit tests (with gcc5 build fix)
This is a fix of https://reviews.llvm.org/D104956, which broke the gcc5 build. We opt to use unit tests rather than check tests as the lattice/merger code is a small C++ component with a well-defined API. Testing this API via check tests would be far less direct and readable. In addition, as the check tests will only be able to test the API indirectly, the tests may break based on unrelated changes; e.g. changes in linalg. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D105828
This commit is contained in:
parent
d5d477780c
commit
40843347b3
|
@ -7,4 +7,5 @@ target_link_libraries(MLIRDialectTests
|
|||
MLIRDialect)
|
||||
|
||||
add_subdirectory(Quant)
|
||||
add_subdirectory(SparseTensor)
|
||||
add_subdirectory(SPIRV)
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
add_mlir_unittest(MLIRSparseTensorTests
|
||||
MergerTest.cpp
|
||||
)
|
||||
target_link_libraries(MLIRSparseTensorTests
|
||||
PRIVATE
|
||||
MLIRSparseTensorUtils
|
||||
)
|
|
@ -0,0 +1,252 @@
|
|||
#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
|
||||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include <memory>
|
||||
|
||||
using namespace mlir::sparse_tensor;
|
||||
|
||||
namespace {
|
||||
|
||||
/// Simple recursive data structure used to match expressions in Mergers.
|
||||
struct Pattern {
|
||||
Kind kind;
|
||||
|
||||
/// Expressions representing tensors simply have a tensor number.
|
||||
unsigned tensorNum;
|
||||
|
||||
/// Tensor operations point to their children.
|
||||
std::shared_ptr<Pattern> e0;
|
||||
std::shared_ptr<Pattern> e1;
|
||||
|
||||
/// Constructors.
|
||||
/// Rather than using these, please use the readable helper constructor
|
||||
/// functions below to make tests more readable.
|
||||
Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {}
|
||||
Pattern(Kind kind, std::shared_ptr<Pattern> e0, std::shared_ptr<Pattern> e1)
|
||||
: kind(kind), e0(e0), e1(e1) {
|
||||
assert(kind >= Kind::kMulF);
|
||||
assert(e0 && e1);
|
||||
}
|
||||
};
|
||||
|
||||
///
|
||||
/// Readable Pattern builder functions.
|
||||
/// These should be preferred over the actual constructors.
|
||||
///
|
||||
|
||||
static std::shared_ptr<Pattern> tensorPattern(unsigned tensorNum) {
|
||||
return std::make_shared<Pattern>(tensorNum);
|
||||
}
|
||||
|
||||
static std::shared_ptr<Pattern> addfPattern(std::shared_ptr<Pattern> e0,
|
||||
std::shared_ptr<Pattern> e1) {
|
||||
return std::make_shared<Pattern>(Kind::kAddF, e0, e1);
|
||||
}
|
||||
|
||||
static std::shared_ptr<Pattern> mulfPattern(std::shared_ptr<Pattern> e0,
|
||||
std::shared_ptr<Pattern> e1) {
|
||||
return std::make_shared<Pattern>(Kind::kMulF, e0, e1);
|
||||
}
|
||||
|
||||
class MergerTestBase : public ::testing::Test {
|
||||
protected:
|
||||
MergerTestBase(unsigned numTensors, unsigned numLoops)
|
||||
: numTensors(numTensors), numLoops(numLoops),
|
||||
merger(numTensors, numLoops) {}
|
||||
|
||||
///
|
||||
/// Expression construction helpers.
|
||||
///
|
||||
|
||||
unsigned tensor(unsigned tensor) {
|
||||
return merger.addExp(Kind::kTensor, tensor);
|
||||
}
|
||||
|
||||
unsigned addf(unsigned e0, unsigned e1) {
|
||||
return merger.addExp(Kind::kAddF, e0, e1);
|
||||
}
|
||||
|
||||
unsigned mulf(unsigned e0, unsigned e1) {
|
||||
return merger.addExp(Kind::kMulF, e0, e1);
|
||||
}
|
||||
|
||||
///
|
||||
/// Comparison helpers.
|
||||
///
|
||||
|
||||
/// For readability of tests.
|
||||
unsigned lat(unsigned lat) { return lat; }
|
||||
|
||||
/// Returns true if a lattice point with an expression matching the given
|
||||
/// pattern and bits matching the given bits is present in lattice points
|
||||
/// [p, p+n) of lattice set s. This is useful for testing partial ordering
|
||||
/// constraints between lattice points. We generally know how contiguous
|
||||
/// groups of lattice points should be ordered with respect to other groups,
|
||||
/// but there is no required ordering within groups.
|
||||
bool latPointWithinRange(unsigned s, unsigned p, unsigned n,
|
||||
std::shared_ptr<Pattern> pattern,
|
||||
llvm::BitVector bits) {
|
||||
for (unsigned i = p; i < p + n; ++i) {
|
||||
if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) &&
|
||||
compareBits(s, i, bits))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Wrapper over latPointWithinRange for readability of tests.
|
||||
void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n,
|
||||
std::shared_ptr<Pattern> pattern,
|
||||
llvm::BitVector bits) {
|
||||
EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits));
|
||||
}
|
||||
|
||||
/// Wrapper over expectLatPointWithinRange for a single lat point.
|
||||
void expectLatPoint(unsigned s, unsigned p, std::shared_ptr<Pattern> pattern,
|
||||
llvm::BitVector bits) {
|
||||
EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits));
|
||||
}
|
||||
|
||||
/// Converts a vector of (loop, tensor) pairs to a bitvector with the
|
||||
/// corresponding bits set.
|
||||
llvm::BitVector
|
||||
loopsToBits(std::vector<std::pair<unsigned, unsigned>> loops) {
|
||||
llvm::BitVector testBits = llvm::BitVector(numTensors + 1, false);
|
||||
for (auto l : loops) {
|
||||
auto loop = std::get<0>(l);
|
||||
auto tensor = std::get<1>(l);
|
||||
testBits.set(numTensors * loop + tensor);
|
||||
}
|
||||
return testBits;
|
||||
}
|
||||
|
||||
/// Returns true if the bits of lattice point p in set s match the given bits.
|
||||
bool compareBits(unsigned s, unsigned p, llvm::BitVector bits) {
|
||||
return merger.lat(merger.set(s)[p]).bits == bits;
|
||||
}
|
||||
|
||||
/// Check that there are n lattice points in set s.
|
||||
void expectNumLatPoints(unsigned s, unsigned n) {
|
||||
EXPECT_THAT(merger.set(s).size(), n);
|
||||
}
|
||||
|
||||
/// Compares expressions for equality. Equality is defined recursively as:
|
||||
/// - Two expressions can only be equal if they have the same Kind.
|
||||
/// - Two binary expressions are equal if they have the same Kind and their
|
||||
/// children are equal.
|
||||
/// - Expressions with Kind invariant or tensor are equal if they have the
|
||||
/// same expression id.
|
||||
bool compareExpression(unsigned e, std::shared_ptr<Pattern> pattern) {
|
||||
auto tensorExp = merger.exp(e);
|
||||
if (tensorExp.kind != pattern->kind)
|
||||
return false;
|
||||
assert(tensorExp.kind != Kind::kInvariant &&
|
||||
"Invariant comparison not yet supported");
|
||||
switch (tensorExp.kind) {
|
||||
case Kind::kTensor:
|
||||
return tensorExp.tensor == pattern->tensorNum;
|
||||
case Kind::kZero:
|
||||
return true;
|
||||
case Kind::kMulF:
|
||||
case Kind::kMulI:
|
||||
case Kind::kAddF:
|
||||
case Kind::kAddI:
|
||||
case Kind::kSubF:
|
||||
case Kind::kSubI:
|
||||
return compareExpression(tensorExp.children.e0, pattern->e0) &&
|
||||
compareExpression(tensorExp.children.e1, pattern->e1);
|
||||
default:
|
||||
llvm_unreachable("Unhandled Kind");
|
||||
}
|
||||
}
|
||||
|
||||
unsigned numTensors;
|
||||
unsigned numLoops;
|
||||
Merger merger;
|
||||
};
|
||||
|
||||
class MergerTest3T1L : public MergerTestBase {
|
||||
protected:
|
||||
// Our three tensors (two inputs, one output).
|
||||
const unsigned t0 = 0, t1 = 1, t2 = 2;
|
||||
|
||||
// Our single loop.
|
||||
const unsigned l0 = 0;
|
||||
|
||||
MergerTest3T1L() : MergerTestBase(3, 1) {
|
||||
// Tensor 0: sparse input vector.
|
||||
merger.addExp(Kind::kTensor, t0, -1u);
|
||||
merger.setDim(t0, l0, Dim::kSparse);
|
||||
|
||||
// Tensor 1: sparse input vector.
|
||||
merger.addExp(Kind::kTensor, t1, -1u);
|
||||
merger.setDim(t1, l0, Dim::kSparse);
|
||||
|
||||
// Tensor 2: dense output vector.
|
||||
merger.addExp(Kind::kTensor, t2, -1u);
|
||||
merger.setDim(t2, l0, Dim::kDense);
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
/// Vector addition of 2 vectors, i.e.:
|
||||
/// a(i) = b(i) + c(i)
|
||||
/// which should form the 3 lattice points
|
||||
/// {
|
||||
/// lat( i_00 i_01 / (tensor_0 + tensor_1) )
|
||||
/// lat( i_00 / tensor_0 )
|
||||
/// lat( i_01 / tensor_1 )
|
||||
/// }
|
||||
/// and after optimization, will reduce to the 2 lattice points
|
||||
/// {
|
||||
/// lat( i_00 i_01 / (tensor_0 + tensor_1) )
|
||||
/// lat( i_00 / tensor_0 )
|
||||
/// }
|
||||
TEST_F(MergerTest3T1L, VectorAdd2) {
|
||||
// Construct expression.
|
||||
auto e = addf(tensor(t0), tensor(t1));
|
||||
|
||||
// Build lattices and check.
|
||||
auto s = merger.buildLattices(e, l0);
|
||||
expectNumLatPoints(s, 3);
|
||||
expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)),
|
||||
loopsToBits({{l0, t0}, {l0, t1}}));
|
||||
expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0),
|
||||
loopsToBits({{l0, t0}}));
|
||||
expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1),
|
||||
loopsToBits({{l0, t1}}));
|
||||
|
||||
// Optimize lattices and check.
|
||||
s = merger.optimizeSet(s);
|
||||
expectNumLatPoints(s, 3);
|
||||
expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)),
|
||||
loopsToBits({{l0, t0}, {l0, t1}}));
|
||||
expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0),
|
||||
loopsToBits({{l0, t0}}));
|
||||
expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1),
|
||||
loopsToBits({{l0, t1}}));
|
||||
}
|
||||
|
||||
/// Vector multiplication of 2 vectors, i.e.:
|
||||
/// a(i) = b(i) * c(i)
|
||||
/// which should form the single lattice point
|
||||
/// {
|
||||
/// lat( i_00 i_01 / (tensor_0 * tensor_1) )
|
||||
/// }
|
||||
TEST_F(MergerTest3T1L, VectorMul2) {
|
||||
// Construct expression.
|
||||
auto e = mulf(t0, t1);
|
||||
|
||||
// Build lattices and check.
|
||||
auto s = merger.buildLattices(e, l0);
|
||||
expectNumLatPoints(s, 1);
|
||||
expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)),
|
||||
loopsToBits({{l0, t0}, {l0, t1}}));
|
||||
|
||||
// Optimize lattices and check.
|
||||
s = merger.optimizeSet(s);
|
||||
expectNumLatPoints(s, 1);
|
||||
expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)),
|
||||
loopsToBits({{l0, t0}, {l0, t1}}));
|
||||
}
|
Loading…
Reference in New Issue