sync code

This commit is contained in:
liyong 2021-09-27 10:57:52 +08:00
parent dc6e62227d
commit 567f137a39
129 changed files with 532 additions and 504 deletions

View File

@ -34,7 +34,7 @@ bool set_seed(int32_t seed) {
MS_LOG(ERROR) << "Seed given is not within the required range: " << seed; MS_LOG(ERROR) << "Seed given is not within the required range: " << seed;
return false; return false;
} }
_config->set_seed((uint32_t)seed); _config->set_seed(static_cast<uint32_t>(seed));
return true; return true;
} }
@ -73,7 +73,7 @@ bool set_monitor_sampling_interval(int32_t interval) {
MS_LOG(ERROR) << "Interval given is not within the required range: " << interval; MS_LOG(ERROR) << "Interval given is not within the required range: " << interval;
return false; return false;
} }
_config->set_monitor_sampling_interval((uint32_t)interval); _config->set_monitor_sampling_interval(static_cast<uint32_t>(interval));
return true; return true;
} }
@ -86,7 +86,7 @@ bool set_callback_timeback(int32_t timeout) {
MS_LOG(ERROR) << "Timeout given is not within the required range: " << timeout; MS_LOG(ERROR) << "Timeout given is not within the required range: " << timeout;
return false; return false;
} }
_config->set_callback_timeout((uint32_t)timeout); _config->set_callback_timeout(static_cast<uint32_t>(timeout));
return true; return true;
} }

View File

