diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_sample_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_sample_cpu_kernel.h index 4a18c47b7c7..f3d611f95a2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_sample_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_sample_cpu_kernel.h @@ -16,6 +16,7 @@ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_SAMPLE_CPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_SAMPLE_CPU_KERNEL_H_ +#include #include #include #include @@ -27,13 +28,14 @@ namespace mindspore { namespace kernel { class BufferCPUSampleKernel : public CPUKernel { public: - BufferCPUSampleKernel() : element_nums_(0), capacity_(0), batch_size_(0), exp_size_(0) {} + BufferCPUSampleKernel() : element_nums_(0), capacity_(0), batch_size_(0), exp_size_(0), seed_(0) {} ~BufferCPUSampleKernel() override = default; void Init(const CNodePtr &kernel_node) { auto shapes = AnfAlgo::GetNodeAttr>(kernel_node, "buffer_elements"); auto types = AnfAlgo::GetNodeAttr>(kernel_node, "buffer_dtype"); capacity_ = AnfAlgo::GetNodeAttr(kernel_node, "capacity"); + seed_ = AnfAlgo::GetNodeAttr(kernel_node, "seed"); batch_size_ = LongToSize(AnfAlgo::GetNodeAttr(kernel_node, "batch_size")); element_nums_ = shapes.size(); for (size_t i = 0; i < element_nums_; i++) { @@ -45,8 +47,6 @@ class BufferCPUSampleKernel : public CPUKernel { output_size_list_.push_back(i * batch_size_); exp_size_ += i; } - // index - input_size_list_.push_back(sizeof(int) * batch_size_); // count and head input_size_list_.push_back(sizeof(int)); input_size_list_.push_back(sizeof(int)); @@ -54,18 +54,29 @@ class BufferCPUSampleKernel : public CPUKernel { bool Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs) { - auto indexes_addr = GetDeviceAddress(inputs, element_nums_); - auto count_addr = GetDeviceAddress(inputs, element_nums_ + 1); - auto head_addr = GetDeviceAddress(inputs, element_nums_ + 2); + auto count_addr = GetDeviceAddress(inputs, element_nums_); + auto head_addr = GetDeviceAddress(inputs, element_nums_ + 1); if ((head_addr[0] > 0 && SizeToLong(batch_size_) > capacity_) || (head_addr[0] == 0 && SizeToLong(batch_size_) > count_addr[0])) { MS_LOG(ERROR) << "The batch size " << batch_size_ << " is larger than total buffer size " << std::min(capacity_, IntToLong(count_addr[0])); } + // Generate random indexes + std::vector indexes; + for (size_t i = 0; i < IntToSize(count_addr[0]); ++i) { + indexes.push_back(i); + } + if (seed_ == 0) { + std::srand(time(nullptr)); + } else { + std::srand(seed_); + } + random_shuffle(indexes.begin(), indexes.end()); + auto task = [&](size_t start, size_t end) { for (size_t j = start; j < end; j++) { - int64_t index = IntToSize(indexes_addr[j]); + size_t index = indexes[j]; for (size_t i = 0; i < element_nums_; i++) { auto buffer_addr = GetDeviceAddress(inputs, i); auto output_addr = GetDeviceAddress(outputs, i); @@ -92,6 +103,7 @@ class BufferCPUSampleKernel : public CPUKernel { int64_t capacity_; size_t batch_size_; int64_t exp_size_; + int64_t seed_; std::vector exp_element_list; }; } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cu index e374c1eda24..7cfa49f6465 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cu @@ -91,6 +91,14 @@ __global__ void BufferSampleKernel(const size_t size, const size_t one_element, } } +__global__ void SrandUniformFloat(const int size, curandState *globalState, const int seedc, float *out) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + curand_init(seedc, threadIdx.x, 0, &globalState[i]); + out[i] = curand_uniform(&globalState[i]); + } + __syncthreads(); +} + void BufferAppend(const int64_t capacity, const size_t size, const int *index, const int exp_batch, unsigned char *buffer, const unsigned char *exp, cudaStream_t cuda_stream) { BufferAppendKernel<<>>(capacity, size, index, exp_batch, buffer, exp); @@ -119,3 +127,7 @@ void BufferSample(const size_t size, const size_t one_element, const int *index, unsigned char *out, cudaStream_t cuda_stream) { BufferSampleKernel<<>>(size, one_element, index, buffer, out); } + +void RandomGen(const int size, curandState *globalState, const int &seedc, float *out, cudaStream_t stream) { + SrandUniformFloat<<>>(size, globalState, seedc, out); +} diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cuh index 2f0e57a49dd..95e2846b384 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cuh @@ -16,7 +16,7 @@ #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RL_BUFFER_IMPL_H_ #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RL_BUFFER_IMPL_H_ - +#include #include "runtime/device/gpu/cuda_common.h" void BufferAppend(const int64_t capacity, const size_t size, const int *index, const int exp_batch, unsigned char *buffer, const unsigned char *exp, cudaStream_t cuda_stream); @@ -29,5 +29,5 @@ void CheckBatchSize(const int *count, const int *head, const size_t batch_size, cudaStream_t cuda_stream); void BufferSample(const size_t size, const size_t one_element, const int *index, const unsigned char *buffer, unsigned char *out, cudaStream_t cuda_stream); - +void RandomGen(const int size, curandState *globalState, const int &seedc, float *out, cudaStream_t stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/rl/buffer_sample_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/rl/buffer_sample_gpu_kernel.cc index 184ccfa9497..759ce95f195 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/rl/buffer_sample_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/rl/buffer_sample_gpu_kernel.cc @@ -19,15 +19,18 @@ #include #include #include +#include +#include #include "backend/kernel_compiler/common_utils.h" #include "backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh" #include "runtime/device/gpu/gpu_common.h" namespace mindspore { namespace kernel { -BufferSampleKernel::BufferSampleKernel() : element_nums_(0), capacity_(0), batch_size_(0) {} +BufferSampleKernel::BufferSampleKernel() : element_nums_(0), capacity_(0), batch_size_(0), seed_(0) {} BufferSampleKernel::~BufferSampleKernel() {} @@ -44,6 +47,7 @@ bool BufferSampleKernel::Init(const CNodePtr &kernel_node) { auto shapes = GetAttr>(kernel_node, "buffer_elements"); auto types = GetAttr>(kernel_node, "buffer_dtype"); capacity_ = GetAttr(kernel_node, "capacity"); + seed_ = GetAttr(kernel_node, "seed"); batch_size_ = LongToSize(GetAttr(kernel_node, "batch_size")); element_nums_ = shapes.size(); for (size_t i = 0; i < element_nums_; i++) { @@ -52,28 +56,52 @@ bool BufferSampleKernel::Init(const CNodePtr &kernel_node) { input_size_list_.push_back(capacity_ * element); output_size_list_.push_back(batch_size_ * element); } - // index - input_size_list_.push_back(sizeof(int) * batch_size_); // count and head input_size_list_.push_back(sizeof(int)); input_size_list_.push_back(sizeof(int)); + workspace_size_list_.push_back(capacity_ * sizeof(curandState)); + workspace_size_list_.push_back(capacity_ * sizeof(float)); + workspace_size_list_.push_back(capacity_ * sizeof(int)); + workspace_size_list_.push_back(capacity_ * sizeof(float)); return true; } void BufferSampleKernel::InitSizeLists() { return; } -bool BufferSampleKernel::Launch(const std::vector &inputs, const std::vector &, +bool BufferSampleKernel::Launch(const std::vector &inputs, const std::vector &workspaces, const std::vector &outputs, void *stream) { - int *index_addr = GetDeviceAddress(inputs, element_nums_); - int *count_addr = GetDeviceAddress(inputs, element_nums_ + 1); - int *head_addr = GetDeviceAddress(inputs, element_nums_ + 2); + int *count_addr = GetDeviceAddress(inputs, element_nums_); + int *head_addr = GetDeviceAddress(inputs, element_nums_ + 1); auto cuda_stream = reinterpret_cast(stream); CheckBatchSize(count_addr, head_addr, batch_size_, capacity_, cuda_stream); + int k_cut = 0; + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(&k_cut, count_addr, sizeof(int), cudaMemcpyDeviceToHost, cuda_stream), + "sync dev to host failed"); + // 1 Generate random floats + auto States = GetDeviceAddress(workspaces, 0); + auto random_f = GetDeviceAddress(workspaces, 1); + auto indexes = GetDeviceAddress(workspaces, 2); + auto useless_out = GetDeviceAddress(workspaces, 3); + int seedc = 0; + if (seed_ == 0) { + generator_.seed(std::chrono::system_clock::now().time_since_epoch().count()); + seedc = generator_(); + } else { + seedc = seed_; + } + + float init_k = std::numeric_limits::lowest(); + curandState *devStates = reinterpret_cast(States); + RandomGen(k_cut, devStates, seedc, random_f, cuda_stream); + // 2 Sort the random floats, and get the sorted indexes as the random indexes + FastTopK(1, k_cut, random_f, k_cut, useless_out, indexes, init_k, cuda_stream); + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSync failed, sample-topk"); for (size_t i = 0; i < element_nums_; i++) { auto buffer_addr = GetDeviceAddress(inputs, i); auto out_addr = GetDeviceAddress(outputs, i); size_t size = batch_size_ * exp_element_list[i]; - BufferSample(size, exp_element_list[i], index_addr, buffer_addr, out_addr, cuda_stream); + BufferSample(size, exp_element_list[i], indexes, buffer_addr, out_addr, cuda_stream); } return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/rl/buffer_sample_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/rl/buffer_sample_gpu_kernel.h index b1e49cebb88..1388bec4318 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/rl/buffer_sample_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/rl/buffer_sample_gpu_kernel.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" @@ -45,6 +46,8 @@ class BufferSampleKernel : public GpuKernel { size_t element_nums_; int64_t capacity_; size_t batch_size_; + int64_t seed_; + std::mt19937 generator_; std::vector exp_element_list; std::vector input_size_list_; std::vector output_size_list_; diff --git a/mindspore/ops/operations/rl_ops.py b/mindspore/ops/operations/rl_ops.py index 84ba2cc86ab..483b88c695a 100644 --- a/mindspore/ops/operations/rl_ops.py +++ b/mindspore/ops/operations/rl_ops.py @@ -36,12 +36,12 @@ class BufferSample(PrimitiveWithInfer): batch_size (int64): The size of the sampled data, lessequal to `capacity`. buffer_shape (tuple(shape)): The shape of an buffer. buffer_dtype (tuple(type)): The type of an buffer. + seed (int64): Random seed for sample. Default: 0. If use the default seed, it will generate a ramdom + one in kernel. Set a number other than `0` to keep a specific seed. Inputs: - **data** (tuple(Parameter(Tensor))) - The tuple(Tensor) represents replaybuffer, each tensor is described by the `buffer_shape` and `buffer_type`. - - **indexes** (tuple(int32)) - The position list in replaybuffer, - the size equal to `batch_size`. - **count** (Parameter) - The count mean the real available size of the buffer, data type: int32. - **head** (Parameter) - The position of the first data in buffer, data type: int32. @@ -69,8 +69,7 @@ class BufferSample(PrimitiveWithInfer): Parameter(Tensor(np.ones((100, 1)).astype(np.int32)), name="reward"), Parameter(Tensor(np.arange(100 * 4).reshape(100, 4).astype(np.float32)), name="state_")] >>> buffer_sample = ops.BufferSample(capacity, batch_size, shapes, types) - >>> indexes = Parameter(Tensor([0, 2, 4, 3, 8], ms.int32), name="index") - >>> output = buffer_sample(buffer, indexes, count, head) + >>> output = buffer_sample(buffer, count, head) >>> print(output) (Tensor(shape=[5, 4], dtype=Float32, value= [[ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00, 3.00000000e+00], @@ -99,7 +98,7 @@ class BufferSample(PrimitiveWithInfer): """ @prim_attr_register - def __init__(self, capacity, batch_size, buffer_shape, buffer_dtype): + def __init__(self, capacity, batch_size, buffer_shape, buffer_dtype, seed=0): """Initialize BufferSample.""" self.init_prim_io_names(inputs=["buffer"], outputs=["sample"]) validator.check_value_type("shape of init data", buffer_shape, [tuple, list], self.name) @@ -110,6 +109,7 @@ class BufferSample(PrimitiveWithInfer): self._n = len(buffer_shape) validator.check_int(self._batch_size, capacity, Rel.LE, "batchsize", self.name) self.add_prim_attr('capacity', capacity) + self.add_prim_attr('seed', seed) buffer_elements = [] for shape in buffer_shape: buffer_elements.append(reduce(lambda x, y: x * y, shape)) @@ -119,17 +119,16 @@ class BufferSample(PrimitiveWithInfer): if context.get_context('device_target') == "Ascend": self.add_prim_attr('device_target', "CPU") - def infer_shape(self, data_shape, index_shape, count_shape, head_shape): + def infer_shape(self, data_shape, count_shape, head_shape): validator.check_value_type("shape of data", data_shape, [tuple, list], self.name) out_shapes = [] for i in range(self._n): out_shapes.append((self._batch_size,) + self._buffer_shape[i]) return tuple(out_shapes) - def infer_dtype(self, data_type, index_type, count_type, head_type): + def infer_dtype(self, data_type, count_type, head_type): validator.check_type_name("count type", count_type, (mstype.int32), self.name) validator.check_type_name("head type", head_type, (mstype.int32), self.name) - validator.check_type_name("index type", index_type, (mstype.int64, mstype.int32), self.name) return tuple(self._buffer_dtype) class BufferAppend(PrimitiveWithInfer): diff --git a/tests/st/ops/cpu/test_rl_buffer_net.py b/tests/st/ops/cpu/test_rl_buffer_net.py index c46a6169bb8..fd43369c143 100644 --- a/tests/st/ops/cpu/test_rl_buffer_net.py +++ b/tests/st/ops/cpu/test_rl_buffer_net.py @@ -45,9 +45,6 @@ class RLBuffer(nn.Cell): self.buffer_get = P.BufferGetItem(self._capacity, shapes, types) self.buffer_sample = P.BufferSample( self._capacity, batch_size, shapes, types) - self.dummy_tensor = Tensor(np.ones(shape=[batch_size]), ms.bool_) - self.rnd_choice_mask = P.RandomChoiceWithMask(count=batch_size) - self.reshape = P.Reshape() @ms_function def append(self, exps): @@ -59,9 +56,7 @@ class RLBuffer(nn.Cell): @ms_function def sample(self): - index, _ = self.rnd_choice_mask(self.dummy_tensor) - index = self.reshape(index, (self._batch_size,)) - return self.buffer_sample(self.buffer, index, self.count, self.head) + return self.buffer_sample(self.buffer, self.count, self.head) s = Tensor(np.array([2, 2, 2, 2]), ms.float32) diff --git a/tests/st/ops/cpu/test_rl_buffer_op.py b/tests/st/ops/cpu/test_rl_buffer_op.py index fd8ed19783a..e818bc758e9 100644 --- a/tests/st/ops/cpu/test_rl_buffer_op.py +++ b/tests/st/ops/cpu/test_rl_buffer_op.py @@ -55,16 +55,13 @@ class RLBufferSample(nn.Cell): def __init__(self, capcity, batch_size, shapes, types): super(RLBufferSample, self).__init__() self._capacity = capcity - count = 5 self.count = Parameter(Tensor(5, ms.int32), name="count") self.head = Parameter(Tensor(0, ms.int32), name="head") - self.input_x = Tensor(np.ones(shape=[count]), ms.bool_) self.buffer_sample = P.BufferSample(self._capacity, batch_size, shapes, types) - self.index = Parameter(Tensor([0, 2, 4], ms.int32), name="index") @ms_function def construct(self, buffer): - return self.buffer_sample(buffer, self.index, self.count, self.head) + return self.buffer_sample(buffer, self.count, self.head) states = Tensor(np.arange(4*5).reshape(5, 4).astype(np.float32)/10.0) @@ -93,14 +90,7 @@ def test_BufferSample(): buffer_sample = RLBufferSample(capcity=5, batch_size=3, shapes=[(4,), (2,), (1,), (4,)], types=[ ms.float32, ms.int32, ms.int32, ms.float32]) ss, aa, rr, ss_ = buffer_sample(b) - expect_s = [[0, 0.1, 0.2, 0.3], [0.8, 0.9, 1.0, 1.1], [1.6, 1.7, 1.8, 1.9]] - expect_a = [[0, 1], [4, 5], [8, 9]] - expect_r = [[1], [1], [1]] - expect_s_ = [[0, 1, 2, 3], [8, 9, 10, 11], [16, 17, 18, 19]] - np.testing.assert_almost_equal(ss.asnumpy(), expect_s) - np.testing.assert_almost_equal(aa.asnumpy(), expect_a) - np.testing.assert_almost_equal(rr.asnumpy(), expect_r) - np.testing.assert_almost_equal(ss_.asnumpy(), expect_s_) + print(ss, aa, rr, ss_) @ pytest.mark.level0 diff --git a/tests/st/ops/gpu/test_rl_buffer_net.py b/tests/st/ops/gpu/test_rl_buffer_net.py index f39afc561a8..00d8a1c73a4 100644 --- a/tests/st/ops/gpu/test_rl_buffer_net.py +++ b/tests/st/ops/gpu/test_rl_buffer_net.py @@ -56,9 +56,7 @@ class RLBuffer(nn.Cell): @ms_function def sample(self): - count = self.reshape(self.count, (1,)) - index = self.randperm(count) - return self.buffer_sample(self.buffer, index, self.count, self.head) + return self.buffer_sample(self.buffer, self.count, self.head) s = Tensor(np.array([2, 2, 2, 2]), ms.float32) diff --git a/tests/st/ops/gpu/test_rl_buffer_op.py b/tests/st/ops/gpu/test_rl_buffer_op.py index 97e4c7bdcb0..44c5b34978d 100644 --- a/tests/st/ops/gpu/test_rl_buffer_op.py +++ b/tests/st/ops/gpu/test_rl_buffer_op.py @@ -55,17 +55,14 @@ class RLBufferSample(nn.Cell): def __init__(self, capcity, batch_size, shapes, types): super(RLBufferSample, self).__init__() self._capacity = capcity - count = 5 self.count = Parameter(Tensor(5, ms.int32), name="count") self.head = Parameter(Tensor(0, ms.int32), name="head") - self.input_x = Tensor(np.ones(shape=[count]), ms.bool_) self.buffer_sample = P.BufferSample( self._capacity, batch_size, shapes, types) - self.index = Parameter(Tensor([0, 2, 4], ms.int32), name="index") @ms_function def construct(self, buffer): - return self.buffer_sample(buffer, self.index, self.count, self.head) + return self.buffer_sample(buffer, self.count, self.head) states = Tensor(np.arange(4*5).reshape(5, 4).astype(np.float32)/10.0) @@ -94,14 +91,7 @@ def test_BufferSample(): buffer_sample = RLBufferSample(capcity=5, batch_size=3, shapes=[(4,), (2,), (1,), (4,)], types=[ ms.float32, ms.int32, ms.int32, ms.float32]) ss, aa, rr, ss_ = buffer_sample(b) - expect_s = [[0, 0.1, 0.2, 0.3], [0.8, 0.9, 1.0, 1.1], [1.6, 1.7, 1.8, 1.9]] - expect_a = [[0, 1], [4, 5], [8, 9]] - expect_r = [[1], [1], [1]] - expect_s_ = [[0, 1, 2, 3], [8, 9, 10, 11], [16, 17, 18, 19]] - np.testing.assert_almost_equal(ss.asnumpy(), expect_s) - np.testing.assert_almost_equal(aa.asnumpy(), expect_a) - np.testing.assert_almost_equal(rr.asnumpy(), expect_r) - np.testing.assert_almost_equal(ss_.asnumpy(), expect_s_) + print(ss, aa, rr, ss_) @ pytest.mark.level0