forked from mindspore-Ecosystem/mindspore
Rectification API
This commit is contained in:
parent
e7ea93dacd
commit
a95f45c4b9
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
# Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -40,7 +40,7 @@ def _init_device_info():
|
|||
As rank_id need to pass into deep layer for numa and device_queue.
|
||||
One process work with only one rank_id, In standalone scenario,
|
||||
rank_id may come from env 'CUDA_VISIBLE_DEVICES', For distribute
|
||||
scenario, rank_id come from _get_global_rank()
|
||||
scenario, rank_id come from _get_global_rank().
|
||||
"""
|
||||
from mindspore import context
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
|
@ -75,7 +75,9 @@ def _init_device_info():
|
|||
|
||||
def set_seed(seed):
|
||||
"""
|
||||
Set the seed to be used in any random generator. This is used to produce deterministic results.
|
||||
If the seed is set, the generated random number will be fixed, this helps to
|
||||
produce deterministic results.
|
||||
|
||||
|
||||
Note:
|
||||
This set_seed function sets the seed in the Python random library and numpy.random library
|
||||
|
@ -84,10 +86,10 @@ def set_seed(seed):
|
|||
does not guarantee deterministic results with num_parallel_workers > 1.
|
||||
|
||||
Args:
|
||||
seed(int): Seed to be set.
|
||||
seed(int): Random number seed. It is used to generate deterministic random numbers.
|
||||
|
||||
Raises:
|
||||
ValueError: If seed is invalid (< 0 or > MAX_UINT_32).
|
||||
ValueError: If seed is invalid when seed < 0 or seed > MAX_UINT_32.
|
||||
|
||||
Examples:
|
||||
>>> # Set a new global configuration value for the seed value.
|
||||
|
@ -104,28 +106,30 @@ def set_seed(seed):
|
|||
|
||||
def get_seed():
|
||||
"""
|
||||
Get the seed.
|
||||
Get random number seed. If seed has been set, then get_seed will
|
||||
get the seed value that has been set, if the seed is not set,
|
||||
it will return std::mt19937::default_seed.
|
||||
|
||||
Returns:
|
||||
int, seed.
|
||||
int, random number seed.
|
||||
"""
|
||||
return _config.get_seed()
|
||||
|
||||
|
||||
def set_prefetch_size(size):
|
||||
"""
|
||||
Set the number of rows to be prefetched.
|
||||
Set the queue capacity of the thread in pipline.
|
||||
|
||||
Args:
|
||||
size (int): Total number of rows to be prefetched per operator per parallel worker.
|
||||
size (int): The length of the cache queue.
|
||||
|
||||
Raises:
|
||||
ValueError: If prefetch_size is invalid (<= 0 or > MAX_INT_32).
|
||||
ValueError: If the queue capacity of the thread is invalid when size <= 0 or size > MAX_INT_32.
|
||||
|
||||
Note:
|
||||
Since total memory used for prefetch can grow very large with high number of workers,
|
||||
when number of workers is > 4, the per worker prefetch size will be reduced. The actual
|
||||
prefetch size at runtime per worker will be prefetchsize * (4 / num_parallel_workers).
|
||||
when number of workers is greater than 4, the per worker prefetch size will be reduced.
|
||||
The actual prefetch size at runtime per worker will be prefetchsize * (4 / num_parallel_workers).
|
||||
|
||||
Examples:
|
||||
>>> # Set a new global configuration value for the prefetch size.
|
||||
|
@ -138,7 +142,7 @@ def set_prefetch_size(size):
|
|||
|
||||
def get_prefetch_size():
|
||||
"""
|
||||
Get the prefetch size in number of rows.
|
||||
Get the prefetch size as for number of rows.
|
||||
|
||||
Returns:
|
||||
int, total number of rows to be prefetched.
|
||||
|
@ -148,13 +152,14 @@ def get_prefetch_size():
|
|||
|
||||
def set_num_parallel_workers(num):
|
||||
"""
|
||||
Set the default number of parallel workers.
|
||||
Set a new global configuration default value for the number of parallel workers.
|
||||
This setting will affect the parallelism of all dataset operation.
|
||||
|
||||
Args:
|
||||
num (int): Number of parallel workers to be used as a default for each operation.
|
||||
|
||||
Raises:
|
||||
ValueError: If num_parallel_workers is invalid (<= 0 or > MAX_INT_32).
|
||||
ValueError: If num_parallel_workers is invalid when num <= 0 or num > MAX_INT_32.
|
||||
|
||||
Examples:
|
||||
>>> # Set a new global configuration value for the number of parallel workers.
|
||||
|
@ -169,7 +174,8 @@ def set_num_parallel_workers(num):
|
|||
def get_num_parallel_workers():
|
||||
"""
|
||||
Get the default number of parallel workers.
|
||||
This is the DEFAULT num_parallel_workers value used for each op, it is not related to AutoNumWorker feature.
|
||||
This is the DEFAULT num_parallel_workers value used for each operation, it is not related
|
||||
to AutoNumWorker feature.
|
||||
|
||||
Returns:
|
||||
int, number of parallel workers to be used as a default for each operation.
|
||||
|
@ -216,7 +222,7 @@ def set_monitor_sampling_interval(interval):
|
|||
interval (int): Interval (in milliseconds) to be used for performance monitor sampling.
|
||||
|
||||
Raises:
|
||||
ValueError: If interval is invalid (<= 0 or > MAX_INT_32).
|
||||
ValueError: If interval is invalid when interval <= 0 or interval > MAX_INT_32.
|
||||
|
||||
Examples:
|
||||
>>> # Set a new global configuration value for the monitor sampling interval.
|
||||
|
@ -239,13 +245,15 @@ def get_monitor_sampling_interval():
|
|||
|
||||
def set_auto_num_workers(enable):
|
||||
"""
|
||||
Set num_parallel_workers for each op automatically. (This feature is turned off by default)
|
||||
Set num_parallel_workers for each op automatically(This feature is turned off by default).
|
||||
|
||||
If turned on, the num_parallel_workers in each op will be adjusted automatically, possibly overwriting the
|
||||
num_parallel_workers passed in by user or the default value (if user doesn't pass anything) set by
|
||||
ds.config.set_num_parallel_workers().
|
||||
|
||||
For now, this function is only optimized for YoloV3 dataset with per_batch_map (running map in batch).
|
||||
This feature aims to provide a baseline for optimized num_workers assignment for each op.
|
||||
Op whose num_parallel_workers is adjusted to a new value will be logged.
|
||||
This feature aims to provide a baseline for optimized num_workers assignment for each operation.
|
||||
Operation whose num_parallel_workers is adjusted to a new value will be logged.
|
||||
|
||||
Args:
|
||||
enable (bool): Whether to enable auto num_workers feature or not.
|
||||
|
@ -307,7 +315,7 @@ def set_callback_timeout(timeout):
|
|||
timeout (int): Timeout (in seconds) to be used to end the wait in DSWaitedCallback in case of a deadlock.
|
||||
|
||||
Raises:
|
||||
ValueError: If timeout is invalid (<= 0 or > MAX_INT_32).
|
||||
ValueError: If timeout is invalid when timeout <= 0 or timeout > MAX_INT_32.
|
||||
|
||||
Examples:
|
||||
>>> # Set a new global configuration value for the timeout value.
|
||||
|
@ -324,7 +332,7 @@ def get_callback_timeout():
|
|||
In case of a deadlock, the wait function will exit after the timeout period.
|
||||
|
||||
Returns:
|
||||
int, the duration in seconds.
|
||||
int, Timeout (in seconds) to be used to end the wait in DSWaitedCallback in case of a deadlock.
|
||||
"""
|
||||
return _config.get_callback_timeout()
|
||||
|
||||
|
@ -341,7 +349,7 @@ def __str__():
|
|||
|
||||
def load(file):
|
||||
"""
|
||||
Load configurations from a file.
|
||||
Load the project configuration form the file format.
|
||||
|
||||
Args:
|
||||
file (str): Path of the configuration file to be loaded.
|
||||
|
@ -392,10 +400,10 @@ def get_enable_shared_mem():
|
|||
def set_enable_shared_mem(enable):
|
||||
"""
|
||||
Set the default state of shared memory flag. If shared_mem_enable is True, will use shared memory queues
|
||||
to pass data to processes that are created for operators that set multiprocessing=True.
|
||||
to pass data to processes that are created for operators that set python_multiprocessing=True.
|
||||
|
||||
Args:
|
||||
enable (bool): Whether to use shared memory in operators with "multiprocessing=True"
|
||||
enable (bool): Whether to use shared memory in operators when python_multiprocessing=True.
|
||||
|
||||
Raises:
|
||||
TypeError: If enable is not a boolean data type.
|
||||
|
@ -414,7 +422,7 @@ def set_sending_batches(batch_num):
|
|||
increase, default is 0 which means will send all batches in dataset.
|
||||
|
||||
Raises:
|
||||
TypeError: If batch_num is not a int data type.
|
||||
TypeError: If batch_num is not in int type.
|
||||
|
||||
Examples:
|
||||
>>> # Set a new global configuration value for the sending batches
|
||||
|
|
Loading…
Reference in New Issue