@ -215,7 +215,6 @@ bool Dataset::DeviceQueueCharIF(const std::vector<char> &queue_name, const std::
MS_LOG(ERROR) << "ToDevice: Failed to get consumer."; MS_LOG(ERROR) << "ToDevice: Failed to get consumer.";
return false; return false;
} }
rc = consumer->Init(ds); rc = consumer->Init(ds);
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "ToDevice: Failed to init. Error status: " << rc; MS_LOG(ERROR) << "ToDevice: Failed to init. Error status: " << rc;
@ -258,7 +257,6 @@ bool Dataset::SaveCharIF(const std::vector<char> &dataset_path, int32_t num_file
MS_LOG(ERROR) << "ToDevice: Failed to get consumer."; MS_LOG(ERROR) << "ToDevice: Failed to get consumer.";
return false; return false;
} }
rc = consumer->Init(ds->IRNode()); rc = consumer->Init(ds->IRNode());
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "CreateSaver failed." << rc; MS_LOG(ERROR) << "CreateSaver failed." << rc;
@ -289,11 +287,15 @@ bool Dataset::SaveCharIF(const std::vector<char> &dataset_path, int32_t num_file
Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); } Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); }
int64_t Dataset::GetDatasetSize(bool estimate) { int64_t Dataset::GetDatasetSize(bool estimate) {
int64_t dataset_size; int64_t dataset_size = -1;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1); RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
std::shared_ptr<DatasetSizeGetter> size_getter = std::make_shared<DatasetSizeGetter>(); std::shared_ptr<DatasetSizeGetter> size_getter = std::make_shared<DatasetSizeGetter>();
DatasetSizeGetter *consumer = size_getter.get(); DatasetSizeGetter *consumer = size_getter.get();
if (consumer == nullptr) {
MS_LOG(ERROR) << "DatasetSizeGetter: Failed to get consumer.";
return -1;
}
runtime_context->AssignConsumer(size_getter); runtime_context->AssignConsumer(size_getter);
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), -1); RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), -1);
RETURN_SECOND_IF_ERROR(consumer->GetDatasetSize(&dataset_size, estimate), -1); RETURN_SECOND_IF_ERROR(consumer->GetDatasetSize(&dataset_size, estimate), -1);
@ -305,6 +307,10 @@ std::vector<mindspore::DataType> Dataset::GetOutputTypes() {
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {}); RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
TreeGetters *consumer = tree_getters_.get(); TreeGetters *consumer = tree_getters_.get();
if (consumer == nullptr) {
MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
return std::vector<mindspore::DataType>();
}
runtime_context->AssignConsumer(tree_getters_); runtime_context->AssignConsumer(tree_getters_);
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {}); RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
RETURN_SECOND_IF_ERROR(consumer->GetOutputTypes(&types), {}); RETURN_SECOND_IF_ERROR(consumer->GetOutputTypes(&types), {});
@ -320,6 +326,10 @@ std::vector<std::vector<int64_t>> Dataset::GetOutputShapes() {
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {}); RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
TreeGetters *consumer = tree_getters_.get(); TreeGetters *consumer = tree_getters_.get();
if (consumer == nullptr) {
MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
return std::vector<std::vector<int64_t>>();
}
runtime_context->AssignConsumer(tree_getters_); runtime_context->AssignConsumer(tree_getters_);
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {}); RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
RETURN_SECOND_IF_ERROR(consumer->GetOutputShapes(&shapes), {}); RETURN_SECOND_IF_ERROR(consumer->GetOutputShapes(&shapes), {});
@ -330,10 +340,14 @@ std::vector<std::vector<int64_t>> Dataset::GetOutputShapes() {
} }
int64_t Dataset::GetNumClasses() { int64_t Dataset::GetNumClasses() {
int64_t num_classes; int64_t num_classes = -1;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1); RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
TreeGetters *consumer = tree_getters_.get(); TreeGetters *consumer = tree_getters_.get();
if (consumer == nullptr) {
MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
return -1;
}
runtime_context->AssignConsumer(tree_getters_); runtime_context->AssignConsumer(tree_getters_);
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), -1); RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), -1);
RETURN_SECOND_IF_ERROR(consumer->GetNumClasses(&num_classes), -1); RETURN_SECOND_IF_ERROR(consumer->GetNumClasses(&num_classes), -1);
@ -345,6 +359,10 @@ std::vector<std::vector<char>> Dataset::GetColumnNamesCharIF() {
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {}); RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
TreeGetters *consumer = tree_getters_.get(); TreeGetters *consumer = tree_getters_.get();
if (consumer == nullptr) {
MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
return std::vector<std::vector<char>>();
}
runtime_context->AssignConsumer(tree_getters_); runtime_context->AssignConsumer(tree_getters_);
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {}); RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
RETURN_SECOND_IF_ERROR(consumer->GetColumnNames(&col_names), {}); RETURN_SECOND_IF_ERROR(consumer->GetColumnNames(&col_names), {});
@ -356,6 +374,10 @@ std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> Dataset::GetClas
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {}); RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
TreeGetters *consumer = tree_getters_.get(); TreeGetters *consumer = tree_getters_.get();
if (consumer == nullptr) {
MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
return std::vector<std::pair<std::vector<char>, std::vector<int32_t>>>();
}
runtime_context->AssignConsumer(tree_getters_); runtime_context->AssignConsumer(tree_getters_);
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {}); RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
RETURN_SECOND_IF_ERROR(consumer->GetClassIndexing(&output_class_indexing), {}); RETURN_SECOND_IF_ERROR(consumer->GetClassIndexing(&output_class_indexing), {});
@ -493,10 +515,10 @@ TakeDataset::TakeDataset(std::shared_ptr<Dataset> input, int32_t count) {
ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) { ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
std::vector<std::shared_ptr<DatasetNode>> all_datasets; std::vector<std::shared_ptr<DatasetNode>> all_datasets;
(void)std::transform( (void)std::transform(datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
datasets.begin(), datasets.end(), std::back_inserter(all_datasets), [](std::shared_ptr<Dataset> dataset) -> std::shared_ptr<DatasetNode> {
[](std::shared_ptr<Dataset> dataset) -> std::shared_ptr<DatasetNode> { return dataset->IRNode(); }); return (dataset != nullptr) ? dataset->IRNode() : nullptr;
});
auto ds = std::make_shared<ZipNode>(all_datasets); auto ds = std::make_shared<ZipNode>(all_datasets);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
@ -544,6 +566,10 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocabCharIF(
auto consumer = std::make_unique<BuildVocabConsumer>(); auto consumer = std::make_unique<BuildVocabConsumer>();
BuildVocabConsumer *bv_consumer = consumer.get(); BuildVocabConsumer *bv_consumer = consumer.get();
if (bv_consumer == nullptr) {
MS_LOG(ERROR) << "BuildVocabConsumer: Failed to get bv_consumer.";
return nullptr;
}
rc = consumer->Init(ds); rc = consumer->Init(ds);
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init consumer. Error status: " << rc; MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init consumer. Error status: " << rc;
@ -577,6 +603,10 @@ std::shared_ptr<Vocab> Dataset::BuildVocabCharIF(const std::vector<std::vector<c
auto consumer = std::make_unique<BuildVocabConsumer>(); auto consumer = std::make_unique<BuildVocabConsumer>();
BuildVocabConsumer *bv_consumer = consumer.get(); BuildVocabConsumer *bv_consumer = consumer.get();
if (bv_consumer == nullptr) {
MS_LOG(ERROR) << "BuildVocabConsumer: Failed to get bv_consumer.";
return nullptr;
}
rc = consumer->Init(ds); rc = consumer->Init(ds);
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "BuildVocab: Failed to init consumer. Error status: " << rc; MS_LOG(ERROR) << "BuildVocab: Failed to init consumer. Error status: " << rc;

View File

@ -54,20 +54,27 @@ struct Execute::ExtraInfo {
#endif #endif
}; };
Execute::Execute(std::shared_ptr<TensorOperation> op, MapTargetDevice device_type, uint32_t device_id) { Status Execute::InitResource(MapTargetDevice device_type, uint32_t device_id) {
ops_.emplace_back(std::move(op));
device_type_ = device_type;
info_ = std::make_shared<ExtraInfo>();
#ifdef ENABLE_ACL #ifdef ENABLE_ACL
if (device_type_ == MapTargetDevice::kAscend310) { if (device_type_ == MapTargetDevice::kAscend310) {
device_resource_ = std::make_shared<AscendResource>(); device_resource_ = std::make_shared<AscendResource>();
Status rc = device_resource_->InitResource(device_id); Status rc = device_resource_->InitResource(device_id);
if (!rc.IsOk()) { if (!rc.IsOk()) {
device_resource_ = nullptr; device_resource_ = nullptr;
MS_LOG(ERROR) << "Initialize Ascend310 resource fail."; std::string err_msg = "Initialize Ascend310 resource fail";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_UNEXPECTED(err_msg);
} }
} }
#endif #endif
return Status::OK();
}
Execute::Execute(std::shared_ptr<TensorOperation> op, MapTargetDevice device_type, uint32_t device_id) {
ops_.emplace_back(std::move(op));
device_type_ = device_type;
info_ = std::make_shared<ExtraInfo>();
(void)InitResource(device_type, device_id);
} }
Execute::Execute(std::shared_ptr<TensorTransform> op, MapTargetDevice device_type, uint32_t device_id) { Execute::Execute(std::shared_ptr<TensorTransform> op, MapTargetDevice device_type, uint32_t device_id) {
@ -76,16 +83,7 @@ Execute::Execute(std::shared_ptr<TensorTransform> op, MapTargetDevice device_typ
info_ = std::make_shared<ExtraInfo>(); info_ = std::make_shared<ExtraInfo>();
device_type_ = device_type; device_type_ = device_type;
#ifdef ENABLE_ACL (void)InitResource(device_type, device_id);
if (device_type_ == MapTargetDevice::kAscend310) {
device_resource_ = std::make_shared<AscendResource>();
Status rc = device_resource_->InitResource(device_id);
if (!rc.IsOk()) {
device_resource_ = nullptr;
MS_LOG(ERROR) << "Initialize Ascend310 resource fail.";
}
}
#endif
} }
Execute::Execute(std::reference_wrapper<TensorTransform> op, MapTargetDevice device_type, uint32_t device_id) { Execute::Execute(std::reference_wrapper<TensorTransform> op, MapTargetDevice device_type, uint32_t device_id) {
@ -96,16 +94,7 @@ Execute::Execute(std::reference_wrapper<TensorTransform> op, MapTargetDevice dev
info_ = std::make_shared<ExtraInfo>(); info_ = std::make_shared<ExtraInfo>();
info_->init_with_shared_ptr_ = false; info_->init_with_shared_ptr_ = false;
device_type_ = device_type; device_type_ = device_type;
#ifdef ENABLE_ACL (void)InitResource(device_type, device_id);
if (device_type_ == MapTargetDevice::kAscend310) {
device_resource_ = std::make_shared<AscendResource>();
Status rc = device_resource_->InitResource(device_id);
if (!rc.IsOk()) {
device_resource_ = nullptr;
MS_LOG(ERROR) << "Initialize Ascend310 resource fail.";
}
}
#endif
} }
// Execute function for the example case: auto decode(new vision::Decode()); // Execute function for the example case: auto decode(new vision::Decode());
@ -116,31 +105,13 @@ Execute::Execute(TensorTransform *op, MapTargetDevice device_type, uint32_t devi
info_ = std::make_shared<ExtraInfo>(); info_ = std::make_shared<ExtraInfo>();
device_type_ = device_type; device_type_ = device_type;
#ifdef ENABLE_ACL (void)InitResource(device_type, device_id);
if (device_type_ == MapTargetDevice::kAscend310) {
device_resource_ = std::make_shared<AscendResource>();
Status rc = device_resource_->InitResource(device_id);
if (!rc.IsOk()) {
device_resource_ = nullptr;
MS_LOG(ERROR) << "Initialize Ascend310 resource fail.";
}
}
#endif
} }
Execute::Execute(std::vector<std::shared_ptr<TensorOperation>> ops, MapTargetDevice device_type, uint32_t device_id) Execute::Execute(std::vector<std::shared_ptr<TensorOperation>> ops, MapTargetDevice device_type, uint32_t device_id)
: ops_(std::move(ops)), device_type_(device_type) { : ops_(std::move(ops)), device_type_(device_type) {
info_ = std::make_shared<ExtraInfo>(); info_ = std::make_shared<ExtraInfo>();
#ifdef ENABLE_ACL (void)InitResource(device_type, device_id);
if (device_type_ == MapTargetDevice::kAscend310) {
device_resource_ = std::make_shared<AscendResource>();
Status rc = device_resource_->InitResource(device_id);
if (!rc.IsOk()) {
device_resource_ = nullptr;
MS_LOG(ERROR) << "Initialize Ascend310 resource fail.";
}
}
#endif
} }
Execute::Execute(std::vector<std::shared_ptr<TensorTransform>> ops, MapTargetDevice device_type, uint32_t device_id) { Execute::Execute(std::vector<std::shared_ptr<TensorTransform>> ops, MapTargetDevice device_type, uint32_t device_id) {
@ -149,16 +120,7 @@ Execute::Execute(std::vector<std::shared_ptr<TensorTransform>> ops, MapTargetDev
info_ = std::make_shared<ExtraInfo>(); info_ = std::make_shared<ExtraInfo>();
device_type_ = device_type; device_type_ = device_type;
#ifdef ENABLE_ACL (void)InitResource(device_type, device_id);
if (device_type_ == MapTargetDevice::kAscend310) {
device_resource_ = std::make_shared<AscendResource>();
Status rc = device_resource_->InitResource(device_id);
if (!rc.IsOk()) {
device_resource_ = nullptr;
MS_LOG(ERROR) << "Initialize Ascend310 resource fail.";
}
}
#endif
} }
Execute::Execute(const std::vector<std::reference_wrapper<TensorTransform>> ops, MapTargetDevice device_type, Execute::Execute(const std::vector<std::reference_wrapper<TensorTransform>> ops, MapTargetDevice device_type,
@ -177,16 +139,7 @@ Execute::Execute(const std::vector<std::reference_wrapper<TensorTransform>> ops,
info_ = std::make_shared<ExtraInfo>(); info_ = std::make_shared<ExtraInfo>();
info_->init_with_shared_ptr_ = false; info_->init_with_shared_ptr_ = false;
device_type_ = device_type; device_type_ = device_type;
#ifdef ENABLE_ACL (void)InitResource(device_type, device_id);
if (device_type_ == MapTargetDevice::kAscend310) {
device_resource_ = std::make_shared<AscendResource>();
Status rc = device_resource_->InitResource(device_id);
if (!rc.IsOk()) {
device_resource_ = nullptr;
MS_LOG(ERROR) << "Initialize Ascend310 resource fail";
}
}
#endif
} }
// Execute function for the example vector case: auto decode(new vision::Decode()); // Execute function for the example vector case: auto decode(new vision::Decode());
@ -199,16 +152,7 @@ Execute::Execute(const std::vector<TensorTransform *> &ops, MapTargetDevice devi
info_ = std::make_shared<ExtraInfo>(); info_ = std::make_shared<ExtraInfo>();
device_type_ = device_type; device_type_ = device_type;
#ifdef ENABLE_ACL (void)InitResource(device_type, device_id);
if (device_type_ == MapTargetDevice::kAscend310) {
device_resource_ = std::make_shared<AscendResource>();
Status rc = device_resource_->InitResource(device_id);
if (!rc.IsOk()) {
device_resource_ = nullptr;
MS_LOG(ERROR) << "Initialize Ascend310 resource fail";
}
}
#endif
} }
Execute::~Execute() { Execute::~Execute() {
@ -228,6 +172,7 @@ Execute::~Execute() {
Status Execute::operator()(const mindspore::MSTensor &input, mindspore::MSTensor *output) { Status Execute::operator()(const mindspore::MSTensor &input, mindspore::MSTensor *output) {
// Validate input tensor // Validate input tensor
RETURN_UNEXPECTED_IF_NULL(output);
CHECK_FAIL_RETURN_UNEXPECTED(input.DataSize() > 0, "Input Tensor has no data."); CHECK_FAIL_RETURN_UNEXPECTED(input.DataSize() > 0, "Input Tensor has no data.");
CHECK_FAIL_RETURN_UNEXPECTED(output != nullptr, "Output Tensor can not be nullptr."); CHECK_FAIL_RETURN_UNEXPECTED(output != nullptr, "Output Tensor can not be nullptr.");
CHECK_FAIL_RETURN_UNEXPECTED(ValidateDevice(), "Device Type should be 'Ascend310' or 'CPU'."); CHECK_FAIL_RETURN_UNEXPECTED(ValidateDevice(), "Device Type should be 'Ascend310' or 'CPU'.");
@ -290,7 +235,8 @@ Status Execute::operator()(const mindspore::MSTensor &input, mindspore::MSTensor
RETURN_STATUS_UNEXPECTED(ss.str()); RETURN_STATUS_UNEXPECTED(ss.str());
} }
*output = mindspore::MSTensor(std::make_shared<DETensor>(de_tensor)); *output = mindspore::MSTensor(std::make_shared<DETensor>(de_tensor));
} else { // Ascend310 case, where we must set Ascend resource on each operators } else if (device_type_ ==
MapTargetDevice::kAscend310) { // Ascend310 case, where we must set Ascend resource on each operators
#ifdef ENABLE_ACL #ifdef ENABLE_ACL
CHECK_FAIL_RETURN_UNEXPECTED(device_resource_, "Device resource is nullptr which is illegal under case Ascend310."); CHECK_FAIL_RETURN_UNEXPECTED(device_resource_, "Device resource is nullptr which is illegal under case Ascend310.");
// Sink data from host into device // Sink data from host into device
@ -311,12 +257,17 @@ Status Execute::operator()(const mindspore::MSTensor &input, mindspore::MSTensor
*output = mindspore::MSTensor(std::make_shared<DETensor>(device_input, true)); *output = mindspore::MSTensor(std::make_shared<DETensor>(device_input, true));
#endif #endif
} else {
std::string err_msg = "Your input device is not supported. (Option: CPU or Ascend310)";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_UNEXPECTED(err_msg);
} }
return Status::OK(); return Status::OK();
} }
Status Execute::operator()(const std::vector<MSTensor> &input_tensor_list, std::vector<MSTensor> *output_tensor_list) { Status Execute::operator()(const std::vector<MSTensor> &input_tensor_list, std::vector<MSTensor> *output_tensor_list) {
// Validate input tensor // Validate input tensor
RETURN_UNEXPECTED_IF_NULL(output_tensor_list);
CHECK_FAIL_RETURN_UNEXPECTED(!input_tensor_list.empty(), "Input Tensor is not valid."); CHECK_FAIL_RETURN_UNEXPECTED(!input_tensor_list.empty(), "Input Tensor is not valid.");
CHECK_FAIL_RETURN_UNEXPECTED(output_tensor_list != nullptr, "Output Tensor can not be nullptr."); CHECK_FAIL_RETURN_UNEXPECTED(output_tensor_list != nullptr, "Output Tensor can not be nullptr.");
output_tensor_list->clear(); output_tensor_list->clear();
@ -380,7 +331,8 @@ Status Execute::operator()(const std::vector<MSTensor> &input_tensor_list, std::
++idx; ++idx;
} }
CHECK_FAIL_RETURN_UNEXPECTED(!output_tensor_list->empty(), "Output Tensor is not valid."); CHECK_FAIL_RETURN_UNEXPECTED(!output_tensor_list->empty(), "Output Tensor is not valid.");
} else { // Case Ascend310 } else if (device_type_ ==
MapTargetDevice::kAscend310) { // Ascend310 case, where we must set Ascend resource on each operators
#ifdef ENABLE_ACL #ifdef ENABLE_ACL
CHECK_FAIL_RETURN_UNEXPECTED(device_resource_, "Device resource is nullptr which is illegal under case Ascend310."); CHECK_FAIL_RETURN_UNEXPECTED(device_resource_, "Device resource is nullptr which is illegal under case Ascend310.");
for (auto &input_tensor : input_tensor_list) { for (auto &input_tensor : input_tensor_list) {
@ -410,6 +362,10 @@ Status Execute::operator()(const std::vector<MSTensor> &input_tensor_list, std::
} }
CHECK_FAIL_RETURN_UNEXPECTED(!output_tensor_list->empty(), "Output Tensor vector is empty."); CHECK_FAIL_RETURN_UNEXPECTED(!output_tensor_list->empty(), "Output Tensor vector is empty.");
#endif #endif
} else {
std::string err_msg = "Your input device is not supported. (Option: CPU or Ascend310)";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_UNEXPECTED(err_msg);
} }
return Status::OK(); return Status::OK();
} }
@ -526,6 +482,10 @@ Status AippInfoCollection(std::map<std::string, std::string> *aipp_options, cons
std::string Execute::AippCfgGenerator() { std::string Execute::AippCfgGenerator() {
std::string config_location = "./aipp.cfg"; std::string config_location = "./aipp.cfg";
if (info_ == nullptr) {
MS_LOG(ERROR) << "info_ is null";
return "";
}
#ifdef ENABLE_ACL #ifdef ENABLE_ACL
if (info_->init_with_shared_ptr_) { if (info_->init_with_shared_ptr_) {
auto rc = ParseTransforms(); auto rc = ParseTransforms();
@ -565,8 +525,7 @@ std::string Execute::AippCfgGenerator() {
if (!outfile.is_open()) { if (!outfile.is_open()) {
MS_LOG(ERROR) << "Fail to open Aipp config file, please verify your system config(including authority)." MS_LOG(ERROR) << "Fail to open Aipp config file, please verify your system config(including authority)."
<< "We will return empty string which represent the location of Aipp config file in this case."; << "We will return empty string which represent the location of Aipp config file in this case.";
std::string except = ""; return "";
return except;
} }
if (device_type_ == MapTargetDevice::kAscend310) { if (device_type_ == MapTargetDevice::kAscend310) {
@ -650,10 +609,14 @@ Status Execute::ParseTransforms() {
[](std::shared_ptr<TensorTransform> operation) -> std::shared_ptr<TensorOperation> { [](std::shared_ptr<TensorTransform> operation) -> std::shared_ptr<TensorOperation> {
return operation->Parse(); return operation->Parse();
}); });
} else { } else if (device_type_ == MapTargetDevice::kAscend310) {
for (auto &transform_ : transforms_) { for (auto &transform_ : transforms_) {
ops_.emplace_back(transform_->Parse(device_type_)); ops_.emplace_back(transform_->Parse(device_type_));
} }
} else {
std::string err_msg = "Your input device is not supported. (Option: CPU or Ascend310)";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_UNEXPECTED(err_msg);
} }
return Status::OK(); return Status::OK();

View File

@ -38,7 +38,7 @@ Status Iterator::GetNextRowCharIF(MSTensorMapChar *row) {
row->clear(); row->clear();
return rc; return rc;
} }
for (auto de_tensor : md_map) { for (auto &de_tensor : md_map) {
std::vector<char> col_name(de_tensor.first.begin(), de_tensor.first.end()); std::vector<char> col_name(de_tensor.first.begin(), de_tensor.first.end());
row->insert(std::make_pair(col_name, mindspore::MSTensor(std::make_shared<DETensor>(de_tensor.second)))); row->insert(std::make_pair(col_name, mindspore::MSTensor(std::make_shared<DETensor>(de_tensor.second))));
} }
@ -48,8 +48,8 @@ Status Iterator::GetNextRowCharIF(MSTensorMapChar *row) {
// Get the next row from the data pipeline. // Get the next row from the data pipeline.
Status Iterator::GetNextRow(MSTensorVec *row) { Status Iterator::GetNextRow(MSTensorVec *row) {
// Clean data row
RETURN_UNEXPECTED_IF_NULL(row); RETURN_UNEXPECTED_IF_NULL(row);
// Clean data row
row->clear(); row->clear();
// create a dataset tensor row and fetch. Then we convert the output to MSTensor // create a dataset tensor row and fetch. Then we convert the output to MSTensor
std::vector<std::shared_ptr<dataset::Tensor>> md_row; std::vector<std::shared_ptr<dataset::Tensor>> md_row;
@ -76,6 +76,7 @@ void Iterator::Stop() {
// Function to build and launch the execution tree. // Function to build and launch the execution tree.
Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds, int32_t num_epochs) { Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds, int32_t num_epochs) {
RETURN_UNEXPECTED_IF_NULL(ds);
runtime_context_ = std::make_unique<NativeRuntimeContext>(); runtime_context_ = std::make_unique<NativeRuntimeContext>();
CHECK_FAIL_RETURN_UNEXPECTED(runtime_context_ != nullptr, "Create runtime_context_ failed."); CHECK_FAIL_RETURN_UNEXPECTED(runtime_context_ != nullptr, "Create runtime_context_ failed.");
RETURN_IF_NOT_OK(runtime_context_->Init()); RETURN_IF_NOT_OK(runtime_context_->Init());

View File

@ -107,7 +107,7 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
// clear the old tensor row // clear the old tensor row
out_row->clear(); out_row->clear();
#ifndef ENABLE_SECURITY #ifndef ENABLE_SECURITY
bool isProfilingEnable = root_->Tree()->GetProfilingManager()->IsProfilingEnable(); bool is_profiling_enable = root_->Tree()->GetProfilingManager()->IsProfilingEnable();
#endif #endif
// Once eof is handled, always return empty row. Class must be destroyed and recreated if you // Once eof is handled, always return empty row. Class must be destroyed and recreated if you
// want to iterate again. // want to iterate again.
@ -131,7 +131,7 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
if (out_row->eoe()) { if (out_row->eoe()) {
MS_LOG(INFO) << "End of data iteration."; MS_LOG(INFO) << "End of data iteration.";
#ifndef ENABLE_SECURITY #ifndef ENABLE_SECURITY
if (isProfilingEnable) { if (is_profiling_enable) {
root_->Tree()->SetEpochEnd(); root_->Tree()->SetEpochEnd();
} }
#endif #endif

View File

@ -248,9 +248,13 @@ Status BatchOp::WorkerEntry(int32_t workerId) {
Status BatchOp::MakeBatchedRow(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair, TensorRow *new_row) { Status BatchOp::MakeBatchedRow(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair, TensorRow *new_row) {
RETURN_UNEXPECTED_IF_NULL(table_pair.first); RETURN_UNEXPECTED_IF_NULL(table_pair.first);
#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
if (!in_col_names_.empty()) RETURN_IF_NOT_OK(MapColumns(&table_pair)); // pass it through pyfunc if (!in_col_names_.empty()) {
RETURN_IF_NOT_OK(MapColumns(&table_pair));
} // pass it through pyfun
#endif #endif
if (pad_) RETURN_IF_NOT_OK(PadColumns(&table_pair.first, pad_info_, column_name_id_map_)); // do padding if needed if (pad_) {
RETURN_IF_NOT_OK(PadColumns(&table_pair.first, pad_info_, column_name_id_map_));
} // do padding if needed
RETURN_IF_NOT_OK(BatchRows(&table_pair.first, new_row, table_pair.first->size())); RETURN_IF_NOT_OK(BatchRows(&table_pair.first, new_row, table_pair.first->size()));
return Status::OK(); return Status::OK();
} }

View File

@ -242,7 +242,7 @@ class BatchOp : public ParallelOp {
// the number of thread pulling from the mOutConnector of the Op below // the number of thread pulling from the mOutConnector of the Op below
// @return int32_t, 1 // @return int32_t, 1
int32_t num_consumers() const override { return 1; } int32_t NumConsumers() const override { return 1; }
// get the batch size for next batch // get the batch size for next batch
// @return Status The status code returned // @return Status The status code returned

View File

@ -71,11 +71,11 @@ class BuildSentencePieceVocabOp : public PipelineOp {
// Getter // Getter
// @return the number of workers // @return the number of workers
int32_t num_producers() const override { return 1; } int32_t NumProducers() const override { return 1; }
// Getter // Getter
// @return the number of threads consuming from the previous Connector // @return the number of threads consuming from the previous Connector
int32_t num_consumers() const override { return 1; } int32_t NumConsumers() const override { return 1; }
Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildSentencePieceVocabOp"); } Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildSentencePieceVocabOp"); }

View File

@ -68,11 +68,11 @@ class BuildVocabOp : public ParallelOp {
/// Getter /// Getter
/// @return the number of workers /// @return the number of workers
int32_t num_producers() const override { return 1; } int32_t NumProducers() const override { return 1; }
/// Getter /// Getter
/// @return the number of threads consuming from the previous Connector /// @return the number of threads consuming from the previous Connector
int32_t num_consumers() const override { return 1; } int32_t NumConsumers() const override { return 1; }
Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildVocabOp"); } Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildVocabOp"); }

View File

@ -69,6 +69,7 @@ Status CacheBase::FetchSamplesToWorkers() {
int64_t wait_cnt = 0; int64_t wait_cnt = 0;
int64_t prefetch_cnt = 0; int64_t prefetch_cnt = 0;
// Kick off several threads which will prefetch prefetch_size_ rows in advance. // Kick off several threads which will prefetch prefetch_size_ rows in advance.
RETURN_UNEXPECTED_IF_NULL(tree_);
RETURN_IF_NOT_OK( RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_prefetchers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1), Name())); tree_->LaunchWorkers(num_prefetchers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1), Name()));
auto send_to_que = [](QueueList<std::unique_ptr<IOBlock>> &qList, int32_t worker_id, auto send_to_que = [](QueueList<std::unique_ptr<IOBlock>> &qList, int32_t worker_id,
@ -80,7 +81,7 @@ Status CacheBase::FetchSamplesToWorkers() {
// Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them // Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them
// to the WorkerEntry. // to the WorkerEntry.
do { do {
if (AllowCacheMiss() && wait_cnt > 0 && wait_cnt % op_num_repeats_per_epoch() == 0) { if (AllowCacheMiss() && wait_cnt > 0 && wait_cnt % GetOpNumRepeatsPerEpoch() == 0) {
MS_LOG(INFO) << "Epoch: " << op_current_epochs_ << " Cache Miss : " << num_cache_miss_ MS_LOG(INFO) << "Epoch: " << op_current_epochs_ << " Cache Miss : " << num_cache_miss_
<< " Total number of rows : " << row_cnt_; << " Total number of rows : " << row_cnt_;
} }
@ -155,7 +156,7 @@ Status CacheBase::FetchSamplesToWorkers() {
} }
// Dump the last epoch result (approximately) without waiting for the worker threads to come back. // Dump the last epoch result (approximately) without waiting for the worker threads to come back.
if (AllowCacheMiss()) { if (AllowCacheMiss()) {
MS_LOG(INFO) << "Epoch: " << wait_cnt / op_num_repeats_per_epoch() << " Cache Miss : " << num_cache_miss_ MS_LOG(INFO) << "Epoch: " << wait_cnt / GetOpNumRepeatsPerEpoch() << " Cache Miss : " << num_cache_miss_
<< " Total number of rows : " << row_cnt_; << " Total number of rows : " << row_cnt_;
} }
return Status::OK(); return Status::OK();
@ -202,6 +203,7 @@ Status CacheBase::FetchFromCache(int32_t worker_id) {
} }
Status CacheBase::RegisterResources() { Status CacheBase::RegisterResources() {
RETURN_UNEXPECTED_IF_NULL(tree_);
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks()));

View File

@ -75,7 +75,7 @@ class CacheBase : public ParallelOp {
/// \brief Getter for the cache client /// \brief Getter for the cache client
/// \return shared ptr to the cache client /// \return shared ptr to the cache client
std::shared_ptr<CacheClient> cache_client() { return cache_client_; } std::shared_ptr<CacheClient> GetCacheClient() { return cache_client_; }
/// \brief Setter for the cache client /// \brief Setter for the cache client
void SetCacheClient(std::shared_ptr<CacheClient> cache_client) { cache_client_ = std::move(cache_client); } void SetCacheClient(std::shared_ptr<CacheClient> cache_client) { cache_client_ = std::move(cache_client); }
/// \brief Derived class must implement this method if a cache miss is treated as error /// \brief Derived class must implement this method if a cache miss is treated as error

View File

@ -40,6 +40,7 @@ CacheOp::~CacheOp() = default;
// This class functor will provide the master loop that drives the logic for performing the work // This class functor will provide the master loop that drives the logic for performing the work
Status CacheOp::operator()() { Status CacheOp::operator()() {
RETURN_UNEXPECTED_IF_NULL(tree_);
if (!sampler_) { if (!sampler_) {
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, CacheOp requires a sampler before it can be executed, but got nullptr."); "Invalid parameter, CacheOp requires a sampler before it can be executed, but got nullptr.");
@ -166,6 +167,7 @@ Status CacheOp::WorkerEntry(int32_t worker_id) {
return Status::OK(); return Status::OK();
} }
Status CacheOp::RegisterResources() { Status CacheOp::RegisterResources() {
RETURN_UNEXPECTED_IF_NULL(tree_);
RETURN_IF_NOT_OK(CacheBase::RegisterResources()); RETURN_IF_NOT_OK(CacheBase::RegisterResources());
RETURN_IF_NOT_OK(rows_cache_done_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(rows_cache_done_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(keys_miss_->Register(tree_->AllTasks())); RETURN_IF_NOT_OK(keys_miss_->Register(tree_->AllTasks()));

View File

@ -194,7 +194,7 @@ Status ConcatOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe
return Status::OK(); return Status::OK();
} }
int32_t ConcatOp::num_consumers() const { int32_t ConcatOp::NumConsumers() const {
if (parent_.empty()) { if (parent_.empty()) {
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1."; MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
return 1; return 1;
@ -202,16 +202,16 @@ int32_t ConcatOp::num_consumers() const {
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0."; MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
return 0; return 0;
} else { } else {
return parent_[0]->num_consumers(); return parent_[0]->NumConsumers();
} }
} }
int32_t ConcatOp::num_producers() const { int32_t ConcatOp::NumProducers() const {
if (child_.empty() || child_[0] == nullptr) { if (child_.empty() || child_[0] == nullptr) {
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0."; MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
return 0; return 0;
} else { } else {
return child_[0]->num_producers(); return child_[0]->NumProducers();
} }
} }
} // namespace dataset } // namespace dataset

View File

@ -73,8 +73,8 @@ class ConcatOp : public PipelineOp {
Status GetNumClasses(int64_t *num_classes) override; Status GetNumClasses(int64_t *num_classes) override;
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override; Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
int32_t num_consumers() const override; int32_t NumConsumers() const override;
int32_t num_producers() const override; int32_t NumProducers() const override;
/// Check if the current sample will be taken or dropped /// Check if the current sample will be taken or dropped
/// \return bool /// \return bool

View File

@ -312,9 +312,9 @@ Status DatasetOp::PrepareOperator() {
// The consumer of the root node is assumed to be one thread. // The consumer of the root node is assumed to be one thread.
// If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion. // If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion.
if (parent_.empty()) { if (parent_.empty()) {
this->CreateConnector(num_producers(), 1); this->CreateConnector(NumProducers(), 1);
} else { } else {
this->CreateConnector(num_producers(), parent_[0]->num_consumers()); this->CreateConnector(NumProducers(), parent_[0]->NumConsumers());
} }
if (out_connector_) { if (out_connector_) {
RETURN_IF_NOT_OK(out_connector_->Register(tree_->AllTasks())); RETURN_IF_NOT_OK(out_connector_->Register(tree_->AllTasks()));

View File

@ -209,33 +209,33 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
// \brief Getter function // \brief Getter function
// \return The number of workers in this op // \return The number of workers in this op
virtual int32_t num_workers() const = 0; virtual int32_t NumWorkers() const = 0;
// \brief Getter function // \brief Getter function
// \return The number of threads consuming from previous op. // \return The number of threads consuming from previous op.
virtual int32_t num_consumers() const = 0; virtual int32_t NumConsumers() const = 0;
// \brief Getter function // \brief Getter function
// \return The number of threads producing to the output connector. // \return The number of threads producing to the output connector.
virtual int32_t num_producers() const = 0; virtual int32_t NumProducers() const = 0;
// \brief Getter function // \brief Getter function
// \return T/F if this is an inlined operator // \return T/F if this is an inlined operator
bool inlined() const { return (oc_queue_size_ == 0); } bool inlined() const { return (oc_queue_size_ == 0); }
// \brief Setter function, set the number of total repeats for the operator // \brief Setter function, set the number of total repeats for the operator
void set_total_repeats(int32_t total_repeats) { op_total_repeats_ = total_repeats; } void SetTotalRepeats(int32_t total_repeats) { op_total_repeats_ = total_repeats; }
// \brief Setter function, set the number of repeats per epoch for the operator // \brief Setter function, set the number of repeats per epoch for the operator
void set_num_repeats_per_epoch(int32_t num_repeats_per_epoch) { op_num_repeats_per_epoch_ = num_repeats_per_epoch; } void SetNumRepeatsPerEpoch(int32_t num_repeats_per_epoch) { op_num_repeats_per_epoch_ = num_repeats_per_epoch; }
// \brief Getter function // \brief Getter function
// \return The number of required repeats for the operator // \return The number of required repeats for the operator
int32_t op_total_repeats() { return op_total_repeats_; } int32_t GetOpTotalRepeats() { return op_total_repeats_; }
// \brief Getter function // \brief Getter function
// \return The number of repeats per epoch for the operator // \return The number of repeats per epoch for the operator
int32_t op_num_repeats_per_epoch() const { return op_num_repeats_per_epoch_; } int32_t GetOpNumRepeatsPerEpoch() const { return op_num_repeats_per_epoch_; }
// \brief Register the internal worker connectors. No op unless it is a parallel op // \brief Register the internal worker connectors. No op unless it is a parallel op
// \return Status // \return Status
@ -393,7 +393,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
// \notes No public interface. Only the class itself, or it's friend the execution tree can set // \notes No public interface. Only the class itself, or it's friend the execution tree can set
// this // this
// \param op_id - the Id value to set into the operator // \param op_id - the Id value to set into the operator
void set_id(int32_t op_id) { operator_id_ = op_id; } void SetId(int32_t op_id) { operator_id_ = op_id; }
// Sets the tree into the op so that the operator has a back pointer to the tree. // Sets the tree into the op so that the operator has a back pointer to the tree.
// \param tree - the tree to assign to the op. // \param tree - the tree to assign to the op.

View File

@ -187,8 +187,8 @@ Status DeviceQueueOp::SendDataToAscend() {
#ifndef ENABLE_SECURITY #ifndef ENABLE_SECURITY
std::shared_ptr<DeviceQueueTracing> profiling_node; std::shared_ptr<DeviceQueueTracing> profiling_node;
bool isProfilingEnable = tree_->GetProfilingManager()->IsProfilingEnable(); bool is_profiling_enable = tree_->GetProfilingManager()->IsProfilingEnable();
if (isProfilingEnable) { if (is_profiling_enable) {
std::shared_ptr<Tracing> node; std::shared_ptr<Tracing> node;
RETURN_IF_NOT_OK(tree_->GetProfilingManager()->GetTracingNode(kDeviceQueueTracingName, &node)); RETURN_IF_NOT_OK(tree_->GetProfilingManager()->GetTracingNode(kDeviceQueueTracingName, &node));
profiling_node = std::dynamic_pointer_cast<DeviceQueueTracing>(node); profiling_node = std::dynamic_pointer_cast<DeviceQueueTracing>(node);
@ -196,7 +196,7 @@ Status DeviceQueueOp::SendDataToAscend() {
connector_capacity = ChildOpConnectorCapacity(); connector_capacity = ChildOpConnectorCapacity();
} }
#else #else
bool isProfilingEnable = false; bool is_profiling_enable = false;
#endif #endif
#ifdef ENABLE_DUMP_IR #ifdef ENABLE_DUMP_IR
md_channel_info_->RecordBatchQueue(ChildOpConnectorSize()); md_channel_info_->RecordBatchQueue(ChildOpConnectorSize());
@ -218,14 +218,14 @@ Status DeviceQueueOp::SendDataToAscend() {
md_channel_info_->RecordPreprocessBatch(send_batch); md_channel_info_->RecordPreprocessBatch(send_batch);
md_channel_info_->RecordPushStartTime(); md_channel_info_->RecordPushStartTime();
#endif #endif
RETURN_IF_NOT_OK(SendRowToTdt(curr_row, isProfilingEnable, &tdt_cost)); RETURN_IF_NOT_OK(SendRowToTdt(curr_row, is_profiling_enable, &tdt_cost));
if (first_push_flag_ != true) { if (first_push_flag_ != true) {
MS_LOG(INFO) << "Loading dataset and push first batch into device successful."; MS_LOG(INFO) << "Loading dataset and push first batch into device successful.";
first_push_flag_ = true; first_push_flag_ = true;
} }
#ifndef ENABLE_SECURITY #ifndef ENABLE_SECURITY
DetectPerBatchTime(&batch_record_start, &batch_record_end); DetectPerBatchTime(&batch_record_start, &batch_record_end);
ProfilingRecorder(isProfilingEnable, profiling_node, send_batch, tdt_cost, &batch_start_time, &end_time, ProfilingRecorder(is_profiling_enable, profiling_node, send_batch, tdt_cost, &batch_start_time, &end_time,
connector_capacity, connector_size); connector_capacity, connector_size);
#endif #endif
send_batch++; send_batch++;
@ -244,7 +244,7 @@ Status DeviceQueueOp::SendDataToAscend() {
LimitSendingBatches(send_batch, &sending_num, cfg); LimitSendingBatches(send_batch, &sending_num, cfg);
#ifndef ENABLE_SECURITY #ifndef ENABLE_SECURITY
if (isProfilingEnable) { if (is_profiling_enable) {
connector_size = ChildOpConnectorSize(); connector_size = ChildOpConnectorSize();
connector_capacity = ChildOpConnectorCapacity(); connector_capacity = ChildOpConnectorCapacity();
} }
@ -252,8 +252,8 @@ Status DeviceQueueOp::SendDataToAscend() {
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&curr_row)); RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&curr_row));
} }
if (curr_row.eoe() && send_epoch_end_) { if (curr_row.eoe() && send_epoch_end_) {
TensorRow currRow; TensorRow dummy_row;
auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost, auto status = tdtInstancePtr->hostPush(dummy_row, true, channel_name_, is_profiling_enable, tdt_cost,
ACL_TENSOR_DATA_END_OF_SEQUENCE); ACL_TENSOR_DATA_END_OF_SEQUENCE);
if (status != Status::OK()) { if (status != Status::OK()) {
if (stop_send_) { if (stop_send_) {
@ -273,7 +273,7 @@ Status DeviceQueueOp::SendDataToAscend() {
stop_send_ = true; stop_send_ = true;
} }
#ifndef ENABLE_SECURITY #ifndef ENABLE_SECURITY
if (isProfilingEnable) { if (is_profiling_enable) {
connector_size = ChildOpConnectorSize(); connector_size = ChildOpConnectorSize();
connector_capacity = ChildOpConnectorCapacity(); connector_capacity = ChildOpConnectorCapacity();
tree_->SetEpochEnd(); tree_->SetEpochEnd();
@ -311,8 +311,8 @@ void DeviceQueueOp::LimitSendingBatches(int64_t send_batch, int64_t *sending_num
} }
} }
Status DeviceQueueOp::SendRowToTdt(TensorRow currRow, bool isProfilingEnable, int32_t *tdt_cost) { Status DeviceQueueOp::SendRowToTdt(TensorRow curr_row, bool is_profiling_enable, int32_t *tdt_cost) {
auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, *tdt_cost); auto status = tdtInstancePtr->hostPush(curr_row, true, channel_name_, is_profiling_enable, *tdt_cost);
if (status != Status::OK()) { if (status != Status::OK()) {
if (stop_send_) { if (stop_send_) {
MS_LOG(INFO) << "stop_send received"; MS_LOG(INFO) << "stop_send received";
@ -328,7 +328,7 @@ Status DeviceQueueOp::SendRowToTdt(TensorRow currRow, bool isProfilingEnable, in
} }
if (create_data_info_queue_) { if (create_data_info_queue_) {
DATA_INFO data_info; DATA_INFO data_info;
(void)std::transform(currRow.begin(), currRow.end(), std::back_inserter(data_info), (void)std::transform(curr_row.begin(), curr_row.end(), std::back_inserter(data_info),
[](const std::shared_ptr<Tensor> &ts) { return std::make_pair(ts->type(), ts->shape()); }); [](const std::shared_ptr<Tensor> &ts) { return std::make_pair(ts->type(), ts->shape()); });
RETURN_IF_NOT_OK(data_info_queue_ptr_->Add(data_info)); RETURN_IF_NOT_OK(data_info_queue_ptr_->Add(data_info));
} }
@ -376,6 +376,7 @@ Status DeviceQueueOp::SetThreadDevice() {
} }
Status DeviceQueueOp::LaunchParallelCopyThread() { Status DeviceQueueOp::LaunchParallelCopyThread() {
RETURN_UNEXPECTED_IF_NULL(tree_);
// Every thread use cuda api should SetThreadDevice // Every thread use cuda api should SetThreadDevice
RETURN_IF_NOT_OK(SetThreadDevice()); RETURN_IF_NOT_OK(SetThreadDevice());
// CircularPool may not safe under multi-threads scenario, so one worker with one pool // CircularPool may not safe under multi-threads scenario, so one worker with one pool
@ -396,6 +397,7 @@ Status DeviceQueueOp::LaunchParallelCopyThread() {
} }
Status DeviceQueueOp::PushDataToGPU() { Status DeviceQueueOp::PushDataToGPU() {
RETURN_UNEXPECTED_IF_NULL(tree_);
// Every thread use cuda api should SetThreadDevice // Every thread use cuda api should SetThreadDevice
RETURN_IF_NOT_OK(SetThreadDevice()); RETURN_IF_NOT_OK(SetThreadDevice());
TaskManager::FindMe()->Post(); TaskManager::FindMe()->Post();
@ -405,8 +407,8 @@ Status DeviceQueueOp::PushDataToGPU() {
int32_t connector_size = 0; int32_t connector_size = 0;
int32_t connector_capacity = 0; int32_t connector_capacity = 0;
std::shared_ptr<DeviceQueueTracing> profiling_node; std::shared_ptr<DeviceQueueTracing> profiling_node;
bool isProfilingEnable = tree_->GetProfilingManager()->IsProfilingEnable(); bool is_profiling_enable = tree_->GetProfilingManager()->IsProfilingEnable();
if (isProfilingEnable) { if (is_profiling_enable) {
std::shared_ptr<Tracing> node; std::shared_ptr<Tracing> node;
RETURN_IF_NOT_OK(tree_->GetProfilingManager()->GetTracingNode(kDeviceQueueTracingName, &node)); RETURN_IF_NOT_OK(tree_->GetProfilingManager()->GetTracingNode(kDeviceQueueTracingName, &node));
profiling_node = std::dynamic_pointer_cast<DeviceQueueTracing>(node); profiling_node = std::dynamic_pointer_cast<DeviceQueueTracing>(node);
@ -452,7 +454,7 @@ Status DeviceQueueOp::PushDataToGPU() {
RETURN_IF_NOT_OK(RetryPushData(handle, items)); RETURN_IF_NOT_OK(RetryPushData(handle, items));
send_batch++; send_batch++;
#ifndef ENABLE_SECURITY #ifndef ENABLE_SECURITY
if (isProfilingEnable) { if (is_profiling_enable) {
uint64_t end_time = ProfilingTime::GetCurMilliSecond(); uint64_t end_time = ProfilingTime::GetCurMilliSecond();
// record push data time // record push data time
profiling_node->Record(TIME, TDT_PUSH_TIME, send_batch, push_cost, end_time); profiling_node->Record(TIME, TDT_PUSH_TIME, send_batch, push_cost, end_time);
@ -502,7 +504,7 @@ Status DeviceQueueOp::PushDataToGPU() {
} }
Status DeviceQueueOp::RetryPushData(unsigned int handle, const std::vector<DataItemGpu> &items) { Status DeviceQueueOp::RetryPushData(unsigned int handle, const std::vector<DataItemGpu> &items) {
bool flagLog = false; bool flag_log = false;
while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) { while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) {
BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME); BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME);
if (ret) { if (ret) {
@ -511,9 +513,9 @@ Status DeviceQueueOp::RetryPushData(unsigned int handle, const std::vector<DataI
"Invalid data, check the output of dataset with creating iterator and print data item."); "Invalid data, check the output of dataset with creating iterator and print data item.");
} else { } else {
if (!stop_send_) { if (!stop_send_) {
if (!flagLog) { if (!flag_log) {
MS_LOG(DEBUG) << "Retry pushing data..."; MS_LOG(DEBUG) << "Retry pushing data...";
flagLog = true; flag_log = true;
} }
continue; continue;
} }
@ -669,11 +671,11 @@ void DeviceQueueOp::Print(std::ostream &out, bool show_all) const {
} }
#ifndef ENABLE_SECURITY #ifndef ENABLE_SECURITY
void DeviceQueueOp::ProfilingRecorder(bool isProfilingEnable, std::shared_ptr<DeviceQueueTracing> profiling_node, void DeviceQueueOp::ProfilingRecorder(bool is_profiling_enable, std::shared_ptr<DeviceQueueTracing> profiling_node,
int64_t send_batch, int32_t tdt_cost, uint64_t *batch_start_time, int64_t send_batch, int32_t tdt_cost, uint64_t *batch_start_time,
uint64_t *end_time, int32_t connector_capacity, int32_t connector_size) { uint64_t *end_time, int32_t connector_capacity, int32_t connector_size) {
// Record the pipeline profiling info // Record the pipeline profiling info
if (isProfilingEnable) { if (is_profiling_enable) {
*end_time = ProfilingTime::GetCurMilliSecond(); *end_time = ProfilingTime::GetCurMilliSecond();
// record push tdt time // record push tdt time
profiling_node->Record(TIME, TDT_PUSH_TIME, send_batch + 1, tdt_cost, *end_time); profiling_node->Record(TIME, TDT_PUSH_TIME, send_batch + 1, tdt_cost, *end_time);

View File

@ -76,7 +76,7 @@ class DeviceQueueOp : public PipelineOp {
Status EoeReceived(int32_t worker_id) override; Status EoeReceived(int32_t worker_id) override;
const int32_t get_prefetch_size() { return prefetch_size_; } const int32_t GetPrefetchSize() { return prefetch_size_; }
void StopSend() { stop_send_ = true; } void StopSend() { stop_send_ = true; }
@ -105,11 +105,11 @@ class DeviceQueueOp : public PipelineOp {
Status operator()() override; Status operator()() override;
#ifndef ENABLE_SECURITY #ifndef ENABLE_SECURITY
// Record the pipeline profiling info // Record the pipeline profiling info
void ProfilingRecorder(bool isProfilingEnable, std::shared_ptr<DeviceQueueTracing> profiling_node, int64_t send_batch, void ProfilingRecorder(bool is_profiling_enable, std::shared_ptr<DeviceQueueTracing> profiling_node,
int32_t tdt_cost, uint64_t *batch_start_time, uint64_t *end_time, int32_t connector_capacity, int64_t send_batch, int32_t tdt_cost, uint64_t *batch_start_time, uint64_t *end_time,
int32_t connector_size); int32_t connector_capacity, int32_t connector_size);
#endif
#endif
// Op name getter // Op name getter
// @return Name of the current Op // @return Name of the current Op
std::string Name() const override { return kDeviceQueueOp; } std::string Name() const override { return kDeviceQueueOp; }
@ -128,7 +128,7 @@ class DeviceQueueOp : public PipelineOp {
void WaitContinueSignal() const; void WaitContinueSignal() const;
Status SendDataToAscend(); Status SendDataToAscend();
void LimitSendingBatches(int64_t send_batch, int64_t *sending_num, std::shared_ptr<ConfigManager> cfg); void LimitSendingBatches(int64_t send_batch, int64_t *sending_num, std::shared_ptr<ConfigManager> cfg);
Status SendRowToTdt(TensorRow currRow, bool isProfilingEnable, int32_t *tdt_cost); Status SendRowToTdt(TensorRow curr_row, bool is_profiling_enable, int32_t *tdt_cost);
bool ascend_keep_waiting_; bool ascend_keep_waiting_;
#endif #endif

View File

@ -218,7 +218,7 @@ Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate
return Status(StatusCode::kSuccess, "FilterOp predicate func call succeed"); return Status(StatusCode::kSuccess, "FilterOp predicate func call succeed");
} }
int32_t FilterOp::num_consumers() const { return 1; } int32_t FilterOp::NumConsumers() const { return 1; }
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -69,7 +69,7 @@ class FilterOp : public ParallelOp {
// @return Name of the current Op // @return Name of the current Op
std::string Name() const override { return kFilterOp; } std::string Name() const override { return kFilterOp; }
int32_t num_consumers() const override; int32_t NumConsumers() const override;
private: private:
// predicate_func python callable which returns a boolean value. // predicate_func python callable which returns a boolean value.

View File

@ -46,7 +46,7 @@ MapOp::MapOp(const std::vector<std::string> &in_col_names, const std::vector<std
} }
// The number of threads consuming data from previous op's output Connector. // The number of threads consuming data from previous op's output Connector.
int32_t MapOp::num_consumers() const { int32_t MapOp::NumConsumers() const {
// When Performance Mode is on, there is only one thread consuming from the previous Connector. // When Performance Mode is on, there is only one thread consuming from the previous Connector.
return 1; return 1;
} }
@ -144,7 +144,7 @@ Status MapOp::operator()() {
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
while (!new_row.eof()) { while (!new_row.eof()) {
if (op_current_repeats_ % op_num_repeats_per_epoch() == 0) { if (op_current_repeats_ % GetOpNumRepeatsPerEpoch() == 0) {
RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
} }
while (!new_row.eoe()) { while (!new_row.eoe()) {
@ -168,7 +168,7 @@ Status MapOp::operator()() {
} }
// check whether this is the end of a real epoch (not all eoe signals end of epoch) // check whether this is the end of a real epoch (not all eoe signals end of epoch)
if ((op_current_repeats_ + 1) % op_num_repeats_per_epoch() == 0) { if ((op_current_repeats_ + 1) % GetOpNumRepeatsPerEpoch() == 0) {
RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
ep_step = 0; ep_step = 0;

View File

@ -101,7 +101,7 @@ class MapOp : public ParallelOp {
// Getter // Getter
// @return the number of threads consuming data from previous op's output Connector. // @return the number of threads consuming data from previous op's output Connector.
int32_t num_consumers() const override; int32_t NumConsumers() const override;
// Op name getter // Op name getter
// @return Name of the current Op // @return Name of the current Op

View File

@ -72,11 +72,11 @@ class ParallelOp : public DatasetOp {
// Getter // Getter
// @return the number of workers // @return the number of workers
int32_t num_workers() const override { return num_workers_; } int32_t NumWorkers() const override { return num_workers_; }
// Getter // Getter
// @return the number of threads consuming from the previous Connector // @return the number of threads consuming from the previous Connector
int32_t num_consumers() const override { return num_workers_; } int32_t NumConsumers() const override { return num_workers_; }
// Getter // Getter
// @return the number of producers pushing to the output Connector // @return the number of producers pushing to the output Connector
@ -84,7 +84,7 @@ class ParallelOp : public DatasetOp {
// when a worker connector is set up. In that case, there are n workers, and a single master // when a worker connector is set up. In that case, there are n workers, and a single master
// such that only 1 thread is a producer rather than the n workers. // such that only 1 thread is a producer rather than the n workers.
// @return the number of producers // @return the number of producers
int32_t num_producers() const override { return num_producers_; } int32_t NumProducers() const override { return num_producers_; }
// Register the internal worker connectors. // Register the internal worker connectors.
// @return Status // @return Status

View File

@ -55,15 +55,15 @@ class PipelineOp : public DatasetOp {
// Getter // Getter
// @return The number of workers inside this op. Pipeline ops only have a single worker. // @return The number of workers inside this op. Pipeline ops only have a single worker.
int32_t num_workers() const override { return 1; } int32_t NumWorkers() const override { return 1; }
// Getter // Getter
// @return the number of threads consuming from the previous Connector // @return the number of threads consuming from the previous Connector
int32_t num_consumers() const override { return 1; } int32_t NumConsumers() const override { return 1; }
// Getter // Getter
// @return The number of threads that push data to the output connector // @return The number of threads that push data to the output connector
int32_t num_producers() const override { return 1; } int32_t NumProducers() const override { return 1; }
protected: protected:
// ******************************************************************************* // *******************************************************************************

View File

@ -76,7 +76,7 @@ TensorRow ProjectOp::Project(const TensorRow &row) {
// ensure that it is not called by mistake (it will generate an error). // ensure that it is not called by mistake (it will generate an error).
Status ProjectOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. ProjectOp is an inlined operator."); } Status ProjectOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. ProjectOp is an inlined operator."); }
int32_t ProjectOp::num_consumers() const { int32_t ProjectOp::NumConsumers() const {
if (parent_.empty()) { if (parent_.empty()) {
MS_LOG(DEBUG) << "Project operator, no parent node, assuming it's the root and returning 1."; MS_LOG(DEBUG) << "Project operator, no parent node, assuming it's the root and returning 1.";
return 1; return 1;
@ -84,16 +84,16 @@ int32_t ProjectOp::num_consumers() const {
MS_LOG(DEBUG) << "Project operator, pointer to the first parent is null. Returning 0."; MS_LOG(DEBUG) << "Project operator, pointer to the first parent is null. Returning 0.";
return 0; return 0;
} else { } else {
return parent_[0]->num_consumers(); return parent_[0]->NumConsumers();
} }
} }
int32_t ProjectOp::num_producers() const { int32_t ProjectOp::NumProducers() const {
if (child_.empty() || child_[0] == nullptr) { if (child_.empty() || child_[0] == nullptr) {
MS_LOG(DEBUG) << "Project operator, pointer to child node is null. Returning 0."; MS_LOG(DEBUG) << "Project operator, pointer to child node is null. Returning 0.";
return 0; return 0;
} else { } else {
return child_[0]->num_producers(); return child_[0]->NumProducers();
} }
} }

View File

@ -63,11 +63,11 @@ class ProjectOp : public PipelineOp {
// Base-class override. Return the number of workers in the first parent. // Base-class override. Return the number of workers in the first parent.
// @param workerId - The worker id // @param workerId - The worker id
int32_t num_consumers() const override; int32_t NumConsumers() const override;
// Base-class override. Return the number of producers in the first child. // Base-class override. Return the number of producers in the first child.
// @param workerId - The worker id // @param workerId - The worker id
int32_t num_producers() const override; int32_t NumProducers() const override;
// Base-class override for special eoe handler. // Base-class override for special eoe handler.
// Inline operators must override this because there is no connector to push eoe onto. // Inline operators must override this because there is no connector to push eoe onto.

View File

@ -130,7 +130,7 @@ void RenameOp::Print(std::ostream &out, // In: The output stream to print t
} }
} }
int32_t RenameOp::num_consumers() const { int32_t RenameOp::NumConsumers() const {
if (parent_.empty()) { if (parent_.empty()) {
MS_LOG(DEBUG) << "Rename operator, no parent node, assuming it's the root and returning 1."; MS_LOG(DEBUG) << "Rename operator, no parent node, assuming it's the root and returning 1.";
return 1; return 1;
@ -138,16 +138,16 @@ int32_t RenameOp::num_consumers() const {
MS_LOG(DEBUG) << "Rename operator, pointer to the first parent is null. Returning 0."; MS_LOG(DEBUG) << "Rename operator, pointer to the first parent is null. Returning 0.";
return 0; return 0;
} else { } else {
return parent_[0]->num_consumers(); return parent_[0]->NumConsumers();
} }
} }
int32_t RenameOp::num_producers() const { int32_t RenameOp::NumProducers() const {
if (child_.empty() || child_[0] == nullptr) { if (child_.empty() || child_[0] == nullptr) {
MS_LOG(DEBUG) << "Rename operator, pointer to child node is null. Returning 0."; MS_LOG(DEBUG) << "Rename operator, pointer to child node is null. Returning 0.";
return 0; return 0;
} else { } else {
return child_[0]->num_producers(); return child_[0]->NumProducers();
} }
} }
} // namespace dataset } // namespace dataset

View File

@ -63,8 +63,8 @@ class RenameOp : public PipelineOp {
// @param row - output pointer to the projected row. // @param row - output pointer to the projected row.
// @param worker_id - The worker id // @param worker_id - The worker id
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override; Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
int32_t num_consumers() const override; int32_t NumConsumers() const override;
int32_t num_producers() const override; int32_t NumProducers() const override;
protected: protected:
// Rename core functionality // Rename core functionality

View File

@ -115,7 +115,7 @@ Status RepeatOp::EofReceived(int32_t worker_id) {
return Status::OK(); return Status::OK();
} }
int32_t RepeatOp::num_consumers() const { int32_t RepeatOp::NumConsumers() const {
if (parent_.empty()) { if (parent_.empty()) {
MS_LOG(DEBUG) << "Repeat operator, no parent node, assuming it's root and returning 1."; MS_LOG(DEBUG) << "Repeat operator, no parent node, assuming it's root and returning 1.";
return 1; return 1;
@ -123,7 +123,7 @@ int32_t RepeatOp::num_consumers() const {
MS_LOG(DEBUG) << "Repeat operator, pointer to the first parent is null. Returning 0."; MS_LOG(DEBUG) << "Repeat operator, pointer to the first parent is null. Returning 0.";
return 0; return 0;
} else { } else {
return parent_[0]->num_consumers(); return parent_[0]->NumConsumers();
} }
} }
@ -140,12 +140,12 @@ Status RepeatOp::Reset() {
return Status::OK(); return Status::OK();
} }
int32_t RepeatOp::num_producers() const { int32_t RepeatOp::NumProducers() const {
if (child_.empty() || child_[0] == nullptr) { if (child_.empty() || child_[0] == nullptr) {
MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0."; MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0.";
return 0; return 0;
} else { } else {
return child_[0]->num_producers(); return child_[0]->NumProducers();
} }
} }

View File

@ -79,11 +79,11 @@ class RepeatOp : public PipelineOp {
// Base-class override. Return the number of workers in the first parent. // Base-class override. Return the number of workers in the first parent.
// @param workerId - The worker id // @param workerId - The worker id
int32_t num_consumers() const override; int32_t NumConsumers() const override;
// Base-class override. Return the number of producers in the first child. // Base-class override. Return the number of producers in the first child.
// @param workerId - The worker id // @param workerId - The worker id
int32_t num_producers() const override; int32_t NumProducers() const override;
// Op name getter // Op name getter
// @return Name of the current Op // @return Name of the current Op

View File

@ -66,7 +66,7 @@ Status SkipOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe)
return Status::OK(); return Status::OK();
} }
int32_t SkipOp::num_consumers() const { int32_t SkipOp::NumConsumers() const {
if (parent_.empty()) { if (parent_.empty()) {
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1."; MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
return 1; return 1;
@ -74,16 +74,16 @@ int32_t SkipOp::num_consumers() const {
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0."; MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
return 0; return 0;
} else { } else {
return parent_[0]->num_consumers(); return parent_[0]->NumConsumers();
} }
} }
int32_t SkipOp::num_producers() const { int32_t SkipOp::NumProducers() const {
if (child_.empty() || child_[0] == nullptr) { if (child_.empty() || child_[0] == nullptr) {
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0."; MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
return 0; return 0;
} else { } else {
return child_[0]->num_producers(); return child_[0]->NumProducers();
} }
} }

View File

@ -49,8 +49,8 @@ class SkipOp : public PipelineOp {
// @return Name of the current Op // @return Name of the current Op
std::string Name() const override { return kSkipOp; } std::string Name() const override { return kSkipOp; }
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override; Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
int32_t num_consumers() const override; int32_t NumConsumers() const override;
int32_t num_producers() const override; int32_t NumProducers() const override;
private: private:
int32_t max_skips_; // The number of skips that the user requested int32_t max_skips_; // The number of skips that the user requested

View File

@ -177,9 +177,9 @@ bool CelebAOp::CheckDatasetTypeValid() {
Status CelebAOp::ParseImageAttrInfo() { Status CelebAOp::ParseImageAttrInfo() {
std::vector<std::string> image_infos; std::vector<std::string> image_infos;
bool needMoreData = true; bool need_more_data = true;
RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos)); RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos));
while (!image_infos.empty() && needMoreData) { while (!image_infos.empty() && need_more_data) {
for (uint32_t index = 0; index < image_infos.size(); index++) { for (uint32_t index = 0; index < image_infos.size(); index++) {
std::string image_info = image_infos[index]; std::string image_info = image_infos[index];
std::vector<std::string> split = Split(image_info); std::vector<std::string> split = Split(image_info);

View File

@ -201,13 +201,13 @@ Status CifarOp::ReadCifar100BlockData() {
} }
Status CifarOp::GetCifarFiles() { Status CifarOp::GetCifarFiles() {
const std::string kExtension = ".bin"; const std::string extension = ".bin";
Path dir_path(folder_path_); Path dir_path(folder_path_);
auto dirIt = Path::DirIterator::OpenDirectory(&dir_path); auto dirIt = Path::DirIterator::OpenDirectory(&dir_path);
if (dirIt) { if (dirIt) {
while (dirIt->HasNext()) { while (dirIt->HasNext()) {
Path file = dirIt->Next(); Path file = dirIt->Next();
if (file.Extension() == kExtension) { if (file.Extension() == extension) {
cifar_files_.push_back(file.ToString()); cifar_files_.push_back(file.ToString());
} }
} }

View File

@ -101,7 +101,7 @@ class CifarOp : public MappableLeafOp {
Status ParseCifarData(); Status ParseCifarData();
/// Method derived from RandomAccess Op, enable Sampler to get all ids for each class /// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
/// @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class /// @param (std::map<uint32_t, std::vector<uint32_t >> *cls_ids - val all ids for this class
/// @return Status The status code returned /// @return Status The status code returned
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override; Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;

View File

@ -120,20 +120,20 @@ Status ClueOp::LoadFile(const std::string &file, int64_t start_offset, int64_t e
RETURN_STATUS_UNEXPECTED("Invalid file, failed to parse JSON file: " + file); RETURN_STATUS_UNEXPECTED("Invalid file, failed to parse JSON file: " + file);
} }
int cols_count = cols_to_keyword_.size(); int cols_count = cols_to_keyword_.size();
TensorRow tRow(cols_count, nullptr); TensorRow t_row(cols_count, nullptr);
// Add file path info // Add file path info
std::vector<std::string> file_path(cols_count, file); std::vector<std::string> file_path(cols_count, file);
tRow.setPath(file_path); t_row.setPath(file_path);
int cout = 0; int cout = 0;
for (auto &p : cols_to_keyword_) { for (auto &p : cols_to_keyword_) {
std::shared_ptr<Tensor> tensor; std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(GetValue(js, p.second, &tensor)); RETURN_IF_NOT_OK(GetValue(js, p.second, &tensor));
tRow[cout] = std::move(tensor); t_row[cout] = std::move(tensor);
cout++; cout++;
} }
rows_total++; rows_total++;
RETURN_IF_NOT_OK(jagged_rows_connector_->Add(worker_id, std::move(tRow))); RETURN_IF_NOT_OK(jagged_rows_connector_->Add(worker_id, std::move(t_row)));
} }
return Status::OK(); return Status::OK();

View File

@ -215,7 +215,7 @@ Status GeneratorOp::operator()() {
// Waiting for repeatOp to start new epoch // Waiting for repeatOp to start new epoch
// If Reset() is called first by repeat op, this wait() will return right away. // If Reset() is called first by repeat op, this wait() will return right away.
// If Reset() is not called yet, this wait() will block until reset. // If Reset() is not called yet, this wait() will block until reset.
if (this->op_total_repeats() < 0) { if (this->GetOpTotalRepeats() < 0) {
RETURN_IF_NOT_OK(wp_.Wait()); RETURN_IF_NOT_OK(wp_.Wait());
// Clear the status of the wait post // Clear the status of the wait post
wp_.Clear(); wp_.Clear();
@ -235,7 +235,7 @@ Status GeneratorOp::Reset() {
MS_LOG(DEBUG) << Name() << " performing a self-reset."; MS_LOG(DEBUG) << Name() << " performing a self-reset.";
// Create new generator object // Create new generator object
RETURN_IF_NOT_OK(CreateGeneratorObject()); RETURN_IF_NOT_OK(CreateGeneratorObject());
if (this->op_total_repeats() < 0) { if (this->GetOpTotalRepeats() < 0) {
// Wake up master thread // Wake up master thread
wp_.Set(); wp_.Set();
} }

View File

@ -84,20 +84,20 @@ Status ImageFolderOp::PrescanMasterEntry(const std::string &filedir) {
// Load 1 TensorRow (image,label) using 1 ImageLabelPair. 1 function call produces 1 TensorTow // Load 1 TensorRow (image,label) using 1 ImageLabelPair. 1 function call produces 1 TensorTow
Status ImageFolderOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) { Status ImageFolderOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
ImageLabelPair pairPtr = image_label_pairs_[row_id]; ImageLabelPair pair_ptr = image_label_pairs_[row_id];
std::shared_ptr<Tensor> image, label; std::shared_ptr<Tensor> image, label;
RETURN_IF_NOT_OK(Tensor::CreateScalar(pairPtr->second, &label)); RETURN_IF_NOT_OK(Tensor::CreateScalar(pair_ptr->second, &label));
RETURN_IF_NOT_OK(Tensor::CreateFromFile(folder_path_ + (pairPtr->first), &image)); RETURN_IF_NOT_OK(Tensor::CreateFromFile(folder_path_ + (pair_ptr->first), &image));
if (decode_ == true) { if (decode_ == true) {
Status rc = Decode(image, &image); Status rc = Decode(image, &image);
if (rc.IsError()) { if (rc.IsError()) {
std::string err = "Invalid data, failed to decode image: " + folder_path_ + (pairPtr->first); std::string err = "Invalid data, failed to decode image: " + folder_path_ + (pair_ptr->first);
RETURN_STATUS_UNEXPECTED(err); RETURN_STATUS_UNEXPECTED(err);
} }
} }
(*trow) = TensorRow(row_id, {std::move(image), std::move(label)}); (*trow) = TensorRow(row_id, {std::move(image), std::move(label)});
trow->setPath({folder_path_ + (pairPtr->first), std::string("")}); trow->setPath({folder_path_ + (pair_ptr->first), std::string("")});
return Status::OK(); return Status::OK();
} }

View File

@ -57,11 +57,14 @@ class ImageFolderOp : public MappableLeafOp {
// @param int32_t num_wkrs - Num of workers reading images in parallel // @param int32_t num_wkrs - Num of workers reading images in parallel
// @param std::string - dir directory of ImageNetFolder // @param std::string - dir directory of ImageNetFolder
// @param int32_t queue_size - connector queue size // @param int32_t queue_size - connector queue size
// @param std::set<std::string> exts - set of file extensions to read, if empty, read everything under the dir // @param bool recursive - read recursively
// @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read // @param bool do_decode - decode the images after reading
// @param std::set<std::string> &exts - set of file extensions to read, if empty, read everything under the dir
// @param std::map<std::string, int32_t> &map- map of folder name and class id
// @param std::unique_ptr<dataschema> data_schema - schema of data
ImageFolderOp(int32_t num_wkrs, std::string file_dir, int32_t queue_size, bool recursive, bool do_decode, ImageFolderOp(int32_t num_wkrs, std::string file_dir, int32_t queue_size, bool recursive, bool do_decode,
const std::set<std::string> &exts, const std::map<std::string, int32_t> &map, const std::set<std::string> &exts, const std::map<std::string, int32_t> &map,
std::unique_ptr<DataSchema>, std::shared_ptr<SamplerRT> sampler); std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
/// Destructor. /// Destructor.
~ImageFolderOp() = default; ~ImageFolderOp() = default;

View File

@ -223,7 +223,9 @@ Status MnistOp::WalkAllFiles() {
const std::string train_prefix = "train"; const std::string train_prefix = "train";
const std::string test_prefix = "t10k"; const std::string test_prefix = "t10k";
Path dir(folder_path_); std::string real_path{""};
RETURN_IF_NOT_OK(Path::RealPath(folder_path_, real_path));
Path dir(real_path);
auto dir_it = Path::DirIterator::OpenDirectory(&dir); auto dir_it = Path::DirIterator::OpenDirectory(&dir);
std::string prefix; // empty string, used to match usage = "" (default) or usage == "all" std::string prefix; // empty string, used to match usage = "" (default) or usage == "all"
if (usage_ == "train" || usage_ == "test") prefix = (usage_ == "test" ? test_prefix : train_prefix); if (usage_ == "train" || usage_ == "test") prefix = (usage_ == "test" ? test_prefix : train_prefix);

View File

@ -68,15 +68,15 @@ void RandomDataOp::Print(std::ostream &out, bool show_all) const {
// Helper function to produce a default/random schema if one didn't exist // Helper function to produce a default/random schema if one didn't exist
void RandomDataOp::GenerateSchema() { void RandomDataOp::GenerateSchema() {
const int32_t kTypeOffset = 2; const int32_t type_offset = 2;
// To randomly create a schema, we need to choose: // To randomly create a schema, we need to choose:
// a) how many columns // a) how many columns
// b) the type of each column // b) the type of each column
// c) the shape of each column (number of dimensions i.e. rank) // c) the shape of each column (number of dimensions i.e. rank)
// d) the shape of each column (dimension values) // d) the shape of each column (dimension values)
data_schema_ = std::make_unique<DataSchema>(); data_schema_ = std::make_unique<DataSchema>();
std::unique_ptr<TensorShape> newShape; std::unique_ptr<TensorShape> new_shape;
std::unique_ptr<ColDescriptor> newCol; std::unique_ptr<ColDescriptor> new_col;
// Loop over the number of chosen columns // Loop over the number of chosen columns
int32_t numColumns = GenRandomInt(1, kMaxNumColumns); int32_t numColumns = GenRandomInt(1, kMaxNumColumns);
@ -84,23 +84,26 @@ void RandomDataOp::GenerateSchema() {
// For each column: // For each column:
// - choose a datatype // - choose a datatype
// - generate a shape that randomly chooses the number of dimensions and the dimension values. // - generate a shape that randomly chooses the number of dimensions and the dimension values.
DataType::Type newType = static_cast<DataType::Type>(GenRandomInt(1, DataType::NUM_OF_TYPES - kTypeOffset)); DataType::Type newType = static_cast<DataType::Type>(GenRandomInt(1, DataType::NUM_OF_TYPES - type_offset));
int32_t rank = GenRandomInt(1, kMaxRank); int32_t rank = GenRandomInt(1, kMaxRank);
std::vector<dsize_t> dims; std::vector<dsize_t> dims;
for (int32_t d = 0; d < rank; d++) { for (int32_t d = 0; d < rank; d++) {
// 0 is not a valid dimension value. however, we can support "*" or unknown, so map the random // 0 is not a valid dimension value. however, we can support "*" or unknown, so map the random
// 0 value to the unknown attribute if 0 is chosen // 0 value to the unknown attribute if 0 is chosen
dsize_t dim_value = static_cast<dsize_t>(GenRandomInt(0, kMaxDimValue)); dsize_t dim_value = static_cast<dsize_t>(GenRandomInt(0, kMaxDimValue));
if (dim_value == 0) dim_value = TensorShape::kDimUnknown; if (dim_value == 0) {
dim_value = TensorShape::kDimUnknown;
}
dims.push_back(dim_value); dims.push_back(dim_value);
} }
newShape = std::make_unique<TensorShape>(dims); new_shape = std::make_unique<TensorShape>(dims);
// Create the column descriptor // Create the column descriptor
std::string colName = "c" + std::to_string(i); std::string col_name = "c" + std::to_string(i);
newCol = std::make_unique<ColDescriptor>(colName, DataType(newType), TensorImpl::kFlexible, rank, newShape.get()); new_col =
std::make_unique<ColDescriptor>(col_name, DataType(newType), TensorImpl::kFlexible, rank, new_shape.get());
Status rc = data_schema_->AddColumn(*newCol); Status rc = data_schema_->AddColumn(*new_col);
if (rc.IsError()) MS_LOG(ERROR) << "Failed to generate a schema. Message:" << rc; if (rc.IsError()) MS_LOG(ERROR) << "Failed to generate a schema. Message:" << rc;
} }
} }
@ -125,6 +128,9 @@ Status RandomDataOp::operator()() {
DatasetOp::CreateConnector(num_producers_, num_workers_); DatasetOp::CreateConnector(num_producers_, num_workers_);
} }
if (num_workers_ == 0) {
RETURN_STATUS_UNEXPECTED("Invalid data, num_workers_ is zero.");
}
// Assign the number of rows to each worker in a round robin fashion. // Assign the number of rows to each worker in a round robin fashion.
worker_max_rows_.reserve(num_workers_); worker_max_rows_.reserve(num_workers_);
worker_rows_packed_.reserve(num_workers_); worker_rows_packed_.reserve(num_workers_);

View File

@ -148,10 +148,10 @@ Status VOCOp::ParseImageIds() {
Status VOCOp::ParseAnnotationIds() { Status VOCOp::ParseAnnotationIds() {
std::vector<std::string> new_image_ids; std::vector<std::string> new_image_ids;
for (auto id : image_ids_) { for (auto id : image_ids_) {
const std::string kAnnotationName = const std::string annotation_name =
folder_path_ + std::string(kAnnotationsFolder) + id + std::string(kAnnotationExtension); folder_path_ + std::string(kAnnotationsFolder) + id + std::string(kAnnotationExtension);
RETURN_IF_NOT_OK(ParseAnnotationBbox(kAnnotationName)); RETURN_IF_NOT_OK(ParseAnnotationBbox(annotation_name));
if (annotation_map_.find(kAnnotationName) != annotation_map_.end()) { if (annotation_map_.find(annotation_name) != annotation_map_.end()) {
new_image_ids.push_back(id); new_image_ids.push_back(id);
} }
} }
@ -179,7 +179,9 @@ void VOCOp::ParseNodeValue(XMLElement *bbox_node, const char *name, float *value
*value = 0.0; *value = 0.0;
if (bbox_node != nullptr) { if (bbox_node != nullptr) {
XMLElement *node = bbox_node->FirstChildElement(name); XMLElement *node = bbox_node->FirstChildElement(name);
if (node != nullptr) *value = node->FloatText(); if (node != nullptr) {
*value = node->FloatText();
}
} }
} }

View File

@ -69,7 +69,7 @@ Status TakeOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe)
return Status::OK(); return Status::OK();
} }
int32_t TakeOp::num_consumers() const { int32_t TakeOp::NumConsumers() const {
if (parent_.empty()) { if (parent_.empty()) {
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1."; MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
return 1; return 1;
@ -77,16 +77,16 @@ int32_t TakeOp::num_consumers() const {
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0."; MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
return 0; return 0;
} else { } else {
return parent_[0]->num_consumers(); return parent_[0]->NumConsumers();
} }
} }
int32_t TakeOp::num_producers() const { int32_t TakeOp::NumProducers() const {
if (child_.empty() || child_[0] == nullptr) { if (child_.empty() || child_[0] == nullptr) {
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0."; MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
return 0; return 0;
} else { } else {
return child_[0]->num_producers(); return child_[0]->NumProducers();
} }
} }
} // namespace dataset } // namespace dataset

View File

@ -59,8 +59,8 @@ class TakeOp : public PipelineOp {
std::string Name() const override { return kTakeOp; } std::string Name() const override { return kTakeOp; }
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override; Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
int32_t num_consumers() const override; int32_t NumConsumers() const override;
int32_t num_producers() const override; int32_t NumProducers() const override;
private: private:
int32_t max_takes_; // The number of takes that the user requested int32_t max_takes_; // The number of takes that the user requested

View File

@ -132,7 +132,7 @@ Status ZipOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) {
return Status::OK(); return Status::OK();
} }
int32_t ZipOp::num_consumers() const { int32_t ZipOp::NumConsumers() const {
if (parent_.empty()) { if (parent_.empty()) {
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1."; MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
return 1; return 1;
@ -140,16 +140,16 @@ int32_t ZipOp::num_consumers() const {
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0."; MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
return 0; return 0;
} else { } else {
return parent_[0]->num_consumers(); return parent_[0]->NumConsumers();
} }
} }
int32_t ZipOp::num_producers() const { int32_t ZipOp::NumProducers() const {
if (child_.empty() || child_[0] == nullptr) { if (child_.empty() || child_[0] == nullptr) {
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0."; MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
return 0; return 0;
} else { } else {
return child_[0]->num_producers(); return child_[0]->NumProducers();
} }
} }
} // namespace dataset } // namespace dataset

View File

@ -65,8 +65,8 @@ class ZipOp : public PipelineOp {
std::string Name() const override { return kZipOp; } std::string Name() const override { return kZipOp; }
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override; Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
int32_t num_consumers() const override; int32_t NumConsumers() const override;
int32_t num_producers() const override; int32_t NumProducers() const override;
private: private:
// Special handle case where an empty row has been received from child iterator // Special handle case where an empty row has been received from child iterator

View File

@ -85,7 +85,7 @@ Status ExecutionTree::AssociateNode(const std::shared_ptr<DatasetOp> &op) {
tree_state_ = kDeTStateBuilding; tree_state_ = kDeTStateBuilding;
// Assign an id to the operator // Assign an id to the operator
op->set_id(id_count_); op->SetId(id_count_);
id_count_++; id_count_++;
// Assign our tree into the op so that each op has a link back to the tree // Assign our tree into the op so that each op has a link back to the tree

View File

@ -104,8 +104,8 @@ Status BatchNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
auto op = std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_, auto op = std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
in_col_names_, out_col_names_, batch_size_func_, batch_map_func_, pad_map_); in_col_names_, out_col_names_, batch_size_func_, batch_map_func_, pad_map_);
op->set_total_repeats(GetTotalRepeats()); op->SetTotalRepeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op); node_ops->push_back(op);
#else #else
node_ops->push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_, node_ops->push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,

View File

@ -87,8 +87,8 @@ Status BucketBatchByLengthNode::Build(std::vector<std::shared_ptr<DatasetOp>> *c
auto op = std::make_shared<BucketBatchByLengthOp>(column_names_, bucket_boundaries_, bucket_batch_sizes_, auto op = std::make_shared<BucketBatchByLengthOp>(column_names_, bucket_boundaries_, bucket_batch_sizes_,
element_length_function_, pad_info_, pad_to_bucket_boundary_, element_length_function_, pad_info_, pad_to_bucket_boundary_,
drop_remainder_, connector_que_size_); drop_remainder_, connector_que_size_);
op->set_total_repeats(GetTotalRepeats()); op->SetTotalRepeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op); node_ops->push_back(op);
if (bucket_boundaries_[0] == 0) { if (bucket_boundaries_[0] == 0) {
bucket_boundaries_.erase(bucket_boundaries_.begin()); bucket_boundaries_.erase(bucket_boundaries_.begin());

View File

@ -56,8 +56,8 @@ void BuildSentenceVocabNode::Print(std::ostream &out) const {
Status BuildSentenceVocabNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { Status BuildSentenceVocabNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto op = std::make_shared<BuildSentencePieceVocabOp>(vocab_, col_names_, vocab_size_, character_coverage_, auto op = std::make_shared<BuildSentencePieceVocabOp>(vocab_, col_names_, vocab_size_, character_coverage_,
model_type_, params_, connector_que_size_); model_type_, params_, connector_que_size_);
op->set_total_repeats(GetTotalRepeats()); op->SetTotalRepeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op); node_ops->push_back(op);
return Status::OK(); return Status::OK();
} }

View File

@ -54,8 +54,8 @@ Status BuildVocabNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node
std::shared_ptr<BuildVocabOp> build_vocab_op; std::shared_ptr<BuildVocabOp> build_vocab_op;
build_vocab_op = std::make_shared<BuildVocabOp>(vocab_, columns_, freq_range_, top_k_, special_tokens_, build_vocab_op = std::make_shared<BuildVocabOp>(vocab_, columns_, freq_range_, top_k_, special_tokens_,
special_first_, num_workers_, connector_que_size_); special_first_, num_workers_, connector_que_size_);
build_vocab_op->set_total_repeats(GetTotalRepeats()); build_vocab_op->SetTotalRepeats(GetTotalRepeats());
build_vocab_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); build_vocab_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(build_vocab_op); node_ops->push_back(build_vocab_op);
return Status::OK(); return Status::OK();
} }

View File

@ -51,8 +51,8 @@ Status CacheLookupNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops)
"Internal error. Attempt to create a cache lookup node without cache client."); "Internal error. Attempt to create a cache lookup node without cache client.");
RETURN_IF_NOT_OK(cache_->Build()); RETURN_IF_NOT_OK(cache_->Build());
RETURN_IF_NOT_OK(cache_->CreateCacheLookupOp(num_workers_, connector_que_size_, sampler_, &lookup_op_)); RETURN_IF_NOT_OK(cache_->CreateCacheLookupOp(num_workers_, connector_que_size_, sampler_, &lookup_op_));
lookup_op_->set_total_repeats(GetTotalRepeats()); lookup_op_->SetTotalRepeats(GetTotalRepeats());
lookup_op_->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); lookup_op_->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(lookup_op_); node_ops->push_back(lookup_op_);
return Status::OK(); return Status::OK();
} }

View File

@ -48,8 +48,8 @@ Status CacheMergeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops)
RETURN_IF_NOT_OK(cache_->Build()); RETURN_IF_NOT_OK(cache_->Build());
std::shared_ptr<DatasetOp> merge_op = nullptr; std::shared_ptr<DatasetOp> merge_op = nullptr;
RETURN_IF_NOT_OK(cache_->CreateCacheMergeOp(num_workers_, connector_que_size_, &merge_op)); RETURN_IF_NOT_OK(cache_->CreateCacheMergeOp(num_workers_, connector_que_size_, &merge_op));
merge_op->set_total_repeats(GetTotalRepeats()); merge_op->SetTotalRepeats(GetTotalRepeats());
merge_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); merge_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(merge_op); node_ops->push_back(merge_op);
return Status::OK(); return Status::OK();
} }

View File

@ -52,8 +52,8 @@ Status CacheNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
RETURN_IF_NOT_OK(cache_->Build()); RETURN_IF_NOT_OK(cache_->Build());
std::shared_ptr<DatasetOp> cache_op = nullptr; std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, connector_que_size_, sampler_, &cache_op)); RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, connector_que_size_, sampler_, &cache_op));
cache_op->set_total_repeats(GetTotalRepeats()); cache_op->SetTotalRepeats(GetTotalRepeats());
cache_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); cache_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(cache_op); node_ops->push_back(cache_op);
return Status::OK(); return Status::OK();
} }

View File

@ -133,8 +133,8 @@ Status ConcatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
op = std::make_shared<ConcatOp>(sampler_rt, children_flag_and_nums_, children_start_end_index_); op = std::make_shared<ConcatOp>(sampler_rt, children_flag_and_nums_, children_start_end_index_);
} }
op->set_total_repeats(GetTotalRepeats()); op->SetTotalRepeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op); node_ops->push_back(op);
return Status::OK(); return Status::OK();

View File

@ -257,7 +257,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
void HasCacheAbove() { descendant_of_cache_ = true; } void HasCacheAbove() { descendant_of_cache_ = true; }
/// \brief Getter of the number of workers /// \brief Getter of the number of workers
int32_t num_workers() { return num_workers_; } int32_t NumWorkers() { return num_workers_; }
/// \brief Getter of dataset cache /// \brief Getter of dataset cache
std::shared_ptr<DatasetCache> GetDatasetCache() { return cache_; } std::shared_ptr<DatasetCache> GetDatasetCache() { return cache_; }

View File

@ -46,8 +46,8 @@ void EpochCtrlNode::Print(std::ostream &out) const {
// Function to build the EpochCtrlOp // Function to build the EpochCtrlOp
Status EpochCtrlNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { Status EpochCtrlNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto new_op_ = std::make_shared<EpochCtrlOp>(repeat_count_); auto new_op_ = std::make_shared<EpochCtrlOp>(repeat_count_);
new_op_->set_total_repeats(GetTotalRepeats()); new_op_->SetTotalRepeats(GetTotalRepeats());
new_op_->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); new_op_->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(new_op_); node_ops->push_back(new_op_);
op_ = new_op_; op_ = new_op_;
return Status::OK(); return Status::OK();

View File

@ -45,8 +45,8 @@ void FilterNode::Print(std::ostream &out) const {
Status FilterNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { Status FilterNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto op = std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, predicate_); auto op = std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, predicate_);
op->set_total_repeats(GetTotalRepeats()); op->SetTotalRepeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op); node_ops->push_back(op);
return Status::OK(); return Status::OK();
} }

View File

@ -89,12 +89,12 @@ Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
if (!project_columns_.empty()) { if (!project_columns_.empty()) {
auto project_op = std::make_shared<ProjectOp>(project_columns_); auto project_op = std::make_shared<ProjectOp>(project_columns_);
project_op->set_total_repeats(GetTotalRepeats()); project_op->SetTotalRepeats(GetTotalRepeats());
project_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); project_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(project_op); node_ops->push_back(project_op);
} }
map_op->set_total_repeats(GetTotalRepeats()); map_op->SetTotalRepeats(GetTotalRepeats());
map_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); map_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(map_op); node_ops->push_back(map_op);
return Status::OK(); return Status::OK();
} }

View File

@ -54,8 +54,8 @@ Status ProjectNode::ValidateParams() {
Status ProjectNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { Status ProjectNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto op = std::make_shared<ProjectOp>(columns_); auto op = std::make_shared<ProjectOp>(columns_);
op->set_total_repeats(GetTotalRepeats()); op->SetTotalRepeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op); node_ops->push_back(op);
return Status::OK(); return Status::OK();
} }

View File

@ -59,8 +59,8 @@ Status RenameNode::ValidateParams() {
Status RenameNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { Status RenameNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto op = std::make_shared<RenameOp>(input_columns_, output_columns_); auto op = std::make_shared<RenameOp>(input_columns_, output_columns_);
op->set_total_repeats(GetTotalRepeats()); op->SetTotalRepeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op); node_ops->push_back(op);
return Status::OK(); return Status::OK();
} }

View File

@ -40,8 +40,8 @@ void RepeatNode::Print(std::ostream &out) const { out << (Name() + "(count:" + s
Status RepeatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { Status RepeatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto new_op = std::make_shared<RepeatOp>(repeat_count_); auto new_op = std::make_shared<RepeatOp>(repeat_count_);
new_op->set_total_repeats(GetTotalRepeats()); new_op->SetTotalRepeats(GetTotalRepeats());
new_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); new_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(new_op); node_ops->push_back(new_op);
op_ = new_op; op_ = new_op;

View File

@ -45,8 +45,8 @@ void ShuffleNode::Print(std::ostream &out) const {
// Function to build the ShuffleOp // Function to build the ShuffleOp
Status ShuffleNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { Status ShuffleNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto op = std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_); auto op = std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_);
op->set_total_repeats(GetTotalRepeats()); op->SetTotalRepeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op); node_ops->push_back(op);
return Status::OK(); return Status::OK();
} }

View File

@ -40,8 +40,8 @@ void SkipNode::Print(std::ostream &out) const { out << (Name() + "(skip_count:"
// Function to build the SkipOp // Function to build the SkipOp
Status SkipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { Status SkipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto op = std::make_shared<SkipOp>(skip_count_); auto op = std::make_shared<SkipOp>(skip_count_);
op->set_total_repeats(GetTotalRepeats()); op->SetTotalRepeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op); node_ops->push_back(op);
return Status::OK(); return Status::OK();
} }

View File

@ -79,8 +79,8 @@ Status AlbumNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
auto album_op = std::make_shared<AlbumOp>(num_workers_, dataset_dir_, connector_que_size_, decode_, extensions, auto album_op = std::make_shared<AlbumOp>(num_workers_, dataset_dir_, connector_que_size_, decode_, extensions,
std::move(schema), std::move(sampler_rt)); std::move(schema), std::move(sampler_rt));
album_op->set_total_repeats(GetTotalRepeats()); album_op->SetTotalRepeats(GetTotalRepeats());
album_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); album_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(album_op); node_ops->push_back(album_op);
return Status::OK(); return Status::OK();
} }

View File

@ -75,8 +75,8 @@ Status CelebANode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
auto celeba_op = std::make_shared<CelebAOp>(num_workers_, dataset_dir_, connector_que_size_, decode_, usage_, auto celeba_op = std::make_shared<CelebAOp>(num_workers_, dataset_dir_, connector_que_size_, decode_, usage_,
extensions_, std::move(schema), std::move(sampler_rt)); extensions_, std::move(schema), std::move(sampler_rt));
celeba_op->set_total_repeats(GetTotalRepeats()); celeba_op->SetTotalRepeats(GetTotalRepeats());
celeba_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); celeba_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(celeba_op); node_ops->push_back(celeba_op);
return Status::OK(); return Status::OK();

View File

@ -71,8 +71,8 @@ Status Cifar100Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
auto cifar_op = std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, dataset_dir_, auto cifar_op = std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, dataset_dir_,
connector_que_size_, std::move(schema), std::move(sampler_rt)); connector_que_size_, std::move(schema), std::move(sampler_rt));
cifar_op->set_total_repeats(GetTotalRepeats()); cifar_op->SetTotalRepeats(GetTotalRepeats());
cifar_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); cifar_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(cifar_op); node_ops->push_back(cifar_op);
return Status::OK(); return Status::OK();

View File

@ -69,8 +69,8 @@ Status Cifar10Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_op
auto cifar_op = std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, dataset_dir_, auto cifar_op = std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, dataset_dir_,
connector_que_size_, std::move(schema), std::move(sampler_rt)); connector_que_size_, std::move(schema), std::move(sampler_rt));
cifar_op->set_total_repeats(GetTotalRepeats()); cifar_op->SetTotalRepeats(GetTotalRepeats());
cifar_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); cifar_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(cifar_op); node_ops->push_back(cifar_op);
return Status::OK(); return Status::OK();

View File

@ -88,8 +88,8 @@ Status CityscapesNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node
auto cityscapes_op = std::make_shared<CityscapesOp>(num_workers_, dataset_dir_, usage_, quality_mode_, task_, decode_, auto cityscapes_op = std::make_shared<CityscapesOp>(num_workers_, dataset_dir_, usage_, quality_mode_, task_, decode_,
connector_que_size_, std::move(schema), std::move(sampler_rt)); connector_que_size_, std::move(schema), std::move(sampler_rt));
cityscapes_op->set_total_repeats(GetTotalRepeats()); cityscapes_op->SetTotalRepeats(GetTotalRepeats());
cityscapes_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); cityscapes_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(cityscapes_op); node_ops->push_back(cityscapes_op);
return Status::OK(); return Status::OK();
} }

View File

@ -202,12 +202,12 @@ Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
// Add the shuffle op after this op // Add the shuffle op after this op
RETURN_IF_NOT_OK( RETURN_IF_NOT_OK(
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op)); AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
shuffle_op->set_total_repeats(GetTotalRepeats()); shuffle_op->SetTotalRepeats(GetTotalRepeats());
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(shuffle_op); node_ops->push_back(shuffle_op);
} }
clue_op->set_total_repeats(GetTotalRepeats()); clue_op->SetTotalRepeats(GetTotalRepeats());
clue_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); clue_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(clue_op); node_ops->push_back(clue_op);
return Status::OK(); return Status::OK();

View File

@ -142,8 +142,8 @@ Status CocoNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
std::shared_ptr<CocoOp> op = std::shared_ptr<CocoOp> op =
std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, connector_que_size_, decode_, std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, connector_que_size_, decode_,
std::move(schema), std::move(sampler_rt), extra_metadata_); std::move(schema), std::move(sampler_rt), extra_metadata_);
op->set_total_repeats(GetTotalRepeats()); op->SetTotalRepeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op); node_ops->push_back(op);
return Status::OK(); return Status::OK();

View File

@ -140,12 +140,12 @@ Status CSVNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
// Add the shuffle op after this op // Add the shuffle op after this op
RETURN_IF_NOT_OK( RETURN_IF_NOT_OK(
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op)); AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
shuffle_op->set_total_repeats(GetTotalRepeats()); shuffle_op->SetTotalRepeats(GetTotalRepeats());
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(shuffle_op); node_ops->push_back(shuffle_op);
} }
csv_op->set_total_repeats(GetTotalRepeats()); csv_op->SetTotalRepeats(GetTotalRepeats());
csv_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); csv_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(csv_op); node_ops->push_back(csv_op);
return Status::OK(); return Status::OK();

