[mlir][linalg] Remove duplicate methods (NFC).

Remove duplicate methods used to check iterator types.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D108102
This commit is contained in:
Tobias Gysi 2021-08-17 07:04:21 +00:00
parent bcec4ccd04
commit 583a754248
5 changed files with 9 additions and 43 deletions

View File

@ -56,19 +56,6 @@ SmallVector<Value, 4> getDynOperands(Location loc, Value val, OpBuilder &b);
/// Otherwise return nullptr. /// Otherwise return nullptr.
IntegerAttr getSmallestBoundingIndex(Value size); IntegerAttr getSmallestBoundingIndex(Value size);
//===----------------------------------------------------------------------===//
// Iterator type utilities
//===----------------------------------------------------------------------===//
/// Checks if an iterator_type attribute is parallel.
bool isParallelIteratorType(Attribute attr);
/// Checks if an iterator_type attribute is parallel.
bool isReductionIteratorType(Attribute attr);
/// Checks if an iterator_type attribute is parallel.
bool isWindowIteratorType(Attribute attr);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Fusion utilities // Fusion utilities
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -53,7 +53,7 @@ struct DistributeTiledLoopPattern
if (procInfoCallback == options.procInfoMap.end()) if (procInfoCallback == options.procInfoMap.end())
continue; continue;
if (!isParallelIteratorType(op.iterator_types()[i])) { if (!isParallelIterator(op.iterator_types()[i])) {
op.emitOpError("only support for parallel loops is implemented"); op.emitOpError("only support for parallel loops is implemented");
return failure(); return failure();
} }

View File

