forked from mindspore-Ecosystem/mindspore
!19308 Serving, pangu alpha update
Merge pull request !19308 from 徐永飞/master
This commit is contained in:
commit
03ba1d08b3
|
@ -138,12 +138,13 @@ ${FILE_PATH}/tokenizer/ ${FILE_PATH}/checkpoint_file filitered 2.6B fp16
|
|||
|
||||
In directory serving:
|
||||
|
||||
- Download [PanGu-Alpha tokenizer repository](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha.git) and copy pangu-alpha/tokenizer to directory flask/tokenizer.
|
||||
- Use scripts/run_distribute_export.sh to export MindIR models, and copy all device* to serving_increment/models/.
|
||||
- Download [PanGu-Alpha tokenizer repository](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha.git) and copy pangu-alpha/tokenizer to directory pangu/tokenizer.
|
||||
- Pip install MindSpore and MindSpore Serving 1.2 whl package.
|
||||
- Pip install flask, flask-apscheduler, jieba, sentencepiece whl package.
|
||||
- Edit server_agent.py and update the path of pangu-alpha models.
|
||||
- Run 'bash start_pangu.sh' to start new execution.
|
||||
- Wait for serving to start successfully: observe the serving_server.log file until the message "Serving master: wait for Ctrl+C to exit" is output.
|
||||
- Wait for serving to start successfully: observe the serving_server.log file until the message "Serving: gRPC server start success, listening on 127.0.0.1:5500" is output.
|
||||
- If any error happened, log can be viewed in serving_server.log, serving_agent.log and flask.log.
|
||||
- If anything all right, access address {ip}:5000 in one browser.
|
||||
- Run 'bash stop_pangu.sh' to stop the existing execution.
|
||||
|
|
|
@ -16,23 +16,26 @@
|
|||
PanGu predict run
|
||||
"""
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.model import Model
|
||||
import mindspore.communication.management as D
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_distributed_checkpoint
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
|
||||
import mindspore.communication.management as D
|
||||
from mindspore import context, Tensor
|
||||
from mindspore import export
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.parallel import set_algo_parameters
|
||||
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_distributed_checkpoint
|
||||
from src.pangu_alpha import PanguAlpha, EvalNet
|
||||
from src.pangu_alpha_config import PANGUALPHAConfig, set_parse
|
||||
from src.utils import get_args
|
||||
|
||||
|
||||
def run_predict(args_opt):
|
||||
def load_model(args_opt):
|
||||
r"""
|
||||
The main function for running prediction
|
||||
The main function for load model
|
||||
"""
|
||||
device_id = int(os.getenv("DEVICE_ID"))
|
||||
rank_id_str = os.getenv('RANK_ID', '0')
|
||||
|
@ -73,6 +76,9 @@ def run_predict(args_opt):
|
|||
rank = 0
|
||||
device_num = 1
|
||||
|
||||
use_past = False
|
||||
if args_opt.export:
|
||||
use_past = True
|
||||
# Set model property
|
||||
model_parallel_num = args_opt.op_level_model_parallel_num
|
||||
data_parallel_num = int(device_num / model_parallel_num)
|
||||
|
@ -91,7 +97,7 @@ def run_predict(args_opt):
|
|||
post_layernorm_residual=False,
|
||||
dropout_rate=0.0,
|
||||
compute_dtype=mstype.float16,
|
||||
use_past=False,
|
||||
use_past=use_past,
|
||||
self_layernorm=True,
|
||||
stage_num=args_opt.stage_num,
|
||||
micro_size=args_opt.micro_size,
|
||||
|
@ -108,7 +114,6 @@ def run_predict(args_opt):
|
|||
eval_net = EvalNet(pangu_alpha)
|
||||
eval_net.set_train(False)
|
||||
model_predict = Model(eval_net)
|
||||
|
||||
# Compile network and obtain tensor layout for loading ckpt
|
||||
inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32)
|
||||
current_index = Tensor(np.array([0]), mstype.int32)
|
||||
|
@ -130,7 +135,29 @@ def run_predict(args_opt):
|
|||
# Load checkpoint files
|
||||
load_distributed_checkpoint(eval_net, ckpt_file_list, predict_layout)
|
||||
print("================load param ok=================", flush=True)
|
||||
return model_predict, config
|
||||
|
||||
|
||||
def export_mindir(model_predict, config):
|
||||
"""Export mindir model"""
|
||||
inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32)
|
||||
current_index = Tensor(np.array([0]), mstype.int32)
|
||||
|
||||
batch_valid_length = Tensor(np.array([0]), mstype.int32)
|
||||
init_true = Tensor([True], mstype.bool_)
|
||||
inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32)
|
||||
|
||||
model_predict.predict_network.add_flags_recursive(is_first_iteration=True)
|
||||
export(model_predict.predict_network, inputs_np, current_index,
|
||||
init_true, batch_valid_length, file_name='pangu_alpha_1024', file_format='MINDIR')
|
||||
model_predict.predict_network.add_flags_recursive(is_first_iteration=False)
|
||||
export(model_predict.predict_network, inputs_np_1, current_index,
|
||||
init_true, batch_valid_length, file_name='pangu_alpha_1', file_format='MINDIR')
|
||||
print("Export finished and now exit.")
|
||||
|
||||
|
||||
def run_predict(model_predict, config, args_opt):
|
||||
"""run predict"""
|
||||
from src.tokenization_jieba import JIEBATokenizer
|
||||
from src.generate import generate, generate_increment
|
||||
# Define tokenizer
|
||||
|
@ -144,12 +171,22 @@ def run_predict(args_opt):
|
|||
input_ids = np.array(start_sentence).reshape(1, -1)
|
||||
# Call inference
|
||||
generate_func = generate_increment if config.use_past else generate
|
||||
output_ids = generate_func(model_predict, input_ids, opt)
|
||||
output_ids = generate_func(model_predict, input_ids, args_opt)
|
||||
# Decode output ids to sentence
|
||||
output_samples = tokenizer.convert_ids_to_tokens(output_ids.tolist())
|
||||
print('Output is:', output_samples, flush=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def main():
|
||||
"""Main process for predict or export model"""
|
||||
opt = get_args(True)
|
||||
set_parse(opt)
|
||||
run_predict(opt)
|
||||
model_predict, config = load_model(opt)
|
||||
if opt.export:
|
||||
export_mindir(model_predict, config)
|
||||
else:
|
||||
run_predict(model_predict, config, opt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
#!/bin/bash
|
||||
execute_path=$(pwd)
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=${execute_path}/../serving_increment/hccl_8p.json
|
||||
export MODE=13B
|
||||
export STRATEGY=$1
|
||||
export CKPT_PATH=$2
|
||||
export CKPT_NAME=$3
|
||||
export PARAM_INIT_TYPE=$4
|
||||
|
||||
for((i=0;i<$RANK_SIZE;i++));
|
||||
do
|
||||
rm -rf ${execute_path}/device_$i/
|
||||
mkdir ${execute_path}/device_$i/
|
||||
cd ${execute_path}/device_$i/ || exit
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
python -s ${self_path}/../predict.py --strategy_load_ckpt_path=$STRATEGY --load_ckpt_path=$CKPT_PATH \
|
||||
--load_ckpt_name=$CKPT_NAME --mode=$MODE --run_type=predict --param_init_type=$PARAM_INIT_TYPE \
|
||||
--export=1 >log$i.log 2>&1 &
|
||||
done
|
|
@ -1,82 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""servable config for pangu alpha"""
|
||||
|
||||
|
||||
from mindspore_serving.worker import register
|
||||
from mindspore_serving.worker import distributed
|
||||
import numpy as np
|
||||
|
||||
# define preprocess pipeline, the function arg is multi instances, every instance is tuple of inputs
|
||||
# this example has one input and one output
|
||||
|
||||
seq_length = 1024
|
||||
|
||||
|
||||
def preprocess(input_tokens):
|
||||
"""Preprocess, padding for input"""
|
||||
_, valid_length = input_tokens.shape
|
||||
token_ids = np.pad(input_tokens, ((0, 0), (0, seq_length - valid_length)), 'constant', constant_values=(0, 6))
|
||||
token_ids = token_ids.astype(np.int32)
|
||||
return token_ids, valid_length
|
||||
|
||||
|
||||
def topk_fun(logits, valid_length, topk=5):
|
||||
"""Get topk"""
|
||||
target_column = logits[valid_length - 1, :].tolist()
|
||||
sorted_array = [(k, v) for k, v in enumerate(target_column)]
|
||||
sorted_array.sort(key=lambda x: x[1], reverse=True)
|
||||
topk_array = sorted_array[:topk]
|
||||
index, value = zip(*topk_array)
|
||||
index = np.array(index)
|
||||
value = np.array(value)
|
||||
return index, value
|
||||
|
||||
|
||||
def postprocess_topk(logits, valid_length):
|
||||
"""Postprocess for one output"""
|
||||
p_args, p = topk_fun(logits, valid_length, 5)
|
||||
p = p / sum(p)
|
||||
target_index = np.random.choice(len(p), p=p)
|
||||
target = p_args[target_index]
|
||||
return target
|
||||
|
||||
|
||||
def postprocess(p, p_args, valid_length):
|
||||
"""Postprocess for two output"""
|
||||
p = p[valid_length - 1]
|
||||
p_args = p_args[valid_length - 1]
|
||||
p = p / sum(p)
|
||||
target_index = np.random.choice(len(p), p=p)
|
||||
target = p_args[target_index]
|
||||
return target
|
||||
|
||||
|
||||
distributed.declare_distributed_servable(rank_size=8, stage_size=1, with_batch_dim=False)
|
||||
|
||||
|
||||
@register.register_method(output_names=["add_token"])
|
||||
def predict(input_tokens):
|
||||
"""register predict method in pangu-alpha"""
|
||||
token_ids, valid_length = register.call_preprocess(preprocess, input_tokens)
|
||||
############# two output ###################
|
||||
# p, p_args = register.call_servable(token_ids)
|
||||
# add_token = register.call_postprocess(postprocess, p, p_args, valid_length)
|
||||
#############################################
|
||||
|
||||
################# one output ####################
|
||||
logits = register.call_servable(token_ids)
|
||||
add_token = register.call_postprocess(postprocess_topk, logits, valid_length)
|
||||
return add_token
|
|
@ -14,65 +14,28 @@
|
|||
# ============================================================================
|
||||
"""flask server, Serving client"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
from flask import Flask, request, jsonify, render_template
|
||||
from flask_apscheduler import APScheduler
|
||||
from mindspore_serving.client import Client
|
||||
from tokenization_jieba import JIEBATokenizer
|
||||
|
||||
cur_dir = os.path.abspath(os.path.dirname(__file__))
|
||||
tokenizer_path = os.path.join(cur_dir, "tokenizer")
|
||||
tokenizer = JIEBATokenizer(os.path.join(tokenizer_path, "vocab.vocab"), os.path.join(tokenizer_path, "vocab.model"))
|
||||
end_token = tokenizer.eot_id
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
def generate(input_sentence):
|
||||
"""nerate sentence with given input_sentence"""
|
||||
client = Client("localhost", 5500, "pangu", "predict")
|
||||
client = Client("localhost:5500", "pangu", "predict")
|
||||
|
||||
print(f"----------------------------- begin {input_sentence} ---------", flush=True)
|
||||
tokens = tokenizer.tokenize(input_sentence)
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
input_len = len(input_ids)
|
||||
|
||||
seq_length = 1024
|
||||
generate_number = 0
|
||||
end_flag = False
|
||||
|
||||
instance = {"input_sentence": input_sentence}
|
||||
time_start = time.time()
|
||||
while generate_number < 50:
|
||||
if len(input_ids) >= seq_length - 1:
|
||||
break
|
||||
result = client.infer(instance)
|
||||
reply = result[0]["output_sentence"]
|
||||
|
||||
time0 = time.time()
|
||||
instance = {"input_tokens": np.array([input_ids])}
|
||||
target = client.infer([instance])
|
||||
target = int(target[0]["add_token"])
|
||||
print(f"request '{input_sentence}' add token {generate_number}: {target}, "
|
||||
f"time cost {(time.time() - time0) * 1000}ms", flush=True)
|
||||
if target == end_token:
|
||||
if len(input_ids) == input_len:
|
||||
continue
|
||||
end_flag = True
|
||||
break
|
||||
print(f"time cost {(time.time() - time_start) * 1000}ms, request '{input_sentence}' get reply '{reply}'",
|
||||
flush=True)
|
||||
|
||||
input_ids.append(target)
|
||||
generate_number += 1
|
||||
|
||||
outputs = input_ids[input_len:]
|
||||
return_tokens = tokenizer.convert_ids_to_tokens(outputs)
|
||||
reply = "".join(return_tokens)
|
||||
if reply:
|
||||
break
|
||||
|
||||
print(f"time cost {(time.time() - time_start) * 1000}ms, request '{input_sentence}' get reply '{reply}'"
|
||||
f" end flag{end_flag}", flush=True)
|
||||
|
||||
return reply, end_flag
|
||||
return reply
|
||||
|
||||
|
||||
@app.route('/query')
|
|
@ -44,7 +44,7 @@
|
|||
}
|
||||
document.getElementById('inputId').onkeydown = onTextareaKeydown;
|
||||
|
||||
function get_sentence(resp, sentence, count, session_id) {
|
||||
function get_sentence(resp, sentence, session_id) {
|
||||
const payload = new URLSearchParams()
|
||||
payload.set("u", sentence)
|
||||
fetch('/query?' + payload.toString()).then(resp => resp.json()).then(rv => {
|
||||
|
@ -52,12 +52,9 @@
|
|||
return;
|
||||
}
|
||||
if (rv.ok) {
|
||||
new_sentence = sentence + rv.rsvp;
|
||||
new_sentence = rv.rsvp;
|
||||
resp.innerHTML = new_sentence
|
||||
$(".output").val(resp.innerHTML);//请求结果放入输出框中
|
||||
if (count > 0 && !rv.end_flag) {
|
||||
get_sentence(resp, new_sentence, count - 1, session_id)
|
||||
}
|
||||
} else {
|
||||
resp.innerHTML = "something is wrong.";
|
||||
}
|
||||
|
@ -74,7 +71,7 @@
|
|||
}
|
||||
newText = input.replace(new RegExp(/( )/g), "").replace(/\n/g, ' NEWLINE ');
|
||||
g_session_id += 1
|
||||
get_sentence(resp, newText, 200, g_session_id)
|
||||
get_sentence(resp, newText, g_session_id)
|
||||
}
|
||||
</script>
|
||||
</body>
|
|
@ -0,0 +1,203 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""servable config for pangu alpha"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from easydict import EasyDict
|
||||
import numpy as np
|
||||
from mindspore_serving.server import register
|
||||
from mindspore_serving.server import distributed
|
||||
|
||||
from pangu.tokenization_jieba import JIEBATokenizer
|
||||
|
||||
cur_dir = os.path.abspath(os.path.dirname(__file__))
|
||||
tokenizer_path = os.path.join(cur_dir, "tokenizer")
|
||||
tokenizer = JIEBATokenizer(os.path.join(tokenizer_path, "vocab.vocab"), os.path.join(tokenizer_path, "vocab.model"))
|
||||
end_token = tokenizer.eot_id
|
||||
|
||||
config = EasyDict({
|
||||
'frequency_penalty': 1.5,
|
||||
'presence_penalty': 0.3,
|
||||
'max_generate_length': 500,
|
||||
'top_k_num': 3,
|
||||
'top_p': 1.0,
|
||||
'end_token': 9,
|
||||
'seq_length': 1024,
|
||||
'vocab_size': 40000,
|
||||
})
|
||||
|
||||
|
||||
def topk_fun(logits, topk=5):
|
||||
"""Get topk"""
|
||||
target_column = logits[0].tolist()
|
||||
sorted_array = [(k, v) for k, v in enumerate(target_column)]
|
||||
sorted_array.sort(key=lambda x: x[1], reverse=True)
|
||||
topk_array = sorted_array[:topk]
|
||||
index, value = zip(*topk_array)
|
||||
index = np.array([index])
|
||||
value = np.array([value])
|
||||
return value, index
|
||||
|
||||
|
||||
distributed.declare_servable(rank_size=8, stage_size=1, with_batch_dim=False)
|
||||
|
||||
|
||||
@register.register_method(output_names=["logits"])
|
||||
def predict_sub0(input_ids, current_index, init, batch_valid_length):
|
||||
logits = register.call_servable(input_ids, current_index, init, batch_valid_length, subgraph=0)
|
||||
return logits
|
||||
|
||||
|
||||
@register.register_method(output_names=["logits"])
|
||||
def predict_sub1(input_id, current_index, init, batch_valid_length):
|
||||
logits = register.call_servable(input_id, current_index, init, batch_valid_length, subgraph=1)
|
||||
return logits
|
||||
|
||||
|
||||
sub0_servable = register.PipelineServable(servable_name="pangu", method="predict_sub0")
|
||||
sub1_servable = register.PipelineServable(servable_name="pangu", method="predict_sub1")
|
||||
|
||||
|
||||
@register.register_pipeline(output_names=["output_sentence"])
|
||||
def predict(input_sentence):
|
||||
"""generate sentence with given input_sentence"""
|
||||
|
||||
print(f"----------------------------- begin {input_sentence} ---------", flush=True)
|
||||
time_start = time.time()
|
||||
|
||||
tokens = tokenizer.tokenize(input_sentence)
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
|
||||
outputs = generate_increment(input_ids)
|
||||
|
||||
return_tokens = tokenizer.convert_ids_to_tokens(outputs)
|
||||
reply = "".join(return_tokens)
|
||||
|
||||
print(f"time cost {(time.time() - time_start) * 1000}ms, request '{input_sentence}' get reply '{reply}'",
|
||||
flush=True)
|
||||
|
||||
return reply
|
||||
|
||||
|
||||
def generate_increment(origin_inputs):
|
||||
"""
|
||||
Text generation for incremental inference
|
||||
|
||||
Inputs:
|
||||
model: the model for inferencing
|
||||
origin_inputs: the original inputs based on which the model will continue writing
|
||||
config: inference configurations
|
||||
|
||||
Returns:
|
||||
outputs: the ids for the generated text
|
||||
"""
|
||||
# Get configurations for inference
|
||||
frequency_penalty = config.frequency_penalty
|
||||
presence_penalty = config.presence_penalty
|
||||
top_p = config.top_p
|
||||
top_k_num = config.top_k_num
|
||||
max_generate_length = config.max_generate_length
|
||||
seq_length = config.seq_length
|
||||
vocab_size = config.vocab_size
|
||||
|
||||
# Init outputs with original inputs
|
||||
outputs = origin_inputs
|
||||
origin_inputs = np.array([origin_inputs])
|
||||
_, valid_length = origin_inputs.shape
|
||||
# If target length exceeds seq_length, use seq_length instead
|
||||
target_length = valid_length + max_generate_length
|
||||
target_length = seq_length if target_length > seq_length else target_length
|
||||
|
||||
# A list of the frequency of each token
|
||||
frequency_list = np.array([[0 for _ in range(vocab_size)]])
|
||||
pad_length = seq_length - origin_inputs.shape[-1]
|
||||
# Pad original inputs to seq_length
|
||||
input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), 'constant', constant_values=(0, 0))
|
||||
|
||||
# Indicate the exact token position
|
||||
current_index = valid_length - 1 if valid_length - 1 > 0 else 0
|
||||
current_index = np.array([current_index], np.int32)
|
||||
batch_valid_length = np.array([current_index], np.int32)
|
||||
# For first graph, not_init should be false
|
||||
init_true = True
|
||||
init_false = False
|
||||
init = init_false
|
||||
# Call a single inference with input size of (bs, seq_length)
|
||||
logits = sub0_servable.run(np.array(input_ids, np.int32), current_index, init, batch_valid_length)
|
||||
|
||||
# Claim the second graph and set not_init to true
|
||||
init = init_true
|
||||
|
||||
# A single loop generates one token, loop until reaching target seq_length or generating eod token
|
||||
while valid_length < target_length:
|
||||
# Reshape the output logits
|
||||
log_probs = logits.reshape(1, vocab_size)
|
||||
|
||||
# Get the revised log_probs considering frequency and presence penalty to eliminate duplicate in generated results
|
||||
log_probs = log_probs.reshape(1, vocab_size)
|
||||
log_probs_revised = log_probs - frequency_list * frequency_penalty - (frequency_list > 0) * presence_penalty
|
||||
|
||||
# Convert the log_probs to probability
|
||||
logits = np.power(10, np.array(log_probs_revised, np.float32))
|
||||
|
||||
# If top_p is less than 1.0, use top_p sampling
|
||||
if top_p < 1.0:
|
||||
# Only consider the 5000 largest logits to reduce computation
|
||||
sorted_logits, index = topk_fun(logits, 5000)
|
||||
cumsum_logits = np.cumsum(sorted_logits, 1)
|
||||
cumsum_logits = cumsum_logits[0]
|
||||
index = index[0]
|
||||
sorted_logits = sorted_logits[0]
|
||||
top_p_num = sum(cumsum_logits > top_p)
|
||||
# In case the probability is smooth, the sum of 5000 largest probabilities are not large enough
|
||||
if top_p_num == 0:
|
||||
top_p_num = 5000
|
||||
# Get the corresponding probs and indices
|
||||
probs = sorted_logits[:top_p_num]
|
||||
p_args = index[:top_p_num]
|
||||
p = probs / sum(probs)
|
||||
# if top_p is set to 1.0, use top_k sampling
|
||||
else:
|
||||
# Get the corresponding probs and indices
|
||||
probs, p_args = topk_fun(logits, top_k_num)
|
||||
probs = probs[0]
|
||||
p_args = p_args[0]
|
||||
# Avoid rounding error
|
||||
if sum(probs) == 0:
|
||||
probs = np.array([1 / top_k_num for _ in range(top_k_num)])
|
||||
p = probs / sum(probs)
|
||||
|
||||
# Random select a token as final output for this round
|
||||
target_index = np.random.choice(len(p), p=p)
|
||||
# Stop judgment
|
||||
if p_args[target_index] == end_token or valid_length == target_length - 1:
|
||||
break
|
||||
|
||||
# Update frequency list
|
||||
target = p_args[target_index]
|
||||
frequency_list[0][target] = frequency_list[0][target] + 1
|
||||
valid_length += 1
|
||||
|
||||
batch_valid_length = np.array([valid_length - 1], np.int32)
|
||||
current_index = np.array([0], np.int32)
|
||||
input_id = np.array([[target]], np.int32)
|
||||
# Update outputs with current generated token
|
||||
outputs.append(int(target))
|
||||
|
||||
# Call a single inference with input size of (bs, 1)
|
||||
logits = sub1_servable.run(input_id, current_index, init, batch_valid_length)
|
||||
# Return valid outputs out of padded outputs
|
||||
return outputs
|
|
@ -23,6 +23,7 @@ import sentencepiece as spm
|
|||
|
||||
class JIEBATokenizer:
|
||||
"""jieba tokenizer for encode and decode text"""
|
||||
|
||||
def __init__(self, vocab_file, model_file, max_len=None):
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
# self.encoder = json.load(open(vocab_file))
|
|
@ -15,15 +15,16 @@
|
|||
# ============================================================================
|
||||
"""Serving agents startup code, load and execute models of pangu alpha"""
|
||||
|
||||
from mindspore_serving.worker.distributed import agent_startup
|
||||
from mindspore_serving.server import distributed
|
||||
|
||||
|
||||
def start():
|
||||
"""Start agents to load and execute models of pangu alpha"""
|
||||
model_files = []
|
||||
for i in range(8):
|
||||
model_files.append(f"models/device_{i}/pangu_alpha_graph.mindir")
|
||||
agent_startup.startup_worker_agents(worker_ip="0.0.0.0", worker_port=6200, model_files=model_files)
|
||||
model_files.append([f"models/device{i}/pangu_alpha_1024_graph.mindir",
|
||||
f"models/device{i}/pangu_alpha_1_graph.mindir"])
|
||||
distributed.startup_agents(distributed_address="0.0.0.0:6200", model_files=model_files)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
|
@ -15,19 +15,18 @@
|
|||
"""Serving server start code, serve service, and manage all agents which load and execute models"""
|
||||
|
||||
import os
|
||||
from mindspore_serving import master
|
||||
from mindspore_serving.worker.distributed import distributed_worker as worker
|
||||
from mindspore_serving import server
|
||||
from mindspore_serving.server import distributed
|
||||
|
||||
|
||||
def start():
|
||||
"""Start server to serve service, and manage all agents which load and execute models"""
|
||||
servable_dir = os.path.abspath(".")
|
||||
|
||||
worker.start_distributed_servable_in_master(servable_dir, "pangu", rank_table_json_file="hccl_8p.json",
|
||||
version_number=1, worker_ip="0.0.0.0", worker_port=6200,
|
||||
wait_agents_time_in_seconds=0)
|
||||
distributed.start_servable(servable_dir, "pangu", rank_table_json_file="hccl_8p.json",
|
||||
distributed_address="0.0.0.0:6200")
|
||||
|
||||
master.start_grpc_server("127.0.0.1", 5500)
|
||||
server.start_grpc_server("127.0.0.1:5500")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
|
@ -13,9 +13,9 @@ start_serving_server()
|
|||
echo "serving server failed to start."
|
||||
fi
|
||||
|
||||
result=`grep -E 'Begin waiting ready of all agents' serving_server.log | wc -l`
|
||||
result=`grep -E 'Master server start success' serving_server.log | wc -l`
|
||||
count=0
|
||||
while [[ ${result} -ne 1 && ${count} -lt 100 ]]
|
||||
while [[ ${result} -eq 0 && ${count} -lt 100 ]]
|
||||
do
|
||||
sleep 1
|
||||
|
||||
|
@ -26,7 +26,7 @@ start_serving_server()
|
|||
fi
|
||||
|
||||
count=$(($count+1))
|
||||
result=`grep -E 'Begin waiting ready of all agents' serving_server.log | wc -l`
|
||||
result=`grep -E 'Master server start success' serving_server.log | wc -l`
|
||||
done
|
||||
|
||||
if [ ${count} -eq 100 ]
|
||||
|
@ -38,7 +38,7 @@ start_serving_server()
|
|||
|
||||
start_serving_agent()
|
||||
{
|
||||
echo "### start serving agent, see serving_agent.log for detail ###"
|
||||
echo "### start serving agent, see and serving_logs/log_pangu_distributed.log for detail ###"
|
||||
python3 serving_agent.py > serving_agent.log 2>&1 &
|
||||
if [ $? -ne 0 ]
|
||||
then
|
||||
|
@ -54,7 +54,7 @@ start_serving_agent()
|
|||
if [ $num -eq 0 ]
|
||||
then
|
||||
bash stop_pangu.sh
|
||||
echo "start serving agent failed, see log serving_agent.log for more detail" && exit 1
|
||||
echo "start serving agent failed, see log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
|
||||
fi
|
||||
|
||||
count=$(($count+1))
|
||||
|
@ -64,7 +64,7 @@ start_serving_agent()
|
|||
if [ ${count} -eq 1800 ]
|
||||
then
|
||||
bash stop_pangu.sh
|
||||
echo "start serving agent failed, see log serving_agent.log for more detail" && exit 1
|
||||
echo "start serving agent failed, see log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
|
||||
fi
|
||||
echo "### start serving agent end ###"
|
||||
}
|
||||
|
@ -104,8 +104,37 @@ start_flask()
|
|||
cat flask.log
|
||||
}
|
||||
|
||||
wait_serving_ready()
|
||||
{
|
||||
echo "### waiting serving server ready, see and serving_logs/log_pangu_distributed.log for detail ###"
|
||||
result=`grep -E 'gRPC server start success' serving_server.log | wc -l`
|
||||
count=0
|
||||
while [[ ${result} -eq 0 && ${count} -lt 100 ]]
|
||||
do
|
||||
sleep 1
|
||||
|
||||
num=`ps -ef | grep 'serving_server.py' | grep -v grep | wc -l`
|
||||
if [ $num -eq 0 ]
|
||||
then
|
||||
bash stop_pangu.sh
|
||||
echo "waiting serving server ready failed, see log serving_server.log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
|
||||
fi
|
||||
|
||||
count=$(($count+1))
|
||||
result=`grep -E 'gRPC server start success' serving_server.log | wc -l`
|
||||
done
|
||||
|
||||
if [ ${count} -eq 100 ]
|
||||
then
|
||||
bash stop_pangu.sh
|
||||
echo "waiting serving server ready failed, see log serving_server.log and serving_logs/log_pangu_distributed.log for more detail" && exit 1
|
||||
fi
|
||||
echo "### waiting serving server ready end ###"
|
||||
}
|
||||
|
||||
bash stop_pangu.sh
|
||||
rm -f serving_server.log serving_agent.log flask.log
|
||||
rm -rf serving_server.log serving_agent.log flask.log serving_logs
|
||||
start_serving_server
|
||||
start_serving_agent
|
||||
start_flask
|
||||
wait_serving_ready
|
|
@ -11,6 +11,13 @@ kill_serving_9()
|
|||
echo "Send kill -9 msg to serving_server.py process"
|
||||
fi
|
||||
|
||||
num=`ps -ef | grep start_distributed_worker.py | grep -v grep | wc -l`
|
||||
if [ $num -ne 0 ]
|
||||
then
|
||||
ps aux | grep 'start_distributed_worker.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -9
|
||||
echo "Send kill -9 msg to start_distributed_worker.py process"
|
||||
fi
|
||||
|
||||
num=`ps -ef | grep serving_agent.py | grep -v grep | wc -l`
|
||||
if [ $num -ne 0 ]
|
||||
then
|
|
@ -22,6 +22,64 @@ import mindspore.common.dtype as mstype
|
|||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
def topk_fun(logits, topk=5):
|
||||
"""Get topk"""
|
||||
target_column = logits[0].tolist()
|
||||
sorted_array = [(k, v) for k, v in enumerate(target_column)]
|
||||
sorted_array.sort(key=lambda x: x[1], reverse=True)
|
||||
topk_array = sorted_array[:topk]
|
||||
index, value = zip(*topk_array)
|
||||
index = np.array([index])
|
||||
value = np.array([value])
|
||||
return value, index
|
||||
|
||||
def sampler(log_probs_revised, top_p, top_k_num, use_pynative=False):
|
||||
"""Convert the log_probs to probability"""
|
||||
if use_pynative:
|
||||
logits = P.Pow()(10, Tensor(log_probs_revised, mstype.float32))
|
||||
else:
|
||||
logits = np.power(10, np.array(log_probs_revised, np.float32))
|
||||
|
||||
# If top_p is less than 1.0, use top_p sampling
|
||||
if top_p < 1.0:
|
||||
# Only consider the 5000 largest logits to reduce computation
|
||||
if use_pynative:
|
||||
sorted_logits, index = P.TopK(sorted=True)(logits, 5000)
|
||||
cumsum_logits = P.CumSum()(sorted_logits, 1)
|
||||
cumsum_logits = cumsum_logits.asnumpy()
|
||||
index = index.asnumpy()
|
||||
sorted_logits = sorted_logits.asnumpy()
|
||||
else:
|
||||
sorted_logits, index = topk_fun(logits, 5000)
|
||||
cumsum_logits = np.cumsum(sorted_logits, 1)
|
||||
cumsum_logits = cumsum_logits[0]
|
||||
index = index[0]
|
||||
sorted_logits = sorted_logits[0]
|
||||
top_p_num = sum(cumsum_logits > top_p)
|
||||
# In case the probability is smooth, the sum of 5000 largest probabilities are not large enough
|
||||
if top_p_num == 0:
|
||||
top_p_num = 5000
|
||||
# Get the corresponding probs and indices
|
||||
probs = sorted_logits[:top_p_num]
|
||||
p_args = index[:top_p_num]
|
||||
p = probs / sum(probs)
|
||||
# if top_p is set to 1.0, use top_k sampling
|
||||
else:
|
||||
# Get the corresponding probs and indices
|
||||
if use_pynative:
|
||||
probs, p_args = P.TopK(sorted=True)(logits, top_k_num)
|
||||
probs = probs.asnumpy()
|
||||
p_args = p_args.asnumpy()
|
||||
else:
|
||||
probs, p_args = topk_fun(logits, top_k_num)
|
||||
probs = probs[0]
|
||||
p_args = p_args[0]
|
||||
# Avoid rounding error
|
||||
if sum(probs) == 0:
|
||||
probs = np.array([1 / top_k_num for _ in range(top_k_num)])
|
||||
p = probs / sum(probs)
|
||||
return p, p_args
|
||||
|
||||
def generate(model, origin_inputs, config):
|
||||
"""
|
||||
Text generation
|
||||
|
@ -42,6 +100,7 @@ def generate(model, origin_inputs, config):
|
|||
max_generate_length = config.max_generate_length
|
||||
seq_length = config.seq_length
|
||||
end_token = config.end_token
|
||||
use_pynative = config.use_pynative_op
|
||||
|
||||
_, valid_length = origin_inputs.shape
|
||||
# If target length exceeds seq_length, use seq_length instead
|
||||
|
@ -67,36 +126,7 @@ def generate(model, origin_inputs, config):
|
|||
log_probs = log_probs.asnumpy().reshape(1, config.vocab_size)
|
||||
log_probs_revised = log_probs - frequency_list * frequency_penalty - (frequency_list > 0) * presence_penalty
|
||||
|
||||
# Convert the log_probs to probability
|
||||
logits = P.Pow()(10, Tensor(log_probs_revised, mstype.float32))
|
||||
|
||||
# If top_p is less than 1.0, use top_p sampling
|
||||
if top_p < 1.0:
|
||||
# Only consider the 5000 largest logits to reduce computation
|
||||
sorted_logits, index = P.TopK(sorted=True)(logits, 5000)
|
||||
cumsum_logits = P.CumSum()(sorted_logits, 1)
|
||||
cumsum_logits = cumsum_logits.asnumpy()[0]
|
||||
index = index.asnumpy()[0]
|
||||
sorted_logits = sorted_logits.asnumpy()[0]
|
||||
top_p_num = sum(cumsum_logits > top_p)
|
||||
# In case the probability is smooth, the sum of 5000 largest probabilities are not large enough
|
||||
if top_p_num == 0:
|
||||
top_p_num = 5000
|
||||
# Get the corresponding probs and indices
|
||||
probs = sorted_logits[:top_p_num]
|
||||
p_args = index[:top_p_num]
|
||||
p = probs / sum(probs)
|
||||
# if top_p is set to 1.0, use top_k sampling
|
||||
else:
|
||||
# Get the corresponding probs and indices
|
||||
probs, p_args = P.TopK(sorted=True)(logits, top_k_num)
|
||||
probs = probs.asnumpy()[0]
|
||||
p_args = p_args.asnumpy()[0]
|
||||
# Avoid rounding error
|
||||
if sum(probs) == 0:
|
||||
probs = np.array([1 / top_k_num for _ in range(top_k_num)])
|
||||
p = probs / sum(probs)
|
||||
|
||||
p, p_args = sampler(log_probs_revised, top_p, top_k_num, use_pynative)
|
||||
# Random select a token as final output for this round
|
||||
target_index = np.random.choice(len(p), p=p)
|
||||
# Stop judgment
|
||||
|
@ -135,6 +165,7 @@ def generate_increment(model, origin_inputs, config):
|
|||
max_generate_length = config.max_generate_length
|
||||
seq_length = config.seq_length
|
||||
end_token = config.end_token
|
||||
use_pynative = config.use_pynative_op
|
||||
|
||||
_, valid_length = origin_inputs.shape
|
||||
# Init outputs with original inputs
|
||||
|
@ -161,7 +192,7 @@ def generate_increment(model, origin_inputs, config):
|
|||
# Claim the first graph
|
||||
model.predict_network.add_flags_recursive(is_first_iteration=True)
|
||||
# Call a single inference with input size of (bs, seq_length)
|
||||
logits = model.predict(Tensor(input_ids, mstype.int32), init, batch_valid_length, current_index)
|
||||
logits = model.predict(Tensor(input_ids, mstype.int32), current_index, init, batch_valid_length)
|
||||
|
||||
# Claim the second graph and set not_init to true
|
||||
init = init_true
|
||||
|
@ -171,42 +202,13 @@ def generate_increment(model, origin_inputs, config):
|
|||
while valid_length < target_length:
|
||||
# Reshape the output logits
|
||||
logits = logits.asnumpy()
|
||||
log_probs = logits.reshape(1, vocab_size)
|
||||
log_probs = logits.reshape(1, config.vocab_size)
|
||||
|
||||
# Get the revised log_probs considering frequency and presence penalty to eliminate duplicate in generated results
|
||||
log_probs = log_probs.asnumpy().reshape(1, config.vocab_size)
|
||||
log_probs_revised = log_probs - frequency_list * frequency_penalty - (frequency_list > 0) * presence_penalty
|
||||
|
||||
# Convert the log_probs to probability
|
||||
logits = P.Pow()(10, Tensor(log_probs_revised, mstype.float32))
|
||||
|
||||
# If top_p is less than 1.0, use top_p sampling
|
||||
if top_p < 1.0:
|
||||
# Only consider the 5000 largest logits to reduce computation
|
||||
sorted_logits, index = P.TopK(sorted=True)(logits, 5000)
|
||||
cumsum_logits = P.CumSum()(sorted_logits, 1)
|
||||
cumsum_logits = cumsum_logits.asnumpy()[0]
|
||||
index = index.asnumpy()[0]
|
||||
sorted_logits = sorted_logits.asnumpy()[0]
|
||||
top_p_num = sum(cumsum_logits > top_p)
|
||||
# In case the probability is smooth, the sum of 5000 largest probabilities are not large enough
|
||||
if top_p_num == 0:
|
||||
top_p_num = 5000
|
||||
# Get the corresponding probs and indices
|
||||
probs = sorted_logits[:top_p_num]
|
||||
p_args = index[:top_p_num]
|
||||
p = probs / sum(probs)
|
||||
# if top_p is set to 1.0, use top_k sampling
|
||||
else:
|
||||
# Get the corresponding probs and indices
|
||||
probs, p_args = P.TopK(sorted=True)(logits, top_k_num)
|
||||
probs = probs.asnumpy()[0]
|
||||
p_args = p_args.asnumpy()[0]
|
||||
# Avoid rounding error
|
||||
if sum(probs) == 0:
|
||||
probs = np.array([1 / top_k_num for _ in range(top_k_num)])
|
||||
p = probs / sum(probs)
|
||||
|
||||
p, p_args = sampler(log_probs_revised, top_p, top_k_num, use_pynative)
|
||||
# Random select a token as final output for this round
|
||||
target_index = np.random.choice(len(p), p=p)
|
||||
# Stop judgment
|
||||
|
|
|
@ -255,6 +255,10 @@ def add_inference_params(opt):
|
|||
type=int,
|
||||
default=9,
|
||||
help="the token id for <end of document>")
|
||||
opt.add_argument("--use_pynative_op",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Whether use pynative op for postproecess")
|
||||
|
||||
|
||||
def add_training_params(opt):
|
||||
|
@ -409,6 +413,10 @@ def get_args(inference=False):
|
|||
type=int,
|
||||
default=1,
|
||||
help="Running on cloud of not. Default 1.")
|
||||
parser.add_argument("--export",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Whether export mindir for serving.")
|
||||
add_training_params(parser)
|
||||
if inference:
|
||||
add_inference_params(parser)
|
||||
|
|
Loading…
Reference in New Issue