diff --git a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h index 28d4a0926007..ac54ee688813 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h @@ -79,6 +79,9 @@ public: void setConstantStride(OpBuilder &builder, Location loc, unsigned pos, uint64_t stride); + /// Returns the type of array element in this descriptor. + Type getIndexType() { return indexType; }; + /// Returns the (LLVM) pointer type this descriptor contains. LLVM::LLVMPointerType getElementPtrType(); diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index b0c881077efb..110b40adf777 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1301,7 +1301,7 @@ static OpFoldResult getCollapsedOutputDimSize( static SmallVector getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, - ArrayRef reassocation, + ArrayRef reassociation, ArrayRef inStaticShape, MemRefDescriptor &inDesc, ArrayRef outStaticShape) { @@ -1309,42 +1309,155 @@ getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, llvm::seq(0, outStaticShape.size()), [&](int64_t outDimIndex) { return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex, outStaticShape[outDimIndex], - inStaticShape, inDesc, reassocation); + inStaticShape, inDesc, reassociation); })); } static SmallVector getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, - ArrayRef reassocation, + ArrayRef reassociation, ArrayRef inStaticShape, MemRefDescriptor &inDesc, ArrayRef outStaticShape) { DenseMap outDimToInDimMap = - getExpandedDimToCollapsedDimMap(reassocation); + getExpandedDimToCollapsedDimMap(reassociation); return llvm::to_vector<4>(llvm::map_range( llvm::seq(0, outStaticShape.size()), [&](int64_t outDimIndex) { return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex, outStaticShape, inDesc, inStaticShape, - reassocation, outDimToInDimMap); + reassociation, outDimToInDimMap); })); } static SmallVector getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, - ArrayRef reassocation, + ArrayRef reassociation, ArrayRef inStaticShape, MemRefDescriptor &inDesc, ArrayRef outStaticShape) { return outStaticShape.size() < inStaticShape.size() ? getAsValues(b, loc, llvmIndexType, getCollapsedOutputShape(b, loc, llvmIndexType, - reassocation, inStaticShape, + reassociation, inStaticShape, inDesc, outStaticShape)) : getAsValues(b, loc, llvmIndexType, getExpandedOutputShape(b, loc, llvmIndexType, - reassocation, inStaticShape, + reassociation, inStaticShape, inDesc, outStaticShape)); } +static void fillInStridesForExpandedMemDescriptor( + OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc, + MemRefDescriptor &dstDesc, ArrayRef reassociation) { + // See comments for computeExpandedLayoutMap for details on how the strides + // are calculated. + for (auto &en : llvm::enumerate(reassociation)) { + auto currentStrideToExpand = srcDesc.stride(b, loc, en.index()); + for (auto dstIndex : llvm::reverse(en.value())) { + dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand); + Value size = dstDesc.size(b, loc, dstIndex); + currentStrideToExpand = + b.create(loc, size, currentStrideToExpand); + } + } +} + +static void fillInStridesForCollapsedMemDescriptor( + ConversionPatternRewriter &rewriter, Location loc, Operation *op, + TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc, + MemRefDescriptor &dstDesc, ArrayRef reassociation) { + // See comments for computeCollapsedLayoutMap for details on how the strides + // are calculated. + auto srcShape = srcType.getShape(); + for (auto &en : llvm::enumerate(reassociation)) { + rewriter.setInsertionPoint(op); + auto dstIndex = en.index(); + ArrayRef ref = llvm::makeArrayRef(en.value()); + while (srcShape[ref.back()] == 1 && ref.size() > 1) + ref = ref.drop_back(); + if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) { + dstDesc.setStride(rewriter, loc, dstIndex, + srcDesc.stride(rewriter, loc, ref.back())); + } else { + // Iterate over the source strides in reverse order. Skip over the + // dimensions whose size is 1. + // TODO: we should take the minimum stride in the reassociation group + // instead of just the first where the dimension is not 1. + // + // +------------------------------------------------------+ + // | curEntry: | + // | %srcStride = strides[srcIndex] | + // | %neOne = cmp sizes[srcIndex],1 +--+ + // | cf.cond_br %neOne, continue(%srcStride), nextEntry | | + // +-------------------------+----------------------------+ | + // | | + // v | + // +-----------------------------+ | + // | nextEntry: | | + // | ... +---+ | + // +--------------+--------------+ | | + // | | | + // v | | + // +-----------------------------+ | | + // | nextEntry: | | | + // | ... | | | + // +--------------+--------------+ | +--------+ + // | | | + // v v v + // +--------------------------------------------------+ + // | continue(%newStride): | + // | %newMemRefDes = setStride(%newStride,dstIndex) | + // +--------------------------------------------------+ + OpBuilder::InsertionGuard guard(rewriter); + Block *initBlock = rewriter.getInsertionBlock(); + Block *continueBlock = + rewriter.splitBlock(initBlock, rewriter.getInsertionPoint()); + continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc); + rewriter.setInsertionPointToStart(continueBlock); + dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0)); + + Block *curEntryBlock = initBlock; + Block *nextEntryBlock; + for (auto srcIndex : llvm::reverse(ref)) { + if (srcShape[srcIndex] == 1 && srcIndex != ref.front()) + continue; + rewriter.setInsertionPointToEnd(curEntryBlock); + Value srcStride = srcDesc.stride(rewriter, loc, srcIndex); + if (srcIndex == ref.front()) { + rewriter.create(loc, srcStride, continueBlock); + break; + } + Value one = rewriter.create( + loc, typeConverter->convertType(rewriter.getI64Type()), + rewriter.getI32IntegerAttr(1)); + Value predNeOne = rewriter.create( + loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex), + one); + { + OpBuilder::InsertionGuard guard(rewriter); + nextEntryBlock = rewriter.createBlock( + initBlock->getParent(), Region::iterator(continueBlock), {}); + } + rewriter.create(loc, predNeOne, continueBlock, + srcStride, nextEntryBlock, llvm::None); + curEntryBlock = nextEntryBlock; + } + } + } +} + +static void fillInDynamicStridesForMemDescriptor( + ConversionPatternRewriter &b, Location loc, Operation *op, + TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType, + MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc, + ArrayRef reassociation) { + if (srcType.getRank() > dstType.getRank()) + fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType, + srcDesc, dstDesc, reassociation); + else + fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc, + reassociation); +} + // ReshapeOp creates a new view descriptor of the proper rank. // For now, the only conversion supported is for target MemRef with static sizes // and strides. @@ -1361,15 +1474,6 @@ public: MemRefType dstType = reshapeOp.getResultType(); MemRefType srcType = reshapeOp.getSrcType(); - // The condition on the layouts can be ignored when all shapes are static. - if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) { - if (!srcType.getLayout().isIdentity() || - !dstType.getLayout().isIdentity()) { - return rewriter.notifyMatchFailure( - reshapeOp, "only empty layout map is supported"); - } - } - int64_t offset; SmallVector strides; if (failed(getStridesAndOffset(dstType, strides, offset))) { @@ -1401,7 +1505,8 @@ public: if (llvm::all_of(strides, isStaticStride)) { for (auto &en : llvm::enumerate(strides)) dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); - } else { + } else if (srcType.getLayout().isIdentity() && + dstType.getLayout().isIdentity()) { Value c1 = rewriter.create(loc, llvmIndexType, rewriter.getIndexAttr(1)); Value stride = c1; @@ -1410,6 +1515,12 @@ public: dstDesc.setStride(rewriter, loc, dimIndex, stride); stride = rewriter.create(loc, dstShape[dimIndex], stride); } + } else { + // There could be mixed static/dynamic strides. For simplicity, we + // recompute all strides if there is at least one dynamic stride. + fillInDynamicStridesForMemDescriptor( + rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType, + srcDesc, dstDesc, reshapeOp.getReassociationIndices()); } rewriter.replaceOp(reshapeOp, {dstDesc}); return success(); diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir index 9b03c18a5880..8b215769e075 100644 --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -706,6 +706,45 @@ func.func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf // ----- +func.func @collapse_shape_dynamic_with_non_identity_layout( + %arg0 : memref<4x?x?xf32, affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s1 + s0 + d1 * 4 + d2)>>) -> + memref<4x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>> { + %0 = memref.collapse_shape %arg0 [[0], [1, 2]]: + memref<4x?x?xf32, affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s1 + s0 + d1 * 4 + d2)>> into + memref<4x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>> + return %0 : memref<4x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>> +} +// CHECK-LABEL: func @collapse_shape_dynamic_with_non_identity_layout( +// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.mlir.constant(1 : index) : i64 +// CHECK: llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 +// CHECK: llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 +// CHECK: llvm.mlir.constant(4 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mlir.constant(1 : i32) : i64 +// CHECK: llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : i64 +// CHECK: llvm.cond_br %{{.*}}, ^bb2(%{{.*}} : i64), ^bb1 +// CHECK: ^bb1: +// CHECK: llvm.extractvalue %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.br ^bb2(%{{.*}} : i64) +// CHECK: ^bb2(%[[STRIDE:.*]]: i64): +// CHECK: llvm.insertvalue %[[STRIDE]], %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + +// ----- + func.func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> { // Reshapes that expand a contiguous tensor with some 1's. %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] @@ -840,6 +879,44 @@ func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>) -> memref<1x2x?xf32> { // ----- +func.func @expand_shape_dynamic_with_non_identity_layout( + %arg0 : memref<1x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>) -> + memref<1x2x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>> { + %0 = memref.expand_shape %arg0 [[0], [1, 2]]: + memref<1x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>> into + memref<1x2x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>> + return %0 : memref<1x2x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>> +} +// CHECK-LABEL: func @expand_shape_dynamic_with_non_identity_layout( +// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.mlir.constant(2 : index) : i64 +// CHECK: llvm.sdiv %{{.*}}, %{{.*}} : i64 +// CHECK: llvm.mlir.constant(1 : index) : i64 +// CHECK: llvm.mlir.constant(2 : index) : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 +// CHECK: llvm.extractvalue %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 + +// ----- + // CHECK-LABEL: func @rank_of_unranked // CHECK32-LABEL: func @rank_of_unranked func.func @rank_of_unranked(%unranked: memref<*xi32>) {