s[mlir] Tighten computation of inferred SubView result type.

The AffineMap in the MemRef inferred by SubViewOp may have uncompressed symbols which result in type mismatch on otherwise unused symbols. Make the computation of the AffineMap compress those unused symbols which results in better canonical types.
Additionally, improve the error message to report which inferred type was expected.

Differential Revision: https://reviews.llvm.org/D96551
This commit is contained in:
Nicolas Vasilache 2021-02-11 22:26:49 +00:00
parent 9e62c9146d
commit 5bc4f8846c
12 changed files with 204 additions and 93 deletions

View File

@ -305,14 +305,14 @@ public:
};
/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
/// `originalShape` with some `1` entries erased, return the vector of booleans
/// that specifies which of the entries of `originalShape` are keep to obtain
/// `originalShape` with some `1` entries erased, return the set of indices
/// that specifies which of the entries of `originalShape` are dropped to obtain
/// `reducedShape`. The returned mask can be applied as a projection to
/// `originalShape` to obtain the `reducedShape`. This mask is useful to track
/// which dimensions must be kept when e.g. compute MemRef strides under
/// rank-reducing operations. Return None if reducedShape cannot be obtained
/// by dropping only `1` entries in `originalShape`.
llvm::Optional<SmallVector<bool, 4>>
llvm::Optional<llvm::SmallDenseSet<unsigned>>
computeRankReductionMask(ArrayRef<int64_t> originalShape,
ArrayRef<int64_t> reducedShape);

View File

@ -127,6 +127,12 @@ public:
AffineExpr replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
ArrayRef<AffineExpr> symReplacements) const;
/// Dim-only version of replaceDimsAndSymbols.
AffineExpr replaceDims(ArrayRef<AffineExpr> dimReplacements) const;
/// Symbol-only version of replaceDimsAndSymbols.
AffineExpr replaceSymbols(ArrayRef<AffineExpr> symReplacements) const;
/// Sparse replace method. Replace `expr` by `replacement` and return the
/// modified expression tree.
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const;

View File

