forked from OSchip/llvm-project
[mlir][SCF] Further simplify affine maps during `for-loop-canonicalization`
* Implement `FlatAffineConstraints::getConstantBound(EQ)`. * Inject a simpler constraint for loops that have at most 1 iteration. * Taking into account constant EQ bounds of FlatAffineConstraint dims/symbols during canonicalization of the resulting affine map in `canonicalizeMinMaxOp`. Differential Revision: https://reviews.llvm.org/D114138
This commit is contained in:
parent
8a8c655fe7
commit
ee1bf18672
|
@ -385,7 +385,6 @@ public:
|
|||
|
||||
/// Returns the constant bound for the pos^th identifier if there is one;
|
||||
/// None otherwise.
|
||||
// TODO: Support EQ bounds.
|
||||
Optional<int64_t> getConstantBound(BoundType type, unsigned pos) const;
|
||||
|
||||
/// Gets the lower and upper bound of the `offset` + `pos`th identifier
|
||||
|
|
|
@ -2836,11 +2836,22 @@ FlatAffineConstraints::computeConstantLowerOrUpperBound(unsigned pos) {
|
|||
|
||||
Optional<int64_t> FlatAffineConstraints::getConstantBound(BoundType type,
|
||||
unsigned pos) const {
|
||||
assert(type != BoundType::EQ && "EQ not implemented");
|
||||
FlatAffineConstraints tmpCst(*this);
|
||||
if (type == BoundType::LB)
|
||||
return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/true>(pos);
|
||||
return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
|
||||
return FlatAffineConstraints(*this)
|
||||
.computeConstantLowerOrUpperBound</*isLower=*/true>(pos);
|
||||
if (type == BoundType::UB)
|
||||
return FlatAffineConstraints(*this)
|
||||
.computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
|
||||
|
||||
assert(type == BoundType::EQ && "expected EQ");
|
||||
Optional<int64_t> lb =
|
||||
FlatAffineConstraints(*this)
|
||||
.computeConstantLowerOrUpperBound</*isLower=*/true>(pos);
|
||||
Optional<int64_t> ub =
|
||||
FlatAffineConstraints(*this)
|
||||
.computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
|
||||
return (lb && ub && *lb == *ub) ? Optional<int64_t>(*ub) : None;
|
||||
}
|
||||
|
||||
// A simple (naive and conservative) check for hyper-rectangularity.
|
||||
|
|
|
@ -305,6 +305,16 @@ canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, AffineMap map,
|
|||
AffineMap newMap = alignedBoundMap;
|
||||
SmallVector<Value> newOperands;
|
||||
unpackOptionalValues(constraints.getMaybeDimAndSymbolValues(), newOperands);
|
||||
// If dims/symbols have known constant values, use those in order to simplify
|
||||
// the affine map further.
|
||||
for (int64_t i = 0; i < constraints.getNumDimAndSymbolIds(); ++i) {
|
||||
// Skip unused operands and operands that are already constants.
|
||||
if (!newOperands[i] || getConstantIntValue(newOperands[i]))
|
||||
continue;
|
||||
if (auto bound = constraints.getConstantBound(FlatAffineConstraints::EQ, i))
|
||||
newOperands[i] =
|
||||
rewriter.create<arith::ConstantIndexOp>(op->getLoc(), *bound);
|
||||
}
|
||||
mlir::canonicalizeMapAndOperands(&newMap, &newOperands);
|
||||
rewriter.setInsertionPoint(op);
|
||||
rewriter.replaceOpWithNewOp<AffineApplyOp>(op, newMap, newOperands);
|
||||
|
@ -457,19 +467,30 @@ mlir::scf::canonicalizeMinMaxOpInLoop(RewriterBase &rewriter, Operation *op,
|
|||
if (ubInt)
|
||||
constraints.addBound(FlatAffineConstraints::EQ, dimUb, *ubInt);
|
||||
|
||||
// iv >= lb (equiv.: iv - lb >= 0)
|
||||
// Lower bound: iv >= lb (equiv.: iv - lb >= 0)
|
||||
SmallVector<int64_t> ineqLb(constraints.getNumCols(), 0);
|
||||
ineqLb[dimIv] = 1;
|
||||
ineqLb[dimLb] = -1;
|
||||
constraints.addInequality(ineqLb);
|
||||
|
||||
// Upper bound
|
||||
AffineExpr ivUb;
|
||||
if (lbInt && ubInt && (*lbInt + *stepInt >= *ubInt)) {
|
||||
// The loop has at most one iteration.
|
||||
// iv < lb + 1
|
||||
// TODO: Try to derive this constraint by simplifying the expression in
|
||||
// the else-branch.
|
||||
ivUb = rewriter.getAffineDimExpr(dimLb) + 1;
|
||||
} else {
|
||||
// The loop may have more than one iteration.
|
||||
// iv < lb + step * ((ub - lb - 1) floorDiv step) + 1
|
||||
AffineExpr exprLb = lbInt ? rewriter.getAffineConstantExpr(*lbInt)
|
||||
: rewriter.getAffineDimExpr(dimLb);
|
||||
AffineExpr exprUb = ubInt ? rewriter.getAffineConstantExpr(*ubInt)
|
||||
: rewriter.getAffineDimExpr(dimUb);
|
||||
AffineExpr ivUb =
|
||||
ivUb =
|
||||
exprLb + 1 + (*stepInt * ((exprUb - exprLb - 1).floorDiv(*stepInt)));
|
||||
}
|
||||
auto map = AffineMap::get(
|
||||
/*dimCount=*/constraints.getNumDimIds(),
|
||||
/*symbolCount=*/constraints.getNumSymbolIds(), /*result=*/ivUb);
|
||||
|
|
|
@ -348,3 +348,22 @@ func @tensor_dim_of_loop_result_no_canonicalize(%t : tensor<?x?xf32>,
|
|||
%dim = tensor.dim %1, %c0 : tensor<?x?xf32>
|
||||
return %dim : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @one_trip_scf_for_canonicalize_min
|
||||
// CHECK: %[[C4:.*]] = arith.constant 4 : i64
|
||||
// CHECK: scf.for
|
||||
// CHECK: memref.store %[[C4]], %{{.*}}[] : memref<i64>
|
||||
func @one_trip_scf_for_canonicalize_min(%A : memref<i64>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
|
||||
scf.for %i = %c0 to %c4 step %c4 {
|
||||
%1 = affine.min affine_map<(d0, d1)[] -> (4, d1 - d0)> (%i, %c4)
|
||||
%2 = arith.index_cast %1: index to i64
|
||||
memref.store %2, %A[]: memref<i64>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue