diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/rl/priority_replay_buffer.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/rl/priority_replay_buffer.cu index 129086e82d9..6f6777496c3 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/rl/priority_replay_buffer.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/rl/priority_replay_buffer.cu @@ -108,22 +108,25 @@ __forceinline__ __device__ size_t GetPrefixSumIdx(T *tree, size_t capacity, floa } template -__global__ void SumTreeSampleKernel(T *tree, curandState *state, size_t capacity, float *beta, size_t batch_size, - size_t *indices, float *weights) { +__global__ void SumTreeSampleKernel(T *tree, curandState *state, size_t capacity, size_t round_start, float *beta, + size_t batch_size, size_t *indices, float *weights) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size; i += gridDim.x * blockDim.x) { size_t segment_len = tree[kRootIdx].sum / batch_size; float prefix_sum = (curand_uniform(&state[i]) + i) * segment_len; size_t idx = GetPrefixSumIdx(tree, capacity, prefix_sum); - indices[i] = idx; + indices[i] = idx + round_start; weights[i] = powf((tree[idx + capacity].sum / tree[kRootIdx].min), -beta[0]); } } template -__global__ void SumTreeUpdateKernel(T *tree, size_t capacity, float alpha, float *max_priority, size_t *indices, - float *priorities, size_t batch_size) { +__global__ void SumTreeUpdateKernel(T *tree, size_t capacity, size_t last_idx, float alpha, float *max_priority, + size_t *indices, float *priorities, size_t batch_size) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size; i += gridDim.x * blockDim.x) { size_t idx = indices[i]; + // skip if the transition is already replaced. + if (idx < last_idx) continue; + float priority = powf(priorities[i], alpha); MsAtomicMax(max_priority, priority); @@ -151,20 +154,22 @@ void SumTreePush(T *tree, const float &alpha, const size_t &idx, const size_t &c // Sample a batch item. Return indices and correction weights. template -void SumTreeSample(T *tree, curandState *state, const size_t &capacity, float *beta, const size_t &batch_size, - size_t *indices, float *weights, cudaStream_t stream) { +void SumTreeSample(T *tree, curandState *state, const size_t &capacity, const size_t &round_start, float *beta, + const size_t &batch_size, size_t *indices, float *weights, cudaStream_t stream) { size_t block = std::min(batch_size, kMaxThreadPerBlock); size_t grid = (batch_size + block - 1) / block; - SumTreeSampleKernel<<>>(tree, state, capacity, beta, batch_size, indices, weights); + SumTreeSampleKernel<<>>(tree, state, capacity, round_start, beta, batch_size, indices, + weights); } // Update item priority. template -void SumTreeUpdate(T *tree, const size_t &capacity, const float &alpha, float *max_priority, size_t *indices, - float *priorities, const size_t &batch_size, cudaStream_t stream) { +void SumTreeUpdate(T *tree, const size_t &capacity, const size_t &last_idx, const float &alpha, float *max_priority, + size_t *indices, float *priorities, const size_t &batch_size, cudaStream_t stream) { size_t block = std::min(batch_size, kMaxThreadPerBlock); size_t grid = (batch_size + block - 1) / block; - SumTreeUpdateKernel<<>>(tree, capacity, alpha, max_priority, indices, priorities, batch_size); + SumTreeUpdateKernel<<>>(tree, capacity, last_idx, alpha, max_priority, indices, priorities, + batch_size); } template CUDA_LIB_EXPORT void SumTreeInit(SumMinTree *tree, float *max_priority, const size_t &capacity, @@ -173,8 +178,10 @@ template CUDA_LIB_EXPORT void SumTreePush(SumMinTree *tree, const fl const size_t &capacity, float *priority, float *max_priority, cudaStream_t stream); template CUDA_LIB_EXPORT void SumTreeSample(SumMinTree *tree, curandState *state, const size_t &capacity, - float *beta, const size_t &batch_size, size_t *indices, - float *weights, cudaStream_t stream); -template CUDA_LIB_EXPORT void SumTreeUpdate(SumMinTree *tree, const size_t &capacity, const float &alpha, - float *max_priority, size_t *indices, float *priorities, - const size_t &batch_size, cudaStream_t stream); + const size_t &round_start, float *beta, + const size_t &batch_size, size_t *indices, float *weights, + cudaStream_t stream); +template CUDA_LIB_EXPORT void SumTreeUpdate(SumMinTree *tree, const size_t &capacity, + const size_t &last_idx, const float &alpha, float *max_priority, + size_t *indices, float *priorities, const size_t &batch_size, + cudaStream_t stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/rl/priority_replay_buffer.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/rl/priority_replay_buffer.cuh index 209f2d73b35..bf41633bdcb 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/rl/priority_replay_buffer.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/rl/priority_replay_buffer.cuh @@ -86,11 +86,13 @@ CUDA_LIB_EXPORT void SumTreePush(T *tree, const float &alpha, const size_t &idx, // Sample a batch item. Return indices and correction weights. template -CUDA_LIB_EXPORT void SumTreeSample(T *tree, curandState *state, const size_t &capacity, float *beta, - const size_t &batch_size, size_t *indices, float *weights, cudaStream_t stream); +CUDA_LIB_EXPORT void SumTreeSample(T *tree, curandState *state, const size_t &capacity, const size_t &round_start, + float *beta, const size_t &batch_size, size_t *indices, float *weights, + cudaStream_t stream); // Update item priority. template -CUDA_LIB_EXPORT void SumTreeUpdate(T *tree, const size_t &capacity, const float &alpha, float *max_priority, - size_t *indices, float *priorities, const size_t &batch_size, cudaStream_t stream); +CUDA_LIB_EXPORT void SumTreeUpdate(T *tree, const size_t &capacity, const size_t &last_idx, const float &alpha, + float *max_priority, size_t *indices, float *priorities, const size_t &batch_size, + cudaStream_t stream); #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMP_PRIORITY_REPLAY_BUFFER_IMPL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/rl/priority_replay_buffer.h b/mindspore/ccsrc/plugin/device/gpu/kernel/rl/priority_replay_buffer.h index 7e8e6a4eb7f..8b2ccafcb24 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/rl/priority_replay_buffer.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/rl/priority_replay_buffer.h @@ -56,6 +56,11 @@ class PriorityReplayBuffer { bool UpdatePriorities(size_t *indices, float *priorities, const size_t &batch_size, cudaStream_t stream); private: + size_t GetStartIndex() const { return total_num_ - total_num_ % capacity_; } + size_t GetLastRoundIndex() const { + return std::max(SizeToLong(total_num_) - SizeToLong(capacity_), static_cast(0)); + } + float alpha_{1.}; std::vector schema_; @@ -63,13 +68,15 @@ class PriorityReplayBuffer { curandState *rand_state_{nullptr}; size_t capacity_{0}; - size_t valid_size_{0}; - size_t head_{-1UL}; std::vector fifo_replay_buffer_; size_t capacity_pow_two_{0}; float *max_priority_{nullptr}; Tree *sum_tree_{nullptr}; + + // Member variables for distributed scenario: + // The operand of `UpdatePriorities()` is replaced by `Push()`. + size_t total_num_{-1UL}; }; template @@ -116,20 +123,19 @@ PriorityReplayBuffer::~PriorityReplayBuffer() { template bool PriorityReplayBuffer::Push(const std::vector &transition, float *priority, cudaStream_t stream) { - // Head point to the latest item. - head_ = head_ >= capacity_ ? 0 : head_ + 1; - valid_size_ = valid_size_ >= capacity_ ? capacity_ : valid_size_ + 1; + total_num_++; + size_t idx = total_num_ % capacity_; // Copy transition to FIFO. for (size_t i = 0; i < transition.size(); i++) { - size_t offset = head_ * schema_[i]; + size_t offset = idx * schema_[i]; CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(fifo_replay_buffer_[i] + offset, transition[i]->addr, schema_[i], cudaMemcpyDeviceToDevice, stream), "cudaMemcpyAsync failed."); } // Set max priority for the newest transition. - SumTreePush(sum_tree_, alpha_, head_, capacity_pow_two_, priority, max_priority_, stream); + SumTreePush(sum_tree_, alpha_, idx, capacity_pow_two_, priority, max_priority_, stream); return true; } @@ -145,7 +151,8 @@ bool PriorityReplayBuffer::Sample(const size_t &batch_size, float *beta, s InitRandState(batch_size, seed_, rand_state_, stream); } - SumTreeSample(sum_tree_, rand_state_, capacity_pow_two_, beta, batch_size, indices, weights, stream); + size_t base_idx = GetStartIndex(); + SumTreeSample(sum_tree_, rand_state_, capacity_pow_two_, base_idx, beta, batch_size, indices, weights, stream); for (size_t i = 0; i < schema_.size(); i++) { auto output_addr = static_cast(transition[i]->addr); @@ -158,7 +165,8 @@ bool PriorityReplayBuffer::Sample(const size_t &batch_size, float *beta, s template bool PriorityReplayBuffer::UpdatePriorities(size_t *indices, float *priorities, const size_t &batch_size, cudaStream_t stream) { - SumTreeUpdate(sum_tree_, capacity_pow_two_, alpha_, max_priority_, indices, priorities, batch_size, stream); + size_t last = GetLastRoundIndex(); + SumTreeUpdate(sum_tree_, capacity_pow_two_, last, alpha_, max_priority_, indices, priorities, batch_size, stream); return true; } } // namespace gpu