!19622 Serving, pangu alpha modelzoo

Merge pull request !19622 from 徐永飞/master
This commit is contained in:
i-robot 2021-07-08 11:34:23 +00:00 committed by Gitee
commit b4c04ef3a8
19 changed files with 890 additions and 158 deletions

View File

@ -1,4 +1,3 @@
# It is still under development
# Contents
@ -185,7 +184,7 @@ bash scripts/run_distribute_train_incremental_train.sh DATASET RANK_TABLE 8 fp32
Please refer to the [website](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha) to download the following parts:
- tokenizer: vocab.txt and vocab.model
- checkpoint file: \*.part\[0-4\] and *.npy under the same parameter size
- checkpoint file: \*.part\[0-4\] (need to extract) and *.npy under the same parameter size
- strategy file: a file described how the parameters are sliced across different devices.
Here we suppose the downloaded checkpoint, tokenizer and strategy file is organized as follows:
@ -204,6 +203,12 @@ ckpts
└── vocab10.vocab
```
We provide two predict methods. The first one is the normal way which needs to pad the input to a certain length every
iteration. Due to the redundant calculation, the latency of this method is quite high and to accelerate the speed
performance, we provide the second state reuse (incremental inference) method.
The state reuse method is the default mode, and you can disable it by changing the argument 'use_past' to False.
### Run Prediction on Distributed mode
The following script will run prediction on 8 Ascend cards.
@ -217,28 +222,217 @@ ${FILE_PATH}/tokenizer/ ${FILE_PATH}/checkpoint_file filitered 2.6B fp32
### Run Prediction Using One Device
The following script will run prediction on 1 Ascend cards. The difference is the net is initialized with float16 type.
And the rank_table should be configured to one device.
```bash
$FILE_PATH=/home/your_path/ckpts
bash scripts/run_distribute_predict.sh 1 /home/config/rank_table_1p.json ${FILE_PATH}/strategy_load_ckpt/strategy.ckpt \
${FILE_PATH}/tokenizer/ ${FILE_PATH}/checkpoint_file filitered 2.6B fp16
bash scripts/run_standalone_predict.sh ${FILE_PATH}/strategy_load_ckpt/strategy.ckpt \
${FILE_PATH}/tokenizer/ ${FILE_PATH}/checkpoint_file filitered 2.6B
```
### Run Serving
In directory serving:
#### Preparation
- Use scripts/run_distribute_export.sh to export MindIR models, and copy all device* to serving_increment/models/.
- Download [PanGu-Alpha tokenizer repository](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha.git) and copy pangu-alpha/tokenizer to directory pangu/tokenizer.
- Pip install MindSpore and MindSpore Serving 1.2 whl package.
- Pip install flask, flask-apscheduler, jieba, sentencepiece whl package.
- Edit server_agent.py and update the path of pangu-alpha models.
- Run 'bash start_pangu.sh' to start new execution.
- Wait for serving to start successfully: observe the serving_server.log file until the message "Serving: gRPC server start success, listening on 127.0.0.1:5500" is output.
- If any error happened, log can be viewed in serving_server.log, serving_agent.log and flask.log.
- If anything all right, access address {ip}:5000 in one browser.
- Run 'bash stop_pangu.sh' to stop the existing execution.
- Pip install MindSpore and MindSpore Serving 1.3 or later.
- Pip install flask, flask-apscheduler, jieba, sentencepiece and other whl package if needed.
- Download [PanGu-Alpha repository](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha), we will need
`pangu-alpha/strategy_load_ckpt` and `pangu-alpha/tokenizer` in the following process.
- Download 13B or 2.6B checkpoint files and `*embedding` files
from [PanGu-Alpha repository](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha).
For 13B, we will need `13B_part0` to `13B_part3`, `13B_word_embedding`, `13B_top_query_embedding`
, `13B_position_embedding`.
For 2.6B, we will need `2.6B_part0` to `2.6B_part3`, `13B_word_embedding`, `2.6B_top_query_embedding`
, `2.6B_position_embedding`.
Decompress all the `13B_part*` or `2.6B_part*` tar files and a large number of `*ckpt` files will generate. Move
all `*embedding` to the same directory of `*.ckpt` files.
#### Run 13B or 2.6B in standalone mode[Ascend910/Nvidia GPU]
- Use scripts/run_standalone_export.sh to export MindIR models, and move all device_0/* to
'serving_increment/pangu_standalone/pangu/1/'.
```shell
>>> cd scripts
>>> bash run_standalone_export.sh ${strategy_file_path} ${ckpt_dir_path}
```
Update the parameter `MODE` in `run_standalone_export.sh` from `13B` to `2.6B` if we want to export 2.6B model.
Update the parameter `DEVICE_TARGET` in `run_standalone_export.sh` from `Ascend` to `GPU` when running in GPU environment.
The `${strategy_file_path}` is file path of `pangu-alpha/strategy_load_ckpt/angu_alpha_13B_cktp_strategy.ckpt` for 13B
and `pangu-alpha/strategy_load_ckpt/angu_alpha_2.6B_cktp_strategy.ckpt` for 2.6B.
The `${ckpt_dir_path}` is the directory the `*ckpt` files generated by decompression and `*embedding` files.
The model will be exported for some minutes. Check log device_0/log0.log, confirm that there is no exception at last.
Confirm that mindir files have been generated in device_0/ which means that the model is exported successfully.
```shell
>>> ls device_0
pangu_alpha_1024_graph.mindir pangu_alpha_1024_variables pangu_alpha_1_graph.mindir pangu_alpha_1_variables
>>> cd - && mkdir serving_increment/pangu_standalone/pangu/1/
>>> mv scripts/device_0/* serving_increment/pangu_standalone/pangu/1/
>>> cd serving_increment
```
- Copy `pangu-alpha/tokenizer` to directory serving_increment/pangu_standalone/pangu/tokenizer.
The directory hierarchy of the required files is shown below. The pangu_alpha_1024_variables and pangu_alpha_1_variables are folded for easy display.
```shell
>>> tree pangu_distributed
pangu_standalone/
├── pangu
│ ├── 1
│ │ ├── pangu_alpha_1024_graph.mindir
│ │ ├── pangu_alpha_1024_variables/
│ │ ├── pangu_alpha_1_graph.mindir
│ │ └── pangu_alpha_1_variables/
│ ├── servable_config.py
│ ├── tokenization_jieba.py
│ └── tokenizer
│ ├── vocab.model
│ └── vocab.vocab
└── serving_server.py
```
- Run `bash start_pangu_standalone.sh` to start new execution, and wait until the serving and flask server are started
successfully.
- If any error happened, log can be viewed in serving_server.log, serving_logs/*.log and flask.log.
- If anything all right, access address {ip}:5000 in one browser. It will take some time to return the reply.
- Run `bash stop_pangu.sh` to stop the existing execution.
#### Run 13B or 2.6B in distributed mode[Ascend910 8 cards]
- Generate [rank table file](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
```shell
# mindspore/model_zoo/utils/hccl_tools/hccl_tools.py
>>> python3 ../../../utils/hccl_tools/hccl_tools.py --device_num "[0,8]"
>>> mv hccl_8p_01234567*.json serving_increment/pangu_distributed/hccl_8p.json
```
- Use scripts/run_distribute_export.sh to export MindIR models, and move all device* to
'serving_increment/pangu_distributed/models/'.
```shell
>>> cd scripts
>>> bash run_distribute_export.sh ${strategy_file_path} ${ckpt_dir_path}
```
Update the parameter `MODE` in `run_distribute_export.sh` from `13B` to `2.6B` if we want to export 2.6B model.
The `${strategy_file_path}` is file path of `pangu-alpha/strategy_load_ckpt/angu_alpha_13B_cktp_strategy.ckpt` for 13B
and `pangu-alpha/strategy_load_ckpt/angu_alpha_2.6B_cktp_strategy.ckpt` for 2.6B.
The `${ckpt_dir_path}` is the directory the *ckpt files generated by decompression and *embedding files.
The model will be exported for some minutes. Check log device_[0-7]/log[0-7].log, confirm that there is no exception at
last. Confirm that mindir files have been generated in device_[0-7]/ which means that the model is exported
successfully.
```shell
>>> cd - && mkdir serving_increment/pangu_distributed/models/
>>> mv scripts/device_* serving_increment/pangu_distributed/models/
>>> cd serving_increment
```
- Update MindIR file name serving_increment/pangu_distributed/serving_agent.py if needed.
- Copy `pangu-alpha/tokenizer` to directory serving_increment/pangu_distributed/pangu/tokenizer.
The directory hierarchy of the required files is shown below. The device_1 to device_7 are folded for easy display.
```shell
>>> tree pangu_distributed
pangu_distributed/
├── hccl_8p.json
├── models
│ ├── device_0
│ │ ├── pangu_alpha_1024_graph.mindir
│ │ ├── pangu_alpha_1024_variables
│ │ │ ├── data_0
│ │ │ ├── data_1
│ │ │ ├── data_2
│ │ │ ├── data_3
│ │ │ └── data_4
│ │ ├── pangu_alpha_1_graph.mindir
│ │ └── pangu_alpha_1_variables
│ │ ├── data_0
│ │ ├── data_1
│ │ ├── data_2
│ │ ├── data_3
│ │ └── data_4
│ ├── device_1/
│ ├── device_2/
│ ├── device_3/
│ ├── device_4/
│ ├── device_5/
│ ├── device_6/
│ └── device_7/
├── pangu
│ ├── servable_config.py
│ ├── tokenization_jieba.py
│ └── tokenizer
│ ├── vocab.model
│ └── vocab.vocab
├── serving_agent.py
└── serving_server.py
```
- Run `bash start_pangu_distributed.sh` to start new execution, and wait until the serving and flask server are started
successfully.
- If any error happened, log can be viewed in serving_server.log, serving_agent.log, serving_logs/*.log and flask.log.
- If anything all right, access address {ip}:5000 in one browser. It will take some time to return the reply.
- Run `bash stop_pangu.sh` to stop the existing execution.
#### Run in distributed mode[Ascend910 8 cards * N machine]
- Generate [rank table file](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
- In every machine, prepare for checkpoint files and embedding files. We can also use 13B as a test example.
- In every machine, use scripts/run_cluster_export.sh to export MindIR models, and move all device* to
'serving_increment/pangu_distributed/models/'.
```shell
>>> cd scripts
>>> bash run_cluster_export.sh ${strategy_file_path} ${ckpt_dir_path} ${rank_table_file} ${rank_size} ${rank_start}
```
Update the parameter `MODE` in `run_distribute_export.sh` from `200B` to `13B` if we want to export 13B model.
The `${rank_start}` is the first rank id in every machine, likes 0,8,16,24.
The model will be exported for some minutes. Check log device_[0-7]/log[0-7].log, confirm that there is no exception at
last. Confirm that mindir files have been generated in device_[0-7]/ which means that the model is exported
successfully.
```shell
>>> cd - && mkdir serving_increment/pangu_distributed/models/
>>> mv scripts/device_* serving_increment/pangu_distributed/models/
>>> cd serving_increment
```
- In the first machine, update the parameter `rank_size` and `stage_size`(Pipeline stage size) of `serving_increment/pangu_distributed/pangu/servable_config.py`.
- In the first machine, update the parameter `rank_table_json_file` of `serving_increment/pangu_distributed/serving_server.py`.
- In every machine, update MindIR file name serving_increment/pangu_distributed/serving_agent.py if needed.
- In every machine, update the parameter `distributed_address` of `serving_increment/pangu_distributed/serving_agent.py` and
`serving_increment/pangu_distributed/serving_server.py` to the first machine ip address.
- In the first machine, copy `pangu-alpha/tokenizer` to directory serving_increment/pangu_distributed/pangu/tokenizer.
- In the first machine, run `bash start_pangu_distributed.sh` to start new execution.
- Meanwhile, in other machines, run `python serving_agent.py` to start serving agent process.
```shell
>>> unset http_proxy && unset https_proxy
>>> python pangu_distributed/serving_agent.py > serving_agent.log 2>&1 &
```
- Wait until the serving and flask server are started successfully.
- If any error happened, log can be viewed in serving_server.log, serving_agent.log, serving_logs/*.log and flask.log.
- If anything all right, access address {first_machine_ip}:5000 in one browser. It will take some time to return the reply.
- Run `bash stop_pangu.sh` to stop the existing execution in every machine.
# [Script Description](#contents)

View File

@ -63,8 +63,12 @@ def load_model(args_opt):
else:
rank = 0
device_num = 1
context.reset_auto_parallel_context()
context.set_auto_parallel_context(
strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path)
use_past = (args_opt.use_past == "true")
print('local_rank:{}, start to run...'.format(rank), flush=True)
use_past = False
if args_opt.export:
use_past = True
# Set model property
@ -72,6 +76,9 @@ def load_model(args_opt):
data_parallel_num = int(device_num / model_parallel_num)
per_batch_size = args_opt.per_batch_size
batch_size = per_batch_size * data_parallel_num
# Now only support single batch_size for predict
if args_opt.run_type == "predict":
batch_size = 1
config = PANGUALPHAConfig(
data_parallel_num=data_parallel_num,
model_parallel_num=model_parallel_num,
@ -105,11 +112,15 @@ def load_model(args_opt):
inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32)
current_index = Tensor(np.array([0]), mstype.int32)
if config.use_past:
if args_opt.distribute == "false":
predict_layout = None
elif config.use_past:
batch_valid_length = Tensor(np.array([0]), mstype.int32)
init_true = Tensor([True], mstype.bool_)
inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32)
model_predict.predict_network.add_flags_recursive(is_first_iteration=True)
predict_layout = model_predict.infer_predict_layout(inputs_np, current_index, init_true, batch_valid_length)
model_predict.predict_network.add_flags_recursive(is_first_iteration=False)
_ = model_predict.infer_predict_layout(inputs_np_1, current_index, init_true, batch_valid_length)
else:
predict_layout = model_predict.infer_predict_layout(inputs_np, current_index)

View File

@ -0,0 +1,40 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
execute_path=$(pwd)
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
MODE=200B # or 13B
PARAM_INIT_TYPE=fp16
STRATEGY=$1
CKPT_PATH=$2
export RANK_TABLE_FILE=$3
export RANK_SIZE=$4
RANK_START=$5 # 0,8,16,... for each machine
CKPT_NAME='filerted'
for((i=0;i<8;i++));
do
rm -rf ${execute_path}/device_$i/
mkdir ${execute_path}/device_$i/
cd ${execute_path}/device_$i/ || exit
export RANK_ID=$(($RANK_START+$i))
export DEVICE_ID=$i
python -s ${self_path}/../predict.py --strategy_load_ckpt_path=$STRATEGY --load_ckpt_path=$CKPT_PATH \
--load_ckpt_name=$CKPT_NAME --mode=$MODE --run_type=predict --param_init_type=$PARAM_INIT_TYPE \
--export=1 >log$i.log 2>&1 &
done

View File

@ -18,12 +18,12 @@ execute_path=$(pwd)
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
export RANK_SIZE=8
export RANK_TABLE_FILE=${execute_path}/../serving_increment/hccl_8p.json
export RANK_TABLE_FILE=${execute_path}/../serving_increment/pangu_distributed/hccl_8p.json
export MODE=13B
export PARAM_INIT_TYPE=fp16
export STRATEGY=$1
export CKPT_PATH=$2
export CKPT_NAME=$3
export PARAM_INIT_TYPE=$4
export CKPT_NAME='filerted'
for((i=0;i<$RANK_SIZE;i++));
do
@ -35,4 +35,4 @@ do
python -s ${self_path}/../predict.py --strategy_load_ckpt_path=$STRATEGY --load_ckpt_path=$CKPT_PATH \
--load_ckpt_name=$CKPT_NAME --mode=$MODE --run_type=predict --param_init_type=$PARAM_INIT_TYPE \
--export=1 >log$i.log 2>&1 &
done
done

View File

@ -0,0 +1,38 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
execute_path=$(pwd)
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
export RANK_SIZE=1
export MODE=13B # or 2.6B
export PARAM_INIT_TYPE=fp16
export STRATEGY=$1
export CKPT_PATH=$2
export DEVICE_TARGET=Ascend # or GPU
export CKPT_NAME='filerted'
for((i=0;i<$RANK_SIZE;i++));
do
rm -rf ${execute_path}/device_$i/
mkdir ${execute_path}/device_$i/
cd ${execute_path}/device_$i/ || exit
export RANK_ID=$i
export DEVICE_ID=$i
python -s ${self_path}/../predict.py --strategy_load_ckpt_path=$STRATEGY --load_ckpt_path=$CKPT_PATH \
--load_ckpt_name=$CKPT_NAME --mode=$MODE --run_type=predict --param_init_type=$PARAM_INIT_TYPE \
--export=1 --distribute=false --device_target=$DEVICE_TARGET >log$i.log 2>&1 &
done

View File

@ -0,0 +1,39 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
execute_path=$(pwd)
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
export RANK_SIZE=1
export STRATEGY=$1
export TOKENIZER=$2
export CKPT_PATH=$3
export CKPT_NAME=$4
export MODE=$5
export PARAM_INIT_TYPE=fp16
export DEVICE_TARGET=Ascend # or GPU
for((i=0;i<$RANK_SIZE;i++));
do
rm -rf ${execute_path}/device_$i/
mkdir ${execute_path}/device_$i/
cd ${execute_path}/device_$i/ || exit
export RANK_ID=$i
export DEVICE_ID=$i
python -s ${self_path}/../predict.py --strategy_load_ckpt_path=$STRATEGY --tokenizer_path=$TOKENIZER --load_ckpt_path=$CKPT_PATH \
--load_ckpt_name=$CKPT_NAME --mode=$MODE --run_type=predict --param_init_type=$PARAM_INIT_TYPE \
--distribute=false --device_target=$DEVICE_TARGET >log$i.log 2>&1 &
done

View File

@ -1,109 +0,0 @@
{
"board_id": "0x0020",
"chip_info": "910",
"deploy_mode": "lab",
"group_count": "1",
"group_list": [
{
"device_num": "8",
"server_num": "1",
"group_name": "",
"instance_count": "8",
"instance_list": [
{
"devices": [
{
"device_id": "0",
"device_ip": "192.98.92.121"
}
],
"rank_id": "0",
"server_id": "127.0.0.1"
},
{
"devices": [
{
"device_id": "1",
"device_ip": "192.98.93.121"
}
],
"rank_id": "1",
"server_id": "127.0.0.1"
},
{
"devices": [
{
"device_id": "2",
"device_ip": "192.98.94.121"
}
],
"rank_id": "2",
"server_id": "127.0.0.1"
},
{
"devices": [
{
"device_id": "3",
"device_ip": "192.98.95.121"
}
],
"rank_id": "3",
"server_id": "127.0.0.1"
},
{
"devices": [
{
"device_id": "4",
"device_ip": "192.98.92.122"
}
],
"rank_id": "4",
"server_id": "127.0.0.1"
},
{
"devices": [
{
"device_id": "5",
"device_ip": "192.98.93.122"
}
],
"rank_id": "5",
"server_id": "127.0.0.1"
},
{
"devices": [
{
"device_id": "6",
"device_ip": "192.98.94.122"
}
],
"rank_id": "6",
"server_id": "127.0.0.1"
},
{
"devices": [
{
"device_id": "7",
"device_ip": "192.98.95.122"
}
],
"rank_id": "7",
"server_id": "127.0.0.1"
}
]
}
],
"para_plane_nic_location": "device",
"para_plane_nic_name": [
"eth0",
"eth1",
"eth2",
"eth3",
"eth4",
"eth5",
"eth6",
"eth7"
],
"para_plane_nic_num": "8",
"status": "completed"
}

View File

@ -22,8 +22,8 @@ def start():
"""Start agents to load and execute models of pangu alpha"""
model_files = []
for i in range(8):
model_files.append([f"models/device{i}/pangu_alpha_1024_graph.mindir",
f"models/device{i}/pangu_alpha_1_graph.mindir"])
model_files.append([f"models/device_{i}/pangu_alpha_1024_graph.mindir",
f"models/device_{i}/pangu_alpha_1_graph.mindir"])
distributed.startup_agents(distributed_address="0.0.0.0:6200", model_files=model_files)

View File

@ -21,8 +21,7 @@ from mindspore_serving.server import distributed
def start():
"""Start server to serve service, and manage all agents which load and execute models"""
servable_dir = os.path.abspath(".")
servable_dir = os.path.dirname(os.path.realpath(__file__))
distributed.start_servable(servable_dir, "pangu", rank_table_json_file="hccl_8p.json",
distributed_address="0.0.0.0:6200")

View File

@ -0,0 +1,203 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""servable config for pangu alpha"""
import os
import time
from easydict import EasyDict
import numpy as np
from mindspore_serving.server import register
from pangu.tokenization_jieba import JIEBATokenizer
cur_dir = os.path.abspath(os.path.dirname(__file__))
tokenizer_path = os.path.join(cur_dir, "tokenizer")
tokenizer = JIEBATokenizer(os.path.join(tokenizer_path, "vocab.vocab"), os.path.join(tokenizer_path, "vocab.model"))
end_token = tokenizer.eot_id
config = EasyDict({
'frequency_penalty': 1.5,
'presence_penalty': 0.3,
'max_generate_length': 500,
'top_k_num': 3,
'top_p': 1.0,
'end_token': 9,
'seq_length': 1024,
'vocab_size': 40000,
})
def topk_fun(logits, topk=5):
"""Get topk"""
target_column = logits[0].tolist()
sorted_array = [(k, v) for k, v in enumerate(target_column)]
sorted_array.sort(key=lambda x: x[1], reverse=True)
topk_array = sorted_array[:topk]
index, value = zip(*topk_array)
index = np.array([index])
value = np.array([value])
return value, index
register.declare_servable(servable_file=["pangu_alpha_1024_graph.mindir", "pangu_alpha_1_graph.mindir"],
model_format="MINDIR", with_batch_dim=False)
@register.register_method(output_names=["logits"])
def predict_sub0(input_ids, current_index, init, batch_valid_length):
logits = register.call_servable(input_ids, current_index, init, batch_valid_length, subgraph=0)
return logits
@register.register_method(output_names=["logits"])
def predict_sub1(input_id, current_index, init, batch_valid_length):
logits = register.call_servable(input_id, current_index, init, batch_valid_length, subgraph=1)
return logits
sub0_servable = register.PipelineServable(servable_name="pangu", method="predict_sub0")
sub1_servable = register.PipelineServable(servable_name="pangu", method="predict_sub1")
@register.register_pipeline(output_names=["output_sentence"])
def predict(input_sentence):
"""generate sentence with given input_sentence"""
print(f"----------------------------- begin {input_sentence} ---------", flush=True)
time_start = time.time()
tokens = tokenizer.tokenize(input_sentence)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
outputs = generate_increment(input_ids)
return_tokens = tokenizer.convert_ids_to_tokens(outputs)
reply = "".join(return_tokens)
print(f"time cost {(time.time() - time_start) * 1000}ms, request '{input_sentence}' get reply '{reply}'",
flush=True)
return reply
def generate_increment(origin_inputs):
"""
Text generation for incremental inference
Inputs:
model: the model for inferencing
origin_inputs: the original inputs based on which the model will continue writing
config: inference configurations
Returns:
outputs: the ids for the generated text
"""
# Get configurations for inference
frequency_penalty = config.frequency_penalty
presence_penalty = config.presence_penalty
top_p = config.top_p
top_k_num = config.top_k_num
max_generate_length = config.max_generate_length
seq_length = config.seq_length
vocab_size = config.vocab_size
# Init outputs with original inputs
outputs = origin_inputs
origin_inputs = np.array([origin_inputs])
_, valid_length = origin_inputs.shape
# If target length exceeds seq_length, use seq_length instead
target_length = valid_length + max_generate_length
target_length = seq_length if target_length > seq_length else target_length
# A list of the frequency of each token
frequency_list = np.array([[0 for _ in range(vocab_size)]])
pad_length = seq_length - origin_inputs.shape[-1]
# Pad original inputs to seq_length
input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), 'constant', constant_values=(0, 0))
# Indicate the exact token position
current_index = valid_length - 1 if valid_length - 1 > 0 else 0
current_index = np.array([current_index], np.int32)
batch_valid_length = np.array([current_index], np.int32)
# For first graph, not_init should be false
init_true = True
init_false = False
init = init_false
# Call a single inference with input size of (bs, seq_length)
logits = sub0_servable.run(np.array(input_ids, np.int32), current_index, init, batch_valid_length)
# Claim the second graph and set not_init to true
init = init_true
# A single loop generates one token, loop until reaching target seq_length or generating eod token
while valid_length < target_length:
# Reshape the output logits
log_probs = logits.reshape(1, vocab_size)
# Get the revised log_probs considering frequency and presence penalty to eliminate duplicate in generated results
log_probs = log_probs.reshape(1, vocab_size)
log_probs_revised = log_probs - frequency_list * frequency_penalty - (frequency_list > 0) * presence_penalty
# Convert the log_probs to probability
logits = np.power(10, np.array(log_probs_revised, np.float32))
# If top_p is less than 1.0, use top_p sampling
if top_p < 1.0:
# Only consider the 5000 largest logits to reduce computation
sorted_logits, index = topk_fun(logits, 5000)
cumsum_logits = np.cumsum(sorted_logits, 1)
cumsum_logits = cumsum_logits[0]
index = index[0]
sorted_logits = sorted_logits[0]
top_p_num = sum(cumsum_logits > top_p)
# In case the probability is smooth, the sum of 5000 largest probabilities are not large enough
if top_p_num == 0:
top_p_num = 5000
# Get the corresponding probs and indices
probs = sorted_logits[:top_p_num]
p_args = index[:top_p_num]
p = probs / sum(probs)
# if top_p is set to 1.0, use top_k sampling
else:
# Get the corresponding probs and indices
probs, p_args = topk_fun(logits, top_k_num)
probs = probs[0]
p_args = p_args[0]
# Avoid rounding error
if sum(probs) == 0:
probs = np.array([1 / top_k_num for _ in range(top_k_num)])
p = probs / sum(probs)
# Random select a token as final output for this round
target_index = np.random.choice(len(p), p=p)
# Stop judgment
if p_args[target_index] == end_token or valid_length == target_length - 1:
break
# Update frequency list
target = p_args[target_index]
frequency_list[0][target] = frequency_list[0][target] + 1
valid_length += 1
batch_valid_length = np.array([valid_length - 1], np.int32)
current_index = np.array([0], np.int32)
input_id = np.array([[target]], np.int32)
# Update outputs with current generated token
outputs.append(int(target))
# Call a single inference with input size of (bs, 1)
logits = sub1_servable.run(input_id, current_index, init, batch_valid_length)
# Return valid outputs out of padded outputs
return outputs

View File

@ -0,0 +1,76 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from io import open
import jieba
import sentencepiece as spm
class JIEBATokenizer:
"""jieba tokenizer for encode and decode text"""
def __init__(self, vocab_file, model_file, max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
# self.encoder = json.load(open(vocab_file))
f = open(vocab_file, 'r')
lines = f.readlines()
self.encoder = {}
for line in enumerate(lines):
key = line[1].split('\t')[0]
self.encoder[key] = line[0]
self.decoder = {v: k for k, v in self.encoder.items()}
self.sp = spm.SentencePieceProcessor(model_file=model_file)
self.translator = str.maketrans(" \n", "\u2582\u2583")
self.eod_id = self.encoder['<eod>']
self.eot_id = self.encoder['<eot>']
self.pad_id = self.encoder['<pad>']
@property
def vocab_size(self):
return len(self.encoder)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
@property
def eod(self):
return self.eod_id
def tokenize(self, text):
""" Tokenize a string. """
seg_list = [x.translate(self.translator) for x in jieba.cut(text, cut_all=False)]
new_seg = " ".join(seg_list)
return self.sp.encode(new_seg)
def convert_tokens_to_ids(self, tokens):
return tokens
def convert_ids_to_tokens(self, ids):
return self.decode(ids)
def encode(self, text):
res = self.tokenize(text)
return res
def decode(self, tokens):
text = self.sp.decode(tokens)
text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n')
return text

View File

@ -0,0 +1,30 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Serving server start code, serve service, and manage all agents which load and execute models"""
import os
from mindspore_serving import server
def start():
servable_dir = os.path.dirname(os.path.realpath(__file__))
config = server.ServableStartConfig(servable_directory=servable_dir, servable_name="pangu", device_ids=(0,))
server.start_servables(config)
server.start_grpc_server("127.0.0.1:5500")
if __name__ == "__main__":
start()

View File

@ -21,7 +21,7 @@ export GLOG_v=1
start_serving_server()
{
echo "### start serving server, see serving_server.log for detail ###"
python3 serving_server.py > serving_server.log 2>&1 &
python3 pangu_distributed/serving_server.py > serving_server.log 2>&1 &
if [ $? -ne 0 ]
then
echo "serving server failed to start."
@ -52,8 +52,8 @@ start_serving_server()
start_serving_agent()
{
echo "### start serving agent, see and serving_logs/log_pangu_distributed.log for detail ###"
python3 serving_agent.py > serving_agent.log 2>&1 &
echo "### start serving agent, see serving_agent.log and serving_logs/log_pangu_distributed.log for detail ###"
python3 pangu_distributed/serving_agent.py > serving_agent.log 2>&1 &
if [ $? -ne 0 ]
then
echo "serving agent failed to start."
@ -68,7 +68,7 @@ start_serving_agent()
if [ $num -eq 0 ]
then
bash stop_pangu.sh
echo "start serving agent failed, see log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
echo "start serving agent failed, see serving_agent.log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
fi
count=$(($count+1))
@ -78,7 +78,7 @@ start_serving_agent()
if [ ${count} -eq 1800 ]
then
bash stop_pangu.sh
echo "start serving agent failed, see log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
echo "start serving agent failed, see serving_agent.log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
fi
echo "### start serving agent end ###"
}
@ -131,7 +131,7 @@ wait_serving_ready()
if [ $num -eq 0 ]
then
bash stop_pangu.sh
echo "waiting serving server ready failed, see log serving_server.log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
echo "waiting serving server ready failed, see serving_server.log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
fi
count=$(($count+1))
@ -141,7 +141,7 @@ wait_serving_ready()
if [ ${count} -eq 100 ]
then
bash stop_pangu.sh
echo "waiting serving server ready failed, see log serving_server.log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
echo "waiting serving server ready failed, see serving_server.log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
fi
echo "### waiting serving server ready end ###"
}
@ -150,5 +150,5 @@ bash stop_pangu.sh
rm -rf serving_server.log serving_agent.log flask.log serving_logs
start_serving_server
start_serving_agent
start_flask
wait_serving_ready
start_flask

View File

@ -0,0 +1,120 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
unset http_proxy
unset https_proxy
export GLOG_v=1
start_serving_server()
{
echo "### start serving server, see serving_server.log for detail ###"
python3 pangu_standalone/serving_server.py > serving_server.log 2>&1 &
if [ $? -ne 0 ]
then
echo "serving server failed to start."
fi
result=`grep -E 'Master server start success' serving_server.log | wc -l`
count=0
while [[ ${result} -eq 0 && ${count} -lt 100 ]]
do
sleep 1
num=`ps -ef | grep 'serving_server.py' | grep -v grep | wc -l`
if [ $num -eq 0 ]
then
echo "start serving server failed, see log serving_server.log for more detail" && exit 1
fi
count=$(($count+1))
result=`grep -E 'Master server start success' serving_server.log | wc -l`
done
if [ ${count} -eq 100 ]
then
echo "start serving server failed, see log serving_server.log for more detail" && exit 1
fi
echo "### start serving server end ###"
}
start_flask()
{
echo "### start flask server, see flask.log for detail ###"
python3 flask/client.py > flask.log 2>&1 &
if [ $? -ne 0 ]
then
echo "flask server failed to start."
fi
result=`grep -E 'Press CTRL\+C to quit' flask.log | wc -l`
count=0
while [[ ${result} -ne 1 && ${count} -lt 10 ]]
do
sleep 1
num=`ps -ef | grep 'flask/client.py' | grep -v grep | wc -l`
if [ $num -eq 0 ]
then
bash stop_pangu.sh
echo "start flask server failed, see log flask.log for more detail" && exit 1
fi
count=$(($count+1))
result=`grep -E 'Press CTRL\+C to quit' flask.log | wc -l`
done
if [ ${count} -eq 10 ]
then
bash stop_pangu.sh
echo "start flask server failed, see log flask.log for more detail" && exit 1
fi
echo "### start flask server end ###"
cat flask.log
}
wait_serving_ready()
{
echo "### waiting serving server ready, see and serving_logs/*.log for detail ###"
result=`grep -E 'gRPC server start success' serving_server.log | wc -l`
count=0
while [[ ${result} -eq 0 && ${count} -lt 1800 ]]
do
sleep 1
num=`ps -ef | grep 'serving_server.py' | grep -v grep | wc -l`
if [ $num -eq 0 ]
then
bash stop_pangu.sh
echo "waiting serving server ready failed, see log serving_server.log and serving_logs/*.log for more detail" && exit 1
fi
count=$(($count+1))
result=`grep -E 'gRPC server start success' serving_server.log | wc -l`
done
if [ ${count} -eq 1800 ]
then
bash stop_pangu.sh
echo "waiting serving server ready failed, see log serving_server.log and serving_logs/*.log for more detail" && exit 1
fi
echo "### waiting serving server ready end ###"
}
bash stop_pangu.sh
rm -rf serving_server.log flask.log serving_logs
start_serving_server
wait_serving_ready
start_flask

View File

@ -22,6 +22,7 @@ import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
def topk_fun(logits, topk=5):
"""Get topk"""
target_column = logits[0].tolist()
@ -33,6 +34,7 @@ def topk_fun(logits, topk=5):
value = np.array([value])
return value, index
def sampler(log_probs_revised, top_p, top_k_num, use_pynative=False):
"""Convert the log_probs to probability"""
if use_pynative:
@ -80,6 +82,7 @@ def sampler(log_probs_revised, top_p, top_k_num, use_pynative=False):
p = probs / sum(probs)
return p, p_args
def generate(model, origin_inputs, config):
"""
Text generation
@ -130,7 +133,7 @@ def generate(model, origin_inputs, config):
# Random select a token as final output for this round
target_index = np.random.choice(len(p), p=p)
# Stop judgment
if p_args[target_index] == end_token or valid_length == target_length-1:
if p_args[target_index] == end_token or valid_length == target_length - 1:
outputs = input_ids
break
@ -145,6 +148,7 @@ def generate(model, origin_inputs, config):
outputs = outputs[0][:length]
return outputs
def generate_increment(model, origin_inputs, config):
"""
Text generation for incremental inference
@ -183,8 +187,8 @@ def generate_increment(model, origin_inputs, config):
# Indicate the exact token position
current_index = valid_length - 1 if valid_length - 1 > 0 else 0
current_index = Tensor(np.array([current_index]), mstype.int32)
batch_valid_length = Tensor(np.array([current_index]), mstype.int32)
current_index = Tensor(np.array([current_index]), mstype.int32)
# For first graph, not_init should be false
init_true = Tensor([True], mstype.bool_)
init_false = Tensor([False], mstype.bool_)
@ -211,7 +215,7 @@ def generate_increment(model, origin_inputs, config):
# Random select a token as final output for this round
target_index = np.random.choice(len(p), p=p)
# Stop judgment
if p_args[target_index] == end_token or valid_length == target_length-1:
if p_args[target_index] == end_token or valid_length == target_length - 1:
break
# Update frequency list

View File

@ -27,10 +27,12 @@ from mindspore import context
from mindspore.common.seed import _get_graph_seed
from mindspore._checkparam import Validator
class Dropout(nn.Cell):
r"""
A Dropout Implements with P.DropoutGenMask and P.DropoutDoMask for parallel training.
"""
def __init__(self, keep_prob=0.5, dtype=mstype.float32):
super(Dropout, self).__init__()
if keep_prob <= 0 or keep_prob > 1:
@ -52,6 +54,7 @@ class Dropout(nn.Cell):
self.cast = P.Cast()
else:
self.dropout = P.Dropout(keep_prob)
def construct(self, x):
r"""
Input: a tensor
@ -83,10 +86,12 @@ class Dropout(nn.Cell):
else:
self.dropout.shard(strategy)
class LayerNorm(nn.Cell):
r"""
A self-defined layer norm operation using reduce sum and reduce mean
"""
def __init__(self, normalized_shape, dp=4, eps=1e-5, parallel_optimizer=False):
super(LayerNorm, self).__init__()
self.gamma = Parameter(initializer('ones', normalized_shape), name="gamma",
@ -102,6 +107,7 @@ class LayerNorm(nn.Cell):
self.add2 = P.TensorAdd().shard(((dp, 1, 1), (1,)))
self.real_div = P.RealDiv().shard(((dp, 1, 1), (dp, 1, 1)))
self.eps = eps
def construct(self, x):
mean = self.mean(x, -1)
diff = self.sub1(x, mean)
@ -111,6 +117,7 @@ class LayerNorm(nn.Cell):
output = self.add2(self.mul(output, self.gamma), self.beta)
return output
class Mapping(nn.Cell):
"""
A mapping function with a 3d input
@ -167,6 +174,7 @@ class MappingOutput(nn.Cell):
Returns:
output: Tensor, a 3d tensor after projection
"""
def __init__(self, config, input_size, output_size, scale=1.0):
super(MappingOutput, self).__init__()
self.output_size = output_size
@ -205,6 +213,7 @@ class FeedForwardLayer(nn.Cell):
Returns:
output: Tensor, the output of this layer after mapping
"""
def __init__(self, config, scale=1.0):
super(FeedForwardLayer, self).__init__()
input_size = config.embedding_size
@ -226,6 +235,7 @@ class FeedForwardLayer(nn.Cell):
output = self.dropout(output)
return output
class EmbeddingLookup(nn.Cell):
"""
The embedding lookup table for vocabulary
@ -236,6 +246,7 @@ class EmbeddingLookup(nn.Cell):
seq_length, embedding_size)
self.embedding_table: Tensor, the embedding table for the vocabulary
"""
def __init__(self):
super(EmbeddingLookup, self).__init__()
self.gather = P.GatherV2()
@ -244,6 +255,7 @@ class EmbeddingLookup(nn.Cell):
output = self.gather(table, input_ids, 0)
return output
class Attention(nn.Cell):
"""
Self-Attention module for each layer
@ -253,6 +265,7 @@ class Attention(nn.Cell):
scale: scale factor for initialization
layer_idx: current layer index
"""
def __init__(self, config, scale=1.0, layer_idx=None):
super(Attention, self).__init__()
# Output layer
@ -295,7 +308,6 @@ class Attention(nn.Cell):
self.softmax.softmax.shard(((config.dp, config.mp, 1),))
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
dense_shape = [config.embedding_size, config.embedding_size]
bias_shape = [config.embedding_size]
# Query
@ -399,7 +411,7 @@ class Attention(nn.Cell):
# The first graph with the input size of (bs, seq_length)
if self.is_first_iteration:
# Get the valid input length without padding
valid_length_vector = F.cast(self.less(self.range, batch_valid_length), self.dtype)
valid_length_vector = F.cast(self.less(self.range, batch_valid_length.view(1, 1, -1)), self.dtype)
# Cover the key and value numbers corresponding to the padding position
key_present = self.mul1(key, self.expand_dims(valid_length_vector, 2))
value_present = self.mul1(value, self.expand_dims(valid_length_vector, 3))
@ -464,7 +476,7 @@ class Attention(nn.Cell):
x_merge: the 3d output
"""
x = self.merger_head_transpose(
x, (0, 2, 1, 3)) #bs, seq_length, head, size_per_head
x, (0, 2, 1, 3)) # bs, seq_length, head, size_per_head
x_shape = P.Shape()(x)
new_shape = x_shape[:-2] + (x_shape[-2] * x_shape[-1],)
x_merge = self.reshape(x, new_shape)
@ -551,6 +563,7 @@ class Decoder(nn.Cell):
output: Tensor, the output logit of this layer
layer_present: Tensor, the feature map of current layer
"""
def __init__(self, config, layer_idx):
super(Decoder, self).__init__()
scale = 1 / math.sqrt(2.0 * config.num_layers)
@ -633,10 +646,12 @@ class Decoder(nn.Cell):
output = self.add(x, mlp_logit)
return output, layer_present
class Embedding(nn.Cell):
"""
Embedding
"""
def __init__(self, config):
super(Embedding, self).__init__()
self.word_embedding = EmbeddingLookup().set_comm_fusion(1)
@ -691,18 +706,22 @@ class Mask(nn.Cell):
"""
Mask
"""
def __init__(self, config):
super(Mask, self).__init__()
self.dtype = config.compute_dtype
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
def construct(self, attention_mask):
attention_mask = self.expand_dims(attention_mask, 1)
return attention_mask
class QueryLayerAttention(Attention):
r"""
Self-Attention module using input query vector.
"""
def construct(self, x, query_hidden_state, attention_mask, key_past=None, value_past=None, batch_valid_length=None):
original_shape = F.shape(x)
x = F.reshape(x, (-1, original_shape[-1]))
@ -732,7 +751,7 @@ class QueryLayerAttention(Attention):
# The first graph with the input size of (bs, seq_length)
if self.is_first_iteration:
# Get the valid input length without padding
valid_length_vector = F.cast(self.less(self.range, batch_valid_length), self.dtype)
valid_length_vector = F.cast(self.less(self.range, batch_valid_length.view(1, 1, -1)), self.dtype)
# Cover the key and value numbers corresponding to the padding position
key_present = self.mul1(key, self.expand_dims(valid_length_vector, 2))
value_present = self.mul1(value, self.expand_dims(valid_length_vector, 3))
@ -767,11 +786,13 @@ class QueryLayerAttention(Attention):
output = self.dropout(output)
return output, layer_present
class QueryLayer(nn.Cell):
r"""
A block usingooked out position embedding as query vector.
This is used as the final block.
"""
def __init__(self, config):
super(QueryLayer, self).__init__()
scale = 1 / math.sqrt(2.0 * config.num_layers)
@ -854,6 +875,7 @@ class QueryLayer(nn.Cell):
output = self.add(x, mlp_logit)
return output, layer_present
class PanguAlphaEmbedding(nn.Cell):
"""
Input embedding, i.e., word embedding and position embedding
@ -870,6 +892,7 @@ class PanguAlphaEmbedding(nn.Cell):
attention_mask: Tensor, attention_mask matrix
embedding_table: Tensor, embedding_table with shape of (vocab_size, embedding_size)
"""
def __init__(self, config):
super(PanguAlphaEmbedding, self).__init__()
self.embedding = Embedding(config)
@ -885,6 +908,7 @@ class PanguAlphaEmbedding(nn.Cell):
attention_mask = self.mask(attention_mask)
return hidden_states, attention_mask
class PanguAlpha_Model(nn.Cell):
"""
The backbone of PanguAlpha network
@ -899,6 +923,7 @@ class PanguAlpha_Model(nn.Cell):
present_layer: Tensor, the current feature map
embedding_table: Tensor, the embedding table for the vocabulary
"""
def __init__(self, config):
super(PanguAlpha_Model, self).__init__()
self.embedding = PanguAlphaEmbedding(config)
@ -971,13 +996,14 @@ class PanguAlpha_Model(nn.Cell):
hidden_states, attention_mask = self.embedding(input_ids, input_mask, table,
input_position, attention_mask,
batch_valid_length)
for i in range(self.num_layers-1):
for i in range(self.num_layers - 1):
hidden_states, _ = self.blocks[i](hidden_states,
attention_mask, init_reset, batch_valid_length)
if self.is_pipeline:
top_query_hidden_states = self.top_query_embedding(input_position.view(-1,), self.top_query_embedding_table)
hidden_states, _ = self.blocks[self.num_layers-1](hidden_states, top_query_hidden_states,
attention_mask, init_reset, batch_valid_length)
top_query_hidden_states = self.top_query_embedding(input_position.view(-1,),
self.top_query_embedding_table)
hidden_states, _ = self.blocks[self.num_layers - 1](hidden_states, top_query_hidden_states,
attention_mask, init_reset, batch_valid_length)
output_state = self.layernorm(hidden_states)
output_state = F.cast(output_state, self.dtype)
else:
@ -988,6 +1014,7 @@ class PanguAlpha_Model(nn.Cell):
attention_mask, init_reset, batch_valid_length)
return output_state
class PanguAlpha_Head(nn.Cell):
"""
Head for PanguAlpha to get the logits of each token in the vocab
@ -999,6 +1026,7 @@ class PanguAlpha_Head(nn.Cell):
Returns:
logits: Tensor, the logits of the corresponding inputs
"""
def __init__(self, config):
super(PanguAlpha_Head, self).__init__()
if config.word_emb_dp:
@ -1029,6 +1057,7 @@ class PanguAlpha(nn.Cell):
Returns:
logits: Tensor: the logits of the corresponding inputs with shape (batch_size, seq_length, vocab_size)
"""
def __init__(self, config):
super(PanguAlpha, self).__init__()
# Network head to get logits over vocabulary
@ -1065,6 +1094,7 @@ class PanguAlpha(nn.Cell):
logits = self.head(output_states, self.embedding_table)
return logits
class CrossEntropyLoss(nn.Cell):
"""
Calculate the cross entropy loss
@ -1077,6 +1107,7 @@ class CrossEntropyLoss(nn.Cell):
Returns:
loss: Tensor, the corrsponding cross entropy loss
"""
def __init__(self, config):
super(CrossEntropyLoss, self).__init__()
self.mean = P.ReduceMean()
@ -1137,6 +1168,7 @@ class CrossEntropyLoss(nn.Cell):
loss = self.div2(numerator, denominator)
return loss
class PanguAlphaWithLoss(nn.Cell):
"""
PanguAlpha training loss
@ -1150,6 +1182,7 @@ class PanguAlphaWithLoss(nn.Cell):
Returns:
output: Tensor, the loss of the network
"""
def __init__(self, config, network, loss, eos_token=6):
super(PanguAlphaWithLoss, self).__init__(auto_prefix=False)
self.network = network
@ -1172,6 +1205,40 @@ class PanguAlphaWithLoss(nn.Cell):
output = self.loss(logits, labels, input_mask)
return output
class AttentionMask(nn.Cell):
"""
Get the attention matrix for self-attention module
Args:
seq_length: the pre-defined sequence length
Inputs:
input_mask: the mask indicating whether each position is a valid input
Returns:
attention_mask: the attention mask matrix with shape (batch_size, seq_length, seq_length)
"""
def __init__(self, seq_length):
super(AttentionMask, self).__init__()
self.reshape = P.Reshape()
self.mul = P.BatchMatMul().shard(
((1, 1, 1), (1, 1, 1)))
self.expand_dim = P.ExpandDims().shard(((1, 1),))
ones = np.ones(shape=(seq_length, seq_length))
self.lower_triangle_mask = Tensor(np.tril(ones), mstype.float32)
self.multiply = P.Mul().shard(((1, 1, 1), (1, 1, 1)))
def construct(self, input_mask):
input_shape = P.Shape()(input_mask)
shape_right = (input_shape[0], 1, input_shape[1])
shape_left = input_shape + (1,)
mask_left = self.reshape(input_mask, shape_left)
mask_right = self.reshape(input_mask, shape_right)
attention_mask = self.mul(mask_left, mask_right)
lower_triangle = self.expand_dim(self.lower_triangle_mask, 0)
attention_mask = self.multiply(attention_mask, lower_triangle)
return attention_mask
class EvalNet(nn.Cell):
"""
PanguAlpha evaluation net
@ -1185,7 +1252,8 @@ class EvalNet(nn.Cell):
Returns:
outputs: Tensor, corresponding output for different tasks
"""
def __init__(self, backbone, generate=False, pad_token=6):
def __init__(self, backbone, generate=False, pad_token=6, seq_length=1024):
super(EvalNet, self).__init__(auto_prefix=False)
self.backbone = backbone
self.pad_token = pad_token
@ -1194,14 +1262,19 @@ class EvalNet(nn.Cell):
self.topk = P.TopK(sorted=True).shard(((1, 1),))
self.gather = P.GatherV2().shard(((1, 1), (1,)))
self.log_softmax = P.LogSoftmax().shard(((1, 1, 1),))
self.get_attention_mask = AttentionMask(seq_length)
def construct(self, input_ids, current_index, init_reset=True, batch_valid_length=None):
"""evaluation net"""
input_mask = F.cast(F.not_equal(input_ids, self.pad_token), mstype.float32)
logits = self.backbone(input_ids, input_mask)
bs, seq_length = F.shape(input_ids)
attention_mask = self.get_attention_mask(input_mask)
input_position = F.tuple_to_array(F.make_range(seq_length))
input_position = P.Tile()(input_position, (bs, 1))
logits = self.backbone(input_ids, input_mask, input_position, attention_mask,
init_reset, batch_valid_length)
index = current_index.view(1,)
logits = self.gather(logits, index, 0)
bs, _ = F.shape(input_ids)
logits = logits.view(bs, 1, -1)
log_probs = self.log_softmax(logits)
return log_probs

View File

@ -71,12 +71,14 @@ class FP32StateAdamWeightDecay(AdamWeightDecay):
get_square_sum = C.MultitypeFuncGraph("get_square_sum")
@get_square_sum.register("Tensor", "Number")
def _get_square_sum(grad, value):
norm = P.ReduceSum(False)(F.square(grad), ()) / value
norm = F.expand_dims(F.cast(norm, mstype.float32), 0)
return norm
apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
@ -85,6 +87,7 @@ def _apply_global_norm(clip_norm, global_norm, grad):
grad = grad * clip_norm / global_norm
return grad
def _get_model_parallel_group(mp):
"""
@ -99,11 +102,12 @@ def _get_model_parallel_group(mp):
local_stage_rank_id = rank % per_stage_device_nums
index = local_stage_rank_id // mp
group = range(0, mp)
rank_str_list = [str(x + index * mp + stage_id * per_stage_device_nums) for x in group]
rank_str_list = [str(x + index * mp + stage_id * per_stage_device_nums) for x in group]
rank_list_str = "-".join(rank_str_list)
rank_list = [x + index * mp + stage_id * per_stage_device_nums for x in group]
return rank_list, rank_list_str
def _get_pipeline_group():
"""
@ -121,10 +125,12 @@ def _get_pipeline_group():
rank_list_str = "-".join(rank_str_list)
return rank_list, rank_list_str
class GlobalNorm(nn.Cell):
"""
Calculate the global norm value of given tensors
"""
def __init__(self, params, config):
super(GlobalNorm, self).__init__()
self.norm = nn.Norm()
@ -165,12 +171,14 @@ class GlobalNorm(nn.Cell):
global_norms = F.sqrt(P.AllReduce()(square_reduce_sum))
return global_norms
class ClipByGlobalNorm(nn.Cell):
"""
Clip grads by global norm
"""
def __init__(self, params, config, clip_norm=1.0):
super(ClipByGlobalNorm, self).__init__()
self.global_norm = GlobalNorm(params, config)
@ -185,6 +193,7 @@ class ClipByGlobalNorm(nn.Cell):
grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads)
return grads, global_norm_value
class LearningRate(LearningRateSchedule):
"""
Warmup-decay learning rate for PanguAlpha network.
@ -259,6 +268,11 @@ def add_inference_params(opt):
type=int,
default=0,
help="Whether use pynative op for postproecess")
opt.add_argument("--use_past",
type=str,
default="true",
choices=["true", "false"],
help="Whether enable state reuse")
def add_training_params(opt):