diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index b03fde6e9b37..81ab7eaa0886 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -56,19 +56,6 @@ SmallVector getDynOperands(Location loc, Value val, OpBuilder &b); /// Otherwise return nullptr. IntegerAttr getSmallestBoundingIndex(Value size); -//===----------------------------------------------------------------------===// -// Iterator type utilities -//===----------------------------------------------------------------------===// - -/// Checks if an iterator_type attribute is parallel. -bool isParallelIteratorType(Attribute attr); - -/// Checks if an iterator_type attribute is parallel. -bool isReductionIteratorType(Attribute attr); - -/// Checks if an iterator_type attribute is parallel. -bool isWindowIteratorType(Attribute attr); - //===----------------------------------------------------------------------===// // Fusion utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Distribution.cpp b/mlir/lib/Dialect/Linalg/Transforms/Distribution.cpp index 994f7c76ddfd..e951d6882022 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Distribution.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Distribution.cpp @@ -53,7 +53,7 @@ struct DistributeTiledLoopPattern if (procInfoCallback == options.procInfoMap.end()) continue; - if (!isParallelIteratorType(op.iterator_types()[i])) { + if (!isParallelIterator(op.iterator_types()[i])) { op.emitOpError("only support for parallel loops is implemented"); return failure(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index abb4328b08f1..0a622e8335c9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -210,7 +210,7 @@ static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType, unsigned idx = 0; SmallVector reductionMask(linalgOp.iterator_types().size(), false); for (auto attr : linalgOp.iterator_types()) { - if (isReductionIteratorType(attr)) + if (isReductionIterator(attr)) reductionMask[idx] = true; ++idx; } @@ -615,7 +615,7 @@ static bool allIndexingsAreProjectedPermutation(LinalgOp op) { // TODO: probably need some extra checks for reduction followed by consumer // ops that may not commute (e.g. linear reduction + non-linear instructions). static LogicalResult reductionPreconditions(LinalgOp op) { - if (llvm::none_of(op.iterator_types(), isReductionIteratorType)) + if (llvm::none_of(op.iterator_types(), isReductionIterator)) return failure(); for (OpOperand *opOperand : op.getOutputOperands()) { Operation *reductionOp = getSingleBinaryOpAssumedReduction(opOperand); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 1620a047390b..596ae49232c6 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -116,27 +116,6 @@ RegionMatcher::matchAsScalarBinaryOp(GenericOp op) { return llvm::None; } -bool mlir::linalg::isParallelIteratorType(Attribute attr) { - if (auto strAttr = attr.dyn_cast()) { - return strAttr.getValue() == getParallelIteratorTypeName(); - } - return false; -} - -bool mlir::linalg::isReductionIteratorType(Attribute attr) { - if (auto strAttr = attr.dyn_cast()) { - return strAttr.getValue() == getReductionIteratorTypeName(); - } - return false; -} - -bool mlir::linalg::isWindowIteratorType(Attribute attr) { - if (auto strAttr = attr.dyn_cast()) { - return strAttr.getValue() == getWindowIteratorTypeName(); - } - return false; -} - /// Explicit instantiation of loop nest generator for different loop types. template struct mlir::linalg::GenerateLoopNest; template struct mlir::linalg::GenerateLoopNest; @@ -233,7 +212,7 @@ void GenerateLoopNest::doit( // Collect loop ranges for parallel dimensions. SmallVector parallelLoopRanges; for (auto iteratorType : enumerate(iteratorTypes)) - if (isParallelIteratorType(iteratorType.value())) + if (isParallelIterator(iteratorType.value())) parallelLoopRanges.push_back(loopRanges[iteratorType.index()]); // Get their distribution schemes. @@ -254,7 +233,7 @@ void GenerateLoopNest::doit( // Filter out scf.for loops that were created out of parallel dimensions. SmallVector loops; for (auto iteratorType : enumerate(iteratorTypes)) - if (isParallelIteratorType(iteratorType.value())) + if (isParallelIterator(iteratorType.value())) loops.push_back(loopNest.loops[iteratorType.index()]); // Distribute - only supports cyclic distribution for now. @@ -375,7 +354,7 @@ static void generateParallelLoopNest( // Find the outermost parallel loops and drop their types from the list. unsigned nLoops = iteratorTypes.size(); unsigned nOuterPar = - nLoops - iteratorTypes.drop_while(isParallelIteratorType).size(); + nLoops - iteratorTypes.drop_while(isParallelIterator).size(); // If there are no outer parallel loops, generate one sequential loop and // recurse. Note that we wouldn't have dropped anything from `iteratorTypes` @@ -502,7 +481,7 @@ void GenerateLoopNest::doit( distributionOptions->distributionMethod.end()); SmallVector parallelLoopRanges; for (auto iteratorType : enumerate(iteratorTypes)) { - if (isParallelIteratorType(iteratorType.value())) + if (isParallelIterator(iteratorType.value())) parallelLoopRanges.push_back(loopRanges[iteratorType.index()]); } if (distributionMethod.size() < parallelLoopRanges.size()) @@ -513,7 +492,7 @@ void GenerateLoopNest::doit( for (auto iteratorType : enumerate(iteratorTypes)) { if (index >= procInfo.size()) break; - if (isParallelIteratorType(iteratorType.value())) { + if (isParallelIterator(iteratorType.value())) { unsigned i = iteratorType.index(); updateBoundsForCyclicDistribution(b, loc, procInfo[index].procId, procInfo[index].nprocs, lbsStorage[i], diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index b2c64e450a84..2567693c4b64 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -743,7 +743,7 @@ static Operation *genFor(Merger &merger, CodeGen &codegen, unsigned tensor = merger.tensor(fb); assert(idx == merger.index(fb)); auto iteratorTypes = op.iterator_types().getValue(); - bool isReduction = linalg::isReductionIteratorType(iteratorTypes[idx]); + bool isReduction = isReductionIterator(iteratorTypes[idx]); bool isSparse = merger.isDim(fb, Dim::kSparse); bool isVector = isVectorFor(codegen, isInner, isSparse) && denseUnitStrides(merger, op, idx);