forked from OSchip/llvm-project
[mlir][sparse] add more unittest cases to sparse dialect merger
Reviewed By: aartbik, wrengr Differential Revision: https://reviews.llvm.org/D128058
This commit is contained in:
parent
43dc319049
commit
daeb2dcea0
|
@ -8,6 +8,68 @@ using namespace mlir::sparse_tensor;
|
|||
|
||||
namespace {
|
||||
|
||||
///
|
||||
/// Defines macros to iterate binary and the combination of binary operations.
|
||||
///
|
||||
|
||||
#define FOREVERY_BINOP(DO) \
|
||||
DO(mulf, Kind::kMulF) \
|
||||
DO(mulc, Kind::kMulC) \
|
||||
DO(muli, Kind::kMulI) \
|
||||
DO(addf, Kind::kAddF) \
|
||||
DO(addc, Kind::kAddC) \
|
||||
DO(addi, Kind::kAddI) \
|
||||
DO(subf, Kind::kSubF) \
|
||||
DO(subc, Kind::kSubC) \
|
||||
DO(subi, Kind::kSubI) \
|
||||
DO(andi, Kind::kAndI) \
|
||||
DO(xori, Kind::kXorI) \
|
||||
DO(ori, Kind::kOrI)
|
||||
|
||||
// TODO: Disjunctive binary operations that need special handling are not
|
||||
// included, e.g., Division are not tested (for now) as it need a constant
|
||||
// non-zero dividend.
|
||||
// ##__VA_ARGS__ handles cases when __VA_ARGS__ is empty.
|
||||
#define FOREVERY_COMMON_DISJ_BINOP(TEST, ...) \
|
||||
TEST(addf, ##__VA_ARGS__) \
|
||||
TEST(addc, ##__VA_ARGS__) \
|
||||
TEST(addi, ##__VA_ARGS__) \
|
||||
TEST(xori, ##__VA_ARGS__) \
|
||||
TEST(ori, ##__VA_ARGS__)
|
||||
|
||||
// TODO: Conjunctive binary operations that need special handling are not
|
||||
// included, e.g., substraction yields a different pattern as it is mapped to
|
||||
// negate operation.
|
||||
#define FOREVERY_COMMON_CONJ_BINOP(TEST, ...) \
|
||||
TEST(mulf, ##__VA_ARGS__) \
|
||||
TEST(mulc, ##__VA_ARGS__) \
|
||||
TEST(muli, ##__VA_ARGS__) \
|
||||
TEST(andi, ##__VA_ARGS__)
|
||||
|
||||
#define FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(TEST) \
|
||||
FOREVERY_COMMON_CONJ_BINOP(TEST, addf) \
|
||||
FOREVERY_COMMON_CONJ_BINOP(TEST, addc) \
|
||||
FOREVERY_COMMON_CONJ_BINOP(TEST, addi) \
|
||||
FOREVERY_COMMON_CONJ_BINOP(TEST, xori) \
|
||||
FOREVERY_COMMON_CONJ_BINOP(TEST, ori)
|
||||
|
||||
#define FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(TEST) \
|
||||
FOREVERY_COMMON_CONJ_BINOP(TEST, mulf) \
|
||||
FOREVERY_COMMON_CONJ_BINOP(TEST, mulc) \
|
||||
FOREVERY_COMMON_CONJ_BINOP(TEST, muli) \
|
||||
FOREVERY_COMMON_CONJ_BINOP(TEST, andi)
|
||||
|
||||
#define FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(TEST) \
|
||||
FOREVERY_COMMON_DISJ_BINOP(TEST, addf) \
|
||||
FOREVERY_COMMON_DISJ_BINOP(TEST, addc) \
|
||||
FOREVERY_COMMON_DISJ_BINOP(TEST, addi) \
|
||||
FOREVERY_COMMON_DISJ_BINOP(TEST, ori) \
|
||||
FOREVERY_COMMON_DISJ_BINOP(TEST, xori)
|
||||
|
||||
///
|
||||
/// Helper classes/functions for testing Merger.
|
||||
///
|
||||
|
||||
/// Simple recursive data structure used to match expressions in Mergers.
|
||||
struct Pattern {
|
||||
Kind kind;
|
||||
|
@ -40,17 +102,16 @@ static std::shared_ptr<Pattern> tensorPattern(unsigned tensorNum) {
|
|||
return std::make_shared<Pattern>(tensorNum);
|
||||
}
|
||||
|
||||
static std::shared_ptr<Pattern>
|
||||
addfPattern(const std::shared_ptr<Pattern> &e0,
|
||||
const std::shared_ptr<Pattern> &e1) {
|
||||
return std::make_shared<Pattern>(Kind::kAddF, e0, e1);
|
||||
}
|
||||
#define IMPL_BINOP_PATTERN(OP, KIND) \
|
||||
static std::shared_ptr<Pattern> OP##Pattern( \
|
||||
const std::shared_ptr<Pattern> &e0, \
|
||||
const std::shared_ptr<Pattern> &e1) { \
|
||||
return std::make_shared<Pattern>(KIND, e0, e1); \
|
||||
}
|
||||
|
||||
static std::shared_ptr<Pattern>
|
||||
mulfPattern(const std::shared_ptr<Pattern> &e0,
|
||||
const std::shared_ptr<Pattern> &e1) {
|
||||
return std::make_shared<Pattern>(Kind::kMulF, e0, e1);
|
||||
}
|
||||
FOREVERY_BINOP(IMPL_BINOP_PATTERN)
|
||||
|
||||
#undef IMPL_BINOP_PATTERN
|
||||
|
||||
class MergerTestBase : public ::testing::Test {
|
||||
protected:
|
||||
|
@ -66,13 +127,14 @@ protected:
|
|||
return merger.addExp(Kind::kTensor, tensor);
|
||||
}
|
||||
|
||||
unsigned addf(unsigned e0, unsigned e1) {
|
||||
return merger.addExp(Kind::kAddF, e0, e1);
|
||||
#define IMPL_BINOP_EXPR(OP, KIND) \
|
||||
unsigned OP##Expr(unsigned e0, unsigned e1) { \
|
||||
return merger.addExp(KIND, e0, e1); \
|
||||
}
|
||||
|
||||
unsigned mulf(unsigned e0, unsigned e1) {
|
||||
return merger.addExp(Kind::kMulF, e0, e1);
|
||||
}
|
||||
FOREVERY_BINOP(IMPL_BINOP_EXPR)
|
||||
|
||||
#undef IMPL_BINOP_EXPR
|
||||
|
||||
///
|
||||
/// Comparison helpers.
|
||||
|
@ -87,12 +149,14 @@ protected:
|
|||
/// constraints between lattice points. We generally know how contiguous
|
||||
/// groups of lattice points should be ordered with respect to other groups,
|
||||
/// but there is no required ordering within groups.
|
||||
/// If simple is true, then compare the lat.simple field instead to test the
|
||||
/// result after optimization
|
||||
bool latPointWithinRange(unsigned s, unsigned p, unsigned n,
|
||||
const std::shared_ptr<Pattern> &pattern,
|
||||
const BitVector &bits) {
|
||||
const BitVector &bits, bool simple) {
|
||||
for (unsigned i = p; i < p + n; ++i) {
|
||||
if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) &&
|
||||
compareBits(s, i, bits))
|
||||
compareBits(s, i, bits, simple))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
@ -101,15 +165,15 @@ protected:
|
|||
/// Wrapper over latPointWithinRange for readability of tests.
|
||||
void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n,
|
||||
const std::shared_ptr<Pattern> &pattern,
|
||||
const BitVector &bits) {
|
||||
EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits));
|
||||
const BitVector &bits, bool simple = false) {
|
||||
EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits, simple));
|
||||
}
|
||||
|
||||
/// Wrapper over expectLatPointWithinRange for a single lat point.
|
||||
void expectLatPoint(unsigned s, unsigned p,
|
||||
const std::shared_ptr<Pattern> &pattern,
|
||||
const BitVector &bits) {
|
||||
EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits));
|
||||
const BitVector &bits, bool simple = false) {
|
||||
EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits, simple));
|
||||
}
|
||||
|
||||
/// Converts a vector of (loop, tensor) pairs to a bitvector with the
|
||||
|
@ -126,7 +190,11 @@ protected:
|
|||
}
|
||||
|
||||
/// Returns true if the bits of lattice point p in set s match the given bits.
|
||||
bool compareBits(unsigned s, unsigned p, const BitVector &bits) {
|
||||
/// If simple is true, then compare the lat.simple field instead to test the
|
||||
/// result after optimization
|
||||
bool compareBits(unsigned s, unsigned p, const BitVector &bits, bool simple) {
|
||||
if (simple)
|
||||
return merger.lat(merger.set(s)[p]).simple == bits;
|
||||
return merger.lat(merger.set(s)[p]).bits == bits;
|
||||
}
|
||||
|
||||
|
@ -215,6 +283,10 @@ protected:
|
|||
Merger merger;
|
||||
};
|
||||
|
||||
///
|
||||
/// Tests with all sparse inputs.
|
||||
///
|
||||
|
||||
class MergerTest3T1L : public MergerTestBase {
|
||||
protected:
|
||||
// Our three tensors (two inputs, one output).
|
||||
|
@ -238,9 +310,63 @@ protected:
|
|||
}
|
||||
};
|
||||
|
||||
class MergerTest4T1L : public MergerTestBase {
|
||||
protected:
|
||||
// Our four tensors (three inputs, one output).
|
||||
const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3;
|
||||
|
||||
// Our single loop.
|
||||
const unsigned l0 = 0;
|
||||
|
||||
MergerTest4T1L() : MergerTestBase(4, 1) {
|
||||
// Tensor 0: sparse input vector.
|
||||
merger.addExp(Kind::kTensor, t0, -1u);
|
||||
merger.setDim(t0, l0, Dim::kSparse);
|
||||
|
||||
// Tensor 1: sparse input vector.
|
||||
merger.addExp(Kind::kTensor, t1, -1u);
|
||||
merger.setDim(t1, l0, Dim::kSparse);
|
||||
|
||||
// Tensor 2: sparse input vector
|
||||
merger.addExp(Kind::kTensor, t2, -1u);
|
||||
merger.setDim(t2, l0, Dim::kSparse);
|
||||
|
||||
// Tensor 3: dense output vector
|
||||
merger.addExp(Kind::kTensor, t3, -1u);
|
||||
merger.setDim(t3, l0, Dim::kDense);
|
||||
}
|
||||
};
|
||||
|
||||
///
|
||||
/// Tests with both sparse and dense input.
|
||||
///
|
||||
|
||||
class MergerTest3T1LD : public MergerTestBase {
|
||||
protected:
|
||||
// Our three tensors (two inputs, one output).
|
||||
const unsigned t0 = 0, t1 = 1, t2 = 2;
|
||||
|
||||
// Our single loop.
|
||||
const unsigned l0 = 0;
|
||||
|
||||
MergerTest3T1LD() : MergerTestBase(3, 1) {
|
||||
// Tensor 0: sparse input vector.
|
||||
merger.addExp(Kind::kTensor, t0, -1u);
|
||||
merger.setDim(t0, l0, Dim::kSparse);
|
||||
|
||||
// Tensor 1: dense input vector.
|
||||
merger.addExp(Kind::kTensor, t1, -1u);
|
||||
merger.setDim(t1, l0, Dim::kDense);
|
||||
|
||||
// Tensor 2: dense output vector.
|
||||
merger.addExp(Kind::kTensor, t2, -1u);
|
||||
merger.setDim(t2, l0, Dim::kDense);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
/// Vector addition of 2 vectors, i.e.:
|
||||
/// Vector addition (disjunction) of 2 vectors. i.e.;
|
||||
/// a(i) = b(i) + c(i)
|
||||
/// which should form the 3 lattice points
|
||||
/// {
|
||||
|
@ -248,55 +374,254 @@ protected:
|
|||
/// lat( i_00 / tensor_0 )
|
||||
/// lat( i_01 / tensor_1 )
|
||||
/// }
|
||||
/// and after optimization, will reduce to the 2 lattice points
|
||||
/// and after optimization, the lattice points do not change (as there is no
|
||||
/// duplicated point and all input vectors are sparse vector).
|
||||
/// {
|
||||
/// lat( i_00 i_01 / (tensor_0 + tensor_1) )
|
||||
/// lat( i_00 / tensor_0 )
|
||||
/// lat( i_01 / tensor_1 )
|
||||
/// }
|
||||
TEST_F(MergerTest3T1L, VectorAdd2) {
|
||||
// Construct expression.
|
||||
auto e = addf(tensor(t0), tensor(t1));
|
||||
#define IMPL_MERGER_TEST_DISJ(OP) \
|
||||
TEST_F(MergerTest3T1L, vector_##OP) { \
|
||||
auto e = OP##Expr(tensor(t0), tensor(t1)); \
|
||||
auto p0 = tensorPattern(t0); \
|
||||
auto p1 = tensorPattern(t1); \
|
||||
auto s = merger.buildLattices(e, l0); \
|
||||
\
|
||||
expectNumLatPoints(s, 3); \
|
||||
expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}})); \
|
||||
expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}})); \
|
||||
expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}})); \
|
||||
\
|
||||
s = merger.optimizeSet(s); \
|
||||
expectNumLatPoints(s, 3); \
|
||||
expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}}), true); \
|
||||
expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}}), \
|
||||
true); \
|
||||
expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}}), \
|
||||
true); \
|
||||
}
|
||||
|
||||
// Build lattices and check.
|
||||
auto s = merger.buildLattices(e, l0);
|
||||
expectNumLatPoints(s, 3);
|
||||
expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)),
|
||||
loopsToBits({{l0, t0}, {l0, t1}}));
|
||||
expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0),
|
||||
loopsToBits({{l0, t0}}));
|
||||
expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1),
|
||||
loopsToBits({{l0, t1}}));
|
||||
FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
|
||||
|
||||
// Optimize lattices and check.
|
||||
s = merger.optimizeSet(s);
|
||||
expectNumLatPoints(s, 3);
|
||||
expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)),
|
||||
loopsToBits({{l0, t0}, {l0, t1}}));
|
||||
expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0),
|
||||
loopsToBits({{l0, t0}}));
|
||||
expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1),
|
||||
loopsToBits({{l0, t1}}));
|
||||
}
|
||||
#undef IMPL_MERGER_TEST_DISJ
|
||||
|
||||
/// Vector multiplication of 2 vectors, i.e.:
|
||||
/// Vector multiplication (conjunction) of 2 vectors, i.e.;
|
||||
/// a(i) = b(i) * c(i)
|
||||
/// which should form the single lattice point
|
||||
/// {
|
||||
/// lat( i_00 i_01 / (tensor_0 * tensor_1) )
|
||||
/// }
|
||||
TEST_F(MergerTest3T1L, VectorMul2) {
|
||||
// Construct expression.
|
||||
auto e = mulf(t0, t1);
|
||||
#define IMPL_MERGER_TEST_CONJ(OP) \
|
||||
TEST_F(MergerTest3T1L, vector_##OP) { \
|
||||
auto e = OP##Expr(t0, t1); \
|
||||
auto p0 = tensorPattern(t0); \
|
||||
auto p1 = tensorPattern(t1); \
|
||||
auto s = merger.buildLattices(e, l0); \
|
||||
\
|
||||
expectNumLatPoints(s, 1); \
|
||||
expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}})); \
|
||||
\
|
||||
s = merger.optimizeSet(s); \
|
||||
expectNumLatPoints(s, 1); \
|
||||
expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}}), true); \
|
||||
}
|
||||
|
||||
// Build lattices and check.
|
||||
auto s = merger.buildLattices(e, l0);
|
||||
expectNumLatPoints(s, 1);
|
||||
expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)),
|
||||
loopsToBits({{l0, t0}, {l0, t1}}));
|
||||
FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
|
||||
|
||||
// Optimize lattices and check.
|
||||
s = merger.optimizeSet(s);
|
||||
expectNumLatPoints(s, 1);
|
||||
expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)),
|
||||
loopsToBits({{l0, t0}, {l0, t1}}));
|
||||
}
|
||||
#undef IMPL_MERGER_TEST_CONJ
|
||||
|
||||
/// Vector multiplication (conjunction) then addition (disjunction), i.e.;
|
||||
/// a(i) = b(i) * c(i) + d(i);
|
||||
/// which should form
|
||||
/// {
|
||||
/// lat( i_00 i_01 i_02 / (tensor_0 * tensor_1) + tensor_2 )
|
||||
/// lat( i_00 i_01 / tensor_0 * tensor_1
|
||||
/// lat( i_02 / tensor_2 )
|
||||
/// }
|
||||
#define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ) \
|
||||
TEST_F(MergerTest4T1L, vector_##CONJ##_##DISJ) { \
|
||||
auto em = CONJ##Expr(t0, t1); \
|
||||
auto e = DISJ##Expr(em, t2); \
|
||||
auto p0 = tensorPattern(t0); \
|
||||
auto p1 = tensorPattern(t1); \
|
||||
auto p2 = tensorPattern(t2); \
|
||||
auto s = merger.buildLattices(e, l0); \
|
||||
\
|
||||
expectNumLatPoints(s, 3); \
|
||||
expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
|
||||
expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}})); \
|
||||
expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}})); \
|
||||
\
|
||||
s = merger.optimizeSet(s); \
|
||||
expectNumLatPoints(s, 3); \
|
||||
expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
|
||||
expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}})); \
|
||||
expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}})); \
|
||||
}
|
||||
|
||||
FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
|
||||
|
||||
#undef IMPL_MERGER_TEST_CONJ_DISJ
|
||||
|
||||
/// Vector addition (disjunction) then addition (disjunction), i.e.;
|
||||
/// a(i) = b(i) + c(i) + d(i)
|
||||
/// which should form
|
||||
/// {
|
||||
/// lat( i_00 i_01 i_02 / (tensor_0 + tensor_1) + tensor_2 )
|
||||
/// lat( i_02 i_01 / tensor_2 + tensor_1 )
|
||||
/// lat( i_02 i_00 / tensor_2 + tensor_0 )
|
||||
/// lat( i_01 i_00 / tensor_1 + tensor_0 )
|
||||
/// lat( i_02 / tensor_2 )
|
||||
/// lat( i_01 / tensor_1 )
|
||||
/// lat( i_00 / tensor_0 )
|
||||
/// }
|
||||
#define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2) \
|
||||
TEST_F(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \
|
||||
auto em = DISJ1##Expr(t0, t1); \
|
||||
auto e = DISJ2##Expr(em, t2); \
|
||||
auto p0 = tensorPattern(t0); \
|
||||
auto p1 = tensorPattern(t1); \
|
||||
auto p2 = tensorPattern(t2); \
|
||||
auto s = merger.buildLattices(e, l0); \
|
||||
\
|
||||
expectNumLatPoints(s, 7); \
|
||||
expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
|
||||
expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2), \
|
||||
loopsToBits({{l0, t1}, {l0, t2}})); \
|
||||
expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2), \
|
||||
loopsToBits({{l0, t0}, {l0, t2}})); \
|
||||
expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}})); \
|
||||
expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}})); \
|
||||
expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}})); \
|
||||
expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}})); \
|
||||
\
|
||||
s = merger.optimizeSet(s); \
|
||||
expectNumLatPoints(s, 7); \
|
||||
expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
|
||||
expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2), \
|
||||
loopsToBits({{l0, t1}, {l0, t2}})); \
|
||||
expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2), \
|
||||
loopsToBits({{l0, t0}, {l0, t2}})); \
|
||||
expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}})); \
|
||||
expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}})); \
|
||||
expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}})); \
|
||||
expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}})); \
|
||||
}
|
||||
|
||||
FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
|
||||
|
||||
#undef IMPL_MERGER_TEST_DISJ_DISJ
|
||||
|
||||
/// Vector multiplication (conjunction) then multiplication (conjunction), i.e.;
|
||||
/// a(i) = b(i) * c(i) * d(i);
|
||||
/// which should form
|
||||
/// {
|
||||
/// lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 )
|
||||
/// }
|
||||
#define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2) \
|
||||
TEST_F(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \
|
||||
auto em = CONJ1##Expr(t0, t1); \
|
||||
auto e = CONJ2##Expr(em, t2); \
|
||||
auto p0 = tensorPattern(t0); \
|
||||
auto p1 = tensorPattern(t1); \
|
||||
auto p2 = tensorPattern(t2); \
|
||||
auto s = merger.buildLattices(e, l0); \
|
||||
expectNumLatPoints(s, 1); \
|
||||
expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
|
||||
s = merger.optimizeSet(s); \
|
||||
expectNumLatPoints(s, 1); \
|
||||
expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}), true); \
|
||||
}
|
||||
|
||||
FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
|
||||
|
||||
#undef IMPL_MERGER_TEST_CONJ_CONJ
|
||||
|
||||
/// Vector addition (disjunction) of 2 vectors, i.e.;
|
||||
/// a(i) = b(i) + c(i)
|
||||
/// which should form the 3 lattice points
|
||||
/// {
|
||||
/// lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) )
|
||||
/// lat( i_00 / sparse_tensor_0 )
|
||||
/// lat( i_01 / dense_tensor_1 )
|
||||
/// }
|
||||
/// which should be optimized to
|
||||
/// {
|
||||
/// lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ) (not singleton)
|
||||
/// lat( i_01 / dense_tensor_0 ) (no sparse dimension)
|
||||
/// }
|
||||
///
|
||||
/// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
|
||||
/// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ).
|
||||
#define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP) \
|
||||
TEST_F(MergerTest3T1LD, vector_opted_##OP) { \
|
||||
auto e = OP##Expr(tensor(t0), tensor(t1)); \
|
||||
auto p0 = tensorPattern(t0); \
|
||||
auto p1 = tensorPattern(t1); \
|
||||
auto s = merger.buildLattices(e, l0); \
|
||||
\
|
||||
expectNumLatPoints(s, 3); \
|
||||
expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}})); \
|
||||
expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}})); \
|
||||
expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}})); \
|
||||
\
|
||||
s = merger.optimizeSet(s); \
|
||||
expectNumLatPoints(s, 2); \
|
||||
expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}}), true); \
|
||||
expectLatPoint(s, lat(1), p1, loopsToBits({{l0, t1}}), true); \
|
||||
}
|
||||
|
||||
FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
|
||||
|
||||
#undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
|
||||
|
||||
/// Vector multiplication (conjunction) of 2 vectors, i.e.:
|
||||
/// a(i) = b(i) * c(i)
|
||||
/// which should form the single lattice point
|
||||
/// {
|
||||
/// lat( i_00 i_01 / (sparse_tensor_0 * dense_tensor_1) )
|
||||
/// }
|
||||
/// it should be optimized to
|
||||
/// {
|
||||
/// lat( i_00 / (sparse_tensor_0 * dense_tensor_1) )
|
||||
/// }
|
||||
/// since i_01 is a dense dimension.
|
||||
#define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP) \
|
||||
TEST_F(MergerTest3T1LD, vector_opted_##OP) { \
|
||||
auto e = OP##Expr(t0, t1); \
|
||||
auto p0 = tensorPattern(t0); \
|
||||
auto p1 = tensorPattern(t1); \
|
||||
auto s = merger.buildLattices(e, l0); \
|
||||
\
|
||||
expectNumLatPoints(s, 1); \
|
||||
expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}})); \
|
||||
\
|
||||
s = merger.optimizeSet(s); \
|
||||
expectNumLatPoints(s, 1); \
|
||||
expectLatPoint(s, lat(0), OP##Pattern(p0, p1), loopsToBits({{l0, t0}}), \
|
||||
true); \
|
||||
}
|
||||
|
||||
FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
|
||||
|
||||
#undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
|
||||
|
||||
// TODO: mult-dim tests
|
||||
|
|
Loading…
Reference in New Issue