forked from OSchip/llvm-project
[mlir][linalg] Cleanup LinalgOp usage in drop unit dims.
Replace the uses of deprecated Structured Op Interface methods in DropUnitDims.cpp. This patch is based on https://reviews.llvm.org/D103394. Differential Revision: https://reviews.llvm.org/D103448
This commit is contained in:
parent
728cc0075e
commit
c698505257
|
@ -183,9 +183,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
|
|||
AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
|
||||
if (!invertedMap)
|
||||
return failure();
|
||||
SmallVector<int64_t, 4> dims;
|
||||
for (ShapedType shapedType : genericOp.getShapedOperandTypes())
|
||||
dims.append(shapedType.getShape().begin(), shapedType.getShape().end());
|
||||
SmallVector<int64_t> dims = genericOp.getStaticShape();
|
||||
|
||||
// Find all the reduction iterators. Those need some special consideration
|
||||
// (see below).
|
||||
|
@ -267,17 +265,18 @@ struct UnitExtentReplacementInfo {
|
|||
/// - modified index map that can be used to access the replaced result/operand
|
||||
/// - the reassociation that converts from the original tensor type to the
|
||||
/// modified tensor type.
|
||||
static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
|
||||
RankedTensorType type,
|
||||
static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
|
||||
OpOperand *opOperand,
|
||||
MLIRContext *context) {
|
||||
ArrayRef<int64_t> shape = type.getShape();
|
||||
ArrayRef<AffineExpr> exprs = indexMap.getResults();
|
||||
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
|
||||
ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
|
||||
ArrayRef<AffineExpr> exprs = indexingMap.getResults();
|
||||
SmallVector<AffineExpr, 2> reassociations;
|
||||
SmallVector<Attribute, 4> reassociationMaps;
|
||||
SmallVector<AffineExpr, 4> newIndexExprs;
|
||||
SmallVector<int64_t, 4> newShape;
|
||||
|
||||
int64_t origRank = type.getRank();
|
||||
int64_t origRank = genericOp.getRank(opOperand);
|
||||
AffineExpr zeroExpr = getAffineConstantExpr(0, context);
|
||||
auto isUnitExtent = [&](int64_t dim) -> bool {
|
||||
return shape[dim] == 1 && exprs[dim] == zeroExpr;
|
||||
|
@ -302,8 +301,9 @@ static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
|
|||
++dim;
|
||||
}
|
||||
UnitExtentReplacementInfo info = {
|
||||
RankedTensorType::get(newShape, type.getElementType()),
|
||||
AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(),
|
||||
RankedTensorType::get(newShape,
|
||||
getElementTypeOrSelf(opOperand->get().getType())),
|
||||
AffineMap::get(indexingMap.getNumDims(), indexingMap.getNumSymbols(),
|
||||
newIndexExprs, context),
|
||||
ArrayAttr::get(context, reassociationMaps)};
|
||||
return info;
|
||||
|
@ -335,15 +335,13 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
|
|||
SmallVector<ArrayAttr, 4> reassociationMaps;
|
||||
SmallVector<ShapedType, 4> newInputOutputTypes;
|
||||
bool doCanonicalization = false;
|
||||
for (auto it : llvm::zip(genericOp.getIndexingMaps(),
|
||||
genericOp.getShapedOperandTypes())) {
|
||||
auto replacementInfo = replaceUnitExtents(
|
||||
std::get<0>(it), std::get<1>(it).template cast<RankedTensorType>(),
|
||||
context);
|
||||
|
||||
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
|
||||
auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
|
||||
reassociationMaps.push_back(replacementInfo.reassociation);
|
||||
newIndexingMaps.push_back(replacementInfo.indexMap);
|
||||
newInputOutputTypes.push_back(replacementInfo.type);
|
||||
doCanonicalization |= replacementInfo.type != std::get<1>(it);
|
||||
doCanonicalization |= replacementInfo.type != opOperand->get().getType();
|
||||
}
|
||||
|
||||
// If the indexing maps of the result operation are not invertible (i.e. not
|
||||
|
|
Loading…
Reference in New Issue