forked from mindspore-Ecosystem/mindspore
!4847 [MD] Add some data validation
Merge pull request !4847 from xiefangqi/xfq_add_numworker_validate
This commit is contained in:
commit
98052f9d06
|
@ -309,6 +309,19 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
||||||
/// \param[in] num_workers The number of threads in this operator
|
/// \param[in] num_workers The number of threads in this operator
|
||||||
/// \return Shared pointer to the original object
|
/// \return Shared pointer to the original object
|
||||||
std::shared_ptr<Dataset> SetNumWorkers(int32_t num_workers) {
|
std::shared_ptr<Dataset> SetNumWorkers(int32_t num_workers) {
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
#ifndef ENABLE_ANDROID
|
||||||
|
int32_t cpu_count = sysconf(_SC_NPROCESSORS_CONF);
|
||||||
|
if (cpu_count < 0 || cpu_count > INT32_MAX) {
|
||||||
|
MS_LOG(ERROR) << "Error determining current CPU: " << cpu_count;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (num_workers < 1 || num_workers > cpu_count) {
|
||||||
|
MS_LOG(ERROR) << "num_workers exceeds the boundary between 1 and " << cpu_count;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
num_workers_ = num_workers;
|
num_workers_ = num_workers;
|
||||||
return shared_from_this();
|
return shared_from_this();
|
||||||
}
|
}
|
||||||
|
@ -336,7 +349,7 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
||||||
/// range would be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency
|
/// range would be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency
|
||||||
/// can be set to default, which corresponds to 0/total_words separately
|
/// can be set to default, which corresponds to 0/total_words separately
|
||||||
/// \param[in] top_k Number of words to be built into vocab. top_k most frequent words are
|
/// \param[in] top_k Number of words to be built into vocab. top_k most frequent words are
|
||||||
// taken. The top_k is taken after freq_range. If not enough top_k, all words will be taken
|
/// taken. The top_k is taken after freq_range. If not enough top_k, all words will be taken
|
||||||
/// \param[in] special_tokens A list of strings, each one is a special token
|
/// \param[in] special_tokens A list of strings, each one is a special token
|
||||||
/// \param[in] special_first Whether special_tokens will be prepended/appended to vocab, If special_tokens
|
/// \param[in] special_first Whether special_tokens will be prepended/appended to vocab, If special_tokens
|
||||||
/// is specified and special_first is set to default, special_tokens will be prepended
|
/// is specified and special_first is set to default, special_tokens will be prepended
|
||||||
|
|
|
@ -555,7 +555,7 @@ def check_map(method):
|
||||||
callbacks], _ = \
|
callbacks], _ = \
|
||||||
parse_user_args(method, *args, **kwargs)
|
parse_user_args(method, *args, **kwargs)
|
||||||
|
|
||||||
nreq_param_columns = ['input_columns', 'output_columns']
|
nreq_param_columns = ['input_columns', 'output_columns', 'columns_order']
|
||||||
|
|
||||||
if columns_order is not None:
|
if columns_order is not None:
|
||||||
type_check(columns_order, (list,), "columns_order")
|
type_check(columns_order, (list,), "columns_order")
|
||||||
|
@ -571,7 +571,7 @@ def check_map(method):
|
||||||
else:
|
else:
|
||||||
type_check(callbacks, (callback.DSCallback,), "callbacks")
|
type_check(callbacks, (callback.DSCallback,), "callbacks")
|
||||||
|
|
||||||
for param_name, param in zip(nreq_param_columns, [input_columns, output_columns]):
|
for param_name, param in zip(nreq_param_columns, [input_columns, output_columns, columns_order]):
|
||||||
if param is not None:
|
if param is not None:
|
||||||
check_columns(param, param_name)
|
check_columns(param, param_name)
|
||||||
if callbacks is not None:
|
if callbacks is not None:
|
||||||
|
|
|
@ -950,3 +950,25 @@ TEST_F(MindDataTestPipeline, TestZipSuccess2) {
|
||||||
// Manually terminate the pipeline
|
// Manually terminate the pipeline
|
||||||
iter->Stop();
|
iter->Stop();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
#ifndef ENABLE_ANDROID
|
||||||
|
TEST_F(MindDataTestPipeline, TestNumWorkersValidate) {
|
||||||
|
// Testing the static zip() function
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestNumWorkersValidate.";
|
||||||
|
|
||||||
|
// Create an ImageFolder Dataset
|
||||||
|
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||||
|
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 9));
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// test if set num_workers=-1
|
||||||
|
std::shared_ptr<Dataset> ds1 = ds->SetNumWorkers(-1);
|
||||||
|
EXPECT_EQ(ds1, nullptr);
|
||||||
|
|
||||||
|
// test if set num_workers>cpu_count
|
||||||
|
std::shared_ptr<Dataset> ds2 = ds->SetNumWorkers(UINT32_MAX);
|
||||||
|
EXPECT_EQ(ds2, nullptr);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
Loading…
Reference in New Issue