View File

@ -100,8 +100,8 @@ Status DIV2KNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
auto div2k_op = std::make_shared<DIV2KOp>(num_workers_, dataset_dir_, usage_, downgrade_, scale_, decode_, auto div2k_op = std::make_shared<DIV2KOp>(num_workers_, dataset_dir_, usage_, downgrade_, scale_, decode_,
connector_que_size_, std::move(schema), std::move(sampler_rt)); connector_que_size_, std::move(schema), std::move(sampler_rt));
div2k_op->set_total_repeats(GetTotalRepeats()); div2k_op->SetTotalRepeats(GetTotalRepeats());
div2k_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); div2k_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(div2k_op); node_ops->push_back(div2k_op);
return Status::OK(); return Status::OK();
} }

View File

@ -100,8 +100,8 @@ Status FlickrNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
auto flickr_op = std::make_shared<FlickrOp>(num_workers_, dataset_dir_, annotation_file_, decode_, auto flickr_op = std::make_shared<FlickrOp>(num_workers_, dataset_dir_, annotation_file_, decode_,
connector_que_size_, std::move(schema), std::move(sampler_rt)); connector_que_size_, std::move(schema), std::move(sampler_rt));
flickr_op->set_total_repeats(GetTotalRepeats()); flickr_op->SetTotalRepeats(GetTotalRepeats());
flickr_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); flickr_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(flickr_op); node_ops->push_back(flickr_op);
return Status::OK(); return Status::OK();
} }

