!32529 Fix optimizer limit for the Parallel Optimizer

Merge pull request !32529 from huangxinjing/fix_model_limit
This commit is contained in:
i-robot 2022-04-06 03:39:36 +00:00 committed by Gitee
commit e09a674bcc
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 122 additions and 8 deletions

View File

@ -129,6 +129,18 @@ CNodePtr FindNodeWithMircoSize(const AnfNodePtr &node_user, const NodeUsersMap &
return nullptr;
}
bool IsSourceUsedByMirror(const CNodePtr &node, const NodeUsersMap &node_user_map) {
if (node->inputs().size() < 2) return false;
auto parameter_node = node->input(1);
if (parameter_node->cast<ParameterPtr>()) {
for (auto &item : node_user_map.at(parameter_node)) {
if (IsPrimitiveCNode(item.first, prim::kPrimMirrorMicroStep)) {
return true;
}
}
}
return false;
}
void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const FuncGraphManagerPtr &manager,
const AnfNodePtr &accu_parameter, const NodeUsersMap &node_user_map) {
auto cnode = node_user.first->cast<CNodePtr>();
@ -138,7 +150,8 @@ void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const F
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer();
bool grad_accumulation_shard = ParallelContext::GetInstance()->grad_accumulation_shard();
if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && enable_parallel_optimizer) {
if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && enable_parallel_optimizer &&
IsSourceUsedByMirror(cnode, node_user_map)) {
return;
}
auto prim = GetCNodePrimitive(cnode);

View File

@ -266,6 +266,7 @@ class AdaFactor(Optimizer):
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim)
"""
_support_parallel_optimizer = True
@opt_init_args_register
def __init__(self,

View File

@ -563,6 +563,8 @@ class AdamWeightDecay(Optimizer):
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim)
"""
_support_parallel_optimizer = True
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
super(AdamWeightDecay, self).__init__(learning_rate, params, weight_decay)
_check_param_value(beta1, beta2, eps, self.cls_name)

View File

