Add affine-to-standard lowerings for affine.load/store/dma_start/dma_wait.

PiperOrigin-RevId: 255960171
This commit is contained in:
Andy Davis 2019-07-01 08:32:44 -07:00 committed by jpienaar
parent 2dc5e19426
commit f487d20bf0
2 changed files with 192 additions and 1 deletions

View File

@ -597,11 +597,140 @@ public:
return matchSuccess(); return matchSuccess();
} }
}; };
// Apply the affine map from an 'affine.load' operation to its operands, and
// feed the results to a newly created 'std.load' operation (which replaces the
// original 'affine.load').
class AffineLoadLowering : public ConversionPattern {
public:
AffineLoadLowering(MLIRContext *ctx)
: ConversionPattern(AffineLoadOp::getOperationName(), 1, ctx) {}
virtual PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto affineLoadOp = cast<AffineLoadOp>(op);
// Expand affine map from 'affineLoadOp'.
auto maybeExpandedMap =
expandAffineMap(rewriter, op->getLoc(), affineLoadOp.getAffineMap(),
operands.drop_front());
if (!maybeExpandedMap)
return matchFailure();
// Build std.load memref[expandedMap.results].
rewriter.replaceOpWithNewOp<LoadOp>(op, operands[0], *maybeExpandedMap);
return matchSuccess();
}
};
// Apply the affine map from an 'affine.store' operation to its operands, and
// feed the results to a newly created 'std.store' operation (which replaces the
// original 'affine.store').
class AffineStoreLowering : public ConversionPattern {
public:
AffineStoreLowering(MLIRContext *ctx)
: ConversionPattern(AffineStoreOp::getOperationName(), 1, ctx) {}
virtual PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto affineStoreOp = cast<AffineStoreOp>(op);
// Expand affine map from 'affineStoreOp'.
auto maybeExpandedMap =
expandAffineMap(rewriter, op->getLoc(), affineStoreOp.getAffineMap(),
operands.drop_front(2));
if (!maybeExpandedMap)
return matchFailure();
// Build std.store valutToStore, memref[expandedMap.results].
rewriter.replaceOpWithNewOp<StoreOp>(op, operands[0], operands[1],
*maybeExpandedMap);
return matchSuccess();
}
};
// Apply the affine maps from an 'affine.dma_start' operation to each of their
// respective map operands, and feed the results to a newly created
// 'std.dma_start' operation (which replaces the original 'affine.dma_start').
class AffineDmaStartLowering : public ConversionPattern {
public:
AffineDmaStartLowering(MLIRContext *ctx)
: ConversionPattern(AffineDmaStartOp::getOperationName(), 1, ctx) {}
virtual PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto affineDmaStartOp = cast<AffineDmaStartOp>(op);
// Expand affine map for DMA source memref.
auto maybeExpandedSrcMap = expandAffineMap(
rewriter, op->getLoc(), affineDmaStartOp.getSrcMap(),
operands.drop_front(affineDmaStartOp.getSrcMemRefOperandIndex() + 1));
if (!maybeExpandedSrcMap)
return matchFailure();
// Expand affine map for DMA destination memref.
auto maybeExpandedDstMap = expandAffineMap(
rewriter, op->getLoc(), affineDmaStartOp.getDstMap(),
operands.drop_front(affineDmaStartOp.getDstMemRefOperandIndex() + 1));
if (!maybeExpandedDstMap)
return matchFailure();
// Expand affine map for DMA tag memref.
auto maybeExpandedTagMap = expandAffineMap(
rewriter, op->getLoc(), affineDmaStartOp.getTagMap(),
operands.drop_front(affineDmaStartOp.getTagMemRefOperandIndex() + 1));
if (!maybeExpandedTagMap)
return matchFailure();
// Build std.dma_start operation with affine map results.
auto *srcMemRef = operands[affineDmaStartOp.getSrcMemRefOperandIndex()];
auto *dstMemRef = operands[affineDmaStartOp.getDstMemRefOperandIndex()];
auto *tagMemRef = operands[affineDmaStartOp.getTagMemRefOperandIndex()];
unsigned numElementsIndex = affineDmaStartOp.getTagMemRefOperandIndex() +
1 + affineDmaStartOp.getTagMap().getNumInputs();
auto *numElements = operands[numElementsIndex];
auto *stride =
affineDmaStartOp.isStrided() ? operands[numElementsIndex + 1] : nullptr;
auto *eltsPerStride =
affineDmaStartOp.isStrided() ? operands[numElementsIndex + 2] : nullptr;
rewriter.replaceOpWithNewOp<DmaStartOp>(
op, srcMemRef, *maybeExpandedSrcMap, dstMemRef, *maybeExpandedDstMap,
numElements, tagMemRef, *maybeExpandedTagMap, stride, eltsPerStride);
return matchSuccess();
}
};
// Apply the affine map from an 'affine.dma_wait' operation tag memref,
// and feed the results to a newly created 'std.dma_wait' operation (which
// replaces the original 'affine.dma_wait').
class AffineDmaWaitLowering : public ConversionPattern {
public:
AffineDmaWaitLowering(MLIRContext *ctx)
: ConversionPattern(AffineDmaWaitOp::getOperationName(), 1, ctx) {}
virtual PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto affineDmaWaitOp = cast<AffineDmaWaitOp>(op);
// Expand affine map for DMA tag memref.
auto maybeExpandedTagMap =
expandAffineMap(rewriter, op->getLoc(), affineDmaWaitOp.getTagMap(),
operands.drop_front());
if (!maybeExpandedTagMap)
return matchFailure();
// Build std.dma_wait operation with affine map results.
unsigned numElementsIndex = 1 + affineDmaWaitOp.getTagMap().getNumInputs();
rewriter.replaceOpWithNewOp<DmaWaitOp>(
op, operands[0], *maybeExpandedTagMap, operands[numElementsIndex]);
return matchSuccess();
}
};
} // end namespace } // end namespace
LogicalResult mlir::lowerAffineConstructs(Function &function) { LogicalResult mlir::lowerAffineConstructs(Function &function) {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
RewriteListBuilder<AffineApplyLowering, AffineForLowering, AffineIfLowering, RewriteListBuilder<AffineApplyLowering, AffineDmaStartLowering,
AffineDmaWaitLowering, AffineLoadLowering,
AffineStoreLowering, AffineForLowering, AffineIfLowering,
AffineTerminatorLowering>::build(patterns, AffineTerminatorLowering>::build(patterns,
function.getContext()); function.getContext());
ConversionTarget target(*function.getContext()); ConversionTarget target(*function.getContext());

View File

@ -637,3 +637,65 @@ func @affine_apply_ceildiv(%arg0 : index) -> (index) {
%0 = affine.apply #mapceildiv (%arg0) %0 = affine.apply #mapceildiv (%arg0)
return %0 : index return %0 : index
} }
// CHECK-LABEL: func @affine_load
func @affine_load(%arg0 : index) {
%0 = alloc() : memref<10xf32>
affine.for %i0 = 0 to 10 {
%1 = affine.load %0[%i0 + symbol(%arg0) + 7] : memref<10xf32>
}
// CHECK: %3 = addi %1, %arg0 : index
// CHECK-NEXT: %c7 = constant 7 : index
// CHECK-NEXT: %4 = addi %3, %c7 : index
// CHECK-NEXT: %5 = load %0[%4] : memref<10xf32>
return
}
// CHECK-LABEL: func @affine_store
func @affine_store(%arg0 : index) {
%0 = alloc() : memref<10xf32>
%1 = constant 11.0 : f32
affine.for %i0 = 0 to 10 {
affine.store %1, %0[%i0 - symbol(%arg0) + 7] : memref<10xf32>
}
// CHECK: %c-1 = constant -1 : index
// CHECK-NEXT: %3 = muli %arg0, %c-1 : index
// CHECK-NEXT: %4 = addi %1, %3 : index
// CHECK-NEXT: %c7 = constant 7 : index
// CHECK-NEXT: %5 = addi %4, %c7 : index
// CHECK-NEXT: store %cst, %0[%5] : memref<10xf32>
return
}
// CHECK-LABEL: func @affine_dma_start
func @affine_dma_start(%arg0 : index) {
%0 = alloc() : memref<100xf32>
%1 = alloc() : memref<100xf32, 2>
%2 = alloc() : memref<1xi32>
%c0 = constant 0 : index
%c64 = constant 64 : index
affine.for %i0 = 0 to 10 {
affine.dma_start %0[%i0 + 7], %1[%arg0 + 11], %2[%c0], %c64
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32>
}
// CHECK: %c7 = constant 7 : index
// CHECK-NEXT: %5 = addi %3, %c7 : index
// CHECK-NEXT: %c11 = constant 11 : index
// CHECK-NEXT: %6 = addi %arg0, %c11 : index
// CHECK-NEXT: dma_start %0[%5], %1[%6], %c64, %2[%c0] : memref<100xf32>, memref<100xf32, 2>, memref<1xi32>
return
}
// CHECK-LABEL: func @affine_dma_wait
func @affine_dma_wait(%arg0 : index) {
%2 = alloc() : memref<1xi32>
%c64 = constant 64 : index
affine.for %i0 = 0 to 10 {
affine.dma_wait %2[%i0 + %arg0 + 17], %c64 : memref<1xi32>
}
// CHECK: %3 = addi %1, %arg0 : index
// CHECK-NEXT: %c17 = constant 17 : index
// CHECK-NEXT: %4 = addi %3, %c17 : index
// CHECK-NEXT: dma_wait %0[%4], %c64 : memref<1xi32>
return
}