View File

@ -97,8 +97,8 @@ Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_
if (reset_ancestor_ != nullptr) { if (reset_ancestor_ != nullptr) {
reset_ancestor_->op_->AddToEoeList(op); reset_ancestor_->op_->AddToEoeList(op);
} }
op->set_total_repeats(GetTotalRepeats()); op->SetTotalRepeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op); node_ops->push_back(op);
return Status::OK(); return Status::OK();
} }

View File

@ -77,8 +77,8 @@ Status ImageFolderNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const nod
auto op = std::make_shared<ImageFolderOp>(num_workers_, dataset_dir_, connector_que_size_, recursive_, decode_, exts_, auto op = std::make_shared<ImageFolderOp>(num_workers_, dataset_dir_, connector_que_size_, recursive_, decode_, exts_,
class_indexing_, std::move(schema), std::move(sampler_rt)); class_indexing_, std::move(schema), std::move(sampler_rt));
op->set_total_repeats(GetTotalRepeats()); op->SetTotalRepeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op); node_ops->push_back(op);
return Status::OK(); return Status::OK();
} }

View File

@ -104,8 +104,8 @@ Status ManifestNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
manifest_op = std::make_shared<ManifestOp>(num_workers_, dataset_file_, connector_que_size_, decode_, class_index_, manifest_op = std::make_shared<ManifestOp>(num_workers_, dataset_file_, connector_que_size_, decode_, class_index_,
std::move(schema), std::move(sampler_rt), usage_); std::move(schema), std::move(sampler_rt), usage_);
manifest_op->set_total_repeats(GetTotalRepeats()); manifest_op->SetTotalRepeats(GetTotalRepeats());
manifest_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); manifest_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(manifest_op); node_ops->push_back(manifest_op);
return Status::OK(); return Status::OK();

View File

@ -212,8 +212,8 @@ Status MindDataNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
} }
RETURN_IF_NOT_OK(mindrecord_op->Init()); RETURN_IF_NOT_OK(mindrecord_op->Init());
mindrecord_op->set_total_repeats(GetTotalRepeats()); mindrecord_op->SetTotalRepeats(GetTotalRepeats());
mindrecord_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); mindrecord_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(mindrecord_op); node_ops->push_back(mindrecord_op);
return Status::OK(); return Status::OK();

View File

@ -65,8 +65,8 @@ Status MnistNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
auto op = std::make_shared<MnistOp>(usage_, num_workers_, dataset_dir_, connector_que_size_, std::move(schema), auto op = std::make_shared<MnistOp>(usage_, num_workers_, dataset_dir_, connector_que_size_, std::move(schema),
std::move(sampler_rt)); std::move(sampler_rt));
op->set_total_repeats(GetTotalRepeats()); op->SetTotalRepeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op); node_ops->push_back(op);
return Status::OK(); return Status::OK();

View File

@ -75,8 +75,8 @@ Status QMnistNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
auto op = std::make_shared<QMnistOp>(dataset_dir_, usage_, compat_, std::move(schema), std::move(sampler_rt), auto op = std::make_shared<QMnistOp>(dataset_dir_, usage_, compat_, std::move(schema), std::move(sampler_rt),
num_workers_, connector_que_size_); num_workers_, connector_que_size_);
op->set_total_repeats(GetTotalRepeats()); op->SetTotalRepeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op); node_ops->push_back(op);
return Status::OK(); return Status::OK();

View File

@ -110,8 +110,8 @@ Status RandomNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
std::shared_ptr<RandomDataOp> op; std::shared_ptr<RandomDataOp> op;
op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, total_rows_, std::move(data_schema_)); op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, total_rows_, std::move(data_schema_));
op->set_total_repeats(GetTotalRepeats()); op->SetTotalRepeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op); node_ops->push_back(op);
return Status::OK(); return Status::OK();

View File

@ -69,8 +69,8 @@ Status SBUNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto op = std::make_shared<SBUOp>(dataset_dir_, decode_, std::move(schema), std::move(sampler_rt), num_workers_, auto op = std::make_shared<SBUOp>(dataset_dir_, decode_, std::move(schema), std::move(sampler_rt), num_workers_,
connector_que_size_); connector_que_size_);
op->set_total_repeats(GetTotalRepeats()); op->SetTotalRepeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op); node_ops->push_back(op);
return Status::OK(); return Status::OK();

View File

@ -108,12 +108,12 @@ Status TextFileNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
// Add the shuffle op after this op // Add the shuffle op after this op
RETURN_IF_NOT_OK( RETURN_IF_NOT_OK(
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op)); AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
shuffle_op->set_total_repeats(GetTotalRepeats()); shuffle_op->SetTotalRepeats(GetTotalRepeats());
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(shuffle_op); node_ops->push_back(shuffle_op);
} }
text_file_op->set_total_repeats(GetTotalRepeats()); text_file_op->SetTotalRepeats(GetTotalRepeats());
text_file_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); text_file_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
// Add TextFileOp // Add TextFileOp
node_ops->push_back(text_file_op); node_ops->push_back(text_file_op);

View File

@ -149,12 +149,12 @@ Status TFRecordNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
// Add the shuffle op after this op // Add the shuffle op after this op
RETURN_IF_NOT_OK(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op)); RETURN_IF_NOT_OK(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
shuffle_op->set_total_repeats(GetTotalRepeats()); shuffle_op->SetTotalRepeats(GetTotalRepeats());
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(shuffle_op); node_ops->push_back(shuffle_op);
} }
tf_reader_op->set_total_repeats(GetTotalRepeats()); tf_reader_op->SetTotalRepeats(GetTotalRepeats());
tf_reader_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); tf_reader_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
// Add TFReaderOp // Add TFReaderOp
node_ops->push_back(tf_reader_op); node_ops->push_back(tf_reader_op);
return Status::OK(); return Status::OK();

View File

@ -97,12 +97,12 @@ Status USPSNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
// Add the shuffle op after this op // Add the shuffle op after this op
RETURN_IF_NOT_OK(AddShuffleOp(op->FileNames().size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op)); RETURN_IF_NOT_OK(AddShuffleOp(op->FileNames().size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
shuffle_op->set_total_repeats(GetTotalRepeats()); shuffle_op->SetTotalRepeats(GetTotalRepeats());
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(shuffle_op); node_ops->push_back(shuffle_op);
} }
op->set_total_repeats(GetTotalRepeats()); op->SetTotalRepeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op); node_ops->push_back(op);
return Status::OK(); return Status::OK();
} }

View File

@ -125,8 +125,8 @@ Status VOCNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
std::shared_ptr<VOCOp> voc_op; std::shared_ptr<VOCOp> voc_op;
voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, connector_que_size_, voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, connector_que_size_,
decode_, std::move(schema), std::move(sampler_rt), extra_metadata_); decode_, std::move(schema), std::move(sampler_rt), extra_metadata_);
voc_op->set_total_repeats(GetTotalRepeats()); voc_op->SetTotalRepeats(GetTotalRepeats());
voc_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); voc_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(voc_op); node_ops->push_back(voc_op);
return Status::OK(); return Status::OK();
} }

View File

@ -46,8 +46,8 @@ Status SyncWaitNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
// The reason for this is because having it otherwise can lead to blocking issues // The reason for this is because having it otherwise can lead to blocking issues
// See barrier_op.h for more details // See barrier_op.h for more details
auto op = std::make_shared<BarrierOp>(connector_que_size_, condition_name_, callback_); auto op = std::make_shared<BarrierOp>(connector_que_size_, condition_name_, callback_);
op->set_total_repeats(GetTotalRepeats()); op->SetTotalRepeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op); node_ops->push_back(op);
return Status::OK(); return Status::OK();
} }

View File

@ -41,8 +41,8 @@ void TakeNode::Print(std::ostream &out) const { out << (Name() + "(num_rows:" +
// Function to build the TakeOp // Function to build the TakeOp
Status TakeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { Status TakeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto op = std::make_shared<TakeOp>(take_count_); auto op = std::make_shared<TakeOp>(take_count_);
op->set_total_repeats(GetTotalRepeats()); op->SetTotalRepeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op); node_ops->push_back(op);
return Status::OK(); return Status::OK();
} }

View File

@ -97,8 +97,8 @@ Status TransferNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
auto op = std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_, auto op = std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_,
total_batch_, create_data_info_queue_); total_batch_, create_data_info_queue_);
op->set_total_repeats(GetTotalRepeats()); op->SetTotalRepeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op); node_ops->push_back(op);
return Status::OK(); return Status::OK();
} }

View File

@ -60,8 +60,8 @@ Status ZipNode::ValidateParams() {
Status ZipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { Status ZipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto op = std::make_shared<ZipOp>(); auto op = std::make_shared<ZipOp>();
op->set_total_repeats(GetTotalRepeats()); op->SetTotalRepeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op); node_ops->push_back(op);
return Status::OK(); return Status::OK();
} }

View File

@ -68,10 +68,10 @@ Status AutoWorkerPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *con
int32_t cur_node_num_worker = std::max(std::min(num_workers, cur_node_max), min_num_workers_); int32_t cur_node_num_worker = std::max(std::min(num_workers, cur_node_max), min_num_workers_);
// if the num_worker to set is same as original, skip setting and printing the logs // if the num_worker to set is same as original, skip setting and printing the logs
if (cur_node_num_worker == p.first->num_workers()) continue; if (cur_node_num_worker == p.first->NumWorkers()) continue;
// log the change via warning msg so user can see what the num_worker is being set for which op // log the change via warning msg so user can see what the num_worker is being set for which op
MS_LOG(WARNING) << "AutoNumWorker enabled, num_workers in " << p.first->Name() << " is auto-adjusted from " MS_LOG(WARNING) << "AutoNumWorker enabled, num_workers in " << p.first->Name() << " is auto-adjusted from "
<< std::to_string(p.first->num_workers()) + " to " + std::to_string(cur_node_num_worker); << std::to_string(p.first->NumWorkers()) + " to " + std::to_string(cur_node_num_worker);
p.first->SetNumWorkers(cur_node_num_worker); p.first->SetNumWorkers(cur_node_num_worker);
} }
return Status::OK(); return Status::OK();

View File

@ -49,7 +49,7 @@ Status DeepCopyPass::Visit(std::shared_ptr<DatasetNode> node, bool *const modifi
// Temporary fix to set the num_workers to each cloned node. // Temporary fix to set the num_workers to each cloned node.
// This can be improved by adding a new method in the base class DatasetNode to transfer the properties to // This can be improved by adding a new method in the base class DatasetNode to transfer the properties to
// the cloned node. Each derived class's Copy() will need to include this method. // the cloned node. Each derived class's Copy() will need to include this method.
new_node->SetNumWorkers(node->num_workers()); new_node->SetNumWorkers(node->NumWorkers());
// This method below assumes a DFS walk and from the first child to the last child. // This method below assumes a DFS walk and from the first child to the last child.
// Future: A more robust implementation that does not depend on the above assumption. // Future: A more robust implementation that does not depend on the above assumption.
RETURN_IF_NOT_OK(parent_->AppendChild(new_node)); RETURN_IF_NOT_OK(parent_->AppendChild(new_node));

View File

@ -42,7 +42,7 @@ json ConnectorSize::ParseOpInfo(const DatasetOp &node, const std::vector<int32_t
json json_node; json json_node;
json_node["op_id"] = node.id(); json_node["op_id"] = node.id();
json_node["op_type"] = node.Name(); json_node["op_type"] = node.Name();
json_node["num_workers"] = node.num_workers(); json_node["num_workers"] = node.NumWorkers();
json metrics; json metrics;
// DeviceQueueOp is a special op,it is not inlined but its output queue is invalid. // DeviceQueueOp is a special op,it is not inlined but its output queue is invalid.
// So we should not output its queue size. // So we should not output its queue size.

View File

@ -78,7 +78,7 @@ json ConnectorThroughput::ParseOpInfo(const DatasetOp &node, const std::vector<d
json json_node; json json_node;
json_node["op_id"] = node.id(); json_node["op_id"] = node.id();
json_node["op_type"] = node.Name(); json_node["op_type"] = node.Name();
json_node["num_workers"] = node.num_workers(); json_node["num_workers"] = node.NumWorkers();
json metrics; json metrics;
// DeviceQueueOp is a special op,it is not inlined but its output queue is invalid. // DeviceQueueOp is a special op,it is not inlined but its output queue is invalid.
// So we should not output its connector throughput. // So we should not output its connector throughput.

View File

@ -264,7 +264,7 @@ Status OperatorCpu::Collect(const ExecutionTree *tree) {
for (auto iter = tree->begin(); iter != tree->end(); ++iter) { for (auto iter = tree->begin(); iter != tree->end(); ++iter) {
id_count_++; id_count_++;
op_name_[iter->id()] = iter->NameWithID(); op_name_[iter->id()] = iter->NameWithID();
op_parallel_workers_[iter->id()] = iter->num_workers(); op_parallel_workers_[iter->id()] = iter->NumWorkers();
} }
#if defined(USING_LINUX) #if defined(USING_LINUX)
cpu_processor_num_ = get_nprocs_conf(); cpu_processor_num_ = get_nprocs_conf();

View File

@ -221,7 +221,7 @@ Status TreeAdapter::GetNext(TensorRow *row) {
RETURN_UNEXPECTED_IF_NULL(row); RETURN_UNEXPECTED_IF_NULL(row);
row->clear(); // make sure row is empty row->clear(); // make sure row is empty
#ifndef ENABLE_SECURITY #ifndef ENABLE_SECURITY
bool isProfilingEnable = tree_->GetProfilingManager()->IsProfilingEnable(); bool is_profiling_enable = tree_->GetProfilingManager()->IsProfilingEnable();
#endif #endif
// When cur_db_ is a nullptr, it means this is the first call to get_next, launch ExecutionTree // When cur_db_ is a nullptr, it means this is the first call to get_next, launch ExecutionTree
@ -233,7 +233,7 @@ Status TreeAdapter::GetNext(TensorRow *row) {
if (row->eoe()) { // return empty tensor if 1st buf is a ctrl buf (no rows) if (row->eoe()) { // return empty tensor if 1st buf is a ctrl buf (no rows)
MS_LOG(INFO) << "End of data iteration."; MS_LOG(INFO) << "End of data iteration.";
#ifndef ENABLE_SECURITY #ifndef ENABLE_SECURITY
if (isProfilingEnable) { if (is_profiling_enable) {
tree_->SetEpochEnd(); tree_->SetEpochEnd();
} }
#endif #endif

View File

@ -123,6 +123,9 @@ class Execute {
/// \brief The function to validate target device setting is valid or not. /// \brief The function to validate target device setting is valid or not.
Status ValidateDevice(); Status ValidateDevice();
/// \brief Initialize 310 resource
Status InitResource(MapTargetDevice device_type, uint32_t device_id);
std::vector<std::shared_ptr<TensorTransform>> transforms_; std::vector<std::shared_ptr<TensorTransform>> transforms_;
std::vector<std::shared_ptr<TensorOperation>> ops_; std::vector<std::shared_ptr<TensorOperation>> ops_;
MapTargetDevice device_type_; MapTargetDevice device_type_;

View File

@ -251,6 +251,14 @@ static bool Conv2DImplement(const LiteMat &src, const LiteMat &kernel, T2 *dst,
int border_y = static_cast<int>(kernel.height_ / 2); int border_y = static_cast<int>(kernel.height_ / 2);
LiteMat pad_mat; LiteMat pad_mat;
if ((border_x > INT_MAX / 2) || (src.width_ > INT_MAX - 2 * border_x)) {
return false;
}
if ((border_y > INT_MAX / 2) || (src.height_ > INT_MAX - 2 * border_y)) {
return false;
}
pad_mat.Init(src.width_ + 2 * border_x, src.height_ + 2 * border_y, src.channel_, src.data_type_); pad_mat.Init(src.width_ + 2 * border_x, src.height_ + 2 * border_y, src.channel_, src.data_type_);
if (!Pad(src, pad_mat, border_y, border_y, border_x, border_x, pad_type)) { if (!Pad(src, pad_mat, border_y, border_y, border_x, border_x, pad_type)) {

View File

@ -30,7 +30,7 @@ Status ComputeUpperAndLowerPercentiles(std::vector<int32_t> *hist, int32_t hi_p,
int32_t n = std::accumulate(hist->begin(), hist->end(), 0); int32_t n = std::accumulate(hist->begin(), hist->end(), 0);
constexpr float kMaxPerc = 100.0; constexpr float kMaxPerc = 100.0;
int32_t cut = static_cast<int32_t>((low_p / kMaxPerc) * n); int32_t cut = static_cast<int32_t>((low_p / kMaxPerc) * n);
for (int32_t lb = 0; lb < hist->size() + 1 && cut > 0; lb++) { for (int32_t lb = 0; lb < hist->size() && cut > 0; lb++) {
if (cut > (*hist)[lb]) { if (cut > (*hist)[lb]) {
cut -= (*hist)[lb]; cut -= (*hist)[lb];
(*hist)[lb] = 0; (*hist)[lb] = 0;

View File

@ -42,7 +42,9 @@ Status MixUpBatchOp::ComputeLabels(const TensorRow &input, std::shared_ptr<Tenso
CHECK_FAIL_RETURN_UNEXPECTED( CHECK_FAIL_RETURN_UNEXPECTED(
images_size <= static_cast<size_t>(std::numeric_limits<int64_t>::max()), images_size <= static_cast<size_t>(std::numeric_limits<int64_t>::max()),
"The \'images_size\' must not be more than \'INT64_MAX\', but got: " + std::to_string(images_size)); "The \'images_size\' must not be more than \'INT64_MAX\', but got: " + std::to_string(images_size));
for (int64_t i = 0; i < static_cast<int64_t>(images_size); i++) rand_indx->push_back(i); for (int64_t i = 0; i < static_cast<int64_t>(images_size); i++) {
rand_indx->push_back(i);
}
std::shuffle(rand_indx->begin(), rand_indx->end(), rnd_); std::shuffle(rand_indx->begin(), rand_indx->end(), rnd_);
RETURN_IF_NOT_OK(TypeCast(std::move(input.at(1)), out_labels, DataType(DataType::DE_FLOAT32))); RETURN_IF_NOT_OK(TypeCast(std::move(input.at(1)), out_labels, DataType(DataType::DE_FLOAT32)));

Some files were not shown because too many files have changed in this diff Show More