From a95f45c4b9b3f946d4864e2af4611a75231584a7 Mon Sep 17 00:00:00 2001 From: shenwei41 Date: Tue, 15 Jun 2021 14:51:48 +0800 Subject: [PATCH] Rectification API --- mindspore/dataset/core/config.py | 60 ++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 26 deletions(-) diff --git a/mindspore/dataset/core/config.py b/mindspore/dataset/core/config.py index 8adb1505c73..f5111ea7ec4 100644 --- a/mindspore/dataset/core/config.py +++ b/mindspore/dataset/core/config.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2019-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -40,7 +40,7 @@ def _init_device_info(): As rank_id need to pass into deep layer for numa and device_queue. One process work with only one rank_id, In standalone scenario, rank_id may come from env 'CUDA_VISIBLE_DEVICES', For distribute - scenario, rank_id come from _get_global_rank() + scenario, rank_id come from _get_global_rank(). """ from mindspore import context from mindspore.parallel._auto_parallel_context import auto_parallel_context @@ -75,7 +75,9 @@ def _init_device_info(): def set_seed(seed): """ - Set the seed to be used in any random generator. This is used to produce deterministic results. + If the seed is set, the generated random number will be fixed, this helps to + produce deterministic results. + Note: This set_seed function sets the seed in the Python random library and numpy.random library @@ -84,10 +86,10 @@ def set_seed(seed): does not guarantee deterministic results with num_parallel_workers > 1. Args: - seed(int): Seed to be set. + seed(int): Random number seed. It is used to generate deterministic random numbers. Raises: - ValueError: If seed is invalid (< 0 or > MAX_UINT_32). + ValueError: If seed is invalid when seed < 0 or seed > MAX_UINT_32. Examples: >>> # Set a new global configuration value for the seed value. @@ -104,28 +106,30 @@ def set_seed(seed): def get_seed(): """ - Get the seed. + Get random number seed. If seed has been set, then get_seed will + get the seed value that has been set, if the seed is not set, + it will return std::mt19937::default_seed. Returns: - int, seed. + int, random number seed. """ return _config.get_seed() def set_prefetch_size(size): """ - Set the number of rows to be prefetched. + Set the queue capacity of the thread in pipline. Args: - size (int): Total number of rows to be prefetched per operator per parallel worker. + size (int): The length of the cache queue. Raises: - ValueError: If prefetch_size is invalid (<= 0 or > MAX_INT_32). + ValueError: If the queue capacity of the thread is invalid when size <= 0 or size > MAX_INT_32. Note: Since total memory used for prefetch can grow very large with high number of workers, - when number of workers is > 4, the per worker prefetch size will be reduced. The actual - prefetch size at runtime per worker will be prefetchsize * (4 / num_parallel_workers). + when number of workers is greater than 4, the per worker prefetch size will be reduced. + The actual prefetch size at runtime per worker will be prefetchsize * (4 / num_parallel_workers). Examples: >>> # Set a new global configuration value for the prefetch size. @@ -138,7 +142,7 @@ def set_prefetch_size(size): def get_prefetch_size(): """ - Get the prefetch size in number of rows. + Get the prefetch size as for number of rows. Returns: int, total number of rows to be prefetched. @@ -148,13 +152,14 @@ def get_prefetch_size(): def set_num_parallel_workers(num): """ - Set the default number of parallel workers. + Set a new global configuration default value for the number of parallel workers. + This setting will affect the parallelism of all dataset operation. Args: num (int): Number of parallel workers to be used as a default for each operation. Raises: - ValueError: If num_parallel_workers is invalid (<= 0 or > MAX_INT_32). + ValueError: If num_parallel_workers is invalid when num <= 0 or num > MAX_INT_32. Examples: >>> # Set a new global configuration value for the number of parallel workers. @@ -169,7 +174,8 @@ def set_num_parallel_workers(num): def get_num_parallel_workers(): """ Get the default number of parallel workers. - This is the DEFAULT num_parallel_workers value used for each op, it is not related to AutoNumWorker feature. + This is the DEFAULT num_parallel_workers value used for each operation, it is not related + to AutoNumWorker feature. Returns: int, number of parallel workers to be used as a default for each operation. @@ -216,7 +222,7 @@ def set_monitor_sampling_interval(interval): interval (int): Interval (in milliseconds) to be used for performance monitor sampling. Raises: - ValueError: If interval is invalid (<= 0 or > MAX_INT_32). + ValueError: If interval is invalid when interval <= 0 or interval > MAX_INT_32. Examples: >>> # Set a new global configuration value for the monitor sampling interval. @@ -239,13 +245,15 @@ def get_monitor_sampling_interval(): def set_auto_num_workers(enable): """ - Set num_parallel_workers for each op automatically. (This feature is turned off by default) + Set num_parallel_workers for each op automatically(This feature is turned off by default). + If turned on, the num_parallel_workers in each op will be adjusted automatically, possibly overwriting the num_parallel_workers passed in by user or the default value (if user doesn't pass anything) set by ds.config.set_num_parallel_workers(). + For now, this function is only optimized for YoloV3 dataset with per_batch_map (running map in batch). - This feature aims to provide a baseline for optimized num_workers assignment for each op. - Op whose num_parallel_workers is adjusted to a new value will be logged. + This feature aims to provide a baseline for optimized num_workers assignment for each operation. + Operation whose num_parallel_workers is adjusted to a new value will be logged. Args: enable (bool): Whether to enable auto num_workers feature or not. @@ -307,7 +315,7 @@ def set_callback_timeout(timeout): timeout (int): Timeout (in seconds) to be used to end the wait in DSWaitedCallback in case of a deadlock. Raises: - ValueError: If timeout is invalid (<= 0 or > MAX_INT_32). + ValueError: If timeout is invalid when timeout <= 0 or timeout > MAX_INT_32. Examples: >>> # Set a new global configuration value for the timeout value. @@ -324,7 +332,7 @@ def get_callback_timeout(): In case of a deadlock, the wait function will exit after the timeout period. Returns: - int, the duration in seconds. + int, Timeout (in seconds) to be used to end the wait in DSWaitedCallback in case of a deadlock. """ return _config.get_callback_timeout() @@ -341,7 +349,7 @@ def __str__(): def load(file): """ - Load configurations from a file. + Load the project configuration form the file format. Args: file (str): Path of the configuration file to be loaded. @@ -392,10 +400,10 @@ def get_enable_shared_mem(): def set_enable_shared_mem(enable): """ Set the default state of shared memory flag. If shared_mem_enable is True, will use shared memory queues - to pass data to processes that are created for operators that set multiprocessing=True. + to pass data to processes that are created for operators that set python_multiprocessing=True. Args: - enable (bool): Whether to use shared memory in operators with "multiprocessing=True" + enable (bool): Whether to use shared memory in operators when python_multiprocessing=True. Raises: TypeError: If enable is not a boolean data type. @@ -414,7 +422,7 @@ def set_sending_batches(batch_num): increase, default is 0 which means will send all batches in dataset. Raises: - TypeError: If batch_num is not a int data type. + TypeError: If batch_num is not in int type. Examples: >>> # Set a new global configuration value for the sending batches