@ -18,6 +18,7 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/DenseSet.h"
namespace mlir {
@ -311,6 +312,20 @@ private:
/// Simplifies an affine map by simplifying its underlying AffineExpr results.
AffineMap simplifyAffineMap(AffineMap map);
/// Drop the dims that are not used.
AffineMap compressUnusedDims(AffineMap map);
/// Drop the dims that are not listed in `unusedDims`.
AffineMap compressDims(AffineMap map,
const llvm::SmallDenseSet<unsigned> &unusedDims);
/// Drop the symbols that are not used.
AffineMap compressUnusedSymbols(AffineMap map);
/// Drop the symbols that are not listed in `unusedSymbols`.
AffineMap compressSymbols(AffineMap map,
const llvm::SmallDenseSet<unsigned> &unusedSymbols);
/// Returns a map with the same dimension and symbol count as `map`, but whose
/// results are the unique affine expressions of `map`.
AffineMap removeDuplicateExprs(AffineMap map);
@ -390,8 +405,11 @@ AffineMap concatAffineMaps(ArrayRef<AffineMap> maps);
/// 3) map : affine_map<(d0, d1, d2) -> (d0, d1)>
/// projected_dimensions : {1}
/// result : affine_map<(d0, d1) -> (d0, 0)>
AffineMap getProjectedMap(AffineMap map,
ArrayRef<unsigned> projectedDimensions);
///
/// This function also compresses unused symbols away.
AffineMap
getProjectedMap(AffineMap map,
const llvm::SmallDenseSet<unsigned> &projectedDimensions);
inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
map.print(os);
@ -402,7 +420,8 @@ inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
namespace llvm {
// AffineExpr hash just like pointers
template <> struct DenseMapInfo<mlir::AffineMap> {
template <>
struct DenseMapInfo<mlir::AffineMap> {
static mlir::AffineMap getEmptyKey() {
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::AffineMap(static_cast<mlir::AffineMap::ImplType *>(pointer));

View File

@ -566,6 +566,10 @@ AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
/// Return true if the layout for `t` is compatible with strided semantics.
bool isStrided(MemRefType t);
/// Return the layout map in strided linear layout AffineMap form.
/// Return null if the layout is not compatible with a strided layout.
AffineMap getStridedLinearLayoutMap(MemRefType t);
} // end namespace mlir
#endif // MLIR_IR_BUILTINTYPES_H

View File

@ -3277,7 +3277,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
auto inferredShape = inferredType.getShape();
size_t inferredShapeRank = inferredShape.size();
size_t resultShapeRank = shape.size();
SmallVector<bool, 4> mask =
llvm::SmallDenseSet<unsigned> unusedDims =
computeRankReductionMask(inferredShape, shape).getValue();
// Extract strides needed to compute offset.
@ -3318,7 +3318,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
"expected sizes and strides of equal length");
for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
i >= 0 && j >= 0; --i) {
if (!mask[i])
if (unusedDims.contains(i))
continue;
// `i` may overflow subViewOp.getMixedSizes because of trailing semantics.

View File

@ -536,10 +536,10 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
/// Prune all dimensions that are of reduction iterator type from `map`.
static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
AffineMap map) {
SmallVector<unsigned, 2> projectedDims;
llvm::SmallDenseSet<unsigned> projectedDims;
for (auto attr : llvm::enumerate(iteratorTypes)) {
if (!isParallelIterator(attr.value()))
projectedDims.push_back(attr.index());
projectedDims.insert(attr.index());
}
return getProjectedMap(map, projectedDims);
}

View File

@ -2957,35 +2957,44 @@ void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
/// For ViewLikeOpInterface.
Value SubViewOp::getViewSource() { return source(); }
llvm::Optional<SmallVector<bool, 4>>
/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
/// `originalShape` with some `1` entries erased, return the set of indices
/// that specifies which of the entries of `originalShape` are dropped to obtain
/// `reducedShape`. The returned mask can be applied as a projection to
/// `originalShape` to obtain the `reducedShape`. This mask is useful to track
/// which dimensions must be kept when e.g. compute MemRef strides under
/// rank-reducing operations. Return None if reducedShape cannot be obtained
/// by dropping only `1` entries in `originalShape`.
llvm::Optional<llvm::SmallDenseSet<unsigned>>
mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
ArrayRef<int64_t> reducedShape) {
size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
SmallVector<bool, 4> mask(originalRank);
llvm::SmallDenseSet<unsigned> unusedDims;
unsigned reducedIdx = 0;
for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
// Skip matching dims greedily.
mask[originalIdx] =
(reducedIdx < reducedRank) &&
(originalShape[originalIdx] == reducedShape[reducedIdx]);
if (mask[originalIdx])
// Greedily insert `originalIdx` if no match.
if (reducedIdx < reducedRank &&
originalShape[originalIdx] == reducedShape[reducedIdx]) {
reducedIdx++;
// 1 is the only non-matching allowed.
else if (originalShape[originalIdx] != 1)
return {};
continue;
}
unusedDims.insert(originalIdx);
// If no match on `originalIdx`, the `originalShape` at this dimension
// must be 1, otherwise we bail.
if (originalShape[originalIdx] != 1)
return llvm::None;
}
// The whole reducedShape must be scanned, otherwise we bail.
if (reducedIdx != reducedRank)
return {};
return mask;
return llvm::None;
return unusedDims;
}
enum SubViewVerificationResult {
Success,
RankTooLarge,
SizeMismatch,
StrideMismatch,
ElemTypeMismatch,
MemSpaceMismatch,
AffineMapMismatch
@ -2994,8 +3003,9 @@ enum SubViewVerificationResult {
/// Checks if `original` Type type can be rank reduced to `reduced` type.
/// This function is slight variant of `is subsequence` algorithm where
/// not matching dimension must be 1.
static SubViewVerificationResult isRankReducedType(Type originalType,
Type candidateReducedType) {
static SubViewVerificationResult
isRankReducedType(Type originalType, Type candidateReducedType,
std::string *errMsg = nullptr) {
if (originalType == candidateReducedType)
return SubViewVerificationResult::Success;
if (!originalType.isa<RankedTensorType>() && !originalType.isa<MemRefType>())
@ -3019,13 +3029,17 @@ static SubViewVerificationResult isRankReducedType(Type originalType,
if (candidateReducedRank > originalRank)
return SubViewVerificationResult::RankTooLarge;
auto optionalMask =
auto optionalUnusedDimsMask =
computeRankReductionMask(originalShape, candidateReducedShape);
// Sizes cannot be matched in case empty vector is returned.
if (!optionalMask.hasValue())
if (!optionalUnusedDimsMask.hasValue())
return SubViewVerificationResult::SizeMismatch;
if (originalShapedType.getElementType() !=
candidateReducedShapedType.getElementType())
return SubViewVerificationResult::ElemTypeMismatch;
// We are done for the tensor case.
if (originalType.isa<RankedTensorType>())
return SubViewVerificationResult::Success;
@ -3033,74 +3047,54 @@ static SubViewVerificationResult isRankReducedType(Type originalType,
// Strided layout logic is relevant for MemRefType only.
MemRefType original = originalType.cast<MemRefType>();
MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
MLIRContext *c = original.getContext();
int64_t originalOffset, candidateReducedOffset;
SmallVector<int64_t, 4> originalStrides, candidateReducedStrides, keepStrides;
SmallVector<bool, 4> keepMask = optionalMask.getValue();
(void)getStridesAndOffset(original, originalStrides, originalOffset);
(void)getStridesAndOffset(candidateReduced, candidateReducedStrides,
candidateReducedOffset);
// Filter strides based on the mask and check that they are the same
// as candidateReduced ones.
unsigned candidateReducedIdx = 0;
for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
if (keepMask[originalIdx]) {
if (originalStrides[originalIdx] !=
candidateReducedStrides[candidateReducedIdx++])
return SubViewVerificationResult::StrideMismatch;
keepStrides.push_back(originalStrides[originalIdx]);
}
}
if (original.getElementType() != candidateReduced.getElementType())
return SubViewVerificationResult::ElemTypeMismatch;
if (original.getMemorySpace() != candidateReduced.getMemorySpace())
return SubViewVerificationResult::MemSpaceMismatch;
// reducedMap is obtained by projecting away the dimensions inferred from
// matching the 1's positions in candidateReducedType.
auto reducedMap = makeStridedLinearLayoutMap(keepStrides, originalOffset, c);
MemRefType expectedReducedType = MemRefType::get(
candidateReduced.getShape(), candidateReduced.getElementType(),
reducedMap, candidateReduced.getMemorySpace());
expectedReducedType = canonicalizeStridedLayout(expectedReducedType);
if (expectedReducedType != canonicalizeStridedLayout(candidateReduced))
llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue();
auto inferredType =
getProjectedMap(getStridedLinearLayoutMap(original), unusedDims);
AffineMap candidateLayout;
if (candidateReduced.getAffineMaps().empty())
candidateLayout = getStridedLinearLayoutMap(candidateReduced);
else
candidateLayout = candidateReduced.getAffineMaps().front();
if (inferredType != candidateLayout) {
if (errMsg) {
llvm::raw_string_ostream os(*errMsg);
os << "inferred type: " << inferredType;
}
return SubViewVerificationResult::AffineMapMismatch;
}
return SubViewVerificationResult::Success;
}
template <typename OpTy>
static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result,
OpTy op, Type expectedType) {
OpTy op, Type expectedType,
StringRef errMsg = "") {
auto memrefType = expectedType.cast<ShapedType>();
switch (result) {
case SubViewVerificationResult::Success:
return success();
case SubViewVerificationResult::RankTooLarge:
return op.emitError("expected result rank to be smaller or equal to ")
<< "the source rank.";
<< "the source rank. " << errMsg;
case SubViewVerificationResult::SizeMismatch:
return op.emitError("expected result type to be ")
<< expectedType
<< " or a rank-reduced version. (mismatch of result sizes)";
case SubViewVerificationResult::StrideMismatch:
return op.emitError("expected result type to be ")
<< expectedType
<< " or a rank-reduced version. (mismatch of result strides)";
<< " or a rank-reduced version. (mismatch of result sizes) "
<< errMsg;
case SubViewVerificationResult::ElemTypeMismatch:
return op.emitError("expected result element type to be ")
<< memrefType.getElementType();
<< memrefType.getElementType() << errMsg;
case SubViewVerificationResult::MemSpaceMismatch:
return op.emitError("expected result and source memory spaces to match.");
return op.emitError("expected result and source memory spaces to match.")
<< errMsg;
case SubViewVerificationResult::AffineMapMismatch:
return op.emitError("expected result type to be ")
<< expectedType
<< " or a rank-reduced version. (mismatch of result affine map)";
<< " or a rank-reduced version. (mismatch of result affine map) "
<< errMsg;
}
llvm_unreachable("unexpected subview verification result");
}
@ -3126,8 +3120,9 @@ static LogicalResult verify(SubViewOp op) {
extractFromI64ArrayAttr(op.static_sizes()),
extractFromI64ArrayAttr(op.static_strides()));
auto result = isRankReducedType(expectedType, subViewType);
return produceSubViewErrorMsg(result, op, expectedType);
std::string errMsg;
auto result = isRankReducedType(expectedType, subViewType, &errMsg);
return produceSubViewErrorMsg(result, op, expectedType, errMsg);
}
raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) {

View File

@ -92,6 +92,15 @@ AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
llvm_unreachable("Unknown AffineExpr");
}
AffineExpr AffineExpr::replaceDims(ArrayRef<AffineExpr> dimReplacements) const {
return replaceDimsAndSymbols(dimReplacements, {});
}
AffineExpr
AffineExpr::replaceSymbols(ArrayRef<AffineExpr> symReplacements) const {
return replaceDimsAndSymbols({}, symReplacements);
}
/// Replace symbols[0 .. numDims - 1] by symbols[shift .. shift + numDims - 1].
AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift) const {
SmallVector<AffineExpr, 4> dims;

View File

@ -420,6 +420,71 @@ AffineMap AffineMap::getMinorSubMap(unsigned numResults) const {
llvm::seq<unsigned>(getNumResults() - numResults, getNumResults())));
}
AffineMap mlir::compressDims(AffineMap map,
const llvm::SmallDenseSet<unsigned> &unusedDims) {
unsigned numDims = 0;
SmallVector<AffineExpr> dimReplacements;
dimReplacements.reserve(map.getNumDims());
MLIRContext *context = map.getContext();
for (unsigned dim = 0, e = map.getNumDims(); dim < e; ++dim) {
if (unusedDims.contains(dim))
dimReplacements.push_back(getAffineConstantExpr(0, context));
else
dimReplacements.push_back(getAffineDimExpr(numDims++, context));
}
SmallVector<AffineExpr> resultExprs;
resultExprs.reserve(map.getNumResults());
for (auto e : map.getResults())
resultExprs.push_back(e.replaceDims(dimReplacements));
return AffineMap::get(numDims, map.getNumSymbols(), resultExprs, context);
}
AffineMap mlir::compressUnusedDims(AffineMap map) {
llvm::SmallDenseSet<unsigned> usedDims;
map.walkExprs([&](AffineExpr expr) {
if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
usedDims.insert(dimExpr.getPosition());
});
llvm::SmallDenseSet<unsigned> unusedDims;
for (unsigned d = 0, e = map.getNumDims(); d != e; ++d)
if (!usedDims.contains(d))
unusedDims.insert(d);
return compressDims(map, unusedDims);
}
AffineMap
mlir::compressSymbols(AffineMap map,
const llvm::SmallDenseSet<unsigned> &unusedSymbols) {
unsigned numSymbols = 0;
SmallVector<AffineExpr> symReplacements;
symReplacements.reserve(map.getNumSymbols());
MLIRContext *context = map.getContext();
for (unsigned sym = 0, e = map.getNumSymbols(); sym < e; ++sym) {
if (unusedSymbols.contains(sym))
symReplacements.push_back(getAffineConstantExpr(0, context));
else
symReplacements.push_back(getAffineSymbolExpr(numSymbols++, context));
}
SmallVector<AffineExpr> resultExprs;
resultExprs.reserve(map.getNumResults());
for (auto e : map.getResults())
resultExprs.push_back(e.replaceSymbols(symReplacements));
return AffineMap::get(map.getNumDims(), numSymbols, resultExprs, context);
}
AffineMap mlir::compressUnusedSymbols(AffineMap map) {
llvm::SmallDenseSet<unsigned> usedSymbols;
map.walkExprs([&](AffineExpr expr) {
if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
usedSymbols.insert(symExpr.getPosition());
});
llvm::SmallDenseSet<unsigned> unusedSymbols;
for (unsigned d = 0, e = map.getNumSymbols(); d != e; ++d)
if (!usedSymbols.contains(d))
unusedSymbols.insert(d);
return compressSymbols(map, unusedSymbols);
}
AffineMap mlir::simplifyAffineMap(AffineMap map) {
SmallVector<AffineExpr, 8> exprs;
for (auto e : map.getResults()) {
@ -480,20 +545,10 @@ AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
maps.front().getContext());
}
AffineMap mlir::getProjectedMap(AffineMap map,
ArrayRef<unsigned> projectedDimensions) {
DenseSet<unsigned> projectedDims(projectedDimensions.begin(),
projectedDimensions.end());
MLIRContext *context = map.getContext();
SmallVector<AffineExpr, 4> resultExprs;
for (auto dim : enumerate(llvm::seq<unsigned>(0, map.getNumDims()))) {
if (!projectedDims.count(dim.value()))
resultExprs.push_back(getAffineDimExpr(dim.index(), context));
else
resultExprs.push_back(getAffineConstantExpr(0, context));
}
return map.compose(AffineMap::get(
map.getNumDims() - projectedDimensions.size(), 0, resultExprs, context));
AffineMap
mlir::getProjectedMap(AffineMap map,
const llvm::SmallDenseSet<unsigned> &unusedDims) {
return compressUnusedSymbols(compressDims(map, unusedDims));
}
//===----------------------------------------------------------------------===//

View File

@ -829,7 +829,17 @@ AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
/// Return true if the layout for `t` is compatible with strided semantics.
bool mlir::isStrided(MemRefType t) {
int64_t offset;
SmallVector<int64_t, 4> stridesAndOffset;
auto res = getStridesAndOffset(t, stridesAndOffset, offset);
SmallVector<int64_t, 4> strides;
auto res = getStridesAndOffset(t, strides, offset);
return succeeded(res);
}
/// Return the layout map in strided linear layout AffineMap form.
/// Return null if the layout is not compatible with a strided layout.
AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(t, strides, offset)))
return AffineMap();
return makeStridedLinearLayoutMap(strides, offset, t.getContext());
}

