diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h index c450024dcb57..f1f267ff0fc2 100644 --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -125,6 +125,10 @@ public: ArrayRef getResults() const; AffineExpr getResult(unsigned idx) const; + /// Extracts the position of the dimensional expression at the given result, + /// when the caller knows it is safe to do so. + unsigned getDimPosition(unsigned idx) const; + /// Walk all of the AffineExpr's in this mapping. Each node in an expression /// tree is visited in postorder. void walkExprs(std::function callback) const; diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp index abc10e8f486a..8e1dbf17d3f1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -466,9 +466,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp, unsigned pos = resultExpr.value().cast().getPosition(); AffineMap foldedDims = reassociationMaps[resultExpr.index()]; numFoldedDims[pos] = foldedDims.getNumResults(); - ArrayRef shape = expandedShape.slice( - foldedDims.getResult(0).cast().getPosition(), - numFoldedDims[pos]); + ArrayRef shape = + expandedShape.slice(foldedDims.getDimPosition(0), numFoldedDims[pos]); expandedDimsShape[pos].assign(shape.begin(), shape.end()); } diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index 39aed718de0a..0cc1e7c07aba 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -336,7 +336,7 @@ static LogicalResult verifyOutputShape( VectorType v = pair.first; auto map = pair.second; for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) { - unsigned pos = map.getResult(idx).cast().getPosition(); + unsigned pos = map.getDimPosition(idx); if (!extents[pos]) extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx); } @@ -785,8 +785,7 @@ static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) { if (insertedPos.size() == extractedPos.size()) { bool fold = true; for (unsigned idx = 0, sz = extractedPos.size(); idx < sz; ++idx) { - auto pos = - permutationMap.getResult(idx).cast().getPosition(); + auto pos = permutationMap.getDimPosition(idx); if (pos >= sz || insertedPos[pos] != extractedPos[idx]) { fold = false; break; diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp index 49865fddba4c..e488db677fe5 100644 --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -50,7 +50,7 @@ using llvm::dbgs; // Helper to find an index in an affine map. static Optional getResultIndex(AffineMap map, int64_t index) { for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { - int64_t idx = map.getResult(i).cast().getPosition(); + int64_t idx = map.getDimPosition(i); if (idx == index) return i; } @@ -76,7 +76,7 @@ static AffineMap adjustMap(AffineMap map, int64_t index, auto *ctx = rewriter.getContext(); SmallVector results; for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { - int64_t idx = map.getResult(i).cast().getPosition(); + int64_t idx = map.getDimPosition(i); if (idx == index) continue; // Re-insert remaining indices, but renamed when occurring @@ -2016,16 +2016,13 @@ Value ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t iterIndex = -1; int64_t dimSize = -1; if (lhsIndex >= 0) { - iterIndex = iMap[0].getResult(lhsIndex).cast().getPosition(); - assert( - (rhsIndex < 0 || - iterIndex == - iMap[1].getResult(rhsIndex).cast().getPosition()) && - "parallel index should be free in LHS or batch in LHS/RHS"); + iterIndex = iMap[0].getDimPosition(lhsIndex); + assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) && + "parallel index should be free in LHS or batch in LHS/RHS"); dimSize = lhsType.getDimSize(lhsIndex); } else { assert(rhsIndex >= 0 && "missing parallel index"); - iterIndex = iMap[1].getResult(rhsIndex).cast().getPosition(); + iterIndex = iMap[1].getDimPosition(rhsIndex); dimSize = rhsType.getDimSize(rhsIndex); } assert(iterIndex >= 0 && "parallel index not listed in operand mapping"); diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp index 1f73d07cc8ff..cc2cb8be4f3c 100644 --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -227,6 +227,10 @@ AffineExpr AffineMap::getResult(unsigned idx) const { return map->results[idx]; } +unsigned AffineMap::getDimPosition(unsigned idx) const { + return getResult(idx).cast().getPosition(); +} + /// Folds the results of the application of an affine map on the provided /// operands to a constant if possible. Returns false if the folding happens, /// true otherwise.