!31646 [dataset][bugfix] 修复config 接口typeerror的问题

Merge pull request !31646 from xiefangqi/md_fix_interface_comments
This commit is contained in:
i-robot 2022-03-23 02:45:01 +00:00 committed by Gitee
commit 4110c3d6d6
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 106 additions and 48 deletions

View File

@ -46,7 +46,8 @@ API示例所需模块的导入代码如下
**异常:** **异常:**
- **ValueError** - `seed` 小于0或 `seed` 大于MAX_UINT_32时 `seed` 无效。 - **TypeError** - `seed` 不是int类型。
- **ValueError** - `seed` 小于0或 `seed` 大于 `UINT32_MAX(4294967295)` 时, `seed` 无效。
.. py:function:: mindspore.dataset.config.get_seed() .. py:function:: mindspore.dataset.config.get_seed()
@ -66,7 +67,8 @@ API示例所需模块的导入代码如下
**异常:** **异常:**
- **ValueError** - 当 `size` 小于等于0或 `size` 大于 `MAX_INT_32` 时,线程的队列容量无效。 - **TypeError** - `size` 不是int类型。
- **ValueError** - `size` 小于等于0或 `size` 大于 `INT32_MAX(2147483647)` 时,线程的队列容量无效。
.. note:: .. note::
用于预取的总内存可能会随着工作线程数量的增加而快速增长所以当工作线程数量大于4时每个工作线程的预取大小将减少。 用于预取的总内存可能会随着工作线程数量的增加而快速增长所以当工作线程数量大于4时每个工作线程的预取大小将减少。
@ -91,7 +93,8 @@ API示例所需模块的导入代码如下
**异常:** **异常:**
- **ValueError** - `num` 小于等于0或 `num` 大于MAX_INT_32时并行工作线程数量设置无效。 - **TypeError** - `num` 不是int类型。
- **ValueError** - `num` 小于等于0或 `num` 大于 `INT32_MAX(2147483647)` 时,并行工作线程数量设置无效。
.. py:function:: mindspore.dataset.config.get_num_parallel_workers() .. py:function:: mindspore.dataset.config.get_num_parallel_workers()
@ -133,7 +136,8 @@ API示例所需模块的导入代码如下
**异常:** **异常:**
- **ValueError** - `interval` 小于等于0或 `interval` 大于MAX_INT_32时 `interval` 无效。 - **TypeError** - `interval` 不是int类型。
- **ValueError** - `interval` 小于等于0或 `interval` 大于 `INT32_MAX(2147483647)` 时, `interval` 无效。
.. py:function:: mindspore.dataset.config.get_monitor_sampling_interval() .. py:function:: mindspore.dataset.config.get_monitor_sampling_interval()
@ -154,7 +158,8 @@ API示例所需模块的导入代码如下
**异常:** **异常:**
- **ValueError** - `timeout` 小于等于0或 `timeout` 大于MAX_INT_32时 `timeout` 无效。 - **TypeError** - `timeout` 不是int类型。
- **ValueError** - `timeout` 小于等于0或 `timeout` 大于 `INT32_MAX(2147483647)``timeout` 无效。
.. py:function:: mindspore.dataset.config.get_callback_timeout() .. py:function:: mindspore.dataset.config.get_callback_timeout()
@ -295,23 +300,24 @@ API示例所需模块的导入代码如下
.. py:function:: mindspore.dataset.config.set_multiprocessing_timeout_interval(interval) .. py:function:: mindspore.dataset.config.set_multiprocessing_timeout_interval(interval)
设置在多进程下,主进程获取数据超时时,告警日志打印的默认时间间隔(秒)。 设置在多进程/多线程下,主进程/主线程获取数据超时时,告警日志打印的默认时间间隔(秒)。
**参数:** **参数:**
- **interval** (int) - 表示多进程下,主进程获取数据超时时,告警日志打印的时间间隔(秒)。 - **interval** (int) - 表示多进程/多线程下,主进程/主线程获取数据超时时,告警日志打印的时间间隔(秒)。
**异常:** **异常:**
- **ValueError** - `interval` 小于等于0或 `interval` 大于MAX_INT_32时 `interval` 无效。 - **TypeError** - `interval` 不是int类型。
- **ValueError** - `interval` 小于等于0或 `interval` 大于 `INT32_MAX(2147483647)` 时, `interval` 无效。
.. py:function:: mindspore.dataset.config.get_multiprocessing_timeout_interval() .. py:function:: mindspore.dataset.config.get_multiprocessing_timeout_interval()
获取在多进程下,主进程获取数据超时时,告警日志打印的时间间隔的全局配置。 获取在多进程/多线程下,主进程/主线程获取数据超时时,告警日志打印的时间间隔的全局配置。
**返回:** **返回:**
int表示多进程下主进程获取数据超时时告警日志打印的时间间隔 int表示多进程/多线程下,主进程/主线程获取数据超时时,告警日志打印的时间间隔(秒)。
.. automodule:: mindspore.dataset.config .. automodule:: mindspore.dataset.config
:members: :members:

