fix_api_problems

This commit is contained in:
Ziyan 2020-09-09 16:52:56 +08:00
parent 19bdba56f9
commit 8ea177e614
8 changed files with 125 additions and 100 deletions

View File

@ -32,33 +32,32 @@ DEFAULT_BACKEND = Backend("hccl")
def _get_group(group): def _get_group(group):
"""Get the global world group if the group is default world comm group.""" """Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`."""
if group == DEFAULT_WORLD_COMM_GROUP: if group == DEFAULT_WORLD_COMM_GROUP:
return GlobalComm.WORLD_COMM_GROUP return GlobalComm.WORLD_COMM_GROUP
return group return group
class GlobalComm: class GlobalComm:
"""Global communication info.""" """World communication information."""
BACKEND = DEFAULT_BACKEND BACKEND = DEFAULT_BACKEND
WORLD_COMM_GROUP = DEFAULT_WORLD_COMM_GROUP WORLD_COMM_GROUP = DEFAULT_WORLD_COMM_GROUP
def init(backend_name=None): def init(backend_name=None):
""" """
Init distributed backend, e.g., hccl/nccl, it is required before communication service can be used. Initialize distributed backend, e.g. HCCL/NCCL, it is required before using the communication service.
Note: Note:
The full name of hccl is Huawei Collective Communication Library. The full name of HCCL is Huawei Collective Communication Library.
The full name of nccl is NVIDIA Collective Communication Library. The full name of NCCL is NVIDIA Collective Communication Library.
Args: Args:
backend_name (str): Backend. backend_name (str): Backend.
Raises: Raises:
TypeError: If backen_name is not a string. TypeError: If `backend_name` is not a string.
RuntimeError: If device target is invalid. RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
RuntimeError: If backend is invalid or distributed init fails.
""" """
if _is_role_pserver() or _is_role_sched(): if _is_role_pserver() or _is_role_sched():
return return
@ -88,17 +87,17 @@ def init(backend_name=None):
def release(): def release():
""" """
Release distributed resource. e.g., hccl/nccl. Release distributed resource. e.g. HCCL/NCCL.
Raises: Raises:
RuntimeError: If distributed resource release fails. RuntimeError: If failed to release distributed resource.
""" """
finalize_hccl() finalize_hccl()
def get_rank(group=GlobalComm.WORLD_COMM_GROUP): def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
""" """
Gets rank ID for current device in specified collective communication group. Get the rank ID for the current device in the specified collective communication group.
Args: Args:
group (str): ProcessGroup, the process group to work on. Default: WORLD_COMM_GROUP. group (str): ProcessGroup, the process group to work on. Default: WORLD_COMM_GROUP.
@ -109,7 +108,7 @@ def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
Raises: Raises:
TypeError: If group is not a string. TypeError: If group is not a string.
ValueError: If backend is invalid. ValueError: If backend is invalid.
RuntimeError: If hccl/nccl is not available or nccl not supports. RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
""" """
return _get_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND) return _get_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
@ -130,14 +129,14 @@ def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP):
Raises: Raises:
TypeError: If group is not a string. TypeError: If group is not a string.
ValueError: If backend is invalid. ValueError: If backend is invalid.
RuntimeError: If hccl/nccl is not available or nccl not supports. RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
""" """
return _get_local_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND) return _get_local_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
def get_group_size(group=GlobalComm.WORLD_COMM_GROUP): def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
""" """
Gets rank size of the specified collective communication group. Get the rank size of the specified collective communication group.
Args: Args:
group (str): ProcessGroup, the process group to work on. group (str): ProcessGroup, the process group to work on.
@ -148,7 +147,7 @@ def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
Raises: Raises:
TypeError: If group is not a string. TypeError: If group is not a string.
ValueError: If backend is invalid. ValueError: If backend is invalid.
RuntimeError: If hccl/nccl is not available or nccl not supports. RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
""" """
return _get_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND) return _get_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
@ -164,22 +163,23 @@ def get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP):
group (str): ProcessGroup, the process group to work on. group (str): ProcessGroup, the process group to work on.
Returns: Returns:
int, the local rank size where the calling process is being within the group. int, the local rank size where the calling process is within the group.
Raises: Raises:
TypeError: If group is not a string. TypeError: If group is not a string.
ValueError: If backend is invalid. ValueError: If backend is invalid.
RuntimeError: If hccl/nccl is not available or nccl not supports. RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
""" """
return _get_local_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND) return _get_local_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
def get_world_rank_from_group_rank(group, group_rank_id): def get_world_rank_from_group_rank(group, group_rank_id):
""" """
Gets the rank ID in world communication group corresponding to the rank ID in specified user communication group. Gets the rank ID in the world communication group corresponding to
the rank ID in the specified user communication group.
Note: Note:
Nccl is not supported. NCCL is not supported.
The parameter group should not be "hccl_world_group". The parameter group should not be "hccl_world_group".
Args: Args:
@ -190,52 +190,53 @@ def get_world_rank_from_group_rank(group, group_rank_id):
int, the rank ID in world communication group. int, the rank ID in world communication group.
Raises: Raises:
TypeError: If group_rank_id is not a int or group is not a string. TypeError: If `group_rank_id` is not an integer or the group is not a string.
ValueError: If group is 'hccl_world_group' or backend is invalid. ValueError: If group is 'hccl_world_group' or backend is invalid.
RuntimeError: If hccl/nccl is not available or nccl not supports. RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
""" """
return _get_world_rank_from_group_rank_helper(group=group, group_rank_id=group_rank_id, backend=GlobalComm.BACKEND) return _get_world_rank_from_group_rank_helper(group=group, group_rank_id=group_rank_id, backend=GlobalComm.BACKEND)
def get_group_rank_from_world_rank(world_rank_id, group): def get_group_rank_from_world_rank(world_rank_id, group):
""" """
Gets the rank ID in specified user communication group corresponding to the rank ID in world communication group. Get the rank ID in the specified user communication group corresponding to
the rank ID in the world communication group.
Note: Note:
Nccl is not supported. NCCL is not supported.
The parameter group should not be "hccl_world_group". The parameter group should not be "hccl_world_group".
Args: Args:
world_rank_id (int): A rank ID in world communication group. world_rank_id (int): A rank ID in the world communication group.
group (str): The user communication group. group (str): The user communication group.
Returns: Returns:
int, the rank ID in user communication group. int, the rank ID in the user communication group.
Raises: Raises:
TypeError: If world_rank_id is not a int or group is not a string. TypeError: If world_rank_id is not an integer or the group is not a string.
ValueError: If group is 'hccl_world_group' or backend is invalid. ValueError: If group is 'hccl_world_group' or backend is invalid.
RuntimeError: If hccl/nccl is not available or nccl not supports. RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
""" """
return _get_group_rank_from_world_rank_helper(world_rank_id=world_rank_id, group=group, backend=GlobalComm.BACKEND) return _get_group_rank_from_world_rank_helper(world_rank_id=world_rank_id, group=group, backend=GlobalComm.BACKEND)
def create_group(group, rank_ids): def create_group(group, rank_ids):
""" """
Creates user collective communication group. Create a user collective communication group.
Note: Note:
Nccl is not supported. NCCL is not supported.
The size of rank_ids should be larger than 1. The size of rank_ids should be larger than 1.
Rank_ids should not have duplicate data. Rank_ids should not have duplicate data.
Args: Args:
group (str): ProcessGroup, the process group to create. group (str): ProcessGroup, the process group to create.
rank_ids (list): List of device ID. rank_ids (list): A list of device IDs.
Raises: Raises:
TypeError: If group is not a string or rank_ids is not a list. TypeError: If group is not a string or `rank_ids` is not a list.
ValueError: If rank_ids size is not larger than 1 or rank_ids has duplicate data or backend is invalid. ValueError: If `rank_ids` size is not larger than 1, or `rank_ids` has duplicate data, or backend is invalid.
RuntimeError: If hccl/nccl is not available or nccl not supports. RuntimeError: If hccl/nccl is not available or nccl not supports.
Examples: Examples:
>>> group = "0-1" >>> group = "0-1"
@ -247,7 +248,7 @@ def create_group(group, rank_ids):
def destroy_group(group): def destroy_group(group):
""" """
Destroys user collective communication group. Destroy the user collective communication group.
Note: Note:
Nccl is not supported. Nccl is not supported.
@ -259,6 +260,6 @@ def destroy_group(group):
Raises: Raises:
TypeError: If group is not a string. TypeError: If group is not a string.
ValueError: If group is "hccl_world_group" or backend is invalid. ValueError: If group is "hccl_world_group" or backend is invalid.
RuntimeError: If hccl/nccl is not available or nccl not supports. RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
""" """
_destroy_group_helper(group, backend=GlobalComm.BACKEND) _destroy_group_helper(group, backend=GlobalComm.BACKEND)

View File

@ -336,6 +336,8 @@ def set_auto_parallel_context(**kwargs):
""" """
Set auto parallel context. Set auto parallel context.
Auto parallel context should be configured before the initialization of your network.
Note: Note:
Attribute name is required for setting attributes. Attribute name is required for setting attributes.
If a program has tasks with different parallel modes, then before setting new parallel mode for the If a program has tasks with different parallel modes, then before setting new parallel mode for the
@ -344,12 +346,25 @@ def set_auto_parallel_context(**kwargs):
Setting or changing parallel modes must be called before any creating Initializer, otherwise, Setting or changing parallel modes must be called before any creating Initializer, otherwise,
RuntimeError may be raised when compiling the network. RuntimeError may be raised when compiling the network.
Some configurations are parallel mode specific, see the below table for details:
=========================== =========================== =================
Common AUTO_PARALLEL DATA_PRALLEL
=========================== =========================== =================
device_num gradient_fp32_sync enable_parallel_optimizer
global_rank loss_repeated_mean
gradients_mean auto_parallel_search_mode
parallel_mode strategy_ckpt_load_file
all_reduce_fusion_config strategy_ckpt_save_file
full_batch
=========================== =========================== =================
Args: Args:
device_num (int): Available device number, the value must be in [1, 4096]. Default: 1. device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0. global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
gradients_mean (bool): Whether to perform mean operator after all-reduce of mirror. gradients_mean (bool): Whether to perform mean operator after allreduce of gradients.
"stand_alone" does not support `gradients_mean`. Default: False. "stand_alone" do not support gradients_mean. Default: False.
gradient_fp32_sync (bool): Gradients allreduce by fp32, even though gradients is fp16 if this flag is True.. gradient_fp32_sync (bool): Run allreduce of gradients in fp32.
"stand_alone", "data_parallel" and "hybrid_parallel" do not support "stand_alone", "data_parallel" and "hybrid_parallel" do not support
gradient_fp32_sync. Default: True. gradient_fp32_sync. Default: True.
parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel", parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
@ -364,8 +379,8 @@ def set_auto_parallel_context(**kwargs):
- semi_auto_parallel: Achieves data parallelism and model parallelism by - semi_auto_parallel: Achieves data parallelism and model parallelism by
setting parallel strategies. setting parallel strategies.
- auto_parallel: Achieves parallelism automatically. - auto_parallel: Achieving parallelism automatically.
auto_parallel_search_mode (str): There are two kinds of search modes, "recursive_programming" auto_parallel_search_mode (str): There are two kinds of shard strategy search modes, "recursive_programming"
and "dynamic_programming". Default: "dynamic_programming". and "dynamic_programming". Default: "dynamic_programming".
- recursive_programming: Recursive programming search mode. - recursive_programming: Recursive programming search mode.
@ -376,9 +391,11 @@ def set_auto_parallel_context(**kwargs):
broadcast. Default: False. broadcast. Default: False.
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
full_batch (bool): Whether to load the whole batch on each device. Default: False. full_batch (bool): If you load whole batch datasets in auto_parallel mode, this parameter
enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation in should be set with True. Default: False.
data parallel training in the benefit of time and memory saving. enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for
data parallel training in the benefit of time and memory saving. For now,
`Lamb` and `AdamWeightDecay` are supported in data parallel mode.
all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM
and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP.
@ -479,7 +496,7 @@ def set_context(**kwargs):
Some configurations are device specific, see the bellow table for details: Some configurations are device specific, see the bellow table for details:
=========================== =========================== ================= =========================== =========================== =================
Common(CPU/GPU/Asecend) Ascend GPU Common(CPU/GPU/Ascend) Ascend GPU
=========================== =========================== ================= =========================== =========================== =================
check_bprop enable_auto_mixed_precision max_device_memory check_bprop enable_auto_mixed_precision max_device_memory
device_id enable_dump device_id enable_dump

