forked from mindspore-Ecosystem/mindspore
!6185 fix api comments
Merge pull request !6185 from gziyan/fix_api_comments
This commit is contained in:
commit
5a76bd717d
|
@ -32,33 +32,32 @@ DEFAULT_BACKEND = Backend("hccl")
|
|||
|
||||
|
||||
def _get_group(group):
|
||||
"""Get the global world group if the group is default world comm group."""
|
||||
"""Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`."""
|
||||
if group == DEFAULT_WORLD_COMM_GROUP:
|
||||
return GlobalComm.WORLD_COMM_GROUP
|
||||
return group
|
||||
|
||||
|
||||
class GlobalComm:
|
||||
"""Global communication info."""
|
||||
"""World communication information."""
|
||||
BACKEND = DEFAULT_BACKEND
|
||||
WORLD_COMM_GROUP = DEFAULT_WORLD_COMM_GROUP
|
||||
|
||||
|
||||
def init(backend_name=None):
|
||||
"""
|
||||
Init distributed backend, e.g., hccl/nccl, it is required before communication service can be used.
|
||||
Initialize distributed backend, e.g. HCCL/NCCL, it is required before using the communication service.
|
||||
|
||||
Note:
|
||||
The full name of hccl is Huawei Collective Communication Library.
|
||||
The full name of nccl is NVIDIA Collective Communication Library.
|
||||
The full name of HCCL is Huawei Collective Communication Library.
|
||||
The full name of NCCL is NVIDIA Collective Communication Library.
|
||||
|
||||
Args:
|
||||
backend_name (str): Backend.
|
||||
|
||||
Raises:
|
||||
TypeError: If backen_name is not a string.
|
||||
RuntimeError: If device target is invalid.
|
||||
RuntimeError: If backend is invalid or distributed init fails.
|
||||
TypeError: If `backend_name` is not a string.
|
||||
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
|
||||
"""
|
||||
if _is_role_pserver() or _is_role_sched():
|
||||
return
|
||||
|
@ -88,17 +87,17 @@ def init(backend_name=None):
|
|||
|
||||
def release():
|
||||
"""
|
||||
Release distributed resource. e.g., hccl/nccl.
|
||||
Release distributed resource. e.g. HCCL/NCCL.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If distributed resource release fails.
|
||||
RuntimeError: If failed to release distributed resource.
|
||||
"""
|
||||
finalize_hccl()
|
||||
|
||||
|
||||
def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
|
||||
"""
|
||||
Gets rank ID for current device in specified collective communication group.
|
||||
Get the rank ID for the current device in the specified collective communication group.
|
||||
|
||||
Args:
|
||||
group (str): ProcessGroup, the process group to work on. Default: WORLD_COMM_GROUP.
|
||||
|
@ -109,7 +108,7 @@ def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
|
|||
Raises:
|
||||
TypeError: If group is not a string.
|
||||
ValueError: If backend is invalid.
|
||||
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
||||
RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
|
||||
"""
|
||||
return _get_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
|
||||
|
||||
|
@ -130,14 +129,14 @@ def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP):
|
|||
Raises:
|
||||
TypeError: If group is not a string.
|
||||
ValueError: If backend is invalid.
|
||||
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
||||
RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
|
||||
"""
|
||||
return _get_local_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
|
||||
|
||||
|
||||
def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
|
||||
"""
|
||||
Gets rank size of the specified collective communication group.
|
||||
Get the rank size of the specified collective communication group.
|
||||
|
||||
Args:
|
||||
group (str): ProcessGroup, the process group to work on.
|
||||
|
@ -148,7 +147,7 @@ def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
|
|||
Raises:
|
||||
TypeError: If group is not a string.
|
||||
ValueError: If backend is invalid.
|
||||
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
||||
RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
|
||||
"""
|
||||
return _get_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
|
||||
|
||||
|
@ -164,22 +163,23 @@ def get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP):
|
|||
group (str): ProcessGroup, the process group to work on.
|
||||
|
||||
Returns:
|
||||
int, the local rank size where the calling process is being within the group.
|
||||
int, the local rank size where the calling process is within the group.
|
||||
|
||||
Raises:
|
||||
TypeError: If group is not a string.
|
||||
ValueError: If backend is invalid.
|
||||
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
||||
RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
|
||||
"""
|
||||
return _get_local_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
|
||||
|
||||
|
||||
def get_world_rank_from_group_rank(group, group_rank_id):
|
||||
"""
|
||||
Gets the rank ID in world communication group corresponding to the rank ID in specified user communication group.
|
||||
Gets the rank ID in the world communication group corresponding to
|
||||
the rank ID in the specified user communication group.
|
||||
|
||||
Note:
|
||||
Nccl is not supported.
|
||||
NCCL is not supported.
|
||||
The parameter group should not be "hccl_world_group".
|
||||
|
||||
Args:
|
||||
|
@ -190,52 +190,53 @@ def get_world_rank_from_group_rank(group, group_rank_id):
|
|||
int, the rank ID in world communication group.
|
||||
|
||||
Raises:
|
||||
TypeError: If group_rank_id is not a int or group is not a string.
|
||||
TypeError: If `group_rank_id` is not an integer or the group is not a string.
|
||||
ValueError: If group is 'hccl_world_group' or backend is invalid.
|
||||
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
||||
RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
|
||||
"""
|
||||
return _get_world_rank_from_group_rank_helper(group=group, group_rank_id=group_rank_id, backend=GlobalComm.BACKEND)
|
||||
|
||||
|
||||
def get_group_rank_from_world_rank(world_rank_id, group):
|
||||
"""
|
||||
Gets the rank ID in specified user communication group corresponding to the rank ID in world communication group.
|
||||
Get the rank ID in the specified user communication group corresponding to
|
||||
the rank ID in the world communication group.
|
||||
|
||||
Note:
|
||||
Nccl is not supported.
|
||||
NCCL is not supported.
|
||||
The parameter group should not be "hccl_world_group".
|
||||
|
||||
Args:
|
||||
world_rank_id (int): A rank ID in world communication group.
|
||||
world_rank_id (int): A rank ID in the world communication group.
|
||||
group (str): The user communication group.
|
||||
|
||||
Returns:
|
||||
int, the rank ID in user communication group.
|
||||
int, the rank ID in the user communication group.
|
||||
|
||||
Raises:
|
||||
TypeError: If world_rank_id is not a int or group is not a string.
|
||||
TypeError: If world_rank_id is not an integer or the group is not a string.
|
||||
ValueError: If group is 'hccl_world_group' or backend is invalid.
|
||||
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
||||
RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
|
||||
"""
|
||||
return _get_group_rank_from_world_rank_helper(world_rank_id=world_rank_id, group=group, backend=GlobalComm.BACKEND)
|
||||
|
||||
|
||||
def create_group(group, rank_ids):
|
||||
"""
|
||||
Creates user collective communication group.
|
||||
Create a user collective communication group.
|
||||
|
||||
Note:
|
||||
Nccl is not supported.
|
||||
NCCL is not supported.
|
||||
The size of rank_ids should be larger than 1.
|
||||
Rank_ids should not have duplicate data.
|
||||
|
||||
Args:
|
||||
group (str): ProcessGroup, the process group to create.
|
||||
rank_ids (list): List of device ID.
|
||||
rank_ids (list): A list of device IDs.
|
||||
|
||||
Raises:
|
||||
TypeError: If group is not a string or rank_ids is not a list.
|
||||
ValueError: If rank_ids size is not larger than 1 or rank_ids has duplicate data or backend is invalid.
|
||||
TypeError: If group is not a string or `rank_ids` is not a list.
|
||||
ValueError: If `rank_ids` size is not larger than 1, or `rank_ids` has duplicate data, or backend is invalid.
|
||||
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
||||
Examples:
|
||||
>>> group = "0-1"
|
||||
|
@ -247,7 +248,7 @@ def create_group(group, rank_ids):
|
|||
|
||||
def destroy_group(group):
|
||||
"""
|
||||
Destroys user collective communication group.
|
||||
Destroy the user collective communication group.
|
||||
|
||||
Note:
|
||||
Nccl is not supported.
|
||||
|
@ -259,6 +260,6 @@ def destroy_group(group):
|
|||
Raises:
|
||||
TypeError: If group is not a string.
|
||||
ValueError: If group is "hccl_world_group" or backend is invalid.
|
||||
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
||||
RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
|
||||
"""
|
||||
_destroy_group_helper(group, backend=GlobalComm.BACKEND)
|
||||
|
|
|
@ -336,6 +336,8 @@ def set_auto_parallel_context(**kwargs):
|
|||
"""
|
||||
Set auto parallel context.
|
||||
|
||||
Auto parallel context should be configured before the initialization of your network.
|
||||
|
||||
Note:
|
||||
Attribute name is required for setting attributes.
|
||||
If a program has tasks with different parallel modes, then before setting new parallel mode for the
|
||||
|
@ -344,12 +346,25 @@ def set_auto_parallel_context(**kwargs):
|
|||
Setting or changing parallel modes must be called before any creating Initializer, otherwise,
|
||||
RuntimeError may be raised when compiling the network.
|
||||
|
||||
Some configurations are parallel mode specific, see the below table for details:
|
||||
|
||||
=========================== =========================== =================
|
||||
Common AUTO_PARALLEL DATA_PRALLEL
|
||||
=========================== =========================== =================
|
||||
device_num gradient_fp32_sync enable_parallel_optimizer
|
||||
global_rank loss_repeated_mean
|
||||
gradients_mean auto_parallel_search_mode
|
||||
parallel_mode strategy_ckpt_load_file
|
||||
all_reduce_fusion_config strategy_ckpt_save_file
|
||||
full_batch
|
||||
=========================== =========================== =================
|
||||
|
||||
Args:
|
||||
device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
|
||||
global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
|
||||
gradients_mean (bool): Whether to perform mean operator after all-reduce of mirror.
|
||||
"stand_alone" does not support `gradients_mean`. Default: False.
|
||||
gradient_fp32_sync (bool): Gradients allreduce by fp32, even though gradients is fp16 if this flag is True..
|
||||
gradients_mean (bool): Whether to perform mean operator after allreduce of gradients.
|
||||
"stand_alone" do not support gradients_mean. Default: False.
|
||||
gradient_fp32_sync (bool): Run allreduce of gradients in fp32.
|
||||
"stand_alone", "data_parallel" and "hybrid_parallel" do not support
|
||||
gradient_fp32_sync. Default: True.
|
||||
parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
|
||||
|
@ -364,8 +379,8 @@ def set_auto_parallel_context(**kwargs):
|
|||
- semi_auto_parallel: Achieves data parallelism and model parallelism by
|
||||
setting parallel strategies.
|
||||
|
||||
- auto_parallel: Achieves parallelism automatically.
|
||||
auto_parallel_search_mode (str): There are two kinds of search modes, "recursive_programming"
|
||||
- auto_parallel: Achieving parallelism automatically.
|
||||
auto_parallel_search_mode (str): There are two kinds of shard strategy search modes, "recursive_programming"
|
||||
and "dynamic_programming". Default: "dynamic_programming".
|
||||
|
||||
- recursive_programming: Recursive programming search mode.
|
||||
|
@ -376,9 +391,11 @@ def set_auto_parallel_context(**kwargs):
|
|||
broadcast. Default: False.
|
||||
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
|
||||
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
|
||||
full_batch (bool): Whether to load the whole batch on each device. Default: False.
|
||||
enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation in
|
||||
data parallel training in the benefit of time and memory saving.
|
||||
full_batch (bool): If you load whole batch datasets in auto_parallel mode, this parameter
|
||||
should be set with True. Default: False.
|
||||
enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for
|
||||
data parallel training in the benefit of time and memory saving. For now,
|
||||
`Lamb` and `AdamWeightDecay` are supported in data parallel mode.
|
||||
all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM
|
||||
and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP.
|
||||
|
||||
|
@ -479,7 +496,7 @@ def set_context(**kwargs):
|
|||
Some configurations are device specific, see the bellow table for details:
|
||||
|
||||
=========================== =========================== =================
|
||||
Common(CPU/GPU/Asecend) Ascend GPU
|
||||
Common(CPU/GPU/Ascend) Ascend GPU
|
||||
=========================== =========================== =================
|
||||
check_bprop enable_auto_mixed_precision max_device_memory
|
||||
device_id enable_dump
|
||||
|
|
|
@ -33,7 +33,7 @@ __all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor
|
|||
|
||||
def add_flags(fn=None, **flags):
|
||||
"""
|
||||
An decorator to add flag for a function.
|
||||
A decorator that adds a flag to the function.
|
||||
|
||||
Note:
|
||||
Only supports bool value.
|
||||
|
@ -43,7 +43,7 @@ def add_flags(fn=None, **flags):
|
|||
flags (dict): Flags use kwargs. Default: None.
|
||||
|
||||
Returns:
|
||||
Function, the fn added flags.
|
||||
Function, the function with added flags.
|
||||
|
||||
Examples:
|
||||
>>> add_flags(net, predit=True)
|
||||
|
@ -63,9 +63,9 @@ def add_flags(fn=None, **flags):
|
|||
|
||||
def core(fn=None, **flags):
|
||||
"""
|
||||
A decorator to add flag to a function.
|
||||
A decorator that adds a flag to the function.
|
||||
|
||||
By default, the function is marked core=True using this decorator to
|
||||
By default, the function is marked as True, enabling to use this decorator to
|
||||
set flag to a graph.
|
||||
|
||||
Args:
|
||||
|
@ -91,11 +91,12 @@ def core(fn=None, **flags):
|
|||
|
||||
class GradOperation(GradOperation_):
|
||||
"""
|
||||
An higher-order function which is used to generate the gradient function for the input function.
|
||||
A higher-order function which is used to generate the gradient function for the input function.
|
||||
|
||||
The gradient function generated by `GradOperation` higher-order function can be customized by construction args.
|
||||
The gradient function generated by `GradOperation` higher-order function can be customized by
|
||||
construction arguments.
|
||||
|
||||
Given an input function `net = Net()` that take `x` and `y` as inputs, and has a parameter `z`,
|
||||
Given an input function `net = Net()` that takes `x` and `y` as inputs, and has a parameter `z`,
|
||||
see `Net` in Examples.
|
||||
|
||||
To generate a gradient function that returns gradients with respect to the first input
|
||||
|
@ -126,7 +127,7 @@ class GradOperation(GradOperation_):
|
|||
1. Construct a `GradOperation` higher-order function with `get_by_list=True`:
|
||||
`grad_op = GradOperation(get_by_list=True)`.
|
||||
|
||||
2. Construct a `ParameterTuple` that will be passed along input function when constructing
|
||||
2. Construct a `ParameterTuple` that will be passed to the input function when constructing
|
||||
`GradOperation` higher-order function, it will be used as a parameter filter that determine
|
||||
which gradient to return: `params = ParameterTuple(net.trainable_params())`.
|
||||
|
||||
|
@ -151,20 +152,20 @@ class GradOperation(GradOperation_):
|
|||
4. Call the gradient function with input function's inputs
|
||||
to get the gradients with respect to all inputs and given parameters: `gradient_function(x, y)`.
|
||||
|
||||
We can configure the sensitiviy(gradient with respect to output) by setting `sens_param=True` and
|
||||
passing in an extra sensitiviy input to the gradient function, the sensitiviy input should be
|
||||
with same shape and type with input function's output(see `GradNetWrtXYWithSensParam` in Examples).
|
||||
We can configure the sensitivity(gradient with respect to output) by setting `sens_param` as True and
|
||||
passing an extra sensitivity input to the gradient function, the sensitivity input should has the
|
||||
same shape and type with input function's output(see `GradNetWrtXYWithSensParam` in Examples).
|
||||
|
||||
1. Construct a `GradOperation` higher-order function with `get_all=True` and `sens_param=True`:
|
||||
`grad_op = GradOperation(get_all=True, sens_param=True)`.
|
||||
|
||||
2. Define grad_wrt_output as sens_param which works as the gradient with respect to output:
|
||||
2. Define `grad_wrt_output` as `sens_param` which works as the gradient with respect to output:
|
||||
`grad_wrt_output = Tensor(np.ones([2, 2]).astype(np.float32))`.
|
||||
|
||||
3. Call it with input function as argument to get the gradient function:
|
||||
`gradient_function = grad_op(net)`.
|
||||
|
||||
4. Call the gradient function with input function's inputs and sens_param to
|
||||
4. Call the gradient function with input function's inputs and `sens_param` to
|
||||
get the gradients with respect to all inputs:
|
||||
`gradient_function(x, y, grad_wrt_output)`.
|
||||
|
||||
|
@ -175,8 +176,9 @@ class GradOperation(GradOperation_):
|
|||
If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables
|
||||
at the same time in the form of ((gradients with respect to inputs),
|
||||
(gradients with respect to parameters)). Default: False.
|
||||
sens_param (bool): Whether append sensitivity(gradient with respect to output) as input. If sens_param is False,
|
||||
a 'ones_like(outputs)' sensitivity will be attached automatically. Default: False.
|
||||
sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input.
|
||||
If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically.
|
||||
Default: False.
|
||||
|
||||
Returns:
|
||||
The higher-order function which takes a function as argument and returns gradient function for it.
|
||||
|
@ -349,9 +351,9 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
|||
"""
|
||||
Generate overloaded functions.
|
||||
|
||||
MultitypeFuncGraph is a class used to generate overloaded functions with different type as inputs.
|
||||
MultitypeFuncGraph is a class used to generate overloaded functions, considering different types as inputs.
|
||||
Initialize an `MultitypeFuncGraph` object with name, and use `register` with input types as the decorator
|
||||
for the function to be registed. And the object can be called with different type of inputs,
|
||||
for the function to be registed. And the object can be called with different types of inputs,
|
||||
and work with `HyperMap` and `Map`.
|
||||
|
||||
Args:
|
||||
|
@ -360,7 +362,7 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
|||
and all inputs will pass by value, set `read_value` to True. Default: False.
|
||||
|
||||
Raises:
|
||||
ValueError: Cannot find matching functions for the given args.
|
||||
ValueError: If failed to find find a matching function for the given arguments.
|
||||
|
||||
Examples:
|
||||
>>> # `add` is a metagraph object which will add two objects according to
|
||||
|
@ -431,7 +433,7 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
|||
|
||||
class HyperMap(HyperMap_):
|
||||
"""
|
||||
Hypermap will apply the set operation on input sequences.
|
||||
Hypermap will apply the set operation to input sequences.
|
||||
|
||||
Apply the operations to every elements of the sequence or nested sequence. Different
|
||||
from `Map`, the `HyperMap` supports to apply on nested structure.
|
||||
|
@ -441,11 +443,10 @@ class HyperMap(HyperMap_):
|
|||
the operations should be put in the first input of the instance.
|
||||
|
||||
Inputs:
|
||||
- **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences,
|
||||
and each row of the sequences. e.g. If args length is 2, and for `i` in length of each sequence
|
||||
`(args[0][i], args[1][i])` will be the input of the operation.
|
||||
- **args** (Tuple[sequence]) - If `ops` is `None`, all the inputs should be sequences with the same length.
|
||||
And each row of the sequences will be the inputs of the operation.
|
||||
|
||||
If `ops` is not `None`, the first input is the operation, and the other is inputs.
|
||||
If `ops` is not `None`, the first input is the operation, and the others are inputs.
|
||||
|
||||
Outputs:
|
||||
Sequence or nested sequence, the sequence of output after applying the function.
|
||||
|
|
|
@ -48,14 +48,15 @@ def normal(shape, mean, stddev, seed=0):
|
|||
Args:
|
||||
shape (tuple): The shape of random tensor to be generated.
|
||||
mean (Tensor): The mean μ distribution parameter, which specifies the location of the peak.
|
||||
With float32 data type.
|
||||
with float32 data type.
|
||||
stddev (Tensor): The deviation σ distribution parameter. It should be greater than 0.
|
||||
With float32 data type.
|
||||
seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
|
||||
Must be non-negative. Default: 0.
|
||||
with float32 data type.
|
||||
seed (int): Seed is used as entropy source for the Random number engines to generate pseudo-random numbers.
|
||||
must be non-negative. Default: 0.
|
||||
|
||||
Returns:
|
||||
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of mean and stddev.
|
||||
Tensor. The shape should be equal to the broadcasted shape between the input `shape` and shapes
|
||||
of `mean` and `stddev`.
|
||||
The dtype is float32.
|
||||
|
||||
Examples:
|
||||
|
@ -123,20 +124,21 @@ def uniform(shape, minval, maxval, seed=0, dtype=mstype.float32):
|
|||
|
||||
Args:
|
||||
shape (tuple): The shape of random tensor to be generated.
|
||||
minval (Tensor): The a distribution parameter.
|
||||
It defines the minimum possibly generated value. With int32 or float32 data type.
|
||||
minval (Tensor): The distribution parameter `a`.
|
||||
It defines the minimum possible generated value, with int32 or float32 data type.
|
||||
If dtype is int32, only one number is allowed.
|
||||
maxval (Tensor): The b distribution parameter.
|
||||
It defines the maximum possibly generated value. With int32 or float32 data type.
|
||||
maxval (Tensor): The distribution parameter `b`.
|
||||
It defines the maximum possible generated value, with int32 or float32 data type.
|
||||
If dtype is int32, only one number is allowed.
|
||||
seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
|
||||
Must be non-negative. Default: 0.
|
||||
seed (int): Seed is used as entropy source for the random number engines to generate pseudo-random numbers,
|
||||
must be non-negative. Default: 0.
|
||||
dtype (mindspore.dtype): type of the Uniform distribution. If it is int32, it generates numbers from discrete
|
||||
uniform distribution; if it is float32, it generates numbers from continuous uniform distribution. It only
|
||||
supports these two data types. Default: mstype.float32.
|
||||
|
||||
Returns:
|
||||
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of minval and maxval.
|
||||
Tensor. The shape should be equal to the broadcasted shape between the input `shape` and shapes
|
||||
of `minval` and `maxval`.
|
||||
The dtype is designated as the input `dtype`.
|
||||
|
||||
Examples:
|
||||
|
@ -175,13 +177,14 @@ def gamma(shape, alpha, beta, seed=0):
|
|||
|
||||
Args:
|
||||
shape (tuple): The shape of random tensor to be generated.
|
||||
alpha (Tensor): The alpha α distribution parameter. It should be greater than 0. With float32 data type.
|
||||
beta (Tensor): The beta β distribution parameter. It should be greater than 0. With float32 data type.
|
||||
seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
|
||||
Must be non-negative. Default: 0.
|
||||
alpha (Tensor): The alpha α distribution parameter. It should be greater than 0 with float32 data type.
|
||||
beta (Tensor): The beta β distribution parameter. It should be greater than 0 with float32 data type.
|
||||
seed (int): Seed is used as entropy source for the random number engines to generate
|
||||
pseudo-random numbers, must be non-negative. Default: 0.
|
||||
|
||||
Returns:
|
||||
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of alpha and beta.
|
||||
Tensor. The shape should be equal to the broadcasted shape between the input "shape" and shapes
|
||||
of `alpha` and `beta`.
|
||||
The dtype is float32.
|
||||
|
||||
Examples:
|
||||
|
@ -203,12 +206,12 @@ def poisson(shape, mean, seed=0):
|
|||
|
||||
Args:
|
||||
shape (tuple): The shape of random tensor to be generated.
|
||||
mean (Tensor): The mean μ distribution parameter. It should be greater than 0. With float32 data type.
|
||||
seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
|
||||
Must be non-negative. Default: 0.
|
||||
mean (Tensor): The mean μ distribution parameter. It should be greater than 0 with float32 data type.
|
||||
seed (int): Seed is used as entropy source for the random number engines to generate pseudo-random numbers
|
||||
and must be non-negative. Default: 0.
|
||||
|
||||
Returns:
|
||||
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of mean.
|
||||
Tensor. The shape should be equal to the broadcasted shape between the input "shape" and shapes of `mean`.
|
||||
The dtype is float32.
|
||||
|
||||
Examples:
|
||||
|
@ -226,21 +229,23 @@ def poisson(shape, mean, seed=0):
|
|||
def multinomial(inputs, num_sample, replacement=True, seed=0):
|
||||
r"""
|
||||
Returns a tensor sampled from the multinomial probability distribution located in the corresponding
|
||||
row of tensor input.
|
||||
row of the input tensor.
|
||||
|
||||
Note:
|
||||
The rows of input do not need to sum to one (in which case we use the values as weights),
|
||||
but must be non-negative, finite and have a non-zero sum.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): the input tensor containing probabilities, must be 1 or 2 dims. With float32 data type.
|
||||
num_sample (int): number of samples to draw.
|
||||
replacement (bool, optional): whether to draw with replacement or not, default True.
|
||||
seed (int, optional): used as entropy source for Random number engines generating pseudo-random numbers.
|
||||
Must be non-negative. Default: 0.
|
||||
inputs (Tensor): The input tensor containing probabilities, must be 1 or 2 dimensions, with
|
||||
float32 data type.
|
||||
num_sample (int): Number of samples to draw.
|
||||
replacement (bool, optional): Whether to draw with replacement or not, default True.
|
||||
seed (int, optional): Seed is used as entropy source for the random number engines to generate
|
||||
pseudo-random numbers,
|
||||
must be non-negative. Default: 0.
|
||||
|
||||
Outputs:
|
||||
Tensor. have the same rows with input, each row has num_samples sampled indices.
|
||||
Tensor, has the same rows with input. The number of sampled indices of each row is `num_samples`.
|
||||
The dtype is float32.
|
||||
|
||||
Examples:
|
||||
|
|
|
@ -197,6 +197,9 @@ class _AutoParallelContext:
|
|||
parameter_broadcast (bool): Parameter broadcast or not.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
if parameter_broadcast is True and context.get_context("enable_ge") is False:
|
||||
raise RuntimeError("Parameter broadcast is a developing feature. For now we suggest to"
|
||||
" use mindspore.common.set_seed() to share parameters among devices.")
|
||||
self._context_handle.set_parameter_broadcast(parameter_broadcast)
|
||||
|
||||
def get_parameter_broadcast(self):
|
||||
|
|
|
@ -58,7 +58,7 @@ if __name__ == '__main__':
|
|||
cfg.group_size = get_group_size()
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size,
|
||||
parameter_broadcast=True, gradients_mean=True)
|
||||
gradients_mean=True)
|
||||
else:
|
||||
cfg.rank = 0
|
||||
cfg.group_size = 1
|
||||
|
|
|
@ -61,7 +61,7 @@ if __name__ == '__main__':
|
|||
cfg.group_size = get_group_size()
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size,
|
||||
parameter_broadcast=True, gradients_mean=True)
|
||||
gradients_mean=True)
|
||||
else:
|
||||
cfg.rank = 0
|
||||
cfg.group_size = 1
|
||||
|
|
|
@ -136,8 +136,7 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl):
|
|||
os.environ['RANK_SIZE'] = str(device_num)
|
||||
if enable_hccl:
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True, parameter_broadcast=True,
|
||||
all_reduce_fusion_config=[107, 160])
|
||||
gradients_mean=True, all_reduce_fusion_config=[107, 160])
|
||||
init()
|
||||
|
||||
# network
|
||||
|
@ -239,8 +238,7 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl):
|
|||
os.environ['RANK_SIZE'] = str(device_num)
|
||||
if enable_hccl:
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True, parameter_broadcast=True,
|
||||
all_reduce_fusion_config=[107])
|
||||
gradients_mean=True, all_reduce_fusion_config=[107])
|
||||
init()
|
||||
|
||||
# network
|
||||
|
|
Loading…
Reference in New Issue