!39199 reservior replay buffer gpu kernel
Merge pull request !39199 from chenweifeng/reservoir-replay-buffer-gpu-kernel
This commit is contained in:
commit
6bbdfb3c5a
|
@ -108,11 +108,12 @@ __global__ void SrandUInt(const int size, curandState *globalState, unsigned int
|
|||
}
|
||||
}
|
||||
|
||||
__global__ void SrandUniformInt(const int size, curandState *globalState, const int upBound, unsigned int *out) {
|
||||
template <typename T>
|
||||
__global__ void SrandUniformInt(const int size, curandState *globalState, const int upBound, T *out) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
// curand_uniform return a pseudorandom floats uniformly distributed between 0.0 and 1.0, where 1.0 is
|
||||
// included and 0.0 is excluded. So decrease the upBound by 1 to avoid out of range.
|
||||
out[i] = static_cast<unsigned int>(curand_uniform(&globalState[i]) * (upBound - 1));
|
||||
out[i] = static_cast<T>(curand_uniform(&globalState[i]) * (upBound - 1));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -159,7 +160,12 @@ void RandomGen(const int size, curandState *globalState, unsigned int *value, un
|
|||
thrust::sort_by_key(policy, dev_key_ptr, dev_key_ptr + size, dev_data_ptr);
|
||||
}
|
||||
|
||||
void RandomGenUniform(const int size, curandState *globalState, const int up_bound, unsigned int *indexes,
|
||||
cudaStream_t stream) {
|
||||
template <typename T>
|
||||
void RandomGenUniform(const int size, curandState *globalState, const int up_bound, T *indexes, cudaStream_t stream) {
|
||||
SrandUniformInt<<<(size + 255) / 256, 256, 0, stream>>>(size, globalState, up_bound, indexes);
|
||||
}
|
||||
|
||||
template void RandomGenUniform<unsigned int>(const int size, curandState *globalState, const int up_bound,
|
||||
unsigned int *indexes, cudaStream_t stream);
|
||||
template void RandomGenUniform<size_t>(const int size, curandState *globalState, const int up_bound, size_t *indexes,
|
||||
cudaStream_t stream);
|
||||
|
|
|
@ -31,6 +31,8 @@ void BufferSample(const size_t size, const size_t one_element, const unsigned in
|
|||
unsigned char *out, cudaStream_t cuda_stream);
|
||||
void RandomGen(const int size, curandState *globalState, unsigned int *value, unsigned int *key, cudaStream_t stream);
|
||||
void RandInit(const int size, const int seed, curandState *state, cudaStream_t stream);
|
||||
void RandomGenUniform(const int size, curandState *globalState, const int up_bound, unsigned int *indexes,
|
||||
|
||||
template <typename T>
|
||||
void RandomGenUniform(const int size, curandState *globalState, const int up_bound, T *indexes,
|
||||
cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/rl/reservoir_replay_buffer.h"
|
||||
|
||||
#include <curand_kernel.h>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "kernel/kernel.h"
|
||||
#include "plugin/device/gpu/hal/device/gpu_common.h"
|
||||
#include "plugin/device/gpu/hal/device/gpu_memory_allocator.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/rl/rl_buffer_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/rl/priority_replay_buffer.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace gpu {
|
||||
ReservoirReplayBuffer::ReservoirReplayBuffer(const uint64_t &seed, const size_t &capacity,
|
||||
const std::vector<size_t> &schema) {
|
||||
schema_ = schema;
|
||||
seed_ = seed;
|
||||
capacity_ = capacity;
|
||||
|
||||
// Init random generator.
|
||||
generator_.seed(seed_);
|
||||
|
||||
// Allocate device memory.
|
||||
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
|
||||
for (const auto &size : schema) {
|
||||
fifo_replay_buffer_.emplace_back(static_cast<uint8_t *>(allocator.AllocTensorMem(size * capacity)));
|
||||
}
|
||||
}
|
||||
|
||||
ReservoirReplayBuffer::~ReservoirReplayBuffer() {
|
||||
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
|
||||
if (indices_) {
|
||||
allocator.FreeTensorMem(indices_);
|
||||
indices_ = nullptr;
|
||||
}
|
||||
|
||||
if (rand_state_) {
|
||||
allocator.FreeTensorMem(rand_state_);
|
||||
rand_state_ = nullptr;
|
||||
}
|
||||
|
||||
for (auto item : fifo_replay_buffer_) {
|
||||
allocator.FreeTensorMem(item);
|
||||
}
|
||||
}
|
||||
|
||||
bool ReservoirReplayBuffer::Insert(const size_t &pos, const std::vector<AddressPtr> &transition, cudaStream_t stream) {
|
||||
for (size_t i = 0; i < transition.size(); i++) {
|
||||
size_t offset = pos * schema_[i];
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(fifo_replay_buffer_[i] + offset, transition[i]->addr, schema_[i],
|
||||
cudaMemcpyDeviceToDevice, stream),
|
||||
"cudaMemcpyAsync failed.");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ReservoirReplayBuffer::Push(const std::vector<AddressPtr> &transition, cudaStream_t stream) {
|
||||
// The buffer is not full: Push the transition at end of buffer.
|
||||
if (total_ < capacity_) {
|
||||
auto ret = Insert(total_, transition, stream);
|
||||
total_++;
|
||||
return ret;
|
||||
}
|
||||
|
||||
// The buffer is full: Random discard this sample or replace the an old one.
|
||||
auto replace_threthold = static_cast<float>(capacity_) / static_cast<float>(total_);
|
||||
std::uniform_real_distribution<float> keep_dist(0, 1);
|
||||
auto prob = keep_dist(generator_);
|
||||
if (prob < replace_threthold) {
|
||||
total_++;
|
||||
std::uniform_int_distribution<size_t> pos_dist(0, capacity_ - 1);
|
||||
size_t pos = pos_dist(generator_);
|
||||
return Insert(pos, transition, stream);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ReservoirReplayBuffer::Sample(const size_t &batch_size, const std::vector<AddressPtr> &transition,
|
||||
cudaStream_t stream) {
|
||||
if (!rand_state_) {
|
||||
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
|
||||
rand_state_ = static_cast<curandState *>(allocator.AllocTensorMem(sizeof(curandState) * batch_size));
|
||||
RandInit(batch_size, seed_, rand_state_, stream);
|
||||
}
|
||||
|
||||
if (!indices_) {
|
||||
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
|
||||
indices_ = static_cast<size_t *>(allocator.AllocTensorMem(sizeof(size_t) * batch_size));
|
||||
}
|
||||
|
||||
size_t valid_size = std::min(total_, capacity_);
|
||||
RandomGenUniform(batch_size, rand_state_, valid_size, indices_, stream);
|
||||
|
||||
for (size_t i = 0; i < schema_.size(); i++) {
|
||||
auto output_addr = static_cast<uint8_t *>(transition[i]->addr);
|
||||
FifoSlice(fifo_replay_buffer_[i], indices_, output_addr, batch_size, schema_[i], stream);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace gpu
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_RL_RESERVOIR_REPLAY_BUFFER_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_RL_RESERVOIR_REPLAY_BUFFER_H_
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <curand_kernel.h>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include "kernel/kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace gpu {
|
||||
class ReservoirReplayBuffer {
|
||||
public:
|
||||
// Construct a fixed-length reservoir replay buffer.
|
||||
ReservoirReplayBuffer(const uint64_t &seed, const size_t &capacity, const std::vector<size_t> &schema);
|
||||
~ReservoirReplayBuffer();
|
||||
|
||||
// Push an experience transition to the buffer which will be given the highest reservoir.
|
||||
bool Push(const std::vector<AddressPtr> &transition, cudaStream_t stream);
|
||||
|
||||
// Sample a batch transitions with indices and bias correction weights.
|
||||
bool Sample(const size_t &batch_size, const std::vector<AddressPtr> &transition, cudaStream_t stream);
|
||||
|
||||
private:
|
||||
bool Insert(const size_t &pos, const std::vector<AddressPtr> &transition, cudaStream_t stream);
|
||||
|
||||
// Random generator
|
||||
std::default_random_engine generator_;
|
||||
curandState *rand_state_{nullptr};
|
||||
|
||||
uint64_t seed_{42};
|
||||
size_t capacity_{0};
|
||||
size_t total_{0};
|
||||
size_t *indices_{nullptr};
|
||||
std::vector<size_t> schema_;
|
||||
std::vector<uint8_t *> fifo_replay_buffer_;
|
||||
};
|
||||
} // namespace gpu
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_RL_RESERVOIR_REPLAY_BUFFER_H_
|
|
@ -0,0 +1,210 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/rl/reservoir_replay_buffer_gpu_kernel.h"
|
||||
#include <vector>
|
||||
#include <random>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include "mindspore/core/ops/reservoir_replay_buffer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
using ReservoirReplayBufferFactory = ReplayBufferFactory<ReservoirReplayBuffer>;
|
||||
|
||||
ReservoirReplayBufferCreateGpuKernel::~ReservoirReplayBufferCreateGpuKernel() {
|
||||
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
|
||||
if (handle_device_) {
|
||||
allocator.FreeTensorMem(handle_device_);
|
||||
}
|
||||
}
|
||||
|
||||
bool ReservoirReplayBufferCreateGpuKernel::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::ReservoirReplayBufferCreate>(base_operator);
|
||||
if (!kernel_ptr) {
|
||||
MS_LOG(ERROR) << "cast ReservoirReplayBufferCreate ops failed!";
|
||||
return false;
|
||||
}
|
||||
|
||||
const int64_t &capacity = kernel_ptr->get_capacity();
|
||||
const std::vector<int64_t> &schema = kernel_ptr->get_schema();
|
||||
const int64_t &seed0 = kernel_ptr->get_seed0();
|
||||
const int64_t &seed1 = kernel_ptr->get_seed1();
|
||||
|
||||
unsigned int seed = 0;
|
||||
std::random_device rd;
|
||||
if (seed1 != 0) {
|
||||
seed = static_cast<unsigned int>(seed1);
|
||||
} else if (seed0 != 0) {
|
||||
seed = static_cast<unsigned int>(seed0);
|
||||
} else {
|
||||
seed = rd();
|
||||
}
|
||||
|
||||
std::vector<size_t> schema_in_size;
|
||||
std::transform(schema.begin(), schema.end(), std::back_inserter(schema_in_size),
|
||||
[](const int64_t &arg) -> size_t { return LongToSize(arg); });
|
||||
|
||||
auto &factory = ReservoirReplayBufferFactory::GetInstance();
|
||||
std::tie(handle_, reservoir_replay_buffer_) = factory.Create(seed, capacity, schema_in_size);
|
||||
MS_EXCEPTION_IF_NULL(reservoir_replay_buffer_);
|
||||
|
||||
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
|
||||
handle_device_ = static_cast<int64_t *>(allocator.AllocTensorMem(sizeof(handle_)));
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpy(handle_device_, &handle_, sizeof(handle_), cudaMemcpyHostToDevice),
|
||||
"cudaMemcpy failed.");
|
||||
|
||||
output_size_list_.push_back(sizeof(handle_));
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> ReservoirReplayBufferCreateGpuKernel::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list = {KernelAttr().AddOutputAttr(kNumberTypeInt64)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
bool ReservoirReplayBufferCreateGpuKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
auto handle = GetDeviceAddress<int64_t>(outputs, 0);
|
||||
auto stream = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
|
||||
cudaMemcpyAsync(handle, handle_device_, sizeof(handle_), cudaMemcpyDeviceToDevice, stream), "cudaMemcpy failed.");
|
||||
return true;
|
||||
}
|
||||
|
||||
ReservoirReplayBufferPushGpuKernel::~ReservoirReplayBufferPushGpuKernel() {
|
||||
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
|
||||
if (handle_device_) {
|
||||
allocator.FreeTensorMem(handle_device_);
|
||||
}
|
||||
}
|
||||
|
||||
bool ReservoirReplayBufferPushGpuKernel::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::ReservoirReplayBufferPush>(base_operator);
|
||||
if (!kernel_ptr) {
|
||||
MS_LOG(ERROR) << "cast ReservoirReplayBufferPush ops failed!";
|
||||
return false;
|
||||
}
|
||||
|
||||
handle_ = kernel_ptr->get_handle();
|
||||
reservior_replay_buffer_ = ReservoirReplayBufferFactory::GetInstance().GetByHandle(handle_);
|
||||
MS_EXCEPTION_IF_NULL(reservior_replay_buffer_);
|
||||
|
||||
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
|
||||
handle_device_ = static_cast<int64_t *>(allocator.AllocTensorMem(sizeof(handle_)));
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpy(handle_device_, &handle_, sizeof(handle_), cudaMemcpyHostToDevice),
|
||||
"cudaMemcpy failed.");
|
||||
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
TypeId type_id = inputs[i]->GetDtype();
|
||||
size_t type_size = GetTypeByte(TypeIdToType(type_id));
|
||||
const std::vector<int64_t> &shape = inputs[i]->GetShapeVector();
|
||||
size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
|
||||
input_size_list_.push_back(tensor_size);
|
||||
}
|
||||
|
||||
output_size_list_.push_back(sizeof(handle_));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ReservoirReplayBufferPushGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
auto stream = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
// Return a placeholder in case of dead code eliminate optimization.
|
||||
auto handle = GetDeviceAddress<int64_t>(outputs, 0);
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
|
||||
cudaMemcpyAsync(handle, handle_device_, sizeof(handle_), cudaMemcpyDeviceToDevice, stream), "cudaMemcpy failed.");
|
||||
|
||||
return reservior_replay_buffer_->Push(inputs, stream);
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> ReservoirReplayBufferPushGpuKernel::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list = {KernelAttr().AddSkipCheckAttr(true)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
bool ReservoirReplayBufferSampleGpuKernel::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::ReservoirReplayBufferSample>(base_operator);
|
||||
if (!kernel_ptr) {
|
||||
MS_LOG(ERROR) << "cast ReservoirReplayBufferSample ops failed!";
|
||||
return false;
|
||||
}
|
||||
|
||||
handle_ = kernel_ptr->get_handle();
|
||||
batch_size_ = kernel_ptr->get_batch_size();
|
||||
reservior_replay_buffer_ = ReservoirReplayBufferFactory::GetInstance().GetByHandle(handle_);
|
||||
MS_EXCEPTION_IF_NULL(reservior_replay_buffer_);
|
||||
|
||||
for (size_t i = 0; i < outputs.size(); i++) {
|
||||
TypeId type_id = outputs[i]->GetDtype();
|
||||
size_t type_size = GetTypeByte(TypeIdToType(type_id));
|
||||
const std::vector<int64_t> &shape = outputs[i]->GetShapeVector();
|
||||
size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
|
||||
output_size_list_.push_back(tensor_size);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ReservoirReplayBufferSampleGpuKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
return reservior_replay_buffer_->Sample(batch_size_, outputs, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> ReservoirReplayBufferSampleGpuKernel::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list = {KernelAttr().AddSkipCheckAttr(true)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
bool ReservoirReplayBufferDestroyGpuKernel::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::ReservoirReplayBufferDestroy>(base_operator);
|
||||
if (!kernel_ptr) {
|
||||
MS_LOG(ERROR) << "cast ReservoirReplayBufferDestroy ops failed!";
|
||||
return false;
|
||||
}
|
||||
|
||||
handle_ = kernel_ptr->get_handle();
|
||||
output_size_list_.push_back(sizeof(handle_));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ReservoirReplayBufferDestroyGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
ReservoirReplayBufferFactory::GetInstance().Delete(handle_);
|
||||
|
||||
// Apply host to device memory copy since it is not performance critical path.
|
||||
auto handle = GetDeviceAddress<float>(outputs, 0);
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(handle, &handle_, sizeof(handle_), cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpy failed.");
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> ReservoirReplayBufferDestroyGpuKernel::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list = {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64)};
|
||||
return support_list;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,114 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_RL_RESERVOIR_REPLAY_BUFFER_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_RL_RESERVOIR_REPLAY_BUFFER_GPU_KERNEL_H_
|
||||
|
||||
#include "plugin/device/gpu/kernel/rl/reservoir_replay_buffer.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "plugin/factory/replay_buffer_factory.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
using gpu::ReservoirReplayBuffer;
|
||||
class ReservoirReplayBufferCreateGpuKernel : public NativeGpuKernelMod {
|
||||
public:
|
||||
ReservoirReplayBufferCreateGpuKernel() = default;
|
||||
~ReservoirReplayBufferCreateGpuKernel() override;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream) override;
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
int64_t handle_{-1};
|
||||
int64_t *handle_device_{nullptr};
|
||||
std::shared_ptr<ReservoirReplayBuffer> reservoir_replay_buffer_{nullptr};
|
||||
};
|
||||
|
||||
class ReservoirReplayBufferPushGpuKernel : public NativeGpuKernelMod {
|
||||
public:
|
||||
ReservoirReplayBufferPushGpuKernel() = default;
|
||||
~ReservoirReplayBufferPushGpuKernel() override;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream) override;
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
int64_t handle_{-1};
|
||||
int64_t *handle_device_{nullptr};
|
||||
std::shared_ptr<ReservoirReplayBuffer> reservior_replay_buffer_{nullptr};
|
||||
};
|
||||
|
||||
class ReservoirReplayBufferSampleGpuKernel : public NativeGpuKernelMod {
|
||||
public:
|
||||
ReservoirReplayBufferSampleGpuKernel() = default;
|
||||
~ReservoirReplayBufferSampleGpuKernel() override = default;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream) override;
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
int64_t handle_{-1};
|
||||
size_t batch_size_{0};
|
||||
std::vector<size_t> schema_;
|
||||
std::shared_ptr<ReservoirReplayBuffer> reservior_replay_buffer_{nullptr};
|
||||
};
|
||||
|
||||
class ReservoirReplayBufferDestroyGpuKernel : public NativeGpuKernelMod {
|
||||
public:
|
||||
ReservoirReplayBufferDestroyGpuKernel() = default;
|
||||
~ReservoirReplayBufferDestroyGpuKernel() override = default;
|
||||
|
||||
// Init kernel from CNode.
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
// Execute kernel.
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream) override;
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
int64_t handle_{-1};
|
||||
};
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, ReservoirReplayBufferCreate, ReservoirReplayBufferCreateGpuKernel);
|
||||
MS_REG_GPU_KERNEL(ReservoirReplayBufferPush, ReservoirReplayBufferPushGpuKernel)
|
||||
MS_REG_GPU_KERNEL(ReservoirReplayBufferSample, ReservoirReplayBufferSampleGpuKernel)
|
||||
MS_REG_GPU_KERNEL(ReservoirReplayBufferDestroy, ReservoirReplayBufferDestroyGpuKernel)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_RL_RESERVOIR_REPLAY_BUFFER_GPU_KERNEL_H_
|
Loading…
Reference in New Issue