View File

@ -33,7 +33,7 @@ __all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor
def add_flags(fn=None, **flags): def add_flags(fn=None, **flags):
""" """
An decorator to add flag for a function. A decorator that adds a flag to the function.
Note: Note:
Only supports bool value. Only supports bool value.
@ -43,7 +43,7 @@ def add_flags(fn=None, **flags):
flags (dict): Flags use kwargs. Default: None. flags (dict): Flags use kwargs. Default: None.
Returns: Returns:
Function, the fn added flags. Function, the function with added flags.
Examples: Examples:
>>> add_flags(net, predit=True) >>> add_flags(net, predit=True)
@ -63,9 +63,9 @@ def add_flags(fn=None, **flags):
def core(fn=None, **flags): def core(fn=None, **flags):
""" """
A decorator to add flag to a function. A decorator that adds a flag to the function.
By default, the function is marked core=True using this decorator to By default, the function is marked as True, enabling to use this decorator to
set flag to a graph. set flag to a graph.
Args: Args:
@ -91,11 +91,12 @@ def core(fn=None, **flags):
class GradOperation(GradOperation_): class GradOperation(GradOperation_):
""" """
An higher-order function which is used to generate the gradient function for the input function. A higher-order function which is used to generate the gradient function for the input function.
The gradient function generated by `GradOperation` higher-order function can be customized by construction args. The gradient function generated by `GradOperation` higher-order function can be customized by
construction arguments.
Given an input function `net = Net()` that take `x` and `y` as inputs, and has a parameter `z`, Given an input function `net = Net()` that takes `x` and `y` as inputs, and has a parameter `z`,
see `Net` in Examples. see `Net` in Examples.
To generate a gradient function that returns gradients with respect to the first input To generate a gradient function that returns gradients with respect to the first input
@ -126,7 +127,7 @@ class GradOperation(GradOperation_):
1. Construct a `GradOperation` higher-order function with `get_by_list=True`: 1. Construct a `GradOperation` higher-order function with `get_by_list=True`:
`grad_op = GradOperation(get_by_list=True)`. `grad_op = GradOperation(get_by_list=True)`.
2. Construct a `ParameterTuple` that will be passed along input function when constructing 2. Construct a `ParameterTuple` that will be passed to the input function when constructing
`GradOperation` higher-order function, it will be used as a parameter filter that determine `GradOperation` higher-order function, it will be used as a parameter filter that determine
which gradient to return: `params = ParameterTuple(net.trainable_params())`. which gradient to return: `params = ParameterTuple(net.trainable_params())`.
@ -151,20 +152,20 @@ class GradOperation(GradOperation_):
4. Call the gradient function with input function's inputs 4. Call the gradient function with input function's inputs
to get the gradients with respect to all inputs and given parameters: `gradient_function(x, y)`. to get the gradients with respect to all inputs and given parameters: `gradient_function(x, y)`.
We can configure the sensitiviy(gradient with respect to output) by setting `sens_param=True` and We can configure the sensitivity(gradient with respect to output) by setting `sens_param` as True and
passing in an extra sensitiviy input to the gradient function, the sensitiviy input should be passing an extra sensitivity input to the gradient function, the sensitivity input should has the
with same shape and type with input function's output(see `GradNetWrtXYWithSensParam` in Examples). same shape and type with input function's output(see `GradNetWrtXYWithSensParam` in Examples).
1. Construct a `GradOperation` higher-order function with `get_all=True` and `sens_param=True`: 1. Construct a `GradOperation` higher-order function with `get_all=True` and `sens_param=True`:
`grad_op = GradOperation(get_all=True, sens_param=True)`. `grad_op = GradOperation(get_all=True, sens_param=True)`.
2. Define grad_wrt_output as sens_param which works as the gradient with respect to output: 2. Define `grad_wrt_output` as `sens_param` which works as the gradient with respect to output:
`grad_wrt_output = Tensor(np.ones([2, 2]).astype(np.float32))`. `grad_wrt_output = Tensor(np.ones([2, 2]).astype(np.float32))`.
3. Call it with input function as argument to get the gradient function: 3. Call it with input function as argument to get the gradient function:
`gradient_function = grad_op(net)`. `gradient_function = grad_op(net)`.
4. Call the gradient function with input function's inputs and sens_param to 4. Call the gradient function with input function's inputs and `sens_param` to
get the gradients with respect to all inputs: get the gradients with respect to all inputs:
`gradient_function(x, y, grad_wrt_output)`. `gradient_function(x, y, grad_wrt_output)`.
@ -175,8 +176,9 @@ class GradOperation(GradOperation_):
If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables
at the same time in the form of ((gradients with respect to inputs), at the same time in the form of ((gradients with respect to inputs),
(gradients with respect to parameters)). Default: False. (gradients with respect to parameters)). Default: False.
sens_param (bool): Whether append sensitivity(gradient with respect to output) as input. If sens_param is False, sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input.
a 'ones_like(outputs)' sensitivity will be attached automatically. Default: False. If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically.
Default: False.
Returns: Returns:
The higher-order function which takes a function as argument and returns gradient function for it. The higher-order function which takes a function as argument and returns gradient function for it.
@ -349,9 +351,9 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
""" """
Generate overloaded functions. Generate overloaded functions.
MultitypeFuncGraph is a class used to generate overloaded functions with different type as inputs. MultitypeFuncGraph is a class used to generate overloaded functions, considering different types as inputs.
Initialize an `MultitypeFuncGraph` object with name, and use `register` with input types as the decorator Initialize an `MultitypeFuncGraph` object with name, and use `register` with input types as the decorator
for the function to be registed. And the object can be called with different type of inputs, for the function to be registed. And the object can be called with different types of inputs,
and work with `HyperMap` and `Map`. and work with `HyperMap` and `Map`.
Args: Args:
@ -360,7 +362,7 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
and all inputs will pass by value, set `read_value` to True. Default: False. and all inputs will pass by value, set `read_value` to True. Default: False.
Raises: Raises:
ValueError: Cannot find matching functions for the given args. ValueError: If failed to find find a matching function for the given arguments.
Examples: Examples:
>>> # `add` is a metagraph object which will add two objects according to >>> # `add` is a metagraph object which will add two objects according to
@ -431,7 +433,7 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
class HyperMap(HyperMap_): class HyperMap(HyperMap_):
""" """
Hypermap will apply the set operation on input sequences. Hypermap will apply the set operation to input sequences.
Apply the operations to every elements of the sequence or nested sequence. Different Apply the operations to every elements of the sequence or nested sequence. Different
from `Map`, the `HyperMap` supports to apply on nested structure. from `Map`, the `HyperMap` supports to apply on nested structure.
@ -441,11 +443,10 @@ class HyperMap(HyperMap_):
the operations should be put in the first input of the instance. the operations should be put in the first input of the instance.
Inputs: Inputs:
- **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences, - **args** (Tuple[sequence]) - If `ops` is `None`, all the inputs should be sequences with the same length.
and each row of the sequences. e.g. If args length is 2, and for `i` in length of each sequence And each row of the sequences will be the inputs of the operation.
`(args[0][i], args[1][i])` will be the input of the operation.
If `ops` is not `None`, the first input is the operation, and the other is inputs. If `ops` is not `None`, the first input is the operation, and the others are inputs.
Outputs: Outputs:
Sequence or nested sequence, the sequence of output after applying the function. Sequence or nested sequence, the sequence of output after applying the function.

