Use "standard" load and stores in LowerVectorTransfers

Clipping creates non-affine memory accesses, use std_load and std_store instead of affine_load and affine_store.
In the future we may also want a fill with the neutral element rather than clip, this would make the accesses affine if we wanted more analyses and transformations to happen post lowering to pointwise copies.

PiperOrigin-RevId: 260110503
This commit is contained in:
Nicolas Vasilache 2019-07-26 02:33:58 -07:00 committed by A. Unique TensorFlower
parent 0f1624697b
commit fae4d94990
3 changed files with 12 additions and 6 deletions

View File

@ -198,6 +198,8 @@ using dim = ValueBuilder<DimOp>;
using muli = ValueBuilder<MulIOp>;
using ret = OperationBuilder<ReturnOp>;
using select = ValueBuilder<SelectOp>;
using std_load = ValueBuilder<LoadOp>;
using std_store = OperationBuilder<StoreOp>;
using subi = ValueBuilder<SubIOp>;
using vector_type_cast = ValueBuilder<VectorTypeCastOp>;

View File

@ -263,6 +263,8 @@ VectorTransferRewriter<VectorTransferReadOp>::matchAndRewrite(
using namespace mlir::edsc;
using namespace mlir::edsc::op;
using namespace mlir::edsc::intrinsics;
using IndexedValue =
TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>;
VectorTransferReadOp transfer = cast<VectorTransferReadOp>(op);
@ -289,7 +291,7 @@ VectorTransferRewriter<VectorTransferReadOp>::matchAndRewrite(
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
local(ivs) = remote(clip(transfer, view, ivs));
});
ValueHandle vectorValue = affine_load(vec, {constant_index(0)});
ValueHandle vectorValue = std_load(vec, {constant_index(0)});
(dealloc(tmp)); // vexing parse
// 3. Propagate.
@ -322,6 +324,8 @@ VectorTransferRewriter<VectorTransferWriteOp>::matchAndRewrite(
using namespace mlir::edsc;
using namespace mlir::edsc::op;
using namespace mlir::edsc::intrinsics;
using IndexedValue =
TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>;
VectorTransferWriteOp transfer = cast<VectorTransferWriteOp>(op);
@ -345,7 +349,7 @@ VectorTransferRewriter<VectorTransferWriteOp>::matchAndRewrite(
ValueHandle tmp = alloc(tmpMemRefType(transfer));
IndexedValue local(tmp);
ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer));
affine_store(vectorValue, vec, {constant_index(0)});
std_store(vectorValue, vec, {constant_index(0)});
LoopNestBuilder(pivs, lbs, ubs, steps)([&] {
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
remote(clip(transfer, view, ivs)) = local(ivs);

View File

@ -20,7 +20,7 @@ func @materialize_read_1d() {
// CHECK: %[[FILTERED1:.*]] = select
// CHECK: {{.*}} = select
// CHECK: %[[FILTERED2:.*]] = select
// CHECK-NEXT: %{{.*}} = affine.load {{.*}}[%[[FILTERED1]], %[[FILTERED2]]] : memref<7x42xf32>
// CHECK-NEXT: %{{.*}} = load {{.*}}[%[[FILTERED1]], %[[FILTERED2]]] : memref<7x42xf32>
}
}
return
@ -94,12 +94,12 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) {
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
// CHECK-NEXT: %[[L3:.*]] = select
//
// CHECK-NEXT: {{.*}} = affine.load %{{.*}}[%[[L0]], %[[L1]], %[[L2]], %[[L3]]] : memref<?x?x?x?xf32>
// CHECK-NEXT: {{.*}} = load %{{.*}}[%[[L0]], %[[L1]], %[[L2]], %[[L3]]] : memref<?x?x?x?xf32>
// CHECK-NEXT: store {{.*}}, %[[ALLOC]][%[[I6]], %[[I5]], %[[I4]]] : memref<5x4x3xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK: {{.*}} = affine.load %[[VECTOR_VIEW]][{{.*}}] : memref<1xvector<5x4x3xf32>>
// CHECK: {{.*}} = load %[[VECTOR_VIEW]][{{.*}}] : memref<1xvector<5x4x3xf32>>
// CHECK-NEXT: dealloc %[[ALLOC]] : memref<5x4x3xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
@ -170,7 +170,7 @@ func @materialize_write(%M: index, %N: index, %O: index, %P: index) {
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
// CHECK-NEXT: %[[S3:.*]] = select {{.*}}, %[[C0]], {{.*}} : index
//
// CHECK-NEXT: {{.*}} = affine.load {{.*}}[%[[I6]], %[[I5]], %[[I4]]] : memref<5x4x3xf32>
// CHECK-NEXT: {{.*}} = load {{.*}}[%[[I6]], %[[I5]], %[[I4]]] : memref<5x4x3xf32>
// CHECK: store {{.*}}, {{.*}}[%[[S0]], %[[S1]], %[[S2]], %[[S3]]] : memref<?x?x?x?xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }