forked from OSchip/llvm-project
[mlir] Prevent assertion failure in DropUnitDims
Don't assert fail on strided memrefs when dropping unit dims. Instead just leave them unchanged. Differential Revision: https://reviews.llvm.org/D108205
This commit is contained in:
parent
0080d2aa55
commit
44485fcd97
|
@ -267,9 +267,9 @@ struct UnitExtentReplacementInfo {
|
|||
/// - modified index map that can be used to access the replaced result/operand
|
||||
/// - the reassociation that converts from the original tensor type to the
|
||||
/// modified tensor type.
|
||||
static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
|
||||
OpOperand *opOperand,
|
||||
MLIRContext *context) {
|
||||
static llvm::Optional<UnitExtentReplacementInfo>
|
||||
replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
|
||||
MLIRContext *context) {
|
||||
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
|
||||
ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
|
||||
ArrayRef<AffineExpr> exprs = indexingMap.getResults();
|
||||
|
@ -284,6 +284,14 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
|
|||
return shape[dim] == 1 && exprs[dim] == zeroExpr;
|
||||
};
|
||||
|
||||
// Early return for memrefs with affine maps to represent that we will always
|
||||
// leave them unchanged.
|
||||
Type actualType = opOperand->get().getType();
|
||||
if (auto memref = actualType.dyn_cast<MemRefType>()) {
|
||||
if (!memref.getAffineMaps().empty())
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
int64_t dim = 0;
|
||||
// Fold dimensions that are unit-extent at the beginning of the tensor.
|
||||
while (dim < origRank && isUnitExtent(dim))
|
||||
|
@ -302,8 +310,8 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
|
|||
reassociations.clear();
|
||||
++dim;
|
||||
}
|
||||
|
||||
// Compute the tensor or scalar replacement type.
|
||||
Type actualType = opOperand->get().getType();
|
||||
Type elementType = getElementTypeOrSelf(opOperand->get());
|
||||
Type replacementType;
|
||||
if (elementType == opOperand->get().getType()) {
|
||||
|
@ -311,8 +319,6 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
|
|||
} else if (actualType.isa<RankedTensorType>()) {
|
||||
replacementType = RankedTensorType::get(newShape, elementType);
|
||||
} else if (actualType.isa<MemRefType>()) {
|
||||
assert(actualType.cast<MemRefType>().getAffineMaps().empty() &&
|
||||
"unsupported strided memrefs");
|
||||
replacementType = MemRefType::get(newShape, elementType);
|
||||
}
|
||||
assert(replacementType && "unsupported shaped type");
|
||||
|
@ -390,12 +396,28 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
|
|||
SmallVector<Type> newInputOutputTypes;
|
||||
bool doCanonicalization = false;
|
||||
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
|
||||
UnitExtentReplacementInfo replacementInfo =
|
||||
replaceUnitExtents(genericOp, opOperand, context);
|
||||
reassociationMaps.push_back(replacementInfo.reassociation);
|
||||
newIndexingMaps.push_back(replacementInfo.indexMap);
|
||||
newInputOutputTypes.push_back(replacementInfo.type);
|
||||
doCanonicalization |= replacementInfo.type != opOperand->get().getType();
|
||||
auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
|
||||
if (replacementInfo) {
|
||||
reassociationMaps.push_back(replacementInfo->reassociation);
|
||||
newIndexingMaps.push_back(replacementInfo->indexMap);
|
||||
newInputOutputTypes.push_back(replacementInfo->type);
|
||||
doCanonicalization |=
|
||||
replacementInfo->type != opOperand->get().getType();
|
||||
} else {
|
||||
// If replaceUnitExtents cannot handle this case, maintain the same
|
||||
// type, indexing map, and create a set of mappings representing an
|
||||
// identity matrix.
|
||||
newInputOutputTypes.push_back(opOperand->get().getType());
|
||||
newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
|
||||
int64_t origRank = genericOp.getRank(opOperand);
|
||||
auto maps = llvm::to_vector<8>(llvm::map_range(
|
||||
llvm::seq<int64_t>(0, origRank), [&](int64_t dim) -> Attribute {
|
||||
return AffineMapAttr::get(
|
||||
AffineMap::get(origRank, /*symbolCount = */ 0,
|
||||
getAffineDimExpr(dim, context), context));
|
||||
}));
|
||||
reassociationMaps.push_back(ArrayAttr::get(context, maps));
|
||||
}
|
||||
}
|
||||
|
||||
// If the indexing maps of the result operation are not invertible (i.e. not
|
||||
|
|
|
@ -750,4 +750,50 @@ func @fold_unit_dim_for_init_memref(%input: memref<1x1000xf32>) -> memref<1xf32>
|
|||
// CHECK: return %[[INIT:.+]] : memref<1xf32>
|
||||
|
||||
|
||||
// -----
|
||||
// Test that nothing changes and no assertions are fired for memrefs with affine
|
||||
// maps while still changing the other operations.
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 + d2)>
|
||||
|
||||
#accesses = [
|
||||
affine_map<(i, j, k, l, m) -> (i, k, m)>,
|
||||
affine_map<(i, j, k, l, m) -> ()>,
|
||||
affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
|
||||
]
|
||||
|
||||
#trait = {
|
||||
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
|
||||
indexing_maps = #accesses,
|
||||
library_call = "some_external_func"
|
||||
}
|
||||
|
||||
func @input_stays_same(%arg0 : memref<?x1x?xf32, #map0>, %arg1 : f32, %shape: memref<?x1x?x1x?xf32>) -> memref<?x1x?x1x?xf32> {
|
||||
linalg.generic #trait
|
||||
ins(%arg0, %arg1 : memref<?x1x?xf32, #map0>, f32)
|
||||
outs(%shape : memref<?x1x?x1x?xf32>) {
|
||||
^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32) :
|
||||
linalg.yield %arg3 : f32
|
||||
}
|
||||
return %shape : memref<?x1x?x1x?xf32>
|
||||
}
|
||||
|
||||
// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 + d2)>
|
||||
// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)>
|
||||
// CHECK: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
|
||||
// CHECK: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK: builtin.func @input_stays_same(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<?x1x?xf32, #[[MAP0]]>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<?x1x?x1x?xf32>)
|
||||
// CHECK-SAME -> memref<?x1x?x1x?xf32> {
|
||||
// CHECK: %[[OUT:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1], [2, 3], [4]]
|
||||
// CHECK-SAME: : memref<?x1x?x1x?xf32> into memref<?x?x?xf32>
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]],
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : memref<?x1x?xf32, #[[MAP0]]>, f32)
|
||||
// CHECK-SAME: outs(%[[OUT]] : memref<?x?x?xf32>) {
|
||||
// CHECK: ^bb0(%{{.*}}: f32, %[[ARG:.*]]: f32, %{{.*}}: f32): // no predecessors
|
||||
// CHECK: linalg.yield %[[ARG]] : f32
|
||||
// CHECK: }
|
||||
// CHECK: return %[[ARG2]] : memref<?x1x?x1x?xf32>
|
||||
|
|
Loading…
Reference in New Issue