!39192 reservoir replay buffer primitive
Merge pull request !39192 from chenweifeng/reservior-replay-buffer-primitive
This commit is contained in:
commit
f198da7489
|
@ -1330,6 +1330,10 @@ GVAR_DEF(PrimitivePtr, kPrimTensorArray, std::make_shared<Primitive>("TensorArra
|
|||
GVAR_DEF(PrimitivePtr, kPrimTensorArrayWrite, std::make_shared<Primitive>("TensorArrayWrite"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTensorArrayGather, std::make_shared<Primitive>("TensorArrayGather"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimKMeansCentroids, std::make_shared<Primitive>("KMeansCentroids"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimReservoirReplayBufferCreate, std::make_shared<Primitive>("ReservoirReplayBufferCreate"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimReservoirReplayBufferPush, std::make_shared<Primitive>("ReservoirReplayBufferPush"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimReservoirReplayBufferSample, std::make_shared<Primitive>("ReservoirReplayBufferSample"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimReservoirReplayBufferDestroy, std::make_shared<Primitive>("ReservoirReplayBufferDestroy"));
|
||||
|
||||
// AdamApplyOne
|
||||
GVAR_DEF(PrimitivePtr, kPrimAdamApplyOne, std::make_shared<Primitive>("AdamApplyOne"));
|
||||
|
|
|
@ -327,6 +327,14 @@ constexpr auto kSearchStep = "search_step";
|
|||
constexpr auto kWithOffset = "with_offset";
|
||||
constexpr auto kLinearSumAssignment = "linear_sum_assignment";
|
||||
constexpr auto kNbins = "nbins";
|
||||
constexpr auto kCapacity = "capacity";
|
||||
constexpr auto kShapes = "shapes";
|
||||
constexpr auto kTypes = "types";
|
||||
constexpr auto kSchema = "schema";
|
||||
constexpr auto kSeed0 = "seed0";
|
||||
constexpr auto kSeed1 = "seed1";
|
||||
constexpr auto kHandle = "handle";
|
||||
constexpr auto kBatchSize = "batch_size";
|
||||
|
||||
constexpr size_t kInputIndex0 = 0;
|
||||
constexpr size_t kInputIndex1 = 1;
|
||||
|
|
|
@ -32,15 +32,6 @@ constexpr auto kNamePriorityReplayBufferSample = "PriorityReplayBufferSample";
|
|||
constexpr auto kNamePriorityReplayBufferUpdate = "PriorityReplayBufferUpdate";
|
||||
constexpr auto kNamePriorityReplayBufferDestroy = "PriorityReplayBufferDestroy";
|
||||
|
||||
constexpr auto kCapacity = "capacity";
|
||||
constexpr auto kShapes = "shapes";
|
||||
constexpr auto kTypes = "types";
|
||||
constexpr auto kSchema = "schema";
|
||||
constexpr auto kSeed0 = "seed0";
|
||||
constexpr auto kSeed1 = "seed1";
|
||||
constexpr auto kHandle = "handle";
|
||||
constexpr auto kBatchSize = "batch_size";
|
||||
|
||||
class MIND_API PriorityReplayBufferCreate : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(PriorityReplayBufferCreate);
|
||||
|
|
|
@ -0,0 +1,285 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/reservoir_replay_buffer.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void ReservoirReplayBufferCreate::set_capacity(const int64_t &capacity) {
|
||||
(void)this->AddAttr(kCapacity, api::MakeValue(capacity));
|
||||
}
|
||||
|
||||
void ReservoirReplayBufferCreate::set_shapes(const std::vector<std::vector<int64_t>> &shapes) {
|
||||
(void)this->AddAttr(kShapes, api::MakeValue(shapes));
|
||||
}
|
||||
|
||||
void ReservoirReplayBufferCreate::set_types(const std::vector<TypePtr> &types) {
|
||||
auto res = std::dynamic_pointer_cast<PrimitiveC>(impl_);
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
res->AddAttr(kTypes, MakeValue(types));
|
||||
}
|
||||
|
||||
void ReservoirReplayBufferCreate::set_schema(const std::vector<int64_t> &schema) {
|
||||
(void)this->AddAttr(kSchema, api::MakeValue(schema));
|
||||
}
|
||||
|
||||
void ReservoirReplayBufferCreate::set_seed0(const int64_t &seed0) {
|
||||
(void)this->AddAttr(kSeed0, api::MakeValue(seed0));
|
||||
}
|
||||
|
||||
void ReservoirReplayBufferCreate::set_seed1(const int64_t &seed1) {
|
||||
(void)this->AddAttr(kSeed1, api::MakeValue(seed1));
|
||||
}
|
||||
|
||||
int64_t ReservoirReplayBufferCreate::get_capacity() const {
|
||||
auto value_ptr = GetAttr(kCapacity);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> ReservoirReplayBufferCreate::get_shapes() const {
|
||||
auto value_ptr = GetAttr(kShapes);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<std::vector<int64_t>>>(value_ptr);
|
||||
}
|
||||
|
||||
std::vector<TypePtr> ReservoirReplayBufferCreate::get_types() const {
|
||||
auto res = std::dynamic_pointer_cast<PrimitiveC>(impl_);
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
return GetValue<std::vector<TypePtr>>(res->GetAttr(kTypes));
|
||||
}
|
||||
|
||||
std::vector<int64_t> ReservoirReplayBufferCreate::get_schema() const {
|
||||
auto value_ptr = GetAttr(kSchema);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t ReservoirReplayBufferCreate::get_seed0() const {
|
||||
auto value_ptr = GetAttr(kSeed0);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t ReservoirReplayBufferCreate::get_seed1() const {
|
||||
auto value_ptr = GetAttr(kSeed1);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void ReservoirReplayBufferCreate::Init(const int64_t &capacity, std::vector<std::vector<int64_t>> &shapes,
|
||||
const std::vector<TypePtr> &types, const int64_t &seed0, const int64_t &seed1) {
|
||||
auto op_name = this->name();
|
||||
if (shapes.size() != types.size()) {
|
||||
MS_LOG(EXCEPTION) << "For " << op_name
|
||||
<< " the rank of shapes and types should be the same, but got the rank of shapes is "
|
||||
<< shapes.size() << ", and types is " << types.size();
|
||||
}
|
||||
|
||||
std::vector<int64_t> schema;
|
||||
for (size_t i = 0; i < shapes.size(); i++) {
|
||||
size_t type_size = GetTypeByte(types[i]);
|
||||
size_t tensor_size = std::accumulate(shapes[i].begin(), shapes[i].end(), type_size, std::multiplies<int64_t>());
|
||||
schema.push_back(tensor_size);
|
||||
}
|
||||
|
||||
this->set_capacity(capacity);
|
||||
this->set_shapes(shapes);
|
||||
this->set_types(types);
|
||||
this->set_schema(schema);
|
||||
this->set_seed0(seed0);
|
||||
this->set_seed1(seed1);
|
||||
}
|
||||
|
||||
void ReservoirReplayBufferPush::set_handle(const int64_t &handle) {
|
||||
(void)this->AddAttr(kHandle, api::MakeValue(handle));
|
||||
}
|
||||
|
||||
int64_t ReservoirReplayBufferPush::get_handle() const {
|
||||
auto value_ptr = GetAttr(kHandle);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void ReservoirReplayBufferPush::Init(const int64_t &handle) { this->set_handle(handle); }
|
||||
|
||||
void ReservoirReplayBufferSample::set_handle(const int64_t &handle) {
|
||||
(void)this->AddAttr(kHandle, api::MakeValue(handle));
|
||||
}
|
||||
|
||||
void ReservoirReplayBufferSample::set_batch_size(const int64_t &batch_size) {
|
||||
(void)this->AddAttr(kBatchSize, api::MakeValue(batch_size));
|
||||
}
|
||||
|
||||
void ReservoirReplayBufferSample::set_shapes(const std::vector<std::vector<int64_t>> &shapes) {
|
||||
(void)this->AddAttr(kShapes, api::MakeValue(shapes));
|
||||
}
|
||||
|
||||
void ReservoirReplayBufferSample::set_types(const std::vector<TypePtr> &types) {
|
||||
auto res = std::dynamic_pointer_cast<PrimitiveC>(impl_);
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
res->AddAttr(kTypes, MakeValue(types));
|
||||
}
|
||||
|
||||
void ReservoirReplayBufferSample::set_schema(const std::vector<int64_t> &schema) {
|
||||
(void)this->AddAttr(kSchema, api::MakeValue(schema));
|
||||
}
|
||||
|
||||
int64_t ReservoirReplayBufferSample::get_handle() const {
|
||||
auto value_ptr = GetAttr(kHandle);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t ReservoirReplayBufferSample::get_batch_size() const {
|
||||
auto value_ptr = GetAttr(kBatchSize);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> ReservoirReplayBufferSample::get_shapes() const {
|
||||
auto value_ptr = GetAttr(kShapes);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<std::vector<int64_t>>>(value_ptr);
|
||||
}
|
||||
|
||||
std::vector<TypePtr> ReservoirReplayBufferSample::get_types() const {
|
||||
auto res = std::dynamic_pointer_cast<PrimitiveC>(impl_);
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
return GetValue<std::vector<TypePtr>>(res->GetAttr(kTypes));
|
||||
}
|
||||
|
||||
std::vector<int64_t> ReservoirReplayBufferSample::get_schema() const {
|
||||
auto value_ptr = GetAttr(kSchema);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
void ReservoirReplayBufferSample::Init(const int64_t &handle, const int64_t &batch_size,
|
||||
const std::vector<std::vector<int64_t>> &shapes,
|
||||
const std::vector<TypePtr> &types) {
|
||||
auto op_name = this->name();
|
||||
if (shapes.size() != types.size()) {
|
||||
MS_LOG(EXCEPTION) << "For " << op_name
|
||||
<< " the rank of shapes and types should be the same, but got the rank of shapes is "
|
||||
<< shapes.size() << ", and types is " << types.size();
|
||||
}
|
||||
|
||||
std::vector<int64_t> schema;
|
||||
for (size_t i = 0; i < shapes.size(); i++) {
|
||||
size_t type_size = GetTypeByte(types[i]);
|
||||
size_t tensor_size = std::accumulate(shapes[i].begin(), shapes[i].end(), type_size, std::multiplies<int64_t>());
|
||||
schema.push_back(tensor_size);
|
||||
}
|
||||
|
||||
this->set_handle(handle);
|
||||
this->set_batch_size(batch_size);
|
||||
this->set_shapes(shapes);
|
||||
this->set_types(types);
|
||||
this->set_schema(schema);
|
||||
}
|
||||
|
||||
void ReservoirReplayBufferDestroy::set_handle(const int64_t &handle) {
|
||||
(void)this->AddAttr(kHandle, api::MakeValue(handle));
|
||||
}
|
||||
|
||||
int64_t ReservoirReplayBufferDestroy::get_handle() const {
|
||||
auto value_ptr = GetAttr(kHandle);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void ReservoirReplayBufferDestroy::Init(const int64_t &handle) { this->set_handle(handle); }
|
||||
|
||||
MIND_API_OPERATOR_IMPL(ReservoirReplayBufferCreate, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(ReservoirReplayBufferPush, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(ReservoirReplayBufferSample, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(ReservoirReplayBufferDestroy, BaseOperator);
|
||||
|
||||
namespace {
|
||||
AbstractBasePtr CreateInfer(const abstract::AnalysisEnginePtr &prim, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
||||
const std::string &prim_name = primitive->name();
|
||||
if (input_args.size() != 0) {
|
||||
MS_LOG(EXCEPTION) << "For Primitive[" << prim_name << "], the input should be empty.";
|
||||
}
|
||||
const ShapeVector &shape = {1};
|
||||
BaseShapePtr out_shape = std::make_shared<abstract::Shape>(shape);
|
||||
return abstract::MakeAbstract(out_shape, kInt64);
|
||||
}
|
||||
|
||||
AbstractBasePtr PushInfer(const abstract::AnalysisEnginePtr &prim, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const ShapeVector &shape = {1};
|
||||
BaseShapePtr out_shape = std::make_shared<abstract::Shape>(shape);
|
||||
return abstract::MakeAbstract(out_shape, kInt64);
|
||||
}
|
||||
|
||||
AbstractBasePtr SampleInfer(const abstract::AnalysisEnginePtr &prim, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
||||
const std::string &prim_name = primitive->name();
|
||||
auto types = GetValue<std::vector<TypePtr>>(primitive->GetAttr("dtypes"));
|
||||
auto shapes = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr("shapes"));
|
||||
if (types.size() != shapes.size()) {
|
||||
MS_LOG(EXCEPTION) << "For Primitive[" << prim_name << "], the types and shapes rank should be same.";
|
||||
}
|
||||
|
||||
auto batch_size = GetValue<int64_t>(primitive->GetAttr("batch_size"));
|
||||
AbstractBasePtrList output;
|
||||
for (size_t i = 0; i < shapes.size(); ++i) {
|
||||
auto shape = shapes[i];
|
||||
shape.emplace(shape.begin(), batch_size);
|
||||
auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, types[i]);
|
||||
auto tensor = std::make_shared<abstract::AbstractTensor>(element, std::make_shared<abstract::Shape>(shape));
|
||||
output.emplace_back(tensor);
|
||||
}
|
||||
|
||||
return std::make_shared<abstract::AbstractTuple>(output);
|
||||
}
|
||||
|
||||
AbstractBasePtr DestroyInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
||||
const ShapeVector &shape = {1};
|
||||
BaseShapePtr out_shape = std::make_shared<abstract::Shape>(shape);
|
||||
return abstract::MakeAbstract(out_shape, kInt64);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ReservoirReplayBufferCreate, prim::kPrimReservoirReplayBufferCreate, CreateInfer, nullptr,
|
||||
true);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ReservoirReplayBufferPush, prim::kPrimReservoirReplayBufferPush, PushInfer, nullptr, true);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ReservoirReplayBufferSample, prim::kPrimReservoirReplayBufferSample, SampleInfer, nullptr,
|
||||
true);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ReservoirReplayBufferDestroy, prim::kPrimReservoirReplayBufferDestroy, DestroyInfer,
|
||||
nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,114 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.MINDSPORE_CORE_OPS_PRIORITY_REPLAY_BUFFER_H_
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_RESERVOIR_REPLAY_BUFFER_H_
|
||||
#define MINDSPORE_CORE_OPS_RESERVOIR_REPLAY_BUFFER_H_
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
#include "mindapi/ir/common.h"
|
||||
#include "ir/dtype/type.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameReservoirReplayBufferCreate = "ReservoirReplayBufferCreate";
|
||||
constexpr auto kNameReservoirReplayBufferPush = "ReservoirReplayBufferPush";
|
||||
constexpr auto kNameReservoirReplayBufferSample = "ReservoirReplayBufferSample";
|
||||
constexpr auto kNameReservoirReplayBufferDestroy = "ReservoirReplayBufferDestroy";
|
||||
|
||||
class MIND_API ReservoirReplayBufferCreate : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ReservoirReplayBufferCreate);
|
||||
/// \brief Constructor.
|
||||
ReservoirReplayBufferCreate() : BaseOperator(kNameReservoirReplayBufferCreate) { InitIOName({}, {"handle"}); }
|
||||
/// \brief Init.
|
||||
/// Refer to the parameters of python API @ref mindspore.ops._rl_inner_ops.ReservoirReplayBufferCreate for the inputs.
|
||||
void Init(const int64_t &capacity, std::vector<std::vector<int64_t>> &shapes, const std::vector<TypePtr> &types,
|
||||
const int64_t &seed0, const int64_t &seed1);
|
||||
|
||||
void set_capacity(const int64_t &capacity);
|
||||
void set_shapes(const std::vector<std::vector<int64_t>> &shapes);
|
||||
void set_types(const std::vector<TypePtr> &types);
|
||||
void set_schema(const std::vector<int64_t> &schema);
|
||||
void set_seed0(const int64_t &seed0);
|
||||
void set_seed1(const int64_t &seed1);
|
||||
|
||||
int64_t get_capacity() const;
|
||||
std::vector<std::vector<int64_t>> get_shapes() const;
|
||||
std::vector<TypePtr> get_types() const;
|
||||
std::vector<int64_t> get_schema() const;
|
||||
int64_t get_seed0() const;
|
||||
int64_t get_seed1() const;
|
||||
};
|
||||
|
||||
class MIND_API ReservoirReplayBufferPush : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ReservoirReplayBufferPush);
|
||||
/// \brief Constructor.
|
||||
ReservoirReplayBufferPush() : BaseOperator(kNameReservoirReplayBufferPush) { InitIOName({"transition"}, {"handle"}); }
|
||||
/// \brief Init.
|
||||
/// Refer to the parameters of python API @ref mindspore.ops._rl_inner_ops.ReservoirReplayBufferPush for the inputs.
|
||||
void Init(const int64_t &handle);
|
||||
|
||||
void set_handle(const int64_t &handle);
|
||||
int64_t get_handle() const;
|
||||
};
|
||||
|
||||
class MIND_API ReservoirReplayBufferSample : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ReservoirReplayBufferSample);
|
||||
/// \brief Constructor.
|
||||
ReservoirReplayBufferSample() : BaseOperator(kNameReservoirReplayBufferSample) {
|
||||
InitIOName({}, {"indices", "weights"});
|
||||
}
|
||||
/// \brief Init.
|
||||
/// Refer to the parameters of python API @ref mindspore.ops._rl_inner_ops.ReservoirReplayBufferSample for the inputs.
|
||||
void Init(const int64_t &handle, const int64_t &batch_size, const std::vector<std::vector<int64_t>> &shapes,
|
||||
const std::vector<TypePtr> &types);
|
||||
|
||||
void set_handle(const int64_t &handle);
|
||||
void set_batch_size(const int64_t &batch_size);
|
||||
void set_shapes(const std::vector<std::vector<int64_t>> &shapes);
|
||||
void set_types(const std::vector<TypePtr> &types);
|
||||
void set_schema(const std::vector<int64_t> &schama);
|
||||
|
||||
int64_t get_handle() const;
|
||||
int64_t get_batch_size() const;
|
||||
std::vector<std::vector<int64_t>> get_shapes() const;
|
||||
std::vector<TypePtr> get_types() const;
|
||||
std::vector<int64_t> get_schema() const;
|
||||
};
|
||||
|
||||
class MIND_API ReservoirReplayBufferDestroy : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ReservoirReplayBufferDestroy);
|
||||
/// \brief Constructor.
|
||||
ReservoirReplayBufferDestroy() : BaseOperator(kNameReservoirReplayBufferDestroy) {
|
||||
InitIOName({"handle"}, {"handle"});
|
||||
}
|
||||
/// \brief Init.
|
||||
/// Refer to the parameters of python API @ref mindspore.ops._rl_inner_ops.ReservoirReplayBufferUpdate for the inputs.
|
||||
void Init(const int64_t &handle);
|
||||
|
||||
void set_handle(const int64_t &handle);
|
||||
int64_t get_handle() const;
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_RESERVOIR_REPLAY_BUFFER_H_
|
|
@ -21,7 +21,7 @@ from mindspore.common.dtype import type_size_in_bytes
|
|||
import mindspore.context as context
|
||||
from ..._checkparam import Validator as validator
|
||||
from ...common import dtype as mstype
|
||||
from ..primitive import prim_attr_register, PrimitiveWithInfer
|
||||
from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive
|
||||
from ..._checkparam import Rel
|
||||
|
||||
|
||||
|
@ -501,6 +501,131 @@ class PriorityReplayBufferDestroy(PrimitiveWithInfer):
|
|||
return mstype.int64
|
||||
|
||||
|
||||
class ReservoirReplayBufferCreate(Primitive):
|
||||
r"""
|
||||
ReservoirReplayBufferCreate is experience container used in reinforcement learning.
|
||||
The algorithm is proposed in `Random sampling with a reservoir <https://dl.acm.org/doi/pdf/10.1145/3147.3165>`
|
||||
which used in `Deep Counterfactual Regret Minimization <https://arxiv.org/abs/1811.00164>`.
|
||||
It lets the reinforcement learning agents remember and reuse experiences from the past. Besides, It keeps an
|
||||
'unbiased' sample of previous iterations.
|
||||
|
||||
Args:
|
||||
capcity (int64): Capacity of the buffer.
|
||||
shapes (list[tuple[int]]): The dimensionality of the transition.
|
||||
dtypes (list[:class:`mindspore.dtype`]): The type of the transition.
|
||||
seed0 (int): Random seed0, must be non-negative. Default: 0.
|
||||
seed1 (int): Random seed1, must be non-negative. Default: 0.
|
||||
|
||||
Outputs:
|
||||
handle(Tensor): Handle of created replay buffer instance with dtype int64 and shape (1,).
|
||||
|
||||
Raises:
|
||||
TypeError: The args not provided.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, capacity, shapes, dtypes, seed0, seed1):
|
||||
"""Initialize ReservoirReplayBufferCreate."""
|
||||
validator.check_int(capacity, 1, Rel.GE, "capacity", self.name)
|
||||
validator.check_value_type("shape of init data", shapes, [tuple, list], self.name)
|
||||
validator.check_value_type("dtypes of init data", dtypes, [tuple, list], self.name)
|
||||
validator.check_non_negative_int(seed0, "seed0", self.name)
|
||||
validator.check_non_negative_int(seed1, "seed1", self.name)
|
||||
|
||||
schema = []
|
||||
for shape, dtype in zip(shapes, dtypes):
|
||||
num_element = functools.reduce(lambda x, y: x * y, shape)
|
||||
schema.append(num_element * type_size_in_bytes(dtype))
|
||||
self.add_prim_attr("schema", schema)
|
||||
|
||||
|
||||
class ReservoirReplayBufferPush(Primitive):
|
||||
r"""
|
||||
Push a transition to the replay buffer.
|
||||
|
||||
Args:
|
||||
handle(Tensor): The replay buffer instance handle with dtype int64 and shape (1,).
|
||||
|
||||
Outputs:
|
||||
handle(Tensor): The replay buffer instance handle with dtype int64 and shape (1,).
|
||||
|
||||
Raises:
|
||||
TypeError: The replay buffer not created before.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, handle):
|
||||
"""Initialize ReservoirReplayBufferPush."""
|
||||
validator.check_int(handle, 0, Rel.GE, "handle", self.name)
|
||||
|
||||
|
||||
class ReservoirReplayBufferSample(Primitive):
|
||||
r"""
|
||||
Sample a transition to the replay buffer.
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
|
||||
Args:
|
||||
handle(Tensor): Priority replay buffer instance handle with dtype int64 and shape (1,).
|
||||
batch_size (int): The size of the sampled transitions.
|
||||
shapes (list[tuple[int]]): The dimensionality of the transition.
|
||||
dtypes (list[:class:`mindspore.dtype`]): The type of the transition.
|
||||
|
||||
Outputs:
|
||||
tuple(Tensor): Transition with its indices and bias correction weights.
|
||||
|
||||
Raises:
|
||||
TypeError: The replay buffer not created before.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, handle, batch_size, shapes, dtypes):
|
||||
"""Initialize PriorityReplaBufferSample."""
|
||||
validator.check_int(handle, 0, Rel.GE, "capacity", self.name)
|
||||
validator.check_int(batch_size, 1, Rel.GE, "batch_size", self.name)
|
||||
validator.check_value_type("shape of init data", shapes, [tuple, list], self.name)
|
||||
validator.check_value_type("dtypes of init data", dtypes, [tuple, list], self.name)
|
||||
|
||||
schema = []
|
||||
for shape, dtype in zip(shapes, dtypes):
|
||||
num_element = functools.reduce(lambda x, y: x * y, shape)
|
||||
schema.append(num_element * type_size_in_bytes(dtype))
|
||||
self.add_prim_attr("schema", schema)
|
||||
|
||||
|
||||
class ReservoirReplayBufferDestroy(PrimitiveWithInfer):
|
||||
r"""
|
||||
Destroy the replay buffer.
|
||||
|
||||
Args:
|
||||
handle(Tensor): The Replay buffer instance handle with dtype int64 and shape (1,).
|
||||
|
||||
Outputs:
|
||||
Replay buffer instance handle with dtype int64 and shape (1,).
|
||||
|
||||
Raises:
|
||||
TypeError: The replay buffer not created before.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, handle):
|
||||
"""Initialize ReservoirReplayBufferDestroy."""
|
||||
validator.check_int(handle, 0, Rel.GE, "handle", self.name)
|
||||
|
||||
|
||||
class BatchAssign(PrimitiveWithInfer):
|
||||
"""
|
||||
Assign the parameters of the source to overwrite the target.
|
||||
|
|
Loading…
Reference in New Issue