diff --git a/docs/api/api_python/mindspore.dataset.config.rst b/docs/api/api_python/mindspore.dataset.config.rst index d9156ba5da0..a8b73070ee6 100644 --- a/docs/api/api_python/mindspore.dataset.config.rst +++ b/docs/api/api_python/mindspore.dataset.config.rst @@ -46,7 +46,8 @@ API示例所需模块的导入代码如下: **异常:** - - **ValueError** - `seed` 小于0或 `seed` 大于MAX_UINT_32时, `seed` 无效。 + - **TypeError** - `seed` 不是int类型。 + - **ValueError** - `seed` 小于0或 `seed` 大于 `UINT32_MAX(4294967295)` 时, `seed` 无效。 .. py:function:: mindspore.dataset.config.get_seed() @@ -66,7 +67,8 @@ API示例所需模块的导入代码如下: **异常:** - - **ValueError** - 当 `size` 小于等于0或 `size` 大于 `MAX_INT_32` 时,线程的队列容量无效。 + - **TypeError** - `size` 不是int类型。 + - **ValueError** - `size` 小于等于0或 `size` 大于 `INT32_MAX(2147483647)` 时,线程的队列容量无效。 .. note:: 用于预取的总内存可能会随着工作线程数量的增加而快速增长,所以当工作线程数量大于4时,每个工作线程的预取大小将减少。 @@ -91,7 +93,8 @@ API示例所需模块的导入代码如下: **异常:** - - **ValueError** - `num` 小于等于0或 `num` 大于MAX_INT_32时,并行工作线程数量设置无效。 + - **TypeError** - `num` 不是int类型。 + - **ValueError** - `num` 小于等于0或 `num` 大于 `INT32_MAX(2147483647)` 时,并行工作线程数量设置无效。 .. py:function:: mindspore.dataset.config.get_num_parallel_workers() @@ -133,7 +136,8 @@ API示例所需模块的导入代码如下: **异常:** - - **ValueError** - `interval` 小于等于0或 `interval` 大于MAX_INT_32时, `interval` 无效。 + - **TypeError** - `interval` 不是int类型。 + - **ValueError** - `interval` 小于等于0或 `interval` 大于 `INT32_MAX(2147483647)` 时, `interval` 无效。 .. py:function:: mindspore.dataset.config.get_monitor_sampling_interval() @@ -154,7 +158,8 @@ API示例所需模块的导入代码如下: **异常:** - - **ValueError** - `timeout` 小于等于0或 `timeout` 大于MAX_INT_32时 `timeout` 无效。 + - **TypeError** - `timeout` 不是int类型。 + - **ValueError** - `timeout` 小于等于0或 `timeout` 大于 `INT32_MAX(2147483647)` 时 `timeout` 无效。 .. py:function:: mindspore.dataset.config.get_callback_timeout() @@ -295,23 +300,24 @@ API示例所需模块的导入代码如下: .. py:function:: mindspore.dataset.config.set_multiprocessing_timeout_interval(interval) - 设置在多进程下,主进程获取数据超时时,告警日志打印的默认时间间隔(秒)。 + 设置在多进程/多线程下,主进程/主线程获取数据超时时,告警日志打印的默认时间间隔(秒)。 **参数:** - - **interval** (int) - 表示多进程下,主进程获取数据超时时,告警日志打印的时间间隔(秒)。 + - **interval** (int) - 表示多进程/多线程下,主进程/主线程获取数据超时时,告警日志打印的时间间隔(秒)。 **异常:** - - **ValueError** - `interval` 小于等于0或 `interval` 大于MAX_INT_32时, `interval` 无效。 + - **TypeError** - `interval` 不是int类型。 + - **ValueError** - `interval` 小于等于0或 `interval` 大于 `INT32_MAX(2147483647)` 时, `interval` 无效。 .. py:function:: mindspore.dataset.config.get_multiprocessing_timeout_interval() - 获取在多进程下,主进程获取数据超时时,告警日志打印的时间间隔的全局配置。 + 获取在多进程/多线程下,主进程/主线程获取数据超时时,告警日志打印的时间间隔的全局配置。 **返回:** - int,表示多进程下,主进程获取数据超时时,告警日志打印的时间间隔(秒)。 + int,表示多进程/多线程下,主进程/主线程获取数据超时时,告警日志打印的时间间隔(秒)。 .. automodule:: mindspore.dataset.config :members: diff --git a/mindspore/python/mindspore/dataset/core/config.py b/mindspore/python/mindspore/dataset/core/config.py index 1c0a2ef934c..909a2b81b80 100644 --- a/mindspore/python/mindspore/dataset/core/config.py +++ b/mindspore/python/mindspore/dataset/core/config.py @@ -107,17 +107,18 @@ def set_seed(seed): seed(int): Random number seed. It is used to generate deterministic random numbers. Raises: - ValueError: If seed is invalid when seed < 0 or seed > MAX_UINT_32. + TypeError: If seed isn't of type int. + ValueError: If seed < 0 or seed > UINT32_MAX(4294967295). Examples: >>> # Set a new global configuration value for the seed value. >>> # Operations with randomness will use the seed value to generate random values. >>> ds.config.set_seed(1000) """ - if not isinstance(seed, int): - raise ValueError("seed isn't of type int.") + if not isinstance(seed, int) or isinstance(seed, bool): + raise TypeError("seed isn't of type int.") if seed < 0 or seed > UINT32_MAX: - raise ValueError("Seed given is not within the required range.") + raise ValueError("seed given is not within the required range [0, UINT32_MAX(4294967295)].") _config.set_seed(seed) random.seed(seed) # numpy.random isn't thread safe @@ -149,7 +150,8 @@ def set_prefetch_size(size): size (int): The length of the cache queue. Raises: - ValueError: If the queue capacity of the thread is invalid when size <= 0 or size > MAX_INT_32. + TypeError: If size is not of type int. + ValueError: If size <= 0 or size > INT32_MAX(2147483647). Note: Since total memory used for prefetch can grow very large with high number of workers, @@ -160,10 +162,10 @@ def set_prefetch_size(size): >>> # Set a new global configuration value for the prefetch size. >>> ds.config.set_prefetch_size(1000) """ - if not isinstance(size, int): - raise ValueError("size isn't of type int.") + if not isinstance(size, int) or isinstance(size, bool): + raise TypeError("size isn't of type int.") if size <= 0 or size > INT32_MAX: - raise ValueError("Prefetch size given is not within the required range.") + raise ValueError("size is not within the required range (0, INT32_MAX(2147483647)].") _config.set_op_connector_size(size) @@ -191,17 +193,19 @@ def set_num_parallel_workers(num): num (int): Number of parallel workers to be used as a default for each operation. Raises: - ValueError: If num_parallel_workers is invalid when num <= 0 or num > MAX_INT_32. + TypeError: If num is not of type int. + ValueError: If num <= 0 or num > INT32_MAX(2147483647). Examples: >>> # Set a new global configuration value for the number of parallel workers. >>> # Now parallel dataset operators will run with 8 workers. >>> ds.config.set_num_parallel_workers(8) """ - if not isinstance(num, int): - raise ValueError("num isn't of type int.") + if not isinstance(num, int) or isinstance(num, bool): + raise TypeError("num isn't of type int.") if num <= 0 or num > INT32_MAX: - raise ValueError("Number of parallel workers given is not within the required range.") + raise ValueError("Number of parallel workers given is not within the required range" + " (0, INT32_MAX(2147483647)].") _config.set_num_parallel_workers(num) @@ -265,16 +269,17 @@ def set_monitor_sampling_interval(interval): interval (int): Interval (in milliseconds) to be used for performance monitor sampling. Raises: - ValueError: If interval is invalid when interval <= 0 or interval > MAX_INT_32. + TypeError: If interval is not type int. + ValueError: If interval <= 0 or interval > INT32_MAX(2147483647). Examples: >>> # Set a new global configuration value for the monitor sampling interval. >>> ds.config.set_monitor_sampling_interval(100) """ - if not isinstance(interval, int): - raise ValueError("interval isn't of type int.") + if not isinstance(interval, int) or isinstance(interval, bool): + raise TypeError("interval isn't of type int.") if interval <= 0 or interval > INT32_MAX: - raise ValueError("Interval given is not within the required range.") + raise ValueError("Interval given is not within the required range (0, INT32_MAX(2147483647)].") _config.set_monitor_sampling_interval(interval) @@ -334,10 +339,11 @@ def _set_auto_workers_config(option): Args: option (int): The id of the profile to use. Raises: - ValueError: If option is not int or not within the range of [0, 6] + TypeError: If option is not of type int. + ValueError: If option is not within the range of [0, 6]. """ - if not isinstance(option, int): - raise ValueError("option isn't of type int.") + if not isinstance(option, int) or isinstance(option, bool): + raise TypeError("option isn't of type int.") if option < 0 or option > 6: raise ValueError("option isn't within the required range of [0, 6].") _config.set_auto_worker_config(option) @@ -366,14 +372,15 @@ def set_callback_timeout(timeout): timeout (int): Timeout (in seconds) to be used to end the wait in DSWaitedCallback in case of a deadlock. Raises: - ValueError: If timeout is invalid when timeout <= 0 or timeout > MAX_INT_32. + TypeError: If timeout is not type int. + ValueError: If timeout <= 0 or timeout > INT32_MAX(2147483647). Examples: >>> # Set a new global configuration value for the timeout value. >>> ds.config.set_callback_timeout(100) """ - if not isinstance(timeout, int): - raise ValueError("timeout isn't of type int.") + if not isinstance(timeout, int) or isinstance(timeout, bool): + raise TypeError("timeout isn't of type int.") if timeout <= 0 or timeout > INT32_MAX: raise ValueError("Timeout given is not within the required range.") _config.set_callback_timeout(timeout) @@ -516,10 +523,10 @@ def set_autotune_interval(interval): >>> # set a new interval for AutoTune >>> ds.config.set_autotune_interval(30) """ - if not isinstance(interval, int): + if not isinstance(interval, int) or isinstance(interval, bool): raise TypeError("interval must be of type int.") if interval < 0 or interval > INT32_MAX: - raise ValueError("Interval given is not within the required range.") + raise ValueError("Interval given is not within the required range [0, INT32_MAX(2147483647)].") _config.set_autotune_interval(interval) @@ -604,7 +611,7 @@ def set_sending_batches(batch_num): >>> # Set a new global configuration value for the sending batches >>> ds.config.set_sending_batches(10) """ - if not isinstance(batch_num, int): + if not isinstance(batch_num, int) or isinstance(batch_num, bool): raise TypeError("batch_num must be an int dtype.") _config.set_sending_batches(batch_num) @@ -680,37 +687,40 @@ def get_enable_watchdog(): def set_multiprocessing_timeout_interval(interval): """ - Set the default interval (in seconds) for multiprocessing timeout when main process gets data from subprocesses. + Set the default interval (in seconds) for multiprocessing/multithreading timeout when main process/thread gets + data from subprocesses/child threads. Args: - interval (int): Interval (in seconds) to be used for multiprocessing timeout when main process gets data from - subprocess. System default: 300s. + interval (int): Interval (in seconds) to be used for multiprocessing/multithreading timeout when main + process/thread gets data from subprocess/child threads. System default: 300s. Raises: - ValueError: If interval is invalid when interval <= 0 or interval > MAX_INT_32. + ValueError: If interval <= 0 or interval > INT32_MAX(2147483647). Examples: - >>> # Set a new global configuration value for multiprocessing timeout when getting data. + >>> # Set a new global configuration value for multiprocessing/multithreading timeout when getting data. >>> ds.config.set_multiprocessing_timeout_interval(300) """ - if not isinstance(interval, int): - raise ValueError("interval isn't of type int.") + if not isinstance(interval, int) or isinstance(interval, bool): + raise TypeError("interval isn't of type int.") if interval <= 0 or interval > INT32_MAX: - raise ValueError("Interval given is not within the required range (0, INT32_MAX).") + raise ValueError("Interval given is not within the required range (0, INT32_MAX(2147483647)).") _config.set_multiprocessing_timeout_interval(interval) def get_multiprocessing_timeout_interval(): """ - Get the global configuration of multiprocessing timeout when main process gets data from subprocesses. + Get the global configuration of multiprocessing/multithreading timeout when main process/thread gets data from + subprocesses/child threads. Returns: - int, interval (in seconds) for multiprocessing timeout when main process gets data from subprocesses - (default is 300s). + int, interval (in seconds) for multiprocessing/multithreading timeout when main process/thread gets data from + subprocesses/child threads (default is 300s). Examples: - >>> # Get the global configuration of multiprocessing timeout when main process gets data from subprocesses. - >>> # If set_multiprocessing_timeout_interval() is never called before, the default value(300) will be returned. + >>> # Get the global configuration of multiprocessing/multithreading timeout when main process/thread gets data + >>> # from subprocesses/child threads. If set_multiprocessing_timeout_interval() is never called before, the + >>> # default value(300) will be returned. >>> multiprocessing_timeout_interval = ds.config.get_multiprocessing_timeout_interval() """ return _config.get_multiprocessing_timeout_interval() diff --git a/tests/ut/python/dataset/test_config.py b/tests/ut/python/dataset/test_config.py index 267cfcbb28f..3be82363cd2 100644 --- a/tests/ut/python/dataset/test_config.py +++ b/tests/ut/python/dataset/test_config.py @@ -32,6 +32,16 @@ DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" +def config_error_func(config_interface, input_args, err_type, except_err_msg): + err_msg = "" + try: + config_interface(input_args) + except err_type as e: + err_msg = str(e) + + assert except_err_msg in err_msg + + def test_basic(): """ Test basic configuration functions @@ -429,6 +439,37 @@ def test_multiprocessing_timeout_interval(): assert saved_config == ds.config.get_multiprocessing_timeout_interval() +def test_config_bool_type_error(): + """ + Feature: Now many interfaces of config support bool input even its valid input is int. + Description: We will raise a type error when input is a bool when it should be int. + Expectation: TypeError will be raised when input is a bool. + """ + # set_seed will raise TypeError if input is a boolean + config_error_func(ds.config.set_seed, True, TypeError, "seed isn't of type int") + + # set_prefetch_size will raise TypeError if input is a boolean + config_error_func(ds.config.set_prefetch_size, True, TypeError, "size isn't of type int") + + # set_num_parallel_workers will raise TypeError if input is a boolean + config_error_func(ds.config.set_num_parallel_workers, True, TypeError, "num isn't of type int") + + # set_monitor_sampling_interval will raise TypeError if input is a boolean + config_error_func(ds.config.set_monitor_sampling_interval, True, TypeError, "interval isn't of type int") + + # set_callback_timeout will raise TypeError if input is a boolean + config_error_func(ds.config.set_callback_timeout, True, TypeError, "timeout isn't of type int") + + # set_autotune_interval will raise TypeError if input is a boolean + config_error_func(ds.config.set_autotune_interval, True, TypeError, "interval must be of type int") + + # set_sending_batches will raise TypeError if input is a boolean + config_error_func(ds.config.set_sending_batches, True, TypeError, "batch_num must be an int dtype") + + # set_multiprocessing_timeout_interval will raise TypeError if input is a boolean + config_error_func(ds.config.set_multiprocessing_timeout_interval, True, TypeError, "interval isn't of type int") + + if __name__ == '__main__': test_basic() test_get_seed() @@ -443,3 +484,4 @@ if __name__ == '__main__': test_auto_num_workers() test_enable_watchdog() test_multiprocessing_timeout_interval() + test_config_bool_type_error()