@ -210,7 +210,7 @@ static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
unsigned idx = 0; unsigned idx = 0;
SmallVector<bool> reductionMask(linalgOp.iterator_types().size(), false); SmallVector<bool> reductionMask(linalgOp.iterator_types().size(), false);
for (auto attr : linalgOp.iterator_types()) { for (auto attr : linalgOp.iterator_types()) {
if (isReductionIteratorType(attr)) if (isReductionIterator(attr))
reductionMask[idx] = true; reductionMask[idx] = true;
++idx; ++idx;
} }
@ -615,7 +615,7 @@ static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
// TODO: probably need some extra checks for reduction followed by consumer // TODO: probably need some extra checks for reduction followed by consumer
// ops that may not commute (e.g. linear reduction + non-linear instructions). // ops that may not commute (e.g. linear reduction + non-linear instructions).
static LogicalResult reductionPreconditions(LinalgOp op) { static LogicalResult reductionPreconditions(LinalgOp op) {
if (llvm::none_of(op.iterator_types(), isReductionIteratorType)) if (llvm::none_of(op.iterator_types(), isReductionIterator))
return failure(); return failure();
for (OpOperand *opOperand : op.getOutputOperands()) { for (OpOperand *opOperand : op.getOutputOperands()) {
Operation *reductionOp = getSingleBinaryOpAssumedReduction(opOperand); Operation *reductionOp = getSingleBinaryOpAssumedReduction(opOperand);

View File

@ -116,27 +116,6 @@ RegionMatcher::matchAsScalarBinaryOp(GenericOp op) {
return llvm::None; return llvm::None;
} }
bool mlir::linalg::isParallelIteratorType(Attribute attr) {
if (auto strAttr = attr.dyn_cast<StringAttr>()) {
return strAttr.getValue() == getParallelIteratorTypeName();
}
return false;
}
bool mlir::linalg::isReductionIteratorType(Attribute attr) {
if (auto strAttr = attr.dyn_cast<StringAttr>()) {
return strAttr.getValue() == getReductionIteratorTypeName();
}
return false;
}
bool mlir::linalg::isWindowIteratorType(Attribute attr) {
if (auto strAttr = attr.dyn_cast<StringAttr>()) {
return strAttr.getValue() == getWindowIteratorTypeName();
}
return false;
}
/// Explicit instantiation of loop nest generator for different loop types. /// Explicit instantiation of loop nest generator for different loop types.
template struct mlir::linalg::GenerateLoopNest<scf::ForOp>; template struct mlir::linalg::GenerateLoopNest<scf::ForOp>;
template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>; template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>;
@ -233,7 +212,7 @@ void GenerateLoopNest<scf::ForOp>::doit(
// Collect loop ranges for parallel dimensions. // Collect loop ranges for parallel dimensions.
SmallVector<Range, 2> parallelLoopRanges; SmallVector<Range, 2> parallelLoopRanges;
for (auto iteratorType : enumerate(iteratorTypes)) for (auto iteratorType : enumerate(iteratorTypes))
if (isParallelIteratorType(iteratorType.value())) if (isParallelIterator(iteratorType.value()))
parallelLoopRanges.push_back(loopRanges[iteratorType.index()]); parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
// Get their distribution schemes. // Get their distribution schemes.
@ -254,7 +233,7 @@ void GenerateLoopNest<scf::ForOp>::doit(
// Filter out scf.for loops that were created out of parallel dimensions. // Filter out scf.for loops that were created out of parallel dimensions.
SmallVector<scf::ForOp, 4> loops; SmallVector<scf::ForOp, 4> loops;
for (auto iteratorType : enumerate(iteratorTypes)) for (auto iteratorType : enumerate(iteratorTypes))
if (isParallelIteratorType(iteratorType.value())) if (isParallelIterator(iteratorType.value()))
loops.push_back(loopNest.loops[iteratorType.index()]); loops.push_back(loopNest.loops[iteratorType.index()]);
// Distribute - only supports cyclic distribution for now. // Distribute - only supports cyclic distribution for now.
@ -375,7 +354,7 @@ static void generateParallelLoopNest(
// Find the outermost parallel loops and drop their types from the list. // Find the outermost parallel loops and drop their types from the list.
unsigned nLoops = iteratorTypes.size(); unsigned nLoops = iteratorTypes.size();
unsigned nOuterPar = unsigned nOuterPar =
nLoops - iteratorTypes.drop_while(isParallelIteratorType).size(); nLoops - iteratorTypes.drop_while(isParallelIterator).size();
// If there are no outer parallel loops, generate one sequential loop and // If there are no outer parallel loops, generate one sequential loop and
// recurse. Note that we wouldn't have dropped anything from `iteratorTypes` // recurse. Note that we wouldn't have dropped anything from `iteratorTypes`
@ -502,7 +481,7 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
distributionOptions->distributionMethod.end()); distributionOptions->distributionMethod.end());
SmallVector<Range, 2> parallelLoopRanges; SmallVector<Range, 2> parallelLoopRanges;
for (auto iteratorType : enumerate(iteratorTypes)) { for (auto iteratorType : enumerate(iteratorTypes)) {
if (isParallelIteratorType(iteratorType.value())) if (isParallelIterator(iteratorType.value()))
parallelLoopRanges.push_back(loopRanges[iteratorType.index()]); parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
} }
if (distributionMethod.size() < parallelLoopRanges.size()) if (distributionMethod.size() < parallelLoopRanges.size())
@ -513,7 +492,7 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
for (auto iteratorType : enumerate(iteratorTypes)) { for (auto iteratorType : enumerate(iteratorTypes)) {
if (index >= procInfo.size()) if (index >= procInfo.size())
break; break;
if (isParallelIteratorType(iteratorType.value())) { if (isParallelIterator(iteratorType.value())) {
unsigned i = iteratorType.index(); unsigned i = iteratorType.index();
updateBoundsForCyclicDistribution(b, loc, procInfo[index].procId, updateBoundsForCyclicDistribution(b, loc, procInfo[index].procId,
procInfo[index].nprocs, lbsStorage[i], procInfo[index].nprocs, lbsStorage[i],

View File

@ -743,7 +743,7 @@ static Operation *genFor(Merger &merger, CodeGen &codegen,
unsigned tensor = merger.tensor(fb); unsigned tensor = merger.tensor(fb);
assert(idx == merger.index(fb)); assert(idx == merger.index(fb));
auto iteratorTypes = op.iterator_types().getValue(); auto iteratorTypes = op.iterator_types().getValue();
bool isReduction = linalg::isReductionIteratorType(iteratorTypes[idx]); bool isReduction = isReductionIterator(iteratorTypes[idx]);
bool isSparse = merger.isDim(fb, Dim::kSparse); bool isSparse = merger.isDim(fb, Dim::kSparse);
bool isVector = isVectorFor(codegen, isInner, isSparse) && bool isVector = isVectorFor(codegen, isInner, isSparse) &&
denseUnitStrides(merger, op, idx); denseUnitStrides(merger, op, idx);