Adds loop attribute as a temporary work around to prevent slice fusion of loop nests containing instructions with side effects (the proper solution will be do use memref read/write regions in the future).

PiperOrigin-RevId: 236733739
This commit is contained in:
MLIR Team 2019-03-04 15:14:12 -08:00 committed by jpienaar
parent 12b9dece8d
commit 39a1ddeb1c
2 changed files with 56 additions and 14 deletions

View File

@ -435,6 +435,7 @@ static Instruction *getInstAtPosition(ArrayRef<unsigned> positions,
return nullptr; return nullptr;
} }
const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier";
// Computes memref dependence between 'srcAccess' and 'dstAccess', projects // Computes memref dependence between 'srcAccess' and 'dstAccess', projects
// out any dst loop IVs at depth greater than 'dstLoopDepth', and computes slice // out any dst loop IVs at depth greater than 'dstLoopDepth', and computes slice
// bounds in 'sliceState' which represent the src IVs in terms of the dst IVs, // bounds in 'sliceState' which represent the src IVs in terms of the dst IVs,
@ -491,24 +492,28 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess,
sliceState->lbOperands.resize(numSrcLoopIVs, sliceBoundOperands); sliceState->lbOperands.resize(numSrcLoopIVs, sliceBoundOperands);
sliceState->ubOperands.resize(numSrcLoopIVs, sliceBoundOperands); sliceState->ubOperands.resize(numSrcLoopIVs, sliceBoundOperands);
// For read-read access pairs, clear any slice bounds on sequential loops. llvm::SmallDenseSet<Value *, 8> sequentialLoops;
if (readReadAccesses) { if (readReadAccesses) {
// For read-read access pairs, clear any slice bounds on sequential loops.
// Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'. // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'.
llvm::SmallDenseSet<Value *, 8> sequentialLoops;
getSequentialLoops(srcLoopIVs[0], &sequentialLoops); getSequentialLoops(srcLoopIVs[0], &sequentialLoops);
// Clear all sliced loop bounds beginning at the first sequential loop.
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
Value *iv = srcLoopIVs[i]->getInductionVar();
if (sequentialLoops.count(iv) == 0)
continue;
for (unsigned j = i; j < numSrcLoopIVs; ++j) {
sliceState->lbs[j] = AffineMap();
sliceState->ubs[j] = AffineMap();
}
break;
}
} }
// Clear all sliced loop bounds beginning at the first sequential loop, or
// first loop with a slice fusion barrier attribute..
// TODO(andydavis, bondhugula) Use MemRef read/write regions instead of
// using 'kSliceFusionBarrierAttrName'.
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
Value *iv = srcLoopIVs[i]->getInductionVar();
if (sequentialLoops.count(iv) == 0 &&
srcLoopIVs[i]->getAttr(kSliceFusionBarrierAttrName) == nullptr)
continue;
for (unsigned j = i; j < numSrcLoopIVs; ++j) {
sliceState->lbs[j] = AffineMap();
sliceState->ubs[j] = AffineMap();
}
break;
}
return true; return true;
} }

View File

@ -2062,3 +2062,40 @@ func @two_matrix_vector_products() {
// CHECK-NEXT: return // CHECK-NEXT: return
return return
} }
// -----
// CHECK-DAG: [[MAP3:#map[0-9]+]] = (d0, d1, d2) -> (-d0 + d1)
// CHECK-DAG: [[MAP4:#map[0-9]+]] = (d0, d1, d2) -> (d2)
func @should_not_slice_past_slice_barrier() {
%0 = alloc() : memref<100x16xf32>
for %i0 = 0 to 100 {
for %i1 = 0 to 16 {
%1 = "op1"() : () -> f32
store %1, %0[%i0, %i1] : memref<100x16xf32>
} {slice_fusion_barrier: true}
}
for %i2 = 0 to 100 {
for %i3 = 0 to 16 {
%2 = load %0[%i2, %i3] : memref<100x16xf32>
"op2"(%2) : (f32) -> ()
}
}
// The 'slice_fusion_barrier' attribute on '%i1' prevents slicing the
// iteration space of '%i1' and any enclosing loop nests.
// CHECK: for %i0 = 0 to 100 {
// CHECK-NEXT: for %i1 = 0 to 16 {
// CHECK-NEXT: %1 = "op1"() : () -> f32
// CHECK-NEXT: %2 = affine.apply [[MAP3]](%i0, %i0, %i1)
// CHECK-NEXT: %3 = affine.apply [[MAP4]](%i0, %i0, %i1)
// CHECK-NEXT: store %1, %0[%2, %3] : memref<1x16xf32>
// CHECK-NEXT: } {slice_fusion_barrier: true}
// CHECK-NEXT: for %i2 = 0 to 16 {
// CHECK-NEXT: %4 = affine.apply [[MAP3]](%i0, %i0, %i2)
// CHECK-NEXT: %5 = affine.apply [[MAP4]](%i0, %i0, %i2)
// CHECK-NEXT: %6 = load %0[%4, %5] : memref<1x16xf32>
// CHECK-NEXT: "op2"(%6) : (f32) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: }
return
}