update buffer sample gpu

This commit is contained in:
VectorSL 2021-08-13 20:54:04 +08:00
parent 2151b927ba
commit 680d319290
10 changed files with 85 additions and 58 deletions

View File

@ -16,6 +16,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_SAMPLE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_SAMPLE_CPU_KERNEL_H_
#include <stdlib.h>
#include <memory>
#include <string>
#include <vector>
@ -27,13 +28,14 @@ namespace mindspore {
namespace kernel {
class BufferCPUSampleKernel : public CPUKernel {
public:
BufferCPUSampleKernel() : element_nums_(0), capacity_(0), batch_size_(0), exp_size_(0) {}
BufferCPUSampleKernel() : element_nums_(0), capacity_(0), batch_size_(0), exp_size_(0), seed_(0) {}
~BufferCPUSampleKernel() override = default;
void Init(const CNodePtr &kernel_node) {
auto shapes = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "buffer_elements");
auto types = AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "buffer_dtype");
capacity_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "capacity");
seed_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "seed");
batch_size_ = LongToSize(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "batch_size"));
element_nums_ = shapes.size();
for (size_t i = 0; i < element_nums_; i++) {
@ -45,8 +47,6 @@ class BufferCPUSampleKernel : public CPUKernel {
output_size_list_.push_back(i * batch_size_);
exp_size_ += i;
}
// index
input_size_list_.push_back(sizeof(int) * batch_size_);
// count and head
input_size_list_.push_back(sizeof(int));
input_size_list_.push_back(sizeof(int));
@ -54,18 +54,29 @@ class BufferCPUSampleKernel : public CPUKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
auto indexes_addr = GetDeviceAddress<int>(inputs, element_nums_);
auto count_addr = GetDeviceAddress<int>(inputs, element_nums_ + 1);
auto head_addr = GetDeviceAddress<int>(inputs, element_nums_ + 2);
auto count_addr = GetDeviceAddress<int>(inputs, element_nums_);
auto head_addr = GetDeviceAddress<int>(inputs, element_nums_ + 1);
if ((head_addr[0] > 0 && SizeToLong(batch_size_) > capacity_) ||
(head_addr[0] == 0 && SizeToLong(batch_size_) > count_addr[0])) {
MS_LOG(ERROR) << "The batch size " << batch_size_ << " is larger than total buffer size "
<< std::min(capacity_, IntToLong(count_addr[0]));
}
// Generate random indexes
std::vector<size_t> indexes;
for (size_t i = 0; i < IntToSize(count_addr[0]); ++i) {
indexes.push_back(i);
}
if (seed_ == 0) {
std::srand(time(nullptr));
} else {
std::srand(seed_);
}
random_shuffle(indexes.begin(), indexes.end());
auto task = [&](size_t start, size_t end) {
for (size_t j = start; j < end; j++) {
int64_t index = IntToSize(indexes_addr[j]);
size_t index = indexes[j];
for (size_t i = 0; i < element_nums_; i++) {
auto buffer_addr = GetDeviceAddress<unsigned char>(inputs, i);
auto output_addr = GetDeviceAddress<unsigned char>(outputs, i);
@ -92,6 +103,7 @@ class BufferCPUSampleKernel : public CPUKernel {
int64_t capacity_;
size_t batch_size_;
int64_t exp_size_;
int64_t seed_;
std::vector<size_t> exp_element_list;
};
} // namespace kernel

View File

@ -91,6 +91,14 @@ __global__ void BufferSampleKernel(const size_t size, const size_t one_element,
}
}
__global__ void SrandUniformFloat(const int size, curandState *globalState, const int seedc, float *out) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
curand_init(seedc, threadIdx.x, 0, &globalState[i]);
out[i] = curand_uniform(&globalState[i]);
}
__syncthreads();
}
void BufferAppend(const int64_t capacity, const size_t size, const int *index, const int exp_batch,
unsigned char *buffer, const unsigned char *exp, cudaStream_t cuda_stream) {
BufferAppendKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(capacity, size, index, exp_batch, buffer, exp);
@ -119,3 +127,7 @@ void BufferSample(const size_t size, const size_t one_element, const int *index,
unsigned char *out, cudaStream_t cuda_stream) {
BufferSampleKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, one_element, index, buffer, out);
}
void RandomGen(const int size, curandState *globalState, const int &seedc, float *out, cudaStream_t stream) {
SrandUniformFloat<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(size, globalState, seedc, out);
}

View File

@ -16,7 +16,7 @@
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RL_BUFFER_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RL_BUFFER_IMPL_H_
#include <curand_kernel.h>
#include "runtime/device/gpu/cuda_common.h"
void BufferAppend(const int64_t capacity, const size_t size, const int *index, const int exp_batch,
unsigned char *buffer, const unsigned char *exp, cudaStream_t cuda_stream);
@ -29,5 +29,5 @@ void CheckBatchSize(const int *count, const int *head, const size_t batch_size,
cudaStream_t cuda_stream);
void BufferSample(const size_t size, const size_t one_element, const int *index, const unsigned char *buffer,
unsigned char *out, cudaStream_t cuda_stream);
void RandomGen(const int size, curandState *globalState, const int &seedc, float *out, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_

View File

@ -19,15 +19,18 @@
#include <string>
#include <vector>
#include <algorithm>
#include <limits>
#include <chrono>
#include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh"
#include "runtime/device/gpu/gpu_common.h"
namespace mindspore {
namespace kernel {
BufferSampleKernel::BufferSampleKernel() : element_nums_(0), capacity_(0), batch_size_(0) {}
BufferSampleKernel::BufferSampleKernel() : element_nums_(0), capacity_(0), batch_size_(0), seed_(0) {}
BufferSampleKernel::~BufferSampleKernel() {}
@ -44,6 +47,7 @@ bool BufferSampleKernel::Init(const CNodePtr &kernel_node) {
auto shapes = GetAttr<std::vector<int64_t>>(kernel_node, "buffer_elements");
auto types = GetAttr<std::vector<TypePtr>>(kernel_node, "buffer_dtype");
capacity_ = GetAttr<int64_t>(kernel_node, "capacity");
seed_ = GetAttr<int64_t>(kernel_node, "seed");
batch_size_ = LongToSize(GetAttr<int64_t>(kernel_node, "batch_size"));
element_nums_ = shapes.size();
for (size_t i = 0; i < element_nums_; i++) {
@ -52,28 +56,52 @@ bool BufferSampleKernel::Init(const CNodePtr &kernel_node) {
input_size_list_.push_back(capacity_ * element);
output_size_list_.push_back(batch_size_ * element);
}
// index
input_size_list_.push_back(sizeof(int) * batch_size_);
// count and head
input_size_list_.push_back(sizeof(int));
input_size_list_.push_back(sizeof(int));
workspace_size_list_.push_back(capacity_ * sizeof(curandState));
workspace_size_list_.push_back(capacity_ * sizeof(float));
workspace_size_list_.push_back(capacity_ * sizeof(int));
workspace_size_list_.push_back(capacity_ * sizeof(float));
return true;
}
void BufferSampleKernel::InitSizeLists() { return; }
bool BufferSampleKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
bool BufferSampleKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
const std::vector<AddressPtr> &outputs, void *stream) {
int *index_addr = GetDeviceAddress<int>(inputs, element_nums_);
int *count_addr = GetDeviceAddress<int>(inputs, element_nums_ + 1);
int *head_addr = GetDeviceAddress<int>(inputs, element_nums_ + 2);
int *count_addr = GetDeviceAddress<int>(inputs, element_nums_);
int *head_addr = GetDeviceAddress<int>(inputs, element_nums_ + 1);
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
CheckBatchSize(count_addr, head_addr, batch_size_, capacity_, cuda_stream);
int k_cut = 0;
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&k_cut, count_addr, sizeof(int), cudaMemcpyDeviceToHost, cuda_stream),
"sync dev to host failed");
// 1 Generate random floats
auto States = GetDeviceAddress<void *>(workspaces, 0);
auto random_f = GetDeviceAddress<float>(workspaces, 1);
auto indexes = GetDeviceAddress<int>(workspaces, 2);
auto useless_out = GetDeviceAddress<float>(workspaces, 3);
int seedc = 0;
if (seed_ == 0) {
generator_.seed(std::chrono::system_clock::now().time_since_epoch().count());
seedc = generator_();
} else {
seedc = seed_;
}
float init_k = std::numeric_limits<float>::lowest();
curandState *devStates = reinterpret_cast<curandState *>(States);
RandomGen(k_cut, devStates, seedc, random_f, cuda_stream);
// 2 Sort the random floats, and get the sorted indexes as the random indexes
FastTopK(1, k_cut, random_f, k_cut, useless_out, indexes, init_k, cuda_stream);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSync failed, sample-topk");
for (size_t i = 0; i < element_nums_; i++) {
auto buffer_addr = GetDeviceAddress<unsigned char>(inputs, i);
auto out_addr = GetDeviceAddress<unsigned char>(outputs, i);
size_t size = batch_size_ * exp_element_list[i];
BufferSample(size, exp_element_list[i], index_addr, buffer_addr, out_addr, cuda_stream);
BufferSample(size, exp_element_list[i], indexes, buffer_addr, out_addr, cuda_stream);
}
return true;
}

View File

@ -20,6 +20,7 @@
#include <memory>
#include <string>
#include <vector>
#include <random>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
@ -45,6 +46,8 @@ class BufferSampleKernel : public GpuKernel {
size_t element_nums_;
int64_t capacity_;
size_t batch_size_;
int64_t seed_;
std::mt19937 generator_;
std::vector<size_t> exp_element_list;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;

View File

@ -36,12 +36,12 @@ class BufferSample(PrimitiveWithInfer):
batch_size (int64): The size of the sampled data, lessequal to `capacity`.
buffer_shape (tuple(shape)): The shape of an buffer.
buffer_dtype (tuple(type)): The type of an buffer.
seed (int64): Random seed for sample. Default: 0. If use the default seed, it will generate a ramdom
one in kernel. Set a number other than `0` to keep a specific seed.
Inputs:
- **data** (tuple(Parameter(Tensor))) - The tuple(Tensor) represents replaybuffer,
each tensor is described by the `buffer_shape` and `buffer_type`.
- **indexes** (tuple(int32)) - The position list in replaybuffer,
the size equal to `batch_size`.
- **count** (Parameter) - The count mean the real available size of the buffer,
data type: int32.
- **head** (Parameter) - The position of the first data in buffer, data type: int32.
@ -69,8 +69,7 @@ class BufferSample(PrimitiveWithInfer):
Parameter(Tensor(np.ones((100, 1)).astype(np.int32)), name="reward"),
Parameter(Tensor(np.arange(100 * 4).reshape(100, 4).astype(np.float32)), name="state_")]
>>> buffer_sample = ops.BufferSample(capacity, batch_size, shapes, types)
>>> indexes = Parameter(Tensor([0, 2, 4, 3, 8], ms.int32), name="index")
>>> output = buffer_sample(buffer, indexes, count, head)
>>> output = buffer_sample(buffer, count, head)
>>> print(output)
(Tensor(shape=[5, 4], dtype=Float32, value=
[[ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00, 3.00000000e+00],
@ -99,7 +98,7 @@ class BufferSample(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self, capacity, batch_size, buffer_shape, buffer_dtype):
def __init__(self, capacity, batch_size, buffer_shape, buffer_dtype, seed=0):
"""Initialize BufferSample."""
self.init_prim_io_names(inputs=["buffer"], outputs=["sample"])
validator.check_value_type("shape of init data", buffer_shape, [tuple, list], self.name)
@ -110,6 +109,7 @@ class BufferSample(PrimitiveWithInfer):
self._n = len(buffer_shape)
validator.check_int(self._batch_size, capacity, Rel.LE, "batchsize", self.name)
self.add_prim_attr('capacity', capacity)
self.add_prim_attr('seed', seed)
buffer_elements = []
for shape in buffer_shape:
buffer_elements.append(reduce(lambda x, y: x * y, shape))
@ -119,17 +119,16 @@ class BufferSample(PrimitiveWithInfer):
if context.get_context('device_target') == "Ascend":
self.add_prim_attr('device_target', "CPU")
def infer_shape(self, data_shape, index_shape, count_shape, head_shape):
def infer_shape(self, data_shape, count_shape, head_shape):
validator.check_value_type("shape of data", data_shape, [tuple, list], self.name)
out_shapes = []
for i in range(self._n):
out_shapes.append((self._batch_size,) + self._buffer_shape[i])
return tuple(out_shapes)
def infer_dtype(self, data_type, index_type, count_type, head_type):
def infer_dtype(self, data_type, count_type, head_type):
validator.check_type_name("count type", count_type, (mstype.int32), self.name)
validator.check_type_name("head type", head_type, (mstype.int32), self.name)
validator.check_type_name("index type", index_type, (mstype.int64, mstype.int32), self.name)
return tuple(self._buffer_dtype)
class BufferAppend(PrimitiveWithInfer):

View File

@ -45,9 +45,6 @@ class RLBuffer(nn.Cell):
self.buffer_get = P.BufferGetItem(self._capacity, shapes, types)
self.buffer_sample = P.BufferSample(
self._capacity, batch_size, shapes, types)
self.dummy_tensor = Tensor(np.ones(shape=[batch_size]), ms.bool_)
self.rnd_choice_mask = P.RandomChoiceWithMask(count=batch_size)
self.reshape = P.Reshape()
@ms_function
def append(self, exps):
@ -59,9 +56,7 @@ class RLBuffer(nn.Cell):
@ms_function
def sample(self):
index, _ = self.rnd_choice_mask(self.dummy_tensor)
index = self.reshape(index, (self._batch_size,))
return self.buffer_sample(self.buffer, index, self.count, self.head)
return self.buffer_sample(self.buffer, self.count, self.head)
s = Tensor(np.array([2, 2, 2, 2]), ms.float32)

View File

@ -55,16 +55,13 @@ class RLBufferSample(nn.Cell):
def __init__(self, capcity, batch_size, shapes, types):
super(RLBufferSample, self).__init__()
self._capacity = capcity
count = 5
self.count = Parameter(Tensor(5, ms.int32), name="count")
self.head = Parameter(Tensor(0, ms.int32), name="head")
self.input_x = Tensor(np.ones(shape=[count]), ms.bool_)
self.buffer_sample = P.BufferSample(self._capacity, batch_size, shapes, types)
self.index = Parameter(Tensor([0, 2, 4], ms.int32), name="index")
@ms_function
def construct(self, buffer):
return self.buffer_sample(buffer, self.index, self.count, self.head)
return self.buffer_sample(buffer, self.count, self.head)
states = Tensor(np.arange(4*5).reshape(5, 4).astype(np.float32)/10.0)
@ -93,14 +90,7 @@ def test_BufferSample():
buffer_sample = RLBufferSample(capcity=5, batch_size=3, shapes=[(4,), (2,), (1,), (4,)], types=[
ms.float32, ms.int32, ms.int32, ms.float32])
ss, aa, rr, ss_ = buffer_sample(b)
expect_s = [[0, 0.1, 0.2, 0.3], [0.8, 0.9, 1.0, 1.1], [1.6, 1.7, 1.8, 1.9]]
expect_a = [[0, 1], [4, 5], [8, 9]]
expect_r = [[1], [1], [1]]
expect_s_ = [[0, 1, 2, 3], [8, 9, 10, 11], [16, 17, 18, 19]]
np.testing.assert_almost_equal(ss.asnumpy(), expect_s)
np.testing.assert_almost_equal(aa.asnumpy(), expect_a)
np.testing.assert_almost_equal(rr.asnumpy(), expect_r)
np.testing.assert_almost_equal(ss_.asnumpy(), expect_s_)
print(ss, aa, rr, ss_)
@ pytest.mark.level0

View File

@ -56,9 +56,7 @@ class RLBuffer(nn.Cell):
@ms_function
def sample(self):
count = self.reshape(self.count, (1,))
index = self.randperm(count)
return self.buffer_sample(self.buffer, index, self.count, self.head)
return self.buffer_sample(self.buffer, self.count, self.head)
s = Tensor(np.array([2, 2, 2, 2]), ms.float32)

View File

@ -55,17 +55,14 @@ class RLBufferSample(nn.Cell):
def __init__(self, capcity, batch_size, shapes, types):
super(RLBufferSample, self).__init__()
self._capacity = capcity
count = 5
self.count = Parameter(Tensor(5, ms.int32), name="count")
self.head = Parameter(Tensor(0, ms.int32), name="head")
self.input_x = Tensor(np.ones(shape=[count]), ms.bool_)
self.buffer_sample = P.BufferSample(
self._capacity, batch_size, shapes, types)
self.index = Parameter(Tensor([0, 2, 4], ms.int32), name="index")
@ms_function
def construct(self, buffer):
return self.buffer_sample(buffer, self.index, self.count, self.head)
return self.buffer_sample(buffer, self.count, self.head)
states = Tensor(np.arange(4*5).reshape(5, 4).astype(np.float32)/10.0)
@ -94,14 +91,7 @@ def test_BufferSample():
buffer_sample = RLBufferSample(capcity=5, batch_size=3, shapes=[(4,), (2,), (1,), (4,)], types=[
ms.float32, ms.int32, ms.int32, ms.float32])
ss, aa, rr, ss_ = buffer_sample(b)
expect_s = [[0, 0.1, 0.2, 0.3], [0.8, 0.9, 1.0, 1.1], [1.6, 1.7, 1.8, 1.9]]
expect_a = [[0, 1], [4, 5], [8, 9]]
expect_r = [[1], [1], [1]]
expect_s_ = [[0, 1, 2, 3], [8, 9, 10, 11], [16, 17, 18, 19]]
np.testing.assert_almost_equal(ss.asnumpy(), expect_s)
np.testing.assert_almost_equal(aa.asnumpy(), expect_a)
np.testing.assert_almost_equal(rr.asnumpy(), expect_r)
np.testing.assert_almost_equal(ss_.asnumpy(), expect_s_)
print(ss, aa, rr, ss_)
@ pytest.mark.level0