!40197 PriorityReplayBuffer support beta annealing

Merge pull request !40197 from chenweifeng/priority-replay-buffer-improve
This commit is contained in:
i-robot 2022-08-12 01:21:40 +00:00 committed by Gitee
commit 67fc9d00c1
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
17 changed files with 98 additions and 102 deletions

View File

@ -46,9 +46,9 @@ size_t PriorityTree::GetPrefixSumIdx(float prefix_sum) const {
return idx - capacity_;
}
PriorityReplayBuffer::PriorityReplayBuffer(uint32_t seed, float alpha, float beta, size_t capacity,
PriorityReplayBuffer::PriorityReplayBuffer(uint32_t seed, float alpha, size_t capacity,
const std::vector<size_t> &schema)
: alpha_(alpha), beta_(beta), capacity_(capacity), max_priority_(1.0), schema_(schema) {
: alpha_(alpha), capacity_(capacity), max_priority_(1.0), schema_(schema) {
random_engine_.seed(seed);
fifo_replay_buffer_ = std::make_unique<FIFOReplayBuffer>(capacity, schema);
priority_tree_ = std::make_unique<PriorityTree>(capacity);
@ -85,7 +85,7 @@ bool PriorityReplayBuffer::UpdatePriorities(const std::vector<size_t> &indices,
}
std::tuple<std::vector<size_t>, std::vector<float>, std::vector<std::vector<AddressPtr>>> PriorityReplayBuffer::Sample(
size_t batch_size) {
size_t batch_size, float beta) {
if (batch_size == 0) {
AICPU_LOGD("The batch size can not be zero.");
}
@ -93,7 +93,7 @@ std::tuple<std::vector<size_t>, std::vector<float>, std::vector<std::vector<Addr
float sum_priority = root.sum_priority;
float min_priority = root.min_priority;
size_t size = fifo_replay_buffer_->size();
float max_weight = Weight(min_priority, sum_priority, size);
float max_weight = Weight(min_priority, sum_priority, size, beta);
float segment_len = root.sum_priority / batch_size;
std::vector<size_t> indices;
@ -110,20 +110,20 @@ std::tuple<std::vector<size_t>, std::vector<float>, std::vector<std::vector<Addr
AICPU_LOGW("The sum priority is %f. It may leads to converge issue.");
max_weight = std::numeric_limits<decltype(max_weight)>::epsilon();
}
(void)weights.emplace_back(Weight(priority, sum_priority, size) / max_weight);
(void)weights.emplace_back(Weight(priority, sum_priority, size, beta) / max_weight);
(void)items.emplace_back(fifo_replay_buffer_->GetItem(idx));
}
return std::forward_as_tuple(indices, weights, items);
}
inline float PriorityReplayBuffer::Weight(float priority, float sum_priority, size_t size) const {
inline float PriorityReplayBuffer::Weight(float priority, float sum_priority, size_t size, float beta) const {
if (sum_priority <= 0.0f) {
AICPU_LOGW("The sum priority is %f. It may leads to converge issue.");
sum_priority = std::numeric_limits<decltype(sum_priority)>::epsilon();
}
float sample_prob = priority / sum_priority;
float weight = static_cast<float>(pow(sample_prob * size, -beta_));
float weight = static_cast<float>(pow(sample_prob * size, -beta));
return weight;
}
} // namespace aicpu

View File

@ -54,22 +54,22 @@ class PriorityTree : public SegmentTree<PriorityItem> {
class PriorityReplayBuffer {
public:
// Construct a fixed-length priority replay buffer.
PriorityReplayBuffer(uint32_t seed, float alpha, float beta, size_t capacity, const std::vector<size_t> &schema);
PriorityReplayBuffer(uint32_t seed, float alpha, size_t capacity, const std::vector<size_t> &schema);
// Push an experience transition to the buffer which will be given the highest priority.
bool Push(const std::vector<AddressPtr> &items);
// Sample a batch transitions with indices and bias correction weights.
std::tuple<std::vector<size_t>, std::vector<float>, std::vector<std::vector<AddressPtr>>> Sample(size_t batch_size);
std::tuple<std::vector<size_t>, std::vector<float>, std::vector<std::vector<AddressPtr>>> Sample(size_t batch_size,
float beta);
// Update experience transitions priorities.
bool UpdatePriorities(const std::vector<size_t> &indices, const std::vector<float> &priorities);
private:
inline float Weight(float priority, float sum_priority, size_t size) const;
inline float Weight(float priority, float sum_priority, size_t size, float beta) const;
float alpha_;
float beta_;
size_t capacity_;
float max_priority_;
std::vector<size_t> schema_;

View File

@ -23,9 +23,10 @@
namespace aicpu {
using PriorityReplayBufferFactory = ReplayBufferFactory<PriorityReplayBuffer>;
constexpr size_t kIndicesIndex = 0;
constexpr size_t kInWeightsIndex = 1;
constexpr size_t kTransitionIndex = 2;
constexpr size_t kBetaIndex = 0;
constexpr size_t kIndicesIndex = 1;
constexpr size_t kWeightsIndex = 2;
constexpr size_t kTransitionIndex = 3;
constexpr size_t kUpdateOpInputNum = 2;
uint32_t PriorityReplayBufferCreate::ParseKernelParam() {
@ -34,7 +35,6 @@ uint32_t PriorityReplayBufferCreate::ParseKernelParam() {
::google::protobuf::Map<::std::string, ::aicpuops::AttrValue> attrs = node_def_.attrs();
capacity_ = attrs["capacity"].i();
alpha_ = attrs["alpha"].f();
beta_ = attrs["beta"].f();
int64_t seed1 = attrs["seed"].i();
int64_t seed2 = attrs["seed2"].i();
@ -54,7 +54,7 @@ uint32_t PriorityReplayBufferCreate::DoCompute() {
int64_t handle;
std::shared_ptr<PriorityReplayBuffer> prioriory_replay_buffer;
auto &factory = PriorityReplayBufferFactory::GetInstance();
std::tie(handle, prioriory_replay_buffer) = factory.Create(seed_, alpha_, beta_, capacity_, schema_);
std::tie(handle, prioriory_replay_buffer) = factory.Create(seed_, alpha_, capacity_, schema_);
auto *output_data = reinterpret_cast<int64_t *>(io_addrs_[0]);
output_data[0] = handle;
@ -114,16 +114,17 @@ uint32_t PriorityReplayBufferSample::DoCompute() {
std::vector<float> weights;
std::vector<std::vector<AddressPtr>> samples;
auto prioriory_replay_buffer = PriorityReplayBufferFactory::GetInstance().GetByHandle(handle_);
std::tie(indices, weights, samples) = prioriory_replay_buffer->Sample(batch_size_);
auto beta = reinterpret_cast<float *>(io_addrs_[kBetaIndex]);
std::tie(indices, weights, samples) = prioriory_replay_buffer->Sample(batch_size_, beta[0]);
auto *indices_data = reinterpret_cast<void *>(io_addrs_[0]);
auto *indices_data = reinterpret_cast<void *>(io_addrs_[kIndicesIndex]);
auto ret = memcpy_s(indices_data, batch_size_ * sizeof(int64_t), indices.data(), batch_size_ * sizeof(int64_t));
if (ret != EOK) {
AICPU_LOGE("memcpy_s() failed: %d.", ret);
return kAicpuKernelStateInternalError;
}
auto *weights_data = reinterpret_cast<void *>(io_addrs_[1]);
auto *weights_data = reinterpret_cast<void *>(io_addrs_[kWeightsIndex]);
ret = memcpy_s(weights_data, batch_size_ * sizeof(float), weights.data(), batch_size_ * sizeof(float));
if (ret != EOK) {
AICPU_LOGE("memcpy_s() failed: %d.", ret);

View File

@ -48,9 +48,9 @@ size_t PriorityTree::GetPrefixSumIdx(float prefix_sum) const {
return idx - capacity_;
}
PriorityReplayBuffer::PriorityReplayBuffer(uint32_t seed, float alpha, float beta, size_t capacity,
PriorityReplayBuffer::PriorityReplayBuffer(uint32_t seed, float alpha, size_t capacity,
const std::vector<size_t> &schema)
: alpha_(alpha), beta_(beta), capacity_(capacity), max_priority_(1.0), schema_(schema) {
: alpha_(alpha), capacity_(capacity), max_priority_(1.0), schema_(schema) {
random_engine_.seed(seed);
fifo_replay_buffer_ = std::make_unique<FIFOReplayBuffer>(capacity, schema);
priority_tree_ = std::make_unique<PriorityTree>(capacity);
@ -87,13 +87,13 @@ bool PriorityReplayBuffer::UpdatePriorities(const std::vector<size_t> &indices,
}
std::tuple<std::vector<size_t>, std::vector<float>, std::vector<std::vector<AddressPtr>>> PriorityReplayBuffer::Sample(
size_t batch_size) {
size_t batch_size, float beta) {
MS_EXCEPTION_IF_ZERO("batch size", batch_size);
const PriorityItem &root = priority_tree_->Root();
float sum_priority = root.sum_priority;
float min_priority = root.min_priority;
size_t size = fifo_replay_buffer_->size();
float max_weight = Weight(min_priority, sum_priority, size);
float max_weight = Weight(min_priority, sum_priority, size, beta);
float segment_len = root.sum_priority / batch_size;
std::vector<size_t> indices;
@ -110,20 +110,20 @@ std::tuple<std::vector<size_t>, std::vector<float>, std::vector<std::vector<Addr
MS_LOG(WARNING) << "The max priority is " << max_weight << ". It may leads to converge issue.";
max_weight = kMinPriority;
}
(void)weights.emplace_back(Weight(priority, sum_priority, size) / max_weight);
(void)weights.emplace_back(Weight(priority, sum_priority, size, beta) / max_weight);
(void)items.emplace_back(fifo_replay_buffer_->GetItem(idx));
}
return std::forward_as_tuple(indices, weights, items);
}
inline float PriorityReplayBuffer::Weight(float priority, float sum_priority, size_t size) const {
inline float PriorityReplayBuffer::Weight(float priority, float sum_priority, size_t size, float beta) const {
if (sum_priority <= 0.0f) {
MS_LOG(WARNING) << "The sum priority is " << sum_priority << ". It may leads to converge issue.";
sum_priority = kMinPriority;
}
float sample_prob = priority / sum_priority;
float weight = static_cast<float>(pow(sample_prob * size, -beta_));
float weight = static_cast<float>(pow(sample_prob * size, -beta));
return weight;
}
} // namespace kernel

View File

@ -57,22 +57,22 @@ class PriorityTree : public SegmentTree<PriorityItem> {
class PriorityReplayBuffer {
public:
// Construct a fixed-length priority replay buffer.
PriorityReplayBuffer(uint32_t seed, float alpha, float beta, size_t capacity, const std::vector<size_t> &schema);
PriorityReplayBuffer(uint32_t seed, float alpha, size_t capacity, const std::vector<size_t> &schema);
// Push an experience transition to the buffer which will be given the highest priority.
bool Push(const std::vector<AddressPtr> &items);
// Sample a batch transitions with indices and bias correction weights.
std::tuple<std::vector<size_t>, std::vector<float>, std::vector<std::vector<AddressPtr>>> Sample(size_t batch_size);
std::tuple<std::vector<size_t>, std::vector<float>, std::vector<std::vector<AddressPtr>>> Sample(size_t batch_size,
float beta);
// Update experience transitions priorities.
bool UpdatePriorities(const std::vector<size_t> &indices, const std::vector<float> &priorities);
private:
inline float Weight(float priority, float sum_priority, size_t size) const;
inline float Weight(float priority, float sum_priority, size_t size, float beta) const;
float alpha_;
float beta_;
size_t capacity_;
float max_priority_;
std::vector<size_t> schema_;

View File

@ -33,7 +33,6 @@ constexpr size_t kTransitionIndex = 2;
void PriorityReplayBufferCreateCpuKernel::InitKernel(const CNodePtr &kernel_node) {
const int64_t &capacity = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "capacity");
const float &alpha = common::AnfAlgo::GetNodeAttr<float>(kernel_node, "alpha");
const float &beta = common::AnfAlgo::GetNodeAttr<float>(kernel_node, "beta");
const auto &dtypes = common::AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "dtypes");
const auto &shapes = common::AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(kernel_node, "shapes");
const int64_t &seed0 = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "seed0");
@ -58,7 +57,7 @@ void PriorityReplayBufferCreateCpuKernel::InitKernel(const CNodePtr &kernel_node
}
auto &factory = PriorityReplayBufferFactory::GetInstance();
std::tie(handle_, prioriory_replay_buffer_) = factory.Create(seed, alpha, beta, capacity, schema);
std::tie(handle_, prioriory_replay_buffer_) = factory.Create(seed, alpha, capacity, schema);
MS_EXCEPTION_IF_NULL(prioriory_replay_buffer_);
}
@ -100,12 +99,14 @@ void PriorityReplayBufferSampleCpuKernel::InitKernel(const CNodePtr &kernel_node
}
}
bool PriorityReplayBufferSampleCpuKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
bool PriorityReplayBufferSampleCpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
std::vector<size_t> indices;
std::vector<float> weights;
std::vector<std::vector<AddressPtr>> samples;
std::tie(indices, weights, samples) = prioriory_replay_buffer_->Sample(batch_size_);
auto beta = reinterpret_cast<float *>(inputs[0]->addr);
std::tie(indices, weights, samples) = prioriory_replay_buffer_->Sample(batch_size_, beta[0]);
MS_EXCEPTION_IF_CHECK_FAIL(outputs.size() == schema_.size() + kTransitionIndex,
"The dtype and shapes must be the same.");

View File

@ -54,9 +54,16 @@ __forceinline__ __device__ void SumTreeInsert(SumTree *tree, size_t idx, float p
}
}
__global__ void SumTreePushKernel(SumTree *tree, float alpha, size_t idx, float *max_priority) {
float priority = powf(*max_priority, alpha);
SumTreeInsert(tree, idx, priority);
__global__ void SumTreePushKernel(SumTree *tree, float alpha, size_t idx, float *priority, float *max_priority) {
float prio;
if (!priority) {
prio = powf(*max_priority, alpha);
} else {
*max_priority = max(*max_priority, *priority);
prio = powf(*priority, alpha);
}
SumTreeInsert(tree, idx, prio);
}
__forceinline__ __device__ size_t GetPrefixSumIdx(SumTree *tree, size_t capacity, float prefix_sum) {
@ -73,14 +80,14 @@ __forceinline__ __device__ size_t GetPrefixSumIdx(SumTree *tree, size_t capacity
return idx - capacity;
}
__global__ void SumTreeSampleKernel(SumTree *tree, curandState *state, size_t capacity, float beta, size_t batch_size,
__global__ void SumTreeSampleKernel(SumTree *tree, curandState *state, size_t capacity, float *beta, size_t batch_size,
size_t *indices, float *weights) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size; i += gridDim.x * blockDim.x) {
size_t segment_len = tree[kRootIdx].sum / batch_size;
float prefix_sum = (curand_uniform(&state[i]) + i) * segment_len;
size_t idx = GetPrefixSumIdx(tree, capacity, prefix_sum);
indices[i] = idx;
weights[i] = powf((tree[idx + capacity].sum / tree[kRootIdx].min), -beta);
weights[i] = powf((tree[idx + capacity].sum / tree[kRootIdx].min), -beta[0]);
}
}
@ -109,14 +116,14 @@ void InitRandState(const size_t &batch_size, const uint64_t &seed, curandState *
InitRandStateKernel<<<grid, block, 0, stream>>>(seed, state);
}
void SumTreePush(SumTree *tree, const float &alpha, const size_t &idx, const size_t &capacity, float *max_priority,
cudaStream_t stream) {
void SumTreePush(SumTree *tree, const float &alpha, const size_t &idx, const size_t &capacity, float *priority,
float *max_priority, cudaStream_t stream) {
size_t idx_in_tree = idx + capacity;
SumTreePushKernel<<<1, 1, 0, stream>>>(tree, alpha, idx_in_tree, max_priority);
SumTreePushKernel<<<1, 1, 0, stream>>>(tree, alpha, idx_in_tree, priority, max_priority);
}
void SumTreeSample(SumTree *tree, curandState *state, const size_t &capacity, const float &beta,
const size_t &batch_size, size_t *indices, float *weights, cudaStream_t stream) {
void SumTreeSample(SumTree *tree, curandState *state, const size_t &capacity, float *beta, const size_t &batch_size,
size_t *indices, float *weights, cudaStream_t stream) {
size_t block = std::min(batch_size, kMaxThreadPerBlock);
size_t grid = (batch_size + block - 1) / block;
SumTreeSampleKernel<<<grid, block, 0, stream>>>(tree, state, capacity, beta, batch_size, indices, weights);

View File

@ -26,10 +26,10 @@ struct SumTree {
void SumTreeInit(SumTree *tree, float *max_priority, const size_t &capacity, cudaStream_t stream);
void InitRandState(const size_t &batch_size, const uint64_t &seed, curandState *state, cudaStream_t stream);
void SumTreePush(SumTree *tree, const float &alpha, const size_t &idx, const size_t &capacity, float *max_priority,
cudaStream_t stream);
void SumTreeSample(SumTree *tree, curandState *state, const size_t &capacity, const float &beta,
const size_t &batch_size, size_t *indices, float *weights, cudaStream_t stream);
void SumTreePush(SumTree *tree, const float &alpha, const size_t &idx, const size_t &capacity, float *priority,
float *max_priority, cudaStream_t stream);
void SumTreeSample(SumTree *tree, curandState *state, const size_t &capacity, float *beta, const size_t &batch_size,
size_t *indices, float *weights, cudaStream_t stream);
void SumTreeUpdate(SumTree *tree, const size_t &capacity, const float &alpha, float *max_priority, size_t *indices,
float *priorities, const size_t &batch_size, cudaStream_t stream);
void FifoSlice(const uint8_t *input, const size_t *indice, uint8_t *output, size_t batch_size, size_t column,

View File

@ -31,10 +31,9 @@ namespace gpu {
constexpr float kMinPriority = 1e-7;
constexpr size_t kNumSubNode = 2;
PriorityReplayBuffer::PriorityReplayBuffer(const uint64_t &seed, const float &alpha, const float &beta,
const size_t &capacity, const std::vector<size_t> &schema) {
PriorityReplayBuffer::PriorityReplayBuffer(const uint64_t &seed, const float &alpha, const size_t &capacity,
const std::vector<size_t> &schema) {
alpha_ = alpha;
beta_ = beta;
schema_ = schema;
seed_ = seed;
capacity_ = capacity;
@ -72,7 +71,7 @@ PriorityReplayBuffer::~PriorityReplayBuffer() {
allocator.FreeTensorMem(max_priority_);
}
bool PriorityReplayBuffer::Push(const std::vector<AddressPtr> &transition, cudaStream_t stream) {
bool PriorityReplayBuffer::Push(const std::vector<AddressPtr> &transition, float *priority, cudaStream_t stream) {
// Head point to the latest item.
head_ = head_ >= capacity_ ? 0 : head_ + 1;
valid_size_ = valid_size_ >= capacity_ ? capacity_ : valid_size_ + 1;
@ -86,11 +85,11 @@ bool PriorityReplayBuffer::Push(const std::vector<AddressPtr> &transition, cudaS
}
// Set max priority for the newest transition.
SumTreePush(sum_tree_, alpha_, head_, capacity_pow_two_, max_priority_, stream);
SumTreePush(sum_tree_, alpha_, head_, capacity_pow_two_, priority, max_priority_, stream);
return true;
}
bool PriorityReplayBuffer::Sample(const size_t &batch_size, size_t *indices, float *weights,
bool PriorityReplayBuffer::Sample(const size_t &batch_size, float *beta, size_t *indices, float *weights,
const std::vector<AddressPtr> &transition, cudaStream_t stream) {
MS_EXCEPTION_IF_ZERO("batch size", batch_size);
@ -101,7 +100,7 @@ bool PriorityReplayBuffer::Sample(const size_t &batch_size, size_t *indices, flo
InitRandState(batch_size, seed_, rand_state_, stream);
}
SumTreeSample(sum_tree_, rand_state_, capacity_pow_two_, beta_, batch_size, indices, weights, stream);
SumTreeSample(sum_tree_, rand_state_, capacity_pow_two_, beta, batch_size, indices, weights, stream);
for (size_t i = 0; i < schema_.size(); i++) {
auto output_addr = static_cast<uint8_t *>(transition[i]->addr);

View File

@ -34,23 +34,22 @@ namespace gpu {
class PriorityReplayBuffer {
public:
// Construct a fixed-length priority replay buffer.
PriorityReplayBuffer(const uint64_t &seed, const float &alpha, const float &beta, const size_t &capacity,
PriorityReplayBuffer(const uint64_t &seed, const float &alpha, const size_t &capacity,
const std::vector<size_t> &schema);
~PriorityReplayBuffer();
// Push an experience transition to the buffer which will be given the highest priority.
bool Push(const std::vector<AddressPtr> &transition, cudaStream_t stream);
bool Push(const std::vector<AddressPtr> &transition, float *priority, cudaStream_t stream);
// Sample a batch transitions with indices and bias correction weights.
bool Sample(const size_t &batch_size, size_t *indices, float *weights, const std::vector<AddressPtr> &transition,
cudaStream_t stream);
bool Sample(const size_t &batch_size, float *beta, size_t *indices, float *weights,
const std::vector<AddressPtr> &transition, cudaStream_t stream);
// Update experience transitions priorities.
bool UpdatePriorities(size_t *indices, float *priorities, const size_t &batch_size, cudaStream_t stream);
private:
float alpha_{1.};
float beta_{1.};
std::vector<size_t> schema_;
uint64_t seed_{42};

View File

@ -43,7 +43,6 @@ bool PriorityReplayBufferCreateGpuKernel::Init(const BaseOperatorPtr &base_opera
const int64_t &capacity = kernel_ptr->get_capacity();
const float &alpha = kernel_ptr->get_alpha();
const float &beta = kernel_ptr->get_beta();
const std::vector<int64_t> &schema = kernel_ptr->get_schema();
const int64_t &seed0 = kernel_ptr->get_seed0();
const int64_t &seed1 = kernel_ptr->get_seed1();
@ -63,7 +62,7 @@ bool PriorityReplayBufferCreateGpuKernel::Init(const BaseOperatorPtr &base_opera
[](const int64_t &arg) -> size_t { return LongToSize(arg); });
auto &factory = PriorityReplayBufferFactory::GetInstance();
std::tie(handle_, prioriory_replay_buffer_) = factory.Create(seed, alpha, beta, capacity, schema_in_size);
std::tie(handle_, prioriory_replay_buffer_) = factory.Create(seed, alpha, capacity, schema_in_size);
MS_EXCEPTION_IF_NULL(prioriory_replay_buffer_);
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
@ -131,10 +130,11 @@ bool PriorityReplayBufferPushGpuKernel::Launch(const std::vector<AddressPtr> &in
auto stream = reinterpret_cast<cudaStream_t>(stream_ptr);
// Return a placeholder in case of dead code eliminate optimization.
auto handle = GetDeviceAddress<int64_t>(outputs, 0);
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
cudaMemcpyAsync(handle, handle_device_, sizeof(handle_), cudaMemcpyDeviceToDevice, stream), "cudaMemcpy failed.");
return prioriory_replay_buffer_->Push(inputs, stream);
return prioriory_replay_buffer_->Push(inputs, nullptr, stream);
}
std::vector<KernelAttr> PriorityReplayBufferPushGpuKernel::GetOpSupport() {
@ -169,6 +169,7 @@ bool PriorityReplayBufferSampleGpuKernel::Init(const BaseOperatorPtr &base_opera
bool PriorityReplayBufferSampleGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
auto beta = GetDeviceAddress<float>(inputs, 0);
auto indices = GetDeviceAddress<size_t>(outputs, 0);
auto weights = GetDeviceAddress<float>(outputs, 1);
std::vector<AddressPtr> transition;
@ -176,7 +177,7 @@ bool PriorityReplayBufferSampleGpuKernel::Launch(const std::vector<AddressPtr> &
transition.push_back(outputs[i]);
}
return prioriory_replay_buffer_->Sample(batch_size_, indices, weights, transition,
return prioriory_replay_buffer_->Sample(batch_size_, beta, indices, weights, transition,
reinterpret_cast<cudaStream_t>(stream_ptr));
}

View File

@ -33,8 +33,6 @@ void PriorityReplayBufferCreate::set_capacity(const int64_t &capacity) {
void PriorityReplayBufferCreate::set_alpha(const float &alpha) { (void)this->AddAttr(kAlpha, api::MakeValue(alpha)); }
void PriorityReplayBufferCreate::set_beta(const float &beta) { (void)this->AddAttr(kBeta, api::MakeValue(beta)); }
void PriorityReplayBufferCreate::set_shapes(const std::vector<std::vector<int64_t>> &shapes) {
(void)this->AddAttr(kShapes, api::MakeValue(shapes));
}
@ -65,12 +63,6 @@ float PriorityReplayBufferCreate::get_alpha() const {
return GetValue<float>(value_ptr);
}
float PriorityReplayBufferCreate::get_beta() const {
auto value_ptr = GetAttr(kBeta);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<float>(value_ptr);
}
std::vector<std::vector<int64_t>> PriorityReplayBufferCreate::get_shapes() const {
auto value_ptr = GetAttr(kShapes);
MS_EXCEPTION_IF_NULL(value_ptr);
@ -101,7 +93,7 @@ int64_t PriorityReplayBufferCreate::get_seed1() const {
return GetValue<int64_t>(value_ptr);
}
void PriorityReplayBufferCreate::Init(const int64_t &capacity, const float &alpha, const float &beta,
void PriorityReplayBufferCreate::Init(const int64_t &capacity, const float &alpha,
std::vector<std::vector<int64_t>> &shapes, const std::vector<TypePtr> &types,
const int64_t &seed0, const int64_t &seed1) {
auto op_name = this->name();
@ -120,7 +112,6 @@ void PriorityReplayBufferCreate::Init(const int64_t &capacity, const float &alph
this->set_capacity(capacity);
this->set_alpha(alpha);
this->set_beta(beta);
this->set_shapes(shapes);
this->set_types(types);
this->set_schema(schema);

View File

@ -39,12 +39,11 @@ class MIND_API PriorityReplayBufferCreate : public BaseOperator {
PriorityReplayBufferCreate() : BaseOperator(kNamePriorityReplayBufferCreate) { InitIOName({}, {"handle"}); }
/// \brief Init.
/// Refer to the parameters of python API @ref mindspore.ops._rl_inner_ops.PriorityReplayBufferCreate for the inputs.
void Init(const int64_t &capacity, const float &alpha, const float &beta, std::vector<std::vector<int64_t>> &shapes,
void Init(const int64_t &capacity, const float &alpha, std::vector<std::vector<int64_t>> &shapes,
const std::vector<TypePtr> &types, const int64_t &seed0, const int64_t &seed1);
void set_capacity(const int64_t &capacity);
void set_alpha(const float &alpha);
void set_beta(const float &beta);
void set_shapes(const std::vector<std::vector<int64_t>> &shapes);
void set_types(const std::vector<TypePtr> &types);
void set_schema(const std::vector<int64_t> &schema);
@ -53,7 +52,6 @@ class MIND_API PriorityReplayBufferCreate : public BaseOperator {
int64_t get_capacity() const;
float get_alpha() const;
float get_beta() const;
std::vector<std::vector<int64_t>> get_shapes() const;
std::vector<TypePtr> get_types() const;
std::vector<int64_t> get_schema() const;

View File

@ -339,11 +339,10 @@ class PriorityReplayBufferCreate(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self, capacity, alpha, beta, shapes, dtypes, seed0, seed1):
def __init__(self, capacity, alpha, shapes, dtypes, seed0, seed1):
"""Initialize PriorityReplaBufferCreate."""
validator.check_int(capacity, 1, Rel.GE, "capacity", self.name)
validator.check_float_range(alpha, 0.0, 1.0, Rel.INC_BOTH)
validator.check_float_range(beta, 0.0, 1.0, Rel.INC_BOTH)
validator.check_value_type("shape of init data", shapes, [tuple, list], self.name)
validator.check_value_type("dtypes of init data", dtypes, [tuple, list], self.name)
validator.check_non_negative_int(seed0, "seed0", self.name)
@ -428,14 +427,14 @@ class PriorityReplayBufferSample(PrimitiveWithInfer):
schema.append(num_element * type_size_in_bytes(dtype))
self.add_prim_attr("schema", schema)
def infer_shape(self):
def infer_shape(self, beta):
output_shape = [(self.batch_size,), (self.batch_size,)]
for shape in self.shapes:
output_shape.append((self.batch_size,) + shape)
# indices, weights, transitions
return tuple(output_shape)
def infer_dtype(self):
def infer_dtype(self, beta):
return (mstype.int64, mstype.float32) + self.dtypes

View File

@ -25,9 +25,9 @@ from mindspore.ops.operations._rl_inner_ops import PriorityReplayBufferDestroy
class PriorityReplayBuffer(nn.Cell):
def __init__(self, capacity, alpha, beta, sample_size, shapes, dtypes, seed0, seed1):
def __init__(self, capacity, alpha, sample_size, shapes, dtypes, seed0, seed1):
super(PriorityReplayBuffer, self).__init__()
handle = PriorityReplayBufferCreate(capacity, alpha, beta, shapes, dtypes, seed0, seed1)().asnumpy().item()
handle = PriorityReplayBufferCreate(capacity, alpha, shapes, dtypes, seed0, seed1)().asnumpy().item()
self.push_op = PriorityReplayBufferPush(handle).add_prim_attr('side_effect_io', True)
self.sample_op = PriorityReplayBufferSample(handle, sample_size, shapes, dtypes)
self.update_op = PriorityReplayBufferUpdate(handle).add_prim_attr('side_effect_io', True)
@ -36,8 +36,8 @@ class PriorityReplayBuffer(nn.Cell):
def push(self, *transition):
return self.push_op(transition)
def sample(self):
return self.sample_op()
def sample(self, beta):
return self.sample_op(beta)
def update_priorities(self, indices, priorities):
return self.update_op(indices, priorities)
@ -65,7 +65,7 @@ def test_priority_replay_buffer_ops():
action_shape, action_dtype = (6,), mindspore.int32
shapes = (state_shape, action_shape)
dtypes = (state_dtype, action_dtype)
prb = PriorityReplayBuffer(capacity, 1., 1., batch_size, shapes, dtypes, seed0=0, seed1=42)
prb = PriorityReplayBuffer(capacity, 1., batch_size, shapes, dtypes, seed0=0, seed1=42)
# Push 100 timestep transitions to priority replay buffer.
for i in range(100):
@ -74,7 +74,7 @@ def test_priority_replay_buffer_ops():
prb.push(state, action)
# Sample a batch of transitions, the indices should be consist with transition.
indices, weights, states, actions = prb.sample()
indices, weights, states, actions = prb.sample(1.)
assert np.all(indices.asnumpy() < 100)
states_expect = np.broadcast_to(indices.asnumpy().reshape(-1, 1), states.shape)
actions_expect = np.broadcast_to(indices.asnumpy().reshape(-1, 1), actions.shape)
@ -85,7 +85,7 @@ def test_priority_replay_buffer_ops():
priorities = Tensor(np.ones(weights.shape) * 1e-7, mindspore.float32)
prb.update_priorities(indices, priorities)
indices_new, _, states_new, actions_new = prb.sample()
indices_new, _, states_new, actions_new = prb.sample(1.)
assert np.all(indices_new.asnumpy() < 100)
assert np.all(indices.asnumpy() != indices_new.asnumpy())
states_expect = np.broadcast_to(indices_new.asnumpy().reshape(-1, 1), states.shape)

View File

@ -25,9 +25,9 @@ from mindspore.ops.operations._rl_inner_ops import PriorityReplayBufferDestroy
class PriorityReplayBuffer(nn.Cell):
def __init__(self, capacity, alpha, beta, sample_size, shapes, dtypes, seed0, seed1):
def __init__(self, capacity, alpha, sample_size, shapes, dtypes, seed0, seed1):
super(PriorityReplayBuffer, self).__init__()
handle = PriorityReplayBufferCreate(capacity, alpha, beta, shapes, dtypes, seed0, seed1)().asnumpy().item()
handle = PriorityReplayBufferCreate(capacity, alpha, shapes, dtypes, seed0, seed1)().asnumpy().item()
self.push_op = PriorityReplayBufferPush(handle).add_prim_attr('side_effect_io', True)
self.sample_op = PriorityReplayBufferSample(handle, sample_size, shapes, dtypes)
self.update_op = PriorityReplayBufferUpdate(handle).add_prim_attr('side_effect_io', True)
@ -36,8 +36,8 @@ class PriorityReplayBuffer(nn.Cell):
def push(self, *transition):
return self.push_op(transition)
def sample(self):
return self.sample_op()
def sample(self, beta):
return self.sample_op(beta)
def update_priorities(self, indices, priorities):
return self.update_op(indices, priorities)
@ -64,7 +64,7 @@ def test_priority_replay_buffer_ops():
action_shape, action_dtype = (6,), mindspore.int32
shapes = (state_shape, action_shape)
dtypes = (state_dtype, action_dtype)
prb = PriorityReplayBuffer(capacity, 1., 1., batch_size, shapes, dtypes, seed0=0, seed1=42)
prb = PriorityReplayBuffer(capacity, 1., batch_size, shapes, dtypes, seed0=0, seed1=42)
# Push 100 timestep transitions to priority replay buffer.
for i in range(100):
@ -73,7 +73,7 @@ def test_priority_replay_buffer_ops():
prb.push(state, action)
# Sample a batch of transitions, the indices should be consist with transition.
indices, weights, states, actions = prb.sample()
indices, weights, states, actions = prb.sample(1.)
assert np.all(indices.asnumpy() < 100)
states_expect = np.broadcast_to(indices.asnumpy().reshape(-1, 1), states.shape)
actions_expect = np.broadcast_to(indices.asnumpy().reshape(-1, 1), actions.shape)
@ -84,7 +84,7 @@ def test_priority_replay_buffer_ops():
priorities = Tensor(np.ones(weights.shape) * 1e-7, mindspore.float32)
prb.update_priorities(indices, priorities)
indices_new, _, states_new, actions_new = prb.sample()
indices_new, _, states_new, actions_new = prb.sample(1.)
assert np.all(indices_new.asnumpy() < 100)
assert np.all(indices.asnumpy() != indices_new.asnumpy())
states_expect = np.broadcast_to(indices_new.asnumpy().reshape(-1, 1), states.shape)

View File

@ -25,9 +25,9 @@ from mindspore.ops.operations._rl_inner_ops import PriorityReplayBufferDestroy
class PriorityReplayBuffer(nn.Cell):
def __init__(self, capacity, alpha, beta, sample_size, shapes, dtypes, seed0, seed1):
def __init__(self, capacity, alpha, sample_size, shapes, dtypes, seed0, seed1):
super(PriorityReplayBuffer, self).__init__()
handle = PriorityReplayBufferCreate(capacity, alpha, beta, shapes, dtypes, seed0, seed1)().asnumpy().item()
handle = PriorityReplayBufferCreate(capacity, alpha, shapes, dtypes, seed0, seed1)().asnumpy().item()
self.push_op = PriorityReplayBufferPush(handle).add_prim_attr('side_effect_io', True)
self.sample_op = PriorityReplayBufferSample(handle, sample_size, shapes, dtypes)
self.update_op = PriorityReplayBufferUpdate(handle).add_prim_attr('side_effect_io', True)
@ -36,8 +36,8 @@ class PriorityReplayBuffer(nn.Cell):
def push(self, *transition):
return self.push_op(transition)
def sample(self):
return self.sample_op()
def sample(self, beta):
return self.sample_op(beta)
def update_priorities(self, indices, priorities):
return self.update_op(indices, priorities)
@ -64,7 +64,7 @@ def test_priority_replay_buffer_ops():
action_shape, action_dtype = (6,), mindspore.int32
shapes = (state_shape, action_shape)
dtypes = (state_dtype, action_dtype)
prb = PriorityReplayBuffer(capacity, 1., 1., batch_size, shapes, dtypes, seed0=0, seed1=42)
prb = PriorityReplayBuffer(capacity, 1., batch_size, shapes, dtypes, seed0=0, seed1=42)
# Push 100 timestep transitions to priority replay buffer.
for i in range(100):
@ -73,7 +73,7 @@ def test_priority_replay_buffer_ops():
prb.push(state, action)
# Sample a batch of transitions, the indices should be consist with transition.
indices, weights, states, actions = prb.sample()
indices, weights, states, actions = prb.sample(1.)
assert np.all(indices.asnumpy() < 100)
states_expect = np.broadcast_to(indices.asnumpy().reshape(-1, 1), states.shape)
actions_expect = np.broadcast_to(indices.asnumpy().reshape(-1, 1), actions.shape)
@ -84,7 +84,7 @@ def test_priority_replay_buffer_ops():
priorities = Tensor(np.ones(weights.shape) * 1e-7, mindspore.float32)
prb.update_priorities(indices, priorities)
indices_new, _, states_new, actions_new = prb.sample()
indices_new, _, states_new, actions_new = prb.sample(1.)
assert np.all(indices_new.asnumpy() < 100)
assert np.all(indices.asnumpy() != indices_new.asnumpy())
states_expect = np.broadcast_to(indices_new.asnumpy().reshape(-1, 1), states.shape)