!41004 gpu priority replay buffer support distributed training

Merge pull request !41004 from chenweifeng/prb-multi-process
This commit is contained in:
i-robot 2022-08-28 09:48:09 +00:00 committed by Gitee
commit 96bd8af06a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 46 additions and 29 deletions

View File

@ -108,22 +108,25 @@ __forceinline__ __device__ size_t GetPrefixSumIdx(T *tree, size_t capacity, floa
}
template <typename T>
__global__ void SumTreeSampleKernel(T *tree, curandState *state, size_t capacity, float *beta, size_t batch_size,
size_t *indices, float *weights) {
__global__ void SumTreeSampleKernel(T *tree, curandState *state, size_t capacity, size_t round_start, float *beta,
size_t batch_size, size_t *indices, float *weights) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size; i += gridDim.x * blockDim.x) {
size_t segment_len = tree[kRootIdx].sum / batch_size;
float prefix_sum = (curand_uniform(&state[i]) + i) * segment_len;
size_t idx = GetPrefixSumIdx(tree, capacity, prefix_sum);
indices[i] = idx;
indices[i] = idx + round_start;
weights[i] = powf((tree[idx + capacity].sum / tree[kRootIdx].min), -beta[0]);
}
}
template <typename T>
__global__ void SumTreeUpdateKernel(T *tree, size_t capacity, float alpha, float *max_priority, size_t *indices,
float *priorities, size_t batch_size) {
__global__ void SumTreeUpdateKernel(T *tree, size_t capacity, size_t last_idx, float alpha, float *max_priority,
size_t *indices, float *priorities, size_t batch_size) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size; i += gridDim.x * blockDim.x) {
size_t idx = indices[i];
// skip if the transition is already replaced.
if (idx < last_idx) continue;
float priority = powf(priorities[i], alpha);
MsAtomicMax(max_priority, priority);
@ -151,20 +154,22 @@ void SumTreePush(T *tree, const float &alpha, const size_t &idx, const size_t &c
// Sample a batch item. Return indices and correction weights.
template <typename T>
void SumTreeSample(T *tree, curandState *state, const size_t &capacity, float *beta, const size_t &batch_size,
size_t *indices, float *weights, cudaStream_t stream) {
void SumTreeSample(T *tree, curandState *state, const size_t &capacity, const size_t &round_start, float *beta,
const size_t &batch_size, size_t *indices, float *weights, cudaStream_t stream) {
size_t block = std::min(batch_size, kMaxThreadPerBlock);
size_t grid = (batch_size + block - 1) / block;
SumTreeSampleKernel<<<grid, block, 0, stream>>>(tree, state, capacity, beta, batch_size, indices, weights);
SumTreeSampleKernel<<<grid, block, 0, stream>>>(tree, state, capacity, round_start, beta, batch_size, indices,
weights);
}
// Update item priority.
template <typename T>
void SumTreeUpdate(T *tree, const size_t &capacity, const float &alpha, float *max_priority, size_t *indices,
float *priorities, const size_t &batch_size, cudaStream_t stream) {
void SumTreeUpdate(T *tree, const size_t &capacity, const size_t &last_idx, const float &alpha, float *max_priority,
size_t *indices, float *priorities, const size_t &batch_size, cudaStream_t stream) {
size_t block = std::min(batch_size, kMaxThreadPerBlock);
size_t grid = (batch_size + block - 1) / block;
SumTreeUpdateKernel<<<grid, block, 0, stream>>>(tree, capacity, alpha, max_priority, indices, priorities, batch_size);
SumTreeUpdateKernel<<<grid, block, 0, stream>>>(tree, capacity, last_idx, alpha, max_priority, indices, priorities,
batch_size);
}
template CUDA_LIB_EXPORT void SumTreeInit<SumMinTree>(SumMinTree *tree, float *max_priority, const size_t &capacity,
@ -173,8 +178,10 @@ template CUDA_LIB_EXPORT void SumTreePush<SumMinTree>(SumMinTree *tree, const fl
const size_t &capacity, float *priority, float *max_priority,
cudaStream_t stream);
template CUDA_LIB_EXPORT void SumTreeSample<SumMinTree>(SumMinTree *tree, curandState *state, const size_t &capacity,
float *beta, const size_t &batch_size, size_t *indices,
float *weights, cudaStream_t stream);
template CUDA_LIB_EXPORT void SumTreeUpdate<SumMinTree>(SumMinTree *tree, const size_t &capacity, const float &alpha,
float *max_priority, size_t *indices, float *priorities,
const size_t &batch_size, cudaStream_t stream);
const size_t &round_start, float *beta,
const size_t &batch_size, size_t *indices, float *weights,
cudaStream_t stream);
template CUDA_LIB_EXPORT void SumTreeUpdate<SumMinTree>(SumMinTree *tree, const size_t &capacity,
const size_t &last_idx, const float &alpha, float *max_priority,
size_t *indices, float *priorities, const size_t &batch_size,
cudaStream_t stream);

View File

@ -86,11 +86,13 @@ CUDA_LIB_EXPORT void SumTreePush(T *tree, const float &alpha, const size_t &idx,
// Sample a batch item. Return indices and correction weights.
template <typename T>
CUDA_LIB_EXPORT void SumTreeSample(T *tree, curandState *state, const size_t &capacity, float *beta,
const size_t &batch_size, size_t *indices, float *weights, cudaStream_t stream);
CUDA_LIB_EXPORT void SumTreeSample(T *tree, curandState *state, const size_t &capacity, const size_t &round_start,
float *beta, const size_t &batch_size, size_t *indices, float *weights,
cudaStream_t stream);
// Update item priority.
template <typename T>
CUDA_LIB_EXPORT void SumTreeUpdate(T *tree, const size_t &capacity, const float &alpha, float *max_priority,
size_t *indices, float *priorities, const size_t &batch_size, cudaStream_t stream);
CUDA_LIB_EXPORT void SumTreeUpdate(T *tree, const size_t &capacity, const size_t &last_idx, const float &alpha,
float *max_priority, size_t *indices, float *priorities, const size_t &batch_size,
cudaStream_t stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMP_PRIORITY_REPLAY_BUFFER_IMPL_H_

View File

@ -56,6 +56,11 @@ class PriorityReplayBuffer {
bool UpdatePriorities(size_t *indices, float *priorities, const size_t &batch_size, cudaStream_t stream);
private:
size_t GetStartIndex() const { return total_num_ - total_num_ % capacity_; }
size_t GetLastRoundIndex() const {
return std::max(SizeToLong(total_num_) - SizeToLong(capacity_), static_cast<int64_t>(0));
}
float alpha_{1.};
std::vector<size_t> schema_;
@ -63,13 +68,15 @@ class PriorityReplayBuffer {
curandState *rand_state_{nullptr};
size_t capacity_{0};
size_t valid_size_{0};
size_t head_{-1UL};
std::vector<uint8_t *> fifo_replay_buffer_;
size_t capacity_pow_two_{0};
float *max_priority_{nullptr};
Tree *sum_tree_{nullptr};
// Member variables for distributed scenario:
// The operand of `UpdatePriorities()` is replaced by `Push()`.
size_t total_num_{-1UL};
};
template <typename Tree>
@ -116,20 +123,19 @@ PriorityReplayBuffer<Tree>::~PriorityReplayBuffer() {
template <typename Tree>
bool PriorityReplayBuffer<Tree>::Push(const std::vector<AddressPtr> &transition, float *priority, cudaStream_t stream) {
// Head point to the latest item.
head_ = head_ >= capacity_ ? 0 : head_ + 1;
valid_size_ = valid_size_ >= capacity_ ? capacity_ : valid_size_ + 1;
total_num_++;
size_t idx = total_num_ % capacity_;
// Copy transition to FIFO.
for (size_t i = 0; i < transition.size(); i++) {
size_t offset = head_ * schema_[i];
size_t offset = idx * schema_[i];
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(fifo_replay_buffer_[i] + offset, transition[i]->addr, schema_[i],
cudaMemcpyDeviceToDevice, stream),
"cudaMemcpyAsync failed.");
}
// Set max priority for the newest transition.
SumTreePush(sum_tree_, alpha_, head_, capacity_pow_two_, priority, max_priority_, stream);
SumTreePush(sum_tree_, alpha_, idx, capacity_pow_two_, priority, max_priority_, stream);
return true;
}
@ -145,7 +151,8 @@ bool PriorityReplayBuffer<Tree>::Sample(const size_t &batch_size, float *beta, s
InitRandState(batch_size, seed_, rand_state_, stream);
}
SumTreeSample(sum_tree_, rand_state_, capacity_pow_two_, beta, batch_size, indices, weights, stream);
size_t base_idx = GetStartIndex();
SumTreeSample(sum_tree_, rand_state_, capacity_pow_two_, base_idx, beta, batch_size, indices, weights, stream);
for (size_t i = 0; i < schema_.size(); i++) {
auto output_addr = static_cast<uint8_t *>(transition[i]->addr);
@ -158,7 +165,8 @@ bool PriorityReplayBuffer<Tree>::Sample(const size_t &batch_size, float *beta, s
template <typename Tree>
bool PriorityReplayBuffer<Tree>::UpdatePriorities(size_t *indices, float *priorities, const size_t &batch_size,
cudaStream_t stream) {
SumTreeUpdate(sum_tree_, capacity_pow_two_, alpha_, max_priority_, indices, priorities, batch_size, stream);
size_t last = GetLastRoundIndex();
SumTreeUpdate(sum_tree_, capacity_pow_two_, last, alpha_, max_priority_, indices, priorities, batch_size, stream);
return true;
}
} // namespace gpu