forked from OSchip/llvm-project
[mlir] Prevent assertion failure in DropUnitDims
Don't assert fail on strided memrefs when dropping unit dims. Instead just leave them unchanged. Differential Revision: https://reviews.llvm.org/D108205
This commit is contained in:
parent
0080d2aa55
commit
44485fcd97
|
@ -267,9 +267,9 @@ struct UnitExtentReplacementInfo {
|
||||||
/// - modified index map that can be used to access the replaced result/operand
|
/// - modified index map that can be used to access the replaced result/operand
|
||||||
/// - the reassociation that converts from the original tensor type to the
|
/// - the reassociation that converts from the original tensor type to the
|
||||||
/// modified tensor type.
|
/// modified tensor type.
|
||||||
static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
|
static llvm::Optional<UnitExtentReplacementInfo>
|
||||||
OpOperand *opOperand,
|
replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
|
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
|
||||||
ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
|
ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
|
||||||
ArrayRef<AffineExpr> exprs = indexingMap.getResults();
|
ArrayRef<AffineExpr> exprs = indexingMap.getResults();
|
||||||
|
@ -284,6 +284,14 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
|
||||||
return shape[dim] == 1 && exprs[dim] == zeroExpr;
|
return shape[dim] == 1 && exprs[dim] == zeroExpr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Early return for memrefs with affine maps to represent that we will always
|
||||||
|
// leave them unchanged.
|
||||||
|
Type actualType = opOperand->get().getType();
|
||||||
|
if (auto memref = actualType.dyn_cast<MemRefType>()) {
|
||||||
|
if (!memref.getAffineMaps().empty())
|
||||||
|
return llvm::None;
|
||||||
|
}
|
||||||
|
|
||||||
int64_t dim = 0;
|
int64_t dim = 0;
|
||||||
// Fold dimensions that are unit-extent at the beginning of the tensor.
|
// Fold dimensions that are unit-extent at the beginning of the tensor.
|
||||||
while (dim < origRank && isUnitExtent(dim))
|
while (dim < origRank && isUnitExtent(dim))
|
||||||
|
@ -302,8 +310,8 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
|
||||||
reassociations.clear();
|
reassociations.clear();
|
||||||
++dim;
|
++dim;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute the tensor or scalar replacement type.
|
// Compute the tensor or scalar replacement type.
|
||||||
Type actualType = opOperand->get().getType();
|
|
||||||
Type elementType = getElementTypeOrSelf(opOperand->get());
|
Type elementType = getElementTypeOrSelf(opOperand->get());
|
||||||
Type replacementType;
|
Type replacementType;
|
||||||
if (elementType == opOperand->get().getType()) {
|
if (elementType == opOperand->get().getType()) {
|
||||||
|
@ -311,8 +319,6 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
|
||||||
} else if (actualType.isa<RankedTensorType>()) {
|
} else if (actualType.isa<RankedTensorType>()) {
|
||||||
replacementType = RankedTensorType::get(newShape, elementType);
|
replacementType = RankedTensorType::get(newShape, elementType);
|
||||||
} else if (actualType.isa<MemRefType>()) {
|
} else if (actualType.isa<MemRefType>()) {
|
||||||
assert(actualType.cast<MemRefType>().getAffineMaps().empty() &&
|
|
||||||
"unsupported strided memrefs");
|
|
||||||
replacementType = MemRefType::get(newShape, elementType);
|
replacementType = MemRefType::get(newShape, elementType);
|
||||||
}
|
}
|
||||||
assert(replacementType && "unsupported shaped type");
|
assert(replacementType && "unsupported shaped type");
|
||||||
|
@ -390,12 +396,28 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
|
||||||
SmallVector<Type> newInputOutputTypes;
|
SmallVector<Type> newInputOutputTypes;
|
||||||
bool doCanonicalization = false;
|
bool doCanonicalization = false;
|
||||||
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
|
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
|
||||||
UnitExtentReplacementInfo replacementInfo =
|
auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
|
||||||
replaceUnitExtents(genericOp, opOperand, context);
|
if (replacementInfo) {
|
||||||
reassociationMaps.push_back(replacementInfo.reassociation);
|
reassociationMaps.push_back(replacementInfo->reassociation);
|
||||||
newIndexingMaps.push_back(replacementInfo.indexMap);
|
newIndexingMaps.push_back(replacementInfo->indexMap);
|
||||||
newInputOutputTypes.push_back(replacementInfo.type);
|
newInputOutputTypes.push_back(replacementInfo->type);
|
||||||
doCanonicalization |= replacementInfo.type != opOperand->get().getType();
|
doCanonicalization |=
|
||||||
|
replacementInfo->type != opOperand->get().getType();
|
||||||
|
} else {
|
||||||
|
// If replaceUnitExtents cannot handle this case, maintain the same
|
||||||
|
// type, indexing map, and create a set of mappings representing an
|
||||||
|
// identity matrix.
|
||||||
|
newInputOutputTypes.push_back(opOperand->get().getType());
|
||||||
|
newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
|
||||||
|
int64_t origRank = genericOp.getRank(opOperand);
|
||||||
|
auto maps = llvm::to_vector<8>(llvm::map_range(
|
||||||
|
llvm::seq<int64_t>(0, origRank), [&](int64_t dim) -> Attribute {
|
||||||
|
return AffineMapAttr::get(
|
||||||
|
AffineMap::get(origRank, /*symbolCount = */ 0,
|
||||||
|
getAffineDimExpr(dim, context), context));
|
||||||
|
}));
|
||||||
|
reassociationMaps.push_back(ArrayAttr::get(context, maps));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the indexing maps of the result operation are not invertible (i.e. not
|
// If the indexing maps of the result operation are not invertible (i.e. not
|
||||||
|
|
|
@ -750,4 +750,50 @@ func @fold_unit_dim_for_init_memref(%input: memref<1x1000xf32>) -> memref<1xf32>
|
||||||
// CHECK: return %[[INIT:.+]] : memref<1xf32>
|
// CHECK: return %[[INIT:.+]] : memref<1xf32>
|
||||||
|
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// Test that nothing changes and no assertions are fired for memrefs with affine
|
||||||
|
// maps while still changing the other operations.
|
||||||
|
|
||||||
|
#map0 = affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 + d2)>
|
||||||
|
|
||||||
|
#accesses = [
|
||||||
|
affine_map<(i, j, k, l, m) -> (i, k, m)>,
|
||||||
|
affine_map<(i, j, k, l, m) -> ()>,
|
||||||
|
affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
|
||||||
|
]
|
||||||
|
|
||||||
|
#trait = {
|
||||||
|
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
|
||||||
|
indexing_maps = #accesses,
|
||||||
|
library_call = "some_external_func"
|
||||||
|
}
|
||||||
|
|
||||||
|
func @input_stays_same(%arg0 : memref<?x1x?xf32, #map0>, %arg1 : f32, %shape: memref<?x1x?x1x?xf32>) -> memref<?x1x?x1x?xf32> {
|
||||||
|
linalg.generic #trait
|
||||||
|
ins(%arg0, %arg1 : memref<?x1x?xf32, #map0>, f32)
|
||||||
|
outs(%shape : memref<?x1x?x1x?xf32>) {
|
||||||
|
^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32) :
|
||||||
|
linalg.yield %arg3 : f32
|
||||||
|
}
|
||||||
|
return %shape : memref<?x1x?x1x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 + d2)>
|
||||||
|
// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)>
|
||||||
|
// CHECK: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
|
||||||
|
// CHECK: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||||
|
// CHECK: builtin.func @input_stays_same(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: memref<?x1x?xf32, #[[MAP0]]>,
|
||||||
|
// CHECK-SAME: %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<?x1x?x1x?xf32>)
|
||||||
|
// CHECK-SAME -> memref<?x1x?x1x?xf32> {
|
||||||
|
// CHECK: %[[OUT:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1], [2, 3], [4]]
|
||||||
|
// CHECK-SAME: : memref<?x1x?x1x?xf32> into memref<?x?x?xf32>
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-SAME: {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]],
|
||||||
|
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
|
||||||
|
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : memref<?x1x?xf32, #[[MAP0]]>, f32)
|
||||||
|
// CHECK-SAME: outs(%[[OUT]] : memref<?x?x?xf32>) {
|
||||||
|
// CHECK: ^bb0(%{{.*}}: f32, %[[ARG:.*]]: f32, %{{.*}}: f32): // no predecessors
|
||||||
|
// CHECK: linalg.yield %[[ARG]] : f32
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: return %[[ARG2]] : memref<?x1x?x1x?xf32>
|
||||||
|
|
Loading…
Reference in New Issue