forked from OSchip/llvm-project
[MLIR] Simplify semi-affine expressions
Simplify semi-affine expression for the operations like ceildiv, floordiv and modulo by any given symbol by checking divisibilty by that symbol. Some properties used in simplification are: 1) Commutative property of the floordiv and ceildiv: ((expr1 floordiv expr2) floordiv expr3 ) = ((expr1 floordiv expr3) floordiv expr2) ((expr1 ceildiv expr2) ceildiv expr3 ) = ((expr1 ceildiv expr3) ceildiv expr2) While simplification if operations are different no simplification is possible as there is no property that simplify expressions like these: ((expr1 ceildiv expr2) floordiv expr3) or ((expr1 floordiv expr2) ceildiv expr3). 2) If both expr1 and expr2 are divisible by the expr3 then: (expr1 % expr2) / expr3 = ((expr1 / expr3) % (expr2 / expr3)) where / is divide symbol. 3) If expr1 is divisible by expr2 then expr1 % expr2 = 0. Signed-off-by: Yash Jain <yash.jain@polymagelabs.com> Differential Revision: https://reviews.llvm.org/D84920
This commit is contained in:
parent
724b035fe4
commit
56593fa370
|
@ -245,6 +245,170 @@ unsigned AffineDimExpr::getPosition() const {
|
|||
return static_cast<ImplType *>(expr)->position;
|
||||
}
|
||||
|
||||
/// Returns true if the expression is divisible by the given symbol with
|
||||
/// position `symbolPos`. The argument `opKind` specifies here what kind of
|
||||
/// division or mod operation called this division. It helps in implementing the
|
||||
/// commutative property of the floordiv and ceildiv operations. If the argument
|
||||
///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
|
||||
/// operation, then the commutative property can be used otherwise, the floordiv
|
||||
/// operation is not divisible. The same argument holds for ceildiv operation.
|
||||
static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
|
||||
AffineExprKind opKind) {
|
||||
// The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
|
||||
assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
|
||||
opKind == AffineExprKind::CeilDiv) &&
|
||||
"unexpected opKind");
|
||||
switch (expr.getKind()) {
|
||||
case AffineExprKind::Constant:
|
||||
if (expr.cast<AffineConstantExpr>().getValue())
|
||||
return false;
|
||||
return true;
|
||||
case AffineExprKind::DimId:
|
||||
return false;
|
||||
case AffineExprKind::SymbolId:
|
||||
return (expr.cast<AffineSymbolExpr>().getPosition() == symbolPos);
|
||||
// Checks divisibility by the given symbol for both operands.
|
||||
case AffineExprKind::Add: {
|
||||
AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
|
||||
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
|
||||
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
|
||||
}
|
||||
// Checks divisibility by the given symbol for both operands. Consider the
|
||||
// expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
|
||||
// this is a division by s1 and both the operands of modulo are divisible by
|
||||
// s1 but it is not divisible by s1 always. The third argument is
|
||||
// `AffineExprKind::Mod` for this reason.
|
||||
case AffineExprKind::Mod: {
|
||||
AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
|
||||
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
|
||||
AffineExprKind::Mod) &&
|
||||
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
|
||||
AffineExprKind::Mod);
|
||||
}
|
||||
// Checks if any of the operand divisible by the given symbol.
|
||||
case AffineExprKind::Mul: {
|
||||
AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
|
||||
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
|
||||
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
|
||||
}
|
||||
// Floordiv and ceildiv are divisible by the given symbol when the first
|
||||
// operand is divisible, and the affine expression kind of the argument expr
|
||||
// is same as the argument `opKind`. This can be inferred from commutative
|
||||
// property of floordiv and ceildiv operations and are as follow:
|
||||
// (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
|
||||
// (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
|
||||
// It will fail if operations are not same. For example:
|
||||
// (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
|
||||
case AffineExprKind::FloorDiv:
|
||||
case AffineExprKind::CeilDiv: {
|
||||
AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
|
||||
if (opKind != expr.getKind())
|
||||
return false;
|
||||
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
|
||||
}
|
||||
}
|
||||
llvm_unreachable("Unknown AffineExpr");
|
||||
}
|
||||
|
||||
/// Divides the given expression by the given symbol at position `symbolPos`. It
|
||||
/// considers the divisibility condition is checked before calling itself. A
|
||||
/// null expression is returned whenever the divisibility condition fails.
|
||||
static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
|
||||
AffineExprKind opKind) {
|
||||
// THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
|
||||
assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
|
||||
opKind == AffineExprKind::CeilDiv) &&
|
||||
"unexpected opKind");
|
||||
switch (expr.getKind()) {
|
||||
case AffineExprKind::Constant:
|
||||
if (expr.cast<AffineConstantExpr>().getValue() != 0)
|
||||
return nullptr;
|
||||
return getAffineConstantExpr(0, expr.getContext());
|
||||
case AffineExprKind::DimId:
|
||||
return nullptr;
|
||||
case AffineExprKind::SymbolId:
|
||||
return getAffineConstantExpr(1, expr.getContext());
|
||||
// Dividing both operands by the given symbol.
|
||||
case AffineExprKind::Add: {
|
||||
AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
|
||||
return getAffineBinaryOpExpr(
|
||||
expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind),
|
||||
symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind));
|
||||
}
|
||||
// Dividing both operands by the given symbol.
|
||||
case AffineExprKind::Mod: {
|
||||
AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
|
||||
return getAffineBinaryOpExpr(
|
||||
expr.getKind(),
|
||||
symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
|
||||
symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind()));
|
||||
}
|
||||
// Dividing any of the operand by the given symbol.
|
||||
case AffineExprKind::Mul: {
|
||||
AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
|
||||
if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
|
||||
return binaryExpr.getLHS() *
|
||||
symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
|
||||
return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) *
|
||||
binaryExpr.getRHS();
|
||||
}
|
||||
// Dividing first operand only by the given symbol.
|
||||
case AffineExprKind::FloorDiv:
|
||||
case AffineExprKind::CeilDiv: {
|
||||
AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
|
||||
return getAffineBinaryOpExpr(
|
||||
expr.getKind(),
|
||||
symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
|
||||
binaryExpr.getRHS());
|
||||
}
|
||||
}
|
||||
llvm_unreachable("Unknown AffineExpr");
|
||||
}
|
||||
|
||||
/// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
|
||||
/// operations when the second operand simplifies to a symbol and the first
|
||||
/// operand is divisible by that symbol. It can be applied to any semi-affine
|
||||
/// expression. Returned expression can either be a semi-affine or pure affine
|
||||
/// expression.
|
||||
static AffineExpr simplifySemiAffine(AffineExpr expr) {
|
||||
switch (expr.getKind()) {
|
||||
case AffineExprKind::Constant:
|
||||
case AffineExprKind::DimId:
|
||||
case AffineExprKind::SymbolId:
|
||||
return expr;
|
||||
case AffineExprKind::Add:
|
||||
case AffineExprKind::Mul: {
|
||||
AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
|
||||
return getAffineBinaryOpExpr(expr.getKind(),
|
||||
simplifySemiAffine(binaryExpr.getLHS()),
|
||||
simplifySemiAffine(binaryExpr.getRHS()));
|
||||
}
|
||||
// Check if the simplification of the second operand is a symbol, and the
|
||||
// first operand is divisible by it. If the operation is a modulo, a constant
|
||||
// zero expression is returned. In the case of floordiv and ceildiv, the
|
||||
// symbol from the simplification of the second operand divides the first
|
||||
// operand. Otherwise, simplification is not possible.
|
||||
case AffineExprKind::FloorDiv:
|
||||
case AffineExprKind::CeilDiv:
|
||||
case AffineExprKind::Mod: {
|
||||
AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
|
||||
AffineExpr sLHS = simplifySemiAffine(binaryExpr.getLHS());
|
||||
AffineExpr sRHS = simplifySemiAffine(binaryExpr.getRHS());
|
||||
AffineSymbolExpr symbolExpr =
|
||||
simplifySemiAffine(binaryExpr.getRHS()).dyn_cast<AffineSymbolExpr>();
|
||||
if (!symbolExpr)
|
||||
return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
|
||||
unsigned symbolPos = symbolExpr.getPosition();
|
||||
if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()))
|
||||
return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
|
||||
if (expr.getKind() == AffineExprKind::Mod)
|
||||
return getAffineConstantExpr(0, expr.getContext());
|
||||
return symbolicDivide(sLHS, symbolPos, expr.getKind());
|
||||
}
|
||||
}
|
||||
llvm_unreachable("Unknown AffineExpr");
|
||||
}
|
||||
|
||||
static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
|
||||
MLIRContext *context) {
|
||||
auto assignCtx = [context](AffineDimExprStorage *storage) {
|
||||
|
@ -878,8 +1042,9 @@ int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
|
|||
/// Simplify the affine expression by flattening it and reconstructing it.
|
||||
AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
|
||||
unsigned numSymbols) {
|
||||
// TODO: only pure affine for now. The simplification here can
|
||||
// be extended to semi-affine maps in the future.
|
||||
// Simplify semi-affine expressions separately.
|
||||
if (!expr.isPureAffine())
|
||||
expr = simplifySemiAffine(expr);
|
||||
if (!expr.isPureAffine())
|
||||
return expr;
|
||||
|
||||
|
|
|
@ -281,3 +281,49 @@ func @simplify_zero_dim_map(%in : memref<f32>) -> f32 {
|
|||
%out = affine.load %in[] : memref<f32>
|
||||
return %out : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests the simplification of a semi-affine expression in various cases.
|
||||
// CHECK-DAG: #[[$map0:.*]] = affine_map<()[s0, s1] -> (-(s1 floordiv s0) + 2)>
|
||||
// CHECK-DAG: #[[$map1:.*]] = affine_map<()[s0, s1] -> (-(s1 floordiv s0) + 42)>
|
||||
|
||||
// Tests the simplification of a semi-affine expression with a modulo operartion on a floordiv and multiplication.
|
||||
// CHECK-LABEL: func @semiaffine_mod
|
||||
func @semiaffine_mod(%arg0: index, %arg1: index) -> index {
|
||||
%a = affine.apply affine_map<(d0)[s0] ->((-((d0 floordiv s0) * s0) + s0 * s0) mod s0)> (%arg0)[%arg1]
|
||||
// CHECK: %[[CST:.*]] = constant 0
|
||||
return %a : index
|
||||
}
|
||||
|
||||
// Tests the simplification of a semi-affine expression with a nested floordiv and a floordiv on modulo operation.
|
||||
// CHECK-LABEL: func @semiaffine_floordiv
|
||||
func @semiaffine_floordiv(%arg0: index, %arg1: index) -> index {
|
||||
%a = affine.apply affine_map<(d0)[s0] ->((-((d0 floordiv s0) * s0) + ((2 * s0) mod (3 * s0))) floordiv s0)> (%arg0)[%arg1]
|
||||
// CHECK: affine.apply #[[$map0]]()[%arg1, %arg0]
|
||||
return %a : index
|
||||
}
|
||||
|
||||
// Tests the simplification of a semi-affine expression with a ceildiv operation and a division of constant 0 by a symbol.
|
||||
// CHECK-LABEL: func @semiaffine_ceildiv
|
||||
func @semiaffine_ceildiv(%arg0: index, %arg1: index) -> index {
|
||||
%a = affine.apply affine_map<(d0)[s0] ->((-((d0 floordiv s0) * s0) + s0 * 42 + ((5-5) floordiv s0)) ceildiv s0)> (%arg0)[%arg1]
|
||||
// CHECK: affine.apply #[[$map1]]()[%arg1, %arg0]
|
||||
return %a : index
|
||||
}
|
||||
|
||||
// Tests the simplification of a semi-affine expression with a nested ceildiv operation and further simplifications after performing ceildiv.
|
||||
// CHECK-LABEL: func @semiaffine_composite_floor
|
||||
func @semiaffine_composite_floor(%arg0: index, %arg1: index) -> index {
|
||||
%a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) ceildiv 4) * 5) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
|
||||
// CHECK: %[[CST:.*]] = constant 47
|
||||
return %a : index
|
||||
}
|
||||
|
||||
// Tests the simplification of a semi-affine expression with a modulo operation with a second operand that simplifies to symbol.
|
||||
// CHECK-LABEL: func @semiaffine_unsimplified_symbol
|
||||
func @semiaffine_unsimplified_symbol(%arg0: index, %arg1: index) -> index {
|
||||
%a = affine.apply affine_map<(d0)[s0] ->(s0 mod (2 * s0 - s0))> (%arg0)[%arg1]
|
||||
// CHECK: %[[CST:.*]] = constant 0
|
||||
return %a : index
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue