!19622 Serving, pangu alpha modelzoo
Merge pull request !19622 from 徐永飞/master
This commit is contained in:
commit
b4c04ef3a8
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
# It is still under development
|
# It is still under development
|
||||||
|
|
||||||
# Contents
|
# Contents
|
||||||
|
@ -185,7 +184,7 @@ bash scripts/run_distribute_train_incremental_train.sh DATASET RANK_TABLE 8 fp32
|
||||||
Please refer to the [website](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha) to download the following parts:
|
Please refer to the [website](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha) to download the following parts:
|
||||||
|
|
||||||
- tokenizer: vocab.txt and vocab.model
|
- tokenizer: vocab.txt and vocab.model
|
||||||
- checkpoint file: \*.part\[0-4\] and *.npy under the same parameter size
|
- checkpoint file: \*.part\[0-4\] (need to extract) and *.npy under the same parameter size
|
||||||
- strategy file: a file described how the parameters are sliced across different devices.
|
- strategy file: a file described how the parameters are sliced across different devices.
|
||||||
|
|
||||||
Here we suppose the downloaded checkpoint, tokenizer and strategy file is organized as follows:
|
Here we suppose the downloaded checkpoint, tokenizer and strategy file is organized as follows:
|
||||||
|
@ -204,6 +203,12 @@ ckpts
|
||||||
└── vocab10.vocab
|
└── vocab10.vocab
|
||||||
```
|
```
|
||||||
|
|
||||||
|
We provide two predict methods. The first one is the normal way which needs to pad the input to a certain length every
|
||||||
|
iteration. Due to the redundant calculation, the latency of this method is quite high and to accelerate the speed
|
||||||
|
performance, we provide the second state reuse (incremental inference) method.
|
||||||
|
|
||||||
|
The state reuse method is the default mode, and you can disable it by changing the argument 'use_past' to False.
|
||||||
|
|
||||||
### Run Prediction on Distributed mode
|
### Run Prediction on Distributed mode
|
||||||
|
|
||||||
The following script will run prediction on 8 Ascend cards.
|
The following script will run prediction on 8 Ascend cards.
|
||||||
|
@ -217,28 +222,217 @@ ${FILE_PATH}/tokenizer/ ${FILE_PATH}/checkpoint_file filitered 2.6B fp32
|
||||||
### Run Prediction Using One Device
|
### Run Prediction Using One Device
|
||||||
|
|
||||||
The following script will run prediction on 1 Ascend cards. The difference is the net is initialized with float16 type.
|
The following script will run prediction on 1 Ascend cards. The difference is the net is initialized with float16 type.
|
||||||
And the rank_table should be configured to one device.
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$FILE_PATH=/home/your_path/ckpts
|
$FILE_PATH=/home/your_path/ckpts
|
||||||
bash scripts/run_distribute_predict.sh 1 /home/config/rank_table_1p.json ${FILE_PATH}/strategy_load_ckpt/strategy.ckpt \
|
bash scripts/run_standalone_predict.sh ${FILE_PATH}/strategy_load_ckpt/strategy.ckpt \
|
||||||
${FILE_PATH}/tokenizer/ ${FILE_PATH}/checkpoint_file filitered 2.6B fp16
|
${FILE_PATH}/tokenizer/ ${FILE_PATH}/checkpoint_file filitered 2.6B
|
||||||
```
|
```
|
||||||
|
|
||||||
### Run Serving
|
### Run Serving
|
||||||
|
|
||||||
In directory serving:
|
#### Preparation
|
||||||
|
|
||||||
- Use scripts/run_distribute_export.sh to export MindIR models, and copy all device* to serving_increment/models/.
|
- Pip install MindSpore and MindSpore Serving 1.3 or later.
|
||||||
- Download [PanGu-Alpha tokenizer repository](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha.git) and copy pangu-alpha/tokenizer to directory pangu/tokenizer.
|
- Pip install flask, flask-apscheduler, jieba, sentencepiece and other whl package if needed.
|
||||||
- Pip install MindSpore and MindSpore Serving 1.2 whl package.
|
- Download [PanGu-Alpha repository](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha), we will need
|
||||||
- Pip install flask, flask-apscheduler, jieba, sentencepiece whl package.
|
`pangu-alpha/strategy_load_ckpt` and `pangu-alpha/tokenizer` in the following process.
|
||||||
- Edit server_agent.py and update the path of pangu-alpha models.
|
- Download 13B or 2.6B checkpoint files and `*embedding` files
|
||||||
- Run 'bash start_pangu.sh' to start new execution.
|
from [PanGu-Alpha repository](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha).
|
||||||
- Wait for serving to start successfully: observe the serving_server.log file until the message "Serving: gRPC server start success, listening on 127.0.0.1:5500" is output.
|
|
||||||
- If any error happened, log can be viewed in serving_server.log, serving_agent.log and flask.log.
|
For 13B, we will need `13B_part0` to `13B_part3`, `13B_word_embedding`, `13B_top_query_embedding`
|
||||||
- If anything all right, access address {ip}:5000 in one browser.
|
, `13B_position_embedding`.
|
||||||
- Run 'bash stop_pangu.sh' to stop the existing execution.
|
|
||||||
|
For 2.6B, we will need `2.6B_part0` to `2.6B_part3`, `13B_word_embedding`, `2.6B_top_query_embedding`
|
||||||
|
, `2.6B_position_embedding`.
|
||||||
|
|
||||||
|
Decompress all the `13B_part*` or `2.6B_part*` tar files and a large number of `*ckpt` files will generate. Move
|
||||||
|
all `*embedding` to the same directory of `*.ckpt` files.
|
||||||
|
|
||||||
|
#### Run 13B or 2.6B in standalone mode[Ascend910/Nvidia GPU]
|
||||||
|
|
||||||
|
- Use scripts/run_standalone_export.sh to export MindIR models, and move all device_0/* to
|
||||||
|
'serving_increment/pangu_standalone/pangu/1/'.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
>>> cd scripts
|
||||||
|
>>> bash run_standalone_export.sh ${strategy_file_path} ${ckpt_dir_path}
|
||||||
|
```
|
||||||
|
|
||||||
|
Update the parameter `MODE` in `run_standalone_export.sh` from `13B` to `2.6B` if we want to export 2.6B model.
|
||||||
|
|
||||||
|
Update the parameter `DEVICE_TARGET` in `run_standalone_export.sh` from `Ascend` to `GPU` when running in GPU environment.
|
||||||
|
|
||||||
|
The `${strategy_file_path}` is file path of `pangu-alpha/strategy_load_ckpt/angu_alpha_13B_cktp_strategy.ckpt` for 13B
|
||||||
|
and `pangu-alpha/strategy_load_ckpt/angu_alpha_2.6B_cktp_strategy.ckpt` for 2.6B.
|
||||||
|
|
||||||
|
The `${ckpt_dir_path}` is the directory the `*ckpt` files generated by decompression and `*embedding` files.
|
||||||
|
|
||||||
|
The model will be exported for some minutes. Check log device_0/log0.log, confirm that there is no exception at last.
|
||||||
|
Confirm that mindir files have been generated in device_0/ which means that the model is exported successfully.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
>>> ls device_0
|
||||||
|
pangu_alpha_1024_graph.mindir pangu_alpha_1024_variables pangu_alpha_1_graph.mindir pangu_alpha_1_variables
|
||||||
|
>>> cd - && mkdir serving_increment/pangu_standalone/pangu/1/
|
||||||
|
>>> mv scripts/device_0/* serving_increment/pangu_standalone/pangu/1/
|
||||||
|
>>> cd serving_increment
|
||||||
|
```
|
||||||
|
|
||||||
|
- Copy `pangu-alpha/tokenizer` to directory serving_increment/pangu_standalone/pangu/tokenizer.
|
||||||
|
|
||||||
|
The directory hierarchy of the required files is shown below. The pangu_alpha_1024_variables and pangu_alpha_1_variables are folded for easy display.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
>>> tree pangu_distributed
|
||||||
|
pangu_standalone/
|
||||||
|
├── pangu
|
||||||
|
│ ├── 1
|
||||||
|
│ │ ├── pangu_alpha_1024_graph.mindir
|
||||||
|
│ │ ├── pangu_alpha_1024_variables/
|
||||||
|
│ │ ├── pangu_alpha_1_graph.mindir
|
||||||
|
│ │ └── pangu_alpha_1_variables/
|
||||||
|
│ ├── servable_config.py
|
||||||
|
│ ├── tokenization_jieba.py
|
||||||
|
│ └── tokenizer
|
||||||
|
│ ├── vocab.model
|
||||||
|
│ └── vocab.vocab
|
||||||
|
└── serving_server.py
|
||||||
|
```
|
||||||
|
|
||||||
|
- Run `bash start_pangu_standalone.sh` to start new execution, and wait until the serving and flask server are started
|
||||||
|
successfully.
|
||||||
|
- If any error happened, log can be viewed in serving_server.log, serving_logs/*.log and flask.log.
|
||||||
|
- If anything all right, access address {ip}:5000 in one browser. It will take some time to return the reply.
|
||||||
|
- Run `bash stop_pangu.sh` to stop the existing execution.
|
||||||
|
|
||||||
|
#### Run 13B or 2.6B in distributed mode[Ascend910 8 cards]
|
||||||
|
|
||||||
|
- Generate [rank table file](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# mindspore/model_zoo/utils/hccl_tools/hccl_tools.py
|
||||||
|
>>> python3 ../../../utils/hccl_tools/hccl_tools.py --device_num "[0,8]"
|
||||||
|
>>> mv hccl_8p_01234567*.json serving_increment/pangu_distributed/hccl_8p.json
|
||||||
|
```
|
||||||
|
|
||||||
|
- Use scripts/run_distribute_export.sh to export MindIR models, and move all device* to
|
||||||
|
'serving_increment/pangu_distributed/models/'.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
>>> cd scripts
|
||||||
|
>>> bash run_distribute_export.sh ${strategy_file_path} ${ckpt_dir_path}
|
||||||
|
```
|
||||||
|
|
||||||
|
Update the parameter `MODE` in `run_distribute_export.sh` from `13B` to `2.6B` if we want to export 2.6B model.
|
||||||
|
|
||||||
|
The `${strategy_file_path}` is file path of `pangu-alpha/strategy_load_ckpt/angu_alpha_13B_cktp_strategy.ckpt` for 13B
|
||||||
|
and `pangu-alpha/strategy_load_ckpt/angu_alpha_2.6B_cktp_strategy.ckpt` for 2.6B.
|
||||||
|
|
||||||
|
The `${ckpt_dir_path}` is the directory the *ckpt files generated by decompression and *embedding files.
|
||||||
|
|
||||||
|
The model will be exported for some minutes. Check log device_[0-7]/log[0-7].log, confirm that there is no exception at
|
||||||
|
last. Confirm that mindir files have been generated in device_[0-7]/ which means that the model is exported
|
||||||
|
successfully.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
>>> cd - && mkdir serving_increment/pangu_distributed/models/
|
||||||
|
>>> mv scripts/device_* serving_increment/pangu_distributed/models/
|
||||||
|
>>> cd serving_increment
|
||||||
|
```
|
||||||
|
|
||||||
|
- Update MindIR file name serving_increment/pangu_distributed/serving_agent.py if needed.
|
||||||
|
- Copy `pangu-alpha/tokenizer` to directory serving_increment/pangu_distributed/pangu/tokenizer.
|
||||||
|
|
||||||
|
The directory hierarchy of the required files is shown below. The device_1 to device_7 are folded for easy display.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
>>> tree pangu_distributed
|
||||||
|
pangu_distributed/
|
||||||
|
├── hccl_8p.json
|
||||||
|
├── models
|
||||||
|
│ ├── device_0
|
||||||
|
│ │ ├── pangu_alpha_1024_graph.mindir
|
||||||
|
│ │ ├── pangu_alpha_1024_variables
|
||||||
|
│ │ │ ├── data_0
|
||||||
|
│ │ │ ├── data_1
|
||||||
|
│ │ │ ├── data_2
|
||||||
|
│ │ │ ├── data_3
|
||||||
|
│ │ │ └── data_4
|
||||||
|
│ │ ├── pangu_alpha_1_graph.mindir
|
||||||
|
│ │ └── pangu_alpha_1_variables
|
||||||
|
│ │ ├── data_0
|
||||||
|
│ │ ├── data_1
|
||||||
|
│ │ ├── data_2
|
||||||
|
│ │ ├── data_3
|
||||||
|
│ │ └── data_4
|
||||||
|
│ ├── device_1/
|
||||||
|
│ ├── device_2/
|
||||||
|
│ ├── device_3/
|
||||||
|
│ ├── device_4/
|
||||||
|
│ ├── device_5/
|
||||||
|
│ ├── device_6/
|
||||||
|
│ └── device_7/
|
||||||
|
├── pangu
|
||||||
|
│ ├── servable_config.py
|
||||||
|
│ ├── tokenization_jieba.py
|
||||||
|
│ └── tokenizer
|
||||||
|
│ ├── vocab.model
|
||||||
|
│ └── vocab.vocab
|
||||||
|
├── serving_agent.py
|
||||||
|
└── serving_server.py
|
||||||
|
```
|
||||||
|
|
||||||
|
- Run `bash start_pangu_distributed.sh` to start new execution, and wait until the serving and flask server are started
|
||||||
|
successfully.
|
||||||
|
- If any error happened, log can be viewed in serving_server.log, serving_agent.log, serving_logs/*.log and flask.log.
|
||||||
|
- If anything all right, access address {ip}:5000 in one browser. It will take some time to return the reply.
|
||||||
|
- Run `bash stop_pangu.sh` to stop the existing execution.
|
||||||
|
|
||||||
|
#### Run in distributed mode[Ascend910 8 cards * N machine]
|
||||||
|
|
||||||
|
- Generate [rank table file](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
|
||||||
|
- In every machine, prepare for checkpoint files and embedding files. We can also use 13B as a test example.
|
||||||
|
- In every machine, use scripts/run_cluster_export.sh to export MindIR models, and move all device* to
|
||||||
|
'serving_increment/pangu_distributed/models/'.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
>>> cd scripts
|
||||||
|
>>> bash run_cluster_export.sh ${strategy_file_path} ${ckpt_dir_path} ${rank_table_file} ${rank_size} ${rank_start}
|
||||||
|
```
|
||||||
|
|
||||||
|
Update the parameter `MODE` in `run_distribute_export.sh` from `200B` to `13B` if we want to export 13B model.
|
||||||
|
|
||||||
|
The `${rank_start}` is the first rank id in every machine, likes 0,8,16,24.
|
||||||
|
|
||||||
|
The model will be exported for some minutes. Check log device_[0-7]/log[0-7].log, confirm that there is no exception at
|
||||||
|
last. Confirm that mindir files have been generated in device_[0-7]/ which means that the model is exported
|
||||||
|
successfully.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
>>> cd - && mkdir serving_increment/pangu_distributed/models/
|
||||||
|
>>> mv scripts/device_* serving_increment/pangu_distributed/models/
|
||||||
|
>>> cd serving_increment
|
||||||
|
```
|
||||||
|
|
||||||
|
- In the first machine, update the parameter `rank_size` and `stage_size`(Pipeline stage size) of `serving_increment/pangu_distributed/pangu/servable_config.py`.
|
||||||
|
- In the first machine, update the parameter `rank_table_json_file` of `serving_increment/pangu_distributed/serving_server.py`.
|
||||||
|
- In every machine, update MindIR file name serving_increment/pangu_distributed/serving_agent.py if needed.
|
||||||
|
- In every machine, update the parameter `distributed_address` of `serving_increment/pangu_distributed/serving_agent.py` and
|
||||||
|
`serving_increment/pangu_distributed/serving_server.py` to the first machine ip address.
|
||||||
|
- In the first machine, copy `pangu-alpha/tokenizer` to directory serving_increment/pangu_distributed/pangu/tokenizer.
|
||||||
|
- In the first machine, run `bash start_pangu_distributed.sh` to start new execution.
|
||||||
|
- Meanwhile, in other machines, run `python serving_agent.py` to start serving agent process.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
>>> unset http_proxy && unset https_proxy
|
||||||
|
>>> python pangu_distributed/serving_agent.py > serving_agent.log 2>&1 &
|
||||||
|
```
|
||||||
|
|
||||||
|
- Wait until the serving and flask server are started successfully.
|
||||||
|
- If any error happened, log can be viewed in serving_server.log, serving_agent.log, serving_logs/*.log and flask.log.
|
||||||
|
- If anything all right, access address {first_machine_ip}:5000 in one browser. It will take some time to return the reply.
|
||||||
|
- Run `bash stop_pangu.sh` to stop the existing execution in every machine.
|
||||||
|
|
||||||
# [Script Description](#contents)
|
# [Script Description](#contents)
|
||||||
|
|
||||||
|
|
|
@ -63,8 +63,12 @@ def load_model(args_opt):
|
||||||
else:
|
else:
|
||||||
rank = 0
|
rank = 0
|
||||||
device_num = 1
|
device_num = 1
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
context.set_auto_parallel_context(
|
||||||
|
strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path)
|
||||||
|
|
||||||
|
use_past = (args_opt.use_past == "true")
|
||||||
print('local_rank:{}, start to run...'.format(rank), flush=True)
|
print('local_rank:{}, start to run...'.format(rank), flush=True)
|
||||||
use_past = False
|
|
||||||
if args_opt.export:
|
if args_opt.export:
|
||||||
use_past = True
|
use_past = True
|
||||||
# Set model property
|
# Set model property
|
||||||
|
@ -72,6 +76,9 @@ def load_model(args_opt):
|
||||||
data_parallel_num = int(device_num / model_parallel_num)
|
data_parallel_num = int(device_num / model_parallel_num)
|
||||||
per_batch_size = args_opt.per_batch_size
|
per_batch_size = args_opt.per_batch_size
|
||||||
batch_size = per_batch_size * data_parallel_num
|
batch_size = per_batch_size * data_parallel_num
|
||||||
|
# Now only support single batch_size for predict
|
||||||
|
if args_opt.run_type == "predict":
|
||||||
|
batch_size = 1
|
||||||
config = PANGUALPHAConfig(
|
config = PANGUALPHAConfig(
|
||||||
data_parallel_num=data_parallel_num,
|
data_parallel_num=data_parallel_num,
|
||||||
model_parallel_num=model_parallel_num,
|
model_parallel_num=model_parallel_num,
|
||||||
|
@ -105,11 +112,15 @@ def load_model(args_opt):
|
||||||
inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32)
|
inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32)
|
||||||
current_index = Tensor(np.array([0]), mstype.int32)
|
current_index = Tensor(np.array([0]), mstype.int32)
|
||||||
|
|
||||||
if config.use_past:
|
if args_opt.distribute == "false":
|
||||||
|
predict_layout = None
|
||||||
|
elif config.use_past:
|
||||||
batch_valid_length = Tensor(np.array([0]), mstype.int32)
|
batch_valid_length = Tensor(np.array([0]), mstype.int32)
|
||||||
init_true = Tensor([True], mstype.bool_)
|
init_true = Tensor([True], mstype.bool_)
|
||||||
inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32)
|
inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32)
|
||||||
|
model_predict.predict_network.add_flags_recursive(is_first_iteration=True)
|
||||||
predict_layout = model_predict.infer_predict_layout(inputs_np, current_index, init_true, batch_valid_length)
|
predict_layout = model_predict.infer_predict_layout(inputs_np, current_index, init_true, batch_valid_length)
|
||||||
|
model_predict.predict_network.add_flags_recursive(is_first_iteration=False)
|
||||||
_ = model_predict.infer_predict_layout(inputs_np_1, current_index, init_true, batch_valid_length)
|
_ = model_predict.infer_predict_layout(inputs_np_1, current_index, init_true, batch_valid_length)
|
||||||
else:
|
else:
|
||||||
predict_layout = model_predict.infer_predict_layout(inputs_np, current_index)
|
predict_layout = model_predict.infer_predict_layout(inputs_np, current_index)
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
execute_path=$(pwd)
|
||||||
|
script_self=$(readlink -f "$0")
|
||||||
|
self_path=$(dirname "${script_self}")
|
||||||
|
|
||||||
|
MODE=200B # or 13B
|
||||||
|
PARAM_INIT_TYPE=fp16
|
||||||
|
STRATEGY=$1
|
||||||
|
CKPT_PATH=$2
|
||||||
|
export RANK_TABLE_FILE=$3
|
||||||
|
export RANK_SIZE=$4
|
||||||
|
RANK_START=$5 # 0,8,16,... for each machine
|
||||||
|
CKPT_NAME='filerted'
|
||||||
|
|
||||||
|
for((i=0;i<8;i++));
|
||||||
|
do
|
||||||
|
rm -rf ${execute_path}/device_$i/
|
||||||
|
mkdir ${execute_path}/device_$i/
|
||||||
|
cd ${execute_path}/device_$i/ || exit
|
||||||
|
export RANK_ID=$(($RANK_START+$i))
|
||||||
|
export DEVICE_ID=$i
|
||||||
|
python -s ${self_path}/../predict.py --strategy_load_ckpt_path=$STRATEGY --load_ckpt_path=$CKPT_PATH \
|
||||||
|
--load_ckpt_name=$CKPT_NAME --mode=$MODE --run_type=predict --param_init_type=$PARAM_INIT_TYPE \
|
||||||
|
--export=1 >log$i.log 2>&1 &
|
||||||
|
done
|
|
@ -18,12 +18,12 @@ execute_path=$(pwd)
|
||||||
script_self=$(readlink -f "$0")
|
script_self=$(readlink -f "$0")
|
||||||
self_path=$(dirname "${script_self}")
|
self_path=$(dirname "${script_self}")
|
||||||
export RANK_SIZE=8
|
export RANK_SIZE=8
|
||||||
export RANK_TABLE_FILE=${execute_path}/../serving_increment/hccl_8p.json
|
export RANK_TABLE_FILE=${execute_path}/../serving_increment/pangu_distributed/hccl_8p.json
|
||||||
export MODE=13B
|
export MODE=13B
|
||||||
|
export PARAM_INIT_TYPE=fp16
|
||||||
export STRATEGY=$1
|
export STRATEGY=$1
|
||||||
export CKPT_PATH=$2
|
export CKPT_PATH=$2
|
||||||
export CKPT_NAME=$3
|
export CKPT_NAME='filerted'
|
||||||
export PARAM_INIT_TYPE=$4
|
|
||||||
|
|
||||||
for((i=0;i<$RANK_SIZE;i++));
|
for((i=0;i<$RANK_SIZE;i++));
|
||||||
do
|
do
|
||||||
|
@ -35,4 +35,4 @@ do
|
||||||
python -s ${self_path}/../predict.py --strategy_load_ckpt_path=$STRATEGY --load_ckpt_path=$CKPT_PATH \
|
python -s ${self_path}/../predict.py --strategy_load_ckpt_path=$STRATEGY --load_ckpt_path=$CKPT_PATH \
|
||||||
--load_ckpt_name=$CKPT_NAME --mode=$MODE --run_type=predict --param_init_type=$PARAM_INIT_TYPE \
|
--load_ckpt_name=$CKPT_NAME --mode=$MODE --run_type=predict --param_init_type=$PARAM_INIT_TYPE \
|
||||||
--export=1 >log$i.log 2>&1 &
|
--export=1 >log$i.log 2>&1 &
|
||||||
done
|
done
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
execute_path=$(pwd)
|
||||||
|
script_self=$(readlink -f "$0")
|
||||||
|
self_path=$(dirname "${script_self}")
|
||||||
|
export RANK_SIZE=1
|
||||||
|
export MODE=13B # or 2.6B
|
||||||
|
export PARAM_INIT_TYPE=fp16
|
||||||
|
export STRATEGY=$1
|
||||||
|
export CKPT_PATH=$2
|
||||||
|
export DEVICE_TARGET=Ascend # or GPU
|
||||||
|
export CKPT_NAME='filerted'
|
||||||
|
|
||||||
|
for((i=0;i<$RANK_SIZE;i++));
|
||||||
|
do
|
||||||
|
rm -rf ${execute_path}/device_$i/
|
||||||
|
mkdir ${execute_path}/device_$i/
|
||||||
|
cd ${execute_path}/device_$i/ || exit
|
||||||
|
export RANK_ID=$i
|
||||||
|
export DEVICE_ID=$i
|
||||||
|
python -s ${self_path}/../predict.py --strategy_load_ckpt_path=$STRATEGY --load_ckpt_path=$CKPT_PATH \
|
||||||
|
--load_ckpt_name=$CKPT_NAME --mode=$MODE --run_type=predict --param_init_type=$PARAM_INIT_TYPE \
|
||||||
|
--export=1 --distribute=false --device_target=$DEVICE_TARGET >log$i.log 2>&1 &
|
||||||
|
done
|
|
@ -0,0 +1,39 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
execute_path=$(pwd)
|
||||||
|
script_self=$(readlink -f "$0")
|
||||||
|
self_path=$(dirname "${script_self}")
|
||||||
|
export RANK_SIZE=1
|
||||||
|
export STRATEGY=$1
|
||||||
|
export TOKENIZER=$2
|
||||||
|
export CKPT_PATH=$3
|
||||||
|
export CKPT_NAME=$4
|
||||||
|
export MODE=$5
|
||||||
|
export PARAM_INIT_TYPE=fp16
|
||||||
|
export DEVICE_TARGET=Ascend # or GPU
|
||||||
|
|
||||||
|
for((i=0;i<$RANK_SIZE;i++));
|
||||||
|
do
|
||||||
|
rm -rf ${execute_path}/device_$i/
|
||||||
|
mkdir ${execute_path}/device_$i/
|
||||||
|
cd ${execute_path}/device_$i/ || exit
|
||||||
|
export RANK_ID=$i
|
||||||
|
export DEVICE_ID=$i
|
||||||
|
python -s ${self_path}/../predict.py --strategy_load_ckpt_path=$STRATEGY --tokenizer_path=$TOKENIZER --load_ckpt_path=$CKPT_PATH \
|
||||||
|
--load_ckpt_name=$CKPT_NAME --mode=$MODE --run_type=predict --param_init_type=$PARAM_INIT_TYPE \
|
||||||
|
--distribute=false --device_target=$DEVICE_TARGET >log$i.log 2>&1 &
|
||||||
|
done
|
|
@ -1,109 +0,0 @@
|
||||||
{
|
|
||||||
"board_id": "0x0020",
|
|
||||||
"chip_info": "910",
|
|
||||||
"deploy_mode": "lab",
|
|
||||||
"group_count": "1",
|
|
||||||
"group_list": [
|
|
||||||
{
|
|
||||||
"device_num": "8",
|
|
||||||
"server_num": "1",
|
|
||||||
"group_name": "",
|
|
||||||
"instance_count": "8",
|
|
||||||
"instance_list": [
|
|
||||||
{
|
|
||||||
"devices": [
|
|
||||||
{
|
|
||||||
"device_id": "0",
|
|
||||||
"device_ip": "192.98.92.121"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"rank_id": "0",
|
|
||||||
"server_id": "127.0.0.1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"devices": [
|
|
||||||
{
|
|
||||||
"device_id": "1",
|
|
||||||
"device_ip": "192.98.93.121"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"rank_id": "1",
|
|
||||||
"server_id": "127.0.0.1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"devices": [
|
|
||||||
{
|
|
||||||
"device_id": "2",
|
|
||||||
"device_ip": "192.98.94.121"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"rank_id": "2",
|
|
||||||
"server_id": "127.0.0.1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"devices": [
|
|
||||||
{
|
|
||||||
"device_id": "3",
|
|
||||||
"device_ip": "192.98.95.121"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"rank_id": "3",
|
|
||||||
"server_id": "127.0.0.1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"devices": [
|
|
||||||
{
|
|
||||||
"device_id": "4",
|
|
||||||
"device_ip": "192.98.92.122"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"rank_id": "4",
|
|
||||||
"server_id": "127.0.0.1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"devices": [
|
|
||||||
{
|
|
||||||
"device_id": "5",
|
|
||||||
"device_ip": "192.98.93.122"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"rank_id": "5",
|
|
||||||
"server_id": "127.0.0.1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"devices": [
|
|
||||||
{
|
|
||||||
"device_id": "6",
|
|
||||||
"device_ip": "192.98.94.122"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"rank_id": "6",
|
|
||||||
"server_id": "127.0.0.1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"devices": [
|
|
||||||
{
|
|
||||||
"device_id": "7",
|
|
||||||
"device_ip": "192.98.95.122"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"rank_id": "7",
|
|
||||||
"server_id": "127.0.0.1"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"para_plane_nic_location": "device",
|
|
||||||
"para_plane_nic_name": [
|
|
||||||
"eth0",
|
|
||||||
"eth1",
|
|
||||||
"eth2",
|
|
||||||
"eth3",
|
|
||||||
"eth4",
|
|
||||||
"eth5",
|
|
||||||
"eth6",
|
|
||||||
"eth7"
|
|
||||||
],
|
|
||||||
"para_plane_nic_num": "8",
|
|
||||||
"status": "completed"
|
|
||||||
}
|
|
|
@ -22,8 +22,8 @@ def start():
|
||||||
"""Start agents to load and execute models of pangu alpha"""
|
"""Start agents to load and execute models of pangu alpha"""
|
||||||
model_files = []
|
model_files = []
|
||||||
for i in range(8):
|
for i in range(8):
|
||||||
model_files.append([f"models/device{i}/pangu_alpha_1024_graph.mindir",
|
model_files.append([f"models/device_{i}/pangu_alpha_1024_graph.mindir",
|
||||||
f"models/device{i}/pangu_alpha_1_graph.mindir"])
|
f"models/device_{i}/pangu_alpha_1_graph.mindir"])
|
||||||
distributed.startup_agents(distributed_address="0.0.0.0:6200", model_files=model_files)
|
distributed.startup_agents(distributed_address="0.0.0.0:6200", model_files=model_files)
|
||||||
|
|
||||||
|
|
|
@ -21,8 +21,7 @@ from mindspore_serving.server import distributed
|
||||||
|
|
||||||
def start():
|
def start():
|
||||||
"""Start server to serve service, and manage all agents which load and execute models"""
|
"""Start server to serve service, and manage all agents which load and execute models"""
|
||||||
servable_dir = os.path.abspath(".")
|
servable_dir = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
||||||
distributed.start_servable(servable_dir, "pangu", rank_table_json_file="hccl_8p.json",
|
distributed.start_servable(servable_dir, "pangu", rank_table_json_file="hccl_8p.json",
|
||||||
distributed_address="0.0.0.0:6200")
|
distributed_address="0.0.0.0:6200")
|
||||||
|
|
|
@ -0,0 +1,203 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""servable config for pangu alpha"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from easydict import EasyDict
|
||||||
|
import numpy as np
|
||||||
|
from mindspore_serving.server import register
|
||||||
|
|
||||||
|
from pangu.tokenization_jieba import JIEBATokenizer
|
||||||
|
|
||||||
|
cur_dir = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
tokenizer_path = os.path.join(cur_dir, "tokenizer")
|
||||||
|
tokenizer = JIEBATokenizer(os.path.join(tokenizer_path, "vocab.vocab"), os.path.join(tokenizer_path, "vocab.model"))
|
||||||
|
end_token = tokenizer.eot_id
|
||||||
|
|
||||||
|
config = EasyDict({
|
||||||
|
'frequency_penalty': 1.5,
|
||||||
|
'presence_penalty': 0.3,
|
||||||
|
'max_generate_length': 500,
|
||||||
|
'top_k_num': 3,
|
||||||
|
'top_p': 1.0,
|
||||||
|
'end_token': 9,
|
||||||
|
'seq_length': 1024,
|
||||||
|
'vocab_size': 40000,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def topk_fun(logits, topk=5):
|
||||||
|
"""Get topk"""
|
||||||
|
target_column = logits[0].tolist()
|
||||||
|
sorted_array = [(k, v) for k, v in enumerate(target_column)]
|
||||||
|
sorted_array.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
topk_array = sorted_array[:topk]
|
||||||
|
index, value = zip(*topk_array)
|
||||||
|
index = np.array([index])
|
||||||
|
value = np.array([value])
|
||||||
|
return value, index
|
||||||
|
|
||||||
|
|
||||||
|
register.declare_servable(servable_file=["pangu_alpha_1024_graph.mindir", "pangu_alpha_1_graph.mindir"],
|
||||||
|
model_format="MINDIR", with_batch_dim=False)
|
||||||
|
|
||||||
|
|
||||||
|
@register.register_method(output_names=["logits"])
|
||||||
|
def predict_sub0(input_ids, current_index, init, batch_valid_length):
|
||||||
|
logits = register.call_servable(input_ids, current_index, init, batch_valid_length, subgraph=0)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
@register.register_method(output_names=["logits"])
|
||||||
|
def predict_sub1(input_id, current_index, init, batch_valid_length):
|
||||||
|
logits = register.call_servable(input_id, current_index, init, batch_valid_length, subgraph=1)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
sub0_servable = register.PipelineServable(servable_name="pangu", method="predict_sub0")
|
||||||
|
sub1_servable = register.PipelineServable(servable_name="pangu", method="predict_sub1")
|
||||||
|
|
||||||
|
|
||||||
|
@register.register_pipeline(output_names=["output_sentence"])
|
||||||
|
def predict(input_sentence):
|
||||||
|
"""generate sentence with given input_sentence"""
|
||||||
|
|
||||||
|
print(f"----------------------------- begin {input_sentence} ---------", flush=True)
|
||||||
|
time_start = time.time()
|
||||||
|
|
||||||
|
tokens = tokenizer.tokenize(input_sentence)
|
||||||
|
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
|
||||||
|
outputs = generate_increment(input_ids)
|
||||||
|
|
||||||
|
return_tokens = tokenizer.convert_ids_to_tokens(outputs)
|
||||||
|
reply = "".join(return_tokens)
|
||||||
|
|
||||||
|
print(f"time cost {(time.time() - time_start) * 1000}ms, request '{input_sentence}' get reply '{reply}'",
|
||||||
|
flush=True)
|
||||||
|
|
||||||
|
return reply
|
||||||
|
|
||||||
|
|
||||||
|
def generate_increment(origin_inputs):
|
||||||
|
"""
|
||||||
|
Text generation for incremental inference
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
model: the model for inferencing
|
||||||
|
origin_inputs: the original inputs based on which the model will continue writing
|
||||||
|
config: inference configurations
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
outputs: the ids for the generated text
|
||||||
|
"""
|
||||||
|
# Get configurations for inference
|
||||||
|
frequency_penalty = config.frequency_penalty
|
||||||
|
presence_penalty = config.presence_penalty
|
||||||
|
top_p = config.top_p
|
||||||
|
top_k_num = config.top_k_num
|
||||||
|
max_generate_length = config.max_generate_length
|
||||||
|
seq_length = config.seq_length
|
||||||
|
vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
# Init outputs with original inputs
|
||||||
|
outputs = origin_inputs
|
||||||
|
origin_inputs = np.array([origin_inputs])
|
||||||
|
_, valid_length = origin_inputs.shape
|
||||||
|
# If target length exceeds seq_length, use seq_length instead
|
||||||
|
target_length = valid_length + max_generate_length
|
||||||
|
target_length = seq_length if target_length > seq_length else target_length
|
||||||
|
|
||||||
|
# A list of the frequency of each token
|
||||||
|
frequency_list = np.array([[0 for _ in range(vocab_size)]])
|
||||||
|
pad_length = seq_length - origin_inputs.shape[-1]
|
||||||
|
# Pad original inputs to seq_length
|
||||||
|
input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), 'constant', constant_values=(0, 0))
|
||||||
|
|
||||||
|
# Indicate the exact token position
|
||||||
|
current_index = valid_length - 1 if valid_length - 1 > 0 else 0
|
||||||
|
current_index = np.array([current_index], np.int32)
|
||||||
|
batch_valid_length = np.array([current_index], np.int32)
|
||||||
|
# For first graph, not_init should be false
|
||||||
|
init_true = True
|
||||||
|
init_false = False
|
||||||
|
init = init_false
|
||||||
|
# Call a single inference with input size of (bs, seq_length)
|
||||||
|
logits = sub0_servable.run(np.array(input_ids, np.int32), current_index, init, batch_valid_length)
|
||||||
|
|
||||||
|
# Claim the second graph and set not_init to true
|
||||||
|
init = init_true
|
||||||
|
|
||||||
|
# A single loop generates one token, loop until reaching target seq_length or generating eod token
|
||||||
|
while valid_length < target_length:
|
||||||
|
# Reshape the output logits
|
||||||
|
log_probs = logits.reshape(1, vocab_size)
|
||||||
|
|
||||||
|
# Get the revised log_probs considering frequency and presence penalty to eliminate duplicate in generated results
|
||||||
|
log_probs = log_probs.reshape(1, vocab_size)
|
||||||
|
log_probs_revised = log_probs - frequency_list * frequency_penalty - (frequency_list > 0) * presence_penalty
|
||||||
|
|
||||||
|
# Convert the log_probs to probability
|
||||||
|
logits = np.power(10, np.array(log_probs_revised, np.float32))
|
||||||
|
|
||||||
|
# If top_p is less than 1.0, use top_p sampling
|
||||||
|
if top_p < 1.0:
|
||||||
|
# Only consider the 5000 largest logits to reduce computation
|
||||||
|
sorted_logits, index = topk_fun(logits, 5000)
|
||||||
|
cumsum_logits = np.cumsum(sorted_logits, 1)
|
||||||
|
cumsum_logits = cumsum_logits[0]
|
||||||
|
index = index[0]
|
||||||
|
sorted_logits = sorted_logits[0]
|
||||||
|
top_p_num = sum(cumsum_logits > top_p)
|
||||||
|
# In case the probability is smooth, the sum of 5000 largest probabilities are not large enough
|
||||||
|
if top_p_num == 0:
|
||||||
|
top_p_num = 5000
|
||||||
|
# Get the corresponding probs and indices
|
||||||
|
probs = sorted_logits[:top_p_num]
|
||||||
|
p_args = index[:top_p_num]
|
||||||
|
p = probs / sum(probs)
|
||||||
|
# if top_p is set to 1.0, use top_k sampling
|
||||||
|
else:
|
||||||
|
# Get the corresponding probs and indices
|
||||||
|
probs, p_args = topk_fun(logits, top_k_num)
|
||||||
|
probs = probs[0]
|
||||||
|
p_args = p_args[0]
|
||||||
|
# Avoid rounding error
|
||||||
|
if sum(probs) == 0:
|
||||||
|
probs = np.array([1 / top_k_num for _ in range(top_k_num)])
|
||||||
|
p = probs / sum(probs)
|
||||||
|
|
||||||
|
# Random select a token as final output for this round
|
||||||
|
target_index = np.random.choice(len(p), p=p)
|
||||||
|
# Stop judgment
|
||||||
|
if p_args[target_index] == end_token or valid_length == target_length - 1:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Update frequency list
|
||||||
|
target = p_args[target_index]
|
||||||
|
frequency_list[0][target] = frequency_list[0][target] + 1
|
||||||
|
valid_length += 1
|
||||||
|
|
||||||
|
batch_valid_length = np.array([valid_length - 1], np.int32)
|
||||||
|
current_index = np.array([0], np.int32)
|
||||||
|
input_id = np.array([[target]], np.int32)
|
||||||
|
# Update outputs with current generated token
|
||||||
|
outputs.append(int(target))
|
||||||
|
|
||||||
|
# Call a single inference with input size of (bs, 1)
|
||||||
|
logits = sub1_servable.run(input_id, current_index, init, batch_valid_length)
|
||||||
|
# Return valid outputs out of padded outputs
|
||||||
|
return outputs
|
|
@ -0,0 +1,76 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Tokenization classes for OpenAI GPT."""
|
||||||
|
from __future__ import (absolute_import, division, print_function,
|
||||||
|
unicode_literals)
|
||||||
|
|
||||||
|
from io import open
|
||||||
|
import jieba
|
||||||
|
import sentencepiece as spm
|
||||||
|
|
||||||
|
|
||||||
|
class JIEBATokenizer:
|
||||||
|
"""jieba tokenizer for encode and decode text"""
|
||||||
|
|
||||||
|
def __init__(self, vocab_file, model_file, max_len=None):
|
||||||
|
self.max_len = max_len if max_len is not None else int(1e12)
|
||||||
|
# self.encoder = json.load(open(vocab_file))
|
||||||
|
f = open(vocab_file, 'r')
|
||||||
|
lines = f.readlines()
|
||||||
|
self.encoder = {}
|
||||||
|
for line in enumerate(lines):
|
||||||
|
key = line[1].split('\t')[0]
|
||||||
|
self.encoder[key] = line[0]
|
||||||
|
|
||||||
|
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||||
|
|
||||||
|
self.sp = spm.SentencePieceProcessor(model_file=model_file)
|
||||||
|
self.translator = str.maketrans(" \n", "\u2582\u2583")
|
||||||
|
|
||||||
|
self.eod_id = self.encoder['<eod>']
|
||||||
|
self.eot_id = self.encoder['<eot>']
|
||||||
|
self.pad_id = self.encoder['<pad>']
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self):
|
||||||
|
return len(self.encoder)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.encoder) + len(self.special_tokens)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eod(self):
|
||||||
|
return self.eod_id
|
||||||
|
|
||||||
|
def tokenize(self, text):
|
||||||
|
""" Tokenize a string. """
|
||||||
|
seg_list = [x.translate(self.translator) for x in jieba.cut(text, cut_all=False)]
|
||||||
|
new_seg = " ".join(seg_list)
|
||||||
|
return self.sp.encode(new_seg)
|
||||||
|
|
||||||
|
def convert_tokens_to_ids(self, tokens):
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def convert_ids_to_tokens(self, ids):
|
||||||
|
return self.decode(ids)
|
||||||
|
|
||||||
|
def encode(self, text):
|
||||||
|
res = self.tokenize(text)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def decode(self, tokens):
|
||||||
|
text = self.sp.decode(tokens)
|
||||||
|
text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n')
|
||||||
|
return text
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Serving server start code, serve service, and manage all agents which load and execute models"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from mindspore_serving import server
|
||||||
|
|
||||||
|
|
||||||
|
def start():
|
||||||
|
servable_dir = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
config = server.ServableStartConfig(servable_directory=servable_dir, servable_name="pangu", device_ids=(0,))
|
||||||
|
server.start_servables(config)
|
||||||
|
|
||||||
|
server.start_grpc_server("127.0.0.1:5500")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
start()
|
|
@ -21,7 +21,7 @@ export GLOG_v=1
|
||||||
start_serving_server()
|
start_serving_server()
|
||||||
{
|
{
|
||||||
echo "### start serving server, see serving_server.log for detail ###"
|
echo "### start serving server, see serving_server.log for detail ###"
|
||||||
python3 serving_server.py > serving_server.log 2>&1 &
|
python3 pangu_distributed/serving_server.py > serving_server.log 2>&1 &
|
||||||
if [ $? -ne 0 ]
|
if [ $? -ne 0 ]
|
||||||
then
|
then
|
||||||
echo "serving server failed to start."
|
echo "serving server failed to start."
|
||||||
|
@ -52,8 +52,8 @@ start_serving_server()
|
||||||
|
|
||||||
start_serving_agent()
|
start_serving_agent()
|
||||||
{
|
{
|
||||||
echo "### start serving agent, see and serving_logs/log_pangu_distributed.log for detail ###"
|
echo "### start serving agent, see serving_agent.log and serving_logs/log_pangu_distributed.log for detail ###"
|
||||||
python3 serving_agent.py > serving_agent.log 2>&1 &
|
python3 pangu_distributed/serving_agent.py > serving_agent.log 2>&1 &
|
||||||
if [ $? -ne 0 ]
|
if [ $? -ne 0 ]
|
||||||
then
|
then
|
||||||
echo "serving agent failed to start."
|
echo "serving agent failed to start."
|
||||||
|
@ -68,7 +68,7 @@ start_serving_agent()
|
||||||
if [ $num -eq 0 ]
|
if [ $num -eq 0 ]
|
||||||
then
|
then
|
||||||
bash stop_pangu.sh
|
bash stop_pangu.sh
|
||||||
echo "start serving agent failed, see log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
|
echo "start serving agent failed, see serving_agent.log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
count=$(($count+1))
|
count=$(($count+1))
|
||||||
|
@ -78,7 +78,7 @@ start_serving_agent()
|
||||||
if [ ${count} -eq 1800 ]
|
if [ ${count} -eq 1800 ]
|
||||||
then
|
then
|
||||||
bash stop_pangu.sh
|
bash stop_pangu.sh
|
||||||
echo "start serving agent failed, see log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
|
echo "start serving agent failed, see serving_agent.log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
|
||||||
fi
|
fi
|
||||||
echo "### start serving agent end ###"
|
echo "### start serving agent end ###"
|
||||||
}
|
}
|
||||||
|
@ -131,7 +131,7 @@ wait_serving_ready()
|
||||||
if [ $num -eq 0 ]
|
if [ $num -eq 0 ]
|
||||||
then
|
then
|
||||||
bash stop_pangu.sh
|
bash stop_pangu.sh
|
||||||
echo "waiting serving server ready failed, see log serving_server.log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
|
echo "waiting serving server ready failed, see serving_server.log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
count=$(($count+1))
|
count=$(($count+1))
|
||||||
|
@ -141,7 +141,7 @@ wait_serving_ready()
|
||||||
if [ ${count} -eq 100 ]
|
if [ ${count} -eq 100 ]
|
||||||
then
|
then
|
||||||
bash stop_pangu.sh
|
bash stop_pangu.sh
|
||||||
echo "waiting serving server ready failed, see log serving_server.log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
|
echo "waiting serving server ready failed, see serving_server.log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
|
||||||
fi
|
fi
|
||||||
echo "### waiting serving server ready end ###"
|
echo "### waiting serving server ready end ###"
|
||||||
}
|
}
|
||||||
|
@ -150,5 +150,5 @@ bash stop_pangu.sh
|
||||||
rm -rf serving_server.log serving_agent.log flask.log serving_logs
|
rm -rf serving_server.log serving_agent.log flask.log serving_logs
|
||||||
start_serving_server
|
start_serving_server
|
||||||
start_serving_agent
|
start_serving_agent
|
||||||
start_flask
|
|
||||||
wait_serving_ready
|
wait_serving_ready
|
||||||
|
start_flask
|
|
@ -0,0 +1,120 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
unset http_proxy
|
||||||
|
unset https_proxy
|
||||||
|
export GLOG_v=1
|
||||||
|
|
||||||
|
start_serving_server()
|
||||||
|
{
|
||||||
|
echo "### start serving server, see serving_server.log for detail ###"
|
||||||
|
python3 pangu_standalone/serving_server.py > serving_server.log 2>&1 &
|
||||||
|
if [ $? -ne 0 ]
|
||||||
|
then
|
||||||
|
echo "serving server failed to start."
|
||||||
|
fi
|
||||||
|
|
||||||
|
result=`grep -E 'Master server start success' serving_server.log | wc -l`
|
||||||
|
count=0
|
||||||
|
while [[ ${result} -eq 0 && ${count} -lt 100 ]]
|
||||||
|
do
|
||||||
|
sleep 1
|
||||||
|
|
||||||
|
num=`ps -ef | grep 'serving_server.py' | grep -v grep | wc -l`
|
||||||
|
if [ $num -eq 0 ]
|
||||||
|
then
|
||||||
|
echo "start serving server failed, see log serving_server.log for more detail" && exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
count=$(($count+1))
|
||||||
|
result=`grep -E 'Master server start success' serving_server.log | wc -l`
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ ${count} -eq 100 ]
|
||||||
|
then
|
||||||
|
echo "start serving server failed, see log serving_server.log for more detail" && exit 1
|
||||||
|
fi
|
||||||
|
echo "### start serving server end ###"
|
||||||
|
}
|
||||||
|
|
||||||
|
start_flask()
|
||||||
|
{
|
||||||
|
echo "### start flask server, see flask.log for detail ###"
|
||||||
|
python3 flask/client.py > flask.log 2>&1 &
|
||||||
|
if [ $? -ne 0 ]
|
||||||
|
then
|
||||||
|
echo "flask server failed to start."
|
||||||
|
fi
|
||||||
|
|
||||||
|
result=`grep -E 'Press CTRL\+C to quit' flask.log | wc -l`
|
||||||
|
count=0
|
||||||
|
while [[ ${result} -ne 1 && ${count} -lt 10 ]]
|
||||||
|
do
|
||||||
|
sleep 1
|
||||||
|
|
||||||
|
num=`ps -ef | grep 'flask/client.py' | grep -v grep | wc -l`
|
||||||
|
if [ $num -eq 0 ]
|
||||||
|
then
|
||||||
|
bash stop_pangu.sh
|
||||||
|
echo "start flask server failed, see log flask.log for more detail" && exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
count=$(($count+1))
|
||||||
|
result=`grep -E 'Press CTRL\+C to quit' flask.log | wc -l`
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ ${count} -eq 10 ]
|
||||||
|
then
|
||||||
|
bash stop_pangu.sh
|
||||||
|
echo "start flask server failed, see log flask.log for more detail" && exit 1
|
||||||
|
fi
|
||||||
|
echo "### start flask server end ###"
|
||||||
|
cat flask.log
|
||||||
|
}
|
||||||
|
|
||||||
|
wait_serving_ready()
|
||||||
|
{
|
||||||
|
echo "### waiting serving server ready, see and serving_logs/*.log for detail ###"
|
||||||
|
result=`grep -E 'gRPC server start success' serving_server.log | wc -l`
|
||||||
|
count=0
|
||||||
|
while [[ ${result} -eq 0 && ${count} -lt 1800 ]]
|
||||||
|
do
|
||||||
|
sleep 1
|
||||||
|
|
||||||
|
num=`ps -ef | grep 'serving_server.py' | grep -v grep | wc -l`
|
||||||
|
if [ $num -eq 0 ]
|
||||||
|
then
|
||||||
|
bash stop_pangu.sh
|
||||||
|
echo "waiting serving server ready failed, see log serving_server.log and serving_logs/*.log for more detail" && exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
count=$(($count+1))
|
||||||
|
result=`grep -E 'gRPC server start success' serving_server.log | wc -l`
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ ${count} -eq 1800 ]
|
||||||
|
then
|
||||||
|
bash stop_pangu.sh
|
||||||
|
echo "waiting serving server ready failed, see log serving_server.log and serving_logs/*.log for more detail" && exit 1
|
||||||
|
fi
|
||||||
|
echo "### waiting serving server ready end ###"
|
||||||
|
}
|
||||||
|
|
||||||
|
bash stop_pangu.sh
|
||||||
|
rm -rf serving_server.log flask.log serving_logs
|
||||||
|
start_serving_server
|
||||||
|
wait_serving_ready
|
||||||
|
start_flask
|
|
@ -22,6 +22,7 @@ import mindspore.common.dtype as mstype
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
def topk_fun(logits, topk=5):
|
def topk_fun(logits, topk=5):
|
||||||
"""Get topk"""
|
"""Get topk"""
|
||||||
target_column = logits[0].tolist()
|
target_column = logits[0].tolist()
|
||||||
|
@ -33,6 +34,7 @@ def topk_fun(logits, topk=5):
|
||||||
value = np.array([value])
|
value = np.array([value])
|
||||||
return value, index
|
return value, index
|
||||||
|
|
||||||
|
|
||||||
def sampler(log_probs_revised, top_p, top_k_num, use_pynative=False):
|
def sampler(log_probs_revised, top_p, top_k_num, use_pynative=False):
|
||||||
"""Convert the log_probs to probability"""
|
"""Convert the log_probs to probability"""
|
||||||
if use_pynative:
|
if use_pynative:
|
||||||
|
@ -80,6 +82,7 @@ def sampler(log_probs_revised, top_p, top_k_num, use_pynative=False):
|
||||||
p = probs / sum(probs)
|
p = probs / sum(probs)
|
||||||
return p, p_args
|
return p, p_args
|
||||||
|
|
||||||
|
|
||||||
def generate(model, origin_inputs, config):
|
def generate(model, origin_inputs, config):
|
||||||
"""
|
"""
|
||||||
Text generation
|
Text generation
|
||||||
|
@ -130,7 +133,7 @@ def generate(model, origin_inputs, config):
|
||||||
# Random select a token as final output for this round
|
# Random select a token as final output for this round
|
||||||
target_index = np.random.choice(len(p), p=p)
|
target_index = np.random.choice(len(p), p=p)
|
||||||
# Stop judgment
|
# Stop judgment
|
||||||
if p_args[target_index] == end_token or valid_length == target_length-1:
|
if p_args[target_index] == end_token or valid_length == target_length - 1:
|
||||||
outputs = input_ids
|
outputs = input_ids
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -145,6 +148,7 @@ def generate(model, origin_inputs, config):
|
||||||
outputs = outputs[0][:length]
|
outputs = outputs[0][:length]
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def generate_increment(model, origin_inputs, config):
|
def generate_increment(model, origin_inputs, config):
|
||||||
"""
|
"""
|
||||||
Text generation for incremental inference
|
Text generation for incremental inference
|
||||||
|
@ -183,8 +187,8 @@ def generate_increment(model, origin_inputs, config):
|
||||||
|
|
||||||
# Indicate the exact token position
|
# Indicate the exact token position
|
||||||
current_index = valid_length - 1 if valid_length - 1 > 0 else 0
|
current_index = valid_length - 1 if valid_length - 1 > 0 else 0
|
||||||
current_index = Tensor(np.array([current_index]), mstype.int32)
|
|
||||||
batch_valid_length = Tensor(np.array([current_index]), mstype.int32)
|
batch_valid_length = Tensor(np.array([current_index]), mstype.int32)
|
||||||
|
current_index = Tensor(np.array([current_index]), mstype.int32)
|
||||||
# For first graph, not_init should be false
|
# For first graph, not_init should be false
|
||||||
init_true = Tensor([True], mstype.bool_)
|
init_true = Tensor([True], mstype.bool_)
|
||||||
init_false = Tensor([False], mstype.bool_)
|
init_false = Tensor([False], mstype.bool_)
|
||||||
|
@ -211,7 +215,7 @@ def generate_increment(model, origin_inputs, config):
|
||||||
# Random select a token as final output for this round
|
# Random select a token as final output for this round
|
||||||
target_index = np.random.choice(len(p), p=p)
|
target_index = np.random.choice(len(p), p=p)
|
||||||
# Stop judgment
|
# Stop judgment
|
||||||
if p_args[target_index] == end_token or valid_length == target_length-1:
|
if p_args[target_index] == end_token or valid_length == target_length - 1:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Update frequency list
|
# Update frequency list
|
||||||
|
|
|
@ -27,10 +27,12 @@ from mindspore import context
|
||||||
from mindspore.common.seed import _get_graph_seed
|
from mindspore.common.seed import _get_graph_seed
|
||||||
from mindspore._checkparam import Validator
|
from mindspore._checkparam import Validator
|
||||||
|
|
||||||
|
|
||||||
class Dropout(nn.Cell):
|
class Dropout(nn.Cell):
|
||||||
r"""
|
r"""
|
||||||
A Dropout Implements with P.DropoutGenMask and P.DropoutDoMask for parallel training.
|
A Dropout Implements with P.DropoutGenMask and P.DropoutDoMask for parallel training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, keep_prob=0.5, dtype=mstype.float32):
|
def __init__(self, keep_prob=0.5, dtype=mstype.float32):
|
||||||
super(Dropout, self).__init__()
|
super(Dropout, self).__init__()
|
||||||
if keep_prob <= 0 or keep_prob > 1:
|
if keep_prob <= 0 or keep_prob > 1:
|
||||||
|
@ -52,6 +54,7 @@ class Dropout(nn.Cell):
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
else:
|
else:
|
||||||
self.dropout = P.Dropout(keep_prob)
|
self.dropout = P.Dropout(keep_prob)
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
r"""
|
r"""
|
||||||
Input: a tensor
|
Input: a tensor
|
||||||
|
@ -83,10 +86,12 @@ class Dropout(nn.Cell):
|
||||||
else:
|
else:
|
||||||
self.dropout.shard(strategy)
|
self.dropout.shard(strategy)
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.Cell):
|
class LayerNorm(nn.Cell):
|
||||||
r"""
|
r"""
|
||||||
A self-defined layer norm operation using reduce sum and reduce mean
|
A self-defined layer norm operation using reduce sum and reduce mean
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, normalized_shape, dp=4, eps=1e-5, parallel_optimizer=False):
|
def __init__(self, normalized_shape, dp=4, eps=1e-5, parallel_optimizer=False):
|
||||||
super(LayerNorm, self).__init__()
|
super(LayerNorm, self).__init__()
|
||||||
self.gamma = Parameter(initializer('ones', normalized_shape), name="gamma",
|
self.gamma = Parameter(initializer('ones', normalized_shape), name="gamma",
|
||||||
|
@ -102,6 +107,7 @@ class LayerNorm(nn.Cell):
|
||||||
self.add2 = P.TensorAdd().shard(((dp, 1, 1), (1,)))
|
self.add2 = P.TensorAdd().shard(((dp, 1, 1), (1,)))
|
||||||
self.real_div = P.RealDiv().shard(((dp, 1, 1), (dp, 1, 1)))
|
self.real_div = P.RealDiv().shard(((dp, 1, 1), (dp, 1, 1)))
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
mean = self.mean(x, -1)
|
mean = self.mean(x, -1)
|
||||||
diff = self.sub1(x, mean)
|
diff = self.sub1(x, mean)
|
||||||
|
@ -111,6 +117,7 @@ class LayerNorm(nn.Cell):
|
||||||
output = self.add2(self.mul(output, self.gamma), self.beta)
|
output = self.add2(self.mul(output, self.gamma), self.beta)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class Mapping(nn.Cell):
|
class Mapping(nn.Cell):
|
||||||
"""
|
"""
|
||||||
A mapping function with a 3d input
|
A mapping function with a 3d input
|
||||||
|
@ -167,6 +174,7 @@ class MappingOutput(nn.Cell):
|
||||||
Returns:
|
Returns:
|
||||||
output: Tensor, a 3d tensor after projection
|
output: Tensor, a 3d tensor after projection
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, input_size, output_size, scale=1.0):
|
def __init__(self, config, input_size, output_size, scale=1.0):
|
||||||
super(MappingOutput, self).__init__()
|
super(MappingOutput, self).__init__()
|
||||||
self.output_size = output_size
|
self.output_size = output_size
|
||||||
|
@ -205,6 +213,7 @@ class FeedForwardLayer(nn.Cell):
|
||||||
Returns:
|
Returns:
|
||||||
output: Tensor, the output of this layer after mapping
|
output: Tensor, the output of this layer after mapping
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, scale=1.0):
|
def __init__(self, config, scale=1.0):
|
||||||
super(FeedForwardLayer, self).__init__()
|
super(FeedForwardLayer, self).__init__()
|
||||||
input_size = config.embedding_size
|
input_size = config.embedding_size
|
||||||
|
@ -226,6 +235,7 @@ class FeedForwardLayer(nn.Cell):
|
||||||
output = self.dropout(output)
|
output = self.dropout(output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingLookup(nn.Cell):
|
class EmbeddingLookup(nn.Cell):
|
||||||
"""
|
"""
|
||||||
The embedding lookup table for vocabulary
|
The embedding lookup table for vocabulary
|
||||||
|
@ -236,6 +246,7 @@ class EmbeddingLookup(nn.Cell):
|
||||||
seq_length, embedding_size)
|
seq_length, embedding_size)
|
||||||
self.embedding_table: Tensor, the embedding table for the vocabulary
|
self.embedding_table: Tensor, the embedding table for the vocabulary
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(EmbeddingLookup, self).__init__()
|
super(EmbeddingLookup, self).__init__()
|
||||||
self.gather = P.GatherV2()
|
self.gather = P.GatherV2()
|
||||||
|
@ -244,6 +255,7 @@ class EmbeddingLookup(nn.Cell):
|
||||||
output = self.gather(table, input_ids, 0)
|
output = self.gather(table, input_ids, 0)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Cell):
|
class Attention(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Self-Attention module for each layer
|
Self-Attention module for each layer
|
||||||
|
@ -253,6 +265,7 @@ class Attention(nn.Cell):
|
||||||
scale: scale factor for initialization
|
scale: scale factor for initialization
|
||||||
layer_idx: current layer index
|
layer_idx: current layer index
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, scale=1.0, layer_idx=None):
|
def __init__(self, config, scale=1.0, layer_idx=None):
|
||||||
super(Attention, self).__init__()
|
super(Attention, self).__init__()
|
||||||
# Output layer
|
# Output layer
|
||||||
|
@ -295,7 +308,6 @@ class Attention(nn.Cell):
|
||||||
self.softmax.softmax.shard(((config.dp, config.mp, 1),))
|
self.softmax.softmax.shard(((config.dp, config.mp, 1),))
|
||||||
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
|
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
|
||||||
|
|
||||||
|
|
||||||
dense_shape = [config.embedding_size, config.embedding_size]
|
dense_shape = [config.embedding_size, config.embedding_size]
|
||||||
bias_shape = [config.embedding_size]
|
bias_shape = [config.embedding_size]
|
||||||
# Query
|
# Query
|
||||||
|
@ -399,7 +411,7 @@ class Attention(nn.Cell):
|
||||||
# The first graph with the input size of (bs, seq_length)
|
# The first graph with the input size of (bs, seq_length)
|
||||||
if self.is_first_iteration:
|
if self.is_first_iteration:
|
||||||
# Get the valid input length without padding
|
# Get the valid input length without padding
|
||||||
valid_length_vector = F.cast(self.less(self.range, batch_valid_length), self.dtype)
|
valid_length_vector = F.cast(self.less(self.range, batch_valid_length.view(1, 1, -1)), self.dtype)
|
||||||
# Cover the key and value numbers corresponding to the padding position
|
# Cover the key and value numbers corresponding to the padding position
|
||||||
key_present = self.mul1(key, self.expand_dims(valid_length_vector, 2))
|
key_present = self.mul1(key, self.expand_dims(valid_length_vector, 2))
|
||||||
value_present = self.mul1(value, self.expand_dims(valid_length_vector, 3))
|
value_present = self.mul1(value, self.expand_dims(valid_length_vector, 3))
|
||||||
|
@ -464,7 +476,7 @@ class Attention(nn.Cell):
|
||||||
x_merge: the 3d output
|
x_merge: the 3d output
|
||||||
"""
|
"""
|
||||||
x = self.merger_head_transpose(
|
x = self.merger_head_transpose(
|
||||||
x, (0, 2, 1, 3)) #bs, seq_length, head, size_per_head
|
x, (0, 2, 1, 3)) # bs, seq_length, head, size_per_head
|
||||||
x_shape = P.Shape()(x)
|
x_shape = P.Shape()(x)
|
||||||
new_shape = x_shape[:-2] + (x_shape[-2] * x_shape[-1],)
|
new_shape = x_shape[:-2] + (x_shape[-2] * x_shape[-1],)
|
||||||
x_merge = self.reshape(x, new_shape)
|
x_merge = self.reshape(x, new_shape)
|
||||||
|
@ -551,6 +563,7 @@ class Decoder(nn.Cell):
|
||||||
output: Tensor, the output logit of this layer
|
output: Tensor, the output logit of this layer
|
||||||
layer_present: Tensor, the feature map of current layer
|
layer_present: Tensor, the feature map of current layer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, layer_idx):
|
def __init__(self, config, layer_idx):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
scale = 1 / math.sqrt(2.0 * config.num_layers)
|
scale = 1 / math.sqrt(2.0 * config.num_layers)
|
||||||
|
@ -633,10 +646,12 @@ class Decoder(nn.Cell):
|
||||||
output = self.add(x, mlp_logit)
|
output = self.add(x, mlp_logit)
|
||||||
return output, layer_present
|
return output, layer_present
|
||||||
|
|
||||||
|
|
||||||
class Embedding(nn.Cell):
|
class Embedding(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Embedding
|
Embedding
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(Embedding, self).__init__()
|
super(Embedding, self).__init__()
|
||||||
self.word_embedding = EmbeddingLookup().set_comm_fusion(1)
|
self.word_embedding = EmbeddingLookup().set_comm_fusion(1)
|
||||||
|
@ -691,18 +706,22 @@ class Mask(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Mask
|
Mask
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(Mask, self).__init__()
|
super(Mask, self).__init__()
|
||||||
self.dtype = config.compute_dtype
|
self.dtype = config.compute_dtype
|
||||||
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
|
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
|
||||||
|
|
||||||
def construct(self, attention_mask):
|
def construct(self, attention_mask):
|
||||||
attention_mask = self.expand_dims(attention_mask, 1)
|
attention_mask = self.expand_dims(attention_mask, 1)
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
class QueryLayerAttention(Attention):
|
class QueryLayerAttention(Attention):
|
||||||
r"""
|
r"""
|
||||||
Self-Attention module using input query vector.
|
Self-Attention module using input query vector.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def construct(self, x, query_hidden_state, attention_mask, key_past=None, value_past=None, batch_valid_length=None):
|
def construct(self, x, query_hidden_state, attention_mask, key_past=None, value_past=None, batch_valid_length=None):
|
||||||
original_shape = F.shape(x)
|
original_shape = F.shape(x)
|
||||||
x = F.reshape(x, (-1, original_shape[-1]))
|
x = F.reshape(x, (-1, original_shape[-1]))
|
||||||
|
@ -732,7 +751,7 @@ class QueryLayerAttention(Attention):
|
||||||
# The first graph with the input size of (bs, seq_length)
|
# The first graph with the input size of (bs, seq_length)
|
||||||
if self.is_first_iteration:
|
if self.is_first_iteration:
|
||||||
# Get the valid input length without padding
|
# Get the valid input length without padding
|
||||||
valid_length_vector = F.cast(self.less(self.range, batch_valid_length), self.dtype)
|
valid_length_vector = F.cast(self.less(self.range, batch_valid_length.view(1, 1, -1)), self.dtype)
|
||||||
# Cover the key and value numbers corresponding to the padding position
|
# Cover the key and value numbers corresponding to the padding position
|
||||||
key_present = self.mul1(key, self.expand_dims(valid_length_vector, 2))
|
key_present = self.mul1(key, self.expand_dims(valid_length_vector, 2))
|
||||||
value_present = self.mul1(value, self.expand_dims(valid_length_vector, 3))
|
value_present = self.mul1(value, self.expand_dims(valid_length_vector, 3))
|
||||||
|
@ -767,11 +786,13 @@ class QueryLayerAttention(Attention):
|
||||||
output = self.dropout(output)
|
output = self.dropout(output)
|
||||||
return output, layer_present
|
return output, layer_present
|
||||||
|
|
||||||
|
|
||||||
class QueryLayer(nn.Cell):
|
class QueryLayer(nn.Cell):
|
||||||
r"""
|
r"""
|
||||||
A block usingooked out position embedding as query vector.
|
A block usingooked out position embedding as query vector.
|
||||||
This is used as the final block.
|
This is used as the final block.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(QueryLayer, self).__init__()
|
super(QueryLayer, self).__init__()
|
||||||
scale = 1 / math.sqrt(2.0 * config.num_layers)
|
scale = 1 / math.sqrt(2.0 * config.num_layers)
|
||||||
|
@ -854,6 +875,7 @@ class QueryLayer(nn.Cell):
|
||||||
output = self.add(x, mlp_logit)
|
output = self.add(x, mlp_logit)
|
||||||
return output, layer_present
|
return output, layer_present
|
||||||
|
|
||||||
|
|
||||||
class PanguAlphaEmbedding(nn.Cell):
|
class PanguAlphaEmbedding(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Input embedding, i.e., word embedding and position embedding
|
Input embedding, i.e., word embedding and position embedding
|
||||||
|
@ -870,6 +892,7 @@ class PanguAlphaEmbedding(nn.Cell):
|
||||||
attention_mask: Tensor, attention_mask matrix
|
attention_mask: Tensor, attention_mask matrix
|
||||||
embedding_table: Tensor, embedding_table with shape of (vocab_size, embedding_size)
|
embedding_table: Tensor, embedding_table with shape of (vocab_size, embedding_size)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(PanguAlphaEmbedding, self).__init__()
|
super(PanguAlphaEmbedding, self).__init__()
|
||||||
self.embedding = Embedding(config)
|
self.embedding = Embedding(config)
|
||||||
|
@ -885,6 +908,7 @@ class PanguAlphaEmbedding(nn.Cell):
|
||||||
attention_mask = self.mask(attention_mask)
|
attention_mask = self.mask(attention_mask)
|
||||||
return hidden_states, attention_mask
|
return hidden_states, attention_mask
|
||||||
|
|
||||||
|
|
||||||
class PanguAlpha_Model(nn.Cell):
|
class PanguAlpha_Model(nn.Cell):
|
||||||
"""
|
"""
|
||||||
The backbone of PanguAlpha network
|
The backbone of PanguAlpha network
|
||||||
|
@ -899,6 +923,7 @@ class PanguAlpha_Model(nn.Cell):
|
||||||
present_layer: Tensor, the current feature map
|
present_layer: Tensor, the current feature map
|
||||||
embedding_table: Tensor, the embedding table for the vocabulary
|
embedding_table: Tensor, the embedding table for the vocabulary
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(PanguAlpha_Model, self).__init__()
|
super(PanguAlpha_Model, self).__init__()
|
||||||
self.embedding = PanguAlphaEmbedding(config)
|
self.embedding = PanguAlphaEmbedding(config)
|
||||||
|
@ -971,13 +996,14 @@ class PanguAlpha_Model(nn.Cell):
|
||||||
hidden_states, attention_mask = self.embedding(input_ids, input_mask, table,
|
hidden_states, attention_mask = self.embedding(input_ids, input_mask, table,
|
||||||
input_position, attention_mask,
|
input_position, attention_mask,
|
||||||
batch_valid_length)
|
batch_valid_length)
|
||||||
for i in range(self.num_layers-1):
|
for i in range(self.num_layers - 1):
|
||||||
hidden_states, _ = self.blocks[i](hidden_states,
|
hidden_states, _ = self.blocks[i](hidden_states,
|
||||||
attention_mask, init_reset, batch_valid_length)
|
attention_mask, init_reset, batch_valid_length)
|
||||||
if self.is_pipeline:
|
if self.is_pipeline:
|
||||||
top_query_hidden_states = self.top_query_embedding(input_position.view(-1,), self.top_query_embedding_table)
|
top_query_hidden_states = self.top_query_embedding(input_position.view(-1,),
|
||||||
hidden_states, _ = self.blocks[self.num_layers-1](hidden_states, top_query_hidden_states,
|
self.top_query_embedding_table)
|
||||||
attention_mask, init_reset, batch_valid_length)
|
hidden_states, _ = self.blocks[self.num_layers - 1](hidden_states, top_query_hidden_states,
|
||||||
|
attention_mask, init_reset, batch_valid_length)
|
||||||
output_state = self.layernorm(hidden_states)
|
output_state = self.layernorm(hidden_states)
|
||||||
output_state = F.cast(output_state, self.dtype)
|
output_state = F.cast(output_state, self.dtype)
|
||||||
else:
|
else:
|
||||||
|
@ -988,6 +1014,7 @@ class PanguAlpha_Model(nn.Cell):
|
||||||
attention_mask, init_reset, batch_valid_length)
|
attention_mask, init_reset, batch_valid_length)
|
||||||
return output_state
|
return output_state
|
||||||
|
|
||||||
|
|
||||||
class PanguAlpha_Head(nn.Cell):
|
class PanguAlpha_Head(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Head for PanguAlpha to get the logits of each token in the vocab
|
Head for PanguAlpha to get the logits of each token in the vocab
|
||||||
|
@ -999,6 +1026,7 @@ class PanguAlpha_Head(nn.Cell):
|
||||||
Returns:
|
Returns:
|
||||||
logits: Tensor, the logits of the corresponding inputs
|
logits: Tensor, the logits of the corresponding inputs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(PanguAlpha_Head, self).__init__()
|
super(PanguAlpha_Head, self).__init__()
|
||||||
if config.word_emb_dp:
|
if config.word_emb_dp:
|
||||||
|
@ -1029,6 +1057,7 @@ class PanguAlpha(nn.Cell):
|
||||||
Returns:
|
Returns:
|
||||||
logits: Tensor: the logits of the corresponding inputs with shape (batch_size, seq_length, vocab_size)
|
logits: Tensor: the logits of the corresponding inputs with shape (batch_size, seq_length, vocab_size)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(PanguAlpha, self).__init__()
|
super(PanguAlpha, self).__init__()
|
||||||
# Network head to get logits over vocabulary
|
# Network head to get logits over vocabulary
|
||||||
|
@ -1065,6 +1094,7 @@ class PanguAlpha(nn.Cell):
|
||||||
logits = self.head(output_states, self.embedding_table)
|
logits = self.head(output_states, self.embedding_table)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
class CrossEntropyLoss(nn.Cell):
|
class CrossEntropyLoss(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Calculate the cross entropy loss
|
Calculate the cross entropy loss
|
||||||
|
@ -1077,6 +1107,7 @@ class CrossEntropyLoss(nn.Cell):
|
||||||
Returns:
|
Returns:
|
||||||
loss: Tensor, the corrsponding cross entropy loss
|
loss: Tensor, the corrsponding cross entropy loss
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(CrossEntropyLoss, self).__init__()
|
super(CrossEntropyLoss, self).__init__()
|
||||||
self.mean = P.ReduceMean()
|
self.mean = P.ReduceMean()
|
||||||
|
@ -1137,6 +1168,7 @@ class CrossEntropyLoss(nn.Cell):
|
||||||
loss = self.div2(numerator, denominator)
|
loss = self.div2(numerator, denominator)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class PanguAlphaWithLoss(nn.Cell):
|
class PanguAlphaWithLoss(nn.Cell):
|
||||||
"""
|
"""
|
||||||
PanguAlpha training loss
|
PanguAlpha training loss
|
||||||
|
@ -1150,6 +1182,7 @@ class PanguAlphaWithLoss(nn.Cell):
|
||||||
Returns:
|
Returns:
|
||||||
output: Tensor, the loss of the network
|
output: Tensor, the loss of the network
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, network, loss, eos_token=6):
|
def __init__(self, config, network, loss, eos_token=6):
|
||||||
super(PanguAlphaWithLoss, self).__init__(auto_prefix=False)
|
super(PanguAlphaWithLoss, self).__init__(auto_prefix=False)
|
||||||
self.network = network
|
self.network = network
|
||||||
|
@ -1172,6 +1205,40 @@ class PanguAlphaWithLoss(nn.Cell):
|
||||||
output = self.loss(logits, labels, input_mask)
|
output = self.loss(logits, labels, input_mask)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionMask(nn.Cell):
|
||||||
|
"""
|
||||||
|
Get the attention matrix for self-attention module
|
||||||
|
Args:
|
||||||
|
seq_length: the pre-defined sequence length
|
||||||
|
Inputs:
|
||||||
|
input_mask: the mask indicating whether each position is a valid input
|
||||||
|
Returns:
|
||||||
|
attention_mask: the attention mask matrix with shape (batch_size, seq_length, seq_length)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, seq_length):
|
||||||
|
super(AttentionMask, self).__init__()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.mul = P.BatchMatMul().shard(
|
||||||
|
((1, 1, 1), (1, 1, 1)))
|
||||||
|
self.expand_dim = P.ExpandDims().shard(((1, 1),))
|
||||||
|
ones = np.ones(shape=(seq_length, seq_length))
|
||||||
|
self.lower_triangle_mask = Tensor(np.tril(ones), mstype.float32)
|
||||||
|
self.multiply = P.Mul().shard(((1, 1, 1), (1, 1, 1)))
|
||||||
|
|
||||||
|
def construct(self, input_mask):
|
||||||
|
input_shape = P.Shape()(input_mask)
|
||||||
|
shape_right = (input_shape[0], 1, input_shape[1])
|
||||||
|
shape_left = input_shape + (1,)
|
||||||
|
mask_left = self.reshape(input_mask, shape_left)
|
||||||
|
mask_right = self.reshape(input_mask, shape_right)
|
||||||
|
attention_mask = self.mul(mask_left, mask_right)
|
||||||
|
lower_triangle = self.expand_dim(self.lower_triangle_mask, 0)
|
||||||
|
attention_mask = self.multiply(attention_mask, lower_triangle)
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
class EvalNet(nn.Cell):
|
class EvalNet(nn.Cell):
|
||||||
"""
|
"""
|
||||||
PanguAlpha evaluation net
|
PanguAlpha evaluation net
|
||||||
|
@ -1185,7 +1252,8 @@ class EvalNet(nn.Cell):
|
||||||
Returns:
|
Returns:
|
||||||
outputs: Tensor, corresponding output for different tasks
|
outputs: Tensor, corresponding output for different tasks
|
||||||
"""
|
"""
|
||||||
def __init__(self, backbone, generate=False, pad_token=6):
|
|
||||||
|
def __init__(self, backbone, generate=False, pad_token=6, seq_length=1024):
|
||||||
super(EvalNet, self).__init__(auto_prefix=False)
|
super(EvalNet, self).__init__(auto_prefix=False)
|
||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.pad_token = pad_token
|
self.pad_token = pad_token
|
||||||
|
@ -1194,14 +1262,19 @@ class EvalNet(nn.Cell):
|
||||||
self.topk = P.TopK(sorted=True).shard(((1, 1),))
|
self.topk = P.TopK(sorted=True).shard(((1, 1),))
|
||||||
self.gather = P.GatherV2().shard(((1, 1), (1,)))
|
self.gather = P.GatherV2().shard(((1, 1), (1,)))
|
||||||
self.log_softmax = P.LogSoftmax().shard(((1, 1, 1),))
|
self.log_softmax = P.LogSoftmax().shard(((1, 1, 1),))
|
||||||
|
self.get_attention_mask = AttentionMask(seq_length)
|
||||||
|
|
||||||
def construct(self, input_ids, current_index, init_reset=True, batch_valid_length=None):
|
def construct(self, input_ids, current_index, init_reset=True, batch_valid_length=None):
|
||||||
"""evaluation net"""
|
"""evaluation net"""
|
||||||
input_mask = F.cast(F.not_equal(input_ids, self.pad_token), mstype.float32)
|
input_mask = F.cast(F.not_equal(input_ids, self.pad_token), mstype.float32)
|
||||||
logits = self.backbone(input_ids, input_mask)
|
bs, seq_length = F.shape(input_ids)
|
||||||
|
attention_mask = self.get_attention_mask(input_mask)
|
||||||
|
input_position = F.tuple_to_array(F.make_range(seq_length))
|
||||||
|
input_position = P.Tile()(input_position, (bs, 1))
|
||||||
|
logits = self.backbone(input_ids, input_mask, input_position, attention_mask,
|
||||||
|
init_reset, batch_valid_length)
|
||||||
index = current_index.view(1,)
|
index = current_index.view(1,)
|
||||||
logits = self.gather(logits, index, 0)
|
logits = self.gather(logits, index, 0)
|
||||||
bs, _ = F.shape(input_ids)
|
|
||||||
logits = logits.view(bs, 1, -1)
|
logits = logits.view(bs, 1, -1)
|
||||||
log_probs = self.log_softmax(logits)
|
log_probs = self.log_softmax(logits)
|
||||||
return log_probs
|
return log_probs
|
||||||
|
|
|
@ -71,12 +71,14 @@ class FP32StateAdamWeightDecay(AdamWeightDecay):
|
||||||
|
|
||||||
get_square_sum = C.MultitypeFuncGraph("get_square_sum")
|
get_square_sum = C.MultitypeFuncGraph("get_square_sum")
|
||||||
|
|
||||||
|
|
||||||
@get_square_sum.register("Tensor", "Number")
|
@get_square_sum.register("Tensor", "Number")
|
||||||
def _get_square_sum(grad, value):
|
def _get_square_sum(grad, value):
|
||||||
norm = P.ReduceSum(False)(F.square(grad), ()) / value
|
norm = P.ReduceSum(False)(F.square(grad), ()) / value
|
||||||
norm = F.expand_dims(F.cast(norm, mstype.float32), 0)
|
norm = F.expand_dims(F.cast(norm, mstype.float32), 0)
|
||||||
return norm
|
return norm
|
||||||
|
|
||||||
|
|
||||||
apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
|
apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
|
||||||
|
|
||||||
|
|
||||||
|
@ -85,6 +87,7 @@ def _apply_global_norm(clip_norm, global_norm, grad):
|
||||||
grad = grad * clip_norm / global_norm
|
grad = grad * clip_norm / global_norm
|
||||||
return grad
|
return grad
|
||||||
|
|
||||||
|
|
||||||
def _get_model_parallel_group(mp):
|
def _get_model_parallel_group(mp):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -99,11 +102,12 @@ def _get_model_parallel_group(mp):
|
||||||
local_stage_rank_id = rank % per_stage_device_nums
|
local_stage_rank_id = rank % per_stage_device_nums
|
||||||
index = local_stage_rank_id // mp
|
index = local_stage_rank_id // mp
|
||||||
group = range(0, mp)
|
group = range(0, mp)
|
||||||
rank_str_list = [str(x + index * mp + stage_id * per_stage_device_nums) for x in group]
|
rank_str_list = [str(x + index * mp + stage_id * per_stage_device_nums) for x in group]
|
||||||
rank_list_str = "-".join(rank_str_list)
|
rank_list_str = "-".join(rank_str_list)
|
||||||
rank_list = [x + index * mp + stage_id * per_stage_device_nums for x in group]
|
rank_list = [x + index * mp + stage_id * per_stage_device_nums for x in group]
|
||||||
return rank_list, rank_list_str
|
return rank_list, rank_list_str
|
||||||
|
|
||||||
|
|
||||||
def _get_pipeline_group():
|
def _get_pipeline_group():
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -121,10 +125,12 @@ def _get_pipeline_group():
|
||||||
rank_list_str = "-".join(rank_str_list)
|
rank_list_str = "-".join(rank_str_list)
|
||||||
return rank_list, rank_list_str
|
return rank_list, rank_list_str
|
||||||
|
|
||||||
|
|
||||||
class GlobalNorm(nn.Cell):
|
class GlobalNorm(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Calculate the global norm value of given tensors
|
Calculate the global norm value of given tensors
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params, config):
|
def __init__(self, params, config):
|
||||||
super(GlobalNorm, self).__init__()
|
super(GlobalNorm, self).__init__()
|
||||||
self.norm = nn.Norm()
|
self.norm = nn.Norm()
|
||||||
|
@ -165,12 +171,14 @@ class GlobalNorm(nn.Cell):
|
||||||
global_norms = F.sqrt(P.AllReduce()(square_reduce_sum))
|
global_norms = F.sqrt(P.AllReduce()(square_reduce_sum))
|
||||||
return global_norms
|
return global_norms
|
||||||
|
|
||||||
|
|
||||||
class ClipByGlobalNorm(nn.Cell):
|
class ClipByGlobalNorm(nn.Cell):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Clip grads by global norm
|
Clip grads by global norm
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params, config, clip_norm=1.0):
|
def __init__(self, params, config, clip_norm=1.0):
|
||||||
super(ClipByGlobalNorm, self).__init__()
|
super(ClipByGlobalNorm, self).__init__()
|
||||||
self.global_norm = GlobalNorm(params, config)
|
self.global_norm = GlobalNorm(params, config)
|
||||||
|
@ -185,6 +193,7 @@ class ClipByGlobalNorm(nn.Cell):
|
||||||
grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads)
|
grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads)
|
||||||
return grads, global_norm_value
|
return grads, global_norm_value
|
||||||
|
|
||||||
|
|
||||||
class LearningRate(LearningRateSchedule):
|
class LearningRate(LearningRateSchedule):
|
||||||
"""
|
"""
|
||||||
Warmup-decay learning rate for PanguAlpha network.
|
Warmup-decay learning rate for PanguAlpha network.
|
||||||
|
@ -259,6 +268,11 @@ def add_inference_params(opt):
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=0,
|
||||||
help="Whether use pynative op for postproecess")
|
help="Whether use pynative op for postproecess")
|
||||||
|
opt.add_argument("--use_past",
|
||||||
|
type=str,
|
||||||
|
default="true",
|
||||||
|
choices=["true", "false"],
|
||||||
|
help="Whether enable state reuse")
|
||||||
|
|
||||||
|
|
||||||
def add_training_params(opt):
|
def add_training_params(opt):
|
||||||
|
|
Loading…
Reference in New Issue