forked from mindspore-Ecosystem/mindspore
!31646 [dataset][bugfix] 修复config 接口typeerror的问题
Merge pull request !31646 from xiefangqi/md_fix_interface_comments
This commit is contained in:
commit
4110c3d6d6
|
@ -46,7 +46,8 @@ API示例所需模块的导入代码如下:
|
|||
|
||||
**异常:**
|
||||
|
||||
- **ValueError** - `seed` 小于0或 `seed` 大于MAX_UINT_32时, `seed` 无效。
|
||||
- **TypeError** - `seed` 不是int类型。
|
||||
- **ValueError** - `seed` 小于0或 `seed` 大于 `UINT32_MAX(4294967295)` 时, `seed` 无效。
|
||||
|
||||
.. py:function:: mindspore.dataset.config.get_seed()
|
||||
|
||||
|
@ -66,7 +67,8 @@ API示例所需模块的导入代码如下:
|
|||
|
||||
**异常:**
|
||||
|
||||
- **ValueError** - 当 `size` 小于等于0或 `size` 大于 `MAX_INT_32` 时,线程的队列容量无效。
|
||||
- **TypeError** - `size` 不是int类型。
|
||||
- **ValueError** - `size` 小于等于0或 `size` 大于 `INT32_MAX(2147483647)` 时,线程的队列容量无效。
|
||||
|
||||
.. note::
|
||||
用于预取的总内存可能会随着工作线程数量的增加而快速增长,所以当工作线程数量大于4时,每个工作线程的预取大小将减少。
|
||||
|
@ -91,7 +93,8 @@ API示例所需模块的导入代码如下:
|
|||
|
||||
**异常:**
|
||||
|
||||
- **ValueError** - `num` 小于等于0或 `num` 大于MAX_INT_32时,并行工作线程数量设置无效。
|
||||
- **TypeError** - `num` 不是int类型。
|
||||
- **ValueError** - `num` 小于等于0或 `num` 大于 `INT32_MAX(2147483647)` 时,并行工作线程数量设置无效。
|
||||
|
||||
.. py:function:: mindspore.dataset.config.get_num_parallel_workers()
|
||||
|
||||
|
@ -133,7 +136,8 @@ API示例所需模块的导入代码如下:
|
|||
|
||||
**异常:**
|
||||
|
||||
- **ValueError** - `interval` 小于等于0或 `interval` 大于MAX_INT_32时, `interval` 无效。
|
||||
- **TypeError** - `interval` 不是int类型。
|
||||
- **ValueError** - `interval` 小于等于0或 `interval` 大于 `INT32_MAX(2147483647)` 时, `interval` 无效。
|
||||
|
||||
.. py:function:: mindspore.dataset.config.get_monitor_sampling_interval()
|
||||
|
||||
|
@ -154,7 +158,8 @@ API示例所需模块的导入代码如下:
|
|||
|
||||
**异常:**
|
||||
|
||||
- **ValueError** - `timeout` 小于等于0或 `timeout` 大于MAX_INT_32时 `timeout` 无效。
|
||||
- **TypeError** - `timeout` 不是int类型。
|
||||
- **ValueError** - `timeout` 小于等于0或 `timeout` 大于 `INT32_MAX(2147483647)` 时 `timeout` 无效。
|
||||
|
||||
.. py:function:: mindspore.dataset.config.get_callback_timeout()
|
||||
|
||||
|
@ -295,23 +300,24 @@ API示例所需模块的导入代码如下:
|
|||
|
||||
.. py:function:: mindspore.dataset.config.set_multiprocessing_timeout_interval(interval)
|
||||
|
||||
设置在多进程下,主进程获取数据超时时,告警日志打印的默认时间间隔(秒)。
|
||||
设置在多进程/多线程下,主进程/主线程获取数据超时时,告警日志打印的默认时间间隔(秒)。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **interval** (int) - 表示多进程下,主进程获取数据超时时,告警日志打印的时间间隔(秒)。
|
||||
- **interval** (int) - 表示多进程/多线程下,主进程/主线程获取数据超时时,告警日志打印的时间间隔(秒)。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **ValueError** - `interval` 小于等于0或 `interval` 大于MAX_INT_32时, `interval` 无效。
|
||||
- **TypeError** - `interval` 不是int类型。
|
||||
- **ValueError** - `interval` 小于等于0或 `interval` 大于 `INT32_MAX(2147483647)` 时, `interval` 无效。
|
||||
|
||||
.. py:function:: mindspore.dataset.config.get_multiprocessing_timeout_interval()
|
||||
|
||||
获取在多进程下,主进程获取数据超时时,告警日志打印的时间间隔的全局配置。
|
||||
获取在多进程/多线程下,主进程/主线程获取数据超时时,告警日志打印的时间间隔的全局配置。
|
||||
|
||||
**返回:**
|
||||
|
||||
int,表示多进程下,主进程获取数据超时时,告警日志打印的时间间隔(秒)。
|
||||
int,表示多进程/多线程下,主进程/主线程获取数据超时时,告警日志打印的时间间隔(秒)。
|
||||
|
||||
.. automodule:: mindspore.dataset.config
|
||||
:members:
|
||||
|
|
|
@ -107,17 +107,18 @@ def set_seed(seed):
|
|||
seed(int): Random number seed. It is used to generate deterministic random numbers.
|
||||
|
||||
Raises:
|
||||
ValueError: If seed is invalid when seed < 0 or seed > MAX_UINT_32.
|
||||
TypeError: If seed isn't of type int.
|
||||
ValueError: If seed < 0 or seed > UINT32_MAX(4294967295).
|
||||
|
||||
Examples:
|
||||
>>> # Set a new global configuration value for the seed value.
|
||||
>>> # Operations with randomness will use the seed value to generate random values.
|
||||
>>> ds.config.set_seed(1000)
|
||||
"""
|
||||
if not isinstance(seed, int):
|
||||
raise ValueError("seed isn't of type int.")
|
||||
if not isinstance(seed, int) or isinstance(seed, bool):
|
||||
raise TypeError("seed isn't of type int.")
|
||||
if seed < 0 or seed > UINT32_MAX:
|
||||
raise ValueError("Seed given is not within the required range.")
|
||||
raise ValueError("seed given is not within the required range [0, UINT32_MAX(4294967295)].")
|
||||
_config.set_seed(seed)
|
||||
random.seed(seed)
|
||||
# numpy.random isn't thread safe
|
||||
|
@ -149,7 +150,8 @@ def set_prefetch_size(size):
|
|||
size (int): The length of the cache queue.
|
||||
|
||||
Raises:
|
||||
ValueError: If the queue capacity of the thread is invalid when size <= 0 or size > MAX_INT_32.
|
||||
TypeError: If size is not of type int.
|
||||
ValueError: If size <= 0 or size > INT32_MAX(2147483647).
|
||||
|
||||
Note:
|
||||
Since total memory used for prefetch can grow very large with high number of workers,
|
||||
|
@ -160,10 +162,10 @@ def set_prefetch_size(size):
|
|||
>>> # Set a new global configuration value for the prefetch size.
|
||||
>>> ds.config.set_prefetch_size(1000)
|
||||
"""
|
||||
if not isinstance(size, int):
|
||||
raise ValueError("size isn't of type int.")
|
||||
if not isinstance(size, int) or isinstance(size, bool):
|
||||
raise TypeError("size isn't of type int.")
|
||||
if size <= 0 or size > INT32_MAX:
|
||||
raise ValueError("Prefetch size given is not within the required range.")
|
||||
raise ValueError("size is not within the required range (0, INT32_MAX(2147483647)].")
|
||||
_config.set_op_connector_size(size)
|
||||
|
||||
|
||||
|
@ -191,17 +193,19 @@ def set_num_parallel_workers(num):
|
|||
num (int): Number of parallel workers to be used as a default for each operation.
|
||||
|
||||
Raises:
|
||||
ValueError: If num_parallel_workers is invalid when num <= 0 or num > MAX_INT_32.
|
||||
TypeError: If num is not of type int.
|
||||
ValueError: If num <= 0 or num > INT32_MAX(2147483647).
|
||||
|
||||
Examples:
|
||||
>>> # Set a new global configuration value for the number of parallel workers.
|
||||
>>> # Now parallel dataset operators will run with 8 workers.
|
||||
>>> ds.config.set_num_parallel_workers(8)
|
||||
"""
|
||||
if not isinstance(num, int):
|
||||
raise ValueError("num isn't of type int.")
|
||||
if not isinstance(num, int) or isinstance(num, bool):
|
||||
raise TypeError("num isn't of type int.")
|
||||
if num <= 0 or num > INT32_MAX:
|
||||
raise ValueError("Number of parallel workers given is not within the required range.")
|
||||
raise ValueError("Number of parallel workers given is not within the required range"
|
||||
" (0, INT32_MAX(2147483647)].")
|
||||
_config.set_num_parallel_workers(num)
|
||||
|
||||
|
||||
|
@ -265,16 +269,17 @@ def set_monitor_sampling_interval(interval):
|
|||
interval (int): Interval (in milliseconds) to be used for performance monitor sampling.
|
||||
|
||||
Raises:
|
||||
ValueError: If interval is invalid when interval <= 0 or interval > MAX_INT_32.
|
||||
TypeError: If interval is not type int.
|
||||
ValueError: If interval <= 0 or interval > INT32_MAX(2147483647).
|
||||
|
||||
Examples:
|
||||
>>> # Set a new global configuration value for the monitor sampling interval.
|
||||
>>> ds.config.set_monitor_sampling_interval(100)
|
||||
"""
|
||||
if not isinstance(interval, int):
|
||||
raise ValueError("interval isn't of type int.")
|
||||
if not isinstance(interval, int) or isinstance(interval, bool):
|
||||
raise TypeError("interval isn't of type int.")
|
||||
if interval <= 0 or interval > INT32_MAX:
|
||||
raise ValueError("Interval given is not within the required range.")
|
||||
raise ValueError("Interval given is not within the required range (0, INT32_MAX(2147483647)].")
|
||||
_config.set_monitor_sampling_interval(interval)
|
||||
|
||||
|
||||
|
@ -334,10 +339,11 @@ def _set_auto_workers_config(option):
|
|||
Args:
|
||||
option (int): The id of the profile to use.
|
||||
Raises:
|
||||
ValueError: If option is not int or not within the range of [0, 6]
|
||||
TypeError: If option is not of type int.
|
||||
ValueError: If option is not within the range of [0, 6].
|
||||
"""
|
||||
if not isinstance(option, int):
|
||||
raise ValueError("option isn't of type int.")
|
||||
if not isinstance(option, int) or isinstance(option, bool):
|
||||
raise TypeError("option isn't of type int.")
|
||||
if option < 0 or option > 6:
|
||||
raise ValueError("option isn't within the required range of [0, 6].")
|
||||
_config.set_auto_worker_config(option)
|
||||
|
@ -366,14 +372,15 @@ def set_callback_timeout(timeout):
|
|||
timeout (int): Timeout (in seconds) to be used to end the wait in DSWaitedCallback in case of a deadlock.
|
||||
|
||||
Raises:
|
||||
ValueError: If timeout is invalid when timeout <= 0 or timeout > MAX_INT_32.
|
||||
TypeError: If timeout is not type int.
|
||||
ValueError: If timeout <= 0 or timeout > INT32_MAX(2147483647).
|
||||
|
||||
Examples:
|
||||
>>> # Set a new global configuration value for the timeout value.
|
||||
>>> ds.config.set_callback_timeout(100)
|
||||
"""
|
||||
if not isinstance(timeout, int):
|
||||
raise ValueError("timeout isn't of type int.")
|
||||
if not isinstance(timeout, int) or isinstance(timeout, bool):
|
||||
raise TypeError("timeout isn't of type int.")
|
||||
if timeout <= 0 or timeout > INT32_MAX:
|
||||
raise ValueError("Timeout given is not within the required range.")
|
||||
_config.set_callback_timeout(timeout)
|
||||
|
@ -516,10 +523,10 @@ def set_autotune_interval(interval):
|
|||
>>> # set a new interval for AutoTune
|
||||
>>> ds.config.set_autotune_interval(30)
|
||||
"""
|
||||
if not isinstance(interval, int):
|
||||
if not isinstance(interval, int) or isinstance(interval, bool):
|
||||
raise TypeError("interval must be of type int.")
|
||||
if interval < 0 or interval > INT32_MAX:
|
||||
raise ValueError("Interval given is not within the required range.")
|
||||
raise ValueError("Interval given is not within the required range [0, INT32_MAX(2147483647)].")
|
||||
_config.set_autotune_interval(interval)
|
||||
|
||||
|
||||
|
@ -604,7 +611,7 @@ def set_sending_batches(batch_num):
|
|||
>>> # Set a new global configuration value for the sending batches
|
||||
>>> ds.config.set_sending_batches(10)
|
||||
"""
|
||||
if not isinstance(batch_num, int):
|
||||
if not isinstance(batch_num, int) or isinstance(batch_num, bool):
|
||||
raise TypeError("batch_num must be an int dtype.")
|
||||
_config.set_sending_batches(batch_num)
|
||||
|
||||
|
@ -680,37 +687,40 @@ def get_enable_watchdog():
|
|||
|
||||
def set_multiprocessing_timeout_interval(interval):
|
||||
"""
|
||||
Set the default interval (in seconds) for multiprocessing timeout when main process gets data from subprocesses.
|
||||
Set the default interval (in seconds) for multiprocessing/multithreading timeout when main process/thread gets
|
||||
data from subprocesses/child threads.
|
||||
|
||||
Args:
|
||||
interval (int): Interval (in seconds) to be used for multiprocessing timeout when main process gets data from
|
||||
subprocess. System default: 300s.
|
||||
interval (int): Interval (in seconds) to be used for multiprocessing/multithreading timeout when main
|
||||
process/thread gets data from subprocess/child threads. System default: 300s.
|
||||
|
||||
Raises:
|
||||
ValueError: If interval is invalid when interval <= 0 or interval > MAX_INT_32.
|
||||
ValueError: If interval <= 0 or interval > INT32_MAX(2147483647).
|
||||
|
||||
Examples:
|
||||
>>> # Set a new global configuration value for multiprocessing timeout when getting data.
|
||||
>>> # Set a new global configuration value for multiprocessing/multithreading timeout when getting data.
|
||||
>>> ds.config.set_multiprocessing_timeout_interval(300)
|
||||
"""
|
||||
if not isinstance(interval, int):
|
||||
raise ValueError("interval isn't of type int.")
|
||||
if not isinstance(interval, int) or isinstance(interval, bool):
|
||||
raise TypeError("interval isn't of type int.")
|
||||
if interval <= 0 or interval > INT32_MAX:
|
||||
raise ValueError("Interval given is not within the required range (0, INT32_MAX).")
|
||||
raise ValueError("Interval given is not within the required range (0, INT32_MAX(2147483647)).")
|
||||
_config.set_multiprocessing_timeout_interval(interval)
|
||||
|
||||
|
||||
def get_multiprocessing_timeout_interval():
|
||||
"""
|
||||
Get the global configuration of multiprocessing timeout when main process gets data from subprocesses.
|
||||
Get the global configuration of multiprocessing/multithreading timeout when main process/thread gets data from
|
||||
subprocesses/child threads.
|
||||
|
||||
Returns:
|
||||
int, interval (in seconds) for multiprocessing timeout when main process gets data from subprocesses
|
||||
(default is 300s).
|
||||
int, interval (in seconds) for multiprocessing/multithreading timeout when main process/thread gets data from
|
||||
subprocesses/child threads (default is 300s).
|
||||
|
||||
Examples:
|
||||
>>> # Get the global configuration of multiprocessing timeout when main process gets data from subprocesses.
|
||||
>>> # If set_multiprocessing_timeout_interval() is never called before, the default value(300) will be returned.
|
||||
>>> # Get the global configuration of multiprocessing/multithreading timeout when main process/thread gets data
|
||||
>>> # from subprocesses/child threads. If set_multiprocessing_timeout_interval() is never called before, the
|
||||
>>> # default value(300) will be returned.
|
||||
>>> multiprocessing_timeout_interval = ds.config.get_multiprocessing_timeout_interval()
|
||||
"""
|
||||
return _config.get_multiprocessing_timeout_interval()
|
||||
|
|
|
@ -32,6 +32,16 @@ DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
|||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
|
||||
|
||||
def config_error_func(config_interface, input_args, err_type, except_err_msg):
|
||||
err_msg = ""
|
||||
try:
|
||||
config_interface(input_args)
|
||||
except err_type as e:
|
||||
err_msg = str(e)
|
||||
|
||||
assert except_err_msg in err_msg
|
||||
|
||||
|
||||
def test_basic():
|
||||
"""
|
||||
Test basic configuration functions
|
||||
|
@ -429,6 +439,37 @@ def test_multiprocessing_timeout_interval():
|
|||
assert saved_config == ds.config.get_multiprocessing_timeout_interval()
|
||||
|
||||
|
||||
def test_config_bool_type_error():
|
||||
"""
|
||||
Feature: Now many interfaces of config support bool input even its valid input is int.
|
||||
Description: We will raise a type error when input is a bool when it should be int.
|
||||
Expectation: TypeError will be raised when input is a bool.
|
||||
"""
|
||||
# set_seed will raise TypeError if input is a boolean
|
||||
config_error_func(ds.config.set_seed, True, TypeError, "seed isn't of type int")
|
||||
|
||||
# set_prefetch_size will raise TypeError if input is a boolean
|
||||
config_error_func(ds.config.set_prefetch_size, True, TypeError, "size isn't of type int")
|
||||
|
||||
# set_num_parallel_workers will raise TypeError if input is a boolean
|
||||
config_error_func(ds.config.set_num_parallel_workers, True, TypeError, "num isn't of type int")
|
||||
|
||||
# set_monitor_sampling_interval will raise TypeError if input is a boolean
|
||||
config_error_func(ds.config.set_monitor_sampling_interval, True, TypeError, "interval isn't of type int")
|
||||
|
||||
# set_callback_timeout will raise TypeError if input is a boolean
|
||||
config_error_func(ds.config.set_callback_timeout, True, TypeError, "timeout isn't of type int")
|
||||
|
||||
# set_autotune_interval will raise TypeError if input is a boolean
|
||||
config_error_func(ds.config.set_autotune_interval, True, TypeError, "interval must be of type int")
|
||||
|
||||
# set_sending_batches will raise TypeError if input is a boolean
|
||||
config_error_func(ds.config.set_sending_batches, True, TypeError, "batch_num must be an int dtype")
|
||||
|
||||
# set_multiprocessing_timeout_interval will raise TypeError if input is a boolean
|
||||
config_error_func(ds.config.set_multiprocessing_timeout_interval, True, TypeError, "interval isn't of type int")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_basic()
|
||||
test_get_seed()
|
||||
|
@ -443,3 +484,4 @@ if __name__ == '__main__':
|
|||
test_auto_num_workers()
|
||||
test_enable_watchdog()
|
||||
test_multiprocessing_timeout_interval()
|
||||
test_config_bool_type_error()
|
||||
|
|
Loading…
Reference in New Issue