View File

@ -48,14 +48,15 @@ def normal(shape, mean, stddev, seed=0):
Args: Args:
shape (tuple): The shape of random tensor to be generated. shape (tuple): The shape of random tensor to be generated.
mean (Tensor): The mean μ distribution parameter, which specifies the location of the peak. mean (Tensor): The mean μ distribution parameter, which specifies the location of the peak.
With float32 data type. with float32 data type.
stddev (Tensor): The deviation σ distribution parameter. It should be greater than 0. stddev (Tensor): The deviation σ distribution parameter. It should be greater than 0.
With float32 data type. with float32 data type.
seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers. seed (int): Seed is used as entropy source for the Random number engines to generate pseudo-random numbers.
Must be non-negative. Default: 0. must be non-negative. Default: 0.
Returns: Returns:
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of mean and stddev. Tensor. The shape should be equal to the broadcasted shape between the input `shape` and shapes
of `mean` and `stddev`.
The dtype is float32. The dtype is float32.
Examples: Examples:
@ -85,20 +86,21 @@ def uniform(shape, minval, maxval, seed=0, dtype=mstype.float32):
Args: Args:
shape (tuple): The shape of random tensor to be generated. shape (tuple): The shape of random tensor to be generated.
minval (Tensor): The a distribution parameter. minval (Tensor): The distribution parameter `a`.
It defines the minimum possibly generated value. With int32 or float32 data type. It defines the minimum possible generated value, with int32 or float32 data type.
If dtype is int32, only one number is allowed. If dtype is int32, only one number is allowed.
maxval (Tensor): The b distribution parameter. maxval (Tensor): The distribution parameter `b`.
It defines the maximum possibly generated value. With int32 or float32 data type. It defines the maximum possible generated value, with int32 or float32 data type.
If dtype is int32, only one number is allowed. If dtype is int32, only one number is allowed.
seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers. seed (int): Seed is used as entropy source for the random number engines to generate pseudo-random numbers,
Must be non-negative. Default: 0. must be non-negative. Default: 0.
dtype (mindspore.dtype): type of the Uniform distribution. If it is int32, it generates numbers from discrete dtype (mindspore.dtype): type of the Uniform distribution. If it is int32, it generates numbers from discrete
uniform distribution; if it is float32, it generates numbers from continuous uniform distribution. It only uniform distribution; if it is float32, it generates numbers from continuous uniform distribution. It only
supports these two data types. Default: mstype.float32. supports these two data types. Default: mstype.float32.
Returns: Returns:
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of minval and maxval. Tensor. The shape should be equal to the broadcasted shape between the input `shape` and shapes
of `minval` and `maxval`.
The dtype is designated as the input `dtype`. The dtype is designated as the input `dtype`.
Examples: Examples:
@ -137,13 +139,14 @@ def gamma(shape, alpha, beta, seed=0):
Args: Args:
shape (tuple): The shape of random tensor to be generated. shape (tuple): The shape of random tensor to be generated.
alpha (Tensor): The alpha α distribution parameter. It should be greater than 0. With float32 data type. alpha (Tensor): The alpha α distribution parameter. It should be greater than 0 with float32 data type.
beta (Tensor): The beta β distribution parameter. It should be greater than 0. With float32 data type. beta (Tensor): The beta β distribution parameter. It should be greater than 0 with float32 data type.
seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers. seed (int): Seed is used as entropy source for the random number engines to generate
Must be non-negative. Default: 0. pseudo-random numbers, must be non-negative. Default: 0.
Returns: Returns:
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of alpha and beta. Tensor. The shape should be equal to the broadcasted shape between the input "shape" and shapes
of `alpha` and `beta`.
The dtype is float32. The dtype is float32.
Examples: Examples:
@ -165,12 +168,12 @@ def poisson(shape, mean, seed=0):
Args: Args:
shape (tuple): The shape of random tensor to be generated. shape (tuple): The shape of random tensor to be generated.
mean (Tensor): The mean μ distribution parameter. It should be greater than 0. With float32 data type. mean (Tensor): The mean μ distribution parameter. It should be greater than 0 with float32 data type.
seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers. seed (int): Seed is used as entropy source for the random number engines to generate pseudo-random numbers
Must be non-negative. Default: 0. and must be non-negative. Default: 0.
Returns: Returns:
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of mean. Tensor. The shape should be equal to the broadcasted shape between the input "shape" and shapes of `mean`.
The dtype is float32. The dtype is float32.
Examples: Examples:
@ -188,21 +191,23 @@ def poisson(shape, mean, seed=0):
def multinomial(inputs, num_sample, replacement=True, seed=0): def multinomial(inputs, num_sample, replacement=True, seed=0):
r""" r"""
Returns a tensor sampled from the multinomial probability distribution located in the corresponding Returns a tensor sampled from the multinomial probability distribution located in the corresponding
row of tensor input. row of the input tensor.
Note: Note:
The rows of input do not need to sum to one (in which case we use the values as weights), The rows of input do not need to sum to one (in which case we use the values as weights),
but must be non-negative, finite and have a non-zero sum. but must be non-negative, finite and have a non-zero sum.
Args: Args:
inputs (Tensor): the input tensor containing probabilities, must be 1 or 2 dims. With float32 data type. inputs (Tensor): The input tensor containing probabilities, must be 1 or 2 dimensions, with
num_sample (int): number of samples to draw. float32 data type.
replacement (bool, optional): whether to draw with replacement or not, default True. num_sample (int): Number of samples to draw.
seed (int, optional): used as entropy source for Random number engines generating pseudo-random numbers. replacement (bool, optional): Whether to draw with replacement or not, default True.
Must be non-negative. Default: 0. seed (int, optional): Seed is used as entropy source for the random number engines to generate
pseudo-random numbers,
must be non-negative. Default: 0.
Outputs: Outputs:
Tensor. have the same rows with input, each row has num_samples sampled indices. Tensor, has the same rows with input. The number of sampled indices of each row is `num_samples`.
The dtype is float32. The dtype is float32.
Examples: Examples:

View File

@ -197,6 +197,9 @@ class _AutoParallelContext:
parameter_broadcast (bool): Parameter broadcast or not. parameter_broadcast (bool): Parameter broadcast or not.
""" """
self.check_context_handle() self.check_context_handle()
if parameter_broadcast is True and context.get_context("enable_ge") is False:
raise RuntimeError("Parameter broadcast is a developing feature. For now we suggest to"
" use mindspore.common.set_seed() to share parameters among devices.")
self._context_handle.set_parameter_broadcast(parameter_broadcast) self._context_handle.set_parameter_broadcast(parameter_broadcast)
def get_parameter_broadcast(self): def get_parameter_broadcast(self):

View File

@ -58,7 +58,7 @@ if __name__ == '__main__':
cfg.group_size = get_group_size() cfg.group_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size, context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size,
parameter_broadcast=True, gradients_mean=True) gradients_mean=True)
else: else:
cfg.rank = 0 cfg.rank = 0
cfg.group_size = 1 cfg.group_size = 1

View File

@ -61,7 +61,7 @@ if __name__ == '__main__':
cfg.group_size = get_group_size() cfg.group_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size, context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size,
parameter_broadcast=True, gradients_mean=True) gradients_mean=True)
else: else:
cfg.rank = 0 cfg.rank = 0
cfg.group_size = 1 cfg.group_size = 1

View File

@ -136,8 +136,7 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl):
os.environ['RANK_SIZE'] = str(device_num) os.environ['RANK_SIZE'] = str(device_num)
if enable_hccl: if enable_hccl:
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True, parameter_broadcast=True, gradients_mean=True, all_reduce_fusion_config=[107, 160])
all_reduce_fusion_config=[107, 160])
init() init()
# network # network
@ -239,8 +238,7 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl):
os.environ['RANK_SIZE'] = str(device_num) os.environ['RANK_SIZE'] = str(device_num)
if enable_hccl: if enable_hccl:
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True, parameter_broadcast=True, gradients_mean=True, all_reduce_fusion_config=[107])
all_reduce_fusion_config=[107])
init() init()
# network # network