forked from OSchip/llvm-project
Exposing logical operators in EDSC all the way up to Python.
PiperOrigin-RevId: 232299839
This commit is contained in:
parent
b26900dce5
commit
9ca0691b06
|
@ -388,9 +388,18 @@ PYBIND11_MODULE(pybind, m) {
|
||||||
DEFINE_PYBIND_BINARY_OP("LE", LE);
|
DEFINE_PYBIND_BINARY_OP("LE", LE);
|
||||||
DEFINE_PYBIND_BINARY_OP("GT", GT);
|
DEFINE_PYBIND_BINARY_OP("GT", GT);
|
||||||
DEFINE_PYBIND_BINARY_OP("GE", GE);
|
DEFINE_PYBIND_BINARY_OP("GE", GE);
|
||||||
|
DEFINE_PYBIND_BINARY_OP("And", And);
|
||||||
|
DEFINE_PYBIND_BINARY_OP("Or", Or);
|
||||||
|
|
||||||
#undef DEFINE_PYBIND_BINARY_OP
|
#undef DEFINE_PYBIND_BINARY_OP
|
||||||
|
|
||||||
|
#define DEFINE_PYBIND_UNARY_OP(PYTHON_NAME, C_NAME) \
|
||||||
|
m.def(PYTHON_NAME, [](PythonExpr e1) { return PythonExpr(::C_NAME(e1)); });
|
||||||
|
|
||||||
|
DEFINE_PYBIND_UNARY_OP("Negate", Negate);
|
||||||
|
|
||||||
|
#undef DEFINE_PYBIND_UNARY_OP
|
||||||
|
|
||||||
py::class_<PythonFunction>(m, "Function",
|
py::class_<PythonFunction>(m, "Function",
|
||||||
"Wrapping class for mlir::Function.")
|
"Wrapping class for mlir::Function.")
|
||||||
.def(py::init<PythonFunction>())
|
.def(py::init<PythonFunction>())
|
||||||
|
|
|
@ -90,6 +90,14 @@ class EdscTest(unittest.TestCase):
|
||||||
str = stmt.__str__()
|
str = stmt.__str__()
|
||||||
self.assertIn("(($1 + ($2 * $3)) - $4)", str)
|
self.assertIn("(($1 + ($2 * $3)) - $4)", str)
|
||||||
|
|
||||||
|
def testBoolean(self):
|
||||||
|
with E.ContextManager():
|
||||||
|
i, j, k, l = list(map(E.Expr, [E.Bindable() for _ in range(4)]))
|
||||||
|
stmt1 = E.And(i < j, j < k)
|
||||||
|
stmt2 = E.Negate(E.Or(stmt1, k < l))
|
||||||
|
str = stmt2.__str__()
|
||||||
|
self.assertIn("~((($1 < $2) && ($2 < $3)) || ($3 < $4))", str)
|
||||||
|
|
||||||
def testSelect(self):
|
def testSelect(self):
|
||||||
with E.ContextManager():
|
with E.ContextManager():
|
||||||
i, j, k = list(map(E.Expr, [E.Bindable() for _ in range(3)]))
|
i, j, k = list(map(E.Expr, [E.Bindable() for _ in range(3)]))
|
||||||
|
@ -164,6 +172,36 @@ class EdscTest(unittest.TestCase):
|
||||||
self.assertIn("constant 123 : i32", str)
|
self.assertIn("constant 123 : i32", str)
|
||||||
self.assertIn("constant 123 : index", str)
|
self.assertIn("constant 123 : index", str)
|
||||||
|
|
||||||
|
def testMLIRBooleanEmission(self):
|
||||||
|
module = E.MLIRModule()
|
||||||
|
t = module.make_scalar_type("i", 1)
|
||||||
|
m = module.make_memref_type(t, [10]) # i1 tensor
|
||||||
|
f = module.make_function("mkbooltensor", [m, m], [])
|
||||||
|
with E.ContextManager():
|
||||||
|
emitter = E.MLIRFunctionEmitter(f)
|
||||||
|
input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
|
||||||
|
i = E.Expr(E.Bindable())
|
||||||
|
j = E.Expr(E.Bindable())
|
||||||
|
k = E.Expr(E.Bindable())
|
||||||
|
idxs = [i, j, k]
|
||||||
|
zero = emitter.bind_constant_index(0)
|
||||||
|
one = emitter.bind_constant_index(1)
|
||||||
|
ten = emitter.bind_constant_index(10)
|
||||||
|
b1 = E.And(i < j, j < k)
|
||||||
|
b2 = E.Negate(b1)
|
||||||
|
b3 = E.Or(b2, k < j)
|
||||||
|
loop = E.Block([
|
||||||
|
E.For(idxs, [zero]*3, [ten]*3, [one]*3, [
|
||||||
|
output.store([i], E.And(input.load([i]), b3))
|
||||||
|
]),
|
||||||
|
E.Return()
|
||||||
|
])
|
||||||
|
emitter.emit(loop)
|
||||||
|
# str = f.__str__()
|
||||||
|
# print(str)
|
||||||
|
module.compile()
|
||||||
|
self.assertNotEqual(module.get_engine_address(), 0)
|
||||||
|
|
||||||
# TODO(ntv): support symbolic For bounds with EDSCs
|
# TODO(ntv): support symbolic For bounds with EDSCs
|
||||||
def testMLIREmission(self):
|
def testMLIREmission(self):
|
||||||
shape = [3, 4, 5]
|
shape = [3, 4, 5]
|
||||||
|
@ -200,6 +238,5 @@ class EdscTest(unittest.TestCase):
|
||||||
module.compile()
|
module.compile()
|
||||||
self.assertNotEqual(module.get_engine_address(), 0)
|
self.assertNotEqual(module.get_engine_address(), 0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -241,6 +241,10 @@ edsc_expr_t LE(edsc_expr_t e1, edsc_expr_t e2);
|
||||||
edsc_expr_t GT(edsc_expr_t e1, edsc_expr_t e2);
|
edsc_expr_t GT(edsc_expr_t e1, edsc_expr_t e2);
|
||||||
edsc_expr_t GE(edsc_expr_t e1, edsc_expr_t e2);
|
edsc_expr_t GE(edsc_expr_t e1, edsc_expr_t e2);
|
||||||
|
|
||||||
|
edsc_expr_t And(edsc_expr_t e1, edsc_expr_t e2);
|
||||||
|
edsc_expr_t Or(edsc_expr_t e1, edsc_expr_t e2);
|
||||||
|
edsc_expr_t Negate(edsc_expr_t e);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} // end extern "C"
|
} // end extern "C"
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -178,9 +178,6 @@ public:
|
||||||
Expr(const Expr &other) = default;
|
Expr(const Expr &other) = default;
|
||||||
Expr &operator=(const Expr &other) = default;
|
Expr &operator=(const Expr &other) = default;
|
||||||
|
|
||||||
explicit operator bool() { return storage; }
|
|
||||||
bool operator!() { return storage == nullptr; }
|
|
||||||
|
|
||||||
template <typename U> bool isa() const;
|
template <typename U> bool isa() const;
|
||||||
template <typename U> U dyn_cast() const;
|
template <typename U> U dyn_cast() const;
|
||||||
template <typename U> U cast() const;
|
template <typename U> U cast() const;
|
||||||
|
@ -206,8 +203,10 @@ public:
|
||||||
Expr operator<=(Expr other) const;
|
Expr operator<=(Expr other) const;
|
||||||
Expr operator>(Expr other) const;
|
Expr operator>(Expr other) const;
|
||||||
Expr operator>=(Expr other) const;
|
Expr operator>=(Expr other) const;
|
||||||
|
/// NB: Unlike boolean && and || these do not short-circuit.
|
||||||
Expr operator&&(Expr other) const;
|
Expr operator&&(Expr other) const;
|
||||||
Expr operator||(Expr other) const;
|
Expr operator||(Expr other) const;
|
||||||
|
Expr operator!() const;
|
||||||
|
|
||||||
/// For debugging purposes.
|
/// For debugging purposes.
|
||||||
const void *getStoragePtr() const { return storage; }
|
const void *getStoragePtr() const { return storage; }
|
||||||
|
@ -376,12 +375,13 @@ struct Stmt {
|
||||||
Expr operator-(Stmt other) const { return getLHS() - other.getLHS(); }
|
Expr operator-(Stmt other) const { return getLHS() - other.getLHS(); }
|
||||||
Expr operator*(Stmt other) const { return getLHS() * other.getLHS(); }
|
Expr operator*(Stmt other) const { return getLHS() * other.getLHS(); }
|
||||||
|
|
||||||
Expr operator<(Stmt other) const { return getLHS() + other.getLHS(); }
|
Expr operator<(Stmt other) const { return getLHS() < other.getLHS(); }
|
||||||
Expr operator<=(Stmt other) const { return getLHS() + other.getLHS(); }
|
Expr operator<=(Stmt other) const { return getLHS() <= other.getLHS(); }
|
||||||
Expr operator>(Stmt other) const { return getLHS() + other.getLHS(); }
|
Expr operator>(Stmt other) const { return getLHS() > other.getLHS(); }
|
||||||
Expr operator>=(Stmt other) const { return getLHS() + other.getLHS(); }
|
Expr operator>=(Stmt other) const { return getLHS() >= other.getLHS(); }
|
||||||
Expr operator&&(Stmt other) const { return getLHS() + other.getLHS(); }
|
Expr operator&&(Stmt other) const { return getLHS() && other.getLHS(); }
|
||||||
Expr operator||(Stmt other) const { return getLHS() + other.getLHS(); }
|
Expr operator||(Stmt other) const { return getLHS() || other.getLHS(); }
|
||||||
|
Expr operator!() const { return !getLHS(); }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
ImplType *storage;
|
ImplType *storage;
|
||||||
|
|
|
@ -174,10 +174,21 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) {
|
||||||
if (un.getKind() == ExprKind::Dealloc) {
|
if (un.getKind() == ExprKind::Dealloc) {
|
||||||
builder->create<DeallocOp>(location, emitExpr(un.getExpr()));
|
builder->create<DeallocOp>(location, emitExpr(un.getExpr()));
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
} else if (un.getKind() == ExprKind::Negate) {
|
||||||
|
auto ctrue = builder->create<mlir::ConstantIntOp>(location, 1,
|
||||||
|
builder->getI1Type());
|
||||||
|
// TODO(dvytin): worth binding constant in ssaBindings in the future?
|
||||||
|
// TODO(dvytin): no need to cast getExpr() to I1?
|
||||||
|
auto val = emitExpr(un.getExpr());
|
||||||
|
assert(val->getType().isInteger(1) &&
|
||||||
|
"Logical Negate expects i1 operand");
|
||||||
|
return sub(builder, location, ctrue, val);
|
||||||
}
|
}
|
||||||
} else if (auto bin = e.dyn_cast<BinaryExpr>()) {
|
} else if (auto bin = e.dyn_cast<BinaryExpr>()) {
|
||||||
auto *a = emitExpr(bin.getLHS());
|
auto lhs = bin.getLHS();
|
||||||
auto *b = emitExpr(bin.getRHS());
|
auto rhs = bin.getRHS();
|
||||||
|
auto *a = emitExpr(lhs);
|
||||||
|
auto *b = emitExpr(rhs);
|
||||||
if (!a || !b) {
|
if (!a || !b) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -187,22 +198,19 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) {
|
||||||
res = sub(builder, location, a, b);
|
res = sub(builder, location, a, b);
|
||||||
} else if (bin.getKind() == ExprKind::Mul) {
|
} else if (bin.getKind() == ExprKind::Mul) {
|
||||||
res = mul(builder, location, a, b);
|
res = mul(builder, location, a, b);
|
||||||
}
|
} else if (bin.getKind() == ExprKind::And) {
|
||||||
// Vanilla comparisons operators.
|
// Operands should both be i1
|
||||||
// else if (bin.getKind() == ExprKind::And) {
|
assert(a->getType().isInteger(1) && "Logical And expects i1 LHS");
|
||||||
// // impl i1
|
assert(b->getType().isInteger(1) && "Logical And expects i1 RHS");
|
||||||
// res = add(builder, location, a, b); // MulIOp on i1
|
res = mul(builder, location, a, b);
|
||||||
// }
|
} else if (bin.getKind() == ExprKind::Or) {
|
||||||
// else if (bin.getKind() == ExprKind::Not) {
|
assert(a->getType().isInteger(1) && "Logical Or expects i1 LHS");
|
||||||
// res = ...; // 1 - cast<i1>()
|
assert(b->getType().isInteger(1) && "Logical Or expects i1 RHS");
|
||||||
// }
|
// a || b = not (not a && not b)
|
||||||
// else if (bin.getKind() == ExprKind::Or) {
|
res = emitExpr(!(!lhs && !rhs));
|
||||||
// res = ...; // not(not(a) and not(b))
|
} // TODO(ntv): signed vs unsiged ??
|
||||||
// }
|
// TODO(ntv): integer vs not ??
|
||||||
|
// TODO(ntv): float cmp
|
||||||
// TODO(ntv): signed vs unsiged ??
|
|
||||||
// TODO(ntv): integer vs not ??
|
|
||||||
// TODO(ntv): float cmp
|
|
||||||
else if (bin.getKind() == ExprKind::EQ) {
|
else if (bin.getKind() == ExprKind::EQ) {
|
||||||
res = builder->create<CmpIOp>(location, mlir::CmpIPredicate::EQ, a, b);
|
res = builder->create<CmpIOp>(location, mlir::CmpIPredicate::EQ, a, b);
|
||||||
} else if (bin.getKind() == ExprKind::NE) {
|
} else if (bin.getKind() == ExprKind::NE) {
|
||||||
|
@ -628,5 +636,14 @@ DEFINE_EDSL_BINARY_OP(LT, <);
|
||||||
DEFINE_EDSL_BINARY_OP(LE, <=);
|
DEFINE_EDSL_BINARY_OP(LE, <=);
|
||||||
DEFINE_EDSL_BINARY_OP(GT, >);
|
DEFINE_EDSL_BINARY_OP(GT, >);
|
||||||
DEFINE_EDSL_BINARY_OP(GE, >=);
|
DEFINE_EDSL_BINARY_OP(GE, >=);
|
||||||
|
DEFINE_EDSL_BINARY_OP(And, &&);
|
||||||
|
DEFINE_EDSL_BINARY_OP(Or, ||);
|
||||||
|
|
||||||
#undef DEFINE_EDSL_BINARY_OP
|
#undef DEFINE_EDSL_BINARY_OP
|
||||||
|
|
||||||
|
#define DEFINE_EDSL_UNARY_OP(FUN_NAME, OP_SYMBOL) \
|
||||||
|
edsc_expr_t FUN_NAME(edsc_expr_t e) { return (OP_SYMBOL(Expr(e))); }
|
||||||
|
|
||||||
|
DEFINE_EDSL_UNARY_OP(Negate, !);
|
||||||
|
|
||||||
|
#undef DEFINE_EDSL_UNARY_OP
|
||||||
|
|
|
@ -142,8 +142,10 @@ Expr mlir::edsc::Expr::operator&&(Expr other) const {
|
||||||
Expr mlir::edsc::Expr::operator||(Expr other) const {
|
Expr mlir::edsc::Expr::operator||(Expr other) const {
|
||||||
return BinaryExpr(ExprKind::Or, *this, other);
|
return BinaryExpr(ExprKind::Or, *this, other);
|
||||||
}
|
}
|
||||||
|
Expr mlir::edsc::Expr::operator!() const {
|
||||||
|
return UnaryExpr(ExprKind::Negate, *this);
|
||||||
|
}
|
||||||
|
|
||||||
// Free functions.
|
|
||||||
llvm::SmallVector<Expr, 8> mlir::edsc::makeNewExprs(unsigned n) {
|
llvm::SmallVector<Expr, 8> mlir::edsc::makeNewExprs(unsigned n) {
|
||||||
llvm::SmallVector<Expr, 8> res;
|
llvm::SmallVector<Expr, 8> res;
|
||||||
res.reserve(n);
|
res.reserve(n);
|
||||||
|
@ -288,7 +290,15 @@ void mlir::edsc::Expr::print(raw_ostream &os) const {
|
||||||
os << "$" << unbound.getId();
|
os << "$" << unbound.getId();
|
||||||
return;
|
return;
|
||||||
} else if (auto un = this->dyn_cast<UnaryExpr>()) {
|
} else if (auto un = this->dyn_cast<UnaryExpr>()) {
|
||||||
os << "unknown_unary";
|
switch (un.getKind()) {
|
||||||
|
case ExprKind::Negate:
|
||||||
|
os << "~";
|
||||||
|
break;
|
||||||
|
default: {
|
||||||
|
os << "unknown_unary";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
os << un.getExpr();
|
||||||
} else if (auto bin = this->dyn_cast<BinaryExpr>()) {
|
} else if (auto bin = this->dyn_cast<BinaryExpr>()) {
|
||||||
os << "(" << bin.getLHS();
|
os << "(" << bin.getLHS();
|
||||||
switch (bin.getKind()) {
|
switch (bin.getKind()) {
|
||||||
|
@ -316,6 +326,12 @@ void mlir::edsc::Expr::print(raw_ostream &os) const {
|
||||||
case ExprKind::GE:
|
case ExprKind::GE:
|
||||||
os << " >= ";
|
os << " >= ";
|
||||||
break;
|
break;
|
||||||
|
case ExprKind::And:
|
||||||
|
os << " && ";
|
||||||
|
break;
|
||||||
|
case ExprKind::Or:
|
||||||
|
os << " || ";
|
||||||
|
break;
|
||||||
default: {
|
default: {
|
||||||
os << "unknown_binary";
|
os << "unknown_binary";
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue