!19622 Serving, pangu alpha modelzoo
Merge pull request !19622 from 徐永飞/master
This commit is contained in:
commit
b4c04ef3a8
|
@ -1,4 +1,3 @@
|
|||
|
||||
# It is still under development
|
||||
|
||||
# Contents
|
||||
|
@ -185,7 +184,7 @@ bash scripts/run_distribute_train_incremental_train.sh DATASET RANK_TABLE 8 fp32
|
|||
Please refer to the [website](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha) to download the following parts:
|
||||
|
||||
- tokenizer: vocab.txt and vocab.model
|
||||
- checkpoint file: \*.part\[0-4\] and *.npy under the same parameter size
|
||||
- checkpoint file: \*.part\[0-4\] (need to extract) and *.npy under the same parameter size
|
||||
- strategy file: a file described how the parameters are sliced across different devices.
|
||||
|
||||
Here we suppose the downloaded checkpoint, tokenizer and strategy file is organized as follows:
|
||||
|
@ -204,6 +203,12 @@ ckpts
|
|||
└── vocab10.vocab
|
||||
```
|
||||
|
||||
We provide two predict methods. The first one is the normal way which needs to pad the input to a certain length every
|
||||
iteration. Due to the redundant calculation, the latency of this method is quite high and to accelerate the speed
|
||||
performance, we provide the second state reuse (incremental inference) method.
|
||||
|
||||
The state reuse method is the default mode, and you can disable it by changing the argument 'use_past' to False.
|
||||
|
||||
### Run Prediction on Distributed mode
|
||||
|
||||
The following script will run prediction on 8 Ascend cards.
|
||||
|
@ -217,28 +222,217 @@ ${FILE_PATH}/tokenizer/ ${FILE_PATH}/checkpoint_file filitered 2.6B fp32
|
|||
### Run Prediction Using One Device
|
||||
|
||||
The following script will run prediction on 1 Ascend cards. The difference is the net is initialized with float16 type.
|
||||
And the rank_table should be configured to one device.
|
||||
|
||||
```bash
|
||||
$FILE_PATH=/home/your_path/ckpts
|
||||
bash scripts/run_distribute_predict.sh 1 /home/config/rank_table_1p.json ${FILE_PATH}/strategy_load_ckpt/strategy.ckpt \
|
||||
${FILE_PATH}/tokenizer/ ${FILE_PATH}/checkpoint_file filitered 2.6B fp16
|
||||
bash scripts/run_standalone_predict.sh ${FILE_PATH}/strategy_load_ckpt/strategy.ckpt \
|
||||
${FILE_PATH}/tokenizer/ ${FILE_PATH}/checkpoint_file filitered 2.6B
|
||||
```
|
||||
|
||||
### Run Serving
|
||||
|
||||
In directory serving:
|
||||
#### Preparation
|
||||
|
||||
- Use scripts/run_distribute_export.sh to export MindIR models, and copy all device* to serving_increment/models/.
|
||||
- Download [PanGu-Alpha tokenizer repository](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha.git) and copy pangu-alpha/tokenizer to directory pangu/tokenizer.
|
||||
- Pip install MindSpore and MindSpore Serving 1.2 whl package.
|
||||
- Pip install flask, flask-apscheduler, jieba, sentencepiece whl package.
|
||||
- Edit server_agent.py and update the path of pangu-alpha models.
|
||||
- Run 'bash start_pangu.sh' to start new execution.
|
||||
- Wait for serving to start successfully: observe the serving_server.log file until the message "Serving: gRPC server start success, listening on 127.0.0.1:5500" is output.
|
||||
- If any error happened, log can be viewed in serving_server.log, serving_agent.log and flask.log.
|
||||
- If anything all right, access address {ip}:5000 in one browser.
|
||||
- Run 'bash stop_pangu.sh' to stop the existing execution.
|
||||
- Pip install MindSpore and MindSpore Serving 1.3 or later.
|
||||
- Pip install flask, flask-apscheduler, jieba, sentencepiece and other whl package if needed.
|
||||
- Download [PanGu-Alpha repository](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha), we will need
|
||||
`pangu-alpha/strategy_load_ckpt` and `pangu-alpha/tokenizer` in the following process.
|
||||
- Download 13B or 2.6B checkpoint files and `*embedding` files
|
||||
from [PanGu-Alpha repository](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha).
|
||||
|
||||
For 13B, we will need `13B_part0` to `13B_part3`, `13B_word_embedding`, `13B_top_query_embedding`
|
||||
, `13B_position_embedding`.
|
||||
|
||||
For 2.6B, we will need `2.6B_part0` to `2.6B_part3`, `13B_word_embedding`, `2.6B_top_query_embedding`
|
||||
, `2.6B_position_embedding`.
|
||||
|
||||
Decompress all the `13B_part*` or `2.6B_part*` tar files and a large number of `*ckpt` files will generate. Move
|
||||
all `*embedding` to the same directory of `*.ckpt` files.
|
||||
|
||||
#### Run 13B or 2.6B in standalone mode[Ascend910/Nvidia GPU]
|
||||
|
||||
- Use scripts/run_standalone_export.sh to export MindIR models, and move all device_0/* to
|
||||
'serving_increment/pangu_standalone/pangu/1/'.
|
||||
|
||||
```shell
|
||||
>>> cd scripts
|
||||
>>> bash run_standalone_export.sh ${strategy_file_path} ${ckpt_dir_path}
|
||||
```
|
||||
|
||||
Update the parameter `MODE` in `run_standalone_export.sh` from `13B` to `2.6B` if we want to export 2.6B model.
|
||||
|
||||
Update the parameter `DEVICE_TARGET` in `run_standalone_export.sh` from `Ascend` to `GPU` when running in GPU environment.
|
||||
|
||||
The `${strategy_file_path}` is file path of `pangu-alpha/strategy_load_ckpt/angu_alpha_13B_cktp_strategy.ckpt` for 13B
|
||||
and `pangu-alpha/strategy_load_ckpt/angu_alpha_2.6B_cktp_strategy.ckpt` for 2.6B.
|
||||
|
||||
The `${ckpt_dir_path}` is the directory the `*ckpt` files generated by decompression and `*embedding` files.
|
||||
|
||||
The model will be exported for some minutes. Check log device_0/log0.log, confirm that there is no exception at last.
|
||||
Confirm that mindir files have been generated in device_0/ which means that the model is exported successfully.
|
||||
|
||||
```shell
|
||||
>>> ls device_0
|
||||
pangu_alpha_1024_graph.mindir pangu_alpha_1024_variables pangu_alpha_1_graph.mindir pangu_alpha_1_variables
|
||||
>>> cd - && mkdir serving_increment/pangu_standalone/pangu/1/
|
||||
>>> mv scripts/device_0/* serving_increment/pangu_standalone/pangu/1/
|
||||
>>> cd serving_increment
|
||||
```
|
||||
|
||||
- Copy `pangu-alpha/tokenizer` to directory serving_increment/pangu_standalone/pangu/tokenizer.
|
||||
|
||||
The directory hierarchy of the required files is shown below. The pangu_alpha_1024_variables and pangu_alpha_1_variables are folded for easy display.
|
||||
|
||||
```shell
|
||||
>>> tree pangu_distributed
|
||||
pangu_standalone/
|
||||
├── pangu
|
||||
│ ├── 1
|
||||
│ │ ├── pangu_alpha_1024_graph.mindir
|
||||
│ │ ├── pangu_alpha_1024_variables/
|
||||
│ │ ├── pangu_alpha_1_graph.mindir
|
||||
│ │ └── pangu_alpha_1_variables/
|
||||
│ ├── servable_config.py
|
||||
│ ├── tokenization_jieba.py
|
||||
│ └── tokenizer
|
||||
│ ├── vocab.model
|
||||
│ └── vocab.vocab
|
||||
└── serving_server.py
|
||||
```
|
||||
|
||||
- Run `bash start_pangu_standalone.sh` to start new execution, and wait until the serving and flask server are started
|
||||
successfully.
|
||||
- If any error happened, log can be viewed in serving_server.log, serving_logs/*.log and flask.log.
|
||||
- If anything all right, access address {ip}:5000 in one browser. It will take some time to return the reply.
|
||||
- Run `bash stop_pangu.sh` to stop the existing execution.
|
||||
|
||||
#### Run 13B or 2.6B in distributed mode[Ascend910 8 cards]
|
||||
|
||||
- Generate [rank table file](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
|
||||
|
||||
```shell
|
||||
# mindspore/model_zoo/utils/hccl_tools/hccl_tools.py
|
||||
>>> python3 ../../../utils/hccl_tools/hccl_tools.py --device_num "[0,8]"
|
||||
>>> mv hccl_8p_01234567*.json serving_increment/pangu_distributed/hccl_8p.json
|
||||
```
|
||||
|
||||
- Use scripts/run_distribute_export.sh to export MindIR models, and move all device* to
|
||||
'serving_increment/pangu_distributed/models/'.
|
||||
|
||||
```shell
|
||||
>>> cd scripts
|
||||
>>> bash run_distribute_export.sh ${strategy_file_path} ${ckpt_dir_path}
|
||||
```
|
||||
|
||||
Update the parameter `MODE` in `run_distribute_export.sh` from `13B` to `2.6B` if we want to export 2.6B model.
|
||||
|
||||
The `${strategy_file_path}` is file path of `pangu-alpha/strategy_load_ckpt/angu_alpha_13B_cktp_strategy.ckpt` for 13B
|
||||
and `pangu-alpha/strategy_load_ckpt/angu_alpha_2.6B_cktp_strategy.ckpt` for 2.6B.
|
||||
|
||||
The `${ckpt_dir_path}` is the directory the *ckpt files generated by decompression and *embedding files.
|
||||
|
||||
The model will be exported for some minutes. Check log device_[0-7]/log[0-7].log, confirm that there is no exception at
|
||||
last. Confirm that mindir files have been generated in device_[0-7]/ which means that the model is exported
|
||||
successfully.
|
||||
|
||||
```shell
|
||||
>>> cd - && mkdir serving_increment/pangu_distributed/models/
|
||||
>>> mv scripts/device_* serving_increment/pangu_distributed/models/
|
||||
>>> cd serving_increment
|
||||
```
|
||||
|
||||
- Update MindIR file name serving_increment/pangu_distributed/serving_agent.py if needed.
|
||||
- Copy `pangu-alpha/tokenizer` to directory serving_increment/pangu_distributed/pangu/tokenizer.
|
||||
|
||||
The directory hierarchy of the required files is shown below. The device_1 to device_7 are folded for easy display.
|
||||
|
||||
```shell
|
||||
>>> tree pangu_distributed
|
||||
pangu_distributed/
|
||||
├── hccl_8p.json
|
||||
├── models
|
||||
│ ├── device_0
|
||||
│ │ ├── pangu_alpha_1024_graph.mindir
|
||||
│ │ ├── pangu_alpha_1024_variables
|
||||
│ │ │ ├── data_0
|
||||
│ │ │ ├── data_1
|
||||
│ │ │ ├── data_2
|
||||
│ │ │ ├── data_3
|
||||
│ │ │ └── data_4
|
||||
│ │ ├── pangu_alpha_1_graph.mindir
|
||||
│ │ └── pangu_alpha_1_variables
|
||||
│ │ ├── data_0
|
||||
│ │ ├── data_1
|
||||
│ │ ├── data_2
|
||||
│ │ ├── data_3
|
||||
│ │ └── data_4
|
||||
│ ├── device_1/
|
||||
│ ├── device_2/
|
||||
│ ├── device_3/
|
||||
│ ├── device_4/
|
||||
│ ├── device_5/
|
||||
│ ├── device_6/
|
||||
│ └── device_7/
|
||||
├── pangu
|
||||
│ ├── servable_config.py
|
||||
│ ├── tokenization_jieba.py
|
||||
│ └── tokenizer
|
||||
│ ├── vocab.model
|
||||
│ └── vocab.vocab
|
||||
├── serving_agent.py
|
||||
└── serving_server.py
|
||||
```
|
||||
|
||||
- Run `bash start_pangu_distributed.sh` to start new execution, and wait until the serving and flask server are started
|
||||
successfully.
|
||||
- If any error happened, log can be viewed in serving_server.log, serving_agent.log, serving_logs/*.log and flask.log.
|
||||
- If anything all right, access address {ip}:5000 in one browser. It will take some time to return the reply.
|
||||
- Run `bash stop_pangu.sh` to stop the existing execution.
|
||||
|
||||
#### Run in distributed mode[Ascend910 8 cards * N machine]
|
||||
|
||||
- Generate [rank table file](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
|
||||
- In every machine, prepare for checkpoint files and embedding files. We can also use 13B as a test example.
|
||||
- In every machine, use scripts/run_cluster_export.sh to export MindIR models, and move all device* to
|
||||
'serving_increment/pangu_distributed/models/'.
|
||||
|
||||
```shell
|
||||
>>> cd scripts
|
||||
>>> bash run_cluster_export.sh ${strategy_file_path} ${ckpt_dir_path} ${rank_table_file} ${rank_size} ${rank_start}
|
||||
```
|
||||
|
||||
Update the parameter `MODE` in `run_distribute_export.sh` from `200B` to `13B` if we want to export 13B model.
|
||||
|
||||
The `${rank_start}` is the first rank id in every machine, likes 0,8,16,24.
|
||||
|
||||
The model will be exported for some minutes. Check log device_[0-7]/log[0-7].log, confirm that there is no exception at
|
||||
last. Confirm that mindir files have been generated in device_[0-7]/ which means that the model is exported
|
||||
successfully.
|
||||
|
||||
```shell
|
||||
>>> cd - && mkdir serving_increment/pangu_distributed/models/
|
||||
>>> mv scripts/device_* serving_increment/pangu_distributed/models/
|
||||
>>> cd serving_increment
|
||||
```
|
||||
|
||||
- In the first machine, update the parameter `rank_size` and `stage_size`(Pipeline stage size) of `serving_increment/pangu_distributed/pangu/servable_config.py`.
|
||||
- In the first machine, update the parameter `rank_table_json_file` of `serving_increment/pangu_distributed/serving_server.py`.
|
||||
- In every machine, update MindIR file name serving_increment/pangu_distributed/serving_agent.py if needed.
|
||||
- In every machine, update the parameter `distributed_address` of `serving_increment/pangu_distributed/serving_agent.py` and
|
||||
`serving_increment/pangu_distributed/serving_server.py` to the first machine ip address.
|
||||
- In the first machine, copy `pangu-alpha/tokenizer` to directory serving_increment/pangu_distributed/pangu/tokenizer.
|
||||
- In the first machine, run `bash start_pangu_distributed.sh` to start new execution.
|
||||
- Meanwhile, in other machines, run `python serving_agent.py` to start serving agent process.
|
||||
|
||||
```shell
|
||||
>>> unset http_proxy && unset https_proxy
|
||||
>>> python pangu_distributed/serving_agent.py > serving_agent.log 2>&1 &
|
||||
```
|
||||
|
||||
- Wait until the serving and flask server are started successfully.
|
||||
- If any error happened, log can be viewed in serving_server.log, serving_agent.log, serving_logs/*.log and flask.log.
|
||||
- If anything all right, access address {first_machine_ip}:5000 in one browser. It will take some time to return the reply.
|
||||
- Run `bash stop_pangu.sh` to stop the existing execution in every machine.
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
|
|
|
@ -63,8 +63,12 @@ def load_model(args_opt):
|
|||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(
|
||||
strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path)
|
||||
|
||||
use_past = (args_opt.use_past == "true")
|
||||
print('local_rank:{}, start to run...'.format(rank), flush=True)
|
||||
use_past = False
|
||||
if args_opt.export:
|
||||
use_past = True
|
||||
# Set model property
|
||||
|
@ -72,6 +76,9 @@ def load_model(args_opt):
|
|||
data_parallel_num = int(device_num / model_parallel_num)
|
||||
per_batch_size = args_opt.per_batch_size
|
||||
batch_size = per_batch_size * data_parallel_num
|
||||
# Now only support single batch_size for predict
|
||||
if args_opt.run_type == "predict":
|
||||
batch_size = 1
|
||||
config = PANGUALPHAConfig(
|
||||
data_parallel_num=data_parallel_num,
|
||||
model_parallel_num=model_parallel_num,
|
||||
|
@ -105,11 +112,15 @@ def load_model(args_opt):
|
|||
inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32)
|
||||
current_index = Tensor(np.array([0]), mstype.int32)
|
||||
|
||||
if config.use_past:
|
||||
if args_opt.distribute == "false":
|
||||
predict_layout = None
|
||||
elif config.use_past:
|
||||
batch_valid_length = Tensor(np.array([0]), mstype.int32)
|
||||
init_true = Tensor([True], mstype.bool_)
|
||||
inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32)
|
||||
model_predict.predict_network.add_flags_recursive(is_first_iteration=True)
|
||||
predict_layout = model_predict.infer_predict_layout(inputs_np, current_index, init_true, batch_valid_length)
|
||||
model_predict.predict_network.add_flags_recursive(is_first_iteration=False)
|
||||
_ = model_predict.infer_predict_layout(inputs_np_1, current_index, init_true, batch_valid_length)
|
||||
else:
|
||||
predict_layout = model_predict.infer_predict_layout(inputs_np, current_index)
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
execute_path=$(pwd)
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
|
||||
MODE=200B # or 13B
|
||||
PARAM_INIT_TYPE=fp16
|
||||
STRATEGY=$1
|
||||
CKPT_PATH=$2
|
||||
export RANK_TABLE_FILE=$3
|
||||
export RANK_SIZE=$4
|
||||
RANK_START=$5 # 0,8,16,... for each machine
|
||||
CKPT_NAME='filerted'
|
||||
|
||||
for((i=0;i<8;i++));
|
||||
do
|
||||
rm -rf ${execute_path}/device_$i/
|
||||
mkdir ${execute_path}/device_$i/
|
||||
cd ${execute_path}/device_$i/ || exit
|
||||
export RANK_ID=$(($RANK_START+$i))
|
||||
export DEVICE_ID=$i
|
||||
python -s ${self_path}/../predict.py --strategy_load_ckpt_path=$STRATEGY --load_ckpt_path=$CKPT_PATH \
|
||||
--load_ckpt_name=$CKPT_NAME --mode=$MODE --run_type=predict --param_init_type=$PARAM_INIT_TYPE \
|
||||
--export=1 >log$i.log 2>&1 &
|
||||
done
|
|
@ -18,12 +18,12 @@ execute_path=$(pwd)
|
|||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=${execute_path}/../serving_increment/hccl_8p.json
|
||||
export RANK_TABLE_FILE=${execute_path}/../serving_increment/pangu_distributed/hccl_8p.json
|
||||
export MODE=13B
|
||||
export PARAM_INIT_TYPE=fp16
|
||||
export STRATEGY=$1
|
||||
export CKPT_PATH=$2
|
||||
export CKPT_NAME=$3
|
||||
export PARAM_INIT_TYPE=$4
|
||||
export CKPT_NAME='filerted'
|
||||
|
||||
for((i=0;i<$RANK_SIZE;i++));
|
||||
do
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
execute_path=$(pwd)
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
export RANK_SIZE=1
|
||||
export MODE=13B # or 2.6B
|
||||
export PARAM_INIT_TYPE=fp16
|
||||
export STRATEGY=$1
|
||||
export CKPT_PATH=$2
|
||||
export DEVICE_TARGET=Ascend # or GPU
|
||||
export CKPT_NAME='filerted'
|
||||
|
||||
for((i=0;i<$RANK_SIZE;i++));
|
||||
do
|
||||
rm -rf ${execute_path}/device_$i/
|
||||
mkdir ${execute_path}/device_$i/
|
||||
cd ${execute_path}/device_$i/ || exit
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
python -s ${self_path}/../predict.py --strategy_load_ckpt_path=$STRATEGY --load_ckpt_path=$CKPT_PATH \
|
||||
--load_ckpt_name=$CKPT_NAME --mode=$MODE --run_type=predict --param_init_type=$PARAM_INIT_TYPE \
|
||||
--export=1 --distribute=false --device_target=$DEVICE_TARGET >log$i.log 2>&1 &
|
||||
done
|
|
@ -0,0 +1,39 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
execute_path=$(pwd)
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
export RANK_SIZE=1
|
||||
export STRATEGY=$1
|
||||
export TOKENIZER=$2
|
||||
export CKPT_PATH=$3
|
||||
export CKPT_NAME=$4
|
||||
export MODE=$5
|
||||
export PARAM_INIT_TYPE=fp16
|
||||
export DEVICE_TARGET=Ascend # or GPU
|
||||
|
||||
for((i=0;i<$RANK_SIZE;i++));
|
||||
do
|
||||
rm -rf ${execute_path}/device_$i/
|
||||
mkdir ${execute_path}/device_$i/
|
||||
cd ${execute_path}/device_$i/ || exit
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
python -s ${self_path}/../predict.py --strategy_load_ckpt_path=$STRATEGY --tokenizer_path=$TOKENIZER --load_ckpt_path=$CKPT_PATH \
|
||||
--load_ckpt_name=$CKPT_NAME --mode=$MODE --run_type=predict --param_init_type=$PARAM_INIT_TYPE \
|
||||
--distribute=false --device_target=$DEVICE_TARGET >log$i.log 2>&1 &
|
||||
done
|
|
@ -1,109 +0,0 @@
|
|||
{
|
||||
"board_id": "0x0020",
|
||||
"chip_info": "910",
|
||||
"deploy_mode": "lab",
|
||||
"group_count": "1",
|
||||
"group_list": [
|
||||
{
|
||||
"device_num": "8",
|
||||
"server_num": "1",
|
||||
"group_name": "",
|
||||
"instance_count": "8",
|
||||
"instance_list": [
|
||||
{
|
||||
"devices": [
|
||||
{
|
||||
"device_id": "0",
|
||||
"device_ip": "192.98.92.121"
|
||||
}
|
||||
],
|
||||
"rank_id": "0",
|
||||
"server_id": "127.0.0.1"
|
||||
},
|
||||
{
|
||||
"devices": [
|
||||
{
|
||||
"device_id": "1",
|
||||
"device_ip": "192.98.93.121"
|
||||
}
|
||||
],
|
||||
"rank_id": "1",
|
||||
"server_id": "127.0.0.1"
|
||||
},
|
||||
{
|
||||
"devices": [
|
||||
{
|
||||
"device_id": "2",
|
||||
"device_ip": "192.98.94.121"
|
||||
}
|
||||
],
|
||||
"rank_id": "2",
|
||||
"server_id": "127.0.0.1"
|
||||
},
|
||||
{
|
||||
"devices": [
|
||||
{
|
||||
"device_id": "3",
|
||||
"device_ip": "192.98.95.121"
|
||||
}
|
||||
],
|
||||
"rank_id": "3",
|
||||
"server_id": "127.0.0.1"
|
||||
},
|
||||
{
|
||||
"devices": [
|
||||
{
|
||||
"device_id": "4",
|
||||
"device_ip": "192.98.92.122"
|
||||
}
|
||||
],
|
||||
"rank_id": "4",
|
||||
"server_id": "127.0.0.1"
|
||||
},
|
||||
{
|
||||
"devices": [
|
||||
{
|
||||
"device_id": "5",
|
||||
"device_ip": "192.98.93.122"
|
||||
}
|
||||
],
|
||||
"rank_id": "5",
|
||||
"server_id": "127.0.0.1"
|
||||
},
|
||||
{
|
||||
"devices": [
|
||||
{
|
||||
"device_id": "6",
|
||||
"device_ip": "192.98.94.122"
|
||||
}
|
||||
],
|
||||
"rank_id": "6",
|
||||
"server_id": "127.0.0.1"
|
||||
},
|
||||
{
|
||||
"devices": [
|
||||
{
|
||||
"device_id": "7",
|
||||
"device_ip": "192.98.95.122"
|
||||
}
|
||||
],
|
||||
"rank_id": "7",
|
||||
"server_id": "127.0.0.1"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"para_plane_nic_location": "device",
|
||||
"para_plane_nic_name": [
|
||||
"eth0",
|
||||
"eth1",
|
||||
"eth2",
|
||||
"eth3",
|
||||
"eth4",
|
||||
"eth5",
|
||||
"eth6",
|
||||
"eth7"
|
||||
],
|
||||
"para_plane_nic_num": "8",
|
||||
"status": "completed"
|
||||
}
|
|
@ -22,8 +22,8 @@ def start():
|
|||
"""Start agents to load and execute models of pangu alpha"""
|
||||
model_files = []
|
||||
for i in range(8):
|
||||
model_files.append([f"models/device{i}/pangu_alpha_1024_graph.mindir",
|
||||
f"models/device{i}/pangu_alpha_1_graph.mindir"])
|
||||
model_files.append([f"models/device_{i}/pangu_alpha_1024_graph.mindir",
|
||||
f"models/device_{i}/pangu_alpha_1_graph.mindir"])
|
||||
distributed.startup_agents(distributed_address="0.0.0.0:6200", model_files=model_files)
|
||||
|
||||
|
|
@ -21,8 +21,7 @@ from mindspore_serving.server import distributed
|
|||
|
||||
def start():
|
||||
"""Start server to serve service, and manage all agents which load and execute models"""
|
||||
servable_dir = os.path.abspath(".")
|
||||
|
||||
servable_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
distributed.start_servable(servable_dir, "pangu", rank_table_json_file="hccl_8p.json",
|
||||
distributed_address="0.0.0.0:6200")
|
||||
|
|
@ -0,0 +1,203 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""servable config for pangu alpha"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from easydict import EasyDict
|
||||
import numpy as np
|
||||
from mindspore_serving.server import register
|
||||
|
||||
from pangu.tokenization_jieba import JIEBATokenizer
|
||||
|
||||
cur_dir = os.path.abspath(os.path.dirname(__file__))
|
||||
tokenizer_path = os.path.join(cur_dir, "tokenizer")
|
||||
tokenizer = JIEBATokenizer(os.path.join(tokenizer_path, "vocab.vocab"), os.path.join(tokenizer_path, "vocab.model"))
|
||||
end_token = tokenizer.eot_id
|
||||
|
||||
config = EasyDict({
|
||||
'frequency_penalty': 1.5,
|
||||
'presence_penalty': 0.3,
|
||||
'max_generate_length': 500,
|
||||
'top_k_num': 3,
|
||||
'top_p': 1.0,
|
||||
'end_token': 9,
|
||||
'seq_length': 1024,
|
||||
'vocab_size': 40000,
|
||||
})
|
||||
|
||||
|
||||
def topk_fun(logits, topk=5):
|
||||
"""Get topk"""
|
||||
target_column = logits[0].tolist()
|
||||
sorted_array = [(k, v) for k, v in enumerate(target_column)]
|
||||
sorted_array.sort(key=lambda x: x[1], reverse=True)
|
||||
topk_array = sorted_array[:topk]
|
||||
index, value = zip(*topk_array)
|
||||
index = np.array([index])
|
||||
value = np.array([value])
|
||||
return value, index
|
||||
|
||||
|
||||
register.declare_servable(servable_file=["pangu_alpha_1024_graph.mindir", "pangu_alpha_1_graph.mindir"],
|
||||
model_format="MINDIR", with_batch_dim=False)
|
||||
|
||||
|
||||
@register.register_method(output_names=["logits"])
|
||||
def predict_sub0(input_ids, current_index, init, batch_valid_length):
|
||||
logits = register.call_servable(input_ids, current_index, init, batch_valid_length, subgraph=0)
|
||||
return logits
|
||||
|
||||
|
||||
@register.register_method(output_names=["logits"])
|
||||
def predict_sub1(input_id, current_index, init, batch_valid_length):
|
||||
logits = register.call_servable(input_id, current_index, init, batch_valid_length, subgraph=1)
|
||||
return logits
|
||||
|
||||
|
||||
sub0_servable = register.PipelineServable(servable_name="pangu", method="predict_sub0")
|
||||
sub1_servable = register.PipelineServable(servable_name="pangu", method="predict_sub1")
|
||||
|
||||
|
||||
@register.register_pipeline(output_names=["output_sentence"])
|
||||
def predict(input_sentence):
|
||||
"""generate sentence with given input_sentence"""
|
||||
|
||||
print(f"----------------------------- begin {input_sentence} ---------", flush=True)
|
||||
time_start = time.time()
|
||||
|
||||
tokens = tokenizer.tokenize(input_sentence)
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
|
||||
outputs = generate_increment(input_ids)
|
||||
|
||||
return_tokens = tokenizer.convert_ids_to_tokens(outputs)
|
||||
reply = "".join(return_tokens)
|
||||
|
||||
print(f"time cost {(time.time() - time_start) * 1000}ms, request '{input_sentence}' get reply '{reply}'",
|
||||
flush=True)
|
||||
|
||||
return reply
|
||||
|
||||
|
||||
def generate_increment(origin_inputs):
|
||||
"""
|
||||
Text generation for incremental inference
|
||||
|
||||
Inputs:
|
||||
model: the model for inferencing
|
||||
origin_inputs: the original inputs based on which the model will continue writing
|
||||
config: inference configurations
|
||||
|
||||
Returns:
|
||||
outputs: the ids for the generated text
|
||||
"""
|
||||
# Get configurations for inference
|
||||
frequency_penalty = config.frequency_penalty
|
||||
presence_penalty = config.presence_penalty
|
||||
top_p = config.top_p
|
||||
top_k_num = config.top_k_num
|
||||
max_generate_length = config.max_generate_length
|
||||
seq_length = config.seq_length
|
||||
vocab_size = config.vocab_size
|
||||
|
||||
# Init outputs with original inputs
|
||||
outputs = origin_inputs
|
||||
origin_inputs = np.array([origin_inputs])
|
||||
_, valid_length = origin_inputs.shape
|
||||
# If target length exceeds seq_length, use seq_length instead
|
||||
target_length = valid_length + max_generate_length
|
||||
target_length = seq_length if target_length > seq_length else target_length
|
||||
|
||||
# A list of the frequency of each token
|
||||
frequency_list = np.array([[0 for _ in range(vocab_size)]])
|
||||
pad_length = seq_length - origin_inputs.shape[-1]
|
||||
# Pad original inputs to seq_length
|
||||
input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), 'constant', constant_values=(0, 0))
|
||||
|
||||
# Indicate the exact token position
|
||||
current_index = valid_length - 1 if valid_length - 1 > 0 else 0
|
||||
current_index = np.array([current_index], np.int32)
|
||||
batch_valid_length = np.array([current_index], np.int32)
|
||||
# For first graph, not_init should be false
|
||||
init_true = True
|
||||
init_false = False
|
||||
init = init_false
|
||||
# Call a single inference with input size of (bs, seq_length)
|
||||
logits = sub0_servable.run(np.array(input_ids, np.int32), current_index, init, batch_valid_length)
|
||||
|
||||
# Claim the second graph and set not_init to true
|
||||
init = init_true
|
||||
|
||||
# A single loop generates one token, loop until reaching target seq_length or generating eod token
|
||||
while valid_length < target_length:
|
||||
# Reshape the output logits
|
||||
log_probs = logits.reshape(1, vocab_size)
|
||||
|
||||
# Get the revised log_probs considering frequency and presence penalty to eliminate duplicate in generated results
|
||||
log_probs = log_probs.reshape(1, vocab_size)
|
||||
log_probs_revised = log_probs - frequency_list * frequency_penalty - (frequency_list > 0) * presence_penalty
|
||||
|
||||
# Convert the log_probs to probability
|
||||
logits = np.power(10, np.array(log_probs_revised, np.float32))
|
||||
|
||||
# If top_p is less than 1.0, use top_p sampling
|
||||
if top_p < 1.0:
|
||||
# Only consider the 5000 largest logits to reduce computation
|
||||
sorted_logits, index = topk_fun(logits, 5000)
|
||||
cumsum_logits = np.cumsum(sorted_logits, 1)
|
||||
cumsum_logits = cumsum_logits[0]
|
||||
index = index[0]
|
||||
sorted_logits = sorted_logits[0]
|
||||
top_p_num = sum(cumsum_logits > top_p)
|
||||
# In case the probability is smooth, the sum of 5000 largest probabilities are not large enough
|
||||
if top_p_num == 0:
|
||||
top_p_num = 5000
|
||||
# Get the corresponding probs and indices
|
||||
probs = sorted_logits[:top_p_num]
|
||||
p_args = index[:top_p_num]
|
||||
p = probs / sum(probs)
|
||||
# if top_p is set to 1.0, use top_k sampling
|
||||
else:
|
||||
# Get the corresponding probs and indices
|
||||
probs, p_args = topk_fun(logits, top_k_num)
|
||||
probs = probs[0]
|
||||
p_args = p_args[0]
|
||||
# Avoid rounding error
|
||||
if sum(probs) == 0:
|
||||
probs = np.array([1 / top_k_num for _ in range(top_k_num)])
|
||||
p = probs / sum(probs)
|
||||
|
||||
# Random select a token as final output for this round
|
||||
target_index = np.random.choice(len(p), p=p)
|
||||
# Stop judgment
|
||||
if p_args[target_index] == end_token or valid_length == target_length - 1:
|
||||
break
|
||||
|
||||
# Update frequency list
|
||||
target = p_args[target_index]
|
||||
frequency_list[0][target] = frequency_list[0][target] + 1
|
||||
valid_length += 1
|
||||
|
||||
batch_valid_length = np.array([valid_length - 1], np.int32)
|
||||
current_index = np.array([0], np.int32)
|
||||
input_id = np.array([[target]], np.int32)
|
||||
# Update outputs with current generated token
|
||||
outputs.append(int(target))
|
||||
|
||||
# Call a single inference with input size of (bs, 1)
|
||||
logits = sub1_servable.run(input_id, current_index, init, batch_valid_length)
|
||||
# Return valid outputs out of padded outputs
|
||||
return outputs
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Tokenization classes for OpenAI GPT."""
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
from io import open
|
||||
import jieba
|
||||
import sentencepiece as spm
|
||||
|
||||
|
||||
class JIEBATokenizer:
|
||||
"""jieba tokenizer for encode and decode text"""
|
||||
|
||||
def __init__(self, vocab_file, model_file, max_len=None):
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
# self.encoder = json.load(open(vocab_file))
|
||||
f = open(vocab_file, 'r')
|
||||
lines = f.readlines()
|
||||
self.encoder = {}
|
||||
for line in enumerate(lines):
|
||||
key = line[1].split('\t')[0]
|
||||
self.encoder[key] = line[0]
|
||||
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
|
||||
self.sp = spm.SentencePieceProcessor(model_file=model_file)
|
||||
self.translator = str.maketrans(" \n", "\u2582\u2583")
|
||||
|
||||
self.eod_id = self.encoder['<eod>']
|
||||
self.eot_id = self.encoder['<eot>']
|
||||
self.pad_id = self.encoder['<pad>']
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.encoder)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.encoder) + len(self.special_tokens)
|
||||
|
||||
@property
|
||||
def eod(self):
|
||||
return self.eod_id
|
||||
|
||||
def tokenize(self, text):
|
||||
""" Tokenize a string. """
|
||||
seg_list = [x.translate(self.translator) for x in jieba.cut(text, cut_all=False)]
|
||||
new_seg = " ".join(seg_list)
|
||||
return self.sp.encode(new_seg)
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return tokens
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
return self.decode(ids)
|
||||
|
||||
def encode(self, text):
|
||||
res = self.tokenize(text)
|
||||
return res
|
||||
|
||||
def decode(self, tokens):
|
||||
text = self.sp.decode(tokens)
|
||||
text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n')
|
||||
return text
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Serving server start code, serve service, and manage all agents which load and execute models"""
|
||||
|
||||
import os
|
||||
from mindspore_serving import server
|
||||
|
||||
|
||||
def start():
|
||||
servable_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
config = server.ServableStartConfig(servable_directory=servable_dir, servable_name="pangu", device_ids=(0,))
|
||||
server.start_servables(config)
|
||||
|
||||
server.start_grpc_server("127.0.0.1:5500")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
start()
|
|
@ -21,7 +21,7 @@ export GLOG_v=1
|
|||
start_serving_server()
|
||||
{
|
||||
echo "### start serving server, see serving_server.log for detail ###"
|
||||
python3 serving_server.py > serving_server.log 2>&1 &
|
||||
python3 pangu_distributed/serving_server.py > serving_server.log 2>&1 &
|
||||
if [ $? -ne 0 ]
|
||||
then
|
||||
echo "serving server failed to start."
|
||||
|
@ -52,8 +52,8 @@ start_serving_server()
|
|||
|
||||
start_serving_agent()
|
||||
{
|
||||
echo "### start serving agent, see and serving_logs/log_pangu_distributed.log for detail ###"
|
||||
python3 serving_agent.py > serving_agent.log 2>&1 &
|
||||
echo "### start serving agent, see serving_agent.log and serving_logs/log_pangu_distributed.log for detail ###"
|
||||
python3 pangu_distributed/serving_agent.py > serving_agent.log 2>&1 &
|
||||
if [ $? -ne 0 ]
|
||||
then
|
||||
echo "serving agent failed to start."
|
||||
|
@ -68,7 +68,7 @@ start_serving_agent()
|
|||
if [ $num -eq 0 ]
|
||||
then
|
||||
bash stop_pangu.sh
|
||||
echo "start serving agent failed, see log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
|
||||
echo "start serving agent failed, see serving_agent.log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
|
||||
fi
|
||||
|
||||
count=$(($count+1))
|
||||
|
@ -78,7 +78,7 @@ start_serving_agent()
|
|||
if [ ${count} -eq 1800 ]
|
||||
then
|
||||
bash stop_pangu.sh
|
||||
echo "start serving agent failed, see log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
|
||||
echo "start serving agent failed, see serving_agent.log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
|
||||
fi
|
||||
echo "### start serving agent end ###"
|
||||
}
|
||||
|
@ -131,7 +131,7 @@ wait_serving_ready()
|
|||
if [ $num -eq 0 ]
|
||||
then
|
||||
bash stop_pangu.sh
|
||||
echo "waiting serving server ready failed, see log serving_server.log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
|
||||
echo "waiting serving server ready failed, see serving_server.log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
|
||||
fi
|
||||
|
||||
count=$(($count+1))
|
||||
|
@ -141,7 +141,7 @@ wait_serving_ready()
|
|||
if [ ${count} -eq 100 ]
|
||||
then
|
||||
bash stop_pangu.sh
|
||||
echo "waiting serving server ready failed, see log serving_server.log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
|
||||
echo "waiting serving server ready failed, see serving_server.log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
|
||||
fi
|
||||
echo "### waiting serving server ready end ###"
|
||||
}
|
||||
|
@ -150,5 +150,5 @@ bash stop_pangu.sh
|
|||
rm -rf serving_server.log serving_agent.log flask.log serving_logs
|
||||
start_serving_server
|
||||
start_serving_agent
|
||||
start_flask
|
||||
wait_serving_ready
|
||||
start_flask
|
|
@ -0,0 +1,120 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
unset http_proxy
|
||||
unset https_proxy
|
||||
export GLOG_v=1
|
||||
|
||||
start_serving_server()
|
||||
{
|
||||
echo "### start serving server, see serving_server.log for detail ###"
|
||||
python3 pangu_standalone/serving_server.py > serving_server.log 2>&1 &
|
||||
if [ $? -ne 0 ]
|
||||
then
|
||||
echo "serving server failed to start."
|
||||
fi
|
||||
|
||||
result=`grep -E 'Master server start success' serving_server.log | wc -l`
|
||||
count=0
|
||||
while [[ ${result} -eq 0 && ${count} -lt 100 ]]
|
||||
do
|
||||
sleep 1
|
||||
|
||||
num=`ps -ef | grep 'serving_server.py' | grep -v grep | wc -l`
|
||||
if [ $num -eq 0 ]
|
||||
then
|
||||
echo "start serving server failed, see log serving_server.log for more detail" && exit 1
|
||||
fi
|
||||
|
||||
count=$(($count+1))
|
||||
result=`grep -E 'Master server start success' serving_server.log | wc -l`
|
||||
done
|
||||
|
||||
if [ ${count} -eq 100 ]
|
||||
then
|
||||
echo "start serving server failed, see log serving_server.log for more detail" && exit 1
|
||||
fi
|
||||
echo "### start serving server end ###"
|
||||
}
|
||||
|
||||
start_flask()
|
||||
{
|
||||
echo "### start flask server, see flask.log for detail ###"
|
||||
python3 flask/client.py > flask.log 2>&1 &
|
||||
if [ $? -ne 0 ]
|
||||
then
|
||||
echo "flask server failed to start."
|
||||
fi
|
||||
|
||||
result=`grep -E 'Press CTRL\+C to quit' flask.log | wc -l`
|
||||
count=0
|
||||
while [[ ${result} -ne 1 && ${count} -lt 10 ]]
|
||||
do
|
||||
sleep 1
|
||||
|
||||
num=`ps -ef | grep 'flask/client.py' | grep -v grep | wc -l`
|
||||
if [ $num -eq 0 ]
|
||||
then
|
||||
bash stop_pangu.sh
|
||||
echo "start flask server failed, see log flask.log for more detail" && exit 1
|
||||
fi
|
||||
|
||||
count=$(($count+1))
|
||||
result=`grep -E 'Press CTRL\+C to quit' flask.log | wc -l`
|
||||
done
|
||||
|
||||
if [ ${count} -eq 10 ]
|
||||
then
|
||||
bash stop_pangu.sh
|
||||
echo "start flask server failed, see log flask.log for more detail" && exit 1
|
||||
fi
|
||||
echo "### start flask server end ###"
|
||||
cat flask.log
|
||||
}
|
||||
|
||||
wait_serving_ready()
|
||||
{
|
||||
echo "### waiting serving server ready, see and serving_logs/*.log for detail ###"
|
||||
result=`grep -E 'gRPC server start success' serving_server.log | wc -l`
|
||||
count=0
|
||||
while [[ ${result} -eq 0 && ${count} -lt 1800 ]]
|
||||
do
|
||||
sleep 1
|
||||
|
||||
num=`ps -ef | grep 'serving_server.py' | grep -v grep | wc -l`
|
||||
if [ $num -eq 0 ]
|
||||
then
|
||||
bash stop_pangu.sh
|
||||
echo "waiting serving server ready failed, see log serving_server.log and serving_logs/*.log for more detail" && exit 1
|
||||
fi
|
||||
|
||||
count=$(($count+1))
|
||||
result=`grep -E 'gRPC server start success' serving_server.log | wc -l`
|
||||
done
|
||||
|
||||
if [ ${count} -eq 1800 ]
|
||||
then
|
||||
bash stop_pangu.sh
|
||||
echo "waiting serving server ready failed, see log serving_server.log and serving_logs/*.log for more detail" && exit 1
|
||||
fi
|
||||
echo "### waiting serving server ready end ###"
|
||||
}
|
||||
|
||||
bash stop_pangu.sh
|
||||
rm -rf serving_server.log flask.log serving_logs
|
||||
start_serving_server
|
||||
wait_serving_ready
|
||||
start_flask
|
|
@ -22,6 +22,7 @@ import mindspore.common.dtype as mstype
|
|||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def topk_fun(logits, topk=5):
|
||||
"""Get topk"""
|
||||
target_column = logits[0].tolist()
|
||||
|
@ -33,6 +34,7 @@ def topk_fun(logits, topk=5):
|
|||
value = np.array([value])
|
||||
return value, index
|
||||
|
||||
|
||||
def sampler(log_probs_revised, top_p, top_k_num, use_pynative=False):
|
||||
"""Convert the log_probs to probability"""
|
||||
if use_pynative:
|
||||
|
@ -80,6 +82,7 @@ def sampler(log_probs_revised, top_p, top_k_num, use_pynative=False):
|
|||
p = probs / sum(probs)
|
||||
return p, p_args
|
||||
|
||||
|
||||
def generate(model, origin_inputs, config):
|
||||
"""
|
||||
Text generation
|
||||
|
@ -130,7 +133,7 @@ def generate(model, origin_inputs, config):
|
|||
# Random select a token as final output for this round
|
||||
target_index = np.random.choice(len(p), p=p)
|
||||
# Stop judgment
|
||||
if p_args[target_index] == end_token or valid_length == target_length-1:
|
||||
if p_args[target_index] == end_token or valid_length == target_length - 1:
|
||||
outputs = input_ids
|
||||
break
|
||||
|
||||
|
@ -145,6 +148,7 @@ def generate(model, origin_inputs, config):
|
|||
outputs = outputs[0][:length]
|
||||
return outputs
|
||||
|
||||
|
||||
def generate_increment(model, origin_inputs, config):
|
||||
"""
|
||||
Text generation for incremental inference
|
||||
|
@ -183,8 +187,8 @@ def generate_increment(model, origin_inputs, config):
|
|||
|
||||
# Indicate the exact token position
|
||||
current_index = valid_length - 1 if valid_length - 1 > 0 else 0
|
||||
current_index = Tensor(np.array([current_index]), mstype.int32)
|
||||
batch_valid_length = Tensor(np.array([current_index]), mstype.int32)
|
||||
current_index = Tensor(np.array([current_index]), mstype.int32)
|
||||
# For first graph, not_init should be false
|
||||
init_true = Tensor([True], mstype.bool_)
|
||||
init_false = Tensor([False], mstype.bool_)
|
||||
|
@ -211,7 +215,7 @@ def generate_increment(model, origin_inputs, config):
|
|||
# Random select a token as final output for this round
|
||||
target_index = np.random.choice(len(p), p=p)
|
||||
# Stop judgment
|
||||
if p_args[target_index] == end_token or valid_length == target_length-1:
|
||||
if p_args[target_index] == end_token or valid_length == target_length - 1:
|
||||
break
|
||||
|
||||
# Update frequency list
|
||||
|
|
|
@ -27,10 +27,12 @@ from mindspore import context
|
|||
from mindspore.common.seed import _get_graph_seed
|
||||
from mindspore._checkparam import Validator
|
||||
|
||||
|
||||
class Dropout(nn.Cell):
|
||||
r"""
|
||||
A Dropout Implements with P.DropoutGenMask and P.DropoutDoMask for parallel training.
|
||||
"""
|
||||
|
||||
def __init__(self, keep_prob=0.5, dtype=mstype.float32):
|
||||
super(Dropout, self).__init__()
|
||||
if keep_prob <= 0 or keep_prob > 1:
|
||||
|
@ -52,6 +54,7 @@ class Dropout(nn.Cell):
|
|||
self.cast = P.Cast()
|
||||
else:
|
||||
self.dropout = P.Dropout(keep_prob)
|
||||
|
||||
def construct(self, x):
|
||||
r"""
|
||||
Input: a tensor
|
||||
|
@ -83,10 +86,12 @@ class Dropout(nn.Cell):
|
|||
else:
|
||||
self.dropout.shard(strategy)
|
||||
|
||||
|
||||
class LayerNorm(nn.Cell):
|
||||
r"""
|
||||
A self-defined layer norm operation using reduce sum and reduce mean
|
||||
"""
|
||||
|
||||
def __init__(self, normalized_shape, dp=4, eps=1e-5, parallel_optimizer=False):
|
||||
super(LayerNorm, self).__init__()
|
||||
self.gamma = Parameter(initializer('ones', normalized_shape), name="gamma",
|
||||
|
@ -102,6 +107,7 @@ class LayerNorm(nn.Cell):
|
|||
self.add2 = P.TensorAdd().shard(((dp, 1, 1), (1,)))
|
||||
self.real_div = P.RealDiv().shard(((dp, 1, 1), (dp, 1, 1)))
|
||||
self.eps = eps
|
||||
|
||||
def construct(self, x):
|
||||
mean = self.mean(x, -1)
|
||||
diff = self.sub1(x, mean)
|
||||
|
@ -111,6 +117,7 @@ class LayerNorm(nn.Cell):
|
|||
output = self.add2(self.mul(output, self.gamma), self.beta)
|
||||
return output
|
||||
|
||||
|
||||
class Mapping(nn.Cell):
|
||||
"""
|
||||
A mapping function with a 3d input
|
||||
|
@ -167,6 +174,7 @@ class MappingOutput(nn.Cell):
|
|||
Returns:
|
||||
output: Tensor, a 3d tensor after projection
|
||||
"""
|
||||
|
||||
def __init__(self, config, input_size, output_size, scale=1.0):
|
||||
super(MappingOutput, self).__init__()
|
||||
self.output_size = output_size
|
||||
|
@ -205,6 +213,7 @@ class FeedForwardLayer(nn.Cell):
|
|||
Returns:
|
||||
output: Tensor, the output of this layer after mapping
|
||||
"""
|
||||
|
||||
def __init__(self, config, scale=1.0):
|
||||
super(FeedForwardLayer, self).__init__()
|
||||
input_size = config.embedding_size
|
||||
|
@ -226,6 +235,7 @@ class FeedForwardLayer(nn.Cell):
|
|||
output = self.dropout(output)
|
||||
return output
|
||||
|
||||
|
||||
class EmbeddingLookup(nn.Cell):
|
||||
"""
|
||||
The embedding lookup table for vocabulary
|
||||
|
@ -236,6 +246,7 @@ class EmbeddingLookup(nn.Cell):
|
|||
seq_length, embedding_size)
|
||||
self.embedding_table: Tensor, the embedding table for the vocabulary
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(EmbeddingLookup, self).__init__()
|
||||
self.gather = P.GatherV2()
|
||||
|
@ -244,6 +255,7 @@ class EmbeddingLookup(nn.Cell):
|
|||
output = self.gather(table, input_ids, 0)
|
||||
return output
|
||||
|
||||
|
||||
class Attention(nn.Cell):
|
||||
"""
|
||||
Self-Attention module for each layer
|
||||
|
@ -253,6 +265,7 @@ class Attention(nn.Cell):
|
|||
scale: scale factor for initialization
|
||||
layer_idx: current layer index
|
||||
"""
|
||||
|
||||
def __init__(self, config, scale=1.0, layer_idx=None):
|
||||
super(Attention, self).__init__()
|
||||
# Output layer
|
||||
|
@ -295,7 +308,6 @@ class Attention(nn.Cell):
|
|||
self.softmax.softmax.shard(((config.dp, config.mp, 1),))
|
||||
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
|
||||
|
||||
|
||||
dense_shape = [config.embedding_size, config.embedding_size]
|
||||
bias_shape = [config.embedding_size]
|
||||
# Query
|
||||
|
@ -399,7 +411,7 @@ class Attention(nn.Cell):
|
|||
# The first graph with the input size of (bs, seq_length)
|
||||
if self.is_first_iteration:
|
||||
# Get the valid input length without padding
|
||||
valid_length_vector = F.cast(self.less(self.range, batch_valid_length), self.dtype)
|
||||
valid_length_vector = F.cast(self.less(self.range, batch_valid_length.view(1, 1, -1)), self.dtype)
|
||||
# Cover the key and value numbers corresponding to the padding position
|
||||
key_present = self.mul1(key, self.expand_dims(valid_length_vector, 2))
|
||||
value_present = self.mul1(value, self.expand_dims(valid_length_vector, 3))
|
||||
|
@ -464,7 +476,7 @@ class Attention(nn.Cell):
|
|||
x_merge: the 3d output
|
||||
"""
|
||||
x = self.merger_head_transpose(
|
||||
x, (0, 2, 1, 3)) #bs, seq_length, head, size_per_head
|
||||
x, (0, 2, 1, 3)) # bs, seq_length, head, size_per_head
|
||||
x_shape = P.Shape()(x)
|
||||
new_shape = x_shape[:-2] + (x_shape[-2] * x_shape[-1],)
|
||||
x_merge = self.reshape(x, new_shape)
|
||||
|
@ -551,6 +563,7 @@ class Decoder(nn.Cell):
|
|||
output: Tensor, the output logit of this layer
|
||||
layer_present: Tensor, the feature map of current layer
|
||||
"""
|
||||
|
||||
def __init__(self, config, layer_idx):
|
||||
super(Decoder, self).__init__()
|
||||
scale = 1 / math.sqrt(2.0 * config.num_layers)
|
||||
|
@ -633,10 +646,12 @@ class Decoder(nn.Cell):
|
|||
output = self.add(x, mlp_logit)
|
||||
return output, layer_present
|
||||
|
||||
|
||||
class Embedding(nn.Cell):
|
||||
"""
|
||||
Embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(Embedding, self).__init__()
|
||||
self.word_embedding = EmbeddingLookup().set_comm_fusion(1)
|
||||
|
@ -691,18 +706,22 @@ class Mask(nn.Cell):
|
|||
"""
|
||||
Mask
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(Mask, self).__init__()
|
||||
self.dtype = config.compute_dtype
|
||||
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
|
||||
|
||||
def construct(self, attention_mask):
|
||||
attention_mask = self.expand_dims(attention_mask, 1)
|
||||
return attention_mask
|
||||
|
||||
|
||||
class QueryLayerAttention(Attention):
|
||||
r"""
|
||||
Self-Attention module using input query vector.
|
||||
"""
|
||||
|
||||
def construct(self, x, query_hidden_state, attention_mask, key_past=None, value_past=None, batch_valid_length=None):
|
||||
original_shape = F.shape(x)
|
||||
x = F.reshape(x, (-1, original_shape[-1]))
|
||||
|
@ -732,7 +751,7 @@ class QueryLayerAttention(Attention):
|
|||
# The first graph with the input size of (bs, seq_length)
|
||||
if self.is_first_iteration:
|
||||
# Get the valid input length without padding
|
||||
valid_length_vector = F.cast(self.less(self.range, batch_valid_length), self.dtype)
|
||||
valid_length_vector = F.cast(self.less(self.range, batch_valid_length.view(1, 1, -1)), self.dtype)
|
||||
# Cover the key and value numbers corresponding to the padding position
|
||||
key_present = self.mul1(key, self.expand_dims(valid_length_vector, 2))
|
||||
value_present = self.mul1(value, self.expand_dims(valid_length_vector, 3))
|
||||
|
@ -767,11 +786,13 @@ class QueryLayerAttention(Attention):
|
|||
output = self.dropout(output)
|
||||
return output, layer_present
|
||||
|
||||
|
||||
class QueryLayer(nn.Cell):
|
||||
r"""
|
||||
A block usingooked out position embedding as query vector.
|
||||
This is used as the final block.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(QueryLayer, self).__init__()
|
||||
scale = 1 / math.sqrt(2.0 * config.num_layers)
|
||||
|
@ -854,6 +875,7 @@ class QueryLayer(nn.Cell):
|
|||
output = self.add(x, mlp_logit)
|
||||
return output, layer_present
|
||||
|
||||
|
||||
class PanguAlphaEmbedding(nn.Cell):
|
||||
"""
|
||||
Input embedding, i.e., word embedding and position embedding
|
||||
|
@ -870,6 +892,7 @@ class PanguAlphaEmbedding(nn.Cell):
|
|||
attention_mask: Tensor, attention_mask matrix
|
||||
embedding_table: Tensor, embedding_table with shape of (vocab_size, embedding_size)
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(PanguAlphaEmbedding, self).__init__()
|
||||
self.embedding = Embedding(config)
|
||||
|
@ -885,6 +908,7 @@ class PanguAlphaEmbedding(nn.Cell):
|
|||
attention_mask = self.mask(attention_mask)
|
||||
return hidden_states, attention_mask
|
||||
|
||||
|
||||
class PanguAlpha_Model(nn.Cell):
|
||||
"""
|
||||
The backbone of PanguAlpha network
|
||||
|
@ -899,6 +923,7 @@ class PanguAlpha_Model(nn.Cell):
|
|||
present_layer: Tensor, the current feature map
|
||||
embedding_table: Tensor, the embedding table for the vocabulary
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(PanguAlpha_Model, self).__init__()
|
||||
self.embedding = PanguAlphaEmbedding(config)
|
||||
|
@ -971,12 +996,13 @@ class PanguAlpha_Model(nn.Cell):
|
|||
hidden_states, attention_mask = self.embedding(input_ids, input_mask, table,
|
||||
input_position, attention_mask,
|
||||
batch_valid_length)
|
||||
for i in range(self.num_layers-1):
|
||||
for i in range(self.num_layers - 1):
|
||||
hidden_states, _ = self.blocks[i](hidden_states,
|
||||
attention_mask, init_reset, batch_valid_length)
|
||||
if self.is_pipeline:
|
||||
top_query_hidden_states = self.top_query_embedding(input_position.view(-1,), self.top_query_embedding_table)
|
||||
hidden_states, _ = self.blocks[self.num_layers-1](hidden_states, top_query_hidden_states,
|
||||
top_query_hidden_states = self.top_query_embedding(input_position.view(-1,),
|
||||
self.top_query_embedding_table)
|
||||
hidden_states, _ = self.blocks[self.num_layers - 1](hidden_states, top_query_hidden_states,
|
||||
attention_mask, init_reset, batch_valid_length)
|
||||
output_state = self.layernorm(hidden_states)
|
||||
output_state = F.cast(output_state, self.dtype)
|
||||
|
@ -988,6 +1014,7 @@ class PanguAlpha_Model(nn.Cell):
|
|||
attention_mask, init_reset, batch_valid_length)
|
||||
return output_state
|
||||
|
||||
|
||||
class PanguAlpha_Head(nn.Cell):
|
||||
"""
|
||||
Head for PanguAlpha to get the logits of each token in the vocab
|
||||
|
@ -999,6 +1026,7 @@ class PanguAlpha_Head(nn.Cell):
|
|||
Returns:
|
||||
logits: Tensor, the logits of the corresponding inputs
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(PanguAlpha_Head, self).__init__()
|
||||
if config.word_emb_dp:
|
||||
|
@ -1029,6 +1057,7 @@ class PanguAlpha(nn.Cell):
|
|||
Returns:
|
||||
logits: Tensor: the logits of the corresponding inputs with shape (batch_size, seq_length, vocab_size)
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(PanguAlpha, self).__init__()
|
||||
# Network head to get logits over vocabulary
|
||||
|
@ -1065,6 +1094,7 @@ class PanguAlpha(nn.Cell):
|
|||
logits = self.head(output_states, self.embedding_table)
|
||||
return logits
|
||||
|
||||
|
||||
class CrossEntropyLoss(nn.Cell):
|
||||
"""
|
||||
Calculate the cross entropy loss
|
||||
|
@ -1077,6 +1107,7 @@ class CrossEntropyLoss(nn.Cell):
|
|||
Returns:
|
||||
loss: Tensor, the corrsponding cross entropy loss
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(CrossEntropyLoss, self).__init__()
|
||||
self.mean = P.ReduceMean()
|
||||
|
@ -1137,6 +1168,7 @@ class CrossEntropyLoss(nn.Cell):
|
|||
loss = self.div2(numerator, denominator)
|
||||
return loss
|
||||
|
||||
|
||||
class PanguAlphaWithLoss(nn.Cell):
|
||||
"""
|
||||
PanguAlpha training loss
|
||||
|
@ -1150,6 +1182,7 @@ class PanguAlphaWithLoss(nn.Cell):
|
|||
Returns:
|
||||
output: Tensor, the loss of the network
|
||||
"""
|
||||
|
||||
def __init__(self, config, network, loss, eos_token=6):
|
||||
super(PanguAlphaWithLoss, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
|
@ -1172,6 +1205,40 @@ class PanguAlphaWithLoss(nn.Cell):
|
|||
output = self.loss(logits, labels, input_mask)
|
||||
return output
|
||||
|
||||
|
||||
class AttentionMask(nn.Cell):
|
||||
"""
|
||||
Get the attention matrix for self-attention module
|
||||
Args:
|
||||
seq_length: the pre-defined sequence length
|
||||
Inputs:
|
||||
input_mask: the mask indicating whether each position is a valid input
|
||||
Returns:
|
||||
attention_mask: the attention mask matrix with shape (batch_size, seq_length, seq_length)
|
||||
"""
|
||||
|
||||
def __init__(self, seq_length):
|
||||
super(AttentionMask, self).__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.mul = P.BatchMatMul().shard(
|
||||
((1, 1, 1), (1, 1, 1)))
|
||||
self.expand_dim = P.ExpandDims().shard(((1, 1),))
|
||||
ones = np.ones(shape=(seq_length, seq_length))
|
||||
self.lower_triangle_mask = Tensor(np.tril(ones), mstype.float32)
|
||||
self.multiply = P.Mul().shard(((1, 1, 1), (1, 1, 1)))
|
||||
|
||||
def construct(self, input_mask):
|
||||
input_shape = P.Shape()(input_mask)
|
||||
shape_right = (input_shape[0], 1, input_shape[1])
|
||||
shape_left = input_shape + (1,)
|
||||
mask_left = self.reshape(input_mask, shape_left)
|
||||
mask_right = self.reshape(input_mask, shape_right)
|
||||
attention_mask = self.mul(mask_left, mask_right)
|
||||
lower_triangle = self.expand_dim(self.lower_triangle_mask, 0)
|
||||
attention_mask = self.multiply(attention_mask, lower_triangle)
|
||||
return attention_mask
|
||||
|
||||
|
||||
class EvalNet(nn.Cell):
|
||||
"""
|
||||
PanguAlpha evaluation net
|
||||
|
@ -1185,7 +1252,8 @@ class EvalNet(nn.Cell):
|
|||
Returns:
|
||||
outputs: Tensor, corresponding output for different tasks
|
||||
"""
|
||||
def __init__(self, backbone, generate=False, pad_token=6):
|
||||
|
||||
def __init__(self, backbone, generate=False, pad_token=6, seq_length=1024):
|
||||
super(EvalNet, self).__init__(auto_prefix=False)
|
||||
self.backbone = backbone
|
||||
self.pad_token = pad_token
|
||||
|
@ -1194,14 +1262,19 @@ class EvalNet(nn.Cell):
|
|||
self.topk = P.TopK(sorted=True).shard(((1, 1),))
|
||||
self.gather = P.GatherV2().shard(((1, 1), (1,)))
|
||||
self.log_softmax = P.LogSoftmax().shard(((1, 1, 1),))
|
||||
self.get_attention_mask = AttentionMask(seq_length)
|
||||
|
||||
def construct(self, input_ids, current_index, init_reset=True, batch_valid_length=None):
|
||||
"""evaluation net"""
|
||||
input_mask = F.cast(F.not_equal(input_ids, self.pad_token), mstype.float32)
|
||||
logits = self.backbone(input_ids, input_mask)
|
||||
bs, seq_length = F.shape(input_ids)
|
||||
attention_mask = self.get_attention_mask(input_mask)
|
||||
input_position = F.tuple_to_array(F.make_range(seq_length))
|
||||
input_position = P.Tile()(input_position, (bs, 1))
|
||||
logits = self.backbone(input_ids, input_mask, input_position, attention_mask,
|
||||
init_reset, batch_valid_length)
|
||||
index = current_index.view(1,)
|
||||
logits = self.gather(logits, index, 0)
|
||||
bs, _ = F.shape(input_ids)
|
||||
logits = logits.view(bs, 1, -1)
|
||||
log_probs = self.log_softmax(logits)
|
||||
return log_probs
|
||||
|
|
|
@ -71,12 +71,14 @@ class FP32StateAdamWeightDecay(AdamWeightDecay):
|
|||
|
||||
get_square_sum = C.MultitypeFuncGraph("get_square_sum")
|
||||
|
||||
|
||||
@get_square_sum.register("Tensor", "Number")
|
||||
def _get_square_sum(grad, value):
|
||||
norm = P.ReduceSum(False)(F.square(grad), ()) / value
|
||||
norm = F.expand_dims(F.cast(norm, mstype.float32), 0)
|
||||
return norm
|
||||
|
||||
|
||||
apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
|
||||
|
||||
|
||||
|
@ -85,6 +87,7 @@ def _apply_global_norm(clip_norm, global_norm, grad):
|
|||
grad = grad * clip_norm / global_norm
|
||||
return grad
|
||||
|
||||
|
||||
def _get_model_parallel_group(mp):
|
||||
"""
|
||||
|
||||
|
@ -104,6 +107,7 @@ def _get_model_parallel_group(mp):
|
|||
rank_list = [x + index * mp + stage_id * per_stage_device_nums for x in group]
|
||||
return rank_list, rank_list_str
|
||||
|
||||
|
||||
def _get_pipeline_group():
|
||||
"""
|
||||
|
||||
|
@ -121,10 +125,12 @@ def _get_pipeline_group():
|
|||
rank_list_str = "-".join(rank_str_list)
|
||||
return rank_list, rank_list_str
|
||||
|
||||
|
||||
class GlobalNorm(nn.Cell):
|
||||
"""
|
||||
Calculate the global norm value of given tensors
|
||||
"""
|
||||
|
||||
def __init__(self, params, config):
|
||||
super(GlobalNorm, self).__init__()
|
||||
self.norm = nn.Norm()
|
||||
|
@ -165,12 +171,14 @@ class GlobalNorm(nn.Cell):
|
|||
global_norms = F.sqrt(P.AllReduce()(square_reduce_sum))
|
||||
return global_norms
|
||||
|
||||
|
||||
class ClipByGlobalNorm(nn.Cell):
|
||||
"""
|
||||
|
||||
Clip grads by global norm
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, params, config, clip_norm=1.0):
|
||||
super(ClipByGlobalNorm, self).__init__()
|
||||
self.global_norm = GlobalNorm(params, config)
|
||||
|
@ -185,6 +193,7 @@ class ClipByGlobalNorm(nn.Cell):
|
|||
grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads)
|
||||
return grads, global_norm_value
|
||||
|
||||
|
||||
class LearningRate(LearningRateSchedule):
|
||||
"""
|
||||
Warmup-decay learning rate for PanguAlpha network.
|
||||
|
@ -259,6 +268,11 @@ def add_inference_params(opt):
|
|||
type=int,
|
||||
default=0,
|
||||
help="Whether use pynative op for postproecess")
|
||||
opt.add_argument("--use_past",
|
||||
type=str,
|
||||
default="true",
|
||||
choices=["true", "false"],
|
||||
help="Whether enable state reuse")
|
||||
|
||||
|
||||
def add_training_params(opt):
|
||||
|
|
Loading…
Reference in New Issue