forked from mindspore-Ecosystem/mindspore
sync code
This commit is contained in:
parent
97eae06817
commit
f752054c19
|
@ -34,7 +34,7 @@ bool set_seed(int32_t seed) {
|
|||
MS_LOG(ERROR) << "Seed given is not within the required range: " << seed;
|
||||
return false;
|
||||
}
|
||||
_config->set_seed((uint32_t)seed);
|
||||
_config->set_seed(static_cast<uint32_t>(seed));
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -73,7 +73,7 @@ bool set_monitor_sampling_interval(int32_t interval) {
|
|||
MS_LOG(ERROR) << "Interval given is not within the required range: " << interval;
|
||||
return false;
|
||||
}
|
||||
_config->set_monitor_sampling_interval((uint32_t)interval);
|
||||
_config->set_monitor_sampling_interval(static_cast<uint32_t>(interval));
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -86,7 +86,7 @@ bool set_callback_timeback(int32_t timeout) {
|
|||
MS_LOG(ERROR) << "Timeout given is not within the required range: " << timeout;
|
||||
return false;
|
||||
}
|
||||
_config->set_callback_timeout((uint32_t)timeout);
|
||||
_config->set_callback_timeout(static_cast<uint32_t>(timeout));
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -209,7 +209,6 @@ bool Dataset::DeviceQueueCharIF(const std::vector<char> &queue_name, const std::
|
|||
MS_LOG(ERROR) << "ToDevice: Failed to get consumer.";
|
||||
return false;
|
||||
}
|
||||
|
||||
rc = consumer->Init(ds);
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "ToDevice: Failed to init. Error status: " << rc;
|
||||
|
@ -252,7 +251,6 @@ bool Dataset::SaveCharIF(const std::vector<char> &dataset_path, int32_t num_file
|
|||
MS_LOG(ERROR) << "ToDevice: Failed to get consumer.";
|
||||
return false;
|
||||
}
|
||||
|
||||
rc = consumer->Init(ds->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "CreateSaver failed." << rc;
|
||||
|
@ -283,11 +281,15 @@ bool Dataset::SaveCharIF(const std::vector<char> &dataset_path, int32_t num_file
|
|||
Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); }
|
||||
|
||||
int64_t Dataset::GetDatasetSize(bool estimate) {
|
||||
int64_t dataset_size;
|
||||
int64_t dataset_size = -1;
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
|
||||
std::shared_ptr<DatasetSizeGetter> size_getter = std::make_shared<DatasetSizeGetter>();
|
||||
DatasetSizeGetter *consumer = size_getter.get();
|
||||
if (consumer == nullptr) {
|
||||
MS_LOG(ERROR) << "DatasetSizeGetter: Failed to get consumer.";
|
||||
return -1;
|
||||
}
|
||||
runtime_context->AssignConsumer(size_getter);
|
||||
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), -1);
|
||||
RETURN_SECOND_IF_ERROR(consumer->GetDatasetSize(&dataset_size, estimate), -1);
|
||||
|
@ -299,6 +301,10 @@ std::vector<mindspore::DataType> Dataset::GetOutputTypes() {
|
|||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
|
||||
TreeGetters *consumer = tree_getters_.get();
|
||||
if (consumer == nullptr) {
|
||||
MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
|
||||
return std::vector<mindspore::DataType>();
|
||||
}
|
||||
runtime_context->AssignConsumer(tree_getters_);
|
||||
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
|
||||
RETURN_SECOND_IF_ERROR(consumer->GetOutputTypes(&types), {});
|
||||
|
@ -314,6 +320,10 @@ std::vector<std::vector<int64_t>> Dataset::GetOutputShapes() {
|
|||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
|
||||
TreeGetters *consumer = tree_getters_.get();
|
||||
if (consumer == nullptr) {
|
||||
MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
|
||||
return std::vector<std::vector<int64_t>>();
|
||||
}
|
||||
runtime_context->AssignConsumer(tree_getters_);
|
||||
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
|
||||
RETURN_SECOND_IF_ERROR(consumer->GetOutputShapes(&shapes), {});
|
||||
|
@ -324,10 +334,14 @@ std::vector<std::vector<int64_t>> Dataset::GetOutputShapes() {
|
|||
}
|
||||
|
||||
int64_t Dataset::GetNumClasses() {
|
||||
int64_t num_classes;
|
||||
int64_t num_classes = -1;
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
|
||||
TreeGetters *consumer = tree_getters_.get();
|
||||
if (consumer == nullptr) {
|
||||
MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
|
||||
return -1;
|
||||
}
|
||||
runtime_context->AssignConsumer(tree_getters_);
|
||||
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), -1);
|
||||
RETURN_SECOND_IF_ERROR(consumer->GetNumClasses(&num_classes), -1);
|
||||
|
@ -339,6 +353,10 @@ std::vector<std::vector<char>> Dataset::GetColumnNamesCharIF() {
|
|||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
|
||||
TreeGetters *consumer = tree_getters_.get();
|
||||
if (consumer == nullptr) {
|
||||
MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
|
||||
return std::vector<std::vector<char>>();
|
||||
}
|
||||
runtime_context->AssignConsumer(tree_getters_);
|
||||
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
|
||||
RETURN_SECOND_IF_ERROR(consumer->GetColumnNames(&col_names), {});
|
||||
|
@ -350,6 +368,10 @@ std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> Dataset::GetClas
|
|||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
|
||||
TreeGetters *consumer = tree_getters_.get();
|
||||
if (consumer == nullptr) {
|
||||
MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
|
||||
return std::vector<std::pair<std::vector<char>, std::vector<int32_t>>>();
|
||||
}
|
||||
runtime_context->AssignConsumer(tree_getters_);
|
||||
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
|
||||
RETURN_SECOND_IF_ERROR(consumer->GetClassIndexing(&output_class_indexing), {});
|
||||
|
@ -487,10 +509,10 @@ TakeDataset::TakeDataset(std::shared_ptr<Dataset> input, int32_t count) {
|
|||
|
||||
ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
|
||||
std::vector<std::shared_ptr<DatasetNode>> all_datasets;
|
||||
(void)std::transform(
|
||||
datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
|
||||
[](std::shared_ptr<Dataset> dataset) -> std::shared_ptr<DatasetNode> { return dataset->IRNode(); });
|
||||
|
||||
(void)std::transform(datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
|
||||
[](std::shared_ptr<Dataset> dataset) -> std::shared_ptr<DatasetNode> {
|
||||
return (dataset != nullptr) ? dataset->IRNode() : nullptr;
|
||||
});
|
||||
auto ds = std::make_shared<ZipNode>(all_datasets);
|
||||
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
|
@ -538,6 +560,10 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocabCharIF(
|
|||
|
||||
auto consumer = std::make_unique<BuildVocabConsumer>();
|
||||
BuildVocabConsumer *bv_consumer = consumer.get();
|
||||
if (bv_consumer == nullptr) {
|
||||
MS_LOG(ERROR) << "BuildVocabConsumer: Failed to get bv_consumer.";
|
||||
return nullptr;
|
||||
}
|
||||
rc = consumer->Init(ds);
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init consumer. Error status: " << rc;
|
||||
|
@ -571,6 +597,10 @@ std::shared_ptr<Vocab> Dataset::BuildVocabCharIF(const std::vector<std::vector<c
|
|||
|
||||
auto consumer = std::make_unique<BuildVocabConsumer>();
|
||||
BuildVocabConsumer *bv_consumer = consumer.get();
|
||||
if (bv_consumer == nullptr) {
|
||||
MS_LOG(ERROR) << "BuildVocabConsumer: Failed to get bv_consumer.";
|
||||
return nullptr;
|
||||
}
|
||||
rc = consumer->Init(ds);
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "BuildVocab: Failed to init consumer. Error status: " << rc;
|
||||
|
|
|
@ -54,20 +54,27 @@ struct Execute::ExtraInfo {
|
|||
#endif
|
||||
};
|
||||
|
||||
Execute::Execute(std::shared_ptr<TensorOperation> op, MapTargetDevice device_type, uint32_t device_id) {
|
||||
ops_.emplace_back(std::move(op));
|
||||
device_type_ = device_type;
|
||||
info_ = std::make_shared<ExtraInfo>();
|
||||
Status Execute::InitResource(MapTargetDevice device_type, uint32_t device_id) {
|
||||
#ifdef ENABLE_ACL
|
||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
||||
device_resource_ = std::make_shared<AscendResource>();
|
||||
Status rc = device_resource_->InitResource(device_id);
|
||||
if (!rc.IsOk()) {
|
||||
device_resource_ = nullptr;
|
||||
MS_LOG(ERROR) << "Initialize Ascend310 resource fail.";
|
||||
std::string err_msg = "Initialize Ascend310 resource fail";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Execute::Execute(std::shared_ptr<TensorOperation> op, MapTargetDevice device_type, uint32_t device_id) {
|
||||
ops_.emplace_back(std::move(op));
|
||||
device_type_ = device_type;
|
||||
info_ = std::make_shared<ExtraInfo>();
|
||||
(void)InitResource(device_type, device_id);
|
||||
}
|
||||
|
||||
Execute::Execute(std::shared_ptr<TensorTransform> op, MapTargetDevice device_type, uint32_t device_id) {
|
||||
|
@ -76,16 +83,7 @@ Execute::Execute(std::shared_ptr<TensorTransform> op, MapTargetDevice device_typ
|
|||
|
||||
info_ = std::make_shared<ExtraInfo>();
|
||||
device_type_ = device_type;
|
||||
#ifdef ENABLE_ACL
|
||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
||||
device_resource_ = std::make_shared<AscendResource>();
|
||||
Status rc = device_resource_->InitResource(device_id);
|
||||
if (!rc.IsOk()) {
|
||||
device_resource_ = nullptr;
|
||||
MS_LOG(ERROR) << "Initialize Ascend310 resource fail.";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
(void)InitResource(device_type, device_id);
|
||||
}
|
||||
|
||||
Execute::Execute(std::reference_wrapper<TensorTransform> op, MapTargetDevice device_type, uint32_t device_id) {
|
||||
|
@ -96,16 +94,7 @@ Execute::Execute(std::reference_wrapper<TensorTransform> op, MapTargetDevice dev
|
|||
info_ = std::make_shared<ExtraInfo>();
|
||||
info_->init_with_shared_ptr_ = false;
|
||||
device_type_ = device_type;
|
||||
#ifdef ENABLE_ACL
|
||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
||||
device_resource_ = std::make_shared<AscendResource>();
|
||||
Status rc = device_resource_->InitResource(device_id);
|
||||
if (!rc.IsOk()) {
|
||||
device_resource_ = nullptr;
|
||||
MS_LOG(ERROR) << "Initialize Ascend310 resource fail.";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
(void)InitResource(device_type, device_id);
|
||||
}
|
||||
|
||||
// Execute function for the example case: auto decode(new vision::Decode());
|
||||
|
@ -116,31 +105,13 @@ Execute::Execute(TensorTransform *op, MapTargetDevice device_type, uint32_t devi
|
|||
|
||||
info_ = std::make_shared<ExtraInfo>();
|
||||
device_type_ = device_type;
|
||||
#ifdef ENABLE_ACL
|
||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
||||
device_resource_ = std::make_shared<AscendResource>();
|
||||
Status rc = device_resource_->InitResource(device_id);
|
||||
if (!rc.IsOk()) {
|
||||
device_resource_ = nullptr;
|
||||
MS_LOG(ERROR) << "Initialize Ascend310 resource fail.";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
(void)InitResource(device_type, device_id);
|
||||
}
|
||||
|
||||
Execute::Execute(std::vector<std::shared_ptr<TensorOperation>> ops, MapTargetDevice device_type, uint32_t device_id)
|
||||
: ops_(std::move(ops)), device_type_(device_type) {
|
||||
info_ = std::make_shared<ExtraInfo>();
|
||||
#ifdef ENABLE_ACL
|
||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
||||
device_resource_ = std::make_shared<AscendResource>();
|
||||
Status rc = device_resource_->InitResource(device_id);
|
||||
if (!rc.IsOk()) {
|
||||
device_resource_ = nullptr;
|
||||
MS_LOG(ERROR) << "Initialize Ascend310 resource fail.";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
(void)InitResource(device_type, device_id);
|
||||
}
|
||||
|
||||
Execute::Execute(std::vector<std::shared_ptr<TensorTransform>> ops, MapTargetDevice device_type, uint32_t device_id) {
|
||||
|
@ -149,16 +120,7 @@ Execute::Execute(std::vector<std::shared_ptr<TensorTransform>> ops, MapTargetDev
|
|||
|
||||
info_ = std::make_shared<ExtraInfo>();
|
||||
device_type_ = device_type;
|
||||
#ifdef ENABLE_ACL
|
||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
||||
device_resource_ = std::make_shared<AscendResource>();
|
||||
Status rc = device_resource_->InitResource(device_id);
|
||||
if (!rc.IsOk()) {
|
||||
device_resource_ = nullptr;
|
||||
MS_LOG(ERROR) << "Initialize Ascend310 resource fail.";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
(void)InitResource(device_type, device_id);
|
||||
}
|
||||
|
||||
Execute::Execute(const std::vector<std::reference_wrapper<TensorTransform>> ops, MapTargetDevice device_type,
|
||||
|
@ -177,16 +139,7 @@ Execute::Execute(const std::vector<std::reference_wrapper<TensorTransform>> ops,
|
|||
info_ = std::make_shared<ExtraInfo>();
|
||||
info_->init_with_shared_ptr_ = false;
|
||||
device_type_ = device_type;
|
||||
#ifdef ENABLE_ACL
|
||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
||||
device_resource_ = std::make_shared<AscendResource>();
|
||||
Status rc = device_resource_->InitResource(device_id);
|
||||
if (!rc.IsOk()) {
|
||||
device_resource_ = nullptr;
|
||||
MS_LOG(ERROR) << "Initialize Ascend310 resource fail";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
(void)InitResource(device_type, device_id);
|
||||
}
|
||||
|
||||
// Execute function for the example vector case: auto decode(new vision::Decode());
|
||||
|
@ -199,16 +152,7 @@ Execute::Execute(const std::vector<TensorTransform *> &ops, MapTargetDevice devi
|
|||
|
||||
info_ = std::make_shared<ExtraInfo>();
|
||||
device_type_ = device_type;
|
||||
#ifdef ENABLE_ACL
|
||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
||||
device_resource_ = std::make_shared<AscendResource>();
|
||||
Status rc = device_resource_->InitResource(device_id);
|
||||
if (!rc.IsOk()) {
|
||||
device_resource_ = nullptr;
|
||||
MS_LOG(ERROR) << "Initialize Ascend310 resource fail";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
(void)InitResource(device_type, device_id);
|
||||
}
|
||||
|
||||
Execute::~Execute() {
|
||||
|
@ -225,6 +169,7 @@ Execute::~Execute() {
|
|||
|
||||
Status Execute::operator()(const mindspore::MSTensor &input, mindspore::MSTensor *output) {
|
||||
// Validate input tensor
|
||||
RETURN_UNEXPECTED_IF_NULL(output);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input.DataSize() > 0, "Input Tensor has no data.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ValidateDevice(), "Device Type should be 'Ascend310' or 'CPU'.");
|
||||
|
||||
|
@ -283,7 +228,8 @@ Status Execute::operator()(const mindspore::MSTensor &input, mindspore::MSTensor
|
|||
RETURN_STATUS_UNEXPECTED(ss.str());
|
||||
}
|
||||
*output = mindspore::MSTensor(std::make_shared<DETensor>(de_tensor));
|
||||
} else { // Ascend310 case, where we must set Ascend resource on each operators
|
||||
} else if (device_type_ ==
|
||||
MapTargetDevice::kAscend310) { // Ascend310 case, where we must set Ascend resource on each operators
|
||||
#ifdef ENABLE_ACL
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(device_resource_, "Device resource is nullptr which is illegal under case Ascend310.");
|
||||
// Sink data from host into device
|
||||
|
@ -304,12 +250,17 @@ Status Execute::operator()(const mindspore::MSTensor &input, mindspore::MSTensor
|
|||
|
||||
*output = mindspore::MSTensor(std::make_shared<DETensor>(device_input, true));
|
||||
#endif
|
||||
} else {
|
||||
std::string err_msg = "Your input device is not supported. (Option: CPU or Ascend310)";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Execute::operator()(const std::vector<MSTensor> &input_tensor_list, std::vector<MSTensor> *output_tensor_list) {
|
||||
// Validate input tensor
|
||||
RETURN_UNEXPECTED_IF_NULL(output_tensor_list);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!input_tensor_list.empty(), "Input Tensor is not valid.");
|
||||
for (auto &tensor : input_tensor_list) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(tensor.DataSize() > 0, "Input Tensor has no data.");
|
||||
|
@ -371,7 +322,8 @@ Status Execute::operator()(const std::vector<MSTensor> &input_tensor_list, std::
|
|||
++idx;
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!output_tensor_list->empty(), "Output Tensor is not valid.");
|
||||
} else { // Case Ascend310
|
||||
} else if (device_type_ ==
|
||||
MapTargetDevice::kAscend310) { // Ascend310 case, where we must set Ascend resource on each operators
|
||||
#ifdef ENABLE_ACL
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(device_resource_, "Device resource is nullptr which is illegal under case Ascend310.");
|
||||
for (auto &input_tensor : input_tensor_list) {
|
||||
|
@ -401,6 +353,10 @@ Status Execute::operator()(const std::vector<MSTensor> &input_tensor_list, std::
|
|||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!output_tensor_list->empty(), "Output Tensor vector is empty.");
|
||||
#endif
|
||||
} else {
|
||||
std::string err_msg = "Your input device is not supported. (Option: CPU or Ascend310)";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -517,9 +473,17 @@ Status AippInfoCollection(std::map<std::string, std::string> *aipp_options, cons
|
|||
|
||||
std::string Execute::AippCfgGenerator() {
|
||||
std::string config_location = "./aipp.cfg";
|
||||
if (info_ == nullptr) {
|
||||
MS_LOG(ERROR) << "info_ is null";
|
||||
return "";
|
||||
}
|
||||
#ifdef ENABLE_ACL
|
||||
if (info_->init_with_shared_ptr_) {
|
||||
ParseTransforms();
|
||||
auto s = ParseTransforms();
|
||||
if (s != Status::OK()) {
|
||||
MS_LOG(ERROR) << "Error in ParseTransforms";
|
||||
return "";
|
||||
}
|
||||
info_->init_with_shared_ptr_ = false;
|
||||
}
|
||||
std::vector<uint32_t> paras; // Record the parameters value of each Ascend operators
|
||||
|
@ -548,8 +512,7 @@ std::string Execute::AippCfgGenerator() {
|
|||
if (!outfile.is_open()) {
|
||||
MS_LOG(ERROR) << "Fail to open Aipp config file, please verify your system config(including authority)."
|
||||
<< "We will return empty string which represent the location of Aipp config file in this case.";
|
||||
std::string except = "";
|
||||
return except;
|
||||
return "";
|
||||
}
|
||||
|
||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
||||
|
@ -629,10 +592,14 @@ Status Execute::ParseTransforms() {
|
|||
[](std::shared_ptr<TensorTransform> operation) -> std::shared_ptr<TensorOperation> {
|
||||
return operation->Parse();
|
||||
});
|
||||
} else {
|
||||
} else if (device_type_ == MapTargetDevice::kAscend310) {
|
||||
for (auto &transform_ : transforms_) {
|
||||
ops_.emplace_back(transform_->Parse(device_type_));
|
||||
}
|
||||
} else {
|
||||
std::string err_msg = "Your input device is not supported. (Option: CPU or Ascend310)";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -38,7 +38,7 @@ Status Iterator::GetNextRowCharIF(MSTensorMapChar *row) {
|
|||
row->clear();
|
||||
return rc;
|
||||
}
|
||||
for (auto de_tensor : md_map) {
|
||||
for (auto &de_tensor : md_map) {
|
||||
std::vector<char> col_name(de_tensor.first.begin(), de_tensor.first.end());
|
||||
row->insert(std::make_pair(col_name, mindspore::MSTensor(std::make_shared<DETensor>(de_tensor.second))));
|
||||
}
|
||||
|
@ -48,8 +48,8 @@ Status Iterator::GetNextRowCharIF(MSTensorMapChar *row) {
|
|||
|
||||
// Get the next row from the data pipeline.
|
||||
Status Iterator::GetNextRow(MSTensorVec *row) {
|
||||
// Clean data row
|
||||
RETURN_UNEXPECTED_IF_NULL(row);
|
||||
// Clean data row
|
||||
row->clear();
|
||||
// create a dataset tensor row and fetch. Then we convert the output to MSTensor
|
||||
std::vector<std::shared_ptr<dataset::Tensor>> md_row;
|
||||
|
@ -76,6 +76,7 @@ void Iterator::Stop() {
|
|||
|
||||
// Function to build and launch the execution tree.
|
||||
Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds, int32_t num_epochs) {
|
||||
RETURN_UNEXPECTED_IF_NULL(ds);
|
||||
runtime_context_ = std::make_unique<NativeRuntimeContext>();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(runtime_context_ != nullptr, "Create runtime_context_ failed.");
|
||||
RETURN_IF_NOT_OK(runtime_context_->Init());
|
||||
|
|
|
@ -102,7 +102,7 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
|
|||
// clear the old tensor row
|
||||
out_row->clear();
|
||||
|
||||
bool isProfilingEnable = root_->Tree()->GetProfilingManager()->IsProfilingEnable();
|
||||
bool is_profiling_enable = root_->Tree()->GetProfilingManager()->IsProfilingEnable();
|
||||
|
||||
// Once eof is handled, always return empty row. Class must be destroyed and recreated if you
|
||||
// want to iterate again.
|
||||
|
@ -124,7 +124,7 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
|
|||
// The next row in the pipeline might be an EOF or a TensorRow for next epoch
|
||||
if (out_row->eoe()) {
|
||||
MS_LOG(INFO) << "End of data iteration.";
|
||||
if (isProfilingEnable) {
|
||||
if (is_profiling_enable) {
|
||||
root_->Tree()->SetEpochEnd();
|
||||
}
|
||||
return Status::OK();
|
||||
|
|
|
@ -250,9 +250,13 @@ Status BatchOp::WorkerEntry(int32_t workerId) {
|
|||
Status BatchOp::MakeBatchedRow(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair, TensorRow *new_row) {
|
||||
RETURN_UNEXPECTED_IF_NULL(table_pair.first);
|
||||
#ifdef ENABLE_PYTHON
|
||||
if (!in_col_names_.empty()) RETURN_IF_NOT_OK(MapColumns(&table_pair)); // pass it through pyfunc
|
||||
if (!in_col_names_.empty()) {
|
||||
RETURN_IF_NOT_OK(MapColumns(&table_pair));
|
||||
} // pass it through pyfun
|
||||
#endif
|
||||
if (pad_) RETURN_IF_NOT_OK(PadColumns(&table_pair.first, pad_info_, column_name_id_map_)); // do padding if needed
|
||||
if (pad_) {
|
||||
RETURN_IF_NOT_OK(PadColumns(&table_pair.first, pad_info_, column_name_id_map_));
|
||||
} // do padding if needed
|
||||
RETURN_IF_NOT_OK(BatchRows(&table_pair.first, new_row, table_pair.first->size()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -242,7 +242,7 @@ class BatchOp : public ParallelOp {
|
|||
|
||||
// the number of thread pulling from the mOutConnector of the Op below
|
||||
// @return int32_t, 1
|
||||
int32_t num_consumers() const override { return 1; }
|
||||
int32_t NumConsumers() const override { return 1; }
|
||||
|
||||
// get the batch size for next batch
|
||||
// @return Status The status code returned
|
||||
|
|
|
@ -71,11 +71,11 @@ class BuildSentencePieceVocabOp : public PipelineOp {
|
|||
|
||||
// Getter
|
||||
// @return the number of workers
|
||||
int32_t num_producers() const override { return 1; }
|
||||
int32_t NumProducers() const override { return 1; }
|
||||
|
||||
// Getter
|
||||
// @return the number of threads consuming from the previous Connector
|
||||
int32_t num_consumers() const override { return 1; }
|
||||
int32_t NumConsumers() const override { return 1; }
|
||||
|
||||
Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildSentencePieceVocabOp"); }
|
||||
|
||||
|
|
|
@ -68,11 +68,11 @@ class BuildVocabOp : public ParallelOp {
|
|||
|
||||
/// Getter
|
||||
/// @return the number of workers
|
||||
int32_t num_producers() const override { return 1; }
|
||||
int32_t NumProducers() const override { return 1; }
|
||||
|
||||
/// Getter
|
||||
/// @return the number of threads consuming from the previous Connector
|
||||
int32_t num_consumers() const override { return 1; }
|
||||
int32_t NumConsumers() const override { return 1; }
|
||||
|
||||
Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildVocabOp"); }
|
||||
|
||||
|
|
|
@ -69,6 +69,7 @@ Status CacheBase::FetchSamplesToWorkers() {
|
|||
int64_t wait_cnt = 0;
|
||||
int64_t prefetch_cnt = 0;
|
||||
// Kick off several threads which will prefetch prefetch_size_ rows in advance.
|
||||
RETURN_UNEXPECTED_IF_NULL(tree_);
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_prefetchers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1), Name()));
|
||||
auto send_to_que = [](QueueList<std::unique_ptr<IOBlock>> &qList, int32_t worker_id,
|
||||
|
@ -80,7 +81,7 @@ Status CacheBase::FetchSamplesToWorkers() {
|
|||
// Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them
|
||||
// to the WorkerEntry.
|
||||
do {
|
||||
if (AllowCacheMiss() && wait_cnt > 0 && wait_cnt % op_num_repeats_per_epoch() == 0) {
|
||||
if (AllowCacheMiss() && wait_cnt > 0 && wait_cnt % GetOpNumRepeatsPerEpoch() == 0) {
|
||||
MS_LOG(INFO) << "Epoch: " << op_current_epochs_ << " Cache Miss : " << num_cache_miss_
|
||||
<< " Total number of rows : " << row_cnt_;
|
||||
}
|
||||
|
@ -155,7 +156,7 @@ Status CacheBase::FetchSamplesToWorkers() {
|
|||
}
|
||||
// Dump the last epoch result (approximately) without waiting for the worker threads to come back.
|
||||
if (AllowCacheMiss()) {
|
||||
MS_LOG(INFO) << "Epoch: " << wait_cnt / op_num_repeats_per_epoch() << " Cache Miss : " << num_cache_miss_
|
||||
MS_LOG(INFO) << "Epoch: " << wait_cnt / GetOpNumRepeatsPerEpoch() << " Cache Miss : " << num_cache_miss_
|
||||
<< " Total number of rows : " << row_cnt_;
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -202,6 +203,7 @@ Status CacheBase::FetchFromCache(int32_t worker_id) {
|
|||
}
|
||||
|
||||
Status CacheBase::RegisterResources() {
|
||||
RETURN_UNEXPECTED_IF_NULL(tree_);
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks()));
|
||||
|
|
|
@ -75,7 +75,7 @@ class CacheBase : public ParallelOp {
|
|||
|
||||
/// \brief Getter for the cache client
|
||||
/// \return shared ptr to the cache client
|
||||
std::shared_ptr<CacheClient> cache_client() { return cache_client_; }
|
||||
std::shared_ptr<CacheClient> GetCacheClient() { return cache_client_; }
|
||||
/// \brief Setter for the cache client
|
||||
void SetCacheClient(std::shared_ptr<CacheClient> cache_client) { cache_client_ = std::move(cache_client); }
|
||||
/// \brief Derived class must implement this method if a cache miss is treated as error
|
||||
|
|
|
@ -73,6 +73,7 @@ Status CacheOp::InitCache() { return Status::OK(); }
|
|||
|
||||
// This class functor will provide the master loop that drives the logic for performing the work
|
||||
Status CacheOp::operator()() {
|
||||
RETURN_UNEXPECTED_IF_NULL(tree_);
|
||||
if (!sampler_) {
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Invalid parameter, CacheOp requires a sampler before it can be executed, but got nullptr.");
|
||||
|
@ -199,6 +200,7 @@ Status CacheOp::WorkerEntry(int32_t worker_id) {
|
|||
return Status::OK();
|
||||
}
|
||||
Status CacheOp::RegisterResources() {
|
||||
RETURN_UNEXPECTED_IF_NULL(tree_);
|
||||
RETURN_IF_NOT_OK(CacheBase::RegisterResources());
|
||||
RETURN_IF_NOT_OK(rows_cache_done_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(keys_miss_->Register(tree_->AllTasks()));
|
||||
|
|
|
@ -194,7 +194,7 @@ Status ConcatOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
int32_t ConcatOp::num_consumers() const {
|
||||
int32_t ConcatOp::NumConsumers() const {
|
||||
if (parent_.empty()) {
|
||||
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
|
||||
return 1;
|
||||
|
@ -202,16 +202,16 @@ int32_t ConcatOp::num_consumers() const {
|
|||
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return parent_[0]->num_consumers();
|
||||
return parent_[0]->NumConsumers();
|
||||
}
|
||||
}
|
||||
|
||||
int32_t ConcatOp::num_producers() const {
|
||||
int32_t ConcatOp::NumProducers() const {
|
||||
if (child_.empty() || child_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return child_[0]->num_producers();
|
||||
return child_[0]->NumProducers();
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
|
|
|
@ -73,8 +73,8 @@ class ConcatOp : public PipelineOp {
|
|||
Status GetNumClasses(int64_t *num_classes) override;
|
||||
|
||||
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
||||
int32_t num_consumers() const override;
|
||||
int32_t num_producers() const override;
|
||||
int32_t NumConsumers() const override;
|
||||
int32_t NumProducers() const override;
|
||||
|
||||
/// Check if the current sample will be taken or dropped
|
||||
/// \return bool
|
||||
|
|
|
@ -312,9 +312,9 @@ Status DatasetOp::PrepareOperator() {
|
|||
// The consumer of the root node is assumed to be one thread.
|
||||
// If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion.
|
||||
if (parent_.empty()) {
|
||||
this->CreateConnector(num_producers(), 1);
|
||||
this->CreateConnector(NumProducers(), 1);
|
||||
} else {
|
||||
this->CreateConnector(num_producers(), parent_[0]->num_consumers());
|
||||
this->CreateConnector(NumProducers(), parent_[0]->NumConsumers());
|
||||
}
|
||||
if (out_connector_) {
|
||||
RETURN_IF_NOT_OK(out_connector_->Register(tree_->AllTasks()));
|
||||
|
|
|
@ -209,33 +209,33 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
|
||||
// \brief Getter function
|
||||
// \return The number of workers in this op
|
||||
virtual int32_t num_workers() const = 0;
|
||||
virtual int32_t NumWorkers() const = 0;
|
||||
|
||||
// \brief Getter function
|
||||
// \return The number of threads consuming from previous op.
|
||||
virtual int32_t num_consumers() const = 0;
|
||||
virtual int32_t NumConsumers() const = 0;
|
||||
|
||||
// \brief Getter function
|
||||
// \return The number of threads producing to the output connector.
|
||||
virtual int32_t num_producers() const = 0;
|
||||
virtual int32_t NumProducers() const = 0;
|
||||
|
||||
// \brief Getter function
|
||||
// \return T/F if this is an inlined operator
|
||||
bool inlined() const { return (oc_queue_size_ == 0); }
|
||||
|
||||
// \brief Setter function, set the number of total repeats for the operator
|
||||
void set_total_repeats(int32_t total_repeats) { op_total_repeats_ = total_repeats; }
|
||||
void SetTotalRepeats(int32_t total_repeats) { op_total_repeats_ = total_repeats; }
|
||||
|
||||
// \brief Setter function, set the number of repeats per epoch for the operator
|
||||
void set_num_repeats_per_epoch(int32_t num_repeats_per_epoch) { op_num_repeats_per_epoch_ = num_repeats_per_epoch; }
|
||||
void SetNumRepeatsPerEpoch(int32_t num_repeats_per_epoch) { op_num_repeats_per_epoch_ = num_repeats_per_epoch; }
|
||||
|
||||
// \brief Getter function
|
||||
// \return The number of required repeats for the operator
|
||||
int32_t op_total_repeats() { return op_total_repeats_; }
|
||||
int32_t GetOpTotalRepeats() { return op_total_repeats_; }
|
||||
|
||||
// \brief Getter function
|
||||
// \return The number of repeats per epoch for the operator
|
||||
int32_t op_num_repeats_per_epoch() const { return op_num_repeats_per_epoch_; }
|
||||
int32_t GetOpNumRepeatsPerEpoch() const { return op_num_repeats_per_epoch_; }
|
||||
|
||||
// \brief Register the internal worker connectors. No op unless it is a parallel op
|
||||
// \return Status
|
||||
|
@ -393,7 +393,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
// \notes No public interface. Only the class itself, or it's friend the execution tree can set
|
||||
// this
|
||||
// \param op_id - the Id value to set into the operator
|
||||
void set_id(int32_t op_id) { operator_id_ = op_id; }
|
||||
void SetId(int32_t op_id) { operator_id_ = op_id; }
|
||||
|
||||
// Sets the tree into the op so that the operator has a back pointer to the tree.
|
||||
// \param tree - the tree to assign to the op.
|
||||
|
|
|
@ -173,8 +173,8 @@ Status DeviceQueueOp::SendDataToAscend() {
|
|||
int64_t sending_num = cfg->sending_batches(); // Get the current sending_num
|
||||
|
||||
std::shared_ptr<DeviceQueueTracing> profiling_node;
|
||||
bool isProfilingEnable = tree_->GetProfilingManager()->IsProfilingEnable();
|
||||
if (isProfilingEnable) {
|
||||
bool is_profiling_enable = tree_->GetProfilingManager()->IsProfilingEnable();
|
||||
if (is_profiling_enable) {
|
||||
std::shared_ptr<Tracing> node;
|
||||
RETURN_IF_NOT_OK(tree_->GetProfilingManager()->GetTracingNode(kDeviceQueueTracingName, &node));
|
||||
profiling_node = std::dynamic_pointer_cast<DeviceQueueTracing>(node);
|
||||
|
@ -197,12 +197,12 @@ Status DeviceQueueOp::SendDataToAscend() {
|
|||
md_channel_info_->RecordPreprocessBatch(send_batch);
|
||||
md_channel_info_->RecordPushStartTime();
|
||||
#endif
|
||||
RETURN_IF_NOT_OK(SendRowToTdt(curr_row, isProfilingEnable, &tdt_cost));
|
||||
RETURN_IF_NOT_OK(SendRowToTdt(curr_row, is_profiling_enable, &tdt_cost));
|
||||
if (first_push_flag_ != true) {
|
||||
MS_LOG(INFO) << "Loading dataset and push first batch into device successful";
|
||||
first_push_flag_ = true;
|
||||
}
|
||||
ProfilingRecorder(isProfilingEnable, profiling_node, send_batch, tdt_cost, &batch_start_time, &end_time,
|
||||
ProfilingRecorder(is_profiling_enable, profiling_node, send_batch, tdt_cost, &batch_start_time, &end_time,
|
||||
connector_capacity, connector_size);
|
||||
send_batch++;
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
|
@ -219,15 +219,15 @@ Status DeviceQueueOp::SendDataToAscend() {
|
|||
// wait when sending num is not 0, and sending num no larger than already sending batch
|
||||
LimitSendingBatches(send_batch, &sending_num, cfg);
|
||||
|
||||
if (isProfilingEnable) {
|
||||
if (is_profiling_enable) {
|
||||
connector_size = ChildOpConnectorSize();
|
||||
connector_capacity = ChildOpConnectorCapacity();
|
||||
}
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&curr_row));
|
||||
}
|
||||
if (curr_row.eoe() && send_epoch_end_) {
|
||||
TensorRow currRow;
|
||||
auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost,
|
||||
TensorRow dummy_row;
|
||||
auto status = tdtInstancePtr->hostPush(dummy_row, true, channel_name_, is_profiling_enable, tdt_cost,
|
||||
ACL_TENSOR_DATA_END_OF_SEQUENCE);
|
||||
if (status != Status::OK()) {
|
||||
if (stop_send_) {
|
||||
|
@ -246,7 +246,7 @@ Status DeviceQueueOp::SendDataToAscend() {
|
|||
MS_LOG(INFO) << "an epoch has already sent, now stop send data.";
|
||||
stop_send_ = true;
|
||||
}
|
||||
if (isProfilingEnable) {
|
||||
if (is_profiling_enable) {
|
||||
connector_size = ChildOpConnectorSize();
|
||||
connector_capacity = ChildOpConnectorCapacity();
|
||||
tree_->SetEpochEnd();
|
||||
|
@ -283,8 +283,8 @@ void DeviceQueueOp::LimitSendingBatches(int64_t send_batch, int64_t *sending_num
|
|||
}
|
||||
}
|
||||
|
||||
Status DeviceQueueOp::SendRowToTdt(TensorRow currRow, bool isProfilingEnable, int32_t *tdt_cost) {
|
||||
auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, *tdt_cost);
|
||||
Status DeviceQueueOp::SendRowToTdt(TensorRow curr_row, bool is_profiling_enable, int32_t *tdt_cost) {
|
||||
auto status = tdtInstancePtr->hostPush(curr_row, true, channel_name_, is_profiling_enable, *tdt_cost);
|
||||
if (status != Status::OK()) {
|
||||
if (stop_send_) {
|
||||
MS_LOG(INFO) << "stop_send received";
|
||||
|
@ -300,7 +300,7 @@ Status DeviceQueueOp::SendRowToTdt(TensorRow currRow, bool isProfilingEnable, in
|
|||
}
|
||||
if (create_data_info_queue_) {
|
||||
DATA_INFO data_info;
|
||||
(void)std::transform(currRow.begin(), currRow.end(), std::back_inserter(data_info),
|
||||
(void)std::transform(curr_row.begin(), curr_row.end(), std::back_inserter(data_info),
|
||||
[](const std::shared_ptr<Tensor> &ts) { return std::make_pair(ts->type(), ts->shape()); });
|
||||
RETURN_IF_NOT_OK(data_info_queue_ptr_->Add(data_info));
|
||||
}
|
||||
|
@ -348,6 +348,7 @@ Status DeviceQueueOp::SetThreadDevice() {
|
|||
}
|
||||
|
||||
Status DeviceQueueOp::LaunchParallelCopyThread() {
|
||||
RETURN_UNEXPECTED_IF_NULL(tree_);
|
||||
// Every thread use cuda api should SetThreadDevice
|
||||
RETURN_IF_NOT_OK(SetThreadDevice());
|
||||
// CircularPool may not safe under multi-threads scenario, so one worker with one pool
|
||||
|
@ -368,6 +369,7 @@ Status DeviceQueueOp::LaunchParallelCopyThread() {
|
|||
}
|
||||
|
||||
Status DeviceQueueOp::PushDataToGPU() {
|
||||
RETURN_UNEXPECTED_IF_NULL(tree_);
|
||||
// Every thread use cuda api should SetThreadDevice
|
||||
RETURN_IF_NOT_OK(SetThreadDevice());
|
||||
TaskManager::FindMe()->Post();
|
||||
|
@ -376,8 +378,8 @@ Status DeviceQueueOp::PushDataToGPU() {
|
|||
int32_t connector_size = 0;
|
||||
int32_t connector_capacity = 0;
|
||||
std::shared_ptr<DeviceQueueTracing> profiling_node;
|
||||
bool isProfilingEnable = tree_->GetProfilingManager()->IsProfilingEnable();
|
||||
if (isProfilingEnable) {
|
||||
bool is_profiling_enable = tree_->GetProfilingManager()->IsProfilingEnable();
|
||||
if (is_profiling_enable) {
|
||||
std::shared_ptr<Tracing> node;
|
||||
RETURN_IF_NOT_OK(tree_->GetProfilingManager()->GetTracingNode(kDeviceQueueTracingName, &node));
|
||||
profiling_node = std::dynamic_pointer_cast<DeviceQueueTracing>(node);
|
||||
|
@ -421,7 +423,7 @@ Status DeviceQueueOp::PushDataToGPU() {
|
|||
}
|
||||
RETURN_IF_NOT_OK(RetryPushData(handle, items));
|
||||
send_batch++;
|
||||
if (isProfilingEnable) {
|
||||
if (is_profiling_enable) {
|
||||
uint64_t end_time = ProfilingTime::GetCurMilliSecond();
|
||||
// record push data time
|
||||
profiling_node->Record(TIME, TDT_PUSH_TIME, send_batch, push_cost, end_time);
|
||||
|
@ -470,7 +472,7 @@ Status DeviceQueueOp::PushDataToGPU() {
|
|||
}
|
||||
|
||||
Status DeviceQueueOp::RetryPushData(unsigned int handle, const std::vector<DataItemGpu> &items) {
|
||||
bool flagLog = false;
|
||||
bool flag_log = false;
|
||||
while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) {
|
||||
BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME);
|
||||
if (ret) {
|
||||
|
@ -479,9 +481,9 @@ Status DeviceQueueOp::RetryPushData(unsigned int handle, const std::vector<DataI
|
|||
"Invalid data, check the output of dataset with creating iterator and print data item.");
|
||||
} else {
|
||||
if (!stop_send_) {
|
||||
if (!flagLog) {
|
||||
if (!flag_log) {
|
||||
MS_LOG(DEBUG) << "Retry pushing data...";
|
||||
flagLog = true;
|
||||
flag_log = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
@ -625,11 +627,11 @@ void DeviceQueueOp::Print(std::ostream &out, bool show_all) const {
|
|||
}
|
||||
}
|
||||
|
||||
void DeviceQueueOp::ProfilingRecorder(bool isProfilingEnable, std::shared_ptr<DeviceQueueTracing> profiling_node,
|
||||
void DeviceQueueOp::ProfilingRecorder(bool is_profiling_enable, std::shared_ptr<DeviceQueueTracing> profiling_node,
|
||||
int64_t send_batch, int32_t tdt_cost, uint64_t *batch_start_time,
|
||||
uint64_t *end_time, int32_t connector_capacity, int32_t connector_size) {
|
||||
// Record the pipeline profiling info
|
||||
if (isProfilingEnable) {
|
||||
if (is_profiling_enable) {
|
||||
*end_time = ProfilingTime::GetCurMilliSecond();
|
||||
// record push tdt time
|
||||
profiling_node->Record(TIME, TDT_PUSH_TIME, send_batch + 1, tdt_cost, *end_time);
|
||||
|
|
|
@ -74,7 +74,7 @@ class DeviceQueueOp : public PipelineOp {
|
|||
|
||||
Status EoeReceived(int32_t worker_id) override;
|
||||
|
||||
const int32_t get_prefetch_size() { return prefetch_size_; }
|
||||
const int32_t GetPrefetchSize() { return prefetch_size_; }
|
||||
|
||||
void StopSend() { stop_send_ = true; }
|
||||
|
||||
|
@ -103,9 +103,9 @@ class DeviceQueueOp : public PipelineOp {
|
|||
Status operator()() override;
|
||||
|
||||
// Record the pipeline profiling info
|
||||
void ProfilingRecorder(bool isProfilingEnable, std::shared_ptr<DeviceQueueTracing> profiling_node, int64_t send_batch,
|
||||
int32_t tdt_cost, uint64_t *batch_start_time, uint64_t *end_time, int32_t connector_capacity,
|
||||
int32_t connector_size);
|
||||
void ProfilingRecorder(bool is_profiling_enable, std::shared_ptr<DeviceQueueTracing> profiling_node,
|
||||
int64_t send_batch, int32_t tdt_cost, uint64_t *batch_start_time, uint64_t *end_time,
|
||||
int32_t connector_capacity, int32_t connector_size);
|
||||
|
||||
// Op name getter
|
||||
// @return Name of the current Op
|
||||
|
@ -125,7 +125,7 @@ class DeviceQueueOp : public PipelineOp {
|
|||
void WaitContinueSignal() const;
|
||||
Status SendDataToAscend();
|
||||
void LimitSendingBatches(int64_t send_batch, int64_t *sending_num, std::shared_ptr<ConfigManager> cfg);
|
||||
Status SendRowToTdt(TensorRow currRow, bool isProfilingEnable, int32_t *tdt_cost);
|
||||
Status SendRowToTdt(TensorRow curr_row, bool is_profiling_enable, int32_t *tdt_cost);
|
||||
bool ascend_keep_waiting_;
|
||||
#endif
|
||||
|
||||
|
|
|
@ -218,7 +218,7 @@ Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate
|
|||
|
||||
return Status(StatusCode::kSuccess, "FilterOp predicate func call succeed");
|
||||
}
|
||||
int32_t FilterOp::num_consumers() const { return 1; }
|
||||
int32_t FilterOp::NumConsumers() const { return 1; }
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -69,7 +69,7 @@ class FilterOp : public ParallelOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return kFilterOp; }
|
||||
|
||||
int32_t num_consumers() const override;
|
||||
int32_t NumConsumers() const override;
|
||||
|
||||
private:
|
||||
// predicate_func python callable which returns a boolean value.
|
||||
|
|
|
@ -46,7 +46,7 @@ MapOp::MapOp(const std::vector<std::string> &in_col_names, const std::vector<std
|
|||
}
|
||||
|
||||
// The number of threads consuming data from previous op's output Connector.
|
||||
int32_t MapOp::num_consumers() const {
|
||||
int32_t MapOp::NumConsumers() const {
|
||||
// When Performance Mode is on, there is only one thread consuming from the previous Connector.
|
||||
return 1;
|
||||
}
|
||||
|
@ -144,7 +144,7 @@ Status MapOp::operator()() {
|
|||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
|
||||
while (!new_row.eof()) {
|
||||
if (op_current_repeats_ % op_num_repeats_per_epoch() == 0) {
|
||||
if (op_current_repeats_ % GetOpNumRepeatsPerEpoch() == 0) {
|
||||
RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
||||
}
|
||||
while (!new_row.eoe()) {
|
||||
|
@ -168,7 +168,7 @@ Status MapOp::operator()() {
|
|||
}
|
||||
|
||||
// check whether this is the end of a real epoch (not all eoe signals end of epoch)
|
||||
if ((op_current_repeats_ + 1) % op_num_repeats_per_epoch() == 0) {
|
||||
if ((op_current_repeats_ + 1) % GetOpNumRepeatsPerEpoch() == 0) {
|
||||
RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
||||
|
||||
ep_step = 0;
|
||||
|
|
|
@ -101,7 +101,7 @@ class MapOp : public ParallelOp {
|
|||
|
||||
// Getter
|
||||
// @return the number of threads consuming data from previous op's output Connector.
|
||||
int32_t num_consumers() const override;
|
||||
int32_t NumConsumers() const override;
|
||||
|
||||
// Op name getter
|
||||
// @return Name of the current Op
|
||||
|
|
|
@ -72,11 +72,11 @@ class ParallelOp : public DatasetOp {
|
|||
|
||||
// Getter
|
||||
// @return the number of workers
|
||||
int32_t num_workers() const override { return num_workers_; }
|
||||
int32_t NumWorkers() const override { return num_workers_; }
|
||||
|
||||
// Getter
|
||||
// @return the number of threads consuming from the previous Connector
|
||||
int32_t num_consumers() const override { return num_workers_; }
|
||||
int32_t NumConsumers() const override { return num_workers_; }
|
||||
|
||||
// Getter
|
||||
// @return the number of producers pushing to the output Connector
|
||||
|
@ -84,7 +84,7 @@ class ParallelOp : public DatasetOp {
|
|||
// when a worker connector is set up. In that case, there are n workers, and a single master
|
||||
// such that only 1 thread is a producer rather than the n workers.
|
||||
// @return the number of producers
|
||||
int32_t num_producers() const override { return num_producers_; }
|
||||
int32_t NumProducers() const override { return num_producers_; }
|
||||
|
||||
// Register the internal worker connectors.
|
||||
// @return Status
|
||||
|
|
|
@ -55,15 +55,15 @@ class PipelineOp : public DatasetOp {
|
|||
|
||||
// Getter
|
||||
// @return The number of workers inside this op. Pipeline ops only have a single worker.
|
||||
int32_t num_workers() const override { return 1; }
|
||||
int32_t NumWorkers() const override { return 1; }
|
||||
|
||||
// Getter
|
||||
// @return the number of threads consuming from the previous Connector
|
||||
int32_t num_consumers() const override { return 1; }
|
||||
int32_t NumConsumers() const override { return 1; }
|
||||
|
||||
// Getter
|
||||
// @return The number of threads that push data to the output connector
|
||||
int32_t num_producers() const override { return 1; }
|
||||
int32_t NumProducers() const override { return 1; }
|
||||
|
||||
protected:
|
||||
// *******************************************************************************
|
||||
|
|
|
@ -76,7 +76,7 @@ TensorRow ProjectOp::Project(const TensorRow &row) {
|
|||
// ensure that it is not called by mistake (it will generate an error).
|
||||
Status ProjectOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. ProjectOp is an inlined operator."); }
|
||||
|
||||
int32_t ProjectOp::num_consumers() const {
|
||||
int32_t ProjectOp::NumConsumers() const {
|
||||
if (parent_.empty()) {
|
||||
MS_LOG(DEBUG) << "Project operator, no parent node, assuming it's the root and returning 1.";
|
||||
return 1;
|
||||
|
@ -84,16 +84,16 @@ int32_t ProjectOp::num_consumers() const {
|
|||
MS_LOG(DEBUG) << "Project operator, pointer to the first parent is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return parent_[0]->num_consumers();
|
||||
return parent_[0]->NumConsumers();
|
||||
}
|
||||
}
|
||||
|
||||
int32_t ProjectOp::num_producers() const {
|
||||
int32_t ProjectOp::NumProducers() const {
|
||||
if (child_.empty() || child_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Project operator, pointer to child node is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return child_[0]->num_producers();
|
||||
return child_[0]->NumProducers();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -63,11 +63,11 @@ class ProjectOp : public PipelineOp {
|
|||
|
||||
// Base-class override. Return the number of workers in the first parent.
|
||||
// @param workerId - The worker id
|
||||
int32_t num_consumers() const override;
|
||||
int32_t NumConsumers() const override;
|
||||
|
||||
// Base-class override. Return the number of producers in the first child.
|
||||
// @param workerId - The worker id
|
||||
int32_t num_producers() const override;
|
||||
int32_t NumProducers() const override;
|
||||
|
||||
// Base-class override for special eoe handler.
|
||||
// Inline operators must override this because there is no connector to push eoe onto.
|
||||
|
|
|
@ -113,7 +113,7 @@ void RenameOp::Print(std::ostream &out, // In: The output stream to print t
|
|||
}
|
||||
}
|
||||
|
||||
int32_t RenameOp::num_consumers() const {
|
||||
int32_t RenameOp::NumConsumers() const {
|
||||
if (parent_.empty()) {
|
||||
MS_LOG(DEBUG) << "Rename operator, no parent node, assuming it's the root and returning 1.";
|
||||
return 1;
|
||||
|
@ -121,16 +121,16 @@ int32_t RenameOp::num_consumers() const {
|
|||
MS_LOG(DEBUG) << "Rename operator, pointer to the first parent is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return parent_[0]->num_consumers();
|
||||
return parent_[0]->NumConsumers();
|
||||
}
|
||||
}
|
||||
|
||||
int32_t RenameOp::num_producers() const {
|
||||
int32_t RenameOp::NumProducers() const {
|
||||
if (child_.empty() || child_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Rename operator, pointer to child node is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return child_[0]->num_producers();
|
||||
return child_[0]->NumProducers();
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
|
|
|
@ -63,8 +63,8 @@ class RenameOp : public PipelineOp {
|
|||
// @param row - output pointer to the projected row.
|
||||
// @param worker_id - The worker id
|
||||
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
||||
int32_t num_consumers() const override;
|
||||
int32_t num_producers() const override;
|
||||
int32_t NumConsumers() const override;
|
||||
int32_t NumProducers() const override;
|
||||
|
||||
protected:
|
||||
// Rename core functionality
|
||||
|
|
|
@ -115,7 +115,7 @@ Status RepeatOp::EofReceived(int32_t worker_id) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
int32_t RepeatOp::num_consumers() const {
|
||||
int32_t RepeatOp::NumConsumers() const {
|
||||
if (parent_.empty()) {
|
||||
MS_LOG(DEBUG) << "Repeat operator, no parent node, assuming it's root and returning 1.";
|
||||
return 1;
|
||||
|
@ -123,7 +123,7 @@ int32_t RepeatOp::num_consumers() const {
|
|||
MS_LOG(DEBUG) << "Repeat operator, pointer to the first parent is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return parent_[0]->num_consumers();
|
||||
return parent_[0]->NumConsumers();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -140,12 +140,12 @@ Status RepeatOp::Reset() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
int32_t RepeatOp::num_producers() const {
|
||||
int32_t RepeatOp::NumProducers() const {
|
||||
if (child_.empty() || child_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return child_[0]->num_producers();
|
||||
return child_[0]->NumProducers();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -79,11 +79,11 @@ class RepeatOp : public PipelineOp {
|
|||
|
||||
// Base-class override. Return the number of workers in the first parent.
|
||||
// @param workerId - The worker id
|
||||
int32_t num_consumers() const override;
|
||||
int32_t NumConsumers() const override;
|
||||
|
||||
// Base-class override. Return the number of producers in the first child.
|
||||
// @param workerId - The worker id
|
||||
int32_t num_producers() const override;
|
||||
int32_t NumProducers() const override;
|
||||
|
||||
// Op name getter
|
||||
// @return Name of the current Op
|
||||
|
|
|
@ -66,7 +66,7 @@ Status SkipOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe)
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
int32_t SkipOp::num_consumers() const {
|
||||
int32_t SkipOp::NumConsumers() const {
|
||||
if (parent_.empty()) {
|
||||
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
|
||||
return 1;
|
||||
|
@ -74,16 +74,16 @@ int32_t SkipOp::num_consumers() const {
|
|||
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return parent_[0]->num_consumers();
|
||||
return parent_[0]->NumConsumers();
|
||||
}
|
||||
}
|
||||
|
||||
int32_t SkipOp::num_producers() const {
|
||||
int32_t SkipOp::NumProducers() const {
|
||||
if (child_.empty() || child_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return child_[0]->num_producers();
|
||||
return child_[0]->NumProducers();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -49,8 +49,8 @@ class SkipOp : public PipelineOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return kSkipOp; }
|
||||
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
||||
int32_t num_consumers() const override;
|
||||
int32_t num_producers() const override;
|
||||
int32_t NumConsumers() const override;
|
||||
int32_t NumProducers() const override;
|
||||
|
||||
private:
|
||||
int32_t max_skips_; // The number of skips that the user requested
|
||||
|
|
|
@ -175,9 +175,9 @@ bool CelebAOp::CheckDatasetTypeValid() {
|
|||
|
||||
Status CelebAOp::ParseImageAttrInfo() {
|
||||
std::vector<std::string> image_infos;
|
||||
bool needMoreData = true;
|
||||
bool need_more_data = true;
|
||||
RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos));
|
||||
while (!image_infos.empty() && needMoreData) {
|
||||
while (!image_infos.empty() && need_more_data) {
|
||||
for (uint32_t index = 0; index < image_infos.size(); index++) {
|
||||
std::string image_info = image_infos[index];
|
||||
std::vector<std::string> split = Split(image_info);
|
||||
|
|
|
@ -201,13 +201,13 @@ Status CifarOp::ReadCifar100BlockData() {
|
|||
}
|
||||
|
||||
Status CifarOp::GetCifarFiles() {
|
||||
const std::string kExtension = ".bin";
|
||||
const std::string extension = ".bin";
|
||||
Path dir_path(folder_path_);
|
||||
auto dirIt = Path::DirIterator::OpenDirectory(&dir_path);
|
||||
if (dirIt) {
|
||||
while (dirIt->HasNext()) {
|
||||
Path file = dirIt->Next();
|
||||
if (file.Extension() == kExtension) {
|
||||
if (file.Extension() == extension) {
|
||||
cifar_files_.push_back(file.ToString());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -101,7 +101,7 @@ class CifarOp : public MappableLeafOp {
|
|||
Status ParseCifarData();
|
||||
|
||||
/// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
|
||||
/// @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class
|
||||
/// @param (std::map<uint32_t, std::vector<uint32_t >> *cls_ids - val all ids for this class
|
||||
/// @return Status The status code returned
|
||||
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;
|
||||
|
||||
|
|
|
@ -119,20 +119,20 @@ Status ClueOp::LoadFile(const std::string &file, int64_t start_offset, int64_t e
|
|||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to parse JSON file: " + file);
|
||||
}
|
||||
int cols_count = cols_to_keyword_.size();
|
||||
TensorRow tRow(cols_count, nullptr);
|
||||
TensorRow t_row(cols_count, nullptr);
|
||||
// Add file path info
|
||||
std::vector<std::string> file_path(cols_count, file);
|
||||
tRow.setPath(file_path);
|
||||
t_row.setPath(file_path);
|
||||
int cout = 0;
|
||||
for (auto &p : cols_to_keyword_) {
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
RETURN_IF_NOT_OK(GetValue(js, p.second, &tensor));
|
||||
tRow[cout] = std::move(tensor);
|
||||
t_row[cout] = std::move(tensor);
|
||||
cout++;
|
||||
}
|
||||
|
||||
rows_total++;
|
||||
RETURN_IF_NOT_OK(jagged_rows_connector_->Add(worker_id, std::move(tRow)));
|
||||
RETURN_IF_NOT_OK(jagged_rows_connector_->Add(worker_id, std::move(t_row)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -215,7 +215,7 @@ Status GeneratorOp::operator()() {
|
|||
// Waiting for repeatOp to start new epoch
|
||||
// If Reset() is called first by repeat op, this wait() will return right away.
|
||||
// If Reset() is not called yet, this wait() will block until reset.
|
||||
if (this->op_total_repeats() < 0) {
|
||||
if (this->GetOpTotalRepeats() < 0) {
|
||||
RETURN_IF_NOT_OK(wp_.Wait());
|
||||
// Clear the status of the wait post
|
||||
wp_.Clear();
|
||||
|
@ -235,7 +235,7 @@ Status GeneratorOp::Reset() {
|
|||
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
|
||||
// Create new generator object
|
||||
RETURN_IF_NOT_OK(CreateGeneratorObject());
|
||||
if (this->op_total_repeats() < 0) {
|
||||
if (this->GetOpTotalRepeats() < 0) {
|
||||
// Wake up master thread
|
||||
wp_.Set();
|
||||
}
|
||||
|
|
|
@ -84,20 +84,20 @@ Status ImageFolderOp::PrescanMasterEntry(const std::string &filedir) {
|
|||
|
||||
// Load 1 TensorRow (image,label) using 1 ImageLabelPair. 1 function call produces 1 TensorTow
|
||||
Status ImageFolderOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
|
||||
ImageLabelPair pairPtr = image_label_pairs_[row_id];
|
||||
ImageLabelPair pair_ptr = image_label_pairs_[row_id];
|
||||
std::shared_ptr<Tensor> image, label;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(pairPtr->second, &label));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromFile(folder_path_ + (pairPtr->first), &image));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(pair_ptr->second, &label));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromFile(folder_path_ + (pair_ptr->first), &image));
|
||||
|
||||
if (decode_ == true) {
|
||||
Status rc = Decode(image, &image);
|
||||
if (rc.IsError()) {
|
||||
std::string err = "Invalid data, failed to decode image: " + folder_path_ + (pairPtr->first);
|
||||
std::string err = "Invalid data, failed to decode image: " + folder_path_ + (pair_ptr->first);
|
||||
RETURN_STATUS_UNEXPECTED(err);
|
||||
}
|
||||
}
|
||||
(*trow) = TensorRow(row_id, {std::move(image), std::move(label)});
|
||||
trow->setPath({folder_path_ + (pairPtr->first), std::string("")});
|
||||
trow->setPath({folder_path_ + (pair_ptr->first), std::string("")});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -57,11 +57,14 @@ class ImageFolderOp : public MappableLeafOp {
|
|||
// @param int32_t num_wkrs - Num of workers reading images in parallel
|
||||
// @param std::string - dir directory of ImageNetFolder
|
||||
// @param int32_t queue_size - connector queue size
|
||||
// @param std::set<std::string> exts - set of file extensions to read, if empty, read everything under the dir
|
||||
// @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
|
||||
// @param bool recursive - read recursively
|
||||
// @param bool do_decode - decode the images after reading
|
||||
// @param std::set<std::string> &exts - set of file extensions to read, if empty, read everything under the dir
|
||||
// @param std::map<std::string, int32_t> &map- map of folder name and class id
|
||||
// @param std::unique_ptr<dataschema> data_schema - schema of data
|
||||
ImageFolderOp(int32_t num_wkrs, std::string file_dir, int32_t queue_size, bool recursive, bool do_decode,
|
||||
const std::set<std::string> &exts, const std::map<std::string, int32_t> &map,
|
||||
std::unique_ptr<DataSchema>, std::shared_ptr<SamplerRT> sampler);
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
|
||||
|
||||
/// Destructor.
|
||||
~ImageFolderOp() = default;
|
||||
|
|
|
@ -260,7 +260,9 @@ Status MnistOp::WalkAllFiles() {
|
|||
const std::string train_prefix = "train";
|
||||
const std::string test_prefix = "t10k";
|
||||
|
||||
Path dir(folder_path_);
|
||||
std::string real_path{""};
|
||||
RETURN_IF_NOT_OK(Path::RealPath(folder_path_, real_path));
|
||||
Path dir(real_path);
|
||||
auto dir_it = Path::DirIterator::OpenDirectory(&dir);
|
||||
std::string prefix; // empty string, used to match usage = "" (default) or usage == "all"
|
||||
if (usage_ == "train" || usage_ == "test") prefix = (usage_ == "test" ? test_prefix : train_prefix);
|
||||
|
|
|
@ -68,15 +68,15 @@ void RandomDataOp::Print(std::ostream &out, bool show_all) const {
|
|||
|
||||
// Helper function to produce a default/random schema if one didn't exist
|
||||
void RandomDataOp::GenerateSchema() {
|
||||
const int32_t kTypeOffset = 2;
|
||||
const int32_t type_offset = 2;
|
||||
// To randomly create a schema, we need to choose:
|
||||
// a) how many columns
|
||||
// b) the type of each column
|
||||
// c) the shape of each column (number of dimensions i.e. rank)
|
||||
// d) the shape of each column (dimension values)
|
||||
data_schema_ = std::make_unique<DataSchema>();
|
||||
std::unique_ptr<TensorShape> newShape;
|
||||
std::unique_ptr<ColDescriptor> newCol;
|
||||
std::unique_ptr<TensorShape> new_shape;
|
||||
std::unique_ptr<ColDescriptor> new_col;
|
||||
|
||||
// Loop over the number of chosen columns
|
||||
int32_t numColumns = GenRandomInt(1, kMaxNumColumns);
|
||||
|
@ -84,23 +84,26 @@ void RandomDataOp::GenerateSchema() {
|
|||
// For each column:
|
||||
// - choose a datatype
|
||||
// - generate a shape that randomly chooses the number of dimensions and the dimension values.
|
||||
DataType::Type newType = static_cast<DataType::Type>(GenRandomInt(1, DataType::NUM_OF_TYPES - kTypeOffset));
|
||||
DataType::Type newType = static_cast<DataType::Type>(GenRandomInt(1, DataType::NUM_OF_TYPES - type_offset));
|
||||
int32_t rank = GenRandomInt(1, kMaxRank);
|
||||
std::vector<dsize_t> dims;
|
||||
for (int32_t d = 0; d < rank; d++) {
|
||||
// 0 is not a valid dimension value. however, we can support "*" or unknown, so map the random
|
||||
// 0 value to the unknown attribute if 0 is chosen
|
||||
dsize_t dim_value = static_cast<dsize_t>(GenRandomInt(0, kMaxDimValue));
|
||||
if (dim_value == 0) dim_value = TensorShape::kDimUnknown;
|
||||
if (dim_value == 0) {
|
||||
dim_value = TensorShape::kDimUnknown;
|
||||
}
|
||||
dims.push_back(dim_value);
|
||||
}
|
||||
newShape = std::make_unique<TensorShape>(dims);
|
||||
new_shape = std::make_unique<TensorShape>(dims);
|
||||
|
||||
// Create the column descriptor
|
||||
std::string colName = "c" + std::to_string(i);
|
||||
newCol = std::make_unique<ColDescriptor>(colName, DataType(newType), TensorImpl::kFlexible, rank, newShape.get());
|
||||
std::string col_name = "c" + std::to_string(i);
|
||||
new_col =
|
||||
std::make_unique<ColDescriptor>(col_name, DataType(newType), TensorImpl::kFlexible, rank, new_shape.get());
|
||||
|
||||
Status rc = data_schema_->AddColumn(*newCol);
|
||||
Status rc = data_schema_->AddColumn(*new_col);
|
||||
if (rc.IsError()) MS_LOG(ERROR) << "Failed to generate a schema. Message:" << rc;
|
||||
}
|
||||
}
|
||||
|
@ -125,6 +128,9 @@ Status RandomDataOp::operator()() {
|
|||
DatasetOp::CreateConnector(num_producers_, num_workers_);
|
||||
}
|
||||
|
||||
if (num_workers_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, num_workers_ is zero.");
|
||||
}
|
||||
// Assign the number of rows to each worker in a round robin fashion.
|
||||
worker_max_rows_.reserve(num_workers_);
|
||||
worker_rows_packed_.reserve(num_workers_);
|
||||
|
|
|
@ -681,7 +681,7 @@ Status TFReaderOp::CountTotalRows(int64_t *out_total_rows, const std::vector<std
|
|||
|
||||
int64_t TFReaderOp::CountTotalRowsSectioned(const std::vector<std::string> &filenames, int64_t begin, int64_t end) {
|
||||
int64_t rows_read = 0;
|
||||
for (int i = begin; i < end; i++) {
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
auto realpath = Common::GetRealPath(filenames[i]);
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Get real path failed, path=" << filenames[i];
|
||||
|
|
|
@ -196,10 +196,10 @@ Status VOCOp::ParseImageIds() {
|
|||
Status VOCOp::ParseAnnotationIds() {
|
||||
std::vector<std::string> new_image_ids;
|
||||
for (auto id : image_ids_) {
|
||||
const std::string kAnnotationName =
|
||||
const std::string annotation_name =
|
||||
folder_path_ + std::string(kAnnotationsFolder) + id + std::string(kAnnotationExtension);
|
||||
RETURN_IF_NOT_OK(ParseAnnotationBbox(kAnnotationName));
|
||||
if (annotation_map_.find(kAnnotationName) != annotation_map_.end()) {
|
||||
RETURN_IF_NOT_OK(ParseAnnotationBbox(annotation_name));
|
||||
if (annotation_map_.find(annotation_name) != annotation_map_.end()) {
|
||||
new_image_ids.push_back(id);
|
||||
}
|
||||
}
|
||||
|
@ -226,7 +226,9 @@ void VOCOp::ParseNodeValue(XMLElement *bbox_node, const char *name, float *value
|
|||
*value = 0.0;
|
||||
if (bbox_node != nullptr) {
|
||||
XMLElement *node = bbox_node->FirstChildElement(name);
|
||||
if (node != nullptr) *value = node->FloatText();
|
||||
if (node != nullptr) {
|
||||
*value = node->FloatText();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -69,7 +69,7 @@ Status TakeOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe)
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
int32_t TakeOp::num_consumers() const {
|
||||
int32_t TakeOp::NumConsumers() const {
|
||||
if (parent_.empty()) {
|
||||
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
|
||||
return 1;
|
||||
|
@ -77,16 +77,16 @@ int32_t TakeOp::num_consumers() const {
|
|||
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return parent_[0]->num_consumers();
|
||||
return parent_[0]->NumConsumers();
|
||||
}
|
||||
}
|
||||
|
||||
int32_t TakeOp::num_producers() const {
|
||||
int32_t TakeOp::NumProducers() const {
|
||||
if (child_.empty() || child_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return child_[0]->num_producers();
|
||||
return child_[0]->NumProducers();
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
|
|
|
@ -59,8 +59,8 @@ class TakeOp : public PipelineOp {
|
|||
std::string Name() const override { return kTakeOp; }
|
||||
|
||||
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
||||
int32_t num_consumers() const override;
|
||||
int32_t num_producers() const override;
|
||||
int32_t NumConsumers() const override;
|
||||
int32_t NumProducers() const override;
|
||||
|
||||
private:
|
||||
int32_t max_takes_; // The number of takes that the user requested
|
||||
|
|
|
@ -132,7 +132,7 @@ Status ZipOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
int32_t ZipOp::num_consumers() const {
|
||||
int32_t ZipOp::NumConsumers() const {
|
||||
if (parent_.empty()) {
|
||||
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
|
||||
return 1;
|
||||
|
@ -140,16 +140,16 @@ int32_t ZipOp::num_consumers() const {
|
|||
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return parent_[0]->num_consumers();
|
||||
return parent_[0]->NumConsumers();
|
||||
}
|
||||
}
|
||||
|
||||
int32_t ZipOp::num_producers() const {
|
||||
int32_t ZipOp::NumProducers() const {
|
||||
if (child_.empty() || child_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return child_[0]->num_producers();
|
||||
return child_[0]->NumProducers();
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
|
|
|
@ -65,8 +65,8 @@ class ZipOp : public PipelineOp {
|
|||
std::string Name() const override { return kZipOp; }
|
||||
|
||||
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
||||
int32_t num_consumers() const override;
|
||||
int32_t num_producers() const override;
|
||||
int32_t NumConsumers() const override;
|
||||
int32_t NumProducers() const override;
|
||||
|
||||
private:
|
||||
// Special handle case where an empty row has been received from child iterator
|
||||
|
|
|
@ -79,7 +79,7 @@ Status ExecutionTree::AssociateNode(const std::shared_ptr<DatasetOp> &op) {
|
|||
tree_state_ = kDeTStateBuilding;
|
||||
|
||||
// Assign an id to the operator
|
||||
op->set_id(id_count_);
|
||||
op->SetId(id_count_);
|
||||
id_count_++;
|
||||
|
||||
// Assign our tree into the op so that each op has a link back to the tree
|
||||
|
|
|
@ -104,8 +104,8 @@ Status BatchNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
|||
|
||||
auto op = std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
|
||||
in_col_names_, out_col_names_, batch_size_func_, batch_map_func_, pad_map_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
#else
|
||||
node_ops->push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
|
||||
|
|
|
@ -87,8 +87,8 @@ Status BucketBatchByLengthNode::Build(std::vector<std::shared_ptr<DatasetOp>> *c
|
|||
auto op = std::make_shared<BucketBatchByLengthOp>(column_names_, bucket_boundaries_, bucket_batch_sizes_,
|
||||
element_length_function_, pad_info_, pad_to_bucket_boundary_,
|
||||
drop_remainder_, connector_que_size_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
if (bucket_boundaries_[0] == 0) {
|
||||
bucket_boundaries_.erase(bucket_boundaries_.begin());
|
||||
|
|
|
@ -56,8 +56,8 @@ void BuildSentenceVocabNode::Print(std::ostream &out) const {
|
|||
Status BuildSentenceVocabNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
auto op = std::make_shared<BuildSentencePieceVocabOp>(vocab_, col_names_, vocab_size_, character_coverage_,
|
||||
model_type_, params_, connector_que_size_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -54,8 +54,8 @@ Status BuildVocabNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node
|
|||
std::shared_ptr<BuildVocabOp> build_vocab_op;
|
||||
build_vocab_op = std::make_shared<BuildVocabOp>(vocab_, columns_, freq_range_, top_k_, special_tokens_,
|
||||
special_first_, num_workers_, connector_que_size_);
|
||||
build_vocab_op->set_total_repeats(GetTotalRepeats());
|
||||
build_vocab_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
build_vocab_op->SetTotalRepeats(GetTotalRepeats());
|
||||
build_vocab_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(build_vocab_op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -51,8 +51,8 @@ Status CacheLookupNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops)
|
|||
"Internal error. Attempt to create a cache lookup node without cache client.");
|
||||
RETURN_IF_NOT_OK(cache_->Build());
|
||||
RETURN_IF_NOT_OK(cache_->CreateCacheLookupOp(num_workers_, &lookup_op_, sampler_));
|
||||
lookup_op_->set_total_repeats(GetTotalRepeats());
|
||||
lookup_op_->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
lookup_op_->SetTotalRepeats(GetTotalRepeats());
|
||||
lookup_op_->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(lookup_op_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -48,8 +48,8 @@ Status CacheMergeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops)
|
|||
RETURN_IF_NOT_OK(cache_->Build());
|
||||
std::shared_ptr<DatasetOp> merge_op = nullptr;
|
||||
RETURN_IF_NOT_OK(cache_->CreateCacheMergeOp(num_workers_, &merge_op));
|
||||
merge_op->set_total_repeats(GetTotalRepeats());
|
||||
merge_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
merge_op->SetTotalRepeats(GetTotalRepeats());
|
||||
merge_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(merge_op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -55,8 +55,8 @@ Status CacheNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
|
|||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
cache_op->SetSampler(sampler_rt);
|
||||
cache_op->set_total_repeats(GetTotalRepeats());
|
||||
cache_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
cache_op->SetTotalRepeats(GetTotalRepeats());
|
||||
cache_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(cache_op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -130,8 +130,8 @@ Status ConcatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
|
|||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
op = std::make_shared<ConcatOp>(sampler_rt, children_flag_and_nums_, children_start_end_index_);
|
||||
}
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -251,7 +251,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
void HasCacheAbove() { descendant_of_cache_ = true; }
|
||||
|
||||
/// \brief Getter of the number of workers
|
||||
int32_t num_workers() { return num_workers_; }
|
||||
int32_t NumWorkers() { return num_workers_; }
|
||||
|
||||
/// \brief Getter of dataset cache
|
||||
std::shared_ptr<DatasetCache> GetDatasetCache() { return cache_; }
|
||||
|
|
|
@ -46,8 +46,8 @@ void EpochCtrlNode::Print(std::ostream &out) const {
|
|||
// Function to build the EpochCtrlOp
|
||||
Status EpochCtrlNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
auto new_op_ = std::make_shared<EpochCtrlOp>(repeat_count_);
|
||||
new_op_->set_total_repeats(GetTotalRepeats());
|
||||
new_op_->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
new_op_->SetTotalRepeats(GetTotalRepeats());
|
||||
new_op_->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(new_op_);
|
||||
op_ = new_op_;
|
||||
return Status::OK();
|
||||
|
|
|
@ -45,8 +45,8 @@ void FilterNode::Print(std::ostream &out) const {
|
|||
|
||||
Status FilterNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
auto op = std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, predicate_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -86,12 +86,12 @@ Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
|||
|
||||
if (!project_columns_.empty()) {
|
||||
auto project_op = std::make_shared<ProjectOp>(project_columns_);
|
||||
project_op->set_total_repeats(GetTotalRepeats());
|
||||
project_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
project_op->SetTotalRepeats(GetTotalRepeats());
|
||||
project_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(project_op);
|
||||
}
|
||||
map_op->set_total_repeats(GetTotalRepeats());
|
||||
map_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
map_op->SetTotalRepeats(GetTotalRepeats());
|
||||
map_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(map_op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -54,8 +54,8 @@ Status ProjectNode::ValidateParams() {
|
|||
|
||||
Status ProjectNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
auto op = std::make_shared<ProjectOp>(columns_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -59,8 +59,8 @@ Status RenameNode::ValidateParams() {
|
|||
|
||||
Status RenameNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
auto op = std::make_shared<RenameOp>(input_columns_, output_columns_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -40,8 +40,8 @@ void RepeatNode::Print(std::ostream &out) const { out << (Name() + "(count:" + s
|
|||
|
||||
Status RepeatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
auto new_op = std::make_shared<RepeatOp>(repeat_count_);
|
||||
new_op->set_total_repeats(GetTotalRepeats());
|
||||
new_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
new_op->SetTotalRepeats(GetTotalRepeats());
|
||||
new_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(new_op);
|
||||
op_ = new_op;
|
||||
|
||||
|
|
|
@ -45,8 +45,8 @@ void ShuffleNode::Print(std::ostream &out) const {
|
|||
// Function to build the ShuffleOp
|
||||
Status ShuffleNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
auto op = std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -40,8 +40,8 @@ void SkipNode::Print(std::ostream &out) const { out << (Name() + "(skip_count:"
|
|||
// Function to build the SkipOp
|
||||
Status SkipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
auto op = std::make_shared<SkipOp>(skip_count_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -76,8 +76,8 @@ Status AlbumNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
|||
|
||||
auto album_op = std::make_shared<AlbumOp>(num_workers_, dataset_dir_, connector_que_size_, decode_, extensions,
|
||||
std::move(schema), std::move(sampler_rt));
|
||||
album_op->set_total_repeats(GetTotalRepeats());
|
||||
album_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
album_op->SetTotalRepeats(GetTotalRepeats());
|
||||
album_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(album_op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -72,8 +72,8 @@ Status CelebANode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
|
|||
|
||||
auto celeba_op = std::make_shared<CelebAOp>(num_workers_, dataset_dir_, connector_que_size_, decode_, usage_,
|
||||
extensions_, std::move(schema), std::move(sampler_rt));
|
||||
celeba_op->set_total_repeats(GetTotalRepeats());
|
||||
celeba_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
celeba_op->SetTotalRepeats(GetTotalRepeats());
|
||||
celeba_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(celeba_op);
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -68,8 +68,8 @@ Status Cifar100Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
|
||||
auto cifar_op = std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, dataset_dir_,
|
||||
connector_que_size_, std::move(schema), std::move(sampler_rt));
|
||||
cifar_op->set_total_repeats(GetTotalRepeats());
|
||||
cifar_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
cifar_op->SetTotalRepeats(GetTotalRepeats());
|
||||
cifar_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(cifar_op);
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -66,8 +66,8 @@ Status Cifar10Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_op
|
|||
|
||||
auto cifar_op = std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, dataset_dir_,
|
||||
connector_que_size_, std::move(schema), std::move(sampler_rt));
|
||||
cifar_op->set_total_repeats(GetTotalRepeats());
|
||||
cifar_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
cifar_op->SetTotalRepeats(GetTotalRepeats());
|
||||
cifar_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(cifar_op);
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -196,12 +196,12 @@ Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
|||
// Add the shuffle op after this op
|
||||
RETURN_IF_NOT_OK(
|
||||
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
|
||||
shuffle_op->set_total_repeats(GetTotalRepeats());
|
||||
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
shuffle_op->SetTotalRepeats(GetTotalRepeats());
|
||||
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(shuffle_op);
|
||||
}
|
||||
clue_op->set_total_repeats(GetTotalRepeats());
|
||||
clue_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
clue_op->SetTotalRepeats(GetTotalRepeats());
|
||||
clue_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(clue_op);
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -134,8 +134,8 @@ Status CocoNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
|||
std::shared_ptr<CocoOp> op =
|
||||
std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, connector_que_size_, decode_,
|
||||
std::move(schema), std::move(sampler_rt), extra_metadata_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -134,12 +134,12 @@ Status CSVNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
|||
// Add the shuffle op after this op
|
||||
RETURN_IF_NOT_OK(
|
||||
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
|
||||
shuffle_op->set_total_repeats(GetTotalRepeats());
|
||||
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
shuffle_op->SetTotalRepeats(GetTotalRepeats());
|
||||
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(shuffle_op);
|
||||
}
|
||||
csv_op->set_total_repeats(GetTotalRepeats());
|
||||
csv_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
csv_op->SetTotalRepeats(GetTotalRepeats());
|
||||
csv_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(csv_op);
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -97,8 +97,8 @@ Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_
|
|||
if (reset_ancestor_ != nullptr) {
|
||||
reset_ancestor_->op_->AddToEoeList(op);
|
||||
}
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -74,8 +74,8 @@ Status ImageFolderNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const nod
|
|||
|
||||
auto op = std::make_shared<ImageFolderOp>(num_workers_, dataset_dir_, connector_que_size_, recursive_, decode_, exts_,
|
||||
class_indexing_, std::move(schema), std::move(sampler_rt));
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -96,8 +96,8 @@ Status ManifestNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
|
||||
manifest_op = std::make_shared<ManifestOp>(num_workers_, dataset_file_, connector_que_size_, decode_, class_index_,
|
||||
std::move(schema), std::move(sampler_rt), usage_);
|
||||
manifest_op->set_total_repeats(GetTotalRepeats());
|
||||
manifest_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
manifest_op->SetTotalRepeats(GetTotalRepeats());
|
||||
manifest_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(manifest_op);
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -212,8 +212,8 @@ Status MindDataNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
}
|
||||
|
||||
RETURN_IF_NOT_OK(mindrecord_op->Init());
|
||||
mindrecord_op->set_total_repeats(GetTotalRepeats());
|
||||
mindrecord_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
mindrecord_op->SetTotalRepeats(GetTotalRepeats());
|
||||
mindrecord_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(mindrecord_op);
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -62,8 +62,8 @@ Status MnistNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
|||
|
||||
auto op = std::make_shared<MnistOp>(usage_, num_workers_, dataset_dir_, connector_que_size_, std::move(schema),
|
||||
std::move(sampler_rt));
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -110,8 +110,8 @@ Status RandomNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
|
|||
|
||||
std::shared_ptr<RandomDataOp> op;
|
||||
op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, total_rows_, std::move(data_schema_));
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -102,12 +102,12 @@ Status TextFileNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
// Add the shuffle op after this op
|
||||
RETURN_IF_NOT_OK(
|
||||
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
|
||||
shuffle_op->set_total_repeats(GetTotalRepeats());
|
||||
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
shuffle_op->SetTotalRepeats(GetTotalRepeats());
|
||||
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(shuffle_op);
|
||||
}
|
||||
text_file_op->set_total_repeats(GetTotalRepeats());
|
||||
text_file_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
text_file_op->SetTotalRepeats(GetTotalRepeats());
|
||||
text_file_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
// Add TextFileOp
|
||||
node_ops->push_back(text_file_op);
|
||||
|
||||
|
|
|
@ -142,12 +142,12 @@ Status TFRecordNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
|
||||
// Add the shuffle op after this op
|
||||
RETURN_IF_NOT_OK(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
|
||||
shuffle_op->set_total_repeats(GetTotalRepeats());
|
||||
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
shuffle_op->SetTotalRepeats(GetTotalRepeats());
|
||||
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(shuffle_op);
|
||||
}
|
||||
tf_reader_op->set_total_repeats(GetTotalRepeats());
|
||||
tf_reader_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
tf_reader_op->SetTotalRepeats(GetTotalRepeats());
|
||||
tf_reader_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
// Add TFReaderOp
|
||||
node_ops->push_back(tf_reader_op);
|
||||
return Status::OK();
|
||||
|
|
|
@ -122,8 +122,8 @@ Status VOCNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
|||
std::shared_ptr<VOCOp> voc_op;
|
||||
voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, connector_que_size_,
|
||||
decode_, std::move(schema), std::move(sampler_rt), extra_metadata_);
|
||||
voc_op->set_total_repeats(GetTotalRepeats());
|
||||
voc_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
voc_op->SetTotalRepeats(GetTotalRepeats());
|
||||
voc_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(voc_op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -46,8 +46,8 @@ Status SyncWaitNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
// The reason for this is because having it otherwise can lead to blocking issues
|
||||
// See barrier_op.h for more details
|
||||
auto op = std::make_shared<BarrierOp>(connector_que_size_, condition_name_, callback_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -41,8 +41,8 @@ void TakeNode::Print(std::ostream &out) const { out << (Name() + "(num_rows:" +
|
|||
// Function to build the TakeOp
|
||||
Status TakeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
auto op = std::make_shared<TakeOp>(take_count_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -97,8 +97,8 @@ Status TransferNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
|
||||
auto op = std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_,
|
||||
total_batch_, create_data_info_queue_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -60,8 +60,8 @@ Status ZipNode::ValidateParams() {
|
|||
|
||||
Status ZipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
auto op = std::make_shared<ZipOp>(connector_que_size_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -68,10 +68,10 @@ Status AutoWorkerPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *con
|
|||
int32_t cur_node_num_worker = std::max(std::min(num_workers, cur_node_max), min_num_workers_);
|
||||
|
||||
// if the num_worker to set is same as original, skip setting and printing the logs
|
||||
if (cur_node_num_worker == p.first->num_workers()) continue;
|
||||
if (cur_node_num_worker == p.first->NumWorkers()) continue;
|
||||
// log the change via warning msg so user can see what the num_worker is being set for which op
|
||||
MS_LOG(WARNING) << "AutoNumWorker enabled, num_workers in " << p.first->Name() << " is auto-adjusted from "
|
||||
<< std::to_string(p.first->num_workers()) + " to " + std::to_string(cur_node_num_worker);
|
||||
<< std::to_string(p.first->NumWorkers()) + " to " + std::to_string(cur_node_num_worker);
|
||||
p.first->SetNumWorkers(cur_node_num_worker);
|
||||
}
|
||||
return Status::OK();
|
||||
|
|
|
@ -49,7 +49,7 @@ Status DeepCopyPass::Visit(std::shared_ptr<DatasetNode> node, bool *const modifi
|
|||
// Temporary fix to set the num_workers to each cloned node.
|
||||
// This can be improved by adding a new method in the base class DatasetNode to transfer the properties to
|
||||
// the cloned node. Each derived class's Copy() will need to include this method.
|
||||
new_node->SetNumWorkers(node->num_workers());
|
||||
new_node->SetNumWorkers(node->NumWorkers());
|
||||
// This method below assumes a DFS walk and from the first child to the last child.
|
||||
// Future: A more robust implementation that does not depend on the above assumption.
|
||||
RETURN_IF_NOT_OK(parent_->AppendChild(new_node));
|
||||
|
|
|
@ -42,7 +42,7 @@ json ConnectorSize::ParseOpInfo(const DatasetOp &node, const std::vector<int32_t
|
|||
json json_node;
|
||||
json_node["op_id"] = node.id();
|
||||
json_node["op_type"] = node.Name();
|
||||
json_node["num_workers"] = node.num_workers();
|
||||
json_node["num_workers"] = node.NumWorkers();
|
||||
json metrics;
|
||||
// DeviceQueueOp is a special op,it is not inlined but its output queue is invalid.
|
||||
// So we should not output its queue size.
|
||||
|
|
|
@ -78,7 +78,7 @@ json ConnectorThroughput::ParseOpInfo(const DatasetOp &node, const std::vector<d
|
|||
json json_node;
|
||||
json_node["op_id"] = node.id();
|
||||
json_node["op_type"] = node.Name();
|
||||
json_node["num_workers"] = node.num_workers();
|
||||
json_node["num_workers"] = node.NumWorkers();
|
||||
json metrics;
|
||||
// DeviceQueueOp is a special op,it is not inlined but its output queue is invalid.
|
||||
// So we should not output its connector throughput.
|
||||
|
|
|
@ -264,7 +264,7 @@ Status OperatorCpu::Collect(const ExecutionTree *tree) {
|
|||
for (auto iter = tree->begin(); iter != tree->end(); ++iter) {
|
||||
id_count_++;
|
||||
op_name_[iter->id()] = iter->NameWithID();
|
||||
op_parallel_workers_[iter->id()] = iter->num_workers();
|
||||
op_parallel_workers_[iter->id()] = iter->NumWorkers();
|
||||
}
|
||||
#if defined(USING_LINUX)
|
||||
cpu_processor_num_ = get_nprocs_conf();
|
||||
|
|
|
@ -219,7 +219,7 @@ Status TreeAdapter::GetNext(TensorRow *row) {
|
|||
RETURN_UNEXPECTED_IF_NULL(row);
|
||||
row->clear(); // make sure row is empty
|
||||
|
||||
bool isProfilingEnable = tree_->GetProfilingManager()->IsProfilingEnable();
|
||||
bool is_profiling_enable = tree_->GetProfilingManager()->IsProfilingEnable();
|
||||
|
||||
// When cur_db_ is a nullptr, it means this is the first call to get_next, launch ExecutionTree
|
||||
if (!launched_) {
|
||||
|
@ -229,7 +229,7 @@ Status TreeAdapter::GetNext(TensorRow *row) {
|
|||
RETURN_IF_NOT_OK(tree_->root()->GetNextRow(row)); // first buf can't be eof or empty buf with none flag
|
||||
if (row->eoe()) { // return empty tensor if 1st buf is a ctrl buf (no rows)
|
||||
MS_LOG(INFO) << "End of data iteration.";
|
||||
if (isProfilingEnable) {
|
||||
if (is_profiling_enable) {
|
||||
tree_->SetEpochEnd();
|
||||
}
|
||||
return Status::OK();
|
||||
|
|
|
@ -119,6 +119,9 @@ class Execute {
|
|||
/// \brief The function to validate target device setting is valid or not.
|
||||
Status ValidateDevice();
|
||||
|
||||
/// \brief Initialize 310 resource
|
||||
Status InitResource(MapTargetDevice device_type, uint32_t device_id);
|
||||
|
||||
std::vector<std::shared_ptr<TensorTransform>> transforms_;
|
||||
std::vector<std::shared_ptr<TensorOperation>> ops_;
|
||||
MapTargetDevice device_type_;
|
||||
|
|
|
@ -251,6 +251,14 @@ static bool Conv2DImplement(const LiteMat &src, const LiteMat &kernel, T2 *dst,
|
|||
int border_y = static_cast<int>(kernel.height_ / 2);
|
||||
|
||||
LiteMat pad_mat;
|
||||
|
||||
if ((border_x > INT_MAX / 2) || (src.width_ > INT_MAX - 2 * border_x)) {
|
||||
return false;
|
||||
}
|
||||
if ((border_y > INT_MAX / 2) || (src.height_ > INT_MAX - 2 * border_y)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
pad_mat.Init(src.width_ + 2 * border_x, src.height_ + 2 * border_y, src.channel_, src.data_type_);
|
||||
|
||||
if (!Pad(src, pad_mat, border_y, border_y, border_x, border_x, pad_type)) {
|
||||
|
|
|
@ -30,7 +30,7 @@ Status ComputeUpperAndLowerPercentiles(std::vector<int32_t> *hist, int32_t hi_p,
|
|||
int32_t n = std::accumulate(hist->begin(), hist->end(), 0);
|
||||
constexpr float kMaxPerc = 100.0;
|
||||
int32_t cut = static_cast<int32_t>((low_p / kMaxPerc) * n);
|
||||
for (int32_t lb = 0; lb < hist->size() + 1 && cut > 0; lb++) {
|
||||
for (int32_t lb = 0; lb < hist->size() && cut > 0; lb++) {
|
||||
if (cut > (*hist)[lb]) {
|
||||
cut -= (*hist)[lb];
|
||||
(*hist)[lb] = 0;
|
||||
|
|
|
@ -41,7 +41,9 @@ Status MixUpBatchOp::ComputeLabels(const TensorRow &input, std::shared_ptr<Tenso
|
|||
const float lam, const size_t images_size) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(images_size <= static_cast<size_t>(std::numeric_limits<int64_t>::max()),
|
||||
"The \"images_size\" must not be more than \"INT64_MAX\".");
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(images_size); i++) rand_indx->push_back(i);
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(images_size); i++) {
|
||||
rand_indx->push_back(i);
|
||||
}
|
||||
std::shuffle(rand_indx->begin(), rand_indx->end(), rnd_);
|
||||
|
||||
RETURN_IF_NOT_OK(TypeCast(std::move(input.at(1)), out_labels, DataType(DataType::DE_FLOAT32)));
|
||||
|
|
|
@ -48,6 +48,12 @@ RandomAffineOp::RandomAffineOp(std::vector<float_t> degrees, std::vector<float_t
|
|||
|
||||
Status RandomAffineOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(translate_range_.size() == 4, "RandomAffine: the translate range size is not 4.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(degrees_range_.size() == 2, "RandomAffine: the degrees range size is not 2.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(scale_range_.size() == 2, "RandomAffine: the scale range size is not 2.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(shear_ranges_.size() == 4, "RandomAffine: the shear ranges size is not 4.");
|
||||
|
||||
dsize_t height = input->shape()[0];
|
||||
dsize_t width = input->shape()[1];
|
||||
CHECK_FAIL_RETURN_UNEXPECTED((std::numeric_limits<float_t>::max() / std::abs(translate_range_[0])) > width,
|
||||
|
|
|
@ -20,7 +20,10 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
IntrpService::IntrpService() : high_water_mark_(0) { (void)ServiceStart(); }
|
||||
IntrpService::IntrpService() try : high_water_mark_(0) { (void)ServiceStart(); } catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "Interrupt service failed: " << e.what() << ".";
|
||||
std::terminate();
|
||||
}
|
||||
|
||||
IntrpService::~IntrpService() noexcept {
|
||||
MS_LOG(INFO) << "Number of registered resources is " << high_water_mark_ << ".";
|
||||
|
|
|
@ -63,7 +63,6 @@ void RWLock::Unlock() noexcept {
|
|||
|
||||
void RWLock::Upgrade() {
|
||||
std::unique_lock<std::mutex> lck(mtx_);
|
||||
MS_ASSERT(status_);
|
||||
if (status_ == -1) {
|
||||
// I am a writer already.
|
||||
return;
|
||||
|
@ -81,7 +80,6 @@ void RWLock::Upgrade() {
|
|||
|
||||
void RWLock::Downgrade() {
|
||||
std::unique_lock<std::mutex> lck(mtx_);
|
||||
MS_ASSERT(status_);
|
||||
if (status_ == -1) {
|
||||
// If there are no other writers waiting, just change the status
|
||||
if (waiting_writers_ == 0) {
|
||||
|
@ -111,26 +109,18 @@ SharedLock::~SharedLock() {
|
|||
}
|
||||
|
||||
void SharedLock::Unlock() {
|
||||
MS_ASSERT(ownlock_ == true);
|
||||
rw_->Unlock();
|
||||
ownlock_ = false;
|
||||
}
|
||||
|
||||
void SharedLock::Lock() {
|
||||
MS_ASSERT(ownlock_ == false);
|
||||
rw_->LockShared();
|
||||
ownlock_ = true;
|
||||
}
|
||||
|
||||
void SharedLock::Upgrade() {
|
||||
MS_ASSERT(ownlock_ == true);
|
||||
rw_->Upgrade();
|
||||
}
|
||||
void SharedLock::Upgrade() { rw_->Upgrade(); }
|
||||
|
||||
void SharedLock::Downgrade() {
|
||||
MS_ASSERT(ownlock_ == true);
|
||||
rw_->Downgrade();
|
||||
}
|
||||
void SharedLock::Downgrade() { rw_->Downgrade(); }
|
||||
|
||||
UniqueLock::UniqueLock(RWLock *rw) : rw_(rw), ownlock_(false) {
|
||||
rw_->LockExclusive();
|
||||
|
@ -146,13 +136,11 @@ UniqueLock::~UniqueLock() {
|
|||
}
|
||||
|
||||
void UniqueLock::Unlock() {
|
||||
MS_ASSERT(ownlock_ == true);
|
||||
rw_->Unlock();
|
||||
ownlock_ = false;
|
||||
}
|
||||
|
||||
void UniqueLock::Lock() {
|
||||
MS_ASSERT(ownlock_ == false);
|
||||
rw_->LockExclusive();
|
||||
ownlock_ = true;
|
||||
}
|
||||
|
@ -171,13 +159,11 @@ LockGuard::~LockGuard() {
|
|||
}
|
||||
|
||||
void LockGuard::Unlock() {
|
||||
MS_ASSERT(own_lock_);
|
||||
lck_->Unlock();
|
||||
own_lock_ = false;
|
||||
}
|
||||
|
||||
void LockGuard::Lock() {
|
||||
MS_ASSERT(own_lock_ == false);
|
||||
lck_->Lock();
|
||||
own_lock_ = true;
|
||||
}
|
||||
|
|
|
@ -39,9 +39,9 @@ class Path {
|
|||
private:
|
||||
explicit DirIterator(Path *f);
|
||||
|
||||
Path *dir_;
|
||||
DIR *dp_;
|
||||
struct dirent *entry_;
|
||||
Path *dir_ = nullptr;
|
||||
DIR *dp_ = nullptr;
|
||||
struct dirent *entry_ = nullptr;
|
||||
};
|
||||
|
||||
explicit Path(const std::string &);
|
||||
|
|
|
@ -43,11 +43,11 @@ std::vector<std::string> StringSplit(const std::string &field, char separator) {
|
|||
}
|
||||
|
||||
bool ValidateFieldName(const std::string &str) {
|
||||
std::string::const_iterator it = str.begin();
|
||||
if (it == str.end()) {
|
||||
auto it = str.cbegin();
|
||||
if (it == str.cend()) {
|
||||
return false;
|
||||
}
|
||||
for (; it != str.end(); ++it) {
|
||||
for (; it != str.cend(); ++it) {
|
||||
if (*it == '_' || ((*it >= '0') && (*it <= '9')) || ((*it >= 'A') && (*it <= 'Z')) ||
|
||||
((*it >= 'a') && (*it <= 'z'))) {
|
||||
continue;
|
||||
|
@ -83,8 +83,7 @@ std::pair<MSRStatus, std::string> GetFileName(const std::string &path) {
|
|||
}
|
||||
#endif
|
||||
std::string s = real_path;
|
||||
char sep = '/';
|
||||
size_t i = s.rfind(sep, s.length());
|
||||
size_t i = s.rfind(kPathSeparator, s.length());
|
||||
if (i != std::string::npos) {
|
||||
if (i + 1 < s.size()) {
|
||||
return {SUCCESS, s.substr(i + 1)};
|
||||
|
@ -119,10 +118,10 @@ std::pair<MSRStatus, std::string> GetParentDir(const std::string &path) {
|
|||
}
|
||||
#endif
|
||||
std::string s = real_path;
|
||||
if (s.rfind('/') + 1 <= s.size()) {
|
||||
return {SUCCESS, s.substr(0, s.rfind('/') + 1)};
|
||||
if (s.rfind(kPathSeparator) + 1 <= s.size()) {
|
||||
return {SUCCESS, s.substr(0, s.rfind(kPathSeparator) + 1)};
|
||||
}
|
||||
return {SUCCESS, "/"};
|
||||
return {SUCCESS, std::string()};
|
||||
}
|
||||
|
||||
bool CheckIsValidUtf8(const std::string &str) {
|
||||
|
@ -155,7 +154,7 @@ bool CheckIsValidUtf8(const std::string &str) {
|
|||
bool IsLegalFile(const std::string &path) {
|
||||
struct stat s;
|
||||
if (stat(common::SafeCStr(path), &s) == 0) {
|
||||
if (s.st_mode & S_IFDIR) {
|
||||
if (S_ISDIR(s.st_mode)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue