forked from mindspore-Ecosystem/mindspore
commit
e902c7c731
|
@ -303,6 +303,7 @@ class Router(Cell):
|
|||
parallel_config=None):
|
||||
super(Router, self).__init__()
|
||||
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
|
||||
dp = parallel_config.data_parallel
|
||||
self.d_model = d_model
|
||||
self.expert_dim = moe_config.expert_num
|
||||
self.capacity_factor = moe_config.capacity_factor
|
||||
|
@ -314,6 +315,7 @@ class Router(Cell):
|
|||
self.noise = Tensor(np.random.uniform(1 - self.noisy_epsilon, 1 + self.noisy_epsilon, (d_model,)))
|
||||
|
||||
self.dense = Dense(in_channels=self.d_model, out_channels=self.expert_dim, has_bias=False)
|
||||
self.dense.matmul.shard(((dp, 1), (1, 1)))
|
||||
self.mul = P.Mul()
|
||||
self.cast = P.Cast()
|
||||
|
||||
|
@ -430,6 +432,8 @@ class TopkRouter(Cell):
|
|||
self.reduce_sum_keep2 = P.ReduceSum(keep_dims=True)
|
||||
self.expand = P.ExpandDims()
|
||||
self.expand2 = P.ExpandDims()
|
||||
self.add_scala = P.Add()
|
||||
self.init_loss = Tensor(0.0, mstype.float32)
|
||||
else:
|
||||
dp = parallel_config.data_parallel
|
||||
self.d_model = d_model
|
||||
|
@ -479,6 +483,8 @@ class TopkRouter(Cell):
|
|||
self.reduce_sum_keep2 = P.ReduceSum(keep_dims=True).shard(((dp, 1, 1, 1),))
|
||||
self.expand = P.ExpandDims().shard(((dp, 1),))
|
||||
self.expand2 = P.ExpandDims().shard(((dp, 1, 1),))
|
||||
self.add_scala = P.Add().shard(((), ()))
|
||||
self.init_loss = Tensor(0.0, mstype.float32)
|
||||
|
||||
def _auxiliary_loss(self, expert_mask, router_prob):
|
||||
"""
|
||||
|
@ -525,7 +531,7 @@ class TopkRouter(Cell):
|
|||
|
||||
accum_expert_mask = 0
|
||||
accum_expert_gate = 0
|
||||
loss = 0
|
||||
loss = self.init_loss
|
||||
mask_count = 0
|
||||
accum_combine_tensor = 0
|
||||
# Probabilities for each token of what expert is should be sent to
|
||||
|
@ -542,7 +548,7 @@ class TopkRouter(Cell):
|
|||
router_prob_normal = self.div1(router_prob, self.add1(self.reduce_sum_keep(router_prob, -1), 1e-9))
|
||||
|
||||
# the balance loss is computed at each routing step
|
||||
loss += self._auxiliary_loss(expert_mask, router_prob_normal)
|
||||
loss = self.add_scala(loss, self._auxiliary_loss(expert_mask, router_prob_normal))
|
||||
|
||||
output = self._maskout_overflowed_tokens(expert_mask, expert_capacity, expert_gate,
|
||||
mask_count, expert_chosen_index)
|
||||
|
|
|
@ -22,6 +22,7 @@ from mindspore.context import set_auto_parallel_context, ParallelMode
|
|||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.parallel.nn import Transformer, TransformerOpParallelConfig, MoEConfig, CrossEntropyLoss
|
||||
from mindspore.parallel import set_algo_parameters
|
||||
from mindspore.nn.optim import AdamWeightDecay
|
||||
from mindspore.nn.wrap.cell_wrapper import TrainOneStepCell, _VirtualDatasetCell
|
||||
from mindspore.train import Model
|
||||
|
@ -80,15 +81,7 @@ class NetWithLossMoe(nn.Cell):
|
|||
return self.add(predict, moe_loss)
|
||||
|
||||
|
||||
def test_transformer_model():
|
||||
"""
|
||||
Feature: Test Transformer+MoE, with All2All enabled.
|
||||
Description: 3-dim input.
|
||||
Expectation: Successful graph compilation with All2All included.
|
||||
"""
|
||||
set_auto_parallel_context(device_num=16, global_rank=0,
|
||||
full_batch=True, enable_alltoall=True,
|
||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
def run_transformer_model():
|
||||
net = Transformer(encoder_layers=1,
|
||||
decoder_layers=1,
|
||||
batch_size=2,
|
||||
|
@ -116,15 +109,32 @@ def test_transformer_model():
|
|||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def test_transformer_model_2d():
|
||||
def test_transformer_model_semi():
|
||||
"""
|
||||
Feature: Test Transformer+MoE, with All2All enabled.
|
||||
Description: 2-dim input.
|
||||
Description: 3-dim input.
|
||||
Expectation: Successful graph compilation with All2All included.
|
||||
"""
|
||||
set_auto_parallel_context(device_num=16, global_rank=0,
|
||||
full_batch=True, enable_alltoall=True,
|
||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
run_transformer_model()
|
||||
|
||||
|
||||
def test_transformer_model_sp():
|
||||
"""
|
||||
Feature: Test Transformer+MoE, with All2All enabled and sharding propagation.
|
||||
Description: 3-dim input.
|
||||
Expectation: Successful graph compilation with All2All included.
|
||||
"""
|
||||
set_auto_parallel_context(device_num=16, global_rank=0, search_mode="sharding_propagation",
|
||||
full_batch=True, enable_alltoall=True,
|
||||
parallel_mode=ParallelMode.AUTO_PARALLEL)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
|
||||
run_transformer_model()
|
||||
|
||||
|
||||
def run_transformer_model_2d():
|
||||
net = Transformer(encoder_layers=1,
|
||||
decoder_layers=1,
|
||||
batch_size=2,
|
||||
|
@ -153,6 +163,31 @@ def test_transformer_model_2d():
|
|||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def test_transformer_model_2d_semi():
|
||||
"""
|
||||
Feature: Test Transformer+MoE, with All2All enabled.
|
||||
Description: 2-dim input.
|
||||
Expectation: Successful graph compilation with All2All included.
|
||||
"""
|
||||
set_auto_parallel_context(device_num=16, global_rank=0,
|
||||
full_batch=True, enable_alltoall=True,
|
||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
run_transformer_model_2d()
|
||||
|
||||
|
||||
def test_transformer_model_2d_sp():
|
||||
"""
|
||||
Feature: Test Transformer+MoE, with All2All enabled and sharding propagation.
|
||||
Description: 2-dim input.
|
||||
Expectation: Successful graph compilation with All2All included.
|
||||
"""
|
||||
set_auto_parallel_context(device_num=16, global_rank=0, search_mode="sharding_propagation",
|
||||
full_batch=True, enable_alltoall=True,
|
||||
parallel_mode=ParallelMode.AUTO_PARALLEL)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
|
||||
run_transformer_model_2d()
|
||||
|
||||
|
||||
class TransformerNet(nn.Cell):
|
||||
"""Transformer with loss"""
|
||||
def __init__(self, en_layer, de_layer, parallel_config):
|
||||
|
@ -176,9 +211,6 @@ class TransformerNet(nn.Cell):
|
|||
|
||||
|
||||
def moe_with_loss_plus_mutiparallel(local_parallel_config):
|
||||
set_auto_parallel_context(device_num=16, enable_alltoall=True,
|
||||
full_batch=True, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
|
||||
encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
|
||||
encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
|
||||
decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
|
||||
|
@ -205,6 +237,8 @@ def test_moe_expert_parallel1():
|
|||
Description: 3-dim input.
|
||||
Expectation: Successful graph compilation with All2All included.
|
||||
"""
|
||||
set_auto_parallel_context(device_num=16, enable_alltoall=True,
|
||||
full_batch=True, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
local_p_config = TransformerOpParallelConfig(data_parallel=2, model_parallel=4, expert_parallel=2)
|
||||
moe_with_loss_plus_mutiparallel(local_p_config)
|
||||
|
||||
|
@ -215,21 +249,52 @@ def test_moe_expert_parallel2():
|
|||
Description: 3-dim input.
|
||||
Expectation: Successful graph compilation with All2All included.
|
||||
"""
|
||||
set_auto_parallel_context(device_num=16, enable_alltoall=True,
|
||||
full_batch=True, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
local_p_config = TransformerOpParallelConfig(data_parallel=2, model_parallel=8, expert_parallel=1)
|
||||
moe_with_loss_plus_mutiparallel(local_p_config)
|
||||
|
||||
|
||||
def test_moe_expert_parallel3():
|
||||
"""
|
||||
Feature: Test Transformer+MoE for data_parallel plus expert_parallel, with All2All enabled
|
||||
and sharding propagation.
|
||||
Description: 3-dim input.
|
||||
Expectation: Successful graph compilation with All2All included.
|
||||
"""
|
||||
set_auto_parallel_context(device_num=16, enable_alltoall=True, search_mode="sharding_propagation",
|
||||
full_batch=True, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
|
||||
local_p_config = TransformerOpParallelConfig(data_parallel=2, model_parallel=4, expert_parallel=2)
|
||||
moe_with_loss_plus_mutiparallel(local_p_config)
|
||||
|
||||
|
||||
def test_moe_expert_parallel4():
|
||||
"""
|
||||
Feature: Test Transformer+MoE for data_parallel plus expert_parallel, with All2All enabled
|
||||
and sharding propagation.
|
||||
Description: 3-dim input.
|
||||
Expectation: Successful graph compilation with All2All included.
|
||||
"""
|
||||
set_auto_parallel_context(device_num=16, enable_alltoall=True, search_mode="sharding_propagation",
|
||||
full_batch=True, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
|
||||
local_p_config = TransformerOpParallelConfig(data_parallel=2, model_parallel=8, expert_parallel=1)
|
||||
moe_with_loss_plus_mutiparallel(local_p_config)
|
||||
|
||||
|
||||
def test_moe_expert_parallel_exception1():
|
||||
"""
|
||||
Feature: Test Transformer+MoE for data_parallel plus expert_parallel, with All2All enabled.
|
||||
Description: 3-dim input.
|
||||
Expectation: Successful graph compilation.
|
||||
Expectation: Raise ValueError.
|
||||
"""
|
||||
local_p_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, expert_parallel=2)
|
||||
with pytest.raises(ValueError):
|
||||
moe_with_loss_plus_mutiparallel(local_p_config)
|
||||
|
||||
def test_moe_expert_parallel_exception():
|
||||
|
||||
def test_moe_expert_parallel_exception2():
|
||||
"""
|
||||
Feature: Test Transformer+MoE for data_parallel plus expert_parallel, with All2All enabled.
|
||||
Description: data_parallel*model_parallel*expert_parallel > device_num
|
||||
|
|
|
@ -166,10 +166,7 @@ def run_total_transformer_model_head(e_layer,
|
|||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def test_transformer_model():
|
||||
set_auto_parallel_context(device_num=8, global_rank=0,
|
||||
full_batch=True,
|
||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
def run_transformer_model():
|
||||
net = Transformer(encoder_layers=1,
|
||||
decoder_layers=2,
|
||||
batch_size=2,
|
||||
|
@ -197,10 +194,32 @@ def test_transformer_model():
|
|||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def test_transformer_model_2d_inputs():
|
||||
def test_transformer_model_semi():
|
||||
"""
|
||||
Feature: Test Transformer.
|
||||
Description: 3-dim input.
|
||||
Expectation: Successful graph compilation.
|
||||
"""
|
||||
set_auto_parallel_context(device_num=8, global_rank=0,
|
||||
full_batch=True,
|
||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
run_transformer_model()
|
||||
|
||||
|
||||
def test_transformer_model_sp():
|
||||
"""
|
||||
Feature: Test Transformer with sharding propagation.
|
||||
Description: 3-dim input.
|
||||
Expectation: Successful graph compilation.
|
||||
"""
|
||||
set_auto_parallel_context(device_num=8, global_rank=0,
|
||||
full_batch=True, search_mode="sharding_propagation",
|
||||
parallel_mode=ParallelMode.AUTO_PARALLEL)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
|
||||
run_transformer_model()
|
||||
|
||||
|
||||
def run_transformer_model_2d_inputs():
|
||||
net = Transformer(encoder_layers=1,
|
||||
decoder_layers=2,
|
||||
batch_size=2,
|
||||
|
@ -228,6 +247,31 @@ def test_transformer_model_2d_inputs():
|
|||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def test_transformer_model_2d_semi():
|
||||
"""
|
||||
Feature: Test Transformer.
|
||||
Description: 2-dim input.
|
||||
Expectation: Successful graph compilation.
|
||||
"""
|
||||
set_auto_parallel_context(device_num=8, global_rank=0,
|
||||
full_batch=True,
|
||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
run_transformer_model_2d_inputs()
|
||||
|
||||
|
||||
def test_transformer_model_2d_sp():
|
||||
"""
|
||||
Feature: Test Transformer with sharding propagation.
|
||||
Description: 2-dim input.
|
||||
Expectation: Successful graph compilation.
|
||||
"""
|
||||
set_auto_parallel_context(device_num=8, global_rank=0,
|
||||
full_batch=True, search_mode="sharding_propagation",
|
||||
parallel_mode=ParallelMode.AUTO_PARALLEL)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
|
||||
run_transformer_model_2d_inputs()
|
||||
|
||||
|
||||
class TestTransformerEmbeddingHead:
|
||||
def __init__(self):
|
||||
self.output_path = None
|
||||
|
@ -255,19 +299,12 @@ class TestTransformerEmbeddingHead:
|
|||
appear_count += 1
|
||||
assert appear_count == target_count
|
||||
|
||||
def test_pipeline_with_embedding(self):
|
||||
"""
|
||||
Feature: Test Transformer with embedding as shared
|
||||
Description: When do pipeline training and applied optimzier shard, the embedding which is model parallel will
|
||||
raise the shape error. This test cast is ensure there is no error raised.
|
||||
Expectation: The number of AssignAdd is not as expected.
|
||||
"""
|
||||
def run_pipeline_with_embedding(self):
|
||||
bs = 16
|
||||
pp = 2
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=pp,
|
||||
full_batch=True,
|
||||
enable_parallel_optimizer=True)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
cf = TransformerOpParallelConfig(data_parallel=1, model_parallel=4, pipeline_stage=pp, vocab_emb_dp=False)
|
||||
pipeline_net = TransformerEncoderNet(batch_size=bs // pp,
|
||||
en_layer=2, de_layer=0, parallel_config=cf)
|
||||
|
@ -284,11 +321,29 @@ class TestTransformerEmbeddingHead:
|
|||
run_network_function(dataset, pipeline_cell_net)
|
||||
self.virtual_assign_add_from_ir(pattern=r'AssignAdd(', target_count=35)
|
||||
|
||||
def test_pipeline_with_embedding_semi(self):
|
||||
"""
|
||||
Feature: Test Transformer with embedding as shared
|
||||
Description: When do pipeline training and applied optimzier shard, the embedding which is model parallel will
|
||||
raise the shape error. This test cast is ensure there is no error raised.
|
||||
Expectation: The number of AssignAdd is not as expected.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
self.run_pipeline_with_embedding()
|
||||
|
||||
def test_transformer_model_int64_inputs():
|
||||
set_auto_parallel_context(device_num=8, global_rank=0,
|
||||
full_batch=True,
|
||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
def test_pipeline_with_embedding_sp(self):
|
||||
"""
|
||||
Feature: Test Transformer with embedding as shared, using sharding propagation.
|
||||
Description: When do pipeline training and applied optimzier shard, the embedding which is model parallel will
|
||||
raise the shape error. This test cast is ensure there is no error raised.
|
||||
Expectation: The number of AssignAdd is not as expected.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", search_mode="sharding_propagation")
|
||||
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
|
||||
self.run_pipeline_with_embedding()
|
||||
|
||||
|
||||
def run_transformer_model_int64_inputs():
|
||||
net = Transformer(encoder_layers=1,
|
||||
decoder_layers=2,
|
||||
batch_size=2,
|
||||
|
@ -317,16 +372,69 @@ def test_transformer_model_int64_inputs():
|
|||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def test_transformer_model_int64_semi():
|
||||
"""
|
||||
Feature: Test Transformer.
|
||||
Description: int64 input.
|
||||
Expectation: Successful graph compilation.
|
||||
"""
|
||||
set_auto_parallel_context(device_num=8, global_rank=0,
|
||||
full_batch=True,
|
||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
run_transformer_model_int64_inputs()
|
||||
|
||||
|
||||
def test_transformer_model_int64_sp():
|
||||
"""
|
||||
Feature: Test Transformer with sharding propagation.
|
||||
Description: int64 input.
|
||||
Expectation: Successful graph compilation.
|
||||
"""
|
||||
set_auto_parallel_context(device_num=8, global_rank=0,
|
||||
full_batch=True, search_mode="sharding_propagation",
|
||||
parallel_mode=ParallelMode.AUTO_PARALLEL)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
|
||||
run_transformer_model_int64_inputs()
|
||||
|
||||
|
||||
def test_transformer_model_head_parallel_only_encoder():
|
||||
local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
|
||||
run_total_transformer_model_head(e_layer=2, d_layer=0, arg_parallel_config=local_config)
|
||||
|
||||
|
||||
def test_transformer_model_head_parallel_only_encoder_sp():
|
||||
"""
|
||||
Feature: Test Transformer with sharding propagation.
|
||||
Description: only encode.
|
||||
Expectation: Successful graph compilation.
|
||||
"""
|
||||
local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
|
||||
set_auto_parallel_context(search_mode="sharding_propagation",
|
||||
parallel_mode=ParallelMode.AUTO_PARALLEL)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
|
||||
run_total_transformer_model_head(e_layer=2, d_layer=0, arg_parallel_config=local_config,
|
||||
mode=ParallelMode.AUTO_PARALLEL)
|
||||
|
||||
|
||||
def test_transformer_model_head_parallel():
|
||||
local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
|
||||
run_total_transformer_model_head(e_layer=1, d_layer=1, arg_parallel_config=local_config)
|
||||
|
||||
|
||||
def test_transformer_model_head_parallel_sp():
|
||||
"""
|
||||
Feature: Test Transformer with sharding propagation.
|
||||
Description: 1 encode and 1 decode.
|
||||
Expectation: Successful graph compilation.
|
||||
"""
|
||||
local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
|
||||
set_auto_parallel_context(search_mode="sharding_propagation",
|
||||
parallel_mode=ParallelMode.AUTO_PARALLEL)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
|
||||
run_total_transformer_model_head(e_layer=1, d_layer=1, arg_parallel_config=local_config,
|
||||
mode=ParallelMode.AUTO_PARALLEL)
|
||||
|
||||
|
||||
def test_transformer_model_head_parallel_decoder():
|
||||
local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
|
||||
with pytest.raises(ValueError):
|
||||
|
@ -353,8 +461,7 @@ def pipeline_single_transformer(grad_accumulation_shard=False):
|
|||
"""
|
||||
set_auto_parallel_context(device_num=64,
|
||||
full_batch=True,
|
||||
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0,
|
||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0)
|
||||
context.set_auto_parallel_context(parallel_optimizer_config=
|
||||
{"gradient_accumulation_shard": grad_accumulation_shard})
|
||||
|
||||
|
@ -394,6 +501,19 @@ def test_pipeline_transformer_gradient_shard_true():
|
|||
Description: Test a single transformer model with pipeline parallel with grad_accumulation_shard True
|
||||
Expectation: The compile passed
|
||||
"""
|
||||
set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
pipeline_single_transformer(grad_accumulation_shard=True)
|
||||
|
||||
|
||||
def test_pipeline_transformer_gradient_shard_true_sp():
|
||||
"""
|
||||
Feature: Gradient Accumulation Shard for Pipeline and Gradient Accumulation with sharding propagation
|
||||
Description: Test a single transformer model with pipeline parallel with grad_accumulation_shard True
|
||||
Expectation: The compile passed
|
||||
"""
|
||||
set_auto_parallel_context(search_mode="sharding_propagation",
|
||||
parallel_mode=ParallelMode.AUTO_PARALLEL)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
|
||||
pipeline_single_transformer(grad_accumulation_shard=True)
|
||||
|
||||
|
||||
|
@ -403,6 +523,19 @@ def test_pipeline_transformer_gradient_shard_false():
|
|||
Description: Test a single transformer model with pipeline parallel with grad_accumulation_shard False
|
||||
Expectation: The compile passed
|
||||
"""
|
||||
set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
pipeline_single_transformer(grad_accumulation_shard=False)
|
||||
|
||||
|
||||
def test_pipeline_transformer_gradient_shard_false_sp():
|
||||
"""
|
||||
Feature: Gradient Accumulation Shard for Pipeline and Gradient Accumulation with sharding propagation
|
||||
Description: Test a single transformer model with pipeline parallel with grad_accumulation_shard False
|
||||
Expectation: The compile passed
|
||||
"""
|
||||
set_auto_parallel_context(search_mode="sharding_propagation",
|
||||
parallel_mode=ParallelMode.AUTO_PARALLEL)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
|
||||
pipeline_single_transformer(grad_accumulation_shard=False)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue