forked from OSchip/llvm-project
Add affine-to-standard lowerings for affine.load/store/dma_start/dma_wait.
PiperOrigin-RevId: 255960171
This commit is contained in:
parent
2dc5e19426
commit
f487d20bf0
|
@ -597,11 +597,140 @@ public:
|
|||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// Apply the affine map from an 'affine.load' operation to its operands, and
|
||||
// feed the results to a newly created 'std.load' operation (which replaces the
|
||||
// original 'affine.load').
|
||||
class AffineLoadLowering : public ConversionPattern {
|
||||
public:
|
||||
AffineLoadLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(AffineLoadOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
virtual PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto affineLoadOp = cast<AffineLoadOp>(op);
|
||||
// Expand affine map from 'affineLoadOp'.
|
||||
auto maybeExpandedMap =
|
||||
expandAffineMap(rewriter, op->getLoc(), affineLoadOp.getAffineMap(),
|
||||
operands.drop_front());
|
||||
if (!maybeExpandedMap)
|
||||
return matchFailure();
|
||||
// Build std.load memref[expandedMap.results].
|
||||
rewriter.replaceOpWithNewOp<LoadOp>(op, operands[0], *maybeExpandedMap);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// Apply the affine map from an 'affine.store' operation to its operands, and
|
||||
// feed the results to a newly created 'std.store' operation (which replaces the
|
||||
// original 'affine.store').
|
||||
class AffineStoreLowering : public ConversionPattern {
|
||||
public:
|
||||
AffineStoreLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(AffineStoreOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
virtual PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto affineStoreOp = cast<AffineStoreOp>(op);
|
||||
// Expand affine map from 'affineStoreOp'.
|
||||
auto maybeExpandedMap =
|
||||
expandAffineMap(rewriter, op->getLoc(), affineStoreOp.getAffineMap(),
|
||||
operands.drop_front(2));
|
||||
if (!maybeExpandedMap)
|
||||
return matchFailure();
|
||||
// Build std.store valutToStore, memref[expandedMap.results].
|
||||
rewriter.replaceOpWithNewOp<StoreOp>(op, operands[0], operands[1],
|
||||
*maybeExpandedMap);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// Apply the affine maps from an 'affine.dma_start' operation to each of their
|
||||
// respective map operands, and feed the results to a newly created
|
||||
// 'std.dma_start' operation (which replaces the original 'affine.dma_start').
|
||||
class AffineDmaStartLowering : public ConversionPattern {
|
||||
public:
|
||||
AffineDmaStartLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(AffineDmaStartOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
virtual PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto affineDmaStartOp = cast<AffineDmaStartOp>(op);
|
||||
// Expand affine map for DMA source memref.
|
||||
auto maybeExpandedSrcMap = expandAffineMap(
|
||||
rewriter, op->getLoc(), affineDmaStartOp.getSrcMap(),
|
||||
operands.drop_front(affineDmaStartOp.getSrcMemRefOperandIndex() + 1));
|
||||
if (!maybeExpandedSrcMap)
|
||||
return matchFailure();
|
||||
// Expand affine map for DMA destination memref.
|
||||
auto maybeExpandedDstMap = expandAffineMap(
|
||||
rewriter, op->getLoc(), affineDmaStartOp.getDstMap(),
|
||||
operands.drop_front(affineDmaStartOp.getDstMemRefOperandIndex() + 1));
|
||||
if (!maybeExpandedDstMap)
|
||||
return matchFailure();
|
||||
// Expand affine map for DMA tag memref.
|
||||
auto maybeExpandedTagMap = expandAffineMap(
|
||||
rewriter, op->getLoc(), affineDmaStartOp.getTagMap(),
|
||||
operands.drop_front(affineDmaStartOp.getTagMemRefOperandIndex() + 1));
|
||||
if (!maybeExpandedTagMap)
|
||||
return matchFailure();
|
||||
|
||||
// Build std.dma_start operation with affine map results.
|
||||
auto *srcMemRef = operands[affineDmaStartOp.getSrcMemRefOperandIndex()];
|
||||
auto *dstMemRef = operands[affineDmaStartOp.getDstMemRefOperandIndex()];
|
||||
auto *tagMemRef = operands[affineDmaStartOp.getTagMemRefOperandIndex()];
|
||||
unsigned numElementsIndex = affineDmaStartOp.getTagMemRefOperandIndex() +
|
||||
1 + affineDmaStartOp.getTagMap().getNumInputs();
|
||||
auto *numElements = operands[numElementsIndex];
|
||||
auto *stride =
|
||||
affineDmaStartOp.isStrided() ? operands[numElementsIndex + 1] : nullptr;
|
||||
auto *eltsPerStride =
|
||||
affineDmaStartOp.isStrided() ? operands[numElementsIndex + 2] : nullptr;
|
||||
|
||||
rewriter.replaceOpWithNewOp<DmaStartOp>(
|
||||
op, srcMemRef, *maybeExpandedSrcMap, dstMemRef, *maybeExpandedDstMap,
|
||||
numElements, tagMemRef, *maybeExpandedTagMap, stride, eltsPerStride);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// Apply the affine map from an 'affine.dma_wait' operation tag memref,
|
||||
// and feed the results to a newly created 'std.dma_wait' operation (which
|
||||
// replaces the original 'affine.dma_wait').
|
||||
class AffineDmaWaitLowering : public ConversionPattern {
|
||||
public:
|
||||
AffineDmaWaitLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(AffineDmaWaitOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
virtual PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto affineDmaWaitOp = cast<AffineDmaWaitOp>(op);
|
||||
// Expand affine map for DMA tag memref.
|
||||
auto maybeExpandedTagMap =
|
||||
expandAffineMap(rewriter, op->getLoc(), affineDmaWaitOp.getTagMap(),
|
||||
operands.drop_front());
|
||||
if (!maybeExpandedTagMap)
|
||||
return matchFailure();
|
||||
|
||||
// Build std.dma_wait operation with affine map results.
|
||||
unsigned numElementsIndex = 1 + affineDmaWaitOp.getTagMap().getNumInputs();
|
||||
rewriter.replaceOpWithNewOp<DmaWaitOp>(
|
||||
op, operands[0], *maybeExpandedTagMap, operands[numElementsIndex]);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace
|
||||
|
||||
LogicalResult mlir::lowerAffineConstructs(Function &function) {
|
||||
OwningRewritePatternList patterns;
|
||||
RewriteListBuilder<AffineApplyLowering, AffineForLowering, AffineIfLowering,
|
||||
RewriteListBuilder<AffineApplyLowering, AffineDmaStartLowering,
|
||||
AffineDmaWaitLowering, AffineLoadLowering,
|
||||
AffineStoreLowering, AffineForLowering, AffineIfLowering,
|
||||
AffineTerminatorLowering>::build(patterns,
|
||||
function.getContext());
|
||||
ConversionTarget target(*function.getContext());
|
||||
|
|
|
@ -637,3 +637,65 @@ func @affine_apply_ceildiv(%arg0 : index) -> (index) {
|
|||
%0 = affine.apply #mapceildiv (%arg0)
|
||||
return %0 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @affine_load
|
||||
func @affine_load(%arg0 : index) {
|
||||
%0 = alloc() : memref<10xf32>
|
||||
affine.for %i0 = 0 to 10 {
|
||||
%1 = affine.load %0[%i0 + symbol(%arg0) + 7] : memref<10xf32>
|
||||
}
|
||||
// CHECK: %3 = addi %1, %arg0 : index
|
||||
// CHECK-NEXT: %c7 = constant 7 : index
|
||||
// CHECK-NEXT: %4 = addi %3, %c7 : index
|
||||
// CHECK-NEXT: %5 = load %0[%4] : memref<10xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @affine_store
|
||||
func @affine_store(%arg0 : index) {
|
||||
%0 = alloc() : memref<10xf32>
|
||||
%1 = constant 11.0 : f32
|
||||
affine.for %i0 = 0 to 10 {
|
||||
affine.store %1, %0[%i0 - symbol(%arg0) + 7] : memref<10xf32>
|
||||
}
|
||||
// CHECK: %c-1 = constant -1 : index
|
||||
// CHECK-NEXT: %3 = muli %arg0, %c-1 : index
|
||||
// CHECK-NEXT: %4 = addi %1, %3 : index
|
||||
// CHECK-NEXT: %c7 = constant 7 : index
|
||||
// CHECK-NEXT: %5 = addi %4, %c7 : index
|
||||
// CHECK-NEXT: store %cst, %0[%5] : memref<10xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @affine_dma_start
|
||||
func @affine_dma_start(%arg0 : index) {
|
||||
%0 = alloc() : memref<100xf32>
|
||||
%1 = alloc() : memref<100xf32, 2>
|
||||
%2 = alloc() : memref<1xi32>
|
||||
%c0 = constant 0 : index
|
||||
%c64 = constant 64 : index
|
||||
affine.for %i0 = 0 to 10 {
|
||||
affine.dma_start %0[%i0 + 7], %1[%arg0 + 11], %2[%c0], %c64
|
||||
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32>
|
||||
}
|
||||
// CHECK: %c7 = constant 7 : index
|
||||
// CHECK-NEXT: %5 = addi %3, %c7 : index
|
||||
// CHECK-NEXT: %c11 = constant 11 : index
|
||||
// CHECK-NEXT: %6 = addi %arg0, %c11 : index
|
||||
// CHECK-NEXT: dma_start %0[%5], %1[%6], %c64, %2[%c0] : memref<100xf32>, memref<100xf32, 2>, memref<1xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @affine_dma_wait
|
||||
func @affine_dma_wait(%arg0 : index) {
|
||||
%2 = alloc() : memref<1xi32>
|
||||
%c64 = constant 64 : index
|
||||
affine.for %i0 = 0 to 10 {
|
||||
affine.dma_wait %2[%i0 + %arg0 + 17], %c64 : memref<1xi32>
|
||||
}
|
||||
// CHECK: %3 = addi %1, %arg0 : index
|
||||
// CHECK-NEXT: %c17 = constant 17 : index
|
||||
// CHECK-NEXT: %4 = addi %3, %c17 : index
|
||||
// CHECK-NEXT: dma_wait %0[%4], %c64 : memref<1xi32>
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue