forked from mindspore-Ecosystem/mindspore
sync code
This commit is contained in:
parent
dc6e62227d
commit
567f137a39
|
@ -34,7 +34,7 @@ bool set_seed(int32_t seed) {
|
||||||
MS_LOG(ERROR) << "Seed given is not within the required range: " << seed;
|
MS_LOG(ERROR) << "Seed given is not within the required range: " << seed;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
_config->set_seed((uint32_t)seed);
|
_config->set_seed(static_cast<uint32_t>(seed));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ bool set_monitor_sampling_interval(int32_t interval) {
|
||||||
MS_LOG(ERROR) << "Interval given is not within the required range: " << interval;
|
MS_LOG(ERROR) << "Interval given is not within the required range: " << interval;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
_config->set_monitor_sampling_interval((uint32_t)interval);
|
_config->set_monitor_sampling_interval(static_cast<uint32_t>(interval));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,7 +86,7 @@ bool set_callback_timeback(int32_t timeout) {
|
||||||
MS_LOG(ERROR) << "Timeout given is not within the required range: " << timeout;
|
MS_LOG(ERROR) << "Timeout given is not within the required range: " << timeout;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
_config->set_callback_timeout((uint32_t)timeout);
|
_config->set_callback_timeout(static_cast<uint32_t>(timeout));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -215,7 +215,6 @@ bool Dataset::DeviceQueueCharIF(const std::vector<char> &queue_name, const std::
|
||||||
MS_LOG(ERROR) << "ToDevice: Failed to get consumer.";
|
MS_LOG(ERROR) << "ToDevice: Failed to get consumer.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
rc = consumer->Init(ds);
|
rc = consumer->Init(ds);
|
||||||
if (rc.IsError()) {
|
if (rc.IsError()) {
|
||||||
MS_LOG(ERROR) << "ToDevice: Failed to init. Error status: " << rc;
|
MS_LOG(ERROR) << "ToDevice: Failed to init. Error status: " << rc;
|
||||||
|
@ -258,7 +257,6 @@ bool Dataset::SaveCharIF(const std::vector<char> &dataset_path, int32_t num_file
|
||||||
MS_LOG(ERROR) << "ToDevice: Failed to get consumer.";
|
MS_LOG(ERROR) << "ToDevice: Failed to get consumer.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
rc = consumer->Init(ds->IRNode());
|
rc = consumer->Init(ds->IRNode());
|
||||||
if (rc.IsError()) {
|
if (rc.IsError()) {
|
||||||
MS_LOG(ERROR) << "CreateSaver failed." << rc;
|
MS_LOG(ERROR) << "CreateSaver failed." << rc;
|
||||||
|
@ -289,11 +287,15 @@ bool Dataset::SaveCharIF(const std::vector<char> &dataset_path, int32_t num_file
|
||||||
Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); }
|
Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); }
|
||||||
|
|
||||||
int64_t Dataset::GetDatasetSize(bool estimate) {
|
int64_t Dataset::GetDatasetSize(bool estimate) {
|
||||||
int64_t dataset_size;
|
int64_t dataset_size = -1;
|
||||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
|
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
|
||||||
std::shared_ptr<DatasetSizeGetter> size_getter = std::make_shared<DatasetSizeGetter>();
|
std::shared_ptr<DatasetSizeGetter> size_getter = std::make_shared<DatasetSizeGetter>();
|
||||||
DatasetSizeGetter *consumer = size_getter.get();
|
DatasetSizeGetter *consumer = size_getter.get();
|
||||||
|
if (consumer == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "DatasetSizeGetter: Failed to get consumer.";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
runtime_context->AssignConsumer(size_getter);
|
runtime_context->AssignConsumer(size_getter);
|
||||||
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), -1);
|
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), -1);
|
||||||
RETURN_SECOND_IF_ERROR(consumer->GetDatasetSize(&dataset_size, estimate), -1);
|
RETURN_SECOND_IF_ERROR(consumer->GetDatasetSize(&dataset_size, estimate), -1);
|
||||||
|
@ -305,6 +307,10 @@ std::vector<mindspore::DataType> Dataset::GetOutputTypes() {
|
||||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
|
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
|
||||||
TreeGetters *consumer = tree_getters_.get();
|
TreeGetters *consumer = tree_getters_.get();
|
||||||
|
if (consumer == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
|
||||||
|
return std::vector<mindspore::DataType>();
|
||||||
|
}
|
||||||
runtime_context->AssignConsumer(tree_getters_);
|
runtime_context->AssignConsumer(tree_getters_);
|
||||||
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
|
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
|
||||||
RETURN_SECOND_IF_ERROR(consumer->GetOutputTypes(&types), {});
|
RETURN_SECOND_IF_ERROR(consumer->GetOutputTypes(&types), {});
|
||||||
|
@ -320,6 +326,10 @@ std::vector<std::vector<int64_t>> Dataset::GetOutputShapes() {
|
||||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
|
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
|
||||||
TreeGetters *consumer = tree_getters_.get();
|
TreeGetters *consumer = tree_getters_.get();
|
||||||
|
if (consumer == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
|
||||||
|
return std::vector<std::vector<int64_t>>();
|
||||||
|
}
|
||||||
runtime_context->AssignConsumer(tree_getters_);
|
runtime_context->AssignConsumer(tree_getters_);
|
||||||
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
|
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
|
||||||
RETURN_SECOND_IF_ERROR(consumer->GetOutputShapes(&shapes), {});
|
RETURN_SECOND_IF_ERROR(consumer->GetOutputShapes(&shapes), {});
|
||||||
|
@ -330,10 +340,14 @@ std::vector<std::vector<int64_t>> Dataset::GetOutputShapes() {
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t Dataset::GetNumClasses() {
|
int64_t Dataset::GetNumClasses() {
|
||||||
int64_t num_classes;
|
int64_t num_classes = -1;
|
||||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
|
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
|
||||||
TreeGetters *consumer = tree_getters_.get();
|
TreeGetters *consumer = tree_getters_.get();
|
||||||
|
if (consumer == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
runtime_context->AssignConsumer(tree_getters_);
|
runtime_context->AssignConsumer(tree_getters_);
|
||||||
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), -1);
|
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), -1);
|
||||||
RETURN_SECOND_IF_ERROR(consumer->GetNumClasses(&num_classes), -1);
|
RETURN_SECOND_IF_ERROR(consumer->GetNumClasses(&num_classes), -1);
|
||||||
|
@ -345,6 +359,10 @@ std::vector<std::vector<char>> Dataset::GetColumnNamesCharIF() {
|
||||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
|
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
|
||||||
TreeGetters *consumer = tree_getters_.get();
|
TreeGetters *consumer = tree_getters_.get();
|
||||||
|
if (consumer == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
|
||||||
|
return std::vector<std::vector<char>>();
|
||||||
|
}
|
||||||
runtime_context->AssignConsumer(tree_getters_);
|
runtime_context->AssignConsumer(tree_getters_);
|
||||||
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
|
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
|
||||||
RETURN_SECOND_IF_ERROR(consumer->GetColumnNames(&col_names), {});
|
RETURN_SECOND_IF_ERROR(consumer->GetColumnNames(&col_names), {});
|
||||||
|
@ -356,6 +374,10 @@ std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> Dataset::GetClas
|
||||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
|
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
|
||||||
TreeGetters *consumer = tree_getters_.get();
|
TreeGetters *consumer = tree_getters_.get();
|
||||||
|
if (consumer == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
|
||||||
|
return std::vector<std::pair<std::vector<char>, std::vector<int32_t>>>();
|
||||||
|
}
|
||||||
runtime_context->AssignConsumer(tree_getters_);
|
runtime_context->AssignConsumer(tree_getters_);
|
||||||
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
|
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
|
||||||
RETURN_SECOND_IF_ERROR(consumer->GetClassIndexing(&output_class_indexing), {});
|
RETURN_SECOND_IF_ERROR(consumer->GetClassIndexing(&output_class_indexing), {});
|
||||||
|
@ -493,10 +515,10 @@ TakeDataset::TakeDataset(std::shared_ptr<Dataset> input, int32_t count) {
|
||||||
|
|
||||||
ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
|
ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
|
||||||
std::vector<std::shared_ptr<DatasetNode>> all_datasets;
|
std::vector<std::shared_ptr<DatasetNode>> all_datasets;
|
||||||
(void)std::transform(
|
(void)std::transform(datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
|
||||||
datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
|
[](std::shared_ptr<Dataset> dataset) -> std::shared_ptr<DatasetNode> {
|
||||||
[](std::shared_ptr<Dataset> dataset) -> std::shared_ptr<DatasetNode> { return dataset->IRNode(); });
|
return (dataset != nullptr) ? dataset->IRNode() : nullptr;
|
||||||
|
});
|
||||||
auto ds = std::make_shared<ZipNode>(all_datasets);
|
auto ds = std::make_shared<ZipNode>(all_datasets);
|
||||||
|
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
|
@ -544,6 +566,10 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocabCharIF(
|
||||||
|
|
||||||
auto consumer = std::make_unique<BuildVocabConsumer>();
|
auto consumer = std::make_unique<BuildVocabConsumer>();
|
||||||
BuildVocabConsumer *bv_consumer = consumer.get();
|
BuildVocabConsumer *bv_consumer = consumer.get();
|
||||||
|
if (bv_consumer == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "BuildVocabConsumer: Failed to get bv_consumer.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
rc = consumer->Init(ds);
|
rc = consumer->Init(ds);
|
||||||
if (rc.IsError()) {
|
if (rc.IsError()) {
|
||||||
MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init consumer. Error status: " << rc;
|
MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init consumer. Error status: " << rc;
|
||||||
|
@ -577,6 +603,10 @@ std::shared_ptr<Vocab> Dataset::BuildVocabCharIF(const std::vector<std::vector<c
|
||||||
|
|
||||||
auto consumer = std::make_unique<BuildVocabConsumer>();
|
auto consumer = std::make_unique<BuildVocabConsumer>();
|
||||||
BuildVocabConsumer *bv_consumer = consumer.get();
|
BuildVocabConsumer *bv_consumer = consumer.get();
|
||||||
|
if (bv_consumer == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "BuildVocabConsumer: Failed to get bv_consumer.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
rc = consumer->Init(ds);
|
rc = consumer->Init(ds);
|
||||||
if (rc.IsError()) {
|
if (rc.IsError()) {
|
||||||
MS_LOG(ERROR) << "BuildVocab: Failed to init consumer. Error status: " << rc;
|
MS_LOG(ERROR) << "BuildVocab: Failed to init consumer. Error status: " << rc;
|
||||||
|
|
|
@ -54,20 +54,27 @@ struct Execute::ExtraInfo {
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
Execute::Execute(std::shared_ptr<TensorOperation> op, MapTargetDevice device_type, uint32_t device_id) {
|
Status Execute::InitResource(MapTargetDevice device_type, uint32_t device_id) {
|
||||||
ops_.emplace_back(std::move(op));
|
|
||||||
device_type_ = device_type;
|
|
||||||
info_ = std::make_shared<ExtraInfo>();
|
|
||||||
#ifdef ENABLE_ACL
|
#ifdef ENABLE_ACL
|
||||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
if (device_type_ == MapTargetDevice::kAscend310) {
|
||||||
device_resource_ = std::make_shared<AscendResource>();
|
device_resource_ = std::make_shared<AscendResource>();
|
||||||
Status rc = device_resource_->InitResource(device_id);
|
Status rc = device_resource_->InitResource(device_id);
|
||||||
if (!rc.IsOk()) {
|
if (!rc.IsOk()) {
|
||||||
device_resource_ = nullptr;
|
device_resource_ = nullptr;
|
||||||
MS_LOG(ERROR) << "Initialize Ascend310 resource fail.";
|
std::string err_msg = "Initialize Ascend310 resource fail";
|
||||||
|
MS_LOG(ERROR) << err_msg;
|
||||||
|
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Execute::Execute(std::shared_ptr<TensorOperation> op, MapTargetDevice device_type, uint32_t device_id) {
|
||||||
|
ops_.emplace_back(std::move(op));
|
||||||
|
device_type_ = device_type;
|
||||||
|
info_ = std::make_shared<ExtraInfo>();
|
||||||
|
(void)InitResource(device_type, device_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
Execute::Execute(std::shared_ptr<TensorTransform> op, MapTargetDevice device_type, uint32_t device_id) {
|
Execute::Execute(std::shared_ptr<TensorTransform> op, MapTargetDevice device_type, uint32_t device_id) {
|
||||||
|
@ -76,16 +83,7 @@ Execute::Execute(std::shared_ptr<TensorTransform> op, MapTargetDevice device_typ
|
||||||
|
|
||||||
info_ = std::make_shared<ExtraInfo>();
|
info_ = std::make_shared<ExtraInfo>();
|
||||||
device_type_ = device_type;
|
device_type_ = device_type;
|
||||||
#ifdef ENABLE_ACL
|
(void)InitResource(device_type, device_id);
|
||||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
|
||||||
device_resource_ = std::make_shared<AscendResource>();
|
|
||||||
Status rc = device_resource_->InitResource(device_id);
|
|
||||||
if (!rc.IsOk()) {
|
|
||||||
device_resource_ = nullptr;
|
|
||||||
MS_LOG(ERROR) << "Initialize Ascend310 resource fail.";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Execute::Execute(std::reference_wrapper<TensorTransform> op, MapTargetDevice device_type, uint32_t device_id) {
|
Execute::Execute(std::reference_wrapper<TensorTransform> op, MapTargetDevice device_type, uint32_t device_id) {
|
||||||
|
@ -96,16 +94,7 @@ Execute::Execute(std::reference_wrapper<TensorTransform> op, MapTargetDevice dev
|
||||||
info_ = std::make_shared<ExtraInfo>();
|
info_ = std::make_shared<ExtraInfo>();
|
||||||
info_->init_with_shared_ptr_ = false;
|
info_->init_with_shared_ptr_ = false;
|
||||||
device_type_ = device_type;
|
device_type_ = device_type;
|
||||||
#ifdef ENABLE_ACL
|
(void)InitResource(device_type, device_id);
|
||||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
|
||||||
device_resource_ = std::make_shared<AscendResource>();
|
|
||||||
Status rc = device_resource_->InitResource(device_id);
|
|
||||||
if (!rc.IsOk()) {
|
|
||||||
device_resource_ = nullptr;
|
|
||||||
MS_LOG(ERROR) << "Initialize Ascend310 resource fail.";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute function for the example case: auto decode(new vision::Decode());
|
// Execute function for the example case: auto decode(new vision::Decode());
|
||||||
|
@ -116,31 +105,13 @@ Execute::Execute(TensorTransform *op, MapTargetDevice device_type, uint32_t devi
|
||||||
|
|
||||||
info_ = std::make_shared<ExtraInfo>();
|
info_ = std::make_shared<ExtraInfo>();
|
||||||
device_type_ = device_type;
|
device_type_ = device_type;
|
||||||
#ifdef ENABLE_ACL
|
(void)InitResource(device_type, device_id);
|
||||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
|
||||||
device_resource_ = std::make_shared<AscendResource>();
|
|
||||||
Status rc = device_resource_->InitResource(device_id);
|
|
||||||
if (!rc.IsOk()) {
|
|
||||||
device_resource_ = nullptr;
|
|
||||||
MS_LOG(ERROR) << "Initialize Ascend310 resource fail.";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Execute::Execute(std::vector<std::shared_ptr<TensorOperation>> ops, MapTargetDevice device_type, uint32_t device_id)
|
Execute::Execute(std::vector<std::shared_ptr<TensorOperation>> ops, MapTargetDevice device_type, uint32_t device_id)
|
||||||
: ops_(std::move(ops)), device_type_(device_type) {
|
: ops_(std::move(ops)), device_type_(device_type) {
|
||||||
info_ = std::make_shared<ExtraInfo>();
|
info_ = std::make_shared<ExtraInfo>();
|
||||||
#ifdef ENABLE_ACL
|
(void)InitResource(device_type, device_id);
|
||||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
|
||||||
device_resource_ = std::make_shared<AscendResource>();
|
|
||||||
Status rc = device_resource_->InitResource(device_id);
|
|
||||||
if (!rc.IsOk()) {
|
|
||||||
device_resource_ = nullptr;
|
|
||||||
MS_LOG(ERROR) << "Initialize Ascend310 resource fail.";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Execute::Execute(std::vector<std::shared_ptr<TensorTransform>> ops, MapTargetDevice device_type, uint32_t device_id) {
|
Execute::Execute(std::vector<std::shared_ptr<TensorTransform>> ops, MapTargetDevice device_type, uint32_t device_id) {
|
||||||
|
@ -149,16 +120,7 @@ Execute::Execute(std::vector<std::shared_ptr<TensorTransform>> ops, MapTargetDev
|
||||||
|
|
||||||
info_ = std::make_shared<ExtraInfo>();
|
info_ = std::make_shared<ExtraInfo>();
|
||||||
device_type_ = device_type;
|
device_type_ = device_type;
|
||||||
#ifdef ENABLE_ACL
|
(void)InitResource(device_type, device_id);
|
||||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
|
||||||
device_resource_ = std::make_shared<AscendResource>();
|
|
||||||
Status rc = device_resource_->InitResource(device_id);
|
|
||||||
if (!rc.IsOk()) {
|
|
||||||
device_resource_ = nullptr;
|
|
||||||
MS_LOG(ERROR) << "Initialize Ascend310 resource fail.";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Execute::Execute(const std::vector<std::reference_wrapper<TensorTransform>> ops, MapTargetDevice device_type,
|
Execute::Execute(const std::vector<std::reference_wrapper<TensorTransform>> ops, MapTargetDevice device_type,
|
||||||
|
@ -177,16 +139,7 @@ Execute::Execute(const std::vector<std::reference_wrapper<TensorTransform>> ops,
|
||||||
info_ = std::make_shared<ExtraInfo>();
|
info_ = std::make_shared<ExtraInfo>();
|
||||||
info_->init_with_shared_ptr_ = false;
|
info_->init_with_shared_ptr_ = false;
|
||||||
device_type_ = device_type;
|
device_type_ = device_type;
|
||||||
#ifdef ENABLE_ACL
|
(void)InitResource(device_type, device_id);
|
||||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
|
||||||
device_resource_ = std::make_shared<AscendResource>();
|
|
||||||
Status rc = device_resource_->InitResource(device_id);
|
|
||||||
if (!rc.IsOk()) {
|
|
||||||
device_resource_ = nullptr;
|
|
||||||
MS_LOG(ERROR) << "Initialize Ascend310 resource fail";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute function for the example vector case: auto decode(new vision::Decode());
|
// Execute function for the example vector case: auto decode(new vision::Decode());
|
||||||
|
@ -199,16 +152,7 @@ Execute::Execute(const std::vector<TensorTransform *> &ops, MapTargetDevice devi
|
||||||
|
|
||||||
info_ = std::make_shared<ExtraInfo>();
|
info_ = std::make_shared<ExtraInfo>();
|
||||||
device_type_ = device_type;
|
device_type_ = device_type;
|
||||||
#ifdef ENABLE_ACL
|
(void)InitResource(device_type, device_id);
|
||||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
|
||||||
device_resource_ = std::make_shared<AscendResource>();
|
|
||||||
Status rc = device_resource_->InitResource(device_id);
|
|
||||||
if (!rc.IsOk()) {
|
|
||||||
device_resource_ = nullptr;
|
|
||||||
MS_LOG(ERROR) << "Initialize Ascend310 resource fail";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Execute::~Execute() {
|
Execute::~Execute() {
|
||||||
|
@ -228,6 +172,7 @@ Execute::~Execute() {
|
||||||
|
|
||||||
Status Execute::operator()(const mindspore::MSTensor &input, mindspore::MSTensor *output) {
|
Status Execute::operator()(const mindspore::MSTensor &input, mindspore::MSTensor *output) {
|
||||||
// Validate input tensor
|
// Validate input tensor
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(output);
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(input.DataSize() > 0, "Input Tensor has no data.");
|
CHECK_FAIL_RETURN_UNEXPECTED(input.DataSize() > 0, "Input Tensor has no data.");
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(output != nullptr, "Output Tensor can not be nullptr.");
|
CHECK_FAIL_RETURN_UNEXPECTED(output != nullptr, "Output Tensor can not be nullptr.");
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(ValidateDevice(), "Device Type should be 'Ascend310' or 'CPU'.");
|
CHECK_FAIL_RETURN_UNEXPECTED(ValidateDevice(), "Device Type should be 'Ascend310' or 'CPU'.");
|
||||||
|
@ -290,7 +235,8 @@ Status Execute::operator()(const mindspore::MSTensor &input, mindspore::MSTensor
|
||||||
RETURN_STATUS_UNEXPECTED(ss.str());
|
RETURN_STATUS_UNEXPECTED(ss.str());
|
||||||
}
|
}
|
||||||
*output = mindspore::MSTensor(std::make_shared<DETensor>(de_tensor));
|
*output = mindspore::MSTensor(std::make_shared<DETensor>(de_tensor));
|
||||||
} else { // Ascend310 case, where we must set Ascend resource on each operators
|
} else if (device_type_ ==
|
||||||
|
MapTargetDevice::kAscend310) { // Ascend310 case, where we must set Ascend resource on each operators
|
||||||
#ifdef ENABLE_ACL
|
#ifdef ENABLE_ACL
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(device_resource_, "Device resource is nullptr which is illegal under case Ascend310.");
|
CHECK_FAIL_RETURN_UNEXPECTED(device_resource_, "Device resource is nullptr which is illegal under case Ascend310.");
|
||||||
// Sink data from host into device
|
// Sink data from host into device
|
||||||
|
@ -311,12 +257,17 @@ Status Execute::operator()(const mindspore::MSTensor &input, mindspore::MSTensor
|
||||||
|
|
||||||
*output = mindspore::MSTensor(std::make_shared<DETensor>(device_input, true));
|
*output = mindspore::MSTensor(std::make_shared<DETensor>(device_input, true));
|
||||||
#endif
|
#endif
|
||||||
|
} else {
|
||||||
|
std::string err_msg = "Your input device is not supported. (Option: CPU or Ascend310)";
|
||||||
|
MS_LOG(ERROR) << err_msg;
|
||||||
|
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Execute::operator()(const std::vector<MSTensor> &input_tensor_list, std::vector<MSTensor> *output_tensor_list) {
|
Status Execute::operator()(const std::vector<MSTensor> &input_tensor_list, std::vector<MSTensor> *output_tensor_list) {
|
||||||
// Validate input tensor
|
// Validate input tensor
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(output_tensor_list);
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(!input_tensor_list.empty(), "Input Tensor is not valid.");
|
CHECK_FAIL_RETURN_UNEXPECTED(!input_tensor_list.empty(), "Input Tensor is not valid.");
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(output_tensor_list != nullptr, "Output Tensor can not be nullptr.");
|
CHECK_FAIL_RETURN_UNEXPECTED(output_tensor_list != nullptr, "Output Tensor can not be nullptr.");
|
||||||
output_tensor_list->clear();
|
output_tensor_list->clear();
|
||||||
|
@ -380,7 +331,8 @@ Status Execute::operator()(const std::vector<MSTensor> &input_tensor_list, std::
|
||||||
++idx;
|
++idx;
|
||||||
}
|
}
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(!output_tensor_list->empty(), "Output Tensor is not valid.");
|
CHECK_FAIL_RETURN_UNEXPECTED(!output_tensor_list->empty(), "Output Tensor is not valid.");
|
||||||
} else { // Case Ascend310
|
} else if (device_type_ ==
|
||||||
|
MapTargetDevice::kAscend310) { // Ascend310 case, where we must set Ascend resource on each operators
|
||||||
#ifdef ENABLE_ACL
|
#ifdef ENABLE_ACL
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(device_resource_, "Device resource is nullptr which is illegal under case Ascend310.");
|
CHECK_FAIL_RETURN_UNEXPECTED(device_resource_, "Device resource is nullptr which is illegal under case Ascend310.");
|
||||||
for (auto &input_tensor : input_tensor_list) {
|
for (auto &input_tensor : input_tensor_list) {
|
||||||
|
@ -410,6 +362,10 @@ Status Execute::operator()(const std::vector<MSTensor> &input_tensor_list, std::
|
||||||
}
|
}
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(!output_tensor_list->empty(), "Output Tensor vector is empty.");
|
CHECK_FAIL_RETURN_UNEXPECTED(!output_tensor_list->empty(), "Output Tensor vector is empty.");
|
||||||
#endif
|
#endif
|
||||||
|
} else {
|
||||||
|
std::string err_msg = "Your input device is not supported. (Option: CPU or Ascend310)";
|
||||||
|
MS_LOG(ERROR) << err_msg;
|
||||||
|
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -526,6 +482,10 @@ Status AippInfoCollection(std::map<std::string, std::string> *aipp_options, cons
|
||||||
|
|
||||||
std::string Execute::AippCfgGenerator() {
|
std::string Execute::AippCfgGenerator() {
|
||||||
std::string config_location = "./aipp.cfg";
|
std::string config_location = "./aipp.cfg";
|
||||||
|
if (info_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "info_ is null";
|
||||||
|
return "";
|
||||||
|
}
|
||||||
#ifdef ENABLE_ACL
|
#ifdef ENABLE_ACL
|
||||||
if (info_->init_with_shared_ptr_) {
|
if (info_->init_with_shared_ptr_) {
|
||||||
auto rc = ParseTransforms();
|
auto rc = ParseTransforms();
|
||||||
|
@ -565,8 +525,7 @@ std::string Execute::AippCfgGenerator() {
|
||||||
if (!outfile.is_open()) {
|
if (!outfile.is_open()) {
|
||||||
MS_LOG(ERROR) << "Fail to open Aipp config file, please verify your system config(including authority)."
|
MS_LOG(ERROR) << "Fail to open Aipp config file, please verify your system config(including authority)."
|
||||||
<< "We will return empty string which represent the location of Aipp config file in this case.";
|
<< "We will return empty string which represent the location of Aipp config file in this case.";
|
||||||
std::string except = "";
|
return "";
|
||||||
return except;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
if (device_type_ == MapTargetDevice::kAscend310) {
|
||||||
|
@ -650,10 +609,14 @@ Status Execute::ParseTransforms() {
|
||||||
[](std::shared_ptr<TensorTransform> operation) -> std::shared_ptr<TensorOperation> {
|
[](std::shared_ptr<TensorTransform> operation) -> std::shared_ptr<TensorOperation> {
|
||||||
return operation->Parse();
|
return operation->Parse();
|
||||||
});
|
});
|
||||||
} else {
|
} else if (device_type_ == MapTargetDevice::kAscend310) {
|
||||||
for (auto &transform_ : transforms_) {
|
for (auto &transform_ : transforms_) {
|
||||||
ops_.emplace_back(transform_->Parse(device_type_));
|
ops_.emplace_back(transform_->Parse(device_type_));
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
std::string err_msg = "Your input device is not supported. (Option: CPU or Ascend310)";
|
||||||
|
MS_LOG(ERROR) << err_msg;
|
||||||
|
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -38,7 +38,7 @@ Status Iterator::GetNextRowCharIF(MSTensorMapChar *row) {
|
||||||
row->clear();
|
row->clear();
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
for (auto de_tensor : md_map) {
|
for (auto &de_tensor : md_map) {
|
||||||
std::vector<char> col_name(de_tensor.first.begin(), de_tensor.first.end());
|
std::vector<char> col_name(de_tensor.first.begin(), de_tensor.first.end());
|
||||||
row->insert(std::make_pair(col_name, mindspore::MSTensor(std::make_shared<DETensor>(de_tensor.second))));
|
row->insert(std::make_pair(col_name, mindspore::MSTensor(std::make_shared<DETensor>(de_tensor.second))));
|
||||||
}
|
}
|
||||||
|
@ -48,8 +48,8 @@ Status Iterator::GetNextRowCharIF(MSTensorMapChar *row) {
|
||||||
|
|
||||||
// Get the next row from the data pipeline.
|
// Get the next row from the data pipeline.
|
||||||
Status Iterator::GetNextRow(MSTensorVec *row) {
|
Status Iterator::GetNextRow(MSTensorVec *row) {
|
||||||
// Clean data row
|
|
||||||
RETURN_UNEXPECTED_IF_NULL(row);
|
RETURN_UNEXPECTED_IF_NULL(row);
|
||||||
|
// Clean data row
|
||||||
row->clear();
|
row->clear();
|
||||||
// create a dataset tensor row and fetch. Then we convert the output to MSTensor
|
// create a dataset tensor row and fetch. Then we convert the output to MSTensor
|
||||||
std::vector<std::shared_ptr<dataset::Tensor>> md_row;
|
std::vector<std::shared_ptr<dataset::Tensor>> md_row;
|
||||||
|
@ -76,6 +76,7 @@ void Iterator::Stop() {
|
||||||
|
|
||||||
// Function to build and launch the execution tree.
|
// Function to build and launch the execution tree.
|
||||||
Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds, int32_t num_epochs) {
|
Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds, int32_t num_epochs) {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(ds);
|
||||||
runtime_context_ = std::make_unique<NativeRuntimeContext>();
|
runtime_context_ = std::make_unique<NativeRuntimeContext>();
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(runtime_context_ != nullptr, "Create runtime_context_ failed.");
|
CHECK_FAIL_RETURN_UNEXPECTED(runtime_context_ != nullptr, "Create runtime_context_ failed.");
|
||||||
RETURN_IF_NOT_OK(runtime_context_->Init());
|
RETURN_IF_NOT_OK(runtime_context_->Init());
|
||||||
|
|
|
@ -107,7 +107,7 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
|
||||||
// clear the old tensor row
|
// clear the old tensor row
|
||||||
out_row->clear();
|
out_row->clear();
|
||||||
#ifndef ENABLE_SECURITY
|
#ifndef ENABLE_SECURITY
|
||||||
bool isProfilingEnable = root_->Tree()->GetProfilingManager()->IsProfilingEnable();
|
bool is_profiling_enable = root_->Tree()->GetProfilingManager()->IsProfilingEnable();
|
||||||
#endif
|
#endif
|
||||||
// Once eof is handled, always return empty row. Class must be destroyed and recreated if you
|
// Once eof is handled, always return empty row. Class must be destroyed and recreated if you
|
||||||
// want to iterate again.
|
// want to iterate again.
|
||||||
|
@ -131,7 +131,7 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
|
||||||
if (out_row->eoe()) {
|
if (out_row->eoe()) {
|
||||||
MS_LOG(INFO) << "End of data iteration.";
|
MS_LOG(INFO) << "End of data iteration.";
|
||||||
#ifndef ENABLE_SECURITY
|
#ifndef ENABLE_SECURITY
|
||||||
if (isProfilingEnable) {
|
if (is_profiling_enable) {
|
||||||
root_->Tree()->SetEpochEnd();
|
root_->Tree()->SetEpochEnd();
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -248,9 +248,13 @@ Status BatchOp::WorkerEntry(int32_t workerId) {
|
||||||
Status BatchOp::MakeBatchedRow(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair, TensorRow *new_row) {
|
Status BatchOp::MakeBatchedRow(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair, TensorRow *new_row) {
|
||||||
RETURN_UNEXPECTED_IF_NULL(table_pair.first);
|
RETURN_UNEXPECTED_IF_NULL(table_pair.first);
|
||||||
#ifdef ENABLE_PYTHON
|
#ifdef ENABLE_PYTHON
|
||||||
if (!in_col_names_.empty()) RETURN_IF_NOT_OK(MapColumns(&table_pair)); // pass it through pyfunc
|
if (!in_col_names_.empty()) {
|
||||||
|
RETURN_IF_NOT_OK(MapColumns(&table_pair));
|
||||||
|
} // pass it through pyfun
|
||||||
#endif
|
#endif
|
||||||
if (pad_) RETURN_IF_NOT_OK(PadColumns(&table_pair.first, pad_info_, column_name_id_map_)); // do padding if needed
|
if (pad_) {
|
||||||
|
RETURN_IF_NOT_OK(PadColumns(&table_pair.first, pad_info_, column_name_id_map_));
|
||||||
|
} // do padding if needed
|
||||||
RETURN_IF_NOT_OK(BatchRows(&table_pair.first, new_row, table_pair.first->size()));
|
RETURN_IF_NOT_OK(BatchRows(&table_pair.first, new_row, table_pair.first->size()));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -242,7 +242,7 @@ class BatchOp : public ParallelOp {
|
||||||
|
|
||||||
// the number of thread pulling from the mOutConnector of the Op below
|
// the number of thread pulling from the mOutConnector of the Op below
|
||||||
// @return int32_t, 1
|
// @return int32_t, 1
|
||||||
int32_t num_consumers() const override { return 1; }
|
int32_t NumConsumers() const override { return 1; }
|
||||||
|
|
||||||
// get the batch size for next batch
|
// get the batch size for next batch
|
||||||
// @return Status The status code returned
|
// @return Status The status code returned
|
||||||
|
|
|
@ -71,11 +71,11 @@ class BuildSentencePieceVocabOp : public PipelineOp {
|
||||||
|
|
||||||
// Getter
|
// Getter
|
||||||
// @return the number of workers
|
// @return the number of workers
|
||||||
int32_t num_producers() const override { return 1; }
|
int32_t NumProducers() const override { return 1; }
|
||||||
|
|
||||||
// Getter
|
// Getter
|
||||||
// @return the number of threads consuming from the previous Connector
|
// @return the number of threads consuming from the previous Connector
|
||||||
int32_t num_consumers() const override { return 1; }
|
int32_t NumConsumers() const override { return 1; }
|
||||||
|
|
||||||
Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildSentencePieceVocabOp"); }
|
Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildSentencePieceVocabOp"); }
|
||||||
|
|
||||||
|
|
|
@ -68,11 +68,11 @@ class BuildVocabOp : public ParallelOp {
|
||||||
|
|
||||||
/// Getter
|
/// Getter
|
||||||
/// @return the number of workers
|
/// @return the number of workers
|
||||||
int32_t num_producers() const override { return 1; }
|
int32_t NumProducers() const override { return 1; }
|
||||||
|
|
||||||
/// Getter
|
/// Getter
|
||||||
/// @return the number of threads consuming from the previous Connector
|
/// @return the number of threads consuming from the previous Connector
|
||||||
int32_t num_consumers() const override { return 1; }
|
int32_t NumConsumers() const override { return 1; }
|
||||||
|
|
||||||
Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildVocabOp"); }
|
Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildVocabOp"); }
|
||||||
|
|
||||||
|
|
|
@ -69,6 +69,7 @@ Status CacheBase::FetchSamplesToWorkers() {
|
||||||
int64_t wait_cnt = 0;
|
int64_t wait_cnt = 0;
|
||||||
int64_t prefetch_cnt = 0;
|
int64_t prefetch_cnt = 0;
|
||||||
// Kick off several threads which will prefetch prefetch_size_ rows in advance.
|
// Kick off several threads which will prefetch prefetch_size_ rows in advance.
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(tree_);
|
||||||
RETURN_IF_NOT_OK(
|
RETURN_IF_NOT_OK(
|
||||||
tree_->LaunchWorkers(num_prefetchers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1), Name()));
|
tree_->LaunchWorkers(num_prefetchers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1), Name()));
|
||||||
auto send_to_que = [](QueueList<std::unique_ptr<IOBlock>> &qList, int32_t worker_id,
|
auto send_to_que = [](QueueList<std::unique_ptr<IOBlock>> &qList, int32_t worker_id,
|
||||||
|
@ -80,7 +81,7 @@ Status CacheBase::FetchSamplesToWorkers() {
|
||||||
// Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them
|
// Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them
|
||||||
// to the WorkerEntry.
|
// to the WorkerEntry.
|
||||||
do {
|
do {
|
||||||
if (AllowCacheMiss() && wait_cnt > 0 && wait_cnt % op_num_repeats_per_epoch() == 0) {
|
if (AllowCacheMiss() && wait_cnt > 0 && wait_cnt % GetOpNumRepeatsPerEpoch() == 0) {
|
||||||
MS_LOG(INFO) << "Epoch: " << op_current_epochs_ << " Cache Miss : " << num_cache_miss_
|
MS_LOG(INFO) << "Epoch: " << op_current_epochs_ << " Cache Miss : " << num_cache_miss_
|
||||||
<< " Total number of rows : " << row_cnt_;
|
<< " Total number of rows : " << row_cnt_;
|
||||||
}
|
}
|
||||||
|
@ -155,7 +156,7 @@ Status CacheBase::FetchSamplesToWorkers() {
|
||||||
}
|
}
|
||||||
// Dump the last epoch result (approximately) without waiting for the worker threads to come back.
|
// Dump the last epoch result (approximately) without waiting for the worker threads to come back.
|
||||||
if (AllowCacheMiss()) {
|
if (AllowCacheMiss()) {
|
||||||
MS_LOG(INFO) << "Epoch: " << wait_cnt / op_num_repeats_per_epoch() << " Cache Miss : " << num_cache_miss_
|
MS_LOG(INFO) << "Epoch: " << wait_cnt / GetOpNumRepeatsPerEpoch() << " Cache Miss : " << num_cache_miss_
|
||||||
<< " Total number of rows : " << row_cnt_;
|
<< " Total number of rows : " << row_cnt_;
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -202,6 +203,7 @@ Status CacheBase::FetchFromCache(int32_t worker_id) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CacheBase::RegisterResources() {
|
Status CacheBase::RegisterResources() {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(tree_);
|
||||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||||
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
||||||
RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks()));
|
RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks()));
|
||||||
|
|
|
@ -75,7 +75,7 @@ class CacheBase : public ParallelOp {
|
||||||
|
|
||||||
/// \brief Getter for the cache client
|
/// \brief Getter for the cache client
|
||||||
/// \return shared ptr to the cache client
|
/// \return shared ptr to the cache client
|
||||||
std::shared_ptr<CacheClient> cache_client() { return cache_client_; }
|
std::shared_ptr<CacheClient> GetCacheClient() { return cache_client_; }
|
||||||
/// \brief Setter for the cache client
|
/// \brief Setter for the cache client
|
||||||
void SetCacheClient(std::shared_ptr<CacheClient> cache_client) { cache_client_ = std::move(cache_client); }
|
void SetCacheClient(std::shared_ptr<CacheClient> cache_client) { cache_client_ = std::move(cache_client); }
|
||||||
/// \brief Derived class must implement this method if a cache miss is treated as error
|
/// \brief Derived class must implement this method if a cache miss is treated as error
|
||||||
|
|
|
@ -40,6 +40,7 @@ CacheOp::~CacheOp() = default;
|
||||||
|
|
||||||
// This class functor will provide the master loop that drives the logic for performing the work
|
// This class functor will provide the master loop that drives the logic for performing the work
|
||||||
Status CacheOp::operator()() {
|
Status CacheOp::operator()() {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(tree_);
|
||||||
if (!sampler_) {
|
if (!sampler_) {
|
||||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||||
"Invalid parameter, CacheOp requires a sampler before it can be executed, but got nullptr.");
|
"Invalid parameter, CacheOp requires a sampler before it can be executed, but got nullptr.");
|
||||||
|
@ -166,6 +167,7 @@ Status CacheOp::WorkerEntry(int32_t worker_id) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
Status CacheOp::RegisterResources() {
|
Status CacheOp::RegisterResources() {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(tree_);
|
||||||
RETURN_IF_NOT_OK(CacheBase::RegisterResources());
|
RETURN_IF_NOT_OK(CacheBase::RegisterResources());
|
||||||
RETURN_IF_NOT_OK(rows_cache_done_.Register(tree_->AllTasks()));
|
RETURN_IF_NOT_OK(rows_cache_done_.Register(tree_->AllTasks()));
|
||||||
RETURN_IF_NOT_OK(keys_miss_->Register(tree_->AllTasks()));
|
RETURN_IF_NOT_OK(keys_miss_->Register(tree_->AllTasks()));
|
||||||
|
|
|
@ -194,7 +194,7 @@ Status ConcatOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t ConcatOp::num_consumers() const {
|
int32_t ConcatOp::NumConsumers() const {
|
||||||
if (parent_.empty()) {
|
if (parent_.empty()) {
|
||||||
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
|
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -202,16 +202,16 @@ int32_t ConcatOp::num_consumers() const {
|
||||||
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
|
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
|
||||||
return 0;
|
return 0;
|
||||||
} else {
|
} else {
|
||||||
return parent_[0]->num_consumers();
|
return parent_[0]->NumConsumers();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t ConcatOp::num_producers() const {
|
int32_t ConcatOp::NumProducers() const {
|
||||||
if (child_.empty() || child_[0] == nullptr) {
|
if (child_.empty() || child_[0] == nullptr) {
|
||||||
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
|
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
|
||||||
return 0;
|
return 0;
|
||||||
} else {
|
} else {
|
||||||
return child_[0]->num_producers();
|
return child_[0]->NumProducers();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
|
|
|
@ -73,8 +73,8 @@ class ConcatOp : public PipelineOp {
|
||||||
Status GetNumClasses(int64_t *num_classes) override;
|
Status GetNumClasses(int64_t *num_classes) override;
|
||||||
|
|
||||||
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
||||||
int32_t num_consumers() const override;
|
int32_t NumConsumers() const override;
|
||||||
int32_t num_producers() const override;
|
int32_t NumProducers() const override;
|
||||||
|
|
||||||
/// Check if the current sample will be taken or dropped
|
/// Check if the current sample will be taken or dropped
|
||||||
/// \return bool
|
/// \return bool
|
||||||
|
|
|
@ -312,9 +312,9 @@ Status DatasetOp::PrepareOperator() {
|
||||||
// The consumer of the root node is assumed to be one thread.
|
// The consumer of the root node is assumed to be one thread.
|
||||||
// If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion.
|
// If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion.
|
||||||
if (parent_.empty()) {
|
if (parent_.empty()) {
|
||||||
this->CreateConnector(num_producers(), 1);
|
this->CreateConnector(NumProducers(), 1);
|
||||||
} else {
|
} else {
|
||||||
this->CreateConnector(num_producers(), parent_[0]->num_consumers());
|
this->CreateConnector(NumProducers(), parent_[0]->NumConsumers());
|
||||||
}
|
}
|
||||||
if (out_connector_) {
|
if (out_connector_) {
|
||||||
RETURN_IF_NOT_OK(out_connector_->Register(tree_->AllTasks()));
|
RETURN_IF_NOT_OK(out_connector_->Register(tree_->AllTasks()));
|
||||||
|
|
|
@ -209,33 +209,33 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
||||||
|
|
||||||
// \brief Getter function
|
// \brief Getter function
|
||||||
// \return The number of workers in this op
|
// \return The number of workers in this op
|
||||||
virtual int32_t num_workers() const = 0;
|
virtual int32_t NumWorkers() const = 0;
|
||||||
|
|
||||||
// \brief Getter function
|
// \brief Getter function
|
||||||
// \return The number of threads consuming from previous op.
|
// \return The number of threads consuming from previous op.
|
||||||
virtual int32_t num_consumers() const = 0;
|
virtual int32_t NumConsumers() const = 0;
|
||||||
|
|
||||||
// \brief Getter function
|
// \brief Getter function
|
||||||
// \return The number of threads producing to the output connector.
|
// \return The number of threads producing to the output connector.
|
||||||
virtual int32_t num_producers() const = 0;
|
virtual int32_t NumProducers() const = 0;
|
||||||
|
|
||||||
// \brief Getter function
|
// \brief Getter function
|
||||||
// \return T/F if this is an inlined operator
|
// \return T/F if this is an inlined operator
|
||||||
bool inlined() const { return (oc_queue_size_ == 0); }
|
bool inlined() const { return (oc_queue_size_ == 0); }
|
||||||
|
|
||||||
// \brief Setter function, set the number of total repeats for the operator
|
// \brief Setter function, set the number of total repeats for the operator
|
||||||
void set_total_repeats(int32_t total_repeats) { op_total_repeats_ = total_repeats; }
|
void SetTotalRepeats(int32_t total_repeats) { op_total_repeats_ = total_repeats; }
|
||||||
|
|
||||||
// \brief Setter function, set the number of repeats per epoch for the operator
|
// \brief Setter function, set the number of repeats per epoch for the operator
|
||||||
void set_num_repeats_per_epoch(int32_t num_repeats_per_epoch) { op_num_repeats_per_epoch_ = num_repeats_per_epoch; }
|
void SetNumRepeatsPerEpoch(int32_t num_repeats_per_epoch) { op_num_repeats_per_epoch_ = num_repeats_per_epoch; }
|
||||||
|
|
||||||
// \brief Getter function
|
// \brief Getter function
|
||||||
// \return The number of required repeats for the operator
|
// \return The number of required repeats for the operator
|
||||||
int32_t op_total_repeats() { return op_total_repeats_; }
|
int32_t GetOpTotalRepeats() { return op_total_repeats_; }
|
||||||
|
|
||||||
// \brief Getter function
|
// \brief Getter function
|
||||||
// \return The number of repeats per epoch for the operator
|
// \return The number of repeats per epoch for the operator
|
||||||
int32_t op_num_repeats_per_epoch() const { return op_num_repeats_per_epoch_; }
|
int32_t GetOpNumRepeatsPerEpoch() const { return op_num_repeats_per_epoch_; }
|
||||||
|
|
||||||
// \brief Register the internal worker connectors. No op unless it is a parallel op
|
// \brief Register the internal worker connectors. No op unless it is a parallel op
|
||||||
// \return Status
|
// \return Status
|
||||||
|
@ -393,7 +393,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
||||||
// \notes No public interface. Only the class itself, or it's friend the execution tree can set
|
// \notes No public interface. Only the class itself, or it's friend the execution tree can set
|
||||||
// this
|
// this
|
||||||
// \param op_id - the Id value to set into the operator
|
// \param op_id - the Id value to set into the operator
|
||||||
void set_id(int32_t op_id) { operator_id_ = op_id; }
|
void SetId(int32_t op_id) { operator_id_ = op_id; }
|
||||||
|
|
||||||
// Sets the tree into the op so that the operator has a back pointer to the tree.
|
// Sets the tree into the op so that the operator has a back pointer to the tree.
|
||||||
// \param tree - the tree to assign to the op.
|
// \param tree - the tree to assign to the op.
|
||||||
|
|
|
@ -187,8 +187,8 @@ Status DeviceQueueOp::SendDataToAscend() {
|
||||||
|
|
||||||
#ifndef ENABLE_SECURITY
|
#ifndef ENABLE_SECURITY
|
||||||
std::shared_ptr<DeviceQueueTracing> profiling_node;
|
std::shared_ptr<DeviceQueueTracing> profiling_node;
|
||||||
bool isProfilingEnable = tree_->GetProfilingManager()->IsProfilingEnable();
|
bool is_profiling_enable = tree_->GetProfilingManager()->IsProfilingEnable();
|
||||||
if (isProfilingEnable) {
|
if (is_profiling_enable) {
|
||||||
std::shared_ptr<Tracing> node;
|
std::shared_ptr<Tracing> node;
|
||||||
RETURN_IF_NOT_OK(tree_->GetProfilingManager()->GetTracingNode(kDeviceQueueTracingName, &node));
|
RETURN_IF_NOT_OK(tree_->GetProfilingManager()->GetTracingNode(kDeviceQueueTracingName, &node));
|
||||||
profiling_node = std::dynamic_pointer_cast<DeviceQueueTracing>(node);
|
profiling_node = std::dynamic_pointer_cast<DeviceQueueTracing>(node);
|
||||||
|
@ -196,7 +196,7 @@ Status DeviceQueueOp::SendDataToAscend() {
|
||||||
connector_capacity = ChildOpConnectorCapacity();
|
connector_capacity = ChildOpConnectorCapacity();
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
bool isProfilingEnable = false;
|
bool is_profiling_enable = false;
|
||||||
#endif
|
#endif
|
||||||
#ifdef ENABLE_DUMP_IR
|
#ifdef ENABLE_DUMP_IR
|
||||||
md_channel_info_->RecordBatchQueue(ChildOpConnectorSize());
|
md_channel_info_->RecordBatchQueue(ChildOpConnectorSize());
|
||||||
|
@ -218,14 +218,14 @@ Status DeviceQueueOp::SendDataToAscend() {
|
||||||
md_channel_info_->RecordPreprocessBatch(send_batch);
|
md_channel_info_->RecordPreprocessBatch(send_batch);
|
||||||
md_channel_info_->RecordPushStartTime();
|
md_channel_info_->RecordPushStartTime();
|
||||||
#endif
|
#endif
|
||||||
RETURN_IF_NOT_OK(SendRowToTdt(curr_row, isProfilingEnable, &tdt_cost));
|
RETURN_IF_NOT_OK(SendRowToTdt(curr_row, is_profiling_enable, &tdt_cost));
|
||||||
if (first_push_flag_ != true) {
|
if (first_push_flag_ != true) {
|
||||||
MS_LOG(INFO) << "Loading dataset and push first batch into device successful.";
|
MS_LOG(INFO) << "Loading dataset and push first batch into device successful.";
|
||||||
first_push_flag_ = true;
|
first_push_flag_ = true;
|
||||||
}
|
}
|
||||||
#ifndef ENABLE_SECURITY
|
#ifndef ENABLE_SECURITY
|
||||||
DetectPerBatchTime(&batch_record_start, &batch_record_end);
|
DetectPerBatchTime(&batch_record_start, &batch_record_end);
|
||||||
ProfilingRecorder(isProfilingEnable, profiling_node, send_batch, tdt_cost, &batch_start_time, &end_time,
|
ProfilingRecorder(is_profiling_enable, profiling_node, send_batch, tdt_cost, &batch_start_time, &end_time,
|
||||||
connector_capacity, connector_size);
|
connector_capacity, connector_size);
|
||||||
#endif
|
#endif
|
||||||
send_batch++;
|
send_batch++;
|
||||||
|
@ -244,7 +244,7 @@ Status DeviceQueueOp::SendDataToAscend() {
|
||||||
LimitSendingBatches(send_batch, &sending_num, cfg);
|
LimitSendingBatches(send_batch, &sending_num, cfg);
|
||||||
|
|
||||||
#ifndef ENABLE_SECURITY
|
#ifndef ENABLE_SECURITY
|
||||||
if (isProfilingEnable) {
|
if (is_profiling_enable) {
|
||||||
connector_size = ChildOpConnectorSize();
|
connector_size = ChildOpConnectorSize();
|
||||||
connector_capacity = ChildOpConnectorCapacity();
|
connector_capacity = ChildOpConnectorCapacity();
|
||||||
}
|
}
|
||||||
|
@ -252,8 +252,8 @@ Status DeviceQueueOp::SendDataToAscend() {
|
||||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&curr_row));
|
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&curr_row));
|
||||||
}
|
}
|
||||||
if (curr_row.eoe() && send_epoch_end_) {
|
if (curr_row.eoe() && send_epoch_end_) {
|
||||||
TensorRow currRow;
|
TensorRow dummy_row;
|
||||||
auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost,
|
auto status = tdtInstancePtr->hostPush(dummy_row, true, channel_name_, is_profiling_enable, tdt_cost,
|
||||||
ACL_TENSOR_DATA_END_OF_SEQUENCE);
|
ACL_TENSOR_DATA_END_OF_SEQUENCE);
|
||||||
if (status != Status::OK()) {
|
if (status != Status::OK()) {
|
||||||
if (stop_send_) {
|
if (stop_send_) {
|
||||||
|
@ -273,7 +273,7 @@ Status DeviceQueueOp::SendDataToAscend() {
|
||||||
stop_send_ = true;
|
stop_send_ = true;
|
||||||
}
|
}
|
||||||
#ifndef ENABLE_SECURITY
|
#ifndef ENABLE_SECURITY
|
||||||
if (isProfilingEnable) {
|
if (is_profiling_enable) {
|
||||||
connector_size = ChildOpConnectorSize();
|
connector_size = ChildOpConnectorSize();
|
||||||
connector_capacity = ChildOpConnectorCapacity();
|
connector_capacity = ChildOpConnectorCapacity();
|
||||||
tree_->SetEpochEnd();
|
tree_->SetEpochEnd();
|
||||||
|
@ -311,8 +311,8 @@ void DeviceQueueOp::LimitSendingBatches(int64_t send_batch, int64_t *sending_num
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DeviceQueueOp::SendRowToTdt(TensorRow currRow, bool isProfilingEnable, int32_t *tdt_cost) {
|
Status DeviceQueueOp::SendRowToTdt(TensorRow curr_row, bool is_profiling_enable, int32_t *tdt_cost) {
|
||||||
auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, *tdt_cost);
|
auto status = tdtInstancePtr->hostPush(curr_row, true, channel_name_, is_profiling_enable, *tdt_cost);
|
||||||
if (status != Status::OK()) {
|
if (status != Status::OK()) {
|
||||||
if (stop_send_) {
|
if (stop_send_) {
|
||||||
MS_LOG(INFO) << "stop_send received";
|
MS_LOG(INFO) << "stop_send received";
|
||||||
|
@ -328,7 +328,7 @@ Status DeviceQueueOp::SendRowToTdt(TensorRow currRow, bool isProfilingEnable, in
|
||||||
}
|
}
|
||||||
if (create_data_info_queue_) {
|
if (create_data_info_queue_) {
|
||||||
DATA_INFO data_info;
|
DATA_INFO data_info;
|
||||||
(void)std::transform(currRow.begin(), currRow.end(), std::back_inserter(data_info),
|
(void)std::transform(curr_row.begin(), curr_row.end(), std::back_inserter(data_info),
|
||||||
[](const std::shared_ptr<Tensor> &ts) { return std::make_pair(ts->type(), ts->shape()); });
|
[](const std::shared_ptr<Tensor> &ts) { return std::make_pair(ts->type(), ts->shape()); });
|
||||||
RETURN_IF_NOT_OK(data_info_queue_ptr_->Add(data_info));
|
RETURN_IF_NOT_OK(data_info_queue_ptr_->Add(data_info));
|
||||||
}
|
}
|
||||||
|
@ -376,6 +376,7 @@ Status DeviceQueueOp::SetThreadDevice() {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DeviceQueueOp::LaunchParallelCopyThread() {
|
Status DeviceQueueOp::LaunchParallelCopyThread() {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(tree_);
|
||||||
// Every thread use cuda api should SetThreadDevice
|
// Every thread use cuda api should SetThreadDevice
|
||||||
RETURN_IF_NOT_OK(SetThreadDevice());
|
RETURN_IF_NOT_OK(SetThreadDevice());
|
||||||
// CircularPool may not safe under multi-threads scenario, so one worker with one pool
|
// CircularPool may not safe under multi-threads scenario, so one worker with one pool
|
||||||
|
@ -396,6 +397,7 @@ Status DeviceQueueOp::LaunchParallelCopyThread() {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DeviceQueueOp::PushDataToGPU() {
|
Status DeviceQueueOp::PushDataToGPU() {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(tree_);
|
||||||
// Every thread use cuda api should SetThreadDevice
|
// Every thread use cuda api should SetThreadDevice
|
||||||
RETURN_IF_NOT_OK(SetThreadDevice());
|
RETURN_IF_NOT_OK(SetThreadDevice());
|
||||||
TaskManager::FindMe()->Post();
|
TaskManager::FindMe()->Post();
|
||||||
|
@ -405,8 +407,8 @@ Status DeviceQueueOp::PushDataToGPU() {
|
||||||
int32_t connector_size = 0;
|
int32_t connector_size = 0;
|
||||||
int32_t connector_capacity = 0;
|
int32_t connector_capacity = 0;
|
||||||
std::shared_ptr<DeviceQueueTracing> profiling_node;
|
std::shared_ptr<DeviceQueueTracing> profiling_node;
|
||||||
bool isProfilingEnable = tree_->GetProfilingManager()->IsProfilingEnable();
|
bool is_profiling_enable = tree_->GetProfilingManager()->IsProfilingEnable();
|
||||||
if (isProfilingEnable) {
|
if (is_profiling_enable) {
|
||||||
std::shared_ptr<Tracing> node;
|
std::shared_ptr<Tracing> node;
|
||||||
RETURN_IF_NOT_OK(tree_->GetProfilingManager()->GetTracingNode(kDeviceQueueTracingName, &node));
|
RETURN_IF_NOT_OK(tree_->GetProfilingManager()->GetTracingNode(kDeviceQueueTracingName, &node));
|
||||||
profiling_node = std::dynamic_pointer_cast<DeviceQueueTracing>(node);
|
profiling_node = std::dynamic_pointer_cast<DeviceQueueTracing>(node);
|
||||||
|
@ -452,7 +454,7 @@ Status DeviceQueueOp::PushDataToGPU() {
|
||||||
RETURN_IF_NOT_OK(RetryPushData(handle, items));
|
RETURN_IF_NOT_OK(RetryPushData(handle, items));
|
||||||
send_batch++;
|
send_batch++;
|
||||||
#ifndef ENABLE_SECURITY
|
#ifndef ENABLE_SECURITY
|
||||||
if (isProfilingEnable) {
|
if (is_profiling_enable) {
|
||||||
uint64_t end_time = ProfilingTime::GetCurMilliSecond();
|
uint64_t end_time = ProfilingTime::GetCurMilliSecond();
|
||||||
// record push data time
|
// record push data time
|
||||||
profiling_node->Record(TIME, TDT_PUSH_TIME, send_batch, push_cost, end_time);
|
profiling_node->Record(TIME, TDT_PUSH_TIME, send_batch, push_cost, end_time);
|
||||||
|
@ -502,7 +504,7 @@ Status DeviceQueueOp::PushDataToGPU() {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DeviceQueueOp::RetryPushData(unsigned int handle, const std::vector<DataItemGpu> &items) {
|
Status DeviceQueueOp::RetryPushData(unsigned int handle, const std::vector<DataItemGpu> &items) {
|
||||||
bool flagLog = false;
|
bool flag_log = false;
|
||||||
while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) {
|
while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) {
|
||||||
BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME);
|
BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME);
|
||||||
if (ret) {
|
if (ret) {
|
||||||
|
@ -511,9 +513,9 @@ Status DeviceQueueOp::RetryPushData(unsigned int handle, const std::vector<DataI
|
||||||
"Invalid data, check the output of dataset with creating iterator and print data item.");
|
"Invalid data, check the output of dataset with creating iterator and print data item.");
|
||||||
} else {
|
} else {
|
||||||
if (!stop_send_) {
|
if (!stop_send_) {
|
||||||
if (!flagLog) {
|
if (!flag_log) {
|
||||||
MS_LOG(DEBUG) << "Retry pushing data...";
|
MS_LOG(DEBUG) << "Retry pushing data...";
|
||||||
flagLog = true;
|
flag_log = true;
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -669,11 +671,11 @@ void DeviceQueueOp::Print(std::ostream &out, bool show_all) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef ENABLE_SECURITY
|
#ifndef ENABLE_SECURITY
|
||||||
void DeviceQueueOp::ProfilingRecorder(bool isProfilingEnable, std::shared_ptr<DeviceQueueTracing> profiling_node,
|
void DeviceQueueOp::ProfilingRecorder(bool is_profiling_enable, std::shared_ptr<DeviceQueueTracing> profiling_node,
|
||||||
int64_t send_batch, int32_t tdt_cost, uint64_t *batch_start_time,
|
int64_t send_batch, int32_t tdt_cost, uint64_t *batch_start_time,
|
||||||
uint64_t *end_time, int32_t connector_capacity, int32_t connector_size) {
|
uint64_t *end_time, int32_t connector_capacity, int32_t connector_size) {
|
||||||
// Record the pipeline profiling info
|
// Record the pipeline profiling info
|
||||||
if (isProfilingEnable) {
|
if (is_profiling_enable) {
|
||||||
*end_time = ProfilingTime::GetCurMilliSecond();
|
*end_time = ProfilingTime::GetCurMilliSecond();
|
||||||
// record push tdt time
|
// record push tdt time
|
||||||
profiling_node->Record(TIME, TDT_PUSH_TIME, send_batch + 1, tdt_cost, *end_time);
|
profiling_node->Record(TIME, TDT_PUSH_TIME, send_batch + 1, tdt_cost, *end_time);
|
||||||
|
|
|
@ -76,7 +76,7 @@ class DeviceQueueOp : public PipelineOp {
|
||||||
|
|
||||||
Status EoeReceived(int32_t worker_id) override;
|
Status EoeReceived(int32_t worker_id) override;
|
||||||
|
|
||||||
const int32_t get_prefetch_size() { return prefetch_size_; }
|
const int32_t GetPrefetchSize() { return prefetch_size_; }
|
||||||
|
|
||||||
void StopSend() { stop_send_ = true; }
|
void StopSend() { stop_send_ = true; }
|
||||||
|
|
||||||
|
@ -105,11 +105,11 @@ class DeviceQueueOp : public PipelineOp {
|
||||||
Status operator()() override;
|
Status operator()() override;
|
||||||
#ifndef ENABLE_SECURITY
|
#ifndef ENABLE_SECURITY
|
||||||
// Record the pipeline profiling info
|
// Record the pipeline profiling info
|
||||||
void ProfilingRecorder(bool isProfilingEnable, std::shared_ptr<DeviceQueueTracing> profiling_node, int64_t send_batch,
|
void ProfilingRecorder(bool is_profiling_enable, std::shared_ptr<DeviceQueueTracing> profiling_node,
|
||||||
int32_t tdt_cost, uint64_t *batch_start_time, uint64_t *end_time, int32_t connector_capacity,
|
int64_t send_batch, int32_t tdt_cost, uint64_t *batch_start_time, uint64_t *end_time,
|
||||||
int32_t connector_size);
|
int32_t connector_capacity, int32_t connector_size);
|
||||||
#endif
|
|
||||||
|
|
||||||
|
#endif
|
||||||
// Op name getter
|
// Op name getter
|
||||||
// @return Name of the current Op
|
// @return Name of the current Op
|
||||||
std::string Name() const override { return kDeviceQueueOp; }
|
std::string Name() const override { return kDeviceQueueOp; }
|
||||||
|
@ -128,7 +128,7 @@ class DeviceQueueOp : public PipelineOp {
|
||||||
void WaitContinueSignal() const;
|
void WaitContinueSignal() const;
|
||||||
Status SendDataToAscend();
|
Status SendDataToAscend();
|
||||||
void LimitSendingBatches(int64_t send_batch, int64_t *sending_num, std::shared_ptr<ConfigManager> cfg);
|
void LimitSendingBatches(int64_t send_batch, int64_t *sending_num, std::shared_ptr<ConfigManager> cfg);
|
||||||
Status SendRowToTdt(TensorRow currRow, bool isProfilingEnable, int32_t *tdt_cost);
|
Status SendRowToTdt(TensorRow curr_row, bool is_profiling_enable, int32_t *tdt_cost);
|
||||||
bool ascend_keep_waiting_;
|
bool ascend_keep_waiting_;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
|
@ -218,7 +218,7 @@ Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate
|
||||||
|
|
||||||
return Status(StatusCode::kSuccess, "FilterOp predicate func call succeed");
|
return Status(StatusCode::kSuccess, "FilterOp predicate func call succeed");
|
||||||
}
|
}
|
||||||
int32_t FilterOp::num_consumers() const { return 1; }
|
int32_t FilterOp::NumConsumers() const { return 1; }
|
||||||
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -69,7 +69,7 @@ class FilterOp : public ParallelOp {
|
||||||
// @return Name of the current Op
|
// @return Name of the current Op
|
||||||
std::string Name() const override { return kFilterOp; }
|
std::string Name() const override { return kFilterOp; }
|
||||||
|
|
||||||
int32_t num_consumers() const override;
|
int32_t NumConsumers() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// predicate_func python callable which returns a boolean value.
|
// predicate_func python callable which returns a boolean value.
|
||||||
|
|
|
@ -46,7 +46,7 @@ MapOp::MapOp(const std::vector<std::string> &in_col_names, const std::vector<std
|
||||||
}
|
}
|
||||||
|
|
||||||
// The number of threads consuming data from previous op's output Connector.
|
// The number of threads consuming data from previous op's output Connector.
|
||||||
int32_t MapOp::num_consumers() const {
|
int32_t MapOp::NumConsumers() const {
|
||||||
// When Performance Mode is on, there is only one thread consuming from the previous Connector.
|
// When Performance Mode is on, there is only one thread consuming from the previous Connector.
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
@ -144,7 +144,7 @@ Status MapOp::operator()() {
|
||||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||||
|
|
||||||
while (!new_row.eof()) {
|
while (!new_row.eof()) {
|
||||||
if (op_current_repeats_ % op_num_repeats_per_epoch() == 0) {
|
if (op_current_repeats_ % GetOpNumRepeatsPerEpoch() == 0) {
|
||||||
RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
||||||
}
|
}
|
||||||
while (!new_row.eoe()) {
|
while (!new_row.eoe()) {
|
||||||
|
@ -168,7 +168,7 @@ Status MapOp::operator()() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// check whether this is the end of a real epoch (not all eoe signals end of epoch)
|
// check whether this is the end of a real epoch (not all eoe signals end of epoch)
|
||||||
if ((op_current_repeats_ + 1) % op_num_repeats_per_epoch() == 0) {
|
if ((op_current_repeats_ + 1) % GetOpNumRepeatsPerEpoch() == 0) {
|
||||||
RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
||||||
|
|
||||||
ep_step = 0;
|
ep_step = 0;
|
||||||
|
|
|
@ -101,7 +101,7 @@ class MapOp : public ParallelOp {
|
||||||
|
|
||||||
// Getter
|
// Getter
|
||||||
// @return the number of threads consuming data from previous op's output Connector.
|
// @return the number of threads consuming data from previous op's output Connector.
|
||||||
int32_t num_consumers() const override;
|
int32_t NumConsumers() const override;
|
||||||
|
|
||||||
// Op name getter
|
// Op name getter
|
||||||
// @return Name of the current Op
|
// @return Name of the current Op
|
||||||
|
|
|
@ -72,11 +72,11 @@ class ParallelOp : public DatasetOp {
|
||||||
|
|
||||||
// Getter
|
// Getter
|
||||||
// @return the number of workers
|
// @return the number of workers
|
||||||
int32_t num_workers() const override { return num_workers_; }
|
int32_t NumWorkers() const override { return num_workers_; }
|
||||||
|
|
||||||
// Getter
|
// Getter
|
||||||
// @return the number of threads consuming from the previous Connector
|
// @return the number of threads consuming from the previous Connector
|
||||||
int32_t num_consumers() const override { return num_workers_; }
|
int32_t NumConsumers() const override { return num_workers_; }
|
||||||
|
|
||||||
// Getter
|
// Getter
|
||||||
// @return the number of producers pushing to the output Connector
|
// @return the number of producers pushing to the output Connector
|
||||||
|
@ -84,7 +84,7 @@ class ParallelOp : public DatasetOp {
|
||||||
// when a worker connector is set up. In that case, there are n workers, and a single master
|
// when a worker connector is set up. In that case, there are n workers, and a single master
|
||||||
// such that only 1 thread is a producer rather than the n workers.
|
// such that only 1 thread is a producer rather than the n workers.
|
||||||
// @return the number of producers
|
// @return the number of producers
|
||||||
int32_t num_producers() const override { return num_producers_; }
|
int32_t NumProducers() const override { return num_producers_; }
|
||||||
|
|
||||||
// Register the internal worker connectors.
|
// Register the internal worker connectors.
|
||||||
// @return Status
|
// @return Status
|
||||||
|
|
|
@ -55,15 +55,15 @@ class PipelineOp : public DatasetOp {
|
||||||
|
|
||||||
// Getter
|
// Getter
|
||||||
// @return The number of workers inside this op. Pipeline ops only have a single worker.
|
// @return The number of workers inside this op. Pipeline ops only have a single worker.
|
||||||
int32_t num_workers() const override { return 1; }
|
int32_t NumWorkers() const override { return 1; }
|
||||||
|
|
||||||
// Getter
|
// Getter
|
||||||
// @return the number of threads consuming from the previous Connector
|
// @return the number of threads consuming from the previous Connector
|
||||||
int32_t num_consumers() const override { return 1; }
|
int32_t NumConsumers() const override { return 1; }
|
||||||
|
|
||||||
// Getter
|
// Getter
|
||||||
// @return The number of threads that push data to the output connector
|
// @return The number of threads that push data to the output connector
|
||||||
int32_t num_producers() const override { return 1; }
|
int32_t NumProducers() const override { return 1; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// *******************************************************************************
|
// *******************************************************************************
|
||||||
|
|
|
@ -76,7 +76,7 @@ TensorRow ProjectOp::Project(const TensorRow &row) {
|
||||||
// ensure that it is not called by mistake (it will generate an error).
|
// ensure that it is not called by mistake (it will generate an error).
|
||||||
Status ProjectOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. ProjectOp is an inlined operator."); }
|
Status ProjectOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. ProjectOp is an inlined operator."); }
|
||||||
|
|
||||||
int32_t ProjectOp::num_consumers() const {
|
int32_t ProjectOp::NumConsumers() const {
|
||||||
if (parent_.empty()) {
|
if (parent_.empty()) {
|
||||||
MS_LOG(DEBUG) << "Project operator, no parent node, assuming it's the root and returning 1.";
|
MS_LOG(DEBUG) << "Project operator, no parent node, assuming it's the root and returning 1.";
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -84,16 +84,16 @@ int32_t ProjectOp::num_consumers() const {
|
||||||
MS_LOG(DEBUG) << "Project operator, pointer to the first parent is null. Returning 0.";
|
MS_LOG(DEBUG) << "Project operator, pointer to the first parent is null. Returning 0.";
|
||||||
return 0;
|
return 0;
|
||||||
} else {
|
} else {
|
||||||
return parent_[0]->num_consumers();
|
return parent_[0]->NumConsumers();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t ProjectOp::num_producers() const {
|
int32_t ProjectOp::NumProducers() const {
|
||||||
if (child_.empty() || child_[0] == nullptr) {
|
if (child_.empty() || child_[0] == nullptr) {
|
||||||
MS_LOG(DEBUG) << "Project operator, pointer to child node is null. Returning 0.";
|
MS_LOG(DEBUG) << "Project operator, pointer to child node is null. Returning 0.";
|
||||||
return 0;
|
return 0;
|
||||||
} else {
|
} else {
|
||||||
return child_[0]->num_producers();
|
return child_[0]->NumProducers();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -63,11 +63,11 @@ class ProjectOp : public PipelineOp {
|
||||||
|
|
||||||
// Base-class override. Return the number of workers in the first parent.
|
// Base-class override. Return the number of workers in the first parent.
|
||||||
// @param workerId - The worker id
|
// @param workerId - The worker id
|
||||||
int32_t num_consumers() const override;
|
int32_t NumConsumers() const override;
|
||||||
|
|
||||||
// Base-class override. Return the number of producers in the first child.
|
// Base-class override. Return the number of producers in the first child.
|
||||||
// @param workerId - The worker id
|
// @param workerId - The worker id
|
||||||
int32_t num_producers() const override;
|
int32_t NumProducers() const override;
|
||||||
|
|
||||||
// Base-class override for special eoe handler.
|
// Base-class override for special eoe handler.
|
||||||
// Inline operators must override this because there is no connector to push eoe onto.
|
// Inline operators must override this because there is no connector to push eoe onto.
|
||||||
|
|
|
@ -130,7 +130,7 @@ void RenameOp::Print(std::ostream &out, // In: The output stream to print t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t RenameOp::num_consumers() const {
|
int32_t RenameOp::NumConsumers() const {
|
||||||
if (parent_.empty()) {
|
if (parent_.empty()) {
|
||||||
MS_LOG(DEBUG) << "Rename operator, no parent node, assuming it's the root and returning 1.";
|
MS_LOG(DEBUG) << "Rename operator, no parent node, assuming it's the root and returning 1.";
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -138,16 +138,16 @@ int32_t RenameOp::num_consumers() const {
|
||||||
MS_LOG(DEBUG) << "Rename operator, pointer to the first parent is null. Returning 0.";
|
MS_LOG(DEBUG) << "Rename operator, pointer to the first parent is null. Returning 0.";
|
||||||
return 0;
|
return 0;
|
||||||
} else {
|
} else {
|
||||||
return parent_[0]->num_consumers();
|
return parent_[0]->NumConsumers();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t RenameOp::num_producers() const {
|
int32_t RenameOp::NumProducers() const {
|
||||||
if (child_.empty() || child_[0] == nullptr) {
|
if (child_.empty() || child_[0] == nullptr) {
|
||||||
MS_LOG(DEBUG) << "Rename operator, pointer to child node is null. Returning 0.";
|
MS_LOG(DEBUG) << "Rename operator, pointer to child node is null. Returning 0.";
|
||||||
return 0;
|
return 0;
|
||||||
} else {
|
} else {
|
||||||
return child_[0]->num_producers();
|
return child_[0]->NumProducers();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
|
|
|
@ -63,8 +63,8 @@ class RenameOp : public PipelineOp {
|
||||||
// @param row - output pointer to the projected row.
|
// @param row - output pointer to the projected row.
|
||||||
// @param worker_id - The worker id
|
// @param worker_id - The worker id
|
||||||
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
||||||
int32_t num_consumers() const override;
|
int32_t NumConsumers() const override;
|
||||||
int32_t num_producers() const override;
|
int32_t NumProducers() const override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Rename core functionality
|
// Rename core functionality
|
||||||
|
|
|
@ -115,7 +115,7 @@ Status RepeatOp::EofReceived(int32_t worker_id) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t RepeatOp::num_consumers() const {
|
int32_t RepeatOp::NumConsumers() const {
|
||||||
if (parent_.empty()) {
|
if (parent_.empty()) {
|
||||||
MS_LOG(DEBUG) << "Repeat operator, no parent node, assuming it's root and returning 1.";
|
MS_LOG(DEBUG) << "Repeat operator, no parent node, assuming it's root and returning 1.";
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -123,7 +123,7 @@ int32_t RepeatOp::num_consumers() const {
|
||||||
MS_LOG(DEBUG) << "Repeat operator, pointer to the first parent is null. Returning 0.";
|
MS_LOG(DEBUG) << "Repeat operator, pointer to the first parent is null. Returning 0.";
|
||||||
return 0;
|
return 0;
|
||||||
} else {
|
} else {
|
||||||
return parent_[0]->num_consumers();
|
return parent_[0]->NumConsumers();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -140,12 +140,12 @@ Status RepeatOp::Reset() {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t RepeatOp::num_producers() const {
|
int32_t RepeatOp::NumProducers() const {
|
||||||
if (child_.empty() || child_[0] == nullptr) {
|
if (child_.empty() || child_[0] == nullptr) {
|
||||||
MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0.";
|
MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0.";
|
||||||
return 0;
|
return 0;
|
||||||
} else {
|
} else {
|
||||||
return child_[0]->num_producers();
|
return child_[0]->NumProducers();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -79,11 +79,11 @@ class RepeatOp : public PipelineOp {
|
||||||
|
|
||||||
// Base-class override. Return the number of workers in the first parent.
|
// Base-class override. Return the number of workers in the first parent.
|
||||||
// @param workerId - The worker id
|
// @param workerId - The worker id
|
||||||
int32_t num_consumers() const override;
|
int32_t NumConsumers() const override;
|
||||||
|
|
||||||
// Base-class override. Return the number of producers in the first child.
|
// Base-class override. Return the number of producers in the first child.
|
||||||
// @param workerId - The worker id
|
// @param workerId - The worker id
|
||||||
int32_t num_producers() const override;
|
int32_t NumProducers() const override;
|
||||||
|
|
||||||
// Op name getter
|
// Op name getter
|
||||||
// @return Name of the current Op
|
// @return Name of the current Op
|
||||||
|
|
|
@ -66,7 +66,7 @@ Status SkipOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe)
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t SkipOp::num_consumers() const {
|
int32_t SkipOp::NumConsumers() const {
|
||||||
if (parent_.empty()) {
|
if (parent_.empty()) {
|
||||||
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
|
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -74,16 +74,16 @@ int32_t SkipOp::num_consumers() const {
|
||||||
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
|
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
|
||||||
return 0;
|
return 0;
|
||||||
} else {
|
} else {
|
||||||
return parent_[0]->num_consumers();
|
return parent_[0]->NumConsumers();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t SkipOp::num_producers() const {
|
int32_t SkipOp::NumProducers() const {
|
||||||
if (child_.empty() || child_[0] == nullptr) {
|
if (child_.empty() || child_[0] == nullptr) {
|
||||||
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
|
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
|
||||||
return 0;
|
return 0;
|
||||||
} else {
|
} else {
|
||||||
return child_[0]->num_producers();
|
return child_[0]->NumProducers();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -49,8 +49,8 @@ class SkipOp : public PipelineOp {
|
||||||
// @return Name of the current Op
|
// @return Name of the current Op
|
||||||
std::string Name() const override { return kSkipOp; }
|
std::string Name() const override { return kSkipOp; }
|
||||||
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
||||||
int32_t num_consumers() const override;
|
int32_t NumConsumers() const override;
|
||||||
int32_t num_producers() const override;
|
int32_t NumProducers() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int32_t max_skips_; // The number of skips that the user requested
|
int32_t max_skips_; // The number of skips that the user requested
|
||||||
|
|
|
@ -177,9 +177,9 @@ bool CelebAOp::CheckDatasetTypeValid() {
|
||||||
|
|
||||||
Status CelebAOp::ParseImageAttrInfo() {
|
Status CelebAOp::ParseImageAttrInfo() {
|
||||||
std::vector<std::string> image_infos;
|
std::vector<std::string> image_infos;
|
||||||
bool needMoreData = true;
|
bool need_more_data = true;
|
||||||
RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos));
|
RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos));
|
||||||
while (!image_infos.empty() && needMoreData) {
|
while (!image_infos.empty() && need_more_data) {
|
||||||
for (uint32_t index = 0; index < image_infos.size(); index++) {
|
for (uint32_t index = 0; index < image_infos.size(); index++) {
|
||||||
std::string image_info = image_infos[index];
|
std::string image_info = image_infos[index];
|
||||||
std::vector<std::string> split = Split(image_info);
|
std::vector<std::string> split = Split(image_info);
|
||||||
|
|
|
@ -201,13 +201,13 @@ Status CifarOp::ReadCifar100BlockData() {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CifarOp::GetCifarFiles() {
|
Status CifarOp::GetCifarFiles() {
|
||||||
const std::string kExtension = ".bin";
|
const std::string extension = ".bin";
|
||||||
Path dir_path(folder_path_);
|
Path dir_path(folder_path_);
|
||||||
auto dirIt = Path::DirIterator::OpenDirectory(&dir_path);
|
auto dirIt = Path::DirIterator::OpenDirectory(&dir_path);
|
||||||
if (dirIt) {
|
if (dirIt) {
|
||||||
while (dirIt->HasNext()) {
|
while (dirIt->HasNext()) {
|
||||||
Path file = dirIt->Next();
|
Path file = dirIt->Next();
|
||||||
if (file.Extension() == kExtension) {
|
if (file.Extension() == extension) {
|
||||||
cifar_files_.push_back(file.ToString());
|
cifar_files_.push_back(file.ToString());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -101,7 +101,7 @@ class CifarOp : public MappableLeafOp {
|
||||||
Status ParseCifarData();
|
Status ParseCifarData();
|
||||||
|
|
||||||
/// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
|
/// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
|
||||||
/// @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class
|
/// @param (std::map<uint32_t, std::vector<uint32_t >> *cls_ids - val all ids for this class
|
||||||
/// @return Status The status code returned
|
/// @return Status The status code returned
|
||||||
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;
|
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;
|
||||||
|
|
||||||
|
|
|
@ -120,20 +120,20 @@ Status ClueOp::LoadFile(const std::string &file, int64_t start_offset, int64_t e
|
||||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to parse JSON file: " + file);
|
RETURN_STATUS_UNEXPECTED("Invalid file, failed to parse JSON file: " + file);
|
||||||
}
|
}
|
||||||
int cols_count = cols_to_keyword_.size();
|
int cols_count = cols_to_keyword_.size();
|
||||||
TensorRow tRow(cols_count, nullptr);
|
TensorRow t_row(cols_count, nullptr);
|
||||||
// Add file path info
|
// Add file path info
|
||||||
std::vector<std::string> file_path(cols_count, file);
|
std::vector<std::string> file_path(cols_count, file);
|
||||||
tRow.setPath(file_path);
|
t_row.setPath(file_path);
|
||||||
int cout = 0;
|
int cout = 0;
|
||||||
for (auto &p : cols_to_keyword_) {
|
for (auto &p : cols_to_keyword_) {
|
||||||
std::shared_ptr<Tensor> tensor;
|
std::shared_ptr<Tensor> tensor;
|
||||||
RETURN_IF_NOT_OK(GetValue(js, p.second, &tensor));
|
RETURN_IF_NOT_OK(GetValue(js, p.second, &tensor));
|
||||||
tRow[cout] = std::move(tensor);
|
t_row[cout] = std::move(tensor);
|
||||||
cout++;
|
cout++;
|
||||||
}
|
}
|
||||||
|
|
||||||
rows_total++;
|
rows_total++;
|
||||||
RETURN_IF_NOT_OK(jagged_rows_connector_->Add(worker_id, std::move(tRow)));
|
RETURN_IF_NOT_OK(jagged_rows_connector_->Add(worker_id, std::move(t_row)));
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -215,7 +215,7 @@ Status GeneratorOp::operator()() {
|
||||||
// Waiting for repeatOp to start new epoch
|
// Waiting for repeatOp to start new epoch
|
||||||
// If Reset() is called first by repeat op, this wait() will return right away.
|
// If Reset() is called first by repeat op, this wait() will return right away.
|
||||||
// If Reset() is not called yet, this wait() will block until reset.
|
// If Reset() is not called yet, this wait() will block until reset.
|
||||||
if (this->op_total_repeats() < 0) {
|
if (this->GetOpTotalRepeats() < 0) {
|
||||||
RETURN_IF_NOT_OK(wp_.Wait());
|
RETURN_IF_NOT_OK(wp_.Wait());
|
||||||
// Clear the status of the wait post
|
// Clear the status of the wait post
|
||||||
wp_.Clear();
|
wp_.Clear();
|
||||||
|
@ -235,7 +235,7 @@ Status GeneratorOp::Reset() {
|
||||||
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
|
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
|
||||||
// Create new generator object
|
// Create new generator object
|
||||||
RETURN_IF_NOT_OK(CreateGeneratorObject());
|
RETURN_IF_NOT_OK(CreateGeneratorObject());
|
||||||
if (this->op_total_repeats() < 0) {
|
if (this->GetOpTotalRepeats() < 0) {
|
||||||
// Wake up master thread
|
// Wake up master thread
|
||||||
wp_.Set();
|
wp_.Set();
|
||||||
}
|
}
|
||||||
|
|
|
@ -84,20 +84,20 @@ Status ImageFolderOp::PrescanMasterEntry(const std::string &filedir) {
|
||||||
|
|
||||||
// Load 1 TensorRow (image,label) using 1 ImageLabelPair. 1 function call produces 1 TensorTow
|
// Load 1 TensorRow (image,label) using 1 ImageLabelPair. 1 function call produces 1 TensorTow
|
||||||
Status ImageFolderOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
|
Status ImageFolderOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
|
||||||
ImageLabelPair pairPtr = image_label_pairs_[row_id];
|
ImageLabelPair pair_ptr = image_label_pairs_[row_id];
|
||||||
std::shared_ptr<Tensor> image, label;
|
std::shared_ptr<Tensor> image, label;
|
||||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(pairPtr->second, &label));
|
RETURN_IF_NOT_OK(Tensor::CreateScalar(pair_ptr->second, &label));
|
||||||
RETURN_IF_NOT_OK(Tensor::CreateFromFile(folder_path_ + (pairPtr->first), &image));
|
RETURN_IF_NOT_OK(Tensor::CreateFromFile(folder_path_ + (pair_ptr->first), &image));
|
||||||
|
|
||||||
if (decode_ == true) {
|
if (decode_ == true) {
|
||||||
Status rc = Decode(image, &image);
|
Status rc = Decode(image, &image);
|
||||||
if (rc.IsError()) {
|
if (rc.IsError()) {
|
||||||
std::string err = "Invalid data, failed to decode image: " + folder_path_ + (pairPtr->first);
|
std::string err = "Invalid data, failed to decode image: " + folder_path_ + (pair_ptr->first);
|
||||||
RETURN_STATUS_UNEXPECTED(err);
|
RETURN_STATUS_UNEXPECTED(err);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(*trow) = TensorRow(row_id, {std::move(image), std::move(label)});
|
(*trow) = TensorRow(row_id, {std::move(image), std::move(label)});
|
||||||
trow->setPath({folder_path_ + (pairPtr->first), std::string("")});
|
trow->setPath({folder_path_ + (pair_ptr->first), std::string("")});
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -57,11 +57,14 @@ class ImageFolderOp : public MappableLeafOp {
|
||||||
// @param int32_t num_wkrs - Num of workers reading images in parallel
|
// @param int32_t num_wkrs - Num of workers reading images in parallel
|
||||||
// @param std::string - dir directory of ImageNetFolder
|
// @param std::string - dir directory of ImageNetFolder
|
||||||
// @param int32_t queue_size - connector queue size
|
// @param int32_t queue_size - connector queue size
|
||||||
// @param std::set<std::string> exts - set of file extensions to read, if empty, read everything under the dir
|
// @param bool recursive - read recursively
|
||||||
// @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
|
// @param bool do_decode - decode the images after reading
|
||||||
|
// @param std::set<std::string> &exts - set of file extensions to read, if empty, read everything under the dir
|
||||||
|
// @param std::map<std::string, int32_t> &map- map of folder name and class id
|
||||||
|
// @param std::unique_ptr<dataschema> data_schema - schema of data
|
||||||
ImageFolderOp(int32_t num_wkrs, std::string file_dir, int32_t queue_size, bool recursive, bool do_decode,
|
ImageFolderOp(int32_t num_wkrs, std::string file_dir, int32_t queue_size, bool recursive, bool do_decode,
|
||||||
const std::set<std::string> &exts, const std::map<std::string, int32_t> &map,
|
const std::set<std::string> &exts, const std::map<std::string, int32_t> &map,
|
||||||
std::unique_ptr<DataSchema>, std::shared_ptr<SamplerRT> sampler);
|
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
|
||||||
|
|
||||||
/// Destructor.
|
/// Destructor.
|
||||||
~ImageFolderOp() = default;
|
~ImageFolderOp() = default;
|
||||||
|
|
|
@ -223,7 +223,9 @@ Status MnistOp::WalkAllFiles() {
|
||||||
const std::string train_prefix = "train";
|
const std::string train_prefix = "train";
|
||||||
const std::string test_prefix = "t10k";
|
const std::string test_prefix = "t10k";
|
||||||
|
|
||||||
Path dir(folder_path_);
|
std::string real_path{""};
|
||||||
|
RETURN_IF_NOT_OK(Path::RealPath(folder_path_, real_path));
|
||||||
|
Path dir(real_path);
|
||||||
auto dir_it = Path::DirIterator::OpenDirectory(&dir);
|
auto dir_it = Path::DirIterator::OpenDirectory(&dir);
|
||||||
std::string prefix; // empty string, used to match usage = "" (default) or usage == "all"
|
std::string prefix; // empty string, used to match usage = "" (default) or usage == "all"
|
||||||
if (usage_ == "train" || usage_ == "test") prefix = (usage_ == "test" ? test_prefix : train_prefix);
|
if (usage_ == "train" || usage_ == "test") prefix = (usage_ == "test" ? test_prefix : train_prefix);
|
||||||
|
|
|
@ -68,15 +68,15 @@ void RandomDataOp::Print(std::ostream &out, bool show_all) const {
|
||||||
|
|
||||||
// Helper function to produce a default/random schema if one didn't exist
|
// Helper function to produce a default/random schema if one didn't exist
|
||||||
void RandomDataOp::GenerateSchema() {
|
void RandomDataOp::GenerateSchema() {
|
||||||
const int32_t kTypeOffset = 2;
|
const int32_t type_offset = 2;
|
||||||
// To randomly create a schema, we need to choose:
|
// To randomly create a schema, we need to choose:
|
||||||
// a) how many columns
|
// a) how many columns
|
||||||
// b) the type of each column
|
// b) the type of each column
|
||||||
// c) the shape of each column (number of dimensions i.e. rank)
|
// c) the shape of each column (number of dimensions i.e. rank)
|
||||||
// d) the shape of each column (dimension values)
|
// d) the shape of each column (dimension values)
|
||||||
data_schema_ = std::make_unique<DataSchema>();
|
data_schema_ = std::make_unique<DataSchema>();
|
||||||
std::unique_ptr<TensorShape> newShape;
|
std::unique_ptr<TensorShape> new_shape;
|
||||||
std::unique_ptr<ColDescriptor> newCol;
|
std::unique_ptr<ColDescriptor> new_col;
|
||||||
|
|
||||||
// Loop over the number of chosen columns
|
// Loop over the number of chosen columns
|
||||||
int32_t numColumns = GenRandomInt(1, kMaxNumColumns);
|
int32_t numColumns = GenRandomInt(1, kMaxNumColumns);
|
||||||
|
@ -84,23 +84,26 @@ void RandomDataOp::GenerateSchema() {
|
||||||
// For each column:
|
// For each column:
|
||||||
// - choose a datatype
|
// - choose a datatype
|
||||||
// - generate a shape that randomly chooses the number of dimensions and the dimension values.
|
// - generate a shape that randomly chooses the number of dimensions and the dimension values.
|
||||||
DataType::Type newType = static_cast<DataType::Type>(GenRandomInt(1, DataType::NUM_OF_TYPES - kTypeOffset));
|
DataType::Type newType = static_cast<DataType::Type>(GenRandomInt(1, DataType::NUM_OF_TYPES - type_offset));
|
||||||
int32_t rank = GenRandomInt(1, kMaxRank);
|
int32_t rank = GenRandomInt(1, kMaxRank);
|
||||||
std::vector<dsize_t> dims;
|
std::vector<dsize_t> dims;
|
||||||
for (int32_t d = 0; d < rank; d++) {
|
for (int32_t d = 0; d < rank; d++) {
|
||||||
// 0 is not a valid dimension value. however, we can support "*" or unknown, so map the random
|
// 0 is not a valid dimension value. however, we can support "*" or unknown, so map the random
|
||||||
// 0 value to the unknown attribute if 0 is chosen
|
// 0 value to the unknown attribute if 0 is chosen
|
||||||
dsize_t dim_value = static_cast<dsize_t>(GenRandomInt(0, kMaxDimValue));
|
dsize_t dim_value = static_cast<dsize_t>(GenRandomInt(0, kMaxDimValue));
|
||||||
if (dim_value == 0) dim_value = TensorShape::kDimUnknown;
|
if (dim_value == 0) {
|
||||||
|
dim_value = TensorShape::kDimUnknown;
|
||||||
|
}
|
||||||
dims.push_back(dim_value);
|
dims.push_back(dim_value);
|
||||||
}
|
}
|
||||||
newShape = std::make_unique<TensorShape>(dims);
|
new_shape = std::make_unique<TensorShape>(dims);
|
||||||
|
|
||||||
// Create the column descriptor
|
// Create the column descriptor
|
||||||
std::string colName = "c" + std::to_string(i);
|
std::string col_name = "c" + std::to_string(i);
|
||||||
newCol = std::make_unique<ColDescriptor>(colName, DataType(newType), TensorImpl::kFlexible, rank, newShape.get());
|
new_col =
|
||||||
|
std::make_unique<ColDescriptor>(col_name, DataType(newType), TensorImpl::kFlexible, rank, new_shape.get());
|
||||||
|
|
||||||
Status rc = data_schema_->AddColumn(*newCol);
|
Status rc = data_schema_->AddColumn(*new_col);
|
||||||
if (rc.IsError()) MS_LOG(ERROR) << "Failed to generate a schema. Message:" << rc;
|
if (rc.IsError()) MS_LOG(ERROR) << "Failed to generate a schema. Message:" << rc;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -125,6 +128,9 @@ Status RandomDataOp::operator()() {
|
||||||
DatasetOp::CreateConnector(num_producers_, num_workers_);
|
DatasetOp::CreateConnector(num_producers_, num_workers_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (num_workers_ == 0) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Invalid data, num_workers_ is zero.");
|
||||||
|
}
|
||||||
// Assign the number of rows to each worker in a round robin fashion.
|
// Assign the number of rows to each worker in a round robin fashion.
|
||||||
worker_max_rows_.reserve(num_workers_);
|
worker_max_rows_.reserve(num_workers_);
|
||||||
worker_rows_packed_.reserve(num_workers_);
|
worker_rows_packed_.reserve(num_workers_);
|
||||||
|
|
|
@ -148,10 +148,10 @@ Status VOCOp::ParseImageIds() {
|
||||||
Status VOCOp::ParseAnnotationIds() {
|
Status VOCOp::ParseAnnotationIds() {
|
||||||
std::vector<std::string> new_image_ids;
|
std::vector<std::string> new_image_ids;
|
||||||
for (auto id : image_ids_) {
|
for (auto id : image_ids_) {
|
||||||
const std::string kAnnotationName =
|
const std::string annotation_name =
|
||||||
folder_path_ + std::string(kAnnotationsFolder) + id + std::string(kAnnotationExtension);
|
folder_path_ + std::string(kAnnotationsFolder) + id + std::string(kAnnotationExtension);
|
||||||
RETURN_IF_NOT_OK(ParseAnnotationBbox(kAnnotationName));
|
RETURN_IF_NOT_OK(ParseAnnotationBbox(annotation_name));
|
||||||
if (annotation_map_.find(kAnnotationName) != annotation_map_.end()) {
|
if (annotation_map_.find(annotation_name) != annotation_map_.end()) {
|
||||||
new_image_ids.push_back(id);
|
new_image_ids.push_back(id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -179,7 +179,9 @@ void VOCOp::ParseNodeValue(XMLElement *bbox_node, const char *name, float *value
|
||||||
*value = 0.0;
|
*value = 0.0;
|
||||||
if (bbox_node != nullptr) {
|
if (bbox_node != nullptr) {
|
||||||
XMLElement *node = bbox_node->FirstChildElement(name);
|
XMLElement *node = bbox_node->FirstChildElement(name);
|
||||||
if (node != nullptr) *value = node->FloatText();
|
if (node != nullptr) {
|
||||||
|
*value = node->FloatText();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -69,7 +69,7 @@ Status TakeOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe)
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t TakeOp::num_consumers() const {
|
int32_t TakeOp::NumConsumers() const {
|
||||||
if (parent_.empty()) {
|
if (parent_.empty()) {
|
||||||
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
|
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -77,16 +77,16 @@ int32_t TakeOp::num_consumers() const {
|
||||||
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
|
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
|
||||||
return 0;
|
return 0;
|
||||||
} else {
|
} else {
|
||||||
return parent_[0]->num_consumers();
|
return parent_[0]->NumConsumers();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t TakeOp::num_producers() const {
|
int32_t TakeOp::NumProducers() const {
|
||||||
if (child_.empty() || child_[0] == nullptr) {
|
if (child_.empty() || child_[0] == nullptr) {
|
||||||
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
|
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
|
||||||
return 0;
|
return 0;
|
||||||
} else {
|
} else {
|
||||||
return child_[0]->num_producers();
|
return child_[0]->NumProducers();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
|
|
|
@ -59,8 +59,8 @@ class TakeOp : public PipelineOp {
|
||||||
std::string Name() const override { return kTakeOp; }
|
std::string Name() const override { return kTakeOp; }
|
||||||
|
|
||||||
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
||||||
int32_t num_consumers() const override;
|
int32_t NumConsumers() const override;
|
||||||
int32_t num_producers() const override;
|
int32_t NumProducers() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int32_t max_takes_; // The number of takes that the user requested
|
int32_t max_takes_; // The number of takes that the user requested
|
||||||
|
|
|
@ -132,7 +132,7 @@ Status ZipOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t ZipOp::num_consumers() const {
|
int32_t ZipOp::NumConsumers() const {
|
||||||
if (parent_.empty()) {
|
if (parent_.empty()) {
|
||||||
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
|
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -140,16 +140,16 @@ int32_t ZipOp::num_consumers() const {
|
||||||
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
|
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
|
||||||
return 0;
|
return 0;
|
||||||
} else {
|
} else {
|
||||||
return parent_[0]->num_consumers();
|
return parent_[0]->NumConsumers();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t ZipOp::num_producers() const {
|
int32_t ZipOp::NumProducers() const {
|
||||||
if (child_.empty() || child_[0] == nullptr) {
|
if (child_.empty() || child_[0] == nullptr) {
|
||||||
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
|
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
|
||||||
return 0;
|
return 0;
|
||||||
} else {
|
} else {
|
||||||
return child_[0]->num_producers();
|
return child_[0]->NumProducers();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
|
|
|
@ -65,8 +65,8 @@ class ZipOp : public PipelineOp {
|
||||||
std::string Name() const override { return kZipOp; }
|
std::string Name() const override { return kZipOp; }
|
||||||
|
|
||||||
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
||||||
int32_t num_consumers() const override;
|
int32_t NumConsumers() const override;
|
||||||
int32_t num_producers() const override;
|
int32_t NumProducers() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Special handle case where an empty row has been received from child iterator
|
// Special handle case where an empty row has been received from child iterator
|
||||||
|
|
|
@ -85,7 +85,7 @@ Status ExecutionTree::AssociateNode(const std::shared_ptr<DatasetOp> &op) {
|
||||||
tree_state_ = kDeTStateBuilding;
|
tree_state_ = kDeTStateBuilding;
|
||||||
|
|
||||||
// Assign an id to the operator
|
// Assign an id to the operator
|
||||||
op->set_id(id_count_);
|
op->SetId(id_count_);
|
||||||
id_count_++;
|
id_count_++;
|
||||||
|
|
||||||
// Assign our tree into the op so that each op has a link back to the tree
|
// Assign our tree into the op so that each op has a link back to the tree
|
||||||
|
|
|
@ -104,8 +104,8 @@ Status BatchNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
||||||
|
|
||||||
auto op = std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
|
auto op = std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
|
||||||
in_col_names_, out_col_names_, batch_size_func_, batch_map_func_, pad_map_);
|
in_col_names_, out_col_names_, batch_size_func_, batch_map_func_, pad_map_);
|
||||||
op->set_total_repeats(GetTotalRepeats());
|
op->SetTotalRepeats(GetTotalRepeats());
|
||||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(op);
|
node_ops->push_back(op);
|
||||||
#else
|
#else
|
||||||
node_ops->push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
|
node_ops->push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
|
||||||
|
|
|
@ -87,8 +87,8 @@ Status BucketBatchByLengthNode::Build(std::vector<std::shared_ptr<DatasetOp>> *c
|
||||||
auto op = std::make_shared<BucketBatchByLengthOp>(column_names_, bucket_boundaries_, bucket_batch_sizes_,
|
auto op = std::make_shared<BucketBatchByLengthOp>(column_names_, bucket_boundaries_, bucket_batch_sizes_,
|
||||||
element_length_function_, pad_info_, pad_to_bucket_boundary_,
|
element_length_function_, pad_info_, pad_to_bucket_boundary_,
|
||||||
drop_remainder_, connector_que_size_);
|
drop_remainder_, connector_que_size_);
|
||||||
op->set_total_repeats(GetTotalRepeats());
|
op->SetTotalRepeats(GetTotalRepeats());
|
||||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(op);
|
node_ops->push_back(op);
|
||||||
if (bucket_boundaries_[0] == 0) {
|
if (bucket_boundaries_[0] == 0) {
|
||||||
bucket_boundaries_.erase(bucket_boundaries_.begin());
|
bucket_boundaries_.erase(bucket_boundaries_.begin());
|
||||||
|
|
|
@ -56,8 +56,8 @@ void BuildSentenceVocabNode::Print(std::ostream &out) const {
|
||||||
Status BuildSentenceVocabNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
Status BuildSentenceVocabNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||||
auto op = std::make_shared<BuildSentencePieceVocabOp>(vocab_, col_names_, vocab_size_, character_coverage_,
|
auto op = std::make_shared<BuildSentencePieceVocabOp>(vocab_, col_names_, vocab_size_, character_coverage_,
|
||||||
model_type_, params_, connector_que_size_);
|
model_type_, params_, connector_que_size_);
|
||||||
op->set_total_repeats(GetTotalRepeats());
|
op->SetTotalRepeats(GetTotalRepeats());
|
||||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(op);
|
node_ops->push_back(op);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,8 +54,8 @@ Status BuildVocabNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node
|
||||||
std::shared_ptr<BuildVocabOp> build_vocab_op;
|
std::shared_ptr<BuildVocabOp> build_vocab_op;
|
||||||
build_vocab_op = std::make_shared<BuildVocabOp>(vocab_, columns_, freq_range_, top_k_, special_tokens_,
|
build_vocab_op = std::make_shared<BuildVocabOp>(vocab_, columns_, freq_range_, top_k_, special_tokens_,
|
||||||
special_first_, num_workers_, connector_que_size_);
|
special_first_, num_workers_, connector_que_size_);
|
||||||
build_vocab_op->set_total_repeats(GetTotalRepeats());
|
build_vocab_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
build_vocab_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
build_vocab_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(build_vocab_op);
|
node_ops->push_back(build_vocab_op);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,8 +51,8 @@ Status CacheLookupNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops)
|
||||||
"Internal error. Attempt to create a cache lookup node without cache client.");
|
"Internal error. Attempt to create a cache lookup node without cache client.");
|
||||||
RETURN_IF_NOT_OK(cache_->Build());
|
RETURN_IF_NOT_OK(cache_->Build());
|
||||||
RETURN_IF_NOT_OK(cache_->CreateCacheLookupOp(num_workers_, connector_que_size_, sampler_, &lookup_op_));
|
RETURN_IF_NOT_OK(cache_->CreateCacheLookupOp(num_workers_, connector_que_size_, sampler_, &lookup_op_));
|
||||||
lookup_op_->set_total_repeats(GetTotalRepeats());
|
lookup_op_->SetTotalRepeats(GetTotalRepeats());
|
||||||
lookup_op_->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
lookup_op_->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(lookup_op_);
|
node_ops->push_back(lookup_op_);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,8 +48,8 @@ Status CacheMergeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops)
|
||||||
RETURN_IF_NOT_OK(cache_->Build());
|
RETURN_IF_NOT_OK(cache_->Build());
|
||||||
std::shared_ptr<DatasetOp> merge_op = nullptr;
|
std::shared_ptr<DatasetOp> merge_op = nullptr;
|
||||||
RETURN_IF_NOT_OK(cache_->CreateCacheMergeOp(num_workers_, connector_que_size_, &merge_op));
|
RETURN_IF_NOT_OK(cache_->CreateCacheMergeOp(num_workers_, connector_que_size_, &merge_op));
|
||||||
merge_op->set_total_repeats(GetTotalRepeats());
|
merge_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
merge_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
merge_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(merge_op);
|
node_ops->push_back(merge_op);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,8 +52,8 @@ Status CacheNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
|
||||||
RETURN_IF_NOT_OK(cache_->Build());
|
RETURN_IF_NOT_OK(cache_->Build());
|
||||||
std::shared_ptr<DatasetOp> cache_op = nullptr;
|
std::shared_ptr<DatasetOp> cache_op = nullptr;
|
||||||
RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, connector_que_size_, sampler_, &cache_op));
|
RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, connector_que_size_, sampler_, &cache_op));
|
||||||
cache_op->set_total_repeats(GetTotalRepeats());
|
cache_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
cache_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
cache_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(cache_op);
|
node_ops->push_back(cache_op);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -133,8 +133,8 @@ Status ConcatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
|
||||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||||
op = std::make_shared<ConcatOp>(sampler_rt, children_flag_and_nums_, children_start_end_index_);
|
op = std::make_shared<ConcatOp>(sampler_rt, children_flag_and_nums_, children_start_end_index_);
|
||||||
}
|
}
|
||||||
op->set_total_repeats(GetTotalRepeats());
|
op->SetTotalRepeats(GetTotalRepeats());
|
||||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(op);
|
node_ops->push_back(op);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -257,7 +257,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
||||||
void HasCacheAbove() { descendant_of_cache_ = true; }
|
void HasCacheAbove() { descendant_of_cache_ = true; }
|
||||||
|
|
||||||
/// \brief Getter of the number of workers
|
/// \brief Getter of the number of workers
|
||||||
int32_t num_workers() { return num_workers_; }
|
int32_t NumWorkers() { return num_workers_; }
|
||||||
|
|
||||||
/// \brief Getter of dataset cache
|
/// \brief Getter of dataset cache
|
||||||
std::shared_ptr<DatasetCache> GetDatasetCache() { return cache_; }
|
std::shared_ptr<DatasetCache> GetDatasetCache() { return cache_; }
|
||||||
|
|
|
@ -46,8 +46,8 @@ void EpochCtrlNode::Print(std::ostream &out) const {
|
||||||
// Function to build the EpochCtrlOp
|
// Function to build the EpochCtrlOp
|
||||||
Status EpochCtrlNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
Status EpochCtrlNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||||
auto new_op_ = std::make_shared<EpochCtrlOp>(repeat_count_);
|
auto new_op_ = std::make_shared<EpochCtrlOp>(repeat_count_);
|
||||||
new_op_->set_total_repeats(GetTotalRepeats());
|
new_op_->SetTotalRepeats(GetTotalRepeats());
|
||||||
new_op_->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
new_op_->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(new_op_);
|
node_ops->push_back(new_op_);
|
||||||
op_ = new_op_;
|
op_ = new_op_;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -45,8 +45,8 @@ void FilterNode::Print(std::ostream &out) const {
|
||||||
|
|
||||||
Status FilterNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
Status FilterNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||||
auto op = std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, predicate_);
|
auto op = std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, predicate_);
|
||||||
op->set_total_repeats(GetTotalRepeats());
|
op->SetTotalRepeats(GetTotalRepeats());
|
||||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(op);
|
node_ops->push_back(op);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -89,12 +89,12 @@ Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||||
|
|
||||||
if (!project_columns_.empty()) {
|
if (!project_columns_.empty()) {
|
||||||
auto project_op = std::make_shared<ProjectOp>(project_columns_);
|
auto project_op = std::make_shared<ProjectOp>(project_columns_);
|
||||||
project_op->set_total_repeats(GetTotalRepeats());
|
project_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
project_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
project_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(project_op);
|
node_ops->push_back(project_op);
|
||||||
}
|
}
|
||||||
map_op->set_total_repeats(GetTotalRepeats());
|
map_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
map_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
map_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(map_op);
|
node_ops->push_back(map_op);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,8 +54,8 @@ Status ProjectNode::ValidateParams() {
|
||||||
|
|
||||||
Status ProjectNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
Status ProjectNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||||
auto op = std::make_shared<ProjectOp>(columns_);
|
auto op = std::make_shared<ProjectOp>(columns_);
|
||||||
op->set_total_repeats(GetTotalRepeats());
|
op->SetTotalRepeats(GetTotalRepeats());
|
||||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(op);
|
node_ops->push_back(op);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -59,8 +59,8 @@ Status RenameNode::ValidateParams() {
|
||||||
|
|
||||||
Status RenameNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
Status RenameNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||||
auto op = std::make_shared<RenameOp>(input_columns_, output_columns_);
|
auto op = std::make_shared<RenameOp>(input_columns_, output_columns_);
|
||||||
op->set_total_repeats(GetTotalRepeats());
|
op->SetTotalRepeats(GetTotalRepeats());
|
||||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(op);
|
node_ops->push_back(op);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,8 +40,8 @@ void RepeatNode::Print(std::ostream &out) const { out << (Name() + "(count:" + s
|
||||||
|
|
||||||
Status RepeatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
Status RepeatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||||
auto new_op = std::make_shared<RepeatOp>(repeat_count_);
|
auto new_op = std::make_shared<RepeatOp>(repeat_count_);
|
||||||
new_op->set_total_repeats(GetTotalRepeats());
|
new_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
new_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
new_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(new_op);
|
node_ops->push_back(new_op);
|
||||||
op_ = new_op;
|
op_ = new_op;
|
||||||
|
|
||||||
|
|
|
@ -45,8 +45,8 @@ void ShuffleNode::Print(std::ostream &out) const {
|
||||||
// Function to build the ShuffleOp
|
// Function to build the ShuffleOp
|
||||||
Status ShuffleNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
Status ShuffleNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||||
auto op = std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_);
|
auto op = std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_);
|
||||||
op->set_total_repeats(GetTotalRepeats());
|
op->SetTotalRepeats(GetTotalRepeats());
|
||||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(op);
|
node_ops->push_back(op);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,8 +40,8 @@ void SkipNode::Print(std::ostream &out) const { out << (Name() + "(skip_count:"
|
||||||
// Function to build the SkipOp
|
// Function to build the SkipOp
|
||||||
Status SkipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
Status SkipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||||
auto op = std::make_shared<SkipOp>(skip_count_);
|
auto op = std::make_shared<SkipOp>(skip_count_);
|
||||||
op->set_total_repeats(GetTotalRepeats());
|
op->SetTotalRepeats(GetTotalRepeats());
|
||||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(op);
|
node_ops->push_back(op);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -79,8 +79,8 @@ Status AlbumNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
||||||
|
|
||||||
auto album_op = std::make_shared<AlbumOp>(num_workers_, dataset_dir_, connector_que_size_, decode_, extensions,
|
auto album_op = std::make_shared<AlbumOp>(num_workers_, dataset_dir_, connector_que_size_, decode_, extensions,
|
||||||
std::move(schema), std::move(sampler_rt));
|
std::move(schema), std::move(sampler_rt));
|
||||||
album_op->set_total_repeats(GetTotalRepeats());
|
album_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
album_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
album_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(album_op);
|
node_ops->push_back(album_op);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -75,8 +75,8 @@ Status CelebANode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
|
||||||
|
|
||||||
auto celeba_op = std::make_shared<CelebAOp>(num_workers_, dataset_dir_, connector_que_size_, decode_, usage_,
|
auto celeba_op = std::make_shared<CelebAOp>(num_workers_, dataset_dir_, connector_que_size_, decode_, usage_,
|
||||||
extensions_, std::move(schema), std::move(sampler_rt));
|
extensions_, std::move(schema), std::move(sampler_rt));
|
||||||
celeba_op->set_total_repeats(GetTotalRepeats());
|
celeba_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
celeba_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
celeba_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(celeba_op);
|
node_ops->push_back(celeba_op);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -71,8 +71,8 @@ Status Cifar100Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
||||||
|
|
||||||
auto cifar_op = std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, dataset_dir_,
|
auto cifar_op = std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, dataset_dir_,
|
||||||
connector_que_size_, std::move(schema), std::move(sampler_rt));
|
connector_que_size_, std::move(schema), std::move(sampler_rt));
|
||||||
cifar_op->set_total_repeats(GetTotalRepeats());
|
cifar_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
cifar_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
cifar_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(cifar_op);
|
node_ops->push_back(cifar_op);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -69,8 +69,8 @@ Status Cifar10Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_op
|
||||||
|
|
||||||
auto cifar_op = std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, dataset_dir_,
|
auto cifar_op = std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, dataset_dir_,
|
||||||
connector_que_size_, std::move(schema), std::move(sampler_rt));
|
connector_que_size_, std::move(schema), std::move(sampler_rt));
|
||||||
cifar_op->set_total_repeats(GetTotalRepeats());
|
cifar_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
cifar_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
cifar_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(cifar_op);
|
node_ops->push_back(cifar_op);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -88,8 +88,8 @@ Status CityscapesNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node
|
||||||
|
|
||||||
auto cityscapes_op = std::make_shared<CityscapesOp>(num_workers_, dataset_dir_, usage_, quality_mode_, task_, decode_,
|
auto cityscapes_op = std::make_shared<CityscapesOp>(num_workers_, dataset_dir_, usage_, quality_mode_, task_, decode_,
|
||||||
connector_que_size_, std::move(schema), std::move(sampler_rt));
|
connector_que_size_, std::move(schema), std::move(sampler_rt));
|
||||||
cityscapes_op->set_total_repeats(GetTotalRepeats());
|
cityscapes_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
cityscapes_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
cityscapes_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(cityscapes_op);
|
node_ops->push_back(cityscapes_op);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -202,12 +202,12 @@ Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
||||||
// Add the shuffle op after this op
|
// Add the shuffle op after this op
|
||||||
RETURN_IF_NOT_OK(
|
RETURN_IF_NOT_OK(
|
||||||
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
|
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
|
||||||
shuffle_op->set_total_repeats(GetTotalRepeats());
|
shuffle_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(shuffle_op);
|
node_ops->push_back(shuffle_op);
|
||||||
}
|
}
|
||||||
clue_op->set_total_repeats(GetTotalRepeats());
|
clue_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
clue_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
clue_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(clue_op);
|
node_ops->push_back(clue_op);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -142,8 +142,8 @@ Status CocoNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
||||||
std::shared_ptr<CocoOp> op =
|
std::shared_ptr<CocoOp> op =
|
||||||
std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, connector_que_size_, decode_,
|
std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, connector_que_size_, decode_,
|
||||||
std::move(schema), std::move(sampler_rt), extra_metadata_);
|
std::move(schema), std::move(sampler_rt), extra_metadata_);
|
||||||
op->set_total_repeats(GetTotalRepeats());
|
op->SetTotalRepeats(GetTotalRepeats());
|
||||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(op);
|
node_ops->push_back(op);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -140,12 +140,12 @@ Status CSVNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||||
// Add the shuffle op after this op
|
// Add the shuffle op after this op
|
||||||
RETURN_IF_NOT_OK(
|
RETURN_IF_NOT_OK(
|
||||||
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
|
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
|
||||||
shuffle_op->set_total_repeats(GetTotalRepeats());
|
shuffle_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(shuffle_op);
|
node_ops->push_back(shuffle_op);
|
||||||
}
|
}
|
||||||
csv_op->set_total_repeats(GetTotalRepeats());
|
csv_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
csv_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
csv_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(csv_op);
|
node_ops->push_back(csv_op);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -100,8 +100,8 @@ Status DIV2KNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
||||||
|
|
||||||
auto div2k_op = std::make_shared<DIV2KOp>(num_workers_, dataset_dir_, usage_, downgrade_, scale_, decode_,
|
auto div2k_op = std::make_shared<DIV2KOp>(num_workers_, dataset_dir_, usage_, downgrade_, scale_, decode_,
|
||||||
connector_que_size_, std::move(schema), std::move(sampler_rt));
|
connector_que_size_, std::move(schema), std::move(sampler_rt));
|
||||||
div2k_op->set_total_repeats(GetTotalRepeats());
|
div2k_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
div2k_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
div2k_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(div2k_op);
|
node_ops->push_back(div2k_op);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -100,8 +100,8 @@ Status FlickrNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
|
||||||
|
|
||||||
auto flickr_op = std::make_shared<FlickrOp>(num_workers_, dataset_dir_, annotation_file_, decode_,
|
auto flickr_op = std::make_shared<FlickrOp>(num_workers_, dataset_dir_, annotation_file_, decode_,
|
||||||
connector_que_size_, std::move(schema), std::move(sampler_rt));
|
connector_que_size_, std::move(schema), std::move(sampler_rt));
|
||||||
flickr_op->set_total_repeats(GetTotalRepeats());
|
flickr_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
flickr_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
flickr_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(flickr_op);
|
node_ops->push_back(flickr_op);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -97,8 +97,8 @@ Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_
|
||||||
if (reset_ancestor_ != nullptr) {
|
if (reset_ancestor_ != nullptr) {
|
||||||
reset_ancestor_->op_->AddToEoeList(op);
|
reset_ancestor_->op_->AddToEoeList(op);
|
||||||
}
|
}
|
||||||
op->set_total_repeats(GetTotalRepeats());
|
op->SetTotalRepeats(GetTotalRepeats());
|
||||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(op);
|
node_ops->push_back(op);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -77,8 +77,8 @@ Status ImageFolderNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const nod
|
||||||
|
|
||||||
auto op = std::make_shared<ImageFolderOp>(num_workers_, dataset_dir_, connector_que_size_, recursive_, decode_, exts_,
|
auto op = std::make_shared<ImageFolderOp>(num_workers_, dataset_dir_, connector_que_size_, recursive_, decode_, exts_,
|
||||||
class_indexing_, std::move(schema), std::move(sampler_rt));
|
class_indexing_, std::move(schema), std::move(sampler_rt));
|
||||||
op->set_total_repeats(GetTotalRepeats());
|
op->SetTotalRepeats(GetTotalRepeats());
|
||||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(op);
|
node_ops->push_back(op);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -104,8 +104,8 @@ Status ManifestNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
||||||
|
|
||||||
manifest_op = std::make_shared<ManifestOp>(num_workers_, dataset_file_, connector_que_size_, decode_, class_index_,
|
manifest_op = std::make_shared<ManifestOp>(num_workers_, dataset_file_, connector_que_size_, decode_, class_index_,
|
||||||
std::move(schema), std::move(sampler_rt), usage_);
|
std::move(schema), std::move(sampler_rt), usage_);
|
||||||
manifest_op->set_total_repeats(GetTotalRepeats());
|
manifest_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
manifest_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
manifest_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(manifest_op);
|
node_ops->push_back(manifest_op);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -212,8 +212,8 @@ Status MindDataNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_IF_NOT_OK(mindrecord_op->Init());
|
RETURN_IF_NOT_OK(mindrecord_op->Init());
|
||||||
mindrecord_op->set_total_repeats(GetTotalRepeats());
|
mindrecord_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
mindrecord_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
mindrecord_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(mindrecord_op);
|
node_ops->push_back(mindrecord_op);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -65,8 +65,8 @@ Status MnistNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
||||||
|
|
||||||
auto op = std::make_shared<MnistOp>(usage_, num_workers_, dataset_dir_, connector_que_size_, std::move(schema),
|
auto op = std::make_shared<MnistOp>(usage_, num_workers_, dataset_dir_, connector_que_size_, std::move(schema),
|
||||||
std::move(sampler_rt));
|
std::move(sampler_rt));
|
||||||
op->set_total_repeats(GetTotalRepeats());
|
op->SetTotalRepeats(GetTotalRepeats());
|
||||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(op);
|
node_ops->push_back(op);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -75,8 +75,8 @@ Status QMnistNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
|
||||||
|
|
||||||
auto op = std::make_shared<QMnistOp>(dataset_dir_, usage_, compat_, std::move(schema), std::move(sampler_rt),
|
auto op = std::make_shared<QMnistOp>(dataset_dir_, usage_, compat_, std::move(schema), std::move(sampler_rt),
|
||||||
num_workers_, connector_que_size_);
|
num_workers_, connector_que_size_);
|
||||||
op->set_total_repeats(GetTotalRepeats());
|
op->SetTotalRepeats(GetTotalRepeats());
|
||||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(op);
|
node_ops->push_back(op);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -110,8 +110,8 @@ Status RandomNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
|
||||||
|
|
||||||
std::shared_ptr<RandomDataOp> op;
|
std::shared_ptr<RandomDataOp> op;
|
||||||
op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, total_rows_, std::move(data_schema_));
|
op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, total_rows_, std::move(data_schema_));
|
||||||
op->set_total_repeats(GetTotalRepeats());
|
op->SetTotalRepeats(GetTotalRepeats());
|
||||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(op);
|
node_ops->push_back(op);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -69,8 +69,8 @@ Status SBUNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||||
|
|
||||||
auto op = std::make_shared<SBUOp>(dataset_dir_, decode_, std::move(schema), std::move(sampler_rt), num_workers_,
|
auto op = std::make_shared<SBUOp>(dataset_dir_, decode_, std::move(schema), std::move(sampler_rt), num_workers_,
|
||||||
connector_que_size_);
|
connector_que_size_);
|
||||||
op->set_total_repeats(GetTotalRepeats());
|
op->SetTotalRepeats(GetTotalRepeats());
|
||||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(op);
|
node_ops->push_back(op);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -108,12 +108,12 @@ Status TextFileNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
||||||
// Add the shuffle op after this op
|
// Add the shuffle op after this op
|
||||||
RETURN_IF_NOT_OK(
|
RETURN_IF_NOT_OK(
|
||||||
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
|
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
|
||||||
shuffle_op->set_total_repeats(GetTotalRepeats());
|
shuffle_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(shuffle_op);
|
node_ops->push_back(shuffle_op);
|
||||||
}
|
}
|
||||||
text_file_op->set_total_repeats(GetTotalRepeats());
|
text_file_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
text_file_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
text_file_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
// Add TextFileOp
|
// Add TextFileOp
|
||||||
node_ops->push_back(text_file_op);
|
node_ops->push_back(text_file_op);
|
||||||
|
|
||||||
|
|
|
@ -149,12 +149,12 @@ Status TFRecordNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
||||||
|
|
||||||
// Add the shuffle op after this op
|
// Add the shuffle op after this op
|
||||||
RETURN_IF_NOT_OK(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
|
RETURN_IF_NOT_OK(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
|
||||||
shuffle_op->set_total_repeats(GetTotalRepeats());
|
shuffle_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(shuffle_op);
|
node_ops->push_back(shuffle_op);
|
||||||
}
|
}
|
||||||
tf_reader_op->set_total_repeats(GetTotalRepeats());
|
tf_reader_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
tf_reader_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
tf_reader_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
// Add TFReaderOp
|
// Add TFReaderOp
|
||||||
node_ops->push_back(tf_reader_op);
|
node_ops->push_back(tf_reader_op);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -97,12 +97,12 @@ Status USPSNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
||||||
|
|
||||||
// Add the shuffle op after this op
|
// Add the shuffle op after this op
|
||||||
RETURN_IF_NOT_OK(AddShuffleOp(op->FileNames().size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
|
RETURN_IF_NOT_OK(AddShuffleOp(op->FileNames().size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
|
||||||
shuffle_op->set_total_repeats(GetTotalRepeats());
|
shuffle_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(shuffle_op);
|
node_ops->push_back(shuffle_op);
|
||||||
}
|
}
|
||||||
op->set_total_repeats(GetTotalRepeats());
|
op->SetTotalRepeats(GetTotalRepeats());
|
||||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(op);
|
node_ops->push_back(op);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -125,8 +125,8 @@ Status VOCNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||||
std::shared_ptr<VOCOp> voc_op;
|
std::shared_ptr<VOCOp> voc_op;
|
||||||
voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, connector_que_size_,
|
voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, connector_que_size_,
|
||||||
decode_, std::move(schema), std::move(sampler_rt), extra_metadata_);
|
decode_, std::move(schema), std::move(sampler_rt), extra_metadata_);
|
||||||
voc_op->set_total_repeats(GetTotalRepeats());
|
voc_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
voc_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
voc_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(voc_op);
|
node_ops->push_back(voc_op);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,8 +46,8 @@ Status SyncWaitNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
||||||
// The reason for this is because having it otherwise can lead to blocking issues
|
// The reason for this is because having it otherwise can lead to blocking issues
|
||||||
// See barrier_op.h for more details
|
// See barrier_op.h for more details
|
||||||
auto op = std::make_shared<BarrierOp>(connector_que_size_, condition_name_, callback_);
|
auto op = std::make_shared<BarrierOp>(connector_que_size_, condition_name_, callback_);
|
||||||
op->set_total_repeats(GetTotalRepeats());
|
op->SetTotalRepeats(GetTotalRepeats());
|
||||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(op);
|
node_ops->push_back(op);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,8 +41,8 @@ void TakeNode::Print(std::ostream &out) const { out << (Name() + "(num_rows:" +
|
||||||
// Function to build the TakeOp
|
// Function to build the TakeOp
|
||||||
Status TakeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
Status TakeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||||
auto op = std::make_shared<TakeOp>(take_count_);
|
auto op = std::make_shared<TakeOp>(take_count_);
|
||||||
op->set_total_repeats(GetTotalRepeats());
|
op->SetTotalRepeats(GetTotalRepeats());
|
||||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(op);
|
node_ops->push_back(op);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -97,8 +97,8 @@ Status TransferNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
||||||
|
|
||||||
auto op = std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_,
|
auto op = std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_,
|
||||||
total_batch_, create_data_info_queue_);
|
total_batch_, create_data_info_queue_);
|
||||||
op->set_total_repeats(GetTotalRepeats());
|
op->SetTotalRepeats(GetTotalRepeats());
|
||||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(op);
|
node_ops->push_back(op);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,8 +60,8 @@ Status ZipNode::ValidateParams() {
|
||||||
|
|
||||||
Status ZipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
Status ZipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||||
auto op = std::make_shared<ZipOp>();
|
auto op = std::make_shared<ZipOp>();
|
||||||
op->set_total_repeats(GetTotalRepeats());
|
op->SetTotalRepeats(GetTotalRepeats());
|
||||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(op);
|
node_ops->push_back(op);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -68,10 +68,10 @@ Status AutoWorkerPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *con
|
||||||
int32_t cur_node_num_worker = std::max(std::min(num_workers, cur_node_max), min_num_workers_);
|
int32_t cur_node_num_worker = std::max(std::min(num_workers, cur_node_max), min_num_workers_);
|
||||||
|
|
||||||
// if the num_worker to set is same as original, skip setting and printing the logs
|
// if the num_worker to set is same as original, skip setting and printing the logs
|
||||||
if (cur_node_num_worker == p.first->num_workers()) continue;
|
if (cur_node_num_worker == p.first->NumWorkers()) continue;
|
||||||
// log the change via warning msg so user can see what the num_worker is being set for which op
|
// log the change via warning msg so user can see what the num_worker is being set for which op
|
||||||
MS_LOG(WARNING) << "AutoNumWorker enabled, num_workers in " << p.first->Name() << " is auto-adjusted from "
|
MS_LOG(WARNING) << "AutoNumWorker enabled, num_workers in " << p.first->Name() << " is auto-adjusted from "
|
||||||
<< std::to_string(p.first->num_workers()) + " to " + std::to_string(cur_node_num_worker);
|
<< std::to_string(p.first->NumWorkers()) + " to " + std::to_string(cur_node_num_worker);
|
||||||
p.first->SetNumWorkers(cur_node_num_worker);
|
p.first->SetNumWorkers(cur_node_num_worker);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -49,7 +49,7 @@ Status DeepCopyPass::Visit(std::shared_ptr<DatasetNode> node, bool *const modifi
|
||||||
// Temporary fix to set the num_workers to each cloned node.
|
// Temporary fix to set the num_workers to each cloned node.
|
||||||
// This can be improved by adding a new method in the base class DatasetNode to transfer the properties to
|
// This can be improved by adding a new method in the base class DatasetNode to transfer the properties to
|
||||||
// the cloned node. Each derived class's Copy() will need to include this method.
|
// the cloned node. Each derived class's Copy() will need to include this method.
|
||||||
new_node->SetNumWorkers(node->num_workers());
|
new_node->SetNumWorkers(node->NumWorkers());
|
||||||
// This method below assumes a DFS walk and from the first child to the last child.
|
// This method below assumes a DFS walk and from the first child to the last child.
|
||||||
// Future: A more robust implementation that does not depend on the above assumption.
|
// Future: A more robust implementation that does not depend on the above assumption.
|
||||||
RETURN_IF_NOT_OK(parent_->AppendChild(new_node));
|
RETURN_IF_NOT_OK(parent_->AppendChild(new_node));
|
||||||
|
|
|
@ -42,7 +42,7 @@ json ConnectorSize::ParseOpInfo(const DatasetOp &node, const std::vector<int32_t
|
||||||
json json_node;
|
json json_node;
|
||||||
json_node["op_id"] = node.id();
|
json_node["op_id"] = node.id();
|
||||||
json_node["op_type"] = node.Name();
|
json_node["op_type"] = node.Name();
|
||||||
json_node["num_workers"] = node.num_workers();
|
json_node["num_workers"] = node.NumWorkers();
|
||||||
json metrics;
|
json metrics;
|
||||||
// DeviceQueueOp is a special op,it is not inlined but its output queue is invalid.
|
// DeviceQueueOp is a special op,it is not inlined but its output queue is invalid.
|
||||||
// So we should not output its queue size.
|
// So we should not output its queue size.
|
||||||
|
|
|
@ -78,7 +78,7 @@ json ConnectorThroughput::ParseOpInfo(const DatasetOp &node, const std::vector<d
|
||||||
json json_node;
|
json json_node;
|
||||||
json_node["op_id"] = node.id();
|
json_node["op_id"] = node.id();
|
||||||
json_node["op_type"] = node.Name();
|
json_node["op_type"] = node.Name();
|
||||||
json_node["num_workers"] = node.num_workers();
|
json_node["num_workers"] = node.NumWorkers();
|
||||||
json metrics;
|
json metrics;
|
||||||
// DeviceQueueOp is a special op,it is not inlined but its output queue is invalid.
|
// DeviceQueueOp is a special op,it is not inlined but its output queue is invalid.
|
||||||
// So we should not output its connector throughput.
|
// So we should not output its connector throughput.
|
||||||
|
|
|
@ -264,7 +264,7 @@ Status OperatorCpu::Collect(const ExecutionTree *tree) {
|
||||||
for (auto iter = tree->begin(); iter != tree->end(); ++iter) {
|
for (auto iter = tree->begin(); iter != tree->end(); ++iter) {
|
||||||
id_count_++;
|
id_count_++;
|
||||||
op_name_[iter->id()] = iter->NameWithID();
|
op_name_[iter->id()] = iter->NameWithID();
|
||||||
op_parallel_workers_[iter->id()] = iter->num_workers();
|
op_parallel_workers_[iter->id()] = iter->NumWorkers();
|
||||||
}
|
}
|
||||||
#if defined(USING_LINUX)
|
#if defined(USING_LINUX)
|
||||||
cpu_processor_num_ = get_nprocs_conf();
|
cpu_processor_num_ = get_nprocs_conf();
|
||||||
|
|
|
@ -221,7 +221,7 @@ Status TreeAdapter::GetNext(TensorRow *row) {
|
||||||
RETURN_UNEXPECTED_IF_NULL(row);
|
RETURN_UNEXPECTED_IF_NULL(row);
|
||||||
row->clear(); // make sure row is empty
|
row->clear(); // make sure row is empty
|
||||||
#ifndef ENABLE_SECURITY
|
#ifndef ENABLE_SECURITY
|
||||||
bool isProfilingEnable = tree_->GetProfilingManager()->IsProfilingEnable();
|
bool is_profiling_enable = tree_->GetProfilingManager()->IsProfilingEnable();
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// When cur_db_ is a nullptr, it means this is the first call to get_next, launch ExecutionTree
|
// When cur_db_ is a nullptr, it means this is the first call to get_next, launch ExecutionTree
|
||||||
|
@ -233,7 +233,7 @@ Status TreeAdapter::GetNext(TensorRow *row) {
|
||||||
if (row->eoe()) { // return empty tensor if 1st buf is a ctrl buf (no rows)
|
if (row->eoe()) { // return empty tensor if 1st buf is a ctrl buf (no rows)
|
||||||
MS_LOG(INFO) << "End of data iteration.";
|
MS_LOG(INFO) << "End of data iteration.";
|
||||||
#ifndef ENABLE_SECURITY
|
#ifndef ENABLE_SECURITY
|
||||||
if (isProfilingEnable) {
|
if (is_profiling_enable) {
|
||||||
tree_->SetEpochEnd();
|
tree_->SetEpochEnd();
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -123,6 +123,9 @@ class Execute {
|
||||||
/// \brief The function to validate target device setting is valid or not.
|
/// \brief The function to validate target device setting is valid or not.
|
||||||
Status ValidateDevice();
|
Status ValidateDevice();
|
||||||
|
|
||||||
|
/// \brief Initialize 310 resource
|
||||||
|
Status InitResource(MapTargetDevice device_type, uint32_t device_id);
|
||||||
|
|
||||||
std::vector<std::shared_ptr<TensorTransform>> transforms_;
|
std::vector<std::shared_ptr<TensorTransform>> transforms_;
|
||||||
std::vector<std::shared_ptr<TensorOperation>> ops_;
|
std::vector<std::shared_ptr<TensorOperation>> ops_;
|
||||||
MapTargetDevice device_type_;
|
MapTargetDevice device_type_;
|
||||||
|
|
|
@ -251,6 +251,14 @@ static bool Conv2DImplement(const LiteMat &src, const LiteMat &kernel, T2 *dst,
|
||||||
int border_y = static_cast<int>(kernel.height_ / 2);
|
int border_y = static_cast<int>(kernel.height_ / 2);
|
||||||
|
|
||||||
LiteMat pad_mat;
|
LiteMat pad_mat;
|
||||||
|
|
||||||
|
if ((border_x > INT_MAX / 2) || (src.width_ > INT_MAX - 2 * border_x)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if ((border_y > INT_MAX / 2) || (src.height_ > INT_MAX - 2 * border_y)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
pad_mat.Init(src.width_ + 2 * border_x, src.height_ + 2 * border_y, src.channel_, src.data_type_);
|
pad_mat.Init(src.width_ + 2 * border_x, src.height_ + 2 * border_y, src.channel_, src.data_type_);
|
||||||
|
|
||||||
if (!Pad(src, pad_mat, border_y, border_y, border_x, border_x, pad_type)) {
|
if (!Pad(src, pad_mat, border_y, border_y, border_x, border_x, pad_type)) {
|
||||||
|
|
|
@ -30,7 +30,7 @@ Status ComputeUpperAndLowerPercentiles(std::vector<int32_t> *hist, int32_t hi_p,
|
||||||
int32_t n = std::accumulate(hist->begin(), hist->end(), 0);
|
int32_t n = std::accumulate(hist->begin(), hist->end(), 0);
|
||||||
constexpr float kMaxPerc = 100.0;
|
constexpr float kMaxPerc = 100.0;
|
||||||
int32_t cut = static_cast<int32_t>((low_p / kMaxPerc) * n);
|
int32_t cut = static_cast<int32_t>((low_p / kMaxPerc) * n);
|
||||||
for (int32_t lb = 0; lb < hist->size() + 1 && cut > 0; lb++) {
|
for (int32_t lb = 0; lb < hist->size() && cut > 0; lb++) {
|
||||||
if (cut > (*hist)[lb]) {
|
if (cut > (*hist)[lb]) {
|
||||||
cut -= (*hist)[lb];
|
cut -= (*hist)[lb];
|
||||||
(*hist)[lb] = 0;
|
(*hist)[lb] = 0;
|
||||||
|
|
|
@ -42,7 +42,9 @@ Status MixUpBatchOp::ComputeLabels(const TensorRow &input, std::shared_ptr<Tenso
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||||
images_size <= static_cast<size_t>(std::numeric_limits<int64_t>::max()),
|
images_size <= static_cast<size_t>(std::numeric_limits<int64_t>::max()),
|
||||||
"The \'images_size\' must not be more than \'INT64_MAX\', but got: " + std::to_string(images_size));
|
"The \'images_size\' must not be more than \'INT64_MAX\', but got: " + std::to_string(images_size));
|
||||||
for (int64_t i = 0; i < static_cast<int64_t>(images_size); i++) rand_indx->push_back(i);
|
for (int64_t i = 0; i < static_cast<int64_t>(images_size); i++) {
|
||||||
|
rand_indx->push_back(i);
|
||||||
|
}
|
||||||
std::shuffle(rand_indx->begin(), rand_indx->end(), rnd_);
|
std::shuffle(rand_indx->begin(), rand_indx->end(), rnd_);
|
||||||
|
|
||||||
RETURN_IF_NOT_OK(TypeCast(std::move(input.at(1)), out_labels, DataType(DataType::DE_FLOAT32)));
|
RETURN_IF_NOT_OK(TypeCast(std::move(input.at(1)), out_labels, DataType(DataType::DE_FLOAT32)));
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue