forked from OSchip/llvm-project
[mlir] Fix subview verifier.
The subview verifier in the rank-reduced case is plainly skipping verification when the resulting type is a memref with empty affine map. This is generally incorrect. Instead, form the actual expected rank-reduced MemRefType that takes into account the projections of 1's dimensions. Then, check the canonicalized expected rank-reduced type against the canonicalized candidate type. Differential Revision: https://reviews.llvm.org/D95316
This commit is contained in:
parent
0024efc69e
commit
7e6fe5c48a
|
@ -212,7 +212,6 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
|
|||
return LLVM::LLVMPointerType::get(converted);
|
||||
}
|
||||
|
||||
|
||||
// Function types are converted to LLVM Function types by recursively converting
|
||||
// argument and result types. If MLIR Function has zero results, the LLVM
|
||||
// Function has one VoidType result. If MLIR Function has more than one result,
|
||||
|
@ -525,10 +524,11 @@ MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
|
|||
auto result = getStridesAndOffset(type, strides, offset);
|
||||
(void)result;
|
||||
assert(succeeded(result) && "unexpected failure in stride computation");
|
||||
assert(offset != MemRefType::getDynamicStrideOrOffset() &&
|
||||
assert(!MemRefType::isDynamicStrideOrOffset(offset) &&
|
||||
"expected static offset");
|
||||
assert(!llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) &&
|
||||
"expected static strides");
|
||||
assert(!llvm::any_of(strides, [](int64_t stride) {
|
||||
return MemRefType::isDynamicStrideOrOffset(stride);
|
||||
}) && "expected static strides");
|
||||
|
||||
auto convertedType = typeConverter.convertType(type);
|
||||
assert(convertedType && "unexpected failure in memref type conversion");
|
||||
|
@ -1044,14 +1044,14 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
|
|||
|
||||
Value index;
|
||||
if (offset != 0) // Skip if offset is zero.
|
||||
index = offset == MemRefType::getDynamicStrideOrOffset()
|
||||
index = MemRefType::isDynamicStrideOrOffset(offset)
|
||||
? memRefDescriptor.offset(rewriter, loc)
|
||||
: createIndexConstant(rewriter, loc, offset);
|
||||
|
||||
for (int i = 0, e = indices.size(); i < e; ++i) {
|
||||
Value increment = indices[i];
|
||||
if (strides[i] != 1) { // Skip if stride is 1.
|
||||
Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset()
|
||||
Value stride = MemRefType::isDynamicStrideOrOffset(strides[i])
|
||||
? memRefDescriptor.stride(rewriter, loc, i)
|
||||
: createIndexConstant(rewriter, loc, strides[i]);
|
||||
increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
|
||||
|
@ -3308,7 +3308,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
|
|||
extracted);
|
||||
targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
|
||||
|
||||
// Copy the buffer pointer from the old descriptor to the new one.
|
||||
// Copy the aligned pointer from the old descriptor to the new one.
|
||||
extracted = sourceMemRef.alignedPtr(rewriter, loc);
|
||||
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
|
||||
loc,
|
||||
|
@ -3487,7 +3487,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
|
|||
ArrayRef<int64_t> strides, Value nextSize,
|
||||
Value runningStride, unsigned idx) const {
|
||||
assert(idx < strides.size());
|
||||
if (strides[idx] != MemRefType::getDynamicStrideOrOffset())
|
||||
if (!MemRefType::isDynamicStrideOrOffset(strides[idx]))
|
||||
return createIndexConstant(rewriter, loc, strides[idx]);
|
||||
if (nextSize)
|
||||
return runningStride
|
||||
|
|
|
@ -3078,29 +3078,32 @@ enum SubViewVerificationResult {
|
|||
/// This function is slight variant of `is subsequence` algorithm where
|
||||
/// not matching dimension must be 1.
|
||||
static SubViewVerificationResult isRankReducedType(Type originalType,
|
||||
Type reducedType) {
|
||||
if (originalType == reducedType)
|
||||
Type candidateReducedType) {
|
||||
if (originalType == candidateReducedType)
|
||||
return SubViewVerificationResult::Success;
|
||||
if (!originalType.isa<RankedTensorType>() && !originalType.isa<MemRefType>())
|
||||
return SubViewVerificationResult::Success;
|
||||
if (originalType.isa<RankedTensorType>() &&
|
||||
!reducedType.isa<RankedTensorType>())
|
||||
!candidateReducedType.isa<RankedTensorType>())
|
||||
return SubViewVerificationResult::Success;
|
||||
if (originalType.isa<MemRefType>() && !reducedType.isa<MemRefType>())
|
||||
if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>())
|
||||
return SubViewVerificationResult::Success;
|
||||
|
||||
ShapedType originalShapedType = originalType.cast<ShapedType>();
|
||||
ShapedType reducedShapedType = reducedType.cast<ShapedType>();
|
||||
ShapedType candidateReducedShapedType =
|
||||
candidateReducedType.cast<ShapedType>();
|
||||
|
||||
// Rank and size logic is valid for all ShapedTypes.
|
||||
ArrayRef<int64_t> originalShape = originalShapedType.getShape();
|
||||
ArrayRef<int64_t> reducedShape = reducedShapedType.getShape();
|
||||
ArrayRef<int64_t> candidateReducedShape =
|
||||
candidateReducedShapedType.getShape();
|
||||
unsigned originalRank = originalShape.size(),
|
||||
reducedRank = reducedShape.size();
|
||||
if (reducedRank > originalRank)
|
||||
candidateReducedRank = candidateReducedShape.size();
|
||||
if (candidateReducedRank > originalRank)
|
||||
return SubViewVerificationResult::RankTooLarge;
|
||||
|
||||
auto optionalMask = computeRankReductionMask(originalShape, reducedShape);
|
||||
auto optionalMask =
|
||||
computeRankReductionMask(originalShape, candidateReducedShape);
|
||||
|
||||
// Sizes cannot be matched in case empty vector is returned.
|
||||
if (!optionalMask.hasValue())
|
||||
|
@ -3112,34 +3115,43 @@ static SubViewVerificationResult isRankReducedType(Type originalType,
|
|||
|
||||
// Strided layout logic is relevant for MemRefType only.
|
||||
MemRefType original = originalType.cast<MemRefType>();
|
||||
MemRefType reduced = reducedType.cast<MemRefType>();
|
||||
MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
|
||||
MLIRContext *c = original.getContext();
|
||||
int64_t originalOffset, reducedOffset;
|
||||
SmallVector<int64_t, 4> originalStrides, reducedStrides, keepStrides;
|
||||
int64_t originalOffset, candidateReducedOffset;
|
||||
SmallVector<int64_t, 4> originalStrides, candidateReducedStrides, keepStrides;
|
||||
SmallVector<bool, 4> keepMask = optionalMask.getValue();
|
||||
getStridesAndOffset(original, originalStrides, originalOffset);
|
||||
getStridesAndOffset(reduced, reducedStrides, reducedOffset);
|
||||
getStridesAndOffset(candidateReduced, candidateReducedStrides,
|
||||
candidateReducedOffset);
|
||||
|
||||
// Filter strides based on the mask and check that they are the same
|
||||
// as reduced ones.
|
||||
unsigned reducedIdx = 0;
|
||||
// as candidateReduced ones.
|
||||
unsigned candidateReducedIdx = 0;
|
||||
for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
|
||||
if (keepMask[originalIdx]) {
|
||||
if (originalStrides[originalIdx] != reducedStrides[reducedIdx++])
|
||||
if (originalStrides[originalIdx] !=
|
||||
candidateReducedStrides[candidateReducedIdx++])
|
||||
return SubViewVerificationResult::StrideMismatch;
|
||||
keepStrides.push_back(originalStrides[originalIdx]);
|
||||
}
|
||||
}
|
||||
|
||||
if (original.getElementType() != reduced.getElementType())
|
||||
if (original.getElementType() != candidateReduced.getElementType())
|
||||
return SubViewVerificationResult::ElemTypeMismatch;
|
||||
|
||||
if (original.getMemorySpace() != reduced.getMemorySpace())
|
||||
if (original.getMemorySpace() != candidateReduced.getMemorySpace())
|
||||
return SubViewVerificationResult::MemSpaceMismatch;
|
||||
|
||||
// reducedMap is obtained by projecting away the dimensions inferred from
|
||||
// matching the 1's positions in candidateReducedType.
|
||||
auto reducedMap = makeStridedLinearLayoutMap(keepStrides, originalOffset, c);
|
||||
if (!reduced.getAffineMaps().empty() &&
|
||||
reducedMap != reduced.getAffineMaps().front())
|
||||
|
||||
MemRefType expectedReducedType = MemRefType::get(
|
||||
candidateReduced.getShape(), candidateReduced.getElementType(),
|
||||
reducedMap, candidateReduced.getMemorySpace());
|
||||
expectedReducedType = canonicalizeStridedLayout(expectedReducedType);
|
||||
|
||||
if (expectedReducedType != canonicalizeStridedLayout(candidateReduced))
|
||||
return SubViewVerificationResult::AffineMapMismatch;
|
||||
|
||||
return SubViewVerificationResult::Success;
|
||||
|
|
|
@ -745,12 +745,20 @@ MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
|
|||
if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1)
|
||||
return t;
|
||||
|
||||
// Corner-case for 0-D affine maps.
|
||||
auto m = affineMaps[0];
|
||||
if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
|
||||
if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
|
||||
if (cst.getValue() == 0)
|
||||
return MemRefType::Builder(t).setAffineMaps({});
|
||||
return t;
|
||||
}
|
||||
|
||||
// If the canonical strided layout for the sizes of `t` is equal to the
|
||||
// simplified layout of `t` we can just return an empty layout. Otherwise,
|
||||
// just simplify the existing layout.
|
||||
AffineExpr expr =
|
||||
makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
|
||||
auto m = affineMaps[0];
|
||||
auto simplifiedLayoutExpr =
|
||||
simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
|
||||
if (expr != simplifiedLayoutExpr)
|
||||
|
|
|
@ -1011,6 +1011,16 @@ func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index)
|
|||
|
||||
// -----
|
||||
|
||||
func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||
%0 = alloc() : memref<8x16x4xf32>
|
||||
// expected-error@+1 {{expected result type to be 'memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 8)>>' or a rank-reduced version. (mismatch of result sizes)}}
|
||||
%1 = subview %0[0, 2, 0][8, 16, 4][1, 1, 1]
|
||||
: memref<8x16x4xf32> to memref<16x4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @invalid_rank_reducing_subview(%arg0 : memref<?x?xf32>, %arg1 : index, %arg2 : index) {
|
||||
// expected-error@+1 {{expected result type to be 'memref<?x1xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result strides)}}
|
||||
%0 = subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref<?x?xf32> to memref<?xf32>
|
||||
|
|
Loading…
Reference in New Issue