[MD][UT] Use default error_samples_mode and debug_mode_flag config values in declient.cfg

This commit is contained in:
Cathy Wong 2022-12-12 11:05:10 -05:00
parent fabc9304e8
commit 0b2afe87f3
3 changed files with 14 additions and 18 deletions

View File

@ -273,7 +273,7 @@ TEST_F(MindDataTestProfiler, TestProfilerManagerByEpoch) {
/// Feature: MindData Profiling Support /// Feature: MindData Profiling Support
/// Description: Test MindData Profiling GetByStep Methods /// Description: Test MindData Profiling GetByStep Methods
/// Expectation: Results are successfully outputted. /// Expectation: Results are successfully outputted.
TEST_F(MindDataTestProfiler, DISABLED_TestProfilerManagerByStep) { TEST_F(MindDataTestProfiler, TestProfilerManagerByStep) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestProfilerManagerByStep."; MS_LOG(INFO) << "Doing MindDataTestPipeline-TestProfilerManagerByStep.";
// Enable profiler and check // Enable profiler and check
common::SetEnv("RANK_ID", "2"); common::SetEnv("RANK_ID", "2");

View File

@ -7,6 +7,6 @@
"seed": 5489, "seed": 5489,
"monitorSamplingInterval": 15, "monitorSamplingInterval": 15,
"fast_recovery": true, "fast_recovery": true,
"debug_mode_flag": true, "debug_mode_flag": false,
"error_samples_mode": 1 "error_samples_mode": 0
} }

View File

@ -29,6 +29,9 @@ import mindspore.dataset.core.config as config
from mindspore import log as logger from mindspore import log as logger
from util import dataset_equal from util import dataset_equal
# Need to run all these tests in separate processes since tests are modifying config parameters
pytestmark = pytest.mark.forked
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
@ -43,7 +46,6 @@ def config_error_func(config_interface, input_args, err_type, except_err_msg):
assert except_err_msg in err_msg assert except_err_msg in err_msg
@pytest.mark.forked
def test_basic(): def test_basic():
""" """
Feature: Config Feature: Config
@ -67,8 +69,8 @@ def test_basic():
assert ds.config.get_seed() == 5489 assert ds.config.get_seed() == 5489
assert ds.config.get_monitor_sampling_interval() == 15 assert ds.config.get_monitor_sampling_interval() == 15
assert ds.config.get_fast_recovery() assert ds.config.get_fast_recovery()
assert ds.config.get_debug_mode() assert not ds.config.get_debug_mode()
assert ds.config.get_error_samples_mode() == config.ErrorSamplesMode.REPLACE assert ds.config.get_error_samples_mode() == config.ErrorSamplesMode.RETURN
ds.config.set_num_parallel_workers(2) ds.config.set_num_parallel_workers(2)
# ds.config.set_worker_connector_size(3) # ds.config.set_worker_connector_size(3)
@ -76,8 +78,8 @@ def test_basic():
ds.config.set_seed(5) ds.config.set_seed(5)
ds.config.set_monitor_sampling_interval(45) ds.config.set_monitor_sampling_interval(45)
ds.config.set_fast_recovery(False) ds.config.set_fast_recovery(False)
ds.config.set_debug_mode(False) ds.config.set_debug_mode(True)
ds.config.set_error_samples_mode(config.ErrorSamplesMode.RETURN) ds.config.set_error_samples_mode(config.ErrorSamplesMode.REPLACE)
assert ds.config.get_num_parallel_workers() == 2 assert ds.config.get_num_parallel_workers() == 2
# assert ds.config.get_worker_connector_size() == 3 # assert ds.config.get_worker_connector_size() == 3
@ -85,13 +87,15 @@ def test_basic():
assert ds.config.get_seed() == 5 assert ds.config.get_seed() == 5
assert ds.config.get_monitor_sampling_interval() == 45 assert ds.config.get_monitor_sampling_interval() == 45
assert not ds.config.get_fast_recovery() assert not ds.config.get_fast_recovery()
assert not ds.config.get_debug_mode() assert ds.config.get_debug_mode()
assert ds.config.get_error_samples_mode() == config.ErrorSamplesMode.RETURN assert ds.config.get_error_samples_mode() == config.ErrorSamplesMode.REPLACE
ds.config.set_fast_recovery(True) ds.config.set_fast_recovery(True)
ds.config.set_debug_mode(False)
ds.config.set_error_samples_mode(config.ErrorSamplesMode.SKIP) ds.config.set_error_samples_mode(config.ErrorSamplesMode.SKIP)
assert ds.config.get_fast_recovery() assert ds.config.get_fast_recovery()
assert not ds.config.get_debug_mode()
assert ds.config.get_error_samples_mode() == config.ErrorSamplesMode.SKIP assert ds.config.get_error_samples_mode() == config.ErrorSamplesMode.SKIP
# Restore original configuration values # Restore original configuration values
@ -534,18 +538,12 @@ def test_fast_recovery():
assert "set_fast_recovery() missing 1 required positional argument: 'fast_recovery'" in str(error_info.value) assert "set_fast_recovery() missing 1 required positional argument: 'fast_recovery'" in str(error_info.value)
@pytest.mark.forked
def test_debug_mode(): def test_debug_mode():
""" """
Feature: Test the debug mode setter/getter function Feature: Test the debug mode setter/getter function
Description: This function only accepts a boolean as input and outputs error otherwise Description: This function only accepts a boolean as input and outputs error otherwise
Expectation: TypeError will be raised when input argument is missing or is not a boolean Expectation: TypeError will be raised when input argument is missing or is not a boolean
""" """
origin_debug_mode_flag = ds.config.get_debug_mode()
# set_debug_mode() to True and then check if the value is indeed True with get_debug_mode().
debug_mode_flag = True
ds.config.set_debug_mode(debug_mode_flag)
assert ds.config.get_debug_mode() == debug_mode_flag
# set_debug_mode will raise TypeError if input is an integer # set_debug_mode will raise TypeError if input is an integer
config_error_func(ds.config.set_debug_mode, 0, TypeError, "debug_mode_flag isn't of type boolean.") config_error_func(ds.config.set_debug_mode, 0, TypeError, "debug_mode_flag isn't of type boolean.")
# set_debug_mode will raise TypeError if input is a string # set_debug_mode will raise TypeError if input is a string
@ -558,8 +556,6 @@ def test_debug_mode():
with pytest.raises(TypeError) as error_info: with pytest.raises(TypeError) as error_info:
ds.config.set_debug_mode() ds.config.set_debug_mode()
assert "set_debug_mode() missing 1 required positional argument: 'debug_mode_flag'" in str(error_info.value) assert "set_debug_mode() missing 1 required positional argument: 'debug_mode_flag'" in str(error_info.value)
# restore to original debug_mode_flag
ds.config.set_debug_mode(origin_debug_mode_flag)
def test_error_samples_mode(): def test_error_samples_mode():