[mlir] Prevent assertion failure in DropUnitDims

Don't assert fail on strided memrefs when dropping unit dims.
Instead just leave them unchanged.

Differential Revision: https://reviews.llvm.org/D108205
This commit is contained in:
Tres Popp 2021-08-17 15:28:26 +02:00
parent 0080d2aa55
commit 44485fcd97
2 changed files with 80 additions and 12 deletions

View File

@ -267,9 +267,9 @@ struct UnitExtentReplacementInfo {
/// - modified index map that can be used to access the replaced result/operand
/// - the reassociation that converts from the original tensor type to the
/// modified tensor type.
static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
OpOperand *opOperand,
MLIRContext *context) {
static llvm::Optional<UnitExtentReplacementInfo>
replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
MLIRContext *context) {
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
ArrayRef<AffineExpr> exprs = indexingMap.getResults();
@ -284,6 +284,14 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
return shape[dim] == 1 && exprs[dim] == zeroExpr;
};
// Early return for memrefs with affine maps to represent that we will always
// leave them unchanged.
Type actualType = opOperand->get().getType();
if (auto memref = actualType.dyn_cast<MemRefType>()) {
if (!memref.getAffineMaps().empty())
return llvm::None;
}
int64_t dim = 0;
// Fold dimensions that are unit-extent at the beginning of the tensor.
while (dim < origRank && isUnitExtent(dim))
@ -302,8 +310,8 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
reassociations.clear();
++dim;
}
// Compute the tensor or scalar replacement type.
Type actualType = opOperand->get().getType();
Type elementType = getElementTypeOrSelf(opOperand->get());
Type replacementType;
if (elementType == opOperand->get().getType()) {
@ -311,8 +319,6 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
} else if (actualType.isa<RankedTensorType>()) {
replacementType = RankedTensorType::get(newShape, elementType);
} else if (actualType.isa<MemRefType>()) {
assert(actualType.cast<MemRefType>().getAffineMaps().empty() &&
"unsupported strided memrefs");
replacementType = MemRefType::get(newShape, elementType);
}
assert(replacementType && "unsupported shaped type");
@ -390,12 +396,28 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
SmallVector<Type> newInputOutputTypes;
bool doCanonicalization = false;
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
UnitExtentReplacementInfo replacementInfo =
replaceUnitExtents(genericOp, opOperand, context);
reassociationMaps.push_back(replacementInfo.reassociation);
newIndexingMaps.push_back(replacementInfo.indexMap);
newInputOutputTypes.push_back(replacementInfo.type);
doCanonicalization |= replacementInfo.type != opOperand->get().getType();
auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
if (replacementInfo) {
reassociationMaps.push_back(replacementInfo->reassociation);
newIndexingMaps.push_back(replacementInfo->indexMap);
newInputOutputTypes.push_back(replacementInfo->type);
doCanonicalization |=
replacementInfo->type != opOperand->get().getType();
} else {
// If replaceUnitExtents cannot handle this case, maintain the same
// type, indexing map, and create a set of mappings representing an
// identity matrix.
newInputOutputTypes.push_back(opOperand->get().getType());
newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
int64_t origRank = genericOp.getRank(opOperand);
auto maps = llvm::to_vector<8>(llvm::map_range(
llvm::seq<int64_t>(0, origRank), [&](int64_t dim) -> Attribute {
return AffineMapAttr::get(
AffineMap::get(origRank, /*symbolCount = */ 0,
getAffineDimExpr(dim, context), context));
}));
reassociationMaps.push_back(ArrayAttr::get(context, maps));
}
}
// If the indexing maps of the result operation are not invertible (i.e. not

View File

@ -750,4 +750,50 @@ func @fold_unit_dim_for_init_memref(%input: memref<1x1000xf32>) -> memref<1xf32>
// CHECK: return %[[INIT:.+]] : memref<1xf32>
// -----
// Test that nothing changes and no assertions are fired for memrefs with affine
// maps while still changing the other operations.
#map0 = affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 + d2)>
#accesses = [
affine_map<(i, j, k, l, m) -> (i, k, m)>,
affine_map<(i, j, k, l, m) -> ()>,
affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
]
#trait = {
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
indexing_maps = #accesses,
library_call = "some_external_func"
}
func @input_stays_same(%arg0 : memref<?x1x?xf32, #map0>, %arg1 : f32, %shape: memref<?x1x?x1x?xf32>) -> memref<?x1x?x1x?xf32> {
linalg.generic #trait
ins(%arg0, %arg1 : memref<?x1x?xf32, #map0>, f32)
outs(%shape : memref<?x1x?x1x?xf32>) {
^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32) :
linalg.yield %arg3 : f32
}
return %shape : memref<?x1x?x1x?xf32>
}
// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 + d2)>
// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)>
// CHECK: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
// CHECK: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: builtin.func @input_stays_same(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x1x?xf32, #[[MAP0]]>,
// CHECK-SAME: %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<?x1x?x1x?xf32>)
// CHECK-SAME -> memref<?x1x?x1x?xf32> {
// CHECK: %[[OUT:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1], [2, 3], [4]]
// CHECK-SAME: : memref<?x1x?x1x?xf32> into memref<?x?x?xf32>
// CHECK: linalg.generic
// CHECK-SAME: {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : memref<?x1x?xf32, #[[MAP0]]>, f32)
// CHECK-SAME: outs(%[[OUT]] : memref<?x?x?xf32>) {
// CHECK: ^bb0(%{{.*}}: f32, %[[ARG:.*]]: f32, %{{.*}}: f32): // no predecessors
// CHECK: linalg.yield %[[ARG]] : f32
// CHECK: }
// CHECK: return %[[ARG2]] : memref<?x1x?x1x?xf32>