Fix MLIR's floordiv, ceildiv, and mod for constant inputs (for negative lhs's)

- introduce mlir::{floorDiv, ceilDiv, mod} for constant inputs in
  mlir/Support/MathExtras.h
- consistently use these everywhere in IR, Analysis, and Transforms.

PiperOrigin-RevId: 215580677
This commit is contained in:
Uday Bondhugula 2018-10-03 10:07:54 -07:00 committed by jpienaar
parent 7d016fd352
commit 0ebc927f2f
5 changed files with 79 additions and 21 deletions

View File

@ -0,0 +1,56 @@
//===- MathExtras.h - Math functions relevant to MLIR -----------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file contains math functions relevant to MLIR.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_SUPPORT_MATHEXTRAS_H_
#define MLIR_SUPPORT_MATHEXTRAS_H_
#include "mlir/Support/LLVM.h"
namespace mlir {
/// Returns the result of MLIR's ceildiv operation on constants. The RHS is
/// expected to be positive.
inline int64_t ceilDiv(int64_t lhs, int64_t rhs) {
assert(rhs >= 1);
// C/C++'s integer division rounds towards 0.
return lhs % rhs > 0 ? lhs / rhs + 1 : lhs / rhs;
}
/// Returns the result of MLIR's floordiv operation on constants. The RHS is
/// expected to be positive.
inline int64_t floorDiv(int64_t lhs, int64_t rhs) {
assert(rhs >= 1);
// C/C++'s integer division rounds towards 0.
return lhs % rhs < 0 ? lhs / rhs - 1 : lhs / rhs;
}
/// Returns MLIR's mod operation on constants. MLIR's mod operation yields the
/// remainder of the Euclidean division of 'lhs' by 'rhs', and is therefore not
/// C's % operator. The RHS is always expected to be positive, and the result
/// is always non-negative.
inline int64_t mod(int64_t lhs, int64_t rhs) {
assert(rhs >= 1);
return lhs % rhs < 0 ? lhs % rhs + rhs : lhs % rhs;
}
} // end namespace mlir
#endif // MLIR_SUPPORT_MATHEXTRAS_H_

View File

@ -25,6 +25,7 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Statements.h"
#include "mlir/Support/MathExtras.h"
using mlir::AffineExpr;
@ -74,10 +75,8 @@ AffineExpr *mlir::getTripCountExpr(const ForStmt &forStmt) {
if (loopSpan < 0)
return 0;
return AffineConstantExpr::get(
static_cast<uint64_t>(loopSpan % step == 0 ? loopSpan / step
: loopSpan / step + 1),
context);
return AffineConstantExpr::get(static_cast<uint64_t>(ceilDiv(loopSpan, step)),
context);
}
/// Returns the trip count of the loop if it's a constant, None otherwise. This

View File

@ -17,8 +17,8 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/MathExtras.h"
using namespace mlir;
@ -165,8 +165,8 @@ AffineExpr *AffineBinaryOpExpr::simplifyFloorDiv(AffineExpr *lhs,
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
if (lhsConst && rhsConst)
return AffineConstantExpr::get(lhsConst->getValue() / rhsConst->getValue(),
context);
return AffineConstantExpr::get(
floorDiv(lhsConst->getValue(), rhsConst->getValue()), context);
// Fold floordiv of a multiply with a constant that is a multiple of the
// divisor. Eg: (i * 128) floordiv 64 = i * 2.
@ -199,9 +199,7 @@ AffineExpr *AffineBinaryOpExpr::simplifyCeilDiv(AffineExpr *lhs,
if (lhsConst && rhsConst)
return AffineConstantExpr::get(
(int64_t)llvm::divideCeil((uint64_t)lhsConst->getValue(),
(uint64_t)rhsConst->getValue()),
context);
ceilDiv(lhsConst->getValue(), rhsConst->getValue()), context);
// Fold ceildiv of a multiply with a constant that is a multiple of the
// divisor. Eg: (i * 128) ceildiv 64 = i * 2.
@ -232,8 +230,8 @@ AffineExpr *AffineBinaryOpExpr::simplifyMod(AffineExpr *lhs, AffineExpr *rhs,
auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
if (lhsConst && rhsConst)
return AffineConstantExpr::get(lhsConst->getValue() % rhsConst->getValue(),
context);
return AffineConstantExpr::get(
mod(lhsConst->getValue(), rhsConst->getValue()), context);
// Fold modulo of an expression that is known to be a multiple of a constant
// to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)

View File

@ -23,6 +23,7 @@
#include "mlir/IR/OperationSet.h"
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/Support/raw_ostream.h"
@ -201,17 +202,14 @@ public:
return constantFoldBinExpr(
expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; });
case AffineExpr::Kind::Mod:
return constantFoldBinExpr(expr, [](int64_t lhs, uint64_t rhs) {
return lhs % rhs < 0 ? lhs % rhs + rhs : lhs % rhs;
});
return constantFoldBinExpr(
expr, [](int64_t lhs, uint64_t rhs) { return mod(lhs, rhs); });
case AffineExpr::Kind::FloorDiv:
return constantFoldBinExpr(expr, [](int64_t lhs, uint64_t rhs) {
return lhs % rhs < 0 ? lhs / rhs - 1 : lhs / rhs;
});
return constantFoldBinExpr(
expr, [](int64_t lhs, uint64_t rhs) { return floorDiv(lhs, rhs); });
case AffineExpr::Kind::CeilDiv:
return constantFoldBinExpr(expr, [](int64_t lhs, uint64_t rhs) {
return lhs % rhs == 0 ? lhs / rhs : lhs / rhs + 1;
});
return constantFoldBinExpr(
expr, [](int64_t lhs, uint64_t rhs) { return ceilDiv(lhs, rhs); });
case AffineExpr::Kind::Constant:
return IntegerAttr::get(cast<AffineConstantExpr>(expr)->getValue(),
context);

View File

@ -171,6 +171,10 @@
// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 * 2 + 1, d1 + s0)
#map50 = (i, j)[s0] -> ( (i * 2 + 1) ceildiv 1, (j + s0) floordiv 1)
// floordiv, ceildiv, and mod where LHS is negative.
// CHECK: #map{{[0-9]+}} = (d0) -> (-2, 1, -1)
#map51 = (i) -> (-5 floordiv 3, -5 mod 3, -5 ceildiv 3)
// CHECK: extfunc @f0(memref<2x4xi8, #map{{[0-9]+}}, 1>)
extfunc @f0(memref<2x4xi8, #map0, 1>)
@ -338,3 +342,6 @@ extfunc @f49(memref<100x100xi8, #map49>)
// CHECK: extfunc @f50(memref<100x100xi8, #map{{[0-9]+}}>)
extfunc @f50(memref<100x100xi8, #map50>)
// CHECK: extfunc @f51(memref<1xi8, #map{{[0-9]+}}>)
extfunc @f51(memref<1xi8, #map51>)