[MLIR] Simplify semi-affine expressions

Simplify semi-affine expression for the operations like ceildiv,
floordiv and modulo by any given symbol by checking divisibilty by that
symbol.

Some properties used in simplification are:

1) Commutative property of the floordiv and ceildiv:
((expr1 floordiv expr2) floordiv expr3 ) = ((expr1 floordiv expr3) floordiv expr2)
((expr1 ceildiv expr2) ceildiv expr3 ) = ((expr1 ceildiv expr3) ceildiv expr2)

While simplification if operations are different no simplification is
possible as there is no property that simplify expressions like these:
((expr1 ceildiv expr2) floordiv expr3) or  ((expr1 floordiv expr2)
ceildiv expr3).

2) If both expr1 and expr2 are divisible by the expr3 then:
(expr1 % expr2) / expr3 = ((expr1 / expr3) % (expr2 / expr3))
where / is divide symbol.

3) If expr1 is divisible by expr2 then expr1 % expr2 = 0.

Signed-off-by: Yash Jain <yash.jain@polymagelabs.com>

Differential Revision: https://reviews.llvm.org/D84920
This commit is contained in:
Yash Jain 2020-08-04 20:21:13 +05:30 committed by Uday Bondhugula
parent 724b035fe4
commit 56593fa370
2 changed files with 213 additions and 2 deletions

View File

@ -245,6 +245,170 @@ unsigned AffineDimExpr::getPosition() const {
return static_cast<ImplType *>(expr)->position;
}
/// Returns true if the expression is divisible by the given symbol with
/// position `symbolPos`. The argument `opKind` specifies here what kind of
/// division or mod operation called this division. It helps in implementing the
/// commutative property of the floordiv and ceildiv operations. If the argument
///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
/// operation, then the commutative property can be used otherwise, the floordiv
/// operation is not divisible. The same argument holds for ceildiv operation.
static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
AffineExprKind opKind) {
// The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
opKind == AffineExprKind::CeilDiv) &&
"unexpected opKind");
switch (expr.getKind()) {
case AffineExprKind::Constant:
if (expr.cast<AffineConstantExpr>().getValue())
return false;
return true;
case AffineExprKind::DimId:
return false;
case AffineExprKind::SymbolId:
return (expr.cast<AffineSymbolExpr>().getPosition() == symbolPos);
// Checks divisibility by the given symbol for both operands.
case AffineExprKind::Add: {
AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
}
// Checks divisibility by the given symbol for both operands. Consider the
// expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
// this is a division by s1 and both the operands of modulo are divisible by
// s1 but it is not divisible by s1 always. The third argument is
// `AffineExprKind::Mod` for this reason.
case AffineExprKind::Mod: {
AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
AffineExprKind::Mod) &&
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
AffineExprKind::Mod);
}
// Checks if any of the operand divisible by the given symbol.
case AffineExprKind::Mul: {
AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
}
// Floordiv and ceildiv are divisible by the given symbol when the first
// operand is divisible, and the affine expression kind of the argument expr
// is same as the argument `opKind`. This can be inferred from commutative
// property of floordiv and ceildiv operations and are as follow:
// (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
// (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
// It will fail if operations are not same. For example:
// (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv: {
AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
if (opKind != expr.getKind())
return false;
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
}
}
llvm_unreachable("Unknown AffineExpr");
}
/// Divides the given expression by the given symbol at position `symbolPos`. It
/// considers the divisibility condition is checked before calling itself. A
/// null expression is returned whenever the divisibility condition fails.
static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
AffineExprKind opKind) {
// THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
opKind == AffineExprKind::CeilDiv) &&
"unexpected opKind");
switch (expr.getKind()) {
case AffineExprKind::Constant:
if (expr.cast<AffineConstantExpr>().getValue() != 0)
return nullptr;
return getAffineConstantExpr(0, expr.getContext());
case AffineExprKind::DimId:
return nullptr;
case AffineExprKind::SymbolId:
return getAffineConstantExpr(1, expr.getContext());
// Dividing both operands by the given symbol.
case AffineExprKind::Add: {
AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
return getAffineBinaryOpExpr(
expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind),
symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind));
}
// Dividing both operands by the given symbol.
case AffineExprKind::Mod: {
AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
return getAffineBinaryOpExpr(
expr.getKind(),
symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind()));
}
// Dividing any of the operand by the given symbol.
case AffineExprKind::Mul: {
AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
return binaryExpr.getLHS() *
symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) *
binaryExpr.getRHS();
}
// Dividing first operand only by the given symbol.
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv: {
AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
return getAffineBinaryOpExpr(
expr.getKind(),
symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
binaryExpr.getRHS());
}
}
llvm_unreachable("Unknown AffineExpr");
}
/// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
/// operations when the second operand simplifies to a symbol and the first
/// operand is divisible by that symbol. It can be applied to any semi-affine
/// expression. Returned expression can either be a semi-affine or pure affine
/// expression.
static AffineExpr simplifySemiAffine(AffineExpr expr) {
switch (expr.getKind()) {
case AffineExprKind::Constant:
case AffineExprKind::DimId:
case AffineExprKind::SymbolId:
return expr;
case AffineExprKind::Add:
case AffineExprKind::Mul: {
AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
return getAffineBinaryOpExpr(expr.getKind(),
simplifySemiAffine(binaryExpr.getLHS()),
simplifySemiAffine(binaryExpr.getRHS()));
}
// Check if the simplification of the second operand is a symbol, and the
// first operand is divisible by it. If the operation is a modulo, a constant
// zero expression is returned. In the case of floordiv and ceildiv, the
// symbol from the simplification of the second operand divides the first
// operand. Otherwise, simplification is not possible.
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv:
case AffineExprKind::Mod: {
AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
AffineExpr sLHS = simplifySemiAffine(binaryExpr.getLHS());
AffineExpr sRHS = simplifySemiAffine(binaryExpr.getRHS());
AffineSymbolExpr symbolExpr =
simplifySemiAffine(binaryExpr.getRHS()).dyn_cast<AffineSymbolExpr>();
if (!symbolExpr)
return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
unsigned symbolPos = symbolExpr.getPosition();
if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()))
return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
if (expr.getKind() == AffineExprKind::Mod)
return getAffineConstantExpr(0, expr.getContext());
return symbolicDivide(sLHS, symbolPos, expr.getKind());
}
}
llvm_unreachable("Unknown AffineExpr");
}
static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
MLIRContext *context) {
auto assignCtx = [context](AffineDimExprStorage *storage) {
@ -878,8 +1042,9 @@ int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
/// Simplify the affine expression by flattening it and reconstructing it.
AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
unsigned numSymbols) {
// TODO: only pure affine for now. The simplification here can
// be extended to semi-affine maps in the future.
// Simplify semi-affine expressions separately.
if (!expr.isPureAffine())
expr = simplifySemiAffine(expr);
if (!expr.isPureAffine())
return expr;

View File

@ -281,3 +281,49 @@ func @simplify_zero_dim_map(%in : memref<f32>) -> f32 {
%out = affine.load %in[] : memref<f32>
return %out : f32
}
// -----
// Tests the simplification of a semi-affine expression in various cases.
// CHECK-DAG: #[[$map0:.*]] = affine_map<()[s0, s1] -> (-(s1 floordiv s0) + 2)>
// CHECK-DAG: #[[$map1:.*]] = affine_map<()[s0, s1] -> (-(s1 floordiv s0) + 42)>
// Tests the simplification of a semi-affine expression with a modulo operartion on a floordiv and multiplication.
// CHECK-LABEL: func @semiaffine_mod
func @semiaffine_mod(%arg0: index, %arg1: index) -> index {
%a = affine.apply affine_map<(d0)[s0] ->((-((d0 floordiv s0) * s0) + s0 * s0) mod s0)> (%arg0)[%arg1]
// CHECK: %[[CST:.*]] = constant 0
return %a : index
}
// Tests the simplification of a semi-affine expression with a nested floordiv and a floordiv on modulo operation.
// CHECK-LABEL: func @semiaffine_floordiv
func @semiaffine_floordiv(%arg0: index, %arg1: index) -> index {
%a = affine.apply affine_map<(d0)[s0] ->((-((d0 floordiv s0) * s0) + ((2 * s0) mod (3 * s0))) floordiv s0)> (%arg0)[%arg1]
// CHECK: affine.apply #[[$map0]]()[%arg1, %arg0]
return %a : index
}
// Tests the simplification of a semi-affine expression with a ceildiv operation and a division of constant 0 by a symbol.
// CHECK-LABEL: func @semiaffine_ceildiv
func @semiaffine_ceildiv(%arg0: index, %arg1: index) -> index {
%a = affine.apply affine_map<(d0)[s0] ->((-((d0 floordiv s0) * s0) + s0 * 42 + ((5-5) floordiv s0)) ceildiv s0)> (%arg0)[%arg1]
// CHECK: affine.apply #[[$map1]]()[%arg1, %arg0]
return %a : index
}
// Tests the simplification of a semi-affine expression with a nested ceildiv operation and further simplifications after performing ceildiv.
// CHECK-LABEL: func @semiaffine_composite_floor
func @semiaffine_composite_floor(%arg0: index, %arg1: index) -> index {
%a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) ceildiv 4) * 5) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
// CHECK: %[[CST:.*]] = constant 47
return %a : index
}
// Tests the simplification of a semi-affine expression with a modulo operation with a second operand that simplifies to symbol.
// CHECK-LABEL: func @semiaffine_unsimplified_symbol
func @semiaffine_unsimplified_symbol(%arg0: index, %arg1: index) -> index {
%a = affine.apply affine_map<(d0)[s0] ->(s0 mod (2 * s0 - s0))> (%arg0)[%arg1]
// CHECK: %[[CST:.*]] = constant 0
return %a : index
}