Merge pull request !42145 from chenweifeng/r1.9
This commit is contained in:
i-robot 2022-09-19 06:43:55 +00:00 committed by Gitee
commit cefc3b9dc8
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 12 additions and 11 deletions

View File

@ -43,8 +43,10 @@ FIFOReplayBuffer::FIFOReplayBuffer(size_t capacity, const std::vector<size_t> &s
FIFOReplayBuffer::~FIFOReplayBuffer() {
for (const auto &item : buffer_) {
if (item->addr) device::cpu::CPUMemoryPool::GetInstance().FreeTensorMem(item->addr);
item->addr = nullptr;
if (item->addr) {
device::cpu::CPUMemoryPool::GetInstance().FreeTensorMem(item->addr);
item->addr = nullptr;
}
}
}

View File

@ -40,7 +40,6 @@ bool ReservoirReplayBufferCreateCpuKernel::Init(const BaseOperatorPtr &base_oper
return false;
}
const int64_t &capacity = kernel_ptr->get_capacity();
const std::vector<int64_t> &schema = kernel_ptr->get_schema();
const int64_t &seed0 = kernel_ptr->get_seed0();
const int64_t &seed1 = kernel_ptr->get_seed1();
@ -60,6 +59,7 @@ bool ReservoirReplayBufferCreateCpuKernel::Init(const BaseOperatorPtr &base_oper
[](const int64_t &arg) -> size_t { return LongToSize(arg); });
auto &factory = ReservoirReplayBufferFactory::GetInstance();
const int64_t &capacity = kernel_ptr->get_capacity();
std::tie(handle_, reservoir_replay_buffer_) = factory.Create(seed, capacity, schema_in_size);
MS_EXCEPTION_IF_NULL(reservoir_replay_buffer_);

View File

@ -193,12 +193,11 @@ void ReservoirReplayBufferSample::Init(const int64_t &handle, const int64_t &bat
size_t tensor_size = std::accumulate(shapes[i].begin(), shapes[i].end(), type_size, std::multiplies<int64_t>());
schema.push_back(tensor_size);
}
this->set_schema(schema);
this->set_handle(handle);
this->set_batch_size(batch_size);
this->set_shapes(shapes);
this->set_types(types);
this->set_schema(schema);
}
void ReservoirReplayBufferDestroy::set_handle(const int64_t &handle) {
@ -258,7 +257,7 @@ AbstractBasePtr SampleInfer(const abstract::AnalysisEnginePtr &, const Primitive
(void)shape.emplace(shape.begin(), batch_size);
auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, types[i]);
auto tensor = std::make_shared<abstract::AbstractTensor>(element, std::make_shared<abstract::Shape>(shape));
output.emplace_back(tensor);
(void)output.emplace_back(tensor);
}
return std::make_shared<abstract::AbstractTuple>(output);

View File

@ -19,11 +19,11 @@ from __future__ import absolute_import
import functools
from mindspore.common.dtype import type_size_in_bytes
import mindspore.context as context
from ..._checkparam import Validator as validator
from ...common import dtype as mstype
from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive
from ..._checkparam import Rel
from ...communication.management import GlobalComm
from mindspore._checkparam import Validator as validator
from mindspore.common import dtype as mstype
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer, Primitive
from mindspore._checkparam import Rel
from mindspore.communication.management import GlobalComm
class EnvCreate(PrimitiveWithInfer):