forked from OSchip/llvm-project
[mlir] Factor out constraint set creation from hoist padding.
This revision adds a ``` FlatAffineValueConstraints(ValueRange ivs, ValueRange lbs, ValueRange ubs) ``` method and use it in hoist padding. Differential Revision: https://reviews.llvm.org/D110427
This commit is contained in:
parent
77aa9ca92a
commit
1b49a72de9
|
@ -592,6 +592,20 @@ public:
|
|||
FlatAffineValueConstraints(ArrayRef<const AffineValueMap *> avmRef,
|
||||
IntegerSet set);
|
||||
|
||||
// Construct a hyperrectangular constraint set from ValueRanges that represent
|
||||
// induction variables, lower and upper bounds. `ivs`, `lbs` and `ubs` are
|
||||
// expected to match one to one. The order of variables and constraints is:
|
||||
//
|
||||
// ivs | lbs | ubs | eq/ineq
|
||||
// ----+-----+-----+---------
|
||||
// 1 -1 0 >= 0
|
||||
// ----+-----+-----+---------
|
||||
// -1 0 1 >= 0
|
||||
//
|
||||
// All dimensions as set as DimId.
|
||||
static FlatAffineValueConstraints
|
||||
getHyperrectangular(ValueRange ivs, ValueRange lbs, ValueRange ubs);
|
||||
|
||||
/// Return the kind of this FlatAffineConstraints.
|
||||
Kind getKind() const override { return Kind::FlatAffineValueConstraints; }
|
||||
|
||||
|
|
|
@ -189,6 +189,48 @@ FlatAffineValueConstraints::FlatAffineValueConstraints(IntegerSet set)
|
|||
values.resize(numIds, None);
|
||||
}
|
||||
|
||||
// Construct a hyperrectangular constraint set from ValueRanges that represent
|
||||
// induction variables, lower and upper bounds. `ivs`, `lbs` and `ubs` are
|
||||
// expected to match one to one. The order of variables and constraints is:
|
||||
//
|
||||
// ivs | lbs | ubs | eq/ineq
|
||||
// ----+-----+-----+---------
|
||||
// 1 -1 0 >= 0
|
||||
// ----+-----+-----+---------
|
||||
// -1 0 1 >= 0
|
||||
//
|
||||
// All dimensions as set as DimId.
|
||||
FlatAffineValueConstraints
|
||||
FlatAffineValueConstraints::getHyperrectangular(ValueRange ivs, ValueRange lbs,
|
||||
ValueRange ubs) {
|
||||
FlatAffineValueConstraints res;
|
||||
unsigned nIvs = ivs.size();
|
||||
assert(nIvs == lbs.size() && "expected as many lower bounds as ivs");
|
||||
assert(nIvs == ubs.size() && "expected as many upper bounds as ivs");
|
||||
|
||||
if (nIvs == 0)
|
||||
return res;
|
||||
|
||||
res.appendDimId(ivs);
|
||||
unsigned lbsStart = res.appendDimId(lbs);
|
||||
unsigned ubsStart = res.appendDimId(ubs);
|
||||
|
||||
MLIRContext *ctx = ivs.front().getContext();
|
||||
for (int ivIdx = 0, e = nIvs; ivIdx < e; ++ivIdx) {
|
||||
// iv - lb >= 0
|
||||
AffineMap lb = AffineMap::get(/*dimCount=*/3 * nIvs, /*symbolCount=*/0,
|
||||
getAffineDimExpr(lbsStart + ivIdx, ctx));
|
||||
if (failed(res.addBound(BoundType::LB, ivIdx, lb)))
|
||||
llvm_unreachable("Unexpected FlatAffineValueConstraints creation error");
|
||||
// -iv + ub >= 0
|
||||
AffineMap ub = AffineMap::get(/*dimCount=*/3 * nIvs, /*symbolCount=*/0,
|
||||
getAffineDimExpr(ubsStart + ivIdx, ctx));
|
||||
if (failed(res.addBound(BoundType::UB, ivIdx, ub)))
|
||||
llvm_unreachable("Unexpected FlatAffineValueConstraints creation error");
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
void FlatAffineConstraints::reset(unsigned numReservedInequalities,
|
||||
unsigned numReservedEqualities,
|
||||
unsigned newNumReservedCols,
|
||||
|
|
|
@ -189,43 +189,6 @@ HoistingAnalysis::HoistingAnalysis(PadTensorOp padTensorOp, int nLevels)
|
|||
valid = true;
|
||||
}
|
||||
|
||||
/// Given a set of loops, assumed to be scf::ForOp, create a constraint set
|
||||
/// containing the inequalities `iv - lb >= 0` and `-iv + ub - 1 >= 0` for each
|
||||
/// loop. The order of the constraints follows:
|
||||
///
|
||||
/// ivs | lbs | ubs | eq/ineq
|
||||
/// ----+-----+-----+---------
|
||||
/// 1 -1 0 >= 0
|
||||
/// ----+-----+-----+---------
|
||||
/// -1 0 1 >= 0
|
||||
///
|
||||
static FlatAffineValueConstraints
|
||||
initLoopIvsAndBounds(ArrayRef<scf::ForOp> loops) {
|
||||
FlatAffineValueConstraints constraints;
|
||||
// Append dims for all ivs, lbs, ubs: the order is important.
|
||||
for (scf::ForOp op : loops)
|
||||
constraints.appendDimId(op.getInductionVar());
|
||||
for (scf::ForOp op : loops)
|
||||
constraints.appendDimId(op.lowerBound());
|
||||
for (scf::ForOp op : loops)
|
||||
constraints.appendDimId(op.upperBound());
|
||||
int numLoops = loops.size();
|
||||
for (int ivIdx = 0, e = numLoops; ivIdx < e; ++ivIdx) {
|
||||
// iv - lb >= 0
|
||||
SmallVector<int64_t, 8> ineqLb(constraints.getNumCols(), 0);
|
||||
ineqLb[ivIdx] = 1;
|
||||
ineqLb[ivIdx + numLoops] = -1;
|
||||
// -iv + ub >= 0
|
||||
SmallVector<int64_t, 8> ineqUb(constraints.getNumCols(), 0);
|
||||
ineqUb[ivIdx] = -1;
|
||||
ineqUb[ivIdx + 2 * numLoops] = 1;
|
||||
ineqUb[constraints.getNumCols() - 1] = -1;
|
||||
constraints.addInequality(ineqLb);
|
||||
constraints.addInequality(ineqUb);
|
||||
}
|
||||
return constraints;
|
||||
}
|
||||
|
||||
static bool isDefinedOutsideOrConstant(scf::ForOp outer, Value v) {
|
||||
return outer.isDefinedOutsideOfLoop(v) || v.getDefiningOp<ConstantOp>();
|
||||
}
|
||||
|
@ -317,8 +280,16 @@ foldUpperBoundsIntoConstraintsSet(FlatAffineValueConstraints &constraints,
|
|||
// `backwardSlice`.
|
||||
FailureOr<SmallVector<Value>>
|
||||
HoistingAnalysis::getPackedTensorSizes(ImplicitLocOpBuilder &b) {
|
||||
FlatAffineValueConstraints constraints =
|
||||
initLoopIvsAndBounds(packingLoops.getArrayRef());
|
||||
// Create the base affine constaints for the packedLoops.
|
||||
auto constraints = FlatAffineValueConstraints::getHyperrectangular(
|
||||
llvm::to_vector<8>(llvm::map_range(
|
||||
packingLoops, [](scf::ForOp op) { return op.getInductionVar(); })),
|
||||
llvm::to_vector<8>(llvm::map_range(
|
||||
packingLoops, [](scf::ForOp op) { return op.lowerBound(); })),
|
||||
llvm::to_vector<8>(llvm::map_range(
|
||||
packingLoops, [](scf::ForOp op) { return op.upperBound(); })));
|
||||
|
||||
// Iteratively try to fold the upper bounds into the constraints set.
|
||||
if (failed(foldUpperBoundsIntoConstraintsSet(
|
||||
constraints, outermostEnclosingForOp, packingLoops.getArrayRef())))
|
||||
return failure();
|
||||
|
|
Loading…
Reference in New Issue