!39192 reservoir replay buffer primitive

Merge pull request !39192 from chenweifeng/reservior-replay-buffer-primitive
This commit is contained in:
i-robot 2022-08-01 10:59:37 +00:00 committed by Gitee
commit f198da7489
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 537 additions and 10 deletions

View File

@ -1330,6 +1330,10 @@ GVAR_DEF(PrimitivePtr, kPrimTensorArray, std::make_shared<Primitive>("TensorArra
GVAR_DEF(PrimitivePtr, kPrimTensorArrayWrite, std::make_shared<Primitive>("TensorArrayWrite"));
GVAR_DEF(PrimitivePtr, kPrimTensorArrayGather, std::make_shared<Primitive>("TensorArrayGather"));
GVAR_DEF(PrimitivePtr, kPrimKMeansCentroids, std::make_shared<Primitive>("KMeansCentroids"));
GVAR_DEF(PrimitivePtr, kPrimReservoirReplayBufferCreate, std::make_shared<Primitive>("ReservoirReplayBufferCreate"));
GVAR_DEF(PrimitivePtr, kPrimReservoirReplayBufferPush, std::make_shared<Primitive>("ReservoirReplayBufferPush"));
GVAR_DEF(PrimitivePtr, kPrimReservoirReplayBufferSample, std::make_shared<Primitive>("ReservoirReplayBufferSample"));
GVAR_DEF(PrimitivePtr, kPrimReservoirReplayBufferDestroy, std::make_shared<Primitive>("ReservoirReplayBufferDestroy"));
// AdamApplyOne
GVAR_DEF(PrimitivePtr, kPrimAdamApplyOne, std::make_shared<Primitive>("AdamApplyOne"));

View File

@ -327,6 +327,14 @@ constexpr auto kSearchStep = "search_step";
constexpr auto kWithOffset = "with_offset";
constexpr auto kLinearSumAssignment = "linear_sum_assignment";
constexpr auto kNbins = "nbins";
constexpr auto kCapacity = "capacity";
constexpr auto kShapes = "shapes";
constexpr auto kTypes = "types";
constexpr auto kSchema = "schema";
constexpr auto kSeed0 = "seed0";
constexpr auto kSeed1 = "seed1";
constexpr auto kHandle = "handle";
constexpr auto kBatchSize = "batch_size";
constexpr size_t kInputIndex0 = 0;
constexpr size_t kInputIndex1 = 1;

View File

@ -32,15 +32,6 @@ constexpr auto kNamePriorityReplayBufferSample = "PriorityReplayBufferSample";
constexpr auto kNamePriorityReplayBufferUpdate = "PriorityReplayBufferUpdate";
constexpr auto kNamePriorityReplayBufferDestroy = "PriorityReplayBufferDestroy";
constexpr auto kCapacity = "capacity";
constexpr auto kShapes = "shapes";
constexpr auto kTypes = "types";
constexpr auto kSchema = "schema";
constexpr auto kSeed0 = "seed0";
constexpr auto kSeed1 = "seed1";
constexpr auto kHandle = "handle";
constexpr auto kBatchSize = "batch_size";
class MIND_API PriorityReplayBufferCreate : public BaseOperator {
public:
MIND_API_BASE_MEMBER(PriorityReplayBufferCreate);

View File

@ -0,0 +1,285 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/reservoir_replay_buffer.h"
#include <string>
#include <algorithm>
#include <functional>
#include <memory>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
void ReservoirReplayBufferCreate::set_capacity(const int64_t &capacity) {
(void)this->AddAttr(kCapacity, api::MakeValue(capacity));
}
void ReservoirReplayBufferCreate::set_shapes(const std::vector<std::vector<int64_t>> &shapes) {
(void)this->AddAttr(kShapes, api::MakeValue(shapes));
}
void ReservoirReplayBufferCreate::set_types(const std::vector<TypePtr> &types) {
auto res = std::dynamic_pointer_cast<PrimitiveC>(impl_);
MS_EXCEPTION_IF_NULL(res);
res->AddAttr(kTypes, MakeValue(types));
}
void ReservoirReplayBufferCreate::set_schema(const std::vector<int64_t> &schema) {
(void)this->AddAttr(kSchema, api::MakeValue(schema));
}
void ReservoirReplayBufferCreate::set_seed0(const int64_t &seed0) {
(void)this->AddAttr(kSeed0, api::MakeValue(seed0));
}
void ReservoirReplayBufferCreate::set_seed1(const int64_t &seed1) {
(void)this->AddAttr(kSeed1, api::MakeValue(seed1));
}
int64_t ReservoirReplayBufferCreate::get_capacity() const {
auto value_ptr = GetAttr(kCapacity);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
std::vector<std::vector<int64_t>> ReservoirReplayBufferCreate::get_shapes() const {
auto value_ptr = GetAttr(kShapes);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<std::vector<int64_t>>>(value_ptr);
}
std::vector<TypePtr> ReservoirReplayBufferCreate::get_types() const {
auto res = std::dynamic_pointer_cast<PrimitiveC>(impl_);
MS_EXCEPTION_IF_NULL(res);
return GetValue<std::vector<TypePtr>>(res->GetAttr(kTypes));
}
std::vector<int64_t> ReservoirReplayBufferCreate::get_schema() const {
auto value_ptr = GetAttr(kSchema);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
int64_t ReservoirReplayBufferCreate::get_seed0() const {
auto value_ptr = GetAttr(kSeed0);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
int64_t ReservoirReplayBufferCreate::get_seed1() const {
auto value_ptr = GetAttr(kSeed1);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
void ReservoirReplayBufferCreate::Init(const int64_t &capacity, std::vector<std::vector<int64_t>> &shapes,
const std::vector<TypePtr> &types, const int64_t &seed0, const int64_t &seed1) {
auto op_name = this->name();
if (shapes.size() != types.size()) {
MS_LOG(EXCEPTION) << "For " << op_name
<< " the rank of shapes and types should be the same, but got the rank of shapes is "
<< shapes.size() << ", and types is " << types.size();
}
std::vector<int64_t> schema;
for (size_t i = 0; i < shapes.size(); i++) {
size_t type_size = GetTypeByte(types[i]);
size_t tensor_size = std::accumulate(shapes[i].begin(), shapes[i].end(), type_size, std::multiplies<int64_t>());
schema.push_back(tensor_size);
}
this->set_capacity(capacity);
this->set_shapes(shapes);
this->set_types(types);
this->set_schema(schema);
this->set_seed0(seed0);
this->set_seed1(seed1);
}
void ReservoirReplayBufferPush::set_handle(const int64_t &handle) {
(void)this->AddAttr(kHandle, api::MakeValue(handle));
}
int64_t ReservoirReplayBufferPush::get_handle() const {
auto value_ptr = GetAttr(kHandle);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
void ReservoirReplayBufferPush::Init(const int64_t &handle) { this->set_handle(handle); }
void ReservoirReplayBufferSample::set_handle(const int64_t &handle) {
(void)this->AddAttr(kHandle, api::MakeValue(handle));
}
void ReservoirReplayBufferSample::set_batch_size(const int64_t &batch_size) {
(void)this->AddAttr(kBatchSize, api::MakeValue(batch_size));
}
void ReservoirReplayBufferSample::set_shapes(const std::vector<std::vector<int64_t>> &shapes) {
(void)this->AddAttr(kShapes, api::MakeValue(shapes));
}
void ReservoirReplayBufferSample::set_types(const std::vector<TypePtr> &types) {
auto res = std::dynamic_pointer_cast<PrimitiveC>(impl_);
MS_EXCEPTION_IF_NULL(res);
res->AddAttr(kTypes, MakeValue(types));
}
void ReservoirReplayBufferSample::set_schema(const std::vector<int64_t> &schema) {
(void)this->AddAttr(kSchema, api::MakeValue(schema));
}
int64_t ReservoirReplayBufferSample::get_handle() const {
auto value_ptr = GetAttr(kHandle);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
int64_t ReservoirReplayBufferSample::get_batch_size() const {
auto value_ptr = GetAttr(kBatchSize);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
std::vector<std::vector<int64_t>> ReservoirReplayBufferSample::get_shapes() const {
auto value_ptr = GetAttr(kShapes);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<std::vector<int64_t>>>(value_ptr);
}
std::vector<TypePtr> ReservoirReplayBufferSample::get_types() const {
auto res = std::dynamic_pointer_cast<PrimitiveC>(impl_);
MS_EXCEPTION_IF_NULL(res);
return GetValue<std::vector<TypePtr>>(res->GetAttr(kTypes));
}
std::vector<int64_t> ReservoirReplayBufferSample::get_schema() const {
auto value_ptr = GetAttr(kSchema);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
void ReservoirReplayBufferSample::Init(const int64_t &handle, const int64_t &batch_size,
const std::vector<std::vector<int64_t>> &shapes,
const std::vector<TypePtr> &types) {
auto op_name = this->name();
if (shapes.size() != types.size()) {
MS_LOG(EXCEPTION) << "For " << op_name
<< " the rank of shapes and types should be the same, but got the rank of shapes is "
<< shapes.size() << ", and types is " << types.size();
}
std::vector<int64_t> schema;
for (size_t i = 0; i < shapes.size(); i++) {
size_t type_size = GetTypeByte(types[i]);
size_t tensor_size = std::accumulate(shapes[i].begin(), shapes[i].end(), type_size, std::multiplies<int64_t>());
schema.push_back(tensor_size);
}
this->set_handle(handle);
this->set_batch_size(batch_size);
this->set_shapes(shapes);
this->set_types(types);
this->set_schema(schema);
}
void ReservoirReplayBufferDestroy::set_handle(const int64_t &handle) {
(void)this->AddAttr(kHandle, api::MakeValue(handle));
}
int64_t ReservoirReplayBufferDestroy::get_handle() const {
auto value_ptr = GetAttr(kHandle);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
void ReservoirReplayBufferDestroy::Init(const int64_t &handle) { this->set_handle(handle); }
MIND_API_OPERATOR_IMPL(ReservoirReplayBufferCreate, BaseOperator);
MIND_API_OPERATOR_IMPL(ReservoirReplayBufferPush, BaseOperator);
MIND_API_OPERATOR_IMPL(ReservoirReplayBufferSample, BaseOperator);
MIND_API_OPERATOR_IMPL(ReservoirReplayBufferDestroy, BaseOperator);
namespace {
AbstractBasePtr CreateInfer(const abstract::AnalysisEnginePtr &prim, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const std::string &prim_name = primitive->name();
if (input_args.size() != 0) {
MS_LOG(EXCEPTION) << "For Primitive[" << prim_name << "], the input should be empty.";
}
const ShapeVector &shape = {1};
BaseShapePtr out_shape = std::make_shared<abstract::Shape>(shape);
return abstract::MakeAbstract(out_shape, kInt64);
}
AbstractBasePtr PushInfer(const abstract::AnalysisEnginePtr &prim, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &) {
MS_EXCEPTION_IF_NULL(primitive);
const ShapeVector &shape = {1};
BaseShapePtr out_shape = std::make_shared<abstract::Shape>(shape);
return abstract::MakeAbstract(out_shape, kInt64);
}
AbstractBasePtr SampleInfer(const abstract::AnalysisEnginePtr &prim, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const std::string &prim_name = primitive->name();
auto types = GetValue<std::vector<TypePtr>>(primitive->GetAttr("dtypes"));
auto shapes = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr("shapes"));
if (types.size() != shapes.size()) {
MS_LOG(EXCEPTION) << "For Primitive[" << prim_name << "], the types and shapes rank should be same.";
}
auto batch_size = GetValue<int64_t>(primitive->GetAttr("batch_size"));
AbstractBasePtrList output;
for (size_t i = 0; i < shapes.size(); ++i) {
auto shape = shapes[i];
shape.emplace(shape.begin(), batch_size);
auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, types[i]);
auto tensor = std::make_shared<abstract::AbstractTensor>(element, std::make_shared<abstract::Shape>(shape));
output.emplace_back(tensor);
}
return std::make_shared<abstract::AbstractTuple>(output);
}
AbstractBasePtr DestroyInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const ShapeVector &shape = {1};
BaseShapePtr out_shape = std::make_shared<abstract::Shape>(shape);
return abstract::MakeAbstract(out_shape, kInt64);
}
} // namespace
REGISTER_PRIMITIVE_EVAL_IMPL(ReservoirReplayBufferCreate, prim::kPrimReservoirReplayBufferCreate, CreateInfer, nullptr,
true);
REGISTER_PRIMITIVE_EVAL_IMPL(ReservoirReplayBufferPush, prim::kPrimReservoirReplayBufferPush, PushInfer, nullptr, true);
REGISTER_PRIMITIVE_EVAL_IMPL(ReservoirReplayBufferSample, prim::kPrimReservoirReplayBufferSample, SampleInfer, nullptr,
true);
REGISTER_PRIMITIVE_EVAL_IMPL(ReservoirReplayBufferDestroy, prim::kPrimReservoirReplayBufferDestroy, DestroyInfer,
nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,114 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.MINDSPORE_CORE_OPS_PRIORITY_REPLAY_BUFFER_H_
*/
#ifndef MINDSPORE_CORE_OPS_RESERVOIR_REPLAY_BUFFER_H_
#define MINDSPORE_CORE_OPS_RESERVOIR_REPLAY_BUFFER_H_
#include <memory>
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
#include "mindapi/ir/common.h"
#include "ir/dtype/type.h"
namespace mindspore {
namespace ops {
constexpr auto kNameReservoirReplayBufferCreate = "ReservoirReplayBufferCreate";
constexpr auto kNameReservoirReplayBufferPush = "ReservoirReplayBufferPush";
constexpr auto kNameReservoirReplayBufferSample = "ReservoirReplayBufferSample";
constexpr auto kNameReservoirReplayBufferDestroy = "ReservoirReplayBufferDestroy";
class MIND_API ReservoirReplayBufferCreate : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ReservoirReplayBufferCreate);
/// \brief Constructor.
ReservoirReplayBufferCreate() : BaseOperator(kNameReservoirReplayBufferCreate) { InitIOName({}, {"handle"}); }
/// \brief Init.
/// Refer to the parameters of python API @ref mindspore.ops._rl_inner_ops.ReservoirReplayBufferCreate for the inputs.
void Init(const int64_t &capacity, std::vector<std::vector<int64_t>> &shapes, const std::vector<TypePtr> &types,
const int64_t &seed0, const int64_t &seed1);
void set_capacity(const int64_t &capacity);
void set_shapes(const std::vector<std::vector<int64_t>> &shapes);
void set_types(const std::vector<TypePtr> &types);
void set_schema(const std::vector<int64_t> &schema);
void set_seed0(const int64_t &seed0);
void set_seed1(const int64_t &seed1);
int64_t get_capacity() const;
std::vector<std::vector<int64_t>> get_shapes() const;
std::vector<TypePtr> get_types() const;
std::vector<int64_t> get_schema() const;
int64_t get_seed0() const;
int64_t get_seed1() const;
};
class MIND_API ReservoirReplayBufferPush : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ReservoirReplayBufferPush);
/// \brief Constructor.
ReservoirReplayBufferPush() : BaseOperator(kNameReservoirReplayBufferPush) { InitIOName({"transition"}, {"handle"}); }
/// \brief Init.
/// Refer to the parameters of python API @ref mindspore.ops._rl_inner_ops.ReservoirReplayBufferPush for the inputs.
void Init(const int64_t &handle);
void set_handle(const int64_t &handle);
int64_t get_handle() const;
};
class MIND_API ReservoirReplayBufferSample : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ReservoirReplayBufferSample);
/// \brief Constructor.
ReservoirReplayBufferSample() : BaseOperator(kNameReservoirReplayBufferSample) {
InitIOName({}, {"indices", "weights"});
}
/// \brief Init.
/// Refer to the parameters of python API @ref mindspore.ops._rl_inner_ops.ReservoirReplayBufferSample for the inputs.
void Init(const int64_t &handle, const int64_t &batch_size, const std::vector<std::vector<int64_t>> &shapes,
const std::vector<TypePtr> &types);
void set_handle(const int64_t &handle);
void set_batch_size(const int64_t &batch_size);
void set_shapes(const std::vector<std::vector<int64_t>> &shapes);
void set_types(const std::vector<TypePtr> &types);
void set_schema(const std::vector<int64_t> &schama);
int64_t get_handle() const;
int64_t get_batch_size() const;
std::vector<std::vector<int64_t>> get_shapes() const;
std::vector<TypePtr> get_types() const;
std::vector<int64_t> get_schema() const;
};
class MIND_API ReservoirReplayBufferDestroy : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ReservoirReplayBufferDestroy);
/// \brief Constructor.
ReservoirReplayBufferDestroy() : BaseOperator(kNameReservoirReplayBufferDestroy) {
InitIOName({"handle"}, {"handle"});
}
/// \brief Init.
/// Refer to the parameters of python API @ref mindspore.ops._rl_inner_ops.ReservoirReplayBufferUpdate for the inputs.
void Init(const int64_t &handle);
void set_handle(const int64_t &handle);
int64_t get_handle() const;
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_RESERVOIR_REPLAY_BUFFER_H_

View File

@ -21,7 +21,7 @@ from mindspore.common.dtype import type_size_in_bytes
import mindspore.context as context
from ..._checkparam import Validator as validator
from ...common import dtype as mstype
from ..primitive import prim_attr_register, PrimitiveWithInfer
from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive
from ..._checkparam import Rel
@ -501,6 +501,131 @@ class PriorityReplayBufferDestroy(PrimitiveWithInfer):
return mstype.int64
class ReservoirReplayBufferCreate(Primitive):
r"""
ReservoirReplayBufferCreate is experience container used in reinforcement learning.
The algorithm is proposed in `Random sampling with a reservoir <https://dl.acm.org/doi/pdf/10.1145/3147.3165>`
which used in `Deep Counterfactual Regret Minimization <https://arxiv.org/abs/1811.00164>`.
It lets the reinforcement learning agents remember and reuse experiences from the past. Besides, It keeps an
'unbiased' sample of previous iterations.
Args:
capcity (int64): Capacity of the buffer.
shapes (list[tuple[int]]): The dimensionality of the transition.
dtypes (list[:class:`mindspore.dtype`]): The type of the transition.
seed0 (int): Random seed0, must be non-negative. Default: 0.
seed1 (int): Random seed1, must be non-negative. Default: 0.
Outputs:
handle(Tensor): Handle of created replay buffer instance with dtype int64 and shape (1,).
Raises:
TypeError: The args not provided.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self, capacity, shapes, dtypes, seed0, seed1):
"""Initialize ReservoirReplayBufferCreate."""
validator.check_int(capacity, 1, Rel.GE, "capacity", self.name)
validator.check_value_type("shape of init data", shapes, [tuple, list], self.name)
validator.check_value_type("dtypes of init data", dtypes, [tuple, list], self.name)
validator.check_non_negative_int(seed0, "seed0", self.name)
validator.check_non_negative_int(seed1, "seed1", self.name)
schema = []
for shape, dtype in zip(shapes, dtypes):
num_element = functools.reduce(lambda x, y: x * y, shape)
schema.append(num_element * type_size_in_bytes(dtype))
self.add_prim_attr("schema", schema)
class ReservoirReplayBufferPush(Primitive):
r"""
Push a transition to the replay buffer.
Args:
handle(Tensor): The replay buffer instance handle with dtype int64 and shape (1,).
Outputs:
handle(Tensor): The replay buffer instance handle with dtype int64 and shape (1,).
Raises:
TypeError: The replay buffer not created before.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self, handle):
"""Initialize ReservoirReplayBufferPush."""
validator.check_int(handle, 0, Rel.GE, "handle", self.name)
class ReservoirReplayBufferSample(Primitive):
r"""
Sample a transition to the replay buffer.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Args:
handle(Tensor): Priority replay buffer instance handle with dtype int64 and shape (1,).
batch_size (int): The size of the sampled transitions.
shapes (list[tuple[int]]): The dimensionality of the transition.
dtypes (list[:class:`mindspore.dtype`]): The type of the transition.
Outputs:
tuple(Tensor): Transition with its indices and bias correction weights.
Raises:
TypeError: The replay buffer not created before.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self, handle, batch_size, shapes, dtypes):
"""Initialize PriorityReplaBufferSample."""
validator.check_int(handle, 0, Rel.GE, "capacity", self.name)
validator.check_int(batch_size, 1, Rel.GE, "batch_size", self.name)
validator.check_value_type("shape of init data", shapes, [tuple, list], self.name)
validator.check_value_type("dtypes of init data", dtypes, [tuple, list], self.name)
schema = []
for shape, dtype in zip(shapes, dtypes):
num_element = functools.reduce(lambda x, y: x * y, shape)
schema.append(num_element * type_size_in_bytes(dtype))
self.add_prim_attr("schema", schema)
class ReservoirReplayBufferDestroy(PrimitiveWithInfer):
r"""
Destroy the replay buffer.
Args:
handle(Tensor): The Replay buffer instance handle with dtype int64 and shape (1,).
Outputs:
Replay buffer instance handle with dtype int64 and shape (1,).
Raises:
TypeError: The replay buffer not created before.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self, handle):
"""Initialize ReservoirReplayBufferDestroy."""
validator.check_int(handle, 0, Rel.GE, "handle", self.name)
class BatchAssign(PrimitiveWithInfer):
"""
Assign the parameters of the source to overwrite the target.