Add StdIndexedValue to EDSC helpers

Add StdIndexedValue to EDSC helper so that we can use it
to generated std.load and std.store in EDSC.

Closes tensorflow/mlir#59

PiperOrigin-RevId: 261324965
This commit is contained in:
Diego Caballero 2019-08-02 08:23:48 -07:00 committed by A. Unique TensorFlower
parent 58e66d71e7
commit c19b72d3f3
2 changed files with 40 additions and 1 deletions

View File

@ -33,9 +33,12 @@ namespace edsc {
template <typename Load, typename Store> class TemplatedIndexedValue;
// By default, edsc::IndexedValue provides an index notation around the affine
// load and stores.
// load and stores. edsc::StdIndexedValue provides the standard load/store
// counterpart.
using IndexedValue =
TemplatedIndexedValue<intrinsics::affine_load, intrinsics::affine_store>;
using StdIndexedValue =
TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>;
// Base class for MemRefView and VectorView.
class View {

View File

@ -603,6 +603,42 @@ memref<?x?x?xf32>, index, index, index) -> ()
f.erase();
}
*/
// Exercise StdIndexedValue for loads and stores.
TEST_FUNC(indirect_access) {
using namespace edsc;
using namespace edsc::intrinsics;
using namespace edsc::op;
auto memrefType =
MemRefType::get({-1}, FloatType::getF32(&globalContext()), {}, 0);
auto f = makeFunction("indirect_access", {},
{memrefType, memrefType, memrefType, memrefType});
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle zero = constant_index(0);
MemRefView vC(f.getArgument(2));
IndexedValue B(f.getArgument(1)), D(f.getArgument(3));
StdIndexedValue A(f.getArgument(0)), C(f.getArgument(2));
IndexHandle i, N(vC.ub(0));
// clang-format off
LoopBuilder(&i, zero, N, 1)([&]{
C((ValueHandle)D(i)) = A((ValueHandle)B(i));
});
// clang-format on
// clang-format off
// CHECK-LABEL: func @indirect_access(
// CHECK: [[B:%.*]] = affine.load
// CHECK: [[D:%.*]] = affine.load
// CHECK: load %{{.*}}{{\[}}[[B]]{{\]}}
// CHECK: store %{{.*}}, %{{.*}}{{\[}}[[D]]{{\]}}
// clang-format on
f.print(llvm::outs());
f.erase();
}
int main() {
RUN_TESTS();
return 0;