@ -336,6 +336,7 @@ class Lamb(Optimizer):
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim)
"""
_support_parallel_optimizer = True
@opt_init_args_register
def __init__(self, params, learning_rate, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):

View File

@ -135,10 +135,13 @@ class Optimizer(Cell):
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
_support_parallel_optimizer = False
def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0):
super(Optimizer, self).__init__(auto_prefix=False)
parameters = self._parameters_base_check(parameters, "parameters")
self.param_rank = None
self.optim_filter = None
if not all(isinstance(x, Parameter) for x in parameters) and not all(isinstance(x, dict) for x in parameters):
raise TypeError("For 'Optimizer', all elements of the argument 'parameters' must be 'Parameter' or 'dict',"
" please check the 'parameters'.")
@ -237,9 +240,9 @@ class Optimizer(Cell):
else:
self.use_parallel = False
if self.use_parallel:
if self.cls_name not in ["Lamb", "AdamWeightDecay", "AdaFactor"]:
raise RuntimeError("For 'Optimizer', parallel optimizer only support optimizer 'Lamb' and "
"'AdamWeightDecay' and 'AdaFactor', but got {}.".format(self.cls_name))
if not self._support_parallel_optimizer:
raise RuntimeError("For 'Optimizer', parallel optimizer shard doest not support "
"optimizer {}.".format(self.cls_name))
self.dev_num = _get_device_num()
if self.dev_num > self.param_length:
raise RuntimeError("Parallel optimizer can not be applied when the number of parameters {} is"

View File

@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import glob
import numpy as np
import pytest
@ -24,7 +27,6 @@ from mindspore.ops import functional as F
import mindspore.ops as P
from mindspore.parallel.nn import TransformerEncoder, TransformerDecoder, Transformer, TransformerOpParallelConfig, \
VocabEmbedding, CrossEntropyLoss, OpParallelConfig, EmbeddingOpParallelConfig, FixedSparseAttention
from mindspore.nn import Dense as Linear
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.nn.optim import AdamWeightDecay
from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell, TrainOneStepCell
@ -60,8 +62,6 @@ class Dataset(MindData):
class TransformerNet(nn.Cell):
def __init__(self, en_layer, de_layer, parallel_config):
super(TransformerNet, self).__init__()
self.embedding = VocabEmbedding(vocab_size=240, embedding_size=20,
parallel_config=config.embedding_dp_mp_config)
self.network = Transformer(encoder_layers=en_layer,
decoder_layers=de_layer,
batch_size=2,
@ -71,7 +71,6 @@ class TransformerNet(nn.Cell):
num_heads=8,
ffn_hidden_size=64,
parallel_config=parallel_config)
self.head = Linear(in_channels=64, out_channels=200)
self.loss = CrossEntropyLoss(parallel_config=config.dp_mp_config)
def construct(self, x1, x2, x3, x4, x5, y, mask):
@ -79,6 +78,32 @@ class TransformerNet(nn.Cell):
predict = P.Reshape()(predict, (-1, F.shape(predict)[-1]))
return self.loss(predict, y, mask)
class TransformerEncoderNet(nn.Cell):
def __init__(self, batch_size, en_layer, de_layer, parallel_config):
super(TransformerEncoderNet, self).__init__()
self.embedding = VocabEmbedding(vocab_size=240, embedding_size=64,
parallel_config=parallel_config.embedding_dp_mp_config)
self.network = Transformer(encoder_layers=en_layer,
decoder_layers=de_layer,
batch_size=batch_size,
src_seq_length=20,
tgt_seq_length=10,
hidden_size=64,
num_heads=8,
ffn_hifloat16dden_size=64,
parallel_config=parallel_config)
self.loss = CrossEntropyLoss(parallel_config=config.dp_mp_config)
def construct(self, x, encoder_mask, label, input_mask):
embedded, _ = self.embedding(x)
logits, _, = self.network(embedded, encoder_mask)
logits = P.Reshape()(logits, (-1, F.shape(logits)[-1]))
label = P.Reshape()(label, (-1,))
input_mask = P.Reshape()(input_mask, (-1,))
return self.loss(logits, label, input_mask)
config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False)
pipeline_config = TransformerOpParallelConfig(data_parallel=2, model_parallel=8, pipeline_stage=4,
micro_batch_num=4, vocab_emb_dp=False)
@ -95,6 +120,18 @@ class NetWithLossFiveInputs(nn.Cell):
return self.loss(predict)
def run_network_function(dataset, pipeline_net):
"""
Feature: Test transformer embedding shared.
Description: a basic function for test compiling.
Expectation: success.
"""
params = pipeline_net.trainable_params()
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(pipeline_net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
def run_total_transformer_model_head(e_layer,
d_layer,
arg_parallel_config,
@ -191,6 +228,63 @@ def test_transformer_model_2d_inputs():
model.train(1, dataset, dataset_sink_mode=False)
class TestTransformerEmbeddingHead:
def __init__(self):
self.output_path = None
def setup_method(self):
self.output_path = './graphs' + self.__str__()
context.set_context(save_graphs=True,
save_graphs_path=self.output_path)
def teardown_method(self):
shutil.rmtree(self.output_path)
def virtual_assign_add_from_ir(self, pattern, target_count):
"""
This function will check the assign aa count with the golden one.
:param pattern: The match pattern for the specific count
:param target_count: The gold float16 count in the Ir files
"""
ir_files = glob.glob(os.path.join(self.output_path, 'rank_0', '*_validate*.ir'))
assert len(ir_files) == 1
appear_count = 0
with open(ir_files[0], 'r') as fp:
for line in fp:
if pattern in line:
appear_count += 1
assert appear_count == target_count
def test_pipeline_with_embedding(self):
"""
Feature: Test Transformer with embedding as shared
Description: When do pipeline training and applied optimzier shard, the embedding which is model parallel will
raise the shape error. This test cast is ensure there is no error raised.
Expectation: The number of AssignAdd is not as expected.
"""
bs = 16
pp = 2
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=pp,
full_batch=True,
enable_parallel_optimizer=True)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
cf = TransformerOpParallelConfig(data_parallel=1, model_parallel=4, pipeline_stage=pp, vocab_emb_dp=False)
pipeline_net = TransformerEncoderNet(batch_size=bs // pp,
en_layer=2, de_layer=0, parallel_config=cf)
pipeline_net.embedding.pipeline_stage = 0
pipeline_net.network.encoder.blocks[0].pipeline_stage = 0
pipeline_net.network.encoder.blocks[1].pipeline_stage = 1
pipeline_cell_net = PipelineCell(pipeline_net, 2)
encoder_input_value = Tensor(np.ones((bs, 20)), mstype.int32)
encoder_input_mask = Tensor(np.ones((bs, 20, 20)), mstype.float16)
label = Tensor(np.ones((bs, 20)), mstype.int32)
mask = Tensor(np.ones((bs, 20)), mstype.float32)
dataset = Dataset(encoder_input_value, encoder_input_mask, label, mask)
run_network_function(dataset, pipeline_cell_net)
self.virtual_assign_add_from_ir(pattern=r'AssignAdd(', target_count=35)
def test_transformer_model_int64_inputs():
set_auto_parallel_context(device_num=8, global_rank=0,
full_batch=True,