forked from mindspore-Ecosystem/mindspore
!23357 Add DatasetName method for out developer
Merge pull request !23357 from xiefangqi/md_white_list_optimization
This commit is contained in:
commit
5abd4821ce
|
@ -84,7 +84,7 @@ Status AlbumOp::PrescanEntry() {
|
||||||
num_rows_ = image_rows_.size();
|
num_rows_ = image_rows_.size();
|
||||||
if (num_rows_ == 0) {
|
if (num_rows_ == 0) {
|
||||||
RETURN_STATUS_UNEXPECTED(
|
RETURN_STATUS_UNEXPECTED(
|
||||||
"Invalid data, AlbumDataset API can't read the data file(interface mismatch or no data found). "
|
"Invalid data, AlbumDataset API can't read the data file (interface mismatch or no data found). "
|
||||||
"Check file path:" +
|
"Check file path:" +
|
||||||
folder_path_ + ".");
|
folder_path_ + ".");
|
||||||
}
|
}
|
||||||
|
|
|
@ -217,7 +217,7 @@ Status CelebAOp::ParseImageAttrInfo() {
|
||||||
num_rows_ = image_labels_vec_.size();
|
num_rows_ = image_labels_vec_.size();
|
||||||
if (num_rows_ == 0) {
|
if (num_rows_ == 0) {
|
||||||
RETURN_STATUS_UNEXPECTED(
|
RETURN_STATUS_UNEXPECTED(
|
||||||
"Invalid data, CelebADataset API can't read the data file(interface mismatch or no data found). "
|
"Invalid data, CelebADataset API can't read the data file (interface mismatch or no data found). "
|
||||||
"Check file path: " +
|
"Check file path: " +
|
||||||
folder_path_);
|
folder_path_);
|
||||||
}
|
}
|
||||||
|
|
|
@ -256,7 +256,7 @@ Status CifarOp::ParseCifarData() {
|
||||||
if (num_rows_ == 0) {
|
if (num_rows_ == 0) {
|
||||||
std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset";
|
std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset";
|
||||||
RETURN_STATUS_UNEXPECTED("Invalid data, " + api +
|
RETURN_STATUS_UNEXPECTED("Invalid data, " + api +
|
||||||
" API can't read the data file(interface mismatch or no data found). "
|
" API can't read the data file (interface mismatch or no data found). "
|
||||||
"Check file in directory:" +
|
"Check file in directory:" +
|
||||||
folder_path_);
|
folder_path_);
|
||||||
}
|
}
|
||||||
|
|
|
@ -221,7 +221,7 @@ Status ClueOp::CalculateNumRowsPerShard() {
|
||||||
}
|
}
|
||||||
std::string file_list = ss.str();
|
std::string file_list = ss.str();
|
||||||
RETURN_STATUS_UNEXPECTED(
|
RETURN_STATUS_UNEXPECTED(
|
||||||
"Invalid data, ClueDataset API can't read the data file(interface mismatch or no data found). "
|
"Invalid data, ClueDataset API can't read the data file (interface mismatch or no data found). "
|
||||||
"Check file path:" +
|
"Check file path:" +
|
||||||
file_list);
|
file_list);
|
||||||
}
|
}
|
||||||
|
|
|
@ -322,7 +322,7 @@ Status CocoOp::ParseAnnotationIds() {
|
||||||
num_rows_ = image_ids_.size();
|
num_rows_ = image_ids_.size();
|
||||||
if (num_rows_ == 0) {
|
if (num_rows_ == 0) {
|
||||||
RETURN_STATUS_UNEXPECTED(
|
RETURN_STATUS_UNEXPECTED(
|
||||||
"Invalid data, CocoDataset API can't read the data file(interface mismatch or no data found). "
|
"Invalid data, CocoDataset API can't read the data file (interface mismatch or no data found). "
|
||||||
"Check file in directory: " +
|
"Check file in directory: " +
|
||||||
image_folder_path_ + ".");
|
image_folder_path_ + ".");
|
||||||
}
|
}
|
||||||
|
|
|
@ -454,14 +454,14 @@ Status CsvOp::LoadFile(const std::string &file, int64_t start_offset, int64_t en
|
||||||
|
|
||||||
auto realpath = FileUtils::GetRealPath(file.data());
|
auto realpath = FileUtils::GetRealPath(file.data());
|
||||||
if (!realpath.has_value()) {
|
if (!realpath.has_value()) {
|
||||||
MS_LOG(ERROR) << "Invalid file, get real path failed, path=" << file;
|
MS_LOG(ERROR) << "Invalid file, " + DatasetName() + " file get real path failed, path=" << file;
|
||||||
RETURN_STATUS_UNEXPECTED("Invalid file, get real path failed, path=" + file);
|
RETURN_STATUS_UNEXPECTED("Invalid file, " + DatasetName() + " file get real path failed, path=" + file);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::ifstream ifs;
|
std::ifstream ifs;
|
||||||
ifs.open(realpath.value(), std::ifstream::in);
|
ifs.open(realpath.value(), std::ifstream::in);
|
||||||
if (!ifs.is_open()) {
|
if (!ifs.is_open()) {
|
||||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file);
|
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open " + DatasetName() + " file: " + file);
|
||||||
}
|
}
|
||||||
if (column_name_list_.empty()) {
|
if (column_name_list_.empty()) {
|
||||||
std::string tmp;
|
std::string tmp;
|
||||||
|
@ -505,7 +505,8 @@ void CsvOp::Print(std::ostream &out, bool show_all) const {
|
||||||
ParallelOp::Print(out, show_all);
|
ParallelOp::Print(out, show_all);
|
||||||
// Then show any custom derived-internal stuff
|
// Then show any custom derived-internal stuff
|
||||||
out << "\nSample count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
|
out << "\nSample count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
|
||||||
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nCsv files list:\n";
|
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\n"
|
||||||
|
<< DatasetName(true) << " files list:\n";
|
||||||
for (int i = 0; i < csv_files_list_.size(); ++i) {
|
for (int i = 0; i < csv_files_list_.size(); ++i) {
|
||||||
out << " " << csv_files_list_[i];
|
out << " " << csv_files_list_[i];
|
||||||
}
|
}
|
||||||
|
@ -574,10 +575,9 @@ Status CsvOp::CalculateNumRowsPerShard() {
|
||||||
ss << " " << csv_files_list_[i];
|
ss << " " << csv_files_list_[i];
|
||||||
}
|
}
|
||||||
std::string file_list = ss.str();
|
std::string file_list = ss.str();
|
||||||
RETURN_STATUS_UNEXPECTED(
|
RETURN_STATUS_UNEXPECTED("Invalid data, " + DatasetName(true) +
|
||||||
"Invalid data, CSVDataset API can't read the data file(interface mismatch or no data found). "
|
"Dataset API can't read the data file (interface mismatch or no data found). Check " +
|
||||||
"Check file path: " +
|
DatasetName() + " file path: " + file_list + ".");
|
||||||
file_list + ".");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
|
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
|
||||||
|
@ -589,13 +589,13 @@ int64_t CsvOp::CountTotalRows(const std::string &file) {
|
||||||
CsvParser csv_parser(0, jagged_rows_connector_.get(), field_delim_, column_default_list_, file);
|
CsvParser csv_parser(0, jagged_rows_connector_.get(), field_delim_, column_default_list_, file);
|
||||||
Status rc = csv_parser.InitCsvParser();
|
Status rc = csv_parser.InitCsvParser();
|
||||||
if (rc.IsError()) {
|
if (rc.IsError()) {
|
||||||
MS_LOG(ERROR) << "[Internal ERROR], failed to initialize CSV Parser. Error:" << rc;
|
MS_LOG(ERROR) << "[Internal ERROR], failed to initialize " + DatasetName(true) + " Parser. Error:" << rc;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto realpath = FileUtils::GetRealPath(file.data());
|
auto realpath = FileUtils::GetRealPath(file.data());
|
||||||
if (!realpath.has_value()) {
|
if (!realpath.has_value()) {
|
||||||
MS_LOG(ERROR) << "Invalid file, get real path failed, path=" << file;
|
MS_LOG(ERROR) << "Invalid file, " + DatasetName() + " file get real path failed, path=" << file;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -659,7 +659,8 @@ Status CsvOp::ComputeColMap() {
|
||||||
// Set the column name mapping (base class field)
|
// Set the column name mapping (base class field)
|
||||||
if (column_name_id_map_.empty()) {
|
if (column_name_id_map_.empty()) {
|
||||||
if (!ColumnNameValidate()) {
|
if (!ColumnNameValidate()) {
|
||||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to obtain column name from input CSV file list.");
|
RETURN_STATUS_UNEXPECTED("Invalid file, failed to obtain column name from input " + DatasetName() +
|
||||||
|
" file list.");
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto &csv_file : csv_files_list_) {
|
for (auto &csv_file : csv_files_list_) {
|
||||||
|
@ -697,8 +698,9 @@ Status CsvOp::ColMapAnalyse(const std::string &csv_file_name) {
|
||||||
if (!check_flag_) {
|
if (!check_flag_) {
|
||||||
auto realpath = FileUtils::GetRealPath(csv_file_name.data());
|
auto realpath = FileUtils::GetRealPath(csv_file_name.data());
|
||||||
if (!realpath.has_value()) {
|
if (!realpath.has_value()) {
|
||||||
MS_LOG(ERROR) << "Invalid file, get real path failed, path=" << csv_file_name;
|
std::string err_msg = "Invalid file, " + DatasetName() + " file get real path failed, path=" + csv_file_name;
|
||||||
RETURN_STATUS_UNEXPECTED("Invalid file, get real path failed, path=" + csv_file_name);
|
MS_LOG(ERROR) << err_msg;
|
||||||
|
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string line;
|
std::string line;
|
||||||
|
@ -757,7 +759,7 @@ bool CsvOp::ColumnNameValidate() {
|
||||||
for (auto &csv_file : csv_files_list_) {
|
for (auto &csv_file : csv_files_list_) {
|
||||||
auto realpath = FileUtils::GetRealPath(csv_file.data());
|
auto realpath = FileUtils::GetRealPath(csv_file.data());
|
||||||
if (!realpath.has_value()) {
|
if (!realpath.has_value()) {
|
||||||
MS_LOG(ERROR) << "Invalid file, get real path failed, path=" << csv_file;
|
MS_LOG(ERROR) << "Invalid file, " + DatasetName() + " file get real path failed, path=" << csv_file;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -179,6 +179,10 @@ class CsvOp : public NonMappableLeafOp {
|
||||||
/// @return Name of the current Op
|
/// @return Name of the current Op
|
||||||
std::string Name() const override { return "CsvOp"; }
|
std::string Name() const override { return "CsvOp"; }
|
||||||
|
|
||||||
|
// DatasetName name getter
|
||||||
|
// \return DatasetName of the current Op
|
||||||
|
virtual std::string DatasetName(bool upper = false) const { return upper ? "CSV" : "csv"; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Parses a single row and puts the data into a tensor table.
|
// Parses a single row and puts the data into a tensor table.
|
||||||
// @param line - the content of the row.
|
// @param line - the content of the row.
|
||||||
|
|
|
@ -72,10 +72,9 @@ Status ImageFolderOp::PrescanMasterEntry(const std::string &filedir) {
|
||||||
image_label_pairs_.shrink_to_fit();
|
image_label_pairs_.shrink_to_fit();
|
||||||
num_rows_ = image_label_pairs_.size();
|
num_rows_ = image_label_pairs_.size();
|
||||||
if (num_rows_ == 0) {
|
if (num_rows_ == 0) {
|
||||||
RETURN_STATUS_UNEXPECTED(
|
RETURN_STATUS_UNEXPECTED("Invalid data, " + DatasetName(true) +
|
||||||
"Invalid data, ImageFolderDataset API can't read the data file(interface mismatch or no data found). "
|
"Dataset API can't read the data file (interface mismatch or no data found). Check " +
|
||||||
"Check file path: " +
|
DatasetName() + " file path: " + folder_path_);
|
||||||
folder_path_);
|
|
||||||
}
|
}
|
||||||
// free memory of two queues used for pre-scan
|
// free memory of two queues used for pre-scan
|
||||||
folder_name_queue_->Reset();
|
folder_name_queue_->Reset();
|
||||||
|
@ -112,8 +111,8 @@ void ImageFolderOp::Print(std::ostream &out, bool show_all) const {
|
||||||
// Call the super class for displaying any common detailed info
|
// Call the super class for displaying any common detailed info
|
||||||
ParallelOp::Print(out, show_all);
|
ParallelOp::Print(out, show_all);
|
||||||
// Then show any custom derived-internal stuff
|
// Then show any custom derived-internal stuff
|
||||||
out << "\nNumber of rows:" << num_rows_ << "\nImageFolder directory: " << folder_path_
|
out << "\nNumber of rows:" << num_rows_ << "\n"
|
||||||
<< "\nDecode: " << (decode_ ? "yes" : "no") << "\n\n";
|
<< DatasetName(true) << " directory: " << folder_path_ << "\nDecode: " << (decode_ ? "yes" : "no") << "\n\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -121,10 +120,9 @@ void ImageFolderOp::Print(std::ostream &out, bool show_all) const {
|
||||||
Status ImageFolderOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
|
Status ImageFolderOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
|
||||||
if (cls_ids == nullptr || !cls_ids->empty() || image_label_pairs_.empty()) {
|
if (cls_ids == nullptr || !cls_ids->empty() || image_label_pairs_.empty()) {
|
||||||
if (image_label_pairs_.empty()) {
|
if (image_label_pairs_.empty()) {
|
||||||
RETURN_STATUS_UNEXPECTED(
|
RETURN_STATUS_UNEXPECTED("Invalid data, " + DatasetName(true) +
|
||||||
"Invalid data, ImageFolderDataset API can't read the data file(interface mismatch or no data found). "
|
"Dataset API can't read the data file(interface mismatch or no data found). Check " +
|
||||||
"Check file path: " +
|
DatasetName() + " file path: " + folder_path_);
|
||||||
folder_path_);
|
|
||||||
} else {
|
} else {
|
||||||
RETURN_STATUS_UNEXPECTED(
|
RETURN_STATUS_UNEXPECTED(
|
||||||
"[Internal ERROR], Map containing image-index pair is nullptr or has been set in other place,"
|
"[Internal ERROR], Map containing image-index pair is nullptr or has been set in other place,"
|
||||||
|
@ -157,7 +155,7 @@ Status ImageFolderOp::PrescanWorkerEntry(int32_t worker_id) {
|
||||||
Path folder(folder_path_ + folder_name);
|
Path folder(folder_path_ + folder_name);
|
||||||
std::shared_ptr<Path::DirIterator> dirItr = Path::DirIterator::OpenDirectory(&folder);
|
std::shared_ptr<Path::DirIterator> dirItr = Path::DirIterator::OpenDirectory(&folder);
|
||||||
if (folder.Exists() == false || dirItr == nullptr) {
|
if (folder.Exists() == false || dirItr == nullptr) {
|
||||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open folder: " + folder_name);
|
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open " + DatasetName() + ": " + folder_name);
|
||||||
}
|
}
|
||||||
std::set<std::string> imgs; // use this for ordering
|
std::set<std::string> imgs; // use this for ordering
|
||||||
while (dirItr->HasNext()) {
|
while (dirItr->HasNext()) {
|
||||||
|
@ -165,7 +163,7 @@ Status ImageFolderOp::PrescanWorkerEntry(int32_t worker_id) {
|
||||||
if (extensions_.empty() || extensions_.find(file.Extension()) != extensions_.end()) {
|
if (extensions_.empty() || extensions_.find(file.Extension()) != extensions_.end()) {
|
||||||
(void)imgs.insert(file.ToString().substr(dirname_offset_));
|
(void)imgs.insert(file.ToString().substr(dirname_offset_));
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(WARNING) << "Image folder operator unsupported file found: " << file.ToString()
|
MS_LOG(WARNING) << DatasetName(true) << " operator unsupported file found: " << file.ToString()
|
||||||
<< ", extension: " << file.Extension() << ".";
|
<< ", extension: " << file.Extension() << ".";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -207,7 +205,7 @@ Status ImageFolderOp::StartAsyncWalk() {
|
||||||
TaskManager::FindMe()->Post();
|
TaskManager::FindMe()->Post();
|
||||||
Path dir(folder_path_);
|
Path dir(folder_path_);
|
||||||
if (dir.Exists() == false || dir.IsDirectory() == false) {
|
if (dir.Exists() == false || dir.IsDirectory() == false) {
|
||||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open image folder: " + folder_path_);
|
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open " + DatasetName() + ": " + folder_path_);
|
||||||
}
|
}
|
||||||
dirname_offset_ = folder_path_.length();
|
dirname_offset_ = folder_path_.length();
|
||||||
RETURN_IF_NOT_OK(RecursiveWalkFolder(&dir));
|
RETURN_IF_NOT_OK(RecursiveWalkFolder(&dir));
|
||||||
|
@ -250,7 +248,7 @@ Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const std::se
|
||||||
std::string err_msg = "";
|
std::string err_msg = "";
|
||||||
int64_t row_cnt = 0;
|
int64_t row_cnt = 0;
|
||||||
err_msg += (dir.Exists() == false || dir.IsDirectory() == false)
|
err_msg += (dir.Exists() == false || dir.IsDirectory() == false)
|
||||||
? "Invalid parameter, image folder path is invalid or not set, path: " + path
|
? "Invalid parameter, input path is invalid or not set, path: " + path
|
||||||
: "";
|
: "";
|
||||||
err_msg +=
|
err_msg +=
|
||||||
(num_classes == nullptr && num_rows == nullptr) ? "Invalid parameter, num_class and num_rows are null.\n" : "";
|
(num_classes == nullptr && num_rows == nullptr) ? "Invalid parameter, num_class and num_rows are null.\n" : "";
|
||||||
|
|
|
@ -96,6 +96,10 @@ class ImageFolderOp : public MappableLeafOp {
|
||||||
/// @return Name of the current Op
|
/// @return Name of the current Op
|
||||||
std::string Name() const override { return "ImageFolderOp"; }
|
std::string Name() const override { return "ImageFolderOp"; }
|
||||||
|
|
||||||
|
// DatasetName name getter
|
||||||
|
// \return DatasetName of the current Op
|
||||||
|
virtual std::string DatasetName(bool upper = false) const { return upper ? "ImageFolder" : "image folder"; }
|
||||||
|
|
||||||
//// \brief Base-class override for GetNumClasses
|
//// \brief Base-class override for GetNumClasses
|
||||||
//// \param[out] num_classes the number of classes
|
//// \param[out] num_classes the number of classes
|
||||||
//// \return Status of the function
|
//// \return Status of the function
|
||||||
|
|
|
@ -258,7 +258,7 @@ Status ManifestOp::CountDatasetInfo() {
|
||||||
num_rows_ = static_cast<int64_t>(image_labelname_.size());
|
num_rows_ = static_cast<int64_t>(image_labelname_.size());
|
||||||
if (num_rows_ == 0) {
|
if (num_rows_ == 0) {
|
||||||
RETURN_STATUS_UNEXPECTED(
|
RETURN_STATUS_UNEXPECTED(
|
||||||
"Invalid data, ManifestDataset API can't read the data file(interface mismatch or no data found). "
|
"Invalid data, ManifestDataset API can't read the data file (interface mismatch or no data found). "
|
||||||
"Check file path: " +
|
"Check file path: " +
|
||||||
file_);
|
file_);
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,7 +66,7 @@ void MnistOp::Print(std::ostream &out, bool show_all) const {
|
||||||
// Call the super class for displaying any common detailed info
|
// Call the super class for displaying any common detailed info
|
||||||
ParallelOp::Print(out, show_all);
|
ParallelOp::Print(out, show_all);
|
||||||
// Then show any custom derived-internal stuff
|
// Then show any custom derived-internal stuff
|
||||||
out << "\nNumber of rows:" << num_rows_ << "\nMNIST Directory: " << folder_path_ << "\n\n";
|
out << "\nNumber of rows:" << num_rows_ << "\n" << DatasetName(true) << " Directory: " << folder_path_ << "\n\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,7 +74,7 @@ void MnistOp::Print(std::ostream &out, bool show_all) const {
|
||||||
Status MnistOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
|
Status MnistOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
|
||||||
if (cls_ids == nullptr || !cls_ids->empty() || image_label_pairs_.empty()) {
|
if (cls_ids == nullptr || !cls_ids->empty() || image_label_pairs_.empty()) {
|
||||||
if (image_label_pairs_.empty()) {
|
if (image_label_pairs_.empty()) {
|
||||||
RETURN_STATUS_UNEXPECTED("Invalid data, no image found in dataset.");
|
RETURN_STATUS_UNEXPECTED("Invalid data, no image found in " + DatasetName() + " file.");
|
||||||
} else {
|
} else {
|
||||||
RETURN_STATUS_UNEXPECTED(
|
RETURN_STATUS_UNEXPECTED(
|
||||||
"[Internal ERROR] Map for containing image-index pair is nullptr or has been set in other place,"
|
"[Internal ERROR] Map for containing image-index pair is nullptr or has been set in other place,"
|
||||||
|
@ -93,7 +93,8 @@ Status MnistOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) co
|
||||||
Status MnistOp::ReadFromReader(std::ifstream *reader, uint32_t *result) {
|
Status MnistOp::ReadFromReader(std::ifstream *reader, uint32_t *result) {
|
||||||
uint32_t res = 0;
|
uint32_t res = 0;
|
||||||
reader->read(reinterpret_cast<char *>(&res), 4);
|
reader->read(reinterpret_cast<char *>(&res), 4);
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(!reader->fail(), "Invalid data, failed to read 4 bytes from file.");
|
CHECK_FAIL_RETURN_UNEXPECTED(!reader->fail(),
|
||||||
|
"Invalid data, failed to read 4 bytes from " + DatasetName() + " file.");
|
||||||
*result = SwapEndian(res);
|
*result = SwapEndian(res);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -104,16 +105,17 @@ uint32_t MnistOp::SwapEndian(uint32_t val) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MnistOp::CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images) {
|
Status MnistOp::CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images) {
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(image_reader->is_open(), "Invalid file, failed to open mnist image file: " + file_name);
|
CHECK_FAIL_RETURN_UNEXPECTED(image_reader->is_open(),
|
||||||
|
"Invalid file, failed to open " + DatasetName() + " image file: " + file_name);
|
||||||
int64_t image_len = image_reader->seekg(0, std::ios::end).tellg();
|
int64_t image_len = image_reader->seekg(0, std::ios::end).tellg();
|
||||||
(void)image_reader->seekg(0, std::ios::beg);
|
(void)image_reader->seekg(0, std::ios::beg);
|
||||||
// The first 16 bytes of the image file are type, number, row and column
|
// The first 16 bytes of the image file are type, number, row and column
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(image_len >= 16, "Invalid file, Mnist file is corrupted: " + file_name);
|
CHECK_FAIL_RETURN_UNEXPECTED(image_len >= 16, "Invalid file, " + DatasetName() + " file is corrupted: " + file_name);
|
||||||
|
|
||||||
uint32_t magic_number;
|
uint32_t magic_number;
|
||||||
RETURN_IF_NOT_OK(ReadFromReader(image_reader, &magic_number));
|
RETURN_IF_NOT_OK(ReadFromReader(image_reader, &magic_number));
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(magic_number == kMnistImageFileMagicNumber,
|
CHECK_FAIL_RETURN_UNEXPECTED(magic_number == kMnistImageFileMagicNumber,
|
||||||
"Invalid file, this is not the mnist image file: " + file_name);
|
"Invalid file, this is not the " + DatasetName() + " image file: " + file_name);
|
||||||
|
|
||||||
uint32_t num_items;
|
uint32_t num_items;
|
||||||
RETURN_IF_NOT_OK(ReadFromReader(image_reader, &num_items));
|
RETURN_IF_NOT_OK(ReadFromReader(image_reader, &num_items));
|
||||||
|
@ -132,15 +134,16 @@ Status MnistOp::CheckImage(const std::string &file_name, std::ifstream *image_re
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MnistOp::CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels) {
|
Status MnistOp::CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels) {
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(label_reader->is_open(), "Invalid file, failed to open mnist label file: " + file_name);
|
CHECK_FAIL_RETURN_UNEXPECTED(label_reader->is_open(),
|
||||||
|
"Invalid file, failed to open " + DatasetName() + " label file: " + file_name);
|
||||||
int64_t label_len = label_reader->seekg(0, std::ios::end).tellg();
|
int64_t label_len = label_reader->seekg(0, std::ios::end).tellg();
|
||||||
(void)label_reader->seekg(0, std::ios::beg);
|
(void)label_reader->seekg(0, std::ios::beg);
|
||||||
// The first 8 bytes of the image file are type and number
|
// The first 8 bytes of the image file are type and number
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(label_len >= 8, "Invalid file, Mnist file is corrupted: " + file_name);
|
CHECK_FAIL_RETURN_UNEXPECTED(label_len >= 8, "Invalid file, " + DatasetName() + " file is corrupted: " + file_name);
|
||||||
uint32_t magic_number;
|
uint32_t magic_number;
|
||||||
RETURN_IF_NOT_OK(ReadFromReader(label_reader, &magic_number));
|
RETURN_IF_NOT_OK(ReadFromReader(label_reader, &magic_number));
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(magic_number == kMnistLabelFileMagicNumber,
|
CHECK_FAIL_RETURN_UNEXPECTED(magic_number == kMnistLabelFileMagicNumber,
|
||||||
"Invalid file, this is not the mnist label file: " + file_name);
|
"Invalid file, this is not the " + DatasetName() + " label file: " + file_name);
|
||||||
uint32_t num_items;
|
uint32_t num_items;
|
||||||
RETURN_IF_NOT_OK(ReadFromReader(label_reader, &num_items));
|
RETURN_IF_NOT_OK(ReadFromReader(label_reader, &num_items));
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED((label_len - 8) == num_items, "Invalid data, number of labels is wrong.");
|
CHECK_FAIL_RETURN_UNEXPECTED((label_len - 8) == num_items, "Invalid data, number of labels is wrong.");
|
||||||
|
@ -159,18 +162,18 @@ Status MnistOp::ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *la
|
||||||
auto images_buf = std::make_unique<char[]>(size * num_images);
|
auto images_buf = std::make_unique<char[]>(size * num_images);
|
||||||
auto labels_buf = std::make_unique<char[]>(num_images);
|
auto labels_buf = std::make_unique<char[]>(num_images);
|
||||||
if (images_buf == nullptr || labels_buf == nullptr) {
|
if (images_buf == nullptr || labels_buf == nullptr) {
|
||||||
std::string err_msg = "[Internal ERROR] Failed to allocate memory for MNIST buffer.";
|
std::string err_msg = "[Internal ERROR] Failed to allocate memory for " + DatasetName() + " buffer.";
|
||||||
MS_LOG(ERROR) << err_msg.c_str();
|
MS_LOG(ERROR) << err_msg.c_str();
|
||||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||||
}
|
}
|
||||||
(void)image_reader->read(images_buf.get(), size * num_images);
|
(void)image_reader->read(images_buf.get(), size * num_images);
|
||||||
if (image_reader->fail()) {
|
if (image_reader->fail()) {
|
||||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to read image: " + image_names_[index] +
|
RETURN_STATUS_UNEXPECTED("Invalid file, failed to read " + DatasetName() + " image: " + image_names_[index] +
|
||||||
", size:" + std::to_string(size * num_images) + ". Ensure data file is not damaged.");
|
", size:" + std::to_string(size * num_images) + ". Ensure data file is not damaged.");
|
||||||
}
|
}
|
||||||
(void)label_reader->read(labels_buf.get(), num_images);
|
(void)label_reader->read(labels_buf.get(), num_images);
|
||||||
if (label_reader->fail()) {
|
if (label_reader->fail()) {
|
||||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to read label:" + label_names_[index] +
|
RETURN_STATUS_UNEXPECTED("Invalid file, failed to read " + DatasetName() + " label:" + label_names_[index] +
|
||||||
", size: " + std::to_string(num_images) + ". Ensure data file is not damaged.");
|
", size: " + std::to_string(num_images) + ". Ensure data file is not damaged.");
|
||||||
}
|
}
|
||||||
TensorShape img_tensor_shape = TensorShape({kMnistImageRows, kMnistImageCols, 1});
|
TensorShape img_tensor_shape = TensorShape({kMnistImageRows, kMnistImageCols, 1});
|
||||||
|
@ -207,10 +210,9 @@ Status MnistOp::ParseMnistData() {
|
||||||
image_label_pairs_.shrink_to_fit();
|
image_label_pairs_.shrink_to_fit();
|
||||||
num_rows_ = image_label_pairs_.size();
|
num_rows_ = image_label_pairs_.size();
|
||||||
if (num_rows_ == 0) {
|
if (num_rows_ == 0) {
|
||||||
RETURN_STATUS_UNEXPECTED(
|
RETURN_STATUS_UNEXPECTED("Invalid data, " + DatasetName(true) +
|
||||||
"Invalid data, MnistDataset API can't read the data file(interface mismatch or no data found). "
|
"Dataset API can't read the data file (interface mismatch or no data found). Check " +
|
||||||
"Check file in directory: " +
|
DatasetName() + " file in directory: " + folder_path_);
|
||||||
folder_path_);
|
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -231,14 +233,14 @@ Status MnistOp::WalkAllFiles() {
|
||||||
std::string fname = file.Basename(); // name of the mnist file
|
std::string fname = file.Basename(); // name of the mnist file
|
||||||
if ((fname.find(prefix + "-images") != std::string::npos) && (fname.find(img_ext) != std::string::npos)) {
|
if ((fname.find(prefix + "-images") != std::string::npos) && (fname.find(img_ext) != std::string::npos)) {
|
||||||
image_names_.push_back(file.ToString());
|
image_names_.push_back(file.ToString());
|
||||||
MS_LOG(INFO) << "Mnist operator found image file at " << fname << ".";
|
MS_LOG(INFO) << DatasetName(true) << " operator found image file at " << fname << ".";
|
||||||
} else if ((fname.find(prefix + "-labels") != std::string::npos) && (fname.find(lbl_ext) != std::string::npos)) {
|
} else if ((fname.find(prefix + "-labels") != std::string::npos) && (fname.find(lbl_ext) != std::string::npos)) {
|
||||||
label_names_.push_back(file.ToString());
|
label_names_.push_back(file.ToString());
|
||||||
MS_LOG(INFO) << "Mnist Operator found label file at " << fname << ".";
|
MS_LOG(INFO) << DatasetName(true) << " Operator found label file at " << fname << ".";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(WARNING) << "Mnist operator unable to open directory " << dir.ToString() << ".";
|
MS_LOG(WARNING) << DatasetName(true) << " operator unable to open directory " << dir.ToString() << ".";
|
||||||
}
|
}
|
||||||
|
|
||||||
std::sort(image_names_.begin(), image_names_.end());
|
std::sort(image_names_.begin(), image_names_.end());
|
||||||
|
@ -252,7 +254,7 @@ Status MnistOp::WalkAllFiles() {
|
||||||
|
|
||||||
Status MnistOp::LaunchThreadsAndInitOp() {
|
Status MnistOp::LaunchThreadsAndInitOp() {
|
||||||
if (tree_ == nullptr) {
|
if (tree_ == nullptr) {
|
||||||
RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set.");
|
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Pipeline init failed, Execution tree not set.");
|
||||||
}
|
}
|
||||||
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
||||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||||
|
|
|
@ -77,6 +77,10 @@ class MnistOp : public MappableLeafOp {
|
||||||
// @return Name of the current Op
|
// @return Name of the current Op
|
||||||
std::string Name() const override { return "MnistOp"; }
|
std::string Name() const override { return "MnistOp"; }
|
||||||
|
|
||||||
|
// DatasetName name getter
|
||||||
|
// \return DatasetName of the current Op
|
||||||
|
virtual std::string DatasetName(bool upper = false) const { return upper ? "Mnist" : "mnist"; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Load a tensor row according to a pair
|
// Load a tensor row according to a pair
|
||||||
// @param row_id_type row_id - id for this tensor row
|
// @param row_id_type row_id - id for this tensor row
|
||||||
|
|
|
@ -50,7 +50,8 @@ void TextFileOp::Print(std::ostream &out, bool show_all) const {
|
||||||
ParallelOp::Print(out, show_all);
|
ParallelOp::Print(out, show_all);
|
||||||
// Then show any custom derived-internal stuff
|
// Then show any custom derived-internal stuff
|
||||||
out << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
|
out << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
|
||||||
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nText files list:\n";
|
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\n"
|
||||||
|
<< DatasetName(true) << " list:\n";
|
||||||
for (size_t i = 0; i < text_files_list_.size(); ++i) {
|
for (size_t i = 0; i < text_files_list_.size(); ++i) {
|
||||||
out << " " << text_files_list_[i];
|
out << " " << text_files_list_[i];
|
||||||
}
|
}
|
||||||
|
@ -81,13 +82,13 @@ Status TextFileOp::LoadTensor(const std::string &line, TensorRow *out_row) {
|
||||||
Status TextFileOp::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
|
Status TextFileOp::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
|
||||||
auto realpath = FileUtils::GetRealPath(file.data());
|
auto realpath = FileUtils::GetRealPath(file.data());
|
||||||
if (!realpath.has_value()) {
|
if (!realpath.has_value()) {
|
||||||
MS_LOG(ERROR) << "Invalid file, get real path failed, path=" << file;
|
MS_LOG(ERROR) << "Invalid file, " + DatasetName() + " get real path failed, path=" << file;
|
||||||
RETURN_STATUS_UNEXPECTED("Invalid file, get real path failed, path=" + file);
|
RETURN_STATUS_UNEXPECTED("Invalid file, " + DatasetName() + " get real path failed, path=" + file);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::ifstream handle(realpath.value());
|
std::ifstream handle(realpath.value());
|
||||||
if (!handle.is_open()) {
|
if (!handle.is_open()) {
|
||||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file);
|
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open " + DatasetName() + ": " + file);
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t rows_total = 0;
|
int64_t rows_total = 0;
|
||||||
|
@ -204,10 +205,9 @@ Status TextFileOp::CalculateNumRowsPerShard() {
|
||||||
ss << " " << text_files_list_[i];
|
ss << " " << text_files_list_[i];
|
||||||
}
|
}
|
||||||
std::string file_list = ss.str();
|
std::string file_list = ss.str();
|
||||||
RETURN_STATUS_UNEXPECTED(
|
RETURN_STATUS_UNEXPECTED("Invalid data, " + DatasetName(true) +
|
||||||
"Invalid data, TextDataset API can't read the data file(interface mismatch or no data found). "
|
"Dataset API can't read the data file (interface mismatch or no data found). Check " +
|
||||||
"Check file: " +
|
DatasetName() + ": " + file_list);
|
||||||
file_list);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
|
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
|
||||||
|
|
|
@ -74,6 +74,10 @@ class TextFileOp : public NonMappableLeafOp {
|
||||||
// @return Name of the current Op
|
// @return Name of the current Op
|
||||||
std::string Name() const override { return "TextFileOp"; }
|
std::string Name() const override { return "TextFileOp"; }
|
||||||
|
|
||||||
|
// DatasetName name getter
|
||||||
|
// \return DatasetName of the current Op
|
||||||
|
virtual std::string DatasetName(bool upper = false) const { return upper ? "TextFile" : "text file"; }
|
||||||
|
|
||||||
// File names getter
|
// File names getter
|
||||||
// @return Vector of the input file names
|
// @return Vector of the input file names
|
||||||
std::vector<std::string> FileNames() { return text_files_list_; }
|
std::vector<std::string> FileNames() { return text_files_list_; }
|
||||||
|
|
|
@ -167,7 +167,7 @@ Status TFReaderOp::CalculateNumRowsPerShard() {
|
||||||
}
|
}
|
||||||
std::string file_list = ss.str();
|
std::string file_list = ss.str();
|
||||||
RETURN_STATUS_UNEXPECTED(
|
RETURN_STATUS_UNEXPECTED(
|
||||||
"Invalid data, TFRecordDataset API can't read the data file(interface mismatch or no data under the file). "
|
"Invalid data, TFRecordDataset API can't read the data file (interface mismatch or no data under the file). "
|
||||||
"Check file path." +
|
"Check file path." +
|
||||||
file_list);
|
file_list);
|
||||||
}
|
}
|
||||||
|
|
|
@ -277,7 +277,7 @@ Status USPSOp::CalculateNumRowsPerShard() {
|
||||||
}
|
}
|
||||||
std::string file_list = ss.str();
|
std::string file_list = ss.str();
|
||||||
RETURN_STATUS_UNEXPECTED(
|
RETURN_STATUS_UNEXPECTED(
|
||||||
"Invalid data, USPSDataset API can't read the data file(interface mismatch or no data found). "
|
"Invalid data, USPSDataset API can't read the data file (interface mismatch or no data found). "
|
||||||
"Check file: " +
|
"Check file: " +
|
||||||
file_list);
|
file_list);
|
||||||
}
|
}
|
||||||
|
|
|
@ -167,7 +167,7 @@ Status VOCOp::ParseAnnotationIds() {
|
||||||
num_rows_ = image_ids_.size();
|
num_rows_ = image_ids_.size();
|
||||||
if (num_rows_ == 0) {
|
if (num_rows_ == 0) {
|
||||||
RETURN_STATUS_UNEXPECTED(
|
RETURN_STATUS_UNEXPECTED(
|
||||||
"Invalid data, VOCDataset API can't read the data file(interface mismatch or no data found). "
|
"Invalid data, VOCDataset API can't read the data file (interface mismatch or no data found). "
|
||||||
"Check file in directory:" +
|
"Check file in directory:" +
|
||||||
folder_path_);
|
folder_path_);
|
||||||
}
|
}
|
||||||
|
|
|
@ -417,7 +417,7 @@ def test_cifar_usage():
|
||||||
assert test_config("all") == 10000
|
assert test_config("all") == 10000
|
||||||
assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid")
|
assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid")
|
||||||
assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
|
assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
|
||||||
assert "Cifar10Dataset API can't read the data file(interface mismatch or no data found)" in test_config("test")
|
assert "Cifar10Dataset API can't read the data file (interface mismatch or no data found)" in test_config("test")
|
||||||
|
|
||||||
# test the usage of CIFAR10
|
# test the usage of CIFAR10
|
||||||
assert test_config("test", False) == 10000
|
assert test_config("test", False) == 10000
|
||||||
|
|
|
@ -270,7 +270,7 @@ def test_mnist_usage():
|
||||||
|
|
||||||
assert test_config("test") == 10000
|
assert test_config("test") == 10000
|
||||||
assert test_config("all") == 10000
|
assert test_config("all") == 10000
|
||||||
assert "MnistDataset API can't read the data file(interface mismatch or no data found)" in test_config("train")
|
assert "MnistDataset API can't read the data file (interface mismatch or no data found)" in test_config("train")
|
||||||
assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid")
|
assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid")
|
||||||
assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
|
assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue