forked from mindspore-Ecosystem/mindspore
!41004 gpu priority replay buffer support distributed training
Merge pull request !41004 from chenweifeng/prb-multi-process
This commit is contained in:
commit
96bd8af06a
|
@ -108,22 +108,25 @@ __forceinline__ __device__ size_t GetPrefixSumIdx(T *tree, size_t capacity, floa
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void SumTreeSampleKernel(T *tree, curandState *state, size_t capacity, float *beta, size_t batch_size,
|
||||
size_t *indices, float *weights) {
|
||||
__global__ void SumTreeSampleKernel(T *tree, curandState *state, size_t capacity, size_t round_start, float *beta,
|
||||
size_t batch_size, size_t *indices, float *weights) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size; i += gridDim.x * blockDim.x) {
|
||||
size_t segment_len = tree[kRootIdx].sum / batch_size;
|
||||
float prefix_sum = (curand_uniform(&state[i]) + i) * segment_len;
|
||||
size_t idx = GetPrefixSumIdx(tree, capacity, prefix_sum);
|
||||
indices[i] = idx;
|
||||
indices[i] = idx + round_start;
|
||||
weights[i] = powf((tree[idx + capacity].sum / tree[kRootIdx].min), -beta[0]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void SumTreeUpdateKernel(T *tree, size_t capacity, float alpha, float *max_priority, size_t *indices,
|
||||
float *priorities, size_t batch_size) {
|
||||
__global__ void SumTreeUpdateKernel(T *tree, size_t capacity, size_t last_idx, float alpha, float *max_priority,
|
||||
size_t *indices, float *priorities, size_t batch_size) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size; i += gridDim.x * blockDim.x) {
|
||||
size_t idx = indices[i];
|
||||
// skip if the transition is already replaced.
|
||||
if (idx < last_idx) continue;
|
||||
|
||||
float priority = powf(priorities[i], alpha);
|
||||
MsAtomicMax(max_priority, priority);
|
||||
|
||||
|
@ -151,20 +154,22 @@ void SumTreePush(T *tree, const float &alpha, const size_t &idx, const size_t &c
|
|||
|
||||
// Sample a batch item. Return indices and correction weights.
|
||||
template <typename T>
|
||||
void SumTreeSample(T *tree, curandState *state, const size_t &capacity, float *beta, const size_t &batch_size,
|
||||
size_t *indices, float *weights, cudaStream_t stream) {
|
||||
void SumTreeSample(T *tree, curandState *state, const size_t &capacity, const size_t &round_start, float *beta,
|
||||
const size_t &batch_size, size_t *indices, float *weights, cudaStream_t stream) {
|
||||
size_t block = std::min(batch_size, kMaxThreadPerBlock);
|
||||
size_t grid = (batch_size + block - 1) / block;
|
||||
SumTreeSampleKernel<<<grid, block, 0, stream>>>(tree, state, capacity, beta, batch_size, indices, weights);
|
||||
SumTreeSampleKernel<<<grid, block, 0, stream>>>(tree, state, capacity, round_start, beta, batch_size, indices,
|
||||
weights);
|
||||
}
|
||||
|
||||
// Update item priority.
|
||||
template <typename T>
|
||||
void SumTreeUpdate(T *tree, const size_t &capacity, const float &alpha, float *max_priority, size_t *indices,
|
||||
float *priorities, const size_t &batch_size, cudaStream_t stream) {
|
||||
void SumTreeUpdate(T *tree, const size_t &capacity, const size_t &last_idx, const float &alpha, float *max_priority,
|
||||
size_t *indices, float *priorities, const size_t &batch_size, cudaStream_t stream) {
|
||||
size_t block = std::min(batch_size, kMaxThreadPerBlock);
|
||||
size_t grid = (batch_size + block - 1) / block;
|
||||
SumTreeUpdateKernel<<<grid, block, 0, stream>>>(tree, capacity, alpha, max_priority, indices, priorities, batch_size);
|
||||
SumTreeUpdateKernel<<<grid, block, 0, stream>>>(tree, capacity, last_idx, alpha, max_priority, indices, priorities,
|
||||
batch_size);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void SumTreeInit<SumMinTree>(SumMinTree *tree, float *max_priority, const size_t &capacity,
|
||||
|
@ -173,8 +178,10 @@ template CUDA_LIB_EXPORT void SumTreePush<SumMinTree>(SumMinTree *tree, const fl
|
|||
const size_t &capacity, float *priority, float *max_priority,
|
||||
cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void SumTreeSample<SumMinTree>(SumMinTree *tree, curandState *state, const size_t &capacity,
|
||||
float *beta, const size_t &batch_size, size_t *indices,
|
||||
float *weights, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void SumTreeUpdate<SumMinTree>(SumMinTree *tree, const size_t &capacity, const float &alpha,
|
||||
float *max_priority, size_t *indices, float *priorities,
|
||||
const size_t &batch_size, cudaStream_t stream);
|
||||
const size_t &round_start, float *beta,
|
||||
const size_t &batch_size, size_t *indices, float *weights,
|
||||
cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void SumTreeUpdate<SumMinTree>(SumMinTree *tree, const size_t &capacity,
|
||||
const size_t &last_idx, const float &alpha, float *max_priority,
|
||||
size_t *indices, float *priorities, const size_t &batch_size,
|
||||
cudaStream_t stream);
|
||||
|
|
|
@ -86,11 +86,13 @@ CUDA_LIB_EXPORT void SumTreePush(T *tree, const float &alpha, const size_t &idx,
|
|||
|
||||
// Sample a batch item. Return indices and correction weights.
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void SumTreeSample(T *tree, curandState *state, const size_t &capacity, float *beta,
|
||||
const size_t &batch_size, size_t *indices, float *weights, cudaStream_t stream);
|
||||
CUDA_LIB_EXPORT void SumTreeSample(T *tree, curandState *state, const size_t &capacity, const size_t &round_start,
|
||||
float *beta, const size_t &batch_size, size_t *indices, float *weights,
|
||||
cudaStream_t stream);
|
||||
|
||||
// Update item priority.
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void SumTreeUpdate(T *tree, const size_t &capacity, const float &alpha, float *max_priority,
|
||||
size_t *indices, float *priorities, const size_t &batch_size, cudaStream_t stream);
|
||||
CUDA_LIB_EXPORT void SumTreeUpdate(T *tree, const size_t &capacity, const size_t &last_idx, const float &alpha,
|
||||
float *max_priority, size_t *indices, float *priorities, const size_t &batch_size,
|
||||
cudaStream_t stream);
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMP_PRIORITY_REPLAY_BUFFER_IMPL_H_
|
||||
|
|
|
@ -56,6 +56,11 @@ class PriorityReplayBuffer {
|
|||
bool UpdatePriorities(size_t *indices, float *priorities, const size_t &batch_size, cudaStream_t stream);
|
||||
|
||||
private:
|
||||
size_t GetStartIndex() const { return total_num_ - total_num_ % capacity_; }
|
||||
size_t GetLastRoundIndex() const {
|
||||
return std::max(SizeToLong(total_num_) - SizeToLong(capacity_), static_cast<int64_t>(0));
|
||||
}
|
||||
|
||||
float alpha_{1.};
|
||||
|
||||
std::vector<size_t> schema_;
|
||||
|
@ -63,13 +68,15 @@ class PriorityReplayBuffer {
|
|||
curandState *rand_state_{nullptr};
|
||||
|
||||
size_t capacity_{0};
|
||||
size_t valid_size_{0};
|
||||
size_t head_{-1UL};
|
||||
std::vector<uint8_t *> fifo_replay_buffer_;
|
||||
|
||||
size_t capacity_pow_two_{0};
|
||||
float *max_priority_{nullptr};
|
||||
Tree *sum_tree_{nullptr};
|
||||
|
||||
// Member variables for distributed scenario:
|
||||
// The operand of `UpdatePriorities()` is replaced by `Push()`.
|
||||
size_t total_num_{-1UL};
|
||||
};
|
||||
|
||||
template <typename Tree>
|
||||
|
@ -116,20 +123,19 @@ PriorityReplayBuffer<Tree>::~PriorityReplayBuffer() {
|
|||
|
||||
template <typename Tree>
|
||||
bool PriorityReplayBuffer<Tree>::Push(const std::vector<AddressPtr> &transition, float *priority, cudaStream_t stream) {
|
||||
// Head point to the latest item.
|
||||
head_ = head_ >= capacity_ ? 0 : head_ + 1;
|
||||
valid_size_ = valid_size_ >= capacity_ ? capacity_ : valid_size_ + 1;
|
||||
total_num_++;
|
||||
size_t idx = total_num_ % capacity_;
|
||||
|
||||
// Copy transition to FIFO.
|
||||
for (size_t i = 0; i < transition.size(); i++) {
|
||||
size_t offset = head_ * schema_[i];
|
||||
size_t offset = idx * schema_[i];
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(fifo_replay_buffer_[i] + offset, transition[i]->addr, schema_[i],
|
||||
cudaMemcpyDeviceToDevice, stream),
|
||||
"cudaMemcpyAsync failed.");
|
||||
}
|
||||
|
||||
// Set max priority for the newest transition.
|
||||
SumTreePush(sum_tree_, alpha_, head_, capacity_pow_two_, priority, max_priority_, stream);
|
||||
SumTreePush(sum_tree_, alpha_, idx, capacity_pow_two_, priority, max_priority_, stream);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -145,7 +151,8 @@ bool PriorityReplayBuffer<Tree>::Sample(const size_t &batch_size, float *beta, s
|
|||
InitRandState(batch_size, seed_, rand_state_, stream);
|
||||
}
|
||||
|
||||
SumTreeSample(sum_tree_, rand_state_, capacity_pow_two_, beta, batch_size, indices, weights, stream);
|
||||
size_t base_idx = GetStartIndex();
|
||||
SumTreeSample(sum_tree_, rand_state_, capacity_pow_two_, base_idx, beta, batch_size, indices, weights, stream);
|
||||
|
||||
for (size_t i = 0; i < schema_.size(); i++) {
|
||||
auto output_addr = static_cast<uint8_t *>(transition[i]->addr);
|
||||
|
@ -158,7 +165,8 @@ bool PriorityReplayBuffer<Tree>::Sample(const size_t &batch_size, float *beta, s
|
|||
template <typename Tree>
|
||||
bool PriorityReplayBuffer<Tree>::UpdatePriorities(size_t *indices, float *priorities, const size_t &batch_size,
|
||||
cudaStream_t stream) {
|
||||
SumTreeUpdate(sum_tree_, capacity_pow_two_, alpha_, max_priority_, indices, priorities, batch_size, stream);
|
||||
size_t last = GetLastRoundIndex();
|
||||
SumTreeUpdate(sum_tree_, capacity_pow_two_, last, alpha_, max_priority_, indices, priorities, batch_size, stream);
|
||||
return true;
|
||||
}
|
||||
} // namespace gpu
|
||||
|
|
Loading…
Reference in New Issue