forked from OSchip/llvm-project
[mlir] refactor common idiom into AffineMap method
motivated by a refactoring in the new sparse code (yet to be merged), this avoids some lengthy code dup Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D91465
This commit is contained in:
parent
e7ed276532
commit
9ddb464d37
|
@ -125,6 +125,10 @@ public:
|
||||||
ArrayRef<AffineExpr> getResults() const;
|
ArrayRef<AffineExpr> getResults() const;
|
||||||
AffineExpr getResult(unsigned idx) const;
|
AffineExpr getResult(unsigned idx) const;
|
||||||
|
|
||||||
|
/// Extracts the position of the dimensional expression at the given result,
|
||||||
|
/// when the caller knows it is safe to do so.
|
||||||
|
unsigned getDimPosition(unsigned idx) const;
|
||||||
|
|
||||||
/// Walk all of the AffineExpr's in this mapping. Each node in an expression
|
/// Walk all of the AffineExpr's in this mapping. Each node in an expression
|
||||||
/// tree is visited in postorder.
|
/// tree is visited in postorder.
|
||||||
void walkExprs(std::function<void(AffineExpr)> callback) const;
|
void walkExprs(std::function<void(AffineExpr)> callback) const;
|
||||||
|
|
|
@ -466,9 +466,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
|
||||||
unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
|
unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
|
||||||
AffineMap foldedDims = reassociationMaps[resultExpr.index()];
|
AffineMap foldedDims = reassociationMaps[resultExpr.index()];
|
||||||
numFoldedDims[pos] = foldedDims.getNumResults();
|
numFoldedDims[pos] = foldedDims.getNumResults();
|
||||||
ArrayRef<int64_t> shape = expandedShape.slice(
|
ArrayRef<int64_t> shape =
|
||||||
foldedDims.getResult(0).cast<AffineDimExpr>().getPosition(),
|
expandedShape.slice(foldedDims.getDimPosition(0), numFoldedDims[pos]);
|
||||||
numFoldedDims[pos]);
|
|
||||||
expandedDimsShape[pos].assign(shape.begin(), shape.end());
|
expandedDimsShape[pos].assign(shape.begin(), shape.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -336,7 +336,7 @@ static LogicalResult verifyOutputShape(
|
||||||
VectorType v = pair.first;
|
VectorType v = pair.first;
|
||||||
auto map = pair.second;
|
auto map = pair.second;
|
||||||
for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
|
for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
|
||||||
unsigned pos = map.getResult(idx).cast<AffineDimExpr>().getPosition();
|
unsigned pos = map.getDimPosition(idx);
|
||||||
if (!extents[pos])
|
if (!extents[pos])
|
||||||
extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
|
extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
|
||||||
}
|
}
|
||||||
|
@ -785,8 +785,7 @@ static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
|
||||||
if (insertedPos.size() == extractedPos.size()) {
|
if (insertedPos.size() == extractedPos.size()) {
|
||||||
bool fold = true;
|
bool fold = true;
|
||||||
for (unsigned idx = 0, sz = extractedPos.size(); idx < sz; ++idx) {
|
for (unsigned idx = 0, sz = extractedPos.size(); idx < sz; ++idx) {
|
||||||
auto pos =
|
auto pos = permutationMap.getDimPosition(idx);
|
||||||
permutationMap.getResult(idx).cast<AffineDimExpr>().getPosition();
|
|
||||||
if (pos >= sz || insertedPos[pos] != extractedPos[idx]) {
|
if (pos >= sz || insertedPos[pos] != extractedPos[idx]) {
|
||||||
fold = false;
|
fold = false;
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -50,7 +50,7 @@ using llvm::dbgs;
|
||||||
// Helper to find an index in an affine map.
|
// Helper to find an index in an affine map.
|
||||||
static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
|
static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
|
||||||
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
|
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
|
||||||
int64_t idx = map.getResult(i).cast<AffineDimExpr>().getPosition();
|
int64_t idx = map.getDimPosition(i);
|
||||||
if (idx == index)
|
if (idx == index)
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
|
@ -76,7 +76,7 @@ static AffineMap adjustMap(AffineMap map, int64_t index,
|
||||||
auto *ctx = rewriter.getContext();
|
auto *ctx = rewriter.getContext();
|
||||||
SmallVector<AffineExpr, 4> results;
|
SmallVector<AffineExpr, 4> results;
|
||||||
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
|
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
|
||||||
int64_t idx = map.getResult(i).cast<AffineDimExpr>().getPosition();
|
int64_t idx = map.getDimPosition(i);
|
||||||
if (idx == index)
|
if (idx == index)
|
||||||
continue;
|
continue;
|
||||||
// Re-insert remaining indices, but renamed when occurring
|
// Re-insert remaining indices, but renamed when occurring
|
||||||
|
@ -2016,16 +2016,13 @@ Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
|
||||||
int64_t iterIndex = -1;
|
int64_t iterIndex = -1;
|
||||||
int64_t dimSize = -1;
|
int64_t dimSize = -1;
|
||||||
if (lhsIndex >= 0) {
|
if (lhsIndex >= 0) {
|
||||||
iterIndex = iMap[0].getResult(lhsIndex).cast<AffineDimExpr>().getPosition();
|
iterIndex = iMap[0].getDimPosition(lhsIndex);
|
||||||
assert(
|
assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) &&
|
||||||
(rhsIndex < 0 ||
|
"parallel index should be free in LHS or batch in LHS/RHS");
|
||||||
iterIndex ==
|
|
||||||
iMap[1].getResult(rhsIndex).cast<AffineDimExpr>().getPosition()) &&
|
|
||||||
"parallel index should be free in LHS or batch in LHS/RHS");
|
|
||||||
dimSize = lhsType.getDimSize(lhsIndex);
|
dimSize = lhsType.getDimSize(lhsIndex);
|
||||||
} else {
|
} else {
|
||||||
assert(rhsIndex >= 0 && "missing parallel index");
|
assert(rhsIndex >= 0 && "missing parallel index");
|
||||||
iterIndex = iMap[1].getResult(rhsIndex).cast<AffineDimExpr>().getPosition();
|
iterIndex = iMap[1].getDimPosition(rhsIndex);
|
||||||
dimSize = rhsType.getDimSize(rhsIndex);
|
dimSize = rhsType.getDimSize(rhsIndex);
|
||||||
}
|
}
|
||||||
assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
|
assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
|
||||||
|
|
|
@ -227,6 +227,10 @@ AffineExpr AffineMap::getResult(unsigned idx) const {
|
||||||
return map->results[idx];
|
return map->results[idx];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsigned AffineMap::getDimPosition(unsigned idx) const {
|
||||||
|
return getResult(idx).cast<AffineDimExpr>().getPosition();
|
||||||
|
}
|
||||||
|
|
||||||
/// Folds the results of the application of an affine map on the provided
|
/// Folds the results of the application of an affine map on the provided
|
||||||
/// operands to a constant if possible. Returns false if the folding happens,
|
/// operands to a constant if possible. Returns false if the folding happens,
|
||||||
/// true otherwise.
|
/// true otherwise.
|
||||||
|
|
Loading…
Reference in New Issue