[mlir][linalg] Cleanup LinalgOp usage in drop unit dims.

Replace the uses of deprecated Structured Op Interface methods in DropUnitDims.cpp. This patch is based on https://reviews.llvm.org/D103394.

Differential Revision: https://reviews.llvm.org/D103448
This commit is contained in:
Tobias Gysi 2021-06-03 12:24:59 +00:00
parent 728cc0075e
commit c698505257
1 changed files with 14 additions and 16 deletions

View File

@ -183,9 +183,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
if (!invertedMap)
return failure();
SmallVector<int64_t, 4> dims;
for (ShapedType shapedType : genericOp.getShapedOperandTypes())
dims.append(shapedType.getShape().begin(), shapedType.getShape().end());
SmallVector<int64_t> dims = genericOp.getStaticShape();
// Find all the reduction iterators. Those need some special consideration
// (see below).
@ -267,17 +265,18 @@ struct UnitExtentReplacementInfo {
/// - modified index map that can be used to access the replaced result/operand
/// - the reassociation that converts from the original tensor type to the
/// modified tensor type.
static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
RankedTensorType type,
static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
OpOperand *opOperand,
MLIRContext *context) {
ArrayRef<int64_t> shape = type.getShape();
ArrayRef<AffineExpr> exprs = indexMap.getResults();
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
ArrayRef<AffineExpr> exprs = indexingMap.getResults();
SmallVector<AffineExpr, 2> reassociations;
SmallVector<Attribute, 4> reassociationMaps;
SmallVector<AffineExpr, 4> newIndexExprs;
SmallVector<int64_t, 4> newShape;
int64_t origRank = type.getRank();
int64_t origRank = genericOp.getRank(opOperand);
AffineExpr zeroExpr = getAffineConstantExpr(0, context);
auto isUnitExtent = [&](int64_t dim) -> bool {
return shape[dim] == 1 && exprs[dim] == zeroExpr;
@ -302,8 +301,9 @@ static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
++dim;
}
UnitExtentReplacementInfo info = {
RankedTensorType::get(newShape, type.getElementType()),
AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(),
RankedTensorType::get(newShape,
getElementTypeOrSelf(opOperand->get().getType())),
AffineMap::get(indexingMap.getNumDims(), indexingMap.getNumSymbols(),
newIndexExprs, context),
ArrayAttr::get(context, reassociationMaps)};
return info;
@ -335,15 +335,13 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
SmallVector<ArrayAttr, 4> reassociationMaps;
SmallVector<ShapedType, 4> newInputOutputTypes;
bool doCanonicalization = false;
for (auto it : llvm::zip(genericOp.getIndexingMaps(),
genericOp.getShapedOperandTypes())) {
auto replacementInfo = replaceUnitExtents(
std::get<0>(it), std::get<1>(it).template cast<RankedTensorType>(),
context);
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
reassociationMaps.push_back(replacementInfo.reassociation);
newIndexingMaps.push_back(replacementInfo.indexMap);
newInputOutputTypes.push_back(replacementInfo.type);
doCanonicalization |= replacementInfo.type != std::get<1>(it);
doCanonicalization |= replacementInfo.type != opOperand->get().getType();
}
// If the indexing maps of the result operation are not invertible (i.e. not