Rectification API

This commit is contained in:
shenwei41 2021-06-15 14:51:48 +08:00
parent e7ea93dacd
commit a95f45c4b9
1 changed files with 34 additions and 26 deletions

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -40,7 +40,7 @@ def _init_device_info():
As rank_id need to pass into deep layer for numa and device_queue.
One process work with only one rank_id, In standalone scenario,
rank_id may come from env 'CUDA_VISIBLE_DEVICES', For distribute
scenario, rank_id come from _get_global_rank()
scenario, rank_id come from _get_global_rank().
"""
from mindspore import context
from mindspore.parallel._auto_parallel_context import auto_parallel_context
@ -75,7 +75,9 @@ def _init_device_info():
def set_seed(seed):
"""
Set the seed to be used in any random generator. This is used to produce deterministic results.
If the seed is set, the generated random number will be fixed, this helps to
produce deterministic results.
Note:
This set_seed function sets the seed in the Python random library and numpy.random library
@ -84,10 +86,10 @@ def set_seed(seed):
does not guarantee deterministic results with num_parallel_workers > 1.
Args:
seed(int): Seed to be set.
seed(int): Random number seed. It is used to generate deterministic random numbers.
Raises:
ValueError: If seed is invalid (< 0 or > MAX_UINT_32).
ValueError: If seed is invalid when seed < 0 or seed > MAX_UINT_32.
Examples:
>>> # Set a new global configuration value for the seed value.
@ -104,28 +106,30 @@ def set_seed(seed):
def get_seed():
"""
Get the seed.
Get random number seed. If seed has been set, then get_seed will
get the seed value that has been set, if the seed is not set,
it will return std::mt19937::default_seed.
Returns:
int, seed.
int, random number seed.
"""
return _config.get_seed()
def set_prefetch_size(size):
"""
Set the number of rows to be prefetched.
Set the queue capacity of the thread in pipline.
Args:
size (int): Total number of rows to be prefetched per operator per parallel worker.
size (int): The length of the cache queue.
Raises:
ValueError: If prefetch_size is invalid (<= 0 or > MAX_INT_32).
ValueError: If the queue capacity of the thread is invalid when size <= 0 or size > MAX_INT_32.
Note:
Since total memory used for prefetch can grow very large with high number of workers,
when number of workers is > 4, the per worker prefetch size will be reduced. The actual
prefetch size at runtime per worker will be prefetchsize * (4 / num_parallel_workers).
when number of workers is greater than 4, the per worker prefetch size will be reduced.
The actual prefetch size at runtime per worker will be prefetchsize * (4 / num_parallel_workers).
Examples:
>>> # Set a new global configuration value for the prefetch size.
@ -138,7 +142,7 @@ def set_prefetch_size(size):
def get_prefetch_size():
"""
Get the prefetch size in number of rows.
Get the prefetch size as for number of rows.
Returns:
int, total number of rows to be prefetched.
@ -148,13 +152,14 @@ def get_prefetch_size():
def set_num_parallel_workers(num):
"""
Set the default number of parallel workers.
Set a new global configuration default value for the number of parallel workers.
This setting will affect the parallelism of all dataset operation.
Args:
num (int): Number of parallel workers to be used as a default for each operation.
Raises:
ValueError: If num_parallel_workers is invalid (<= 0 or > MAX_INT_32).
ValueError: If num_parallel_workers is invalid when num <= 0 or num > MAX_INT_32.
Examples:
>>> # Set a new global configuration value for the number of parallel workers.
@ -169,7 +174,8 @@ def set_num_parallel_workers(num):
def get_num_parallel_workers():
"""
Get the default number of parallel workers.
This is the DEFAULT num_parallel_workers value used for each op, it is not related to AutoNumWorker feature.
This is the DEFAULT num_parallel_workers value used for each operation, it is not related
to AutoNumWorker feature.
Returns:
int, number of parallel workers to be used as a default for each operation.
@ -216,7 +222,7 @@ def set_monitor_sampling_interval(interval):
interval (int): Interval (in milliseconds) to be used for performance monitor sampling.
Raises:
ValueError: If interval is invalid (<= 0 or > MAX_INT_32).
ValueError: If interval is invalid when interval <= 0 or interval > MAX_INT_32.
Examples:
>>> # Set a new global configuration value for the monitor sampling interval.
@ -239,13 +245,15 @@ def get_monitor_sampling_interval():
def set_auto_num_workers(enable):
"""
Set num_parallel_workers for each op automatically. (This feature is turned off by default)
Set num_parallel_workers for each op automatically(This feature is turned off by default).
If turned on, the num_parallel_workers in each op will be adjusted automatically, possibly overwriting the
num_parallel_workers passed in by user or the default value (if user doesn't pass anything) set by
ds.config.set_num_parallel_workers().
For now, this function is only optimized for YoloV3 dataset with per_batch_map (running map in batch).
This feature aims to provide a baseline for optimized num_workers assignment for each op.
Op whose num_parallel_workers is adjusted to a new value will be logged.
This feature aims to provide a baseline for optimized num_workers assignment for each operation.
Operation whose num_parallel_workers is adjusted to a new value will be logged.
Args:
enable (bool): Whether to enable auto num_workers feature or not.
@ -307,7 +315,7 @@ def set_callback_timeout(timeout):
timeout (int): Timeout (in seconds) to be used to end the wait in DSWaitedCallback in case of a deadlock.
Raises:
ValueError: If timeout is invalid (<= 0 or > MAX_INT_32).
ValueError: If timeout is invalid when timeout <= 0 or timeout > MAX_INT_32.
Examples:
>>> # Set a new global configuration value for the timeout value.
@ -324,7 +332,7 @@ def get_callback_timeout():
In case of a deadlock, the wait function will exit after the timeout period.
Returns:
int, the duration in seconds.
int, Timeout (in seconds) to be used to end the wait in DSWaitedCallback in case of a deadlock.
"""
return _config.get_callback_timeout()
@ -341,7 +349,7 @@ def __str__():
def load(file):
"""
Load configurations from a file.
Load the project configuration form the file format.
Args:
file (str): Path of the configuration file to be loaded.
@ -392,10 +400,10 @@ def get_enable_shared_mem():
def set_enable_shared_mem(enable):
"""
Set the default state of shared memory flag. If shared_mem_enable is True, will use shared memory queues
to pass data to processes that are created for operators that set multiprocessing=True.
to pass data to processes that are created for operators that set python_multiprocessing=True.
Args:
enable (bool): Whether to use shared memory in operators with "multiprocessing=True"
enable (bool): Whether to use shared memory in operators when python_multiprocessing=True.
Raises:
TypeError: If enable is not a boolean data type.
@ -414,7 +422,7 @@ def set_sending_batches(batch_num):
increase, default is 0 which means will send all batches in dataset.
Raises:
TypeError: If batch_num is not a int data type.
TypeError: If batch_num is not in int type.
Examples:
>>> # Set a new global configuration value for the sending batches