Support non identity layout map for reshape ops in MemRefToLLVM lowering

This change borrows the ideas from `computeExpanded/CollapsedLayoutMap`
and computes the dynamic strides at runtime for the memref descriptors.

Differential Revision: https://reviews.llvm.org/D124001
This commit is contained in:
Yi Zhang 2022-04-18 20:50:30 -04:00
parent e1836123a7
commit e1318078a4
3 changed files with 209 additions and 18 deletions

View File

@ -79,6 +79,9 @@ public:
void setConstantStride(OpBuilder &builder, Location loc, unsigned pos,
uint64_t stride);
/// Returns the type of array element in this descriptor.
Type getIndexType() { return indexType; };
/// Returns the (LLVM) pointer type this descriptor contains.
LLVM::LLVMPointerType getElementPtrType();

View File

@ -1301,7 +1301,7 @@ static OpFoldResult getCollapsedOutputDimSize(
static SmallVector<OpFoldResult, 4>
getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
ArrayRef<ReassociationIndices> reassocation,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<int64_t> inStaticShape,
MemRefDescriptor &inDesc,
ArrayRef<int64_t> outStaticShape) {
@ -1309,42 +1309,155 @@ getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
outStaticShape[outDimIndex],
inStaticShape, inDesc, reassocation);
inStaticShape, inDesc, reassociation);
}));
}
static SmallVector<OpFoldResult, 4>
getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
ArrayRef<ReassociationIndices> reassocation,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<int64_t> inStaticShape,
MemRefDescriptor &inDesc,
ArrayRef<int64_t> outStaticShape) {
DenseMap<int64_t, int64_t> outDimToInDimMap =
getExpandedDimToCollapsedDimMap(reassocation);
getExpandedDimToCollapsedDimMap(reassociation);
return llvm::to_vector<4>(llvm::map_range(
llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
outStaticShape, inDesc, inStaticShape,
reassocation, outDimToInDimMap);
reassociation, outDimToInDimMap);
}));
}
static SmallVector<Value>
getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
ArrayRef<ReassociationIndices> reassocation,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
ArrayRef<int64_t> outStaticShape) {
return outStaticShape.size() < inStaticShape.size()
? getAsValues(b, loc, llvmIndexType,
getCollapsedOutputShape(b, loc, llvmIndexType,
reassocation, inStaticShape,
reassociation, inStaticShape,
inDesc, outStaticShape))
: getAsValues(b, loc, llvmIndexType,
getExpandedOutputShape(b, loc, llvmIndexType,
reassocation, inStaticShape,
reassociation, inStaticShape,
inDesc, outStaticShape));
}
static void fillInStridesForExpandedMemDescriptor(
OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc,
MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
// See comments for computeExpandedLayoutMap for details on how the strides
// are calculated.
for (auto &en : llvm::enumerate(reassociation)) {
auto currentStrideToExpand = srcDesc.stride(b, loc, en.index());
for (auto dstIndex : llvm::reverse(en.value())) {
dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand);
Value size = dstDesc.size(b, loc, dstIndex);
currentStrideToExpand =
b.create<LLVM::MulOp>(loc, size, currentStrideToExpand);
}
}
}
static void fillInStridesForCollapsedMemDescriptor(
ConversionPatternRewriter &rewriter, Location loc, Operation *op,
TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc,
MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
// See comments for computeCollapsedLayoutMap for details on how the strides
// are calculated.
auto srcShape = srcType.getShape();
for (auto &en : llvm::enumerate(reassociation)) {
rewriter.setInsertionPoint(op);
auto dstIndex = en.index();
ArrayRef<int64_t> ref = llvm::makeArrayRef(en.value());
while (srcShape[ref.back()] == 1 && ref.size() > 1)
ref = ref.drop_back();
if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
dstDesc.setStride(rewriter, loc, dstIndex,
srcDesc.stride(rewriter, loc, ref.back()));
} else {
// Iterate over the source strides in reverse order. Skip over the
// dimensions whose size is 1.
// TODO: we should take the minimum stride in the reassociation group
// instead of just the first where the dimension is not 1.
//
// +------------------------------------------------------+
// | curEntry: |
// | %srcStride = strides[srcIndex] |
// | %neOne = cmp sizes[srcIndex],1 +--+
// | cf.cond_br %neOne, continue(%srcStride), nextEntry | |
// +-------------------------+----------------------------+ |
// | |
// v |
// +-----------------------------+ |
// | nextEntry: | |
// | ... +---+ |
// +--------------+--------------+ | |
// | | |
// v | |
// +-----------------------------+ | |
// | nextEntry: | | |
// | ... | | |
// +--------------+--------------+ | +--------+
// | | |
// v v v
// +--------------------------------------------------+
// | continue(%newStride): |
// | %newMemRefDes = setStride(%newStride,dstIndex) |
// +--------------------------------------------------+
OpBuilder::InsertionGuard guard(rewriter);
Block *initBlock = rewriter.getInsertionBlock();
Block *continueBlock =
rewriter.splitBlock(initBlock, rewriter.getInsertionPoint());
continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc);
rewriter.setInsertionPointToStart(continueBlock);
dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0));
Block *curEntryBlock = initBlock;
Block *nextEntryBlock;
for (auto srcIndex : llvm::reverse(ref)) {
if (srcShape[srcIndex] == 1 && srcIndex != ref.front())
continue;
rewriter.setInsertionPointToEnd(curEntryBlock);
Value srcStride = srcDesc.stride(rewriter, loc, srcIndex);
if (srcIndex == ref.front()) {
rewriter.create<LLVM::BrOp>(loc, srcStride, continueBlock);
break;
}
Value one = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter->convertType(rewriter.getI64Type()),
rewriter.getI32IntegerAttr(1));
Value predNeOne = rewriter.create<LLVM::ICmpOp>(
loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex),
one);
{
OpBuilder::InsertionGuard guard(rewriter);
nextEntryBlock = rewriter.createBlock(
initBlock->getParent(), Region::iterator(continueBlock), {});
}
rewriter.create<LLVM::CondBrOp>(loc, predNeOne, continueBlock,
srcStride, nextEntryBlock, llvm::None);
curEntryBlock = nextEntryBlock;
}
}
}
}
static void fillInDynamicStridesForMemDescriptor(
ConversionPatternRewriter &b, Location loc, Operation *op,
TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType,
MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc,
ArrayRef<ReassociationIndices> reassociation) {
if (srcType.getRank() > dstType.getRank())
fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType,
srcDesc, dstDesc, reassociation);
else
fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc,
reassociation);
}
// ReshapeOp creates a new view descriptor of the proper rank.
// For now, the only conversion supported is for target MemRef with static sizes
// and strides.
@ -1361,15 +1474,6 @@ public:
MemRefType dstType = reshapeOp.getResultType();
MemRefType srcType = reshapeOp.getSrcType();
// The condition on the layouts can be ignored when all shapes are static.
if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) {
if (!srcType.getLayout().isIdentity() ||
!dstType.getLayout().isIdentity()) {
return rewriter.notifyMatchFailure(
reshapeOp, "only empty layout map is supported");
}
}
int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(dstType, strides, offset))) {
@ -1401,7 +1505,8 @@ public:
if (llvm::all_of(strides, isStaticStride)) {
for (auto &en : llvm::enumerate(strides))
dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
} else {
} else if (srcType.getLayout().isIdentity() &&
dstType.getLayout().isIdentity()) {
Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
rewriter.getIndexAttr(1));
Value stride = c1;
@ -1410,6 +1515,12 @@ public:
dstDesc.setStride(rewriter, loc, dimIndex, stride);
stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
}
} else {
// There could be mixed static/dynamic strides. For simplicity, we
// recompute all strides if there is at least one dynamic stride.
fillInDynamicStridesForMemDescriptor(
rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType,
srcDesc, dstDesc, reshapeOp.getReassociationIndices());
}
rewriter.replaceOp(reshapeOp, {dstDesc});
return success();

