Fix all-reduce int tests by host-registering memrefs.

Reduce amount of boiler plate to register host memory.

Summary: Fix all-reduce int tests by host-registering memrefs.

Reviewers: herhut

Reviewed By: herhut

Subscribers: clementval, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D76563
This commit is contained in:
Christian Sigg 2020-03-22 20:18:23 +01:00
parent ea64ee0edb
commit b43ae21e60
7 changed files with 110 additions and 36 deletions

View File

@ -2,9 +2,7 @@
func @main() {
%data = alloc() : memref<2x6xi32>
%sum_and = alloc() : memref<2xi32>
%sum_or = alloc() : memref<2xi32>
%sum_min = alloc() : memref<2xi32>
%sum = alloc() : memref<2xi32>
%cst0 = constant 0 : i32
%cst1 = constant 1 : i32
%cst2 = constant 2 : i32
@ -25,7 +23,12 @@ func @main() {
%c4 = constant 4 : index
%c5 = constant 5 : index
%c6 = constant 6 : index
%cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
%cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
store %cst2, %data[%c0, %c2] : memref<2x6xi32>
@ -44,17 +47,19 @@ func @main() {
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c2, %grid_y = %c1, %grid_z = %c1)
threads(%tx, %ty, %tz) in (%block_x = %c6, %block_y = %c1, %block_z = %c1) {
%val = load %data[%bx, %tx] : memref<2x6xi32>
%reduced_and = "gpu.all_reduce"(%val) ({}) { op = "and" } : (i32) -> (i32)
store %reduced_and, %sum_and[%bx] : memref<2xi32>
%reduced = "gpu.all_reduce"(%val) ({}) { op = "and" } : (i32) -> (i32)
store %reduced, %sum[%bx] : memref<2xi32>
gpu.terminator
}
%ptr_and = memref_cast %sum_and : memref<2xi32> to memref<*xi32>
call @print_memref_i32(%ptr_and) : (memref<*xi32>) -> ()
%ptr = memref_cast %sum : memref<2xi32> to memref<*xi32>
call @print_memref_i32(%ptr) : (memref<*xi32>) -> ()
// CHECK: [0, 2]
return
}
func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
func @print_memref_i32(memref<*xi32>)

View File

@ -23,7 +23,12 @@ func @main() {
%c4 = constant 4 : index
%c5 = constant 5 : index
%c6 = constant 6 : index
%cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
%cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
store %cst2, %data[%c0, %c2] : memref<2x6xi32>
@ -54,5 +59,7 @@ func @main() {
return
}
func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
func @print_memref_i32(memref<*xi32>)

View File

@ -23,7 +23,12 @@ func @main() {
%c4 = constant 4 : index
%c5 = constant 5 : index
%c6 = constant 6 : index
%cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
%cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
store %cst2, %data[%c0, %c2] : memref<2x6xi32>
@ -54,5 +59,7 @@ func @main() {
return
}
func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
func @print_memref_i32(memref<*xi32>)

View File

@ -23,7 +23,12 @@ func @main() {
%c4 = constant 4 : index
%c5 = constant 5 : index
%c6 = constant 6 : index
%cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
%cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
store %cst2, %data[%c0, %c2] : memref<2x6xi32>
@ -54,5 +59,7 @@ func @main() {
return
}
func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
func @print_memref_i32(memref<*xi32>)

View File

@ -23,7 +23,12 @@ func @main() {
%c4 = constant 4 : index
%c5 = constant 5 : index
%c6 = constant 6 : index
%cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
%cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
store %cst2, %data[%c0, %c2] : memref<2x6xi32>
@ -54,5 +59,7 @@ func @main() {
return
}
func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
func @print_memref_i32(memref<*xi32>)

View File

@ -25,6 +25,13 @@ func @main() {
%c5 = constant 5 : index
%c6 = constant 6 : index
%cast_data = memref_cast %data : memref<2x6xf32> to memref<?x?xf32>
call @mcuMemHostRegisterMemRef2dFloat(%cast_data) : (memref<?x?xf32>) -> ()
%cast_sum = memref_cast %sum : memref<2xf32> to memref<?xf32>
call @mcuMemHostRegisterMemRef1dFloat(%cast_sum) : (memref<?xf32>) -> ()
%cast_mul = memref_cast %mul : memref<2xf32> to memref<?xf32>
call @mcuMemHostRegisterMemRef1dFloat(%cast_mul) : (memref<?xf32>) -> ()
store %cst0, %data[%c0, %c0] : memref<2x6xf32>
store %cst1, %data[%c0, %c1] : memref<2x6xf32>
store %cst2, %data[%c0, %c2] : memref<2x6xf32>
@ -61,4 +68,6 @@ func @main() {
return
}
func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
func @mcuMemHostRegisterMemRef2dFloat(%ptr : memref<?x?xf32>)
func @print_memref_f32(memref<*xf32>)

View File

@ -15,6 +15,7 @@
#include <cassert>
#include <numeric>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/raw_ostream.h"
#include "cuda.h"
@ -89,24 +90,39 @@ template <typename T, int N> struct MemRefType {
// Allows to register a MemRef with the CUDA runtime. Initializes array with
// value. Helpful until we have transfer functions implemented.
template <typename T, int N>
void mcuMemHostRegisterMemRef(const MemRefType<T, N> *arg, T value) {
auto count = std::accumulate(arg->sizes, arg->sizes + N, 1,
std::multiplies<int64_t>());
std::fill_n(arg->data, count, value);
mcuMemHostRegister(arg->data, count * sizeof(T));
template <typename T>
void mcuMemHostRegisterMemRef(T *pointer, llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides, T value) {
assert(sizes.size() == strides.size());
llvm::SmallVector<int64_t, 4> denseStrides(strides.size());
std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(),
std::multiplies<int64_t>());
auto count = denseStrides.front();
// Only densely packed tensors are currently supported.
std::rotate(denseStrides.begin(), denseStrides.begin() + 1,
denseStrides.end());
denseStrides.back() = 1;
assert(strides == llvm::makeArrayRef(denseStrides));
std::fill_n(pointer, count, value);
mcuMemHostRegister(pointer, count * sizeof(T));
}
extern "C" void mcuMemHostRegisterMemRef1dFloat(float *allocated,
float *aligned, int64_t offset,
int64_t size, int64_t stride) {
MemRefType<float, 1> descriptor;
descriptor.basePtr = allocated;
descriptor.data = aligned;
descriptor.offset = offset;
descriptor.sizes[0] = size;
descriptor.strides[0] = stride;
mcuMemHostRegisterMemRef(&descriptor, 1.23f);
mcuMemHostRegisterMemRef(aligned + offset, {size}, {stride}, 1.23f);
}
extern "C" void mcuMemHostRegisterMemRef2dFloat(float *allocated,
float *aligned, int64_t offset,
int64_t size0, int64_t size1,
int64_t stride0,
int64_t stride1) {
mcuMemHostRegisterMemRef(aligned + offset, {size0, size1}, {stride0, stride1},
1.23f);
}
extern "C" void mcuMemHostRegisterMemRef3dFloat(float *allocated,
@ -115,15 +131,31 @@ extern "C" void mcuMemHostRegisterMemRef3dFloat(float *allocated,
int64_t size2, int64_t stride0,
int64_t stride1,
int64_t stride2) {
MemRefType<float, 3> descriptor;
descriptor.basePtr = allocated;
descriptor.data = aligned;
descriptor.offset = offset;
descriptor.sizes[0] = size0;
descriptor.strides[0] = stride0;
descriptor.sizes[1] = size1;
descriptor.strides[1] = stride1;
descriptor.sizes[2] = size2;
descriptor.strides[2] = stride2;
mcuMemHostRegisterMemRef(&descriptor, 1.23f);
mcuMemHostRegisterMemRef(aligned + offset, {size0, size1, size2},
{stride0, stride1, stride2}, 1.23f);
}
extern "C" void mcuMemHostRegisterMemRef1dInt32(int32_t *allocated,
int32_t *aligned,
int64_t offset, int64_t size,
int64_t stride) {
mcuMemHostRegisterMemRef(aligned + offset, {size}, {stride}, 123);
}
extern "C" void mcuMemHostRegisterMemRef2dInt32(int32_t *allocated,
int32_t *aligned,
int64_t offset, int64_t size0,
int64_t size1, int64_t stride0,
int64_t stride1) {
mcuMemHostRegisterMemRef(aligned + offset, {size0, size1}, {stride0, stride1},
123);
}
extern "C" void
mcuMemHostRegisterMemRef3dInt32(int32_t *allocated, int32_t *aligned,
int64_t offset, int64_t size0, int64_t size1,
int64_t size2, int64_t stride0, int64_t stride1,
int64_t stride2) {
mcuMemHostRegisterMemRef(aligned + offset, {size0, size1, size2},
{stride0, stride1, stride2}, 123);
}