forked from OSchip/llvm-project
Exposed division and remainder operations in EDSC
This change introduces three new operators in EDSC: Div (also exposed via Expr.__div__ aka /) -- floating-point division, FloorDiv and CeilDiv for flooring/ceiling index division. The lowering to LLVM will be implemented in b/124872679. PiperOrigin-RevId: 234963217
This commit is contained in:
parent
59a209721e
commit
1cc9305c71
|
@ -494,7 +494,10 @@ PYBIND11_MODULE(pybind, m) {
|
|||
DEFINE_PYBIND_BINARY_OP("Add", Add);
|
||||
DEFINE_PYBIND_BINARY_OP("Mul", Mul);
|
||||
DEFINE_PYBIND_BINARY_OP("Sub", Sub);
|
||||
// DEFINE_PYBIND_BINARY_OP("Div", Div);
|
||||
DEFINE_PYBIND_BINARY_OP("Div", Div);
|
||||
DEFINE_PYBIND_BINARY_OP("Rem", Rem);
|
||||
DEFINE_PYBIND_BINARY_OP("FloorDiv", FloorDiv);
|
||||
DEFINE_PYBIND_BINARY_OP("CeilDiv", CeilDiv);
|
||||
DEFINE_PYBIND_BINARY_OP("LT", LT);
|
||||
DEFINE_PYBIND_BINARY_OP("LE", LE);
|
||||
DEFINE_PYBIND_BINARY_OP("GT", GT);
|
||||
|
@ -649,8 +652,14 @@ PYBIND11_MODULE(pybind, m) {
|
|||
PythonExpr e2) { return PythonExpr(::Sub(e1, e2)); })
|
||||
.def("__mul__", [](PythonExpr e1,
|
||||
PythonExpr e2) { return PythonExpr(::Mul(e1, e2)); })
|
||||
// .def("__div__", [](PythonExpr e1, PythonExpr e2) { return
|
||||
// PythonExpr(::Div(e1, e2)); })
|
||||
.def("__div__", [](PythonExpr e1,
|
||||
PythonExpr e2) { return PythonExpr(::Div(e1, e2)); })
|
||||
.def("__truediv__",
|
||||
[](PythonExpr e1, PythonExpr e2) {
|
||||
return PythonExpr(::Div(e1, e2));
|
||||
})
|
||||
.def("__mod__", [](PythonExpr e1,
|
||||
PythonExpr e2) { return PythonExpr(::Rem(e1, e2)); })
|
||||
.def("__lt__", [](PythonExpr e1,
|
||||
PythonExpr e2) { return PythonExpr(::LT(e1, e2)); })
|
||||
.def("__le__", [](PythonExpr e1,
|
||||
|
|
|
@ -113,9 +113,9 @@ class EdscTest(unittest.TestCase):
|
|||
with E.ContextManager():
|
||||
i, j, k, l = list(
|
||||
map(E.Expr, [E.Bindable(self.f32Type) for _ in range(4)]))
|
||||
stmt = i + j * k - l
|
||||
stmt = i % j + j * k - l / k
|
||||
str = stmt.__str__()
|
||||
self.assertIn("(($1 + ($2 * $3)) - $4)", str)
|
||||
self.assertIn("((($1 % $2) + ($2 * $3)) - ($4 / $3))", str)
|
||||
|
||||
def testBoolean(self):
|
||||
with E.ContextManager():
|
||||
|
|
|
@ -260,7 +260,8 @@ edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_max_expr_t lb, edsc_min_expr_t ub,
|
|||
edsc_expr_t Add(edsc_expr_t e1, edsc_expr_t e2);
|
||||
edsc_expr_t Sub(edsc_expr_t e1, edsc_expr_t e2);
|
||||
edsc_expr_t Mul(edsc_expr_t e1, edsc_expr_t e2);
|
||||
// edsc_expr_t Div(edsc_expr_t e1, edsc_expr_t e2);
|
||||
edsc_expr_t Div(edsc_expr_t e1, edsc_expr_t e2);
|
||||
edsc_expr_t Rem(edsc_expr_t e1, edsc_expr_t e2);
|
||||
edsc_expr_t LT(edsc_expr_t e1, edsc_expr_t e2);
|
||||
edsc_expr_t LE(edsc_expr_t e1, edsc_expr_t e2);
|
||||
edsc_expr_t GT(edsc_expr_t e1, edsc_expr_t e2);
|
||||
|
@ -268,6 +269,9 @@ edsc_expr_t GE(edsc_expr_t e1, edsc_expr_t e2);
|
|||
edsc_expr_t EQ(edsc_expr_t e1, edsc_expr_t e2);
|
||||
edsc_expr_t NE(edsc_expr_t e1, edsc_expr_t e2);
|
||||
|
||||
edsc_expr_t FloorDiv(edsc_expr_t e1, edsc_expr_t e2);
|
||||
edsc_expr_t CeilDiv(edsc_expr_t e1, edsc_expr_t e2);
|
||||
|
||||
edsc_expr_t And(edsc_expr_t e1, edsc_expr_t e2);
|
||||
edsc_expr_t Or(edsc_expr_t e1, edsc_expr_t e2);
|
||||
edsc_expr_t Negate(edsc_expr_t e);
|
||||
|
|
|
@ -436,6 +436,8 @@ namespace op {
|
|||
Expr operator+(Expr lhs, Expr rhs);
|
||||
Expr operator-(Expr lhs, Expr rhs);
|
||||
Expr operator*(Expr lhs, Expr rhs);
|
||||
Expr operator/(Expr lhs, Expr rhs);
|
||||
Expr operator%(Expr lhs, Expr rhs);
|
||||
/// In particular operator==, operator!= return a new Expr and *not* a bool.
|
||||
Expr operator==(Expr lhs, Expr rhs);
|
||||
Expr operator!=(Expr lhs, Expr rhs);
|
||||
|
@ -479,6 +481,9 @@ inline Expr operator||(Stmt lhs, Stmt rhs) {
|
|||
inline Expr operator!(Stmt stmt) { return !stmt.getLHS(); }
|
||||
} // end namespace op
|
||||
|
||||
Expr floorDiv(Expr lhs, Expr rhs);
|
||||
Expr ceilDiv(Expr lhs, Expr rhs);
|
||||
|
||||
template <typename U> bool Expr::isa() const {
|
||||
auto kind = getKind();
|
||||
if (std::is_same<U, Bindable>::value) {
|
||||
|
|
|
@ -82,6 +82,10 @@ def AddIOp : IntArithmeticOp<"addi", [Commutative]> {
|
|||
let hasConstantFolder = 0b1;
|
||||
}
|
||||
|
||||
def DivFOp : FloatArithmeticOp<"divf"> {
|
||||
let summary = "floating point division operation";
|
||||
}
|
||||
|
||||
def DivISOp : IntArithmeticOp<"divis"> {
|
||||
let summary = "signed integer division operation";
|
||||
let hasConstantFolder = 0b1;
|
||||
|
@ -103,6 +107,10 @@ def MulIOp : IntArithmeticOp<"muli", [Commutative]> {
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def RemFOp : FloatArithmeticOp<"remf"> {
|
||||
let summary = "floating point division remainder operation";
|
||||
}
|
||||
|
||||
def RemISOp : IntArithmeticOp<"remis"> {
|
||||
let summary = "signed integer division remainder operation";
|
||||
let hasConstantFolder = 0b1;
|
||||
|
|
|
@ -534,7 +534,8 @@ void bindMemRefView(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef,
|
|||
DEFINE_EDSL_BINARY_OP(Add, +);
|
||||
DEFINE_EDSL_BINARY_OP(Sub, -);
|
||||
DEFINE_EDSL_BINARY_OP(Mul, *);
|
||||
// DEFINE_EDSL_BINARY_OP(Div, /);
|
||||
DEFINE_EDSL_BINARY_OP(Div, /);
|
||||
DEFINE_EDSL_BINARY_OP(Rem, %);
|
||||
DEFINE_EDSL_BINARY_OP(LT, <);
|
||||
DEFINE_EDSL_BINARY_OP(LE, <=);
|
||||
DEFINE_EDSL_BINARY_OP(GT, >);
|
||||
|
@ -546,6 +547,14 @@ DEFINE_EDSL_BINARY_OP(Or, ||);
|
|||
|
||||
#undef DEFINE_EDSL_BINARY_OP
|
||||
|
||||
edsc_expr_t FloorDiv(edsc_expr_t e1, edsc_expr_t e2) {
|
||||
return edsc::floorDiv(Expr(e1), Expr(e2));
|
||||
}
|
||||
|
||||
edsc_expr_t CeilDiv(edsc_expr_t e1, edsc_expr_t e2) {
|
||||
return edsc::ceilDiv(Expr(e1), Expr(e2));
|
||||
}
|
||||
|
||||
#define DEFINE_EDSL_UNARY_OP(FUN_NAME, OP_SYMBOL) \
|
||||
edsc_expr_t FUN_NAME(edsc_expr_t e) { \
|
||||
using edsc::op::operator OP_SYMBOL; \
|
||||
|
|
|
@ -205,6 +205,26 @@ static AffineExpr createOperandAffineExpr(Expr e, int64_t position,
|
|||
return getAffineDimExpr(position, context);
|
||||
}
|
||||
|
||||
static Expr createBinaryIndexExpr(
|
||||
Expr lhs, Expr rhs,
|
||||
std::function<AffineExpr(AffineExpr, AffineExpr)> affCombiner) {
|
||||
assert(lhs.getResultTypes().size() == 1 && rhs.getResultTypes().size() == 1 &&
|
||||
"only single-result exprs are supported in operators");
|
||||
auto thisType = lhs.getResultTypes().front();
|
||||
auto thatType = rhs.getResultTypes().front();
|
||||
assert(thisType == thatType && "cannot mix types in operators");
|
||||
assert(thisType.isIndex() && "expected exprs of index type");
|
||||
MLIRContext *context = thisType.getContext();
|
||||
auto lhsAff = createOperandAffineExpr(lhs, 0, context);
|
||||
auto rhsAff = createOperandAffineExpr(rhs, 1, context);
|
||||
auto map = AffineMap::get(2, 0, {affCombiner(lhsAff, rhsAff)}, {});
|
||||
auto attr = AffineMapAttr::get(map);
|
||||
auto attrId = Identifier::get("map", context);
|
||||
auto namedAttr = NamedAttribute{attrId, attr};
|
||||
return VariadicExpr("affine.apply", {lhs, rhs}, {IndexType::get(context)},
|
||||
{namedAttr});
|
||||
}
|
||||
|
||||
// Create a binary expression between the two arguments emitting `IOp` if
|
||||
// arguments are integers or vectors/tensors thereof, `FOp` if arguments are
|
||||
// floating-point or vectors/tensors thereof, and `AffineApplyOp` with an
|
||||
|
@ -220,15 +240,7 @@ static Expr createBinaryExpr(
|
|||
auto thatType = rhs.getResultTypes().front();
|
||||
assert(thisType == thatType && "cannot mix types in operators");
|
||||
if (thisType.isIndex()) {
|
||||
MLIRContext *context = thisType.getContext();
|
||||
auto lhsAff = createOperandAffineExpr(lhs, 0, context);
|
||||
auto rhsAff = createOperandAffineExpr(rhs, 1, context);
|
||||
auto map = AffineMap::get(2, 0, {affCombiner(lhsAff, rhsAff)}, {});
|
||||
auto attr = AffineMapAttr::get(map);
|
||||
auto attrId = Identifier::get("map", context);
|
||||
auto namedAttr = NamedAttribute{attrId, attr};
|
||||
return VariadicExpr("affine.apply", {lhs, rhs}, {IndexType::get(context)},
|
||||
{namedAttr});
|
||||
return createBinaryIndexExpr(lhs, rhs, affCombiner);
|
||||
} else if (thisType.isa<IntegerType>()) {
|
||||
return BinaryExpr::make<IOp>(thisType, lhs, rhs);
|
||||
} else if (thisType.isa<FloatType>()) {
|
||||
|
@ -255,6 +267,25 @@ Expr mlir::edsc::op::operator*(Expr lhs, Expr rhs) {
|
|||
return createBinaryExpr<MulIOp, MulFOp>(
|
||||
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 * d1; });
|
||||
}
|
||||
Expr mlir::edsc::op::operator/(Expr lhs, Expr rhs) {
|
||||
return createBinaryExpr<DivISOp, DivFOp>(
|
||||
lhs, rhs, [](AffineExpr d0, AffineExpr d1) -> AffineExpr {
|
||||
llvm_unreachable("only exprs of non-index type support operator/");
|
||||
});
|
||||
}
|
||||
Expr mlir::edsc::op::operator%(Expr lhs, Expr rhs) {
|
||||
return createBinaryExpr<RemISOp, RemFOp>(
|
||||
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 % d1; });
|
||||
}
|
||||
|
||||
Expr mlir::edsc::floorDiv(Expr lhs, Expr rhs) {
|
||||
return createBinaryIndexExpr(
|
||||
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.floorDiv(d1); });
|
||||
}
|
||||
Expr mlir::edsc::ceilDiv(Expr lhs, Expr rhs) {
|
||||
return createBinaryIndexExpr(
|
||||
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.ceilDiv(d1); });
|
||||
}
|
||||
|
||||
static Expr createComparisonExpr(CmpIPredicate predicate, Expr lhs, Expr rhs) {
|
||||
assert(lhs.getResultTypes().size() == 1 && rhs.getResultTypes().size() == 1 &&
|
||||
|
@ -688,9 +719,11 @@ void mlir::edsc::Expr::print(raw_ostream &os) const {
|
|||
infix = "-";
|
||||
else if (binExpr.is_op<MulIOp>() || binExpr.is_op<MulFOp>())
|
||||
infix = binExpr.getResultTypes().front().isInteger(1) ? "&&" : "*";
|
||||
else if (binExpr.is_op<DivISOp>() || binExpr.is_op<DivIUOp>())
|
||||
else if (binExpr.is_op<DivISOp>() || binExpr.is_op<DivIUOp>() ||
|
||||
binExpr.is_op<DivFOp>())
|
||||
infix = "/";
|
||||
else if (binExpr.is_op<RemISOp>() || binExpr.is_op<RemIUOp>())
|
||||
else if (binExpr.is_op<RemISOp>() || binExpr.is_op<RemIUOp>() ||
|
||||
binExpr.is_op<RemFOp>())
|
||||
infix = "%";
|
||||
else if (binExpr.is_op<CmpIOp>())
|
||||
infix = getCmpIPredicateInfix(*this);
|
||||
|
|
Loading…
Reference in New Issue