View File

@ -107,17 +107,18 @@ def set_seed(seed):
seed(int): Random number seed. It is used to generate deterministic random numbers. seed(int): Random number seed. It is used to generate deterministic random numbers.
Raises: Raises:
ValueError: If seed is invalid when seed < 0 or seed > MAX_UINT_32. TypeError: If seed isn't of type int.
ValueError: If seed < 0 or seed > UINT32_MAX(4294967295).
Examples: Examples:
>>> # Set a new global configuration value for the seed value. >>> # Set a new global configuration value for the seed value.
>>> # Operations with randomness will use the seed value to generate random values. >>> # Operations with randomness will use the seed value to generate random values.
>>> ds.config.set_seed(1000) >>> ds.config.set_seed(1000)
""" """
if not isinstance(seed, int): if not isinstance(seed, int) or isinstance(seed, bool):
raise ValueError("seed isn't of type int.") raise TypeError("seed isn't of type int.")
if seed < 0 or seed > UINT32_MAX: if seed < 0 or seed > UINT32_MAX:
raise ValueError("Seed given is not within the required range.") raise ValueError("seed given is not within the required range [0, UINT32_MAX(4294967295)].")
_config.set_seed(seed) _config.set_seed(seed)
random.seed(seed) random.seed(seed)
# numpy.random isn't thread safe # numpy.random isn't thread safe
@ -149,7 +150,8 @@ def set_prefetch_size(size):
size (int): The length of the cache queue. size (int): The length of the cache queue.
Raises: Raises:
ValueError: If the queue capacity of the thread is invalid when size <= 0 or size > MAX_INT_32. TypeError: If size is not of type int.
ValueError: If size <= 0 or size > INT32_MAX(2147483647).
Note: Note:
Since total memory used for prefetch can grow very large with high number of workers, Since total memory used for prefetch can grow very large with high number of workers,
@ -160,10 +162,10 @@ def set_prefetch_size(size):
>>> # Set a new global configuration value for the prefetch size. >>> # Set a new global configuration value for the prefetch size.
>>> ds.config.set_prefetch_size(1000) >>> ds.config.set_prefetch_size(1000)
""" """
if not isinstance(size, int): if not isinstance(size, int) or isinstance(size, bool):
raise ValueError("size isn't of type int.") raise TypeError("size isn't of type int.")
if size <= 0 or size > INT32_MAX: if size <= 0 or size > INT32_MAX:
raise ValueError("Prefetch size given is not within the required range.") raise ValueError("size is not within the required range (0, INT32_MAX(2147483647)].")
_config.set_op_connector_size(size) _config.set_op_connector_size(size)
@ -191,17 +193,19 @@ def set_num_parallel_workers(num):
num (int): Number of parallel workers to be used as a default for each operation. num (int): Number of parallel workers to be used as a default for each operation.
Raises: Raises:
ValueError: If num_parallel_workers is invalid when num <= 0 or num > MAX_INT_32. TypeError: If num is not of type int.
ValueError: If num <= 0 or num > INT32_MAX(2147483647).
Examples: Examples:
>>> # Set a new global configuration value for the number of parallel workers. >>> # Set a new global configuration value for the number of parallel workers.
>>> # Now parallel dataset operators will run with 8 workers. >>> # Now parallel dataset operators will run with 8 workers.
>>> ds.config.set_num_parallel_workers(8) >>> ds.config.set_num_parallel_workers(8)
""" """
if not isinstance(num, int): if not isinstance(num, int) or isinstance(num, bool):
raise ValueError("num isn't of type int.") raise TypeError("num isn't of type int.")
if num <= 0 or num > INT32_MAX: if num <= 0 or num > INT32_MAX:
raise ValueError("Number of parallel workers given is not within the required range.") raise ValueError("Number of parallel workers given is not within the required range"
" (0, INT32_MAX(2147483647)].")
_config.set_num_parallel_workers(num) _config.set_num_parallel_workers(num)
@ -265,16 +269,17 @@ def set_monitor_sampling_interval(interval):
interval (int): Interval (in milliseconds) to be used for performance monitor sampling. interval (int): Interval (in milliseconds) to be used for performance monitor sampling.
Raises: Raises:
ValueError: If interval is invalid when interval <= 0 or interval > MAX_INT_32. TypeError: If interval is not type int.
ValueError: If interval <= 0 or interval > INT32_MAX(2147483647).
Examples: Examples:
>>> # Set a new global configuration value for the monitor sampling interval. >>> # Set a new global configuration value for the monitor sampling interval.
>>> ds.config.set_monitor_sampling_interval(100) >>> ds.config.set_monitor_sampling_interval(100)
""" """
if not isinstance(interval, int): if not isinstance(interval, int) or isinstance(interval, bool):
raise ValueError("interval isn't of type int.") raise TypeError("interval isn't of type int.")
if interval <= 0 or interval > INT32_MAX: if interval <= 0 or interval > INT32_MAX:
raise ValueError("Interval given is not within the required range.") raise ValueError("Interval given is not within the required range (0, INT32_MAX(2147483647)].")
_config.set_monitor_sampling_interval(interval) _config.set_monitor_sampling_interval(interval)
@ -334,10 +339,11 @@ def _set_auto_workers_config(option):
Args: Args:
option (int): The id of the profile to use. option (int): The id of the profile to use.
Raises: Raises:
ValueError: If option is not int or not within the range of [0, 6] TypeError: If option is not of type int.
ValueError: If option is not within the range of [0, 6].
""" """
if not isinstance(option, int): if not isinstance(option, int) or isinstance(option, bool):
raise ValueError("option isn't of type int.") raise TypeError("option isn't of type int.")
if option < 0 or option > 6: if option < 0 or option > 6:
raise ValueError("option isn't within the required range of [0, 6].") raise ValueError("option isn't within the required range of [0, 6].")
_config.set_auto_worker_config(option) _config.set_auto_worker_config(option)
@ -366,14 +372,15 @@ def set_callback_timeout(timeout):
timeout (int): Timeout (in seconds) to be used to end the wait in DSWaitedCallback in case of a deadlock. timeout (int): Timeout (in seconds) to be used to end the wait in DSWaitedCallback in case of a deadlock.
Raises: Raises:
ValueError: If timeout is invalid when timeout <= 0 or timeout > MAX_INT_32. TypeError: If timeout is not type int.
ValueError: If timeout <= 0 or timeout > INT32_MAX(2147483647).
Examples: Examples:
>>> # Set a new global configuration value for the timeout value. >>> # Set a new global configuration value for the timeout value.
>>> ds.config.set_callback_timeout(100) >>> ds.config.set_callback_timeout(100)
""" """
if not isinstance(timeout, int): if not isinstance(timeout, int) or isinstance(timeout, bool):
raise ValueError("timeout isn't of type int.") raise TypeError("timeout isn't of type int.")
if timeout <= 0 or timeout > INT32_MAX: if timeout <= 0 or timeout > INT32_MAX:
raise ValueError("Timeout given is not within the required range.") raise ValueError("Timeout given is not within the required range.")
_config.set_callback_timeout(timeout) _config.set_callback_timeout(timeout)
@ -516,10 +523,10 @@ def set_autotune_interval(interval):
>>> # set a new interval for AutoTune >>> # set a new interval for AutoTune
>>> ds.config.set_autotune_interval(30) >>> ds.config.set_autotune_interval(30)
""" """
if not isinstance(interval, int): if not isinstance(interval, int) or isinstance(interval, bool):
raise TypeError("interval must be of type int.") raise TypeError("interval must be of type int.")
if interval < 0 or interval > INT32_MAX: if interval < 0 or interval > INT32_MAX:
raise ValueError("Interval given is not within the required range.") raise ValueError("Interval given is not within the required range [0, INT32_MAX(2147483647)].")
_config.set_autotune_interval(interval) _config.set_autotune_interval(interval)
@ -604,7 +611,7 @@ def set_sending_batches(batch_num):
>>> # Set a new global configuration value for the sending batches >>> # Set a new global configuration value for the sending batches
>>> ds.config.set_sending_batches(10) >>> ds.config.set_sending_batches(10)
""" """
if not isinstance(batch_num, int): if not isinstance(batch_num, int) or isinstance(batch_num, bool):
raise TypeError("batch_num must be an int dtype.") raise TypeError("batch_num must be an int dtype.")
_config.set_sending_batches(batch_num) _config.set_sending_batches(batch_num)
@ -680,37 +687,40 @@ def get_enable_watchdog():
def set_multiprocessing_timeout_interval(interval): def set_multiprocessing_timeout_interval(interval):
""" """
Set the default interval (in seconds) for multiprocessing timeout when main process gets data from subprocesses. Set the default interval (in seconds) for multiprocessing/multithreading timeout when main process/thread gets
data from subprocesses/child threads.
Args: Args:
interval (int): Interval (in seconds) to be used for multiprocessing timeout when main process gets data from interval (int): Interval (in seconds) to be used for multiprocessing/multithreading timeout when main
subprocess. System default: 300s. process/thread gets data from subprocess/child threads. System default: 300s.
Raises: Raises:
ValueError: If interval is invalid when interval <= 0 or interval > MAX_INT_32. ValueError: If interval <= 0 or interval > INT32_MAX(2147483647).
Examples: Examples:
>>> # Set a new global configuration value for multiprocessing timeout when getting data. >>> # Set a new global configuration value for multiprocessing/multithreading timeout when getting data.
>>> ds.config.set_multiprocessing_timeout_interval(300) >>> ds.config.set_multiprocessing_timeout_interval(300)
""" """
if not isinstance(interval, int): if not isinstance(interval, int) or isinstance(interval, bool):
raise ValueError("interval isn't of type int.") raise TypeError("interval isn't of type int.")
if interval <= 0 or interval > INT32_MAX: if interval <= 0 or interval > INT32_MAX:
raise ValueError("Interval given is not within the required range (0, INT32_MAX).") raise ValueError("Interval given is not within the required range (0, INT32_MAX(2147483647)).")
_config.set_multiprocessing_timeout_interval(interval) _config.set_multiprocessing_timeout_interval(interval)
def get_multiprocessing_timeout_interval(): def get_multiprocessing_timeout_interval():
""" """
Get the global configuration of multiprocessing timeout when main process gets data from subprocesses. Get the global configuration of multiprocessing/multithreading timeout when main process/thread gets data from
subprocesses/child threads.
Returns: Returns:
int, interval (in seconds) for multiprocessing timeout when main process gets data from subprocesses int, interval (in seconds) for multiprocessing/multithreading timeout when main process/thread gets data from
(default is 300s). subprocesses/child threads (default is 300s).
Examples: Examples:
>>> # Get the global configuration of multiprocessing timeout when main process gets data from subprocesses. >>> # Get the global configuration of multiprocessing/multithreading timeout when main process/thread gets data
>>> # If set_multiprocessing_timeout_interval() is never called before, the default value(300) will be returned. >>> # from subprocesses/child threads. If set_multiprocessing_timeout_interval() is never called before, the
>>> # default value(300) will be returned.
>>> multiprocessing_timeout_interval = ds.config.get_multiprocessing_timeout_interval() >>> multiprocessing_timeout_interval = ds.config.get_multiprocessing_timeout_interval()
""" """
return _config.get_multiprocessing_timeout_interval() return _config.get_multiprocessing_timeout_interval()

