forked from OSchip/llvm-project
Add affine-to-standard lowerings for affine.load/store/dma_start/dma_wait.
PiperOrigin-RevId: 255960171
This commit is contained in:
parent
2dc5e19426
commit
f487d20bf0
|
@ -597,11 +597,140 @@ public:
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Apply the affine map from an 'affine.load' operation to its operands, and
|
||||||
|
// feed the results to a newly created 'std.load' operation (which replaces the
|
||||||
|
// original 'affine.load').
|
||||||
|
class AffineLoadLowering : public ConversionPattern {
|
||||||
|
public:
|
||||||
|
AffineLoadLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(AffineLoadOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
|
virtual PatternMatchResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto affineLoadOp = cast<AffineLoadOp>(op);
|
||||||
|
// Expand affine map from 'affineLoadOp'.
|
||||||
|
auto maybeExpandedMap =
|
||||||
|
expandAffineMap(rewriter, op->getLoc(), affineLoadOp.getAffineMap(),
|
||||||
|
operands.drop_front());
|
||||||
|
if (!maybeExpandedMap)
|
||||||
|
return matchFailure();
|
||||||
|
// Build std.load memref[expandedMap.results].
|
||||||
|
rewriter.replaceOpWithNewOp<LoadOp>(op, operands[0], *maybeExpandedMap);
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Apply the affine map from an 'affine.store' operation to its operands, and
|
||||||
|
// feed the results to a newly created 'std.store' operation (which replaces the
|
||||||
|
// original 'affine.store').
|
||||||
|
class AffineStoreLowering : public ConversionPattern {
|
||||||
|
public:
|
||||||
|
AffineStoreLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(AffineStoreOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
|
virtual PatternMatchResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto affineStoreOp = cast<AffineStoreOp>(op);
|
||||||
|
// Expand affine map from 'affineStoreOp'.
|
||||||
|
auto maybeExpandedMap =
|
||||||
|
expandAffineMap(rewriter, op->getLoc(), affineStoreOp.getAffineMap(),
|
||||||
|
operands.drop_front(2));
|
||||||
|
if (!maybeExpandedMap)
|
||||||
|
return matchFailure();
|
||||||
|
// Build std.store valutToStore, memref[expandedMap.results].
|
||||||
|
rewriter.replaceOpWithNewOp<StoreOp>(op, operands[0], operands[1],
|
||||||
|
*maybeExpandedMap);
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Apply the affine maps from an 'affine.dma_start' operation to each of their
|
||||||
|
// respective map operands, and feed the results to a newly created
|
||||||
|
// 'std.dma_start' operation (which replaces the original 'affine.dma_start').
|
||||||
|
class AffineDmaStartLowering : public ConversionPattern {
|
||||||
|
public:
|
||||||
|
AffineDmaStartLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(AffineDmaStartOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
|
virtual PatternMatchResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto affineDmaStartOp = cast<AffineDmaStartOp>(op);
|
||||||
|
// Expand affine map for DMA source memref.
|
||||||
|
auto maybeExpandedSrcMap = expandAffineMap(
|
||||||
|
rewriter, op->getLoc(), affineDmaStartOp.getSrcMap(),
|
||||||
|
operands.drop_front(affineDmaStartOp.getSrcMemRefOperandIndex() + 1));
|
||||||
|
if (!maybeExpandedSrcMap)
|
||||||
|
return matchFailure();
|
||||||
|
// Expand affine map for DMA destination memref.
|
||||||
|
auto maybeExpandedDstMap = expandAffineMap(
|
||||||
|
rewriter, op->getLoc(), affineDmaStartOp.getDstMap(),
|
||||||
|
operands.drop_front(affineDmaStartOp.getDstMemRefOperandIndex() + 1));
|
||||||
|
if (!maybeExpandedDstMap)
|
||||||
|
return matchFailure();
|
||||||
|
// Expand affine map for DMA tag memref.
|
||||||
|
auto maybeExpandedTagMap = expandAffineMap(
|
||||||
|
rewriter, op->getLoc(), affineDmaStartOp.getTagMap(),
|
||||||
|
operands.drop_front(affineDmaStartOp.getTagMemRefOperandIndex() + 1));
|
||||||
|
if (!maybeExpandedTagMap)
|
||||||
|
return matchFailure();
|
||||||
|
|
||||||
|
// Build std.dma_start operation with affine map results.
|
||||||
|
auto *srcMemRef = operands[affineDmaStartOp.getSrcMemRefOperandIndex()];
|
||||||
|
auto *dstMemRef = operands[affineDmaStartOp.getDstMemRefOperandIndex()];
|
||||||
|
auto *tagMemRef = operands[affineDmaStartOp.getTagMemRefOperandIndex()];
|
||||||
|
unsigned numElementsIndex = affineDmaStartOp.getTagMemRefOperandIndex() +
|
||||||
|
1 + affineDmaStartOp.getTagMap().getNumInputs();
|
||||||
|
auto *numElements = operands[numElementsIndex];
|
||||||
|
auto *stride =
|
||||||
|
affineDmaStartOp.isStrided() ? operands[numElementsIndex + 1] : nullptr;
|
||||||
|
auto *eltsPerStride =
|
||||||
|
affineDmaStartOp.isStrided() ? operands[numElementsIndex + 2] : nullptr;
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<DmaStartOp>(
|
||||||
|
op, srcMemRef, *maybeExpandedSrcMap, dstMemRef, *maybeExpandedDstMap,
|
||||||
|
numElements, tagMemRef, *maybeExpandedTagMap, stride, eltsPerStride);
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Apply the affine map from an 'affine.dma_wait' operation tag memref,
|
||||||
|
// and feed the results to a newly created 'std.dma_wait' operation (which
|
||||||
|
// replaces the original 'affine.dma_wait').
|
||||||
|
class AffineDmaWaitLowering : public ConversionPattern {
|
||||||
|
public:
|
||||||
|
AffineDmaWaitLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(AffineDmaWaitOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
|
virtual PatternMatchResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto affineDmaWaitOp = cast<AffineDmaWaitOp>(op);
|
||||||
|
// Expand affine map for DMA tag memref.
|
||||||
|
auto maybeExpandedTagMap =
|
||||||
|
expandAffineMap(rewriter, op->getLoc(), affineDmaWaitOp.getTagMap(),
|
||||||
|
operands.drop_front());
|
||||||
|
if (!maybeExpandedTagMap)
|
||||||
|
return matchFailure();
|
||||||
|
|
||||||
|
// Build std.dma_wait operation with affine map results.
|
||||||
|
unsigned numElementsIndex = 1 + affineDmaWaitOp.getTagMap().getNumInputs();
|
||||||
|
rewriter.replaceOpWithNewOp<DmaWaitOp>(
|
||||||
|
op, operands[0], *maybeExpandedTagMap, operands[numElementsIndex]);
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // end namespace
|
} // end namespace
|
||||||
|
|
||||||
LogicalResult mlir::lowerAffineConstructs(Function &function) {
|
LogicalResult mlir::lowerAffineConstructs(Function &function) {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
RewriteListBuilder<AffineApplyLowering, AffineForLowering, AffineIfLowering,
|
RewriteListBuilder<AffineApplyLowering, AffineDmaStartLowering,
|
||||||
|
AffineDmaWaitLowering, AffineLoadLowering,
|
||||||
|
AffineStoreLowering, AffineForLowering, AffineIfLowering,
|
||||||
AffineTerminatorLowering>::build(patterns,
|
AffineTerminatorLowering>::build(patterns,
|
||||||
function.getContext());
|
function.getContext());
|
||||||
ConversionTarget target(*function.getContext());
|
ConversionTarget target(*function.getContext());
|
||||||
|
|
|
@ -637,3 +637,65 @@ func @affine_apply_ceildiv(%arg0 : index) -> (index) {
|
||||||
%0 = affine.apply #mapceildiv (%arg0)
|
%0 = affine.apply #mapceildiv (%arg0)
|
||||||
return %0 : index
|
return %0 : index
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @affine_load
|
||||||
|
func @affine_load(%arg0 : index) {
|
||||||
|
%0 = alloc() : memref<10xf32>
|
||||||
|
affine.for %i0 = 0 to 10 {
|
||||||
|
%1 = affine.load %0[%i0 + symbol(%arg0) + 7] : memref<10xf32>
|
||||||
|
}
|
||||||
|
// CHECK: %3 = addi %1, %arg0 : index
|
||||||
|
// CHECK-NEXT: %c7 = constant 7 : index
|
||||||
|
// CHECK-NEXT: %4 = addi %3, %c7 : index
|
||||||
|
// CHECK-NEXT: %5 = load %0[%4] : memref<10xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @affine_store
|
||||||
|
func @affine_store(%arg0 : index) {
|
||||||
|
%0 = alloc() : memref<10xf32>
|
||||||
|
%1 = constant 11.0 : f32
|
||||||
|
affine.for %i0 = 0 to 10 {
|
||||||
|
affine.store %1, %0[%i0 - symbol(%arg0) + 7] : memref<10xf32>
|
||||||
|
}
|
||||||
|
// CHECK: %c-1 = constant -1 : index
|
||||||
|
// CHECK-NEXT: %3 = muli %arg0, %c-1 : index
|
||||||
|
// CHECK-NEXT: %4 = addi %1, %3 : index
|
||||||
|
// CHECK-NEXT: %c7 = constant 7 : index
|
||||||
|
// CHECK-NEXT: %5 = addi %4, %c7 : index
|
||||||
|
// CHECK-NEXT: store %cst, %0[%5] : memref<10xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @affine_dma_start
|
||||||
|
func @affine_dma_start(%arg0 : index) {
|
||||||
|
%0 = alloc() : memref<100xf32>
|
||||||
|
%1 = alloc() : memref<100xf32, 2>
|
||||||
|
%2 = alloc() : memref<1xi32>
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%c64 = constant 64 : index
|
||||||
|
affine.for %i0 = 0 to 10 {
|
||||||
|
affine.dma_start %0[%i0 + 7], %1[%arg0 + 11], %2[%c0], %c64
|
||||||
|
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32>
|
||||||
|
}
|
||||||
|
// CHECK: %c7 = constant 7 : index
|
||||||
|
// CHECK-NEXT: %5 = addi %3, %c7 : index
|
||||||
|
// CHECK-NEXT: %c11 = constant 11 : index
|
||||||
|
// CHECK-NEXT: %6 = addi %arg0, %c11 : index
|
||||||
|
// CHECK-NEXT: dma_start %0[%5], %1[%6], %c64, %2[%c0] : memref<100xf32>, memref<100xf32, 2>, memref<1xi32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @affine_dma_wait
|
||||||
|
func @affine_dma_wait(%arg0 : index) {
|
||||||
|
%2 = alloc() : memref<1xi32>
|
||||||
|
%c64 = constant 64 : index
|
||||||
|
affine.for %i0 = 0 to 10 {
|
||||||
|
affine.dma_wait %2[%i0 + %arg0 + 17], %c64 : memref<1xi32>
|
||||||
|
}
|
||||||
|
// CHECK: %3 = addi %1, %arg0 : index
|
||||||
|
// CHECK-NEXT: %c17 = constant 17 : index
|
||||||
|
// CHECK-NEXT: %4 = addi %3, %c17 : index
|
||||||
|
// CHECK-NEXT: dma_wait %0[%4], %c64 : memref<1xi32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue