diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 6b6ba9065373..77a23b156b0f 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -597,11 +597,140 @@ public: return matchSuccess(); } }; + +// Apply the affine map from an 'affine.load' operation to its operands, and +// feed the results to a newly created 'std.load' operation (which replaces the +// original 'affine.load'). +class AffineLoadLowering : public ConversionPattern { +public: + AffineLoadLowering(MLIRContext *ctx) + : ConversionPattern(AffineLoadOp::getOperationName(), 1, ctx) {} + + virtual PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { + auto affineLoadOp = cast(op); + // Expand affine map from 'affineLoadOp'. + auto maybeExpandedMap = + expandAffineMap(rewriter, op->getLoc(), affineLoadOp.getAffineMap(), + operands.drop_front()); + if (!maybeExpandedMap) + return matchFailure(); + // Build std.load memref[expandedMap.results]. + rewriter.replaceOpWithNewOp(op, operands[0], *maybeExpandedMap); + return matchSuccess(); + } +}; + +// Apply the affine map from an 'affine.store' operation to its operands, and +// feed the results to a newly created 'std.store' operation (which replaces the +// original 'affine.store'). +class AffineStoreLowering : public ConversionPattern { +public: + AffineStoreLowering(MLIRContext *ctx) + : ConversionPattern(AffineStoreOp::getOperationName(), 1, ctx) {} + + virtual PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { + auto affineStoreOp = cast(op); + // Expand affine map from 'affineStoreOp'. + auto maybeExpandedMap = + expandAffineMap(rewriter, op->getLoc(), affineStoreOp.getAffineMap(), + operands.drop_front(2)); + if (!maybeExpandedMap) + return matchFailure(); + // Build std.store valutToStore, memref[expandedMap.results]. + rewriter.replaceOpWithNewOp(op, operands[0], operands[1], + *maybeExpandedMap); + return matchSuccess(); + } +}; + +// Apply the affine maps from an 'affine.dma_start' operation to each of their +// respective map operands, and feed the results to a newly created +// 'std.dma_start' operation (which replaces the original 'affine.dma_start'). +class AffineDmaStartLowering : public ConversionPattern { +public: + AffineDmaStartLowering(MLIRContext *ctx) + : ConversionPattern(AffineDmaStartOp::getOperationName(), 1, ctx) {} + + virtual PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { + auto affineDmaStartOp = cast(op); + // Expand affine map for DMA source memref. + auto maybeExpandedSrcMap = expandAffineMap( + rewriter, op->getLoc(), affineDmaStartOp.getSrcMap(), + operands.drop_front(affineDmaStartOp.getSrcMemRefOperandIndex() + 1)); + if (!maybeExpandedSrcMap) + return matchFailure(); + // Expand affine map for DMA destination memref. + auto maybeExpandedDstMap = expandAffineMap( + rewriter, op->getLoc(), affineDmaStartOp.getDstMap(), + operands.drop_front(affineDmaStartOp.getDstMemRefOperandIndex() + 1)); + if (!maybeExpandedDstMap) + return matchFailure(); + // Expand affine map for DMA tag memref. + auto maybeExpandedTagMap = expandAffineMap( + rewriter, op->getLoc(), affineDmaStartOp.getTagMap(), + operands.drop_front(affineDmaStartOp.getTagMemRefOperandIndex() + 1)); + if (!maybeExpandedTagMap) + return matchFailure(); + + // Build std.dma_start operation with affine map results. + auto *srcMemRef = operands[affineDmaStartOp.getSrcMemRefOperandIndex()]; + auto *dstMemRef = operands[affineDmaStartOp.getDstMemRefOperandIndex()]; + auto *tagMemRef = operands[affineDmaStartOp.getTagMemRefOperandIndex()]; + unsigned numElementsIndex = affineDmaStartOp.getTagMemRefOperandIndex() + + 1 + affineDmaStartOp.getTagMap().getNumInputs(); + auto *numElements = operands[numElementsIndex]; + auto *stride = + affineDmaStartOp.isStrided() ? operands[numElementsIndex + 1] : nullptr; + auto *eltsPerStride = + affineDmaStartOp.isStrided() ? operands[numElementsIndex + 2] : nullptr; + + rewriter.replaceOpWithNewOp( + op, srcMemRef, *maybeExpandedSrcMap, dstMemRef, *maybeExpandedDstMap, + numElements, tagMemRef, *maybeExpandedTagMap, stride, eltsPerStride); + return matchSuccess(); + } +}; + +// Apply the affine map from an 'affine.dma_wait' operation tag memref, +// and feed the results to a newly created 'std.dma_wait' operation (which +// replaces the original 'affine.dma_wait'). +class AffineDmaWaitLowering : public ConversionPattern { +public: + AffineDmaWaitLowering(MLIRContext *ctx) + : ConversionPattern(AffineDmaWaitOp::getOperationName(), 1, ctx) {} + + virtual PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { + auto affineDmaWaitOp = cast(op); + // Expand affine map for DMA tag memref. + auto maybeExpandedTagMap = + expandAffineMap(rewriter, op->getLoc(), affineDmaWaitOp.getTagMap(), + operands.drop_front()); + if (!maybeExpandedTagMap) + return matchFailure(); + + // Build std.dma_wait operation with affine map results. + unsigned numElementsIndex = 1 + affineDmaWaitOp.getTagMap().getNumInputs(); + rewriter.replaceOpWithNewOp( + op, operands[0], *maybeExpandedTagMap, operands[numElementsIndex]); + return matchSuccess(); + } +}; + } // end namespace LogicalResult mlir::lowerAffineConstructs(Function &function) { OwningRewritePatternList patterns; - RewriteListBuilder::build(patterns, function.getContext()); ConversionTarget target(*function.getContext()); diff --git a/mlir/test/Transforms/lower-affine.mlir b/mlir/test/Transforms/lower-affine.mlir index fc6afbd9b68e..8538a94a2616 100644 --- a/mlir/test/Transforms/lower-affine.mlir +++ b/mlir/test/Transforms/lower-affine.mlir @@ -637,3 +637,65 @@ func @affine_apply_ceildiv(%arg0 : index) -> (index) { %0 = affine.apply #mapceildiv (%arg0) return %0 : index } + +// CHECK-LABEL: func @affine_load +func @affine_load(%arg0 : index) { + %0 = alloc() : memref<10xf32> + affine.for %i0 = 0 to 10 { + %1 = affine.load %0[%i0 + symbol(%arg0) + 7] : memref<10xf32> + } +// CHECK: %3 = addi %1, %arg0 : index +// CHECK-NEXT: %c7 = constant 7 : index +// CHECK-NEXT: %4 = addi %3, %c7 : index +// CHECK-NEXT: %5 = load %0[%4] : memref<10xf32> + return +} + +// CHECK-LABEL: func @affine_store +func @affine_store(%arg0 : index) { + %0 = alloc() : memref<10xf32> + %1 = constant 11.0 : f32 + affine.for %i0 = 0 to 10 { + affine.store %1, %0[%i0 - symbol(%arg0) + 7] : memref<10xf32> + } +// CHECK: %c-1 = constant -1 : index +// CHECK-NEXT: %3 = muli %arg0, %c-1 : index +// CHECK-NEXT: %4 = addi %1, %3 : index +// CHECK-NEXT: %c7 = constant 7 : index +// CHECK-NEXT: %5 = addi %4, %c7 : index +// CHECK-NEXT: store %cst, %0[%5] : memref<10xf32> + return +} + +// CHECK-LABEL: func @affine_dma_start +func @affine_dma_start(%arg0 : index) { + %0 = alloc() : memref<100xf32> + %1 = alloc() : memref<100xf32, 2> + %2 = alloc() : memref<1xi32> + %c0 = constant 0 : index + %c64 = constant 64 : index + affine.for %i0 = 0 to 10 { + affine.dma_start %0[%i0 + 7], %1[%arg0 + 11], %2[%c0], %c64 + : memref<100xf32>, memref<100xf32, 2>, memref<1xi32> + } +// CHECK: %c7 = constant 7 : index +// CHECK-NEXT: %5 = addi %3, %c7 : index +// CHECK-NEXT: %c11 = constant 11 : index +// CHECK-NEXT: %6 = addi %arg0, %c11 : index +// CHECK-NEXT: dma_start %0[%5], %1[%6], %c64, %2[%c0] : memref<100xf32>, memref<100xf32, 2>, memref<1xi32> + return +} + +// CHECK-LABEL: func @affine_dma_wait +func @affine_dma_wait(%arg0 : index) { + %2 = alloc() : memref<1xi32> + %c64 = constant 64 : index + affine.for %i0 = 0 to 10 { + affine.dma_wait %2[%i0 + %arg0 + 17], %c64 : memref<1xi32> + } +// CHECK: %3 = addi %1, %arg0 : index +// CHECK-NEXT: %c17 = constant 17 : index +// CHECK-NEXT: %4 = addi %3, %c17 : index +// CHECK-NEXT: dma_wait %0[%4], %c64 : memref<1xi32> + return +}