diff --git a/mindspore/communication/management.py b/mindspore/communication/management.py index 5d13bb5e7b1..88e5ed104eb 100755 --- a/mindspore/communication/management.py +++ b/mindspore/communication/management.py @@ -32,33 +32,32 @@ DEFAULT_BACKEND = Backend("hccl") def _get_group(group): - """Get the global world group if the group is default world comm group.""" + """Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`.""" if group == DEFAULT_WORLD_COMM_GROUP: return GlobalComm.WORLD_COMM_GROUP return group class GlobalComm: - """Global communication info.""" + """World communication information.""" BACKEND = DEFAULT_BACKEND WORLD_COMM_GROUP = DEFAULT_WORLD_COMM_GROUP def init(backend_name=None): """ - Init distributed backend, e.g., hccl/nccl, it is required before communication service can be used. + Initialize distributed backend, e.g. HCCL/NCCL, it is required before using the communication service. Note: - The full name of hccl is Huawei Collective Communication Library. - The full name of nccl is NVIDIA Collective Communication Library. + The full name of HCCL is Huawei Collective Communication Library. + The full name of NCCL is NVIDIA Collective Communication Library. Args: backend_name (str): Backend. Raises: - TypeError: If backen_name is not a string. - RuntimeError: If device target is invalid. - RuntimeError: If backend is invalid or distributed init fails. + TypeError: If `backend_name` is not a string. + RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails. """ if _is_role_pserver() or _is_role_sched(): return @@ -88,17 +87,17 @@ def init(backend_name=None): def release(): """ - Release distributed resource. e.g., hccl/nccl. + Release distributed resource. e.g. HCCL/NCCL. Raises: - RuntimeError: If distributed resource release fails. + RuntimeError: If failed to release distributed resource. """ finalize_hccl() def get_rank(group=GlobalComm.WORLD_COMM_GROUP): """ - Gets rank ID for current device in specified collective communication group. + Get the rank ID for the current device in the specified collective communication group. Args: group (str): ProcessGroup, the process group to work on. Default: WORLD_COMM_GROUP. @@ -109,7 +108,7 @@ def get_rank(group=GlobalComm.WORLD_COMM_GROUP): Raises: TypeError: If group is not a string. ValueError: If backend is invalid. - RuntimeError: If hccl/nccl is not available or nccl not supports. + RuntimeError: If HCCL/NCCL is not available or NCCL is not supported. """ return _get_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND) @@ -130,14 +129,14 @@ def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP): Raises: TypeError: If group is not a string. ValueError: If backend is invalid. - RuntimeError: If hccl/nccl is not available or nccl not supports. + RuntimeError: If HCCL/NCCL is not available or NCCL is not supported. """ return _get_local_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND) def get_group_size(group=GlobalComm.WORLD_COMM_GROUP): """ - Gets rank size of the specified collective communication group. + Get the rank size of the specified collective communication group. Args: group (str): ProcessGroup, the process group to work on. @@ -148,7 +147,7 @@ def get_group_size(group=GlobalComm.WORLD_COMM_GROUP): Raises: TypeError: If group is not a string. ValueError: If backend is invalid. - RuntimeError: If hccl/nccl is not available or nccl not supports. + RuntimeError: If HCCL/NCCL is not available or NCCL is not supported. """ return _get_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND) @@ -164,22 +163,23 @@ def get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP): group (str): ProcessGroup, the process group to work on. Returns: - int, the local rank size where the calling process is being within the group. + int, the local rank size where the calling process is within the group. Raises: TypeError: If group is not a string. ValueError: If backend is invalid. - RuntimeError: If hccl/nccl is not available or nccl not supports. + RuntimeError: If HCCL/NCCL is not available or NCCL is not supported. """ return _get_local_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND) def get_world_rank_from_group_rank(group, group_rank_id): """ - Gets the rank ID in world communication group corresponding to the rank ID in specified user communication group. + Gets the rank ID in the world communication group corresponding to + the rank ID in the specified user communication group. Note: - Nccl is not supported. + NCCL is not supported. The parameter group should not be "hccl_world_group". Args: @@ -190,52 +190,53 @@ def get_world_rank_from_group_rank(group, group_rank_id): int, the rank ID in world communication group. Raises: - TypeError: If group_rank_id is not a int or group is not a string. + TypeError: If `group_rank_id` is not an integer or the group is not a string. ValueError: If group is 'hccl_world_group' or backend is invalid. - RuntimeError: If hccl/nccl is not available or nccl not supports. + RuntimeError: If HCCL/NCCL is not available or NCCL is not supported. """ return _get_world_rank_from_group_rank_helper(group=group, group_rank_id=group_rank_id, backend=GlobalComm.BACKEND) def get_group_rank_from_world_rank(world_rank_id, group): """ - Gets the rank ID in specified user communication group corresponding to the rank ID in world communication group. + Get the rank ID in the specified user communication group corresponding to + the rank ID in the world communication group. Note: - Nccl is not supported. + NCCL is not supported. The parameter group should not be "hccl_world_group". Args: - world_rank_id (int): A rank ID in world communication group. + world_rank_id (int): A rank ID in the world communication group. group (str): The user communication group. Returns: - int, the rank ID in user communication group. + int, the rank ID in the user communication group. Raises: - TypeError: If world_rank_id is not a int or group is not a string. + TypeError: If world_rank_id is not an integer or the group is not a string. ValueError: If group is 'hccl_world_group' or backend is invalid. - RuntimeError: If hccl/nccl is not available or nccl not supports. + RuntimeError: If HCCL/NCCL is not available or NCCL is not supported. """ return _get_group_rank_from_world_rank_helper(world_rank_id=world_rank_id, group=group, backend=GlobalComm.BACKEND) def create_group(group, rank_ids): """ - Creates user collective communication group. + Create a user collective communication group. Note: - Nccl is not supported. + NCCL is not supported. The size of rank_ids should be larger than 1. Rank_ids should not have duplicate data. Args: group (str): ProcessGroup, the process group to create. - rank_ids (list): List of device ID. + rank_ids (list): A list of device IDs. Raises: - TypeError: If group is not a string or rank_ids is not a list. - ValueError: If rank_ids size is not larger than 1 or rank_ids has duplicate data or backend is invalid. + TypeError: If group is not a string or `rank_ids` is not a list. + ValueError: If `rank_ids` size is not larger than 1, or `rank_ids` has duplicate data, or backend is invalid. RuntimeError: If hccl/nccl is not available or nccl not supports. Examples: >>> group = "0-1" @@ -247,7 +248,7 @@ def create_group(group, rank_ids): def destroy_group(group): """ - Destroys user collective communication group. + Destroy the user collective communication group. Note: Nccl is not supported. @@ -259,6 +260,6 @@ def destroy_group(group): Raises: TypeError: If group is not a string. ValueError: If group is "hccl_world_group" or backend is invalid. - RuntimeError: If hccl/nccl is not available or nccl not supports. + RuntimeError: If HCCL/NCCL is not available or NCCL is not supported. """ _destroy_group_helper(group, backend=GlobalComm.BACKEND) diff --git a/mindspore/context.py b/mindspore/context.py index d6784c9e395..5889b27b1d7 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -336,6 +336,8 @@ def set_auto_parallel_context(**kwargs): """ Set auto parallel context. + Auto parallel context should be configured before the initialization of your network. + Note: Attribute name is required for setting attributes. If a program has tasks with different parallel modes, then before setting new parallel mode for the @@ -344,12 +346,25 @@ def set_auto_parallel_context(**kwargs): Setting or changing parallel modes must be called before any creating Initializer, otherwise, RuntimeError may be raised when compiling the network. + Some configurations are parallel mode specific, see the below table for details: + + =========================== =========================== ================= + Common AUTO_PARALLEL DATA_PRALLEL + =========================== =========================== ================= + device_num gradient_fp32_sync enable_parallel_optimizer + global_rank loss_repeated_mean + gradients_mean auto_parallel_search_mode + parallel_mode strategy_ckpt_load_file + all_reduce_fusion_config strategy_ckpt_save_file + full_batch + =========================== =========================== ================= + Args: device_num (int): Available device number, the value must be in [1, 4096]. Default: 1. global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0. - gradients_mean (bool): Whether to perform mean operator after all-reduce of mirror. - "stand_alone" does not support `gradients_mean`. Default: False. - gradient_fp32_sync (bool): Gradients allreduce by fp32, even though gradients is fp16 if this flag is True.. + gradients_mean (bool): Whether to perform mean operator after allreduce of gradients. + "stand_alone" do not support gradients_mean. Default: False. + gradient_fp32_sync (bool): Run allreduce of gradients in fp32. "stand_alone", "data_parallel" and "hybrid_parallel" do not support gradient_fp32_sync. Default: True. parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel", @@ -364,8 +379,8 @@ def set_auto_parallel_context(**kwargs): - semi_auto_parallel: Achieves data parallelism and model parallelism by setting parallel strategies. - - auto_parallel: Achieves parallelism automatically. - auto_parallel_search_mode (str): There are two kinds of search modes, "recursive_programming" + - auto_parallel: Achieving parallelism automatically. + auto_parallel_search_mode (str): There are two kinds of shard strategy search modes, "recursive_programming" and "dynamic_programming". Default: "dynamic_programming". - recursive_programming: Recursive programming search mode. @@ -376,9 +391,11 @@ def set_auto_parallel_context(**kwargs): broadcast. Default: False. strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' - full_batch (bool): Whether to load the whole batch on each device. Default: False. - enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation in - data parallel training in the benefit of time and memory saving. + full_batch (bool): If you load whole batch datasets in auto_parallel mode, this parameter + should be set with True. Default: False. + enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for + data parallel training in the benefit of time and memory saving. For now, + `Lamb` and `AdamWeightDecay` are supported in data parallel mode. all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. @@ -479,7 +496,7 @@ def set_context(**kwargs): Some configurations are device specific, see the bellow table for details: =========================== =========================== ================= - Common(CPU/GPU/Asecend) Ascend GPU + Common(CPU/GPU/Ascend) Ascend GPU =========================== =========================== ================= check_bprop enable_auto_mixed_precision max_device_memory device_id enable_dump diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index 0666174cdd5..d067fcc4ef4 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -33,7 +33,7 @@ __all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor def add_flags(fn=None, **flags): """ - An decorator to add flag for a function. + A decorator that adds a flag to the function. Note: Only supports bool value. @@ -43,7 +43,7 @@ def add_flags(fn=None, **flags): flags (dict): Flags use kwargs. Default: None. Returns: - Function, the fn added flags. + Function, the function with added flags. Examples: >>> add_flags(net, predit=True) @@ -63,9 +63,9 @@ def add_flags(fn=None, **flags): def core(fn=None, **flags): """ - A decorator to add flag to a function. + A decorator that adds a flag to the function. - By default, the function is marked core=True using this decorator to + By default, the function is marked as True, enabling to use this decorator to set flag to a graph. Args: @@ -91,11 +91,12 @@ def core(fn=None, **flags): class GradOperation(GradOperation_): """ - An higher-order function which is used to generate the gradient function for the input function. + A higher-order function which is used to generate the gradient function for the input function. - The gradient function generated by `GradOperation` higher-order function can be customized by construction args. + The gradient function generated by `GradOperation` higher-order function can be customized by + construction arguments. - Given an input function `net = Net()` that take `x` and `y` as inputs, and has a parameter `z`, + Given an input function `net = Net()` that takes `x` and `y` as inputs, and has a parameter `z`, see `Net` in Examples. To generate a gradient function that returns gradients with respect to the first input @@ -126,7 +127,7 @@ class GradOperation(GradOperation_): 1. Construct a `GradOperation` higher-order function with `get_by_list=True`: `grad_op = GradOperation(get_by_list=True)`. - 2. Construct a `ParameterTuple` that will be passed along input function when constructing + 2. Construct a `ParameterTuple` that will be passed to the input function when constructing `GradOperation` higher-order function, it will be used as a parameter filter that determine which gradient to return: `params = ParameterTuple(net.trainable_params())`. @@ -151,20 +152,20 @@ class GradOperation(GradOperation_): 4. Call the gradient function with input function's inputs to get the gradients with respect to all inputs and given parameters: `gradient_function(x, y)`. - We can configure the sensitiviy(gradient with respect to output) by setting `sens_param=True` and - passing in an extra sensitiviy input to the gradient function, the sensitiviy input should be - with same shape and type with input function's output(see `GradNetWrtXYWithSensParam` in Examples). + We can configure the sensitivity(gradient with respect to output) by setting `sens_param` as True and + passing an extra sensitivity input to the gradient function, the sensitivity input should has the + same shape and type with input function's output(see `GradNetWrtXYWithSensParam` in Examples). 1. Construct a `GradOperation` higher-order function with `get_all=True` and `sens_param=True`: `grad_op = GradOperation(get_all=True, sens_param=True)`. - 2. Define grad_wrt_output as sens_param which works as the gradient with respect to output: + 2. Define `grad_wrt_output` as `sens_param` which works as the gradient with respect to output: `grad_wrt_output = Tensor(np.ones([2, 2]).astype(np.float32))`. 3. Call it with input function as argument to get the gradient function: `gradient_function = grad_op(net)`. - 4. Call the gradient function with input function's inputs and sens_param to + 4. Call the gradient function with input function's inputs and `sens_param` to get the gradients with respect to all inputs: `gradient_function(x, y, grad_wrt_output)`. @@ -175,8 +176,9 @@ class GradOperation(GradOperation_): If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables at the same time in the form of ((gradients with respect to inputs), (gradients with respect to parameters)). Default: False. - sens_param (bool): Whether append sensitivity(gradient with respect to output) as input. If sens_param is False, - a 'ones_like(outputs)' sensitivity will be attached automatically. Default: False. + sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input. + If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically. + Default: False. Returns: The higher-order function which takes a function as argument and returns gradient function for it. @@ -349,9 +351,9 @@ class MultitypeFuncGraph(MultitypeFuncGraph_): """ Generate overloaded functions. - MultitypeFuncGraph is a class used to generate overloaded functions with different type as inputs. + MultitypeFuncGraph is a class used to generate overloaded functions, considering different types as inputs. Initialize an `MultitypeFuncGraph` object with name, and use `register` with input types as the decorator - for the function to be registed. And the object can be called with different type of inputs, + for the function to be registed. And the object can be called with different types of inputs, and work with `HyperMap` and `Map`. Args: @@ -360,7 +362,7 @@ class MultitypeFuncGraph(MultitypeFuncGraph_): and all inputs will pass by value, set `read_value` to True. Default: False. Raises: - ValueError: Cannot find matching functions for the given args. + ValueError: If failed to find find a matching function for the given arguments. Examples: >>> # `add` is a metagraph object which will add two objects according to @@ -431,7 +433,7 @@ class MultitypeFuncGraph(MultitypeFuncGraph_): class HyperMap(HyperMap_): """ - Hypermap will apply the set operation on input sequences. + Hypermap will apply the set operation to input sequences. Apply the operations to every elements of the sequence or nested sequence. Different from `Map`, the `HyperMap` supports to apply on nested structure. @@ -441,11 +443,10 @@ class HyperMap(HyperMap_): the operations should be put in the first input of the instance. Inputs: - - **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences, - and each row of the sequences. e.g. If args length is 2, and for `i` in length of each sequence - `(args[0][i], args[1][i])` will be the input of the operation. + - **args** (Tuple[sequence]) - If `ops` is `None`, all the inputs should be sequences with the same length. + And each row of the sequences will be the inputs of the operation. - If `ops` is not `None`, the first input is the operation, and the other is inputs. + If `ops` is not `None`, the first input is the operation, and the others are inputs. Outputs: Sequence or nested sequence, the sequence of output after applying the function. diff --git a/mindspore/ops/composite/random_ops.py b/mindspore/ops/composite/random_ops.py index ea93bb100f8..f2ed85c20eb 100644 --- a/mindspore/ops/composite/random_ops.py +++ b/mindspore/ops/composite/random_ops.py @@ -48,14 +48,15 @@ def normal(shape, mean, stddev, seed=0): Args: shape (tuple): The shape of random tensor to be generated. mean (Tensor): The mean μ distribution parameter, which specifies the location of the peak. - With float32 data type. + with float32 data type. stddev (Tensor): The deviation σ distribution parameter. It should be greater than 0. - With float32 data type. - seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers. - Must be non-negative. Default: 0. + with float32 data type. + seed (int): Seed is used as entropy source for the Random number engines to generate pseudo-random numbers. + must be non-negative. Default: 0. Returns: - Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of mean and stddev. + Tensor. The shape should be equal to the broadcasted shape between the input `shape` and shapes + of `mean` and `stddev`. The dtype is float32. Examples: @@ -123,20 +124,21 @@ def uniform(shape, minval, maxval, seed=0, dtype=mstype.float32): Args: shape (tuple): The shape of random tensor to be generated. - minval (Tensor): The a distribution parameter. - It defines the minimum possibly generated value. With int32 or float32 data type. + minval (Tensor): The distribution parameter `a`. + It defines the minimum possible generated value, with int32 or float32 data type. If dtype is int32, only one number is allowed. - maxval (Tensor): The b distribution parameter. - It defines the maximum possibly generated value. With int32 or float32 data type. + maxval (Tensor): The distribution parameter `b`. + It defines the maximum possible generated value, with int32 or float32 data type. If dtype is int32, only one number is allowed. - seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers. - Must be non-negative. Default: 0. + seed (int): Seed is used as entropy source for the random number engines to generate pseudo-random numbers, + must be non-negative. Default: 0. dtype (mindspore.dtype): type of the Uniform distribution. If it is int32, it generates numbers from discrete uniform distribution; if it is float32, it generates numbers from continuous uniform distribution. It only supports these two data types. Default: mstype.float32. Returns: - Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of minval and maxval. + Tensor. The shape should be equal to the broadcasted shape between the input `shape` and shapes + of `minval` and `maxval`. The dtype is designated as the input `dtype`. Examples: @@ -175,13 +177,14 @@ def gamma(shape, alpha, beta, seed=0): Args: shape (tuple): The shape of random tensor to be generated. - alpha (Tensor): The alpha α distribution parameter. It should be greater than 0. With float32 data type. - beta (Tensor): The beta β distribution parameter. It should be greater than 0. With float32 data type. - seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers. - Must be non-negative. Default: 0. + alpha (Tensor): The alpha α distribution parameter. It should be greater than 0 with float32 data type. + beta (Tensor): The beta β distribution parameter. It should be greater than 0 with float32 data type. + seed (int): Seed is used as entropy source for the random number engines to generate + pseudo-random numbers, must be non-negative. Default: 0. Returns: - Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of alpha and beta. + Tensor. The shape should be equal to the broadcasted shape between the input "shape" and shapes + of `alpha` and `beta`. The dtype is float32. Examples: @@ -203,12 +206,12 @@ def poisson(shape, mean, seed=0): Args: shape (tuple): The shape of random tensor to be generated. - mean (Tensor): The mean μ distribution parameter. It should be greater than 0. With float32 data type. - seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers. - Must be non-negative. Default: 0. + mean (Tensor): The mean μ distribution parameter. It should be greater than 0 with float32 data type. + seed (int): Seed is used as entropy source for the random number engines to generate pseudo-random numbers + and must be non-negative. Default: 0. Returns: - Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of mean. + Tensor. The shape should be equal to the broadcasted shape between the input "shape" and shapes of `mean`. The dtype is float32. Examples: @@ -226,21 +229,23 @@ def poisson(shape, mean, seed=0): def multinomial(inputs, num_sample, replacement=True, seed=0): r""" Returns a tensor sampled from the multinomial probability distribution located in the corresponding - row of tensor input. + row of the input tensor. Note: The rows of input do not need to sum to one (in which case we use the values as weights), but must be non-negative, finite and have a non-zero sum. Args: - inputs (Tensor): the input tensor containing probabilities, must be 1 or 2 dims. With float32 data type. - num_sample (int): number of samples to draw. - replacement (bool, optional): whether to draw with replacement or not, default True. - seed (int, optional): used as entropy source for Random number engines generating pseudo-random numbers. - Must be non-negative. Default: 0. + inputs (Tensor): The input tensor containing probabilities, must be 1 or 2 dimensions, with + float32 data type. + num_sample (int): Number of samples to draw. + replacement (bool, optional): Whether to draw with replacement or not, default True. + seed (int, optional): Seed is used as entropy source for the random number engines to generate + pseudo-random numbers, + must be non-negative. Default: 0. Outputs: - Tensor. have the same rows with input, each row has num_samples sampled indices. + Tensor, has the same rows with input. The number of sampled indices of each row is `num_samples`. The dtype is float32. Examples: diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index 0e543eb54a7..aed133ee26f 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -197,6 +197,9 @@ class _AutoParallelContext: parameter_broadcast (bool): Parameter broadcast or not. """ self.check_context_handle() + if parameter_broadcast is True and context.get_context("enable_ge") is False: + raise RuntimeError("Parameter broadcast is a developing feature. For now we suggest to" + " use mindspore.common.set_seed() to share parameters among devices.") self._context_handle.set_parameter_broadcast(parameter_broadcast) def get_parameter_broadcast(self): diff --git a/model_zoo/official/cv/nasnet/train.py b/model_zoo/official/cv/nasnet/train.py index f962ea63b93..b343e3880cc 100755 --- a/model_zoo/official/cv/nasnet/train.py +++ b/model_zoo/official/cv/nasnet/train.py @@ -58,7 +58,7 @@ if __name__ == '__main__': cfg.group_size = get_group_size() parallel_mode = ParallelMode.DATA_PARALLEL context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size, - parameter_broadcast=True, gradients_mean=True) + gradients_mean=True) else: cfg.rank = 0 cfg.group_size = 1 diff --git a/model_zoo/official/cv/shufflenetv2/train.py b/model_zoo/official/cv/shufflenetv2/train.py index dca1cb132de..13ff3c39d11 100644 --- a/model_zoo/official/cv/shufflenetv2/train.py +++ b/model_zoo/official/cv/shufflenetv2/train.py @@ -61,7 +61,7 @@ if __name__ == '__main__': cfg.group_size = get_group_size() parallel_mode = ParallelMode.DATA_PARALLEL context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size, - parameter_broadcast=True, gradients_mean=True) + gradients_mean=True) else: cfg.rank = 0 cfg.group_size = 1 diff --git a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py index 78ee7fdaf24..c1b682a31b4 100644 --- a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py +++ b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py @@ -136,8 +136,7 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl): os.environ['RANK_SIZE'] = str(device_num) if enable_hccl: context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, - gradients_mean=True, parameter_broadcast=True, - all_reduce_fusion_config=[107, 160]) + gradients_mean=True, all_reduce_fusion_config=[107, 160]) init() # network @@ -239,8 +238,7 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl): os.environ['RANK_SIZE'] = str(device_num) if enable_hccl: context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, - gradients_mean=True, parameter_broadcast=True, - all_reduce_fusion_config=[107]) + gradients_mean=True, all_reduce_fusion_config=[107]) init() # network