From 1cc9305c71e425d20b87d1e6e2e397d4141fecfd Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 21 Feb 2019 03:09:51 -0800 Subject: [PATCH] Exposed division and remainder operations in EDSC This change introduces three new operators in EDSC: Div (also exposed via Expr.__div__ aka /) -- floating-point division, FloorDiv and CeilDiv for flooring/ceiling index division. The lowering to LLVM will be implemented in b/124872679. PiperOrigin-RevId: 234963217 --- mlir/bindings/python/pybind.cpp | 15 ++++- mlir/bindings/python/test/test_py2and3.py | 4 +- mlir/include/mlir-c/Core.h | 6 +- mlir/include/mlir/EDSC/Types.h | 5 ++ mlir/include/mlir/StandardOps/standard_ops.td | 8 +++ mlir/lib/EDSC/MLIREmitter.cpp | 11 +++- mlir/lib/EDSC/Types.cpp | 55 +++++++++++++++---- 7 files changed, 86 insertions(+), 18 deletions(-) diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp index 3d8f9a5761d9..ef2a4d92dcc8 100644 --- a/mlir/bindings/python/pybind.cpp +++ b/mlir/bindings/python/pybind.cpp @@ -494,7 +494,10 @@ PYBIND11_MODULE(pybind, m) { DEFINE_PYBIND_BINARY_OP("Add", Add); DEFINE_PYBIND_BINARY_OP("Mul", Mul); DEFINE_PYBIND_BINARY_OP("Sub", Sub); - // DEFINE_PYBIND_BINARY_OP("Div", Div); + DEFINE_PYBIND_BINARY_OP("Div", Div); + DEFINE_PYBIND_BINARY_OP("Rem", Rem); + DEFINE_PYBIND_BINARY_OP("FloorDiv", FloorDiv); + DEFINE_PYBIND_BINARY_OP("CeilDiv", CeilDiv); DEFINE_PYBIND_BINARY_OP("LT", LT); DEFINE_PYBIND_BINARY_OP("LE", LE); DEFINE_PYBIND_BINARY_OP("GT", GT); @@ -649,8 +652,14 @@ PYBIND11_MODULE(pybind, m) { PythonExpr e2) { return PythonExpr(::Sub(e1, e2)); }) .def("__mul__", [](PythonExpr e1, PythonExpr e2) { return PythonExpr(::Mul(e1, e2)); }) - // .def("__div__", [](PythonExpr e1, PythonExpr e2) { return - // PythonExpr(::Div(e1, e2)); }) + .def("__div__", [](PythonExpr e1, + PythonExpr e2) { return PythonExpr(::Div(e1, e2)); }) + .def("__truediv__", + [](PythonExpr e1, PythonExpr e2) { + return PythonExpr(::Div(e1, e2)); + }) + .def("__mod__", [](PythonExpr e1, + PythonExpr e2) { return PythonExpr(::Rem(e1, e2)); }) .def("__lt__", [](PythonExpr e1, PythonExpr e2) { return PythonExpr(::LT(e1, e2)); }) .def("__le__", [](PythonExpr e1, diff --git a/mlir/bindings/python/test/test_py2and3.py b/mlir/bindings/python/test/test_py2and3.py index 4fed50300bea..c004bf6d3c60 100644 --- a/mlir/bindings/python/test/test_py2and3.py +++ b/mlir/bindings/python/test/test_py2and3.py @@ -113,9 +113,9 @@ class EdscTest(unittest.TestCase): with E.ContextManager(): i, j, k, l = list( map(E.Expr, [E.Bindable(self.f32Type) for _ in range(4)])) - stmt = i + j * k - l + stmt = i % j + j * k - l / k str = stmt.__str__() - self.assertIn("(($1 + ($2 * $3)) - $4)", str) + self.assertIn("((($1 % $2) + ($2 * $3)) - ($4 / $3))", str) def testBoolean(self): with E.ContextManager(): diff --git a/mlir/include/mlir-c/Core.h b/mlir/include/mlir-c/Core.h index c0f5286bf058..09dc5c850e27 100644 --- a/mlir/include/mlir-c/Core.h +++ b/mlir/include/mlir-c/Core.h @@ -260,7 +260,8 @@ edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_max_expr_t lb, edsc_min_expr_t ub, edsc_expr_t Add(edsc_expr_t e1, edsc_expr_t e2); edsc_expr_t Sub(edsc_expr_t e1, edsc_expr_t e2); edsc_expr_t Mul(edsc_expr_t e1, edsc_expr_t e2); -// edsc_expr_t Div(edsc_expr_t e1, edsc_expr_t e2); +edsc_expr_t Div(edsc_expr_t e1, edsc_expr_t e2); +edsc_expr_t Rem(edsc_expr_t e1, edsc_expr_t e2); edsc_expr_t LT(edsc_expr_t e1, edsc_expr_t e2); edsc_expr_t LE(edsc_expr_t e1, edsc_expr_t e2); edsc_expr_t GT(edsc_expr_t e1, edsc_expr_t e2); @@ -268,6 +269,9 @@ edsc_expr_t GE(edsc_expr_t e1, edsc_expr_t e2); edsc_expr_t EQ(edsc_expr_t e1, edsc_expr_t e2); edsc_expr_t NE(edsc_expr_t e1, edsc_expr_t e2); +edsc_expr_t FloorDiv(edsc_expr_t e1, edsc_expr_t e2); +edsc_expr_t CeilDiv(edsc_expr_t e1, edsc_expr_t e2); + edsc_expr_t And(edsc_expr_t e1, edsc_expr_t e2); edsc_expr_t Or(edsc_expr_t e1, edsc_expr_t e2); edsc_expr_t Negate(edsc_expr_t e); diff --git a/mlir/include/mlir/EDSC/Types.h b/mlir/include/mlir/EDSC/Types.h index 6234fe62d743..bb333a93938c 100644 --- a/mlir/include/mlir/EDSC/Types.h +++ b/mlir/include/mlir/EDSC/Types.h @@ -436,6 +436,8 @@ namespace op { Expr operator+(Expr lhs, Expr rhs); Expr operator-(Expr lhs, Expr rhs); Expr operator*(Expr lhs, Expr rhs); +Expr operator/(Expr lhs, Expr rhs); +Expr operator%(Expr lhs, Expr rhs); /// In particular operator==, operator!= return a new Expr and *not* a bool. Expr operator==(Expr lhs, Expr rhs); Expr operator!=(Expr lhs, Expr rhs); @@ -479,6 +481,9 @@ inline Expr operator||(Stmt lhs, Stmt rhs) { inline Expr operator!(Stmt stmt) { return !stmt.getLHS(); } } // end namespace op +Expr floorDiv(Expr lhs, Expr rhs); +Expr ceilDiv(Expr lhs, Expr rhs); + template bool Expr::isa() const { auto kind = getKind(); if (std::is_same::value) { diff --git a/mlir/include/mlir/StandardOps/standard_ops.td b/mlir/include/mlir/StandardOps/standard_ops.td index b7f2c3976efa..08ce19e671ed 100644 --- a/mlir/include/mlir/StandardOps/standard_ops.td +++ b/mlir/include/mlir/StandardOps/standard_ops.td @@ -82,6 +82,10 @@ def AddIOp : IntArithmeticOp<"addi", [Commutative]> { let hasConstantFolder = 0b1; } +def DivFOp : FloatArithmeticOp<"divf"> { + let summary = "floating point division operation"; +} + def DivISOp : IntArithmeticOp<"divis"> { let summary = "signed integer division operation"; let hasConstantFolder = 0b1; @@ -103,6 +107,10 @@ def MulIOp : IntArithmeticOp<"muli", [Commutative]> { let hasFolder = 1; } +def RemFOp : FloatArithmeticOp<"remf"> { + let summary = "floating point division remainder operation"; +} + def RemISOp : IntArithmeticOp<"remis"> { let summary = "signed integer division remainder operation"; let hasConstantFolder = 0b1; diff --git a/mlir/lib/EDSC/MLIREmitter.cpp b/mlir/lib/EDSC/MLIREmitter.cpp index 0519c725308b..aa4ca47e6609 100644 --- a/mlir/lib/EDSC/MLIREmitter.cpp +++ b/mlir/lib/EDSC/MLIREmitter.cpp @@ -534,7 +534,8 @@ void bindMemRefView(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef, DEFINE_EDSL_BINARY_OP(Add, +); DEFINE_EDSL_BINARY_OP(Sub, -); DEFINE_EDSL_BINARY_OP(Mul, *); -// DEFINE_EDSL_BINARY_OP(Div, /); +DEFINE_EDSL_BINARY_OP(Div, /); +DEFINE_EDSL_BINARY_OP(Rem, %); DEFINE_EDSL_BINARY_OP(LT, <); DEFINE_EDSL_BINARY_OP(LE, <=); DEFINE_EDSL_BINARY_OP(GT, >); @@ -546,6 +547,14 @@ DEFINE_EDSL_BINARY_OP(Or, ||); #undef DEFINE_EDSL_BINARY_OP +edsc_expr_t FloorDiv(edsc_expr_t e1, edsc_expr_t e2) { + return edsc::floorDiv(Expr(e1), Expr(e2)); +} + +edsc_expr_t CeilDiv(edsc_expr_t e1, edsc_expr_t e2) { + return edsc::ceilDiv(Expr(e1), Expr(e2)); +} + #define DEFINE_EDSL_UNARY_OP(FUN_NAME, OP_SYMBOL) \ edsc_expr_t FUN_NAME(edsc_expr_t e) { \ using edsc::op::operator OP_SYMBOL; \ diff --git a/mlir/lib/EDSC/Types.cpp b/mlir/lib/EDSC/Types.cpp index 2f40001d064e..8dc4db1f22b1 100644 --- a/mlir/lib/EDSC/Types.cpp +++ b/mlir/lib/EDSC/Types.cpp @@ -205,6 +205,26 @@ static AffineExpr createOperandAffineExpr(Expr e, int64_t position, return getAffineDimExpr(position, context); } +static Expr createBinaryIndexExpr( + Expr lhs, Expr rhs, + std::function affCombiner) { + assert(lhs.getResultTypes().size() == 1 && rhs.getResultTypes().size() == 1 && + "only single-result exprs are supported in operators"); + auto thisType = lhs.getResultTypes().front(); + auto thatType = rhs.getResultTypes().front(); + assert(thisType == thatType && "cannot mix types in operators"); + assert(thisType.isIndex() && "expected exprs of index type"); + MLIRContext *context = thisType.getContext(); + auto lhsAff = createOperandAffineExpr(lhs, 0, context); + auto rhsAff = createOperandAffineExpr(rhs, 1, context); + auto map = AffineMap::get(2, 0, {affCombiner(lhsAff, rhsAff)}, {}); + auto attr = AffineMapAttr::get(map); + auto attrId = Identifier::get("map", context); + auto namedAttr = NamedAttribute{attrId, attr}; + return VariadicExpr("affine.apply", {lhs, rhs}, {IndexType::get(context)}, + {namedAttr}); +} + // Create a binary expression between the two arguments emitting `IOp` if // arguments are integers or vectors/tensors thereof, `FOp` if arguments are // floating-point or vectors/tensors thereof, and `AffineApplyOp` with an @@ -220,15 +240,7 @@ static Expr createBinaryExpr( auto thatType = rhs.getResultTypes().front(); assert(thisType == thatType && "cannot mix types in operators"); if (thisType.isIndex()) { - MLIRContext *context = thisType.getContext(); - auto lhsAff = createOperandAffineExpr(lhs, 0, context); - auto rhsAff = createOperandAffineExpr(rhs, 1, context); - auto map = AffineMap::get(2, 0, {affCombiner(lhsAff, rhsAff)}, {}); - auto attr = AffineMapAttr::get(map); - auto attrId = Identifier::get("map", context); - auto namedAttr = NamedAttribute{attrId, attr}; - return VariadicExpr("affine.apply", {lhs, rhs}, {IndexType::get(context)}, - {namedAttr}); + return createBinaryIndexExpr(lhs, rhs, affCombiner); } else if (thisType.isa()) { return BinaryExpr::make(thisType, lhs, rhs); } else if (thisType.isa()) { @@ -255,6 +267,25 @@ Expr mlir::edsc::op::operator*(Expr lhs, Expr rhs) { return createBinaryExpr( lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 * d1; }); } +Expr mlir::edsc::op::operator/(Expr lhs, Expr rhs) { + return createBinaryExpr( + lhs, rhs, [](AffineExpr d0, AffineExpr d1) -> AffineExpr { + llvm_unreachable("only exprs of non-index type support operator/"); + }); +} +Expr mlir::edsc::op::operator%(Expr lhs, Expr rhs) { + return createBinaryExpr( + lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 % d1; }); +} + +Expr mlir::edsc::floorDiv(Expr lhs, Expr rhs) { + return createBinaryIndexExpr( + lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.floorDiv(d1); }); +} +Expr mlir::edsc::ceilDiv(Expr lhs, Expr rhs) { + return createBinaryIndexExpr( + lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.ceilDiv(d1); }); +} static Expr createComparisonExpr(CmpIPredicate predicate, Expr lhs, Expr rhs) { assert(lhs.getResultTypes().size() == 1 && rhs.getResultTypes().size() == 1 && @@ -688,9 +719,11 @@ void mlir::edsc::Expr::print(raw_ostream &os) const { infix = "-"; else if (binExpr.is_op() || binExpr.is_op()) infix = binExpr.getResultTypes().front().isInteger(1) ? "&&" : "*"; - else if (binExpr.is_op() || binExpr.is_op()) + else if (binExpr.is_op() || binExpr.is_op() || + binExpr.is_op()) infix = "/"; - else if (binExpr.is_op() || binExpr.is_op()) + else if (binExpr.is_op() || binExpr.is_op() || + binExpr.is_op()) infix = "%"; else if (binExpr.is_op()) infix = getCmpIPredicateInfix(*this);