View File

@ -812,6 +812,10 @@ func @memref_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
// CHECK: subview %{{.*}}[%{{.*}}, 1] [1, 1] [1, 1] : memref<5x3xf32> to memref<f32, #[[$SUBVIEW_MAP12]]>
%28 = subview %24[%arg0, 1] [1, 1] [1, 1] : memref<5x3xf32> to memref<f32, affine_map<()[s0] -> (s0)>>
// CHECK: subview %{{.*}}[0, %{{.*}}] [%{{.*}}, 1] [1, 1] : memref<?x?xf32> to memref<?xf32, #[[$SUBVIEW_MAP1]]>
%a30 = alloc(%arg0, %arg0) : memref<?x?xf32>
%30 = subview %a30[0, %arg1][%arg2, 1][1, 1] : memref<?x?xf32> to memref<?xf32, affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>>
return
}

View File

@ -970,7 +970,7 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
%0 = alloc() : memref<8x16x4xf32>
// expected-error@+1 {{expected result type to be 'memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>' or a rank-reduced version. (mismatch of result strides)}}
// expected-error@+1 {{expected result type to be 'memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>' or a rank-reduced version. (mismatch of result affine map)}}
%1 = subview %0[%arg0, %arg1, %arg2][%arg0, %arg1, %arg2][%arg0, %arg1, %arg2]
: memref<8x16x4xf32> to
memref<?x?x?xf32, offset: ?, strides: [64, 4, 1]>
@ -1022,13 +1022,22 @@ func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index)
// -----
func @invalid_rank_reducing_subview(%arg0 : memref<?x?xf32>, %arg1 : index, %arg2 : index) {
// expected-error@+1 {{expected result type to be 'memref<?x1xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result strides)}}
// expected-error@+1 {{expected result type to be 'memref<?x1xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result affine map)}}
%0 = subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref<?x?xf32> to memref<?xf32>
return
}
// -----
// The affine map affine_map<(d0)[s0, s1, s2] -> (d0 * s1 + s0)> has an extra unused symbol.
func @invalid_rank_reducing_subview(%arg0 : memref<?x?xf32>, %arg1 : index, %arg2 : index) {
// expected-error@+1 {{expected result type to be 'memref<?x1xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result affine map) inferred type: (d0)[s0, s1] -> (d0 * s1 + s0)}}
%0 = subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref<?x?xf32> to memref<?xf32, affine_map<(d0)[s0, s1, s2] -> (d0 * s1 + s0)>>
return
}
// -----
func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) {
// expected-error@+1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 128 + d1 * 32 + d2 * 2)>>' are cast incompatible}}
%0 = memref_cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:0, strides:[128, 32, 2]>