Exposed division and remainder operations in EDSC

This change introduces three new operators in EDSC: Div (also exposed
via Expr.__div__ aka /) -- floating-point division, FloorDiv and CeilDiv
for flooring/ceiling index division.

The lowering to LLVM will be implemented in b/124872679.

PiperOrigin-RevId: 234963217
This commit is contained in:
Sergei Lebedev 2019-02-21 03:09:51 -08:00 committed by jpienaar
parent 59a209721e
commit 1cc9305c71
7 changed files with 86 additions and 18 deletions

View File

@ -494,7 +494,10 @@ PYBIND11_MODULE(pybind, m) {
DEFINE_PYBIND_BINARY_OP("Add", Add);
DEFINE_PYBIND_BINARY_OP("Mul", Mul);
DEFINE_PYBIND_BINARY_OP("Sub", Sub);
// DEFINE_PYBIND_BINARY_OP("Div", Div);
DEFINE_PYBIND_BINARY_OP("Div", Div);
DEFINE_PYBIND_BINARY_OP("Rem", Rem);
DEFINE_PYBIND_BINARY_OP("FloorDiv", FloorDiv);
DEFINE_PYBIND_BINARY_OP("CeilDiv", CeilDiv);
DEFINE_PYBIND_BINARY_OP("LT", LT);
DEFINE_PYBIND_BINARY_OP("LE", LE);
DEFINE_PYBIND_BINARY_OP("GT", GT);
@ -649,8 +652,14 @@ PYBIND11_MODULE(pybind, m) {
PythonExpr e2) { return PythonExpr(::Sub(e1, e2)); })
.def("__mul__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::Mul(e1, e2)); })
// .def("__div__", [](PythonExpr e1, PythonExpr e2) { return
// PythonExpr(::Div(e1, e2)); })
.def("__div__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::Div(e1, e2)); })
.def("__truediv__",
[](PythonExpr e1, PythonExpr e2) {
return PythonExpr(::Div(e1, e2));
})
.def("__mod__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::Rem(e1, e2)); })
.def("__lt__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::LT(e1, e2)); })
.def("__le__", [](PythonExpr e1,

View File

@ -113,9 +113,9 @@ class EdscTest(unittest.TestCase):
with E.ContextManager():
i, j, k, l = list(
map(E.Expr, [E.Bindable(self.f32Type) for _ in range(4)]))
stmt = i + j * k - l
stmt = i % j + j * k - l / k
str = stmt.__str__()
self.assertIn("(($1 + ($2 * $3)) - $4)", str)
self.assertIn("((($1 % $2) + ($2 * $3)) - ($4 / $3))", str)
def testBoolean(self):
with E.ContextManager():

View File

@ -260,7 +260,8 @@ edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_max_expr_t lb, edsc_min_expr_t ub,
edsc_expr_t Add(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t Sub(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t Mul(edsc_expr_t e1, edsc_expr_t e2);
// edsc_expr_t Div(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t Div(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t Rem(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t LT(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t LE(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t GT(edsc_expr_t e1, edsc_expr_t e2);
@ -268,6 +269,9 @@ edsc_expr_t GE(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t EQ(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t NE(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t FloorDiv(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t CeilDiv(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t And(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t Or(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t Negate(edsc_expr_t e);

View File

@ -436,6 +436,8 @@ namespace op {
Expr operator+(Expr lhs, Expr rhs);
Expr operator-(Expr lhs, Expr rhs);
Expr operator*(Expr lhs, Expr rhs);
Expr operator/(Expr lhs, Expr rhs);
Expr operator%(Expr lhs, Expr rhs);
/// In particular operator==, operator!= return a new Expr and *not* a bool.
Expr operator==(Expr lhs, Expr rhs);
Expr operator!=(Expr lhs, Expr rhs);
@ -479,6 +481,9 @@ inline Expr operator||(Stmt lhs, Stmt rhs) {
inline Expr operator!(Stmt stmt) { return !stmt.getLHS(); }
} // end namespace op
Expr floorDiv(Expr lhs, Expr rhs);
Expr ceilDiv(Expr lhs, Expr rhs);
template <typename U> bool Expr::isa() const {
auto kind = getKind();
if (std::is_same<U, Bindable>::value) {

View File

@ -82,6 +82,10 @@ def AddIOp : IntArithmeticOp<"addi", [Commutative]> {
let hasConstantFolder = 0b1;
}
def DivFOp : FloatArithmeticOp<"divf"> {
let summary = "floating point division operation";
}
def DivISOp : IntArithmeticOp<"divis"> {
let summary = "signed integer division operation";
let hasConstantFolder = 0b1;
@ -103,6 +107,10 @@ def MulIOp : IntArithmeticOp<"muli", [Commutative]> {
let hasFolder = 1;
}
def RemFOp : FloatArithmeticOp<"remf"> {
let summary = "floating point division remainder operation";
}
def RemISOp : IntArithmeticOp<"remis"> {
let summary = "signed integer division remainder operation";
let hasConstantFolder = 0b1;

View File

@ -534,7 +534,8 @@ void bindMemRefView(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef,
DEFINE_EDSL_BINARY_OP(Add, +);
DEFINE_EDSL_BINARY_OP(Sub, -);
DEFINE_EDSL_BINARY_OP(Mul, *);
// DEFINE_EDSL_BINARY_OP(Div, /);
DEFINE_EDSL_BINARY_OP(Div, /);
DEFINE_EDSL_BINARY_OP(Rem, %);
DEFINE_EDSL_BINARY_OP(LT, <);
DEFINE_EDSL_BINARY_OP(LE, <=);
DEFINE_EDSL_BINARY_OP(GT, >);
@ -546,6 +547,14 @@ DEFINE_EDSL_BINARY_OP(Or, ||);
#undef DEFINE_EDSL_BINARY_OP
edsc_expr_t FloorDiv(edsc_expr_t e1, edsc_expr_t e2) {
return edsc::floorDiv(Expr(e1), Expr(e2));
}
edsc_expr_t CeilDiv(edsc_expr_t e1, edsc_expr_t e2) {
return edsc::ceilDiv(Expr(e1), Expr(e2));
}
#define DEFINE_EDSL_UNARY_OP(FUN_NAME, OP_SYMBOL) \
edsc_expr_t FUN_NAME(edsc_expr_t e) { \
using edsc::op::operator OP_SYMBOL; \

View File

@ -205,6 +205,26 @@ static AffineExpr createOperandAffineExpr(Expr e, int64_t position,
return getAffineDimExpr(position, context);
}
static Expr createBinaryIndexExpr(
Expr lhs, Expr rhs,
std::function<AffineExpr(AffineExpr, AffineExpr)> affCombiner) {
assert(lhs.getResultTypes().size() == 1 && rhs.getResultTypes().size() == 1 &&
"only single-result exprs are supported in operators");
auto thisType = lhs.getResultTypes().front();
auto thatType = rhs.getResultTypes().front();
assert(thisType == thatType && "cannot mix types in operators");
assert(thisType.isIndex() && "expected exprs of index type");
MLIRContext *context = thisType.getContext();
auto lhsAff = createOperandAffineExpr(lhs, 0, context);
auto rhsAff = createOperandAffineExpr(rhs, 1, context);
auto map = AffineMap::get(2, 0, {affCombiner(lhsAff, rhsAff)}, {});
auto attr = AffineMapAttr::get(map);
auto attrId = Identifier::get("map", context);
auto namedAttr = NamedAttribute{attrId, attr};
return VariadicExpr("affine.apply", {lhs, rhs}, {IndexType::get(context)},
{namedAttr});
}
// Create a binary expression between the two arguments emitting `IOp` if
// arguments are integers or vectors/tensors thereof, `FOp` if arguments are
// floating-point or vectors/tensors thereof, and `AffineApplyOp` with an
@ -220,15 +240,7 @@ static Expr createBinaryExpr(
auto thatType = rhs.getResultTypes().front();
assert(thisType == thatType && "cannot mix types in operators");
if (thisType.isIndex()) {
MLIRContext *context = thisType.getContext();
auto lhsAff = createOperandAffineExpr(lhs, 0, context);
auto rhsAff = createOperandAffineExpr(rhs, 1, context);
auto map = AffineMap::get(2, 0, {affCombiner(lhsAff, rhsAff)}, {});
auto attr = AffineMapAttr::get(map);
auto attrId = Identifier::get("map", context);
auto namedAttr = NamedAttribute{attrId, attr};
return VariadicExpr("affine.apply", {lhs, rhs}, {IndexType::get(context)},
{namedAttr});
return createBinaryIndexExpr(lhs, rhs, affCombiner);
} else if (thisType.isa<IntegerType>()) {
return BinaryExpr::make<IOp>(thisType, lhs, rhs);
} else if (thisType.isa<FloatType>()) {
@ -255,6 +267,25 @@ Expr mlir::edsc::op::operator*(Expr lhs, Expr rhs) {
return createBinaryExpr<MulIOp, MulFOp>(
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 * d1; });
}
Expr mlir::edsc::op::operator/(Expr lhs, Expr rhs) {
return createBinaryExpr<DivISOp, DivFOp>(
lhs, rhs, [](AffineExpr d0, AffineExpr d1) -> AffineExpr {
llvm_unreachable("only exprs of non-index type support operator/");
});
}
Expr mlir::edsc::op::operator%(Expr lhs, Expr rhs) {
return createBinaryExpr<RemISOp, RemFOp>(
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 % d1; });
}
Expr mlir::edsc::floorDiv(Expr lhs, Expr rhs) {
return createBinaryIndexExpr(
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.floorDiv(d1); });
}
Expr mlir::edsc::ceilDiv(Expr lhs, Expr rhs) {
return createBinaryIndexExpr(
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.ceilDiv(d1); });
}
static Expr createComparisonExpr(CmpIPredicate predicate, Expr lhs, Expr rhs) {
assert(lhs.getResultTypes().size() == 1 && rhs.getResultTypes().size() == 1 &&
@ -688,9 +719,11 @@ void mlir::edsc::Expr::print(raw_ostream &os) const {
infix = "-";
else if (binExpr.is_op<MulIOp>() || binExpr.is_op<MulFOp>())
infix = binExpr.getResultTypes().front().isInteger(1) ? "&&" : "*";
else if (binExpr.is_op<DivISOp>() || binExpr.is_op<DivIUOp>())
else if (binExpr.is_op<DivISOp>() || binExpr.is_op<DivIUOp>() ||
binExpr.is_op<DivFOp>())
infix = "/";
else if (binExpr.is_op<RemISOp>() || binExpr.is_op<RemIUOp>())
else if (binExpr.is_op<RemISOp>() || binExpr.is_op<RemIUOp>() ||
binExpr.is_op<RemFOp>())
infix = "%";
else if (binExpr.is_op<CmpIOp>())
infix = getCmpIPredicateInfix(*this);