Exposing logical operators in EDSC all the way up to Python.

PiperOrigin-RevId: 232299839
This commit is contained in:
Dimitrios Vytiniotis 2019-02-04 08:15:46 -08:00 committed by jpienaar
parent b26900dce5
commit 9ca0691b06
6 changed files with 113 additions and 30 deletions

View File

@ -388,9 +388,18 @@ PYBIND11_MODULE(pybind, m) {
DEFINE_PYBIND_BINARY_OP("LE", LE); DEFINE_PYBIND_BINARY_OP("LE", LE);
DEFINE_PYBIND_BINARY_OP("GT", GT); DEFINE_PYBIND_BINARY_OP("GT", GT);
DEFINE_PYBIND_BINARY_OP("GE", GE); DEFINE_PYBIND_BINARY_OP("GE", GE);
DEFINE_PYBIND_BINARY_OP("And", And);
DEFINE_PYBIND_BINARY_OP("Or", Or);
#undef DEFINE_PYBIND_BINARY_OP #undef DEFINE_PYBIND_BINARY_OP
#define DEFINE_PYBIND_UNARY_OP(PYTHON_NAME, C_NAME) \
m.def(PYTHON_NAME, [](PythonExpr e1) { return PythonExpr(::C_NAME(e1)); });
DEFINE_PYBIND_UNARY_OP("Negate", Negate);
#undef DEFINE_PYBIND_UNARY_OP
py::class_<PythonFunction>(m, "Function", py::class_<PythonFunction>(m, "Function",
"Wrapping class for mlir::Function.") "Wrapping class for mlir::Function.")
.def(py::init<PythonFunction>()) .def(py::init<PythonFunction>())

View File

@ -90,6 +90,14 @@ class EdscTest(unittest.TestCase):
str = stmt.__str__() str = stmt.__str__()
self.assertIn("(($1 + ($2 * $3)) - $4)", str) self.assertIn("(($1 + ($2 * $3)) - $4)", str)
def testBoolean(self):
with E.ContextManager():
i, j, k, l = list(map(E.Expr, [E.Bindable() for _ in range(4)]))
stmt1 = E.And(i < j, j < k)
stmt2 = E.Negate(E.Or(stmt1, k < l))
str = stmt2.__str__()
self.assertIn("~((($1 < $2) && ($2 < $3)) || ($3 < $4))", str)
def testSelect(self): def testSelect(self):
with E.ContextManager(): with E.ContextManager():
i, j, k = list(map(E.Expr, [E.Bindable() for _ in range(3)])) i, j, k = list(map(E.Expr, [E.Bindable() for _ in range(3)]))
@ -164,6 +172,36 @@ class EdscTest(unittest.TestCase):
self.assertIn("constant 123 : i32", str) self.assertIn("constant 123 : i32", str)
self.assertIn("constant 123 : index", str) self.assertIn("constant 123 : index", str)
def testMLIRBooleanEmission(self):
module = E.MLIRModule()
t = module.make_scalar_type("i", 1)
m = module.make_memref_type(t, [10]) # i1 tensor
f = module.make_function("mkbooltensor", [m, m], [])
with E.ContextManager():
emitter = E.MLIRFunctionEmitter(f)
input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
i = E.Expr(E.Bindable())
j = E.Expr(E.Bindable())
k = E.Expr(E.Bindable())
idxs = [i, j, k]
zero = emitter.bind_constant_index(0)
one = emitter.bind_constant_index(1)
ten = emitter.bind_constant_index(10)
b1 = E.And(i < j, j < k)
b2 = E.Negate(b1)
b3 = E.Or(b2, k < j)
loop = E.Block([
E.For(idxs, [zero]*3, [ten]*3, [one]*3, [
output.store([i], E.And(input.load([i]), b3))
]),
E.Return()
])
emitter.emit(loop)
# str = f.__str__()
# print(str)
module.compile()
self.assertNotEqual(module.get_engine_address(), 0)
# TODO(ntv): support symbolic For bounds with EDSCs # TODO(ntv): support symbolic For bounds with EDSCs
def testMLIREmission(self): def testMLIREmission(self):
shape = [3, 4, 5] shape = [3, 4, 5]
@ -200,6 +238,5 @@ class EdscTest(unittest.TestCase):
module.compile() module.compile()
self.assertNotEqual(module.get_engine_address(), 0) self.assertNotEqual(module.get_engine_address(), 0)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -241,6 +241,10 @@ edsc_expr_t LE(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t GT(edsc_expr_t e1, edsc_expr_t e2); edsc_expr_t GT(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t GE(edsc_expr_t e1, edsc_expr_t e2); edsc_expr_t GE(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t And(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t Or(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t Negate(edsc_expr_t e);
#ifdef __cplusplus #ifdef __cplusplus
} // end extern "C" } // end extern "C"
#endif #endif

View File

@ -178,9 +178,6 @@ public:
Expr(const Expr &other) = default; Expr(const Expr &other) = default;
Expr &operator=(const Expr &other) = default; Expr &operator=(const Expr &other) = default;
explicit operator bool() { return storage; }
bool operator!() { return storage == nullptr; }
template <typename U> bool isa() const; template <typename U> bool isa() const;
template <typename U> U dyn_cast() const; template <typename U> U dyn_cast() const;
template <typename U> U cast() const; template <typename U> U cast() const;
@ -206,8 +203,10 @@ public:
Expr operator<=(Expr other) const; Expr operator<=(Expr other) const;
Expr operator>(Expr other) const; Expr operator>(Expr other) const;
Expr operator>=(Expr other) const; Expr operator>=(Expr other) const;
/// NB: Unlike boolean && and || these do not short-circuit.
Expr operator&&(Expr other) const; Expr operator&&(Expr other) const;
Expr operator||(Expr other) const; Expr operator||(Expr other) const;
Expr operator!() const;
/// For debugging purposes. /// For debugging purposes.
const void *getStoragePtr() const { return storage; } const void *getStoragePtr() const { return storage; }
@ -376,12 +375,13 @@ struct Stmt {
Expr operator-(Stmt other) const { return getLHS() - other.getLHS(); } Expr operator-(Stmt other) const { return getLHS() - other.getLHS(); }
Expr operator*(Stmt other) const { return getLHS() * other.getLHS(); } Expr operator*(Stmt other) const { return getLHS() * other.getLHS(); }
Expr operator<(Stmt other) const { return getLHS() + other.getLHS(); } Expr operator<(Stmt other) const { return getLHS() < other.getLHS(); }
Expr operator<=(Stmt other) const { return getLHS() + other.getLHS(); } Expr operator<=(Stmt other) const { return getLHS() <= other.getLHS(); }
Expr operator>(Stmt other) const { return getLHS() + other.getLHS(); } Expr operator>(Stmt other) const { return getLHS() > other.getLHS(); }
Expr operator>=(Stmt other) const { return getLHS() + other.getLHS(); } Expr operator>=(Stmt other) const { return getLHS() >= other.getLHS(); }
Expr operator&&(Stmt other) const { return getLHS() + other.getLHS(); } Expr operator&&(Stmt other) const { return getLHS() && other.getLHS(); }
Expr operator||(Stmt other) const { return getLHS() + other.getLHS(); } Expr operator||(Stmt other) const { return getLHS() || other.getLHS(); }
Expr operator!() const { return !getLHS(); }
protected: protected:
ImplType *storage; ImplType *storage;

View File

@ -174,10 +174,21 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) {
if (un.getKind() == ExprKind::Dealloc) { if (un.getKind() == ExprKind::Dealloc) {
builder->create<DeallocOp>(location, emitExpr(un.getExpr())); builder->create<DeallocOp>(location, emitExpr(un.getExpr()));
return nullptr; return nullptr;
} else if (un.getKind() == ExprKind::Negate) {
auto ctrue = builder->create<mlir::ConstantIntOp>(location, 1,
builder->getI1Type());
// TODO(dvytin): worth binding constant in ssaBindings in the future?
// TODO(dvytin): no need to cast getExpr() to I1?
auto val = emitExpr(un.getExpr());
assert(val->getType().isInteger(1) &&
"Logical Negate expects i1 operand");
return sub(builder, location, ctrue, val);
} }
} else if (auto bin = e.dyn_cast<BinaryExpr>()) { } else if (auto bin = e.dyn_cast<BinaryExpr>()) {
auto *a = emitExpr(bin.getLHS()); auto lhs = bin.getLHS();
auto *b = emitExpr(bin.getRHS()); auto rhs = bin.getRHS();
auto *a = emitExpr(lhs);
auto *b = emitExpr(rhs);
if (!a || !b) { if (!a || !b) {
return nullptr; return nullptr;
} }
@ -187,22 +198,19 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) {
res = sub(builder, location, a, b); res = sub(builder, location, a, b);
} else if (bin.getKind() == ExprKind::Mul) { } else if (bin.getKind() == ExprKind::Mul) {
res = mul(builder, location, a, b); res = mul(builder, location, a, b);
} } else if (bin.getKind() == ExprKind::And) {
// Vanilla comparisons operators. // Operands should both be i1
// else if (bin.getKind() == ExprKind::And) { assert(a->getType().isInteger(1) && "Logical And expects i1 LHS");
// // impl i1 assert(b->getType().isInteger(1) && "Logical And expects i1 RHS");
// res = add(builder, location, a, b); // MulIOp on i1 res = mul(builder, location, a, b);
// } } else if (bin.getKind() == ExprKind::Or) {
// else if (bin.getKind() == ExprKind::Not) { assert(a->getType().isInteger(1) && "Logical Or expects i1 LHS");
// res = ...; // 1 - cast<i1>() assert(b->getType().isInteger(1) && "Logical Or expects i1 RHS");
// } // a || b = not (not a && not b)
// else if (bin.getKind() == ExprKind::Or) { res = emitExpr(!(!lhs && !rhs));
// res = ...; // not(not(a) and not(b)) } // TODO(ntv): signed vs unsiged ??
// } // TODO(ntv): integer vs not ??
// TODO(ntv): float cmp
// TODO(ntv): signed vs unsiged ??
// TODO(ntv): integer vs not ??
// TODO(ntv): float cmp
else if (bin.getKind() == ExprKind::EQ) { else if (bin.getKind() == ExprKind::EQ) {
res = builder->create<CmpIOp>(location, mlir::CmpIPredicate::EQ, a, b); res = builder->create<CmpIOp>(location, mlir::CmpIPredicate::EQ, a, b);
} else if (bin.getKind() == ExprKind::NE) { } else if (bin.getKind() == ExprKind::NE) {
@ -628,5 +636,14 @@ DEFINE_EDSL_BINARY_OP(LT, <);
DEFINE_EDSL_BINARY_OP(LE, <=); DEFINE_EDSL_BINARY_OP(LE, <=);
DEFINE_EDSL_BINARY_OP(GT, >); DEFINE_EDSL_BINARY_OP(GT, >);
DEFINE_EDSL_BINARY_OP(GE, >=); DEFINE_EDSL_BINARY_OP(GE, >=);
DEFINE_EDSL_BINARY_OP(And, &&);
DEFINE_EDSL_BINARY_OP(Or, ||);
#undef DEFINE_EDSL_BINARY_OP #undef DEFINE_EDSL_BINARY_OP
#define DEFINE_EDSL_UNARY_OP(FUN_NAME, OP_SYMBOL) \
edsc_expr_t FUN_NAME(edsc_expr_t e) { return (OP_SYMBOL(Expr(e))); }
DEFINE_EDSL_UNARY_OP(Negate, !);
#undef DEFINE_EDSL_UNARY_OP

View File

@ -142,8 +142,10 @@ Expr mlir::edsc::Expr::operator&&(Expr other) const {
Expr mlir::edsc::Expr::operator||(Expr other) const { Expr mlir::edsc::Expr::operator||(Expr other) const {
return BinaryExpr(ExprKind::Or, *this, other); return BinaryExpr(ExprKind::Or, *this, other);
} }
Expr mlir::edsc::Expr::operator!() const {
return UnaryExpr(ExprKind::Negate, *this);
}
// Free functions.
llvm::SmallVector<Expr, 8> mlir::edsc::makeNewExprs(unsigned n) { llvm::SmallVector<Expr, 8> mlir::edsc::makeNewExprs(unsigned n) {
llvm::SmallVector<Expr, 8> res; llvm::SmallVector<Expr, 8> res;
res.reserve(n); res.reserve(n);
@ -288,7 +290,15 @@ void mlir::edsc::Expr::print(raw_ostream &os) const {
os << "$" << unbound.getId(); os << "$" << unbound.getId();
return; return;
} else if (auto un = this->dyn_cast<UnaryExpr>()) { } else if (auto un = this->dyn_cast<UnaryExpr>()) {
os << "unknown_unary"; switch (un.getKind()) {
case ExprKind::Negate:
os << "~";
break;
default: {
os << "unknown_unary";
}
}
os << un.getExpr();
} else if (auto bin = this->dyn_cast<BinaryExpr>()) { } else if (auto bin = this->dyn_cast<BinaryExpr>()) {
os << "(" << bin.getLHS(); os << "(" << bin.getLHS();
switch (bin.getKind()) { switch (bin.getKind()) {
@ -316,6 +326,12 @@ void mlir::edsc::Expr::print(raw_ostream &os) const {
case ExprKind::GE: case ExprKind::GE:
os << " >= "; os << " >= ";
break; break;
case ExprKind::And:
os << " && ";
break;
case ExprKind::Or:
os << " || ";
break;
default: { default: {
os << "unknown_binary"; os << "unknown_binary";
} }