View File

@ -706,6 +706,45 @@ func.func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf
// -----
func.func @collapse_shape_dynamic_with_non_identity_layout(
%arg0 : memref<4x?x?xf32, affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s1 + s0 + d1 * 4 + d2)>>) ->
memref<4x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>> {
%0 = memref.collapse_shape %arg0 [[0], [1, 2]]:
memref<4x?x?xf32, affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s1 + s0 + d1 * 4 + d2)>> into
memref<4x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>
return %0 : memref<4x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>
}
// CHECK-LABEL: func @collapse_shape_dynamic_with_non_identity_layout(
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.mlir.constant(1 : index) : i64
// CHECK: llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64
// CHECK: llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64
// CHECK: llvm.mlir.constant(4 : index) : i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.extractvalue %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.mlir.constant(1 : i32) : i64
// CHECK: llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : i64
// CHECK: llvm.cond_br %{{.*}}, ^bb2(%{{.*}} : i64), ^bb1
// CHECK: ^bb1:
// CHECK: llvm.extractvalue %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.br ^bb2(%{{.*}} : i64)
// CHECK: ^bb2(%[[STRIDE:.*]]: i64):
// CHECK: llvm.insertvalue %[[STRIDE]], %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// -----
func.func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
// Reshapes that expand a contiguous tensor with some 1's.
%0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]]
@ -840,6 +879,44 @@ func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>) -> memref<1x2x?xf32> {
// -----
func.func @expand_shape_dynamic_with_non_identity_layout(
%arg0 : memref<1x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>) ->
memref<1x2x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>> {
%0 = memref.expand_shape %arg0 [[0], [1, 2]]:
memref<1x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>> into
memref<1x2x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>
return %0 : memref<1x2x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>
}
// CHECK-LABEL: func @expand_shape_dynamic_with_non_identity_layout(
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.mlir.constant(2 : index) : i64
// CHECK: llvm.sdiv %{{.*}}, %{{.*}} : i64
// CHECK: llvm.mlir.constant(1 : index) : i64
// CHECK: llvm.mlir.constant(2 : index) : i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64
// CHECK: llvm.extractvalue %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64
// -----
// CHECK-LABEL: func @rank_of_unranked
// CHECK32-LABEL: func @rank_of_unranked
func.func @rank_of_unranked(%unranked: memref<*xi32>) {