forked from OSchip/llvm-project
Fix all-reduce int tests by host-registering memrefs.
Reduce amount of boiler plate to register host memory. Summary: Fix all-reduce int tests by host-registering memrefs. Reviewers: herhut Reviewed By: herhut Subscribers: clementval, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D76563
This commit is contained in:
parent
ea64ee0edb
commit
b43ae21e60
|
@ -2,9 +2,7 @@
|
|||
|
||||
func @main() {
|
||||
%data = alloc() : memref<2x6xi32>
|
||||
%sum_and = alloc() : memref<2xi32>
|
||||
%sum_or = alloc() : memref<2xi32>
|
||||
%sum_min = alloc() : memref<2xi32>
|
||||
%sum = alloc() : memref<2xi32>
|
||||
%cst0 = constant 0 : i32
|
||||
%cst1 = constant 1 : i32
|
||||
%cst2 = constant 2 : i32
|
||||
|
@ -25,7 +23,12 @@ func @main() {
|
|||
%c4 = constant 4 : index
|
||||
%c5 = constant 5 : index
|
||||
%c6 = constant 6 : index
|
||||
|
||||
|
||||
%cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
|
||||
call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
|
||||
%cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
|
||||
call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
|
||||
|
||||
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
|
||||
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
|
||||
store %cst2, %data[%c0, %c2] : memref<2x6xi32>
|
||||
|
@ -44,17 +47,19 @@ func @main() {
|
|||
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c2, %grid_y = %c1, %grid_z = %c1)
|
||||
threads(%tx, %ty, %tz) in (%block_x = %c6, %block_y = %c1, %block_z = %c1) {
|
||||
%val = load %data[%bx, %tx] : memref<2x6xi32>
|
||||
%reduced_and = "gpu.all_reduce"(%val) ({}) { op = "and" } : (i32) -> (i32)
|
||||
store %reduced_and, %sum_and[%bx] : memref<2xi32>
|
||||
%reduced = "gpu.all_reduce"(%val) ({}) { op = "and" } : (i32) -> (i32)
|
||||
store %reduced, %sum[%bx] : memref<2xi32>
|
||||
gpu.terminator
|
||||
}
|
||||
|
||||
%ptr_and = memref_cast %sum_and : memref<2xi32> to memref<*xi32>
|
||||
call @print_memref_i32(%ptr_and) : (memref<*xi32>) -> ()
|
||||
%ptr = memref_cast %sum : memref<2xi32> to memref<*xi32>
|
||||
call @print_memref_i32(%ptr) : (memref<*xi32>) -> ()
|
||||
// CHECK: [0, 2]
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
|
||||
func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
|
||||
func @print_memref_i32(memref<*xi32>)
|
||||
|
||||
|
|
|
@ -23,7 +23,12 @@ func @main() {
|
|||
%c4 = constant 4 : index
|
||||
%c5 = constant 5 : index
|
||||
%c6 = constant 6 : index
|
||||
|
||||
|
||||
%cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
|
||||
call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
|
||||
%cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
|
||||
call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
|
||||
|
||||
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
|
||||
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
|
||||
store %cst2, %data[%c0, %c2] : memref<2x6xi32>
|
||||
|
@ -54,5 +59,7 @@ func @main() {
|
|||
return
|
||||
}
|
||||
|
||||
func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
|
||||
func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
|
||||
func @print_memref_i32(memref<*xi32>)
|
||||
|
||||
|
|
|
@ -23,7 +23,12 @@ func @main() {
|
|||
%c4 = constant 4 : index
|
||||
%c5 = constant 5 : index
|
||||
%c6 = constant 6 : index
|
||||
|
||||
|
||||
%cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
|
||||
call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
|
||||
%cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
|
||||
call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
|
||||
|
||||
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
|
||||
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
|
||||
store %cst2, %data[%c0, %c2] : memref<2x6xi32>
|
||||
|
@ -54,5 +59,7 @@ func @main() {
|
|||
return
|
||||
}
|
||||
|
||||
func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
|
||||
func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
|
||||
func @print_memref_i32(memref<*xi32>)
|
||||
|
||||
|
|
|
@ -23,7 +23,12 @@ func @main() {
|
|||
%c4 = constant 4 : index
|
||||
%c5 = constant 5 : index
|
||||
%c6 = constant 6 : index
|
||||
|
||||
|
||||
%cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
|
||||
call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
|
||||
%cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
|
||||
call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
|
||||
|
||||
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
|
||||
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
|
||||
store %cst2, %data[%c0, %c2] : memref<2x6xi32>
|
||||
|
@ -54,5 +59,7 @@ func @main() {
|
|||
return
|
||||
}
|
||||
|
||||
func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
|
||||
func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
|
||||
func @print_memref_i32(memref<*xi32>)
|
||||
|
||||
|
|
|
@ -23,7 +23,12 @@ func @main() {
|
|||
%c4 = constant 4 : index
|
||||
%c5 = constant 5 : index
|
||||
%c6 = constant 6 : index
|
||||
|
||||
|
||||
%cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
|
||||
call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
|
||||
%cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
|
||||
call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
|
||||
|
||||
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
|
||||
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
|
||||
store %cst2, %data[%c0, %c2] : memref<2x6xi32>
|
||||
|
@ -54,5 +59,7 @@ func @main() {
|
|||
return
|
||||
}
|
||||
|
||||
func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
|
||||
func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
|
||||
func @print_memref_i32(memref<*xi32>)
|
||||
|
||||
|
|
|
@ -25,6 +25,13 @@ func @main() {
|
|||
%c5 = constant 5 : index
|
||||
%c6 = constant 6 : index
|
||||
|
||||
%cast_data = memref_cast %data : memref<2x6xf32> to memref<?x?xf32>
|
||||
call @mcuMemHostRegisterMemRef2dFloat(%cast_data) : (memref<?x?xf32>) -> ()
|
||||
%cast_sum = memref_cast %sum : memref<2xf32> to memref<?xf32>
|
||||
call @mcuMemHostRegisterMemRef1dFloat(%cast_sum) : (memref<?xf32>) -> ()
|
||||
%cast_mul = memref_cast %mul : memref<2xf32> to memref<?xf32>
|
||||
call @mcuMemHostRegisterMemRef1dFloat(%cast_mul) : (memref<?xf32>) -> ()
|
||||
|
||||
store %cst0, %data[%c0, %c0] : memref<2x6xf32>
|
||||
store %cst1, %data[%c0, %c1] : memref<2x6xf32>
|
||||
store %cst2, %data[%c0, %c2] : memref<2x6xf32>
|
||||
|
@ -61,4 +68,6 @@ func @main() {
|
|||
return
|
||||
}
|
||||
|
||||
func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
|
||||
func @mcuMemHostRegisterMemRef2dFloat(%ptr : memref<?x?xf32>)
|
||||
func @print_memref_f32(memref<*xf32>)
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include <cassert>
|
||||
#include <numeric>
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "cuda.h"
|
||||
|
@ -89,24 +90,39 @@ template <typename T, int N> struct MemRefType {
|
|||
|
||||
// Allows to register a MemRef with the CUDA runtime. Initializes array with
|
||||
// value. Helpful until we have transfer functions implemented.
|
||||
template <typename T, int N>
|
||||
void mcuMemHostRegisterMemRef(const MemRefType<T, N> *arg, T value) {
|
||||
auto count = std::accumulate(arg->sizes, arg->sizes + N, 1,
|
||||
std::multiplies<int64_t>());
|
||||
std::fill_n(arg->data, count, value);
|
||||
mcuMemHostRegister(arg->data, count * sizeof(T));
|
||||
template <typename T>
|
||||
void mcuMemHostRegisterMemRef(T *pointer, llvm::ArrayRef<int64_t> sizes,
|
||||
llvm::ArrayRef<int64_t> strides, T value) {
|
||||
assert(sizes.size() == strides.size());
|
||||
llvm::SmallVector<int64_t, 4> denseStrides(strides.size());
|
||||
|
||||
std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(),
|
||||
std::multiplies<int64_t>());
|
||||
auto count = denseStrides.front();
|
||||
|
||||
// Only densely packed tensors are currently supported.
|
||||
std::rotate(denseStrides.begin(), denseStrides.begin() + 1,
|
||||
denseStrides.end());
|
||||
denseStrides.back() = 1;
|
||||
assert(strides == llvm::makeArrayRef(denseStrides));
|
||||
|
||||
std::fill_n(pointer, count, value);
|
||||
mcuMemHostRegister(pointer, count * sizeof(T));
|
||||
}
|
||||
|
||||
extern "C" void mcuMemHostRegisterMemRef1dFloat(float *allocated,
|
||||
float *aligned, int64_t offset,
|
||||
int64_t size, int64_t stride) {
|
||||
MemRefType<float, 1> descriptor;
|
||||
descriptor.basePtr = allocated;
|
||||
descriptor.data = aligned;
|
||||
descriptor.offset = offset;
|
||||
descriptor.sizes[0] = size;
|
||||
descriptor.strides[0] = stride;
|
||||
mcuMemHostRegisterMemRef(&descriptor, 1.23f);
|
||||
mcuMemHostRegisterMemRef(aligned + offset, {size}, {stride}, 1.23f);
|
||||
}
|
||||
|
||||
extern "C" void mcuMemHostRegisterMemRef2dFloat(float *allocated,
|
||||
float *aligned, int64_t offset,
|
||||
int64_t size0, int64_t size1,
|
||||
int64_t stride0,
|
||||
int64_t stride1) {
|
||||
mcuMemHostRegisterMemRef(aligned + offset, {size0, size1}, {stride0, stride1},
|
||||
1.23f);
|
||||
}
|
||||
|
||||
extern "C" void mcuMemHostRegisterMemRef3dFloat(float *allocated,
|
||||
|
@ -115,15 +131,31 @@ extern "C" void mcuMemHostRegisterMemRef3dFloat(float *allocated,
|
|||
int64_t size2, int64_t stride0,
|
||||
int64_t stride1,
|
||||
int64_t stride2) {
|
||||
MemRefType<float, 3> descriptor;
|
||||
descriptor.basePtr = allocated;
|
||||
descriptor.data = aligned;
|
||||
descriptor.offset = offset;
|
||||
descriptor.sizes[0] = size0;
|
||||
descriptor.strides[0] = stride0;
|
||||
descriptor.sizes[1] = size1;
|
||||
descriptor.strides[1] = stride1;
|
||||
descriptor.sizes[2] = size2;
|
||||
descriptor.strides[2] = stride2;
|
||||
mcuMemHostRegisterMemRef(&descriptor, 1.23f);
|
||||
mcuMemHostRegisterMemRef(aligned + offset, {size0, size1, size2},
|
||||
{stride0, stride1, stride2}, 1.23f);
|
||||
}
|
||||
|
||||
extern "C" void mcuMemHostRegisterMemRef1dInt32(int32_t *allocated,
|
||||
int32_t *aligned,
|
||||
int64_t offset, int64_t size,
|
||||
int64_t stride) {
|
||||
mcuMemHostRegisterMemRef(aligned + offset, {size}, {stride}, 123);
|
||||
}
|
||||
|
||||
extern "C" void mcuMemHostRegisterMemRef2dInt32(int32_t *allocated,
|
||||
int32_t *aligned,
|
||||
int64_t offset, int64_t size0,
|
||||
int64_t size1, int64_t stride0,
|
||||
int64_t stride1) {
|
||||
mcuMemHostRegisterMemRef(aligned + offset, {size0, size1}, {stride0, stride1},
|
||||
123);
|
||||
}
|
||||
|
||||
extern "C" void
|
||||
mcuMemHostRegisterMemRef3dInt32(int32_t *allocated, int32_t *aligned,
|
||||
int64_t offset, int64_t size0, int64_t size1,
|
||||
int64_t size2, int64_t stride0, int64_t stride1,
|
||||
int64_t stride2) {
|
||||
mcuMemHostRegisterMemRef(aligned + offset, {size0, size1, size2},
|
||||
{stride0, stride1, stride2}, 123);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue