forked from mindspore-Ecosystem/mindspore
commit
cefc3b9dc8
|
@ -43,8 +43,10 @@ FIFOReplayBuffer::FIFOReplayBuffer(size_t capacity, const std::vector<size_t> &s
|
|||
|
||||
FIFOReplayBuffer::~FIFOReplayBuffer() {
|
||||
for (const auto &item : buffer_) {
|
||||
if (item->addr) device::cpu::CPUMemoryPool::GetInstance().FreeTensorMem(item->addr);
|
||||
item->addr = nullptr;
|
||||
if (item->addr) {
|
||||
device::cpu::CPUMemoryPool::GetInstance().FreeTensorMem(item->addr);
|
||||
item->addr = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -40,7 +40,6 @@ bool ReservoirReplayBufferCreateCpuKernel::Init(const BaseOperatorPtr &base_oper
|
|||
return false;
|
||||
}
|
||||
|
||||
const int64_t &capacity = kernel_ptr->get_capacity();
|
||||
const std::vector<int64_t> &schema = kernel_ptr->get_schema();
|
||||
const int64_t &seed0 = kernel_ptr->get_seed0();
|
||||
const int64_t &seed1 = kernel_ptr->get_seed1();
|
||||
|
@ -60,6 +59,7 @@ bool ReservoirReplayBufferCreateCpuKernel::Init(const BaseOperatorPtr &base_oper
|
|||
[](const int64_t &arg) -> size_t { return LongToSize(arg); });
|
||||
|
||||
auto &factory = ReservoirReplayBufferFactory::GetInstance();
|
||||
const int64_t &capacity = kernel_ptr->get_capacity();
|
||||
std::tie(handle_, reservoir_replay_buffer_) = factory.Create(seed, capacity, schema_in_size);
|
||||
MS_EXCEPTION_IF_NULL(reservoir_replay_buffer_);
|
||||
|
||||
|
|
|
@ -193,12 +193,11 @@ void ReservoirReplayBufferSample::Init(const int64_t &handle, const int64_t &bat
|
|||
size_t tensor_size = std::accumulate(shapes[i].begin(), shapes[i].end(), type_size, std::multiplies<int64_t>());
|
||||
schema.push_back(tensor_size);
|
||||
}
|
||||
|
||||
this->set_schema(schema);
|
||||
this->set_handle(handle);
|
||||
this->set_batch_size(batch_size);
|
||||
this->set_shapes(shapes);
|
||||
this->set_types(types);
|
||||
this->set_schema(schema);
|
||||
}
|
||||
|
||||
void ReservoirReplayBufferDestroy::set_handle(const int64_t &handle) {
|
||||
|
@ -258,7 +257,7 @@ AbstractBasePtr SampleInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
(void)shape.emplace(shape.begin(), batch_size);
|
||||
auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, types[i]);
|
||||
auto tensor = std::make_shared<abstract::AbstractTensor>(element, std::make_shared<abstract::Shape>(shape));
|
||||
output.emplace_back(tensor);
|
||||
(void)output.emplace_back(tensor);
|
||||
}
|
||||
|
||||
return std::make_shared<abstract::AbstractTuple>(output);
|
||||
|
|
|
@ -19,11 +19,11 @@ from __future__ import absolute_import
|
|||
import functools
|
||||
from mindspore.common.dtype import type_size_in_bytes
|
||||
import mindspore.context as context
|
||||
from ..._checkparam import Validator as validator
|
||||
from ...common import dtype as mstype
|
||||
from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive
|
||||
from ..._checkparam import Rel
|
||||
from ...communication.management import GlobalComm
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer, Primitive
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore.communication.management import GlobalComm
|
||||
|
||||
|
||||
class EnvCreate(PrimitiveWithInfer):
|
||||
|
|
Loading…
Reference in New Issue