forked from OSSInnovation/mindspore
fix_api_problems
This commit is contained in:
parent
19bdba56f9
commit
8ea177e614
|
@ -32,33 +32,32 @@ DEFAULT_BACKEND = Backend("hccl")
|
||||||
|
|
||||||
|
|
||||||
def _get_group(group):
|
def _get_group(group):
|
||||||
"""Get the global world group if the group is default world comm group."""
|
"""Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`."""
|
||||||
if group == DEFAULT_WORLD_COMM_GROUP:
|
if group == DEFAULT_WORLD_COMM_GROUP:
|
||||||
return GlobalComm.WORLD_COMM_GROUP
|
return GlobalComm.WORLD_COMM_GROUP
|
||||||
return group
|
return group
|
||||||
|
|
||||||
|
|
||||||
class GlobalComm:
|
class GlobalComm:
|
||||||
"""Global communication info."""
|
"""World communication information."""
|
||||||
BACKEND = DEFAULT_BACKEND
|
BACKEND = DEFAULT_BACKEND
|
||||||
WORLD_COMM_GROUP = DEFAULT_WORLD_COMM_GROUP
|
WORLD_COMM_GROUP = DEFAULT_WORLD_COMM_GROUP
|
||||||
|
|
||||||
|
|
||||||
def init(backend_name=None):
|
def init(backend_name=None):
|
||||||
"""
|
"""
|
||||||
Init distributed backend, e.g., hccl/nccl, it is required before communication service can be used.
|
Initialize distributed backend, e.g. HCCL/NCCL, it is required before using the communication service.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
The full name of hccl is Huawei Collective Communication Library.
|
The full name of HCCL is Huawei Collective Communication Library.
|
||||||
The full name of nccl is NVIDIA Collective Communication Library.
|
The full name of NCCL is NVIDIA Collective Communication Library.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
backend_name (str): Backend.
|
backend_name (str): Backend.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If backen_name is not a string.
|
TypeError: If `backend_name` is not a string.
|
||||||
RuntimeError: If device target is invalid.
|
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
|
||||||
RuntimeError: If backend is invalid or distributed init fails.
|
|
||||||
"""
|
"""
|
||||||
if _is_role_pserver() or _is_role_sched():
|
if _is_role_pserver() or _is_role_sched():
|
||||||
return
|
return
|
||||||
|
@ -88,17 +87,17 @@ def init(backend_name=None):
|
||||||
|
|
||||||
def release():
|
def release():
|
||||||
"""
|
"""
|
||||||
Release distributed resource. e.g., hccl/nccl.
|
Release distributed resource. e.g. HCCL/NCCL.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If distributed resource release fails.
|
RuntimeError: If failed to release distributed resource.
|
||||||
"""
|
"""
|
||||||
finalize_hccl()
|
finalize_hccl()
|
||||||
|
|
||||||
|
|
||||||
def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
|
def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
|
||||||
"""
|
"""
|
||||||
Gets rank ID for current device in specified collective communication group.
|
Get the rank ID for the current device in the specified collective communication group.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group (str): ProcessGroup, the process group to work on. Default: WORLD_COMM_GROUP.
|
group (str): ProcessGroup, the process group to work on. Default: WORLD_COMM_GROUP.
|
||||||
|
@ -109,7 +108,7 @@ def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If group is not a string.
|
TypeError: If group is not a string.
|
||||||
ValueError: If backend is invalid.
|
ValueError: If backend is invalid.
|
||||||
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
|
||||||
"""
|
"""
|
||||||
return _get_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
|
return _get_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
|
||||||
|
|
||||||
|
@ -130,14 +129,14 @@ def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP):
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If group is not a string.
|
TypeError: If group is not a string.
|
||||||
ValueError: If backend is invalid.
|
ValueError: If backend is invalid.
|
||||||
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
|
||||||
"""
|
"""
|
||||||
return _get_local_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
|
return _get_local_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
|
||||||
|
|
||||||
|
|
||||||
def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
|
def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
|
||||||
"""
|
"""
|
||||||
Gets rank size of the specified collective communication group.
|
Get the rank size of the specified collective communication group.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group (str): ProcessGroup, the process group to work on.
|
group (str): ProcessGroup, the process group to work on.
|
||||||
|
@ -148,7 +147,7 @@ def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If group is not a string.
|
TypeError: If group is not a string.
|
||||||
ValueError: If backend is invalid.
|
ValueError: If backend is invalid.
|
||||||
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
|
||||||
"""
|
"""
|
||||||
return _get_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
|
return _get_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
|
||||||
|
|
||||||
|
@ -164,22 +163,23 @@ def get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP):
|
||||||
group (str): ProcessGroup, the process group to work on.
|
group (str): ProcessGroup, the process group to work on.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int, the local rank size where the calling process is being within the group.
|
int, the local rank size where the calling process is within the group.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If group is not a string.
|
TypeError: If group is not a string.
|
||||||
ValueError: If backend is invalid.
|
ValueError: If backend is invalid.
|
||||||
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
|
||||||
"""
|
"""
|
||||||
return _get_local_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
|
return _get_local_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
|
||||||
|
|
||||||
|
|
||||||
def get_world_rank_from_group_rank(group, group_rank_id):
|
def get_world_rank_from_group_rank(group, group_rank_id):
|
||||||
"""
|
"""
|
||||||
Gets the rank ID in world communication group corresponding to the rank ID in specified user communication group.
|
Gets the rank ID in the world communication group corresponding to
|
||||||
|
the rank ID in the specified user communication group.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Nccl is not supported.
|
NCCL is not supported.
|
||||||
The parameter group should not be "hccl_world_group".
|
The parameter group should not be "hccl_world_group".
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -190,52 +190,53 @@ def get_world_rank_from_group_rank(group, group_rank_id):
|
||||||
int, the rank ID in world communication group.
|
int, the rank ID in world communication group.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If group_rank_id is not a int or group is not a string.
|
TypeError: If `group_rank_id` is not an integer or the group is not a string.
|
||||||
ValueError: If group is 'hccl_world_group' or backend is invalid.
|
ValueError: If group is 'hccl_world_group' or backend is invalid.
|
||||||
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
|
||||||
"""
|
"""
|
||||||
return _get_world_rank_from_group_rank_helper(group=group, group_rank_id=group_rank_id, backend=GlobalComm.BACKEND)
|
return _get_world_rank_from_group_rank_helper(group=group, group_rank_id=group_rank_id, backend=GlobalComm.BACKEND)
|
||||||
|
|
||||||
|
|
||||||
def get_group_rank_from_world_rank(world_rank_id, group):
|
def get_group_rank_from_world_rank(world_rank_id, group):
|
||||||
"""
|
"""
|
||||||
Gets the rank ID in specified user communication group corresponding to the rank ID in world communication group.
|
Get the rank ID in the specified user communication group corresponding to
|
||||||
|
the rank ID in the world communication group.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Nccl is not supported.
|
NCCL is not supported.
|
||||||
The parameter group should not be "hccl_world_group".
|
The parameter group should not be "hccl_world_group".
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
world_rank_id (int): A rank ID in world communication group.
|
world_rank_id (int): A rank ID in the world communication group.
|
||||||
group (str): The user communication group.
|
group (str): The user communication group.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int, the rank ID in user communication group.
|
int, the rank ID in the user communication group.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If world_rank_id is not a int or group is not a string.
|
TypeError: If world_rank_id is not an integer or the group is not a string.
|
||||||
ValueError: If group is 'hccl_world_group' or backend is invalid.
|
ValueError: If group is 'hccl_world_group' or backend is invalid.
|
||||||
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
|
||||||
"""
|
"""
|
||||||
return _get_group_rank_from_world_rank_helper(world_rank_id=world_rank_id, group=group, backend=GlobalComm.BACKEND)
|
return _get_group_rank_from_world_rank_helper(world_rank_id=world_rank_id, group=group, backend=GlobalComm.BACKEND)
|
||||||
|
|
||||||
|
|
||||||
def create_group(group, rank_ids):
|
def create_group(group, rank_ids):
|
||||||
"""
|
"""
|
||||||
Creates user collective communication group.
|
Create a user collective communication group.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Nccl is not supported.
|
NCCL is not supported.
|
||||||
The size of rank_ids should be larger than 1.
|
The size of rank_ids should be larger than 1.
|
||||||
Rank_ids should not have duplicate data.
|
Rank_ids should not have duplicate data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group (str): ProcessGroup, the process group to create.
|
group (str): ProcessGroup, the process group to create.
|
||||||
rank_ids (list): List of device ID.
|
rank_ids (list): A list of device IDs.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If group is not a string or rank_ids is not a list.
|
TypeError: If group is not a string or `rank_ids` is not a list.
|
||||||
ValueError: If rank_ids size is not larger than 1 or rank_ids has duplicate data or backend is invalid.
|
ValueError: If `rank_ids` size is not larger than 1, or `rank_ids` has duplicate data, or backend is invalid.
|
||||||
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
||||||
Examples:
|
Examples:
|
||||||
>>> group = "0-1"
|
>>> group = "0-1"
|
||||||
|
@ -247,7 +248,7 @@ def create_group(group, rank_ids):
|
||||||
|
|
||||||
def destroy_group(group):
|
def destroy_group(group):
|
||||||
"""
|
"""
|
||||||
Destroys user collective communication group.
|
Destroy the user collective communication group.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Nccl is not supported.
|
Nccl is not supported.
|
||||||
|
@ -259,6 +260,6 @@ def destroy_group(group):
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If group is not a string.
|
TypeError: If group is not a string.
|
||||||
ValueError: If group is "hccl_world_group" or backend is invalid.
|
ValueError: If group is "hccl_world_group" or backend is invalid.
|
||||||
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
|
||||||
"""
|
"""
|
||||||
_destroy_group_helper(group, backend=GlobalComm.BACKEND)
|
_destroy_group_helper(group, backend=GlobalComm.BACKEND)
|
||||||
|
|
|
@ -336,6 +336,8 @@ def set_auto_parallel_context(**kwargs):
|
||||||
"""
|
"""
|
||||||
Set auto parallel context.
|
Set auto parallel context.
|
||||||
|
|
||||||
|
Auto parallel context should be configured before the initialization of your network.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Attribute name is required for setting attributes.
|
Attribute name is required for setting attributes.
|
||||||
If a program has tasks with different parallel modes, then before setting new parallel mode for the
|
If a program has tasks with different parallel modes, then before setting new parallel mode for the
|
||||||
|
@ -344,12 +346,25 @@ def set_auto_parallel_context(**kwargs):
|
||||||
Setting or changing parallel modes must be called before any creating Initializer, otherwise,
|
Setting or changing parallel modes must be called before any creating Initializer, otherwise,
|
||||||
RuntimeError may be raised when compiling the network.
|
RuntimeError may be raised when compiling the network.
|
||||||
|
|
||||||
|
Some configurations are parallel mode specific, see the below table for details:
|
||||||
|
|
||||||
|
=========================== =========================== =================
|
||||||
|
Common AUTO_PARALLEL DATA_PRALLEL
|
||||||
|
=========================== =========================== =================
|
||||||
|
device_num gradient_fp32_sync enable_parallel_optimizer
|
||||||
|
global_rank loss_repeated_mean
|
||||||
|
gradients_mean auto_parallel_search_mode
|
||||||
|
parallel_mode strategy_ckpt_load_file
|
||||||
|
all_reduce_fusion_config strategy_ckpt_save_file
|
||||||
|
full_batch
|
||||||
|
=========================== =========================== =================
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
|
device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
|
||||||
global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
|
global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
|
||||||
gradients_mean (bool): Whether to perform mean operator after all-reduce of mirror.
|
gradients_mean (bool): Whether to perform mean operator after allreduce of gradients.
|
||||||
"stand_alone" does not support `gradients_mean`. Default: False.
|
"stand_alone" do not support gradients_mean. Default: False.
|
||||||
gradient_fp32_sync (bool): Gradients allreduce by fp32, even though gradients is fp16 if this flag is True..
|
gradient_fp32_sync (bool): Run allreduce of gradients in fp32.
|
||||||
"stand_alone", "data_parallel" and "hybrid_parallel" do not support
|
"stand_alone", "data_parallel" and "hybrid_parallel" do not support
|
||||||
gradient_fp32_sync. Default: True.
|
gradient_fp32_sync. Default: True.
|
||||||
parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
|
parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
|
||||||
|
@ -364,8 +379,8 @@ def set_auto_parallel_context(**kwargs):
|
||||||
- semi_auto_parallel: Achieves data parallelism and model parallelism by
|
- semi_auto_parallel: Achieves data parallelism and model parallelism by
|
||||||
setting parallel strategies.
|
setting parallel strategies.
|
||||||
|
|
||||||
- auto_parallel: Achieves parallelism automatically.
|
- auto_parallel: Achieving parallelism automatically.
|
||||||
auto_parallel_search_mode (str): There are two kinds of search modes, "recursive_programming"
|
auto_parallel_search_mode (str): There are two kinds of shard strategy search modes, "recursive_programming"
|
||||||
and "dynamic_programming". Default: "dynamic_programming".
|
and "dynamic_programming". Default: "dynamic_programming".
|
||||||
|
|
||||||
- recursive_programming: Recursive programming search mode.
|
- recursive_programming: Recursive programming search mode.
|
||||||
|
@ -376,9 +391,11 @@ def set_auto_parallel_context(**kwargs):
|
||||||
broadcast. Default: False.
|
broadcast. Default: False.
|
||||||
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
|
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
|
||||||
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
|
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
|
||||||
full_batch (bool): Whether to load the whole batch on each device. Default: False.
|
full_batch (bool): If you load whole batch datasets in auto_parallel mode, this parameter
|
||||||
enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation in
|
should be set with True. Default: False.
|
||||||
data parallel training in the benefit of time and memory saving.
|
enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for
|
||||||
|
data parallel training in the benefit of time and memory saving. For now,
|
||||||
|
`Lamb` and `AdamWeightDecay` are supported in data parallel mode.
|
||||||
all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM
|
all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM
|
||||||
and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP.
|
and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP.
|
||||||
|
|
||||||
|
@ -479,7 +496,7 @@ def set_context(**kwargs):
|
||||||
Some configurations are device specific, see the bellow table for details:
|
Some configurations are device specific, see the bellow table for details:
|
||||||
|
|
||||||
=========================== =========================== =================
|
=========================== =========================== =================
|
||||||
Common(CPU/GPU/Asecend) Ascend GPU
|
Common(CPU/GPU/Ascend) Ascend GPU
|
||||||
=========================== =========================== =================
|
=========================== =========================== =================
|
||||||
check_bprop enable_auto_mixed_precision max_device_memory
|
check_bprop enable_auto_mixed_precision max_device_memory
|
||||||
device_id enable_dump
|
device_id enable_dump
|
||||||
|
|
|
@ -33,7 +33,7 @@ __all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor
|
||||||
|
|
||||||
def add_flags(fn=None, **flags):
|
def add_flags(fn=None, **flags):
|
||||||
"""
|
"""
|
||||||
An decorator to add flag for a function.
|
A decorator that adds a flag to the function.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Only supports bool value.
|
Only supports bool value.
|
||||||
|
@ -43,7 +43,7 @@ def add_flags(fn=None, **flags):
|
||||||
flags (dict): Flags use kwargs. Default: None.
|
flags (dict): Flags use kwargs. Default: None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Function, the fn added flags.
|
Function, the function with added flags.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> add_flags(net, predit=True)
|
>>> add_flags(net, predit=True)
|
||||||
|
@ -63,9 +63,9 @@ def add_flags(fn=None, **flags):
|
||||||
|
|
||||||
def core(fn=None, **flags):
|
def core(fn=None, **flags):
|
||||||
"""
|
"""
|
||||||
A decorator to add flag to a function.
|
A decorator that adds a flag to the function.
|
||||||
|
|
||||||
By default, the function is marked core=True using this decorator to
|
By default, the function is marked as True, enabling to use this decorator to
|
||||||
set flag to a graph.
|
set flag to a graph.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -91,11 +91,12 @@ def core(fn=None, **flags):
|
||||||
|
|
||||||
class GradOperation(GradOperation_):
|
class GradOperation(GradOperation_):
|
||||||
"""
|
"""
|
||||||
An higher-order function which is used to generate the gradient function for the input function.
|
A higher-order function which is used to generate the gradient function for the input function.
|
||||||
|
|
||||||
The gradient function generated by `GradOperation` higher-order function can be customized by construction args.
|
The gradient function generated by `GradOperation` higher-order function can be customized by
|
||||||
|
construction arguments.
|
||||||
|
|
||||||
Given an input function `net = Net()` that take `x` and `y` as inputs, and has a parameter `z`,
|
Given an input function `net = Net()` that takes `x` and `y` as inputs, and has a parameter `z`,
|
||||||
see `Net` in Examples.
|
see `Net` in Examples.
|
||||||
|
|
||||||
To generate a gradient function that returns gradients with respect to the first input
|
To generate a gradient function that returns gradients with respect to the first input
|
||||||
|
@ -126,7 +127,7 @@ class GradOperation(GradOperation_):
|
||||||
1. Construct a `GradOperation` higher-order function with `get_by_list=True`:
|
1. Construct a `GradOperation` higher-order function with `get_by_list=True`:
|
||||||
`grad_op = GradOperation(get_by_list=True)`.
|
`grad_op = GradOperation(get_by_list=True)`.
|
||||||
|
|
||||||
2. Construct a `ParameterTuple` that will be passed along input function when constructing
|
2. Construct a `ParameterTuple` that will be passed to the input function when constructing
|
||||||
`GradOperation` higher-order function, it will be used as a parameter filter that determine
|
`GradOperation` higher-order function, it will be used as a parameter filter that determine
|
||||||
which gradient to return: `params = ParameterTuple(net.trainable_params())`.
|
which gradient to return: `params = ParameterTuple(net.trainable_params())`.
|
||||||
|
|
||||||
|
@ -151,20 +152,20 @@ class GradOperation(GradOperation_):
|
||||||
4. Call the gradient function with input function's inputs
|
4. Call the gradient function with input function's inputs
|
||||||
to get the gradients with respect to all inputs and given parameters: `gradient_function(x, y)`.
|
to get the gradients with respect to all inputs and given parameters: `gradient_function(x, y)`.
|
||||||
|
|
||||||
We can configure the sensitiviy(gradient with respect to output) by setting `sens_param=True` and
|
We can configure the sensitivity(gradient with respect to output) by setting `sens_param` as True and
|
||||||
passing in an extra sensitiviy input to the gradient function, the sensitiviy input should be
|
passing an extra sensitivity input to the gradient function, the sensitivity input should has the
|
||||||
with same shape and type with input function's output(see `GradNetWrtXYWithSensParam` in Examples).
|
same shape and type with input function's output(see `GradNetWrtXYWithSensParam` in Examples).
|
||||||
|
|
||||||
1. Construct a `GradOperation` higher-order function with `get_all=True` and `sens_param=True`:
|
1. Construct a `GradOperation` higher-order function with `get_all=True` and `sens_param=True`:
|
||||||
`grad_op = GradOperation(get_all=True, sens_param=True)`.
|
`grad_op = GradOperation(get_all=True, sens_param=True)`.
|
||||||
|
|
||||||
2. Define grad_wrt_output as sens_param which works as the gradient with respect to output:
|
2. Define `grad_wrt_output` as `sens_param` which works as the gradient with respect to output:
|
||||||
`grad_wrt_output = Tensor(np.ones([2, 2]).astype(np.float32))`.
|
`grad_wrt_output = Tensor(np.ones([2, 2]).astype(np.float32))`.
|
||||||
|
|
||||||
3. Call it with input function as argument to get the gradient function:
|
3. Call it with input function as argument to get the gradient function:
|
||||||
`gradient_function = grad_op(net)`.
|
`gradient_function = grad_op(net)`.
|
||||||
|
|
||||||
4. Call the gradient function with input function's inputs and sens_param to
|
4. Call the gradient function with input function's inputs and `sens_param` to
|
||||||
get the gradients with respect to all inputs:
|
get the gradients with respect to all inputs:
|
||||||
`gradient_function(x, y, grad_wrt_output)`.
|
`gradient_function(x, y, grad_wrt_output)`.
|
||||||
|
|
||||||
|
@ -175,8 +176,9 @@ class GradOperation(GradOperation_):
|
||||||
If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables
|
If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables
|
||||||
at the same time in the form of ((gradients with respect to inputs),
|
at the same time in the form of ((gradients with respect to inputs),
|
||||||
(gradients with respect to parameters)). Default: False.
|
(gradients with respect to parameters)). Default: False.
|
||||||
sens_param (bool): Whether append sensitivity(gradient with respect to output) as input. If sens_param is False,
|
sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input.
|
||||||
a 'ones_like(outputs)' sensitivity will be attached automatically. Default: False.
|
If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically.
|
||||||
|
Default: False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The higher-order function which takes a function as argument and returns gradient function for it.
|
The higher-order function which takes a function as argument and returns gradient function for it.
|
||||||
|
@ -349,9 +351,9 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
||||||
"""
|
"""
|
||||||
Generate overloaded functions.
|
Generate overloaded functions.
|
||||||
|
|
||||||
MultitypeFuncGraph is a class used to generate overloaded functions with different type as inputs.
|
MultitypeFuncGraph is a class used to generate overloaded functions, considering different types as inputs.
|
||||||
Initialize an `MultitypeFuncGraph` object with name, and use `register` with input types as the decorator
|
Initialize an `MultitypeFuncGraph` object with name, and use `register` with input types as the decorator
|
||||||
for the function to be registed. And the object can be called with different type of inputs,
|
for the function to be registed. And the object can be called with different types of inputs,
|
||||||
and work with `HyperMap` and `Map`.
|
and work with `HyperMap` and `Map`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -360,7 +362,7 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
||||||
and all inputs will pass by value, set `read_value` to True. Default: False.
|
and all inputs will pass by value, set `read_value` to True. Default: False.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: Cannot find matching functions for the given args.
|
ValueError: If failed to find find a matching function for the given arguments.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> # `add` is a metagraph object which will add two objects according to
|
>>> # `add` is a metagraph object which will add two objects according to
|
||||||
|
@ -431,7 +433,7 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
||||||
|
|
||||||
class HyperMap(HyperMap_):
|
class HyperMap(HyperMap_):
|
||||||
"""
|
"""
|
||||||
Hypermap will apply the set operation on input sequences.
|
Hypermap will apply the set operation to input sequences.
|
||||||
|
|
||||||
Apply the operations to every elements of the sequence or nested sequence. Different
|
Apply the operations to every elements of the sequence or nested sequence. Different
|
||||||
from `Map`, the `HyperMap` supports to apply on nested structure.
|
from `Map`, the `HyperMap` supports to apply on nested structure.
|
||||||
|
@ -441,11 +443,10 @@ class HyperMap(HyperMap_):
|
||||||
the operations should be put in the first input of the instance.
|
the operations should be put in the first input of the instance.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences,
|
- **args** (Tuple[sequence]) - If `ops` is `None`, all the inputs should be sequences with the same length.
|
||||||
and each row of the sequences. e.g. If args length is 2, and for `i` in length of each sequence
|
And each row of the sequences will be the inputs of the operation.
|
||||||
`(args[0][i], args[1][i])` will be the input of the operation.
|
|
||||||
|
|
||||||
If `ops` is not `None`, the first input is the operation, and the other is inputs.
|
If `ops` is not `None`, the first input is the operation, and the others are inputs.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
Sequence or nested sequence, the sequence of output after applying the function.
|
Sequence or nested sequence, the sequence of output after applying the function.
|
||||||
|
|
|
@ -48,14 +48,15 @@ def normal(shape, mean, stddev, seed=0):
|
||||||
Args:
|
Args:
|
||||||
shape (tuple): The shape of random tensor to be generated.
|
shape (tuple): The shape of random tensor to be generated.
|
||||||
mean (Tensor): The mean μ distribution parameter, which specifies the location of the peak.
|
mean (Tensor): The mean μ distribution parameter, which specifies the location of the peak.
|
||||||
With float32 data type.
|
with float32 data type.
|
||||||
stddev (Tensor): The deviation σ distribution parameter. It should be greater than 0.
|
stddev (Tensor): The deviation σ distribution parameter. It should be greater than 0.
|
||||||
With float32 data type.
|
with float32 data type.
|
||||||
seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
|
seed (int): Seed is used as entropy source for the Random number engines to generate pseudo-random numbers.
|
||||||
Must be non-negative. Default: 0.
|
must be non-negative. Default: 0.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of mean and stddev.
|
Tensor. The shape should be equal to the broadcasted shape between the input `shape` and shapes
|
||||||
|
of `mean` and `stddev`.
|
||||||
The dtype is float32.
|
The dtype is float32.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
@ -85,20 +86,21 @@ def uniform(shape, minval, maxval, seed=0, dtype=mstype.float32):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
shape (tuple): The shape of random tensor to be generated.
|
shape (tuple): The shape of random tensor to be generated.
|
||||||
minval (Tensor): The a distribution parameter.
|
minval (Tensor): The distribution parameter `a`.
|
||||||
It defines the minimum possibly generated value. With int32 or float32 data type.
|
It defines the minimum possible generated value, with int32 or float32 data type.
|
||||||
If dtype is int32, only one number is allowed.
|
If dtype is int32, only one number is allowed.
|
||||||
maxval (Tensor): The b distribution parameter.
|
maxval (Tensor): The distribution parameter `b`.
|
||||||
It defines the maximum possibly generated value. With int32 or float32 data type.
|
It defines the maximum possible generated value, with int32 or float32 data type.
|
||||||
If dtype is int32, only one number is allowed.
|
If dtype is int32, only one number is allowed.
|
||||||
seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
|
seed (int): Seed is used as entropy source for the random number engines to generate pseudo-random numbers,
|
||||||
Must be non-negative. Default: 0.
|
must be non-negative. Default: 0.
|
||||||
dtype (mindspore.dtype): type of the Uniform distribution. If it is int32, it generates numbers from discrete
|
dtype (mindspore.dtype): type of the Uniform distribution. If it is int32, it generates numbers from discrete
|
||||||
uniform distribution; if it is float32, it generates numbers from continuous uniform distribution. It only
|
uniform distribution; if it is float32, it generates numbers from continuous uniform distribution. It only
|
||||||
supports these two data types. Default: mstype.float32.
|
supports these two data types. Default: mstype.float32.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of minval and maxval.
|
Tensor. The shape should be equal to the broadcasted shape between the input `shape` and shapes
|
||||||
|
of `minval` and `maxval`.
|
||||||
The dtype is designated as the input `dtype`.
|
The dtype is designated as the input `dtype`.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
@ -137,13 +139,14 @@ def gamma(shape, alpha, beta, seed=0):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
shape (tuple): The shape of random tensor to be generated.
|
shape (tuple): The shape of random tensor to be generated.
|
||||||
alpha (Tensor): The alpha α distribution parameter. It should be greater than 0. With float32 data type.
|
alpha (Tensor): The alpha α distribution parameter. It should be greater than 0 with float32 data type.
|
||||||
beta (Tensor): The beta β distribution parameter. It should be greater than 0. With float32 data type.
|
beta (Tensor): The beta β distribution parameter. It should be greater than 0 with float32 data type.
|
||||||
seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
|
seed (int): Seed is used as entropy source for the random number engines to generate
|
||||||
Must be non-negative. Default: 0.
|
pseudo-random numbers, must be non-negative. Default: 0.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of alpha and beta.
|
Tensor. The shape should be equal to the broadcasted shape between the input "shape" and shapes
|
||||||
|
of `alpha` and `beta`.
|
||||||
The dtype is float32.
|
The dtype is float32.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
@ -165,12 +168,12 @@ def poisson(shape, mean, seed=0):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
shape (tuple): The shape of random tensor to be generated.
|
shape (tuple): The shape of random tensor to be generated.
|
||||||
mean (Tensor): The mean μ distribution parameter. It should be greater than 0. With float32 data type.
|
mean (Tensor): The mean μ distribution parameter. It should be greater than 0 with float32 data type.
|
||||||
seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
|
seed (int): Seed is used as entropy source for the random number engines to generate pseudo-random numbers
|
||||||
Must be non-negative. Default: 0.
|
and must be non-negative. Default: 0.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of mean.
|
Tensor. The shape should be equal to the broadcasted shape between the input "shape" and shapes of `mean`.
|
||||||
The dtype is float32.
|
The dtype is float32.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
@ -188,21 +191,23 @@ def poisson(shape, mean, seed=0):
|
||||||
def multinomial(inputs, num_sample, replacement=True, seed=0):
|
def multinomial(inputs, num_sample, replacement=True, seed=0):
|
||||||
r"""
|
r"""
|
||||||
Returns a tensor sampled from the multinomial probability distribution located in the corresponding
|
Returns a tensor sampled from the multinomial probability distribution located in the corresponding
|
||||||
row of tensor input.
|
row of the input tensor.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
The rows of input do not need to sum to one (in which case we use the values as weights),
|
The rows of input do not need to sum to one (in which case we use the values as weights),
|
||||||
but must be non-negative, finite and have a non-zero sum.
|
but must be non-negative, finite and have a non-zero sum.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs (Tensor): the input tensor containing probabilities, must be 1 or 2 dims. With float32 data type.
|
inputs (Tensor): The input tensor containing probabilities, must be 1 or 2 dimensions, with
|
||||||
num_sample (int): number of samples to draw.
|
float32 data type.
|
||||||
replacement (bool, optional): whether to draw with replacement or not, default True.
|
num_sample (int): Number of samples to draw.
|
||||||
seed (int, optional): used as entropy source for Random number engines generating pseudo-random numbers.
|
replacement (bool, optional): Whether to draw with replacement or not, default True.
|
||||||
Must be non-negative. Default: 0.
|
seed (int, optional): Seed is used as entropy source for the random number engines to generate
|
||||||
|
pseudo-random numbers,
|
||||||
|
must be non-negative. Default: 0.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
Tensor. have the same rows with input, each row has num_samples sampled indices.
|
Tensor, has the same rows with input. The number of sampled indices of each row is `num_samples`.
|
||||||
The dtype is float32.
|
The dtype is float32.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
|
@ -197,6 +197,9 @@ class _AutoParallelContext:
|
||||||
parameter_broadcast (bool): Parameter broadcast or not.
|
parameter_broadcast (bool): Parameter broadcast or not.
|
||||||
"""
|
"""
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
|
if parameter_broadcast is True and context.get_context("enable_ge") is False:
|
||||||
|
raise RuntimeError("Parameter broadcast is a developing feature. For now we suggest to"
|
||||||
|
" use mindspore.common.set_seed() to share parameters among devices.")
|
||||||
self._context_handle.set_parameter_broadcast(parameter_broadcast)
|
self._context_handle.set_parameter_broadcast(parameter_broadcast)
|
||||||
|
|
||||||
def get_parameter_broadcast(self):
|
def get_parameter_broadcast(self):
|
||||||
|
|
|
@ -58,7 +58,7 @@ if __name__ == '__main__':
|
||||||
cfg.group_size = get_group_size()
|
cfg.group_size = get_group_size()
|
||||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size,
|
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size,
|
||||||
parameter_broadcast=True, gradients_mean=True)
|
gradients_mean=True)
|
||||||
else:
|
else:
|
||||||
cfg.rank = 0
|
cfg.rank = 0
|
||||||
cfg.group_size = 1
|
cfg.group_size = 1
|
||||||
|
|
|
@ -61,7 +61,7 @@ if __name__ == '__main__':
|
||||||
cfg.group_size = get_group_size()
|
cfg.group_size = get_group_size()
|
||||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size,
|
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size,
|
||||||
parameter_broadcast=True, gradients_mean=True)
|
gradients_mean=True)
|
||||||
else:
|
else:
|
||||||
cfg.rank = 0
|
cfg.rank = 0
|
||||||
cfg.group_size = 1
|
cfg.group_size = 1
|
||||||
|
|
|
@ -136,8 +136,7 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl):
|
||||||
os.environ['RANK_SIZE'] = str(device_num)
|
os.environ['RANK_SIZE'] = str(device_num)
|
||||||
if enable_hccl:
|
if enable_hccl:
|
||||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
gradients_mean=True, parameter_broadcast=True,
|
gradients_mean=True, all_reduce_fusion_config=[107, 160])
|
||||||
all_reduce_fusion_config=[107, 160])
|
|
||||||
init()
|
init()
|
||||||
|
|
||||||
# network
|
# network
|
||||||
|
@ -239,8 +238,7 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl):
|
||||||
os.environ['RANK_SIZE'] = str(device_num)
|
os.environ['RANK_SIZE'] = str(device_num)
|
||||||
if enable_hccl:
|
if enable_hccl:
|
||||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
gradients_mean=True, parameter_broadcast=True,
|
gradients_mean=True, all_reduce_fusion_config=[107])
|
||||||
all_reduce_fusion_config=[107])
|
|
||||||
init()
|
init()
|
||||||
|
|
||||||
# network
|
# network
|
||||||
|
|
Loading…
Reference in New Issue