!39199 reservior replay buffer gpu kernel

Merge pull request !39199 from chenweifeng/reservoir-replay-buffer-gpu-kernel
This commit is contained in:
i-robot 2022-08-03 01:57:07 +00:00 committed by Gitee
commit 6bbdfb3c5a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 519 additions and 5 deletions

View File

@ -108,11 +108,12 @@ __global__ void SrandUInt(const int size, curandState *globalState, unsigned int
}
}
__global__ void SrandUniformInt(const int size, curandState *globalState, const int upBound, unsigned int *out) {
template <typename T>
__global__ void SrandUniformInt(const int size, curandState *globalState, const int upBound, T *out) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
// curand_uniform return a pseudorandom floats uniformly distributed between 0.0 and 1.0, where 1.0 is
// included and 0.0 is excluded. So decrease the upBound by 1 to avoid out of range.
out[i] = static_cast<unsigned int>(curand_uniform(&globalState[i]) * (upBound - 1));
out[i] = static_cast<T>(curand_uniform(&globalState[i]) * (upBound - 1));
}
}
@ -159,7 +160,12 @@ void RandomGen(const int size, curandState *globalState, unsigned int *value, un
thrust::sort_by_key(policy, dev_key_ptr, dev_key_ptr + size, dev_data_ptr);
}
void RandomGenUniform(const int size, curandState *globalState, const int up_bound, unsigned int *indexes,
cudaStream_t stream) {
template <typename T>
void RandomGenUniform(const int size, curandState *globalState, const int up_bound, T *indexes, cudaStream_t stream) {
SrandUniformInt<<<(size + 255) / 256, 256, 0, stream>>>(size, globalState, up_bound, indexes);
}
template void RandomGenUniform<unsigned int>(const int size, curandState *globalState, const int up_bound,
unsigned int *indexes, cudaStream_t stream);
template void RandomGenUniform<size_t>(const int size, curandState *globalState, const int up_bound, size_t *indexes,
cudaStream_t stream);

View File

@ -31,6 +31,8 @@ void BufferSample(const size_t size, const size_t one_element, const unsigned in
unsigned char *out, cudaStream_t cuda_stream);
void RandomGen(const int size, curandState *globalState, unsigned int *value, unsigned int *key, cudaStream_t stream);
void RandInit(const int size, const int seed, curandState *state, cudaStream_t stream);
void RandomGenUniform(const int size, curandState *globalState, const int up_bound, unsigned int *indexes,
template <typename T>
void RandomGenUniform(const int size, curandState *globalState, const int up_bound, T *indexes,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_

View File

@ -0,0 +1,122 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/rl/reservoir_replay_buffer.h"
#include <curand_kernel.h>
#include <vector>
#include <memory>
#include <algorithm>
#include "kernel/kernel.h"
#include "plugin/device/gpu/hal/device/gpu_common.h"
#include "plugin/device/gpu/hal/device/gpu_memory_allocator.h"
#include "plugin/device/gpu/kernel/cuda_impl/rl/rl_buffer_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/rl/priority_replay_buffer.cuh"
namespace mindspore {
namespace kernel {
namespace gpu {
ReservoirReplayBuffer::ReservoirReplayBuffer(const uint64_t &seed, const size_t &capacity,
const std::vector<size_t> &schema) {
schema_ = schema;
seed_ = seed;
capacity_ = capacity;
// Init random generator.
generator_.seed(seed_);
// Allocate device memory.
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
for (const auto &size : schema) {
fifo_replay_buffer_.emplace_back(static_cast<uint8_t *>(allocator.AllocTensorMem(size * capacity)));
}
}
ReservoirReplayBuffer::~ReservoirReplayBuffer() {
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
if (indices_) {
allocator.FreeTensorMem(indices_);
indices_ = nullptr;
}
if (rand_state_) {
allocator.FreeTensorMem(rand_state_);
rand_state_ = nullptr;
}
for (auto item : fifo_replay_buffer_) {
allocator.FreeTensorMem(item);
}
}
bool ReservoirReplayBuffer::Insert(const size_t &pos, const std::vector<AddressPtr> &transition, cudaStream_t stream) {
for (size_t i = 0; i < transition.size(); i++) {
size_t offset = pos * schema_[i];
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(fifo_replay_buffer_[i] + offset, transition[i]->addr, schema_[i],
cudaMemcpyDeviceToDevice, stream),
"cudaMemcpyAsync failed.");
}
return true;
}
bool ReservoirReplayBuffer::Push(const std::vector<AddressPtr> &transition, cudaStream_t stream) {
// The buffer is not full: Push the transition at end of buffer.
if (total_ < capacity_) {
auto ret = Insert(total_, transition, stream);
total_++;
return ret;
}
// The buffer is full: Random discard this sample or replace the an old one.
auto replace_threthold = static_cast<float>(capacity_) / static_cast<float>(total_);
std::uniform_real_distribution<float> keep_dist(0, 1);
auto prob = keep_dist(generator_);
if (prob < replace_threthold) {
total_++;
std::uniform_int_distribution<size_t> pos_dist(0, capacity_ - 1);
size_t pos = pos_dist(generator_);
return Insert(pos, transition, stream);
}
return true;
}
bool ReservoirReplayBuffer::Sample(const size_t &batch_size, const std::vector<AddressPtr> &transition,
cudaStream_t stream) {
if (!rand_state_) {
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
rand_state_ = static_cast<curandState *>(allocator.AllocTensorMem(sizeof(curandState) * batch_size));
RandInit(batch_size, seed_, rand_state_, stream);
}
if (!indices_) {
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
indices_ = static_cast<size_t *>(allocator.AllocTensorMem(sizeof(size_t) * batch_size));
}
size_t valid_size = std::min(total_, capacity_);
RandomGenUniform(batch_size, rand_state_, valid_size, indices_, stream);
for (size_t i = 0; i < schema_.size(); i++) {
auto output_addr = static_cast<uint8_t *>(transition[i]->addr);
FifoSlice(fifo_replay_buffer_[i], indices_, output_addr, batch_size, schema_[i], stream);
}
return true;
}
} // namespace gpu
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,60 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_RL_RESERVOIR_REPLAY_BUFFER_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_RL_RESERVOIR_REPLAY_BUFFER_H_
#include <cuda_runtime_api.h>
#include <curand_kernel.h>
#include <vector>
#include <memory>
#include <random>
#include "kernel/kernel.h"
namespace mindspore {
namespace kernel {
namespace gpu {
class ReservoirReplayBuffer {
public:
// Construct a fixed-length reservoir replay buffer.
ReservoirReplayBuffer(const uint64_t &seed, const size_t &capacity, const std::vector<size_t> &schema);
~ReservoirReplayBuffer();
// Push an experience transition to the buffer which will be given the highest reservoir.
bool Push(const std::vector<AddressPtr> &transition, cudaStream_t stream);
// Sample a batch transitions with indices and bias correction weights.
bool Sample(const size_t &batch_size, const std::vector<AddressPtr> &transition, cudaStream_t stream);
private:
bool Insert(const size_t &pos, const std::vector<AddressPtr> &transition, cudaStream_t stream);
// Random generator
std::default_random_engine generator_;
curandState *rand_state_{nullptr};
uint64_t seed_{42};
size_t capacity_{0};
size_t total_{0};
size_t *indices_{nullptr};
std::vector<size_t> schema_;
std::vector<uint8_t *> fifo_replay_buffer_;
};
} // namespace gpu
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_RL_RESERVOIR_REPLAY_BUFFER_H_

View File

@ -0,0 +1,210 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/rl/reservoir_replay_buffer_gpu_kernel.h"
#include <vector>
#include <random>
#include <algorithm>
#include <functional>
#include "mindspore/core/ops/reservoir_replay_buffer.h"
namespace mindspore {
namespace kernel {
using ReservoirReplayBufferFactory = ReplayBufferFactory<ReservoirReplayBuffer>;
ReservoirReplayBufferCreateGpuKernel::~ReservoirReplayBufferCreateGpuKernel() {
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
if (handle_device_) {
allocator.FreeTensorMem(handle_device_);
}
}
bool ReservoirReplayBufferCreateGpuKernel::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::ReservoirReplayBufferCreate>(base_operator);
if (!kernel_ptr) {
MS_LOG(ERROR) << "cast ReservoirReplayBufferCreate ops failed!";
return false;
}
const int64_t &capacity = kernel_ptr->get_capacity();
const std::vector<int64_t> &schema = kernel_ptr->get_schema();
const int64_t &seed0 = kernel_ptr->get_seed0();
const int64_t &seed1 = kernel_ptr->get_seed1();
unsigned int seed = 0;
std::random_device rd;
if (seed1 != 0) {
seed = static_cast<unsigned int>(seed1);
} else if (seed0 != 0) {
seed = static_cast<unsigned int>(seed0);
} else {
seed = rd();
}
std::vector<size_t> schema_in_size;
std::transform(schema.begin(), schema.end(), std::back_inserter(schema_in_size),
[](const int64_t &arg) -> size_t { return LongToSize(arg); });
auto &factory = ReservoirReplayBufferFactory::GetInstance();
std::tie(handle_, reservoir_replay_buffer_) = factory.Create(seed, capacity, schema_in_size);
MS_EXCEPTION_IF_NULL(reservoir_replay_buffer_);
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
handle_device_ = static_cast<int64_t *>(allocator.AllocTensorMem(sizeof(handle_)));
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpy(handle_device_, &handle_, sizeof(handle_), cudaMemcpyHostToDevice),
"cudaMemcpy failed.");
output_size_list_.push_back(sizeof(handle_));
return true;
}
std::vector<KernelAttr> ReservoirReplayBufferCreateGpuKernel::GetOpSupport() {
std::vector<KernelAttr> support_list = {KernelAttr().AddOutputAttr(kNumberTypeInt64)};
return support_list;
}
bool ReservoirReplayBufferCreateGpuKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
auto handle = GetDeviceAddress<int64_t>(outputs, 0);
auto stream = reinterpret_cast<cudaStream_t>(stream_ptr);
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
cudaMemcpyAsync(handle, handle_device_, sizeof(handle_), cudaMemcpyDeviceToDevice, stream), "cudaMemcpy failed.");
return true;
}
ReservoirReplayBufferPushGpuKernel::~ReservoirReplayBufferPushGpuKernel() {
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
if (handle_device_) {
allocator.FreeTensorMem(handle_device_);
}
}
bool ReservoirReplayBufferPushGpuKernel::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::ReservoirReplayBufferPush>(base_operator);
if (!kernel_ptr) {
MS_LOG(ERROR) << "cast ReservoirReplayBufferPush ops failed!";
return false;
}
handle_ = kernel_ptr->get_handle();
reservior_replay_buffer_ = ReservoirReplayBufferFactory::GetInstance().GetByHandle(handle_);
MS_EXCEPTION_IF_NULL(reservior_replay_buffer_);
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
handle_device_ = static_cast<int64_t *>(allocator.AllocTensorMem(sizeof(handle_)));
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpy(handle_device_, &handle_, sizeof(handle_), cudaMemcpyHostToDevice),
"cudaMemcpy failed.");
for (size_t i = 0; i < inputs.size(); i++) {
TypeId type_id = inputs[i]->GetDtype();
size_t type_size = GetTypeByte(TypeIdToType(type_id));
const std::vector<int64_t> &shape = inputs[i]->GetShapeVector();
size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
input_size_list_.push_back(tensor_size);
}
output_size_list_.push_back(sizeof(handle_));
return true;
}
bool ReservoirReplayBufferPushGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
auto stream = reinterpret_cast<cudaStream_t>(stream_ptr);
// Return a placeholder in case of dead code eliminate optimization.
auto handle = GetDeviceAddress<int64_t>(outputs, 0);
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
cudaMemcpyAsync(handle, handle_device_, sizeof(handle_), cudaMemcpyDeviceToDevice, stream), "cudaMemcpy failed.");
return reservior_replay_buffer_->Push(inputs, stream);
}
std::vector<KernelAttr> ReservoirReplayBufferPushGpuKernel::GetOpSupport() {
std::vector<KernelAttr> support_list = {KernelAttr().AddSkipCheckAttr(true)};
return support_list;
}
bool ReservoirReplayBufferSampleGpuKernel::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::ReservoirReplayBufferSample>(base_operator);
if (!kernel_ptr) {
MS_LOG(ERROR) << "cast ReservoirReplayBufferSample ops failed!";
return false;
}
handle_ = kernel_ptr->get_handle();
batch_size_ = kernel_ptr->get_batch_size();
reservior_replay_buffer_ = ReservoirReplayBufferFactory::GetInstance().GetByHandle(handle_);
MS_EXCEPTION_IF_NULL(reservior_replay_buffer_);
for (size_t i = 0; i < outputs.size(); i++) {
TypeId type_id = outputs[i]->GetDtype();
size_t type_size = GetTypeByte(TypeIdToType(type_id));
const std::vector<int64_t> &shape = outputs[i]->GetShapeVector();
size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
output_size_list_.push_back(tensor_size);
}
return true;
}
bool ReservoirReplayBufferSampleGpuKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
return reservior_replay_buffer_->Sample(batch_size_, outputs, reinterpret_cast<cudaStream_t>(stream_ptr));
}
std::vector<KernelAttr> ReservoirReplayBufferSampleGpuKernel::GetOpSupport() {
std::vector<KernelAttr> support_list = {KernelAttr().AddSkipCheckAttr(true)};
return support_list;
}
bool ReservoirReplayBufferDestroyGpuKernel::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::ReservoirReplayBufferDestroy>(base_operator);
if (!kernel_ptr) {
MS_LOG(ERROR) << "cast ReservoirReplayBufferDestroy ops failed!";
return false;
}
handle_ = kernel_ptr->get_handle();
output_size_list_.push_back(sizeof(handle_));
return true;
}
bool ReservoirReplayBufferDestroyGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
ReservoirReplayBufferFactory::GetInstance().Delete(handle_);
// Apply host to device memory copy since it is not performance critical path.
auto handle = GetDeviceAddress<float>(outputs, 0);
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(handle, &handle_, sizeof(handle_), cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpy failed.");
return true;
}
std::vector<KernelAttr> ReservoirReplayBufferDestroyGpuKernel::GetOpSupport() {
std::vector<KernelAttr> support_list = {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64)};
return support_list;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,114 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_RL_RESERVOIR_REPLAY_BUFFER_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_RL_RESERVOIR_REPLAY_BUFFER_GPU_KERNEL_H_
#include "plugin/device/gpu/kernel/rl/reservoir_replay_buffer.h"
#include <vector>
#include <memory>
#include "plugin/factory/replay_buffer_factory.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
using gpu::ReservoirReplayBuffer;
class ReservoirReplayBufferCreateGpuKernel : public NativeGpuKernelMod {
public:
ReservoirReplayBufferCreateGpuKernel() = default;
~ReservoirReplayBufferCreateGpuKernel() override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream) override;
std::vector<KernelAttr> GetOpSupport() override;
private:
int64_t handle_{-1};
int64_t *handle_device_{nullptr};
std::shared_ptr<ReservoirReplayBuffer> reservoir_replay_buffer_{nullptr};
};
class ReservoirReplayBufferPushGpuKernel : public NativeGpuKernelMod {
public:
ReservoirReplayBufferPushGpuKernel() = default;
~ReservoirReplayBufferPushGpuKernel() override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream) override;
std::vector<KernelAttr> GetOpSupport() override;
private:
int64_t handle_{-1};
int64_t *handle_device_{nullptr};
std::shared_ptr<ReservoirReplayBuffer> reservior_replay_buffer_{nullptr};
};
class ReservoirReplayBufferSampleGpuKernel : public NativeGpuKernelMod {
public:
ReservoirReplayBufferSampleGpuKernel() = default;
~ReservoirReplayBufferSampleGpuKernel() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream) override;
std::vector<KernelAttr> GetOpSupport() override;
private:
int64_t handle_{-1};
size_t batch_size_{0};
std::vector<size_t> schema_;
std::shared_ptr<ReservoirReplayBuffer> reservior_replay_buffer_{nullptr};
};
class ReservoirReplayBufferDestroyGpuKernel : public NativeGpuKernelMod {
public:
ReservoirReplayBufferDestroyGpuKernel() = default;
~ReservoirReplayBufferDestroyGpuKernel() override = default;
// Init kernel from CNode.
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
// Execute kernel.
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream) override;
std::vector<KernelAttr> GetOpSupport() override;
private:
int64_t handle_{-1};
};
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, ReservoirReplayBufferCreate, ReservoirReplayBufferCreateGpuKernel);
MS_REG_GPU_KERNEL(ReservoirReplayBufferPush, ReservoirReplayBufferPushGpuKernel)
MS_REG_GPU_KERNEL(ReservoirReplayBufferSample, ReservoirReplayBufferSampleGpuKernel)
MS_REG_GPU_KERNEL(ReservoirReplayBufferDestroy, ReservoirReplayBufferDestroyGpuKernel)
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_RL_RESERVOIR_REPLAY_BUFFER_GPU_KERNEL_H_