From 9e69f9fac040476beb4388637452747d50a0f9a7 Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Mon, 21 Sep 2020 11:52:00 +0800 Subject: [PATCH] add transformer hub_conf --- .../nlp/transformer/mindspore_hub_conf.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 model_zoo/official/nlp/transformer/mindspore_hub_conf.py diff --git a/model_zoo/official/nlp/transformer/mindspore_hub_conf.py b/model_zoo/official/nlp/transformer/mindspore_hub_conf.py new file mode 100644 index 00000000000..a37f0261f32 --- /dev/null +++ b/model_zoo/official/nlp/transformer/mindspore_hub_conf.py @@ -0,0 +1,48 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +''' +Transformer hub interface for transformer large +''' +from src.transformer_model import TransformerModel +from src.transformer_model import TransformerConfig +import mindspore.common.dtype as mstype +transformer_net_cfg_large = TransformerConfig( + batch_size=96, + seq_length=128, + vocab_size=36560, + hidden_size=1024, + num_hidden_layers=6, + num_attention_heads=16, + intermediate_size=4096, + hidden_act="relu", + hidden_dropout_prob=0.2, + attention_probs_dropout_prob=0.2, + max_position_embeddings=128, + initializer_range=0.02, + label_smoothing=0.1, + input_mask_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16 +) +def create_network(name, *args, **kwargs): + ''' + Create transformer network for large. + ''' + if name == 'transformer_large': + if "seq_length" in kwargs: + transformer_net_cfg_large.seq_length = kwargs["seq_length"] + is_training = kwargs.get("is_training", False) + return TransformerModel(transformer_net_cfg_large, is_training, *args) + raise NotImplementedError(f"{name} is not implemented in the repo")