Introduce SDBMDirect expression into the SDBM expression hierarchy

Direct expressions are those that do not negate any of the variables they
involve.  They include input expressions (dimensions and symbols), stripe and
sum expressions, and combinations of those.  Reifying direct expressions as a
class is a precondition for enabling additions on the LHS of a stripe
expression.

PiperOrigin-RevId: 269336031
This commit is contained in:
Alex Zinenko 2019-09-16 08:08:22 -07:00 committed by A. Unique TensorFlower
parent 0ce64b0bf3
commit e94db619d9
2 changed files with 39 additions and 11 deletions

View File

@ -176,14 +176,30 @@ public:
}
};
/// SDBM direct expression includes exactly one variable (symbol or dimension),
/// which is not negated in the expression. It can be one of:
/// - term expression;
/// - sum expression.
class SDBMDirectExpr : public SDBMVaryingExpr {
public:
using SDBMVaryingExpr::SDBMVaryingExpr;
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::DimId ||
expr.getKind() == SDBMExprKind::SymbolId ||
expr.getKind() == SDBMExprKind::Stripe ||
expr.getKind() == SDBMExprKind::Add;
}
};
/// SDBM term expression can be one of:
/// - single variable expression;
/// - stripe expression.
/// Stripe expressions are treated as terms since, in the SDBM domain, they are
/// attached to temporary variables and can appear anywhere a variable can.
class SDBMTermExpr : public SDBMVaryingExpr {
class SDBMTermExpr : public SDBMDirectExpr {
public:
using SDBMVaryingExpr::SDBMVaryingExpr;
using SDBMDirectExpr::SDBMDirectExpr;
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::DimId ||
@ -194,10 +210,10 @@ public:
/// SDBM sum expression. LHS is a varying expression and RHS is always a
/// constant expression.
class SDBMSumExpr : public SDBMVaryingExpr {
class SDBMSumExpr : public SDBMDirectExpr {
public:
using ImplType = detail::SDBMBinaryExprStorage;
using SDBMVaryingExpr::SDBMVaryingExpr;
using SDBMDirectExpr::SDBMDirectExpr;
/// Obtain or create a sum expression unique'ed in the given context.
static SDBMSumExpr get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs);
@ -352,9 +368,19 @@ protected:
void visitNeg(SDBMNegExpr) {}
void visitConstant(SDBMConstantExpr) {}
/// Default implementation of visitTerm dispatches to the special
/// functions for stripes and other variables. Concrete visitors can override
/// it.
/// Default implementation of visitDirect dispatches to the dedicated for sums
/// or delegates to visitTerm for the other expression kinds. Concrete
/// visitors can overload it.
Result visitDirect(SDBMDirectExpr expr) {
auto *derived = static_cast<Derived *>(this);
if (auto sum = expr.dyn_cast<SDBMSumExpr>())
return derived->visitSum(sum);
else
return derived->visitTerm(expr.cast<SDBMTermExpr>());
}
/// Default implementation of visitTerm dispatches to the special functions
/// for stripes and other variables. Concrete visitors can override it.
Result visitTerm(SDBMTermExpr expr) {
auto *derived = static_cast<Derived *>(this);
if (expr.getKind() == SDBMExprKind::Stripe)
@ -379,12 +405,10 @@ protected:
/// override it to visit all variables and negations instead.
Result visitVarying(SDBMVaryingExpr expr) {
auto *derived = static_cast<Derived *>(this);
if (auto var = expr.dyn_cast<SDBMTermExpr>())
return derived->visitTerm(var);
if (auto var = expr.dyn_cast<SDBMDirectExpr>())
return derived->visitDirect(var);
else if (auto neg = expr.dyn_cast<SDBMNegExpr>())
return derived->visitNeg(neg);
else if (auto sum = expr.dyn_cast<SDBMSumExpr>())
return derived->visitSum(sum);
else if (auto diff = expr.dyn_cast<SDBMDiffExpr>())
return derived->visitDiff(diff);

View File

@ -174,6 +174,7 @@ TEST(SDBMExpr, Dim) {
EXPECT_TRUE(generic.isa<SDBMDimExpr>());
EXPECT_TRUE(generic.isa<SDBMInputExpr>());
EXPECT_TRUE(generic.isa<SDBMTermExpr>());
EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
// Dimensions are not Symbols.
@ -196,6 +197,7 @@ TEST(SDBMExpr, Symbol) {
EXPECT_TRUE(generic.isa<SDBMSymbolExpr>());
EXPECT_TRUE(generic.isa<SDBMInputExpr>());
EXPECT_TRUE(generic.isa<SDBMTermExpr>());
EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
// Dimensions are not Symbols.
@ -229,6 +231,7 @@ TEST(SDBMExpr, Stripe) {
auto generic = static_cast<SDBMExpr>(expr);
EXPECT_TRUE(generic.isa<SDBMStripeExpr>());
EXPECT_TRUE(generic.isa<SDBMTermExpr>());
EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
}
@ -271,6 +274,7 @@ TEST(SDBMExpr, Sum) {
// Hierarchy is okay.
auto generic = static_cast<SDBMExpr>(expr);
EXPECT_TRUE(generic.isa<SDBMSumExpr>());
EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
}