View File

@ -32,6 +32,16 @@ DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def config_error_func(config_interface, input_args, err_type, except_err_msg):
err_msg = ""
try:
config_interface(input_args)
except err_type as e:
err_msg = str(e)
assert except_err_msg in err_msg
def test_basic(): def test_basic():
""" """
Test basic configuration functions Test basic configuration functions
@ -429,6 +439,37 @@ def test_multiprocessing_timeout_interval():
assert saved_config == ds.config.get_multiprocessing_timeout_interval() assert saved_config == ds.config.get_multiprocessing_timeout_interval()
def test_config_bool_type_error():
"""
Feature: Now many interfaces of config support bool input even its valid input is int.
Description: We will raise a type error when input is a bool when it should be int.
Expectation: TypeError will be raised when input is a bool.
"""
# set_seed will raise TypeError if input is a boolean
config_error_func(ds.config.set_seed, True, TypeError, "seed isn't of type int")
# set_prefetch_size will raise TypeError if input is a boolean
config_error_func(ds.config.set_prefetch_size, True, TypeError, "size isn't of type int")
# set_num_parallel_workers will raise TypeError if input is a boolean
config_error_func(ds.config.set_num_parallel_workers, True, TypeError, "num isn't of type int")
# set_monitor_sampling_interval will raise TypeError if input is a boolean
config_error_func(ds.config.set_monitor_sampling_interval, True, TypeError, "interval isn't of type int")
# set_callback_timeout will raise TypeError if input is a boolean
config_error_func(ds.config.set_callback_timeout, True, TypeError, "timeout isn't of type int")
# set_autotune_interval will raise TypeError if input is a boolean
config_error_func(ds.config.set_autotune_interval, True, TypeError, "interval must be of type int")
# set_sending_batches will raise TypeError if input is a boolean
config_error_func(ds.config.set_sending_batches, True, TypeError, "batch_num must be an int dtype")
# set_multiprocessing_timeout_interval will raise TypeError if input is a boolean
config_error_func(ds.config.set_multiprocessing_timeout_interval, True, TypeError, "interval isn't of type int")
if __name__ == '__main__': if __name__ == '__main__':
test_basic() test_basic()
test_get_seed() test_get_seed()
@ -443,3 +484,4 @@ if __name__ == '__main__':
test_auto_num_workers() test_auto_num_workers()
test_enable_watchdog() test_enable_watchdog()
test_multiprocessing_timeout_interval() test_multiprocessing_timeout_interval()
test_config_bool_type_error()