sync code

This commit is contained in:
liyong 2021-09-08 15:39:21 +08:00
parent 97eae06817
commit f752054c19
128 changed files with 556 additions and 518 deletions

View File

@ -34,7 +34,7 @@ bool set_seed(int32_t seed) {
MS_LOG(ERROR) << "Seed given is not within the required range: " << seed;
return false;
}
_config->set_seed((uint32_t)seed);
_config->set_seed(static_cast<uint32_t>(seed));
return true;
}
@ -73,7 +73,7 @@ bool set_monitor_sampling_interval(int32_t interval) {
MS_LOG(ERROR) << "Interval given is not within the required range: " << interval;
return false;
}
_config->set_monitor_sampling_interval((uint32_t)interval);
_config->set_monitor_sampling_interval(static_cast<uint32_t>(interval));
return true;
}
@ -86,7 +86,7 @@ bool set_callback_timeback(int32_t timeout) {
MS_LOG(ERROR) << "Timeout given is not within the required range: " << timeout;
return false;
}
_config->set_callback_timeout((uint32_t)timeout);
_config->set_callback_timeout(static_cast<uint32_t>(timeout));
return true;
}

View File

@ -209,7 +209,6 @@ bool Dataset::DeviceQueueCharIF(const std::vector<char> &queue_name, const std::
MS_LOG(ERROR) << "ToDevice: Failed to get consumer.";
return false;
}
rc = consumer->Init(ds);
if (rc.IsError()) {
MS_LOG(ERROR) << "ToDevice: Failed to init. Error status: " << rc;
@ -252,7 +251,6 @@ bool Dataset::SaveCharIF(const std::vector<char> &dataset_path, int32_t num_file
MS_LOG(ERROR) << "ToDevice: Failed to get consumer.";
return false;
}
rc = consumer->Init(ds->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "CreateSaver failed." << rc;
@ -283,11 +281,15 @@ bool Dataset::SaveCharIF(const std::vector<char> &dataset_path, int32_t num_file
Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); }
int64_t Dataset::GetDatasetSize(bool estimate) {
int64_t dataset_size;
int64_t dataset_size = -1;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
std::shared_ptr<DatasetSizeGetter> size_getter = std::make_shared<DatasetSizeGetter>();
DatasetSizeGetter *consumer = size_getter.get();
if (consumer == nullptr) {
MS_LOG(ERROR) << "DatasetSizeGetter: Failed to get consumer.";
return -1;
}
runtime_context->AssignConsumer(size_getter);
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), -1);
RETURN_SECOND_IF_ERROR(consumer->GetDatasetSize(&dataset_size, estimate), -1);
@ -299,6 +301,10 @@ std::vector<mindspore::DataType> Dataset::GetOutputTypes() {
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
TreeGetters *consumer = tree_getters_.get();
if (consumer == nullptr) {
MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
return std::vector<mindspore::DataType>();
}
runtime_context->AssignConsumer(tree_getters_);
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
RETURN_SECOND_IF_ERROR(consumer->GetOutputTypes(&types), {});
@ -314,6 +320,10 @@ std::vector<std::vector<int64_t>> Dataset::GetOutputShapes() {
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
TreeGetters *consumer = tree_getters_.get();
if (consumer == nullptr) {
MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
return std::vector<std::vector<int64_t>>();
}
runtime_context->AssignConsumer(tree_getters_);
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
RETURN_SECOND_IF_ERROR(consumer->GetOutputShapes(&shapes), {});
@ -324,10 +334,14 @@ std::vector<std::vector<int64_t>> Dataset::GetOutputShapes() {
}
int64_t Dataset::GetNumClasses() {
int64_t num_classes;
int64_t num_classes = -1;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
TreeGetters *consumer = tree_getters_.get();
if (consumer == nullptr) {
MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
return -1;
}
runtime_context->AssignConsumer(tree_getters_);
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), -1);
RETURN_SECOND_IF_ERROR(consumer->GetNumClasses(&num_classes), -1);
@ -339,6 +353,10 @@ std::vector<std::vector<char>> Dataset::GetColumnNamesCharIF() {
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
TreeGetters *consumer = tree_getters_.get();
if (consumer == nullptr) {
MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
return std::vector<std::vector<char>>();
}
runtime_context->AssignConsumer(tree_getters_);
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
RETURN_SECOND_IF_ERROR(consumer->GetColumnNames(&col_names), {});
@ -350,6 +368,10 @@ std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> Dataset::GetClas
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
TreeGetters *consumer = tree_getters_.get();
if (consumer == nullptr) {
MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
return std::vector<std::pair<std::vector<char>, std::vector<int32_t>>>();
}
runtime_context->AssignConsumer(tree_getters_);
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
RETURN_SECOND_IF_ERROR(consumer->GetClassIndexing(&output_class_indexing), {});
@ -487,10 +509,10 @@ TakeDataset::TakeDataset(std::shared_ptr<Dataset> input, int32_t count) {
ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
std::vector<std::shared_ptr<DatasetNode>> all_datasets;
(void)std::transform(
datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
[](std::shared_ptr<Dataset> dataset) -> std::shared_ptr<DatasetNode> { return dataset->IRNode(); });
(void)std::transform(datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
[](std::shared_ptr<Dataset> dataset) -> std::shared_ptr<DatasetNode> {
return (dataset != nullptr) ? dataset->IRNode() : nullptr;
});
auto ds = std::make_shared<ZipNode>(all_datasets);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
@ -538,6 +560,10 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocabCharIF(
auto consumer = std::make_unique<BuildVocabConsumer>();
BuildVocabConsumer *bv_consumer = consumer.get();
if (bv_consumer == nullptr) {
MS_LOG(ERROR) << "BuildVocabConsumer: Failed to get bv_consumer.";
return nullptr;
}
rc = consumer->Init(ds);
if (rc.IsError()) {
MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init consumer. Error status: " << rc;
@ -571,6 +597,10 @@ std::shared_ptr<Vocab> Dataset::BuildVocabCharIF(const std::vector<std::vector<c
auto consumer = std::make_unique<BuildVocabConsumer>();
BuildVocabConsumer *bv_consumer = consumer.get();
if (bv_consumer == nullptr) {
MS_LOG(ERROR) << "BuildVocabConsumer: Failed to get bv_consumer.";
return nullptr;
}
rc = consumer->Init(ds);
if (rc.IsError()) {
MS_LOG(ERROR) << "BuildVocab: Failed to init consumer. Error status: " << rc;

View File

@ -54,20 +54,27 @@ struct Execute::ExtraInfo {
#endif
};
Execute::Execute(std::shared_ptr<TensorOperation> op, MapTargetDevice device_type, uint32_t device_id) {
ops_.emplace_back(std::move(op));
device_type_ = device_type;
info_ = std::make_shared<ExtraInfo>();
Status Execute::InitResource(MapTargetDevice device_type, uint32_t device_id) {
#ifdef ENABLE_ACL
if (device_type_ == MapTargetDevice::kAscend310) {
device_resource_ = std::make_shared<AscendResource>();
Status rc = device_resource_->InitResource(device_id);
if (!rc.IsOk()) {
device_resource_ = nullptr;
MS_LOG(ERROR) << "Initialize Ascend310 resource fail.";
std::string err_msg = "Initialize Ascend310 resource fail";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_UNEXPECTED(err_msg);
}
}
#endif
return Status::OK();
}
Execute::Execute(std::shared_ptr<TensorOperation> op, MapTargetDevice device_type, uint32_t device_id) {
ops_.emplace_back(std::move(op));
device_type_ = device_type;
info_ = std::make_shared<ExtraInfo>();
(void)InitResource(device_type, device_id);
}
Execute::Execute(std::shared_ptr<TensorTransform> op, MapTargetDevice device_type, uint32_t device_id) {
@ -76,16 +83,7 @@ Execute::Execute(std::shared_ptr<TensorTransform> op, MapTargetDevice device_typ
info_ = std::make_shared<ExtraInfo>();
device_type_ = device_type;
#ifdef ENABLE_ACL
if (device_type_ == MapTargetDevice::kAscend310) {
device_resource_ = std::make_shared<AscendResource>();
Status rc = device_resource_->InitResource(device_id);
if (!rc.IsOk()) {
device_resource_ = nullptr;
MS_LOG(ERROR) << "Initialize Ascend310 resource fail.";
}
}
#endif
(void)InitResource(device_type, device_id);
}
Execute::Execute(std::reference_wrapper<TensorTransform> op, MapTargetDevice device_type, uint32_t device_id) {
@ -96,16 +94,7 @@ Execute::Execute(std::reference_wrapper<TensorTransform> op, MapTargetDevice dev
info_ = std::make_shared<ExtraInfo>();
info_->init_with_shared_ptr_ = false;
device_type_ = device_type;
#ifdef ENABLE_ACL
if (device_type_ == MapTargetDevice::kAscend310) {
device_resource_ = std::make_shared<AscendResource>();
Status rc = device_resource_->InitResource(device_id);
if (!rc.IsOk()) {
device_resource_ = nullptr;
MS_LOG(ERROR) << "Initialize Ascend310 resource fail.";
}
}
#endif
(void)InitResource(device_type, device_id);
}
// Execute function for the example case: auto decode(new vision::Decode());
@ -116,31 +105,13 @@ Execute::Execute(TensorTransform *op, MapTargetDevice device_type, uint32_t devi
info_ = std::make_shared<ExtraInfo>();
device_type_ = device_type;
#ifdef ENABLE_ACL
if (device_type_ == MapTargetDevice::kAscend310) {
device_resource_ = std::make_shared<AscendResource>();
Status rc = device_resource_->InitResource(device_id);
if (!rc.IsOk()) {
device_resource_ = nullptr;
MS_LOG(ERROR) << "Initialize Ascend310 resource fail.";
}
}
#endif
(void)InitResource(device_type, device_id);
}
Execute::Execute(std::vector<std::shared_ptr<TensorOperation>> ops, MapTargetDevice device_type, uint32_t device_id)
: ops_(std::move(ops)), device_type_(device_type) {
info_ = std::make_shared<ExtraInfo>();
#ifdef ENABLE_ACL
if (device_type_ == MapTargetDevice::kAscend310) {
device_resource_ = std::make_shared<AscendResource>();
Status rc = device_resource_->InitResource(device_id);
if (!rc.IsOk()) {
device_resource_ = nullptr;
MS_LOG(ERROR) << "Initialize Ascend310 resource fail.";
}
}
#endif
(void)InitResource(device_type, device_id);
}
Execute::Execute(std::vector<std::shared_ptr<TensorTransform>> ops, MapTargetDevice device_type, uint32_t device_id) {
@ -149,16 +120,7 @@ Execute::Execute(std::vector<std::shared_ptr<TensorTransform>> ops, MapTargetDev
info_ = std::make_shared<ExtraInfo>();
device_type_ = device_type;
#ifdef ENABLE_ACL
if (device_type_ == MapTargetDevice::kAscend310) {
device_resource_ = std::make_shared<AscendResource>();
Status rc = device_resource_->InitResource(device_id);
if (!rc.IsOk()) {
device_resource_ = nullptr;
MS_LOG(ERROR) << "Initialize Ascend310 resource fail.";
}
}
#endif
(void)InitResource(device_type, device_id);
}
Execute::Execute(const std::vector<std::reference_wrapper<TensorTransform>> ops, MapTargetDevice device_type,
@ -177,16 +139,7 @@ Execute::Execute(const std::vector<std::reference_wrapper<TensorTransform>> ops,
info_ = std::make_shared<ExtraInfo>();
info_->init_with_shared_ptr_ = false;
device_type_ = device_type;
#ifdef ENABLE_ACL
if (device_type_ == MapTargetDevice::kAscend310) {
device_resource_ = std::make_shared<AscendResource>();
Status rc = device_resource_->InitResource(device_id);
if (!rc.IsOk()) {
device_resource_ = nullptr;
MS_LOG(ERROR) << "Initialize Ascend310 resource fail";
}
}
#endif
(void)InitResource(device_type, device_id);
}
// Execute function for the example vector case: auto decode(new vision::Decode());
@ -199,16 +152,7 @@ Execute::Execute(const std::vector<TensorTransform *> &ops, MapTargetDevice devi
info_ = std::make_shared<ExtraInfo>();
device_type_ = device_type;
#ifdef ENABLE_ACL
if (device_type_ == MapTargetDevice::kAscend310) {
device_resource_ = std::make_shared<AscendResource>();
Status rc = device_resource_->InitResource(device_id);
if (!rc.IsOk()) {
device_resource_ = nullptr;
MS_LOG(ERROR) << "Initialize Ascend310 resource fail";
}
}
#endif
(void)InitResource(device_type, device_id);
}
Execute::~Execute() {
@ -225,6 +169,7 @@ Execute::~Execute() {
Status Execute::operator()(const mindspore::MSTensor &input, mindspore::MSTensor *output) {
// Validate input tensor
RETURN_UNEXPECTED_IF_NULL(output);
CHECK_FAIL_RETURN_UNEXPECTED(input.DataSize() > 0, "Input Tensor has no data.");
CHECK_FAIL_RETURN_UNEXPECTED(ValidateDevice(), "Device Type should be 'Ascend310' or 'CPU'.");
@ -283,7 +228,8 @@ Status Execute::operator()(const mindspore::MSTensor &input, mindspore::MSTensor
RETURN_STATUS_UNEXPECTED(ss.str());
}
*output = mindspore::MSTensor(std::make_shared<DETensor>(de_tensor));
} else { // Ascend310 case, where we must set Ascend resource on each operators
} else if (device_type_ ==
MapTargetDevice::kAscend310) { // Ascend310 case, where we must set Ascend resource on each operators
#ifdef ENABLE_ACL
CHECK_FAIL_RETURN_UNEXPECTED(device_resource_, "Device resource is nullptr which is illegal under case Ascend310.");
// Sink data from host into device
@ -304,12 +250,17 @@ Status Execute::operator()(const mindspore::MSTensor &input, mindspore::MSTensor
*output = mindspore::MSTensor(std::make_shared<DETensor>(device_input, true));
#endif
} else {
std::string err_msg = "Your input device is not supported. (Option: CPU or Ascend310)";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_UNEXPECTED(err_msg);
}
return Status::OK();
}
Status Execute::operator()(const std::vector<MSTensor> &input_tensor_list, std::vector<MSTensor> *output_tensor_list) {
// Validate input tensor
RETURN_UNEXPECTED_IF_NULL(output_tensor_list);
CHECK_FAIL_RETURN_UNEXPECTED(!input_tensor_list.empty(), "Input Tensor is not valid.");
for (auto &tensor : input_tensor_list) {
CHECK_FAIL_RETURN_UNEXPECTED(tensor.DataSize() > 0, "Input Tensor has no data.");
@ -371,7 +322,8 @@ Status Execute::operator()(const std::vector<MSTensor> &input_tensor_list, std::
++idx;
}
CHECK_FAIL_RETURN_UNEXPECTED(!output_tensor_list->empty(), "Output Tensor is not valid.");
} else { // Case Ascend310
} else if (device_type_ ==
MapTargetDevice::kAscend310) { // Ascend310 case, where we must set Ascend resource on each operators
#ifdef ENABLE_ACL
CHECK_FAIL_RETURN_UNEXPECTED(device_resource_, "Device resource is nullptr which is illegal under case Ascend310.");
for (auto &input_tensor : input_tensor_list) {
@ -401,6 +353,10 @@ Status Execute::operator()(const std::vector<MSTensor> &input_tensor_list, std::
}
CHECK_FAIL_RETURN_UNEXPECTED(!output_tensor_list->empty(), "Output Tensor vector is empty.");
#endif
} else {
std::string err_msg = "Your input device is not supported. (Option: CPU or Ascend310)";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_UNEXPECTED(err_msg);
}
return Status::OK();
}
@ -517,9 +473,17 @@ Status AippInfoCollection(std::map<std::string, std::string> *aipp_options, cons
std::string Execute::AippCfgGenerator() {
std::string config_location = "./aipp.cfg";
if (info_ == nullptr) {
MS_LOG(ERROR) << "info_ is null";
return "";
}
#ifdef ENABLE_ACL
if (info_->init_with_shared_ptr_) {
ParseTransforms();
auto s = ParseTransforms();
if (s != Status::OK()) {
MS_LOG(ERROR) << "Error in ParseTransforms";
return "";
}
info_->init_with_shared_ptr_ = false;
}
std::vector<uint32_t> paras; // Record the parameters value of each Ascend operators
@ -548,8 +512,7 @@ std::string Execute::AippCfgGenerator() {
if (!outfile.is_open()) {
MS_LOG(ERROR) << "Fail to open Aipp config file, please verify your system config(including authority)."
<< "We will return empty string which represent the location of Aipp config file in this case.";
std::string except = "";
return except;
return "";
}
if (device_type_ == MapTargetDevice::kAscend310) {
@ -629,10 +592,14 @@ Status Execute::ParseTransforms() {
[](std::shared_ptr<TensorTransform> operation) -> std::shared_ptr<TensorOperation> {
return operation->Parse();
});
} else {
} else if (device_type_ == MapTargetDevice::kAscend310) {
for (auto &transform_ : transforms_) {
ops_.emplace_back(transform_->Parse(device_type_));
}
} else {
std::string err_msg = "Your input device is not supported. (Option: CPU or Ascend310)";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_UNEXPECTED(err_msg);
}
return Status::OK();

View File

@ -38,7 +38,7 @@ Status Iterator::GetNextRowCharIF(MSTensorMapChar *row) {
row->clear();
return rc;
}
for (auto de_tensor : md_map) {
for (auto &de_tensor : md_map) {
std::vector<char> col_name(de_tensor.first.begin(), de_tensor.first.end());
row->insert(std::make_pair(col_name, mindspore::MSTensor(std::make_shared<DETensor>(de_tensor.second))));
}
@ -48,8 +48,8 @@ Status Iterator::GetNextRowCharIF(MSTensorMapChar *row) {
// Get the next row from the data pipeline.
Status Iterator::GetNextRow(MSTensorVec *row) {
// Clean data row
RETURN_UNEXPECTED_IF_NULL(row);
// Clean data row
row->clear();
// create a dataset tensor row and fetch. Then we convert the output to MSTensor
std::vector<std::shared_ptr<dataset::Tensor>> md_row;
@ -76,6 +76,7 @@ void Iterator::Stop() {
// Function to build and launch the execution tree.
Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds, int32_t num_epochs) {
RETURN_UNEXPECTED_IF_NULL(ds);
runtime_context_ = std::make_unique<NativeRuntimeContext>();
CHECK_FAIL_RETURN_UNEXPECTED(runtime_context_ != nullptr, "Create runtime_context_ failed.");
RETURN_IF_NOT_OK(runtime_context_->Init());

View File

@ -102,7 +102,7 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
// clear the old tensor row
out_row->clear();
bool isProfilingEnable = root_->Tree()->GetProfilingManager()->IsProfilingEnable();
bool is_profiling_enable = root_->Tree()->GetProfilingManager()->IsProfilingEnable();
// Once eof is handled, always return empty row. Class must be destroyed and recreated if you
// want to iterate again.
@ -124,7 +124,7 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
// The next row in the pipeline might be an EOF or a TensorRow for next epoch
if (out_row->eoe()) {
MS_LOG(INFO) << "End of data iteration.";
if (isProfilingEnable) {
if (is_profiling_enable) {
root_->Tree()->SetEpochEnd();
}
return Status::OK();

View File

@ -250,9 +250,13 @@ Status BatchOp::WorkerEntry(int32_t workerId) {
Status BatchOp::MakeBatchedRow(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair, TensorRow *new_row) {
RETURN_UNEXPECTED_IF_NULL(table_pair.first);
#ifdef ENABLE_PYTHON
if (!in_col_names_.empty()) RETURN_IF_NOT_OK(MapColumns(&table_pair)); // pass it through pyfunc
if (!in_col_names_.empty()) {
RETURN_IF_NOT_OK(MapColumns(&table_pair));
} // pass it through pyfun
#endif
if (pad_) RETURN_IF_NOT_OK(PadColumns(&table_pair.first, pad_info_, column_name_id_map_)); // do padding if needed
if (pad_) {
RETURN_IF_NOT_OK(PadColumns(&table_pair.first, pad_info_, column_name_id_map_));
} // do padding if needed
RETURN_IF_NOT_OK(BatchRows(&table_pair.first, new_row, table_pair.first->size()));
return Status::OK();
}

View File

@ -242,7 +242,7 @@ class BatchOp : public ParallelOp {
// the number of thread pulling from the mOutConnector of the Op below
// @return int32_t, 1
int32_t num_consumers() const override { return 1; }
int32_t NumConsumers() const override { return 1; }
// get the batch size for next batch
// @return Status The status code returned

View File

@ -71,11 +71,11 @@ class BuildSentencePieceVocabOp : public PipelineOp {
// Getter
// @return the number of workers
int32_t num_producers() const override { return 1; }
int32_t NumProducers() const override { return 1; }
// Getter
// @return the number of threads consuming from the previous Connector
int32_t num_consumers() const override { return 1; }
int32_t NumConsumers() const override { return 1; }
Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildSentencePieceVocabOp"); }

View File

@ -68,11 +68,11 @@ class BuildVocabOp : public ParallelOp {
/// Getter
/// @return the number of workers
int32_t num_producers() const override { return 1; }
int32_t NumProducers() const override { return 1; }
/// Getter
/// @return the number of threads consuming from the previous Connector
int32_t num_consumers() const override { return 1; }
int32_t NumConsumers() const override { return 1; }
Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildVocabOp"); }

View File

@ -69,6 +69,7 @@ Status CacheBase::FetchSamplesToWorkers() {
int64_t wait_cnt = 0;
int64_t prefetch_cnt = 0;
// Kick off several threads which will prefetch prefetch_size_ rows in advance.
RETURN_UNEXPECTED_IF_NULL(tree_);
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_prefetchers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1), Name()));
auto send_to_que = [](QueueList<std::unique_ptr<IOBlock>> &qList, int32_t worker_id,
@ -80,7 +81,7 @@ Status CacheBase::FetchSamplesToWorkers() {
// Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them
// to the WorkerEntry.
do {
if (AllowCacheMiss() && wait_cnt > 0 && wait_cnt % op_num_repeats_per_epoch() == 0) {
if (AllowCacheMiss() && wait_cnt > 0 && wait_cnt % GetOpNumRepeatsPerEpoch() == 0) {
MS_LOG(INFO) << "Epoch: " << op_current_epochs_ << " Cache Miss : " << num_cache_miss_
<< " Total number of rows : " << row_cnt_;
}
@ -155,7 +156,7 @@ Status CacheBase::FetchSamplesToWorkers() {
}
// Dump the last epoch result (approximately) without waiting for the worker threads to come back.
if (AllowCacheMiss()) {
MS_LOG(INFO) << "Epoch: " << wait_cnt / op_num_repeats_per_epoch() << " Cache Miss : " << num_cache_miss_
MS_LOG(INFO) << "Epoch: " << wait_cnt / GetOpNumRepeatsPerEpoch() << " Cache Miss : " << num_cache_miss_
<< " Total number of rows : " << row_cnt_;
}
return Status::OK();
@ -202,6 +203,7 @@ Status CacheBase::FetchFromCache(int32_t worker_id) {
}
Status CacheBase::RegisterResources() {
RETURN_UNEXPECTED_IF_NULL(tree_);
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks()));

View File

@ -75,7 +75,7 @@ class CacheBase : public ParallelOp {
/// \brief Getter for the cache client
/// \return shared ptr to the cache client
std::shared_ptr<CacheClient> cache_client() { return cache_client_; }
std::shared_ptr<CacheClient> GetCacheClient() { return cache_client_; }
/// \brief Setter for the cache client
void SetCacheClient(std::shared_ptr<CacheClient> cache_client) { cache_client_ = std::move(cache_client); }
/// \brief Derived class must implement this method if a cache miss is treated as error

View File

@ -73,6 +73,7 @@ Status CacheOp::InitCache() { return Status::OK(); }
// This class functor will provide the master loop that drives the logic for performing the work
Status CacheOp::operator()() {
RETURN_UNEXPECTED_IF_NULL(tree_);
if (!sampler_) {
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, CacheOp requires a sampler before it can be executed, but got nullptr.");
@ -199,6 +200,7 @@ Status CacheOp::WorkerEntry(int32_t worker_id) {
return Status::OK();
}
Status CacheOp::RegisterResources() {
RETURN_UNEXPECTED_IF_NULL(tree_);
RETURN_IF_NOT_OK(CacheBase::RegisterResources());
RETURN_IF_NOT_OK(rows_cache_done_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(keys_miss_->Register(tree_->AllTasks()));

View File

@ -194,7 +194,7 @@ Status ConcatOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe
return Status::OK();
}
int32_t ConcatOp::num_consumers() const {
int32_t ConcatOp::NumConsumers() const {
if (parent_.empty()) {
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
return 1;
@ -202,16 +202,16 @@ int32_t ConcatOp::num_consumers() const {
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
return 0;
} else {
return parent_[0]->num_consumers();
return parent_[0]->NumConsumers();
}
}
int32_t ConcatOp::num_producers() const {
int32_t ConcatOp::NumProducers() const {
if (child_.empty() || child_[0] == nullptr) {
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
return 0;
} else {
return child_[0]->num_producers();
return child_[0]->NumProducers();
}
}
} // namespace dataset

View File

@ -73,8 +73,8 @@ class ConcatOp : public PipelineOp {
Status GetNumClasses(int64_t *num_classes) override;
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
int32_t num_consumers() const override;
int32_t num_producers() const override;
int32_t NumConsumers() const override;
int32_t NumProducers() const override;
/// Check if the current sample will be taken or dropped
/// \return bool

View File

@ -312,9 +312,9 @@ Status DatasetOp::PrepareOperator() {
// The consumer of the root node is assumed to be one thread.
// If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion.
if (parent_.empty()) {
this->CreateConnector(num_producers(), 1);
this->CreateConnector(NumProducers(), 1);
} else {
this->CreateConnector(num_producers(), parent_[0]->num_consumers());
this->CreateConnector(NumProducers(), parent_[0]->NumConsumers());
}
if (out_connector_) {
RETURN_IF_NOT_OK(out_connector_->Register(tree_->AllTasks()));

View File

@ -209,33 +209,33 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
// \brief Getter function
// \return The number of workers in this op
virtual int32_t num_workers() const = 0;
virtual int32_t NumWorkers() const = 0;
// \brief Getter function
// \return The number of threads consuming from previous op.
virtual int32_t num_consumers() const = 0;
virtual int32_t NumConsumers() const = 0;
// \brief Getter function
// \return The number of threads producing to the output connector.
virtual int32_t num_producers() const = 0;
virtual int32_t NumProducers() const = 0;
// \brief Getter function
// \return T/F if this is an inlined operator
bool inlined() const { return (oc_queue_size_ == 0); }
// \brief Setter function, set the number of total repeats for the operator
void set_total_repeats(int32_t total_repeats) { op_total_repeats_ = total_repeats; }
void SetTotalRepeats(int32_t total_repeats) { op_total_repeats_ = total_repeats; }
// \brief Setter function, set the number of repeats per epoch for the operator
void set_num_repeats_per_epoch(int32_t num_repeats_per_epoch) { op_num_repeats_per_epoch_ = num_repeats_per_epoch; }
void SetNumRepeatsPerEpoch(int32_t num_repeats_per_epoch) { op_num_repeats_per_epoch_ = num_repeats_per_epoch; }
// \brief Getter function
// \return The number of required repeats for the operator
int32_t op_total_repeats() { return op_total_repeats_; }
int32_t GetOpTotalRepeats() { return op_total_repeats_; }
// \brief Getter function
// \return The number of repeats per epoch for the operator
int32_t op_num_repeats_per_epoch() const { return op_num_repeats_per_epoch_; }
int32_t GetOpNumRepeatsPerEpoch() const { return op_num_repeats_per_epoch_; }
// \brief Register the internal worker connectors. No op unless it is a parallel op
// \return Status
@ -393,7 +393,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
// \notes No public interface. Only the class itself, or it's friend the execution tree can set
// this
// \param op_id - the Id value to set into the operator
void set_id(int32_t op_id) { operator_id_ = op_id; }
void SetId(int32_t op_id) { operator_id_ = op_id; }
// Sets the tree into the op so that the operator has a back pointer to the tree.
// \param tree - the tree to assign to the op.

View File

@ -173,8 +173,8 @@ Status DeviceQueueOp::SendDataToAscend() {
int64_t sending_num = cfg->sending_batches(); // Get the current sending_num
std::shared_ptr<DeviceQueueTracing> profiling_node;
bool isProfilingEnable = tree_->GetProfilingManager()->IsProfilingEnable();
if (isProfilingEnable) {
bool is_profiling_enable = tree_->GetProfilingManager()->IsProfilingEnable();
if (is_profiling_enable) {
std::shared_ptr<Tracing> node;
RETURN_IF_NOT_OK(tree_->GetProfilingManager()->GetTracingNode(kDeviceQueueTracingName, &node));
profiling_node = std::dynamic_pointer_cast<DeviceQueueTracing>(node);
@ -197,12 +197,12 @@ Status DeviceQueueOp::SendDataToAscend() {
md_channel_info_->RecordPreprocessBatch(send_batch);
md_channel_info_->RecordPushStartTime();
#endif
RETURN_IF_NOT_OK(SendRowToTdt(curr_row, isProfilingEnable, &tdt_cost));
RETURN_IF_NOT_OK(SendRowToTdt(curr_row, is_profiling_enable, &tdt_cost));
if (first_push_flag_ != true) {
MS_LOG(INFO) << "Loading dataset and push first batch into device successful";
first_push_flag_ = true;
}
ProfilingRecorder(isProfilingEnable, profiling_node, send_batch, tdt_cost, &batch_start_time, &end_time,
ProfilingRecorder(is_profiling_enable, profiling_node, send_batch, tdt_cost, &batch_start_time, &end_time,
connector_capacity, connector_size);
send_batch++;
#ifdef ENABLE_DUMP_IR
@ -219,15 +219,15 @@ Status DeviceQueueOp::SendDataToAscend() {
// wait when sending num is not 0, and sending num no larger than already sending batch
LimitSendingBatches(send_batch, &sending_num, cfg);
if (isProfilingEnable) {
if (is_profiling_enable) {
connector_size = ChildOpConnectorSize();
connector_capacity = ChildOpConnectorCapacity();
}
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&curr_row));
}
if (curr_row.eoe() && send_epoch_end_) {
TensorRow currRow;
auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost,
TensorRow dummy_row;
auto status = tdtInstancePtr->hostPush(dummy_row, true, channel_name_, is_profiling_enable, tdt_cost,
ACL_TENSOR_DATA_END_OF_SEQUENCE);
if (status != Status::OK()) {
if (stop_send_) {
@ -246,7 +246,7 @@ Status DeviceQueueOp::SendDataToAscend() {
MS_LOG(INFO) << "an epoch has already sent, now stop send data.";
stop_send_ = true;
}
if (isProfilingEnable) {
if (is_profiling_enable) {
connector_size = ChildOpConnectorSize();
connector_capacity = ChildOpConnectorCapacity();
tree_->SetEpochEnd();
@ -283,8 +283,8 @@ void DeviceQueueOp::LimitSendingBatches(int64_t send_batch, int64_t *sending_num
}
}
Status DeviceQueueOp::SendRowToTdt(TensorRow currRow, bool isProfilingEnable, int32_t *tdt_cost) {
auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, *tdt_cost);
Status DeviceQueueOp::SendRowToTdt(TensorRow curr_row, bool is_profiling_enable, int32_t *tdt_cost) {
auto status = tdtInstancePtr->hostPush(curr_row, true, channel_name_, is_profiling_enable, *tdt_cost);
if (status != Status::OK()) {
if (stop_send_) {
MS_LOG(INFO) << "stop_send received";
@ -300,7 +300,7 @@ Status DeviceQueueOp::SendRowToTdt(TensorRow currRow, bool isProfilingEnable, in
}
if (create_data_info_queue_) {
DATA_INFO data_info;
(void)std::transform(currRow.begin(), currRow.end(), std::back_inserter(data_info),
(void)std::transform(curr_row.begin(), curr_row.end(), std::back_inserter(data_info),
[](const std::shared_ptr<Tensor> &ts) { return std::make_pair(ts->type(), ts->shape()); });
RETURN_IF_NOT_OK(data_info_queue_ptr_->Add(data_info));
}
@ -348,6 +348,7 @@ Status DeviceQueueOp::SetThreadDevice() {
}
Status DeviceQueueOp::LaunchParallelCopyThread() {
RETURN_UNEXPECTED_IF_NULL(tree_);
// Every thread use cuda api should SetThreadDevice
RETURN_IF_NOT_OK(SetThreadDevice());
// CircularPool may not safe under multi-threads scenario, so one worker with one pool
@ -368,6 +369,7 @@ Status DeviceQueueOp::LaunchParallelCopyThread() {
}
Status DeviceQueueOp::PushDataToGPU() {
RETURN_UNEXPECTED_IF_NULL(tree_);
// Every thread use cuda api should SetThreadDevice
RETURN_IF_NOT_OK(SetThreadDevice());
TaskManager::FindMe()->Post();
@ -376,8 +378,8 @@ Status DeviceQueueOp::PushDataToGPU() {
int32_t connector_size = 0;
int32_t connector_capacity = 0;
std::shared_ptr<DeviceQueueTracing> profiling_node;
bool isProfilingEnable = tree_->GetProfilingManager()->IsProfilingEnable();
if (isProfilingEnable) {
bool is_profiling_enable = tree_->GetProfilingManager()->IsProfilingEnable();
if (is_profiling_enable) {
std::shared_ptr<Tracing> node;
RETURN_IF_NOT_OK(tree_->GetProfilingManager()->GetTracingNode(kDeviceQueueTracingName, &node));
profiling_node = std::dynamic_pointer_cast<DeviceQueueTracing>(node);
@ -421,7 +423,7 @@ Status DeviceQueueOp::PushDataToGPU() {
}
RETURN_IF_NOT_OK(RetryPushData(handle, items));
send_batch++;
if (isProfilingEnable) {
if (is_profiling_enable) {
uint64_t end_time = ProfilingTime::GetCurMilliSecond();
// record push data time
profiling_node->Record(TIME, TDT_PUSH_TIME, send_batch, push_cost, end_time);
@ -470,7 +472,7 @@ Status DeviceQueueOp::PushDataToGPU() {
}
Status DeviceQueueOp::RetryPushData(unsigned int handle, const std::vector<DataItemGpu> &items) {
bool flagLog = false;
bool flag_log = false;
while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) {
BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME);
if (ret) {
@ -479,9 +481,9 @@ Status DeviceQueueOp::RetryPushData(unsigned int handle, const std::vector<DataI
"Invalid data, check the output of dataset with creating iterator and print data item.");
} else {
if (!stop_send_) {
if (!flagLog) {
if (!flag_log) {
MS_LOG(DEBUG) << "Retry pushing data...";
flagLog = true;
flag_log = true;
}
continue;
}
@ -625,11 +627,11 @@ void DeviceQueueOp::Print(std::ostream &out, bool show_all) const {
}
}
void DeviceQueueOp::ProfilingRecorder(bool isProfilingEnable, std::shared_ptr<DeviceQueueTracing> profiling_node,
void DeviceQueueOp::ProfilingRecorder(bool is_profiling_enable, std::shared_ptr<DeviceQueueTracing> profiling_node,
int64_t send_batch, int32_t tdt_cost, uint64_t *batch_start_time,
uint64_t *end_time, int32_t connector_capacity, int32_t connector_size) {
// Record the pipeline profiling info
if (isProfilingEnable) {
if (is_profiling_enable) {
*end_time = ProfilingTime::GetCurMilliSecond();
// record push tdt time
profiling_node->Record(TIME, TDT_PUSH_TIME, send_batch + 1, tdt_cost, *end_time);

View File

@ -74,7 +74,7 @@ class DeviceQueueOp : public PipelineOp {
Status EoeReceived(int32_t worker_id) override;
const int32_t get_prefetch_size() { return prefetch_size_; }
const int32_t GetPrefetchSize() { return prefetch_size_; }
void StopSend() { stop_send_ = true; }
@ -103,9 +103,9 @@ class DeviceQueueOp : public PipelineOp {
Status operator()() override;
// Record the pipeline profiling info
void ProfilingRecorder(bool isProfilingEnable, std::shared_ptr<DeviceQueueTracing> profiling_node, int64_t send_batch,
int32_t tdt_cost, uint64_t *batch_start_time, uint64_t *end_time, int32_t connector_capacity,
int32_t connector_size);
void ProfilingRecorder(bool is_profiling_enable, std::shared_ptr<DeviceQueueTracing> profiling_node,
int64_t send_batch, int32_t tdt_cost, uint64_t *batch_start_time, uint64_t *end_time,
int32_t connector_capacity, int32_t connector_size);
// Op name getter
// @return Name of the current Op
@ -125,7 +125,7 @@ class DeviceQueueOp : public PipelineOp {
void WaitContinueSignal() const;
Status SendDataToAscend();
void LimitSendingBatches(int64_t send_batch, int64_t *sending_num, std::shared_ptr<ConfigManager> cfg);
Status SendRowToTdt(TensorRow currRow, bool isProfilingEnable, int32_t *tdt_cost);
Status SendRowToTdt(TensorRow curr_row, bool is_profiling_enable, int32_t *tdt_cost);
bool ascend_keep_waiting_;
#endif

View File

@ -218,7 +218,7 @@ Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate
return Status(StatusCode::kSuccess, "FilterOp predicate func call succeed");
}
int32_t FilterOp::num_consumers() const { return 1; }
int32_t FilterOp::NumConsumers() const { return 1; }
} // namespace dataset
} // namespace mindspore

View File

@ -69,7 +69,7 @@ class FilterOp : public ParallelOp {
// @return Name of the current Op
std::string Name() const override { return kFilterOp; }
int32_t num_consumers() const override;
int32_t NumConsumers() const override;
private:
// predicate_func python callable which returns a boolean value.

View File

@ -46,7 +46,7 @@ MapOp::MapOp(const std::vector<std::string> &in_col_names, const std::vector<std
}
// The number of threads consuming data from previous op's output Connector.
int32_t MapOp::num_consumers() const {
int32_t MapOp::NumConsumers() const {
// When Performance Mode is on, there is only one thread consuming from the previous Connector.
return 1;
}
@ -144,7 +144,7 @@ Status MapOp::operator()() {
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
while (!new_row.eof()) {
if (op_current_repeats_ % op_num_repeats_per_epoch() == 0) {
if (op_current_repeats_ % GetOpNumRepeatsPerEpoch() == 0) {
RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
}
while (!new_row.eoe()) {
@ -168,7 +168,7 @@ Status MapOp::operator()() {
}
// check whether this is the end of a real epoch (not all eoe signals end of epoch)
if ((op_current_repeats_ + 1) % op_num_repeats_per_epoch() == 0) {
if ((op_current_repeats_ + 1) % GetOpNumRepeatsPerEpoch() == 0) {
RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
ep_step = 0;

View File

@ -101,7 +101,7 @@ class MapOp : public ParallelOp {
// Getter
// @return the number of threads consuming data from previous op's output Connector.
int32_t num_consumers() const override;
int32_t NumConsumers() const override;
// Op name getter
// @return Name of the current Op

View File

@ -72,11 +72,11 @@ class ParallelOp : public DatasetOp {
// Getter
// @return the number of workers
int32_t num_workers() const override { return num_workers_; }
int32_t NumWorkers() const override { return num_workers_; }
// Getter
// @return the number of threads consuming from the previous Connector
int32_t num_consumers() const override { return num_workers_; }
int32_t NumConsumers() const override { return num_workers_; }
// Getter
// @return the number of producers pushing to the output Connector
@ -84,7 +84,7 @@ class ParallelOp : public DatasetOp {
// when a worker connector is set up. In that case, there are n workers, and a single master
// such that only 1 thread is a producer rather than the n workers.
// @return the number of producers
int32_t num_producers() const override { return num_producers_; }
int32_t NumProducers() const override { return num_producers_; }
// Register the internal worker connectors.
// @return Status

View File

@ -55,15 +55,15 @@ class PipelineOp : public DatasetOp {
// Getter
// @return The number of workers inside this op. Pipeline ops only have a single worker.
int32_t num_workers() const override { return 1; }
int32_t NumWorkers() const override { return 1; }
// Getter
// @return the number of threads consuming from the previous Connector
int32_t num_consumers() const override { return 1; }
int32_t NumConsumers() const override { return 1; }
// Getter
// @return The number of threads that push data to the output connector
int32_t num_producers() const override { return 1; }
int32_t NumProducers() const override { return 1; }
protected:
// *******************************************************************************

View File

@ -76,7 +76,7 @@ TensorRow ProjectOp::Project(const TensorRow &row) {
// ensure that it is not called by mistake (it will generate an error).
Status ProjectOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. ProjectOp is an inlined operator."); }
int32_t ProjectOp::num_consumers() const {
int32_t ProjectOp::NumConsumers() const {
if (parent_.empty()) {
MS_LOG(DEBUG) << "Project operator, no parent node, assuming it's the root and returning 1.";
return 1;
@ -84,16 +84,16 @@ int32_t ProjectOp::num_consumers() const {
MS_LOG(DEBUG) << "Project operator, pointer to the first parent is null. Returning 0.";
return 0;
} else {
return parent_[0]->num_consumers();
return parent_[0]->NumConsumers();
}
}
int32_t ProjectOp::num_producers() const {
int32_t ProjectOp::NumProducers() const {
if (child_.empty() || child_[0] == nullptr) {
MS_LOG(DEBUG) << "Project operator, pointer to child node is null. Returning 0.";
return 0;
} else {
return child_[0]->num_producers();
return child_[0]->NumProducers();
}
}

View File

@ -63,11 +63,11 @@ class ProjectOp : public PipelineOp {
// Base-class override. Return the number of workers in the first parent.
// @param workerId - The worker id
int32_t num_consumers() const override;
int32_t NumConsumers() const override;
// Base-class override. Return the number of producers in the first child.
// @param workerId - The worker id
int32_t num_producers() const override;
int32_t NumProducers() const override;
// Base-class override for special eoe handler.
// Inline operators must override this because there is no connector to push eoe onto.

View File

@ -113,7 +113,7 @@ void RenameOp::Print(std::ostream &out, // In: The output stream to print t
}
}
int32_t RenameOp::num_consumers() const {
int32_t RenameOp::NumConsumers() const {
if (parent_.empty()) {
MS_LOG(DEBUG) << "Rename operator, no parent node, assuming it's the root and returning 1.";
return 1;
@ -121,16 +121,16 @@ int32_t RenameOp::num_consumers() const {
MS_LOG(DEBUG) << "Rename operator, pointer to the first parent is null. Returning 0.";
return 0;
} else {
return parent_[0]->num_consumers();
return parent_[0]->NumConsumers();
}
}
int32_t RenameOp::num_producers() const {
int32_t RenameOp::NumProducers() const {
if (child_.empty() || child_[0] == nullptr) {
MS_LOG(DEBUG) << "Rename operator, pointer to child node is null. Returning 0.";
return 0;
} else {
return child_[0]->num_producers();
return child_[0]->NumProducers();
}
}
} // namespace dataset

View File

@ -63,8 +63,8 @@ class RenameOp : public PipelineOp {
// @param row - output pointer to the projected row.
// @param worker_id - The worker id
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
int32_t num_consumers() const override;
int32_t num_producers() const override;
int32_t NumConsumers() const override;
int32_t NumProducers() const override;
protected:
// Rename core functionality

View File

@ -115,7 +115,7 @@ Status RepeatOp::EofReceived(int32_t worker_id) {
return Status::OK();
}
int32_t RepeatOp::num_consumers() const {
int32_t RepeatOp::NumConsumers() const {
if (parent_.empty()) {
MS_LOG(DEBUG) << "Repeat operator, no parent node, assuming it's root and returning 1.";
return 1;
@ -123,7 +123,7 @@ int32_t RepeatOp::num_consumers() const {
MS_LOG(DEBUG) << "Repeat operator, pointer to the first parent is null. Returning 0.";
return 0;
} else {
return parent_[0]->num_consumers();
return parent_[0]->NumConsumers();
}
}
@ -140,12 +140,12 @@ Status RepeatOp::Reset() {
return Status::OK();
}
int32_t RepeatOp::num_producers() const {
int32_t RepeatOp::NumProducers() const {
if (child_.empty() || child_[0] == nullptr) {
MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0.";
return 0;
} else {
return child_[0]->num_producers();
return child_[0]->NumProducers();
}
}

View File

@ -79,11 +79,11 @@ class RepeatOp : public PipelineOp {
// Base-class override. Return the number of workers in the first parent.
// @param workerId - The worker id
int32_t num_consumers() const override;
int32_t NumConsumers() const override;
// Base-class override. Return the number of producers in the first child.
// @param workerId - The worker id
int32_t num_producers() const override;
int32_t NumProducers() const override;
// Op name getter
// @return Name of the current Op

View File

@ -66,7 +66,7 @@ Status SkipOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe)
return Status::OK();
}
int32_t SkipOp::num_consumers() const {
int32_t SkipOp::NumConsumers() const {
if (parent_.empty()) {
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
return 1;
@ -74,16 +74,16 @@ int32_t SkipOp::num_consumers() const {
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
return 0;
} else {
return parent_[0]->num_consumers();
return parent_[0]->NumConsumers();
}
}
int32_t SkipOp::num_producers() const {
int32_t SkipOp::NumProducers() const {
if (child_.empty() || child_[0] == nullptr) {
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
return 0;
} else {
return child_[0]->num_producers();
return child_[0]->NumProducers();
}
}

View File

@ -49,8 +49,8 @@ class SkipOp : public PipelineOp {
// @return Name of the current Op
std::string Name() const override { return kSkipOp; }
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
int32_t num_consumers() const override;
int32_t num_producers() const override;
int32_t NumConsumers() const override;
int32_t NumProducers() const override;
private:
int32_t max_skips_; // The number of skips that the user requested

View File

@ -175,9 +175,9 @@ bool CelebAOp::CheckDatasetTypeValid() {
Status CelebAOp::ParseImageAttrInfo() {
std::vector<std::string> image_infos;
bool needMoreData = true;
bool need_more_data = true;
RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos));
while (!image_infos.empty() && needMoreData) {
while (!image_infos.empty() && need_more_data) {
for (uint32_t index = 0; index < image_infos.size(); index++) {
std::string image_info = image_infos[index];
std::vector<std::string> split = Split(image_info);

View File

@ -201,13 +201,13 @@ Status CifarOp::ReadCifar100BlockData() {
}
Status CifarOp::GetCifarFiles() {
const std::string kExtension = ".bin";
const std::string extension = ".bin";
Path dir_path(folder_path_);
auto dirIt = Path::DirIterator::OpenDirectory(&dir_path);
if (dirIt) {
while (dirIt->HasNext()) {
Path file = dirIt->Next();
if (file.Extension() == kExtension) {
if (file.Extension() == extension) {
cifar_files_.push_back(file.ToString());
}
}

View File

@ -101,7 +101,7 @@ class CifarOp : public MappableLeafOp {
Status ParseCifarData();
/// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
/// @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class
/// @param (std::map<uint32_t, std::vector<uint32_t >> *cls_ids - val all ids for this class
/// @return Status The status code returned
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;

View File

@ -119,20 +119,20 @@ Status ClueOp::LoadFile(const std::string &file, int64_t start_offset, int64_t e
RETURN_STATUS_UNEXPECTED("Invalid file, failed to parse JSON file: " + file);
}
int cols_count = cols_to_keyword_.size();
TensorRow tRow(cols_count, nullptr);
TensorRow t_row(cols_count, nullptr);
// Add file path info
std::vector<std::string> file_path(cols_count, file);
tRow.setPath(file_path);
t_row.setPath(file_path);
int cout = 0;
for (auto &p : cols_to_keyword_) {
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(GetValue(js, p.second, &tensor));
tRow[cout] = std::move(tensor);
t_row[cout] = std::move(tensor);
cout++;
}
rows_total++;
RETURN_IF_NOT_OK(jagged_rows_connector_->Add(worker_id, std::move(tRow)));
RETURN_IF_NOT_OK(jagged_rows_connector_->Add(worker_id, std::move(t_row)));
}
return Status::OK();

View File

@ -215,7 +215,7 @@ Status GeneratorOp::operator()() {
// Waiting for repeatOp to start new epoch
// If Reset() is called first by repeat op, this wait() will return right away.
// If Reset() is not called yet, this wait() will block until reset.
if (this->op_total_repeats() < 0) {
if (this->GetOpTotalRepeats() < 0) {
RETURN_IF_NOT_OK(wp_.Wait());
// Clear the status of the wait post
wp_.Clear();
@ -235,7 +235,7 @@ Status GeneratorOp::Reset() {
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
// Create new generator object
RETURN_IF_NOT_OK(CreateGeneratorObject());
if (this->op_total_repeats() < 0) {
if (this->GetOpTotalRepeats() < 0) {
// Wake up master thread
wp_.Set();
}

View File

@ -84,20 +84,20 @@ Status ImageFolderOp::PrescanMasterEntry(const std::string &filedir) {
// Load 1 TensorRow (image,label) using 1 ImageLabelPair. 1 function call produces 1 TensorTow
Status ImageFolderOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
ImageLabelPair pairPtr = image_label_pairs_[row_id];
ImageLabelPair pair_ptr = image_label_pairs_[row_id];
std::shared_ptr<Tensor> image, label;
RETURN_IF_NOT_OK(Tensor::CreateScalar(pairPtr->second, &label));
RETURN_IF_NOT_OK(Tensor::CreateFromFile(folder_path_ + (pairPtr->first), &image));
RETURN_IF_NOT_OK(Tensor::CreateScalar(pair_ptr->second, &label));
RETURN_IF_NOT_OK(Tensor::CreateFromFile(folder_path_ + (pair_ptr->first), &image));
if (decode_ == true) {
Status rc = Decode(image, &image);
if (rc.IsError()) {
std::string err = "Invalid data, failed to decode image: " + folder_path_ + (pairPtr->first);
std::string err = "Invalid data, failed to decode image: " + folder_path_ + (pair_ptr->first);
RETURN_STATUS_UNEXPECTED(err);
}
}
(*trow) = TensorRow(row_id, {std::move(image), std::move(label)});
trow->setPath({folder_path_ + (pairPtr->first), std::string("")});
trow->setPath({folder_path_ + (pair_ptr->first), std::string("")});
return Status::OK();
}

View File

@ -57,11 +57,14 @@ class ImageFolderOp : public MappableLeafOp {
// @param int32_t num_wkrs - Num of workers reading images in parallel
// @param std::string - dir directory of ImageNetFolder
// @param int32_t queue_size - connector queue size
// @param std::set<std::string> exts - set of file extensions to read, if empty, read everything under the dir
// @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
// @param bool recursive - read recursively
// @param bool do_decode - decode the images after reading
// @param std::set<std::string> &exts - set of file extensions to read, if empty, read everything under the dir
// @param std::map<std::string, int32_t> &map- map of folder name and class id
// @param std::unique_ptr<dataschema> data_schema - schema of data
ImageFolderOp(int32_t num_wkrs, std::string file_dir, int32_t queue_size, bool recursive, bool do_decode,
const std::set<std::string> &exts, const std::map<std::string, int32_t> &map,
std::unique_ptr<DataSchema>, std::shared_ptr<SamplerRT> sampler);
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
/// Destructor.
~ImageFolderOp() = default;

View File

@ -260,7 +260,9 @@ Status MnistOp::WalkAllFiles() {
const std::string train_prefix = "train";
const std::string test_prefix = "t10k";
Path dir(folder_path_);
std::string real_path{""};
RETURN_IF_NOT_OK(Path::RealPath(folder_path_, real_path));
Path dir(real_path);
auto dir_it = Path::DirIterator::OpenDirectory(&dir);
std::string prefix; // empty string, used to match usage = "" (default) or usage == "all"
if (usage_ == "train" || usage_ == "test") prefix = (usage_ == "test" ? test_prefix : train_prefix);

View File

@ -68,15 +68,15 @@ void RandomDataOp::Print(std::ostream &out, bool show_all) const {
// Helper function to produce a default/random schema if one didn't exist
void RandomDataOp::GenerateSchema() {
const int32_t kTypeOffset = 2;
const int32_t type_offset = 2;
// To randomly create a schema, we need to choose:
// a) how many columns
// b) the type of each column
// c) the shape of each column (number of dimensions i.e. rank)
// d) the shape of each column (dimension values)
data_schema_ = std::make_unique<DataSchema>();
std::unique_ptr<TensorShape> newShape;
std::unique_ptr<ColDescriptor> newCol;
std::unique_ptr<TensorShape> new_shape;
std::unique_ptr<ColDescriptor> new_col;
// Loop over the number of chosen columns
int32_t numColumns = GenRandomInt(1, kMaxNumColumns);
@ -84,23 +84,26 @@ void RandomDataOp::GenerateSchema() {
// For each column:
// - choose a datatype
// - generate a shape that randomly chooses the number of dimensions and the dimension values.
DataType::Type newType = static_cast<DataType::Type>(GenRandomInt(1, DataType::NUM_OF_TYPES - kTypeOffset));
DataType::Type newType = static_cast<DataType::Type>(GenRandomInt(1, DataType::NUM_OF_TYPES - type_offset));
int32_t rank = GenRandomInt(1, kMaxRank);
std::vector<dsize_t> dims;
for (int32_t d = 0; d < rank; d++) {
// 0 is not a valid dimension value. however, we can support "*" or unknown, so map the random
// 0 value to the unknown attribute if 0 is chosen
dsize_t dim_value = static_cast<dsize_t>(GenRandomInt(0, kMaxDimValue));
if (dim_value == 0) dim_value = TensorShape::kDimUnknown;
if (dim_value == 0) {
dim_value = TensorShape::kDimUnknown;
}
dims.push_back(dim_value);
}
newShape = std::make_unique<TensorShape>(dims);
new_shape = std::make_unique<TensorShape>(dims);
// Create the column descriptor
std::string colName = "c" + std::to_string(i);
newCol = std::make_unique<ColDescriptor>(colName, DataType(newType), TensorImpl::kFlexible, rank, newShape.get());
std::string col_name = "c" + std::to_string(i);
new_col =
std::make_unique<ColDescriptor>(col_name, DataType(newType), TensorImpl::kFlexible, rank, new_shape.get());
Status rc = data_schema_->AddColumn(*newCol);
Status rc = data_schema_->AddColumn(*new_col);
if (rc.IsError()) MS_LOG(ERROR) << "Failed to generate a schema. Message:" << rc;
}
}
@ -125,6 +128,9 @@ Status RandomDataOp::operator()() {
DatasetOp::CreateConnector(num_producers_, num_workers_);
}
if (num_workers_ == 0) {
RETURN_STATUS_UNEXPECTED("Invalid data, num_workers_ is zero.");
}
// Assign the number of rows to each worker in a round robin fashion.
worker_max_rows_.reserve(num_workers_);
worker_rows_packed_.reserve(num_workers_);

View File

@ -681,7 +681,7 @@ Status TFReaderOp::CountTotalRows(int64_t *out_total_rows, const std::vector<std
int64_t TFReaderOp::CountTotalRowsSectioned(const std::vector<std::string> &filenames, int64_t begin, int64_t end) {
int64_t rows_read = 0;
for (int i = begin; i < end; i++) {
for (int64_t i = begin; i < end; ++i) {
auto realpath = Common::GetRealPath(filenames[i]);
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Get real path failed, path=" << filenames[i];

View File

@ -196,10 +196,10 @@ Status VOCOp::ParseImageIds() {
Status VOCOp::ParseAnnotationIds() {
std::vector<std::string> new_image_ids;
for (auto id : image_ids_) {
const std::string kAnnotationName =
const std::string annotation_name =
folder_path_ + std::string(kAnnotationsFolder) + id + std::string(kAnnotationExtension);
RETURN_IF_NOT_OK(ParseAnnotationBbox(kAnnotationName));
if (annotation_map_.find(kAnnotationName) != annotation_map_.end()) {
RETURN_IF_NOT_OK(ParseAnnotationBbox(annotation_name));
if (annotation_map_.find(annotation_name) != annotation_map_.end()) {
new_image_ids.push_back(id);
}
}
@ -226,7 +226,9 @@ void VOCOp::ParseNodeValue(XMLElement *bbox_node, const char *name, float *value
*value = 0.0;
if (bbox_node != nullptr) {
XMLElement *node = bbox_node->FirstChildElement(name);
if (node != nullptr) *value = node->FloatText();
if (node != nullptr) {
*value = node->FloatText();
}
}
}

View File

@ -69,7 +69,7 @@ Status TakeOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe)
return Status::OK();
}
int32_t TakeOp::num_consumers() const {
int32_t TakeOp::NumConsumers() const {
if (parent_.empty()) {
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
return 1;
@ -77,16 +77,16 @@ int32_t TakeOp::num_consumers() const {
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
return 0;
} else {
return parent_[0]->num_consumers();
return parent_[0]->NumConsumers();
}
}
int32_t TakeOp::num_producers() const {
int32_t TakeOp::NumProducers() const {
if (child_.empty() || child_[0] == nullptr) {
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
return 0;
} else {
return child_[0]->num_producers();
return child_[0]->NumProducers();
}
}
} // namespace dataset

View File

@ -59,8 +59,8 @@ class TakeOp : public PipelineOp {
std::string Name() const override { return kTakeOp; }
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
int32_t num_consumers() const override;
int32_t num_producers() const override;
int32_t NumConsumers() const override;
int32_t NumProducers() const override;
private:
int32_t max_takes_; // The number of takes that the user requested

View File

@ -132,7 +132,7 @@ Status ZipOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) {
return Status::OK();
}
int32_t ZipOp::num_consumers() const {
int32_t ZipOp::NumConsumers() const {
if (parent_.empty()) {
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
return 1;
@ -140,16 +140,16 @@ int32_t ZipOp::num_consumers() const {
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
return 0;
} else {
return parent_[0]->num_consumers();
return parent_[0]->NumConsumers();
}
}
int32_t ZipOp::num_producers() const {
int32_t ZipOp::NumProducers() const {
if (child_.empty() || child_[0] == nullptr) {
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
return 0;
} else {
return child_[0]->num_producers();
return child_[0]->NumProducers();
}
}
} // namespace dataset

View File

@ -65,8 +65,8 @@ class ZipOp : public PipelineOp {
std::string Name() const override { return kZipOp; }
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
int32_t num_consumers() const override;
int32_t num_producers() const override;
int32_t NumConsumers() const override;
int32_t NumProducers() const override;
private:
// Special handle case where an empty row has been received from child iterator

View File

@ -79,7 +79,7 @@ Status ExecutionTree::AssociateNode(const std::shared_ptr<DatasetOp> &op) {
tree_state_ = kDeTStateBuilding;
// Assign an id to the operator
op->set_id(id_count_);
op->SetId(id_count_);
id_count_++;
// Assign our tree into the op so that each op has a link back to the tree

View File

@ -104,8 +104,8 @@ Status BatchNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
auto op = std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
in_col_names_, out_col_names_, batch_size_func_, batch_map_func_, pad_map_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
#else
node_ops->push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,

View File

@ -87,8 +87,8 @@ Status BucketBatchByLengthNode::Build(std::vector<std::shared_ptr<DatasetOp>> *c
auto op = std::make_shared<BucketBatchByLengthOp>(column_names_, bucket_boundaries_, bucket_batch_sizes_,
element_length_function_, pad_info_, pad_to_bucket_boundary_,
drop_remainder_, connector_que_size_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
if (bucket_boundaries_[0] == 0) {
bucket_boundaries_.erase(bucket_boundaries_.begin());

View File

@ -56,8 +56,8 @@ void BuildSentenceVocabNode::Print(std::ostream &out) const {
Status BuildSentenceVocabNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto op = std::make_shared<BuildSentencePieceVocabOp>(vocab_, col_names_, vocab_size_, character_coverage_,
model_type_, params_, connector_que_size_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -54,8 +54,8 @@ Status BuildVocabNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node
std::shared_ptr<BuildVocabOp> build_vocab_op;
build_vocab_op = std::make_shared<BuildVocabOp>(vocab_, columns_, freq_range_, top_k_, special_tokens_,
special_first_, num_workers_, connector_que_size_);
build_vocab_op->set_total_repeats(GetTotalRepeats());
build_vocab_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
build_vocab_op->SetTotalRepeats(GetTotalRepeats());
build_vocab_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(build_vocab_op);
return Status::OK();
}

View File

@ -51,8 +51,8 @@ Status CacheLookupNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops)
"Internal error. Attempt to create a cache lookup node without cache client.");
RETURN_IF_NOT_OK(cache_->Build());
RETURN_IF_NOT_OK(cache_->CreateCacheLookupOp(num_workers_, &lookup_op_, sampler_));
lookup_op_->set_total_repeats(GetTotalRepeats());
lookup_op_->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
lookup_op_->SetTotalRepeats(GetTotalRepeats());
lookup_op_->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(lookup_op_);
return Status::OK();
}

View File

@ -48,8 +48,8 @@ Status CacheMergeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops)
RETURN_IF_NOT_OK(cache_->Build());
std::shared_ptr<DatasetOp> merge_op = nullptr;
RETURN_IF_NOT_OK(cache_->CreateCacheMergeOp(num_workers_, &merge_op));
merge_op->set_total_repeats(GetTotalRepeats());
merge_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
merge_op->SetTotalRepeats(GetTotalRepeats());
merge_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(merge_op);
return Status::OK();
}

View File

@ -55,8 +55,8 @@ Status CacheNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
cache_op->SetSampler(sampler_rt);
cache_op->set_total_repeats(GetTotalRepeats());
cache_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
cache_op->SetTotalRepeats(GetTotalRepeats());
cache_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(cache_op);
return Status::OK();
}

View File

@ -130,8 +130,8 @@ Status ConcatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
op = std::make_shared<ConcatOp>(sampler_rt, children_flag_and_nums_, children_start_end_index_);
}
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();

View File

@ -251,7 +251,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
void HasCacheAbove() { descendant_of_cache_ = true; }
/// \brief Getter of the number of workers
int32_t num_workers() { return num_workers_; }
int32_t NumWorkers() { return num_workers_; }
/// \brief Getter of dataset cache
std::shared_ptr<DatasetCache> GetDatasetCache() { return cache_; }

View File

@ -46,8 +46,8 @@ void EpochCtrlNode::Print(std::ostream &out) const {
// Function to build the EpochCtrlOp
Status EpochCtrlNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto new_op_ = std::make_shared<EpochCtrlOp>(repeat_count_);
new_op_->set_total_repeats(GetTotalRepeats());
new_op_->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
new_op_->SetTotalRepeats(GetTotalRepeats());
new_op_->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(new_op_);
op_ = new_op_;
return Status::OK();

View File

@ -45,8 +45,8 @@ void FilterNode::Print(std::ostream &out) const {
Status FilterNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto op = std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, predicate_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -86,12 +86,12 @@ Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
if (!project_columns_.empty()) {
auto project_op = std::make_shared<ProjectOp>(project_columns_);
project_op->set_total_repeats(GetTotalRepeats());
project_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
project_op->SetTotalRepeats(GetTotalRepeats());
project_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(project_op);
}
map_op->set_total_repeats(GetTotalRepeats());
map_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
map_op->SetTotalRepeats(GetTotalRepeats());
map_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(map_op);
return Status::OK();
}

View File

@ -54,8 +54,8 @@ Status ProjectNode::ValidateParams() {
Status ProjectNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto op = std::make_shared<ProjectOp>(columns_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -59,8 +59,8 @@ Status RenameNode::ValidateParams() {
Status RenameNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto op = std::make_shared<RenameOp>(input_columns_, output_columns_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -40,8 +40,8 @@ void RepeatNode::Print(std::ostream &out) const { out << (Name() + "(count:" + s
Status RepeatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto new_op = std::make_shared<RepeatOp>(repeat_count_);
new_op->set_total_repeats(GetTotalRepeats());
new_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
new_op->SetTotalRepeats(GetTotalRepeats());
new_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(new_op);
op_ = new_op;

View File

@ -45,8 +45,8 @@ void ShuffleNode::Print(std::ostream &out) const {
// Function to build the ShuffleOp
Status ShuffleNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto op = std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -40,8 +40,8 @@ void SkipNode::Print(std::ostream &out) const { out << (Name() + "(skip_count:"
// Function to build the SkipOp
Status SkipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto op = std::make_shared<SkipOp>(skip_count_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -76,8 +76,8 @@ Status AlbumNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
auto album_op = std::make_shared<AlbumOp>(num_workers_, dataset_dir_, connector_que_size_, decode_, extensions,
std::move(schema), std::move(sampler_rt));
album_op->set_total_repeats(GetTotalRepeats());
album_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
album_op->SetTotalRepeats(GetTotalRepeats());
album_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(album_op);
return Status::OK();
}

View File

@ -72,8 +72,8 @@ Status CelebANode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
auto celeba_op = std::make_shared<CelebAOp>(num_workers_, dataset_dir_, connector_que_size_, decode_, usage_,
extensions_, std::move(schema), std::move(sampler_rt));
celeba_op->set_total_repeats(GetTotalRepeats());
celeba_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
celeba_op->SetTotalRepeats(GetTotalRepeats());
celeba_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(celeba_op);
return Status::OK();

View File

@ -68,8 +68,8 @@ Status Cifar100Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
auto cifar_op = std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, dataset_dir_,
connector_que_size_, std::move(schema), std::move(sampler_rt));
cifar_op->set_total_repeats(GetTotalRepeats());
cifar_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
cifar_op->SetTotalRepeats(GetTotalRepeats());
cifar_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(cifar_op);
return Status::OK();

View File

@ -66,8 +66,8 @@ Status Cifar10Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_op
auto cifar_op = std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, dataset_dir_,
connector_que_size_, std::move(schema), std::move(sampler_rt));
cifar_op->set_total_repeats(GetTotalRepeats());
cifar_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
cifar_op->SetTotalRepeats(GetTotalRepeats());
cifar_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(cifar_op);
return Status::OK();

View File

@ -196,12 +196,12 @@ Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
// Add the shuffle op after this op
RETURN_IF_NOT_OK(
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
shuffle_op->set_total_repeats(GetTotalRepeats());
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
shuffle_op->SetTotalRepeats(GetTotalRepeats());
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(shuffle_op);
}
clue_op->set_total_repeats(GetTotalRepeats());
clue_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
clue_op->SetTotalRepeats(GetTotalRepeats());
clue_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(clue_op);
return Status::OK();

View File

@ -134,8 +134,8 @@ Status CocoNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
std::shared_ptr<CocoOp> op =
std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, connector_que_size_, decode_,
std::move(schema), std::move(sampler_rt), extra_metadata_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();

View File

@ -134,12 +134,12 @@ Status CSVNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
// Add the shuffle op after this op
RETURN_IF_NOT_OK(
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
shuffle_op->set_total_repeats(GetTotalRepeats());
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
shuffle_op->SetTotalRepeats(GetTotalRepeats());
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(shuffle_op);
}
csv_op->set_total_repeats(GetTotalRepeats());
csv_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
csv_op->SetTotalRepeats(GetTotalRepeats());
csv_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(csv_op);
return Status::OK();

View File

@ -97,8 +97,8 @@ Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_
if (reset_ancestor_ != nullptr) {
reset_ancestor_->op_->AddToEoeList(op);
}
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -74,8 +74,8 @@ Status ImageFolderNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const nod
auto op = std::make_shared<ImageFolderOp>(num_workers_, dataset_dir_, connector_que_size_, recursive_, decode_, exts_,
class_indexing_, std::move(schema), std::move(sampler_rt));
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -96,8 +96,8 @@ Status ManifestNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
manifest_op = std::make_shared<ManifestOp>(num_workers_, dataset_file_, connector_que_size_, decode_, class_index_,
std::move(schema), std::move(sampler_rt), usage_);
manifest_op->set_total_repeats(GetTotalRepeats());
manifest_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
manifest_op->SetTotalRepeats(GetTotalRepeats());
manifest_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(manifest_op);
return Status::OK();

View File

@ -212,8 +212,8 @@ Status MindDataNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
}
RETURN_IF_NOT_OK(mindrecord_op->Init());
mindrecord_op->set_total_repeats(GetTotalRepeats());
mindrecord_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
mindrecord_op->SetTotalRepeats(GetTotalRepeats());
mindrecord_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(mindrecord_op);
return Status::OK();

View File

@ -62,8 +62,8 @@ Status MnistNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
auto op = std::make_shared<MnistOp>(usage_, num_workers_, dataset_dir_, connector_que_size_, std::move(schema),
std::move(sampler_rt));
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();

View File

@ -110,8 +110,8 @@ Status RandomNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
std::shared_ptr<RandomDataOp> op;
op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, total_rows_, std::move(data_schema_));
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();

View File

@ -102,12 +102,12 @@ Status TextFileNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
// Add the shuffle op after this op
RETURN_IF_NOT_OK(
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
shuffle_op->set_total_repeats(GetTotalRepeats());
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
shuffle_op->SetTotalRepeats(GetTotalRepeats());
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(shuffle_op);
}
text_file_op->set_total_repeats(GetTotalRepeats());
text_file_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
text_file_op->SetTotalRepeats(GetTotalRepeats());
text_file_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
// Add TextFileOp
node_ops->push_back(text_file_op);

View File

@ -142,12 +142,12 @@ Status TFRecordNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
// Add the shuffle op after this op
RETURN_IF_NOT_OK(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
shuffle_op->set_total_repeats(GetTotalRepeats());
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
shuffle_op->SetTotalRepeats(GetTotalRepeats());
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(shuffle_op);
}
tf_reader_op->set_total_repeats(GetTotalRepeats());
tf_reader_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
tf_reader_op->SetTotalRepeats(GetTotalRepeats());
tf_reader_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
// Add TFReaderOp
node_ops->push_back(tf_reader_op);
return Status::OK();

View File

@ -122,8 +122,8 @@ Status VOCNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
std::shared_ptr<VOCOp> voc_op;
voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, connector_que_size_,
decode_, std::move(schema), std::move(sampler_rt), extra_metadata_);
voc_op->set_total_repeats(GetTotalRepeats());
voc_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
voc_op->SetTotalRepeats(GetTotalRepeats());
voc_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(voc_op);
return Status::OK();
}

View File

@ -46,8 +46,8 @@ Status SyncWaitNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
// The reason for this is because having it otherwise can lead to blocking issues
// See barrier_op.h for more details
auto op = std::make_shared<BarrierOp>(connector_que_size_, condition_name_, callback_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -41,8 +41,8 @@ void TakeNode::Print(std::ostream &out) const { out << (Name() + "(num_rows:" +
// Function to build the TakeOp
Status TakeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto op = std::make_shared<TakeOp>(take_count_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -97,8 +97,8 @@ Status TransferNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
auto op = std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_,
total_batch_, create_data_info_queue_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -60,8 +60,8 @@ Status ZipNode::ValidateParams() {
Status ZipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto op = std::make_shared<ZipOp>(connector_que_size_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -68,10 +68,10 @@ Status AutoWorkerPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *con
int32_t cur_node_num_worker = std::max(std::min(num_workers, cur_node_max), min_num_workers_);
// if the num_worker to set is same as original, skip setting and printing the logs
if (cur_node_num_worker == p.first->num_workers()) continue;
if (cur_node_num_worker == p.first->NumWorkers()) continue;
// log the change via warning msg so user can see what the num_worker is being set for which op
MS_LOG(WARNING) << "AutoNumWorker enabled, num_workers in " << p.first->Name() << " is auto-adjusted from "
<< std::to_string(p.first->num_workers()) + " to " + std::to_string(cur_node_num_worker);
<< std::to_string(p.first->NumWorkers()) + " to " + std::to_string(cur_node_num_worker);
p.first->SetNumWorkers(cur_node_num_worker);
}
return Status::OK();

View File

@ -49,7 +49,7 @@ Status DeepCopyPass::Visit(std::shared_ptr<DatasetNode> node, bool *const modifi
// Temporary fix to set the num_workers to each cloned node.
// This can be improved by adding a new method in the base class DatasetNode to transfer the properties to
// the cloned node. Each derived class's Copy() will need to include this method.
new_node->SetNumWorkers(node->num_workers());
new_node->SetNumWorkers(node->NumWorkers());
// This method below assumes a DFS walk and from the first child to the last child.
// Future: A more robust implementation that does not depend on the above assumption.
RETURN_IF_NOT_OK(parent_->AppendChild(new_node));

View File

@ -42,7 +42,7 @@ json ConnectorSize::ParseOpInfo(const DatasetOp &node, const std::vector<int32_t
json json_node;
json_node["op_id"] = node.id();
json_node["op_type"] = node.Name();
json_node["num_workers"] = node.num_workers();
json_node["num_workers"] = node.NumWorkers();
json metrics;
// DeviceQueueOp is a special op,it is not inlined but its output queue is invalid.
// So we should not output its queue size.

View File

@ -78,7 +78,7 @@ json ConnectorThroughput::ParseOpInfo(const DatasetOp &node, const std::vector<d
json json_node;
json_node["op_id"] = node.id();
json_node["op_type"] = node.Name();
json_node["num_workers"] = node.num_workers();
json_node["num_workers"] = node.NumWorkers();
json metrics;
// DeviceQueueOp is a special op,it is not inlined but its output queue is invalid.
// So we should not output its connector throughput.

View File

@ -264,7 +264,7 @@ Status OperatorCpu::Collect(const ExecutionTree *tree) {
for (auto iter = tree->begin(); iter != tree->end(); ++iter) {
id_count_++;
op_name_[iter->id()] = iter->NameWithID();
op_parallel_workers_[iter->id()] = iter->num_workers();
op_parallel_workers_[iter->id()] = iter->NumWorkers();
}
#if defined(USING_LINUX)
cpu_processor_num_ = get_nprocs_conf();

View File

@ -219,7 +219,7 @@ Status TreeAdapter::GetNext(TensorRow *row) {
RETURN_UNEXPECTED_IF_NULL(row);
row->clear(); // make sure row is empty
bool isProfilingEnable = tree_->GetProfilingManager()->IsProfilingEnable();
bool is_profiling_enable = tree_->GetProfilingManager()->IsProfilingEnable();
// When cur_db_ is a nullptr, it means this is the first call to get_next, launch ExecutionTree
if (!launched_) {
@ -229,7 +229,7 @@ Status TreeAdapter::GetNext(TensorRow *row) {
RETURN_IF_NOT_OK(tree_->root()->GetNextRow(row)); // first buf can't be eof or empty buf with none flag
if (row->eoe()) { // return empty tensor if 1st buf is a ctrl buf (no rows)
MS_LOG(INFO) << "End of data iteration.";
if (isProfilingEnable) {
if (is_profiling_enable) {
tree_->SetEpochEnd();
}
return Status::OK();

View File

@ -119,6 +119,9 @@ class Execute {
/// \brief The function to validate target device setting is valid or not.
Status ValidateDevice();
/// \brief Initialize 310 resource
Status InitResource(MapTargetDevice device_type, uint32_t device_id);
std::vector<std::shared_ptr<TensorTransform>> transforms_;
std::vector<std::shared_ptr<TensorOperation>> ops_;
MapTargetDevice device_type_;

View File

@ -251,6 +251,14 @@ static bool Conv2DImplement(const LiteMat &src, const LiteMat &kernel, T2 *dst,
int border_y = static_cast<int>(kernel.height_ / 2);
LiteMat pad_mat;
if ((border_x > INT_MAX / 2) || (src.width_ > INT_MAX - 2 * border_x)) {
return false;
}
if ((border_y > INT_MAX / 2) || (src.height_ > INT_MAX - 2 * border_y)) {
return false;
}
pad_mat.Init(src.width_ + 2 * border_x, src.height_ + 2 * border_y, src.channel_, src.data_type_);
if (!Pad(src, pad_mat, border_y, border_y, border_x, border_x, pad_type)) {

View File

@ -30,7 +30,7 @@ Status ComputeUpperAndLowerPercentiles(std::vector<int32_t> *hist, int32_t hi_p,
int32_t n = std::accumulate(hist->begin(), hist->end(), 0);
constexpr float kMaxPerc = 100.0;
int32_t cut = static_cast<int32_t>((low_p / kMaxPerc) * n);
for (int32_t lb = 0; lb < hist->size() + 1 && cut > 0; lb++) {
for (int32_t lb = 0; lb < hist->size() && cut > 0; lb++) {
if (cut > (*hist)[lb]) {
cut -= (*hist)[lb];
(*hist)[lb] = 0;

View File

@ -41,7 +41,9 @@ Status MixUpBatchOp::ComputeLabels(const TensorRow &input, std::shared_ptr<Tenso
const float lam, const size_t images_size) {
CHECK_FAIL_RETURN_UNEXPECTED(images_size <= static_cast<size_t>(std::numeric_limits<int64_t>::max()),
"The \"images_size\" must not be more than \"INT64_MAX\".");
for (int64_t i = 0; i < static_cast<int64_t>(images_size); i++) rand_indx->push_back(i);
for (int64_t i = 0; i < static_cast<int64_t>(images_size); i++) {
rand_indx->push_back(i);
}
std::shuffle(rand_indx->begin(), rand_indx->end(), rnd_);
RETURN_IF_NOT_OK(TypeCast(std::move(input.at(1)), out_labels, DataType(DataType::DE_FLOAT32)));

View File

@ -48,6 +48,12 @@ RandomAffineOp::RandomAffineOp(std::vector<float_t> degrees, std::vector<float_t
Status RandomAffineOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
CHECK_FAIL_RETURN_UNEXPECTED(translate_range_.size() == 4, "RandomAffine: the translate range size is not 4.");
CHECK_FAIL_RETURN_UNEXPECTED(degrees_range_.size() == 2, "RandomAffine: the degrees range size is not 2.");
CHECK_FAIL_RETURN_UNEXPECTED(scale_range_.size() == 2, "RandomAffine: the scale range size is not 2.");
CHECK_FAIL_RETURN_UNEXPECTED(shear_ranges_.size() == 4, "RandomAffine: the shear ranges size is not 4.");
dsize_t height = input->shape()[0];
dsize_t width = input->shape()[1];
CHECK_FAIL_RETURN_UNEXPECTED((std::numeric_limits<float_t>::max() / std::abs(translate_range_[0])) > width,

View File

@ -20,7 +20,10 @@
namespace mindspore {
namespace dataset {
IntrpService::IntrpService() : high_water_mark_(0) { (void)ServiceStart(); }
IntrpService::IntrpService() try : high_water_mark_(0) { (void)ServiceStart(); } catch (const std::exception &e) {
MS_LOG(ERROR) << "Interrupt service failed: " << e.what() << ".";
std::terminate();
}
IntrpService::~IntrpService() noexcept {
MS_LOG(INFO) << "Number of registered resources is " << high_water_mark_ << ".";

View File

@ -63,7 +63,6 @@ void RWLock::Unlock() noexcept {
void RWLock::Upgrade() {
std::unique_lock<std::mutex> lck(mtx_);
MS_ASSERT(status_);
if (status_ == -1) {
// I am a writer already.
return;
@ -81,7 +80,6 @@ void RWLock::Upgrade() {
void RWLock::Downgrade() {
std::unique_lock<std::mutex> lck(mtx_);
MS_ASSERT(status_);
if (status_ == -1) {
// If there are no other writers waiting, just change the status
if (waiting_writers_ == 0) {
@ -111,26 +109,18 @@ SharedLock::~SharedLock() {
}
void SharedLock::Unlock() {
MS_ASSERT(ownlock_ == true);
rw_->Unlock();
ownlock_ = false;
}
void SharedLock::Lock() {
MS_ASSERT(ownlock_ == false);
rw_->LockShared();
ownlock_ = true;
}
void SharedLock::Upgrade() {
MS_ASSERT(ownlock_ == true);
rw_->Upgrade();
}
void SharedLock::Upgrade() { rw_->Upgrade(); }
void SharedLock::Downgrade() {
MS_ASSERT(ownlock_ == true);
rw_->Downgrade();
}
void SharedLock::Downgrade() { rw_->Downgrade(); }
UniqueLock::UniqueLock(RWLock *rw) : rw_(rw), ownlock_(false) {
rw_->LockExclusive();
@ -146,13 +136,11 @@ UniqueLock::~UniqueLock() {
}
void UniqueLock::Unlock() {
MS_ASSERT(ownlock_ == true);
rw_->Unlock();
ownlock_ = false;
}
void UniqueLock::Lock() {
MS_ASSERT(ownlock_ == false);
rw_->LockExclusive();
ownlock_ = true;
}
@ -171,13 +159,11 @@ LockGuard::~LockGuard() {
}
void LockGuard::Unlock() {
MS_ASSERT(own_lock_);
lck_->Unlock();
own_lock_ = false;
}
void LockGuard::Lock() {
MS_ASSERT(own_lock_ == false);
lck_->Lock();
own_lock_ = true;
}

View File

@ -39,9 +39,9 @@ class Path {
private:
explicit DirIterator(Path *f);
Path *dir_;
DIR *dp_;
struct dirent *entry_;
Path *dir_ = nullptr;
DIR *dp_ = nullptr;
struct dirent *entry_ = nullptr;
};
explicit Path(const std::string &);

View File

@ -43,11 +43,11 @@ std::vector<std::string> StringSplit(const std::string &field, char separator) {
}
bool ValidateFieldName(const std::string &str) {
std::string::const_iterator it = str.begin();
if (it == str.end()) {
auto it = str.cbegin();
if (it == str.cend()) {
return false;
}
for (; it != str.end(); ++it) {
for (; it != str.cend(); ++it) {
if (*it == '_' || ((*it >= '0') && (*it <= '9')) || ((*it >= 'A') && (*it <= 'Z')) ||
((*it >= 'a') && (*it <= 'z'))) {
continue;
@ -83,8 +83,7 @@ std::pair<MSRStatus, std::string> GetFileName(const std::string &path) {
}
#endif
std::string s = real_path;
char sep = '/';
size_t i = s.rfind(sep, s.length());
size_t i = s.rfind(kPathSeparator, s.length());
if (i != std::string::npos) {
if (i + 1 < s.size()) {
return {SUCCESS, s.substr(i + 1)};
@ -119,10 +118,10 @@ std::pair<MSRStatus, std::string> GetParentDir(const std::string &path) {
}
#endif
std::string s = real_path;
if (s.rfind('/') + 1 <= s.size()) {
return {SUCCESS, s.substr(0, s.rfind('/') + 1)};
if (s.rfind(kPathSeparator) + 1 <= s.size()) {
return {SUCCESS, s.substr(0, s.rfind(kPathSeparator) + 1)};
}
return {SUCCESS, "/"};
return {SUCCESS, std::string()};
}
bool CheckIsValidUtf8(const std::string &str) {
@ -155,7 +154,7 @@ bool CheckIsValidUtf8(const std::string &str) {
bool IsLegalFile(const std::string &path) {
struct stat s;
if (stat(common::SafeCStr(path), &s) == 0) {
if (s.st_mode & S_IFDIR) {
if (S_ISDIR(s.st_mode)) {
return false;
}
return true;

Some files were not shown because too many files have changed in this diff Show More